A Stateful Calculator In Scala

In one of the best books about Cats, Scala with Cats by Welsh & Gurnell, there is an interesting example illustrating how to build a stateful integer calculator using Cats State.

Cats State

The Cats State is a Scala object with the defining apply method:

def apply[S, A](f: S => (S, A)): State[S, A]

that takes a state-transforming function f: S => (S, A) where S represents the state type and A the result type. It returns a State[S, A] which is a type alias of StateT[Eval, S, A] (or equivalently IndexedStateT[Eval, S, S, A]).

StateT[F, S, A] takes a S state and produces an updated state and an A result wrapped in the F context. In this case, Eval which is equipped with stack-safety features is the context.

Other methods in the State object include the following:

def empty[S, A](implicit A: Monoid[A]): State[S, A]
def pure[S, A](a: A): State[S, A]
def get[S]: State[S, S]
def set[S](s: S): State[S, Unit]
def inspect[S, T](f: (S) => T): State[S, T]
def modify[S](f: (S) => S): State[S, Unit]

along with class methods such as run, runS and runA provided by the IndexedStateT class.

A stateful post-order calculator

The simplistic calculator processes a sequence of integer arithmetic operations in a “post-order” manner to return the computed result. In each arithmetic operation, it takes a pair of integer operands followed by an arithmetic operator. For example 1 2 + 3 * would be interpreted as (1 + 2) * 3.

Implementation is straight forward. The input string consisting of the integer operands and operators (+|-|*|/) will be parsed with operands being pushed into a stack and, upon coming across an operator, popped from the stack to carry out the corresponding arithmetics.

Implementation using Scala Cats

// A post-order integer calculator implemented using Scala Cats
// (Source: Scala with Cats by Welsh and Gurnell)

import cats.data.State

def operator(op: (Int, Int) => Int): State[List[Int], Int] = {
  State[List[Int], Int] {
    case x :: y :: ls =>
      val res = op(y, x)
      (res :: ls, res)
    case _ =>
      sys.error("Missing operands error!")
  }
}

def operand(value: String): State[List[Int], Int] = {
  value.toIntOption match {
    case Some(v) =>
      State[List[Int], Int] { ls => (v :: ls, v) }
    case None =>
      sys.error(s"Operand $value type error!")
  }
}

def evalOne(sym: String): State[List[Int], Int] = {
  if (sym == "+") operator(_ + _)
  else if (sym == "-") operator(_ - _)
  else if (sym == "*") operator(_ * _)
  else if (sym == "/") operator(_ / _)
  else operand(sym)
}

def evalAll(instructions: String): State[List[Int], Int] = {
  val list = instructions.split("\\s+").toList
  list.foldLeft(State.pure[List[Int], Int](0)){ (acc, sym) =>
    acc.flatMap(_ => evalOne(sym))
  }
}

// Test running ...

evalOne("30").run(Nil).value
// (List(30),30)

evalAll("10 20 + 15 - 5 / 4 *").run(Nil).value
// (List(12),12)

Using State[List[Int], Int], the operands are being kept in a stack (i.e. List[Int]) within the State structure and will be extracted to carry out the integer arithmetic operations. Method operand() takes a String-typed integer and pushes into the stack, and method operator() takes a binary function (Int, Int) => Int to process the two most recently pushed integers from the stack with the corresponding arithmetic operator.

Using the two helper methods, evalOne() transforms a given operand or operator into a State[List[Int], Int]. Finally, evalAll() takes an input String of a sequence of post-order arithmetic operations, parses the content and compute the result using evalOne iteratively in a fold aggregation.

Implementing with a plain Scala class

Now, what if one wants to stick to using Scala’s standard library? Since the approach of using Cats State structure has just proved itself to be an effective one, we could come up with a simple Scala class to mimic what Cats State[S, A] does.

For what we need, we’ll minimally need a class that takes a S => (S, A) state transformation function and an equivalence of Cats State’s flatMap for chaining of operations.

case class State[S, A](r: S => (S, A)) {
  def result(s: S): (S, A) = r(s)
  def map[B](f: A => B): State[S, B] =
    State(r andThen { case (s, a) => (s, f(a)) })
  def flatMap[B](g: A => State[S, B]): State[S, B] =
    State(r andThen { case (s, a) => g(a).r(s) })
}

As shown in the above snippet, method flatMap is created by composing function r with a partial function via andThen. Though not needed for this particular calculator implementation, we also come up with method map, for completeness if nothing else. Method result is simply for extracting the post-transformation (S, A) tuple.

With the Scala State class, we can implement the calculator’s parsing and arithmetic operations just like how it was done using Cats State.

// A stateful post-order calculator implemented using plain Scala class

case class State[S, A](r: S => (S, A)) {
  def result(s: S): (S, A) = r(s)
  def map[B](f: A => B): State[S, B] =
    State(r andThen { case (s, a) => (s, f(a)) })
  def flatMap[B](g: A => State[S, B]): State[S, B] =
    State(r andThen { case (s, a) => g(a).r(s) })
}

def operator(op: (Int, Int) => Int): State[List[Int], Int] = {
  State[List[Int], Int] {
    case x :: y :: ls =>
      val res = op(y, x)
      (res :: ls, res)
    case _ =>
      throw new Exception("Missing operands error!")
  }
}

def operand(value: String): State[List[Int], Int] = {
  value.toIntOption match {
    case Some(v) =>
      State[List[Int], Int] { ls => (v :: ls, v) }
    case None =>
      throw new Exception(s"Operand $value type error!")
  }
}

def evalOne(sym: String): State[List[Int], Int] = {
  if (sym == "+") operator(_ + _)
  else if (sym == "-") operator(_ - _)
  else if (sym == "*") operator(_ * _)
  else if (sym == "/") operator(_ / _)
  else operand(sym)
}

def evalAll(ops: String): State[List[Int], Int] = {
  val list = ops.split("\\s+").toList
  list.foldLeft(State[List[Int], Int](_ => (Nil, 0))){ (acc, op) =>
    acc.flatMap(_ => evalOne(op))
  }
}

// Test running ...

evalOne("30").result(Nil)
// (List(30),30)

evalAll("10 20 + 15 - 5 / 4 *").result(Nil)  // (List(12),12)
// evalAll("1 2 + 3 4 + *").result(Nil)  // ((List(21),21)

Leave a Reply

Your email address will not be published. Required fields are marked *