Tag Archives: scala

Scala On Spark – Sum Over Periods

This is another programming example in my Scala-on-Spark blog series. While it uses the same minuscule weather data created in the first example of the blog series, it can be viewed as an independent programming exercise.

In this example, we want a table of total precipitation over custom past periods by weather stations. The specific periods in this example are the previous month, previous 3 months, and all previous months. We have data from July through December, and let’s say it’s now January hence the previous month is December.

The result should be like this:

+-------+---------------+---------------+---------------+
|station|precip_prev_1mo|precip_prev_3mo|precip_prev_all|
+-------+---------------+---------------+---------------+
|    100|            3.5|            5.0|            7.5|
|    115|            4.5|            9.0|           10.0|
+-------+---------------+---------------+---------------+

User-defined functions (UDF) will be used in this example. Spark’s UDF supplements its API by allowing the vast library of Scala (or any of the other supported languages) functions to be used. That said, a method from Spark’s API should be picked over an UDF of same functionality as the former would likely perform more optimally.

First, let’s load up the said weather data.

import java.sql.Date

// DataFrame columns:
//   Weather Station ID
//   Start Date of a half-month period
//   Temperature High (in Fahrenheit) over the period
//   Temperature Low (in Fahrenheit) over the period
//   Total Precipitation (in inches) over the period
val weatherDataDF = Seq(
  (100, Date.valueOf("2017-07-01"), 75, 59, 0.0),
  (100, Date.valueOf("2017-07-16"), 77, 59, 0.5),
  (100, Date.valueOf("2017-08-01"), 80, 63, 1.0),
  (100, Date.valueOf("2017-08-16"), 78, 62, 1.0),
  (100, Date.valueOf("2017-09-01"), 74, 59, 0.0),
  (100, Date.valueOf("2017-09-16"), 72, 57, 0.0),
  (100, Date.valueOf("2017-10-01"), 68, 54, 0.0),
  (100, Date.valueOf("2017-10-16"), 66, 54, 0.0),
  (100, Date.valueOf("2017-11-01"), 64, 50, 0.5),
  (100, Date.valueOf("2017-11-16"), 61, 48, 1.0),
  (100, Date.valueOf("2017-12-01"), 59, 46, 2.0),
  (100, Date.valueOf("2017-12-16"), 57, 45, 1.5),
  (115, Date.valueOf("2017-07-01"), 76, 57, 0.0),
  (115, Date.valueOf("2017-07-16"), 76, 56, 1.0),
  (115, Date.valueOf("2017-08-01"), 78, 57, 0.0),
  (115, Date.valueOf("2017-08-16"), 81, 57, 0.0),
  (115, Date.valueOf("2017-09-01"), 77, 54, 0.0),
  (115, Date.valueOf("2017-09-16"), 72, 50, 0.0),
  (115, Date.valueOf("2017-10-01"), 65, 45, 0.0),
  (115, Date.valueOf("2017-10-16"), 59, 40, 1.5),
  (115, Date.valueOf("2017-11-01"), 55, 37, 1.0),
  (115, Date.valueOf("2017-11-16"), 52, 35, 2.0),
  (115, Date.valueOf("2017-12-01"), 45, 30, 3.0),
  (115, Date.valueOf("2017-12-16"), 41, 28, 1.5)
).toDF("station", "start_date", "temp_high", "temp_low", "total_precip")

We first create a DataFrame of precipitation by weather station and month, each with the number of months that lag the current month.

import org.apache.spark.sql.functions._

val monthlyPrecipDF = weatherDataDF.groupBy($"station", last_day($"start_date").as("mo_end")).
  agg(sum($"total_precip").as("monthly_precip")).
  withColumn("mo_lag", months_between(last_day(current_date), $"mo_end")).
  orderBy($"station", $"mo_end")

