Tag Archives: scala

Scala Unfold

In Scala 2.13, method `unfold` is added to the standard library without a big bang it probably deserves. A first glance at the method signature might make one wonder how it could possibly be useful. Admittedly, it’s not intuitive to reason about how to make use of it. Although it’s new in the Scala standard library, a couple of Akka Stream operators like `Source.unfold`, `Source.unfoldAsync` with similar method signature and functionality have already been available for a while.

While the method is available for a number of Scala collections, I’ll illustrate the using of it with the `Iterator` collection. One reason `Iterator` is chosen is that its “laziness” allows the method to be used for generating infinite sequences. Here’s the method unfold’s signature:

def unfold[A, S](init: S)(f: (S) => Option[(A, S)]): Iterator[A]

Method fold versus unfold

Looking at a method by the name `unfold`, one might begin to ponder its correlation to method fold. The contrary between `fold` and `unfold` is in some way analogous to that between `apply` and unapply, except that it’s a little more intuitive to “reverse” the logic from `apply` to `unapply` than from `fold` to `unfold`.

Let’s take a look at the method signature of `fold`:

def fold[A1 >: A](z: A1)(op: (A1, A1) => A1): A1

Given a collection (in this case, an Iterator), method `fold` allows one to iteratively transform the elements of the collection into an aggregated element of similar type (a supertype of the elements to be precise) by means of a binary operator. For example:

Iterator(1 to 10: _*).fold(1000)(_ + _)
// res1: Int = 1055

Reversing method fold

In the above example, the binary operator `_ + _`, which is a shorthand for `(acc, x) => acc + x`, iteratively adds a number from a sequence of number, and `fold` applies the operator against the given Iterator’s content starting with an initial number 1000. It’s in essence doing this:

1000 + 1 + 2 + ... + 10

To interpret the “reverse” logic in a loose fashion, let’s hypothesize a problem with the following requirement:

Given the number 1055 (the “folded” sum), iteratively assemble a monotonically increasing sequence from 1 such that subtracting the cumulative sum of the sequence elements from 1055 remains larger than 1000.

Here’s one way of doing it using `unfold`:

val iter = Iterator.unfold((1, 1055)){
  case (i, n) if n > 1000 => Some((i, (i+1, n-i)))
  case _ => None
}

iter.toList
// res1: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

How does unfold work?

Recall that Iterator’s `unfold` has the following method signature:

def unfold[A, S](init: S)(f: (S) => Option[(A, S)]): Iterator[A]

As can be seen from the signature, starting from a given “folded” initial state value, elements of a yet-to-be-generated sequence are iteratively “unfolded” by means of the function `f`. In each iteration, the returned tuple of type Option[(A, S)] determines a few things:

  1. the 1st tuple element of type `A` is the new element to be added to the resulting sequence
  2. the 2nd tuple element of type `S` is the next `state` value, revealing how the state is being iteratively mutated
  3. a returned `Some((elem, state))` signals a new element being generated whereas a returned `None` signals the “termination” for the sequence generation operation

In the above example, the `state` is itself a tuple with initial state value `(1, 1055)` and next state value `(i+1, n-i)`. The current state `(i, n)` is then iteratively transformed into an `Option` of tuple with:

  • the element value incrementing from `i` to `i+1`
  • the state value decrementing from `n` to `n-i`, which will be iteratively checked against the `n > 1000` condition

Modified examples from Akka Stream API doc

Let’s look at a couple of examples modified from the Akka Stream API doc for Source.unfold. The modification is minor but necessary due to difference in the method signatures.

Example 1:

def countDown(n: Int): Iterator[Int] =
  Iterator.unfold(n) { count =>
    if (count > 0) Some((count, count - 1))
    else None
  }

countDown(10).toList
// res1: List[Int] = List(10, 9, 8, 7, 6, 5, 4, 3, 2, 1)

This is a nice “Hello World” example of `unfold`. Following the above bullet points of how-it-works, #1 and #2 tell us the resulting sequence has a starting element `count` iteratively decremented by 1 and how-it-works #3 says when `count` is not larger than 0 (i.e. decremented down to 0) the sequence generation operation stops.

Example 2:

val fibonacci: Iterator[BigInt] =
  Iterator.unfold((BigInt(0), BigInt(1))) {
    case (x, y) => Some((x, (y, x + y)))
  }

