Author Archives: Leo Cheung

Spark Higher-order Functions

Apache Spark’s DataFrame API provides comprehensive functions for transforming or aggregating data in a row-wise fashion. Like many popular relational database systems such as PostgreSQL, these functions are internally optimized to efficiently process large number of rows. Better yet, Spark runs on distributed platforms and if configured to fully utilize available processing cores and memory, it can be handling data at really large scale.

That’s all great, but what about transforming or aggregating data of same type column-wise? Starting from Spark 2.4, a number of methods for ArrayType (and MapType) columns have been added. But users still feel hand-tied when none of the available methods can deal with something as simple as, say, summing the integer elements of an array.

User-provided lambda functions

A higher-order function allows one to process a collection of elements (of the same data type) in accordance with a user-provided lambda function to specify how the collection content should be transformed or aggregated. The lambda function being part of the function signature makes it possible to process the collection of elements with relatively complex processing logic.

Coupled with the using of method array, higher-order functions are particularly useful when transformation or aggregation across a list of columns (of the same data type) is needed. Below are a few of such functions:

  • filter()
  • exists()
  • transform()
  • aggregate()

The lambda function could either be a unary or binary operator. As will be shown in examples below, function aggregate() requires a binary operator whereas the other functions expect a unary operator.

A caveat

Unless you’re on Spark 3.x, higher-order functions aren’t part of Spark 2.4’s built-in DataFrame API. They are expressed in standard SQL syntax along with a lambda function and need to be passed in as a String via expr(). Hence, to use these functions, one would need to temporarily “exit” the Scala world to assemble proper SQL expressions in the SQL arena.

Let’s create a simple DataFrame for illustrating how these higher-order functions work.

case class Order(price: Double, qty: Int)

val df = Seq(
  (101, 10, Order(1.2, 5), Order(1.0, 3), Order(1.5, 4), Seq("strawberry", "currant")),
  (102, 15, Order(1.5, 6), Order(0.8, 5), Order(1.0, 7), Seq("raspberry", "cherry", "blueberry"))
).toDF("id", "discount", "order1", "order2", "order3", "fruits")

df.show(false)
// +---+--------+--------+--------+--------+------------------------------+
// |id |discount|order1  |order2  |order3  |fruits                        |
// +---+--------+--------+--------+--------+------------------------------+
// |101|10      |[1.2, 5]|[1.0, 3]|[1.5, 4]|[strawberry, currant]         |
// |102|15      |[1.5, 6]|[0.8, 5]|[1.0, 7]|[raspberry, cherry, blueberry]|
// +---+--------+--------+--------+--------+------------------------------+

Function filter()

Here’s an “unofficial” method signature of of filter():

// Scala-style signature of `filter()`
def filter[T](arrayCol: ArrayType[T], fcn: T => Boolean): ArrayType[T]

The following snippet uses filter to extract any fruit item that ends with “berry”.

df.
  withColumn("berries", expr("filter(fruits, x -> x rlike '.*berry')")).
  select("id", "fruits", "berries").
  show(false)
// +---+------------------------------+----------------------+
// |id |fruits                        |berries               |
// +---+------------------------------+----------------------+
// |101|[strawberry, currant]         |[strawberry]          |
// |102|[raspberry, cherry, blueberry]|[raspberry, blueberry]|
// +---+------------------------------+----------------------+

Function transform()

Method signature (unofficial) of transform():

// Scala-style signature of `transform()`
def transform[T, S](arrayCol: ArrayType[T], fcn: T => S): ArrayType[S]

Here’s an example of using transform() to flag any fruit not ending with “berry” with an ‘*’.

df.withColumn(
    "non-berries",
    expr("transform(fruits, x -> case when x rlike '.*berry' then x else concat(x, '*') end)")
  ).
  select("id", "fruits", "non-berries").
  show(false)
// +---+------------------------------+-------------------------------+
// |id |fruits                        |non-berries                    |
// +---+------------------------------+-------------------------------+
// |101|[strawberry, currant]         |[strawberry, currant*]         |
// |102|[raspberry, cherry, blueberry]|[raspberry, cherry*, blueberry]|
// +---+------------------------------+-------------------------------+

So far, we’ve seen how higher-order functions transform data in an ArrayType collection. For the following examples, we’ll illustrate applying the higher-order functions to individual columns (of same data type) by first turning selected columns into a single ArrayType column.