monthlyPrecipDF.show
// +-------+----------+--------------+------+
// |station|    mo_end|monthly_precip|mo_lag|
// +-------+----------+--------------+------+
// |    100|2017-12-31|           3.5|   1.0|
// |    100|2017-11-30|           1.5|   2.0|
// |    100|2017-10-31|           0.0|   3.0|
// |    100|2017-09-30|           0.0|   4.0|
// |    100|2017-08-31|           2.0|   5.0|
// |    100|2017-07-31|           0.5|   6.0|
// |    115|2017-12-31|           4.5|   1.0|
// |    115|2017-11-30|           3.0|   2.0|
// |    115|2017-10-31|           1.5|   3.0|
// |    115|2017-09-30|           0.0|   4.0|
// |    115|2017-08-31|           0.0|   5.0|
// |    115|2017-07-31|           1.0|   6.0|
// +-------+----------+--------------+------+

Next, we combine the list of months-lagged with monthly precipitation by means of a UDF to create a map column. To do that, we use Scala’s zip method within the UDF to create a list of tuples from the two input lists and convert the resulting list into a map.

// UDF to combine 2 array-type columns to map
def arraysToMap = udf(
  (a: Seq[Double], b: Seq[Double]) => (a zip b).toMap
)

val precipMapDF = monthlyPrecipDF.groupBy("station").agg(
  collect_list($"mo_lag").as("mo_lag_list"),
  collect_list($"monthly_precip").as("precip_list")
).select(
  $"station", arraysToMap($"mo_lag_list", $"precip_list").as("mo_precip_map")
)

precipMapDF.show(truncate=false)
// +-------+---------------------------------------------------------------------------+
// |station|mo_precip_map                                                              |
// +-------+---------------------------------------------------------------------------+
// |115    |Map(5.0 -> 0.0, 1.0 -> 4.5, 6.0 -> 1.0, 2.0 -> 3.0, 3.0 -> 1.5, 4.0 -> 0.0)|
// |100    |Map(5.0 -> 2.0, 1.0 -> 3.5, 6.0 -> 0.5, 2.0 -> 1.5, 3.0 -> 0.0, 4.0 -> 0.0)|
// +-------+---------------------------------------------------------------------------+

Note that the map content might look different depending on when it is generated, as the months-lagged is relative to the current month when the application is run.

Using another UDF to sum precipitation counting backward from the previous months based on the number of months lagged, we create the result DataFrame.

// UDF to aggregate map values for keys less than or equal to x (0 for all)
def aggMapValues = udf(
  (m: Map[Double, Double], x: Double) =>
    if (x > 0) m.map{ case (k, v) => if (k <= x) v else 0 }.sum else
      m.map{ case (k, v) => v }.sum
)

val customPrecipDF = precipMapDF.
  withColumn( "precip_prev_1mo", aggMapValues($"mo_precip_map", lit(1)) ).
  withColumn( "precip_prev_3mo", aggMapValues($"mo_precip_map", lit(3)) ).
  withColumn( "precip_prev_all", aggMapValues($"mo_precip_map", lit(0)) ).
  select( $"station", $"precip_prev_1mo", $"precip_prev_3mo", $"precip_prev_all" )

customPrecipDF.show
// +-------+---------------+---------------+---------------+
// |station|precip_prev_1mo|precip_prev_3mo|precip_prev_all|
// +-------+---------------+---------------+---------------+
// |    115|            4.5|            9.0|           10.0|
// |    100|            3.5|            5.0|            7.5|
// +-------+---------------+---------------+---------------+

Again, note that the months-lagged is relative to the current month when the application is executed, hence the months-lagged parameters for the aggMapValues UDF should be adjusted accordingly.

We can use similar approach to come up with a table for temperature high/low over the custom periods. Below are the steps for creating the result table for temperature high.

val monthlyHighDF = weatherDataDF.groupBy($"station", last_day($"start_date").as("mo_end")).
  agg(max($"temp_high").as("monthly_high")).
  withColumn("mo_lag", months_between(last_day(current_date), $"mo_end"))

monthlyHighDF.orderBy($"station", $"mo_end").show
// +-------+----------+------------+------+
// |station|    mo_end|monthly_high|mo_lag|
// +-------+----------+------------+------+
// |    100|2017-07-31|          77|   6.0|
// |    100|2017-08-31|          80|   5.0|
// |    100|2017-09-30|          74|   4.0|
// |    100|2017-10-31|          68|   3.0|
// |    100|2017-11-30|          64|   2.0|
// |    100|2017-12-31|          59|   1.0|
// |    115|2017-07-31|          76|   6.0|
// |    115|2017-08-31|          81|   5.0|
// |    115|2017-09-30|          77|   4.0|
// |    115|2017-10-31|          65|   3.0|
// |    115|2017-11-30|          55|   2.0|
// |    115|2017-12-31|          45|   1.0|
// +-------+----------+------------+------+

