Category Archives: All About Software Technology

Scala Cats Typeclasses At A Glance

Scala Cats comes with a rich set of typeclasses, each of which “owns” a well-defined autonomous problem space. Many of those typeclasses are correlated and some are extended from others.

In this blog post, we’re going to give an at-a-glance hierarchical view of some of the most common Cats typeclasses. For brevity, we’ll skip discussions re: their corresponding mathematical laws, which can be found in many relevant tech docs. Our focus will be more on highlighting the correlations among these typeclasses.

Common typeclass hierarchy

For the impatient, below is a diagram highlighting the hierarchical correlation.

Scala Cats Common Typeclass Hierarchy

Semigroup and Monoid

Let’s start with the simplest ones, Semigroup and Monoid.

Semigroup comes with the abstract method combine to be implemented with the specific “combine” computational logic such as the addition of integers, union of sets, etc.

trait Semigroup[A] {
  def combine(x: A, y: A): A
}

trait Monoid[A] extends Semigroup[A] {
  def combine(x: A, y: A): A
  def empty: A
}

Note that Monoid simply supplements Semigroup with empty as the “zero” or “identity” element, allowing aggregating operations of arbitrarily many elements (e.g. summation of numbers from an initial 0).

Example:

implicit def setSemigroup[A]: Semigroup[Set[A]] =
  new Semigroup[Set[A]] {
    def combine(s1: Set[A], s2: Set[A]) = s1 union s2
  }

// Or, using SAM for brevity
// implicit def setSemigroup[A]: Semigroup[Set[A]] = _ union _

val setSG = implicitly[Semigroup[Set[Char]]]
setSG.combine(Set('a', 'b'), Set('c'))
// Set('a', 'b', 'c')

implicit def setMonoid[A](implicit sg: Semigroup[Set[A]]): Monoid[Set[A]] =
  new Monoid[Set[A]] {
    def combine(s1: Set[A], s2: Set[A]) = sg.combine(s1, s2)
    def empty = Set()
  }

val setM = implicitly[Monoid[Set[Char]]]
List(Set('a','b'),Set('c'),Set('d','e')).
  foldLeft(setM.empty)(setM.combine(_, _))
// HashSet('e', 'a', 'b', 'c', 'd')

SemigroupK and MonoidK

With a similar correlation, SemigroupK and MonoidK are the higher-kinded version of Semigroup and Monoid, respectively. SemigroupK combines values within a given context and MonoidK ensures the existence of an “empty” context.

trait SemigroupK[F[_]] {
  def combineK[A](x: F[A], y: F[A]): F[A]
}

trait MonoidK[F[_]] extends SemigroupK[F] {
  def combineK[A](x: F[A], y: F[A]): F[A]
  def empty[A]: F[A]
}

Example:

implicit val listSemigroupK: SemigroupK[List] =
  new SemigroupK[List] {
    def combineK[A](ls1: List[A], ls2: List[A]) = ls1 ::: ls2
  }

val listSGK = implicitly[SemigroupK[List]]
listSGK.combineK(List(1,2), List(3))
// List(1, 2, 3)

implicit def listMonoidK(implicit sgk: SemigroupK[List]): MonoidK[List] =
  new MonoidK[List] {
    def combineK[A](ls1: List[A], ls2: List[A]) = sgk.combineK(ls1, ls2)
    def empty[A] = List.empty[A]
  }

val listMK = implicitly[MonoidK[List]]
List(List(1,2),List(3),List(4,5)).foldLeft[List[Int]](listMK.empty)(listMK.combineK(_, _))
// List(1, 2, 3, 4, 5)

Functor

Functor is a higher-kinded typeclass characterized by its method map which transforms some value within a given context F via a function.

trait Functor[F[_]] {
  def map[A, B](fa: F[A])(f: A => B): F[B]
}

Example:

implicit val listFunctor: Functor[List] =
  new Functor[List] {
    def map[A, B](ls: List[A])(f: A => B) = ls.map(f)
  }

val listF = implicitly[Functor[List]]
listF.map(List(1,2,3))(i => s"#$i!")
// List("#1!", "#2!", "#3!")

Monad

Monad enables sequencing of operations in which resulting values from an operation can be utilized in the subsequent one.

But first, let’s look at typeclass FlatMap.

trait FlatMap[F[_]] extends Apply[F] {
  def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]
  def map[A, B](fa: F[A])(f: A => B): F[B]
  @tailrec
  def tailRecM[A, B](init: A)(f: A => F[Either[A, B]]): F[B]
}

FlatMap extends Apply whose key methods aren’t what we would like to focus on at the moment. Rather, we’re more interested in method flatMap which enables sequential chaining of operations.

In addition, method tailRecM is a required implementation for stack-safe recursions on the JVM (which doesn’t natively support tail call optimization).

