Tag Archives: tail recursion

Trampolining with Scala TailCalls

It’s pretty safe to assert that all software engineers must have come across stack overflow problems over the course of their career. For a computational task that involves recursions with complex programming logic, keeping the stack frames in a controlled manner could be challenging.

What is trampolining?

As outlined in the Wikipedia link about trampoline, it could mean different things in different programming paradigms. In the Scala world, trampolining is a means of making a function with recursive programming logic stack-safe by formulating the recursive logic flow into tail calls among functional components wrapped in objects, resulting in code being run in the JVM heap memory without the need of stacks.

A classic example for illustrating stack overflow issues is the good old Factorial.

def naiveFactorial(n: Int): BigInt = {
  require(n >= 0, "N must be a non-negative integer!")
  if (n == 0)
    1
  else
    n * naiveFactorial(n-1)
}

naiveFactorial(10000)  // java.lang.StackOverflowError!

Tail call versus tail recursion

It should be noted that tail call can be, in a way, viewed as a generalization of tail recursion. A tail recursive function is a tail-calling function that calls only itself, whereas a general tail-calling function could be performing tail calls of other functions. Tail recursive functions are not only stack-safe they oftentimes can operate with an efficient O(n) time complexity and O(1) space complexity.

On the other hand, trampolining a function provides a stack-safe solution but it does not speed up computation. To the contrary, it often takes longer time than the stack-based method to produce result (that is, until the stack overflow problem surfaces). That’s largely due to the necessary “bounces” between functions and object-wrapping.

Factorial with Scala TailCalls

Scala’s standard library provides TailCalls for writing stack-safe recursive functions via trampolining.

The stack-safe factorial function using Scala TailCalls will look like below:

def trampFactorial(n: Int): TailRec[BigInt] = {
  require(n >= 0, "N must be a non-negative integer!")
  if (n == 0)
    done(1)
  else
    tailcall(trampFactorial(n-1)).map(n * _)
}

trampFactorial(10000).result  // 28462596809170545189064132121198...

Scala TailCalls under the hood

To understand how things work under the hood, let’s extract the skeletal components from object TailCalls‘ source code.

object TailCalls {

  sealed abstract class TailRec[+A] {
    final def map[B](f: A => B): TailRec[B] =
      flatMap(a => Call(() => Done(f(a))))

    final def flatMap[B](f: A => TailRec[B]): TailRec[B] = this match {
      case Done(a) => Call(() => f(a))
      case c @ Call(_) => Cont(c, f)
      case c: Cont[a1, _] => Cont(c.a, (x: a1) => c.f(x) flatMap f)
    }

    @scala.annotation.tailrec
    final def result: A = this match {
      case Done(a) => a
      case Call(t) => t().result
      case Cont(a, f) => a match {
        case Done(v) => f(v).result
        case Call(t) => t().flatMap(f).result
        case Cont(b, g) => b.flatMap(x => g(x) flatMap f).result
      }
    }
  }

  protected case class Done[A](value: A) extends TailRec[A]

  protected case class Call[A](rest: () => TailRec[A]) extends TailRec[A]

  protected case class Cont[A, B](a: TailRec[A], f: A => TailRec[B]) extends TailRec[B]

  def tailcall[A](rest: => TailRec[A]): TailRec[A] = Call(() => rest)

  def done[A](result: A): TailRec[A] = Done(result)
}

Object TailCalls consists of the base class TailRec which consists of transformation methods flatMap and map, and a tail-recursive method result that performs the actual evaluations. Also included are subclasses Done, Call, and Cont, which encapsulate expressions that represent, respectively, the following:

  • literal returning values
  • recursive calls
  • continual transformations

In addition, methods tailcall() and done() are exposed to the users so they don’t need to directly meddle with the subclasses.

In general, to enable trampolining for a given function, make the function’s return value of type TailRec, followed by restructuring any literal function returns, recursive calls and continual transformations within the function body into corresponding subclasses of TailRec — all in a tail-calling fashion.

Methods done() and tailcall()

As can be seen from the TailCalls source code, all the subclasses (i.e. Done, Call, Cont) are shielded by the protected access modifier. Users are expected to use the interfacing methods done() and tailcall() along with class methods map and flatMp to formulate a tail-call version of the target function.

  def tailcall[A](rest: => TailRec[A]): TailRec[A] = Call(() => rest)

  def done[A](result: A): TailRec[A] = Done(result)

