Monthly Archives: August 2016

Generic Merge Sort In Scala

Many software engineers may not need to explicitly deal with type parameterization or generic types in their day-to-day job, but it’s very likely that the libraries and frameworks that they’re heavily using have already done their duty to ensure static type-safety via such parametric polymorphism feature.

In a static-typing functional programming language like Scala, such feature would often need to be used first-hand in order to create useful functions that ensure type-safety while keeping the code lean and versatile. Generics is apparently taken seriously in Scala’s inherent language design. That, coupled with Scala’s implicit conversion, constitutes a signature feature of Scala’s. Given Scala’s love of “smileys”, a few of them are designated for the relevant functionalities.

Merge Sort

Merge Sort is a popular text-book sorting algorithm that I think also serves a great brain-teasing programming exercise. I have an old blog post about implementing Merge Sort using Java Generics. In this post, I’m going to use Merge Sort again to illustrate Scala’s type parameterization.

By means of a merge function which recursively merge-sorts the left and right halves of a partitioned list, a basic Merge Sort function for integer sorting might be something similar to the following:

  def mergeSort(ls: List[Int]): List[Int] = {
    def merge(l: List[Int], r: List[Int]): List[Int] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

A quick test …

scala> val li = List(9, 5, 16, 3, 4, 11, 8, 12)
li: List[Int] = List(9, 5, 16, 3, 4, 11, 8, 12)

scala> mergeSort(li)
res1: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

Contrary to Java Generics’ MyClass<T> notation, Scala’s generic types are in the form of MyClass[T]. Let’s generalize the integer Merge Sort as follows:

  def mergeSort[T](ls: List[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }
:15: error: value < is not a member of type parameter T
               if (lHead < rHead)
                         ^

The compiler immediately complains about the ‘<‘ comparison, since T might not be a type that supports ordering for ‘<‘ to make any sense. To generalize the Merge Sort function for any list type that supports ordering, we can supply a parameter in a curried form as follows:

  // Generic Merge Sort using math.Ordering
  import math.Ordering

  def mergeSort[T](ls: List[T])(order: Ordering[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (order.lt(lHead, rHead))
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a)(order), mergeSort(b)(order))
    }
  }

Another quick test …

scala> val li = List(9, 5, 16, 3, 4, 11, 8, 12)
li: List[Int] = List(9, 5, 16, 3, 4, 11, 8, 12)

scala> mergeSort(li)(Ordering[Int])
res2: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

scala> val ls = List("banana", "pear", "orange", "grape", "apple", "strawberry", "guava", "peach")
ls: List[String] = List(banana, pear, orange, grape, apple, strawberry, guava, peach)

scala> mergeSort(ls)(Ordering[String])
res3: List[String] = List(apple, banana, grape, guava, orange, peach, pear, strawberry)