Let’s assemble an array of the individual columns we would like to process across:

val orderCols = df.columns.filter{
  c => "^order\\d+$".r.findFirstIn(c).nonEmpty
}
// orderCols: Array[String] = Array(order1, order2, order3)

Function exists()

Method signature (unofficial) of exists():

// Scala-style signature of `exists()`
def exists[T](arrayCol: ArrayType[T], fcn: T => Boolean): Boolean

An example using exists() to check whether any of the individual orders per row consists of item price below $1.

df.
  withColumn("orders", array(orderCols.map(col): _*)).
  withColumn("sub$-prices", expr("exists(orders, x -> x.price < 1)")).
  select("id", "orders", "sub$-prices").
  show(false)
// +---+------------------------------+-----------+
// |id |orders                        |sub$-prices|
// +---+------------------------------+-----------+
// |101|[[1.2, 5], [1.0, 3], [1.5, 4]]|false      |
// |102|[[1.5, 6], [0.8, 5], [1.0, 7]]|true       |
// +---+------------------------------+-----------+

Function aggregate()

Method signature (unofficial) of aggregate():

// Scala-style signature of `aggregate()`
def aggregate[T, S](arrayCol: ArrayType[T], init: S, fcn: (S, T) => S): ArrayType[S]

The example below shows how to compute discounted total of all the orders per row using aggregate().

df.
  withColumn("orders", array(orderCols.map(col): _*)).
  withColumn("total", expr("aggregate(orders, 0d, (acc, x) -> acc + x.price * x.qty)")).
  withColumn("discounted", $"total" * (lit(1.0) - $"discount"/100.0)).
  select("id", "discount", "orders", "total", "discounted").
  show(false)
// +---+--------+------------------------------+-----+----------+
// |id |discount|orders                        |total|discounted|
// +---+--------+------------------------------+-----+----------+
// |101|10      |[[1.2, 5], [1.0, 3], [1.5, 4]]|15.0 |13.5      |
// |102|15      |[[1.5, 6], [0.8, 5], [1.0, 7]]|20.0 |17.0      |
// +---+--------+------------------------------+-----+----------+

Merging Akka Streams With MergeLatest

Akka Stream comes with a comprehensive set of fan-in/fan-out features for stream processing. It’s worth noting that rather than as substreams, fan-in/fan-out operations take regular streams as input and generate regular streams as output. These operations are different from substreaming which produces nested SubSource or SubFlow instances with operators like groupBy which, in turn, can be merged back into a regular stream via functions like mergeSubstreams.

Fan-in: Zip versus Merge

For fan-in functionalities, they primarily belong to two types of operations: Zip and Merge. One of the main differences between the them is that Zip may combine streams of different element types to generate a stream of tuple-typed elements whereas Merge takes streams of same type and generates a stream of elements (or a stream of collections of elements). Another difference is that the resulting stream emits when each of the input streams has an element for Zip; as opposed to emitting as soon as any one of the input streams has an element for Merge.

Starting v2.6, Akka Stream introduces a few additional flavors of Merge functions such as mergeLatest, mergePreferred, mergePrioritized. In this blog post, we’re going to focus on Merge, in particular, mergeLatest which, unlike most other Merge functions, generates a list of elements for each element emitted from any of the input streams.

MergeLatest

Function mergeLatest takes a couple of parameters: inputPorts which is the number of input streams and eagerClose which specifics whether the stream completes when all upstreams complete (false) or one upstream completes (true).

Let’s try it out using Source.combine, which takes two or more Sources and apply the provided uniform fan-in operator (in this case, MergeLatest):

import akka.stream.scaladsl._
import akka.actor.ActorSystem

implicit val system = ActorSystem("system")

val s1 = Source(1 to 3)
val s2 = Source(11 to 13).throttle(1, 50.millis)
val s3 = Source(101 to 103).throttle(1, 100.millis)

// Source.combine(s1, s2, s3)(Merge[Int](_)).runForeach(println)  // Ordinary Merge
Source.combine(s1, s2, s3)(MergeLatest[Int](_, 0)).runForeach(println)

// Output: 
//
// List(1, 11, 101)
// List(2, 11, 101)
// List(2, 12, 101)
// List(3, 12, 101)
// List(3, 13, 101)
// List(3, 13, 102)
// List(3, 13, 103)

For comparison, had MergeLatest been replaced with the ordinary Merge, the output would be like this:

// Output:
//
// 1
// 11
// 101
// 2
// 12
// 3
// 13
// 102
// 103

As can be seen from Akka Stream’s Flow source code, mergeLatest uses the stream processing operator MergeLatest for the special case of 2 input streams:

def mergeLatest[U >: Out, M](that: Graph[SourceShape[U], M], eagerComplete: Boolean = false): Repr[immutable.Seq[U]] =
  via(mergeLatestGraph(that, eagerComplete))

protected def mergeLatestGraph[U >: Out, M](
    that: Graph[SourceShape[U], M],
      eagerComplete: Boolean): Graph[FlowShape[Out @uncheckedVariance, immutable.Seq[U]], M] =
  GraphDSL.create(that) { implicit b => r =>
    val merge = b.add(MergeLatest[U](2, eagerComplete))
    r ~> merge.in(1)
    FlowShape(merge.in(0), merge.out)
  }

And below is how the MergeLatest operator is implemented:

object MergeLatest {
  def apply[T](inputPorts: Int, eagerComplete: Boolean = false): GraphStage[UniformFanInShape[T, List[T]]] =
    new MergeLatest[T, List[T]](inputPorts, eagerComplete)(_.toList)
}

final class MergeLatest[T, M](val inputPorts: Int, val eagerClose: Boolean)(buildElem: Array[T] => M)
    extends GraphStage[UniformFanInShape[T, M]] {
  require(inputPorts >= 1, "input ports must be >= 1")

  val in: immutable.IndexedSeq[Inlet[T]] = Vector.tabulate(inputPorts)(i => Inlet[T]("MergeLatest.in" + i))
  val out: Outlet[M] = Outlet[M]("MergeLatest.out")
  override val shape: UniformFanInShape[T, M] = UniformFanInShape(out, in: _*)

  override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
    new GraphStageLogic(shape) with OutHandler {
      private val activeStreams: java.util.HashSet[Int] = new java.util.HashSet[Int]()
      private var runningUpstreams: Int = inputPorts
      private def upstreamsClosed: Boolean = runningUpstreams == 0
      private def allMessagesReady: Boolean = activeStreams.size == inputPorts
      private val messages: Array[Any] = new Array[Any](inputPorts)

      override def preStart(): Unit = in.foreach(tryPull)

      in.zipWithIndex.foreach {
        case (input, index) =>
          setHandler(
            input,
            new InHandler {
              override def onPush(): Unit = {
                messages.update(index, grab(input))
                activeStreams.add(index)
                if (allMessagesReady) emit(out, buildElem(messages.asInstanceOf[Array[T]]))
                tryPull(input)
              }

              override def onUpstreamFinish(): Unit = {
                if (!eagerClose) {
                  runningUpstreams -= 1
                  if (upstreamsClosed) completeStage()
                } else completeStage()
              }
            })
      }

      override def onPull(): Unit = {
        var i = 0
        while (i < inputPorts) {
          if (!hasBeenPulled(in(i))) tryPull(in(i))
          i += 1
        }
      }

      setHandler(out, this)
    }

  override def toString = "MergeLatest"
}

As shown in the source code, it’s implemented as a standard GraphStage of UniformFanInShape. The implementation is so modular that repurposing it to do something a little differently can be rather easy.

Repurposing MergeLatest

There was a relevant use case inquiry on Stack Overflow to which I offered a solution for changing the initial stream emission behavior. MergeLatest by design starts emitting the output stream only after each input stream has emitted an initial element, which is somewhat an exception to typical Merge behavior as mentioned earlier. The solution I suggested is to revise the operator to change the emission behavior similar to other Merge operators — i.e. start emitting as soon as one of the input streams has an element by filling in the rest with a user-provided default element.

Below is the repurposed code:

import akka.stream.scaladsl._
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }
import akka.stream.{ Attributes, Inlet, Outlet, UniformFanInShape }
import scala.collection.immutable

object MergeLatestWithDefault {
  def apply[T](inputPorts: Int, default: T, eagerComplete: Boolean = false): GraphStage[UniformFanInShape[T, List[T]]] =
	new MergeLatestWithDefault[T, List[T]](inputPorts, default, eagerComplete)(_.toList)
}

