Implementing Linked List In Scala

In Scala, if you wonder why its standard library doesn’t come with a data structure called `LinkedList`, you may have overlooked. The collection List is in fact a linked list — although it often appears in the form of a Seq or Vector collection rather than the generally “mysterious” linked list that exposes its “head” with a hidden “tail” to be revealed only iteratively.

Our ADT: LinkedNode

Perhaps because of its simplicity and dynamicity as a data structure, implementation of linked list remains a popular coding exercise. To implement our own linked list, let’s start with a barebone ADT (algebraic data structure) as follows:

sealed trait LinkedNode[+A]
case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]
case object EmptyNode extends LinkedNode[Nothing]

If you’re familiar with Scala List, you’ll probably notice that our ADT resembles `List` and its subclasses `Cons` (i.e. `::`) and `Nil` (see source code):

sealed abstract class List[+A] { ... }
final case class ::[+A](override val head: A, var next: List[A]) extends List[A] { ... }
case object Nil extends List[Nothing] { ... }

Expanding LinkedNode

Let’s expand trait `LinkedNode` to create class methods `insertNode`/`deleteNode` at a given index for inserting/deleting a node, `toList` for extracting contained elements into a `List` collection, and `toString` for display:

sealed trait LinkedNode[+A] { self =>

  def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match {
      case EmptyNode =>
        if (i < idx)
          EmptyNode
        else
          Node(x, EmptyNode)
      case Node(e, nx) =>
        if (i < idx)
          Node(e, loop(nx, i + 1))
        else
          Node(x, ln)
    }
    loop(self, 0)
  }

  def deleteNode(idx: Int): LinkedNode[A] = {
    def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (idx == 0)
          nx
        else {
          if (i < idx - 1)
            Node(e, loop(nx, i + 1))
          else
            nx match {
              case EmptyNode => Node(e, EmptyNode)
              case Node(_, nx2) => Node(e, nx2)
            }
        }
    }
    loop(self, 0)
  }

  def toList: List[A] = {
    def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, e :: acc)
    }
    loop(self, List.empty[A]).reverse
  }

  override def toString: String = {
    def loop(ln: LinkedNode[A], acc: String): String = ln match {
      case EmptyNode =>
        acc + "()"
      case Node(e, nx) =>
        loop(nx, acc + e + " -> " )
    }
    loop(self, "")
  }
}

case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]

case object EmptyNode extends LinkedNode[Nothing]

// Test running ...

val ln = List(1, 2, 3, 3, 4).foldRight[LinkedNode[Int]](EmptyNode)(Node(_, _))
// ln: LinkedNode[Int] = 1 -> 2 -> 3 -> 3 -> 4 -> ()

ln.insertNode(9, 2)
// res1: LinkedNode[Int] = 1 -> 2 -> 9 -> 3 -> 3 -> 4 -> ()

ln.deleteNode(1)
// res2: LinkedNode[Int] = 1 -> 3 -> 3 -> 4 -> ()

Note that `LinkedNode` is made covariant. In addition, method `insertNode` has type `A` as its lower type bound because Function1 is `contravariant` over its parameter type.

Recursion and pattern matching

A couple of notes on the approach we implement our class methods with:

  1. We use recursive functions to avoid using of mutable variables. They should be made tail-recursive for optimal performance, but that isn’t the focus of this implementation. If performance is a priority, using conventional while-loops with mutable class fields `elem`/`next` would be a more practical option.
  2. Pattern matching is routinely used for handling cases of `Node` versus `EmptyNode`. An alternative approach would be to define fields `elem` and `next` in the base trait and implement class methods accordingly within `Node` and `EmptyNode`.

Finding first/last matching nodes

Next, we add a couple of class methods for finding first/last matching nodes.

  def find[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (e == x) ln else loop(nx)
    }
    loop(self)
  }

  def findLast[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        if (e == x) loop(nx, ln :: acc) else loop(nx, acc)
    }
    loop(self, List.empty[LinkedNode[B]]).headOption match {
      case None => EmptyNode
      case Some(node) => node
    }
  }

Reversing a linked list by groups of nodes

