Tag Archives: spark

Spark Higher-order Functions

Apache Spark’s DataFrame API provides comprehensive functions for transforming or aggregating data in a row-wise fashion. Like many popular relational database systems such as PostgreSQL, these functions are internally optimized to efficiently process large number of rows. Better yet, Spark runs on distributed platforms and if configured to fully utilize available processing cores and memory, it can be handling data at really large scale.

That’s all great, but what about transforming or aggregating data of same type column-wise? Starting from Spark 2.4, a number of methods for ArrayType (and MapType) columns have been added. But users still feel hand-tied when none of the available methods can deal with something as simple as, say, summing the integer elements of an array.

User-provided lambda functions

A higher-order function allows one to process a collection of elements (of the same data type) in accordance with a user-provided lambda function to specify how the collection content should be transformed or aggregated. The lambda function being part of the function signature makes it possible to process the collection of elements with relatively complex processing logic.

Coupled with the using of method array, higher-order functions are particularly useful when transformation or aggregation across a list of columns (of the same data type) is needed. Below are a few of such functions:

  • filter()
  • exists()
  • transform()
  • aggregate()

The lambda function could either be a unary or binary operator. As will be shown in examples below, function aggregate() requires a binary operator whereas the other functions expect a unary operator.

A caveat

Unless you’re on Spark 3.x, higher-order functions aren’t part of Spark 2.4’s built-in DataFrame API. They are expressed in standard SQL syntax along with a lambda function and need to be passed in as a String via expr(). Hence, to use these functions, one would need to temporarily “exit” the Scala world to assemble proper SQL expressions in the SQL arena.

Let’s create a simple DataFrame for illustrating how these higher-order functions work.

case class Order(price: Double, qty: Int)

val df = Seq(
  (101, 10, Order(1.2, 5), Order(1.0, 3), Order(1.5, 4), Seq("strawberry", "currant")),
  (102, 15, Order(1.5, 6), Order(0.8, 5), Order(1.0, 7), Seq("raspberry", "cherry", "blueberry"))
).toDF("id", "discount", "order1", "order2", "order3", "fruits")

df.show(false)
// +---+--------+--------+--------+--------+------------------------------+
// |id |discount|order1  |order2  |order3  |fruits                        |
// +---+--------+--------+--------+--------+------------------------------+
// |101|10      |[1.2, 5]|[1.0, 3]|[1.5, 4]|[strawberry, currant]         |
// |102|15      |[1.5, 6]|[0.8, 5]|[1.0, 7]|[raspberry, cherry, blueberry]|
// +---+--------+--------+--------+--------+------------------------------+

Function filter()

Here’s an “unofficial” method signature of of filter():

// Scala-style signature of `filter()`
def filter[T](arrayCol: ArrayType[T], fcn: T => Boolean): ArrayType[T]

The following snippet uses filter to extract any fruit item that ends with “berry”.

df.
  withColumn("berries", expr("filter(fruits, x -> x rlike '.*berry')")).
  select("id", "fruits", "berries").
  show(false)
// +---+------------------------------+----------------------+
// |id |fruits                        |berries               |
// +---+------------------------------+----------------------+
// |101|[strawberry, currant]         |[strawberry]          |
// |102|[raspberry, cherry, blueberry]|[raspberry, blueberry]|
// +---+------------------------------+----------------------+

Function transform()

Method signature (unofficial) of transform():

// Scala-style signature of `transform()`
def transform[T, S](arrayCol: ArrayType[T], fcn: T => S): ArrayType[S]

Here’s an example of using transform() to flag any fruit not ending with “berry” with an ‘*’.

df.withColumn(
    "non-berries",
    expr("transform(fruits, x -> case when x rlike '.*berry' then x else concat(x, '*') end)")
  ).
  select("id", "fruits", "non-berries").
  show(false)
// +---+------------------------------+-------------------------------+
// |id |fruits                        |non-berries                    |
// +---+------------------------------+-------------------------------+
// |101|[strawberry, currant]         |[strawberry, currant*]         |
// |102|[raspberry, cherry, blueberry]|[raspberry, cherry*, blueberry]|
// +---+------------------------------+-------------------------------+

So far, we’ve seen how higher-order functions transform data in an ArrayType collection. For the following examples, we’ll illustrate applying the higher-order functions to individual columns (of same data type) by first turning selected columns into a single ArrayType column.

