Author Archives: Leo Cheung

Implementing Linked List In Scala

In Scala, if you wonder why its standard library doesn’t come with a data structure called `LinkedList`, you may have overlooked. The collection List is in fact a linked list — although it often appears in the form of a Seq or Vector collection rather than the generally “mysterious” linked list that exposes its “head” with a hidden “tail” to be revealed only iteratively.

Our ADT: LinkedNode

Perhaps because of its simplicity and dynamicity as a data structure, implementation of linked list remains a popular coding exercise. To implement our own linked list, let’s start with a barebone ADT (algebraic data structure) as follows:

sealed trait LinkedNode[+A]
case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]
case object EmptyNode extends LinkedNode[Nothing]

If you’re familiar with Scala List, you’ll probably notice that our ADT resembles `List` and its subclasses `Cons` (i.e. `::`) and `Nil` (see source code):

sealed abstract class List[+A] { ... }
final case class ::[+A](override val head: A, var next: List[A]) extends List[A] { ... }
case object Nil extends List[Nothing] { ... }

Expanding LinkedNode

Let’s expand trait `LinkedNode` to create class methods `insertNode`/`deleteNode` at a given index for inserting/deleting a node, `toList` for extracting contained elements into a `List` collection, and `toString` for display:

sealed trait LinkedNode[+A] { self =>

  def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match {
      case EmptyNode =>
        if (i < idx)
          EmptyNode
        else
          Node(x, EmptyNode)
      case Node(e, nx) =>
        if (i < idx)
          Node(e, loop(nx, i + 1))
        else
          Node(x, ln)
    }
    loop(self, 0)
  }

  def deleteNode(idx: Int): LinkedNode[A] = {
    def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (idx == 0)
          nx
        else {
          if (i < idx - 1)
            Node(e, loop(nx, i + 1))
          else
            nx match {
              case EmptyNode => Node(e, EmptyNode)
              case Node(_, nx2) => Node(e, nx2)
            }
        }
    }
    loop(self, 0)
  }

  def toList: List[A] = {
    def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, e :: acc)
    }
    loop(self, List.empty[A]).reverse
  }

  override def toString: String = {
    def loop(ln: LinkedNode[A], acc: String): String = ln match {
      case EmptyNode =>
        acc + "()"
      case Node(e, nx) =>
        loop(nx, acc + e + " -> " )
    }
    loop(self, "")
  }
}

case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]

case object EmptyNode extends LinkedNode[Nothing]

// Test running ...

val ln = List(1, 2, 3, 3, 4).foldRight[LinkedNode[Int]](EmptyNode)(Node(_, _))
// ln: LinkedNode[Int] = 1 -> 2 -> 3 -> 3 -> 4 -> ()

ln.insertNode(9, 2)
// res1: LinkedNode[Int] = 1 -> 2 -> 9 -> 3 -> 3 -> 4 -> ()

ln.deleteNode(1)
// res2: LinkedNode[Int] = 1 -> 3 -> 3 -> 4 -> ()

Note that `LinkedNode` is made covariant. In addition, method `insertNode` has type `A` as its lower type bound because Function1 is `contravariant` over its parameter type.

Recursion and pattern matching

A couple of notes on the approach we implement our class methods with:

  1. We use recursive functions to avoid using of mutable variables. They should be made tail-recursive for optimal performance, but that isn’t the focus of this implementation. If performance is a priority, using conventional while-loops with mutable class fields `elem`/`next` would be a more practical option.
  2. Pattern matching is routinely used for handling cases of `Node` versus `EmptyNode`. An alternative approach would be to define fields `elem` and `next` in the base trait and implement class methods accordingly within `Node` and `EmptyNode`.

Finding first/last matching nodes

Next, we add a couple of class methods for finding first/last matching nodes.

  def find[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (e == x) ln else loop(nx)
    }
    loop(self)
  }

  def findLast[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        if (e == x) loop(nx, ln :: acc) else loop(nx, acc)
    }
    loop(self, List.empty[LinkedNode[B]]).headOption match {
      case None => EmptyNode
      case Some(node) => node
    }
  }

Reversing a linked list by groups of nodes

Reversing a given LinkedNode can be accomplished via recursion by cumulatively wrapping the element of each `Node` in a new `Node` with its `next` pointer set to the `Node` created in the previous iteration.

  def reverse: LinkedNode[A] = {
    def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, Node(e, acc))
    }
    loop(self, EmptyNode)
  }

  def reverseK(k: Int): LinkedNode[A] = {
    if (k == 1)
      self
    else
      self.toList.grouped(k).flatMap(_.reverse).
        foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
  }

