Author Archives: Leo Cheung

Composing Partial Functions In Scala

Just like partial functions in mathematics, a partial function in Scala is a function whose domain doesn’t cover all elements of the domain’s data type. For example:

val f: Function[Int, Int] = x => 100 / x

f(1)
// res1: Int = 100

f(2)
// res2: Int = 50

f(0)
// java.lang.ArithmeticException: / by zero ...

It’s a function defined for all non-zero integers, but f(0) would produce a `java.lang.ArithmeticException`.

By defining it as a partial function like below:

val pf: PartialFunction[Int, Int] = { case x if x != 0 => 100 / x }
// pf: PartialFunction[Int,Int] = 

we will be able to leverage PartialFunction’s methods like isDefinedAt to check on a given element before applying the function to it.

pf.isDefinedAt(1)
// res1: Boolean = true

pf.isDefinedAt(0)
// res2: Boolean = false

Methods lift and unlift

Scala provides a method `lift` for “lifting” a partial function into a total function that returns an Option type. Using the above partial function as an example:

val pf: PartialFunction[Int, Int] = { case x if x != 0 => 100 / x }

val f = pf.lift
// f: Int => Option[Int] = 

f(1)
// res1: Option[Int] = Some(100)

f(0)
// res2: Option[Int] = None

Simple enough. Conversely, an Option-typed total function can be “unlifted” to a partial function. Applying `unlift` to the above function `f` would create a new partial function same as `pf`:

val pf2 = f.unlift
// pf2: PartialFunction[Int,Int] = 

pf2.isDefinedAt(1)
// res3: Boolean = true

pf2.isDefinedAt(0)
// res4: Boolean = false

Function compositions

For simplicity, we’ll look at only functions with arity 1 (i.e. `Function1`, which takes a single argument). It’s trivial to use the same concept to apply to `FunctionN`.

Methods like `andThen` and `compose` enable compositions of Scala functions. Since both methods are quite similar, I’m going to talk about `andThen` only. Readers who would like to extend to `compose` may try it as a programming exercise.

Method andThen for `Function1[T1, R]` has the following signature:

def andThen[A](g: (R) => A): (T1) => A

A trivial example:

val double: Int => Int = _ * 2
val add1: Int => Int = _ + 1

val doubleThenAdd1 = double andThen add1
// doubleThenAdd1: Int => Int = scala.Function1$Lambda$...

doubleThenAdd1(10)
// res1: Int = 21

Now, let’s replace the 2nd function `add1` with a partial function `inverse`:

val double: Int => Int = _ * 2
val inverse: PartialFunction[Int, Double] = { case x if x != 0 => 1.0 / x }

val doubleThenInverse = double andThen inverse
// doubleThenInverse: Int => Double = scala.Function1$Lambda$...

doubleThenInverse(10)
// res2: Double = 0.05

doubleThenInverse(0)
// scala.MatchError: 0 (of class java.lang.Integer) ...

doubleThenInverse.isDefinedAt(0)
// error: value isDefinedAt is not a member of Int => Double

Note that `doubleThenInverse` still returns a total function even though the composing function is partial. That’s because PartialFunction is a subclass of Function:

trait PartialFunction[-A, +B] extends (A) => B

hence method `andThen` rightfully returns a total function as advertised.

Unfortunately, that’s undesirable as the resulting function lost the `inverse` partial function’s domain information.

Partial function compositions

Method andThen for `PartialFunction[A, C]` has its signature as follows:

def andThen[C](k: (B) => C): PartialFunction[A, C]

Example:

val inverse: PartialFunction[Int, Double] = { case x if x != 0 => 1.0 / x }
val pfMap: PartialFunction[Double, String] = Map(0.1 -> "a",  0.2 -> "b")

val inverseThenPfMap = inverse andThen pfMap
// inverseThenPfMap: PartialFunction[Int,String] = 

inverseThenPfMap(10)
// res1: String = a

inverseThenPfMap(5)
// res2: String = b

inverseThenPfMap.isDefinedAt(10)
// res3: Boolean = true

inverseThenPfMap.isDefinedAt(5)
// res4: Boolean = true

inverseThenPfMap.isDefinedAt(0)
// res5: Boolean = false

// So far so good ... Now, let's try:

inverseThenPfMap(2)
// java.util.NoSuchElementException: key not found: 0.5

inverseThenPfMap.isDefinedAt(2)
// res6: Boolean = false

That works perfectly, since any given element not in the domain of any of the partial functions being composed should have its corresponding element(s) eliminated from the domain of the composed function. In this case, 0.5 is not in the domain of `pfMap`, hence its corresponding element, 2 (which would have been `inverse`-ed to 0.5), should not be in `inverseThenPfMap`’s domain.

