In Scala, if you wonder why its standard library doesn’t come with a data structure called `LinkedList`, you may have overlooked. The collection List is in fact a linked list — although it often appears in the form of a Seq or Vector collection rather than the generally “mysterious” linked list that exposes its “head” with a hidden “tail” to be revealed only iteratively.
Our ADT: LinkedNode
Perhaps because of its simplicity and dynamicity as a data structure, implementation of linked list remains a popular coding exercise. To implement our own linked list, let’s start with a barebone ADT (algebraic data structure) as follows:
sealed trait LinkedNode[+A] case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A] case object EmptyNode extends LinkedNode[Nothing]
If you’re familiar with Scala List, you’ll probably notice that our ADT resembles `List` and its subclasses `Cons` (i.e. `::`) and `Nil` (see source code):
sealed abstract class List[+A] { ... } final case class ::[+A](override val head: A, var next: List[A]) extends List[A] { ... } case object Nil extends List[Nothing] { ... }
Expanding LinkedNode
Let’s expand trait `LinkedNode` to create class methods `insertNode`/`deleteNode` at a given index for inserting/deleting a node, `toList` for extracting contained elements into a `List` collection, and `toString` for display:
sealed trait LinkedNode[+A] { self => def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = { def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match { case EmptyNode => if (i < idx) EmptyNode else Node(x, EmptyNode) case Node(e, nx) => if (i < idx) Node(e, loop(nx, i + 1)) else Node(x, ln) } loop(self, 0) } def deleteNode(idx: Int): LinkedNode[A] = { def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match { case EmptyNode => EmptyNode case Node(e, nx) => if (idx == 0) nx else { if (i < idx - 1) Node(e, loop(nx, i + 1)) else nx match { case EmptyNode => Node(e, EmptyNode) case Node(_, nx2) => Node(e, nx2) } } } loop(self, 0) } def toList: List[A] = { def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, e :: acc) } loop(self, List.empty[A]).reverse } override def toString: String = { def loop(ln: LinkedNode[A], acc: String): String = ln match { case EmptyNode => acc + "()" case Node(e, nx) => loop(nx, acc + e + " -> " ) } loop(self, "") } } case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A] case object EmptyNode extends LinkedNode[Nothing] // Test running ... val ln = List(1, 2, 3, 3, 4).foldRight[LinkedNode[Int]](EmptyNode)(Node(_, _)) // ln: LinkedNode[Int] = 1 -> 2 -> 3 -> 3 -> 4 -> () ln.insertNode(9, 2) // res1: LinkedNode[Int] = 1 -> 2 -> 9 -> 3 -> 3 -> 4 -> () ln.deleteNode(1) // res2: LinkedNode[Int] = 1 -> 3 -> 3 -> 4 -> ()
Note that `LinkedNode` is made covariant. In addition, method `insertNode` has type `A` as its lower type bound because Function1 is `contravariant` over its parameter type.
Recursion and pattern matching
A couple of notes on the approach we implement our class methods with:
- We use recursive functions to avoid using of mutable variables. They should be made tail-recursive for optimal performance, but that isn’t the focus of this implementation. If performance is a priority, using conventional while-loops with mutable class fields `elem`/`next` would be a more practical option.
- Pattern matching is routinely used for handling cases of `Node` versus `EmptyNode`. An alternative approach would be to define fields `elem` and `next` in the base trait and implement class methods accordingly within `Node` and `EmptyNode`.
Finding first/last matching nodes
Next, we add a couple of class methods for finding first/last matching nodes.
def find[B >: A](x: B): LinkedNode[B] = { def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match { case EmptyNode => EmptyNode case Node(e, nx) => if (e == x) ln else loop(nx) } loop(self) } def findLast[B >: A](x: B): LinkedNode[B] = { def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match { case EmptyNode => acc case Node(e, nx) => if (e == x) loop(nx, ln :: acc) else loop(nx, acc) } loop(self, List.empty[LinkedNode[B]]).headOption match { case None => EmptyNode case Some(node) => node } }
Reversing a linked list by groups of nodes
Reversing a given LinkedNode can be accomplished via recursion by cumulatively wrapping the element of each `Node` in a new `Node` with its `next` pointer set to the `Node` created in the previous iteration.
def reverse: LinkedNode[A] = { def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, Node(e, acc)) } loop(self, EmptyNode) } def reverseK(k: Int): LinkedNode[A] = { if (k == 1) self else self.toList.grouped(k).flatMap(_.reverse). foldRight[LinkedNode[A]](EmptyNode)(Node(_, _)) }
Method `reverseK` reverses a LinkedNode by groups of `k` elements using a different approach that extracts the elements into groups of `k` elements, reverses the elements in each of the groups and re-wraps each of the flattened elements in a `Node`.
Using LinkedNode as a Stack
For `LinkedNode` to serve as a Stack, we include simple methods `push` and `pop` as follows:
def push[B >: A](x: B): LinkedNode[B] = Node(x, self) def pop: (Option[A], LinkedNode[A]) = self match { case EmptyNode => (None, self) case Node(e, nx) => (Some(e), nx) }
Addendum: implementing map, flatMap, fold
From a different perspective, if we view `LinkedNode` as a collection like a Scala List or Vector, we might be craving for methods like `map`, `flatMap`, `fold`. Using the same approach of recursion along with pattern matching, it’s rather straight forward to crank them out.
def map[B](f: A => B): LinkedNode[B] = self match { case EmptyNode => EmptyNode case Node(e, nx) => Node(f(e), nx.map(f)) } def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match { case EmptyNode => EmptyNode case Node(e, nx) => f(e) match { case EmptyNode => nx.flatMap(f) case Node(e2, _) => Node(e2, nx.flatMap(f)) } } def foldLeft[B](z: B)(f: (B, A) => B): B = { def loop(ln: LinkedNode[A], acc: B): B = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, f(acc, e)) } loop(self, z) } def foldRight[B](z: B)(f: (A, B) => B): B = { def loop(ln: LinkedNode[A], acc: B): B = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, f(e, acc)) } loop(self.reverse, z) } }
Putting everything together
Along with a few additional simple class methods and a factory method wrapped in `LinkedNode`’s companion object, below is the final `LinkedNode` ADT that includes everything described above.
sealed trait LinkedNode[+A] { self => def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = { def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match { case EmptyNode => if (i < idx) EmptyNode else Node(x, EmptyNode) case Node(e, nx) => if (i < idx) Node(e, loop(nx, i + 1)) else Node(x, ln) } loop(self, 0) } def deleteNode(idx: Int): LinkedNode[A] = { def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match { case EmptyNode => EmptyNode case Node(e, nx) => if (idx == 0) nx else { if (i < idx - 1) Node(e, loop(nx, i + 1)) else nx match { case EmptyNode => Node(e, EmptyNode) case Node(_, nx2) => Node(e, nx2) } } } loop(self, 0) } def get(idx: Int): LinkedNode[A] = { def loop(ln: LinkedNode[A], count: Int): LinkedNode[A] = { if (count < idx) { ln match { case EmptyNode => EmptyNode case Node(_, next) => loop(next, count + 1) } } else ln } loop(self, 0) } def find[B >: A](x: B): LinkedNode[B] = { def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match { case EmptyNode => EmptyNode case Node(e, nx) => if (e == x) ln else loop(nx) } loop(self) } def findLast[B >: A](x: B): LinkedNode[B] = { def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match { case EmptyNode => acc case Node(e, nx) => if (e == x) loop(nx, ln :: acc) else loop(nx, acc) } loop(self, List.empty[LinkedNode[B]]).headOption match { case None => EmptyNode case Some(node) => node } } def indexOf[B >: A](x: B): Int = { def loop(ln: LinkedNode[B], idx: Int): Int = ln match { case EmptyNode => -1 case Node(e, nx) => if (e == x) idx else loop(nx, idx + 1) } loop(self, 0) } def indexLast[B >: A](x: B): Int = { def loop(ln: LinkedNode[B], idx: Int, acc: List[Int]): List[Int] = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, idx + 1, if (e == x) idx :: acc else acc) } loop(self, 0, List(-1)).head } def reverse: LinkedNode[A] = { def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, Node(e, acc)) } loop(self, EmptyNode) } def reverseK(k: Int): LinkedNode[A] = { if (k == 1) self else self.toList.grouped(k).flatMap(_.reverse). foldRight[LinkedNode[A]](EmptyNode)(Node(_, _)) } def toList: List[A] = { def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, e :: acc) } loop(self, List.empty[A]).reverse } override def toString: String = { def loop(ln: LinkedNode[A], acc: String): String = ln match { case EmptyNode => acc + "()" case Node(e, nx) => loop(nx, acc + e + " -> " ) } loop(self, "") } // ---- push / pop ---- def push[B >: A](x: B): LinkedNode[B] = Node(x, self) def pop: (Option[A], LinkedNode[A]) = self match { case EmptyNode => (None, self) case Node(e, nx) => (Some(e), nx) } // ---- map / flatMap / fold ---- def map[B](f: A => B): LinkedNode[B] = self match { case EmptyNode => EmptyNode case Node(e, nx) => Node(f(e), nx.map(f)) } def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match { case EmptyNode => EmptyNode case Node(e, nx) => f(e) match { case EmptyNode => nx.flatMap(f) case Node(e2, _) => Node(e2, nx.flatMap(f)) } } def foldLeft[B](z: B)(f: (B, A) => B): B = { def loop(ln: LinkedNode[A], acc: B): B = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, f(acc, e)) } loop(self, z) } def foldRight[B](z: B)(f: (A, B) => B): B = { def loop(ln: LinkedNode[A], acc: B): B = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, f(e, acc)) } loop(self.reverse, z) } } object LinkedNode { def apply[A](ls: List[A]): LinkedNode[A] = ls.foldRight[LinkedNode[A]](EmptyNode)(Node(_, _)) } case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A] case object EmptyNode extends LinkedNode[Nothing]
A cursory test-run …
val ln = LinkedNode(List(1, 2, 2, 3, 4, 4, 5)) // ln: LinkedNode[Int] = 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> () val ln2 = ln.insertNode(9, 3) // ln2: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 3 -> 4 -> 4 -> 5 -> () val ln3 = ln2.deleteNode(4) // ln3: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 4 -> 4 -> 5 -> () ln.reverseK(3) // res1: LinkedNode[Int] = 2 -> 2 -> 1 -> 4 -> 4 -> 3 -> 5 -> () ln.find(4) // res2: LinkedNode[Int] = 4 -> 4 -> 5 -> () ln.indexLast(4) // res3: Int = 5 // ---- Using LinkedNode like a `stack` ---- ln.push(9) // res1: LinkedNode[Int] = 9 -> 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> () ln.pop // res2: (Option[Int], LinkedNode[Int]) = (Some(1), 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ()) // ---- Using LinkedNode like a `Vector` collection ---- ln.map(_ + 10) // res1: LinkedNode[Int] = 11 -> 12 -> 12 -> 13 -> 14 -> 14 -> 15 -> () ln.flatMap(i => Node(s"$i!", EmptyNode)) // res2: LinkedNode[String] = 1! -> 2! -> 2! -> 3! -> 4! -> 4! -> 5! -> () ln.foldLeft(0)(_ + _) // res3: Int = 21 ln.foldRight("")((x, acc) => if (acc.isEmpty) s"$x" else s"$acc|$x") // res4: String = 5|4|4|3|2|2|1