Method `reverseK` reverses a LinkedNode by groups of `k` elements using a different approach that extracts the elements into groups of `k` elements, reverses the elements in each of the groups and re-wraps each of the flattened elements in a `Node`.

Using LinkedNode as a Stack

For `LinkedNode` to serve as a Stack, we include simple methods `push` and `pop` as follows:

  def push[B >: A](x: B): LinkedNode[B] = Node(x, self)

  def pop: (Option[A], LinkedNode[A]) = self match {
    case EmptyNode => (None, self)
    case Node(e, nx) => (Some(e), nx)
  }

Addendum: implementing map, flatMap, fold

From a different perspective, if we view `LinkedNode` as a collection like a Scala List or Vector, we might be craving for methods like `map`, `flatMap`, `fold`. Using the same approach of recursion along with pattern matching, it’s rather straight forward to crank them out.

  def map[B](f: A => B): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      Node(f(e), nx.map(f))
  }

  def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      f(e) match {
        case EmptyNode => nx.flatMap(f)
        case Node(e2, _) => Node(e2, nx.flatMap(f))
      }
  }

  def foldLeft[B](z: B)(f: (B, A) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(acc, e))
    }
    loop(self, z)
  }

  def foldRight[B](z: B)(f: (A, B) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(e, acc))
    }
    loop(self.reverse, z)
  }
}

Putting everything together

Along with a few additional simple class methods and a factory method wrapped in `LinkedNode`’s companion object, below is the final `LinkedNode` ADT that includes everything described above.

sealed trait LinkedNode[+A] { self =>

  def insertNode[B >: A](x: B, idx: Int): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], i: Int): LinkedNode[B] = ln match {
      case EmptyNode =>
        if (i < idx)
          EmptyNode
        else
          Node(x, EmptyNode)
      case Node(e, nx) =>
        if (i < idx)
          Node(e, loop(nx, i + 1))
        else
          Node(x, ln)
    }
    loop(self, 0)
  }

  def deleteNode(idx: Int): LinkedNode[A] = {
    def loop(ln: LinkedNode[A], i: Int): LinkedNode[A] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (idx == 0)
          nx
        else {
          if (i < idx - 1)
            Node(e, loop(nx, i + 1))
          else
            nx match {
              case EmptyNode => Node(e, EmptyNode)
              case Node(_, nx2) => Node(e, nx2)
            }
        }
    }
    loop(self, 0)
  }

  def get(idx: Int): LinkedNode[A] = {
    def loop(ln: LinkedNode[A], count: Int): LinkedNode[A] = {
      if (count < idx) {
        ln match {
          case EmptyNode => EmptyNode
          case Node(_, next) => loop(next, count + 1)
        }
      }
      else
        ln
    }
    loop(self, 0)
  }

  def find[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B]): LinkedNode[B] = ln match {
      case EmptyNode =>
        EmptyNode
      case Node(e, nx) =>
        if (e == x) ln else loop(nx)
    }
    loop(self)
  }

  def findLast[B >: A](x: B): LinkedNode[B] = {
    def loop(ln: LinkedNode[B], acc: List[LinkedNode[B]]): List[LinkedNode[B]] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        if (e == x) loop(nx, ln :: acc) else loop(nx, acc)
    }
    loop(self, List.empty[LinkedNode[B]]).headOption match {
      case None => EmptyNode
      case Some(node) => node
    }
  }

  def indexOf[B >: A](x: B): Int = {
    def loop(ln: LinkedNode[B], idx: Int): Int = ln match {
      case EmptyNode =>
        -1
      case Node(e, nx) =>
        if (e == x) idx else loop(nx, idx + 1)
     }
    loop(self, 0)
  }

  def indexLast[B >: A](x: B): Int = {
    def loop(ln: LinkedNode[B], idx: Int, acc: List[Int]): List[Int] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, idx + 1, if (e == x) idx :: acc else acc)
    }
    loop(self, 0, List(-1)).head
  }

  def reverse: LinkedNode[A] = {
    def loop(ln: LinkedNode[A], acc: LinkedNode[A]): LinkedNode[A] = ln match {
      case EmptyNode =>
        acc
       case Node(e, nx) =>
        loop(nx, Node(e, acc))
    }
    loop(self, EmptyNode)
  }

  def reverseK(k: Int): LinkedNode[A] = {
    if (k == 1)
      self
    else
      self.toList.grouped(k).flatMap(_.reverse).
        foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
  }

  def toList: List[A] = {
    def loop(ln: LinkedNode[A], acc: List[A]): List[A] = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, e :: acc)
    }
    loop(self, List.empty[A]).reverse
  }

  override def toString: String = {
    def loop(ln: LinkedNode[A], acc: String): String = ln match {
      case EmptyNode =>
        acc + "()"
      case Node(e, nx) =>
        loop(nx, acc + e + " -> " )
    }
    loop(self, "")
  }

  // ---- push / pop ----

  def push[B >: A](x: B): LinkedNode[B] = Node(x, self)

  def pop: (Option[A], LinkedNode[A]) = self match {
    case EmptyNode => (None, self)
    case Node(e, nx) => (Some(e), nx)
  }

  // ---- map / flatMap / fold ----

  def map[B](f: A => B): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      Node(f(e), nx.map(f))
  }

  def flatMap[B](f: A => LinkedNode[B]): LinkedNode[B] = self match {
    case EmptyNode =>
      EmptyNode
    case Node(e, nx) =>
      f(e) match {
        case EmptyNode => nx.flatMap(f)
        case Node(e2, _) => Node(e2, nx.flatMap(f))
      }
  }

  def foldLeft[B](z: B)(f: (B, A) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(acc, e))
    }
    loop(self, z)
  }

  def foldRight[B](z: B)(f: (A, B) => B): B = {
    def loop(ln: LinkedNode[A], acc: B): B = ln match {
      case EmptyNode =>
        acc
      case Node(e, nx) =>
        loop(nx, f(e, acc))
    }
    loop(self.reverse, z)
  }
}