import org.apache.spark.sql.types.DoubleType

val tempHighMapDF = monthlyHighDF.groupBy("station").agg(
  collect_list($"mo_lag").as("mo_lag_list"),
  collect_list($"monthly_high".cast(DoubleType)).as("temp_high_list")
).select(
  $"station", arraysToMap($"mo_lag_list", $"temp_high_list").as("mo_high_map")
)

tempHighMapDF.show(truncate=false)

// UDF to aggregate map values for keys less than or equal to x (0 for all)
def aggMapValues = udf(
  (m: Map[Double, Double], x: Double) =>
    if (x > 0) m.map{ case (k, v) => if (k <= x) v else 0 }.max else
      m.map{ case (k, v) => v }.max
)

val customTempHighDF = tempHighMapDF.
  withColumn( "high_prev_1mo", aggMapValues($"mo_high_map", lit(1)) ).
  withColumn( "high_prev_3mo", aggMapValues($"mo_high_map", lit(3)) ).
  withColumn( "high_prev_all", aggMapValues($"mo_high_map", lit(0)) ).
  select( $"station", $"high_prev_1mo", $"high_prev_3mo", $"high_prev_all" )

customTempHighDF.show
// +-------+-------------+-------------+-------------+
// |station|high_prev_1mo|high_prev_3mo|high_prev_all|
// +-------+-------------+-------------+-------------+
// |    115|         45.0|         65.0|         81.0|
// |    100|         59.0|         68.0|         80.0|
// +-------+-------------+-------------+-------------+

I’ll leave creating the temperature low result table as a programming exercise for the readers. Note that rather than calculating temperature high and low separately, one could aggregate both of them together in some of the steps with little code change. For those who are up for a slightly more challenging exercise, both temperature high and low data can in fact be transformed together in every step of the way.

Scala On Spark – Cumulative Pivot Sum

In a couple of recent R&D projects, I was using Apache Spark rather extensively to address some data processing needs on Hadoop clusters. Although there is an abundance of big data processing platforms these days, it didn’t take long for me to settle on Spark. One of the main reasons is that the programming language for the R&D is Scala, which is what Spark itself is written in. In particular, Spark’s inherent support for functional programming and compositional transformations on immutable data enables high performance at scale as well as readability. Other main reasons are very much in line with some of the key factors attributing to Spark’s rising popularity.

I’m starting a mini blog series on Scala-on-Spark (SoS) with each blog post demonstrating with some Scala programming example on Apache Spark. In the blog series, I’m going to illustrate how the functionality-rich SoS is able to resolve some non-trivial data processing problems with seemingly little effort. If nothing else, they are good brain-teasing programming exercise in Scala on Spark.

As the source data for the example, let’s consider a minuscule set of weather data stored in a DataFrame, which consists of the following columns:

  • Weather Station ID
  • Start Date of a half-month period
  • Temperature High (in Fahrenheit) over the period
  • Temperature Low (in Fahrenheit) over the period
  • Total Precipitation (in inches) over the period

Note that with a properly configured Spark cluster, the methods illustrated in the following example can be readily adapted to handle much more granular data at scale – e.g. down to sub-hourly weather data from tens of thousands of weather stations. It’s also worth mentioning that there can be other ways to solve the problems presented in the examples.

For illustration purpose, the following code snippets are executed on a Spark Shell. First thing is to generate a DataFrame with the said columns of sample data, which will be used as source data for this example and a couple following ones.

import java.sql.Date