Monad inherits almost all its signature methods from FlatMap.

trait Monad[F[_]] extends FlatMap[F] with Applicative[F] {
  def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]
  def pure[A](a: A): F[A]
  def map[A, B](fa: F[A])(f: A => B): F[B]
  @tailrec
  def tailRecM[A, B](init: A)(f: A => F[Either[A, B]]): F[B]
}

Monad also extends Applicative which we’ll get to (along with Apply) in a bit. For now, it suffices to note that Monad inherits pure from Applicative.

Even without realizing that Monad extends Functor (indirectly through FlatMap and Apply), one could conclude that Monads are inherently Functors by implementing map using flatMap and pure.

def map[A, B](fa: F[A])(f: A => B): F[B] = flatMap(fa)(a => pure(f(a)))

Example:

trait Monad[F[_]] {  // Skipping dependent classes
  def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]
  def pure[A](a: A): F[A]
  def map[A, B](fa: F[A])(f: A => B): F[B]
  def tailRecM[A, B](init: A)(f: A => F[Either[A, B]]): F[B]
}

implicit val optionMonad: Monad[Option] =
  new Monad[Option] {
    def flatMap[A, B](opt: Option[A])(f: A => Option[B]) = opt.flatMap(f)
    def pure[A](a: A) = Option(a)
    def map[A, B](opt: Option[A])(f: A => B): Option[B] = flatMap(opt)(a => pure(f(a)))
    @scala.annotation.tailrec
    def tailRecM[A, B](a: A)(f: A => Option[Either[A, B]]): Option[B] = f(a) match {
      case None => None
      case Some(leftOrRight) => leftOrRight match {
        case Left(a1) => tailRecM(a1)(f)
        case Right(b1) => Option(b1)
      }
    }
  }

val optMonad = implicitly[Monad[Option]]
optMonad.flatMap(Option(3))(i => if (i > 0) Some(s"#$i!") else None)
// Some("#3!")

Semigroupal and Apply

A higher-kinded typeclass, Semigroupal conceptually deviates from SemiGroup’s values combining operation to joining independent contexts in a tupled form “product”.

trait Semigroupal[F[_]] {
  def product[A, B](fa: F[A], fb: F[B]): F[(A, B)]
}

Despite the simplicity of method product (which is the only class method), Semigroupal lays out the skeletal foundation for the problem space of concurrency of independent operations, as opposed to Monad’s sequential chaining.

Next, Apply brings together the goodies of Semigroupal and Functor. Its main method ap has a rather peculiar signature that doesn’t look intuitively meaningful.

trait Apply[F[_]] extends Semigroupal[F] with Functor[F] {
  def ap[A, B](ff: F[A => B])(fa: F[A]): F[B]
  def map[A, B](fa: F[A])(f: A => B): F[B]
  def product[A, B](fa: F[A], fb: F[B]): F[(A, B)]
}

Conceptually, it can be viewed as a specialized map in which the transformation function is “wrapped” in the context.

By restructuring the type parameters in ap[A, B] and map[A, B], method product can be implemented in terms of ap and map.

// 1: Substitute `B` with `B => (A, B)`
def map[A, B](fa: F[A])(f: A => B => (A, B)): F[B => (A, B)]

// 2: Substitute `A` with `B` and `B` with `(A, B)`
def ap[A, B](ff: F[B => (A, B)])(fb: F[B]): F[(A, B)]

// Applying 1 and 2:
def product[A, B](fa: F[A], fb: F[B]): F[(A, B)] =
  ap(map(fa)(a => (b: B) => (a, b)))(fb)

Applicative

trait Applicative[F[_]] extends Apply[F] {
  def pure[A](a: A): F[A]
}

Like how Monoid supplements SemiGroup with the empty element to form a more “self-contained” typeclass, Applicative extends Apply and adds method pure which wraps a value in a context. The seemingly insignificant inclusion makes Applicative a typeclass capable of addressing problems within a particular problem space.

Similarly, Monad takes pure from Applicative along with the core methods from FlatMap to become another “self-contained” typeclass to master a different computational problem space.

Contrary to Monad’s chaining of dependent operations, Applicative embodies concurrent operations, allowing independent computations to be done in parallel.

We’ll defer examples for Applicative to a later section.

Foldable

Foldable offers fold methods that go over (from left to right or vice versa) some contextual value (oftentimes a collection) and aggregate via a binary function starting from an initial value. It also provides method foldMap that maps to a Monoid using an unary function.