fibonacci.take(10).toList 
// res1: List[BigInt] = List(0, 1, 1, 2, 3, 5, 8, 13, 21, 34)

fibonacci.drop(90).next
// res2: BigInt = 2880067194370816120

This example showcases a slick way of generating a Fibonacci sequence. Here, we use a tuple as the initial `state` value, resulting in the operator returning a value with a nested tuple. Tuples are used for the `state` because each number in a Fibonacci sequence depends on two preceding numbers, `Fib(n) = Fib(n-1) + Fib(n-2)`, hence in composing the sequence content we want to carry over more than one number in every iteration.

Applying the logic of how-it-works #1 and #2, if `x` and `y` represent the current and next elements, respectively, generated for the resulting sequence, `x + y` would be the value of the following element in accordance with the definition of Fibonacci numbers. In essence, the tuple `state` represents the next two values of elements to be generated. What about how-it-works #3? The missing of `None` case in the return value of the binary operator indicates that there is no terminating condition, hence we have an infinite sequence here.

Another example:

Here’s one more example which illustrates how one could generate a Factorial sequence using `unfold`.

val factorial: Iterator[BigInt] =
  Iterator.unfold((BigInt(1), 0)){
    case (n, i) => Some((n, (n*(i+1), i+1)))
  }

factorial.take(10).toList
// res1: List[BigInt] = List(1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880)

In this example, we also use a tuple to represent the `state`, although there is a critical difference between what the tuple elements represent when compared with the previous example. By definition, the next number in a Factorial sequence only depends on the immediately preceding number, `Fact(i+1) = Fact(i) * (i+1)`, thus the first tuple element, `n * (i+1)`, takes care of that, defining what the next element of the resulting sequence will be. But there is also a need to carry over the next index value and that’s what the second tuple element is for. Again, without the `None` case in the return value of the binary operator, the resulting sequence will be infinite.

As a side note, we could also use method iterate that comes with Scala `Iterator` collection with similar iteration logic like below:

def factorial(n: Int): Iterator[BigInt] =
  Iterator.iterate((BigInt(1), 0)){
      case (n, i) => (n*(i+1), i+1)
    }.
    take(n).map(_._1)

factorial(10).toList
// res2: List[BigInt] = List(1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880)

Merkle Tree Implementation In Scala

A Merkle tree, a.k.a. hash tree, is a tree in which every leaf node contains a cryptographic hash of a dataset, and every branch node contains a hash of the concatenation of the corresponding hashes of its child nodes. Typical usage is for efficient verification of the content stored in the tree nodes.

Blockchain and Merkle tree

As cryptocurrency (or more generally, blockchain system) has become popular, so has its underlying authentication-oriented data structure, Merkle tree. In the cryptocurrency world, a blockchain can be viewed as a distributed ledger consisting of immutable but chain-able blocks, each of which hosts a set of transactions in the form of a Merkle tree. In order to chain a new block to an existing blockchain, part of the tamper-proof requirement is to guarantee the integrity of the enclosed transactions by composing their hashes in a specific way and storing them in a Merkle tree.

In case the above sounds like gibberish, here’s a great introductory article about blockchain. To delve slight deeper into it with a focus on cryptocurrency, this blockchain guide from the Bitcoin Project website might be of interest. Just to be clear, even though blockchain helps popularize Merkle tree, implementing a flavor of the data structure does not require knowledge of blockchain or cryptocurrency.

In this blog post, we will assemble a barebone Merkle tree using Scala. While a Merkle tree is most often a binary tree, it’s certainly not confined to be one, although that’s what we’re going to implement.

A barebone Merkle tree class

// MerkleTree class
class MerkleTree(
    val hash: Array[Byte],
    val left: Option[MerkleTree] = None,
    val right: Option[MerkleTree] = None
  )

Note that when both the class fields `left` and `right` are `None`, it represents a leaf node.

To build a Merkle tree from a collection of byte-arrays (which might represent a transaction dataset), we will use a companion object to perform the task via its `apply` method. To create a `hash` within each of the tree nodes, we will also need a hash function, `hashFcn` of type `Array[Byte] => Array[Byte]`.