Let’s assemble an array of the individual columns we would like to process across:

val orderCols = df.columns.filter{
  c => "^order\\d+$".r.findFirstIn(c).nonEmpty
}
// orderCols: Array[String] = Array(order1, order2, order3)

Function exists()

Method signature (unofficial) of exists():

// Scala-style signature of `exists()`
def exists[T](arrayCol: ArrayType[T], fcn: T => Boolean): Boolean

An example using exists() to check whether any of the individual orders per row consists of item price below $1.

df.
  withColumn("orders", array(orderCols.map(col): _*)).
  withColumn("sub$-prices", expr("exists(orders, x -> x.price < 1)")).
  select("id", "orders", "sub$-prices").
  show(false)
// +---+------------------------------+-----------+
// |id |orders                        |sub$-prices|
// +---+------------------------------+-----------+
// |101|[[1.2, 5], [1.0, 3], [1.5, 4]]|false      |
// |102|[[1.5, 6], [0.8, 5], [1.0, 7]]|true       |
// +---+------------------------------+-----------+

Function aggregate()

Method signature (unofficial) of aggregate():

// Scala-style signature of `aggregate()`
def aggregate[T, S](arrayCol: ArrayType[T], init: S, fcn: (S, T) => S): ArrayType[S]

The example below shows how to compute discounted total of all the orders per row using aggregate().

df.
  withColumn("orders", array(orderCols.map(col): _*)).
  withColumn("total", expr("aggregate(orders, 0d, (acc, x) -> acc + x.price * x.qty)")).
  withColumn("discounted", $"total" * (lit(1.0) - $"discount"/100.0)).
  select("id", "discount", "orders", "total", "discounted").
  show(false)
// +---+--------+------------------------------+-----+----------+
// |id |discount|orders                        |total|discounted|
// +---+--------+------------------------------+-----+----------+
// |101|10      |[[1.2, 5], [1.0, 3], [1.5, 4]]|15.0 |13.5      |
// |102|15      |[[1.5, 6], [0.8, 5], [1.0, 7]]|20.0 |17.0      |
// +---+--------+------------------------------+-----+----------+

Spark – Schema With Nested Columns

Extracting columns based on certain criteria from a DataFrame (or Dataset) with a flat schema of only top-level columns is simple. It gets slightly less trivial, though, if the schema consists of hierarchical nested columns.

Recursive traversal

In functional programming, a common tactic to traverse arbitrarily nested collections of elements is through recursion. It’s generally preferred over using while-loops with mutable counters. For performance at scale, making the traversal tail-recursive may be necessary – although it’s less of a concern in this case given that a DataFrame typically consists not more than a few hundreds of columns and a few levels of nesting.

We’re going to illustrate in a couple of simple examples how recursion can be used to effectively process a DataFrame with a schema of nested columns.

Example #1:  Get all nested columns of a given data type

Consider the following snippet:

import org.apache.spark.sql.types._

def getColsByType(schema: StructType, dType: DataType) = {
  def recurPull(sType: StructType, prefix: String): Seq[String] = sType.fields.flatMap {
    case StructField(name, dt: StructType, _, _) ⇒
      recurPull(dt, s"$prefix$name.")
    case StructField(name, dt: ArrayType, _, _) if dt.elementType.isInstanceOf[StructType] ⇒
      recurPull(dt.elementType.asInstanceOf[StructType], s"$prefix$name[].")
    case field @ StructField(name, _, _, _) if field.dataType == dType ⇒
      Seq(s"$prefix$name")
    case _ ⇒
      Seq.empty[String]
  }
  recurPull(schema, "")
}

By means of a simple recursive method, the data type of each column in the DataFrame is traversed and, in the case of `StructType`, recurs to traverse its child columns. A string-type `prefix` during the traversal is assembled to express the hierarchy of the individual nested columns and gets prepended to columns with the matching data type.

Testing the method:

import org.apache.spark.sql.functions._
import spark.implicits._

case class Spec(sid: Int, desc: String)
case class Prod(pid: String, spec: Spec)