Method done(value) is just equivalent to Done(value), whereas tailcall(r: => TailRec) represents Call(() => r). It should be noted that the by-name parameter of tailcall() is critical to ensuring laziness, correlating to the Function0 parameter of class Call(). It’s an integral part of the stack-safe mechanics of trampolining.

Methods map and flatMap

Back to the factorial function. In the n == 0 case, the return value is a constant, so we wrap it in a Done(). For the remaining case, it involves continual operations and we wrap it in a Call(). Obviously, we can’t simply do tailcall(n * trampFactorial(n-1)) since trampFactorial() is now an object. Rather, we transform via map with a function t => n * t, similar to how we transform the internal value of an Option or Future.

But then tailcall(trampFactorial(n-1)).map(n * _) doesn’t look like a tail call. Why is it able to accomplish stack safety? To find out why, let’s look at how map and flatMap are implemented in the Scala TailCalls source code.

Map and flatMap are “trampolines”

    final def map[B](f: A => B): TailRec[B] =
      flatMap(a => Call(() => Done(f(a))))

    final def flatMap[B](f: A => TailRec[B]): TailRec[B] =
      this match {
        case Done(a) => Call(() => f(a))
        case c @ Call(_) => Cont(c, f)
        case c: Cont[a1, b1] => Cont(c.a, (x: a1) => c.f(x) flatMap f)
      }

From the source code, one can see the implementation of its class method flatMap follows the same underlying principles of trampolining — regardless of which subclass the current TailRec object belongs to, the method returns a tail-calling Call() or Cont(). That makes flatMap itself a trampolining transformation.

As for method map, it’s implemented using the special case of flatMap transforming a => Call(() => Done(f(a))), which is a trampolining tail call as well. Thus, both map and flatMap are trampolining transformations. Consequently, a tail expression of an arbitrary sequence of transformations with the two methods will preserve trampolining. That gives great flexibility for users to formulate a tail-calling function.

Evaluating the tail-call function

The tail-calling function will return a TailRec object, but all that in the heap memory is just an object “wired” for the trampolining mechanism. It won’t get evaluated until class method result is called.

    @scala.annotation.tailrec
    final def result: A = this match {
      case Done(a) => a
      case Call(t) => t().result
      case Cont(a, f) => a match {
        case Done(v) => f(v).result
        case Call(t) => t().flatMap(f).result
        case Cont(b, g) => b.flatMap(x => g(x) flatMap f).result
      }
    }

Constructed as an optimally efficient tail-recursive function, method result evaluates the function by matching the current TailRec[A] object against each of the subclasses to carry out the programming logic accordingly to return the resulting value of the type specified as the type parameter A.

If the current TailRec object is a Cont(a, f) which represents a transformation with function f on TailRec object a, the transformation will be carried out in accordance with what a is (thus another level of subclass matching). The class method flatMap comes in handy for carrying out the necessary composable transformation f as its signature conforms to that of the function taken by flatMap.

Trampolining Fibonacci

As a side note, Fibonacci generally will not incur stack overflow due to its relatively small space complexity, thus there is essentially no reason to apply trampolining. Nevertheless, it still serves as a good exercise of how to use TailCalls. On the other hand, a tail-recursive version of Fibonacci is highly efficient (see example in this previous post).

def naiveFibonacci(n: Int): BigInt = {
  require(n >= 0, "N must be a non-negative integer!")
  if (n <= 1)
    n
  else
    naiveFibonacci(n-2) + naiveFibonacci(n-1)
}

import scala.util.control.TailCalls._

def trampFibonacci(n: Int): TailRec[BigInt] = {
  require(n >= 0, "N must be a non-negative integer!")
  if (n <= 1)
    done(n)
  else
    tailcall(trampFibonacci(n-2)).flatMap(t => tailcall(trampFibonacci(n-1).map(t + _)))
}

In case it isn’t obvious, the else case expression is to sum the by-definition values of F(n-2) and F(n-1) wrapped in tailcall(F(n-2)) and tailcall(F(n-1)), respectively, via flatMap and map:

    tailcall(trampFibonacci(n-2)).flatMap{ t2 =>
      tailcall(trampFibonacci(n-1).map(t1 =>
        t2 + t1))
    }

which could also be achieved using for-comprehension:

    for {
      t2 <- tailcall(trampFibonacci(n-2))
      t1 <- tailcall(trampFibonacci(n-1))
    }
    yield t2 + t1