// MerkleTree companion object
object MerkleTree {
  def apply(data: Array[Array[Byte]], hashFcn: Array[Byte] => Array[Byte]): MerkleTree = {
    val nodes = data.map(byteArr => new MerkleTree(hashFcn(byteArr)))
    buildTree(nodes, hashFcn)(0)  // Return root of the tree
  }

  private def buildTree(...): Array[MerkleTree] = ???
}

Building a Merkle tree

As shown in the code, what’s needed for function `buildTree` is to recursively pair up the nodes to form a tree with each of its nodes consisting the combined hash of their corresponding child nodes. The recursive pairing will eventually end with the single top-level node called the Merkle root. Below is an implementation of such a function:

// Building a Merkle tree
  @scala.annotation.tailrec
  private def buildTree(
      nodes: Array[MerkleTree],
      hashFcn: Array[Byte] => Array[Byte]): Array[MerkleTree] = nodes match {

    case ns if ns.size <= 1 =>
      ns
    case ns =>
      val pairedNodes = ns.grouped(2).map{
          case Array(a, b) => new MerkleTree(hashFcn(a.hash ++ b.hash), Some(a), Some(b))
          case Array(a)    => new MerkleTree(hashFcn(a.hash), Some(a), None)
        }.toArray
      buildTree(pairedNodes, hashFcn)
  }

Now, back to class `MerkleTree`, and let’s add a simple function for computing height of the tree:

// Computing the height of a Merkle tree
  def height: Int = {
    def loop(node: MerkleTree): Int = {
      if (!node.left.isEmpty && !node.right.isEmpty)
        math.max(loop(node.left.get), loop(node.right.get)) + 1
      else if (!node.left.isEmpty)
        loop(node.left.get) + 1
      else if(!node.right.isEmpty)
        loop(node.right.get) + 1
      else 1
    }
    loop(this)
  }

Putting all the pieces together

For illustration purpose, we’ll add a side-effecting function `printNodes` along with a couple of for-display utility functions so as to see what our Merkle tree can do. Putting everything altogether, we have:

// Merkle tree class and companion object
class MerkleTree(
  val hash: Array[Byte],
  val left: Option[MerkleTree] = None,
  val right: Option[MerkleTree] = None) {

  def height: Int = {
    def loop(node: MerkleTree): Int = {
      if (!node.left.isEmpty && !node.right.isEmpty)
        math.max(loop(node.left.get), loop(node.right.get)) + 1
      else if (!node.left.isEmpty)
        loop(node.left.get) + 1
      else if(!node.right.isEmpty)
        loop(node.right.get) + 1
      else 1
    }
    loop(this)
  }

  override def toString: String = s"MerkleTree(hash = ${bytesToBase64(hash)})"
  private def toShortString: String = s"MT(${bytesToBase64(hash).substring(0, 4)})"
  private def bytesToBase64(bytes: Array[Byte]): String =
    java.util.Base64.getEncoder.encodeToString(bytes)

  def printNodes: Unit = {
    def printlnByLevel(t: MerkleTree): Unit = {
      for (l <- 1 to t.height) {
        loopByLevel(t, l)
        println
      }
    }
    def loopByLevel(node: MerkleTree, level: Int): Unit = {
      if (level <= 1)
        print(s"${node.toShortString} ")
      else {
        if (!node.left.isEmpty)
          loopByLevel(node.left.get, level - 1)
        else ()
        if (!node.right.isEmpty)
          loopByLevel(node.right.get, level - 1)
        else ()
      }
    }
    printlnByLevel(this)
  }
}

object MerkleTree {
  def apply(data: Array[Array[Byte]], hashFcn: Array[Byte] => Array[Byte]): MerkleTree = {
    @scala.annotation.tailrec
    def buildTree(nodes: Array[MerkleTree]): Array[MerkleTree] = nodes match {
      case ns if ns.size <= 1 =>
        ns
      case ns =>
        val pairedNodes = ns.grouped(2).map{
            case Array(a, b) => new MerkleTree(hashFcn(a.hash ++ b.hash), Some(a), Some(b))
            case Array(a)    => new MerkleTree(hashFcn(a.hash), Some(a), None)
          }.toArray
        buildTree(pairedNodes)
    }

    if (data.isEmpty)
      new MerkleTree(hashFcn(Array.empty[Byte]))
    else {
      val nodes = data.map(byteArr => new MerkleTree(hashFcn(byteArr)))
      buildTree(nodes)(0)  // Return root of the tree
    }
  }
}