object LinkedNode {
  def apply[A](ls: List[A]): LinkedNode[A] =
    ls.foldRight[LinkedNode[A]](EmptyNode)(Node(_, _))
}

case class Node[A](elem: A, next: LinkedNode[A]) extends LinkedNode[A]

case object EmptyNode extends LinkedNode[Nothing]

A cursory test-run …

val ln = LinkedNode(List(1, 2, 2, 3, 4, 4, 5))
// ln: LinkedNode[Int] = 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ()

val ln2 = ln.insertNode(9, 3)
// ln2: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 3 -> 4 -> 4 -> 5 -> ()

val ln3 = ln2.deleteNode(4)
// ln3: LinkedNode[Int] = 1 -> 2 -> 2 -> 9 -> 4 -> 4 -> 5 -> ()

ln.reverseK(3)
// res1: LinkedNode[Int] = 2 -> 2 -> 1 -> 4 -> 4 -> 3 -> 5 -> ()

ln.find(4)
// res2: LinkedNode[Int] = 4 -> 4 -> 5 -> ()

ln.indexLast(4)
// res3: Int = 5

// ---- Using LinkedNode like a `stack` ----

ln.push(9)
// res1: LinkedNode[Int] = 9 -> 1 -> 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ()

ln.pop
// res2: (Option[Int], LinkedNode[Int]) = (Some(1), 2 -> 2 -> 3 -> 4 -> 4 -> 5 -> ())

// ---- Using LinkedNode like a `Vector` collection ----

ln.map(_ + 10)
// res1: LinkedNode[Int] = 11 -> 12 -> 12 -> 13 -> 14 -> 14 -> 15 -> ()

ln.flatMap(i => Node(s"$i!", EmptyNode))
// res2: LinkedNode[String] = 1! -> 2! -> 2! -> 3! -> 4! -> 4! -> 5! -> ()

ln.foldLeft(0)(_ + _)
// res3: Int = 21

ln.foldRight("")((x, acc) => if (acc.isEmpty) s"$x" else s"$acc|$x")
// res4: String = 5|4|4|3|2|2|1

Traversing A Scala Collection

When we have a custom class in the form of an ADT, say, `Container[A]` and are required to process a collection of the derived objects like `List[Container[A]]`, there might be times we want to flip the collection “inside out” to become a single “container” of collection `Container[List[A]]`, and maybe further transform the inner collection with a function.

For those who are familiar with Scala Futures, the nature of such transformation is analogous to what method Future.sequence does. In case the traversal involves also applying to individual elements with a function, say, `f: A => Container[B]` to transform the collection into `Container[List[B]]`, it’ll be more like how Future.traverse works.

