It’s pretty safe to assert that all software engineers must have come across stack overflow problems over the course of their career. For a computational task that involves recursions with complex programming logic, keeping the stack frames in a controlled manner could be challenging.
What is trampolining?
As outlined in the Wikipedia link about trampoline, it could mean different things in different programming paradigms. In the Scala world, trampolining is a means of making a function with recursive programming logic stack-safe by formulating the recursive logic flow into tail calls among functional components wrapped in objects, resulting in code being run in the JVM heap memory without the need of stacks.
A classic example for illustrating stack overflow issues is the good old Factorial.
def naiveFactorial(n: Int): BigInt = { require(n >= 0, "N must be a non-negative integer!") if (n == 0) 1 else n * naiveFactorial(n-1) } naiveFactorial(10000) // java.lang.StackOverflowError!
Tail call versus tail recursion
It should be noted that tail call can be, in a way, viewed as a generalization of tail recursion. A tail recursive function is a tail-calling function that calls only itself, whereas a general tail-calling function could be performing tail calls of other functions. Tail recursive functions are not only stack-safe they oftentimes can operate with an efficient O(n)
time complexity and O(1)
space complexity.
On the other hand, trampolining a function provides a stack-safe solution but it does not speed up computation. To the contrary, it often takes longer time than the stack-based method to produce result (that is, until the stack overflow problem surfaces). That’s largely due to the necessary “bounces” between functions and object-wrapping.
Factorial with Scala TailCalls
Scala’s standard library provides TailCalls for writing stack-safe recursive functions via trampolining.
The stack-safe factorial function using Scala TailCalls
will look like below:
def trampFactorial(n: Int): TailRec[BigInt] = { require(n >= 0, "N must be a non-negative integer!") if (n == 0) done(1) else tailcall(trampFactorial(n-1)).map(n * _) } trampFactorial(10000).result // 28462596809170545189064132121198...
Scala TailCalls under the hood
To understand how things work under the hood, let’s extract the skeletal components from object TailCalls‘ source code.
object TailCalls { sealed abstract class TailRec[+A] { final def map[B](f: A => B): TailRec[B] = flatMap(a => Call(() => Done(f(a)))) final def flatMap[B](f: A => TailRec[B]): TailRec[B] = this match { case Done(a) => Call(() => f(a)) case c @ Call(_) => Cont(c, f) case c: Cont[a1, _] => Cont(c.a, (x: a1) => c.f(x) flatMap f) } @scala.annotation.tailrec final def result: A = this match { case Done(a) => a case Call(t) => t().result case Cont(a, f) => a match { case Done(v) => f(v).result case Call(t) => t().flatMap(f).result case Cont(b, g) => b.flatMap(x => g(x) flatMap f).result } } } protected case class Done[A](value: A) extends TailRec[A] protected case class Call[A](rest: () => TailRec[A]) extends TailRec[A] protected case class Cont[A, B](a: TailRec[A], f: A => TailRec[B]) extends TailRec[B] def tailcall[A](rest: => TailRec[A]): TailRec[A] = Call(() => rest) def done[A](result: A): TailRec[A] = Done(result) }
Object TailCalls consists of the base class TailRec
which consists of transformation methods flatMap
and map
, and a tail-recursive method result
that performs the actual evaluations. Also included are subclasses Done
, Call
, and Cont
, which encapsulate expressions that represent, respectively, the following:
- literal returning values
- recursive calls
- continual transformations
In addition, methods tailcall()
and done()
are exposed to the users so they don’t need to directly meddle with the subclasses.
In general, to enable trampolining for a given function, make the function’s return value of type TailRec
, followed by restructuring any literal function returns, recursive calls and continual transformations within the function body into corresponding subclasses of TailRec
— all in a tail-calling fashion.
Methods done() and tailcall()
As can be seen from the TailCalls
source code, all the subclasses (i.e. Done
, Call
, Cont
) are shielded by the protected
access modifier. Users are expected to use the interfacing methods done()
and tailcall()
along with class methods map
and flatMp
to formulate a tail-call version of the target function.
def tailcall[A](rest: => TailRec[A]): TailRec[A] = Call(() => rest) def done[A](result: A): TailRec[A] = Done(result)
Method done(value)
is just equivalent to Done(value)
, whereas tailcall(r: => TailRec)
represents Call(() => r)
. It should be noted that the by-name
parameter of tailcall()
is critical to ensuring laziness
, correlating to the Function0
parameter of class Call()
. It’s an integral part of the stack-safe mechanics of trampolining.
Methods map and flatMap
Back to the factorial function. In the n == 0
case, the return value is a constant, so we wrap it in a Done()
. For the remaining case, it involves continual operations and we wrap it in a Call()
. Obviously, we can’t simply do tailcall(n * trampFactorial(n-1))
since trampFactorial()
is now an object. Rather, we transform via map
with a function t => n * t
, similar to how we transform the internal value of an Option
or Future
.
But then tailcall(trampFactorial(n-1)).map(n * _)
doesn’t look like a tail call. Why is it able to accomplish stack safety? To find out why, let’s look at how map
and flatMap
are implemented in the Scala TailCalls source code.
Map and flatMap are “trampolines”
final def map[B](f: A => B): TailRec[B] = flatMap(a => Call(() => Done(f(a)))) final def flatMap[B](f: A => TailRec[B]): TailRec[B] = this match { case Done(a) => Call(() => f(a)) case c @ Call(_) => Cont(c, f) case c: Cont[a1, b1] => Cont(c.a, (x: a1) => c.f(x) flatMap f) }
From the source code, one can see the implementation of its class method flatMap
follows the same underlying principles of trampolining — regardless of which subclass the current TailRec object belongs to, the method returns a tail-calling Call()
or Cont()
. That makes flatMap
itself a trampolining transformation.
As for method map
, it’s implemented using the special case of flatMap
transforming a => Call(() => Done(f(a)))
, which is a trampolining tail call as well. Thus, both map
and flatMap
are trampolining transformations. Consequently, a tail expression of an arbitrary sequence of transformations with the two methods will preserve trampolining. That gives great flexibility for users to formulate a tail-calling function.
Evaluating the tail-call function
The tail-calling function will return a TailRec
object, but all that in the heap memory is just an object “wired” for the trampolining mechanism. It won’t get evaluated until class method result
is called.
@scala.annotation.tailrec final def result: A = this match { case Done(a) => a case Call(t) => t().result case Cont(a, f) => a match { case Done(v) => f(v).result case Call(t) => t().flatMap(f).result case Cont(b, g) => b.flatMap(x => g(x) flatMap f).result } }
Constructed as an optimally efficient tail-recursive function, method result
evaluates the function by matching the current TailRec[A]
object against each of the subclasses to carry out the programming logic accordingly to return the resulting value of the type specified as the type parameter A
.
If the current TailRec
object is a Cont(a, f)
which represents a transformation with function f
on TailRec
object a
, the transformation will be carried out in accordance with what a
is (thus another level of subclass matching). The class method flatMap
comes in handy for carrying out the necessary composable transformation f
as its signature conforms to that of the function taken by flatMap
.
Trampolining Fibonacci
As a side note, Fibonacci generally will not incur stack overflow due to its relatively small space complexity, thus there is essentially no reason to apply trampolining. Nevertheless, it still serves as a good exercise of how to use TailCalls. On the other hand, a tail-recursive version of Fibonacci is highly efficient (see example in this previous post).
def naiveFibonacci(n: Int): BigInt = { require(n >= 0, "N must be a non-negative integer!") if (n <= 1) n else naiveFibonacci(n-2) + naiveFibonacci(n-1) } import scala.util.control.TailCalls._ def trampFibonacci(n: Int): TailRec[BigInt] = { require(n >= 0, "N must be a non-negative integer!") if (n <= 1) done(n) else tailcall(trampFibonacci(n-2)).flatMap(t => tailcall(trampFibonacci(n-1).map(t + _))) }
In case it isn’t obvious, the else
case expression is to sum the by-definition values of F(n-2)
and F(n-1
) wrapped in tailcall(F(n-2))
and tailcall(F(n-1))
, respectively, via flatMap
and map
:
tailcall(trampFibonacci(n-2)).flatMap{ t2 => tailcall(trampFibonacci(n-1).map(t1 => t2 + t1)) }
which could also be achieved using for-comprehension
:
for { t2 <- tailcall(trampFibonacci(n-2)) t1 <- tailcall(trampFibonacci(n-1)) } yield t2 + t1
Height of binary search tree
Let's look at one more trampoline example that involves computing the height of a binary search tree. The following is derived from a barebone version of the binary search tree defined in a previous blog post:
import scala.util.control.TailCalls._ sealed trait BSTree[+A] { self => def height: Int = self match { case BSLeaf => 0 case BSBranch(_, _, l, r) => (l.height max r.height) + 1 } def heightTC: Int = self match { case BSLeaf => 0 case BSBranch(_, _, ltree, rtree) => { def loop(l: BSTree[A], r: BSTree[A], ht: Int): TailRec[Int] = (l, r) match { case (BSLeaf, BSLeaf) => done(ht) case (BSBranch(_, _, ll, lr), BSLeaf) => tailcall(loop(ll, lr, ht+1)) case (BSLeaf, BSBranch(_, _, rl, rr)) => tailcall(loop(rl, rr, ht+1)) case (BSBranch(_, _, ll, lr), BSBranch(_, _, rl, rr)) => tailcall(loop(ll, lr, ht+1)).flatMap(lHt => tailcall(loop(rl, rr, ht+1)).map(rHt => lHt max rHt ) ) } loop(ltree, rtree, 1).result } } } case class BSBranch[+A]( elem: A, count: Int = 1, left: BSTree[A] = BSLeaf, right: BSTree[A] = BSLeaf ) extends BSTree[A] case object BSLeaf extends BSTree[Nothing]
The original height
method is kept in trait BSTree
as a reference. For the trampoline version heightTC
, A helper function loop
with an accumulator (i.e. ht
) for aggregating the tree height is employed to tally the level of tree height. Using flatMap
and map
(or equivalently a for-comprehension
), the main recursive tracing of the tree height follows similar tactic that the trampolining Fibonacci function uses.
Test running heightTC
:
BSLeaf.heightTC // 0 BSBranch(30, 1).heightTC // 1 BSBranch(30, 1, BSBranch(10, 1, BSLeaf, BSBranch(20, 2) ), BSLeaf ). heightTC // 3 /* 30 / 10 \ 20 */ BSBranch(30,1, BSBranch(10, 1, BSLeaf, BSBranch(20, 2) ), BSBranch(50, 2, BSBranch(40, 1), BSBranch(70, 1, BSBranch(60, 2), BSBranch(80, 1) ) ) ). heightTC // 4
While at it, a tail-recursive version of the tree height method can be created with just slightly different approach. To achieve tail recursion, a recursive function is run with a Scala List of tree node-height tuples along with a max tree height value as function parameters, as shown below.
sealed trait BSTree[+A] { self => // ... def heightTR: Int = self match { case BSLeaf => 0 case BSBranch(_, _, lt, rt) => { @scala.annotation.tailrec def loop(list: List[(BSTree[A], Int)], maxHt: Int): Int = list match { case Nil => maxHt case head :: tail => val (tree, ht) = head tree match { case BSLeaf => loop(tail, maxHt) case BSBranch(_, _, l, r) => loop((l, ht+1) :: (r, ht+1) :: tail, ht max maxHt) } } loop((self, 1) :: Nil, 0) } } // ... }