Extracting columns based on certain criteria from a DataFrame (or Dataset) with a flat schema of only top-level columns is simple. It gets slightly less trivial, though, if the schema consists of hierarchical nested columns.
Recursive traversal
In functional programming, a common tactic to traverse arbitrarily nested collections of elements is through recursion. It’s generally preferred over using while-loops with mutable counters. For performance at scale, making the traversal tail-recursive may be necessary – although it’s less of a concern in this case given that a DataFrame typically consists not more than a few hundreds of columns and a few levels of nesting.
We’re going to illustrate in a couple of simple examples how recursion can be used to effectively process a DataFrame with a schema of nested columns.
Example #1: Get all nested columns of a given data type
Consider the following snippet:
import org.apache.spark.sql.types._ def getColsByType(schema: StructType, dType: DataType) = { def recurPull(sType: StructType, prefix: String): Seq[String] = sType.fields.flatMap { case StructField(name, dt: StructType, _, _) ⇒ recurPull(dt, s"$prefix$name.") case StructField(name, dt: ArrayType, _, _) if dt.elementType.isInstanceOf[StructType] ⇒ recurPull(dt.elementType.asInstanceOf[StructType], s"$prefix$name[].") case field @ StructField(name, _, _, _) if field.dataType == dType ⇒ Seq(s"$prefix$name") case _ ⇒ Seq.empty[String] } recurPull(schema, "") }
By means of a simple recursive method, the data type of each column in the DataFrame is traversed and, in the case of `StructType`, recurs to traverse its child columns. A string-type `prefix` during the traversal is assembled to express the hierarchy of the individual nested columns and gets prepended to columns with the matching data type.
Testing the method:
import org.apache.spark.sql.functions._ import spark.implicits._ case class Spec(sid: Int, desc: String) case class Prod(pid: String, spec: Spec) val df = Seq( (101, "jenn", Seq(1, 2), Seq(Spec(1, "A"), Spec(2, "B")), Prod("X11", Spec(11, "X")), 1100.0), (202, "mike", Seq(3), Seq(Spec(3, "C")), Prod("Y22", Spec(22, "Y")), 2200.0) ).toDF("uid", "user", "ids", "specs", "product", "amount") getColsByType(df.schema, DoubleType) // res1: Seq[String] = ArraySeq(amount) getColsByType(df.schema, IntegerType) // res2: Seq[String] = ArraySeq(uid, specs[].sid, product.spec.sid) getColsByType(df.schema, StringType) // res3: Seq[String] = ArraySeq(user, specs[].desc, product.pid, product.spec.desc) // getColsByType(df.schema, ArrayType) // ERROR: type mismatch; // found : org.apache.spark.sql.types.ArrayType.type // required: org.apache.spark.sql.types.DataType getColsByType(df.schema, ArrayType(IntegerType, false)) // res4: Seq[String] = ArraySeq(ids)
Example #2: Rename all nested columns via a provided function
In this example, we’re going to rename columns in a DataFrame with a nested schema based on a provided `rename` function. The required logic for recursively traversing the nested columns is pretty much the same as in the previous example.
import org.apache.spark.sql.types._ def renameAllCols(schema: StructType, rename: String ⇒ String): StructType = { def recurRename(sType: StructType): Seq[StructField] = sType.fields.map{ case StructField(name, dt: StructType, nu, meta) ⇒ StructField(rename(name), StructType(recurRename(dt)), nu, meta) case StructField(name, dt: ArrayType, nu, meta) if dt.elementType.isInstanceOf[StructType] ⇒ StructField(rename(name), ArrayType(StructType(recurRename(dt.elementType.asInstanceOf[StructType])), true), nu, meta) case StructField(name, dt, nu, meta) ⇒ StructField(rename(name), dt, nu, meta) } StructType(recurRename(schema)) }
Testing the method (with the same DataFrame used in the previous example):
import org.apache.spark.sql.functions._ import spark.implicits._ val renameFcn = (s: String) ⇒ if (s.endsWith("id")) s.replace("id", "_id") else s val newDF = spark.createDataFrame(df.rdd, renameAllCols(df.schema, renameFcn)) newDF.printSchema // root // |-- u_id: integer (nullable = false) // |-- user: string (nullable = true) // |-- ids: array (nullable = true) // | |-- element: integer (containsNull = false) // |-- specs: array (nullable = true) // | |-- element: struct (containsNull = true) // | | |-- s_id: integer (nullable = false) // | | |-- desc: string (nullable = true) // |-- product: struct (nullable = true) // | |-- p_id: string (nullable = true) // | |-- spec: struct (nullable = true) // | | |-- s_id: integer (nullable = false) // | | |-- desc: string (nullable = true) // |-- amount: double (nullable = false)
In case it isn’t obvious, in traversing a given StructType
‘s child columns, we use map
(as opposed to flatMap
in the previous example) to preserve the hierarchical column structure.