Height of binary search tree

Let's look at one more trampoline example that involves computing the height of a binary search tree. The following is derived from a barebone version of the binary search tree defined in a previous blog post:

import scala.util.control.TailCalls._

sealed trait BSTree[+A] { self =>

  def height: Int = self match {
    case BSLeaf => 0
    case BSBranch(_, _, l, r) => (l.height max r.height) + 1
  }

  def heightTC: Int = self match {
    case BSLeaf => 0
    case BSBranch(_, _, ltree, rtree) => {
      def loop(l: BSTree[A], r: BSTree[A], ht: Int): TailRec[Int] = (l, r) match {
        case (BSLeaf, BSLeaf) => done(ht)
        case (BSBranch(_, _, ll, lr), BSLeaf) => tailcall(loop(ll, lr, ht+1))
        case (BSLeaf, BSBranch(_, _, rl, rr)) => tailcall(loop(rl, rr, ht+1))
        case (BSBranch(_, _, ll, lr), BSBranch(_, _, rl, rr)) =>
          tailcall(loop(ll, lr, ht+1)).flatMap(lHt =>
            tailcall(loop(rl, rr, ht+1)).map(rHt =>
              lHt max rHt
            )
          )
      }
      loop(ltree, rtree, 1).result
    }
  }
}

case class BSBranch[+A](
    elem: A, count: Int = 1, left: BSTree[A] = BSLeaf, right: BSTree[A] = BSLeaf
  ) extends BSTree[A]

case object BSLeaf extends BSTree[Nothing]

The original height method is kept in trait BSTree as a reference. For the trampoline version heightTC, A helper function loop with an accumulator (i.e. ht) for aggregating the tree height is employed to tally the level of tree height. Using flatMap and map (or equivalently a for-comprehension), the main recursive tracing of the tree height follows similar tactic that the trampolining Fibonacci function uses.

Test running heightTC:

BSLeaf.heightTC  // 0

BSBranch(30, 1).heightTC  // 1

BSBranch(30, 1,
    BSBranch(10, 1,
      BSLeaf,
      BSBranch(20, 2)
    ),
    BSLeaf
  ).
  heightTC  // 3

/*
           30
        /
    10 
       \
        20
*/

BSBranch(30,1,
    BSBranch(10, 1,
      BSLeaf,
      BSBranch(20, 2)
    ),
    BSBranch(50, 2,
      BSBranch(40, 1),
      BSBranch(70, 1,
        BSBranch(60, 2),
        BSBranch(80, 1)
      )
    )
  ).
  heightTC  // 4

While at it, a tail-recursive version of the tree height method can be created with just slightly different approach. To achieve tail recursion, a recursive function is run with a Scala List of tree node-height tuples along with a max tree height value as function parameters, as shown below.

sealed trait BSTree[+A] { self =>
  // ...

  def heightTR: Int = self match {
    case BSLeaf => 0
    case BSBranch(_, _, lt, rt) => {
      @scala.annotation.tailrec
      def loop(list: List[(BSTree[A], Int)], maxHt: Int): Int = list match {
        case Nil =>
          maxHt
        case head :: tail =>
          val (tree, ht) = head
          tree match {
            case BSLeaf =>
              loop(tail, maxHt)
            case BSBranch(_, _, l, r) =>
              loop((l, ht+1) :: (r, ht+1) :: tail, ht max maxHt)
          }
      }
      loop((self, 1) :: Nil, 0)
    }
  }

  // ...
}

Implementing Linked List In Scala

In Scala, if you wonder why its standard library doesn’t come with a data structure called `LinkedList`, you may have overlooked. The collection List is in fact a linked list — although it often appears in the form of a Seq or Vector collection rather than the generally “mysterious” linked list that exposes its “head” with a hidden “tail” to be revealed only iteratively.

Our ADT: LinkedNode

Perhaps because of its simplicity and dynamicity as a data structure, implementation of linked list remains a popular coding exercise. To implement our own linked list, let’s start with a barebone ADT (algebraic data structure) as follows:

sealed trait LinkedNode[+A]
case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]
case object EmptyNode extends LinkedNode[Nothing]

If you’re familiar with Scala List, you’ll probably notice that our ADT resembles `List` and its subclasses `Cons` (i.e. `::`) and `Nil` (see source code):

sealed abstract class List[+A] { ... }
final case class ::[+A](override val head: A, var next: List[A]) extends List[A] { ... }
case object Nil extends List[Nothing] { ... }