val df = Seq(
  (101, "jenn", Seq(1, 2), Seq(Spec(1, "A"), Spec(2, "B")), Prod("X11", Spec(11, "X")), 1100.0),
  (202, "mike", Seq(3), Seq(Spec(3, "C")), Prod("Y22", Spec(22, "Y")), 2200.0)
).toDF("uid", "user", "ids", "specs", "product", "amount")

getColsByType(df.schema, DoubleType)
// res1: Seq[String] = ArraySeq(amount)

getColsByType(df.schema, IntegerType)
// res2: Seq[String] = ArraySeq(uid, specs[].sid, product.spec.sid)

getColsByType(df.schema, StringType)
// res3: Seq[String] = ArraySeq(user, specs[].desc, product.pid, product.spec.desc)

// getColsByType(df.schema, ArrayType)
// ERROR: type mismatch;
//  found   : org.apache.spark.sql.types.ArrayType.type
//  required: org.apache.spark.sql.types.DataType

getColsByType(df.schema, ArrayType(IntegerType, false))
// res4: Seq[String] = ArraySeq(ids)

Example #2:  Rename all nested columns via a provided function

In this example, we’re going to rename columns in a DataFrame with a nested schema based on a provided `rename` function. The required logic for recursively traversing the nested columns is pretty much the same as in the previous example.

import org.apache.spark.sql.types._

def renameAllCols(schema: StructType, rename: String ⇒ String): StructType = {
  def recurRename(sType: StructType): Seq[StructField] = sType.fields.map{
    case StructField(name, dt: StructType, nu, meta) ⇒
      StructField(rename(name), StructType(recurRename(dt)), nu, meta)
    case StructField(name, dt: ArrayType, nu, meta) if dt.elementType.isInstanceOf[StructType] ⇒
      StructField(rename(name), ArrayType(StructType(recurRename(dt.elementType.asInstanceOf[StructType])), true), nu, meta)
    case StructField(name, dt, nu, meta) ⇒
      StructField(rename(name), dt, nu, meta)
  }
  StructType(recurRename(schema))
}

Testing the method (with the same DataFrame used in the previous example):

import org.apache.spark.sql.functions._
import spark.implicits._

val renameFcn = (s: String) ⇒
  if (s.endsWith("id")) s.replace("id", "_id") else s

val newDF = spark.createDataFrame(df.rdd, renameAllCols(df.schema, renameFcn))

newDF.printSchema
// root
//  |-- u_id: integer (nullable = false)
//  |-- user: string (nullable = true)
//  |-- ids: array (nullable = true)
//  |    |-- element: integer (containsNull = false)
//  |-- specs: array (nullable = true)
//  |    |-- element: struct (containsNull = true)
//  |    |    |-- s_id: integer (nullable = false)
//  |    |    |-- desc: string (nullable = true)
//  |-- product: struct (nullable = true)
//  |    |-- p_id: string (nullable = true)
//  |    |-- spec: struct (nullable = true)
//  |    |    |-- s_id: integer (nullable = false)
//  |    |    |-- desc: string (nullable = true)
//  |-- amount: double (nullable = false)

In case it isn’t obvious, in traversing a given StructType‘s child columns, we use map (as opposed to flatMap in the previous example) to preserve the hierarchical column structure.

Spark – Custom Timeout Sessions

In the previous blog post, we saw how one could partition a time series log of web activities into web page-based sessions. Operating on the same original dataset, we’re going to generate sessions based on a different set of rules.

Rather than web page-based, sessions are defined with the following rules:

  1. A session expires after inactivity of a timeout period (say `tmo1`), and,
  2. An active session expires after a timeout period (say `tmo2`).

First, we assemble the original sample dataset used in the previous blog:

val df = Seq(
  (101, "2018-06-01 10:30:15", "home", "redirect"),
  (101, "2018-06-01 10:32:00", "home", "info-click"),
  (101, "2018-06-01 10:35:00", "home", "info-click"),
  (101, "2018-06-01 11:00:45", "products", "visit"),
  (101, "2018-06-01 11:12:00", "products", "info-click"),
  (101, "2018-06-01 11:25:30", "products", "info-click"),
  (101, "2018-06-01 11:38:15", "about-us", "info-click"),
  (101, "2018-06-01 11:50:00", "about-us", "info-click"),
  (101, "2018-06-01 12:01:45", "home", "visit"),
  (101, "2018-06-01 12:04:00", "home", "info-click"),
  (101, "2018-06-01 20:02:45", "home", "visit"),
  (101, "2018-06-01 20:40:00", "products", "info-click"),
  (101, "2018-06-01 20:46:30", "products", "info-click"),
  (101, "2018-06-01 20:50:15", "products", "add-to-cart"),
  (220, "2018-06-01 18:15:30", "home", "redirect"),
  (220, "2018-06-01 18:17:00", "home", "info-click"),
  (220, "2018-06-01 18:40:45", "home", "info-click"),
  (220, "2018-06-01 18:52:30", "home", "info-click"),
  (220, "2018-06-01 19:04:45", "products", "info-click"),
  (220, "2018-06-01 19:17:00", "products", "info-click"),
  (220, "2018-06-01 19:30:30", "products", "info-click"),
  (220, "2018-06-01 19:42:30", "products", "info-click"),
  (220, "2018-06-01 19:45:30", "products", "add-to-cart")
).toDF("user", "timestamp", "page", "action")

Let’s set the first timeout `tmo1` to 15 minutes, and the second timeout `tmo2` to 60 minutes.

The end result should look something like below:

+----+-------------------+-------+
|user|          timestamp|sess_id|
+----+-------------------+-------+
| 101|2018-06-01 10:30:15|  101-1|
| 101|2018-06-01 10:32:00|  101-1|
| 101|2018-06-01 10:35:00|  101-1|
| 101|2018-06-01 11:00:45|  101-2|  <-- timeout rule #1
| 101|2018-06-01 11:12:00|  101-2|
| 101|2018-06-01 11:25:30|  101-2|
| 101|2018-06-01 11:38:15|  101-2|
| 101|2018-06-01 11:50:00|  101-2|
| 101|2018-06-01 12:01:45|  101-3|  <-- timeout rule #2
| 101|2018-06-01 12:04:00|  101-3|
| ...|           ...     |    ...|
+----+-------------------+-------+

Given the above session creation rules, it’s obvious that all programming logic is going to be centered around the timestamp alone, hence the omission of columns like `page` in the expected final result.

Generating sessions based on rule #1 is rather straight forward as computing the timestamp difference between consecutive rows is easy with Spark built-in Window functions. As for session creation rule #2, it requires dynamically identifying the start of the next session that depends on where the current session ends. Hence, even robust Window functions over, say, `partitionBy(user).orderBy(timestamp).rangeBetween(0, tmo2)` wouldn’t cut it.

The solution to be suggested involves using a UDF (user-defined fucntion) to leverage Scala’s feature-rich set of functions:

def tmoSessList(tmo: Long) = udf{ (uid: String, tsList: Seq[String], tsDiffs: Seq[Long]) =>
  def sid(n: Long) = s"$uid-$n"

  val sessList = tsDiffs.foldLeft( (List[String](), 0L, 0L) ){ case ((ls, j, k), i) =>
    if (i == 0 || j + i >= tmo)
      (sid(k + 1) :: ls, 0L, k + 1)
    else
      (sid(k) :: ls, j + i, k)
  }._1.reverse

  tsList zip sessList
}

