In Scala, if you wonder why its standard library doesn’t come with a data structure called `LinkedList`, you may have overlooked. The collection List is in fact a linked list — although it often appears in the form of a Seq or Vector collection rather than the generally “mysterious” linked list that exposes its “head” with a hidden “tail” to be revealed only iteratively.
Our ADT: LinkedNode
Perhaps because of its simplicity and dynamicity as a data structure, implementation of linked list remains a popular coding exercise. To implement our own linked list, let’s start with a barebone ADT (algebraic data structure) as follows:
sealed trait LinkedNode[+A] case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A] case object EmptyNode extends LinkedNode[Nothing]
If you’re familiar with Scala List, you’ll probably notice that our ADT resembles `List` and its subclasses `Cons` (i.e. `::`) and `Nil` (see source code):
sealed abstract class List[+A] { ... }
final case class ::[+A](override val head: A, var next: List[A]) extends List[A] { ... }
case object Nil extends List[Nothing] { ... }
Expanding LinkedNode
Let’s expand trait `LinkedNode` to create class methods `insertNode`/`deleteNode` at a given index for inserting/deleting a node, `toList` for extracting contained elements into a `List` collection, and `toString` for display:
sealed trait LinkedNode[+A] { self =>
def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = {
def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match {
case EmptyNode =>
if (i < idx)
EmptyNode
else
Node(x, EmptyNode)
case Node(e, nx) =>
if (i < idx)
Node(e, loop(nx, i + 1))
else
Node(x, ln)
}
loop(self, 0)
}
def deleteNode(idx: Int): LinkedNode[A] = {
def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match {
case EmptyNode =>
EmptyNode
case Node(e, nx) =>
if (idx == 0)
nx
else {
if (i < idx - 1)
Node(e, loop(nx, i + 1))
else
nx match {
case EmptyNode => Node(e, EmptyNode)
case Node(_, nx2) => Node(e, nx2)
}
}
}
loop(self, 0)
}
def toList: List[A] = {
def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match {
case EmptyNode =>
acc
case Node(e, nx) =>
loop(nx, e :: acc)
}
loop(self, List.empty[A]).reverse
}
override def toString: String = {
def loop(ln: LinkedNode[A], acc: String): String = ln match {
case EmptyNode =>
acc + "()"
case Node(e, nx) =>
loop(nx, acc + e + " -> " )
}
loop(self, "")
}
}
case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]
case object EmptyNode extends LinkedNode[Nothing]
// Test running ...
val ln = List(1, 2, 3, 3, 4).foldRight[LinkedNode[Int]](EmptyNode)(Node(_, _))
// ln: LinkedNode[Int] = 1 -> 2 -> 3 -> 3 -> 4 -> ()
ln.insertNode(9, 2)
// res1: LinkedNode[Int] = 1 -> 2 -> 9 -> 3 -> 3 -> 4 -> ()
ln.deleteNode(1)
// res2: LinkedNode[Int] = 1 -> 3 -> 3 -> 4 -> ()
Note that `LinkedNode` is made covariant. In addition, method `insertNode` has type `A` as its lower type bound because Function1 is `contravariant` over its parameter type.
Recursion and pattern matching
A couple of notes on the approach we implement our class methods with:
- We use recursive functions to avoid using of mutable variables. They should be made tail-recursive for optimal performance, but that isn’t the focus of this implementation. If performance is a priority, using conventional while-loops with mutable class fields `elem`/`next` would be a more practical option.
- Pattern matching is routinely used for handling cases of `Node` versus `EmptyNode`. An alternative approach would be to define fields `elem` and `next` in the base trait and implement class methods accordingly within `Node` and `EmptyNode`.
Finding first/last matching nodes
Next, we add a couple of class methods for finding first/last matching nodes.
def find[B >: A](x: B): LinkedNode[B] = {
def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match {
case EmptyNode =>
EmptyNode
case Node(e, nx) =>
if (e == x) ln else loop(nx)
}
loop(self)
}
def findLast[B >: A](x: B): LinkedNode[B] = {
def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match {
case EmptyNode =>
acc
case Node(e, nx) =>
if (e == x) loop(nx, ln :: acc) else loop(nx, acc)
}
loop(self, List.empty[LinkedNode[B]]).headOption match {
case None => EmptyNode
case Some(node) => node
}
}
Reversing a linked list by groups of nodes
Reversing a given LinkedNode can be accomplished via recursion by cumulatively wrapping the element of each `Node` in a new `Node` with its `next` pointer set to the `Node` created in the previous iteration.
def reverse: LinkedNode[A] = {
def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match {
case EmptyNode =>
acc
case Node(e, nx) =>
loop(nx, Node(e, acc))
}
loop(self, EmptyNode)
}
def reverseK(k: Int): LinkedNode[A] = {
if (k == 1)
self
else
self.toList.grouped(k).flatMap(_.reverse).
foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
}
Method `reverseK` reverses a LinkedNode by groups of `k` elements using a different approach that extracts the elements into groups of `k` elements, reverses the elements in each of the groups and re-wraps each of the flattened elements in a `Node`.
Using LinkedNode as a Stack
For `LinkedNode` to serve as a Stack, we include simple methods `push` and `pop` as follows:
def push[B >: A](x: B): LinkedNode[B] = Node(x, self)
def pop: (Option[A], LinkedNode[A]) = self match {
case EmptyNode => (None, self)
case Node(e, nx) => (Some(e), nx)
}
Addendum: implementing map, flatMap, fold
From a different perspective, if we view `LinkedNode` as a collection like a Scala List or Vector, we might be craving for methods like `map`, `flatMap`, `fold`. Using the same approach of recursion along with pattern matching, it’s rather straight forward to crank them out.
def map[B](f: A => B): LinkedNode[B] = self match {
case EmptyNode =>
EmptyNode
case Node(e, nx) =>
Node(f(e), nx.map(f))
}
def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match {
case EmptyNode =>
EmptyNode
case Node(e, nx) =>
f(e) match {
case EmptyNode => nx.flatMap(f)
case Node(e2, _) => Node(e2, nx.flatMap(f))
}
}
def foldLeft[B](z: B)(f: (B, A) => B): B = {
def loop(ln: LinkedNode[A], acc: B): B = ln match {
case EmptyNode =>
acc
case Node(e, nx) =>
loop(nx, f(acc, e))
}
loop(self, z)
}
def foldRight[B](z: B)(f: (A, B) => B): B = {
def loop(ln: LinkedNode[A], acc: B): B = ln match {
case EmptyNode =>
acc
case Node(e, nx) =>
loop(nx, f(e, acc))
}
loop(self.reverse, z)
}
}
Putting everything together
Along with a few additional simple class methods and a factory method wrapped in `LinkedNode`’s companion object, below is the final `LinkedNode` ADT that includes everything described above.
sealed trait LinkedNode[+A] { self =>
def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = {
def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match {
case EmptyNode =>
if (i < idx)
EmptyNode
else
Node(x, EmptyNode)
case Node(e, nx) =>
if (i < idx)
Node(e, loop(nx, i + 1))
else
Node(x, ln)
}
loop(self, 0)
}
def deleteNode(idx: Int): LinkedNode[A] = {
def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match {
case EmptyNode =>
EmptyNode
case Node(e, nx) =>
if (idx == 0)
nx
else {
if (i < idx - 1)
Node(e, loop(nx, i + 1))
else
nx match {
case EmptyNode => Node(e, EmptyNode)
case Node(_, nx2) => Node(e, nx2)
}
}
}
loop(self, 0)
}
def get(idx: Int): LinkedNode[A] = {
def loop(ln: LinkedNode[A], count: Int): LinkedNode[A] = {
if (count < idx) {
ln match {
case EmptyNode => EmptyNode
case Node(_, next) => loop(next, count + 1)
}
}
else
ln
}
loop(self, 0)
}
def find[B >: A](x: B): LinkedNode[B] = {
def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match {
case EmptyNode =>
EmptyNode
case Node(e, nx) =>
if (e == x) ln else loop(nx)
}
loop(self)
}
def findLast[B >: A](x: B): LinkedNode[B] = {
def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match {
case EmptyNode =>
acc
case Node(e, nx) =>
if (e == x) loop(nx, ln :: acc) else loop(nx, acc)
}
loop(self, List.empty[LinkedNode[B]]).headOption match {
case None => EmptyNode
case Some(node) => node
}
}
def indexOf[B >: A](x: B): Int = {
def loop(ln: LinkedNode[B], idx: Int): Int = ln match {
case EmptyNode =>
-1
case Node(e, nx) =>
if (e == x) idx else loop(nx, idx + 1)
}
loop(self, 0)
}
def indexLast[B >: A](x: B): Int = {
def loop(ln: LinkedNode[B], idx: Int, acc: List[Int]): List[Int] = ln match {
case EmptyNode =>
acc
case Node(e, nx) =>
loop(nx, idx + 1, if (e == x) idx :: acc else acc)
}
loop(self, 0, List(-1)).head
}
def reverse: LinkedNode[A] = {
def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match {
case EmptyNode =>
acc
case Node(e, nx) =>
loop(nx, Node(e, acc))
}
loop(self, EmptyNode)
}
def reverseK(k: Int): LinkedNode[A] = {
if (k == 1)
self
else
self.toList.grouped(k).flatMap(_.reverse).
foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
}
def toList: List[A] = {
def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match {
case EmptyNode =>
acc
case Node(e, nx) =>
loop(nx, e :: acc)
}
loop(self, List.empty[A]).reverse
}
override def toString: String = {
def loop(ln: LinkedNode[A], acc: String): String = ln match {
case EmptyNode =>
acc + "()"
case Node(e, nx) =>
loop(nx, acc + e + " -> " )
}
loop(self, "")
}
// ---- push / pop ----
def push[B >: A](x: B): LinkedNode[B] = Node(x, self)
def pop: (Option[A], LinkedNode[A]) = self match {
case EmptyNode => (None, self)
case Node(e, nx) => (Some(e), nx)
}
// ---- map / flatMap / fold ----
def map[B](f: A => B): LinkedNode[B] = self match {
case EmptyNode =>
EmptyNode
case Node(e, nx) =>
Node(f(e), nx.map(f))
}
def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match {
case EmptyNode =>
EmptyNode
case Node(e, nx) =>
f(e) match {
case EmptyNode => nx.flatMap(f)
case Node(e2, _) => Node(e2, nx.flatMap(f))
}
}
def foldLeft[B](z: B)(f: (B, A) => B): B = {
def loop(ln: LinkedNode[A], acc: B): B = ln match {
case EmptyNode =>
acc
case Node(e, nx) =>
loop(nx, f(acc, e))
}
loop(self, z)
}
def foldRight[B](z: B)(f: (A, B) => B): B = {
def loop(ln: LinkedNode[A], acc: B): B = ln match {
case EmptyNode =>
acc
case Node(e, nx) =>
loop(nx, f(e, acc))
}
loop(self.reverse, z)
}
}
object LinkedNode {
def apply[A](ls: List[A]): LinkedNode[A] =
ls.foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
}
case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]
case object EmptyNode extends LinkedNode[Nothing]
A cursory test-run …
val ln = LinkedNode(List(1, 2, 2, 3, 4, 4, 5))
// ln: LinkedNode[Int] = 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ()
val ln2 = ln.insertNode(9, 3)
// ln2: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 3 -> 4 -> 4 -> 5 -> ()
val ln3 = ln2.deleteNode(4)
// ln3: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 4 -> 4 -> 5 -> ()
ln.reverseK(3)
// res1: LinkedNode[Int] = 2 -> 2 -> 1 -> 4 -> 4 -> 3 -> 5 -> ()
ln.find(4)
// res2: LinkedNode[Int] = 4 -> 4 -> 5 -> ()
ln.indexLast(4)
// res3: Int = 5
// ---- Using LinkedNode like a `stack` ----
ln.push(9)
// res1: LinkedNode[Int] = 9 -> 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ()
ln.pop
// res2: (Option[Int], LinkedNode[Int]) = (Some(1), 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ())
// ---- Using LinkedNode like a `Vector` collection ----
ln.map(_ + 10)
// res1: LinkedNode[Int] = 11 -> 12 -> 12 -> 13 -> 14 -> 14 -> 15 -> ()
ln.flatMap(i => Node(s"$i!", EmptyNode))
// res2: LinkedNode[String] = 1! -> 2! -> 2! -> 3! -> 4! -> 4! -> 5! -> ()
ln.foldLeft(0)(_ + _)
// res3: Int = 21
ln.foldRight("")((x, acc) => if (acc.isEmpty) s"$x" else s"$acc|$x")
// res4: String = 5|4|4|3|2|2|1