Expanding LinkedNode

Let’s expand trait `LinkedNode` to create class methods `insertNode`/`deleteNode` at a given index for inserting/deleting a node, `toList` for extracting contained elements into a `List` collection, and `toString` for display:

sealed trait LinkedNode[+A] { self =>

  def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match {
      case EmptyNode =>
        if (i < idx)
          EmptyNode
        else
          Node(x, EmptyNode)
      case Node(e, nx) =>
        if (i < idx)
          Node(e, loop(nx, i + 1))
        else
          Node(x, ln)
    }
    loop(self, 0)
  }

  def deleteNode(idx: Int): LinkedNode[A] = {
    def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (idx == 0)
          nx
        else {
          if (i < idx - 1)
            Node(e, loop(nx, i + 1))
          else
            nx match {
              case EmptyNode => Node(e, EmptyNode)
              case Node(_, nx2) => Node(e, nx2)
            }
        }
    }
    loop(self, 0)
  }

  def toList: List[A] = {
    def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, e :: acc)
    }
    loop(self, List.empty[A]).reverse
  }

  override def toString: String = {
    def loop(ln: LinkedNode[A], acc: String): String = ln match {
      case EmptyNode =>
        acc + "()"
      case Node(e, nx) =>
        loop(nx, acc + e + " -> " )
    }
    loop(self, "")
  }
}

case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]

case object EmptyNode extends LinkedNode[Nothing]

// Test running ...

val ln = List(1, 2, 3, 3, 4).foldRight[LinkedNode[Int]](EmptyNode)(Node(_, _))
// ln: LinkedNode[Int] = 1 -> 2 -> 3 -> 3 -> 4 -> ()

ln.insertNode(9, 2)
// res1: LinkedNode[Int] = 1 -> 2 -> 9 -> 3 -> 3 -> 4 -> ()

ln.deleteNode(1)
// res2: LinkedNode[Int] = 1 -> 3 -> 3 -> 4 -> ()

Note that `LinkedNode` is made covariant. In addition, method `insertNode` has type `A` as its lower type bound because Function1 is `contravariant` over its parameter type.

Recursion and pattern matching

A couple of notes on the approach we implement our class methods with:

  1. We use recursive functions to avoid using of mutable variables. They should be made tail-recursive for optimal performance, but that isn’t the focus of this implementation. If performance is a priority, using conventional while-loops with mutable class fields `elem`/`next` would be a more practical option.
  2. Pattern matching is routinely used for handling cases of `Node` versus `EmptyNode`. An alternative approach would be to define fields `elem` and `next` in the base trait and implement class methods accordingly within `Node` and `EmptyNode`.

Finding first/last matching nodes

Next, we add a couple of class methods for finding first/last matching nodes.

  def find[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (e == x) ln else loop(nx)
    }
    loop(self)
  }

  def findLast[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        if (e == x) loop(nx, ln :: acc) else loop(nx, acc)
    }
    loop(self, List.empty[LinkedNode[B]]).headOption match {
      case None => EmptyNode
      case Some(node) => node
    }
  }

Reversing a linked list by groups of nodes

Reversing a given LinkedNode can be accomplished via recursion by cumulatively wrapping the element of each `Node` in a new `Node` with its `next` pointer set to the `Node` created in the previous iteration.

  def reverse: LinkedNode[A] = {
    def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, Node(e, acc))
    }
    loop(self, EmptyNode)
  }

  def reverseK(k: Int): LinkedNode[A] = {
    if (k == 1)
      self
    else
      self.toList.grouped(k).flatMap(_.reverse).
        foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
  }

Method `reverseK` reverses a LinkedNode by groups of `k` elements using a different approach that extracts the elements into groups of `k` elements, reverses the elements in each of the groups and re-wraps each of the flattened elements in a `Node`.

Using LinkedNode as a Stack

For `LinkedNode` to serve as a Stack, we include simple methods `push` and `pop` as follows:

  def push[B >: A](x: B): LinkedNode[B] = Node(x, self)

  def pop: (Option[A], LinkedNode[A]) = self match {
    case EmptyNode => (None, self)
    case Node(e, nx) => (Some(e), nx)
  }

Addendum: implementing map, flatMap, fold

