Tag Archives: tail recursion

Trampolining with Scala TailCalls

It’s pretty safe to assert that all software engineers must have come across stack overflow problems over the course of their career. For a computational task that involves recursions with complex programming logic, keeping the stack frames in a controlled manner could be challenging.

What is trampolining?

As outlined in the Wikipedia link about trampoline, it could mean different things in different programming paradigms. In the Scala world, trampolining is a means of making a function with recursive programming logic stack-safe by formulating the recursive logic flow into tail calls among functional components wrapped in objects, resulting in code being run in the JVM heap memory without the need of stacks.

A classic example for illustrating stack overflow issues is the good old Factorial.

Tail call versus tail recursion

It should be noted that tail call can be, in a way, viewed as a generalization of tail recursion. A tail recursive function is a tail-calling function that calls only itself, whereas a general tail-calling function could be performing tail calls of other functions. Tail recursive functions are not only stack-safe they oftentimes can operate with an efficient O(n) time complexity and O(1) space complexity.

On the other hand, trampolining a function provides a stack-safe solution but it does not speed up computation. To the contrary, it often takes longer time than the stack-based method to produce result (that is, until the stack overflow problem surfaces). That’s largely due to the necessary “bounces” between functions and object-wrapping.

Factorial with Scala TailCalls

Scala’s standard library provides TailCalls for writing stack-safe recursive functions via trampolining.

The stack-safe factorial function using Scala TailCalls will look like below:

Scala TailCalls under the hood

To understand how things work under the hood, let’s extract the skeletal components from object TailCalls‘ source code.

Object TailCalls consists of the base class TailRec which consists of transformation methods flatMap and map, and a tail-recursive method result that performs the actual evaluations. Also included are subclasses Done, Call, and Cont, which encapsulate expressions that represent, respectively, the following:

  • literal returning values
  • recursive calls
  • continual transformations

In addition, methods tailcall() and done() are exposed to the users so they don’t need to directly meddle with the subclasses.

In general, to enable trampolining for a given function, make the function’s return value of type TailRec, followed by restructuring any literal function returns, recursive calls and continual transformations within the function body into corresponding subclasses of TailRec — all in a tail-calling fashion.

Methods done() and tailcall()

As can be seen from the TailCalls source code, all the subclasses (i.e. Done, Call, Cont) are shielded by the protected access modifier. Users are expected to use the interfacing methods done() and tailcall() along with class methods map and flatMp to formulate a tail-call version of the target function.

Method done(value) is just equivalent to Done(value), whereas tailcall(r: => TailRec) represents Call(() => r). It should be noted that the by-name parameter of tailcall() is critical to ensuring laziness, correlating to the Function0 parameter of class Call(). It’s an integral part of the stack-safe mechanics of trampolining.

Methods map and flatMap

Back to the factorial function. In the n == 0 case, the return value is a constant, so we wrap it in a Done(). For the remaining case, it involves continual operations and we wrap it in a Call(). Obviously, we can’t simply do tailcall(n * trampFactorial(n-1)) since trampFactorial() is now an object. Rather, we transform via map with a function t => n * t, similar to how we transform the internal value of an Option or Future.

But then tailcall(trampFactorial(n-1)).map(n * _) doesn’t look like a tail call. Why is it able to accomplish stack safety? To find out why, let’s look at how map and flatMap are implemented in the Scala TailCalls source code.

Map and flatMap are “trampolines”

From the source code, one can see the implementation of its class method flatMap follows the same underlying principles of trampolining — regardless of which subclass the current TailRec object belongs to, the method returns a tail-calling Call() or Cont(). That makes flatMap itself a trampolining transformation.

As for method map, it’s implemented using the special case of flatMap transforming a => Call(() => Done(f(a))), which is a trampolining tail call as well. Thus, both map and flatMap are trampolining transformations. Consequently, a tail expression of an arbitrary sequence of transformations with the two methods will preserve trampolining. That gives great flexibility for users to formulate a tail-calling function.

Evaluating the tail-call function

The tail-calling function will return a TailRec object, but all that in the heap memory is just an object “wired” for the trampolining mechanism. It won’t get evaluated until class method result is called.

Constructed as an optimally efficient tail-recursive function, method result evaluates the function by matching the current TailRec[A] object against each of the subclasses to carry out the programming logic accordingly to return the resulting value of the type specified as the type parameter A.

