Monthly Archives: October 2019

Spark – Schema With Nested Columns

Extracting columns based on certain criteria from a DataFrame (or Dataset) with a flat schema of only top-level columns is simple. It gets slightly less trivial, though, if the schema consists of hierarchical nested columns.

Recursive traversal

In functional programming, a common tactic to traverse arbitrarily nested collections of elements is through recursion. It’s generally preferred over using while-loops with mutable counters. For performance at scale, making the traversal tail-recursive may be necessary – although it’s less of a concern in this case given that a DataFrame typically consists not more than a few hundreds of columns and a few levels of nesting.

We’re going to illustrate in a couple of simple examples how recursion can be used to effectively process a DataFrame with a schema of nested columns.

Example #1:  Get all nested columns of a given data type

Consider the following snippet:

import org.apache.spark.sql.types._

def getColsByType(schema: StructType, dType: DataType) = {
  def recurPull(sType: StructType, prefix: String): Seq[String] = sType.fields.flatMap {
    case StructField(name, dt: StructType, _, _) ⇒
      recurPull(dt, s"$prefix$name.")
    case StructField(name, dt: ArrayType, _, _) if dt.elementType.isInstanceOf[StructType] ⇒
      recurPull(dt.elementType.asInstanceOf[StructType], s"$prefix$name[].")
    case field @ StructField(name, _, _, _) if field.dataType == dType ⇒
      Seq(s"$prefix$name")
    case _ ⇒
      Seq.empty[String]
  }
  recurPull(schema, "")
}

By means of a simple recursive method, the data type of each column in the DataFrame is traversed and, in the case of `StructType`, recurs to traverse its child columns. A string-type `prefix` during the traversal is assembled to express the hierarchy of the individual nested columns and gets prepended to columns with the matching data type.

Testing the method:

import org.apache.spark.sql.functions._
import spark.implicits._

case class Spec(sid: Int, desc: String)
case class Prod(pid: String, spec: Spec)

val df = Seq(
  (101, "jenn", Seq(1, 2), Seq(Spec(1, "A"), Spec(2, "B")), Prod("X11", Spec(11, "X")), 1100.0),
  (202, "mike", Seq(3), Seq(Spec(3, "C")), Prod("Y22", Spec(22, "Y")), 2200.0)
).toDF("uid", "user", "ids", "specs", "product", "amount")

getColsByType(df.schema, DoubleType)
// res1: Seq[String] = ArraySeq(amount)

getColsByType(df.schema, IntegerType)
// res2: Seq[String] = ArraySeq(uid, specs[].sid, product.spec.sid)

getColsByType(df.schema, StringType)
// res3: Seq[String] = ArraySeq(user, specs[].desc, product.pid, product.spec.desc)

// getColsByType(df.schema, ArrayType)
// ERROR: type mismatch;
//  found   : org.apache.spark.sql.types.ArrayType.type
//  required: org.apache.spark.sql.types.DataType

getColsByType(df.schema, ArrayType(IntegerType, false))
// res4: Seq[String] = ArraySeq(ids)

Example #2:  Rename all nested columns via a provided function

In this example, we’re going to rename columns in a DataFrame with a nested schema based on a provided `rename` function. The required logic for recursively traversing the nested columns is pretty much the same as in the previous example.

import org.apache.spark.sql.types._

def renameAllCols(schema: StructType, rename: String ⇒ String): StructType = {
  def recurRename(sType: StructType): Seq[StructField] = sType.fields.map{
    case StructField(name, dt: StructType, nu, meta) ⇒
      StructField(rename(name), StructType(recurRename(dt)), nu, meta)
    case StructField(name, dt: ArrayType, nu, meta) if dt.elementType.isInstanceOf[StructType] ⇒
      StructField(rename(name), ArrayType(StructType(recurRename(dt.elementType.asInstanceOf[StructType])), true), nu, meta)
    case StructField(name, dt, nu, meta) ⇒
      StructField(rename(name), dt, nu, meta)
  }
  StructType(recurRename(schema))
}

Testing the method (with the same DataFrame used in the previous example):

import org.apache.spark.sql.functions._
import spark.implicits._

val renameFcn = (s: String) ⇒
  if (s.endsWith("id")) s.replace("id", "_id") else s

val newDF = spark.createDataFrame(df.rdd, renameAllCols(df.schema, renameFcn))

newDF.printSchema
// root
//  |-- u_id: integer (nullable = false)
//  |-- user: string (nullable = true)
//  |-- ids: array (nullable = true)
//  |    |-- element: integer (containsNull = false)
//  |-- specs: array (nullable = true)
//  |    |-- element: struct (containsNull = true)
//  |    |    |-- s_id: integer (nullable = false)
//  |    |    |-- desc: string (nullable = true)
//  |-- product: struct (nullable = true)
//  |    |-- p_id: string (nullable = true)
//  |    |-- spec: struct (nullable = true)
//  |    |    |-- s_id: integer (nullable = false)
//  |    |    |-- desc: string (nullable = true)
//  |-- amount: double (nullable = false)

In case it isn’t obvious, in traversing a given StructType‘s child columns, we use map (as opposed to flatMap in the previous example) to preserve the hierarchical column structure.