Tag Archives: scala adt

Scala Binary Search Tree

When I wrote about Scala linked list implementation a couple of years ago, I also did some quick ground work for implementing binary search trees (BST). Occupied by other R&D projects at the time, it was put aside and has since been patiently awaiting its turn to see the light of day. As much of the code is already there, I’m going to put it up in this blog post along with some narrative remarks.

First, we come up with an ADT (algebraic data type). Let’s call it BSTree, starting out with a base trait with generic type A for the data element to be stored inside the tree structure, to be extended by a case class BSBranch as tree branches and a case object BSLeaf as “null” tree nodes. The ADT’s overall structure resembles that of the one used in the linked list implementation described in the old post.

ADT BSTree

sealed trait BSTree[+A] { self =>
  // ...

  override def toString: String = self match {
    case BSLeaf => "-"
    case BSBranch(e, c, l, r) => (l, r) match {
      case (BSLeaf, BSLeaf) => s"TN($e[$c], -, -)"
      case (BSBranch(le, lc, _, _), BSLeaf) => s"TN($e[$c], $le[$lc], -)"
      case (BSLeaf, BSBranch(re, rc, _, _)) => s"TN($e[$c], -, $re[$rc])"
      case (BSBranch(le, lc, _, _), BSBranch(re, rc, _, _)) => s"TN($e[$c], $le[$lc], $re[$rc])"
    }
  }
}

case class BSBranch[+A](
    elem: A, count: Int = 1, left: BSTree[A] = BSLeaf, right: BSTree[A] = BSLeaf
  ) extends BSTree[A]

case object BSLeaf extends BSTree[Nothing]

A few notes:

  • Data elements (of generic type A) are stored only in branches and data of the same value will go to the same branch and be represented with a proper count value.
  • BSTree is covariant, or else BSLeaf can’t even be defined as a sub-type of BSTree[Nothing].
  • A toString method is created for simplified string output of a tree instance.

Populating a BSTree

One of the first things we need is a method to insert tree nodes into an existing BSTree. We start expanding the base trait with method insert(). That’s all great for adding a node one at a time, but we also need a way to create a BSTree and populate it from a readily available collection of data elements. It makes sense to delegate such a factory method to the companion object BSTree as its method apply().

sealed trait BSTree[+A] { self =>

  def insert[B >: A : Ordering](elem: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSBranch(elem, 1, BSLeaf, BSLeaf)
      case t @ BSBranch(e, c, l, r) =>
        if (e > elem)
          t.copy(left = l.insert(elem))
        else if (e < elem)
          t.copy(right = r.insert(elem))
        else
          t.copy(count = c + 1)
    }
  }

  // ...
}

object BSTree {
  def apply[A : Ordering](elems: Vector[A]): BSTree[A] = {
    val ord = implicitly[Ordering[A]]
    import ord._
    elems.foldLeft[BSTree[A]](BSLeaf)(_.insert(_))
  }
}

Note that type parameter B for insert() needs to be a supertype of A because Function1 is contravariant over its parameter type. In addition, the context bound "B : Ordering" constrains type B to be capable of being ordered (i.e. compared) which is necessary for traversing a binary search tree.

Testing BSTree.apply():

val tree = BSTree(Vector(30, 10, 50, 20, 40, 70, 60, 80, 20, 50, 60))
// tree: BSTree[Int] = TN(30[1], 10[1], 50[2])

/*
tree:
           30
        /       \
    10            50
       \        /    \
        20     40     70
                    /    \
                   60    80
*/

Tree traversal and finding tree nodes

Next, we need methods for tree traversal and search. For brevity, we only include in-order traversal.

  def traverseInOrder: Unit = self match {
    case BSLeaf => ()
    case BSBranch(_, _, l, r) =>
      l.traverseInOrder
      println(self)
      r.traverseInOrder
  }

  def traverseByLevel: Unit = {
    def loop(tree: BSTree[A], level: Int): Unit = {
      if (level > 0)
        tree match {
          case BSLeaf =>
          case BSBranch(_, _, l, r) =>
            loop(l, level - 1)
            loop(r, level - 1)
        }
      else
        print(s"$tree  ")
    }
    for (l <- 0 until self.height) {
      loop(self, l)
      println()
    }
  }

  def find[B >: A : Ordering](elem: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, c, l, r) =>
        if (e < elem)
          r.find(elem)
        else if (e > elem)
          l.find(elem)
        else
          t
    }
  }

  def height: Int = self match {
    case BSLeaf => 0
    case BSBranch(_, _, l, r) => (l.height max r.height) + 1
  }

Using the tree created above:

/*
tree:
           30
        /       \
    10            50
       \        /    \
        20     40     70
                    /    \
                   60    80
*/

tree.height  // 4

tree.traverseInOrder
/*
TN(10[1], -, 20[2])
TN(20[2], -, -)
TN(30[1], 10[1], 50[2])
TN(40[1], -, -)
TN(50[2], 40[1], 70[1])
TN(60[2], -, -)
TN(70[1], 60[2], 80[1])
TN(80[1], -, -)
*/

tree.traverseByLevel
/*
TN(30[1], 10[1], 50[2])
TN(10[1], -, 20[2])  TN(50[2], 40[1], 70[1])
TN(20[2], -, -)  TN(40[1], -, -)  TN(70[1], 60[2], 80[1])
TN(60[2], -, -)  TN(80[1], -, -)
*/

val tree50 = tree.find(50)  // TN(50[2], 40[1], 70[1])
/*
tree50:
           50
        /       \
    40            70
                /    \
               60     80
*/

Removing tree nodes

To be able to remove tree nodes that consist of a specific or range of element values, we include also the following few methods in the base trait.

Note that delete() may involve a little shuffling of the tree nodes. Once the tree node to be removed is located, that node may need to be filled with the node having the next-bigger element & count values from its right node (or equivalently, the node having the next-smaller element from its left node).

  def delete[B >: A : Ordering](elem: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    def leftMost(tree: BSTree[B], prev: BSTree[B]): BSTree[B] = tree match {
      case BSLeaf => prev
      case t @ BSBranch(_, _, l, _) => leftMost(l, t)
    }
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, c, l, r) =>
        if (e < elem)
          t.copy(right = r.delete(elem))
        else if (e > elem)
          t.copy(left = l.delete(elem))
        else {
          if (l == BSLeaf)
            r
          else if (r == BSLeaf)
            l
          else {
            val nextBigger = leftMost(r, r)
            nextBigger match { case BSBranch(enb, cnb, _, _) =>
              t.copy(elem = enb, count = cnb, right = r.delete(enb))
            }
          }
        }
    }
  }

  def trim[B >: A : Ordering](lower: B, upper: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, _, l, r) =>
        val tt = t.copy(left = l.trim(lower, upper), right = r.trim(lower, upper)) 
        if (e < lower)
          tt match { case BSBranch(_, _, _, r) => r }
        else if (e > upper)
          tt match { case BSBranch(_, _, l, _) => l }
        else
          tt
    }
  }

  def cutOut[B >: A : Ordering](lower: B, upper: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, _, l, r) =>
        val tt = t.copy(left = l.cutOut(lower, upper), right = r.cutOut(lower, upper)) 
        if (e >= lower && e <= upper)
          tt match { case BSBranch(e, _, _, _) => tt.delete(e) }
        else
          tt
    }
  }

Method trim() removes tree nodes with element values below or above the provided range. Meanwhile, method cutOut() does the opposite by cutting out tree nodes with values within the given range. It involves slightly more work than trim(), requiring the use of delete() for individual tree nodes.

Example:

val treeD50 = tree.delete(50)  // TN(30[1], 10[1], 60[2])
/*
treeD50:
           30
        /       \
    10            60
       \        /    \
        20     40     70
                         \
                          80
*/

val treeT15to65 = tree.trim(15, 65)  // TN(30[1], 20[2], 50[2])
/*
treeT15to65:
           30
        /       \
    20            50
                /    \
               40     60
*/

val treeC25to55 = tree.trimInRange(25, 55)  // TN(60[1], 10[1], 70[1])
/*
treeC25to55:
           60
        /       \
    10            70
       \             \
        20            80
*/

Rebuilding a binary search tree

A highly unbalanced binary search tree beats the purpose of using such a data structure. One of the most straight forward ways to rebuild a binary search tree is to "unpack" the individual tree nodes of the existing tree by traversing in-order into a list (e.g. a Vector or List) of elements, followed by reconstructing a new tree with nodes being assigned elements from recursively half-ing the in-order node list.

  def rebuild: BSTree[A] = {
    def loop(vs: Vector[(A, Int)], lower: Int, upper: Int): BSTree[A] = {
      if (upper >= lower) {
        val m = (lower + upper) / 2
        val (e, c) = vs(m)
        BSBranch(e, c, loop(vs, lower, m-1), loop(vs, m+1, upper))
      }
      else
        BSLeaf
    }
    self match {
      case BSLeaf =>
        BSLeaf
      case _ =>
        val vs = self.toElemVector
        loop(vs, 0, vs.size - 1)
    }
  }

  def isBalanced: Boolean = self match {
    case BSLeaf => true
    case BSBranch(_, _, l, r) =>
      if (math.abs(l.height - r.height) > 1)
        false
      else
        l.isBalanced && r.isBalanced
  }

  def toElemVector: Vector[(A, Int)] = self match {  // in-order
    case BSLeaf =>
      Vector.empty[(A, Int)]
    case BSBranch(e, c, l, r) =>
      (l.toElemVector :+ (e, c)) ++ r.toElemVector
  }

  def toVector: Vector[BSTree[A]] = self match {  // in-order
    case BSLeaf => Vector.empty[BSTree[A]]
    case BSBranch(_, _, l, r) =>
      (l.toVector :+ self) ++ r.toVector
  }