val weatherDataDF = Seq(
  (100, Date.valueOf("2017-07-01"), 75, 59, 0.0),
  (100, Date.valueOf("2017-07-16"), 77, 59, 0.5),
  (100, Date.valueOf("2017-08-01"), 80, 63, 1.0),
  (100, Date.valueOf("2017-08-16"), 78, 62, 1.0),
  (100, Date.valueOf("2017-09-01"), 74, 59, 0.0),
  (100, Date.valueOf("2017-09-16"), 72, 57, 0.0),
  (100, Date.valueOf("2017-10-01"), 68, 54, 0.0),
  (100, Date.valueOf("2017-10-16"), 66, 54, 0.0),
  (100, Date.valueOf("2017-11-01"), 64, 50, 0.5),
  (100, Date.valueOf("2017-11-16"), 61, 48, 1.0),
  (100, Date.valueOf("2017-12-01"), 59, 46, 2.0),
  (100, Date.valueOf("2017-12-16"), 57, 45, 1.5),
  (115, Date.valueOf("2017-07-01"), 76, 57, 0.0),
  (115, Date.valueOf("2017-07-16"), 76, 56, 1.0),
  (115, Date.valueOf("2017-08-01"), 78, 57, 0.0),
  (115, Date.valueOf("2017-08-16"), 81, 57, 0.0),
  (115, Date.valueOf("2017-09-01"), 77, 54, 0.0),
  (115, Date.valueOf("2017-09-16"), 72, 50, 0.0),
  (115, Date.valueOf("2017-10-01"), 65, 45, 0.0),
  (115, Date.valueOf("2017-10-16"), 59, 40, 1.5),
  (115, Date.valueOf("2017-11-01"), 55, 37, 1.0),
  (115, Date.valueOf("2017-11-16"), 52, 35, 2.0),
  (115, Date.valueOf("2017-12-01"), 45, 30, 3.0),
  (115, Date.valueOf("2017-12-16"), 41, 28, 1.5)
).toDF("station", "start_date", "temp_high", "temp_low", "total_precip")

In this first example, the goal is to generate a table of cumulative precipitation by weather stations in month-by-month columns. By ‘cumulative sum’, it means the monthly precipitation will be cumulated from one month over to the next one (i.e. rolling sum). In other words, if July’s precipitation is 2 inches and August’s is 1 inch, the figure for August will be 3 inches. The result should look like the following table:

+-------+-------+-------+-------+-------+-------+-------+
|station|2017-07|2017-08|2017-09|2017-10|2017-11|2017-12|
+-------+-------+-------+-------+-------+-------+-------+
|    100|    0.5|    2.5|    2.5|    2.5|    4.0|    7.5|
|    115|    1.0|    1.0|    1.0|    2.5|    5.5|   10.0|
+-------+-------+-------+-------+-------+-------+-------+

First, we transform the original DataFrame to include an additional year-month column, followed by using Spark’s groupBy, pivot and agg methods to generate the pivot table.

import org.apache.spark.sql.functions._

val monthlyData = weatherDataDF.
  withColumn("year_mo", concat(
    year($"start_date"), lit("-"), lpad(month($"start_date"), 2, "0")
  )).
  groupBy("station").pivot("year_mo")
  
val monthlyPrecipDF = monthlyData.agg(sum($"total_precip"))

monthlyPrecipDF.show
// +-------+-------+-------+-------+-------+-------+-------+
// |station|2017-07|2017-08|2017-09|2017-10|2017-11|2017-12|
// +-------+-------+-------+-------+-------+-------+-------+
// |    115|    1.0|    0.0|    0.0|    1.5|    3.0|    4.5|
// |    100|    0.5|    2.0|    0.0|    0.0|    1.5|    3.5|
// +-------+-------+-------+-------+-------+-------+-------+

Next, we assemble a list of the year-month columns and traverse the list using method foldLeft, which is one of the most versatile Scala functions for custom iterative transformations. In this particular case, the data to be transformed by foldLeft is a tuple of (DataFrame, Double). Normally, transforming the DataFrame alone should suffice, but in this case we need an additional value to address to rolling cumulation requirement.

The tuple’s first DataFrame-type element, with monthlyPrecipDF as its initial value, will be transformed using the binary operator function specified as foldLeft’s second argument (i.e. (acc, c) => …). As for the tuple’s second Double-type element, with the first year-month as its initial value it’s for carrying the current month value over to the next iteration. The end result is a (DataFrame, Double) tuple successively transformed month-by-month.

val yearMonths = monthlyPrecipDF.columns.filter(_ != "station")

val cumulativePrecipDF = yearMonths.drop(1).
  foldLeft((monthlyPrecipDF, yearMonths.head))( (acc, c) =>
    ( acc._1.withColumn(c, col(acc._2) + col(c)), c )
)._1

