Tag Archives: merge sort

Generic Merge Sort In Scala

Many software engineers may not need to explicitly deal with type parameterization or generic types in their day-to-day job, but it’s very likely that the libraries and frameworks that they’re heavily using have already done their duty to ensure static type-safety via such parametric polymorphism feature.

In a static-typing functional programming language like Scala, such feature would often need to be used first-hand in order to create useful functions that ensure type-safety while keeping the code lean and versatile. Generics is apparently taken seriously in Scala’s inherent language design. That, coupled with Scala’s implicit conversion, constitutes a signature feature of Scala’s. Given Scala’s love of “smileys”, a few of them are designated for the relevant functionalities.

Merge Sort

Merge Sort is a popular text-book sorting algorithm that I think also serves a great brain-teasing programming exercise. I have an old blog post about implementing Merge Sort using Java Generics. In this post, I’m going to use Merge Sort again to illustrate Scala’s type parameterization.

By means of a merge function which recursively merge-sorts the left and right halves of a partitioned list, a basic Merge Sort function for integer sorting might be something similar to the following:

  def mergeSort(ls: List[Int]): List[Int] = {
    def merge(l: List[Int], r: List[Int]): List[Int] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

A quick test …

scala> val li = List(9, 5, 16, 3, 4, 11, 8, 12)
li: List[Int] = List(9, 5, 16, 3, 4, 11, 8, 12)

scala> mergeSort(li)
res1: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

Contrary to Java Generics’ MyClass<T> notation, Scala’s generic types are in the form of MyClass[T]. Let’s generalize the integer Merge Sort as follows:

  def mergeSort[T](ls: List[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }
:15: error: value < is not a member of type parameter T
               if (lHead < rHead)
                         ^

The compiler immediately complains about the ‘<‘ comparison, since T might not be a type that supports ordering for ‘<‘ to make any sense. To generalize the Merge Sort function for any list type that supports ordering, we can supply a parameter in a curried form as follows:

  // Generic Merge Sort using math.Ordering
  import math.Ordering

  def mergeSort[T](ls: List[T])(order: Ordering[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (order.lt(lHead, rHead))
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a)(order), mergeSort(b)(order))
    }
  }

Another quick test …

scala> val li = List(9, 5, 16, 3, 4, 11, 8, 12)
li: List[Int] = List(9, 5, 16, 3, 4, 11, 8, 12)

scala> mergeSort(li)(Ordering[Int])
res2: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

scala> val ls = List("banana", "pear", "orange", "grape", "apple", "strawberry", "guava", "peach")
ls: List[String] = List(banana, pear, orange, grape, apple, strawberry, guava, peach)

scala> mergeSort(ls)(Ordering[String])
res3: List[String] = List(apple, banana, grape, guava, orange, peach, pear, strawberry)

