Implement Merkle Tree operations.

Implements the operations on Merkle Trees as defined in RFC 9162, and
includes the extensions that work with subtrees as defined in
draft-davidben-tls-merkle-tree-certs.

Change-Id: I0a9fca045146d951c804e8f30f2fa3bd561ddb15
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/82567
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Nick Harper <nharper@chromium.org>
diff --git a/.clang-format-ignore b/.clang-format-ignore
new file mode 100644
index 0000000..85b2aef
--- /dev/null
+++ b/.clang-format-ignore
@@ -0,0 +1 @@
+build.json
diff --git a/build.json b/build.json
index 8f77c42..f8ac562 100644
--- a/build.json
+++ b/build.json
@@ -613,6 +613,7 @@
             "pki/general_names.cc",
             "pki/input.cc",
             "pki/ip_util.cc",
+            "pki/merkle_tree.cc",
             "pki/name_constraints.cc",
             "pki/ocsp.cc",
             "pki/parse_certificate.cc",
@@ -658,6 +659,7 @@
             "pki/general_names.h",
             "pki/input.h",
             "pki/ip_util.h",
+            "pki/merkle_tree.h",
             "pki/mock_signature_verify_cache.h",
             "pki/name_constraints.h",
             "pki/nist_pkits_unittest.h",
@@ -963,6 +965,7 @@
             "pki/general_names_unittest.cc",
             "pki/input_unittest.cc",
             "pki/ip_util_unittest.cc",
+            "pki/merkle_tree_unittest.cc",
             "pki/mock_signature_verify_cache.cc",
             "pki/name_constraints_unittest.cc",
             "pki/nist_pkits_unittest.cc",
diff --git a/gen/sources.bzl b/gen/sources.bzl
index f7f8f3f..b1a7ffc 100644
--- a/gen/sources.bzl
+++ b/gen/sources.bzl
@@ -1221,6 +1221,7 @@
     "pki/general_names.cc",
     "pki/input.cc",
     "pki/ip_util.cc",
+    "pki/merkle_tree.cc",
     "pki/name_constraints.cc",
     "pki/ocsp.cc",
     "pki/parse_certificate.cc",
@@ -1267,6 +1268,7 @@
     "pki/general_names.h",
     "pki/input.h",
     "pki/ip_util.h",
+    "pki/merkle_tree.h",
     "pki/mock_signature_verify_cache.h",
     "pki/name_constraints.h",
     "pki/nist_pkits_unittest.h",
@@ -1304,6 +1306,7 @@
     "pki/general_names_unittest.cc",
     "pki/input_unittest.cc",
     "pki/ip_util_unittest.cc",
+    "pki/merkle_tree_unittest.cc",
     "pki/mock_signature_verify_cache.cc",
     "pki/name_constraints_unittest.cc",
     "pki/nist_pkits_unittest.cc",
diff --git a/gen/sources.cmake b/gen/sources.cmake
index dbac7a5..7f23f4e 100644
--- a/gen/sources.cmake
+++ b/gen/sources.cmake
@@ -1265,6 +1265,7 @@
   pki/general_names.cc
   pki/input.cc
   pki/ip_util.cc
+  pki/merkle_tree.cc
   pki/name_constraints.cc
   pki/ocsp.cc
   pki/parse_certificate.cc
@@ -1315,6 +1316,7 @@
   pki/general_names.h
   pki/input.h
   pki/ip_util.h
+  pki/merkle_tree.h
   pki/mock_signature_verify_cache.h
   pki/name_constraints.h
   pki/nist_pkits_unittest.h
@@ -1354,6 +1356,7 @@
   pki/general_names_unittest.cc
   pki/input_unittest.cc
   pki/ip_util_unittest.cc
+  pki/merkle_tree_unittest.cc
   pki/mock_signature_verify_cache.cc
   pki/name_constraints_unittest.cc
   pki/nist_pkits_unittest.cc
diff --git a/gen/sources.gni b/gen/sources.gni
index a753926..e38d9a7 100644
--- a/gen/sources.gni
+++ b/gen/sources.gni
@@ -1221,6 +1221,7 @@
   "pki/general_names.cc",
   "pki/input.cc",
   "pki/ip_util.cc",
+  "pki/merkle_tree.cc",
   "pki/name_constraints.cc",
   "pki/ocsp.cc",
   "pki/parse_certificate.cc",
@@ -1267,6 +1268,7 @@
   "pki/general_names.h",
   "pki/input.h",
   "pki/ip_util.h",
+  "pki/merkle_tree.h",
   "pki/mock_signature_verify_cache.h",
   "pki/name_constraints.h",
   "pki/nist_pkits_unittest.h",
@@ -1304,6 +1306,7 @@
   "pki/general_names_unittest.cc",
   "pki/input_unittest.cc",
   "pki/ip_util_unittest.cc",
+  "pki/merkle_tree_unittest.cc",
   "pki/mock_signature_verify_cache.cc",
   "pki/name_constraints_unittest.cc",
   "pki/nist_pkits_unittest.cc",
diff --git a/gen/sources.json b/gen/sources.json
index beec447..011adb2 100644
--- a/gen/sources.json
+++ b/gen/sources.json
@@ -1204,6 +1204,7 @@
       "pki/general_names.cc",
       "pki/input.cc",
       "pki/ip_util.cc",