cumulativePrecipDF.show
// +-------+-------+-------+-------+-------+-------+-------+
// |station|2017-07|2017-08|2017-09|2017-10|2017-11|2017-12|
// +-------+-------+-------+-------+-------+-------+-------+
// |    115|    1.0|    1.0|    1.0|    2.5|    5.5|   10.0|
// |    100|    0.5|    2.5|    2.5|    2.5|    4.0|    7.5|
// +-------+-------+-------+-------+-------+-------+-------+

Similar pivot aggregations can be applied to temperature high’s/low’s as well, with method sum replaced with method max/min.

val monthlyHighDF = monthlyData.agg(max($"temp_high").as("high"))

monthlyHighDF.show
// +-------+-------+-------+-------+-------+-------+-------+
// |station|2017-07|2017-08|2017-09|2017-10|2017-11|2017-12|
// +-------+-------+-------+-------+-------+-------+-------+
// |    115|     76|     81|     77|     65|     55|     45|
// |    100|     77|     80|     74|     68|     64|     59|
// +-------+-------+-------+-------+-------+-------+-------+

val monthlyLowDF = monthlyData.agg(min($"temp_low").as("low"))

monthlyLowDF.show
// +-------+-------+-------+-------+-------+-------+-------+
// |station|2017-07|2017-08|2017-09|2017-10|2017-11|2017-12|
// +-------+-------+-------+-------+-------+-------+-------+
// |    115|     56|     57|     50|     40|     35|     28|
// |    100|     59|     62|     57|     54|     48|     45|
// +-------+-------+-------+-------+-------+-------+-------+

Finally, we compute cumulative temperature high/low like cumulative precipitation, by replacing method sum with iterative max/min using Spark’s when-otherwise method.

val cumulativeHighDF = yearMonths.drop(1).
  foldLeft((monthlyHighDF, yearMonths.head))( (acc, c) =>
    ( acc._1.withColumn(c, when(col(acc._2) > col(c), col(acc._2)).
      otherwise(col(c))), c )
)._1

cumulativeHighDF.show
// +-------+-------+-------+-------+-------+-------+-------+
// |station|2017-07|2017-08|2017-09|2017-10|2017-11|2017-12|
// +-------+-------+-------+-------+-------+-------+-------+
// |    115|     76|     81|     81|     81|     81|     81|
// |    100|     77|     80|     80|     80|     80|     80|
// +-------+-------+-------+-------+-------+-------+-------+

val cumulativeLowDF = yearMonths.drop(1).
  foldLeft((monthlyLowDF, yearMonths.head))( (acc, c) =>
    ( acc._1.withColumn(c, when(col(acc._2) < col(c), col(acc._2)).
      otherwise(col(c))), c )
)._1

cumulativeLowDF.show
// +-------+-------+-------+-------+-------+-------+-------+
// |station|2017-07|2017-08|2017-09|2017-10|2017-11|2017-12|
// +-------+-------+-------+-------+-------+-------+-------+
// |    115|     56|     56|     50|     40|     35|     28|
// |    100|     59|     59|     57|     54|     48|     45|
// +-------+-------+-------+-------+-------+-------+-------+

Generic Merge Sort In Scala

Many software engineers may not need to explicitly deal with type parameterization or generic types in their day-to-day job, but it’s very likely that the libraries and frameworks that they’re heavily using have already done their duty to ensure static type-safety via such parametric polymorphism feature.

In a static-typing functional programming language like Scala, such feature would often need to be used first-hand in order to create useful functions that ensure type-safety while keeping the code lean and versatile. Generics is apparently taken seriously in Scala’s inherent language design. That, coupled with Scala’s implicit conversion, constitutes a signature feature of Scala’s. Given Scala’s love of “smileys”, a few of them are designated for the relevant functionalities.

Merge Sort

Merge Sort is a popular text-book sorting algorithm that I think also serves a great brain-teasing programming exercise. I have an old blog post about implementing Merge Sort using Java Generics. In this post, I’m going to use Merge Sort again to illustrate Scala’s type parameterization.