Example:

val unbalancedTree = BSTree(Vector(20, 10, 30, 30, 40, 50, 60, 60, 50))
// unbalancedTree: BSTree[Int] = TN(20[1], 10[1], 30[2])

/*
unbalancedTree:
           20
        /       \
    10            30
                     \
                     40
                        \
                         60
                        /
                       50
*/

unbalancedTree.isBalanced  // false

val rebuiltTree = unbalancedTree.rebuild  // TN(30[2], 10[1], 50[2])
/*
rebuiltTree:

           30
        /       \
    10            50
       \        /    \
        20     40     60
*/

rebuiltTree.isBalanced  // true

Thoughts on the ADT

An alternative to how the ADT is designed is to have the class fields and methods declared in the BSTree base trait with specific implementations reside within subclasses BSBranch and BSLeaf, thus eliminating the need of the boiler-plate pattern matching for the subclasses. There is also the benefit of making class fields like left & right referenceable from the base trait, though they would need to be wrapped in Options with value None for BSLeaf.

As can be seen with the existing ADT, an advantage is having all the binary tree functions defined within the base trait once and for all. If there is the need for having left and right referenceable from the BSTree base trait, one can define something like below within the trait.

  def leftIfAny: Option[BSTree[A]] = self match {
    case BSLeaf => None
    case BSBranch(_, _, l, _) => Some(l)
  }

  def rightIfAny: Option[BSTree[A]] = self match {
    case BSLeaf => None
    case BSBranch(_, _, _, r) => Some(r)
  }

Example:

tree.leftIfAny
// Some(TN(10[1], -, 20[2]))

tree.rightIfAny.flatMap(_.rightIfAny)
// Some(TN(70[1], 60[2], 80[1]))

Then there is also non-idiomatic approach of using mutable class fields in a single tree class commonly seen in Java implementation, like below:

object SimpleBST {
  class TreeNode[A](_elem: A, _left: TreeNode[A] = null, _right: TreeNode[A] = null) {
    var elem: A = _elem
    var count: Int = 1
    var left: TreeNode[A] = _left
    var right: TreeNode[A] = _right
    override def toString: String =
      if (this == null)
        "null"
      else (left, right) match {
        case (null, null) => s"TN($elem[$count], -, -)"
        case (l,    null) => s"TN($elem[$count], ${l.elem}[${l.count}], -)"
        case (null, r)    => s"TN($elem[$count], -, ${r.elem}[${r.count}])"
        case (l,    r)    => s"TN($elem[$count], ${l.elem}[${l.count}], ${r.elem}[${r.count}])"
      }
  }

  // Methods for node addition, deletion, etc ...
}

Addendum: Complete source code of the BSTree ADT

sealed trait BSTree[+A] { self =>

  def traverseInOrder: Unit = self match {
    case BSLeaf => ()
    case BSBranch(_, _, l, r) =>
      l.traverseInOrder
      println(self)
      r.traverseInOrder
  }

  def traverseByLevel: Unit = {
    def loop(tree: BSTree[A], level: Int): Unit = {
      if (level > 0)
        tree match {
          case BSLeaf =>
          case BSBranch(_, _, l, r) =>
            loop(l, level - 1)
            loop(r, level - 1)
        }
      else
        print(s"$tree  ")
    }
    for (l <- 0 until self.height) {
      loop(self, l)
      println()
    }
  }

