A Merkle tree, a.k.a. hash tree, is a tree in which every leaf node contains a cryptographic hash of a dataset, and every branch node contains a hash of the concatenation of the corresponding hashes of its child nodes. Typical usage is for efficient verification of the content stored in the tree nodes.
Blockchain and Merkle tree
As cryptocurrency (or more generally, blockchain system) has become popular, so has its underlying authentication-oriented data structure, Merkle tree. In the cryptocurrency world, a blockchain can be viewed as a distributed ledger consisting of immutable but chain-able blocks, each of which hosts a set of transactions in the form of a Merkle tree. In order to chain a new block to an existing blockchain, part of the tamper-proof requirement is to guarantee the integrity of the enclosed transactions by composing their hashes in a specific way and storing them in a Merkle tree.
In case the above sounds like gibberish, here’s a great introductory article about blockchain. To delve slight deeper into it with a focus on cryptocurrency, this blockchain guide from the Bitcoin Project website might be of interest. Just to be clear, even though blockchain helps popularize Merkle tree, implementing a flavor of the data structure does not require knowledge of blockchain or cryptocurrency.
In this blog post, we will assemble a barebone Merkle tree using Scala. While a Merkle tree is most often a binary tree, it’s certainly not confined to be one, although that’s what we’re going to implement.
A barebone Merkle tree class
// MerkleTree class class MerkleTree( val hash: Array[Byte], val left: Option[MerkleTree] = None, val right: Option[MerkleTree] = None )
Note that when both the class fields `left` and `right` are `None`, it represents a leaf node.
To build a Merkle tree from a collection of byte-arrays (which might represent a transaction dataset), we will use a companion object to perform the task via its `apply` method. To create a `hash` within each of the tree nodes, we will also need a hash function, `hashFcn` of type `Array[Byte] => Array[Byte]`.
// MerkleTree companion object object MerkleTree { def apply(data: Array[Array[Byte]], hashFcn: Array[Byte] => Array[Byte]): MerkleTree = { val nodes = data.map(byteArr => new MerkleTree(hashFcn(byteArr))) buildTree(nodes, hashFcn)(0) // Return root of the tree } private def buildTree(...): Array[MerkleTree] = ??? }
Building a Merkle tree
As shown in the code, what’s needed for function `buildTree` is to recursively pair up the nodes to form a tree with each of its nodes consisting the combined hash of their corresponding child nodes. The recursive pairing will eventually end with the single top-level node called the Merkle root. Below is an implementation of such a function:
// Building a Merkle tree @scala.annotation.tailrec private def buildTree( nodes: Array[MerkleTree], hashFcn: Array[Byte] => Array[Byte]): Array[MerkleTree] = nodes match { case ns if ns.size <= 1 => ns case ns => val pairedNodes = ns.grouped(2).map{ case Array(a, b) => new MerkleTree(hashFcn(a.hash ++ b.hash), Some(a), Some(b)) case Array(a) => new MerkleTree(hashFcn(a.hash), Some(a), None) }.toArray buildTree(pairedNodes, hashFcn) }
Now, back to class `MerkleTree`, and let’s add a simple function for computing height of the tree:
// Computing the height of a Merkle tree def height: Int = { def loop(node: MerkleTree): Int = { if (!node.left.isEmpty && !node.right.isEmpty) math.max(loop(node.left.get), loop(node.right.get)) + 1 else if (!node.left.isEmpty) loop(node.left.get) + 1 else if(!node.right.isEmpty) loop(node.right.get) + 1 else 1 } loop(this) }
Putting all the pieces together
For illustration purpose, we’ll add a side-effecting function `printNodes` along with a couple of for-display utility functions so as to see what our Merkle tree can do. Putting everything altogether, we have:
// Merkle tree class and companion object class MerkleTree( val hash: Array[Byte], val left: Option[MerkleTree] = None, val right: Option[MerkleTree] = None) { def height: Int = { def loop(node: MerkleTree): Int = { if (!node.left.isEmpty && !node.right.isEmpty) math.max(loop(node.left.get), loop(node.right.get)) + 1 else if (!node.left.isEmpty) loop(node.left.get) + 1 else if(!node.right.isEmpty) loop(node.right.get) + 1 else 1 } loop(this) } override def toString: String = s"MerkleTree(hash = ${bytesToBase64(hash)})" private def toShortString: String = s"MT(${bytesToBase64(hash).substring(0, 4)})" private def bytesToBase64(bytes: Array[Byte]): String = java.util.Base64.getEncoder.encodeToString(bytes) def printNodes: Unit = { def printlnByLevel(t: MerkleTree): Unit = { for (l <- 1 to t.height) { loopByLevel(t, l) println } } def loopByLevel(node: MerkleTree, level: Int): Unit = { if (level <= 1) print(s"${node.toShortString} ") else { if (!node.left.isEmpty) loopByLevel(node.left.get, level - 1) else () if (!node.right.isEmpty) loopByLevel(node.right.get, level - 1) else () } } printlnByLevel(this) } } object MerkleTree { def apply(data: Array[Array[Byte]], hashFcn: Array[Byte] => Array[Byte]): MerkleTree = { @scala.annotation.tailrec def buildTree(nodes: Array[MerkleTree]): Array[MerkleTree] = nodes match { case ns if ns.size <= 1 => ns case ns => val pairedNodes = ns.grouped(2).map{ case Array(a, b) => new MerkleTree(hashFcn(a.hash ++ b.hash), Some(a), Some(b)) case Array(a) => new MerkleTree(hashFcn(a.hash), Some(a), None) }.toArray buildTree(pairedNodes) } if (data.isEmpty) new MerkleTree(hashFcn(Array.empty[Byte])) else { val nodes = data.map(byteArr => new MerkleTree(hashFcn(byteArr))) buildTree(nodes)(0) // Return root of the tree } } }
Test building the Merkle tree with a hash function
By providing the required arguments for MerkleTree’s `apply` factory method, let’s create a Merkle tree with, say, 5 dummy byte-arrays using a popular hash function `SHA-256`. The created Merkle tree will be represented by its tree root, a.k.a. Merkle Root:
// Test building a Merkle tree val sha256: Array[Byte] => Array[Byte] = byteArr => java.security.MessageDigest.getInstance("SHA-256").digest(byteArr) val data = Array( Array[Byte](0, 1, 2), Array[Byte](3, 4, 5), Array[Byte](6, 7, 8), Array[Byte](9, 10, 11), Array[Byte](12, 13, 14) ) val mRoot = MerkleTree(data, sha256) // mt: MerkleTree = MerkleTree(hash = C6OoSW1rymkivkJPBrhf9necQAbzPsq7RyZKd4ZU8hM=) mRoot.hash // res1: Array[Byte] = Array(11, -93, -88, 73, 109, 107, -54, 105, ...) mRoot.printNodes // MT(C6Oo) // MT(QKRB) MT(Hfri) // MT(J3JM) MT(d7VY) MT(9ypu) // MT(rksy) MT(KEhp) MT(Q4f2) MT(45qR) MT(BSpd)
As can be seen from the output, the 5 dummy data blocks get hashed and placed in the 5 leaf nodes, each with its hash value wrapped with its sibling’s (if any) in another hash and placed in the parent node.
For a little better clarity, below is an edited output in a tree structure:
// MT(C6Oo) // / \ // ---------- ----------- // / \ // MT(QKRB) MT(Hfri) // / \ / // / \ / // MT(J3JM) MT(d7VY) MT(9ypu) // / \ / \ / // / \ / \ / // MT(rksy) MT(KEhp) MT(Q4f2) MT(45qR) MT(BSpd)
Building a Merkle tree from blockchain transactions
To apply the using of Merkle tree in the blockchain world, we’ll substitute the data block with a sequence of transactions from a blockchain.
First, we define a trivialized `Transaction` class with the transaction ID being the hash value of the combined class fields using the same hash function `sha256`:
// A trivialized Transaction class import java.nio.file.{Files, Paths} import java.time.{Instant, LocalDateTime, ZoneId} import java.time.format.DateTimeFormatter case class Transaction(id: String, from: String, to: String, amount: Long, timestamp: Long) { override def toString: String = { val ts = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss").format( LocalDateTime.ofInstant(Instant.ofEpochMilli(timestamp), ZoneId.of("UTC")) ) s"T(${id.substring(0, 4)}, ${from} -> ${to}, ${amount}, ${ts})" } } object Transaction { def apply( from: String, to: String, amount: Long, timestamp: Long, hashFcn: Array[Byte] => Array[Byte] ): Transaction = { val bytes = from.getBytes ++ to.getBytes ++ longToBytes(amount) ++ longToBytes(timestamp) new Transaction(bytesToBase64(hashFcn(bytes)), from, to, amount, timestamp) } private def bytesToBase64(bytes: Array[Byte]): String = java.util.Base64.getEncoder.encodeToString(bytes) private def longToBytes(num: Long) = java.nio.ByteBuffer.allocate(8).putLong(num).array }
Next, we create an array of transactions:
// An array of transactions val transactions = Array( Transaction("Alice", "Bob", 2500L, 1582587819175L, sha256), Transaction("Bob", "Carl", 4000L, 1582588700350L, sha256), Transaction("Carl", "Dana", 4000L, 1582591774502L, sha256) ) // transactions: Array[Transaction] = Array( // T(ikSk, Alice -> Bob, 2500, 2020-02-24 23:43:39), // T(+EvZ, Bob -> Carl, 4000, 2020-02-24 23:58:20), // T(Ke8m, Carl -> Dana, 4000, 2020-02-25 00:49:34) // )
Again, using MerkleTree’s `apply` factory method, we build a Merkle tree consisting of hash values of the individual transaction IDs, which in turn are hashes of their corresponding transaction content:
// Creating the Merkle root val mRoot = MerkleTree(transactions.map(_.id.getBytes), sha256) // res1: MerkleTree = MerkleTree(hash = CobcOH899Hq91uk3cXRR5As5J+ThqLacivYEZifhhfM=) mRoot.printNodes // MT(Cobc) // MT(VG1C) MT(g7oF) // MT(nhJS) MT(nt1U) MT(pslD)
The Merkle root along with the associated transactions are kept in an immutable block. It’s also an integral part of the elements to be collectively hashed into the block-identifying hash value. The block hash will serve as the linking block ID for the next block that manages to successfully append to it. All the cross-hashing operations coupled with the immutable block structure make any attempt to tamper with the blockchain content highly expensive.