By means of a merge function which recursively merge-sorts the left and right halves of a partitioned list, a basic Merge Sort function for integer sorting might be something similar to the following:

  def mergeSort(ls: List[Int]): List[Int] = {
    def merge(l: List[Int], r: List[Int]): List[Int] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

A quick test …

scala> val li = List(9, 5, 16, 3, 4, 11, 8, 12)
li: List[Int] = List(9, 5, 16, 3, 4, 11, 8, 12)

scala> mergeSort(li)
res1: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

Contrary to Java Generics’ MyClass<T> notation, Scala’s generic types are in the form of MyClass[T]. Let’s generalize the integer Merge Sort as follows:

  def mergeSort[T](ls: List[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }
:15: error: value < is not a member of type parameter T
               if (lHead < rHead)
                         ^

The compiler immediately complains about the ‘<‘ comparison, since T might not be a type that supports ordering for ‘<‘ to make any sense. To generalize the Merge Sort function for any list type that supports ordering, we can supply a parameter in a curried form as follows:

  // Generic Merge Sort using math.Ordering
  import math.Ordering

  def mergeSort[T](ls: List[T])(order: Ordering[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (order.lt(lHead, rHead))
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a)(order), mergeSort(b)(order))
    }
  }

Another quick test …

scala> val li = List(9, 5, 16, 3, 4, 11, 8, 12)
li: List[Int] = List(9, 5, 16, 3, 4, 11, 8, 12)

scala> mergeSort(li)(Ordering[Int])
res2: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

scala> val ls = List("banana", "pear", "orange", "grape", "apple", "strawberry", "guava", "peach")
ls: List[String] = List(banana, pear, orange, grape, apple, strawberry, guava, peach)

scala> mergeSort(ls)(Ordering[String])
res3: List[String] = List(apple, banana, grape, guava, orange, peach, pear, strawberry)

