Scala Binary Search Tree

When I wrote about Scala linked list implementation a couple of years ago, I also did some quick ground work for implementing binary search trees (BST). Occupied by other R&D projects at the time, it was put aside and has since been patiently awaiting its turn to see the light of day. As much of the code is already there, I’m going to put it up in this blog post along with some narrative remarks.

First, we come up with an ADT (algebraic data type). Let’s call it BSTree, starting out with a base trait with generic type A for the data element to be stored inside the tree structure, to be extended by a case class BSBranch as tree branches and a case object BSLeaf as “null” tree nodes. The ADT’s overall structure resembles that of the one used in the linked list implementation described in the old post.

ADT BSTree

sealed trait BSTree[+A] { self =>
  // ...

  override def toString: String = self match {
    case BSLeaf => "-"
    case BSBranch(e, c, l, r) => (l, r) match {
      case (BSLeaf, BSLeaf) => s"TN($e[$c], -, -)"
      case (BSBranch(le, lc, _, _), BSLeaf) => s"TN($e[$c], $le[$lc], -)"
      case (BSLeaf, BSBranch(re, rc, _, _)) => s"TN($e[$c], -, $re[$rc])"
      case (BSBranch(le, lc, _, _), BSBranch(re, rc, _, _)) => s"TN($e[$c], $le[$lc], $re[$rc])"
    }
  }
}

case class BSBranch[+A](
    elem: A, count: Int = 1, left: BSTree[A] = BSLeaf, right: BSTree[A] = BSLeaf
  ) extends BSTree[A]

case object BSLeaf extends BSTree[Nothing]

A few notes:

  • Data elements (of generic type A) are stored only in branches and data of the same value will go to the same branch and be represented with a proper count value.
  • BSTree is covariant, or else BSLeaf can’t even be defined as a sub-type of BSTree[Nothing].
  • A toString method is created for simplified string output of a tree instance.

Populating a BSTree

One of the first things we need is a method to insert tree nodes into an existing BSTree. We start expanding the base trait with method insert(). That’s all great for adding a node one at a time, but we also need a way to create a BSTree and populate it from a readily available collection of data elements. It makes sense to delegate such a factory method to the companion object BSTree as its method apply().

sealed trait BSTree[+A] { self =>

  def insert[B >: A : Ordering](elem: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSBranch(elem, 1, BSLeaf, BSLeaf)
      case t @ BSBranch(e, c, l, r) =>
        if (e > elem)
          t.copy(left = l.insert(elem))
        else if (e < elem)
          t.copy(right = r.insert(elem))
        else
          t.copy(count = c + 1)
    }
  }

  // ...
}

object BSTree {
  def apply[A : Ordering](elems: Vector[A]): BSTree[A] = {
    val ord = implicitly[Ordering[A]]
    import ord._
    elems.foldLeft[BSTree[A]](BSLeaf)(_.insert(_))
  }
}

Note that type parameter B for insert() needs to be a supertype of A because Function1 is contravariant over its parameter type. In addition, the context bound "B : Ordering" constrains type B to be capable of being ordered (i.e. compared) which is necessary for traversing a binary search tree.

Testing BSTree.apply():

val tree = BSTree(Vector(30, 10, 50, 20, 40, 70, 60, 80, 20, 50, 60))
// tree: BSTree[Int] = TN(30[1], 10[1], 50[2])

/*
tree:
           30
        /       \
    10            50
       \        /    \
        20     40     70
                    /    \
                   60    80
*/

Tree traversal and finding tree nodes

Next, we need methods for tree traversal and search. For brevity, we only include in-order traversal.

  def traverseInOrder: Unit = self match {
    case BSLeaf => ()
    case BSBranch(_, _, l, r) =>
      l.traverseInOrder
      println(self)
      r.traverseInOrder
  }

  def traverseByLevel: Unit = {
    def loop(tree: BSTree[A], level: Int): Unit = {
      if (level > 0)
        tree match {
          case BSLeaf =>
          case BSBranch(_, _, l, r) =>
            loop(l, level - 1)
            loop(r, level - 1)
        }
      else
        print(s"$tree  ")
    }
    for (l <- 0 until self.height) {
      loop(self, l)
      println()
    }
  }

  def find[B >: A : Ordering](elem: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, c, l, r) =>
        if (e < elem)
          r.find(elem)
        else if (e > elem)
          l.find(elem)
        else
          t
    }
  }

  def height: Int = self match {
    case BSLeaf => 0
    case BSBranch(_, _, l, r) => (l.height max r.height) + 1
  }

Using the tree created above:

/*
tree:
           30
        /       \
    10            50
       \        /    \
        20     40     70
                    /    \
                   60    80
*/

tree.height  // 4

tree.traverseInOrder
/*
TN(10[1], -, 20[2])
TN(20[2], -, -)
TN(30[1], 10[1], 50[2])
TN(40[1], -, -)
TN(50[2], 40[1], 70[1])
TN(60[2], -, -)
TN(70[1], 60[2], 80[1])
TN(80[1], -, -)
*/

tree.traverseByLevel
/*
TN(30[1], 10[1], 50[2])
TN(10[1], -, 20[2])  TN(50[2], 40[1], 70[1])
TN(20[2], -, -)  TN(40[1], -, -)  TN(70[1], 60[2], 80[1])
TN(60[2], -, -)  TN(80[1], -, -)
*/

val tree50 = tree.find(50)  // TN(50[2], 40[1], 70[1])
/*
tree50:
           50
        /       \
    40            70
                /    \
               60     80
*/

Removing tree nodes

To be able to remove tree nodes that consist of a specific or range of element values, we include also the following few methods in the base trait.

Note that delete() may involve a little shuffling of the tree nodes. Once the tree node to be removed is located, that node may need to be filled with the node having the next-bigger element & count values from its right node (or equivalently, the node having the next-smaller element from its left node).

  def delete[B >: A : Ordering](elem: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    def leftMost(tree: BSTree[B], prev: BSTree[B]): BSTree[B] = tree match {
      case BSLeaf => prev
      case t @ BSBranch(_, _, l, _) => leftMost(l, t)
    }
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, c, l, r) =>
        if (e < elem)
          t.copy(right = r.delete(elem))
        else if (e > elem)
          t.copy(left = l.delete(elem))
        else {
          if (l == BSLeaf)
            r
          else if (r == BSLeaf)
            l
          else {
            val nextBigger = leftMost(r, r)
            nextBigger match { case BSBranch(enb, cnb, _, _) =>
              t.copy(elem = enb, count = cnb, right = r.delete(enb))
            }
          }
        }
    }
  }

  def trim[B >: A : Ordering](lower: B, upper: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, _, l, r) =>
        val tt = t.copy(left = l.trim(lower, upper), right = r.trim(lower, upper)) 
        if (e < lower)
          tt match { case BSBranch(_, _, _, r) => r }
        else if (e > upper)
          tt match { case BSBranch(_, _, l, _) => l }
        else
          tt
    }
  }

  def cutOut[B >: A : Ordering](lower: B, upper: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, _, l, r) =>
        val tt = t.copy(left = l.cutOut(lower, upper), right = r.cutOut(lower, upper)) 
        if (e >= lower && e <= upper)
          tt match { case BSBranch(e, _, _, _) => tt.delete(e) }
        else
          tt
    }
  }

Method trim() removes tree nodes with element values below or above the provided range. Meanwhile, method cutOut() does the opposite by cutting out tree nodes with values within the given range. It involves slightly more work than trim(), requiring the use of delete() for individual tree nodes.

Example:

val treeD50 = tree.delete(50)  // TN(30[1], 10[1], 60[2])
/*
treeD50:
           30
        /       \
    10            60
       \        /    \
        20     40     70
                         \
                          80
*/

val treeT15to65 = tree.trim(15, 65)  // TN(30[1], 20[2], 50[2])
/*
treeT15to65:
           30
        /       \
    20            50
                /    \
               40     60
*/

val treeC25to55 = tree.trimInRange(25, 55)  // TN(60[1], 10[1], 70[1])
/*
treeC25to55:
           60
        /       \
    10            70
       \             \
        20            80
*/

Rebuilding a binary search tree

A highly unbalanced binary search tree beats the purpose of using such a data structure. One of the most straight forward ways to rebuild a binary search tree is to "unpack" the individual tree nodes of the existing tree by traversing in-order into a list (e.g. a Vector or List) of elements, followed by reconstructing a new tree with nodes being assigned elements from recursively half-ing the in-order node list.

  def rebuild: BSTree[A] = {
    def loop(vs: Vector[(A, Int)], lower: Int, upper: Int): BSTree[A] = {
      if (upper >= lower) {
        val m = (lower + upper) / 2
        val (e, c) = vs(m)
        BSBranch(e, c, loop(vs, lower, m-1), loop(vs, m+1, upper))
      }
      else
        BSLeaf
    }
    self match {
      case BSLeaf =>
        BSLeaf
      case _ =>
        val vs = self.toElemVector
        loop(vs, 0, vs.size - 1)
    }
  }

  def isBalanced: Boolean = self match {
    case BSLeaf => true
    case BSBranch(_, _, l, r) =>
      if (math.abs(l.height - r.height) > 1)
        false
      else
        l.isBalanced && r.isBalanced
  }

  def toElemVector: Vector[(A, Int)] = self match {  // in-order
    case BSLeaf =>
      Vector.empty[(A, Int)]
    case BSBranch(e, c, l, r) =>
      (l.toElemVector :+ (e, c)) ++ r.toElemVector
  }

  def toVector: Vector[BSTree[A]] = self match {  // in-order
    case BSLeaf => Vector.empty[BSTree[A]]
    case BSBranch(_, _, l, r) =>
      (l.toVector :+ self) ++ r.toVector
  }