That works well, but it’s cumbersome that one needs to supply the corresponding Ordering[T] for the list type. That’s where implicit parameter can help:

  // Generic Merge Sort using implicit math.Ordering parameter
  import math.Ordering

  def mergeSort[T](ls: List[T])(implicit order: Ordering[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (order.lt(lHead, rHead))
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

Testing again …

scala> mergeSort(li)
res3: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

scala> mergeSort(ls)
res4: List[String] = List(apple, banana, grape, guava, orange, peach, pear, strawberry)

Note that the ‘if (lHead < rHead)’ condition is now replaced with ‘if (order.lt(lHead, rHead))’. That’s because math.Ordering defines its own less-than method for generic types. Let’s dig a little deeper into how it works. Scala’s math.Ordering extends Java’s Comparator interface and implements method compare(x: T, y: T) for all the common types, Int, Long, Float, Double, String, etc. It then provides all these lt(x: T, y: T), gt(x: T, y: T), …, methods that know how to perform all the less-than, greater-than comparisons for various types.

The following are highlights of math.Ordering’s partial source code:

// Scala's math.Ordering source code highlight
import java.util.Comparator
...

trait Ordering[T] extends Comparator[T] with PartialOrdering[T] with Serializable {
  ...
  def compare(x: T, y: T): Int
  ...
  override def lt(x: T, y: T): Boolean = compare(x, y) < 0
  override def gt(x: T, y: T): Boolean = compare(x, y) > 0
  ...
  class Ops(lhs: T) {
    def <(rhs: T) = lt(lhs, rhs)
    def >(rhs: T) = gt(lhs, rhs)
    ...
  }
  implicit def mkOrderingOps(lhs: T): Ops = new Ops(lhs)
  ...
}

...

object Ordering extends LowPriorityOrderingImplicits {
  def apply[T](implicit ord: Ordering[T]) = ord
  ...
  trait IntOrdering extends Ordering[Int] {
    def compare(x: Int, y: Int) =
      if (x < y) -1
      else if (x == y) 0
      else 1
  }
  implicit object Int extends IntOrdering
  ...
  trait StringOrdering extends Ordering[String] {
    def compare(x: String, y: String) = x.compareTo(y)
  }
  implicit object String extends StringOrdering
  ...
}

Context Bound

Scala provides a typeclass pattern called Context Bound which represents such common pattern of passing in an implicit value:

  // Implicit value passed in as implicit parameter
  def someFunction[T](x: SomeClass[T])(implicit imp: AnotherClass[T]): WhateverClass[T] = {
    ...
  }

With the context bound syntactic sugar, it becomes:

  // Context Bound
  def someFunction[T : AnotherClass](x: SomeClass[T]): WhateverClass[T] = {
    ...
  }

The mergeSort function using context bound looks as follows:

  // Generic Merge Sort using Context Bound
  def mergeSort[T : Ordering](ls: List[T]): List[T] = {
    val order = implicitly[Ordering[T]]
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (order.lt(lHead, rHead))
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

Note that ‘implicitly[Ordering[T]]’ is there for access to methods in math.Ordering which is no longer passed in with a parameter name.

Scala’s math.Ordered versus math.Ordering

One noteworthy thing about math.Ordering is that it does not overload comparison operators ‘<‘, ‘>‘, etc, which is why method lt(x: T, y: T) is used instead in mergeSort for the ‘<‘ operator. To use comparison operators like ‘<‘, one would need to import order.mkOrderingOps (or order._) within the mergeSort function. That’s because in math.Ordering, comparison operators ‘<‘, ‘>‘, etc, are all defined in inner class Ops which can be instantiated by calling method mkOrderingOps.

Scala’s math.Ordered extends Java’s Comparable interface (instead of Comparator) and implements method compareTo(y: T), derived from math.Ordering’s compare(x: T, y: T) via implicit parameter. One nice thing about math.Ordered is that it consists of overloaded comparison operators.

The following highlights partial source code of math.Ordered:

// Scala's math.Ordered source code highlight
trait Ordered[A] extends Any with java.lang.Comparable[A] {
  ...
  def compare(that: A): Int
  def <  (that: A): Boolean = (this compare that) <  0
  def >  (that: A): Boolean = (this compare that) >  0
  def <= (that: A): Boolean = (this compare that) <= 0
  def >= (that: A): Boolean = (this compare that) >= 0
  def compareTo(that: A): Int = compare(that)
}

object Ordered {
  implicit def orderingToOrdered[T](x: T)(implicit ord: Ordering[T]): Ordered[T] = new Ordered[T] {
      def compare(that: T): Int = ord.compare(x, that)
  }
}

Using math.Ordered, an implicit method, implicit orderer: T => Ordered[T], (as opposed to an implicit value when using math.Ordering) is passed to the mergeSort function as a curried parameter. As illustrated in a previous blog post, it’s an implicit conversion rule for the compiler to fall back to when encountering problem associated with type T.

Below is a version of generic Merge Sort using math.Ordered:

  // Generic Merge Sort using implicit math.Ordered conversion
  import math.Ordered

  def mergeSort[T](ls: List[T])(implicit orderer: T => Ordered[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }
scala> mergeSort(li)
res5: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

scala> mergeSort(ls)
res6: List[String] = List(apple, banana, grape, guava, orange, peach, pear, strawberry)

View Bound

A couple of notes:

  1. The implicit method ‘implicit orderer: T => Ordered[T]’ is passed into the mergeSort function also as an implicit parameter.
  2. Function mergeSort has a signature of the following common form:
  // Implicit method passed in as implicit parameter
  def someFunction[T](x: SomeClass[T])(implicit imp: T => AnotherClass[T]): WhateverClass[T] = {
    ...
  }

Such pattern of implicit method passed in as implicit paramter is so common that it’s given the term called View Bound and awarded a designated smiley ‘<%’. Using view bound, it can be expressed as:

  // View Bound
  def someFunction[T <% AnotherClass[T]](x: SomeClass[T]): WhateverClass[T] = {
    ...
  }

Applying to the mergeSort function, it gives a slightly more lean and mean look:

  // Generic Merge Sort using view bound
  import math.Ordered

  def mergeSort[T <% Ordered[T]](ls: List[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

As a side note, while the view bound looks like the other smiley ‘<:’ (Upper Bound), they represent very different things. An upper bound is commonly seen in the following form:

  // Upper Bound
  def someFunction[T <: S](x: T): R = {
    ...
  }

This means someFunction takes only input parameter of type T that is a sub-type of (or the same as) type S. While at it, a Lower Bound represented by the ‘>:’ smiley in the form of [T >: S] means the input parameter can only be a super-type of (or the same as) type S.

Implicit Conversion In Scala

These days, software engineers with knowledge of robust frameworks/libraries are abundant, but those who fully command the core basics of a language platform remain scarce. When required to come up with coding solutions to perform, scale or resolve tricky bugs, a good understanding of the programming language’s core features is often the real deal.

Scala’s signature strengths

Having immersed in a couple of R&D projects using Scala (along with Akka actors) over the past 6 months, I’ve come to appreciate quite a few things it offers. Aside from an obvious signature strength of being a good hybrid of functional programming and object-oriented programming, others include implicit conversion, type parametrization and futures/promises. In addition, Akka actors coupled with Scala make a highly scalable concurrency solution applicable to many distributed systems including IoT systems.

In this blog post, I’m going to talk about Scala’s implicit conversion which I think is part of the language’s core basics. For illustration purpose, simple arithmetics of complex numbers will be implemented using the very feature.

A basic complex-number class would probably look something like the following:

class Complex(r: Double, i: Double) {
  val re: Double = r
  val im: Double = i
  override def toString =
    ...
  // Methods for arithmetic operations +|-|*|/
  ...
}

Since a complex number can have zero imaginary component leaving only the real component, it’s handy to have an auxiliary constructor for those real-only cases as follows:

  // Auxiliary constructor
  def this(x: Double) = this(x, 0)

Just a side note, an auxiliary constructor must invoke another constructor of the class as its first action and cannot invoke a superclass constructor.

Next, let’s override method toString to cover various cases of how a x + yi complex number would look:

  // Override method toString
  override def toString =
    if (re != 0)
      if (im > 0)
        re + " + " + im + "i"
      else if (im < 0)
        re + " - " + -im + "i"
      else  // im == 0
        re.toString
    else  // re == 0
      if (im != 0)
        im + "i"
      else
        re.toString

Let’s also fill out the section for the basic arithmetic operations:

  // Overload methods for arithmetic operations
  def + (that: Complex): Complex =
    new Complex(re + that.re, im + that.im)
  def - (that: Complex): Complex =
    new Complex(re - that.re, im - that.im)
  def * (that: Complex): Complex =
    new Complex(re * that.re - im * that.im, re * that.im + im * that.re)
  def / (that: Complex): Complex = {
    require(that.re != 0 || that.im != 0)
    val den = that.re * that.re + that.im * that.im
    new Complex((re * that.re + im * that.im) / den, (- re * that.im + im * that.re) / den)
  }

Testing it out …

scala> val a = new Complex(1.0, 2.0)
a: Complex = 1.0 + 2.0i

scala> val b = new Complex(4.0, -3.0)
b: Complex = 4.0 - 3.0i

scala> val c = new Complex(2.0)
c: Complex = 2.0

scala> a + b
res0: Complex = 5.0 - 1.0i

scala> a - b
res1: Complex = -3.0 + 5.0i

scala> a * b
res2: Complex = 10.0 + 5.0i

scala> a / b
res3: Complex = -0.08 + 0.44i

So far so good. But what about this?

scala> a + 1.0
:3: error: type mismatch;
 found   : Double(1.0)
 required: Complex
       a + 1.0
           ^

The compiler complains because it does not know how to handle arithmetic operations between a Complex and a Double. With the auxiliary constructor, ‘a + new Complex(1.0)’ will compile fine, but it’s cumbersome to have to represent every real-only complex number that way. We could resolve the problem by adding methods like the following for the ‘+’ method:

  def + (that: Complex): Complex =
    new Complex(re + that.re, im + that.im)
  def + (x: Double): Complex =
    new Complex(re + x, im)

But then what about this?

scala> 2.0 + b
:4: error: overloaded method value + with alternatives:
  (x: Double)Double 
  (x: Float)Double 
  (x: Long)Double 
  (x: Int)Double 
  (x: Char)Double 
  (x: Short)Double 
  (x: Byte)Double 
  (x: String)String
 cannot be applied to (Complex)
       2.0 + b
           ^

The compiler interprets ‘a + 1.0’ as a.+(1.0). Since a is a Complex, the proposed new ‘+’ method in the Complex class can handle it. But ‘2.0 + b’ will fail because there isn’t a ‘+’ method in Double that can handle Complex. This is where implicit conversion shines.

  // Implicit conversion
  implicit def realToComplex(x: Double) = new Complex(x, 0)

The implicit method realToComplex hints the compiler to fall back to using the method when it encounters a compilation problem associated with type Double. In many cases, the implicit methods would never be explicitly called thus their name can be pretty much arbitrary. For instance, renaming realToComplex to foobar in this case would get the same job done.

As a bonus, arithmetic operations between Complex and Integer (or Long, Float) would work too. That’s because Scala already got, for instance, integer-to-double covered internally using implicit conversion in object Int, and in version 2.9.x or older, object Predef:

// Scala implicit conversions in class Int
final abstract class Int private extends AnyVal {
  ...
  import scala.language.implicitConversions
  implicit def int2long(x: Int): Long = x.toLong
  implicit def int2float(x: Int): Float = x.toFloat
  implicit def int2double(x: Int): Double = x.toDouble
}

Testing again …

scala> val a = new Complex(1.0, 2.0)
a: Complex = 1.0 + 2.0i

scala> val b = new Complex(4.0, -3.0)
b: Complex = 4.0 - 3.0i

scala> a + 1.5
res11: Complex = 2.5 + 2.0i

scala> 3 - b
res12: Complex = -1.0 + 3.0i

scala> a * 2
res13: Complex = 2.0 + 4.0i

scala> val x: Long = 50
x: Long = 50

scala> a + x
res14: Complex = 51.0 + 2.0i

scala> val y: Float = 6.6f
y: Float = 6.6

scala> val y: Float = 4.0f
y: Float = 4.0

scala> y / b
res15: Complex = 0.64 + 0.48i

Implicit conversion scope

To ensure the implicit conversion rule to be effective when you use the Complex class, we need to keep it in scope. By defining the implicit method or importing a snippet containing the method in the current scope, it’ll certainly serve us well. An alternative is to define it in a companion object as follows:

// Instantiation by constructor
object Complex {
  implicit def realToComplex(x: Double) = new Complex(x, 0)
}

class Complex(r: Double, i: Double) {
  val re: Double = r
  val im: Double = i
  def this(x: Double) = this(x, 0)  // Auxiliary constructor
  override def toString =
    if (re != 0)
      if (im > 0)
        re + " + " + im + "i"
      else if (im < 0)
        re + " - " + -im + "i"
      else  // im == 0
        re.toString
    else  // re == 0
      if (im != 0)
        im + "i"
      else
        re.toString
  // Methods for arithmetic operations
  def + (that: Complex): Complex =
    new Complex(re + that.re, im + that.im)
  def - (that: Complex): Complex =
    new Complex(re - that.re, im - that.im)
  def * (that: Complex): Complex =
    new Complex(re * that.re - im * that.im, re * that.im + im * that.re)
  def / (that: Complex): Complex = {
    require(that.re != 0 || that.im != 0)
    val den = that.re * that.re + that.im * that.im
    new Complex((re * that.re + im * that.im) / den, (- re * that.im + im * that.re) / den)
  }
}

As a final note, in case factory method is preferred thus removing the need for the ‘new’ keyword in instantiation, we could slightly modify the companion object/class as follows:

// Instantiation by factory
object Complex {
  implicit def realToComplex(x: Double) = new Complex(x, 0)
  def apply(r: Double, i: Double) = new Complex(r, i)
  def apply(r: Double) = new Complex(r, 0)
}

class Complex private (r: Double, i: Double) {
  val re: Double = r
  val im: Double = i
  override def toString =
    if (re != 0)
      if (im > 0)
        re + " + " + im + "i"
      else if (im < 0)
        re + " - " + -im + "i"
      else  // im == 0
        re.toString
    else  // re == 0
      if (im != 0)
        im + "i"
      else
        re.toString
  // Methods for arithmetic operations
  def + (that: Complex): Complex =
    new Complex(re + that.re, im + that.im)
  def - (that: Complex): Complex =
    new Complex(re - that.re, im - that.im)
  def * (that: Complex): Complex =
    new Complex(re * that.re - im * that.im, re * that.im + im * that.re)
  def / (that: Complex): Complex = {
    require(that.re != 0 || that.im != 0)
    val den = that.re * that.re + that.im * that.im
    new Complex((re * that.re + im * that.im) / den, (- re * that.im + im * that.re) / den)
  }
}

Another quick test …

scala> val a = Complex(1, 2)
a: Complex = 1.0 + 2.0i

scala> val b = Complex(4.0, -3.0)
b: Complex = 4.0 - 3.0i

scala> a + 2
res16: Complex = 3.0 + 2.0i

scala> 3 * b
res17: Complex = 12.0 - 9.0i