That works well, but it’s cumbersome that one needs to supply the corresponding Ordering[T] for the list type. That’s where implicit parameter can help:

  // Generic Merge Sort using implicit math.Ordering parameter
  import math.Ordering

  def mergeSort[T](ls: List[T])(implicit order: Ordering[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (order.lt(lHead, rHead))
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

Testing again …

scala> mergeSort(li)
res3: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

scala> mergeSort(ls)
res4: List[String] = List(apple, banana, grape, guava, orange, peach, pear, strawberry)

Note that the ‘if (lHead < rHead)’ condition is now replaced with ‘if (order.lt(lHead, rHead))’. That’s because math.Ordering defines its own less-than method for generic types. Let’s dig a little deeper into how it works. Scala’s math.Ordering extends Java’s Comparator interface and implements method compare(x: T, y: T) for all the common types, Int, Long, Float, Double, String, etc. It then provides all these lt(x: T, y: T), gt(x: T, y: T), …, methods that know how to perform all the less-than, greater-than comparisons for various types.

The following are highlights of math.Ordering’s partial source code:

// Scala's math.Ordering source code highlight
import java.util.Comparator
...

trait Ordering[T] extends Comparator[T] with PartialOrdering[T] with Serializable {
  ...
  def compare(x: T, y: T): Int
  ...
  override def lt(x: T, y: T): Boolean = compare(x, y) < 0
  override def gt(x: T, y: T): Boolean = compare(x, y) > 0
  ...
  class Ops(lhs: T) {
    def <(rhs: T) = lt(lhs, rhs)
    def >(rhs: T) = gt(lhs, rhs)
    ...
  }
  implicit def mkOrderingOps(lhs: T): Ops = new Ops(lhs)
  ...
}

...

object Ordering extends LowPriorityOrderingImplicits {
  def apply[T](implicit ord: Ordering[T]) = ord
  ...
  trait IntOrdering extends Ordering[Int] {
    def compare(x: Int, y: Int) =
      if (x < y) -1
      else if (x == y) 0
      else 1
  }
  implicit object Int extends IntOrdering
  ...
  trait StringOrdering extends Ordering[String] {
    def compare(x: String, y: String) = x.compareTo(y)
  }
  implicit object String extends StringOrdering
  ...
}

Context Bound

Scala provides a typeclass pattern called Context Bound which represents such common pattern of passing in an implicit value:

  // Implicit value passed in as implicit parameter
  def someFunction[T](x: SomeClass[T])(implicit imp: AnotherClass[T]): WhateverClass[T] = {
    ...
  }

With the context bound syntactic sugar, it becomes:

  // Context Bound
  def someFunction[T : AnotherClass](x: SomeClass[T]): WhateverClass[T] = {
    ...
  }

The mergeSort function using context bound looks as follows:

  // Generic Merge Sort using Context Bound
  def mergeSort[T : Ordering](ls: List[T]): List[T] = {
    val order = implicitly[Ordering[T]]
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (order.lt(lHead, rHead))
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

Note that ‘implicitly[Ordering[T]]’ is there for access to methods in math.Ordering which is no longer passed in with a parameter name.

Scala’s math.Ordered versus math.Ordering

One noteworthy thing about math.Ordering is that it does not overload comparison operators ‘<‘, ‘>‘, etc, which is why method lt(x: T, y: T) is used instead in mergeSort for the ‘<‘ operator. To use comparison operators like ‘<‘, one would need to import order.mkOrderingOps (or order._) within the mergeSort function. That’s because in math.Ordering, comparison operators ‘<‘, ‘>‘, etc, are all defined in inner class Ops which can be instantiated by calling method mkOrderingOps.

Scala’s math.Ordered extends Java’s Comparable interface (instead of Comparator) and implements method compareTo(y: T), derived from math.Ordering’s compare(x: T, y: T) via implicit parameter. One nice thing about math.Ordered is that it consists of overloaded comparison operators.

The following highlights partial source code of math.Ordered:

// Scala's math.Ordered source code highlight
trait Ordered[A] extends Any with java.lang.Comparable[A] {
  ...
  def compare(that: A): Int
  def <  (that: A): Boolean = (this compare that) <  0
  def >  (that: A): Boolean = (this compare that) >  0
  def <= (that: A): Boolean = (this compare that) <= 0
  def >= (that: A): Boolean = (this compare that) >= 0
  def compareTo(that: A): Int = compare(that)
}

object Ordered {
  implicit def orderingToOrdered[T](x: T)(implicit ord: Ordering[T]): Ordered[T] = new Ordered[T] {
      def compare(that: T): Int = ord.compare(x, that)
  }
}

Using math.Ordered, an implicit method, implicit orderer: T => Ordered[T], (as opposed to an implicit value when using math.Ordering) is passed to the mergeSort function as a curried parameter. As illustrated in a previous blog post, it’s an implicit conversion rule for the compiler to fall back to when encountering problem associated with type T.

Below is a version of generic Merge Sort using math.Ordered:

  // Generic Merge Sort using implicit math.Ordered conversion
  import math.Ordered

  def mergeSort[T](ls: List[T])(implicit orderer: T => Ordered[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }
scala> mergeSort(li)
res5: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

scala> mergeSort(ls)
res6: List[String] = List(apple, banana, grape, guava, orange, peach, pear, strawberry)

View Bound

A couple of notes:

  1. The implicit method ‘implicit orderer: T => Ordered[T]’ is passed into the mergeSort function also as an implicit parameter.
  2. Function mergeSort has a signature of the following common form:
  // Implicit method passed in as implicit parameter
  def someFunction[T](x: SomeClass[T])(implicit imp: T => AnotherClass[T]): WhateverClass[T] = {
    ...
  }

Such pattern of implicit method passed in as implicit paramter is so common that it’s given the term called View Bound and awarded a designated smiley ‘<%’. Using view bound, it can be expressed as:

  // View Bound
  def someFunction[T <% AnotherClass[T]](x: SomeClass[T]): WhateverClass[T] = {
    ...
  }

Applying to the mergeSort function, it gives a slightly more lean and mean look:

  // Generic Merge Sort using view bound
  import math.Ordered

  def mergeSort[T <% Ordered[T]](ls: List[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

As a side note, while the view bound looks like the other smiley ‘<:’ (Upper Bound), they represent very different things. An upper bound is commonly seen in the following form:

  // Upper Bound
  def someFunction[T <: S](x: T): R = {
    ...
  }

This means someFunction takes only input parameter of type T that is a sub-type of (or the same as) type S. While at it, a Lower Bound represented by the ‘>:’ smiley in the form of [T >: S] means the input parameter can only be a super-type of (or the same as) type S.