  def find[B >: A : Ordering](elem: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, c, l, r) =>
        if (e < elem)
          r.find(elem)
        else if (e > elem)
          l.find(elem)
        else
          t
    }
  }

  def height: Int = self match {
    case BSLeaf => 0
    case BSBranch(_, _, l, r) => (l.height max r.height) + 1
  }

  def insert[B >: A : Ordering](elem: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSBranch(elem, 1, BSLeaf, BSLeaf)
      case t @ BSBranch(e, c, l, r) =>
        if (e > elem)
          t.copy(left = l.insert(elem))
        else if (e < elem)
          t.copy(right = r.insert(elem))
        else
          t.copy(count = c + 1)
    }
  }

  def delete[B >: A : Ordering](elem: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    def leftMost(tree: BSTree[B], prev: BSTree[B]): BSTree[B] = tree match {
      case BSLeaf => prev
      case t @ BSBranch(_, _, l, _) => leftMost(l, t)
    }
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, c, l, r) =>
        if (e < elem)
          t.copy(right = r.delete(elem))
        else if (e > elem)
          t.copy(left = l.delete(elem))
        else {
          if (l == BSLeaf)
            r
          else if (r == BSLeaf)
            l
          else {
            val nextBigger = leftMost(r, r)
            nextBigger match { case BSBranch(enb, cnb, _, _) =>
              t.copy(elem = enb, count = cnb, right = r.delete(enb))
            }
          }
        }
    }
  }

  def trim[B >: A : Ordering](lower: B, upper: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, _, l, r) =>
        val tt = t.copy(left = l.trim(lower, upper), right = r.trim(lower, upper)) 
        if (e < lower)
          tt match { case BSBranch(_, _, _, r) => r }
        else if (e > upper)
          tt match { case BSBranch(_, _, l, _) => l }
        else
          tt
    }
  }

  def cutOut[B >: A : Ordering](lower: B, upper: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, _, l, r) =>
        val tt = t.copy(left = l.cutOut(lower, upper), right = r.cutOut(lower, upper)) 
        if (e >= lower && e <= upper)
          tt match { case BSBranch(e, _, _, _) => tt.delete(e) }
        else
          tt
    }
  }

  def rebuild(): BSTree[A] = {
    def loop(vs: Vector[A], lower: Int, upper: Int): BSTree[A] = {
      if (upper >= lower) {
        val m = (lower + upper) / 2
        val em = vs(m)
        self match {
          case BSBranch(e, c, _, _) if e == em =>
            BSBranch(em, c+1, loop(vs, lower, m-1), loop(vs, m+1, upper))
          case _ =>
            BSBranch(em, 1, loop(vs, lower, m-1), loop(vs, m+1, upper))
        }
      }
      else
        BSLeaf
    }
    self match {
      case BSLeaf =>
        BSLeaf
      case _ =>
        val vs = self.toElemVector
        loop(vs, 0, vs.size - 1)
    }
  }

  def isBalanced: Boolean = self match {
    case BSLeaf => true
    case BSBranch(_, _, l, r) =>
      if (math.abs(l.height - r.height) > 1)
        false
      else
        l.isBalanced && r.isBalanced
  }

  def toElemVector: Vector[A] = self match {  // in-order
    case BSLeaf =>
      Vector.empty[A]
    case BSBranch(e, _, l, r) =>
      (l.toElemVector :+ e) ++ r.toElemVector
  }

  def toVector: Vector[BSTree[A]] = self match {  // in-order
    case BSLeaf => Vector.empty[BSTree[A]]
    case BSBranch(_, _, l, r) =>
      (l.toVector :+ self) ++ r.toVector
  }

  override def toString: String = self match {
    case BSLeaf => "-"
    case BSBranch(e, c, l, r) => (l, r) match {
      case (BSLeaf, BSLeaf) => s"TN($e[$c], -, -)"
      case (BSBranch(le, lc, _, _), BSLeaf) => s"TN($e[$c], $le[$lc], -)"
      case (BSLeaf, BSBranch(re, rc, _, _)) => s"TN($e[$c], -, $re[$rc])"
      case (BSBranch(le, lc, _, _), BSBranch(re, rc, _, _)) => s"TN($e[$c], $le[$lc], $re[$rc])"
    }
  }
}

object BSTree {
  def apply[A : Ordering](elems: Vector[A]): BSTree[A] = {
    val ord = implicitly[Ordering[A]]
    import ord._
    elems.foldLeft[BSTree[A]](BSLeaf)(_.insert(_))
  }
}

case class BSBranch[A](
    elem: A, count: Int = 1, left: BSTree[A] = BSLeaf, right: BSTree[A] = BSLeaf
  ) extends BSTree[A]

case object BSLeaf extends BSTree[Nothing]

Implementing Linked List In Scala

In Scala, if you wonder why its standard library doesn’t come with a data structure called `LinkedList`, you may have overlooked. The collection List is in fact a linked list — although it often appears in the form of a Seq or Vector collection rather than the generally “mysterious” linked list that exposes its “head” with a hidden “tail” to be revealed only iteratively.

Our ADT: LinkedNode