Unfortunately, I neglected to mention I’m on Scala 2.13.x. For Scala 2.12 or below, inverseThenPfMap.isDefinedAt(2) would be `true`.

Turning composed functions into a proper partial function

Summarizing what we’ve looked at, there are two issues at hand:

  1. If the first function among the functions being composed is a total function, the composed function is a total function, discarding domain information of any subsequent partial functions being composed.
  2. Unless you’re on Scala 2.13+, with the first function being a partial function, the resulting composed function is a partial function, but its domain would not embody domain information of any subsequent partial functions being composed.

To tackle the issues, one approach is to leverage implicit conversion by defining a couple of implicit methods that handle composing a partial function on a total function and on a partial function, respectively.

object ComposeFcnOps {
  // Implicit conversion for total function
  implicit class TotalCompose[A, B](f: Function[A, B]) {
    def andThenPF[C](that: PartialFunction[B, C]): PartialFunction[A, C] =
      Function.unlift(x => Option(f(x)).flatMap(that.lift))
  }

  // Implicit conversion for partial function (Not needed on Scala 2.13+)
  implicit class PartialCompose[A, B](pf: PartialFunction[A, B]) {
    def andThenPF[C](that: PartialFunction[B, C]): PartialFunction[A, C] =
      Function.unlift(x => pf.lift(x).flatMap(that.lift))
  }
}

Note that the implicit methods are defined as methods within `implicit class` wrappers, a common practice for the implicit conversion to carry out by invoking the methods like calling class methods.

In the first implicit class, function `f` (i.e. the total function to be implicitly converted) is first transformed to return an `Option`, chained using `flatMap` to the lifted partial function (i.e. the partial function to be composed), followed by an `unlift` to return a partial function.

Similarly, in the second implicit class, function `pf` (i.e. the partial function to be implicitly converted) is first lifted, chained to the lifted partial function (i.e. the partial function to be composed), followed by an `unlift`.

In both cases, `andThenPF` returns a partial function that incorporates the partial domains of the functions involved in the function composition.

Let’s reuse the `double` and `inverse` functions from a previous example:

val double: Int => Int = _ * 2
val inverse: PartialFunction[Int, Double] = { case x if x != 0 => 1.0 / x }

val doubleThenInverse = double andThen inverse
// doubleThenInverse: Int => Double = scala.Function1$Lambda$...

Recall from that example that `doubleThenInverse` is a total function. Now, let’s replace `andThen` with our custom `andThenPF`:

import ComposeFcnOps._

val doubleThenPFInverse: PartialFunction[Int, Double] = double andThenPF inverse
// doubleThenPFInverse: PartialFunction[Int,Double] = 

doubleThenPFInverse(10)
// res1: Double = 0.05

doubleThenPFInverse(0)
// scala.MatchError: 0 (of class java.lang.Integer) ...

doubleThenPFInverse.isDefinedAt(10)
// res2: Boolean = true

doubleThenPFInverse.isDefinedAt(0)
// res2: Boolean = false

The resulting function is now a partial function with the composing function’s partial domain as its own domain. I’ll leave testing for the cases in which the function to be composed is a partial function to the readers.

Scala’s groupMap And groupMapReduce

For grouping elements in a Scala collection by a provided key, the de facto method of choice has been groupBy, which has the following signature for an `Iterable`:

// Method groupBy
def groupBy[K](f: (A) => K): immutable.Map[K, Iterable[A]]

It returns an immutable Map of elements each consisting of a key and a collection of values of the original type. To process this collection of values in the resulting Map, Scala provides a method mapValues with the below signature:

// Method mapValues
def mapValues[W](f: (V) => W): Map[K, W]

This `groupBy/mapValues` combo proves to be handy for processing the values of the Map generated from the grouping. However, as of Scala 2.13, method `mapValues` is no longer available.

groupMap

A new method, groupMap, has emerged for grouping of a collection based on provided functions for defining the keys and values of the resulting Map. Here’s the signature of method groupMap for an `Iterable`:

// Method groupMap
def groupMap[K, B](key: (A) => K)(f: (A) => B): immutable.Map[K, Iterable[B]]

Let’s start with a simple example grouping via the good old `groupBy` method:

// Example 1: groupBy
val fruits = List("apple", "apple", "orange", "pear", "pear", "pear")

fruits.groupBy(identity)
// res1: Map[String, List[String]] = Map(
//   "orange" -> List("orange"),
//   "apple" -> List("apple", "apple"),
//   "pear" -> List("pear", "pear", "pear")
// )