trait Foldable[F[_]] {
  def foldLeft[A, B](fa: F[A], b: B)(f: (B, A) => B): B
  def foldRight[A, B](fa: F[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B]
  def foldMap[A, B: Monoid](fa: F[A])(f: A => B): B
}

Note that the well known foldRight method in some Scala collections may not be stack-safe (especially in older versions). Cats uses a data type Eval in its foldRight method to ensure stack-safety.

Traverse

Traverse extends Functor and Foldable and provides method traverse. The method traverses and transforms some contextual value using a function that wraps the transformed value within the destination context, which as a requirement, is bound to an Applicative.

trait Traverse[F[_]] extends Functor[F] with Foldable[F] {
  def traverse[G[_]: Applicative, A, B](fa: F[A])(ff: A => G[B]): G[F[B]]
}

If you’ve used Scala Futures, method traverse (and the sequence method) might look familiar.

def sequence[G[_]: Applicative, A](fg: F[G[A]]): G[F[A]] =
  traverse(fg)(identity)

Method sequence has the effect of turning a nested context “inside out” and is just a special case of traverse by substituting A with G[B] (i.e. making ff an identity function).

Example: Applicative and Traverse

To avoid going into a full-on implementation of Traverse in its general form that would, in turn, require laborious implementations of all the dependent typeclasses, we’ll trivialize our example to cover only the case for Futures (i.e. type G = Future).

First, we come up with a specialized Traverse as follows:

import scala.concurrent.{ExecutionContext, Future}

trait FutureTraverse[F[_]] {  // Skipping dependent classes
  def traverse[A, B](fa: F[A])(ff: A => Future[B]): Future[F[B]]
}

For similar reasons, let’s also “repurpose” Applicative to include only the methods we need. In particular, we include method map2 which will prove handy for implementing the traverse method for FutureTraverse.

trait Applicative[F[_]] {  // Skipping dependent classes
  def map[A, B](fa: F[A])(f: A => B): F[B]
  def map2[A, B, Z](fa: F[A], fb: F[B])(f: (A, B) => Z): F[Z]
  def pure[A](a: A): F[A]
}

implicit val futureApplicative: Applicative[Future] =
  new Applicative[Future] {
    implicit val ec = ExecutionContext.Implicits.global
    def map[A, B](fa: Future[A])(f: A => B): Future[B] = fa.map(f)
    def map2[A, B, Z](fa: Future[A], fb: Future[B])(f: (A, B) => Z): Future[Z] =
      (fa zip fb).map(f.tupled)
    def pure[A](a: A): Future[A] = Future.successful(a)
  }

We implement map2 by tuple-ing the Futures and binary function via zip and tupled, respectively. With the implicit Applicative[Future] in place, we’re ready to implement FutureTraverse[List].

implicit val listFutureTraverse: FutureTraverse[List] =
  new FutureTraverse[List] {
    implicit val ec = ExecutionContext.Implicits.global
    implicit val appF = implicitly[Applicative[Future]]
    def traverse[A, B](ls: List[A])(ff: A => Future[B]): Future[List[B]] = {
      ls.foldRight[Future[List[B]]](Future.successful(List.empty[B])){ (a, acc) =>
        appF.map2(ff(a), acc)(_ :: _)
      }
    }
  }

import scala.concurrent.ExecutionContext.Implicits.global

val lsFutTraverse = implicitly[FutureTraverse[List]]
lsFutTraverse.traverse(List(1,2,3)){ i =>
  if (i > 0) Future.successful(s"#$i!") else Future.failed(new Exception())
}
// Future(Success(List("#1!", "#2!", "#3!")))

As a side note, we could implement traverse without using Applicative. Below is an implementation leveraging Future’s flatMap method along with a helper function (as demonstrated in a previous blog post about Scala collection traversal).

implicit val listFutureTraverse: FutureTraverse[List] =
  new FutureTraverse[List] {
    implicit val ec = ExecutionContext.Implicits.global
    def pushToList[A](a: A)(as: List[A]): List[A] = a :: as
    def traverse[A, B](ls: List[A])(ff: A => Future[B]): Future[List[B]] = {
      ls.foldRight[Future[List[B]]](Future.successful(List.empty[B])){ (a, acc) =>
        ff(a).map(pushToList).flatMap(acc.map)
      }
    }
  }

NIO-based Reactor In Scala

For high concurrency at scale, event-driven server design with non-blocking I/O operations has been one of the most popular server architectures. Nginx and Node.js, both leading server platforms in their own spaces, adopt the very technology. Among the various event-driven server implementations, the Reactor pattern remains a prominent design pattern that leverages an event loop equipped with a demultiplexer to efficiently select events that are ready to be processed by a set of event handers.

Back in 2013, I wrote a blog post about building a barebone server using Java NIO API to implement the Reactor pattern with non-blocking I/O in Java. The goal here is to rewrite the NIO-based Reactor server in Scala.

Java NIO and Reactor pattern

A quick recap of Java NIO, which consists of the following key components:

  • Buffer – a container of primitive typed data (e.g. Byte, Int) that can be optimized for native I/O operations with memory alignment and paging functionality
  • Channel – a connector associated with an I/O entity (e.g. files, sockets) that supports non-blocking I/O operations
  • Selector – a demultiplexer on an event loop that selects events which are ready for carrying out pre-registered I/O operations (e.g. read, write)

Note that NIO Channel implements SelectableChannel which can be registered with the Selector as a SelectionKey for any I/O operations of interest. To optimally handle high-volume client connections to the server, channels can be configured via method configureBlocking(false) to support non-blocking I/O.

With NIO Buffers enabling optimal memory access and native I/O operations, Channels programmatically connecting I/O entities, and Selector serving as the demultiplexer on an event loop selecting ready-to-go I/O events to execute in a non-blocking fashion, the Java NIO API is a great fit for implementing an effective Reactor server.

Reactor event loop

This Scala version of the NIO Reactor server consists of two main classes NioReactor and Handler, along with a trait SelKeyAttm which is the base class for objects that are to be coupled with individual selection-keys as their attachments (more on this later).

Central to the NioReactor class is the “perpetual” event loop performed by class method selectorLoop(). It’s an recursive function that doesn’t ever return (thus returning Nothing), equivalent to the conventional infinite while(true){} loop. All it does is to repetitively check for the selection-keys whose corresponding channels are ready for the registered I/O operations and iterate through the keys to carry out the necessary work defined in the passed-in function iterFn().

  import java.util.{Iterator => JavaIter}

  @scala.annotation.tailrec
  final def selectorLoop(iterFn: (JavaIter[SelectionKey]) => Unit): Nothing = {
    selector.select()
    val it = selector.selectedKeys().iterator()
    iterFn(it)
    selectorLoop(iterFn)
  }

Function iterateSelKeys, which is passed in as the parameter for the event loop function, holds the selection-keys iteration logic. While it’s tempting to convert the Java Iterator used in the original Java application to a Scala Iterator, the idea was scrapped due to the need for the timely removal of the iterated selection-key elements via remove() which apparently is a required step for the time-critical inner working of the selector. Scala Iterator (or Iterable) does not have such method or its equivalence.

  @scala.annotation.tailrec
  final def iterateSelKeys(it: JavaIter[SelectionKey]): Unit = {
    if (it.hasNext()) {
      val sk = it.next()
      it.remove()
      val attm: SelKeyAttm = sk.attachment().asInstanceOf[SelKeyAttm]
      if (attm != null)
        attm.run()
      iterateSelKeys(it)
    }
    else ()
  }

Contrary to the selection-key attachments being of type Runnable in the original version, they’re now a subtype of SelKeyAttm each of which implements method run() that gets called once selected by the Selector. Using Scala Futures, Runnables are no longer the object type of the selection-key attachments. By making SelKeyAttm the base type for objects attached to the selection-keys, a slightly more specific “contract” (in the form of method specifications) is set up for those objects to adhere to.

Acceptor

The Acceptor, associated with the NIO ServerSocketChannel for the listener socket, is a subtype of SelKeyAttm. It’s responsible for reception of server connection requests.

  class Acceptor extends SelKeyAttm {
    def run(): Try[Unit] = Try {
        val channel: SocketChannel = serverChannel.accept()
        if (channel != null)
          new Handler(selector, channel)
        ()
      }
      .recover {
        case e: IOException => println(s"Acceptor: $e")
      }
  }

Part of class NioReactor’s constructor routine is to bind the ServerSocketChannel to a specified port number. It’s also where the ServerSocketChannel is configured to be non-blocking and registered with the selector it’s ready to accept connections (OP_ACCEPT), subsequently creating a selection-key with the Acceptor instance as its attachment.

class NioReactor(port: Int) {
  implicit val ec: ExecutionContext = NioReactor.ec

  val selector: Selector = Selector.open()
  val serverChannel: ServerSocketChannel = ServerSocketChannel.open()

  serverChannel.socket().bind(new InetSocketAddress(port))
  serverChannel.configureBlocking(false)

  val sk: SelectionKey = serverChannel.register(selector, SelectionKey.OP_ACCEPT)
  sk.attach(new Acceptor())

  // ...
}

The companion object of the NioReactor class is set up with a thread pool to run the Reactor server at a provided port number in a Scala Future.

object NioReactor {
  val poolSize: Int = 10
  val workerPool = Executors.newFixedThreadPool(poolSize)
  implicit val ec = ExecutionContext.fromExecutorService(workerPool)

  def apply(port: Int = 9090): Future[Unit] = Future {
      (new NioReactor(port)).loop()
    }
    .recover {
      case e: IOException => println(s"Reactor($port): $e")
    }