Perhaps because of its simplicity and dynamicity as a data structure, implementation of linked list remains a popular coding exercise. To implement our own linked list, let’s start with a barebone ADT (algebraic data structure) as follows:

sealed trait LinkedNode[+A]
case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]
case object EmptyNode extends LinkedNode[Nothing]

If you’re familiar with Scala List, you’ll probably notice that our ADT resembles `List` and its subclasses `Cons` (i.e. `::`) and `Nil` (see source code):

sealed abstract class List[+A] { ... }
final case class ::[+A](override val head: A, var next: List[A]) extends List[A] { ... }
case object Nil extends List[Nothing] { ... }

Expanding LinkedNode

Let’s expand trait `LinkedNode` to create class methods `insertNode`/`deleteNode` at a given index for inserting/deleting a node, `toList` for extracting contained elements into a `List` collection, and `toString` for display:

sealed trait LinkedNode[+A] { self =>

  def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match {
      case EmptyNode =>
        if (i < idx)
          EmptyNode
        else
          Node(x, EmptyNode)
      case Node(e, nx) =>
        if (i < idx)
          Node(e, loop(nx, i + 1))
        else
          Node(x, ln)
    }
    loop(self, 0)
  }

  def deleteNode(idx: Int): LinkedNode[A] = {
    def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (idx == 0)
          nx
        else {
          if (i < idx - 1)
            Node(e, loop(nx, i + 1))
          else
            nx match {
              case EmptyNode => Node(e, EmptyNode)
              case Node(_, nx2) => Node(e, nx2)
            }
        }
    }
    loop(self, 0)
  }

  def toList: List[A] = {
    def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, e :: acc)
    }
    loop(self, List.empty[A]).reverse
  }

  override def toString: String = {
    def loop(ln: LinkedNode[A], acc: String): String = ln match {
      case EmptyNode =>
        acc + "()"
      case Node(e, nx) =>
        loop(nx, acc + e + " -> " )
    }
    loop(self, "")
  }
}

case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]

case object EmptyNode extends LinkedNode[Nothing]

// Test running ...

val ln = List(1, 2, 3, 3, 4).foldRight[LinkedNode[Int]](EmptyNode)(Node(_, _))
// ln: LinkedNode[Int] = 1 -> 2 -> 3 -> 3 -> 4 -> ()

ln.insertNode(9, 2)
// res1: LinkedNode[Int] = 1 -> 2 -> 9 -> 3 -> 3 -> 4 -> ()

ln.deleteNode(1)
// res2: LinkedNode[Int] = 1 -> 3 -> 3 -> 4 -> ()

Note that `LinkedNode` is made covariant. In addition, method `insertNode` has type `A` as its lower type bound because Function1 is `contravariant` over its parameter type.

Recursion and pattern matching

A couple of notes on the approach we implement our class methods with:

  1. We use recursive functions to avoid using of mutable variables. They should be made tail-recursive for optimal performance, but that isn’t the focus of this implementation. If performance is a priority, using conventional while-loops with mutable class fields `elem`/`next` would be a more practical option.
  2. Pattern matching is routinely used for handling cases of `Node` versus `EmptyNode`. An alternative approach would be to define fields `elem` and `next` in the base trait and implement class methods accordingly within `Node` and `EmptyNode`.

Finding first/last matching nodes

Next, we add a couple of class methods for finding first/last matching nodes.

  def find[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (e == x) ln else loop(nx)
    }
    loop(self)
  }

  def findLast[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        if (e == x) loop(nx, ln :: acc) else loop(nx, acc)
    }
    loop(self, List.empty[LinkedNode[B]]).headOption match {
      case None => EmptyNode
      case Some(node) => node
    }
  }

Reversing a linked list by groups of nodes

Reversing a given LinkedNode can be accomplished via recursion by cumulatively wrapping the element of each `Node` in a new `Node` with its `next` pointer set to the `Node` created in the previous iteration.

  def reverse: LinkedNode[A] = {
    def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, Node(e, acc))
    }
    loop(self, EmptyNode)
  }

  def reverseK(k: Int): LinkedNode[A] = {
    if (k == 1)
      self
    else
      self.toList.grouped(k).flatMap(_.reverse).
        foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
  }

Method `reverseK` reverses a LinkedNode by groups of `k` elements using a different approach that extracts the elements into groups of `k` elements, reverses the elements in each of the groups and re-wraps each of the flattened elements in a `Node`.

Using LinkedNode as a Stack

