mldsa: abstract over the matrix size.

Given the likelihood that we'll need other ML-DSA configurations,
convert ML-DSA to use templates so that other matrix sizes can
be supported.

Change-Id: Iadbdde0a9b36414a256c6103ad549c91d85c0fe7
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/71768
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Adam Langley <agl@google.com>
diff --git a/crypto/mldsa/internal.h b/crypto/mldsa/internal.h
index 08ff329..92f4ae1 100644
--- a/crypto/mldsa/internal.h
+++ b/crypto/mldsa/internal.h
@@ -27,6 +27,9 @@
 // random entropy necessary to generate a signature in randomized mode.
 #define MLDSA_SIGNATURE_RANDOMIZER_BYTES 32
 
+
+// ML-DSA-65
+
 // MLDSA65_generate_key_external_entropy generates a public/private key pair
 // using the given seed, writes the encoded public key to
 // |out_encoded_public_key| and sets |out_private_key| to the private key.
diff --git a/crypto/mldsa/mldsa.cc b/crypto/mldsa/mldsa.cc
index a5ada3d..bfb00b7 100644
--- a/crypto/mldsa/mldsa.cc
+++ b/crypto/mldsa/mldsa.cc
@@ -25,49 +25,121 @@
 #include "../keccak/internal.h"
 #include "./internal.h"
 