  // ...
}

Event handlers

As shown in the snippet of the Acceptor class, upon acceptance of a server connection, an instance of Handler is spawned. All events (in our case, the reading requests from and writing responses to client sockets) are processed by those handlers, which are another subtype of SelKeyAttm.

The Handler class instance takes a Selector and a SocketChannel as parameters, initializes a couple of ByteBuffers for read/write, configures the SocketChannel to be non-blocking, registers with the selector for I/O operation OP_READ, creates a selection-key with the existing handler instance as its attachment, followed by nudging the selector for immediate return of any selected channels.

Method run() is responsible for, upon being called, carrying out the main read/write handling logic in accordance with the selection-key the passed-in SocketChannel is associated with and the corresponding I/O operation of interest.

object Handler {
  val readBufSize: Int = 1024
  val writeBufSize: Int = 1024
}

class Handler(sel: Selector, channel: SocketChannel)(implicit ec: ExecutionContext) extends SelKeyAttm {
  import Handler._

  var selKey: SelectionKey = null
  val readBuf = ByteBuffer.allocate(readBufSize)
  var writeBuf = ByteBuffer.allocate(writeBufSize)

  channel.configureBlocking(false)

  selKey = channel.register(sel, SelectionKey.OP_READ)
  selKey.attach(this)
  sel.wakeup()

  def run(): Try[Unit] = Try {
      if (selKey.isReadable())
        read()
      else if (selKey.isWritable())
        write()
    }
    .recover {
      case e: IOException => println(s"Handler run(): $e")
    }

  def process(): Unit = ???

  def read(): Unit = ???

  def write(): Unit = ???

}

Processing read/write buffers

Method read() calls channel.read(readBuf) which reads a preset number of bytes from the channel into the readBuf ByteBuffer and returns the number of Bytes read. If the channel has reached “end-of-stream”, in which case channel.read() will return -1, the corresponding selection-key will be cancelled and the channel will be closed; otherwise, processing work will commence.

  def read(): Unit = synchronized {
      Try {
          val numBytes: Int = channel.read(readBuf)
          println("Handler read(): #bytes read into 'readBuf' buffer = " + numBytes)
  
          if (numBytes == -1) {
            selKey.cancel()
            channel.close()
            println("Handler read(): client connection might have been dropped!")
          }
          else {
            Future {
                process()
              }
              .recover {
                case e: IOException => println(s"Handler process(): $e")
              }
          }
        }
        .recover {
          case e: IOException => println(s"Handler read(): $e")
        }
    }

Method process() does the actual post-read processing work. It’s supposed to do the heavy-lifting (thus being wrapped in a Scala Future), although in this trivial server example, all it does is simply echoing whatever read from the readBuf ByteBuffer using the NIO Buffer API and write into the writeBuf ByteBuffer, followed by switching the selection-key’s I/O operation of interest to OP_WRITE.

  def process(): Unit = synchronized {
      readBuf.flip()
      val bytes: Array[Byte] = Array.ofDim[Byte](readBuf.remaining())
      readBuf.get(bytes, 0, bytes.length)
      print("Handler process(): " + new String(bytes, Charset.forName("ISO-8859-1")))

      writeBuf = ByteBuffer.wrap(bytes)

      selKey.interestOps(SelectionKey.OP_WRITE)
      selKey.selector().wakeup()
    }

Method write() calls channel.write(writeBuf) to write from the writeBuf ByteBuffer into the calling channel, followed by clearing both the read/write ByteBuffers and switching the selection-key’s I/O operation of interest back to OP_READ.

  def write(): Unit = {
    Try {
        val numBytes: Int = channel.write(writeBuf)
        println("Handler write(): #bytes read from 'writeBuf' buffer = " + numBytes)

        if (numBytes > 0) {
          readBuf.clear()
          writeBuf.clear()

          selKey.interestOps(SelectionKey.OP_READ)
          selKey.selector().wakeup()
        }
      }
      .recover {
        case e: IOException => println(s"Handler write(): $e")
      }
  }

Final thoughts

In this code rewrite in Scala, the main changes include the replacement of:

  • Java Runnable with Scala Future along with the base type SelKeyAttm for the Acceptor and Handler objects that are to be attached to selection-keys
  • while-loop with recursive functions
  • try-catch with Try-recover

While Java NIO is a great API for building efficient I/O-heavy applications, its underlying design apparently favors the imperative programming style. Rewriting the NIO-based Reactor server application using a functional programming language like Scala doesn’t necessarily make the code easier to read or maintain, as many function calls in the API return void (i.e. Scala Unit) and mutate variables passed in as parameters, making it difficult to be thoroughly rewritten in an idiomatic fashion.

Full source code of the Scala NIO Reactor server application is available at this GitHub repo.

To compile and run the Reactor server, git-clone the repo and run sbt from the project-root at a terminal on the server host:

$ sbt compile
$ sbt "runMain reactor.NioReactor `port`"

Skipping the port number will bind the server to the default port 9090.

To connect to the Reactor server, use telnet from one or more client host terminals:

telnet `server-host` `port`
#  e.g. telnet reactor.example.com 8080, telnet localhost 9090

Any text input from the client host(s) will be echoed back by the Reactor server, which itself will also report what has been processed. Below are sample input/output from a couple of client host terminals and the server terminal:

## Client host terminal #1:

$ telnet 192.168.1.100 9090
Trying ::1...
Connected to 192.168.1.100.
Escape character is '^]'.
blah blah blah from term #1
blah blah blah from term #1
^]
telnet> quit
Connection closed.

## Client host terminal #2:

$ telnet 192.168.1.100 9090
Trying ::1...
Connected to 192.168.1.100.
Escape character is '^]'.
foo bar from term #2
foo bar from term #2
^]
telnet> quit
Connection closed.

## Server host terminal:

$ sbt "runMain reactor.NioReactor"
[info] ...
[info] running reactor.NioReactor 
Handler read(): #bytes read into 'readBuf' buffer = 29
Handler write(): #bytes read from 'writeBuf' buffer = 29
Handler read(): #bytes read into 'readBuf' buffer = 22
Handler write(): #bytes read from 'writeBuf' buffer = 22
Handler read(): #bytes read into 'readBuf' buffer = -1
Handler read(): client connection might have been dropped!
Handler read(): #bytes read into 'readBuf' buffer = -1
Handler read(): client connection might have been dropped!
^C
[warn] Canceling execution...
Cancelled: runMain reactor.NioReactor

As a side note, the output from method Handler.process() which is wrapped in a Scala Future will be reported if the server is being run from within an IDE like IntelliJ.

Trampolining with Scala TailCalls

It’s pretty safe to assert that all software engineers must have come across stack overflow problems over the course of their career. For a computational task that involves recursions with complex programming logic, keeping the stack frames in a controlled manner could be challenging.

What is trampolining?

As outlined in the Wikipedia link about trampoline, it could mean different things in different programming paradigms. In the Scala world, trampolining is a means of making a function with recursive programming logic stack-safe by formulating the recursive logic flow into tail calls among functional components wrapped in objects, resulting in code being run in the JVM heap memory without the need of stacks.

A classic example for illustrating stack overflow issues is the good old Factorial.

def naiveFactorial(n: Int): BigInt = {
  require(n >= 0, "N must be a non-negative integer!")
  if (n == 0)
    1
  else
    n * naiveFactorial(n-1)
}

naiveFactorial(10000)  // java.lang.StackOverflowError!

Tail call versus tail recursion

It should be noted that tail call can be, in a way, viewed as a generalization of tail recursion. A tail recursive function is a tail-calling function that calls only itself, whereas a general tail-calling function could be performing tail calls of other functions. Tail recursive functions are not only stack-safe they oftentimes can operate with an efficient O(n) time complexity and O(1) space complexity.

On the other hand, trampolining a function provides a stack-safe solution but it does not speed up computation. To the contrary, it often takes longer time than the stack-based method to produce result (that is, until the stack overflow problem surfaces). That’s largely due to the necessary “bounces” between functions and object-wrapping.

Factorial with Scala TailCalls

Scala’s standard library provides TailCalls for writing stack-safe recursive functions via trampolining.

The stack-safe factorial function using Scala TailCalls will look like below:

def trampFactorial(n: Int): TailRec[BigInt] = {
  require(n >= 0, "N must be a non-negative integer!")
  if (n == 0)
    done(1)
  else
    tailcall(trampFactorial(n-1)).map(n * _)
}

trampFactorial(10000).result  // 28462596809170545189064132121198...

Scala TailCalls under the hood

To understand how things work under the hood, let’s extract the skeletal components from object TailCalls‘ source code.

object TailCalls {

  sealed abstract class TailRec[+A] {
    final def map[B](f: A => B): TailRec[B] =
      flatMap(a => Call(() => Done(f(a))))

    final def flatMap[B](f: A => TailRec[B]): TailRec[B] = this match {
      case Done(a) => Call(() => f(a))
      case c @ Call(_) => Cont(c, f)
      case c: Cont[a1, _] => Cont(c.a, (x: a1) => c.f(x) flatMap f)
    }

    @scala.annotation.tailrec
    final def result: A = this match {
      case Done(a) => a
      case Call(t) => t().result
      case Cont(a, f) => a match {
        case Done(v) => f(v).result
        case Call(t) => t().flatMap(f).result
        case Cont(b, g) => b.flatMap(x => g(x) flatMap f).result
      }
    }
  }