For `LinkedNode` to serve as a Stack, we include simple methods `push` and `pop` as follows:

  def push[B >: A](x: B): LinkedNode[B] = Node(x, self)

  def pop: (Option[A], LinkedNode[A]) = self match {
    case EmptyNode => (None, self)
    case Node(e, nx) => (Some(e), nx)
  }

Addendum: implementing map, flatMap, fold

From a different perspective, if we view `LinkedNode` as a collection like a Scala List or Vector, we might be craving for methods like `map`, `flatMap`, `fold`. Using the same approach of recursion along with pattern matching, it’s rather straight forward to crank them out.

  def map[B](f: A => B): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      Node(f(e), nx.map(f))
  }

  def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      f(e) match {
        case EmptyNode => nx.flatMap(f)
        case Node(e2, _) => Node(e2, nx.flatMap(f))
      }
  }

  def foldLeft[B](z: B)(f: (B, A) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(acc, e))
    }
    loop(self, z)
  }

  def foldRight[B](z: B)(f: (A, B) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(e, acc))
    }
    loop(self.reverse, z)
  }
}

Putting everything together

Along with a few additional simple class methods and a factory method wrapped in `LinkedNode`’s companion object, below is the final `LinkedNode` ADT that includes everything described above.

sealed trait LinkedNode[+A] { self =>

  def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match {
      case EmptyNode =>
        if (i < idx)
          EmptyNode
        else
          Node(x, EmptyNode)
      case Node(e, nx) =>
        if (i < idx)
          Node(e, loop(nx, i + 1))
        else
          Node(x, ln)
    }
    loop(self, 0)
  }

  def deleteNode(idx: Int): LinkedNode[A] = {
    def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (idx == 0)
          nx
        else {
          if (i < idx - 1)
            Node(e, loop(nx, i + 1))
          else
            nx match {
              case EmptyNode => Node(e, EmptyNode)
              case Node(_, nx2) => Node(e, nx2)
            }
        }
    }
    loop(self, 0)
  }

  def get(idx: Int): LinkedNode[A] = {
    def loop(ln: LinkedNode[A], count: Int): LinkedNode[A] = {
      if (count < idx) {
        ln match {
          case EmptyNode => EmptyNode
          case Node(_, next) => loop(next, count + 1)
        }
      }
      else
        ln
    }
    loop(self, 0)
  }

  def find[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (e == x) ln else loop(nx)
    }
    loop(self)
  }

  def findLast[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        if (e == x) loop(nx, ln :: acc) else loop(nx, acc)
    }
    loop(self, List.empty[LinkedNode[B]]).headOption match {
      case None => EmptyNode
      case Some(node) => node
    }
  }

  def indexOf[B >: A](x: B): Int = {
    def loop(ln: LinkedNode[B], idx: Int): Int = ln match {
      case EmptyNode =>
        -1
      case Node(e, nx) =>
        if (e == x) idx else loop(nx, idx + 1)
     }
    loop(self, 0)
  }

  def indexLast[B >: A](x: B): Int = {
    def loop(ln: LinkedNode[B], idx: Int, acc: List[Int]): List[Int] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, idx + 1, if (e == x) idx :: acc else acc)
    }
    loop(self, 0, List(-1)).head
  }

  def reverse: LinkedNode[A] = {
    def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match {
      case EmptyNode =>
        acc
       case Node(e, nx) =>
        loop(nx, Node(e, acc))
    }
    loop(self, EmptyNode)
  }

  def reverseK(k: Int): LinkedNode[A] = {
    if (k == 1)
      self
    else
      self.toList.grouped(k).flatMap(_.reverse).
        foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
  }

  def toList: List[A] = {
    def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, e :: acc)
    }
    loop(self, List.empty[A]).reverse
  }

  override def toString: String = {
    def loop(ln: LinkedNode[A], acc: String): String = ln match {
      case EmptyNode =>
        acc + "()"
      case Node(e, nx) =>
        loop(nx, acc + e + " -> " )
    }
    loop(self, "")
  }

  // ---- push / pop ----

  def push[B >: A](x: B): LinkedNode[B] = Node(x, self)

  def pop: (Option[A], LinkedNode[A]) = self match {
    case EmptyNode => (None, self)
    case Node(e, nx) => (Some(e), nx)
  }

  // ---- map / flatMap / fold ----

  def map[B](f: A => B): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      Node(f(e), nx.map(f))
  }

  def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      f(e) match {
        case EmptyNode => nx.flatMap(f)
        case Node(e2, _) => Node(e2, nx.flatMap(f))
      }
  }

  def foldLeft[B](z: B)(f: (B, A) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(acc, e))
    }
    loop(self, z)
  }

  def foldRight[B](z: B)(f: (A, B) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(e, acc))
    }
    loop(self.reverse, z)
  }
}