Reversing a given LinkedNode can be accomplished via recursion by cumulatively wrapping the element of each `Node` in a new `Node` with its `next` pointer set to the `Node` created in the previous iteration.

  def reverse: LinkedNode[A] = {
    def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, Node(e, acc))
    }
    loop(self, EmptyNode)
  }

  def reverseK(k: Int): LinkedNode[A] = {
    if (k == 1)
      self
    else
      self.toList.grouped(k).flatMap(_.reverse).
        foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
  }

Method `reverseK` reverses a LinkedNode by groups of `k` elements using a different approach that extracts the elements into groups of `k` elements, reverses the elements in each of the groups and re-wraps each of the flattened elements in a `Node`.

Using LinkedNode as a Stack

For `LinkedNode` to serve as a Stack, we include simple methods `push` and `pop` as follows:

  def push[B >: A](x: B): LinkedNode[B] = Node(x, self)

  def pop: (Option[A], LinkedNode[A]) = self match {
    case EmptyNode => (None, self)
    case Node(e, nx) => (Some(e), nx)
  }

Addendum: implementing map, flatMap, fold

From a different perspective, if we view `LinkedNode` as a collection like a Scala List or Vector, we might be craving for methods like `map`, `flatMap`, `fold`. Using the same approach of recursion along with pattern matching, it’s rather straight forward to crank them out.

  def map[B](f: A => B): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      Node(f(e), nx.map(f))
  }

  def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      f(e) match {
        case EmptyNode => nx.flatMap(f)
        case Node(e2, _) => Node(e2, nx.flatMap(f))
      }
  }

  def foldLeft[B](z: B)(f: (B, A) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(acc, e))
    }
    loop(self, z)
  }

  def foldRight[B](z: B)(f: (A, B) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(e, acc))
    }
    loop(self.reverse, z)
  }
}

Putting everything together

Along with a few additional simple class methods and a factory method wrapped in `LinkedNode`’s companion object, below is the final `LinkedNode` ADT that includes everything described above.

sealed trait LinkedNode[+A] { self =>

  def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match {
      case EmptyNode =>
        if (i < idx)
          EmptyNode
        else
          Node(x, EmptyNode)
      case Node(e, nx) =>
        if (i < idx)
          Node(e, loop(nx, i + 1))
        else
          Node(x, ln)
    }
    loop(self, 0)
  }

  def deleteNode(idx: Int): LinkedNode[A] = {
    def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (idx == 0)
          nx
        else {
          if (i < idx - 1)
            Node(e, loop(nx, i + 1))
          else
            nx match {
              case EmptyNode => Node(e, EmptyNode)
              case Node(_, nx2) => Node(e, nx2)
            }
        }
    }
    loop(self, 0)
  }

  def get(idx: Int): LinkedNode[A] = {
    def loop(ln: LinkedNode[A], count: Int): LinkedNode[A] = {
      if (count < idx) {
        ln match {
          case EmptyNode => EmptyNode
          case Node(_, next) => loop(next, count + 1)
        }
      }
      else
        ln
    }
    loop(self, 0)
  }

  def find[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (e == x) ln else loop(nx)
    }
    loop(self)
  }

  def findLast[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        if (e == x) loop(nx, ln :: acc) else loop(nx, acc)
    }
    loop(self, List.empty[LinkedNode[B]]).headOption match {
      case None => EmptyNode
      case Some(node) => node
    }
  }

  def indexOf[B >: A](x: B): Int = {
    def loop(ln: LinkedNode[B], idx: Int): Int = ln match {
      case EmptyNode =>
        -1
      case Node(e, nx) =>
        if (e == x) idx else loop(nx, idx + 1)
     }
    loop(self, 0)
  }

  def indexLast[B >: A](x: B): Int = {
    def loop(ln: LinkedNode[B], idx: Int, acc: List[Int]): List[Int] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, idx + 1, if (e == x) idx :: acc else acc)
    }
    loop(self, 0, List(-1)).head
  }

  def reverse: LinkedNode[A] = {
    def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match {
      case EmptyNode =>
        acc
       case Node(e, nx) =>
        loop(nx, Node(e, acc))
    }
    loop(self, EmptyNode)
  }

  def reverseK(k: Int): LinkedNode[A] = {
    if (k == 1)
      self
    else
      self.toList.grouped(k).flatMap(_.reverse).
        foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
  }

  def toList: List[A] = {
    def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, e :: acc)
    }
    loop(self, List.empty[A]).reverse
  }

  override def toString: String = {
    def loop(ln: LinkedNode[A], acc: String): String = ln match {
      case EmptyNode =>
        acc + "()"
      case Node(e, nx) =>
        loop(nx, acc + e + " -> " )
    }
    loop(self, "")
  }

  // ---- push / pop ----

  def push[B >: A](x: B): LinkedNode[B] = Node(x, self)

  def pop: (Option[A], LinkedNode[A]) = self match {
    case EmptyNode => (None, self)
    case Node(e, nx) => (Some(e), nx)
  }

  // ---- map / flatMap / fold ----

  def map[B](f: A => B): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      Node(f(e), nx.map(f))
  }

  def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      f(e) match {
        case EmptyNode => nx.flatMap(f)
        case Node(e2, _) => Node(e2, nx.flatMap(f))
      }
  }

  def foldLeft[B](z: B)(f: (B, A) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(acc, e))
    }
    loop(self, z)
  }

  def foldRight[B](z: B)(f: (A, B) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(e, acc))
    }
    loop(self.reverse, z)
  }
}

