Unify the HPKE implementation for ML-KEM.

ML-KEM-768 and ML-KEM-1024 HPKE implementations are roughly copy-pastes
with "768" changed to "1024". C++ templates are able to do that sort of
thing and remove the code duplication.

Change-Id: I6133d1fe5017539c479c60482a465ef94a036ff4
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/82487
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Adam Langley <agl@google.com>
diff --git a/crypto/hpke/hpke.cc b/crypto/hpke/hpke.cc
index d49882a..38fba75 100644
--- a/crypto/hpke/hpke.cc
+++ b/crypto/hpke/hpke.cc
@@ -720,239 +720,161 @@
   return &kKEM;
 }
 
-#define MLKEM768_PRIVATE_KEY_LEN MLKEM_SEED_BYTES
-#define MLKEM768_PUBLIC_KEY_LEN MLKEM768_PUBLIC_KEY_BYTES
-#define MLKEM768_PUBLIC_VALUE_LEN MLKEM768_CIPHERTEXT_BYTES
-#define MLKEM768_SEED_LEN BCM_MLKEM_ENCAP_ENTROPY
-#define MLKEM768_SHARED_KEY_LEN MLKEM_SHARED_SECRET_BYTES
+namespace {
 
-static int mlkem768_init_key(EVP_HPKE_KEY *key, const uint8_t *priv_key,
-                             size_t priv_key_len) {
-  MLKEM768_private_key expanded_private_key;
-  if (!MLKEM768_private_key_from_seed(&expanded_private_key, priv_key,
-                                      priv_key_len)) {
-    OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
-    return 0;
-  }
-  MLKEM768_public_key public_key;
-  MLKEM768_public_from_private(&public_key, &expanded_private_key);
-  CBB cbb;
-  static_assert(sizeof(key->public_key) >= MLKEM768_PUBLIC_KEY_LEN,
-                "EVP_HPKE_KEY public_key is too small for ML-KEM-768.");
-  if (!CBB_init_fixed(&cbb, key->public_key, MLKEM768_PUBLIC_KEY_LEN) ||
-      !MLKEM768_marshal_public_key(&cbb, &public_key)) {
-    return 0;
+template <uint16_t KEM_ID, size_t PUBLIC_KEY_BYTES, size_t CIPHERTEXT_BYTES,
+          size_t ENCAP_ENTROPY_BYTES,
+
+          typename PrivateKey, typename PublicKey, typename BCMPublicKey,
+
+          int (*PrivateKeyFromSeed)(PrivateKey *, const uint8_t *, size_t),
+          void (*PublicFromPrivate)(PublicKey *, const PrivateKey *),
+          int (*MarshalPublicKey)(CBB *, const PublicKey *),
+          void (*GenerateKey)(uint8_t *, uint8_t *, PrivateKey *),
+          bcm_status (*BCMParsePublicKey)(BCMPublicKey *, CBS *),
+          bcm_infallible (*BCMEncapExternalEntropy)(
+              uint8_t *, uint8_t *, const BCMPublicKey *, const uint8_t *),
+          int (*Decap)(uint8_t *, const uint8_t *, size_t, const PrivateKey *)>
+struct MLKEMHPKE {
+  // These sizes are common across both ML-KEM-768 and ML-KEM-1024.
+  static constexpr size_t PRIVATE_KEY_LEN = MLKEM_SEED_BYTES;
+  static constexpr size_t SHARED_KEY_LEN = MLKEM_SHARED_SECRET_BYTES;
+
+  static constexpr uint16_t ID = KEM_ID;
+  static constexpr size_t PUBLIC_KEY_LEN = PUBLIC_KEY_BYTES;
+  static constexpr size_t SEED_LEN = ENCAP_ENTROPY_BYTES;
+  static constexpr size_t ENC_LEN = CIPHERTEXT_BYTES;
+
+  static int InitKey(EVP_HPKE_KEY *key, const uint8_t *priv_key,
+                     size_t priv_key_len) {
+    PrivateKey expanded_private_key;
+    if (!PrivateKeyFromSeed(&expanded_private_key, priv_key, priv_key_len)) {
+      OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
+      return 0;
+    }
+    PublicKey public_key;
+    PublicFromPrivate(&public_key, &expanded_private_key);
+    CBB cbb;
+    static_assert(sizeof(key->public_key) >= PUBLIC_KEY_BYTES,
+                  "EVP_HPKE_KEY public_key is too small for ML-KEM.");
+    if (!CBB_init_fixed(&cbb, key->public_key, PUBLIC_KEY_BYTES) ||
+        !MarshalPublicKey(&cbb, &public_key)) {
+      return 0;
+    }
+
+    static_assert(sizeof(key->private_key) >= PRIVATE_KEY_LEN,
+                  "EVP_HPKE_KEY private_key is too small for ML-KEM");
+    OPENSSL_memcpy(key->private_key, priv_key, priv_key_len);
+    return 1;
   }
 
-  static_assert(sizeof(key->private_key) >= MLKEM768_PRIVATE_KEY_LEN,
-                "EVP_HPKE_KEY private_key is too small for ML-KEM-768");
-  OPENSSL_memcpy(key->private_key, priv_key, priv_key_len);
-  return 1;
-}
+  static int HpkeGenerateKey(EVP_HPKE_KEY *key) {
+    static_assert(sizeof(key->public_key) >= PUBLIC_KEY_BYTES,
+                  "EVP_HPKE_KEY public_key is too small for ML-KEM.");
+    static_assert(sizeof(key->private_key) >= PRIVATE_KEY_LEN,
+                  "EVP_HPKE_KEY private_key is too small for ML-KEM");
+    PrivateKey expanded_private_key;
+    GenerateKey(key->public_key, key->private_key, &expanded_private_key);
 
-static int mlkem768_generate_key(EVP_HPKE_KEY *key) {
-  static_assert(sizeof(key->public_key) >= MLKEM768_PUBLIC_KEY_LEN,
-                "EVP_HPKE_KEY public_key is too small for ML-KEM-768.");
-  static_assert(sizeof(key->private_key) >= MLKEM768_PRIVATE_KEY_LEN,
-                "EVP_HPKE_KEY private_key is too small for ML-KEM-768");
-  MLKEM768_private_key expanded_private_key;
-  MLKEM768_generate_key(key->public_key, key->private_key,
-                        &expanded_private_key);
-
-  return 1;
-}
-
-static int mlkem768_encap_with_seed(
-    const EVP_HPKE_KEM *kem, uint8_t *out_shared_secret,
-    size_t *out_shared_secret_len, uint8_t *out_enc, size_t *out_enc_len,
-    size_t max_enc, const uint8_t *peer_public_key, size_t peer_public_key_len,
-    const uint8_t *seed, size_t seed_len) {
-  if (max_enc < MLKEM768_PUBLIC_VALUE_LEN) {
-    OPENSSL_PUT_ERROR(EVP, EVP_R_INVALID_BUFFER_SIZE);
-    return 0;
-  }
-  if (peer_public_key_len != MLKEM768_PUBLIC_KEY_LEN ||
-      seed_len != MLKEM768_SEED_LEN) {
-    OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
-    return 0;
+    return 1;
   }
 
-  CBS cbs;
-  CBS_init(&cbs, peer_public_key, peer_public_key_len);
-  BCM_mlkem768_public_key public_key;
-  if (!bcm_success(BCM_mlkem768_parse_public_key(&public_key, &cbs))) {
-    OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
-    return 0;
-  }
-  // The public ML-KEM interface doesn't support providing the encap entropy so
-  // the BCM function is used here.
-  BCM_mlkem768_encap_external_entropy(out_enc, out_shared_secret, &public_key,
-                                      seed);
+  static int EncapWithSeed(const EVP_HPKE_KEM *kem, uint8_t *out_shared_secret,
+                           size_t *out_shared_secret_len, uint8_t *out_enc,
+                           size_t *out_enc_len, size_t max_enc,
+                           const uint8_t *peer_public_key,
+                           size_t peer_public_key_len, const uint8_t *seed,
+                           size_t seed_len) {
+    if (max_enc < CIPHERTEXT_BYTES) {
+      OPENSSL_PUT_ERROR(EVP, EVP_R_INVALID_BUFFER_SIZE);
+      return 0;
+    }
+    if (peer_public_key_len != PUBLIC_KEY_BYTES ||
+        seed_len != ENCAP_ENTROPY_BYTES) {
+      OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
+      return 0;
+    }
 
-  *out_enc_len = MLKEM768_PUBLIC_VALUE_LEN;
-  *out_shared_secret_len = MLKEM768_SHARED_KEY_LEN;
-  return 1;
-}
+    CBS cbs;
+    CBS_init(&cbs, peer_public_key, peer_public_key_len);
+    BCMPublicKey public_key;
+    if (!bcm_success(BCMParsePublicKey(&public_key, &cbs))) {
+      OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
+      return 0;
+    }
+    // The public ML-KEM interface doesn't support providing the encap entropy
+    // so the BCM function is used here.
+    BCMEncapExternalEntropy(out_enc, out_shared_secret, &public_key, seed);
 
-static int mlkem768_decap(const EVP_HPKE_KEY *key, uint8_t *out_shared_secret,
-                          size_t *out_shared_secret_len, const uint8_t *enc,
-                          size_t enc_len) {
-  if (enc_len != MLKEM768_PUBLIC_VALUE_LEN) {
-    OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
-    return 0;
+    *out_enc_len = CIPHERTEXT_BYTES;
+    *out_shared_secret_len = SHARED_KEY_LEN;
+    return 1;
   }
 
-  MLKEM768_private_key private_key;
-  if (!MLKEM768_private_key_from_seed(&private_key, key->private_key,
-                                      MLKEM768_PRIVATE_KEY_LEN)) {
-    OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
-    return 0;
+  static int HpkeDecap(const EVP_HPKE_KEY *key, uint8_t *out_shared_secret,
+                       size_t *out_shared_secret_len, const uint8_t *enc,
+                       size_t enc_len) {
+    PrivateKey private_key;
+    if (!PrivateKeyFromSeed(&private_key, key->private_key, PRIVATE_KEY_LEN)) {
+      OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
+      return 0;
+    }
+
+    if (!Decap(out_shared_secret, enc, enc_len, &private_key)) {
+      OPENSSL_PUT_ERROR(EVP, ERR_R_INTERNAL_ERROR);
+      return 0;
+    }
+
+    *out_shared_secret_len = SHARED_KEY_LEN;
+    return 1;
   }
+};
 
-  if (!MLKEM768_decap(out_shared_secret, enc, enc_len, &private_key)) {
-    OPENSSL_PUT_ERROR(EVP, ERR_R_INTERNAL_ERROR);
-    return 0;
-  }
+using MLKEM768HPKE =
+    MLKEMHPKE<EVP_HPKE_MLKEM768, MLKEM768_PUBLIC_KEY_BYTES,
+              MLKEM768_CIPHERTEXT_BYTES, BCM_MLKEM_ENCAP_ENTROPY,
 
-  *out_shared_secret_len = MLKEM768_SHARED_KEY_LEN;
-  return 1;
-}
+              MLKEM768_private_key, MLKEM768_public_key,
+              BCM_mlkem768_public_key,
 
-const EVP_HPKE_KEM *EVP_hpke_mlkem768(void) {
-  static const EVP_HPKE_KEM kKEM = {
-      /*id=*/EVP_HPKE_MLKEM768,
-      /*public_key_len=*/MLKEM768_PUBLIC_KEY_LEN,
-      /*private_key_len=*/MLKEM768_PRIVATE_KEY_LEN,
-      /*seed_len=*/MLKEM768_SEED_LEN,
-      /*enc_len=*/MLKEM768_PUBLIC_VALUE_LEN,
-      mlkem768_init_key,
-      mlkem768_generate_key,
-      mlkem768_encap_with_seed,
-      mlkem768_decap,
-      // MLKEM768 doesn't support authenticated encapsulation/decapsulation:
-      // https://datatracker.ietf.org/doc/draft-ietf-hpke-pq/01/
-      /* auth_encap_with_seed= */ nullptr,
-      /* auth_decap= */ nullptr,
-  };
-  return &kKEM;
-}
+              MLKEM768_private_key_from_seed, MLKEM768_public_from_private,
+              MLKEM768_marshal_public_key, MLKEM768_generate_key,
+              BCM_mlkem768_parse_public_key,
+              BCM_mlkem768_encap_external_entropy, MLKEM768_decap>;
 
-#define MLKEM1024_PRIVATE_KEY_LEN MLKEM_SEED_BYTES
-#define MLKEM1024_PUBLIC_KEY_LEN MLKEM1024_PUBLIC_KEY_BYTES
-#define MLKEM1024_PUBLIC_VALUE_LEN MLKEM1024_CIPHERTEXT_BYTES
-#define MLKEM1024_SEED_LEN BCM_MLKEM_ENCAP_ENTROPY
-#define MLKEM1024_SHARED_KEY_LEN MLKEM_SHARED_SECRET_BYTES
+using MLKEM1024HPKE =
+    MLKEMHPKE<EVP_HPKE_MLKEM1024, MLKEM1024_PUBLIC_KEY_BYTES,
+              MLKEM1024_CIPHERTEXT_BYTES, BCM_MLKEM_ENCAP_ENTROPY,
 
-static int mlkem1024_init_key(EVP_HPKE_KEY *key, const uint8_t *priv_key,
-                              size_t priv_key_len) {
-  MLKEM1024_private_key expanded_private_key;
-  if (!MLKEM1024_private_key_from_seed(&expanded_private_key, priv_key,
-                                       priv_key_len)) {
-    OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
-    return 0;
-  }
-  MLKEM1024_public_key public_key;
-  MLKEM1024_public_from_private(&public_key, &expanded_private_key);
-  CBB cbb;
-  static_assert(sizeof(key->public_key) >= MLKEM1024_PUBLIC_KEY_LEN,
-                "EVP_HPKE_KEY public_key is too small for ML-KEM-1024.");
-  if (!CBB_init_fixed(&cbb, key->public_key, MLKEM1024_PUBLIC_KEY_LEN) ||
-      !MLKEM1024_marshal_public_key(&cbb, &public_key)) {
-    return 0;
-  }
+              MLKEM1024_private_key, MLKEM1024_public_key,
+              BCM_mlkem1024_public_key,
 
-  static_assert(sizeof(key->private_key) >= MLKEM1024_PRIVATE_KEY_LEN,
-                "EVP_HPKE_KEY private_key is too small for ML-KEM-1024");
-  OPENSSL_memcpy(key->private_key, priv_key, priv_key_len);
-  return 1;
-}
+              MLKEM1024_private_key_from_seed, MLKEM1024_public_from_private,
+              MLKEM1024_marshal_public_key, MLKEM1024_generate_key,
+              BCM_mlkem1024_parse_public_key,
+              BCM_mlkem1024_encap_external_entropy, MLKEM1024_decap>;
 
-static int mlkem1024_generate_key(EVP_HPKE_KEY *key) {
-  static_assert(sizeof(key->public_key) >= MLKEM1024_PUBLIC_KEY_LEN,
-                "EVP_HPKE_KEY public_key is too small for ML-KEM-1024.");
-  static_assert(sizeof(key->private_key) >= MLKEM1024_PRIVATE_KEY_LEN,
-                "EVP_HPKE_KEY private_key is too small for ML-KEM-1024");
-  MLKEM1024_private_key expanded_private_key;
-  MLKEM1024_generate_key(key->public_key, key->private_key,
-                         &expanded_private_key);
+template <typename MLKEM>
+static const EVP_HPKE_KEM kMLKEM = {
+    /*id=*/MLKEM::ID,
+    /*public_key_len=*/MLKEM::PUBLIC_KEY_LEN,
+    /*private_key_len=*/MLKEM::PRIVATE_KEY_LEN,
+    /*seed_len=*/MLKEM::SEED_LEN,
+    /*enc_len=*/MLKEM::ENC_LEN,
+    MLKEM::InitKey,
+    MLKEM::HpkeGenerateKey,
+    MLKEM::EncapWithSeed,
+    MLKEM::HpkeDecap,
+    // ML-KEM doesn't support authenticated encapsulation/decapsulation:
+    // https://datatracker.ietf.org/doc/draft-ietf-hpke-pq/01/
+    /*auth_encap_with_seed=*/nullptr,
+    /*auth_decap=*/nullptr,
+};
 
-  return 1;
-}
+}  // namespace
 
-static int mlkem1024_encap_with_seed(
-    const EVP_HPKE_KEM *kem, uint8_t *out_shared_secret,
-    size_t *out_shared_secret_len, uint8_t *out_enc, size_t *out_enc_len,
-    size_t max_enc, const uint8_t *peer_public_key, size_t peer_public_key_len,
-    const uint8_t *seed, size_t seed_len) {
-  if (max_enc < MLKEM1024_PUBLIC_VALUE_LEN) {
-    OPENSSL_PUT_ERROR(EVP, EVP_R_INVALID_BUFFER_SIZE);
-    return 0;
-  }
-  if (peer_public_key_len != MLKEM1024_PUBLIC_KEY_LEN ||
-      seed_len != MLKEM1024_SEED_LEN) {
-    OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
-    return 0;
-  }
-
-  CBS cbs;
-  CBS_init(&cbs, peer_public_key, peer_public_key_len);
-  BCM_mlkem1024_public_key public_key;
-  if (!bcm_success(BCM_mlkem1024_parse_public_key(&public_key, &cbs))) {
-    OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
-    return 0;
-  }
-  // The public ML-KEM interface doesn't support providing the encap entropy so
-  // the BCM function is used here.
-  BCM_mlkem1024_encap_external_entropy(out_enc, out_shared_secret, &public_key,
-                                       seed);
-
-  *out_enc_len = MLKEM1024_PUBLIC_VALUE_LEN;
-  *out_shared_secret_len = MLKEM1024_SHARED_KEY_LEN;
-  return 1;
-}
-
-static int mlkem1024_decap(const EVP_HPKE_KEY *key, uint8_t *out_shared_secret,
-                           size_t *out_shared_secret_len, const uint8_t *enc,
-                           size_t enc_len) {
-  if (enc_len != MLKEM1024_PUBLIC_VALUE_LEN) {
-    OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
-    return 0;
-  }
-
-  MLKEM1024_private_key private_key;
-  if (!MLKEM1024_private_key_from_seed(&private_key, key->private_key,
-                                       MLKEM1024_PRIVATE_KEY_LEN)) {
-    OPENSSL_PUT_ERROR(EVP, EVP_R_DECODE_ERROR);
-    return 0;
-  }
-
-  if (!MLKEM1024_decap(out_shared_secret, enc, enc_len, &private_key)) {
-    OPENSSL_PUT_ERROR(EVP, ERR_R_INTERNAL_ERROR);
-    return 0;
-  }
-
-  *out_shared_secret_len = MLKEM1024_SHARED_KEY_LEN;
-  return 1;
-}
-
-const EVP_HPKE_KEM *EVP_hpke_mlkem1024(void) {
-  static const EVP_HPKE_KEM kKEM = {
-      /*id=*/EVP_HPKE_MLKEM1024,
-      /*public_key_len=*/MLKEM1024_PUBLIC_KEY_LEN,
-      /*private_key_len=*/MLKEM1024_PRIVATE_KEY_LEN,
-      /*seed_len=*/MLKEM1024_SEED_LEN,
-      /*enc_len=*/MLKEM1024_PUBLIC_VALUE_LEN,
-      mlkem1024_init_key,
-      mlkem1024_generate_key,
-      mlkem1024_encap_with_seed,
-      mlkem1024_decap,
-      // MLKEM1024 doesn't support authenticated encapsulation/decapsulation:
-      // https://datatracker.ietf.org/doc/draft-ietf-hpke-pq/01/
-      /* auth_encap_with_seed= */ nullptr,
-      /* auth_decap= */ nullptr,
-  };
-  return &kKEM;
-}
+const EVP_HPKE_KEM *EVP_hpke_mlkem768(void) { return &kMLKEM<MLKEM768HPKE>; }
+const EVP_HPKE_KEM *EVP_hpke_mlkem1024(void) { return &kMLKEM<MLKEM1024HPKE>; }
 
 uint16_t EVP_HPKE_KEM_id(const EVP_HPKE_KEM *kem) { return kem->id; }