We can replace `groupBy` with `groupMap` like below:

// Example 1: groupMap
fruits.groupMap(identity)(identity)

In this particular case, the new method doesn’t offer any benefit over the old one.

Let’s look at another example that involves a collection of class objects:

// Example 2
case class Pet(species: String, name: String, age: Int)

val pets = List(
  Pet("cat", "sassy", 2), Pet("cat", "bella", 3), 
  Pet("dog", "poppy", 3), Pet("dog", "bodie", 4), Pet("dog", "poppy", 2), 
  Pet("bird", "coco", 2), Pet("bird", "kiwi", 1)
)

If we want to list all pet names per species, a `groupBy` coupled with `mapValues` will do:

// Example 2: groupBy
pets.groupBy(_.species).mapValues(_.map(_.name))
// res2: Map[String, List[String]] = Map(
//   "cat" -> List("sassy", "bella"),
//   "bird" -> List("coco", "kiwi"),
//   "dog" -> List("poppy", "bodie", "poppy")
// )

But in this case, `groupMap` can do it with better readability due to the functions for defining the keys and values of the resulting Map being nicely placed side by side as parameters:

// Example 2: groupMap
pets.groupMap(_.species)(_.name)

groupMapReduce

At times, we need to perform reduction on the Map values after grouping of a collection. This is when the other new method groupMapReduce comes in handy:

// Method groupMapReduce
def groupMapReduce[K, B](key: (A) => K)(f: (A) => B)(reduce: (B, B) => B): immutable.Map[K, B]

Besides the parameters for defining the keys and values of the resulting Map like `groupMap`, `groupMapReduce` also expects an additional parameter in the form of a binary operation for reduction.

Using the same pets example, if we want to compute the count of pets per species, a `groupBy/mapValues` approach will look like below:

// Example 3: groupBy/mapValues
pets.groupBy(_.species).mapValues(_.size)
// res1: Map[String, Int] = Map("cat" -> 2, "bird" -> 2, "dog" -> 3)

With `groupMapReduce`, we can “compartmentalize” the functions for the keys, values and reduction operation separately as follows:

// Example 3: groupMapReduce
pets.groupMapReduce(_.species)(_ => 1)(_ + _)

One more example:

// Example 4
import java.time.LocalDate
case class Product(id: String, saleDate: LocalDate, listPrice: Double, discPrice: Double)

val products = List(
  Product("p001", LocalDate.of(2019, 9, 11), 10, 8.5),
  Product("p002", LocalDate.of(2019, 9, 18), 12, 10),
  Product("p003", LocalDate.of(2019, 9, 27), 10, 9),
  Product("p004", LocalDate.of(2019, 10, 6), 15, 12.5),
  Product("p005", LocalDate.of(2019, 10, 20), 12, 8),
  Product("p006", LocalDate.of(2019, 11, 8), 15, 12),
  Product("p007", LocalDate.of(2019, 11, 16), 10, 8.5),
  Product("p008", LocalDate.of(2019, 11, 25), 10, 9)
)

Let’s say we want to compute the monthly total of list price and discounted price of the product list. In the `groupBy/mapValues` way:

// Example 4: groupBy/mapValues
products.groupBy(_.saleDate.getMonth).mapValues(
  _.map(p => (p.listPrice, p.discPrice)).reduce(
    (total, prc) => (total._1 + prc._1, total._2 + prc._2))
)
// res2: scala.collection.immutable.Map[java.time.Month,(Double, Double)] =
//   Map(OCTOBER -> (27.0,20.5), SEPTEMBER -> (32.0,27.5), NOVEMBER -> (35.0,29.5))

Using `groupMapReduce`:

// Example 4: groupMapReduce
products.groupMapReduce(_.saleDate.getMonth)(p => (p.listPrice, p.discPrice))(
  (total, prc) => (total._1 + prc._1, total._2 + prc._2))
)

Spark – Schema With Nested Columns

Extracting columns based on certain criteria from a DataFrame (or Dataset) with a flat schema of only top-level columns is simple. It gets slightly less trivial, though, if the schema consists of hierarchical nested columns.

Recursive traversal

In functional programming, a common tactic to traverse arbitrarily nested collections of elements is through recursion. It’s generally preferred over using while-loops with mutable counters. For performance at scale, making the traversal tail-recursive may be necessary – although it’s less of a concern in this case given that a DataFrame typically consists not more than a few hundreds of columns and a few levels of nesting.