If the current TailRec object is a Cont(a, f) which represents a transformation with function f on TailRec object a, the transformation will be carried out in accordance with what a is (thus another level of subclass matching). The class method flatMap comes in handy for carrying out the necessary composable transformation f as its signature conforms to that of the function taken by flatMap.

Trampolining Fibonacci

As a side note, Fibonacci generally will not incur stack overflow due to its relatively small space complexity, thus there is essentially no reason to apply trampolining. Nevertheless, it still serves as a good exercise of how to use TailCalls. On the other hand, a tail-recursive version of Fibonacci is highly efficient (see example in this previous post).

In case it isn’t obvious, the else case expression is to sum the by-definition values of F(n-2) and F(n-1) wrapped in tailcall(F(n-2)) and tailcall(F(n-1)), respectively, via flatMap and map:

which could also be achieved using for-comprehension:

Height of binary search tree

Let’s look at one more trampoline example that involves computing the height of a binary search tree. The following is derived from a barebone version of the binary search tree defined in a previous blog post:

The original height method is kept in trait BSTree as a reference. For the trampoline version heightTC, A helper function loop with an accumulator (i.e. ht) for aggregating the tree height is employed to tally the level of tree height. Using flatMap and map (or equivalently a for-comprehension), the main recursive tracing of the tree height follows similar tactic that the trampolining Fibonacci function uses.

Test running heightTC:

While at it, a tail-recursive version of the tree height method can be created with just slightly different approach. To achieve tail recursion, a recursive function is run with a Scala List of tree node-height tuples along with a max tree height value as function parameters, as shown below.

Implementing Linked List In Scala

In Scala, if you wonder why its standard library doesn’t come with a data structure called LinkedList, you may have overlooked. The collection List is in fact a linked list — although it often appears in the form of a Seq or Vector collection rather than the generally “mysterious” linked list that exposes its “head” with a hidden “tail” to be revealed only iteratively.

Our ADT: LinkedNode

Perhaps because of its simplicity and dynamicity as a data structure, implementation of linked list remains a popular coding exercise. To implement our own linked list, let’s start with a barebone ADT (algebraic data structure) as follows:

If you’re familiar with Scala List, you’ll probably notice that our ADT resembles List and its subclasses Cons (i.e. ::) and Nil (see source code):

Expanding LinkedNode

Let’s expand trait LinkedNode to create class methods insertNode/deleteNode at a given index for inserting/deleting a node, toList for extracting contained elements into a List collection, and toString for display:

Note that LinkedNode is made covariant. In addition, method insertNode has type A as its lower type bound because Function1 is contravariant over its parameter type.

Recursion and pattern matching

A couple of notes on the approach we implement our class methods with:

  1. We use recursive functions to avoid using of mutable variables. They should be made tail-recursive for optimal performance, but that isn’t the focus of this implementation. If performance is a priority, using conventional while-loops with mutable class fields elem/next would be a more practical option.
  2. Pattern matching is routinely used for handling cases of Node versus EmptyNode. An alternative approach would be to define fields elem and next in the base trait and implement class methods accordingly within Node and EmptyNode.

Finding first/last matching nodes

Next, we add a couple of class methods for finding first/last matching nodes.

Reversing a linked list by groups of nodes

Reversing a given LinkedNode can be accomplished via recursion by cumulatively wrapping the element of each Node in a new Node with its next pointer set to the Node created in the previous iteration.

Method reverseK reverses a LinkedNode by groups of k elements using a different approach that extracts the elements into groups of k elements, reverses the elements in each of the groups and re-wraps each of the flattened elements in a Node.

Using LinkedNode as a Stack

For LinkedNode to serve as a Stack, we include simple methods push and pop as follows:

Addendum: implementing map, flatMap, fold

From a different perspective, if we view LinkedNode as a collection like a Scala List or Vector, we might be craving for methods like map, flatMap, fold. Using the same approach of recursion along with pattern matching, it’s rather straight forward to crank them out.

Putting everything together

Along with a few additional simple class methods and a factory method wrapped in LinkedNode‘s companion object, below is the final LinkedNode ADT that includes everything described above.

A cursory test-run …

Fibonacci In Scala: Tailrec, Memoized

One of the most popular number series being used as a programming exercise is undoubtedly the Fibonacci numbers:

Perhaps a prominent reason why the Fibonacci sequence is of vast interest in Math is the associated Golden Ratio, but I think what makes it a great programming exercise is that despite a simplistic definition, the sequence’s exponential growth rate presents challenges in implementations with space/time efficiency in mind.

Having seen various ways of implementing methods for the Fibonacci numbers, I thought it might be worth putting them together, from a naive implementation to something more space/time efficient. But first, let’s take a quick look at the computational complexity of Fibonacci.

Fibonacci complexity

If we denote T(n) as the time required to compute F(n), by definition:

where K is the time taken by some simple arithmetic to arrive at F(n) from F(n-1) and F(n-2).

With some approximation Math analysis (see this post), it can be shown that the lower bound and upper bound of T(n) are O(2^(n/2)) and O(2^n), respectively. For better precision, one can derive a more exact time complexity by solving the associated characteristic equation, x^2 = x + 1, which yields x = ~1.618 to deduce that:

where R = ~1.618 is the Golden Ratio.

As for space complexity, if one looks at the recursive tree for computing F(n), it’s pretty clear that its depth is F(n-1)’s tree depth plus one. Thus, the required space for F(n) is proportional to n. In other words:

The relatively small space complexity compared with the exponential time complexity explains why computing a Fibonacci number too large for a computer would generally lead to an infinite run rather than a out-of-memory/stack overflow problem.

It’s worth noting, though, if F(n) is computed via conventional iterations (e.g. a while-loop or tail recursion which gets translated into iterations by Scala under the hood), the time complexity would be reduced to O(n) proportional to the number of the loop cycles. And the space complexity would be O(1) since no n-dependent extra space is needed other than that for storing the Fibonacci sequence.

Naive Fibonacci

To generate Fibonacci numbers, the most straight forward approach is via a basic recursive function like below:

With such a naive recursive function, computing the 50th number, i.e. fib(50), would take minutes on a typical laptop, and attempts to compute any number higher up like fib(90) would most certainly lead to an infinite run.

Tail recursive Fibonacci

So, let’s come up with a tail recursive method:

As shown above, tail recursion is accomplished by means of a couple of accumulators as parameters for the inner method to recursively carry over the two numbers that precede the current number.

With the Fibonacci TailRec version, computing, say, the 90th number would finish instantaneously.

Fibonacci in a Scala Stream

Another way of implementing Fibonacci is to define the sequence to be stored in a “lazy” collection, such as a Scala Stream:

Using method scan, scan(1)(_ + _) generates a Stream with each of its elements being successively assigned the sum of the previous two elements. Since Streams are “lazy”, none of the element values in the defined fibStream will be evaluated until the element is being requested.

While at it, there is a couple of other commonly seen Fibonacci implementation variants with Scala Stream:

Scala Stream memoizes by design

These Stream-based Fibonacci implementations perform reasonably well, somewhat comparable to the tail recursive Fibonacci. But while these Stream implementations all involve recursion, none is tail recursive. So, why doesn’t it suffer the same performance issue like the naive Fibonacci implementation does? The short answer is memoization.

Digging into the source code of Scala Stream would reveal that method #:: (which is wrapped in class ConsWrapper) is defined as:

Tracing method cons further reveals that the Stream tail is a by-name parameter to class Cons, thus ensuring that the concatenation is performed lazily:

But lazy evaluation via by-name parameter does nothing to memoization. Digging deeper into the source code, one would see that Stream content is iterated through a StreamIterator class defined as follows:

The inner class LazyCell not only has a by-name parameter but, more importantly, makes the Stream represented by the StreamIterator instance a lazy val which, by nature, enables memoization by caching the value upon the first (and only first) evaluation.

Memoized Fibonacci using a mutable Map

While using a Scala Stream to implement Fibonacci would automatically leverage memoization, one could also explicitly employ the very feature without Streams. For instance, by leveraging method getOrElseUpdate in a mutable Map, a memoize function can be defined as follows:

For example, the naive Fibonacci equipped with memoization via this memoize function would instantly become a much more efficient implementation:

For the tail recursive Fibonacci fibTR, this memoize function wouldn’t be applicable as its inner function fibFcn takes accumulators as additional parameters. As for the Stream-based fibS which is already equipped with Stream’s memoization, applying memoize wouldn’t produce any significant performance gain.