object LinkedNode {
  def apply[A](ls: List[A]): LinkedNode[A] =
    ls.foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
}

case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]

case object EmptyNode extends LinkedNode[Nothing]

A cursory test-run …

val ln = LinkedNode(List(1, 2, 2, 3, 4, 4, 5))
// ln: LinkedNode[Int] = 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ()

val ln2 = ln.insertNode(9, 3)
// ln2: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 3 -> 4 -> 4 -> 5 -> ()

val ln3 = ln2.deleteNode(4)
// ln3: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 4 -> 4 -> 5 -> ()

ln.reverseK(3)
// res1: LinkedNode[Int] = 2 -> 2 -> 1 -> 4 -> 4 -> 3 -> 5 -> ()

ln.find(4)
// res2: LinkedNode[Int] = 4 -> 4 -> 5 -> ()

ln.indexLast(4)
// res3: Int = 5

// ---- Using LinkedNode like a `stack` ----

ln.push(9)
// res1: LinkedNode[Int] = 9 -> 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ()

ln.pop
// res2: (Option[Int], LinkedNode[Int]) = (Some(1), 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ())

// ---- Using LinkedNode like a `Vector` collection ----

ln.map(_ + 10)
// res1: LinkedNode[Int] = 11 -> 12 -> 12 -> 13 -> 14 -> 14 -> 15 -> ()

ln.flatMap(i => Node(s"$i!", EmptyNode))
// res2: LinkedNode[String] = 1! -> 2! -> 2! -> 3! -> 4! -> 4! -> 5! -> ()

ln.foldLeft(0)(_ + _)
// res3: Int = 21

ln.foldRight("")((x, acc) => if (acc.isEmpty) s"$x" else s"$acc|$x")
// res4: String = 5|4|4|3|2|2|1

Traversing A Scala Collection

When we have a custom class in the form of an ADT, say, `Container[A]` and are required to process a collection of the derived objects like `List[Container[A]]`, there might be times we want to flip the collection “inside out” to become a single “container” of collection `Container[List[A]]`, and maybe further transform the inner collection with a function.

For those who are familiar with Scala Futures, the nature of such transformation is analogous to what method Future.sequence does. In case the traversal involves also applying to individual elements with a function, say, `f: A => Container[B]` to transform the collection into `Container[List[B]]`, it’ll be more like how Future.traverse works.

To illustrate how we can come up with methods `sequence` and `traverse` for the collection of our own ADTs, let’s assemble a simple ADT `Fillable[A]`. Our goal is to create the following two methods:

  def sequence[A](listFA: List[Fillable[A]]): Fillable[List[A]]
  def traverse[A, B](list: List[A])(f: A => Fillable[B]): Fillable[List[B]]

For simplicity, rather than a generic collection like IterableOnce, we fix the collection type to `List`.

A simple ADT

sealed trait Fillable[A]
case class Filled[A](a: A) extends Fillable[A]
case object Emptied extends Fillable[Nothing]

It looks a little like a home-made version of Scala `Option`, but is certainly not very useful yet. Let’s equip it with a companion object and a couple of methods for transforming the element within a `Fillable`:

sealed trait Fillable[+A] { self =>
  def map[B](f: A => B): Fillable[B] = self match {
    case Filled(a) => Filled(f(a))
    case Emptied   => Emptied
  }

  def flatMap[B](f: A => Fillable[B]): Fillable[B] = self match {
    case Filled(a) => f(a)
    case Emptied   => Emptied
  }
}

object Fillable {
  def apply[A](a: A): Fillable[A] = Filled(a)
}

case class Filled[A](a: A) extends Fillable[A]

case object Emptied extends Fillable[Nothing]

With slightly different signatures, methods `map` and `flatMap` are now available for transforming the element “contained” within a `Fillable`.

A couple of quick notes:

  • Fillable[A] is made covariant so that method `map/flatMap` is able to operate on subtypes of `Fillable`.
  • Using of self-type annotation isn’t necessary here, but is rather a personal coding style preference.

Testing the ADT:

val listF: List[Fillable[String]] = List(Filled("a"), Filled("bb"), Emptied, Filled("dddd"))

Filled("bb").map(_.length)
// res1: Fillable[Int] = Filled(2)

Filled("bb").flatMap(s => Fillable(s.length))
// res2: Fillable[Int] = Filled(2)

listF.map(_.map(_.length))
// res3: List[Fillable[Int]] = List(Filled(1), Filled(2), Emptied, Filled(4))

Sequencing a collection of Fillables

