Monthly Archives: February 2021

Spark Higher-order Functions

Apache Spark’s DataFrame API provides comprehensive functions for transforming or aggregating data in a row-wise fashion. Like many popular relational database systems such as PostgreSQL, these functions are internally optimized to efficiently process large number of rows. Better yet, Spark runs on distributed platforms and if configured to fully utilize available processing cores and memory, it can be handling data at really large scale.

That’s all great, but what about transforming or aggregating data of same type column-wise? Starting from Spark 2.4, a number of methods for ArrayType (and MapType) columns have been added. But users still feel hand-tied when none of the available methods can deal with something as simple as, say, summing the integer elements of an array.

User-provided lambda functions

A higher-order function allows one to process a collection of elements (of the same data type) in accordance with a user-provided lambda function to specify how the collection content should be transformed or aggregated. The lambda function being part of the function signature makes it possible to process the collection of elements with relatively complex processing logic.

Coupled with the using of method array, higher-order functions are particularly useful when transformation or aggregation across a list of columns (of the same data type) is needed. Below are a few of such functions:

  • filter()
  • exists()
  • transform()
  • aggregate()

The lambda function could either be a unary or binary operator. As will be shown in examples below, function aggregate() requires a binary operator whereas the other functions expect a unary operator.

A caveat

Unless you’re on Spark 3.x, higher-order functions aren’t part of Spark 2.4’s built-in DataFrame API. They are expressed in standard SQL syntax along with a lambda function and need to be passed in as a String via expr(). Hence, to use these functions, one would need to temporarily “exit” the Scala world to assemble proper SQL expressions in the SQL arena.

Let’s create a simple DataFrame for illustrating how these higher-order functions work.

case class Order(price: Double, qty: Int)

val df = Seq(
  (101, 10, Order(1.2, 5), Order(1.0, 3), Order(1.5, 4), Seq("strawberry", "currant")),
  (102, 15, Order(1.5, 6), Order(0.8, 5), Order(1.0, 7), Seq("raspberry", "cherry", "blueberry"))
).toDF("id", "discount", "order1", "order2", "order3", "fruits")

df.show(false)
// +---+--------+--------+--------+--------+------------------------------+
// |id |discount|order1  |order2  |order3  |fruits                        |
// +---+--------+--------+--------+--------+------------------------------+
// |101|10      |[1.2, 5]|[1.0, 3]|[1.5, 4]|[strawberry, currant]         |
// |102|15      |[1.5, 6]|[0.8, 5]|[1.0, 7]|[raspberry, cherry, blueberry]|
// +---+--------+--------+--------+--------+------------------------------+

Function filter()

Here’s an “unofficial” method signature of of filter():

// Scala-style signature of `filter()`
def filter[T](arrayCol: ArrayType[T], fcn: T => Boolean): ArrayType[T]

The following snippet uses filter to extract any fruit item that ends with “berry”.

df.
  withColumn("berries", expr("filter(fruits, x -> x rlike '.*berry')")).
  select("id", "fruits", "berries").
  show(false)
// +---+------------------------------+----------------------+
// |id |fruits                        |berries               |
// +---+------------------------------+----------------------+
// |101|[strawberry, currant]         |[strawberry]          |
// |102|[raspberry, cherry, blueberry]|[raspberry, blueberry]|
// +---+------------------------------+----------------------+

Function transform()

Method signature (unofficial) of transform():

// Scala-style signature of `transform()`
def transform[T, S](arrayCol: ArrayType[T], fcn: T => S): ArrayType[S]

Here’s an example of using transform() to flag any fruit not ending with “berry” with an ‘*’.

df.withColumn(
    "non-berries",
    expr("transform(fruits, x -> case when x rlike '.*berry' then x else concat(x, '*') end)")
  ).
  select("id", "fruits", "non-berries").
  show(false)
// +---+------------------------------+-------------------------------+
// |id |fruits                        |non-berries                    |
// +---+------------------------------+-------------------------------+
// |101|[strawberry, currant]         |[strawberry, currant*]         |
// |102|[raspberry, cherry, blueberry]|[raspberry, cherry*, blueberry]|
// +---+------------------------------+-------------------------------+

So far, we’ve seen how higher-order functions transform data in an ArrayType collection. For the following examples, we’ll illustrate applying the higher-order functions to individual columns (of same data type) by first turning selected columns into a single ArrayType column.

Let’s assemble an array of the individual columns we would like to process across:

val orderCols = df.columns.filter{
  c => "^order\\d+$".r.findFirstIn(c).nonEmpty
}
// orderCols: Array[String] = Array(order1, order2, order3)

Function exists()

Method signature (unofficial) of exists():

// Scala-style signature of `exists()`
def exists[T](arrayCol: ArrayType[T], fcn: T => Boolean): Boolean

An example using exists() to check whether any of the individual orders per row consists of item price below $1.

df.
  withColumn("orders", array(orderCols.map(col): _*)).
  withColumn("sub$-prices", expr("exists(orders, x -> x.price < 1)")).
  select("id", "orders", "sub$-prices").
  show(false)
// +---+------------------------------+-----------+
// |id |orders                        |sub$-prices|
// +---+------------------------------+-----------+
// |101|[[1.2, 5], [1.0, 3], [1.5, 4]]|false      |
// |102|[[1.5, 6], [0.8, 5], [1.0, 7]]|true       |
// +---+------------------------------+-----------+

Function aggregate()

Method signature (unofficial) of aggregate():

// Scala-style signature of `aggregate()`
def aggregate[T, S](arrayCol: ArrayType[T], init: S, fcn: (S, T) => S): ArrayType[S]

The example below shows how to compute discounted total of all the orders per row using aggregate().

df.
  withColumn("orders", array(orderCols.map(col): _*)).
  withColumn("total", expr("aggregate(orders, 0d, (acc, x) -> acc + x.price * x.qty)")).
  withColumn("discounted", $"total" * (lit(1.0) - $"discount"/100.0)).
  select("id", "discount", "orders", "total", "discounted").
  show(false)
// +---+--------+------------------------------+-----+----------+
// |id |discount|orders                        |total|discounted|
// +---+--------+------------------------------+-----+----------+
// |101|10      |[[1.2, 5], [1.0, 3], [1.5, 4]]|15.0 |13.5      |
// |102|15      |[[1.5, 6], [0.8, 5], [1.0, 7]]|20.0 |17.0      |
// +---+--------+------------------------------+-----+----------+