Tag Archives: scala generics

Generic Merge Sort In Scala

Many software engineers may not need to explicitly deal with type parameterization or generic types in their day-to-day job, but it’s very likely that the libraries and frameworks that they’re heavily using have already done their duty to ensure static type-safety via such parametric polymorphism feature.

In a static-typing functional programming language like Scala, such feature would often need to be used first-hand in order to create useful functions that ensure type-safety while keeping the code lean and versatile. Generics is apparently taken seriously in Scala’s inherent language design. That, coupled with Scala’s implicit conversion, constitutes a signature feature of Scala’s. Given Scala’s love of “smileys”, a few of them are designated for the relevant functionalities.

Merge Sort

Merge Sort is a popular text-book sorting algorithm that I think also serves a great brain-teasing programming exercise. I have an old blog post about implementing Merge Sort using Java Generics. In this post, I’m going to use Merge Sort again to illustrate Scala’s type parameterization.

By means of a merge function which recursively merge-sorts the left and right halves of a partitioned list, a basic Merge Sort function for integer sorting might be something similar to the following:

  def mergeSort(ls: List[Int]): List[Int] = {
    def merge(l: List[Int], r: List[Int]): List[Int] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

A quick test …

scala> val li = List(9, 5, 16, 3, 4, 11, 8, 12)
li: List[Int] = List(9, 5, 16, 3, 4, 11, 8, 12)

scala> mergeSort(li)
res1: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

Contrary to Java Generics’ MyClass<T> notation, Scala’s generic types are in the form of MyClass[T]. Let’s generalize the integer Merge Sort as follows:

  def mergeSort[T](ls: List[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }
:15: error: value < is not a member of type parameter T
               if (lHead < rHead)
                         ^

The compiler immediately complains about the ‘<‘ comparison, since T might not be a type that supports ordering for ‘<‘ to make any sense. To generalize the Merge Sort function for any list type that supports ordering, we can supply a parameter in a curried form as follows:

  // Generic Merge Sort using math.Ordering
  import math.Ordering

  def mergeSort[T](ls: List[T])(order: Ordering[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (order.lt(lHead, rHead))
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a)(order), mergeSort(b)(order))
    }
  }

Another quick test …

scala> val li = List(9, 5, 16, 3, 4, 11, 8, 12)
li: List[Int] = List(9, 5, 16, 3, 4, 11, 8, 12)

scala> mergeSort(li)(Ordering[Int])
res2: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

scala> val ls = List("banana", "pear", "orange", "grape", "apple", "strawberry", "guava", "peach")
ls: List[String] = List(banana, pear, orange, grape, apple, strawberry, guava, peach)

scala> mergeSort(ls)(Ordering[String])
res3: List[String] = List(apple, banana, grape, guava, orange, peach, pear, strawberry)