We’re going to illustrate in a couple of simple examples how recursion can be used to effectively process a DataFrame with a schema of nested columns.

Example #1:  Get all nested columns of a given data type

Consider the following snippet:

import org.apache.spark.sql.types._

def getColsByType(schema: StructType, dType: DataType) = {
  def recurPull(sType: StructType, prefix: String): Seq[String] = sType.fields.flatMap {
    case StructField(name, dt: StructType, _, _) ⇒
      recurPull(dt, s"$prefix$name.")
    case StructField(name, dt: ArrayType, _, _) if dt.elementType.isInstanceOf[StructType] ⇒
      recurPull(dt.elementType.asInstanceOf[StructType], s"$prefix$name[].")
    case field @ StructField(name, _, _, _) if field.dataType == dType ⇒
      Seq(s"$prefix$name")
    case _ ⇒
      Seq.empty[String]
  }
  recurPull(schema, "")
}

By means of a simple recursive method, the data type of each column in the DataFrame is traversed and, in the case of `StructType`, recurs to traverse its child columns. A string-type `prefix` during the traversal is assembled to express the hierarchy of the individual nested columns and gets prepended to columns with the matching data type.

Testing the method:

import org.apache.spark.sql.functions._
import spark.implicits._

case class Spec(sid: Int, desc: String)
case class Prod(pid: String, spec: Spec)

val df = Seq(
  (101, "jenn", Seq(1, 2), Seq(Spec(1, "A"), Spec(2, "B")), Prod("X11", Spec(11, "X")), 1100.0),
  (202, "mike", Seq(3), Seq(Spec(3, "C")), Prod("Y22", Spec(22, "Y")), 2200.0)
).toDF("uid", "user", "ids", "specs", "product", "amount")

getColsByType(df.schema, DoubleType)
// res1: Seq[String] = ArraySeq(amount)

getColsByType(df.schema, IntegerType)
// res2: Seq[String] = ArraySeq(uid, specs[].sid, product.spec.sid)

getColsByType(df.schema, StringType)
// res3: Seq[String] = ArraySeq(user, specs[].desc, product.pid, product.spec.desc)

// getColsByType(df.schema, ArrayType)
// ERROR: type mismatch;
//  found   : org.apache.spark.sql.types.ArrayType.type
//  required: org.apache.spark.sql.types.DataType

getColsByType(df.schema, ArrayType(IntegerType, false))
// res4: Seq[String] = ArraySeq(ids)

Example #2:  Rename all nested columns via a provided function

In this example, we’re going to rename columns in a DataFrame with a nested schema based on a provided `rename` function. The required logic for recursively traversing the nested columns is pretty much the same as in the previous example.

import org.apache.spark.sql.types._

def renameAllCols(schema: StructType, rename: String ⇒ String): StructType = {
  def recurRename(sType: StructType): Seq[StructField] = sType.fields.map{
    case StructField(name, dt: StructType, nu, meta) ⇒
      StructField(rename(name), StructType(recurRename(dt)), nu, meta)
    case StructField(name, dt: ArrayType, nu, meta) if dt.elementType.isInstanceOf[StructType] ⇒
      StructField(rename(name), ArrayType(StructType(recurRename(dt.elementType.asInstanceOf[StructType])), true), nu, meta)
    case StructField(name, dt, nu, meta) ⇒
      StructField(rename(name), dt, nu, meta)
  }
  StructType(recurRename(schema))
}

Testing the method (with the same DataFrame used in the previous example):

import org.apache.spark.sql.functions._
import spark.implicits._

val renameFcn = (s: String) ⇒
  if (s.endsWith("id")) s.replace("id", "_id") else s

val newDF = spark.createDataFrame(df.rdd, renameAllCols(df.schema, renameFcn))

newDF.printSchema
// root
//  |-- u_id: integer (nullable = false)
//  |-- user: string (nullable = true)
//  |-- ids: array (nullable = true)
//  |    |-- element: integer (containsNull = false)
//  |-- specs: array (nullable = true)
//  |    |-- element: struct (containsNull = true)
//  |    |    |-- s_id: integer (nullable = false)
//  |    |    |-- desc: string (nullable = true)
//  |-- product: struct (nullable = true)
//  |    |-- p_id: string (nullable = true)
//  |    |-- spec: struct (nullable = true)
//  |    |    |-- s_id: integer (nullable = false)
//  |    |    |-- desc: string (nullable = true)
//  |-- amount: double (nullable = false)

In case it isn’t obvious, in traversing a given StructType‘s child columns, we use map (as opposed to flatMap in the previous example) to preserve the hierarchical column structure.