final class MergeLatestWithDefault[T, M](val inputPorts: Int, val default: T, val eagerClose: Boolean)(buildElem: Array[T] => M)
	extends GraphStage[UniformFanInShape[T, M]] {
  require(inputPorts >= 1, "input ports must be >= 1")

  val in: immutable.IndexedSeq[Inlet[T]] = Vector.tabulate(inputPorts)(i => Inlet[T]("MergeLatestWithDefault.in" + i))
  val out: Outlet[M] = Outlet[M]("MergeLatestWithDefault.out")
  override val shape: UniformFanInShape[T, M] = UniformFanInShape(out, in: _*)

  override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
	new GraphStageLogic(shape) with OutHandler {
	  private val activeStreams: java.util.HashSet[Int] = new java.util.HashSet[Int]()
	  private var runningUpstreams: Int = inputPorts
	  private def upstreamsClosed: Boolean = runningUpstreams == 0
	  private val messages: Array[Any] = Array.fill[Any](inputPorts)(default)

	  override def preStart(): Unit = in.foreach(tryPull)

	  in.zipWithIndex.foreach {
		case (input, index) =>
		  setHandler(
			input,
			new InHandler {
			  override def onPush(): Unit = {
				messages.update(index, grab(input))
				activeStreams.add(index)
				emit(out, buildElem(messages.asInstanceOf[Array[T]]))
				tryPull(input)
			  }

			  override def onUpstreamFinish(): Unit = {
				if (!eagerClose) {
				  runningUpstreams -= 1
				  if (upstreamsClosed) completeStage()
				} else completeStage()
			  }
			})
	  }

	  override def onPull(): Unit = {
		var i = 0
		while (i < inputPorts) {
		  if (!hasBeenPulled(in(i))) tryPull(in(i))
		  i += 1
		}
	  }

	  setHandler(out, this)
	}

  override def toString = "MergeLatestWithDefault"
}

Little code change is necessary in this case. Besides an additional parameter for the default element value to be pre-filled in an internal array, the only change is that emit within onPush within the InHandler is no longer conditional.

Testing it out:

import akka.stream.scaladsl._
import akka.actor.ActorSystem

implicit val system = ActorSystem("system")

val s1 = Source(1 to 3)
val s2 = Source(11 to 13).throttle(1, 50.millis)
val s3 = Source(101 to 103).throttle(1, 100.millis)

Source.combine(s1, s2, s3)(MergeLatestWithDefault[Int](_, 0)).runForeach(println)

// Output: 
//
// List(1, 0, 0)
// List(1, 11, 0)
// List(1, 11, 101)
// List(2, 11, 101)
// List(2, 12, 101)
// List(3, 12, 101)
// List(3, 13, 101)
// List(3, 13, 102)
// List(3, 13, 103)

Akka Stream Stateful MapConcat

If you’ve been building applications with Akka Stream in Scala, you would probably have used mapConcat (and perhaps flatMapConcat as well). It’s a handy method for expanding and flattening content of a Stream, much like how `flatMap` operates on an ordinary Scala collection. The method has the following signature:

def mapConcat[T](f: (Out) => Iterable[T]): Repr[T]

Here’s a trivial example using mapConcat:

import akka.actor.ActorSystem
import akka.stream.scaladsl._

implicit val system = ActorSystem("system")

Source(List("alice", "bob", "charle")).
  mapConcat(name => List(s"Hi $name", s"Bye $name")).
  runForeach(println))
// Hi alice
// Bye alice
// Hi bob
// Bye bob
// Hi charle
// Bye charle

A mapConcat with an internal state

A relatively less popular method that allows one to expand and flatten Stream elements while iteratively processing some internal state is statefulMapConcat, with method signature as follows:

def statefulMapConcat[T](f: () => (Out) => Iterable[T]): Repr[T]

Interestingly, method `mapConcat` is just a parametrically restricted version of method `statefulMapConcat`. Here’s how `mapConcat[T]` is implemented in Akka Stream Flow:

def mapConcat[T](f: Out => immutable.Iterable[T]): Repr[T] = statefulMapConcat(() => f)

Example 1: Extracting sections of elements

Let’s look at a simple example that illustrates how `statefulMapConcat` can be used to extract sections of a given Source in accordance with special elements designated for section-start / stop.

val source = Source(List(
    "a", ">>", "b", "c", "<<", "d", "e", ">>", "f", "g", "h", "<<", "i", ">>", "j", "<<", "k"
  ))