Example:

val unbalancedTree = BSTree(Vector(20, 10, 30, 30, 40, 50, 60, 60, 50))
// unbalancedTree: BSTree[Int] = TN(20[1], 10[1], 30[2])

/*
unbalancedTree:
           20
        /       \
    10            30
                     \
                     40
                        \
                         60
                        /
                       50
*/

unbalancedTree.isBalanced  // false

val rebuiltTree = unbalancedTree.rebuild  // TN(30[2], 10[1], 50[2])
/*
rebuiltTree:

           30
        /       \
    10            50
       \        /    \
        20     40     60
*/

rebuiltTree.isBalanced  // true

Thoughts on the ADT

An alternative to how the ADT is designed is to have the class fields and methods declared in the BSTree base trait with specific implementations reside within subclasses BSBranch and BSLeaf, thus eliminating the need of the boiler-plate pattern matching for the subclasses. There is also the benefit of making class fields like left & right referenceable from the base trait, though they would need to be wrapped in Options with value None for BSLeaf.

As can be seen with the existing ADT, an advantage is having all the binary tree functions defined within the base trait once and for all. If there is the need for having left and right referenceable from the BSTree base trait, one can define something like below within the trait.

  def leftIfAny: Option[BSTree[A]] = self match {
    case BSLeaf => None
    case BSBranch(_, _, l, _) => Some(l)
  }

  def rightIfAny: Option[BSTree[A]] = self match {
    case BSLeaf => None
    case BSBranch(_, _, _, r) => Some(r)
  }

Example:

tree.leftIfAny
// Some(TN(10[1], -, 20[2]))

tree.rightIfAny.flatMap(_.rightIfAny)
// Some(TN(70[1], 60[2], 80[1]))

Then there is also non-idiomatic approach of using mutable class fields in a single tree class commonly seen in Java implementation, like below:

object SimpleBST {
  class TreeNode[A](_elem: A, _left: TreeNode[A] = null, _right: TreeNode[A] = null) {
    var elem: A = _elem
    var count: Int = 1
    var left: TreeNode[A] = _left
    var right: TreeNode[A] = _right
    override def toString: String =
      if (this == null)
        "null"
      else (left, right) match {
        case (null, null) => s"TN($elem[$count], -, -)"
        case (l,    null) => s"TN($elem[$count], ${l.elem}[${l.count}], -)"
        case (null, r)    => s"TN($elem[$count], -, ${r.elem}[${r.count}])"
        case (l,    r)    => s"TN($elem[$count], ${l.elem}[${l.count}], ${r.elem}[${r.count}])"
      }
  }

  // Methods for node addition, deletion, etc ...
}

Addendum: Complete source code of the BSTree ADT

sealed trait BSTree[+A] { self =>

  def traverseInOrder: Unit = self match {
    case BSLeaf => ()
    case BSBranch(_, _, l, r) =>
      l.traverseInOrder
      println(self)
      r.traverseInOrder
  }

  def traverseByLevel: Unit = {
    def loop(tree: BSTree[A], level: Int): Unit = {
      if (level > 0)
        tree match {
          case BSLeaf =>
          case BSBranch(_, _, l, r) =>
            loop(l, level - 1)
            loop(r, level - 1)
        }
      else
        print(s"$tree  ")
    }
    for (l <- 0 until self.height) {
      loop(self, l)
      println()
    }
  }