That works well, but it’s cumbersome that one needs to supply the corresponding Ordering[T] for the list type. That’s where implicit parameter can help:

  // Generic Merge Sort using implicit math.Ordering parameter
  import math.Ordering

  def mergeSort[T](ls: List[T])(implicit order: Ordering[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (order.lt(lHead, rHead))
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

Testing again …

scala> mergeSort(li)
res3: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

scala> mergeSort(ls)
res4: List[String] = List(apple, banana, grape, guava, orange, peach, pear, strawberry)

Note that the ‘if (lHead < rHead)’ condition is now replaced with ‘if (order.lt(lHead, rHead))’. That’s because math.Ordering defines its own less-than method for generic types. Let’s dig a little deeper into how it works. Scala’s math.Ordering extends Java’s Comparator interface and implements method compare(x: T, y: T) for all the common types, Int, Long, Float, Double, String, etc. It then provides all these lt(x: T, y: T), gt(x: T, y: T), …, methods that know how to perform all the less-than, greater-than comparisons for various types.

The following are highlights of math.Ordering’s partial source code:

// Scala's math.Ordering source code highlight
import java.util.Comparator
...

trait Ordering[T] extends Comparator[T] with PartialOrdering[T] with Serializable {
  ...
  def compare(x: T, y: T): Int
  ...
  override def lt(x: T, y: T): Boolean = compare(x, y) < 0
  override def gt(x: T, y: T): Boolean = compare(x, y) > 0
  ...
  class Ops(lhs: T) {
    def <(rhs: T) = lt(lhs, rhs)
    def >(rhs: T) = gt(lhs, rhs)
    ...
  }
  implicit def mkOrderingOps(lhs: T): Ops = new Ops(lhs)
  ...
}

...

object Ordering extends LowPriorityOrderingImplicits {
  def apply[T](implicit ord: Ordering[T]) = ord
  ...
  trait IntOrdering extends Ordering[Int] {
    def compare(x: Int, y: Int) =
      if (x < y) -1
      else if (x == y) 0
      else 1
  }
  implicit object Int extends IntOrdering
  ...
  trait StringOrdering extends Ordering[String] {
    def compare(x: String, y: String) = x.compareTo(y)
  }
  implicit object String extends StringOrdering
  ...
}

Context Bound

Scala provides a typeclass pattern called Context Bound which represents such common pattern of passing in an implicit value:

  // Implicit value passed in as implicit parameter
  def someFunction[T](x: SomeClass[T])(implicit imp: AnotherClass[T]): WhateverClass[T] = {
    ...
  }

With the context bound syntactic sugar, it becomes:

  // Context Bound
  def someFunction[T : AnotherClass](x: SomeClass[T]): WhateverClass[T] = {
    ...
  }

The mergeSort function using context bound looks as follows:

  // Generic Merge Sort using Context Bound
  def mergeSort[T : Ordering](ls: List[T]): List[T] = {
    val order = implicitly[Ordering[T]]
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (order.lt(lHead, rHead))
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

Note that ‘implicitly[Ordering[T]]’ is there for access to methods in math.Ordering which is no longer passed in with a parameter name.

Scala’s math.Ordered versus math.Ordering

One noteworthy thing about math.Ordering is that it does not overload comparison operators ‘<‘, ‘>‘, etc, which is why method lt(x: T, y: T) is used instead in mergeSort for the ‘<‘ operator. To use comparison operators like ‘<‘, one would need to import order.mkOrderingOps (or order._) within the mergeSort function. That’s because in math.Ordering, comparison operators ‘<‘, ‘>‘, etc, are all defined in inner class Ops which can be instantiated by calling method mkOrderingOps.

Scala’s math.Ordered extends Java’s Comparable interface (instead of Comparator) and implements method compareTo(y: T), derived from math.Ordering’s compare(x: T, y: T) via implicit parameter. One nice thing about math.Ordered is that it consists of overloaded comparison operators.

The following highlights partial source code of math.Ordered:

// Scala's math.Ordered source code highlight
trait Ordered[A] extends Any with java.lang.Comparable[A] {
  ...
  def compare(that: A): Int
  def <  (that: A): Boolean = (this compare that) <  0
  def >  (that: A): Boolean = (this compare that) >  0
  def <= (that: A): Boolean = (this compare that) <= 0
  def >= (that: A): Boolean = (this compare that) >= 0
  def compareTo(that: A): Int = compare(that)
}

object Ordered {
  implicit def orderingToOrdered[T](x: T)(implicit ord: Ordering[T]): Ordered[T] = new Ordered[T] {
      def compare(that: T): Int = ord.compare(x, that)
  }
}

Using math.Ordered, an implicit method, implicit orderer: T => Ordered[T], (as opposed to an implicit value when using math.Ordering) is passed to the mergeSort function as a curried parameter. As illustrated in a previous blog post, it’s an implicit conversion rule for the compiler to fall back to when encountering problem associated with type T.

Below is a version of generic Merge Sort using math.Ordered:

  // Generic Merge Sort using implicit math.Ordered conversion
  import math.Ordered

  def mergeSort[T](ls: List[T])(implicit orderer: T => Ordered[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }
scala> mergeSort(li)
res5: List[Int] = List(3, 4, 5, 8, 9, 11, 12, 16)

scala> mergeSort(ls)
res6: List[String] = List(apple, banana, grape, guava, orange, peach, pear, strawberry)

View Bound

A couple of notes:

  1. The implicit method ‘implicit orderer: T => Ordered[T]’ is passed into the mergeSort function also as an implicit parameter.
  2. Function mergeSort has a signature of the following common form:
  // Implicit method passed in as implicit parameter
  def someFunction[T](x: SomeClass[T])(implicit imp: T => AnotherClass[T]): WhateverClass[T] = {
    ...
  }

Such pattern of implicit method passed in as implicit paramter is so common that it’s given the term called View Bound and awarded a designated smiley ‘<%’. Using view bound, it can be expressed as:

  // View Bound
  def someFunction[T <% AnotherClass[T]](x: SomeClass[T]): WhateverClass[T] = {
    ...
  }

Applying to the mergeSort function, it gives a slightly more lean and mean look:

  // Generic Merge Sort using view bound
  import math.Ordered

  def mergeSort[T <% Ordered[T]](ls: List[T]): List[T] = {
    def merge(l: List[T], r: List[T]): List[T] = (l, r) match {
      case (Nil, _) => r
      case (_, Nil) => l
      case (lHead :: lTail, rHead :: rTail) =>
        if (lHead < rHead)
          lHead :: merge(lTail, r)
        else
          rHead :: merge(l, rTail)
    }
    val n = ls.length / 2
    if (n == 0)
      ls
    else {
      val (a, b) = ls splitAt n
      merge(mergeSort(a), mergeSort(b))
    }
  }

As a side note, while the view bound looks like the other smiley ‘<:’ (Upper Bound), they represent very different things. An upper bound is commonly seen in the following form:

  // Upper Bound
  def someFunction[T <: S](x: T): R = {
    ...
  }

This means someFunction takes only input parameter of type T that is a sub-type of (or the same as) type S. While at it, a Lower Bound represented by the ‘>:’ smiley in the form of [T >: S] means the input parameter can only be a super-type of (or the same as) type S.

Programming Exercise – Sorting Algorithm

Unless algorithm development is part of the job, many software engineers use readily available algorithmic methods as needed and rarely need to develop algorithms themselves. I’m not talking about some complicated statistical or machine learning algorithms. Just simple mundane ones such as a sorting algorithm. Even if you don’t need to code algorithms, going back to writing a sorting program can still be a good exercise to review certain basic skills that might not be frequently used in your current day job. It’s a good-sized programming exercise that isn’t too trivial or taking up too much time. It also reminds you some clever (or dumb) tricks on how to perform sorting by means of recursive divide-and-merge, pivot partitioning, etc. And if nothing else, it might help you in your next technical job interview in the near future.

If you’re up for such an exercise, first, look up from Wikipedia or any other suitable source for a sorting algorithm (e.g. Merge Sort, Quick Sort) of your choice to re-familiarize yourself with its underlying mechanism. Next, decide on the scope of the application – for example, do you want an integer-sorting application or something more generic? … etc. Next, pseudo code, pick the programming language of your choice, and go for it.

Appended is a simple implementation of both Merge Sort and Quick Sort in Java. For the convenience of making method calls with varying data types and sorting algorithms, an interface (SimpleSort) and a wrapper (Sorter) are used. Java Generics is used to generalize the sorter for different data types. Adding Generics to a sorting application requires the using of either the Comparable or Comparator interface, as ordering is necessary in sorting. In this example application, the Comparable interface is used since the default ordering is good enough for basic sorting. The overall implementation code is pretty self explanatory.

SimpleSort.java

package sorting;

public interface SimpleSort<E extends Comparable<E>> {

    public void sort(E[] list);
}

Sorter.java

package sorting;

public class Sorter<E extends Comparable<E>> {

    public void performSorting(E[] list, SimpleSort<E> sortingAlgo) {
        long startTime, elapsedTime;

        System.out.println("\nOriginal " + list.getClass().getSimpleName() + " list:  ...");
        for (int i = 0; i < list.length; i++) {
            System.out.print(list[i] + "  ");
        }
        System.out.println();

        startTime = System.currentTimeMillis();
        sortingAlgo.sort(list);
        elapsedTime = System.currentTimeMillis() - startTime;

        System.out.println("\nResult: " + sortingAlgo.getClass().getSimpleName() + " ...");
        for (int i = 0; i < list.length; i++) {
            System.out.print(list[i] + "  ");
        }
        System.out.println("\n\nTime taken: " + elapsedTime + "ms");
    }
}

MergeSort.java

package sorting;

public class MergeSort<E extends Comparable<E>> implements SimpleSort<E> {

    private E[] list;
    private E[] holder;

    @SuppressWarnings({"unchecked"})
    public void sort(E[] list) {
        this.list = list;
        int size = list.length;
        // Type erasure: first bound class of <E extends Comparable<E>>
        holder = (E[]) new Comparable[size];

        mergeSort(0, size - 1);
    }

    private void mergeSort(int left, int right) {
        if (left < right) {
            int center = left + (right - left) / 2;
            mergeSort(left, center);
            mergeSort(center + 1, right);
            merge(left, center, right);
        }
    }

    private void merge(int left, int center, int right) {

        for (int i = left; i <= right; i++) {
            holder[i] = list[i];
        }

        int i = left;
        int j = center + 1;
        int k = left;

        while (i <= center && j <= right) {

            if (holder[i].compareTo(holder[j]) <= 0) {
                list[k] = holder[i];  // Overwrite list[k] with element from the left list
                i++;
            }
            else {
                list[k] = holder[j];  // Overwrite list[k] with element from the right list
                j++;
            }
            k++;
        }

        // Overwirte remaining list[k] if the left list isn't exhausted
        // No need to do the same for the right list, as its elements are already correctly placed
        while (i <= center) {
            list[k] = holder[i];
            k++;
            i++;
        }
    }
}

QuickSort.java

package sorting;

import java.util.Random;

public class QuickSort<E extends Comparable<E>> implements SimpleSort<E> {

    private E[] list;

    @SuppressWarnings({"unchecked"})
    public void sort(E[] list) {
        this.list = list;
        int size = list.length;

        quickSort(list, 0, size-1);
    }

    private void swapListElements(E[] list, int index1, int index2) {
        E tempValue = list[index1];
        list[index1] = list[index2];
        list[index2] = tempValue;
    }

    private int partitionIndex(E[] list, int left, int right, int pivot) {
        int pivotIndex;
        E pivotValue;

        pivotValue = list[pivot];
        swapListElements(list, pivot, right);  // Swap pivot element to rightmost index

        pivotIndex = left;  // Calculate pivot index starting from leftmost index
        for (int i = left; i < right; i++) {
            if (list[i].compareTo(pivotValue) < 0) {
                swapListElements(list, i, pivotIndex);
                pivotIndex++;
            }
        }
        swapListElements(list, pivotIndex, right);  // Swap pivot element back to calculated pivot index

        return pivotIndex;
    }

    private void quickSort(E[] list, int left, int right) {
        int randomIndex, pivotIndex;

        if (left < right) {
            // Pick random index between left and right (inclusive)
            Random randomNum = new Random();
            randomIndex = randomNum.nextInt(right-left+1) + left;

            pivotIndex = partitionIndex(list, left, right, randomIndex);
            quickSort(list, left, pivotIndex - 1);
            quickSort(list, pivotIndex + 1, right);
        }
        else
            return;
    }
}

SortingMain.java

package sorting;

public class SortingMain {

    @SuppressWarnings({"unchecked"})
    public static void main(String[] args) {

        Integer[] integerList = {8, 15, 6, 2, 11, 4, 7, 12, 3};
        Double[]  doubleList = {2.5, 8.0, 7.2, 4.9, 6.6, 1.8, 3.0, 7.5};
        Character[] characterList = {'M', 'X', 'B', 'C', 'T', 'R', 'F', 'S'};
        String[] stringList = {"lion", "tiger", "snow leopard", "jaguar", "cheeta", "cougar", "leopard"};

        Sorter sorter = new Sorter();

        sorter.performSorting(integerList, new MergeSort());
        sorter.performSorting(doubleList, new QuickSort());
        sorter.performSorting(characterList, new MergeSort());
        sorter.performSorting(stringList, new QuickSort());
    }
}