From a different perspective, if we view `LinkedNode` as a collection like a Scala List or Vector, we might be craving for methods like `map`, `flatMap`, `fold`. Using the same approach of recursion along with pattern matching, it’s rather straight forward to crank them out.

  def map[B](f: A => B): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      Node(f(e), nx.map(f))
  }

  def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      f(e) match {
        case EmptyNode => nx.flatMap(f)
        case Node(e2, _) => Node(e2, nx.flatMap(f))
      }
  }

  def foldLeft[B](z: B)(f: (B, A) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(acc, e))
    }
    loop(self, z)
  }

  def foldRight[B](z: B)(f: (A, B) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(e, acc))
    }
    loop(self.reverse, z)
  }
}

Putting everything together

Along with a few additional simple class methods and a factory method wrapped in `LinkedNode`’s companion object, below is the final `LinkedNode` ADT that includes everything described above.

sealed trait LinkedNode[+A] { self =>

  def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match {
      case EmptyNode =>
        if (i < idx)
          EmptyNode
        else
          Node(x, EmptyNode)
      case Node(e, nx) =>
        if (i < idx)
          Node(e, loop(nx, i + 1))
        else
          Node(x, ln)
    }
    loop(self, 0)
  }

  def deleteNode(idx: Int): LinkedNode[A] = {
    def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (idx == 0)
          nx
        else {
          if (i < idx - 1)
            Node(e, loop(nx, i + 1))
          else
            nx match {
              case EmptyNode => Node(e, EmptyNode)
              case Node(_, nx2) => Node(e, nx2)
            }
        }
    }
    loop(self, 0)
  }

  def get(idx: Int): LinkedNode[A] = {
    def loop(ln: LinkedNode[A], count: Int): LinkedNode[A] = {
      if (count < idx) {
        ln match {
          case EmptyNode => EmptyNode
          case Node(_, next) => loop(next, count + 1)
        }
      }
      else
        ln
    }
    loop(self, 0)
  }

  def find[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (e == x) ln else loop(nx)
    }
    loop(self)
  }

  def findLast[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        if (e == x) loop(nx, ln :: acc) else loop(nx, acc)
    }
    loop(self, List.empty[LinkedNode[B]]).headOption match {
      case None => EmptyNode
      case Some(node) => node
    }
  }

  def indexOf[B >: A](x: B): Int = {
    def loop(ln: LinkedNode[B], idx: Int): Int = ln match {
      case EmptyNode =>
        -1
      case Node(e, nx) =>
        if (e == x) idx else loop(nx, idx + 1)
     }
    loop(self, 0)
  }

  def indexLast[B >: A](x: B): Int = {
    def loop(ln: LinkedNode[B], idx: Int, acc: List[Int]): List[Int] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, idx + 1, if (e == x) idx :: acc else acc)
    }
    loop(self, 0, List(-1)).head
  }

  def reverse: LinkedNode[A] = {
    def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match {
      case EmptyNode =>
        acc
       case Node(e, nx) =>
        loop(nx, Node(e, acc))
    }
    loop(self, EmptyNode)
  }

  def reverseK(k: Int): LinkedNode[A] = {
    if (k == 1)
      self
    else
      self.toList.grouped(k).flatMap(_.reverse).
        foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
  }

  def toList: List[A] = {
    def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, e :: acc)
    }
    loop(self, List.empty[A]).reverse
  }

  override def toString: String = {
    def loop(ln: LinkedNode[A], acc: String): String = ln match {
      case EmptyNode =>
        acc + "()"
      case Node(e, nx) =>
        loop(nx, acc + e + " -> " )
    }
    loop(self, "")
  }

  // ---- push / pop ----

  def push[B >: A](x: B): LinkedNode[B] = Node(x, self)

  def pop: (Option[A], LinkedNode[A]) = self match {
    case EmptyNode => (None, self)
    case Node(e, nx) => (Some(e), nx)
  }

  // ---- map / flatMap / fold ----

  def map[B](f: A => B): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      Node(f(e), nx.map(f))
  }

  def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      f(e) match {
        case EmptyNode => nx.flatMap(f)
        case Node(e2, _) => Node(e2, nx.flatMap(f))
      }
  }

  def foldLeft[B](z: B)(f: (B, A) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(acc, e))
    }
    loop(self, z)
  }

  def foldRight[B](z: B)(f: (A, B) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(e, acc))
    }
    loop(self.reverse, z)
  }
}

object LinkedNode {
  def apply[A](ls: List[A]): LinkedNode[A] =
    ls.foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
}

case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]