object LinkedNode {
  def apply[A](ls: List[A]): LinkedNode[A] =
    ls.foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
}

case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]

case object EmptyNode extends LinkedNode[Nothing]

A cursory test-run …

val ln = LinkedNode(List(1, 2, 2, 3, 4, 4, 5))
// ln: LinkedNode[Int] = 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ()

val ln2 = ln.insertNode(9, 3)
// ln2: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 3 -> 4 -> 4 -> 5 -> ()

val ln3 = ln2.deleteNode(4)
// ln3: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 4 -> 4 -> 5 -> ()

ln.reverseK(3)
// res1: LinkedNode[Int] = 2 -> 2 -> 1 -> 4 -> 4 -> 3 -> 5 -> ()

ln.find(4)
// res2: LinkedNode[Int] = 4 -> 4 -> 5 -> ()

ln.indexLast(4)
// res3: Int = 5

// ---- Using LinkedNode like a `stack` ----

ln.push(9)
// res1: LinkedNode[Int] = 9 -> 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ()

ln.pop
// res2: (Option[Int], LinkedNode[Int]) = (Some(1), 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ())

// ---- Using LinkedNode like a `Vector` collection ----

ln.map(_ + 10)
// res1: LinkedNode[Int] = 11 -> 12 -> 12 -> 13 -> 14 -> 14 -> 15 -> ()

ln.flatMap(i => Node(s"$i!", EmptyNode))
// res2: LinkedNode[String] = 1! -> 2! -> 2! -> 3! -> 4! -> 4! -> 5! -> ()

ln.foldLeft(0)(_ + _)
// res3: Int = 21

ln.foldRight("")((x, acc) => if (acc.isEmpty) s"$x" else s"$acc|$x")
// res4: String = 5|4|4|3|2|2|1

3 thoughts on “Implementing Linked List In Scala

  1. Pingback: Scala Binary Search Tree | Genuine Blog

  2. luqmaan s

    In the implementation of the insertNode method for the LinkedNode class, what is the purpose of the condition if (i < idx) and why does it return EmptyNode in that case?

    Reply
    1. Leo Cheung Post author

      Thanks for the comment and sorry for the late reply. Calling method insertNode(x, idx) represents a request to insert a node with value `x` as the (`idx+1`)-th node. For instance, insertNode(‘a’, 4) means the intention to insert a node with value ‘a’ as the 5th node. Now, if the original linked list has less than 4 nodes, it’s impossible to insert a 5th node. Thus, as the inner loop reaches the tail end (i.e. EmptyNode), `i` would still be less than `idx`, and therefore the loop should exit early with no change (i.e. EmptyNode).

      Reply

Leave a Reply

Your email address will not be published. Required fields are marked *