  protected case class Done[A](value: A) extends TailRec[A]

  protected case class Call[A](rest: () => TailRec[A]) extends TailRec[A]

  protected case class Cont[A, B](a: TailRec[A], f: A => TailRec[B]) extends TailRec[B]

  def tailcall[A](rest: => TailRec[A]): TailRec[A] = Call(() => rest)

  def done[A](result: A): TailRec[A] = Done(result)
}

Object TailCalls consists of the base class TailRec which consists of transformation methods flatMap and map, and a tail-recursive method result that performs the actual evaluations. Also included are subclasses Done, Call, and Cont, which encapsulate expressions that represent, respectively, the following:

  • literal returning values
  • recursive calls
  • continual transformations

In addition, methods tailcall() and done() are exposed to the users so they don’t need to directly meddle with the subclasses.

In general, to enable trampolining for a given function, make the function’s return value of type TailRec, followed by restructuring any literal function returns, recursive calls and continual transformations within the function body into corresponding subclasses of TailRec — all in a tail-calling fashion.

Methods done() and tailcall()

As can be seen from the TailCalls source code, all the subclasses (i.e. Done, Call, Cont) are shielded by the protected access modifier. Users are expected to use the interfacing methods done() and tailcall() along with class methods map and flatMp to formulate a tail-call version of the target function.

  def tailcall[A](rest: => TailRec[A]): TailRec[A] = Call(() => rest)

  def done[A](result: A): TailRec[A] = Done(result)

Method done(value) is just equivalent to Done(value), whereas tailcall(r: => TailRec) represents Call(() => r). It should be noted that the by-name parameter of tailcall() is critical to ensuring laziness, correlating to the Function0 parameter of class Call(). It’s an integral part of the stack-safe mechanics of trampolining.

Methods map and flatMap

Back to the factorial function. In the n == 0 case, the return value is a constant, so we wrap it in a Done(). For the remaining case, it involves continual operations and we wrap it in a Call(). Obviously, we can’t simply do tailcall(n * trampFactorial(n-1)) since trampFactorial() is now an object. Rather, we transform via map with a function t => n * t, similar to how we transform the internal value of an Option or Future.

But then tailcall(trampFactorial(n-1)).map(n * _) doesn’t look like a tail call. Why is it able to accomplish stack safety? To find out why, let’s look at how map and flatMap are implemented in the Scala TailCalls source code.

Map and flatMap are “trampolines”

    final def map[B](f: A => B): TailRec[B] =
      flatMap(a => Call(() => Done(f(a))))

    final def flatMap[B](f: A => TailRec[B]): TailRec[B] =
      this match {
        case Done(a) => Call(() => f(a))
        case c @ Call(_) => Cont(c, f)
        case c: Cont[a1, b1] => Cont(c.a, (x: a1) => c.f(x) flatMap f)
      }

From the source code, one can see the implementation of its class method flatMap follows the same underlying principles of trampolining — regardless of which subclass the current TailRec object belongs to, the method returns a tail-calling Call() or Cont(). That makes flatMap itself a trampolining transformation.

As for method map, it’s implemented using the special case of flatMap transforming a => Call(() => Done(f(a))), which is a trampolining tail call as well. Thus, both map and flatMap are trampolining transformations. Consequently, a tail expression of an arbitrary sequence of transformations with the two methods will preserve trampolining. That gives great flexibility for users to formulate a tail-calling function.

Evaluating the tail-call function

The tail-calling function will return a TailRec object, but all that in the heap memory is just an object “wired” for the trampolining mechanism. It won’t get evaluated until class method result is called.

    @scala.annotation.tailrec
    final def result: A = this match {
      case Done(a) => a
      case Call(t) => t().result
      case Cont(a, f) => a match {
        case Done(v) => f(v).result
        case Call(t) => t().flatMap(f).result
        case Cont(b, g) => b.flatMap(x => g(x) flatMap f).result
      }
    }

Constructed as an optimally efficient tail-recursive function, method result evaluates the function by matching the current TailRec[A] object against each of the subclasses to carry out the programming logic accordingly to return the resulting value of the type specified as the type parameter A.

If the current TailRec object is a Cont(a, f) which represents a transformation with function f on TailRec object a, the transformation will be carried out in accordance with what a is (thus another level of subclass matching). The class method flatMap comes in handy for carrying out the necessary composable transformation f as its signature conforms to that of the function taken by flatMap.

Trampolining Fibonacci

As a side note, Fibonacci generally will not incur stack overflow due to its relatively small space complexity, thus there is essentially no reason to apply trampolining. Nevertheless, it still serves as a good exercise of how to use TailCalls. On the other hand, a tail-recursive version of Fibonacci is highly efficient (see example in this previous post).