  def find[B >: A : Ordering](elem: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, c, l, r) =>
        if (e < elem)
          r.find(elem)
        else if (e > elem)
          l.find(elem)
        else
          t
    }
  }

  def height: Int = self match {
    case BSLeaf => 0
    case BSBranch(_, _, l, r) => (l.height max r.height) + 1
  }

  def insert[B >: A : Ordering](elem: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSBranch(elem, 1, BSLeaf, BSLeaf)
      case t @ BSBranch(e, c, l, r) =>
        if (e > elem)
          t.copy(left = l.insert(elem))
        else if (e < elem)
          t.copy(right = r.insert(elem))
        else
          t.copy(count = c + 1)
    }
  }

  def delete[B >: A : Ordering](elem: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    def leftMost(tree: BSTree[B], prev: BSTree[B]): BSTree[B] = tree match {
      case BSLeaf => prev
      case t @ BSBranch(_, _, l, _) => leftMost(l, t)
    }
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, c, l, r) =>
        if (e < elem)
          t.copy(right = r.delete(elem))
        else if (e > elem)
          t.copy(left = l.delete(elem))
        else {
          if (l == BSLeaf)
            r
          else if (r == BSLeaf)
            l
          else {
            val nextBigger = leftMost(r, r)
            nextBigger match { case BSBranch(enb, cnb, _, _) =>
              t.copy(elem = enb, count = cnb, right = r.delete(enb))
            }
          }
        }
    }
  }

  def trim[B >: A : Ordering](lower: B, upper: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, _, l, r) =>
        val tt = t.copy(left = l.trim(lower, upper), right = r.trim(lower, upper)) 
        if (e < lower)
          tt match { case BSBranch(_, _, _, r) => r }
        else if (e > upper)
          tt match { case BSBranch(_, _, l, _) => l }
        else
          tt
    }
  }

  def cutOut[B >: A : Ordering](lower: B, upper: B): BSTree[B] = {
    val ord = implicitly[Ordering[B]]
    import ord._
    self match {
      case BSLeaf => BSLeaf
      case t @ BSBranch(e, _, l, r) =>
        val tt = t.copy(left = l.cutOut(lower, upper), right = r.cutOut(lower, upper)) 
        if (e >= lower && e <= upper)
          tt match { case BSBranch(e, _, _, _) => tt.delete(e) }
        else
          tt
    }
  }

  def rebuild(): BSTree[A] = {
    def loop(vs: Vector[A], lower: Int, upper: Int): BSTree[A] = {
      if (upper >= lower) {
        val m = (lower + upper) / 2
        val em = vs(m)
        self match {
          case BSBranch(e, c, _, _) if e == em =>
            BSBranch(em, c+1, loop(vs, lower, m-1), loop(vs, m+1, upper))
          case _ =>
            BSBranch(em, 1, loop(vs, lower, m-1), loop(vs, m+1, upper))
        }
      }
      else
        BSLeaf
    }
    self match {
      case BSLeaf =>
        BSLeaf
      case _ =>
        val vs = self.toElemVector
        loop(vs, 0, vs.size - 1)
    }
  }

  def isBalanced: Boolean = self match {
    case BSLeaf => true
    case BSBranch(_, _, l, r) =>
      if (math.abs(l.height - r.height) > 1)
        false
      else
        l.isBalanced && r.isBalanced
  }

  def toElemVector: Vector[A] = self match {  // in-order
    case BSLeaf =>
      Vector.empty[A]
    case BSBranch(e, _, l, r) =>
      (l.toElemVector :+ e) ++ r.toElemVector
  }

  def toVector: Vector[BSTree[A]] = self match {  // in-order
    case BSLeaf => Vector.empty[BSTree[A]]
    case BSBranch(_, _, l, r) =>
      (l.toVector :+ self) ++ r.toVector
  }

  override def toString: String = self match {
    case BSLeaf => "-"
    case BSBranch(e, c, l, r) => (l, r) match {
      case (BSLeaf, BSLeaf) => s"TN($e[$c], -, -)"
      case (BSBranch(le, lc, _, _), BSLeaf) => s"TN($e[$c], $le[$lc], -)"
      case (BSLeaf, BSBranch(re, rc, _, _)) => s"TN($e[$c], -, $re[$rc])"
      case (BSBranch(le, lc, _, _), BSBranch(re, rc, _, _)) => s"TN($e[$c], $le[$lc], $re[$rc])"
    }
  }
}

object BSTree {
  def apply[A : Ordering](elems: Vector[A]): BSTree[A] = {
    val ord = implicitly[Ordering[A]]
    import ord._
    elems.foldLeft[BSTree[A]](BSLeaf)(_.insert(_))
  }
}

case class BSBranch[A](
    elem: A, count: Int = 1, left: BSTree[A] = BSLeaf, right: BSTree[A] = BSLeaf
  ) extends BSTree[A]

case object BSLeaf extends BSTree[Nothing]

One thought on “Scala Binary Search Tree

  1. Pingback: Trampolining with Scala TailCalls - Genuine Blog

Leave a Reply

Your email address will not be published. Required fields are marked *