Test building the Merkle tree with a hash function

By providing the required arguments for MerkleTree’s `apply` factory method, let’s create a Merkle tree with, say, 5 dummy byte-arrays using a popular hash function `SHA-256`. The created Merkle tree will be represented by its tree root, a.k.a. Merkle Root:

// Test building a Merkle tree
val sha256: Array[Byte] => Array[Byte] =
  byteArr => java.security.MessageDigest.getInstance("SHA-256").digest(byteArr)

val data = Array(
    Array[Byte](0, 1, 2),
    Array[Byte](3, 4, 5),
    Array[Byte](6, 7, 8),
    Array[Byte](9, 10, 11),
    Array[Byte](12, 13, 14)
  )

val mRoot = MerkleTree(data, sha256)
// mt: MerkleTree = MerkleTree(hash = C6OoSW1rymkivkJPBrhf9necQAbzPsq7RyZKd4ZU8hM=)

mRoot.hash
// res1: Array[Byte] = Array(11, -93, -88, 73, 109, 107, -54, 105, ...)

mRoot.printNodes
// MT(C6Oo) 
// MT(QKRB) MT(Hfri) 
// MT(J3JM) MT(d7VY) MT(9ypu) 
// MT(rksy) MT(KEhp) MT(Q4f2) MT(45qR) MT(BSpd) 

As can be seen from the output, the 5 dummy data blocks get hashed and placed in the 5 leaf nodes, each with its hash value wrapped with its sibling’s (if any) in another hash and placed in the parent node.

For a little better clarity, below is an edited output in a tree structure:

//                                        MT(C6Oo)
//                                       /        \
//                            ----------           -----------
//                          /                                  \
//                  MT(QKRB)                                    MT(Hfri)
//                /          \                                /
//               /            \                              /
//       MT(J3JM)              MT(d7VY)              MT(9ypu)
//       /     \               /     \               /
//      /       \             /       \             /
// MT(rksy)   MT(KEhp)   MT(Q4f2)   MT(45qR)   MT(BSpd)

Building a Merkle tree from blockchain transactions

To apply the using of Merkle tree in the blockchain world, we’ll substitute the data block with a sequence of transactions from a blockchain.

First, we define a trivialized `Transaction` class with the transaction ID being the hash value of the combined class fields using the same hash function `sha256`:

// A trivialized Transaction class
import java.nio.file.{Files, Paths}
import java.time.{Instant, LocalDateTime, ZoneId}
import java.time.format.DateTimeFormatter

case class Transaction(id: String, from: String, to: String, amount: Long, timestamp: Long) {
  override def toString: String = {
    val ts = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss").format(
        LocalDateTime.ofInstant(Instant.ofEpochMilli(timestamp), ZoneId.of("UTC"))
      )
    s"T(${id.substring(0, 4)}, ${from} -> ${to}, ${amount}, ${ts})"
  }
}

object Transaction {
  def apply(
      from: String, to: String, amount: Long, timestamp: Long, hashFcn: Array[Byte] => Array[Byte]
    ): Transaction = {
    val bytes = from.getBytes ++ to.getBytes ++ longToBytes(amount) ++ longToBytes(timestamp)
    new Transaction(bytesToBase64(hashFcn(bytes)), from, to, amount, timestamp)
  }

  private def bytesToBase64(bytes: Array[Byte]): String =
    java.util.Base64.getEncoder.encodeToString(bytes)

  private def longToBytes(num: Long) =
    java.nio.ByteBuffer.allocate(8).putLong(num).array
}

Next, we create an array of transactions:

// An array of transactions
val transactions = Array(
  Transaction("Alice", "Bob", 2500L, 1582587819175L, sha256),
  Transaction("Bob", "Carl", 4000L, 1582588700350L, sha256),
  Transaction("Carl", "Dana", 4000L, 1582591774502L, sha256)
)
// transactions: Array[Transaction] = Array(
//   T(ikSk, Alice -> Bob, 2500, 2020-02-24 23:43:39),
//   T(+EvZ, Bob -> Carl, 4000, 2020-02-24 23:58:20),
//   T(Ke8m, Carl -> Dana, 4000, 2020-02-25 00:49:34)
// )

Again, using MerkleTree’s `apply` factory method, we build a Merkle tree consisting of hash values of the individual transaction IDs, which in turn are hashes of their corresponding transaction content:

// Creating the Merkle root
val mRoot = MerkleTree(transactions.map(_.id.getBytes), sha256)
// res1: MerkleTree = MerkleTree(hash = CobcOH899Hq91uk3cXRR5As5J+ThqLacivYEZifhhfM=)

mRoot.printNodes
// MT(Cobc) 
// MT(VG1C) MT(g7oF) 
// MT(nhJS) MT(nt1U) MT(pslD) 

The Merkle root along with the associated transactions are kept in an immutable block. It’s also an integral part of the elements to be collectively hashed into the block-identifying hash value. The block hash will serve as the linking block ID for the next block that manages to successfully append to it. All the cross-hashing operations coupled with the immutable block structure make any attempt to tamper with the blockchain content highly expensive.

Composing Partial Functions In Scala

Just like partial functions in mathematics, a partial function in Scala is a function whose domain doesn’t cover all elements of the domain’s data type. For example:

val f: Function[Int, Int] = x => 100 / x

f(1)
// res1: Int = 100

f(2)
// res2: Int = 50

f(0)
// java.lang.ArithmeticException: / by zero ...

It’s a function defined for all non-zero integers, but f(0) would produce a `java.lang.ArithmeticException`.

By defining it as a partial function like below:

val pf: PartialFunction[Int, Int] = { case x if x != 0 => 100 / x }
// pf: PartialFunction[Int,Int] = 

we will be able to leverage PartialFunction’s methods like isDefinedAt to check on a given element before applying the function to it.

pf.isDefinedAt(1)
// res1: Boolean = true

pf.isDefinedAt(0)
// res2: Boolean = false

Methods lift and unlift

Scala provides a method `lift` for “lifting” a partial function into a total function that returns an Option type. Using the above partial function as an example:

val pf: PartialFunction[Int, Int] = { case x if x != 0 => 100 / x }

val f = pf.lift
// f: Int => Option[Int] = 

f(1)
// res1: Option[Int] = Some(100)

f(0)
// res2: Option[Int] = None

Simple enough. Conversely, an Option-typed total function can be “unlifted” to a partial function. Applying `unlift` to the above function `f` would create a new partial function same as `pf`:

val pf2 = f.unlift
// pf2: PartialFunction[Int,Int] = 

pf2.isDefinedAt(1)
// res3: Boolean = true

pf2.isDefinedAt(0)
// res4: Boolean = false

Function compositions

For simplicity, we’ll look at only functions with arity 1 (i.e. `Function1`, which takes a single argument). It’s trivial to use the same concept to apply to `FunctionN`.

Methods like `andThen` and `compose` enable compositions of Scala functions. Since both methods are quite similar, I’m going to talk about `andThen` only. Readers who would like to extend to `compose` may try it as a programming exercise.

Method andThen for `Function1[T1, R]` has the following signature:

def andThen[A](g: (R) => A): (T1) => A

A trivial example:

val double: Int => Int = _ * 2
val add1: Int => Int = _ + 1

val doubleThenAdd1 = double andThen add1
// doubleThenAdd1: Int => Int = scala.Function1$Lambda$...

doubleThenAdd1(10)
// res1: Int = 21

Now, let’s replace the 2nd function `add1` with a partial function `inverse`:

val double: Int => Int = _ * 2
val inverse: PartialFunction[Int, Double] = { case x if x != 0 => 1.0 / x }

val doubleThenInverse = double andThen inverse
// doubleThenInverse: Int => Double = scala.Function1$Lambda$...

doubleThenInverse(10)
// res2: Double = 0.05

doubleThenInverse(0)
// scala.MatchError: 0 (of class java.lang.Integer) ...

doubleThenInverse.isDefinedAt(0)
// error: value isDefinedAt is not a member of Int => Double

Note that `doubleThenInverse` still returns a total function even though the composing function is partial. That’s because PartialFunction is a subclass of Function:

trait PartialFunction[-A, +B] extends (A) => B

hence method `andThen` rightfully returns a total function as advertised.

Unfortunately, that’s undesirable as the resulting function lost the `inverse` partial function’s domain information.

Partial function compositions

Method andThen for `PartialFunction[A, C]` has its signature as follows:

def andThen[C](k: (B) => C): PartialFunction[A, C]

Example:

val inverse: PartialFunction[Int, Double] = { case x if x != 0 => 1.0 / x }
val pfMap: PartialFunction[Double, String] = Map(0.1 -> "a",  0.2 -> "b")

val inverseThenPfMap = inverse andThen pfMap
// inverseThenPfMap: PartialFunction[Int,String] = 

inverseThenPfMap(10)
// res1: String = a

inverseThenPfMap(5)
// res2: String = b

inverseThenPfMap.isDefinedAt(10)
// res3: Boolean = true

inverseThenPfMap.isDefinedAt(5)
// res4: Boolean = true

inverseThenPfMap.isDefinedAt(0)
// res5: Boolean = false

// So far so good ... Now, let's try:

inverseThenPfMap(2)
// java.util.NoSuchElementException: key not found: 0.5

inverseThenPfMap.isDefinedAt(2)
// res6: Boolean = false

That works perfectly, since any given element not in the domain of any of the partial functions being composed should have its corresponding element(s) eliminated from the domain of the composed function. In this case, 0.5 is not in the domain of `pfMap`, hence its corresponding element, 2 (which would have been `inverse`-ed to 0.5), should not be in `inverseThenPfMap`’s domain.

Unfortunately, I neglected to mention I’m on Scala 2.13.x. For Scala 2.12 or below, inverseThenPfMap.isDefinedAt(2) would be `true`.

Turning composed functions into a proper partial function

Summarizing what we’ve looked at, there are two issues at hand:

  1. If the first function among the functions being composed is a total function, the composed function is a total function, discarding domain information of any subsequent partial functions being composed.
  2. Unless you’re on Scala 2.13+, with the first function being a partial function, the resulting composed function is a partial function, but its domain would not embody domain information of any subsequent partial functions being composed.

To tackle the issues, one approach is to leverage implicit conversion by defining a couple of implicit methods that handle composing a partial function on a total function and on a partial function, respectively.

object ComposeFcnOps {
  // Implicit conversion for total function
  implicit class TotalCompose[A, B](f: Function[A, B]) {
    def andThenPF[C](that: PartialFunction[B, C]): PartialFunction[A, C] =
      Function.unlift(x => Option(f(x)).flatMap(that.lift))
  }

  // Implicit conversion for partial function (Not needed on Scala 2.13+)
  implicit class PartialCompose[A, B](pf: PartialFunction[A, B]) {
    def andThenPF[C](that: PartialFunction[B, C]): PartialFunction[A, C] =
      Function.unlift(x => pf.lift(x).flatMap(that.lift))
  }
}

Note that the implicit methods are defined as methods within `implicit class` wrappers, a common practice for the implicit conversion to carry out by invoking the methods like calling class methods.

In the first implicit class, function `f` (i.e. the total function to be implicitly converted) is first transformed to return an `Option`, chained using `flatMap` to the lifted partial function (i.e. the partial function to be composed), followed by an `unlift` to return a partial function.

Similarly, in the second implicit class, function `pf` (i.e. the partial function to be implicitly converted) is first lifted, chained to the lifted partial function (i.e. the partial function to be composed), followed by an `unlift`.

In both cases, `andThenPF` returns a partial function that incorporates the partial domains of the functions involved in the function composition.

Let’s reuse the `double` and `inverse` functions from a previous example:

val double: Int => Int = _ * 2
val inverse: PartialFunction[Int, Double] = { case x if x != 0 => 1.0 / x }

val doubleThenInverse = double andThen inverse
// doubleThenInverse: Int => Double = scala.Function1$Lambda$...

Recall from that example that `doubleThenInverse` is a total function. Now, let’s replace `andThen` with our custom `andThenPF`:

import ComposeFcnOps._

val doubleThenPFInverse: PartialFunction[Int, Double] = double andThenPF inverse
// doubleThenPFInverse: PartialFunction[Int,Double] = 

doubleThenPFInverse(10)
// res1: Double = 0.05

doubleThenPFInverse(0)
// scala.MatchError: 0 (of class java.lang.Integer) ...

doubleThenPFInverse.isDefinedAt(10)
// res2: Boolean = true

doubleThenPFInverse.isDefinedAt(0)
// res2: Boolean = false

The resulting function is now a partial function with the composing function’s partial domain as its own domain. I’ll leave testing for the cases in which the function to be composed is a partial function to the readers.