To illustrate how we can come up with methods `sequence` and `traverse` for the collection of our own ADTs, let’s assemble a simple ADT `Fillable[A]`. Our goal is to create the following two methods:

  def sequence[A](listFA: List[Fillable[A]]): Fillable[List[A]]
  def traverse[A, B](list: List[A])(f: A => Fillable[B]): Fillable[List[B]]

For simplicity, rather than a generic collection like IterableOnce, we fix the collection type to `List`.

A simple ADT

sealed trait Fillable[A]
case class Filled[A](a: A) extends Fillable[A]
case object Emptied extends Fillable[Nothing]

It looks a little like a home-made version of Scala `Option`, but is certainly not very useful yet. Let’s equip it with a companion object and a couple of methods for transforming the element within a `Fillable`:

sealed trait Fillable[+A] { self =>
  def map[B](f: A => B): Fillable[B] = self match {
    case Filled(a) => Filled(f(a))
    case Emptied   => Emptied
  }

  def flatMap[B](f: A => Fillable[B]): Fillable[B] = self match {
    case Filled(a) => f(a)
    case Emptied   => Emptied
  }
}

object Fillable {
  def apply[A](a: A): Fillable[A] = Filled(a)
}

case class Filled[A](a: A) extends Fillable[A]

case object Emptied extends Fillable[Nothing]

With slightly different signatures, methods `map` and `flatMap` are now available for transforming the element “contained” within a `Fillable`.

A couple of quick notes:

  • Fillable[A] is made covariant so that method `map/flatMap` is able to operate on subtypes of `Fillable`.
  • Using of self-type annotation isn’t necessary here, but is rather a personal coding style preference.

Testing the ADT:

val listF: List[Fillable[String]] = List(Filled("a"), Filled("bb"), Emptied, Filled("dddd"))

Filled("bb").map(_.length)
// res1: Fillable[Int] = Filled(2)

Filled("bb").flatMap(s => Fillable(s.length))
// res2: Fillable[Int] = Filled(2)

listF.map(_.map(_.length))
// res3: List[Fillable[Int]] = List(Filled(1), Filled(2), Emptied, Filled(4))

Sequencing a collection of Fillables

Let’s assemble method `sequence` which will reside within the companion object. Looking at the signature of the method to be defined:

  def sequence[A](listFA: List[Fillable[A]]): Fillable[List[A]]

it seems logical to consider aggregating a `List` from scratch within `Fillable` using Scala `fold`. However, trying to iteratively aggregate a list out of elements from within their individual “containers” isn’t as trivial as it may seem. Had there been methods like `get/getOrElse` that unwraps a `Fillable` to obtain the contained element, it would’ve been straight forward – although an implementation leveraging a getter method would require a default value for the contained element to handle the `Emptied` case.

One approach to implement `sequence` using only `map/flatMap` would be to first `map` within the `fold` operation each `Fillable` element of the input `List` into a list-push function for the element’s contained value, followed by a `flatMap` to aggregate the resulting `List` by iteratively applying the list-push functions within the `Fillable` container:

  def pushToList[A](a: A)(la: List[A]): List[A] = a :: la

  def sequence[A](listFA: List[Fillable[A]]): Fillable[List[A]] =
    listFA.foldRight(Fillable(List.empty[A])){ (fa, acc) =>
      fa match {
        case Emptied => acc
        case _       => fa.map(pushToList).flatMap(acc.map)
      }
    }

Note that `pushToList` within `map` is now regarded as a function that takes an element of type `A` and returns a `List[A] => List[A]` function. The expression `fa.map(pushToList).flatMap(acc.map)` is just a short-hand for:

  fa.map(a => pushToList(a)).flatMap(fn => acc.map(fn))

In essence, the first `map` transforms element within each `Fillable` in the input list into a corresponding list-push function for the specific element, and the `flatMap` uses the individual list-push functions for the inner `map` to iteratively aggregate the list inside the resulting `Fillable`.

Traversing the Fillable collection

Next, we’re going to define method `traverse` with the following signature within the companion object:

 def traverse[A, B](list: List[A])(f: A => Fillable[B]): Fillable[List[B]]

In case it doesn’t seem obvious, based on the method signatures, `sequence` is really just a special case of `traverse` with `f(a: A) = Fillable(a)`