+      "pki/merkle_tree.cc",
       "pki/name_constraints.cc",
       "pki/ocsp.cc",
       "pki/parse_certificate.cc",
@@ -1248,6 +1249,7 @@
       "pki/general_names.h",
       "pki/input.h",
       "pki/ip_util.h",
+      "pki/merkle_tree.h",
       "pki/mock_signature_verify_cache.h",
       "pki/name_constraints.h",
       "pki/nist_pkits_unittest.h",
@@ -1286,6 +1288,7 @@
       "pki/general_names_unittest.cc",
       "pki/input_unittest.cc",
       "pki/ip_util_unittest.cc",
+      "pki/merkle_tree_unittest.cc",
       "pki/mock_signature_verify_cache.cc",
       "pki/name_constraints_unittest.cc",
       "pki/nist_pkits_unittest.cc",
diff --git a/gen/sources.mk b/gen/sources.mk
index 622d932..9a9b350 100644
--- a/gen/sources.mk
+++ b/gen/sources.mk
@@ -1200,6 +1200,7 @@
   pki/general_names.cc \
   pki/input.cc \
   pki/ip_util.cc \
+  pki/merkle_tree.cc \
   pki/name_constraints.cc \
   pki/ocsp.cc \
   pki/parse_certificate.cc \
@@ -1244,6 +1245,7 @@
   pki/general_names.h \
   pki/input.h \
   pki/ip_util.h \
+  pki/merkle_tree.h \
   pki/mock_signature_verify_cache.h \
   pki/name_constraints.h \
   pki/nist_pkits_unittest.h \
@@ -1280,6 +1282,7 @@
   pki/general_names_unittest.cc \
   pki/input_unittest.cc \
   pki/ip_util_unittest.cc \
+  pki/merkle_tree_unittest.cc \
   pki/mock_signature_verify_cache.cc \
   pki/name_constraints_unittest.cc \
   pki/nist_pkits_unittest.cc \
diff --git a/include/openssl/span.h b/include/openssl/span.h
index 4f01247..11798a7 100644
--- a/include/openssl/span.h
+++ b/include/openssl/span.h
@@ -168,6 +168,21 @@
   // NOLINTNEXTLINE(google-explicit-constructor): same as std::span.
   constexpr Span(T (&array)[NA]) : internal::SpanStorage<T, N>(array, NA) {}
 
+  // TODO(crbug.com/457351017): Add tests for these c'tors.
+  template <size_t NA, typename U,
+            typename = internal::EnableIfContainer<std::array<U, NA>, T>,
+            typename = std::enable_if_t<N == NA || N == dynamic_extent>>
+  // NOLINTNEXTLINE(google-explicit-constructor): same as std::span.
+  constexpr Span(std::array<U, NA> &array)
+      : internal::SpanStorage<T, N>(array.data(), NA) {}
+
+  template <size_t NA, typename U,
+            typename = internal::EnableIfContainer<const std::array<U, NA>, T>,
+            typename = std::enable_if_t<N == NA || N == dynamic_extent>>
+  // NOLINTNEXTLINE(google-explicit-constructor): same as std::span.
+  constexpr Span(const std::array<U, NA> &array)
+      : internal::SpanStorage<T, N>(array.data(), NA) {}
+
   template <
       size_t NA, typename U,
       typename = std::enable_if_t<std::is_convertible_v<U (*)[], T (*)[]>>,
@@ -310,6 +325,10 @@
 Span(T *, size_t) -> Span<T>;
 template <typename T, size_t size>
 Span(T (&array)[size]) -> Span<T, size>;
+template <typename T, size_t size>
+Span(std::array<T, size> &array) -> Span<T, size>;
+template <typename T, size_t size>
+Span(const std::array<T, size> &array) -> Span<const T, size>;
 template <
     typename C,
     typename T = std::remove_pointer_t<decltype(std::declval<C>().data())>,