Note that the timestamp diff list `tsDiffs` is the main input being processed for generating sessions based on the `tmo2` value (session create rule #2). The timestamp list `tsList` is being “passed thru” merely to be included in the output with each timestamp paired with the corresponding session ID.

Also note that the accumulator for `foldLeft` in the UDF is a Tuple of `(ls, j, k)`, where:

  • `ls` is the list of formatted session IDs to be returned
  • `j` and `k` are for carrying over the conditionally changing timestamp value and session id number, respectively, to the next iteration

Now, let’s lay out the steps for carrying out the necessary transformations to generate the sessions:

  1. Identify sessions (with 0 = start of a session) per user based on session creation rule #1
  2. Group the dataset to assemble the timestamp diff list per user
  3. Process the timestamp diff list via the above UDF to identify sessions based on rule #2 and generate all session IDs per user
  4. Expand the processed dataset which consists of the timestamp paired with the corresponding session IDs

Step 1:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
import spark.implicits._

val tmo1: Long = 15 * 60
val tmo2: Long = 60 * 60

val win1 = Window.partitionBy("user").orderBy("timestamp")

val df1 = df.
  withColumn("ts_diff",
    unix_timestamp($"timestamp") - unix_timestamp(lag($"timestamp", 1).over(win1))
  ).
  withColumn("ts_diff", when(row_number.over(win1) === 1 || $"ts_diff" >= tmo1, 0L).
    otherwise($"ts_diff")
  )

df1.show(50)
+----+-------------------+--------+-----------+-------+
|user|          timestamp|    page|     action|ts_diff|
+----+-------------------+--------+-----------+-------+
| 101|2018-06-01 10:30:15|    home|   redirect|      0|
| 101|2018-06-01 10:32:00|    home| info-click|    105|
| 101|2018-06-01 10:35:00|    home| info-click|    180|
| 101|2018-06-01 11:00:45|products|      visit|      0|
| 101|2018-06-01 11:12:00|products| info-click|    675|
| 101|2018-06-01 11:25:30|products| info-click|    810|
| 101|2018-06-01 11:38:15|about-us| info-click|    765|
| 101|2018-06-01 11:50:00|about-us| info-click|    705|
| 101|2018-06-01 12:01:45|    home|      visit|    705|
| 101|2018-06-01 12:04:00|    home| info-click|    135|
| 101|2018-06-01 20:02:45|    home|      visit|      0|
| 101|2018-06-01 20:40:00|products| info-click|      0|
| 101|2018-06-01 20:46:30|products| info-click|    390|
| 101|2018-06-01 20:50:15|products|add-to-cart|    225|
| 220|2018-06-01 18:15:30|    home|   redirect|      0|
| 220|2018-06-01 18:17:00|    home| info-click|     90|
| 220|2018-06-01 18:40:45|    home| info-click|      0|
| 220|2018-06-01 18:52:30|    home| info-click|    705|
| 220|2018-06-01 19:04:45|products| info-click|    735|
| 220|2018-06-01 19:17:00|products| info-click|    735|
| 220|2018-06-01 19:30:30|products| info-click|    810|
| 220|2018-06-01 19:42:30|products| info-click|    720|
| 220|2018-06-01 19:45:30|products|add-to-cart|    180|
+----+-------------------+--------+-----------+-------+

Steps 2-4:

val df2 = df1.
  groupBy("user").agg(
    collect_list($"timestamp").as("ts_list"), collect_list($"ts_diff").as("ts_diffs")
  ).
  withColumn("tmo_sess_id",
    explode(tmoSessList(tmo2)($"user", $"ts_list", $"ts_diffs"))
  ).
  select($"user", $"tmo_sess_id._1".as("timestamp"), $"tmo_sess_id._2".as("sess_id"))

df2.show(50)
+----+-------------------+-------+
|user|          timestamp|sess_id|
+----+-------------------+-------+
| 101|2018-06-01 10:30:15|  101-1|  User 101
| 101|2018-06-01 10:32:00|  101-1|
| 101|2018-06-01 10:35:00|  101-1|
| 101|2018-06-01 11:00:45|  101-2|  <-- timeout rule #1
| 101|2018-06-01 11:12:00|  101-2|
| 101|2018-06-01 11:25:30|  101-2|
| 101|2018-06-01 11:38:15|  101-2|
| 101|2018-06-01 11:50:00|  101-2|
| 101|2018-06-01 12:01:45|  101-3|  <-- timeout rule #2
| 101|2018-06-01 12:04:00|  101-3|
| 101|2018-06-01 20:02:45|  101-4|  <-- timeout rule #1
| 101|2018-06-01 20:40:00|  101-5|  <-- timeout rule #1
| 101|2018-06-01 20:46:30|  101-5|
| 101|2018-06-01 20:50:15|  101-5|
| 220|2018-06-01 18:15:30|  220-1|  User 220
| 220|2018-06-01 18:17:00|  220-1|
| 220|2018-06-01 18:40:45|  220-2|  <-- timeout rule #1
| 220|2018-06-01 18:52:30|  220-2|
| 220|2018-06-01 19:04:45|  220-2|
| 220|2018-06-01 19:17:00|  220-2|
| 220|2018-06-01 19:30:30|  220-2|
| 220|2018-06-01 19:42:30|  220-3|  <-- timeout rule #2
| 220|2018-06-01 19:45:30|  220-3|
+----+-------------------+-------+