-#define DEGREE 256
-#define K 6
-#define L 5
-#define ETA 4
-#define TAU 49
-#define BETA 196
-#define OMEGA 55
+namespace {
 
-#define RHO_BYTES 32
-#define SIGMA_BYTES 64
-#define K_BYTES 32
-#define TR_BYTES 64
-#define MU_BYTES 64
-#define RHO_PRIME_BYTES 64
-#define LAMBDA_BITS 192
-#define LAMBDA_BYTES (LAMBDA_BITS / 8)
+constexpr int kDegree = 256;
+constexpr int kRhoBytes = 32;
+constexpr int kSigmaBytes = 64;
+constexpr int kKBytes = 32;
+constexpr int kTrBytes = 64;
+constexpr int kMuBytes = 64;
+constexpr int kRhoPrimeBytes = 64;
 
 // 2^23 - 2^13 + 1
-static const uint32_t kPrime = 8380417;
+constexpr uint32_t kPrime = 8380417;
 // Inverse of -kPrime modulo 2^32
-static const uint32_t kPrimeNegInverse = 4236238847;
-static const int kDroppedBits = 13;
-static const uint32_t kHalfPrime = (8380417 - 1) / 2;
-static const uint32_t kGamma1 = 1 << 19;
-static const uint32_t kGamma2 = (8380417 - 1) / 32;
+constexpr uint32_t kPrimeNegInverse = 4236238847;
+constexpr int kDroppedBits = 13;
+constexpr uint32_t kHalfPrime = (kPrime - 1) / 2;
+constexpr uint32_t kGamma2 = (kPrime - 1) / 32;
 // 256^-1 mod kPrime, in Montgomery form.
-static const uint32_t kInverseDegreeMontgomery = 41978;
+constexpr uint32_t kInverseDegreeMontgomery = 41978;
+
+// Constants that vary depending on ML-DSA size.
+//
+// These are implemented as templates which take the K parameter to distinguish
+// the ML-DSA sizes. (At the time of writing, `if constexpr` was not available.)
+//
+// TODO(crbug.com/42290600): Switch this to `if constexpr` when C++17 is
+// available.
+
+template <int K>
+constexpr size_t public_key_bytes();
+
+template <>
+constexpr size_t public_key_bytes<6>() {
+  return MLDSA65_PUBLIC_KEY_BYTES;
+}
+
+template <int K>
+constexpr size_t signature_bytes();
+
+template <>
+constexpr size_t signature_bytes<6>() {
+  return MLDSA65_SIGNATURE_BYTES;
+}
+
+template <int K>
+constexpr int tau();
+
+template <>
+constexpr int tau<6>() {
+  return 49;
+}
+
+template <int K>
+constexpr int lambda_bytes();
+
+template <>
+constexpr int lambda_bytes<6>() {
+  return 192 / 8;
+}
+
+template <int K>
+constexpr int gamma1();
+
+template <>
+constexpr int gamma1<6>() {
+  return 1 << 19;
+}
+
+template <int K>
+constexpr int beta();
+
+template <>
+constexpr int beta<6>() {
+  return 196;
+}
+
+template <int K>
+constexpr int omega();
+
+template <>
+constexpr int omega<6>() {
+  return 55;
+}
+
+template <int K>
+constexpr int eta();
+
+template <>
+constexpr int eta<6>() {
+  return 4;
+}
+
+template <int K>
+constexpr int plus_minus_eta_bitlen();
+
+template <>
+constexpr int plus_minus_eta_bitlen<6>() {
+  return 4;
+}
+
+// Fundamental types.
 
 typedef struct scalar {
-  uint32_t c[DEGREE];
+  uint32_t c[kDegree];
 } scalar;
 
-typedef struct vectork {
+template <int K>
+struct vector {
   scalar v[K];
-} vectork;
+};
 
-typedef struct vectorl {
-  scalar v[L];
-} vectorl;
-
-typedef struct matrix {
+template <int K, int L>
+struct matrix {
   scalar v[K][L];
-} matrix;
+};
 
 /* Arithmetic */
 
@@ -173,13 +245,13 @@
 }
 
 static void scalar_add(scalar *out, const scalar *lhs, const scalar *rhs) {
-  for (int i = 0; i < DEGREE; i++) {
+  for (int i = 0; i < kDegree; i++) {
     out->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
   }
 }
 
 static void scalar_sub(scalar *out, const scalar *lhs, const scalar *rhs) {
-  for (int i = 0; i < DEGREE; i++) {
+  for (int i = 0; i < kDegree; i++) {
     out->c[i] = mod_sub(lhs->c[i], rhs->c[i]);
   }
 }
@@ -195,7 +267,7 @@
 
 // Multiply two scalars in the number theoretically transformed state.
 static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
-  for (int i = 0; i < DEGREE; i++) {
+  for (int i = 0; i < kDegree; i++) {
     out->c[i] = reduce_montgomery((uint64_t)lhs->c[i] * (uint64_t)rhs->c[i]);
   }
 }
@@ -206,8 +278,8 @@
 static void scalar_ntt(scalar *s) {
   // Step: 1, 2, 4, 8, ..., 128
   // Offset: 128, 64, 32, 16, ..., 1
-  int offset = DEGREE;
-  for (int step = 1; step < DEGREE; step <<= 1) {
+  int offset = kDegree;
+  for (int step = 1; step < kDegree; step <<= 1) {
     offset >>= 1;
     int k = 0;
     for (int i = 0; i < step; i++) {
@@ -234,8 +306,8 @@
 static void scalar_inverse_ntt(scalar *s) {
   // Step: 128, 64, 32, 16, ..., 1
   // Offset: 1, 2, 4, 8, ..., 128
-  int step = DEGREE;
-  for (int offset = 1; offset < DEGREE; offset <<= 1) {
+  int step = kDegree;
+  for (int offset = 1; offset < kDegree; offset <<= 1) {
     step >>= 1;
     int k = 0;
     for (int i = 0; i < step; i++) {
@@ -258,72 +330,59 @@
       k += 2 * offset;
     }
   }
-  for (int i = 0; i < DEGREE; i++) {
+  for (int i = 0; i < kDegree; i++) {
     s->c[i] = reduce_montgomery((uint64_t)s->c[i] *
                                 (uint64_t)kInverseDegreeMontgomery);
   }
 }
 
-static void vectork_zero(vectork *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
+template <int X>
+static void vector_zero(vector<X> *out) {
+  OPENSSL_memset(out, 0, sizeof(*out));
+}
 
-static void vectork_add(vectork *out, const vectork *lhs, const vectork *rhs) {
-  for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_add(vector<X> *out, const vector<X> *lhs,
+                       const vector<X> *rhs) {
+  for (int i = 0; i < X; i++) {
     scalar_add(&out->v[i], &lhs->v[i], &rhs->v[i]);
   }
 }
 
-static void vectork_sub(vectork *out, const vectork *lhs, const vectork *rhs) {
-  for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_sub(vector<X> *out, const vector<X> *lhs,
+                       const vector<X> *rhs) {
+  for (int i = 0; i < X; i++) {
     scalar_sub(&out->v[i], &lhs->v[i], &rhs->v[i]);
   }
 }
 
-static void vectork_mult_scalar(vectork *out, const vectork *lhs,
-                                const scalar *rhs) {
-  for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_mult_scalar(vector<X> *out, const vector<X> *lhs,
+                               const scalar *rhs) {
+  for (int i = 0; i < X; i++) {
     scalar_mult(&out->v[i], &lhs->v[i], rhs);
   }
 }
 
-static void vectork_ntt(vectork *a) {
-  for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_ntt(vector<X> *a) {
+  for (int i = 0; i < X; i++) {
     scalar_ntt(&a->v[i]);
   }
 }
 
-static void vectork_inverse_ntt(vectork *a) {
-  for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_inverse_ntt(vector<X> *a) {
+  for (int i = 0; i < X; i++) {
     scalar_inverse_ntt(&a->v[i]);
   }
 }
 
-static void vectorl_add(vectorl *out, const vectorl *lhs, const vectorl *rhs) {
-  for (int i = 0; i < L; i++) {
-    scalar_add(&out->v[i], &lhs->v[i], &rhs->v[i]);
-  }
-}
-
-static void vectorl_mult_scalar(vectorl *out, const vectorl *lhs,
-                                const scalar *rhs) {
-  for (int i = 0; i < L; i++) {
-    scalar_mult(&out->v[i], &lhs->v[i], rhs);
-  }
-}
-
-static void vectorl_ntt(vectorl *a) {
-  for (int i = 0; i < L; i++) {
-    scalar_ntt(&a->v[i]);
-  }
-}
-
-static void vectorl_inverse_ntt(vectorl *a) {
-  for (int i = 0; i < L; i++) {
-    scalar_inverse_ntt(&a->v[i]);
-  }
-}
-
-static void matrix_mult(vectork *out, const matrix *m, const vectorl *a) {
-  vectork_zero(out);
+template <int K, int L>
+static void matrix_mult(vector<K> *out, const matrix<K, L> *m,
+                        const vector<L> *a) {
+  vector_zero(out);
   for (int i = 0; i < K; i++) {
     for (int j = 0; j < L; j++) {
       scalar product;
@@ -435,38 +494,38 @@
 }
 
 static void scalar_power2_round(scalar *s1, scalar *s0, const scalar *s) {
-  for (int i = 0; i < DEGREE; i++) {
+  for (int i = 0; i < kDegree; i++) {
     power2_round(&s1->c[i], &s0->c[i], s->c[i]);
   }
 }
 
 static void scalar_scale_power2_round(scalar *out, const scalar *in) {
-  for (int i = 0; i < DEGREE; i++) {
+  for (int i = 0; i < kDegree; i++) {
     scale_power2_round(&out->c[i], in->c[i]);
   }
 }
 
 static void scalar_high_bits(scalar *out, const scalar *in) {
-  for (int i = 0; i < DEGREE; i++) {
+  for (int i = 0; i < kDegree; i++) {
     out->c[i] = high_bits(in->c[i]);
   }
 }
 
 static void scalar_low_bits(scalar *out, const scalar *in) {
-  for (int i = 0; i < DEGREE; i++) {
+  for (int i = 0; i < kDegree; i++) {
     out->c[i] = low_bits(in->c[i]);
   }
 }
 
 static void scalar_max(uint32_t *max, const scalar *s) {
-  for (int i = 0; i < DEGREE; i++) {
+  for (int i = 0; i < kDegree; i++) {
     uint32_t abs = abs_mod_prime(s->c[i]);
     *max = maximum(*max, abs);
   }
 }
 
 static void scalar_max_signed(uint32_t *max, const scalar *s) {
-  for (int i = 0; i < DEGREE; i++) {
+  for (int i = 0; i < kDegree; i++) {
     uint32_t abs = abs_signed(s->c[i]);
     *max = maximum(*max, abs);
   }
@@ -474,98 +533,100 @@
 
 static void scalar_make_hint(scalar *out, const scalar *ct0, const scalar *cs2,
                              const scalar *w) {
-  for (int i = 0; i < DEGREE; i++) {
+  for (int i = 0; i < kDegree; i++) {
     out->c[i] = make_hint(ct0->c[i], cs2->c[i], w->c[i]);
   }
 }
 
 static void scalar_use_hint_vartime(scalar *out, const scalar *h,
                                     const scalar *r) {
-  for (int i = 0; i < DEGREE; i++) {
+  for (int i = 0; i < kDegree; i++) {
     out->c[i] = use_hint_vartime(h->c[i], r->c[i]);
   }
 }
 
-static void vectork_power2_round(vectork *t1, vectork *t0, const vectork *t) {
-  for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_power2_round(vector<X> *t1, vector<X> *t0,
+                                const vector<X> *t) {
+  for (int i = 0; i < X; i++) {
     scalar_power2_round(&t1->v[i], &t0->v[i], &t->v[i]);
   }
 }
 
-static void vectork_scale_power2_round(vectork *out, const vectork *in) {
-  for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_scale_power2_round(vector<X> *out, const vector<X> *in) {
+  for (int i = 0; i < X; i++) {
     scalar_scale_power2_round(&out->v[i], &in->v[i]);
   }
 }
 
-static void vectork_high_bits(vectork *out, const vectork *in) {
-  for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_high_bits(vector<X> *out, const vector<X> *in) {
+  for (int i = 0; i < X; i++) {
     scalar_high_bits(&out->v[i], &in->v[i]);
   }
 }
 
-static void vectork_low_bits(vectork *out, const vectork *in) {
-  for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_low_bits(vector<X> *out, const vector<X> *in) {
+  for (int i = 0; i < X; i++) {
     scalar_low_bits(&out->v[i], &in->v[i]);
   }
 }
 
-static uint32_t vectork_max(const vectork *a) {
+template <int X>
+static uint32_t vector_max(const vector<X> *a) {
   uint32_t max = 0;
-  for (int i = 0; i < K; i++) {
+  for (int i = 0; i < X; i++) {
     scalar_max(&max, &a->v[i]);
   }
   return max;
 }
 
-static uint32_t vectork_max_signed(const vectork *a) {
+template <int X>
+static uint32_t vector_max_signed(const vector<X> *a) {
   uint32_t max = 0;
-  for (int i = 0; i < K; i++) {
+  for (int i = 0; i < X; i++) {
     scalar_max_signed(&max, &a->v[i]);
   }
   return max;
 }
 
 // The input vector contains only zeroes and ones.
-static size_t vectork_count_ones(const vectork *a) {
+template <int X>
+static size_t vector_count_ones(const vector<X> *a) {
   size_t count = 0;
-  for (int i = 0; i < K; i++) {
-    for (int j = 0; j < DEGREE; j++) {
+  for (int i = 0; i < X; i++) {
+    for (int j = 0; j < kDegree; j++) {
       count += a->v[i].c[j];
     }
   }
   return count;
 }
 
-static void vectork_make_hint(vectork *out, const vectork *ct0,
-                              const vectork *cs2, const vectork *w) {
-  for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_make_hint(vector<X> *out, const vector<X> *ct0,
+                             const vector<X> *cs2, const vector<X> *w) {
+  for (int i = 0; i < X; i++) {
     scalar_make_hint(&out->v[i], &ct0->v[i], &cs2->v[i], &w->v[i]);
   }
 }
 
-static void vectork_use_hint_vartime(vectork *out, const vectork *h,
-                                     const vectork *r) {
-  for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_use_hint_vartime(vector<X> *out, const vector<X> *h,
+                                    const vector<X> *r) {
+  for (int i = 0; i < X; i++) {
     scalar_use_hint_vartime(&out->v[i], &h->v[i], &r->v[i]);
   }
 }
 
-static uint32_t vectorl_max(const vectorl *a) {
-  uint32_t max = 0;
-  for (int i = 0; i < L; i++) {
-    scalar_max(&max, &a->v[i]);
-  }
-  return max;
-}
-
 /* Bit packing */
 
 // FIPS 204, Algorithm 16 (`SimpleBitPack`). Specialized to bitlen(b) = 4.
 static void scalar_encode_4(uint8_t out[128], const scalar *s) {
   // Every two elements lands on a byte boundary.
-  static_assert(DEGREE % 2 == 0, "DEGREE must be a multiple of 2");
-  for (int i = 0; i < DEGREE / 2; i++) {
+  static_assert(kDegree % 2 == 0, "kDegree must be a multiple of 2");
+  for (int i = 0; i < kDegree / 2; i++) {
     uint32_t a = s->c[2 * i];
     uint32_t b = s->c[2 * i + 1];
     declassify_assert(a < 16);
@@ -577,8 +638,8 @@
 // FIPS 204, Algorithm 16 (`SimpleBitPack`). Specialized to bitlen(b) = 10.
 static void scalar_encode_10(uint8_t out[320], const scalar *s) {
   // Every four elements lands on a byte boundary.
-  static_assert(DEGREE % 4 == 0, "DEGREE must be a multiple of 4");
-  for (int i = 0; i < DEGREE / 4; i++) {
+  static_assert(kDegree % 4 == 0, "kDegree must be a multiple of 4");
+  for (int i = 0; i < kDegree / 4; i++) {
     uint32_t a = s->c[4 * i];
     uint32_t b = s->c[4 * i + 1];
     uint32_t c = s->c[4 * i + 2];
@@ -595,14 +656,13 @@
   }
 }
 
-// FIPS 204, Algorithm 17 (`BitPack`). Specialized to bitlen(b) = 4 and b =
-// 2^19.
-static void scalar_encode_signed_4_eta(uint8_t out[128], const scalar *s) {
+// FIPS 204, Algorithm 17 (`BitPack`). Specialized to bitlen(b) = 4 and b = 4.
+static void scalar_encode_signed_4_4(uint8_t out[128], const scalar *s) {
   // Every two elements lands on a byte boundary.
-  static_assert(DEGREE % 2 == 0, "DEGREE must be a multiple of 2");
-  for (int i = 0; i < DEGREE / 2; i++) {
-    uint32_t a = mod_sub(ETA, s->c[2 * i]);
-    uint32_t b = mod_sub(ETA, s->c[2 * i + 1]);
+  static_assert(kDegree % 2 == 0, "kDegree must be a multiple of 2");
+  for (int i = 0; i < kDegree / 2; i++) {
+    uint32_t a = mod_sub(4, s->c[2 * i]);
+    uint32_t b = mod_sub(4, s->c[2 * i + 1]);
     declassify_assert(a < 16);
     declassify_assert(b < 16);
     out[i] = a | (b << 4);
@@ -614,8 +674,8 @@
 static void scalar_encode_signed_13_12(uint8_t out[416], const scalar *s) {
   static const uint32_t kMax = 1u << 12;
   // Every two elements lands on a byte boundary.
-  static_assert(DEGREE % 8 == 0, "DEGREE must be a multiple of 8");
-  for (int i = 0; i < DEGREE / 8; i++) {
+  static_assert(kDegree % 8 == 0, "kDegree must be a multiple of 8");
+  for (int i = 0; i < kDegree / 8; i++) {
     uint32_t a = mod_sub(kMax, s->c[8 * i]);
     uint32_t b = mod_sub(kMax, s->c[8 * i + 1]);
     uint32_t c = mod_sub(kMax, s->c[8 * i + 2]);
@@ -654,8 +714,8 @@
 static void scalar_encode_signed_20_19(uint8_t out[640], const scalar *s) {
   static const uint32_t kMax = 1u << 19;
   // Every two elements lands on a byte boundary.
-  static_assert(DEGREE % 4 == 0, "DEGREE must be a multiple of 4");
-  for (int i = 0; i < DEGREE / 4; i++) {
+  static_assert(kDegree % 4 == 0, "kDegree must be a multiple of 4");
+  for (int i = 0; i < kDegree / 4; i++) {
     uint32_t a = mod_sub(kMax, s->c[4 * i]);
     uint32_t b = mod_sub(kMax, s->c[4 * i + 1]);
     uint32_t c = mod_sub(kMax, s->c[4 * i + 2]);
@@ -679,8 +739,8 @@
 static void scalar_encode_signed(uint8_t *out, const scalar *s, int bits,
                                  uint32_t max) {
   if (bits == 4) {
-    assert(max == ETA);
-    scalar_encode_signed_4_eta(out, s);
+    assert(max == 4);
+    scalar_encode_signed_4_4(out, s);
   } else if (bits == 20) {
     assert(max == 1u << 19);
     scalar_encode_signed_20_19(out, s);
@@ -694,8 +754,8 @@
 // FIPS 204, Algorithm 18 (`SimpleBitUnpack`). Specialized for bitlen(b) == 10.
 static void scalar_decode_10(scalar *out, const uint8_t in[320]) {
   uint32_t v;
-  static_assert(DEGREE % 4 == 0, "DEGREE must be a multiple of 4");
-  for (int i = 0; i < DEGREE / 4; i++) {
+  static_assert(kDegree % 4 == 0, "kDegree must be a multiple of 4");
+  for (int i = 0; i < kDegree / 4; i++) {
     OPENSSL_memcpy(&v, &in[5 * i], sizeof(v));
     out->c[4 * i] = v & 0x3ff;
     out->c[4 * i + 1] = (v >> 10) & 0x3ff;
@@ -705,13 +765,12 @@
 }
 
 // FIPS 204, Algorithm 19 (`BitUnpack`). Specialized to bitlen(a+b) = 4 and b =
-// eta.
-static int scalar_decode_signed_4_eta(scalar *out, const uint8_t in[128]) {
+// 4.
+static int scalar_decode_signed_4_4(scalar *out, const uint8_t in[128]) {
   uint32_t v;
-  static_assert(DEGREE % 8 == 0, "DEGREE must be a multiple of 8");
-  for (int i = 0; i < DEGREE / 8; i++) {
+  static_assert(kDegree % 8 == 0, "kDegree must be a multiple of 8");
+  for (int i = 0; i < kDegree / 8; i++) {
     OPENSSL_memcpy(&v, &in[4 * i], sizeof(v));
-    static_assert(ETA == 4, "ETA must be 4");
     // None of the nibbles may be >= 9. So if the MSB of any nibble is set, none
     // of the other bits may be set. First, select all the MSBs.
     const uint32_t msbs = v & 0x88888888u;
@@ -723,14 +782,14 @@
       return 0;
     }
 
-    out->c[i * 8] = mod_sub(ETA, v & 15);
-    out->c[i * 8 + 1] = mod_sub(ETA, (v >> 4) & 15);
-    out->c[i * 8 + 2] = mod_sub(ETA, (v >> 8) & 15);
-    out->c[i * 8 + 3] = mod_sub(ETA, (v >> 12) & 15);
-    out->c[i * 8 + 4] = mod_sub(ETA, (v >> 16) & 15);
-    out->c[i * 8 + 5] = mod_sub(ETA, (v >> 20) & 15);
-    out->c[i * 8 + 6] = mod_sub(ETA, (v >> 24) & 15);
-    out->c[i * 8 + 7] = mod_sub(ETA, v >> 28);
+    out->c[i * 8] = mod_sub(4, v & 15);
+    out->c[i * 8 + 1] = mod_sub(4, (v >> 4) & 15);
+    out->c[i * 8 + 2] = mod_sub(4, (v >> 8) & 15);
+    out->c[i * 8 + 3] = mod_sub(4, (v >> 12) & 15);
+    out->c[i * 8 + 4] = mod_sub(4, (v >> 16) & 15);
+    out->c[i * 8 + 5] = mod_sub(4, (v >> 20) & 15);
+    out->c[i * 8 + 6] = mod_sub(4, (v >> 24) & 15);
+    out->c[i * 8 + 7] = mod_sub(4, v >> 28);
   }
   return 1;
 }
@@ -744,8 +803,8 @@
 
   uint32_t a, b, c;
   uint8_t d;
-  static_assert(DEGREE % 8 == 0, "DEGREE must be a multiple of 8");
-  for (int i = 0; i < DEGREE / 8; i++) {
+  static_assert(kDegree % 8 == 0, "kDegree must be a multiple of 8");
+  for (int i = 0; i < kDegree / 8; i++) {
     OPENSSL_memcpy(&a, &in[13 * i], sizeof(a));
     OPENSSL_memcpy(&b, &in[13 * i + 4], sizeof(b));
     OPENSSL_memcpy(&c, &in[13 * i + 8], sizeof(c));
@@ -772,8 +831,8 @@
 
   uint32_t a, b;
   uint16_t c;
-  static_assert(DEGREE % 4 == 0, "DEGREE must be a multiple of 4");
-  for (int i = 0; i < DEGREE / 4; i++) {
+  static_assert(kDegree % 4 == 0, "kDegree must be a multiple of 4");
+  for (int i = 0; i < kDegree / 4; i++) {
     OPENSSL_memcpy(&a, &in[10 * i], sizeof(a));
     OPENSSL_memcpy(&b, &in[10 * i + 4], sizeof(b));
     OPENSSL_memcpy(&c, &in[10 * i + 8], sizeof(c));
@@ -791,8 +850,8 @@
 static int scalar_decode_signed(scalar *out, const uint8_t *in, int bits,
                                 uint32_t max) {
   if (bits == 4) {
-    assert(max == ETA);
-    return scalar_decode_signed_4_eta(out, in);
+    assert(max == 4);
+    return scalar_decode_signed_4_4(out, in);
   } else if (bits == 13) {
     assert(max == (1u << 12));
     scalar_decode_signed_13_12(out, in);
@@ -813,19 +872,19 @@
 // Rejection samples a Keccak stream to get uniformly distributed elements. This
 // is used for matrix expansion and only operates on public inputs.
 static void scalar_from_keccak_vartime(
-    scalar *out, const uint8_t derived_seed[RHO_BYTES + 2]) {
+    scalar *out, const uint8_t derived_seed[kRhoBytes + 2]) {
   struct BORINGSSL_keccak_st keccak_ctx;
   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
-  BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, RHO_BYTES + 2);
+  BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, kRhoBytes + 2);
   assert(keccak_ctx.squeeze_offset == 0);
   assert(keccak_ctx.rate_bytes == 168);
   static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
 
   int done = 0;
-  while (done < DEGREE) {
+  while (done < kDegree) {
     uint8_t block[168];
     BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
-    for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
+    for (size_t i = 0; i < sizeof(block) && done < kDegree; i += 3) {
       // FIPS 204, Algorithm 14 (`CoeffFromThreeBytes`).
       uint32_t value = (uint32_t)block[i] | ((uint32_t)block[i + 1] << 8) |
                        (((uint32_t)block[i + 2] & 0x7f) << 16);
@@ -836,22 +895,33 @@
   }
 }
 
-// FIPS 204, Algorithm 31 (`RejBoundedPoly`).
-static void scalar_uniform_eta_4(scalar *out,
-                                 const uint8_t derived_seed[SIGMA_BYTES + 2]) {
-  static_assert(ETA == 4, "This implementation is specialized for ETA == 4");
+template <int ETA>
+static bool coefficient_from_nibble(uint32_t nibble, uint32_t *result);
 
+template <>
+bool coefficient_from_nibble<4>(uint32_t nibble, uint32_t *result) {
+  if (constant_time_declassify_int(nibble < 9)) {
+    *result = mod_sub(4, nibble);
+    return true;
+  }
+  return false;
+}
+
+// FIPS 204, Algorithm 31 (`RejBoundedPoly`).
+template <int ETA>
+static void scalar_uniform(scalar *out,
+                           const uint8_t derived_seed[kSigmaBytes + 2]) {
   struct BORINGSSL_keccak_st keccak_ctx;
   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
-  BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, SIGMA_BYTES + 2);
+  BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, kSigmaBytes + 2);
   assert(keccak_ctx.squeeze_offset == 0);
   assert(keccak_ctx.rate_bytes == 136);
 
   int done = 0;
-  while (done < DEGREE) {
+  while (done < kDegree) {
     uint8_t block[136];
     BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
-    for (size_t i = 0; i < sizeof(block) && done < DEGREE; ++i) {
+    for (size_t i = 0; i < sizeof(block) && done < kDegree; ++i) {
       uint32_t t0 = block[i] & 0x0F;
       uint32_t t1 = block[i] >> 4;
       // FIPS 204, Algorithm 15 (`CoefFromHalfByte`). Although both the input
@@ -859,21 +929,22 @@
       // Individual bytes of the SHAKE-256 stream are (indistiguishable from)
       // independent of each other and the original seed, so leaking information
       // about the rejected bytes does not reveal the input or output.
-      if (constant_time_declassify_int(t0 < 9)) {
-        out->c[done++] = mod_sub(ETA, t0);
+      uint32_t v;
+      if (coefficient_from_nibble<ETA>(t0, &v)) {
+        out->c[done++] = v;
       }
-      if (done < DEGREE && constant_time_declassify_int(t1 < 9)) {
-        out->c[done++] = mod_sub(ETA, t1);
+      if (done < kDegree && coefficient_from_nibble<ETA>(t1, &v)) {
+        out->c[done++] = v;
       }
     }
   }
 }
 
 // FIPS 204, Algorithm 34 (`ExpandMask`), but just a single step.
-static void scalar_sample_mask(
-    scalar *out, const uint8_t derived_seed[RHO_PRIME_BYTES + 2]) {
+static void scalar_sample_mask(scalar *out,
+                               const uint8_t derived_seed[kRhoPrimeBytes + 2]) {
   uint8_t buf[640];
-  BORINGSSL_keccak(buf, sizeof(buf), derived_seed, RHO_PRIME_BYTES + 2,
+  BORINGSSL_keccak(buf, sizeof(buf), derived_seed, kRhoPrimeBytes + 2,
                    boringssl_shake256);
 
   scalar_decode_signed_20_19(out, buf);
@@ -881,9 +952,7 @@
 
 // FIPS 204, Algorithm 29 (`SampleInBall`).
 static void scalar_sample_in_ball_vartime(scalar *out, const uint8_t *seed,
-                                          int len) {
-  assert(len == 2 * LAMBDA_BYTES);
-
+                                          int len, int tau) {
   struct BORINGSSL_keccak_st keccak_ctx;
   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
   BORINGSSL_keccak_absorb(&keccak_ctx, seed, len);
@@ -902,7 +971,7 @@
   CONSTTIME_DECLASSIFY(block + offset, sizeof(block) - offset);
 
   OPENSSL_memset(out, 0, sizeof(*out));
-  for (size_t i = DEGREE - TAU; i < DEGREE; i++) {
+  for (size_t i = kDegree - tau; i < kDegree; i++) {
     size_t byte;
     for (;;) {
       if (offset == 136) {
@@ -925,54 +994,57 @@
 }
 
 // FIPS 204, Algorithm 32 (`ExpandA`).
-static void matrix_expand(matrix *out, const uint8_t rho[RHO_BYTES]) {
+template <int K, int L>
+static void matrix_expand(matrix<K, L> *out, const uint8_t rho[kRhoBytes]) {
   static_assert(K <= 0x100, "K must fit in 8 bits");
   static_assert(L <= 0x100, "L must fit in 8 bits");
 
-  uint8_t derived_seed[RHO_BYTES + 2];
-  OPENSSL_memcpy(derived_seed, rho, RHO_BYTES);
+  uint8_t derived_seed[kRhoBytes + 2];
+  OPENSSL_memcpy(derived_seed, rho, kRhoBytes);
   for (int i = 0; i < K; i++) {
     for (int j = 0; j < L; j++) {
-      derived_seed[RHO_BYTES + 1] = (uint8_t)i;
-      derived_seed[RHO_BYTES] = (uint8_t)j;
+      derived_seed[kRhoBytes + 1] = (uint8_t)i;
+      derived_seed[kRhoBytes] = (uint8_t)j;
       scalar_from_keccak_vartime(&out->v[i][j], derived_seed);
     }
   }
 }
 
 // FIPS 204, Algorithm 33 (`ExpandS`).
-static void vector_expand_short(vectorl *s1, vectork *s2,
-                                const uint8_t sigma[SIGMA_BYTES]) {
+template <int K, int L>
+static void vector_expand_short(vector<L> *s1, vector<K> *s2,
+                                const uint8_t sigma[kSigmaBytes]) {
   static_assert(K <= 0x100, "K must fit in 8 bits");
   static_assert(L <= 0x100, "L must fit in 8 bits");
   static_assert(K + L <= 0x100, "K+L must fit in 8 bits");
 
-  uint8_t derived_seed[SIGMA_BYTES + 2];
-  OPENSSL_memcpy(derived_seed, sigma, SIGMA_BYTES);
-  derived_seed[SIGMA_BYTES] = 0;
-  derived_seed[SIGMA_BYTES + 1] = 0;
+  uint8_t derived_seed[kSigmaBytes + 2];
+  OPENSSL_memcpy(derived_seed, sigma, kSigmaBytes);
+  derived_seed[kSigmaBytes] = 0;
+  derived_seed[kSigmaBytes + 1] = 0;
   for (int i = 0; i < L; i++) {
-    scalar_uniform_eta_4(&s1->v[i], derived_seed);
-    ++derived_seed[SIGMA_BYTES];
+    scalar_uniform<eta<K>()>(&s1->v[i], derived_seed);
+    ++derived_seed[kSigmaBytes];
   }
   for (int i = 0; i < K; i++) {
-    scalar_uniform_eta_4(&s2->v[i], derived_seed);
-    ++derived_seed[SIGMA_BYTES];
+    scalar_uniform<eta<K>()>(&s2->v[i], derived_seed);
+    ++derived_seed[kSigmaBytes];
   }
 }
 
 // FIPS 204, Algorithm 34 (`ExpandMask`).
-static void vectorl_expand_mask(vectorl *out,
-                                const uint8_t seed[RHO_PRIME_BYTES],
-                                size_t kappa) {
+template <int L>
+static void vector_expand_mask(vector<L> *out,
+                               const uint8_t seed[kRhoPrimeBytes],
+                               size_t kappa) {
   assert(kappa + L <= 0x10000);
 
-  uint8_t derived_seed[RHO_PRIME_BYTES + 2];
-  OPENSSL_memcpy(derived_seed, seed, RHO_PRIME_BYTES);
+  uint8_t derived_seed[kRhoPrimeBytes + 2];
+  OPENSSL_memcpy(derived_seed, seed, kRhoPrimeBytes);
   for (int i = 0; i < L; i++) {
     size_t index = kappa + i;
-    derived_seed[RHO_PRIME_BYTES] = index & 0xFF;
-    derived_seed[RHO_PRIME_BYTES + 1] = (index >> 8) & 0xFF;
+    derived_seed[kRhoPrimeBytes] = index & 0xFF;
+    derived_seed[kRhoPrimeBytes + 1] = (index >> 8) & 0xFF;
     scalar_sample_mask(&out->v[i], derived_seed);
   }
 }
@@ -981,63 +1053,49 @@
 
 // FIPS 204, Algorithm 16 (`SimpleBitPack`).
 //
-// Encodes an entire vector into 32*K*|bits| bytes. Note that since 256 (DEGREE)
-// is divisible by 8, the individual vector entries will always fill a whole
-// number of bytes, so we do not need to worry about bit packing here.
-static void vectork_encode(uint8_t *out, const vectork *a, int bits) {
+// Encodes an entire vector into 32*K*|bits| bytes. Note that since 256
+// (kDegree) is divisible by 8, the individual vector entries will always fill a
+// whole number of bytes, so we do not need to worry about bit packing here.
+template <int K>
+static void vector_encode(uint8_t *out, const vector<K> *a, int bits) {
   if (bits == 4) {
     for (int i = 0; i < K; i++) {
-      scalar_encode_4(out + i * bits * DEGREE / 8, &a->v[i]);
+      scalar_encode_4(out + i * bits * kDegree / 8, &a->v[i]);
     }
   } else {
     assert(bits == 10);
     for (int i = 0; i < K; i++) {
-      scalar_encode_10(out + i * bits * DEGREE / 8, &a->v[i]);
+      scalar_encode_10(out + i * bits * kDegree / 8, &a->v[i]);
     }
   }
 }
 
 // FIPS 204, Algorithm 18 (`SimpleBitUnpack`).
-static void vectork_decode_10(vectork *out, const uint8_t *in) {
+template <int K>
+static void vector_decode_10(vector<K> *out, const uint8_t *in) {
   for (int i = 0; i < K; i++) {
-    scalar_decode_10(&out->v[i], in + i * 10 * DEGREE / 8);
+    scalar_decode_10(&out->v[i], in + i * 10 * kDegree / 8);
   }
 }
 
-static void vectork_encode_signed(uint8_t *out, const vectork *a, int bits,
-                                  uint32_t max) {
-  for (int i = 0; i < K; i++) {
-    scalar_encode_signed(out + i * bits * DEGREE / 8, &a->v[i], bits, max);
-  }
-}
-
-static int vectork_decode_signed(vectork *out, const uint8_t *in, int bits,
-                                 uint32_t max) {
-  for (int i = 0; i < K; i++) {
-    if (!scalar_decode_signed(&out->v[i], in + i * bits * DEGREE / 8, bits,
-                              max)) {
-      return 0;
-    }
-  }
-  return 1;
-}
-
 // FIPS 204, Algorithm 17 (`BitPack`).
 //
-// Encodes an entire vector into 32*L*|bits| bytes. Note that since 256 (DEGREE)
-// is divisible by 8, the individual vector entries will always fill a whole
-// number of bytes, so we do not need to worry about bit packing here.
-static void vectorl_encode_signed(uint8_t *out, const vectorl *a, int bits,
-                                  uint32_t max) {
-  for (int i = 0; i < L; i++) {
-    scalar_encode_signed(out + i * bits * DEGREE / 8, &a->v[i], bits, max);
+// Encodes an entire vector into 32*L*|bits| bytes. Note that since 256
+// (kDegree) is divisible by 8, the individual vector entries will always fill a
+// whole number of bytes, so we do not need to worry about bit packing here.
+template <int X>
+static void vector_encode_signed(uint8_t *out, const vector<X> *a, int bits,
+                                 uint32_t max) {
+  for (int i = 0; i < X; i++) {
+    scalar_encode_signed(out + i * bits * kDegree / 8, &a->v[i], bits, max);
   }
 }
 
-static int vectorl_decode_signed(vectorl *out, const uint8_t *in, int bits,
-                                 uint32_t max) {
-  for (int i = 0; i < L; i++) {
-    if (!scalar_decode_signed(&out->v[i], in + i * bits * DEGREE / 8, bits,
+template <int X>
+static int vector_decode_signed(vector<X> *out, const uint8_t *in, int bits,
+                                uint32_t max) {
+  for (int i = 0; i < X; i++) {
+    if (!scalar_decode_signed(&out->v[i], in + i * bits * kDegree / 8, bits,
                               max)) {
       return 0;
     }
@@ -1046,33 +1104,36 @@
 }
 
 // FIPS 204, Algorithm 28 (`w1Encode`).
-static void w1_encode(uint8_t out[128 * K], const vectork *w1) {
-  vectork_encode(out, w1, 4);
+template <int K>
+static void w1_encode(uint8_t out[128 * K], const vector<K> *w1) {
+  vector_encode(out, w1, 4);
 }
 
 // FIPS 204, Algorithm 20 (`HintBitPack`).
-static void hint_bit_pack(uint8_t out[OMEGA + K], const vectork *h) {
-  OPENSSL_memset(out, 0, OMEGA + K);
+template <int K>
+static void hint_bit_pack(uint8_t out[omega<K>() + K], const vector<K> *h) {
+  OPENSSL_memset(out, 0, omega<K>() + K);
   int index = 0;
   for (int i = 0; i < K; i++) {
-    for (int j = 0; j < DEGREE; j++) {
+    for (int j = 0; j < kDegree; j++) {
       if (h->v[i].c[j]) {
-        // h must have at most OMEGA non-zero coefficients.
-        BSSL_CHECK(index < OMEGA);
+        // h must have at most omega<K>() non-zero coefficients.
+        BSSL_CHECK(index < omega<K>());
         out[index++] = j;
       }
     }
-    out[OMEGA + i] = index;
+    out[omega<K>() + i] = index;
   }
 }
 
 // FIPS 204, Algorithm 21 (`HintBitUnpack`).
-static int hint_bit_unpack(vectork *h, const uint8_t in[OMEGA + K]) {
-  vectork_zero(h);
+template <int K>
+static int hint_bit_unpack(vector<K> *h, const uint8_t in[omega<K>() + K]) {
+  vector_zero(h);
   int index = 0;
   for (int i = 0; i < K; i++) {
-    const int limit = in[OMEGA + i];
-    if (limit < index || limit > OMEGA) {
+    const int limit = in[omega<K>() + i];
+    if (limit < index || limit > omega<K>()) {
       return 0;
     }
 
@@ -1083,12 +1144,12 @@
         return 0;
       }
       last = byte;
-      static_assert(DEGREE == 256,
-                    "DEGREE must be 256 for this write to be in bounds");
+      static_assert(kDegree == 256,
+                    "kDegree must be 256 for this write to be in bounds");
       h->v[i].c[byte] = 1;
     }
   }
-  for (; index < OMEGA; index++) {
+  for (; index < omega<K>(); index++) {
     if (in[index] != 0) {
       return 0;
     }
@@ -1096,30 +1157,34 @@
   return 1;
 }
 
+template <int K>
 struct public_key {
-  uint8_t rho[RHO_BYTES];
-  vectork t1;
+  uint8_t rho[kRhoBytes];
+  vector<K> t1;
   // Pre-cached value(s).
-  uint8_t public_key_hash[TR_BYTES];
+  uint8_t public_key_hash[kTrBytes];
 };
 
+template <int K, int L>
 struct private_key {
-  uint8_t rho[RHO_BYTES];
-  uint8_t k[K_BYTES];
-  uint8_t public_key_hash[TR_BYTES];
-  vectorl s1;
-  vectork s2;
-  vectork t0;
+  uint8_t rho[kRhoBytes];
+  uint8_t k[kKBytes];
+  uint8_t public_key_hash[kTrBytes];
+  vector<L> s1;
+  vector<K> s2;
+  vector<K> t0;
 };
 
+template <int K, int L>
 struct signature {
-  uint8_t c_tilde[2 * LAMBDA_BYTES];
-  vectorl z;
-  vectork h;
+  uint8_t c_tilde[2 * lambda_bytes<K>()];
+  vector<L> z;
+  vector<K> h;
 };
 
 // FIPS 204, Algorithm 22 (`pkEncode`).
-static int mldsa_marshal_public_key(CBB *out, const struct public_key *pub) {
+template <int K>
+static int mldsa_marshal_public_key(CBB *out, const struct public_key<K> *pub) {
   if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
     return 0;
   }
@@ -1128,28 +1193,37 @@
   if (!CBB_add_space(out, &vectork_output, 320 * K)) {
     return 0;
   }
-  vectork_encode(vectork_output, &pub->t1, 10);
+  vector_encode(vectork_output, &pub->t1, 10);
 
   return 1;
 }
 
 // FIPS 204, Algorithm 23 (`pkDecode`).
-static int mldsa_parse_public_key(struct public_key *pub, CBS *in) {
+template <int K>
+static int mldsa_parse_public_key(struct public_key<K> *pub, CBS *in) {
+  const CBS orig_in = *in;
+
   if (!CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
     return 0;
   }
 
   CBS t1_bytes;
-  if (!CBS_get_bytes(in, &t1_bytes, 320 * K)) {
+  if (!CBS_get_bytes(in, &t1_bytes, 320 * K) || CBS_len(in) != 0) {
     return 0;
   }
-  vectork_decode_10(&pub->t1, CBS_data(&t1_bytes));
+  vector_decode_10(&pub->t1, CBS_data(&t1_bytes));
+
+  // Compute pre-cached values.
+  BORINGSSL_keccak(pub->public_key_hash, sizeof(pub->public_key_hash),
+                   CBS_data(&orig_in), CBS_len(&orig_in), boringssl_shake256);
 
   return 1;
 }
 
 // FIPS 204, Algorithm 24 (`skEncode`).
-static int mldsa_marshal_private_key(CBB *out, const struct private_key *priv) {
+template <int K, int L>
+static int mldsa_marshal_private_key(CBB *out,
+                                     const struct private_key<K, L> *priv) {
   if (!CBB_add_bytes(out, priv->rho, sizeof(priv->rho)) ||
       !CBB_add_bytes(out, priv->k, sizeof(priv->k)) ||
       !CBB_add_bytes(out, priv->public_key_hash,
@@ -1157,42 +1231,52 @@
     return 0;
   }
 
+  constexpr size_t scalar_bytes =
+      (kDegree * plus_minus_eta_bitlen<K>() + 7) / 8;
   uint8_t *vectorl_output;
-  if (!CBB_add_space(out, &vectorl_output, 128 * L)) {
+  if (!CBB_add_space(out, &vectorl_output, scalar_bytes * L)) {
     return 0;
   }
-  vectorl_encode_signed(vectorl_output, &priv->s1, 4, ETA);
+  vector_encode_signed(vectorl_output, &priv->s1, plus_minus_eta_bitlen<K>(),
+                       eta<K>());
 
-  uint8_t *vectork_output;
-  if (!CBB_add_space(out, &vectork_output, 128 * K)) {
+  uint8_t *s2_output;
+  if (!CBB_add_space(out, &s2_output, scalar_bytes * K)) {
     return 0;
   }
-  vectork_encode_signed(vectork_output, &priv->s2, 4, ETA);
+  vector_encode_signed(s2_output, &priv->s2, plus_minus_eta_bitlen<K>(),
+                       eta<K>());
 
-  if (!CBB_add_space(out, &vectork_output, 416 * K)) {
+  uint8_t *t0_output;
+  if (!CBB_add_space(out, &t0_output, 416 * K)) {
     return 0;
   }
-  vectork_encode_signed(vectork_output, &priv->t0, 13, 1 << 12);
+  vector_encode_signed(t0_output, &priv->t0, 13, 1 << 12);
 
   return 1;
 }
 
 // FIPS 204, Algorithm 25 (`skDecode`).
-static int mldsa_parse_private_key(struct private_key *priv, CBS *in) {
+template <int K, int L>
+static int mldsa_parse_private_key(struct private_key<K, L> *priv, CBS *in) {
   CBS s1_bytes;
   CBS s2_bytes;
   CBS t0_bytes;
+  constexpr size_t scalar_bytes =
+      (kDegree * plus_minus_eta_bitlen<K>() + 7) / 8;
   if (!CBS_copy_bytes(in, priv->rho, sizeof(priv->rho)) ||
       !CBS_copy_bytes(in, priv->k, sizeof(priv->k)) ||
       !CBS_copy_bytes(in, priv->public_key_hash,
                       sizeof(priv->public_key_hash)) ||
-      !CBS_get_bytes(in, &s1_bytes, 128 * L) ||
-      !vectorl_decode_signed(&priv->s1, CBS_data(&s1_bytes), 4, ETA) ||
-      !CBS_get_bytes(in, &s2_bytes, 128 * K) ||
-      !vectork_decode_signed(&priv->s2, CBS_data(&s2_bytes), 4, ETA) ||
+      !CBS_get_bytes(in, &s1_bytes, scalar_bytes * L) ||
+      !vector_decode_signed(&priv->s1, CBS_data(&s1_bytes),
+                            plus_minus_eta_bitlen<K>(), eta<K>()) ||
+      !CBS_get_bytes(in, &s2_bytes, scalar_bytes * K) ||
+      !vector_decode_signed(&priv->s2, CBS_data(&s2_bytes),
+                            plus_minus_eta_bitlen<K>(), eta<K>()) ||
       !CBS_get_bytes(in, &t0_bytes, 416 * K) ||
       // Note: Decoding 13 bits into (-2^12, 2^12] cannot fail.
-      !vectork_decode_signed(&priv->t0, CBS_data(&t0_bytes), 13, 1 << 12)) {
+      !vector_decode_signed(&priv->t0, CBS_data(&t0_bytes), 13, 1 << 12)) {
     return 0;
   }
 
@@ -1200,7 +1284,9 @@
 }
 
 // FIPS 204, Algorithm 26 (`sigEncode`).
-static int mldsa_marshal_signature(CBB *out, const struct signature *sign) {
+template <int K, int L>
+static int mldsa_marshal_signature(CBB *out,
+                                   const struct signature<K, L> *sign) {
   if (!CBB_add_bytes(out, sign->c_tilde, sizeof(sign->c_tilde))) {
     return 0;
   }
@@ -1209,10 +1295,10 @@
   if (!CBB_add_space(out, &vectorl_output, 640 * L)) {
     return 0;
   }
-  vectorl_encode_signed(vectorl_output, &sign->z, 20, 1 << 19);
+  vector_encode_signed(vectorl_output, &sign->z, 20, 1 << 19);
 
   uint8_t *hint_output;
-  if (!CBB_add_space(out, &hint_output, OMEGA + K)) {
+  if (!CBB_add_space(out, &hint_output, omega<K>() + K)) {
     return 0;
   }
   hint_bit_pack(hint_output, &sign->h);
@@ -1221,14 +1307,15 @@
 }
 
 // FIPS 204, Algorithm 27 (`sigDecode`).
-static int mldsa_parse_signature(struct signature *sign, CBS *in) {
+template <int K, int L>
+static int mldsa_parse_signature(struct signature<K, L> *sign, CBS *in) {
   CBS z_bytes;
   CBS hint_bytes;
   if (!CBS_copy_bytes(in, sign->c_tilde, sizeof(sign->c_tilde)) ||
       !CBS_get_bytes(in, &z_bytes, 640 * L) ||
       // Note: Decoding 20 bits into (-2^19, 2^19] cannot fail.
-      !vectorl_decode_signed(&sign->z, CBS_data(&z_bytes), 20, 1 << 19) ||
-      !CBS_get_bytes(in, &hint_bytes, OMEGA + K) ||
+      !vector_decode_signed(&sign->z, CBS_data(&z_bytes), 20, 1 << 19) ||
+      !CBS_get_bytes(in, &hint_bytes, omega<K>() + K) ||
       !hint_bit_unpack(&sign->h, CBS_data(&hint_bytes))) {
     return 0;
   };
@@ -1236,28 +1323,370 @@
   return 1;
 }
 
-static struct private_key *private_key_from_external(
+template <typename T>
+struct DeleterFree {
+  void operator()(T *ptr) { OPENSSL_free(ptr); }
+};
+
+// FIPS 204, Algorithm 6 (`ML-DSA.KeyGen_internal`). Returns 1 on success and 0
+// on failure.
+template <int K, int L>
+static int mldsa_generate_key_external_entropy(
+    uint8_t out_encoded_public_key[public_key_bytes<K>()],
+    struct private_key<K, L> *priv, const uint8_t entropy[MLDSA_SEED_BYTES]) {
+  // Intermediate values, allocated on the heap to allow use when there is a
+  // limited amount of stack.
+  struct values_st {
+    struct public_key<K> pub;
+    matrix<K, L> a_ntt;
+    vector<L> s1_ntt;
+    vector<K> t;
+  };
+  std::unique_ptr<values_st, DeleterFree<values_st>> values(
+      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
+  if (values == NULL) {
+    return 0;
+  }
+
+  uint8_t augmented_entropy[MLDSA_SEED_BYTES + 2];
+  OPENSSL_memcpy(augmented_entropy, entropy, MLDSA_SEED_BYTES);
+  // The k and l parameters are appended to the seed.
+  augmented_entropy[MLDSA_SEED_BYTES] = K;
+  augmented_entropy[MLDSA_SEED_BYTES + 1] = L;
+  uint8_t expanded_seed[kRhoBytes + kSigmaBytes + kKBytes];
+  BORINGSSL_keccak(expanded_seed, sizeof(expanded_seed), augmented_entropy,
+                   sizeof(augmented_entropy), boringssl_shake256);
+  const uint8_t *const rho = expanded_seed;
+  const uint8_t *const sigma = expanded_seed + kRhoBytes;
+  const uint8_t *const k = expanded_seed + kRhoBytes + kSigmaBytes;
+  // rho is public.
+  CONSTTIME_DECLASSIFY(rho, kRhoBytes);
+  OPENSSL_memcpy(values->pub.rho, rho, sizeof(values->pub.rho));
+  OPENSSL_memcpy(priv->rho, rho, sizeof(priv->rho));
+  OPENSSL_memcpy(priv->k, k, sizeof(priv->k));
+
+  matrix_expand(&values->a_ntt, rho);
+  vector_expand_short(&priv->s1, &priv->s2, sigma);
+
+  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
+  vector_ntt(&values->s1_ntt);
+
+  matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
+  vector_inverse_ntt(&values->t);
+  vector_add(&values->t, &values->t, &priv->s2);
+
+  vector_power2_round(&values->pub.t1, &priv->t0, &values->t);
+  // t1 is public.
+  CONSTTIME_DECLASSIFY(&values->pub.t1, sizeof(values->pub.t1));
+
+  CBB cbb;
+  CBB_init_fixed(&cbb, out_encoded_public_key, public_key_bytes<K>());
+  if (!mldsa_marshal_public_key(&cbb, &values->pub)) {
+    return 0;
+  }
+  assert(CBB_len(&cbb) == public_key_bytes<K>());
+
+  BORINGSSL_keccak(priv->public_key_hash, sizeof(priv->public_key_hash),
+                   out_encoded_public_key, public_key_bytes<K>(),
+                   boringssl_shake256);
+
+  return 1;
+}
+
+template <int K, int L>
+static int mldsa_public_from_private(struct public_key<K> *pub,
+                                     const struct private_key<K, L> *priv) {
+  // Intermediate values, allocated on the heap to allow use when there is a
+  // limited amount of stack.
+  struct values_st {
+    matrix<K, L> a_ntt;
+    vector<L> s1_ntt;
+    vector<K> t;
+    vector<K> t0;
+  };
+  std::unique_ptr<values_st, DeleterFree<values_st>> values(
+      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
+  if (values == NULL) {
+    return 0;
+  }
+
+
+  OPENSSL_memcpy(pub->rho, priv->rho, sizeof(pub->rho));
+  OPENSSL_memcpy(pub->public_key_hash, priv->public_key_hash,
+                 sizeof(pub->public_key_hash));
+
+  matrix_expand(&values->a_ntt, priv->rho);
+
+  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
+  vector_ntt(&values->s1_ntt);
+
+  matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
+  vector_inverse_ntt(&values->t);
+  vector_add(&values->t, &values->t, &priv->s2);
+
+  vector_power2_round(&pub->t1, &values->t0, &values->t);
+  return 1;
+}
+
+// FIPS 204, Algorithm 7 (`ML-DSA.Sign_internal`). Returns 1 on success and 0
+// on failure.
+template <int K, int L>
+static int mldsa_sign_internal(
+    uint8_t out_encoded_signature[signature_bytes<K>()],
+    const struct private_key<K, L> *priv, const uint8_t *msg, size_t msg_len,
+    const uint8_t *context_prefix, size_t context_prefix_len,
+    const uint8_t *context, size_t context_len,
+    const uint8_t randomizer[MLDSA_SIGNATURE_RANDOMIZER_BYTES]) {
+  uint8_t mu[kMuBytes];
+  struct BORINGSSL_keccak_st keccak_ctx;
+  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+  BORINGSSL_keccak_absorb(&keccak_ctx, priv->public_key_hash,
+                          sizeof(priv->public_key_hash));
+  BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
+  BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
+  BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
+  BORINGSSL_keccak_squeeze(&keccak_ctx, mu, kMuBytes);
+
+  uint8_t rho_prime[kRhoPrimeBytes];
+  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+  BORINGSSL_keccak_absorb(&keccak_ctx, priv->k, sizeof(priv->k));
+  BORINGSSL_keccak_absorb(&keccak_ctx, randomizer,
+                          MLDSA_SIGNATURE_RANDOMIZER_BYTES);
+  BORINGSSL_keccak_absorb(&keccak_ctx, mu, kMuBytes);
+  BORINGSSL_keccak_squeeze(&keccak_ctx, rho_prime, kRhoPrimeBytes);
+
+  // Intermediate values, allocated on the heap to allow use when there is a
+  // limited amount of stack.
+  struct values_st {
+    struct signature<K, L> sign;
+    vector<L> s1_ntt;
+    vector<K> s2_ntt;
+    vector<K> t0_ntt;
+    matrix<K, L> a_ntt;
+    vector<L> y;
+    vector<K> w;
+    vector<K> w1;
+    vector<L> cs1;
+    vector<K> cs2;
+  };
+  std::unique_ptr<values_st, DeleterFree<values_st>> values(
+      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
+  if (values == NULL) {
+    return 0;
+  }
+  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
+  vector_ntt(&values->s1_ntt);
+
+  OPENSSL_memcpy(&values->s2_ntt, &priv->s2, sizeof(values->s2_ntt));
+  vector_ntt(&values->s2_ntt);
+
+  OPENSSL_memcpy(&values->t0_ntt, &priv->t0, sizeof(values->t0_ntt));
+  vector_ntt(&values->t0_ntt);
+
+  matrix_expand(&values->a_ntt, priv->rho);
+
+  // kappa must not exceed 2**16/L = 13107. But the probability of it
+  // exceeding even 1000 iterations is vanishingly small.
+  for (size_t kappa = 0;; kappa += L) {
+    vector_expand_mask(&values->y, rho_prime, kappa);
+
+    vector<L> *y_ntt = &values->cs1;
+    OPENSSL_memcpy(y_ntt, &values->y, sizeof(*y_ntt));
+    vector_ntt(y_ntt);
+
+    matrix_mult(&values->w, &values->a_ntt, y_ntt);
+    vector_inverse_ntt(&values->w);
+
+    vector_high_bits(&values->w1, &values->w);
+    uint8_t w1_encoded[128 * K];
+    w1_encode(w1_encoded, &values->w1);
+
+    BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+    BORINGSSL_keccak_absorb(&keccak_ctx, mu, kMuBytes);
+    BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
+    BORINGSSL_keccak_squeeze(&keccak_ctx, values->sign.c_tilde,
+                             2 * lambda_bytes<K>());
+
+    scalar c_ntt;
+    scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde,
+                                  sizeof(values->sign.c_tilde), tau<K>());
+    scalar_ntt(&c_ntt);
+
+    vector_mult_scalar(&values->cs1, &values->s1_ntt, &c_ntt);
+    vector_inverse_ntt(&values->cs1);
+    vector_mult_scalar(&values->cs2, &values->s2_ntt, &c_ntt);
+    vector_inverse_ntt(&values->cs2);
+
+    vector_add(&values->sign.z, &values->y, &values->cs1);
+
+    vector<K> *r0 = &values->w1;
+    vector_sub(r0, &values->w, &values->cs2);
+    vector_low_bits(r0, r0);
+
+    // Leaking the fact that a signature was rejected is fine as the next
+    // attempt at a signature will be (indistinguishable from) independent of
+    // this one. Note, however, that we additionally leak which of the two
+    // branches rejected the signature. Section 5.5 of
+    // https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf
+    // describes this leak as OK. Note we leak less than what is described by
+    // the paper; we do not reveal which coefficient violated the bound, and
+    // we hide which of the |z_max| or |r0_max| bound failed. See also
+    // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/2bbab0fa_d241d35a/
+    uint32_t z_max = vector_max(&values->sign.z);
+    uint32_t r0_max = vector_max_signed(r0);
+    if (constant_time_declassify_w(
+            constant_time_ge_w(z_max, gamma1<K>() - beta<K>()) |
+            constant_time_ge_w(r0_max, kGamma2 - beta<K>()))) {
+      continue;
+    }
+
+    vector<K> *ct0 = &values->w1;
+    vector_mult_scalar(ct0, &values->t0_ntt, &c_ntt);
+    vector_inverse_ntt(ct0);
+    vector_make_hint(&values->sign.h, ct0, &values->cs2, &values->w);
+
+    // See above.
+    uint32_t ct0_max = vector_max(ct0);
+    size_t h_ones = vector_count_ones(&values->sign.h);
+    if (constant_time_declassify_w(constant_time_ge_w(ct0_max, kGamma2) |
+                                   constant_time_lt_w(omega<K>(), h_ones))) {
+      continue;
+    }
+
+    // Although computed with the private key, the signature is public.
+    CONSTTIME_DECLASSIFY(values->sign.c_tilde, sizeof(values->sign.c_tilde));
+    CONSTTIME_DECLASSIFY(&values->sign.z, sizeof(values->sign.z));
+    CONSTTIME_DECLASSIFY(&values->sign.h, sizeof(values->sign.h));
+
+    CBB cbb;
+    CBB_init_fixed(&cbb, out_encoded_signature, signature_bytes<K>());
+    if (!mldsa_marshal_signature(&cbb, &values->sign)) {
+      return 0;
+    }
+
+    BSSL_CHECK(CBB_len(&cbb) == signature_bytes<K>());
+    return 1;
+  }
+}
+
+// FIPS 204, Algorithm 8 (`ML-DSA.Verify_internal`).
+template <int K, int L>
+static int mldsa_verify_internal(
+    const struct public_key<K> *pub,
+    const uint8_t encoded_signature[signature_bytes<K>()], const uint8_t *msg,
+    size_t msg_len, const uint8_t *context_prefix, size_t context_prefix_len,
+    const uint8_t *context, size_t context_len) {
+  // Intermediate values, allocated on the heap to allow use when there is a
+  // limited amount of stack.
+  struct values_st {
+    struct signature<K, L> sign;
+    matrix<K, L> a_ntt;
+    vector<L> z_ntt;
+    vector<K> az_ntt;
+    vector<K> ct1_ntt;
+  };
+  std::unique_ptr<values_st, DeleterFree<values_st>> values(
+      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
+  if (values == NULL) {
+    return 0;
+  }
+
+  CBS cbs;
+  CBS_init(&cbs, encoded_signature, signature_bytes<K>());
+  if (!mldsa_parse_signature(&values->sign, &cbs)) {
+    return 0;
+  }
+
+  matrix_expand(&values->a_ntt, pub->rho);
+
+  uint8_t mu[kMuBytes];
+  struct BORINGSSL_keccak_st keccak_ctx;
+  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+  BORINGSSL_keccak_absorb(&keccak_ctx, pub->public_key_hash,
+                          sizeof(pub->public_key_hash));
+  BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
+  BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
+  BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
+  BORINGSSL_keccak_squeeze(&keccak_ctx, mu, kMuBytes);
+
+  scalar c_ntt;
+  scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde,
+                                sizeof(values->sign.c_tilde), tau<K>());
+  scalar_ntt(&c_ntt);
+
+  OPENSSL_memcpy(&values->z_ntt, &values->sign.z, sizeof(values->z_ntt));
+  vector_ntt(&values->z_ntt);
+
+  matrix_mult(&values->az_ntt, &values->a_ntt, &values->z_ntt);
+
+  vector_scale_power2_round(&values->ct1_ntt, &pub->t1);
+  vector_ntt(&values->ct1_ntt);
+
+  vector_mult_scalar(&values->ct1_ntt, &values->ct1_ntt, &c_ntt);
+
+  vector<K> *const w1 = &values->az_ntt;
+  vector_sub(w1, &values->az_ntt, &values->ct1_ntt);
+  vector_inverse_ntt(w1);
+
+  vector_use_hint_vartime(w1, &values->sign.h, w1);
+  uint8_t w1_encoded[128 * K];
+  w1_encode(w1_encoded, w1);
+
+  uint8_t c_tilde[2 * lambda_bytes<K>()];
+  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+  BORINGSSL_keccak_absorb(&keccak_ctx, mu, kMuBytes);
+  BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
+  BORINGSSL_keccak_squeeze(&keccak_ctx, c_tilde, 2 * lambda_bytes<K>());
+
+  uint32_t z_max = vector_max(&values->sign.z);
+  return z_max < static_cast<uint32_t>(gamma1<K>() - beta<K>()) &&
+         OPENSSL_memcmp(c_tilde, values->sign.c_tilde, 2 * lambda_bytes<K>()) ==
+             0;
+}
+
+}  // namespace
+
+// ML-DSA-65 specific wrappers.
+
+static struct private_key<6, 5> *mldsa65_private_key_from_external(
     const struct MLDSA65_private_key *external) {
-  static_assert(
-      sizeof(struct MLDSA65_private_key) == sizeof(struct private_key),
-      "Kyber private key size incorrect");
-  static_assert(
-      alignof(struct MLDSA65_private_key) == alignof(struct private_key),
-      "Kyber private key align incorrect");
-  return (struct private_key *)external;
+  static_assert(sizeof(struct MLDSA65_private_key) ==
+                    sizeof(struct private_key<6, 5>),
+                "MLDSA65 private key size incorrect");
+  static_assert(alignof(struct MLDSA65_private_key) ==
+                    alignof(struct private_key<6, 5>),
+                "MLDSA65 private key align incorrect");
+  return (struct private_key<6, 5> *)external;
 }
 
-static struct public_key *public_key_from_external(
-    const struct MLDSA65_public_key *external) {
-  static_assert(sizeof(struct MLDSA65_public_key) == sizeof(struct public_key),
-                "mldsa public key size incorrect");
-  static_assert(
-      alignof(struct MLDSA65_public_key) == alignof(struct public_key),
-      "mldsa public key align incorrect");
-  return (struct public_key *)external;
+static struct public_key<6> *
+mldsa65_public_key_from_external(const struct MLDSA65_public_key *external) {
+  static_assert(sizeof(struct MLDSA65_public_key) ==
+                    sizeof(struct public_key<6>),
+                "MLDSA65 public key size incorrect");
+  static_assert(alignof(struct MLDSA65_public_key) ==
+                    alignof(struct public_key<6>),
+                "MLDSA65 public key align incorrect");
+  return (struct public_key<6> *)external;
 }
 
-/* API */
+int MLDSA65_parse_public_key(struct MLDSA65_public_key *public_key, CBS *in) {
+  return mldsa_parse_public_key(mldsa65_public_key_from_external(public_key),
+                                in);
+}
+
+int MLDSA65_marshal_private_key(CBB *out,
+                                const struct MLDSA65_private_key *private_key) {
+  return mldsa_marshal_private_key(
+      out, mldsa65_private_key_from_external(private_key));
+}
+
+int MLDSA65_parse_private_key(struct MLDSA65_private_key *private_key,
+                              CBS *in) {
+  return mldsa_parse_private_key(mldsa65_private_key_from_external(private_key),
+                                 in) &&
+         CBS_len(in) == 0;
+}
 
 // Calls |MLDSA_generate_key_external_entropy| with random bytes from
 // |RAND_bytes|. Returns 1 on success and 0 on failure.
@@ -1280,257 +1709,35 @@
                                                seed);
 }
 
-template <typename T>
-struct DeleterFree {
-  void operator()(T *ptr) { OPENSSL_free(ptr); }
-};
-
-// FIPS 204, Algorithm 6 (`ML-DSA.KeyGen_internal`). Returns 1 on success and 0
-// on failure.
 int MLDSA65_generate_key_external_entropy(
     uint8_t out_encoded_public_key[MLDSA65_PUBLIC_KEY_BYTES],
     struct MLDSA65_private_key *out_private_key,
     const uint8_t entropy[MLDSA_SEED_BYTES]) {
-  // Intermediate values, allocated on the heap to allow use when there is a
-  // limited amount of stack.
-  struct values_st {
-    struct public_key pub;
-    matrix a_ntt;
-    vectorl s1_ntt;
-    vectork t;
-  };
-  std::unique_ptr<values_st, DeleterFree<values_st>> values(
-      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
-  if (values == NULL) {
-    return 0;
-  }
-
-  struct private_key *priv = private_key_from_external(out_private_key);
-
-  uint8_t augmented_entropy[MLDSA_SEED_BYTES + 2];
-  OPENSSL_memcpy(augmented_entropy, entropy, MLDSA_SEED_BYTES);
-  // The k and l parameters are appended to the seed.
-  augmented_entropy[MLDSA_SEED_BYTES] = K;
-  augmented_entropy[MLDSA_SEED_BYTES + 1] = L;
-  uint8_t expanded_seed[RHO_BYTES + SIGMA_BYTES + K_BYTES];
-  BORINGSSL_keccak(expanded_seed, sizeof(expanded_seed), augmented_entropy,
-                   sizeof(augmented_entropy), boringssl_shake256);
-  const uint8_t *const rho = expanded_seed;
-  const uint8_t *const sigma = expanded_seed + RHO_BYTES;
-  const uint8_t *const k = expanded_seed + RHO_BYTES + SIGMA_BYTES;
-  // rho is public.
-  CONSTTIME_DECLASSIFY(rho, RHO_BYTES);
-  OPENSSL_memcpy(values->pub.rho, rho, sizeof(values->pub.rho));
-  OPENSSL_memcpy(priv->rho, rho, sizeof(priv->rho));
-  OPENSSL_memcpy(priv->k, k, sizeof(priv->k));
-
-  matrix_expand(&values->a_ntt, rho);
-  vector_expand_short(&priv->s1, &priv->s2, sigma);
-
-  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
-  vectorl_ntt(&values->s1_ntt);
-
-  matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
-  vectork_inverse_ntt(&values->t);
-  vectork_add(&values->t, &values->t, &priv->s2);
-
-  vectork_power2_round(&values->pub.t1, &priv->t0, &values->t);
-  // t1 is public.
-  CONSTTIME_DECLASSIFY(&values->pub.t1, sizeof(values->pub.t1));
-
-  CBB cbb;
-  CBB_init_fixed(&cbb, out_encoded_public_key, MLDSA65_PUBLIC_KEY_BYTES);
-  if (!mldsa_marshal_public_key(&cbb, &values->pub)) {
-    return 0;
-  }
-  assert(CBB_len(&cbb) == MLDSA65_PUBLIC_KEY_BYTES);
-
-  BORINGSSL_keccak(priv->public_key_hash, sizeof(priv->public_key_hash),
-                   out_encoded_public_key, MLDSA65_PUBLIC_KEY_BYTES,
-                   boringssl_shake256);
-
-  return 1;
+  return mldsa_generate_key_external_entropy(
+      out_encoded_public_key,
+      mldsa65_private_key_from_external(out_private_key), entropy);
 }
 
 int MLDSA65_public_from_private(struct MLDSA65_public_key *out_public_key,
                                 const struct MLDSA65_private_key *private_key) {
-  // Intermediate values, allocated on the heap to allow use when there is a
-  // limited amount of stack.
-  struct values_st {
-    matrix a_ntt;
-    vectorl s1_ntt;
-    vectork t;
-    vectork t0;
-  };
-  std::unique_ptr<values_st, DeleterFree<values_st>> values(
-      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
-  if (values == NULL) {
-    return 0;
-  }
-
-  const struct private_key *priv = private_key_from_external(private_key);
-  struct public_key *pub = public_key_from_external(out_public_key);
-
-  OPENSSL_memcpy(pub->rho, priv->rho, sizeof(pub->rho));
-  OPENSSL_memcpy(pub->public_key_hash, priv->public_key_hash,
-                 sizeof(pub->public_key_hash));
-
-  matrix_expand(&values->a_ntt, priv->rho);
-
-  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
-  vectorl_ntt(&values->s1_ntt);
-
-  matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
-  vectork_inverse_ntt(&values->t);
-  vectork_add(&values->t, &values->t, &priv->s2);
-
-  vectork_power2_round(&pub->t1, &values->t0, &values->t);
-  return 1;
+  return mldsa_public_from_private(
+      mldsa65_public_key_from_external(out_public_key),
+      mldsa65_private_key_from_external(private_key));
 }
 
-// FIPS 204, Algorithm 7 (`ML-DSA.Sign_internal`). Returns 1 on success and 0 on
-// failure.
 int MLDSA65_sign_internal(
     uint8_t out_encoded_signature[MLDSA65_SIGNATURE_BYTES],
     const struct MLDSA65_private_key *private_key, const uint8_t *msg,
     size_t msg_len, const uint8_t *context_prefix, size_t context_prefix_len,
     const uint8_t *context, size_t context_len,
     const uint8_t randomizer[MLDSA_SIGNATURE_RANDOMIZER_BYTES]) {
-  const struct private_key *priv = private_key_from_external(private_key);
-
-  uint8_t mu[MU_BYTES];
-  struct BORINGSSL_keccak_st keccak_ctx;
-  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
-  BORINGSSL_keccak_absorb(&keccak_ctx, priv->public_key_hash,
-                          sizeof(priv->public_key_hash));
-  BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
-  BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
-  BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
-  BORINGSSL_keccak_squeeze(&keccak_ctx, mu, MU_BYTES);
-
-  uint8_t rho_prime[RHO_PRIME_BYTES];
-  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
-  BORINGSSL_keccak_absorb(&keccak_ctx, priv->k, sizeof(priv->k));
-  BORINGSSL_keccak_absorb(&keccak_ctx, randomizer,
-                          MLDSA_SIGNATURE_RANDOMIZER_BYTES);
-  BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
-  BORINGSSL_keccak_squeeze(&keccak_ctx, rho_prime, RHO_PRIME_BYTES);
-
-  // Intermediate values, allocated on the heap to allow use when there is a
-  // limited amount of stack.
-  struct values_st {
-    struct signature sign;
-    vectorl s1_ntt;
-    vectork s2_ntt;
-    vectork t0_ntt;
-    matrix a_ntt;
-    vectorl y;
-    vectork w;
-    vectork w1;
-    vectorl cs1;
-    vectork cs2;
-  };
-  std::unique_ptr<values_st, DeleterFree<values_st>> values(
-      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
-  if (values == NULL) {
-    return 0;
-  }
-  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
-  vectorl_ntt(&values->s1_ntt);
-
-  OPENSSL_memcpy(&values->s2_ntt, &priv->s2, sizeof(values->s2_ntt));
-  vectork_ntt(&values->s2_ntt);
-
-  OPENSSL_memcpy(&values->t0_ntt, &priv->t0, sizeof(values->t0_ntt));
-  vectork_ntt(&values->t0_ntt);
-
-  matrix_expand(&values->a_ntt, priv->rho);
-
-  // kappa must not exceed 2**16/L = 13107. But the probability of it exceeding
-  // even 1000 iterations is vanishingly small.
-  for (size_t kappa = 0;; kappa += L) {
-    vectorl_expand_mask(&values->y, rho_prime, kappa);
-
-    vectorl *y_ntt = &values->cs1;
-    OPENSSL_memcpy(y_ntt, &values->y, sizeof(*y_ntt));
-    vectorl_ntt(y_ntt);
-
-    matrix_mult(&values->w, &values->a_ntt, y_ntt);
-    vectork_inverse_ntt(&values->w);
-
-    vectork_high_bits(&values->w1, &values->w);
-    uint8_t w1_encoded[128 * K];
-    w1_encode(w1_encoded, &values->w1);
-
-    BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
-    BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
-    BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
-    BORINGSSL_keccak_squeeze(&keccak_ctx, values->sign.c_tilde,
-                             2 * LAMBDA_BYTES);
-
-    scalar c_ntt;
-    scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde,
-                                  sizeof(values->sign.c_tilde));
-    scalar_ntt(&c_ntt);
-
-    vectorl_mult_scalar(&values->cs1, &values->s1_ntt, &c_ntt);
-    vectorl_inverse_ntt(&values->cs1);
-    vectork_mult_scalar(&values->cs2, &values->s2_ntt, &c_ntt);
-    vectork_inverse_ntt(&values->cs2);
-
-    vectorl_add(&values->sign.z, &values->y, &values->cs1);
-
-    vectork *r0 = &values->w1;
-    vectork_sub(r0, &values->w, &values->cs2);
-    vectork_low_bits(r0, r0);
-
-    // Leaking the fact that a signature was rejected is fine as the next
-    // attempt at a signature will be (indistinguishable from) independent of
-    // this one. Note, however, that we additionally leak which of the two
-    // branches rejected the signature. Section 5.5 of
-    // https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf
-    // describes this leak as OK. Note we leak less than what is described by
-    // the paper; we do not reveal which coefficient violated the bound, and we
-    // hide which of the |z_max| or |r0_max| bound failed. See also
-    // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/2bbab0fa_d241d35a/
-    uint32_t z_max = vectorl_max(&values->sign.z);
-    uint32_t r0_max = vectork_max_signed(r0);
-    if (constant_time_declassify_w(
-            constant_time_ge_w(z_max, kGamma1 - BETA) |
-            constant_time_ge_w(r0_max, kGamma2 - BETA))) {
-      continue;
-    }
-
-    vectork *ct0 = &values->w1;
-    vectork_mult_scalar(ct0, &values->t0_ntt, &c_ntt);
-    vectork_inverse_ntt(ct0);
-    vectork_make_hint(&values->sign.h, ct0, &values->cs2, &values->w);
-
-    // See above.
-    uint32_t ct0_max = vectork_max(ct0);
-    size_t h_ones = vectork_count_ones(&values->sign.h);
-    if (constant_time_declassify_w(constant_time_ge_w(ct0_max, kGamma2) |
-                                   constant_time_lt_w(OMEGA, h_ones))) {
-      continue;
-    }
-
-    // Although computed with the private key, the signature is public.
-    CONSTTIME_DECLASSIFY(values->sign.c_tilde, sizeof(values->sign.c_tilde));
-    CONSTTIME_DECLASSIFY(&values->sign.z, sizeof(values->sign.z));
-    CONSTTIME_DECLASSIFY(&values->sign.h, sizeof(values->sign.h));
-
-    CBB cbb;
-    CBB_init_fixed(&cbb, out_encoded_signature, MLDSA65_SIGNATURE_BYTES);
-    if (!mldsa_marshal_signature(&cbb, &values->sign)) {
-      return 0;
-    }
-
-    BSSL_CHECK(CBB_len(&cbb) == MLDSA65_SIGNATURE_BYTES);
-    return 1;
-  }
+  return mldsa_sign_internal(out_encoded_signature,
+                             mldsa65_private_key_from_external(private_key),
+                             msg, msg_len, context_prefix, context_prefix_len,
+                             context, context_len, randomizer);
 }
 
-// mldsa signature in randomized mode, filling the random bytes with
+// ML-DSA signature in randomized mode, filling the random bytes with
 // |RAND_bytes|. Returns 1 on success and 0 on failure.
 int MLDSA65_sign(uint8_t out_encoded_signature[MLDSA65_SIGNATURE_BYTES],
                  const struct MLDSA65_private_key *private_key,
@@ -1564,108 +1771,18 @@
                                  context, context_len);
 }
 
-// FIPS 204, Algorithm 8 (`ML-DSA.Verify_internal`).
 int MLDSA65_verify_internal(
     const struct MLDSA65_public_key *public_key,
     const uint8_t encoded_signature[MLDSA65_SIGNATURE_BYTES],
     const uint8_t *msg, size_t msg_len, const uint8_t *context_prefix,
     size_t context_prefix_len, const uint8_t *context, size_t context_len) {
-  // Intermediate values, allocated on the heap to allow use when there is a
-  // limited amount of stack.
-  struct values_st {
-    struct signature sign;
-    matrix a_ntt;
-    vectorl z_ntt;
-    vectork az_ntt;
-    vectork ct1_ntt;
-  };
-  std::unique_ptr<values_st, DeleterFree<values_st>> values(
-      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
-  if (values == NULL) {
-    return 0;
-  }
-
-  const struct public_key *pub = public_key_from_external(public_key);
-
-  CBS cbs;
-  CBS_init(&cbs, encoded_signature, MLDSA65_SIGNATURE_BYTES);
-  if (!mldsa_parse_signature(&values->sign, &cbs)) {
-    return 0;
-  }
-
-  matrix_expand(&values->a_ntt, pub->rho);
-
-  uint8_t mu[MU_BYTES];
-  struct BORINGSSL_keccak_st keccak_ctx;
-  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
-  BORINGSSL_keccak_absorb(&keccak_ctx, pub->public_key_hash,
-                          sizeof(pub->public_key_hash));
-  BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
-  BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
-  BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
-  BORINGSSL_keccak_squeeze(&keccak_ctx, mu, MU_BYTES);
-
-  scalar c_ntt;
-  scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde,
-                                sizeof(values->sign.c_tilde));
-  scalar_ntt(&c_ntt);
-
-  OPENSSL_memcpy(&values->z_ntt, &values->sign.z, sizeof(values->z_ntt));
-  vectorl_ntt(&values->z_ntt);
-
-  matrix_mult(&values->az_ntt, &values->a_ntt, &values->z_ntt);
-
-  vectork_scale_power2_round(&values->ct1_ntt, &pub->t1);
-  vectork_ntt(&values->ct1_ntt);
-
-  vectork_mult_scalar(&values->ct1_ntt, &values->ct1_ntt, &c_ntt);
-
-  vectork *const w1 = &values->az_ntt;
-  vectork_sub(w1, &values->az_ntt, &values->ct1_ntt);
-  vectork_inverse_ntt(w1);
-
-  vectork_use_hint_vartime(w1, &values->sign.h, w1);
-  uint8_t w1_encoded[128 * K];
-  w1_encode(w1_encoded, w1);
-
-  uint8_t c_tilde[2 * LAMBDA_BYTES];
-  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
-  BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
-  BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
-  BORINGSSL_keccak_squeeze(&keccak_ctx, c_tilde, 2 * LAMBDA_BYTES);
-
-  uint32_t z_max = vectorl_max(&values->sign.z);
-  return z_max < kGamma1 - BETA &&
-         OPENSSL_memcmp(c_tilde, values->sign.c_tilde, 2 * LAMBDA_BYTES) == 0;
+  return mldsa_verify_internal<6, 5>(
+      mldsa65_public_key_from_external(public_key), encoded_signature, msg,
+      msg_len, context_prefix, context_prefix_len, context, context_len);
 }
 
-/* Serialization of keys. */
-
 int MLDSA65_marshal_public_key(CBB *out,
                                const struct MLDSA65_public_key *public_key) {
-  return mldsa_marshal_public_key(out, public_key_from_external(public_key));
-}
-
-int MLDSA65_parse_public_key(struct MLDSA65_public_key *public_key, CBS *in) {
-  struct public_key *pub = public_key_from_external(public_key);
-  CBS orig_in = *in;
-  if (!mldsa_parse_public_key(pub, in) || CBS_len(in) != 0) {
-    return 0;
-  }
-
-  // Compute pre-cached values.
-  BORINGSSL_keccak(pub->public_key_hash, sizeof(pub->public_key_hash),
-                   CBS_data(&orig_in), CBS_len(&orig_in), boringssl_shake256);
-  return 1;
-}
-
-int MLDSA65_marshal_private_key(CBB *out,
-                                const struct MLDSA65_private_key *private_key) {
-  return mldsa_marshal_private_key(out, private_key_from_external(private_key));
-}
-
-int MLDSA65_parse_private_key(struct MLDSA65_private_key *private_key,
-                              CBS *in) {
-  struct private_key *priv = private_key_from_external(private_key);
-  return mldsa_parse_private_key(priv, in) && CBS_len(in) == 0;
+  return mldsa_marshal_public_key(out,
+                                  mldsa65_public_key_from_external(public_key));
 }
diff --git a/crypto/mldsa/mldsa_test.cc b/crypto/mldsa/mldsa_test.cc
index c9b0828..56035ef 100644
--- a/crypto/mldsa/mldsa_test.cc
+++ b/crypto/mldsa/mldsa_test.cc
@@ -264,6 +264,11 @@
   FileTestGTest("crypto/mldsa/mldsa_nist_keygen_tests.txt", MLDSAKeyGenTest);
 }
 
+template <typename PrivateKey, int (*ParsePrivateKey)(PrivateKey *, CBS *),
+          size_t SignatureBytes,
+          int (*SignInternal)(uint8_t *, const PrivateKey *, const uint8_t *,
+                              size_t, const uint8_t *, size_t, const uint8_t *,
+                              size_t, const uint8_t *)>
 static void MLDSAWycheproofSignTest(FileTest *t) {
   std::vector<uint8_t> private_key_bytes, msg, expected_signature, context;
   ASSERT_TRUE(t->GetInstructionBytes(&private_key_bytes, "privateKey"));
@@ -278,8 +283,8 @@
 
   CBS cbs;
   CBS_init(&cbs, private_key_bytes.data(), private_key_bytes.size());
-  auto priv = std::make_unique<MLDSA65_private_key>();
-  const int priv_ok = MLDSA65_parse_private_key(priv.get(), &cbs);
+  auto priv = std::make_unique<PrivateKey>();
+  const int priv_ok = ParsePrivateKey(priv.get(), &cbs);
 
   if (!priv_ok) {
     ASSERT_TRUE(result != "valid");
@@ -295,21 +300,25 @@
   }
 
   const uint8_t zero_randomizer[MLDSA_SIGNATURE_RANDOMIZER_BYTES] = {0};
-  std::vector<uint8_t> signature(MLDSA65_SIGNATURE_BYTES);
+  std::vector<uint8_t> signature(SignatureBytes);
   const uint8_t context_prefix[2] = {0, static_cast<uint8_t>(context.size())};
-  EXPECT_TRUE(MLDSA65_sign_internal(
-      signature.data(), priv.get(), msg.data(), msg.size(), context_prefix,
-      sizeof(context_prefix), context.data(), context.size(), zero_randomizer));
+  EXPECT_TRUE(SignInternal(signature.data(), priv.get(), msg.data(), msg.size(),
+                           context_prefix, sizeof(context_prefix),
+                           context.data(), context.size(), zero_randomizer));
 
   EXPECT_EQ(Bytes(signature), Bytes(expected_signature));
 }
 
-TEST(MLDSATest, WycheproofSignTests) {
+TEST(MLDSATest, WycheproofSignTests65) {
   FileTestGTest(
       "third_party/wycheproof_testvectors/mldsa_65_standard_sign_test.txt",
-      MLDSAWycheproofSignTest);
+      MLDSAWycheproofSignTest<MLDSA65_private_key, MLDSA65_parse_private_key,
+                              MLDSA65_SIGNATURE_BYTES, MLDSA65_sign_internal>);
 }
 
+template <typename PublicKey, int (*ParsePublicKey)(PublicKey *, CBS *),
+          int (*Verify)(const PublicKey *, const uint8_t *, size_t,
+                        const uint8_t *, size_t, const uint8_t *, size_t)>
 static void MLDSAWycheproofVerifyTest(FileTest *t) {
   std::vector<uint8_t> public_key_bytes, msg, signature, context;
   ASSERT_TRUE(t->GetInstructionBytes(&public_key_bytes, "publicKey"));
@@ -324,8 +333,8 @@
 
   CBS cbs;
   CBS_init(&cbs, public_key_bytes.data(), public_key_bytes.size());
-  auto pub = std::make_unique<MLDSA65_public_key>();
-  const int pub_ok = MLDSA65_parse_public_key(pub.get(), &cbs);
+  auto pub = std::make_unique<PublicKey>();
+  const int pub_ok = ParsePublicKey(pub.get(), &cbs);
 
   if (!pub_ok) {
     EXPECT_EQ(flags, "IncorrectPublicKeyLength");
@@ -333,8 +342,8 @@
   }
 
   const int sig_ok =
-      MLDSA65_verify(pub.get(), signature.data(), signature.size(), msg.data(),
-                     msg.size(), context.data(), context.size());
+      Verify(pub.get(), signature.data(), signature.size(), msg.data(),
+             msg.size(), context.data(), context.size());
   if (!sig_ok) {
     EXPECT_EQ(result, "invalid");
   } else {
@@ -342,10 +351,12 @@
   }
 }
 
-TEST(MLDSATest, WycheproofVerifyTests) {
+
+TEST(MLDSATest, WycheproofVerifyTests65) {
   FileTestGTest(
       "third_party/wycheproof_testvectors/mldsa_65_standard_verify_test.txt",
-      MLDSAWycheproofVerifyTest);
+      MLDSAWycheproofVerifyTest<MLDSA65_public_key, MLDSA65_parse_public_key,
+                                MLDSA65_verify>);
 }
 
 }  // namespace
diff --git a/include/openssl/mldsa.h b/include/openssl/mldsa.h
index a0a7560..3be22f3 100644
--- a/include/openssl/mldsa.h
+++ b/include/openssl/mldsa.h
@@ -22,12 +22,18 @@
 #endif
 
 
-// ML-DSA-65.
+// ML-DSA.
 //
 // This implements the Module-Lattice-Based Digital Signature Standard from
 // https://csrc.nist.gov/pubs/fips/204/final
 
 
+// MLDSA_SEED_BYTES is the number of bytes in an ML-DSA seed value.
+#define MLDSA_SEED_BYTES 32
+
+
+// ML-DSA-65.
+
 // MLDSA65_private_key contains an ML-DSA-65 private key. The contents of this
 // object should never leave the address space since the format is unstable.
 struct MLDSA65_private_key {
@@ -58,9 +64,6 @@
 // signature.
 #define MLDSA65_SIGNATURE_BYTES 3309
 
-// MLDSA_SEED_BYTES is the number of bytes in an ML-DSA seed value.
-#define MLDSA_SEED_BYTES 32
-
 // MLDSA65_generate_key generates a random public/private key pair, writes the
 // encoded public key to |out_encoded_public_key|, writes the seed to
 // |out_seed|, and sets |out_private_key| to the private key. Returns 1 on
@@ -106,9 +109,6 @@
                                   size_t msg_len, const uint8_t *context,
                                   size_t context_len);
 
-
-// Serialisation of keys.
-
 // MLDSA65_marshal_public_key serializes |public_key| to |out| in the standard
 // format for ML-DSA-65 public keys. It returns 1 on success or 0 on
 // allocation error.