case object EmptyNode extends LinkedNode[Nothing]

A cursory test-run …

val ln = LinkedNode(List(1, 2, 2, 3, 4, 4, 5))
// ln: LinkedNode[Int] = 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ()

val ln2 = ln.insertNode(9, 3)
// ln2: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 3 -> 4 -> 4 -> 5 -> ()

val ln3 = ln2.deleteNode(4)
// ln3: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 4 -> 4 -> 5 -> ()

ln.reverseK(3)
// res1: LinkedNode[Int] = 2 -> 2 -> 1 -> 4 -> 4 -> 3 -> 5 -> ()

ln.find(4)
// res2: LinkedNode[Int] = 4 -> 4 -> 5 -> ()

ln.indexLast(4)
// res3: Int = 5

// ---- Using LinkedNode like a `stack` ----

ln.push(9)
// res1: LinkedNode[Int] = 9 -> 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ()

ln.pop
// res2: (Option[Int], LinkedNode[Int]) = (Some(1), 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ())

// ---- Using LinkedNode like a `Vector` collection ----

ln.map(_ + 10)
// res1: LinkedNode[Int] = 11 -> 12 -> 12 -> 13 -> 14 -> 14 -> 15 -> ()

ln.flatMap(i => Node(s"$i!", EmptyNode))
// res2: LinkedNode[String] = 1! -> 2! -> 2! -> 3! -> 4! -> 4! -> 5! -> ()

ln.foldLeft(0)(_ + _)
// res3: Int = 21

ln.foldRight("")((x, acc) => if (acc.isEmpty) s"$x" else s"$acc|$x")
// res4: String = 5|4|4|3|2|2|1

Fibonacci In Scala: Tailrec, Memoized

One of the most popular number series being used as a programming exercise is undoubtedly the Fibonacci numbers:

F(0) = 1
F(1) = 1
F(n) = F(n-1) + F(n-2)

Perhaps a prominent reason why the Fibonacci sequence is of vast interest in Math is the associated Golden Ratio, but I think what makes it a great programming exercise is that despite a simplistic definition, the sequence’s exponential growth rate presents challenges in implementations with space/time efficiency in mind.

Having seen various ways of implementing methods for the Fibonacci numbers, I thought it might be worth putting them together, from a naive implementation to something more space/time efficient. But first, let’s take a quick look at the computational complexity of Fibonacci.

Fibonacci complexity

If we denote T(n) as the time required to compute F(n), by definition:

T(n) = T(n-1) + T(n-2) + K

where K is the time taken by some simple arithmetic to arrive at F(n) from F(n-1) and F(n-2).

With some approximation Math analysis (see this post), it can be shown that the lower bound and upper bound of T(n) are O(2^(n/2)) and O(2^n), respectively. For better precision, one can derive a more exact time complexity by solving the associated characteristic equation, `x^2 = x + 1`, which yields x = ~1.618 to deduce that:

Time complexity for computing F(n) = O(R^n)

where R = ~1.618 is the Golden Ratio.

As for space complexity, if one looks at the recursive tree for computing F(n), it’s pretty clear that its depth is F(n-1)’s tree depth plus one. Thus, the required space for F(n) is proportional to n. In other words:

Space complexity for computing F(n) = O(n)

The relatively small space complexity compared with the exponential time complexity explains why computing a Fibonacci number too large for a computer would generally lead to an infinite run rather than a out-of-memory/stack overflow problem.

It’s worth noting, though, if F(n) is computed via conventional iterations (e.g. a while-loop or tail recursion which gets translated into iterations by Scala under the hood), the time complexity would be reduced to O(n) proportional to the number of the loop cycles. And the space complexity would be O(1) since no `n`-dependent extra space is needed other than that for storing the Fibonacci sequence.

Naive Fibonacci

To generate Fibonacci numbers, the most straight forward approach is via a basic recursive function like below:

def fib(n: Int): BigInt = n match {
  case 0 => 0
  case 1 => 1
  case _ => fib(n-2) + fib(n-1)
}

(0 to 10).foreach(n => print(fib(n) + " "))
// 0 1 1 2 3 5 8 13 21 34 55

fib(50)
// res1: BigInt = 12586269025

With such a `naive` recursive function, computing the 50th number, i.e. fib(50), would take minutes on a typical laptop, and attempts to compute any number higher up like fib(90) would most certainly lead to an infinite run.

Tail recursive Fibonacci

So, let’s come up with a tail recursive method:

def fibTR(num: Int): BigInt = {
  @scala.annotation.tailrec
  def fibFcn(n: Int, acc1: BigInt, acc2: BigInt): BigInt = n match {
    case 0 => acc1
    case 1 => acc2
    case _ => fibFcn(n - 1, acc2, acc1 + acc2)
  }

  fibFcn(num, 0, 1)
}

As shown above, tail recursion is accomplished by means of a couple of accumulators as parameters for the inner method to recursively carry over the two numbers that precede the current number.

With the Fibonacci `TailRec` version, computing, say, the 90th number would finish instantaneously.

fibTR(90)
// res2: BigInt = 2880067194370816120

Fibonacci in a Scala Stream

Another way of implementing Fibonacci is to define the sequence to be stored in a “lazy” collection, such as a Scala Stream:

val fibS: Stream[BigInt] = 0 #:: fibS.scan(BigInt(1))(_ + _)

fibS(90)
// res3: BigInt = 2880067194370816120

Using method scan, `scan(1)(_ + _)` generates a Stream with each of its elements being successively assigned the sum of the previous two elements. Since Streams are “lazy”, none of the element values in the defined `fibStream` will be evaluated until the element is being requested.

While at it, there is a couple of other commonly seen Fibonacci implementation variants with Scala Stream:

val fibS: Stream[BigInt] = 0 #:: 1 #:: (fibS zip fibS.tail).map(n => n._1 + n._2)

val fibS: Stream[BigInt] = {
  def fs(prev: BigInt, curr: BigInt): Stream[BigInt] = prev #:: fs(curr, prev + curr)
  fs(0, 1)
}

Scala Stream memoizes by design

These Stream-based Fibonacci implementations perform reasonably well, somewhat comparable to the tail recursive Fibonacci. But while these Stream implementations all involve recursion, none is tail recursive. So, why doesn’t it suffer the same performance issue like the `naive` Fibonacci implementation does? The short answer is memoization.

Digging into the source code of Scala Stream would reveal that method `#::` (which is wrapped in class ConsWrapper) is defined as:

def #::[B >: A](hd: B): Stream[B] = cons(hd, tl) 

Tracing method `cons` further reveals that the Stream tail is a by-name parameter to class `Cons`, thus ensuring that the concatenation is performed lazily:

final class Cons[+A](hd: A, tl: => Stream[A]) extends Stream[A]

But lazy evaluation via by-name parameter does nothing to memoization. Digging deeper into the source code, one would see that Stream content is iterated through a StreamIterator class defined as follows:

final class StreamIterator[+A] private() extends AbstractIterator[A] with Iterator[A] {
  def this(self: Stream[A]) {
    this()
    these = new LazyCell(self)
  }

  class LazyCell(st: => Stream[A]) {
    lazy val v = st
  }

  private var these: LazyCell = _

  def hasNext: Boolean = these.v.nonEmpty

  def next(): A =
    if (isEmpty) Iterator.empty.next()
    else {
      val cur    = these.v
      val result = cur.head
      these = new LazyCell(cur.tail)
      result
    }

  ...
}

The inner class `LazyCell` not only has a by-name parameter but, more importantly, makes the Stream represented by the StreamIterator instance a `lazy val` which, by nature, enables memoization by caching the value upon the first (and only first) evaluation.

Memoized Fibonacci using a mutable Map

While using a Scala Stream to implement Fibonacci would automatically leverage memoization, one could also explicitly employ the very feature without Streams. For instance, by leveraging method getOrElseUpdate in a mutable Map, a `memoize` function can be defined as follows:

// Memoization using mutable Map

def memoize[K, V](f: K => V): K => V = {
  val cache = scala.collection.mutable.Map.empty[K, V]
  k => cache.getOrElseUpdate(k, f(k))
}

For example, the `naive` Fibonacci equipped with memoization via this `memoize` function would instantly become a much more efficient implementation:

val fibM: Int => BigInt = memoize(n => n match {
  case 0 => 0
  case 1 => 1
  case _ => fibM(n-2) + fibM(n-1)
})

fibM(90)
// res4: BigInt = 2880067194370816120

For the tail recursive Fibonacci `fibTR`, this `memoize` function wouldn’t be applicable as its inner function `fibFcn` takes accumulators as additional parameters. As for the Stream-based `fibS` which is already equipped with Stream’s memoization, applying `memoize` wouldn’t produce any significant performance gain.