Similar to the way `sequence` is implemented, we’ll also use `fold` for iterative aggregating the resulting list. Since an element of type Fillable[A] when `flatMap`-ed with the provided function `f` would yield a Fillable[B], we’re essentially dealing with the same problem we did for `sequence` except that we type `A` is now replaced with type `B`.

  def traverse[A, B](list: List[A])(f: A => Fillable[B]): Fillable[List[B]] =
    list.foldRight(Fillable(List.empty[B])) { (a, acc) =>
      val fb = f(a)
      fb match {
        case Emptied => acc
        case _ => fb.map(pushToList).flatMap(acc.map)
      }
    }

Putting everything together:

sealed trait Fillable[+A] { self =>
  def map[B](f: A => B): Fillable[B] = self match {
    case Filled(a) => Filled(f(a))
    case Emptied   => Emptied
  }

  def flatMap[B](f: A => Fillable[B]): Fillable[B] = self match {
    case Filled(a) => f(a)
    case Emptied   => Emptied
  }
}

object Fillable {
  def apply[A](a: A): Fillable[A] = Filled(a)

  def pushToList[A](a: A)(la: List[A]): List[A] = a :: la

  def sequence[A](listFA: List[Fillable[A]]): Fillable[List[A]] =
    listFA.foldRight(Fillable(List.empty[A])){ (fa, acc) =>
      fa match {
        case Emptied => acc
        case _       => fa.map(pushToList).flatMap(acc.map)
      }
    }

  def traverse[A, B](list: List[A])(f: A => Fillable[B]): Fillable[List[B]] =
    list.foldRight(Fillable(List.empty[B])) { (a, acc) =>
      val fb = f(a)
      fb match {
        case Emptied => acc
        case _ => fb.map(pushToList).flatMap(acc.map)
      }
    }
}

case class Filled[A](a: A) extends Fillable[A]

case object Emptied extends Fillable[Nothing]

Testing with the newly created methods:

val listF: List[Fillable[String]] = List(Filled("a"), Filled("bb"), Emptied, Filled("dddd"))

Fillable.sequence(listF)
// res1: Fillable[List[Int]] = Filled(List(a, bb, dddd))

val list: List[String] = List("a", "bb", "ccc", "dddd")

Fillable.traverse[String, Int](list)(a => Fillable(a.length))
// res2: Fillable[List[Int]] = Filled(List(1, 2, 3, 4))

Fillable.traverse[String, Int](list)(_ => Emptied)
// res3: Fillable[List[Int]] = Filled(List())

Scala Unfold

In Scala 2.13, method `unfold` is added to the standard library without a big bang it probably deserves. A first glance at the method signature might make one wonder how it could possibly be useful. Admittedly, it’s not intuitive to reason about how to make use of it. Although it’s new in the Scala standard library, a couple of Akka Stream operators like `Source.unfold`, `Source.unfoldAsync` with similar method signature and functionality have already been available for a while.

While the method is available for a number of Scala collections, I’ll illustrate the using of it with the `Iterator` collection. One reason `Iterator` is chosen is that its “laziness” allows the method to be used for generating infinite sequences. Here’s the method unfold’s signature:

def unfold[A, S](init: S)(f: (S) => Option[(A, S)]): Iterator[A]

Method fold versus unfold

Looking at a method by the name `unfold`, one might begin to ponder its correlation to method fold. The contrary between `fold` and `unfold` is in some way analogous to that between `apply` and unapply, except that it’s a little more intuitive to “reverse” the logic from `apply` to `unapply` than from `fold` to `unfold`.

Let’s take a look at the method signature of `fold`:

def fold[A1 >: A](z: A1)(op: (A1, A1) => A1): A1

Given a collection (in this case, an Iterator), method `fold` allows one to iteratively transform the elements of the collection into an aggregated element of similar type (a supertype of the elements to be precise) by means of a binary operator. For example:

Iterator(1 to 10: _*).fold(1000)(_ + _)
// res1: Int = 1055

Reversing method fold

In the above example, the binary operator `_ + _`, which is a shorthand for `(acc, x) => acc + x`, iteratively adds a number from a sequence of number, and `fold` applies the operator against the given Iterator’s content starting with an initial number 1000. It’s in essence doing this:

1000 + 1 + 2 + ... + 10

To interpret the “reverse” logic in a loose fashion, let’s hypothesize a problem with the following requirement:

Given the number 1055 (the “folded” sum), iteratively assemble a monotonically increasing sequence from 1 such that subtracting the cumulative sum of the sequence elements from 1055 remains larger than 1000.

Here’s one way of doing it using `unfold`:

val iter = Iterator.unfold((1, 1055)){
  case (i, n) if n > 1000 => Some((i, (i+1, n-i)))
  case _ => None
}

iter.toList
// res1: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

How does unfold work?

Recall that Iterator’s `unfold` has the following method signature:

def unfold[A, S](init: S)(f: (S) => Option[(A, S)]): Iterator[A]

As can be seen from the signature, starting from a given “folded” initial state value, elements of a yet-to-be-generated sequence are iteratively “unfolded” by means of the function `f`. In each iteration, the returned tuple of type Option[(A, S)] determines a few things:

  1. the 1st tuple element of type `A` is the new element to be added to the resulting sequence
  2. the 2nd tuple element of type `S` is the next `state` value, revealing how the state is being iteratively mutated
  3. a returned `Some((elem, state))` signals a new element being generated whereas a returned `None` signals the “termination” for the sequence generation operation

In the above example, the `state` is itself a tuple with initial state value `(1, 1055)` and next state value `(i+1, n-i)`. The current state `(i, n)` is then iteratively transformed into an `Option` of tuple with:

  • the element value incrementing from `i` to `i+1`
  • the state value decrementing from `n` to `n-i`, which will be iteratively checked against the `n > 1000` condition

Modified examples from Akka Stream API doc

Let’s look at a couple of examples modified from the Akka Stream API doc for Source.unfold. The modification is minor but necessary due to difference in the method signatures.

Example 1:

def countDown(n: Int): Iterator[Int] =
  Iterator.unfold(n) { count =>
    if (count > 0) Some((count, count - 1))
    else None
  }

countDown(10).toList
// res1: List[Int] = List(10, 9, 8, 7, 6, 5, 4, 3, 2, 1)

This is a nice “Hello World” example of `unfold`. Following the above bullet points of how-it-works, #1 and #2 tell us the resulting sequence has a starting element `count` iteratively decremented by 1 and how-it-works #3 says when `count` is not larger than 0 (i.e. decremented down to 0) the sequence generation operation stops.

Example 2:

val fibonacci: Iterator[BigInt] =
  Iterator.unfold((BigInt(0), BigInt(1))) {
    case (x, y) => Some((x, (y, x + y)))
  }

fibonacci.take(10).toList 
// res1: List[BigInt] = List(0, 1, 1, 2, 3, 5, 8, 13, 21, 34)

fibonacci.drop(90).next
// res2: BigInt = 2880067194370816120

This example showcases a slick way of generating a Fibonacci sequence. Here, we use a tuple as the initial `state` value, resulting in the operator returning a value with a nested tuple. Tuples are used for the `state` because each number in a Fibonacci sequence depends on two preceding numbers, `Fib(n) = Fib(n-1) + Fib(n-2)`, hence in composing the sequence content we want to carry over more than one number in every iteration.

Applying the logic of how-it-works #1 and #2, if `x` and `y` represent the current and next elements, respectively, generated for the resulting sequence, `x + y` would be the value of the following element in accordance with the definition of Fibonacci numbers. In essence, the tuple `state` represents the next two values of elements to be generated. What about how-it-works #3? The missing of `None` case in the return value of the binary operator indicates that there is no terminating condition, hence we have an infinite sequence here.

Another example:

Here’s one more example which illustrates how one could generate a Factorial sequence using `unfold`.

val factorial: Iterator[BigInt] =
  Iterator.unfold((BigInt(1), 0)){
    case (n, i) => Some((n, (n*(i+1), i+1)))
  }

factorial.take(10).toList
// res1: List[BigInt] = List(1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880)

In this example, we also use a tuple to represent the `state`, although there is a critical difference between what the tuple elements represent when compared with the previous example. By definition, the next number in a Factorial sequence only depends on the immediately preceding number, `Fact(i+1) = Fact(i) * (i+1)`, thus the first tuple element, `n * (i+1)`, takes care of that, defining what the next element of the resulting sequence will be. But there is also a need to carry over the next index value and that’s what the second tuple element is for. Again, without the `None` case in the return value of the binary operator, the resulting sequence will be infinite.

As a side note, we could also use method iterate that comes with Scala `Iterator` collection with similar iteration logic like below:

def factorial(n: Int): Iterator[BigInt] =
  Iterator.iterate((BigInt(1), 0)){
      case (n, i) => (n*(i+1), i+1)
    }.
    take(n).map(_._1)

factorial(10).toList
// res2: List[BigInt] = List(1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880)