That works well, but it’s cumbersome that one needs to supply the corresponding Ordering[T] for the list type. That’s where implicit parameter can help:

  // Generic Merge Sort using implicit math.Ordering parameter
  import math.Ordering

  def mergeSort[T](ls: List[T])(implicit order: Ordering[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (order.lt(lHead, rHead))
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

Testing again …

scala> mergeSort(li)
res3: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

scala> mergeSort(ls)
res4: List[String] = List(apple, banana, grape, guava, orange, peach, pear, strawberry)

Note that the ‘if (lHead < rHead)’ condition is now replaced with ‘if (order.lt(lHead, rHead))’. That’s because math.Ordering defines its own less-than method for generic types. Let’s dig a little deeper into how it works. Scala’s math.Ordering extends Java’s Comparator interface and implements method compare(x: T, y: T) for all the common types, Int, Long, Float, Double, String, etc. It then provides all these lt(x: T, y: T), gt(x: T, y: T), …, methods that know how to perform all the less-than, greater-than comparisons for various types.

The following are highlights of math.Ordering’s partial source code:

// Scala's math.Ordering source code highlight
import java.util.Comparator
...

trait Ordering[T] extends Comparator[T] with PartialOrdering[T] with Serializable {
  ...
  def compare(x: T, y: T): Int
  ...
  override def lt(x: T, y: T): Boolean = compare(x, y) < 0
  override def gt(x: T, y: T): Boolean = compare(x, y) > 0
  ...
  class Ops(lhs: T) {
    def <(rhs: T) = lt(lhs, rhs)
    def >(rhs: T) = gt(lhs, rhs)
    ...
  }
  implicit def mkOrderingOps(lhs: T): Ops = new Ops(lhs)
  ...
}

...

object Ordering extends LowPriorityOrderingImplicits {
  def apply[T](implicit ord: Ordering[T]) = ord
  ...
  trait IntOrdering extends Ordering[Int] {
    def compare(x: Int, y: Int) =
      if (x < y) -1
      else if (x == y) 0
      else 1
  }
  implicit object Int extends IntOrdering
  ...
  trait StringOrdering extends Ordering[String] {
    def compare(x: String, y: String) = x.compareTo(y)
  }
  implicit object String extends StringOrdering
  ...
}

Context Bound

Scala provides a typeclass pattern called Context Bound which represents such common pattern of passing in an implicit value:

  // Implicit value passed in as implicit parameter
  def someFunction[T](x: SomeClass[T])(implicit imp: AnotherClass[T]): WhateverClass[T] = {
    ...
  }

With the context bound syntactic sugar, it becomes:

  // Context Bound
  def someFunction[T : AnotherClass](x: SomeClass[T]): WhateverClass[T] = {
    ...
  }

The mergeSort function using context bound looks as follows:

  // Generic Merge Sort using Context Bound
  def mergeSort[T : Ordering](ls: List[T]): List[T] = {
    val order = implicitly[Ordering[T]]
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (order.lt(lHead, rHead))
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

Note that ‘implicitly[Ordering[T]]’ is there for access to methods in math.Ordering which is no longer passed in with a parameter name.

Scala’s math.Ordered versus math.Ordering

One noteworthy thing about math.Ordering is that it does not overload comparison operators ‘<‘, ‘>‘, etc, which is why method lt(x: T, y: T) is used instead in mergeSort for the ‘<‘ operator. To use comparison operators like ‘<‘, one would need to import order.mkOrderingOps (or order._) within the mergeSort function. That’s because in math.Ordering, comparison operators ‘<‘, ‘>‘, etc, are all defined in inner class Ops which can be instantiated by calling method mkOrderingOps.

Scala’s math.Ordered extends Java’s Comparable interface (instead of Comparator) and implements method compareTo(y: T), derived from math.Ordering’s compare(x: T, y: T) via implicit parameter. One nice thing about math.Ordered is that it consists of overloaded comparison operators.

The following highlights partial source code of math.Ordered:

// Scala's math.Ordered source code highlight
trait Ordered[A] extends Any with java.lang.Comparable[A] {
  ...
  def compare(that: A): Int
  def <  (that: A): Boolean = (this compare that) <  0
  def >  (that: A): Boolean = (this compare that) >  0
  def <= (that: A): Boolean = (this compare that) <= 0
  def >= (that: A): Boolean = (this compare that) >= 0
  def compareTo(that: A): Int = compare(that)
}

object Ordered {
  implicit def orderingToOrdered[T](x: T)(implicit ord: Ordering[T]): Ordered[T] = new Ordered[T] {
      def compare(that: T): Int = ord.compare(x, that)
  }
}

Using math.Ordered, an implicit method, implicit orderer: T => Ordered[T], (as opposed to an implicit value when using math.Ordering) is passed to the mergeSort function as a curried parameter. As illustrated in a previous blog post, it’s an implicit conversion rule for the compiler to fall back to when encountering problem associated with type T.

Below is a version of generic Merge Sort using math.Ordered:

  // Generic Merge Sort using implicit math.Ordered conversion
  import math.Ordered

  def mergeSort[T](ls: List[T])(implicit orderer: T => Ordered[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }
scala> mergeSort(li)
res5: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

scala> mergeSort(ls)
res6: List[String] = List(apple, banana, grape, guava, orange, peach, pear, strawberry)

View Bound

A couple of notes:

  1. The implicit method ‘implicit orderer: T => Ordered[T]’ is passed into the mergeSort function also as an implicit parameter.
  2. Function mergeSort has a signature of the following common form:
  // Implicit method passed in as implicit parameter
  def someFunction[T](x: SomeClass[T])(implicit imp: T => AnotherClass[T]): WhateverClass[T] = {
    ...
  }

Such pattern of implicit method passed in as implicit paramter is so common that it’s given the term called View Bound and awarded a designated smiley ‘<%’. Using view bound, it can be expressed as:

  // View Bound
  def someFunction[T <% AnotherClass[T]](x: SomeClass[T]): WhateverClass[T] = {
    ...
  }

Applying to the mergeSort function, it gives a slightly more lean and mean look:

  // Generic Merge Sort using view bound
  import math.Ordered

  def mergeSort[T <% Ordered[T]](ls: List[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

As a side note, while the view bound looks like the other smiley ‘<:’ (Upper Bound), they represent very different things. An upper bound is commonly seen in the following form:

  // Upper Bound
  def someFunction[T <: S](x: T): R = {
    ...
  }

This means someFunction takes only input parameter of type T that is a sub-type of (or the same as) type S. While at it, a Lower Bound represented by the ‘>:’ smiley in the form of [T >: S] means the input parameter can only be a super-type of (or the same as) type S.