diff --git a/pki/merkle_tree.cc b/pki/merkle_tree.cc
new file mode 100644
index 0000000..8247ce3
--- /dev/null
+++ b/pki/merkle_tree.cc
@@ -0,0 +1,206 @@
+// Copyright 2025 The BoringSSL Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "merkle_tree.h"
+
+#include <algorithm>
+#include <optional>
+
+#include <openssl/mem.h>
+#include <openssl/span.h>
+#include "openssl/sha2.h"
+
+BSSL_NAMESPACE_BEGIN
+
+// Computes HASH(0x01 || left || right) and saves the result to |out|.
+void HashNode(TreeHashConstSpan left, TreeHashConstSpan right,
+              TreeHashSpan out) {
+  static const uint8_t header = 0x01;
+  SHA256_CTX ctx;
+  SHA256_Init(&ctx);
+  SHA256_Update(&ctx, &header, 1);
+  SHA256_Update(&ctx, left.data(), left.size());
+  SHA256_Update(&ctx, right.data(), right.size());
+  SHA256_Final(out.data(), &ctx);
+}
+
+namespace {
+
+std::optional<TreeHashConstSpan> NextProofHash(Span<const uint8_t> *proof) {
+  if (proof->size() < SHA256_DIGEST_LENGTH) {
+    return std::nullopt;
+  }
+  auto ret = proof->first<SHA256_DIGEST_LENGTH>();
+  *proof = proof->subspan(SHA256_DIGEST_LENGTH);
+  return ret;
+}
+
+}  // namespace
+
+std::optional<TreeHash> EvaluateMerkleSubtreeConsistencyProof(
+    uint64_t n, const Subtree &subtree, Span<const uint8_t> proof,
+    TreeHashConstSpan node_hash) {
+  // For more detail on how subtree consistency proofs work, see appendix B
+  // of draft-davidben-tls-merkle-tree-certs-08.
+
+  // Check that inputs are valid. (Step 1)
+  if (!subtree.IsValid() || n < subtree.end) {
+    return std::nullopt;
+  }
+
+  // Initialize fn (first number), sn (second number), and tn (third number).
+  // Each number is the path from the root of the tree to 1) the leftmost child
+  // of the subtree, 2) the rightmost child of the subtree, and 3) the rightmost
+  // child of the full tree. (Step 2)
+  uint64_t fn = subtree.start;
+  uint64_t sn = subtree.end - 1;
+  uint64_t tn = n - 1;
+
+  // The bit patterns of these numbers indicates whether the path goes left or
+  // right (or in some cases on the rightmost edge of a (sub)treee, that a level
+  // of the tree is skipped).
+  //
+  // When consuming the proof, we work up the tree from the bottom level (level
+  // 0) up to the root, and moving up one level in the tree corresponds to
+  // consuming the least significant bit from each of fn, sn, and tn. The proof
+  // can start at a higher level than level 0 if the node on the right edge of
+  // the subtree at that level is also a node of the overall tree.
+  //
+  // Remove bits equally from the right of fn, sn, and tn to skip to the level
+  // of the tree where the proof starts. (Steps 3 and 4)
+  if (sn == tn) {
+    // If sn == tn, then the rightmost child of the subtree and the rightmost
+    // child of the full tree are the same node, meaning that the subtree is
+    // directly contained in the full tree. The proof starts at the same level
+    // as the top of the subtree. That level is identified by advancing
+    // bit-by-bit through fn and sn until they are the same, meaning that we've
+    // moved up the levels of the tree to where the left and right edge of the
+    // subtree reach the same point (the root of the subtree).
+    //
+    // (Step 3: right shift until fn is sn)
+    while (fn != sn) {
+      fn >>= 1;
+      sn >>= 1;
+      tn >>= 1;
+    }
+  } else {
+    // Find the largest full (rather than partial) subtree that the rightmost
+    // edge of the input subtree is still the rightmost edge of, without going
+    // beyond the bounds of the input subtree. Any such full subtree is directly
+    // contained in the tree of hash operations for the full tree, meaning that
+    // the proof can start at that level rather than level 0.
+    //
+    // As long as the path to the rightmost edge of the input subtree is
+    // indicating it's on the right side of its parent node (its LSB is 1), it's
+    // still on the rightmost edge of a full (rather than partial) subtree.
+    // Iteration stops when it's no longer on the rightmost edge of a full
+    // subtree (LSB(sn) is not set) or when we reach the top of the input
+    // subtree (fn is sn).
+    //
+    // (Step 4: right-shift until fn is sn or LSB(sn) is not set)
+    while (fn != sn && (sn & 1) == 1) {
+      fn >>= 1;
+      sn >>= 1;
+      tn >>= 1;
+    }
+  }
+
+  // The proof array starts with the highest node from the subtree's right edge
+  // that is also in the overall tree, and subsequent values from the proof
+  // array are hashed in to compute the values that should be node_hash and
+  // root_hash if the proof is valid. As an optimization, if node_hash is the
+  // hash of the highest node from the subtree's right edge (i.e. the whole
+  // subtree is directly contained in the overall tree), that value is omitted
+  // from the proof.
+  //
+  // In this code, computed_node_hash and computed_root_hash are the values fr
+  // and sr from draft-davidben-tls-merkle-tree-certs-08.
+  // (Steps 5 and 6)
+  TreeHash computed_node_hash, computed_root_hash;
+  if (fn == sn) {
+    // The optimization mentioned above. (Step 5)
+    std::copy(node_hash.begin(), node_hash.end(), computed_node_hash.data());
+    std::copy(node_hash.begin(), node_hash.end(), computed_root_hash.data());
+  } else {
+    // The hashes start from the first value of the proof (Step 6)
+    std::optional<TreeHashConstSpan> first_hash = NextProofHash(&proof);
+    if (!first_hash) {
+      return std::nullopt;
+    }
+    std::copy(first_hash->begin(), first_hash->end(),
+              computed_node_hash.data());
+    std::copy(first_hash->begin(), first_hash->end(),
+              computed_root_hash.data());
+  }
+
+  // Iterate over the (remaining) elements of the proof array and traverse up
+  // the fn/sn/tn paths until we reach the root of the tree. Each step should
+  // consume one element from the proof array and move one level up the tree,
+  // and if the proof is valid both iterators end at the same time. (Step 7)
+  while (!proof.empty()) {
+    auto p = NextProofHash(&proof);
+    if (!p) {
+      return std::nullopt;
+    }
+    // (Step 7.1)
+    if (tn == 0) {
+      // We reached the root of the tree before running out of elements in the
+      // proof.
+      return std::nullopt;
+    }
+    // Update the computed root_hash, and if applicable the computed node_hash.
+    // We stop updating computed_node_hash when we've reached the level of the
+    // root of the subtree, which occurs when the paths to the leftmost and
+    // rightmost nodes of the subtree are the same, i.e. fn == sn.
+    //
+    // (Step 7.2)
+    if ((sn & 1) == 1 || sn == tn) {
+      // (Step 7.2.1)
+      if (fn < sn) {
+        HashNode(*p, computed_node_hash, computed_node_hash);
+      }
+      // (Step 7.2.2)
+      HashNode(*p, computed_root_hash, computed_root_hash);
+      // Until LSB(sn) is set, right-shift fn, sn, and tn equally.
+      // (Step 7.2.3)
+      while ((sn & 1) == 0) {
+        fn >>= 1;
+        sn >>= 1;
+        tn >>= 1;
+      }
+    } else {
+      // (Step 7.3.1)
+      HashNode(computed_root_hash, *p, computed_root_hash);
+    }
+    // Advance the iterators: (Step 7.4)
+    fn >>= 1;
+    sn >>= 1;
+    tn >>= 1;
+  }
+
+  // Check that the iterators ended together: (Step 8)
+  if (tn != 0) {
+    return std::nullopt;
+  }
+
+  // Check that the computed values match the expected values: (Step 8)
+  if (CRYPTO_memcmp(computed_node_hash.data(), node_hash.data(),
+                    computed_node_hash.size()) != 0) {
+    return std::nullopt;
+  }
+  // Return the computed root_hash for the caller to compare:
+  return computed_root_hash;
+}
+
+BSSL_NAMESPACE_END
diff --git a/pki/merkle_tree.h b/pki/merkle_tree.h
new file mode 100644
index 0000000..4523579
--- /dev/null
+++ b/pki/merkle_tree.h
@@ -0,0 +1,156 @@
+// Copyright 2025 The BoringSSL Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef BSSL_PKI_MERKLE_TREE_H_
+#define BSSL_PKI_MERKLE_TREE_H_
+
+#include <assert.h>
+
+#include <array>
+#include <optional>
+
+#include <openssl/sha2.h>
+#include <openssl/span.h>
+
+BSSL_NAMESPACE_BEGIN
+
+// A Subtree represents a range of elements in a Merkle Tree, identified by the
+// half-open interval [start, end) of tree indexes. A Subtree with start == end
+// represents a range of zero elements.
+struct Subtree {
+  uint64_t start;
+  uint64_t end;
+
+  constexpr bool operator==(const Subtree &other) const {
+    return start == other.start && end == other.end;
+  }
+
+  // Returns the number of elements in the Subtree.
+  constexpr uint64_t Size() const { return end - start; }
+
+  // Returns a value k such that Subtrees left = [start, k) and right = [k, end)
+  // are valid and share no interior nodes. Further, neither left nor right is
+  // empty unless the input Subtree has fewer than 2 elements.
+  constexpr uint64_t Split() const {
+    uint64_t n = Size();
+    if (n < 2) {
+      return end;
+    }
+    // find the largest power of 2 smaller than n
+    uint64_t k = Pow2Smaller(n);
+    return start + k;
+  }
+
+  // Returns the left subtree of this Subtree. If this subtree has fewer than 2
+  // elements, returns itself.
+  constexpr Subtree Left() const { return {start, Split()}; }
+
+  // Returns the right subtree of this Subtree. If this subtree has fewer than 2
+  // elements, returns an empty subtree.
+  constexpr Subtree Right() const { return {Split(), end}; }
+
+  // Returns whether [start, end) specifies a valid Subtree.
+  constexpr bool IsValid() const {
+    // A Subtree's half-open interval must have start <= end, otherwise the
+    // interval is improperly defined.
+    if (start > end) {
+      return false;
+    }
+    uint64_t n = Size();
+    // A Subtree must not have a ragged left edge, i.e. if k is the largest
+    // power of 2 that divides start, n must be less than or equal to k.
+    uint64_t k = start & (~start + 1);
+    return (start == 0 || n <= k);
+  }
+
+  // Returns whether this Subtree contains a leaf node at index.
+  constexpr bool Contains(uint64_t index) const {
+    return start <= index && index < end;
+  }
+  constexpr bool Contains(const Subtree &subtree) const {
+    return start <= subtree.start && subtree.end <= end;
+  }
+
+ private:
+  // compute the largest power of 2 smaller than n. Assumes n >= 2.
+  constexpr static uint64_t Pow2Smaller(uint64_t n) {
+    // TODO(crbug.com/404286922): replace the entirety of this function with
+    // std::bit_floor(n-1) once we can use C++20.
+    assert(n >= 2);
+    // The bitwise OR ladder here (`n |= n >> 1; n |= n >> 2;` etc) takes a
+    // number and copies any 1 bits to all positions to the right, resulting in
+    // a number that looks like 0b00...00111...1, where the number of 1 bits
+    // matches the bit position of the most significant bit in the input number.
+    // Assuming the input number m has the most significant bit in position k,
+    // it produces the value 2^(k+1)-1 >= m. Because both 2^(k+1)-1 and m have
+    // their MSB in the same place, right shifting 2^(k+1)-1 1 bit produces a
+    // number that is strictly smaller than both; that number is 2^k-1. Thus, we
+    // have 2^k-1 < m <= 2^(k+1)-1.
+    //
+    // Using that bit trick and right shifting 1 give us 2^k-1, the largest "one
+    // less than a power of 2" smaller than the input number. (It is the largest
+    // such value because the next largest is 2^(k+1)-1, which we showed to be
+    // greater than or equal to the input.) However, we want a power of 2, not
+    // one less than a power of 2. Finishing the procedure by adding 1 gives us
+    // a power of 2, but it does not guarantee that it is smaller than the
+    // input. If we consider the inequality 2^k-1 < m <= 2^(k+1)-1, we can add 1
+    // to all sides of the inequality to get 2^k < m+1 <= 2^(k+1). Given an
+    // input value n, we want to find the value 2^k such that
+    // 2^k < n <= 2^(k+1).
+    //
+    // Substituting n = m+1 and solving for m = n-1 means that running this
+    // procedure with n-1 will give us the power of 2 we're looking for.
+    n -= 1;
+    n |= n >> 1;
+    n |= n >> 2;
+    n |= n >> 4;
+    n |= n >> 8;
+    n |= n >> 16;
+    n |= n >> 32;
+    return (n >> 1) + 1;
+  }
+};
+
+using TreeHash = std::array<uint8_t, SHA256_DIGEST_LENGTH>;
+using TreeHashSpan = Span<uint8_t, SHA256_DIGEST_LENGTH>;
+using TreeHashConstSpan = Span<const uint8_t, SHA256_DIGEST_LENGTH>;
+
+// Performs the procedure defined in section 4.4.3 of
+// draft-davidben-tls-merkle-tree-certs-08, Verifying a Subtree Consistency
+// Proof:
+//
+//   Given a Merkle Tree over `n` elements, a subtree defined by `[start, end)`,
+//   a consistency proof `proof`, a subtree hash `node_hash`, and a root hash
+//   `root_hash`
+//
+// The one difference between this function and the routine described in
+// draft-davidben-tls-merkle-tree-certs-08 is that instead of taking `root_hash`
+// as an input, this function returns the computed root hash and it is the
+// caller's responsibility to verify that the computed root hash matches the
+// expected root hash. This function returns std::nullopt if other steps of
+// proof verification failed.
+OPENSSL_EXPORT std::optional<TreeHash> EvaluateMerkleSubtreeConsistencyProof(
+    uint64_t n, const Subtree &subtree, Span<const uint8_t> proof,
+    TreeHashConstSpan node_hash);
+
+// Helper function to compute the hash value of an interior node in a Merkle
+// tree, i.e. HASH(0x01 || left || right). 32 bytes of output are written to
+// |out|. This function is intended for internal use only and only exists here
+// for the convenience of merkle_tree_unittest.cc.
+OPENSSL_EXPORT void HashNode(TreeHashConstSpan left, TreeHashConstSpan right,
+                             TreeHashSpan out);
+
+BSSL_NAMESPACE_END
+
+#endif  // BSSL_PKI_MERKLE_TREE_H_
diff --git a/pki/merkle_tree_unittest.cc b/pki/merkle_tree_unittest.cc
new file mode 100644
index 0000000..2447598
--- /dev/null
+++ b/pki/merkle_tree_unittest.cc
@@ -0,0 +1,451 @@
+// Copyright 2025 The BoringSSL Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "merkle_tree.h"
+
+#include <cassert>
+#include <cstdint>
+#include <limits>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <openssl/sha2.h>
+
+BSSL_NAMESPACE_BEGIN
+
+namespace {
+
+class MerkleTree {
+ public:
+  class Data {
+   public:
+    virtual ~Data() = default;
+    virtual std::vector<uint8_t> At(uint64_t index) = 0;
+    // Caching functions are only called for full subtrees (size is a power of
+    // 2) that have size at least 2.
+    virtual std::optional<TreeHash> NodeHash(Subtree node) = 0;
+    virtual void CacheNodeHash(Subtree node, TreeHash hash) = 0;
+  };
+
+  explicit MerkleTree(Data *data) : data_(data) {}
+  MerkleTree(const MerkleTree &) = delete;
+  MerkleTree(MerkleTree &&) = default;
+  MerkleTree &operator=(const MerkleTree &) = delete;
+  MerkleTree &operator=(MerkleTree &&) = default;
+
+  // MTH computes the Merkle Tree Hash (MTH; RFC 9162, section 2.1.1) of a
+  // tree containing `end - start` elements from D_n, starting at d[start]. If
+  // start > end or there is an internal error, this function returns nullopt.
+  //
+  // Note that the MTH function defined in RFC 9162 takes an ordered list of
+  // inputs D_n. This function takes start and end indicies to identify the
+  // inputs.
+  std::optional<TreeHash> MTH(const Subtree &subtree) {
+    if (!subtree.IsValid()) {
+      return std::nullopt;
+    }
+    SHA256_CTX ctx;
+    SHA256_Init(&ctx);
+    TreeHash out;
+    uint64_t n = subtree.Size();
+    if (n == 0) {
+      // The hash of an empty list is the hash of an empty string.
+      SHA256_Final(out.data(), &ctx);
+      return out;
+    }
+    if (n == 1) {
+      // One element in the list: return a leaf hash.
+      static const uint8_t header = 0x00;
+      auto leaf = data_->At(subtree.start);
+      SHA256_Update(&ctx, &header, 1);
+      SHA256_Update(&ctx, leaf.data(), leaf.size());
+      SHA256_Final(out.data(), &ctx);
+      return out;
+    }
+    // Only use the cache for subtrees with a size that is a power of 2.
+    uint64_t s = subtree.end - subtree.start;
+    bool use_cache = (s & (s - 1)) == 0;
+    if (use_cache) {
+      if (auto hash_opt = data_->NodeHash(subtree); hash_opt.has_value()) {
+        return *hash_opt;
+      }
+    }
+    // n elements in the list: MTH() is defined recursively.
+    auto left = MTH(subtree.Left());
+    auto right = MTH(subtree.Right());
+    if (!left.has_value() || !right.has_value()) {
+      return std::nullopt;
+    }
+    HashNode(*left, *right, out);
+    if (use_cache) {
+      data_->CacheNodeHash(subtree, out);
+    }
+    return out;
+  }
+
+  // Computes an inclusion proof to the element at index in subtree from start
+  // to end.
+  std::optional<std::vector<TreeHash>> InclusionProof(uint64_t index,
+                                                      const Subtree &subtree) {
+    return SubtreeSubproof({index, index + 1}, subtree, true);
+  }
+
+  // Computes a consistency proof that |subtree| is contained in |tree|.
+  std::optional<std::vector<TreeHash>> ConsistencyProof(const Subtree &subtree,
+                                                        const Subtree &tree) {
+    return SubtreeSubproof(subtree, tree, true);
+  }
+
+ private:
+  // Computes a SUBTREE_SUBPROOF from subtree to tree, where subtree is
+  // contained within tree.
+  std::optional<std::vector<TreeHash>> SubtreeSubproof(Subtree subtree,
+                                                       const Subtree &tree,
+                                                       bool known_hash) {
+    if (!subtree.IsValid() || !tree.IsValid() || !tree.Contains(subtree)) {
+      // Invalid inputs
+      return std::nullopt;
+    }
+    uint64_t n = tree.Size();
+    if (n == 0) {
+      // There must be a tree with contents for there to be a proof that
+      // something is in said tree.
+      return std::nullopt;
+    }
+    if (subtree == tree) {
+      if (known_hash) {
+        return std::vector<TreeHash>();
+      }
+      auto mth = MTH(tree);
+      if (!mth.has_value()) {
+        return std::nullopt;
+      }
+      return {{*mth}};
+    }
+
+    uint64_t k = tree.Split();
+    Subtree subproof_tree, mth_tree;
+    if (subtree.end <= k) {
+      subproof_tree = tree.Left();
+      mth_tree = tree.Right();
+    } else if (subtree.start >= k) {
+      mth_tree = tree.Left();
+      subproof_tree = tree.Right();
+    } else {
+      subtree.start = k;
+      mth_tree = tree.Left();
+      subproof_tree = tree.Right();
+      known_hash = false;
+    }
+    auto subproof = SubtreeSubproof(subtree, subproof_tree, known_hash);
+    auto mth = MTH(mth_tree);
+    if (!subproof.has_value() || !mth.has_value()) {
+      return std::nullopt;
+    }
+    subproof->push_back(*mth);
+    return subproof;
+  }
+
+  Data *data_;
+};
+
+size_t countr_zero(uint64_t n) {
+  if (n == 0) {
+    return 8 * sizeof(n);
+  }
+  size_t count = 0;
+  while ((n & 1) == 0) {
+    n >>= 1;
+    count++;
+  }
+  return count;
+}
+
+class ConcatData : public MerkleTree::Data {
+ public:
+  explicit ConcatData(Span<const uint8_t> label)
+      : label_(label.begin(), label.end()) {}
+
+  std::vector<uint8_t> At(uint64_t index) override {
+    std::vector<uint8_t> out(label_.size() + sizeof(index));
+    memcpy(out.data(), label_.data(), label_.size());
+    memcpy(out.data() + label_.size(), &index, sizeof(index));
+    return out;
+  }
+
+  std::optional<TreeHash> NodeHash(Subtree node) override {
+    size_t level_index = countr_zero(node.Size()) - 1;
+    if (node_cache_.size() <= level_index) {
+      return std::nullopt;
+    }
+    size_t position_index = node.start / node.Size();
+    if (node_cache_[level_index].size() <= position_index) {
+      return std::nullopt;
+    }
+    return node_cache_[level_index][position_index];
+  }
+
+  void CacheNodeHash(Subtree node, TreeHash hash) override {
+    size_t level_index = countr_zero(node.Size()) - 1;
+    size_t position_index = node.start / node.Size();
+    if (level_index >= node_cache_.size()) {
+      BSSL_CHECK(level_index == node_cache_.size());
+      node_cache_.push_back({});
+    }
+    if (position_index < node_cache_[level_index].size()) {
+      node_cache_[level_index][position_index] = hash;
+      return;
+    }
+    node_cache_[level_index].resize(position_index);
+    node_cache_[level_index].push_back(hash);
+  }
+
+ private:
+  std::vector<uint8_t> label_;
+  std::vector<std::vector<std::optional<TreeHash>>> node_cache_;
+};
+
+
+TEST(MerkleTreeTest, SubtreeIsValid) {
+  // An empty subtree is valid.
+  EXPECT_TRUE((Subtree{0, 0}.IsValid()));
+  // But if the end is before start, it's invalid.
+  EXPECT_FALSE((Subtree{1, 0}.IsValid()));
+  // A subtree of the maximum expressible size is valid.
+  EXPECT_TRUE((Subtree{0, std::numeric_limits<uint64_t>::max()}.IsValid()));
+
+  // Subtrees don't have to start at 0.
+  EXPECT_TRUE((Subtree{4, 8}.IsValid()));
+  // But if they don't start at 0, there's a limit to how big they can be.
+  EXPECT_FALSE((Subtree{4, 9}.IsValid()));
+  // Subtrees can have a ragged right edge.
+  EXPECT_TRUE((Subtree{4, 6}.IsValid()));
+  EXPECT_TRUE((Subtree{0, 6}.IsValid()));
+}
+
+TEST(MerkleTreeTest, SubtreeSplit) {
+  // Empty subtree.
+  EXPECT_EQ((Subtree{24601, 24601}).Split(), 24601ul);
+  // Single-item subtree.
+  EXPECT_EQ((Subtree{1336, 1337}).Split(), 1337ul);
+  // Two items in subtree.
+  EXPECT_EQ((Subtree{42, 44}).Split(), 43ul);
+  // Subtree size is 1 less than a power of 2.
+  EXPECT_EQ((Subtree{0, 31}).Split(), 16ul);
+  // Subtree size is a power of 2.
+  EXPECT_EQ((Subtree{64, 128}).Split(), 96ul);
+  /// Subtree size is 1 more than a power of 2.
+  EXPECT_EQ((Subtree{0, 257}).Split(), 256ul);
+
+  static const uint64_t u64_max = std::numeric_limits<uint64_t>::max();
+  // Maximum size tree.
+  EXPECT_EQ((Subtree{0, u64_max}).Split(), 1ull << 63);
+  // Small tree, with end at maximum value.
+  EXPECT_EQ((Subtree{u64_max - 3, u64_max}).Split(), u64_max - 1);
+}
+
+std::vector<uint8_t> ConcatProof(const std::vector<TreeHash> &proof) {
+  std::vector<uint8_t> out;
+  for (const auto &p : proof) {
+    out.insert(out.end(), p.begin(), p.end());
+  }
+  return out;
+}
+
+TEST(MerkleTreeTest, VerifySubtreeConsistencyProof) {
+  ConcatData tree_data(StringAsBytes("label"));
+  MerkleTree tree(&tree_data);
+
+  uint64_t index = 0;
+  auto node_hash = tree.MTH({index, index + 1});
+  Subtree subtree{0, 16};
+  auto proof = tree.InclusionProof(index, subtree);
+  ASSERT_TRUE(proof.has_value());
+  auto root_hash = EvaluateMerkleSubtreeConsistencyProof(
+      subtree.end, {index, index + 1}, ConcatProof(*proof), *node_hash);
+  ASSERT_TRUE(root_hash.has_value());
+  EXPECT_EQ(root_hash, tree.MTH(subtree));
+}
+
+// Test that the computed consistency proofs match the examples given in RFC
+// 9162 section 2.1.5.
+TEST(MerkleTreeTest, SubtreeConsistencyProofRFC9162) {
+  ConcatData tree_data(StringAsBytes("label"));
+  MerkleTree tree(&tree_data);
+
+  // The example from section 2.1.5 has a final tree with 7 leaves.
+  Subtree final_tree{0, 7};
+
+  // The examples refer to letters representing the MTH of various subtrees
+  // within that tree.
+  // a isn't used in any of the examples.
+  auto b = tree.MTH({1, 2});
+  auto c = tree.MTH({2, 3});
+  auto d = tree.MTH({3, 4});
+  // e isn't used in any of the examples.
+  auto f = tree.MTH({5, 6});
+  auto g = tree.MTH({0, 2});
+  auto h = tree.MTH({2, 4});
+  auto i = tree.MTH({4, 6});
+  auto j = tree.MTH({6, 7});
+  auto k = tree.MTH({0, 4});
+  auto l = tree.MTH({4, 7});
+
+  // Inclusion proofs:
+
+  // Section 2.1.5: "The inclusion proof for `d0` is `[b, h, l]`."
+  auto d0_proof = tree.InclusionProof(0, final_tree);
+  EXPECT_THAT(*d0_proof, testing::ElementsAre(*b, *h, *l));
+
+  // Section 2.1.5: "The inclusion proof for `d3` is `[c, g, l]`."
+  auto d3_proof = tree.InclusionProof(3, final_tree);
+  EXPECT_THAT(*d3_proof, testing::ElementsAre(*c, *g, *l));
+
+  // Section 2.1.5: "The inclusion proof for `d4` is `[f, j, k]`."
+  auto d4_proof = tree.InclusionProof(4, final_tree);
+  EXPECT_THAT(*d4_proof, testing::ElementsAre(*f, *j, *k));
+
+  // Section 2.1.5: "The inclusion proof for `d6` is `[i, k]`."
+  auto d6_proof = tree.InclusionProof(6, final_tree);
+  EXPECT_THAT(*d6_proof, testing::ElementsAre(*i, *k));
+
+  // Consistency proofs:
+
+  // The consistency proofs refer to the lettered MTHs above, as well as some
+  // MTHs representing the tree as it was incrementally built.
+  Subtree hash0_subtree = {0, 3};
+  Subtree hash1_subtree = {0, 4};
+  auto hash1 = tree.MTH(hash1_subtree);
+  ASSERT_EQ(hash1, k);
+  Subtree hash2_subtree = {0, 6};
+
+  // "The consistency proof between hash0 and hash is [c, d, g, l]."
+  auto hash0_proof = tree.ConsistencyProof(hash0_subtree, final_tree);
+  EXPECT_THAT(*hash0_proof, testing::ElementsAre(*c, *d, *g, *l));
+
+  // "The consistency proof beween hash1 and hash is [l]."
+  auto hash1_proof = tree.ConsistencyProof(hash1_subtree, final_tree);
+  EXPECT_THAT(*hash1_proof, testing::ElementsAre(*l));
+
+  // "The consistency proof between hash2 and hash is [i, j, k]."
+  auto hash2_proof = tree.ConsistencyProof(hash2_subtree, final_tree);
+  EXPECT_THAT(*hash2_proof, testing::ElementsAre(*i, *j, *k));
+}
+
+TEST(MerkleTreeTest, ValidProofsTest) {
+  ConcatData tree_data(StringAsBytes("label"));
+  MerkleTree tree(&tree_data);
+
+  uint64_t n = 4, start = 0, end = 3;
+  Subtree full_tree{0, n};
+  auto tree_hash = tree.MTH(full_tree);
+  Subtree subtree{start, end};
+  ASSERT_TRUE(subtree.IsValid());
+  auto subtree_hash = tree.MTH(subtree);
+
+  auto proof = tree.ConsistencyProof(subtree, full_tree);
+  ASSERT_TRUE(proof.has_value());
+  auto computed_hash = EvaluateMerkleSubtreeConsistencyProof(
+      n, subtree, ConcatProof(*proof), *subtree_hash);
+  EXPECT_EQ(computed_hash, tree_hash);
+}
+
+TEST(MerkleTreeTest, ValidProofs) {
+  ConcatData tree_data(StringAsBytes("label"));
+  MerkleTree tree(&tree_data);
+
+  // As of the time of writing this test, a run was performed with limit=257 and
+  // the test passed (but it took 1.7 seconds to run). This value is set to 129
+  // to balance how much of the space to explore with test execution time.
+  uint64_t limit = 129;
+  for (uint64_t n = 0; n < limit; n++) {
+    Subtree full_tree{0, n};
+    auto tree_hash = tree.MTH(full_tree);
+    for (uint64_t end = 0; end <= n; end++) {
+      for (uint64_t start = 0; start < end; start++) {
+        Subtree subtree{start, end};
+        if (!subtree.IsValid()) {
+          continue;
+        }
+        SCOPED_TRACE(testing::Message() << "Tree n=" << n << ", start: "
+                                        << start << ", end: " << end);
+        auto subtree_hash = tree.MTH(subtree);
+
+        auto proof = tree.ConsistencyProof(subtree, full_tree);
+        ASSERT_TRUE(proof.has_value());
+        auto computed_hash = EvaluateMerkleSubtreeConsistencyProof(
+            n, subtree, ConcatProof(*proof), *subtree_hash);
+        EXPECT_EQ(computed_hash, tree_hash);
+      }
+    }
+  }
+}
+
+class ConstData : public MerkleTree::Data {
+ public:
+  // A tree that uses ConstData has the same data at every index in the tree.
+  std::vector<uint8_t> At(uint64_t index) override { return {}; }
+
+  std::optional<TreeHash> NodeHash(Subtree node) override {
+    size_t index = countr_zero(node.Size()) - 1;
+    if (node_cache_.size() <= index) {
+      return std::nullopt;
+    }
+    return node_cache_[index];
+  }
+
+  void CacheNodeHash(Subtree node, TreeHash hash) override {
+    size_t index = countr_zero(node.Size()) - 1;
+    // Cache entries are only inserted if not present, and always inserted from
+    // lowest level of the tree to highest.
+    BSSL_CHECK(index == node_cache_.size());
+    node_cache_.push_back(hash);
+  }
+
+ private:
+  std::vector<TreeHash> node_cache_;
+};
+
+TEST(MerkleTreeTest, VeryLargeProofs) {
+  ConstData tree_data;
+  MerkleTree tree(&tree_data);
+
+  Subtree fullest_tree = {0, std::numeric_limits<uint64_t>::max()};
+  auto root_hash = tree.MTH(fullest_tree);
+  ASSERT_TRUE(root_hash.has_value());
+
+  Subtree test_subtrees[] = {
+      fullest_tree,
+      {0, 1},
+      {0, 1ull << 63},
+      {1ull << 63, std::numeric_limits<uint64_t>::max()},
+      {std::numeric_limits<uint64_t>::max() - 1,
+       std::numeric_limits<uint64_t>::max()},
+  };
+  for (auto subtree : test_subtrees) {
+    SCOPED_TRACE(testing::Message() << "Subtree start: " << subtree.start
+                                    << ", end: " << subtree.end);
+
+    auto proof = tree.ConsistencyProof(subtree, fullest_tree);
+    ASSERT_TRUE(proof.has_value());
+    auto computed_root_hash = EvaluateMerkleSubtreeConsistencyProof(
+        fullest_tree.end, subtree, ConcatProof(*proof), *tree.MTH(subtree));
+    EXPECT_EQ(computed_root_hash, root_hash);
+  }
+}
+
+}  // namespace
+
+BSSL_NAMESPACE_END