def naiveFibonacci(n: Int): BigInt = {
  require(n >= 0, "N must be a non-negative integer!")
  if (n <= 1)
    n
  else
    naiveFibonacci(n-2) + naiveFibonacci(n-1)
}

import scala.util.control.TailCalls._

def trampFibonacci(n: Int): TailRec[BigInt] = {
  require(n >= 0, "N must be a non-negative integer!")
  if (n <= 1)
    done(n)
  else
    tailcall(trampFibonacci(n-2)).flatMap(t => tailcall(trampFibonacci(n-1).map(t + _)))
}

In case it isn’t obvious, the else case expression is to sum the by-definition values of F(n-2) and F(n-1) wrapped in tailcall(F(n-2)) and tailcall(F(n-1)), respectively, via flatMap and map:

    tailcall(trampFibonacci(n-2)).flatMap{ t2 =>
      tailcall(trampFibonacci(n-1).map(t1 =>
        t2 + t1))
    }

which could also be achieved using for-comprehension:

    for {
      t2 <- tailcall(trampFibonacci(n-2))
      t1 <- tailcall(trampFibonacci(n-1))
    }
    yield t2 + t1

Height of binary search tree

Let's look at one more trampoline example that involves computing the height of a binary search tree. The following is derived from a barebone version of the binary search tree defined in a previous blog post:

import scala.util.control.TailCalls._

sealed trait BSTree[+A] { self =>

  def height: Int = self match {
    case BSLeaf => 0
    case BSBranch(_, _, l, r) => (l.height max r.height) + 1
  }

  def heightTC: Int = self match {
    case BSLeaf => 0
    case BSBranch(_, _, ltree, rtree) => {
      def loop(l: BSTree[A], r: BSTree[A], ht: Int): TailRec[Int] = (l, r) match {
        case (BSLeaf, BSLeaf) => done(ht)
        case (BSBranch(_, _, ll, lr), BSLeaf) => tailcall(loop(ll, lr, ht+1))
        case (BSLeaf, BSBranch(_, _, rl, rr)) => tailcall(loop(rl, rr, ht+1))
        case (BSBranch(_, _, ll, lr), BSBranch(_, _, rl, rr)) =>
          tailcall(loop(ll, lr, ht+1)).flatMap(lHt =>
            tailcall(loop(rl, rr, ht+1)).map(rHt =>
              lHt max rHt
            )
          )
      }
      loop(ltree, rtree, 1).result
    }
  }
}

case class BSBranch[+A](
    elem: A, count: Int = 1, left: BSTree[A] = BSLeaf, right: BSTree[A] = BSLeaf
  ) extends BSTree[A]

case object BSLeaf extends BSTree[Nothing]

The original height method is kept in trait BSTree as a reference. For the trampoline version heightTC, A helper function loop with an accumulator (i.e. ht) for aggregating the tree height is employed to tally the level of tree height. Using flatMap and map (or equivalently a for-comprehension), the main recursive tracing of the tree height follows similar tactic that the trampolining Fibonacci function uses.

Test running heightTC:

BSLeaf.heightTC  // 0

BSBranch(30, 1).heightTC  // 1

BSBranch(30, 1,
    BSBranch(10, 1,
      BSLeaf,
      BSBranch(20, 2)
    ),
    BSLeaf
  ).
  heightTC  // 3

/*
           30
        /
    10 
       \
        20
*/

BSBranch(30,1,
    BSBranch(10, 1,
      BSLeaf,
      BSBranch(20, 2)
    ),
    BSBranch(50, 2,
      BSBranch(40, 1),
      BSBranch(70, 1,
        BSBranch(60, 2),
        BSBranch(80, 1)
      )
    )
  ).
  heightTC  // 4

While at it, a tail-recursive version of the tree height method can be created with just slightly different approach. To achieve tail recursion, a recursive function is run with a Scala List of tree node-height tuples along with a max tree height value as function parameters, as shown below.

sealed trait BSTree[+A] { self =>
  // ...

  def heightTR: Int = self match {
    case BSLeaf => 0
    case BSBranch(_, _, lt, rt) => {
      @scala.annotation.tailrec
      def loop(list: List[(BSTree[A], Int)], maxHt: Int): Int = list match {
        case Nil =>
          maxHt
        case head :: tail =>
          val (tree, ht) = head
          tree match {
            case BSLeaf =>
              loop(tail, maxHt)
            case BSBranch(_, _, l, r) =>
              loop((l, ht+1) :: (r, ht+1) :: tail, ht max maxHt)
          }
      }
      loop((self, 1) :: Nil, 0)
    }
  }

  // ...
}