In Scala, if you wonder why its standard library doesn’t come with a data structure called LinkedList
, you may have overlooked. The collection List is in fact a linked list — although it often appears in the form of a Seq or Vector collection rather than the generally “mysterious” linked list that exposes its “head” with a hidden “tail” to be revealed only iteratively.
Our ADT: LinkedNode
Perhaps because of its simplicity and dynamicity as a data structure, implementation of linked list remains a popular coding exercise. To implement our own linked list, let’s start with a barebone ADT (algebraic data structure) as follows:
1 2 3 |
sealed trait LinkedNode[+A] case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A] case object EmptyNode extends LinkedNode[Nothing] |
If you’re familiar with Scala List, you’ll probably notice that our ADT resembles List
and its subclasses Cons
(i.e. ::
) and Nil
(see source code):
1 2 3 |
sealed abstract class List[+A] { ... } final case class ::[+A](override val head: A, var next: List[A]) extends List[A] { ... } case object Nil extends List[Nothing] { ... } |
Expanding LinkedNode
Let’s expand trait LinkedNode
to create class methods insertNode
/deleteNode
at a given index for inserting/deleting a node, toList
for extracting contained elements into a List
collection, and toString
for display:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
sealed trait LinkedNode[+A] { self => def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = { def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match { case EmptyNode => if (i < idx) EmptyNode else Node(x, EmptyNode) case Node(e, nx) => if (i < idx) Node(e, loop(nx, i + 1)) else Node(x, ln) } loop(self, 0) } def deleteNode(idx: Int): LinkedNode[A] = { def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match { case EmptyNode => EmptyNode case Node(e, nx) => if (idx == 0) nx else { if (i < idx - 1) Node(e, loop(nx, i + 1)) else nx match { case EmptyNode => Node(e, EmptyNode) case Node(_, nx2) => Node(e, nx2) } } } loop(self, 0) } def toList: List[A] = { def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, e :: acc) } loop(self, List.empty[A]).reverse } override def toString: String = { def loop(ln: LinkedNode[A], acc: String): String = ln match { case EmptyNode => acc + "()" case Node(e, nx) => loop(nx, acc + e + " -> " ) } loop(self, "") } } case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A] case object EmptyNode extends LinkedNode[Nothing] // Test running ... val ln = List(1, 2, 3, 3, 4).foldRight[LinkedNode[Int]](EmptyNode)(Node(_, _)) // ln: LinkedNode[Int] = 1 -> 2 -> 3 -> 3 -> 4 -> () ln.insertNode(9, 2) // res1: LinkedNode[Int] = 1 -> 2 -> 9 -> 3 -> 3 -> 4 -> () ln.deleteNode(1) // res2: LinkedNode[Int] = 1 -> 3 -> 3 -> 4 -> () |
Note that LinkedNode
is made covariant. In addition, method insertNode
has type A
as its lower type bound because Function1 is contravariant
over its parameter type.
Recursion and pattern matching
A couple of notes on the approach we implement our class methods with:
- We use recursive functions to avoid using of mutable variables. They should be made tail-recursive for optimal performance, but that isn’t the focus of this implementation. If performance is a priority, using conventional while-loops with mutable class fields
elem
/next
would be a more practical option. - Pattern matching is routinely used for handling cases of
Node
versusEmptyNode
. An alternative approach would be to define fieldselem
andnext
in the base trait and implement class methods accordingly withinNode
andEmptyNode
.
Finding first/last matching nodes
Next, we add a couple of class methods for finding first/last matching nodes.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
def find[B >: A](x: B): LinkedNode[B] = { def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match { case EmptyNode => EmptyNode case Node(e, nx) => if (e == x) ln else loop(nx) } loop(self) } def findLast[B >: A](x: B): LinkedNode[B] = { def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match { case EmptyNode => acc case Node(e, nx) => if (e == x) loop(nx, ln :: acc) else loop(nx, acc) } loop(self, List.empty[LinkedNode[B]]).headOption match { case None => EmptyNode case Some(node) => node } } |
Reversing a linked list by groups of nodes
Reversing a given LinkedNode can be accomplished via recursion by cumulatively wrapping the element of each Node
in a new Node
with its next
pointer set to the Node
created in the previous iteration.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
def reverse: LinkedNode[A] = { def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, Node(e, acc)) } loop(self, EmptyNode) } def reverseK(k: Int): LinkedNode[A] = { if (k == 1) self else self.toList.grouped(k).flatMap(_.reverse). foldRight[LinkedNode[A]](EmptyNode)(Node(_, _)) } |
Method reverseK
reverses a LinkedNode by groups of k
elements using a different approach that extracts the elements into groups of k
elements, reverses the elements in each of the groups and re-wraps each of the flattened elements in a Node
.
Using LinkedNode as a Stack
For LinkedNode
to serve as a Stack, we include simple methods push
and pop
as follows:
1 2 3 4 5 6 |
def push[B >: A](x: B): LinkedNode[B] = Node(x, self) def pop: (Option[A], LinkedNode[A]) = self match { case EmptyNode => (None, self) case Node(e, nx) => (Some(e), nx) } |
Addendum: implementing map, flatMap, fold
From a different perspective, if we view LinkedNode
as a collection like a Scala List or Vector, we might be craving for methods like map
, flatMap
, fold
. Using the same approach of recursion along with pattern matching, it’s rather straight forward to crank them out.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
def map[B](f: A => B): LinkedNode[B] = self match { case EmptyNode => EmptyNode case Node(e, nx) => Node(f(e), nx.map(f)) } def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match { case EmptyNode => EmptyNode case Node(e, nx) => f(e) match { case EmptyNode => nx.flatMap(f) case Node(e2, _) => Node(e2, nx.flatMap(f)) } } def foldLeft[B](z: B)(f: (B, A) => B): B = { def loop(ln: LinkedNode[A], acc: B): B = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, f(acc, e)) } loop(self, z) } def foldRight[B](z: B)(f: (A, B) => B): B = { def loop(ln: LinkedNode[A], acc: B): B = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, f(e, acc)) } loop(self.reverse, z) } } |
Putting everything together
Along with a few additional simple class methods and a factory method wrapped in LinkedNode
‘s companion object, below is the final LinkedNode
ADT that includes everything described above.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
sealed trait LinkedNode[+A] { self => def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = { def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match { case EmptyNode => if (i < idx) EmptyNode else Node(x, EmptyNode) case Node(e, nx) => if (i < idx) Node(e, loop(nx, i + 1)) else Node(x, ln) } loop(self, 0) } def deleteNode(idx: Int): LinkedNode[A] = { def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match { case EmptyNode => EmptyNode case Node(e, nx) => if (idx == 0) nx else { if (i < idx - 1) Node(e, loop(nx, i + 1)) else nx match { case EmptyNode => Node(e, EmptyNode) case Node(_, nx2) => Node(e, nx2) } } } loop(self, 0) } def get(idx: Int): LinkedNode[A] = { def loop(ln: LinkedNode[A], count: Int): LinkedNode[A] = { if (count < idx) { ln match { case EmptyNode => EmptyNode case Node(_, next) => loop(next, count + 1) } } else ln } loop(self, 0) } def find[B >: A](x: B): LinkedNode[B] = { def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match { case EmptyNode => EmptyNode case Node(e, nx) => if (e == x) ln else loop(nx) } loop(self) } def findLast[B >: A](x: B): LinkedNode[B] = { def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match { case EmptyNode => acc case Node(e, nx) => if (e == x) loop(nx, ln :: acc) else loop(nx, acc) } loop(self, List.empty[LinkedNode[B]]).headOption match { case None => EmptyNode case Some(node) => node } } def indexOf[B >: A](x: B): Int = { def loop(ln: LinkedNode[B], idx: Int): Int = ln match { case EmptyNode => -1 case Node(e, nx) => if (e == x) idx else loop(nx, idx + 1) } loop(self, 0) } def indexLast[B >: A](x: B): Int = { def loop(ln: LinkedNode[B], idx: Int, acc: List[Int]): List[Int] = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, idx + 1, if (e == x) idx :: acc else acc) } loop(self, 0, List(-1)).head } def reverse: LinkedNode[A] = { def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, Node(e, acc)) } loop(self, EmptyNode) } def reverseK(k: Int): LinkedNode[A] = { if (k == 1) self else self.toList.grouped(k).flatMap(_.reverse). foldRight[LinkedNode[A]](EmptyNode)(Node(_, _)) } def toList: List[A] = { def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, e :: acc) } loop(self, List.empty[A]).reverse } override def toString: String = { def loop(ln: LinkedNode[A], acc: String): String = ln match { case EmptyNode => acc + "()" case Node(e, nx) => loop(nx, acc + e + " -> " ) } loop(self, "") } // ---- push / pop ---- def push[B >: A](x: B): LinkedNode[B] = Node(x, self) def pop: (Option[A], LinkedNode[A]) = self match { case EmptyNode => (None, self) case Node(e, nx) => (Some(e), nx) } // ---- map / flatMap / fold ---- def map[B](f: A => B): LinkedNode[B] = self match { case EmptyNode => EmptyNode case Node(e, nx) => Node(f(e), nx.map(f)) } def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match { case EmptyNode => EmptyNode case Node(e, nx) => f(e) match { case EmptyNode => nx.flatMap(f) case Node(e2, _) => Node(e2, nx.flatMap(f)) } } def foldLeft[B](z: B)(f: (B, A) => B): B = { def loop(ln: LinkedNode[A], acc: B): B = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, f(acc, e)) } loop(self, z) } def foldRight[B](z: B)(f: (A, B) => B): B = { def loop(ln: LinkedNode[A], acc: B): B = ln match { case EmptyNode => acc case Node(e, nx) => loop(nx, f(e, acc)) } loop(self.reverse, z) } } object LinkedNode { def apply[A](ls: List[A]): LinkedNode[A] = ls.foldRight[LinkedNode[A]](EmptyNode)(Node(_, _)) } case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A] case object EmptyNode extends LinkedNode[Nothing] |
A cursory test-run …
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
val ln = LinkedNode(List(1, 2, 2, 3, 4, 4, 5)) // ln: LinkedNode[Int] = 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> () val ln2 = ln.insertNode(9, 3) // ln2: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 3 -> 4 -> 4 -> 5 -> () val ln3 = ln2.deleteNode(4) // ln3: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 4 -> 4 -> 5 -> () ln.reverseK(3) // res1: LinkedNode[Int] = 2 -> 2 -> 1 -> 4 -> 4 -> 3 -> 5 -> () ln.find(4) // res2: LinkedNode[Int] = 4 -> 4 -> 5 -> () ln.indexLast(4) // res3: Int = 5 // ---- Using LinkedNode like a `stack` ---- ln.push(9) // res1: LinkedNode[Int] = 9 -> 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> () ln.pop // res2: (Option[Int], LinkedNode[Int]) = (Some(1), 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ()) // ---- Using LinkedNode like a `Vector` collection ---- ln.map(_ + 10) // res1: LinkedNode[Int] = 11 -> 12 -> 12 -> 13 -> 14 -> 14 -> 15 -> () ln.flatMap(i => Node(s"$i!", EmptyNode)) // res2: LinkedNode[String] = 1! -> 2! -> 2! -> 3! -> 4! -> 4! -> 5! -> () ln.foldLeft(0)(_ + _) // res3: Int = 21 ln.foldRight("")((x, acc) => if (acc.isEmpty) s"$x" else s"$acc|$x") // res4: String = 5|4|4|3|2|2|1 |
Pingback: Scala Binary Search Tree | Genuine Blog
In the implementation of the insertNode method for the LinkedNode class, what is the purpose of the condition if (i < idx) and why does it return EmptyNode in that case?
Thanks for the comment and sorry for the late reply. Calling method insertNode(x, idx) represents a request to insert a node with value
x
as the (idx+1
)-th node. For instance, insertNode(‘a’, 4) means the intention to insert a node with value ‘a’ as the 5th node. Now, if the original linked list has less than 4 nodes, it’s impossible to insert a 5th node. Thus, as the inner loop reaches the tail end (i.e. EmptyNode),i
would still be less thanidx
, and therefore the loop should exit early with no change (i.e. EmptyNode).