val extractFlow = Flow[String].statefulMapConcat { () =>
  val start = ">>"
  val stop = "<<"
  var discard = true
  elem =>
    if (discard) {
      if (elem == start)
        discard = false
      Nil
    }
    else {
      if (elem == stop) {
        discard = true
        Nil
      }
      else
        elem :: Nil
    }
}

source.via(extractFlow).runForeach(x => print(s"$x "))
// b c f g h j 

The internal state in the above example is the mutable Boolean variable `discard` being toggled in accordance with the designated start/stop element to either return an empty Iterable (in this case, `Nil`) or an Iterable consisting the element in a given iteration.

Example 2: Conditional element-wise pairing of streams

Next, we look at a slightly more complex example. Say, we have two Sources of integer elements and we would like to pair up the elements from the two Sources based on some condition provided as a `(Int, Int) => Boolean` function.

def popFirstMatch(ls: List[Int], condF: Int => Boolean): (Option[Int], List[Int]) = {
  ls.find(condF) match {
	case None =>
	  (None, ls)
	case Some(e) => 
	  val idx = ls.indexOf(e)
	  if (idx < 0)
		(None, ls)
	  else {
		val (l, r) = ls.splitAt(idx)
		(r.headOption, l ++ r.tail)
	  }
  }
}

def conditionalZip( first: Source[Int, NotUsed],
					second: Source[Int, NotUsed],
					filler: Int,
					condFcn: (Int, Int) => Boolean ): Source[(Int, Int), NotUsed] = {
  first.zipAll(second, filler, filler).statefulMapConcat{ () =>
	var prevList1 = List.empty[Int]
	var prevList2 = List.empty[Int]
	tuple => tuple match { case (e1, e2) =>
	  if (e2 != filler) {
		if (e1 != filler && condFcn(e1, e2))
		  (e1, e2) :: Nil
		else {
		  if (e1 != filler)
			prevList1 :+= e1
		  prevList2 :+= e2
		  val (opElem1, rest1) = popFirstMatch(prevList1, condFcn(_, e2))
		  opElem1 match {
			case None =>
			  if (e1 != filler) {
				val (opElem2, rest2) = popFirstMatch(prevList2, condFcn(e1, _))
				opElem2 match {
				  case None =>
					Nil
				  case Some(e) =>
					prevList2 = rest2
					(e1, e) :: Nil
				}
			  }
			  else
				Nil
			case Some(e) =>
			  prevList1 = rest1
			  (e, e2) :: Nil
		  }
		}
	  }
	  else
		Nil
	}
  }
}

In the main method `ConditionalZip`, a couple of Lists are maintained for the two Stream Sources to keep track of elements held off in previous iterations to be conditionally consumed in subsequent iterations based on the provided condition function.

Utility method `popFirstMatch` is for extracting the first element in a List that satisfies the condition derived from the condition function. It also returns the resulting List consisting of the remaining elements.

Note that the `filler` elements are for method `zipAll` (available on Akka Stream 2.6+) to cover all elements in the “bigger” Stream Source of the two. The provided `filler` value should be distinguishable from the Stream elements (`Int.Minvalue` in this example) so that the condition logic can be applied accordingly.

Test running `ConditionalZip`:

//// Case 1:
val first = Source(1 :: 2 :: 4 :: 6 :: Nil)
val second = Source(1 :: 2 :: 3 :: 4 :: 5 :: 6 :: 7 :: Nil)

conditionalZip(first, second, Int.MinValue, _ == _).runForeach(println) 
// (1,1)
// (2,2)
// (4,4)
// (6,6)

conditionalZip(first, second, Int.MinValue, _ > _).runForeach(println) 
// (2,1)
// (4,3)
// (6,4)

conditionalZip(first, second, Int.MinValue, _ < _).runForeach(println) 
// (1,2)
// (2,3)
// (4,5)
// (6,7)

//// Case 2:
val first = Source(3 :: 9 :: 5 :: 5 :: 6 :: Nil)
val second = Source(1 :: 3 :: 5 :: 2 :: 5 :: 6 :: Nil)

conditionalZip(first, second, Int.MinValue, _ == _).runForeach(println)
// (3,3)
// (5,5)
// (5,5)
// (6,6)

conditionalZip(first, second, Int.MinValue, _ > _).runForeach(println)
// (3,1)
// (9,3)
// (5,2)
// (6,5)

conditionalZip(first, second, Int.MinValue, _ < _).runForeach(println)
// (3,5)
// (5,6)