Let’s assemble method `sequence` which will reside within the companion object. Looking at the signature of the method to be defined:

  def sequence[A](listFA: List[Fillable[A]]): Fillable[List[A]]

it seems logical to consider aggregating a `List` from scratch within `Fillable` using Scala `fold`. However, trying to iteratively aggregate a list out of elements from within their individual “containers” isn’t as trivial as it may seem. Had there been methods like `get/getOrElse` that unwraps a `Fillable` to obtain the contained element, it would’ve been straight forward – although an implementation leveraging a getter method would require a default value for the contained element to handle the `Emptied` case.

One approach to implement `sequence` using only `map/flatMap` would be to first `map` within the `fold` operation each `Fillable` element of the input `List` into a list-push function for the element’s contained value, followed by a `flatMap` to aggregate the resulting `List` by iteratively applying the list-push functions within the `Fillable` container:

  def pushToList[A](a: A)(la: List[A]): List[A] = a :: la

  def sequence[A](listFA: List[Fillable[A]]): Fillable[List[A]] =
    listFA.foldRight(Fillable(List.empty[A])){ (fa, acc) =>
      fa match {
        case Emptied => acc
        case _       => fa.map(pushToList).flatMap(acc.map)
      }
    }

Note that `pushToList` within `map` is now regarded as a function that takes an element of type `A` and returns a `List[A] => List[A]` function. The expression `fa.map(pushToList).flatMap(acc.map)` is just a short-hand for:

  fa.map(a => pushToList(a)).flatMap(fn => acc.map(fn))

In essence, the first `map` transforms element within each `Fillable` in the input list into a corresponding list-push function for the specific element, and the `flatMap` uses the individual list-push functions for the inner `map` to iteratively aggregate the list inside the resulting `Fillable`.

Traversing the Fillable collection

Next, we’re going to define method `traverse` with the following signature within the companion object:

 def traverse[A, B](list: List[A])(f: A => Fillable[B]): Fillable[List[B]]

In case it doesn’t seem obvious, based on the method signatures, `sequence` is really just a special case of `traverse` with `f(a: A) = Fillable(a)`

Similar to the way `sequence` is implemented, we’ll also use `fold` for iterative aggregating the resulting list. Since an element of type Fillable[A] when `flatMap`-ed with the provided function `f` would yield a Fillable[B], we’re essentially dealing with the same problem we did for `sequence` except that we type `A` is now replaced with type `B`.

  def traverse[A, B](list: List[A])(f: A => Fillable[B]): Fillable[List[B]] =
    list.foldRight(Fillable(List.empty[B])) { (a, acc) =>
      val fb = f(a)
      fb match {
        case Emptied => acc
        case _ => fb.map(pushToList).flatMap(acc.map)
      }
    }

Putting everything together:

sealed trait Fillable[+A] { self =>
  def map[B](f: A => B): Fillable[B] = self match {
    case Filled(a) => Filled(f(a))
    case Emptied   => Emptied
  }

  def flatMap[B](f: A => Fillable[B]): Fillable[B] = self match {
    case Filled(a) => f(a)
    case Emptied   => Emptied
  }
}

object Fillable {
  def apply[A](a: A): Fillable[A] = Filled(a)

  def pushToList[A](a: A)(la: List[A]): List[A] = a :: la

  def sequence[A](listFA: List[Fillable[A]]): Fillable[List[A]] =
    listFA.foldRight(Fillable(List.empty[A])){ (fa, acc) =>
      fa match {
        case Emptied => acc
        case _       => fa.map(pushToList).flatMap(acc.map)
      }
    }

  def traverse[A, B](list: List[A])(f: A => Fillable[B]): Fillable[List[B]] =
    list.foldRight(Fillable(List.empty[B])) { (a, acc) =>
      val fb = f(a)
      fb match {
        case Emptied => acc
        case _ => fb.map(pushToList).flatMap(acc.map)
      }
    }
}

case class Filled[A](a: A) extends Fillable[A]

case object Emptied extends Fillable[Nothing]

Testing with the newly created methods:

val listF: List[Fillable[String]] = List(Filled("a"), Filled("bb"), Emptied, Filled("dddd"))

Fillable.sequence(listF)
// res1: Fillable[List[Int]] = Filled(List(a, bb, dddd))

val list: List[String] = List("a", "bb", "ccc", "dddd")

Fillable.traverse[String, Int](list)(a => Fillable(a.length))
// res2: Fillable[List[Int]] = Filled(List(1, 2, 3, 4))

Fillable.traverse[String, Int](list)(_ => Emptied)
// res3: Fillable[List[Int]] = Filled(List())