mldsa: abstract over the matrix size.
Given the likelihood that we'll need other ML-DSA configurations,
convert ML-DSA to use templates so that other matrix sizes can
be supported.
Change-Id: Iadbdde0a9b36414a256c6103ad549c91d85c0fe7
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/71768
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Adam Langley <agl@google.com>
diff --git a/crypto/mldsa/internal.h b/crypto/mldsa/internal.h
index 08ff329..92f4ae1 100644
--- a/crypto/mldsa/internal.h
+++ b/crypto/mldsa/internal.h
@@ -27,6 +27,9 @@
// random entropy necessary to generate a signature in randomized mode.
#define MLDSA_SIGNATURE_RANDOMIZER_BYTES 32
+
+// ML-DSA-65
+
// MLDSA65_generate_key_external_entropy generates a public/private key pair
// using the given seed, writes the encoded public key to
// |out_encoded_public_key| and sets |out_private_key| to the private key.
diff --git a/crypto/mldsa/mldsa.cc b/crypto/mldsa/mldsa.cc
index a5ada3d..bfb00b7 100644
--- a/crypto/mldsa/mldsa.cc
+++ b/crypto/mldsa/mldsa.cc
@@ -25,49 +25,121 @@
#include "../keccak/internal.h"
#include "./internal.h"
-#define DEGREE 256
-#define K 6
-#define L 5
-#define ETA 4
-#define TAU 49
-#define BETA 196
-#define OMEGA 55
+namespace {
-#define RHO_BYTES 32
-#define SIGMA_BYTES 64
-#define K_BYTES 32
-#define TR_BYTES 64
-#define MU_BYTES 64
-#define RHO_PRIME_BYTES 64
-#define LAMBDA_BITS 192
-#define LAMBDA_BYTES (LAMBDA_BITS / 8)
+constexpr int kDegree = 256;
+constexpr int kRhoBytes = 32;
+constexpr int kSigmaBytes = 64;
+constexpr int kKBytes = 32;
+constexpr int kTrBytes = 64;
+constexpr int kMuBytes = 64;
+constexpr int kRhoPrimeBytes = 64;
// 2^23 - 2^13 + 1
-static const uint32_t kPrime = 8380417;
+constexpr uint32_t kPrime = 8380417;
// Inverse of -kPrime modulo 2^32
-static const uint32_t kPrimeNegInverse = 4236238847;
-static const int kDroppedBits = 13;
-static const uint32_t kHalfPrime = (8380417 - 1) / 2;
-static const uint32_t kGamma1 = 1 << 19;
-static const uint32_t kGamma2 = (8380417 - 1) / 32;
+constexpr uint32_t kPrimeNegInverse = 4236238847;
+constexpr int kDroppedBits = 13;
+constexpr uint32_t kHalfPrime = (kPrime - 1) / 2;
+constexpr uint32_t kGamma2 = (kPrime - 1) / 32;
// 256^-1 mod kPrime, in Montgomery form.
-static const uint32_t kInverseDegreeMontgomery = 41978;
+constexpr uint32_t kInverseDegreeMontgomery = 41978;
+
+// Constants that vary depending on ML-DSA size.
+//
+// These are implemented as templates which take the K parameter to distinguish
+// the ML-DSA sizes. (At the time of writing, `if constexpr` was not available.)
+//
+// TODO(crbug.com/42290600): Switch this to `if constexpr` when C++17 is
+// available.
+
+template <int K>
+constexpr size_t public_key_bytes();
+
+template <>
+constexpr size_t public_key_bytes<6>() {
+ return MLDSA65_PUBLIC_KEY_BYTES;
+}
+
+template <int K>
+constexpr size_t signature_bytes();
+
+template <>
+constexpr size_t signature_bytes<6>() {
+ return MLDSA65_SIGNATURE_BYTES;
+}
+
+template <int K>
+constexpr int tau();
+
+template <>
+constexpr int tau<6>() {
+ return 49;
+}
+
+template <int K>
+constexpr int lambda_bytes();
+
+template <>
+constexpr int lambda_bytes<6>() {
+ return 192 / 8;
+}
+
+template <int K>
+constexpr int gamma1();
+
+template <>
+constexpr int gamma1<6>() {
+ return 1 << 19;
+}
+
+template <int K>
+constexpr int beta();
+
+template <>
+constexpr int beta<6>() {
+ return 196;
+}
+
+template <int K>
+constexpr int omega();
+
+template <>
+constexpr int omega<6>() {
+ return 55;
+}
+
+template <int K>
+constexpr int eta();
+
+template <>
+constexpr int eta<6>() {
+ return 4;
+}
+
+template <int K>
+constexpr int plus_minus_eta_bitlen();
+
+template <>
+constexpr int plus_minus_eta_bitlen<6>() {
+ return 4;
+}
+
+// Fundamental types.
typedef struct scalar {
- uint32_t c[DEGREE];
+ uint32_t c[kDegree];
} scalar;
-typedef struct vectork {
+template <int K>
+struct vector {
scalar v[K];
-} vectork;
+};
-typedef struct vectorl {
- scalar v[L];
-} vectorl;
-
-typedef struct matrix {
+template <int K, int L>
+struct matrix {
scalar v[K][L];
-} matrix;
+};
/* Arithmetic */
@@ -173,13 +245,13 @@
}
static void scalar_add(scalar *out, const scalar *lhs, const scalar *rhs) {
- for (int i = 0; i < DEGREE; i++) {
+ for (int i = 0; i < kDegree; i++) {
out->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
}
}
static void scalar_sub(scalar *out, const scalar *lhs, const scalar *rhs) {
- for (int i = 0; i < DEGREE; i++) {
+ for (int i = 0; i < kDegree; i++) {
out->c[i] = mod_sub(lhs->c[i], rhs->c[i]);
}
}
@@ -195,7 +267,7 @@
// Multiply two scalars in the number theoretically transformed state.
static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
- for (int i = 0; i < DEGREE; i++) {
+ for (int i = 0; i < kDegree; i++) {
out->c[i] = reduce_montgomery((uint64_t)lhs->c[i] * (uint64_t)rhs->c[i]);
}
}
@@ -206,8 +278,8 @@
static void scalar_ntt(scalar *s) {
// Step: 1, 2, 4, 8, ..., 128
// Offset: 128, 64, 32, 16, ..., 1
- int offset = DEGREE;
- for (int step = 1; step < DEGREE; step <<= 1) {
+ int offset = kDegree;
+ for (int step = 1; step < kDegree; step <<= 1) {
offset >>= 1;
int k = 0;
for (int i = 0; i < step; i++) {
@@ -234,8 +306,8 @@
static void scalar_inverse_ntt(scalar *s) {
// Step: 128, 64, 32, 16, ..., 1
// Offset: 1, 2, 4, 8, ..., 128
- int step = DEGREE;
- for (int offset = 1; offset < DEGREE; offset <<= 1) {
+ int step = kDegree;
+ for (int offset = 1; offset < kDegree; offset <<= 1) {
step >>= 1;
int k = 0;
for (int i = 0; i < step; i++) {
@@ -258,72 +330,59 @@
k += 2 * offset;
}
}
- for (int i = 0; i < DEGREE; i++) {
+ for (int i = 0; i < kDegree; i++) {
s->c[i] = reduce_montgomery((uint64_t)s->c[i] *
(uint64_t)kInverseDegreeMontgomery);
}
}
-static void vectork_zero(vectork *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
+template <int X>
+static void vector_zero(vector<X> *out) {
+ OPENSSL_memset(out, 0, sizeof(*out));
+}
-static void vectork_add(vectork *out, const vectork *lhs, const vectork *rhs) {
- for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_add(vector<X> *out, const vector<X> *lhs,
+ const vector<X> *rhs) {
+ for (int i = 0; i < X; i++) {
scalar_add(&out->v[i], &lhs->v[i], &rhs->v[i]);
}
}
-static void vectork_sub(vectork *out, const vectork *lhs, const vectork *rhs) {
- for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_sub(vector<X> *out, const vector<X> *lhs,
+ const vector<X> *rhs) {
+ for (int i = 0; i < X; i++) {
scalar_sub(&out->v[i], &lhs->v[i], &rhs->v[i]);
}
}
-static void vectork_mult_scalar(vectork *out, const vectork *lhs,
- const scalar *rhs) {
- for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_mult_scalar(vector<X> *out, const vector<X> *lhs,
+ const scalar *rhs) {
+ for (int i = 0; i < X; i++) {
scalar_mult(&out->v[i], &lhs->v[i], rhs);
}
}
-static void vectork_ntt(vectork *a) {
- for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_ntt(vector<X> *a) {
+ for (int i = 0; i < X; i++) {
scalar_ntt(&a->v[i]);
}
}
-static void vectork_inverse_ntt(vectork *a) {
- for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_inverse_ntt(vector<X> *a) {
+ for (int i = 0; i < X; i++) {
scalar_inverse_ntt(&a->v[i]);
}
}
-static void vectorl_add(vectorl *out, const vectorl *lhs, const vectorl *rhs) {
- for (int i = 0; i < L; i++) {
- scalar_add(&out->v[i], &lhs->v[i], &rhs->v[i]);
- }
-}
-
-static void vectorl_mult_scalar(vectorl *out, const vectorl *lhs,
- const scalar *rhs) {
- for (int i = 0; i < L; i++) {
- scalar_mult(&out->v[i], &lhs->v[i], rhs);
- }
-}
-
-static void vectorl_ntt(vectorl *a) {
- for (int i = 0; i < L; i++) {
- scalar_ntt(&a->v[i]);
- }
-}
-
-static void vectorl_inverse_ntt(vectorl *a) {
- for (int i = 0; i < L; i++) {
- scalar_inverse_ntt(&a->v[i]);
- }
-}
-
-static void matrix_mult(vectork *out, const matrix *m, const vectorl *a) {
- vectork_zero(out);
+template <int K, int L>
+static void matrix_mult(vector<K> *out, const matrix<K, L> *m,
+ const vector<L> *a) {
+ vector_zero(out);
for (int i = 0; i < K; i++) {
for (int j = 0; j < L; j++) {
scalar product;
@@ -435,38 +494,38 @@
}
static void scalar_power2_round(scalar *s1, scalar *s0, const scalar *s) {
- for (int i = 0; i < DEGREE; i++) {
+ for (int i = 0; i < kDegree; i++) {
power2_round(&s1->c[i], &s0->c[i], s->c[i]);
}
}
static void scalar_scale_power2_round(scalar *out, const scalar *in) {
- for (int i = 0; i < DEGREE; i++) {
+ for (int i = 0; i < kDegree; i++) {
scale_power2_round(&out->c[i], in->c[i]);
}
}
static void scalar_high_bits(scalar *out, const scalar *in) {
- for (int i = 0; i < DEGREE; i++) {
+ for (int i = 0; i < kDegree; i++) {
out->c[i] = high_bits(in->c[i]);
}
}
static void scalar_low_bits(scalar *out, const scalar *in) {
- for (int i = 0; i < DEGREE; i++) {
+ for (int i = 0; i < kDegree; i++) {
out->c[i] = low_bits(in->c[i]);
}
}
static void scalar_max(uint32_t *max, const scalar *s) {
- for (int i = 0; i < DEGREE; i++) {
+ for (int i = 0; i < kDegree; i++) {
uint32_t abs = abs_mod_prime(s->c[i]);
*max = maximum(*max, abs);
}
}
static void scalar_max_signed(uint32_t *max, const scalar *s) {
- for (int i = 0; i < DEGREE; i++) {
+ for (int i = 0; i < kDegree; i++) {
uint32_t abs = abs_signed(s->c[i]);
*max = maximum(*max, abs);
}
@@ -474,98 +533,100 @@
static void scalar_make_hint(scalar *out, const scalar *ct0, const scalar *cs2,
const scalar *w) {
- for (int i = 0; i < DEGREE; i++) {
+ for (int i = 0; i < kDegree; i++) {
out->c[i] = make_hint(ct0->c[i], cs2->c[i], w->c[i]);
}
}
static void scalar_use_hint_vartime(scalar *out, const scalar *h,
const scalar *r) {
- for (int i = 0; i < DEGREE; i++) {
+ for (int i = 0; i < kDegree; i++) {
out->c[i] = use_hint_vartime(h->c[i], r->c[i]);
}
}
-static void vectork_power2_round(vectork *t1, vectork *t0, const vectork *t) {
- for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_power2_round(vector<X> *t1, vector<X> *t0,
+ const vector<X> *t) {
+ for (int i = 0; i < X; i++) {
scalar_power2_round(&t1->v[i], &t0->v[i], &t->v[i]);
}
}
-static void vectork_scale_power2_round(vectork *out, const vectork *in) {
- for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_scale_power2_round(vector<X> *out, const vector<X> *in) {
+ for (int i = 0; i < X; i++) {
scalar_scale_power2_round(&out->v[i], &in->v[i]);
}
}
-static void vectork_high_bits(vectork *out, const vectork *in) {
- for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_high_bits(vector<X> *out, const vector<X> *in) {
+ for (int i = 0; i < X; i++) {
scalar_high_bits(&out->v[i], &in->v[i]);
}
}
-static void vectork_low_bits(vectork *out, const vectork *in) {
- for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_low_bits(vector<X> *out, const vector<X> *in) {
+ for (int i = 0; i < X; i++) {
scalar_low_bits(&out->v[i], &in->v[i]);
}
}
-static uint32_t vectork_max(const vectork *a) {
+template <int X>
+static uint32_t vector_max(const vector<X> *a) {
uint32_t max = 0;
- for (int i = 0; i < K; i++) {
+ for (int i = 0; i < X; i++) {
scalar_max(&max, &a->v[i]);
}
return max;
}
-static uint32_t vectork_max_signed(const vectork *a) {
+template <int X>
+static uint32_t vector_max_signed(const vector<X> *a) {
uint32_t max = 0;
- for (int i = 0; i < K; i++) {
+ for (int i = 0; i < X; i++) {
scalar_max_signed(&max, &a->v[i]);
}
return max;
}
// The input vector contains only zeroes and ones.
-static size_t vectork_count_ones(const vectork *a) {
+template <int X>
+static size_t vector_count_ones(const vector<X> *a) {
size_t count = 0;
- for (int i = 0; i < K; i++) {
- for (int j = 0; j < DEGREE; j++) {
+ for (int i = 0; i < X; i++) {
+ for (int j = 0; j < kDegree; j++) {
count += a->v[i].c[j];
}
}
return count;
}
-static void vectork_make_hint(vectork *out, const vectork *ct0,
- const vectork *cs2, const vectork *w) {
- for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_make_hint(vector<X> *out, const vector<X> *ct0,
+ const vector<X> *cs2, const vector<X> *w) {
+ for (int i = 0; i < X; i++) {
scalar_make_hint(&out->v[i], &ct0->v[i], &cs2->v[i], &w->v[i]);
}
}
-static void vectork_use_hint_vartime(vectork *out, const vectork *h,
- const vectork *r) {
- for (int i = 0; i < K; i++) {
+template <int X>
+static void vector_use_hint_vartime(vector<X> *out, const vector<X> *h,
+ const vector<X> *r) {
+ for (int i = 0; i < X; i++) {
scalar_use_hint_vartime(&out->v[i], &h->v[i], &r->v[i]);
}
}
-static uint32_t vectorl_max(const vectorl *a) {
- uint32_t max = 0;
- for (int i = 0; i < L; i++) {
- scalar_max(&max, &a->v[i]);
- }
- return max;
-}
-
/* Bit packing */
// FIPS 204, Algorithm 16 (`SimpleBitPack`). Specialized to bitlen(b) = 4.
static void scalar_encode_4(uint8_t out[128], const scalar *s) {
// Every two elements lands on a byte boundary.
- static_assert(DEGREE % 2 == 0, "DEGREE must be a multiple of 2");
- for (int i = 0; i < DEGREE / 2; i++) {
+ static_assert(kDegree % 2 == 0, "kDegree must be a multiple of 2");
+ for (int i = 0; i < kDegree / 2; i++) {
uint32_t a = s->c[2 * i];
uint32_t b = s->c[2 * i + 1];
declassify_assert(a < 16);
@@ -577,8 +638,8 @@
// FIPS 204, Algorithm 16 (`SimpleBitPack`). Specialized to bitlen(b) = 10.
static void scalar_encode_10(uint8_t out[320], const scalar *s) {
// Every four elements lands on a byte boundary.
- static_assert(DEGREE % 4 == 0, "DEGREE must be a multiple of 4");
- for (int i = 0; i < DEGREE / 4; i++) {
+ static_assert(kDegree % 4 == 0, "kDegree must be a multiple of 4");
+ for (int i = 0; i < kDegree / 4; i++) {
uint32_t a = s->c[4 * i];
uint32_t b = s->c[4 * i + 1];
uint32_t c = s->c[4 * i + 2];
@@ -595,14 +656,13 @@
}
}
-// FIPS 204, Algorithm 17 (`BitPack`). Specialized to bitlen(b) = 4 and b =
-// 2^19.
-static void scalar_encode_signed_4_eta(uint8_t out[128], const scalar *s) {
+// FIPS 204, Algorithm 17 (`BitPack`). Specialized to bitlen(b) = 4 and b = 4.
+static void scalar_encode_signed_4_4(uint8_t out[128], const scalar *s) {
// Every two elements lands on a byte boundary.
- static_assert(DEGREE % 2 == 0, "DEGREE must be a multiple of 2");
- for (int i = 0; i < DEGREE / 2; i++) {
- uint32_t a = mod_sub(ETA, s->c[2 * i]);
- uint32_t b = mod_sub(ETA, s->c[2 * i + 1]);
+ static_assert(kDegree % 2 == 0, "kDegree must be a multiple of 2");
+ for (int i = 0; i < kDegree / 2; i++) {
+ uint32_t a = mod_sub(4, s->c[2 * i]);
+ uint32_t b = mod_sub(4, s->c[2 * i + 1]);
declassify_assert(a < 16);
declassify_assert(b < 16);
out[i] = a | (b << 4);
@@ -614,8 +674,8 @@
static void scalar_encode_signed_13_12(uint8_t out[416], const scalar *s) {
static const uint32_t kMax = 1u << 12;
// Every two elements lands on a byte boundary.
- static_assert(DEGREE % 8 == 0, "DEGREE must be a multiple of 8");
- for (int i = 0; i < DEGREE / 8; i++) {
+ static_assert(kDegree % 8 == 0, "kDegree must be a multiple of 8");
+ for (int i = 0; i < kDegree / 8; i++) {
uint32_t a = mod_sub(kMax, s->c[8 * i]);
uint32_t b = mod_sub(kMax, s->c[8 * i + 1]);
uint32_t c = mod_sub(kMax, s->c[8 * i + 2]);
@@ -654,8 +714,8 @@
static void scalar_encode_signed_20_19(uint8_t out[640], const scalar *s) {
static const uint32_t kMax = 1u << 19;
// Every two elements lands on a byte boundary.
- static_assert(DEGREE % 4 == 0, "DEGREE must be a multiple of 4");
- for (int i = 0; i < DEGREE / 4; i++) {
+ static_assert(kDegree % 4 == 0, "kDegree must be a multiple of 4");
+ for (int i = 0; i < kDegree / 4; i++) {
uint32_t a = mod_sub(kMax, s->c[4 * i]);
uint32_t b = mod_sub(kMax, s->c[4 * i + 1]);
uint32_t c = mod_sub(kMax, s->c[4 * i + 2]);
@@ -679,8 +739,8 @@
static void scalar_encode_signed(uint8_t *out, const scalar *s, int bits,
uint32_t max) {
if (bits == 4) {
- assert(max == ETA);
- scalar_encode_signed_4_eta(out, s);
+ assert(max == 4);
+ scalar_encode_signed_4_4(out, s);
} else if (bits == 20) {
assert(max == 1u << 19);
scalar_encode_signed_20_19(out, s);
@@ -694,8 +754,8 @@
// FIPS 204, Algorithm 18 (`SimpleBitUnpack`). Specialized for bitlen(b) == 10.
static void scalar_decode_10(scalar *out, const uint8_t in[320]) {
uint32_t v;
- static_assert(DEGREE % 4 == 0, "DEGREE must be a multiple of 4");
- for (int i = 0; i < DEGREE / 4; i++) {
+ static_assert(kDegree % 4 == 0, "kDegree must be a multiple of 4");
+ for (int i = 0; i < kDegree / 4; i++) {
OPENSSL_memcpy(&v, &in[5 * i], sizeof(v));
out->c[4 * i] = v & 0x3ff;
out->c[4 * i + 1] = (v >> 10) & 0x3ff;
@@ -705,13 +765,12 @@
}
// FIPS 204, Algorithm 19 (`BitUnpack`). Specialized to bitlen(a+b) = 4 and b =
-// eta.
-static int scalar_decode_signed_4_eta(scalar *out, const uint8_t in[128]) {
+// 4.
+static int scalar_decode_signed_4_4(scalar *out, const uint8_t in[128]) {
uint32_t v;
- static_assert(DEGREE % 8 == 0, "DEGREE must be a multiple of 8");
- for (int i = 0; i < DEGREE / 8; i++) {
+ static_assert(kDegree % 8 == 0, "kDegree must be a multiple of 8");
+ for (int i = 0; i < kDegree / 8; i++) {
OPENSSL_memcpy(&v, &in[4 * i], sizeof(v));
- static_assert(ETA == 4, "ETA must be 4");
// None of the nibbles may be >= 9. So if the MSB of any nibble is set, none
// of the other bits may be set. First, select all the MSBs.
const uint32_t msbs = v & 0x88888888u;
@@ -723,14 +782,14 @@
return 0;
}
- out->c[i * 8] = mod_sub(ETA, v & 15);
- out->c[i * 8 + 1] = mod_sub(ETA, (v >> 4) & 15);
- out->c[i * 8 + 2] = mod_sub(ETA, (v >> 8) & 15);
- out->c[i * 8 + 3] = mod_sub(ETA, (v >> 12) & 15);
- out->c[i * 8 + 4] = mod_sub(ETA, (v >> 16) & 15);
- out->c[i * 8 + 5] = mod_sub(ETA, (v >> 20) & 15);
- out->c[i * 8 + 6] = mod_sub(ETA, (v >> 24) & 15);
- out->c[i * 8 + 7] = mod_sub(ETA, v >> 28);
+ out->c[i * 8] = mod_sub(4, v & 15);
+ out->c[i * 8 + 1] = mod_sub(4, (v >> 4) & 15);
+ out->c[i * 8 + 2] = mod_sub(4, (v >> 8) & 15);
+ out->c[i * 8 + 3] = mod_sub(4, (v >> 12) & 15);
+ out->c[i * 8 + 4] = mod_sub(4, (v >> 16) & 15);
+ out->c[i * 8 + 5] = mod_sub(4, (v >> 20) & 15);
+ out->c[i * 8 + 6] = mod_sub(4, (v >> 24) & 15);
+ out->c[i * 8 + 7] = mod_sub(4, v >> 28);
}
return 1;
}
@@ -744,8 +803,8 @@
uint32_t a, b, c;
uint8_t d;
- static_assert(DEGREE % 8 == 0, "DEGREE must be a multiple of 8");
- for (int i = 0; i < DEGREE / 8; i++) {
+ static_assert(kDegree % 8 == 0, "kDegree must be a multiple of 8");
+ for (int i = 0; i < kDegree / 8; i++) {
OPENSSL_memcpy(&a, &in[13 * i], sizeof(a));
OPENSSL_memcpy(&b, &in[13 * i + 4], sizeof(b));
OPENSSL_memcpy(&c, &in[13 * i + 8], sizeof(c));
@@ -772,8 +831,8 @@
uint32_t a, b;
uint16_t c;
- static_assert(DEGREE % 4 == 0, "DEGREE must be a multiple of 4");
- for (int i = 0; i < DEGREE / 4; i++) {
+ static_assert(kDegree % 4 == 0, "kDegree must be a multiple of 4");
+ for (int i = 0; i < kDegree / 4; i++) {
OPENSSL_memcpy(&a, &in[10 * i], sizeof(a));
OPENSSL_memcpy(&b, &in[10 * i + 4], sizeof(b));
OPENSSL_memcpy(&c, &in[10 * i + 8], sizeof(c));
@@ -791,8 +850,8 @@
static int scalar_decode_signed(scalar *out, const uint8_t *in, int bits,
uint32_t max) {
if (bits == 4) {
- assert(max == ETA);
- return scalar_decode_signed_4_eta(out, in);
+ assert(max == 4);
+ return scalar_decode_signed_4_4(out, in);
} else if (bits == 13) {
assert(max == (1u << 12));
scalar_decode_signed_13_12(out, in);
@@ -813,19 +872,19 @@
// Rejection samples a Keccak stream to get uniformly distributed elements. This
// is used for matrix expansion and only operates on public inputs.
static void scalar_from_keccak_vartime(
- scalar *out, const uint8_t derived_seed[RHO_BYTES + 2]) {
+ scalar *out, const uint8_t derived_seed[kRhoBytes + 2]) {
struct BORINGSSL_keccak_st keccak_ctx;
BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
- BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, RHO_BYTES + 2);
+ BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, kRhoBytes + 2);
assert(keccak_ctx.squeeze_offset == 0);
assert(keccak_ctx.rate_bytes == 168);
static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
int done = 0;
- while (done < DEGREE) {
+ while (done < kDegree) {
uint8_t block[168];
BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
- for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
+ for (size_t i = 0; i < sizeof(block) && done < kDegree; i += 3) {
// FIPS 204, Algorithm 14 (`CoeffFromThreeBytes`).
uint32_t value = (uint32_t)block[i] | ((uint32_t)block[i + 1] << 8) |
(((uint32_t)block[i + 2] & 0x7f) << 16);
@@ -836,22 +895,33 @@
}
}
-// FIPS 204, Algorithm 31 (`RejBoundedPoly`).
-static void scalar_uniform_eta_4(scalar *out,
- const uint8_t derived_seed[SIGMA_BYTES + 2]) {
- static_assert(ETA == 4, "This implementation is specialized for ETA == 4");
+template <int ETA>
+static bool coefficient_from_nibble(uint32_t nibble, uint32_t *result);
+template <>
+bool coefficient_from_nibble<4>(uint32_t nibble, uint32_t *result) {
+ if (constant_time_declassify_int(nibble < 9)) {
+ *result = mod_sub(4, nibble);
+ return true;
+ }
+ return false;
+}
+
+// FIPS 204, Algorithm 31 (`RejBoundedPoly`).
+template <int ETA>
+static void scalar_uniform(scalar *out,
+ const uint8_t derived_seed[kSigmaBytes + 2]) {
struct BORINGSSL_keccak_st keccak_ctx;
BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
- BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, SIGMA_BYTES + 2);
+ BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, kSigmaBytes + 2);
assert(keccak_ctx.squeeze_offset == 0);
assert(keccak_ctx.rate_bytes == 136);
int done = 0;
- while (done < DEGREE) {
+ while (done < kDegree) {
uint8_t block[136];
BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
- for (size_t i = 0; i < sizeof(block) && done < DEGREE; ++i) {
+ for (size_t i = 0; i < sizeof(block) && done < kDegree; ++i) {
uint32_t t0 = block[i] & 0x0F;
uint32_t t1 = block[i] >> 4;
// FIPS 204, Algorithm 15 (`CoefFromHalfByte`). Although both the input
@@ -859,21 +929,22 @@
// Individual bytes of the SHAKE-256 stream are (indistiguishable from)
// independent of each other and the original seed, so leaking information
// about the rejected bytes does not reveal the input or output.
- if (constant_time_declassify_int(t0 < 9)) {
- out->c[done++] = mod_sub(ETA, t0);
+ uint32_t v;
+ if (coefficient_from_nibble<ETA>(t0, &v)) {
+ out->c[done++] = v;
}
- if (done < DEGREE && constant_time_declassify_int(t1 < 9)) {
- out->c[done++] = mod_sub(ETA, t1);
+ if (done < kDegree && coefficient_from_nibble<ETA>(t1, &v)) {
+ out->c[done++] = v;
}
}
}
}
// FIPS 204, Algorithm 34 (`ExpandMask`), but just a single step.
-static void scalar_sample_mask(
- scalar *out, const uint8_t derived_seed[RHO_PRIME_BYTES + 2]) {
+static void scalar_sample_mask(scalar *out,
+ const uint8_t derived_seed[kRhoPrimeBytes + 2]) {
uint8_t buf[640];
- BORINGSSL_keccak(buf, sizeof(buf), derived_seed, RHO_PRIME_BYTES + 2,
+ BORINGSSL_keccak(buf, sizeof(buf), derived_seed, kRhoPrimeBytes + 2,
boringssl_shake256);
scalar_decode_signed_20_19(out, buf);
@@ -881,9 +952,7 @@
// FIPS 204, Algorithm 29 (`SampleInBall`).
static void scalar_sample_in_ball_vartime(scalar *out, const uint8_t *seed,
- int len) {
- assert(len == 2 * LAMBDA_BYTES);
-
+ int len, int tau) {
struct BORINGSSL_keccak_st keccak_ctx;
BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
BORINGSSL_keccak_absorb(&keccak_ctx, seed, len);
@@ -902,7 +971,7 @@
CONSTTIME_DECLASSIFY(block + offset, sizeof(block) - offset);
OPENSSL_memset(out, 0, sizeof(*out));
- for (size_t i = DEGREE - TAU; i < DEGREE; i++) {
+ for (size_t i = kDegree - tau; i < kDegree; i++) {
size_t byte;
for (;;) {
if (offset == 136) {
@@ -925,54 +994,57 @@
}
// FIPS 204, Algorithm 32 (`ExpandA`).
-static void matrix_expand(matrix *out, const uint8_t rho[RHO_BYTES]) {
+template <int K, int L>
+static void matrix_expand(matrix<K, L> *out, const uint8_t rho[kRhoBytes]) {
static_assert(K <= 0x100, "K must fit in 8 bits");
static_assert(L <= 0x100, "L must fit in 8 bits");
- uint8_t derived_seed[RHO_BYTES + 2];
- OPENSSL_memcpy(derived_seed, rho, RHO_BYTES);
+ uint8_t derived_seed[kRhoBytes + 2];
+ OPENSSL_memcpy(derived_seed, rho, kRhoBytes);
for (int i = 0; i < K; i++) {
for (int j = 0; j < L; j++) {
- derived_seed[RHO_BYTES + 1] = (uint8_t)i;
- derived_seed[RHO_BYTES] = (uint8_t)j;
+ derived_seed[kRhoBytes + 1] = (uint8_t)i;
+ derived_seed[kRhoBytes] = (uint8_t)j;
scalar_from_keccak_vartime(&out->v[i][j], derived_seed);
}
}
}
// FIPS 204, Algorithm 33 (`ExpandS`).
-static void vector_expand_short(vectorl *s1, vectork *s2,
- const uint8_t sigma[SIGMA_BYTES]) {
+template <int K, int L>
+static void vector_expand_short(vector<L> *s1, vector<K> *s2,
+ const uint8_t sigma[kSigmaBytes]) {
static_assert(K <= 0x100, "K must fit in 8 bits");
static_assert(L <= 0x100, "L must fit in 8 bits");
static_assert(K + L <= 0x100, "K+L must fit in 8 bits");
- uint8_t derived_seed[SIGMA_BYTES + 2];
- OPENSSL_memcpy(derived_seed, sigma, SIGMA_BYTES);
- derived_seed[SIGMA_BYTES] = 0;
- derived_seed[SIGMA_BYTES + 1] = 0;
+ uint8_t derived_seed[kSigmaBytes + 2];
+ OPENSSL_memcpy(derived_seed, sigma, kSigmaBytes);
+ derived_seed[kSigmaBytes] = 0;
+ derived_seed[kSigmaBytes + 1] = 0;
for (int i = 0; i < L; i++) {
- scalar_uniform_eta_4(&s1->v[i], derived_seed);
- ++derived_seed[SIGMA_BYTES];
+ scalar_uniform<eta<K>()>(&s1->v[i], derived_seed);
+ ++derived_seed[kSigmaBytes];
}
for (int i = 0; i < K; i++) {
- scalar_uniform_eta_4(&s2->v[i], derived_seed);
- ++derived_seed[SIGMA_BYTES];
+ scalar_uniform<eta<K>()>(&s2->v[i], derived_seed);
+ ++derived_seed[kSigmaBytes];
}
}
// FIPS 204, Algorithm 34 (`ExpandMask`).
-static void vectorl_expand_mask(vectorl *out,
- const uint8_t seed[RHO_PRIME_BYTES],
- size_t kappa) {
+template <int L>
+static void vector_expand_mask(vector<L> *out,
+ const uint8_t seed[kRhoPrimeBytes],
+ size_t kappa) {
assert(kappa + L <= 0x10000);
- uint8_t derived_seed[RHO_PRIME_BYTES + 2];
- OPENSSL_memcpy(derived_seed, seed, RHO_PRIME_BYTES);
+ uint8_t derived_seed[kRhoPrimeBytes + 2];
+ OPENSSL_memcpy(derived_seed, seed, kRhoPrimeBytes);
for (int i = 0; i < L; i++) {
size_t index = kappa + i;
- derived_seed[RHO_PRIME_BYTES] = index & 0xFF;
- derived_seed[RHO_PRIME_BYTES + 1] = (index >> 8) & 0xFF;
+ derived_seed[kRhoPrimeBytes] = index & 0xFF;
+ derived_seed[kRhoPrimeBytes + 1] = (index >> 8) & 0xFF;
scalar_sample_mask(&out->v[i], derived_seed);
}
}
@@ -981,63 +1053,49 @@
// FIPS 204, Algorithm 16 (`SimpleBitPack`).
//
-// Encodes an entire vector into 32*K*|bits| bytes. Note that since 256 (DEGREE)
-// is divisible by 8, the individual vector entries will always fill a whole
-// number of bytes, so we do not need to worry about bit packing here.
-static void vectork_encode(uint8_t *out, const vectork *a, int bits) {
+// Encodes an entire vector into 32*K*|bits| bytes. Note that since 256
+// (kDegree) is divisible by 8, the individual vector entries will always fill a
+// whole number of bytes, so we do not need to worry about bit packing here.
+template <int K>
+static void vector_encode(uint8_t *out, const vector<K> *a, int bits) {
if (bits == 4) {
for (int i = 0; i < K; i++) {
- scalar_encode_4(out + i * bits * DEGREE / 8, &a->v[i]);
+ scalar_encode_4(out + i * bits * kDegree / 8, &a->v[i]);
}
} else {
assert(bits == 10);
for (int i = 0; i < K; i++) {
- scalar_encode_10(out + i * bits * DEGREE / 8, &a->v[i]);
+ scalar_encode_10(out + i * bits * kDegree / 8, &a->v[i]);
}
}
}
// FIPS 204, Algorithm 18 (`SimpleBitUnpack`).
-static void vectork_decode_10(vectork *out, const uint8_t *in) {
+template <int K>
+static void vector_decode_10(vector<K> *out, const uint8_t *in) {
for (int i = 0; i < K; i++) {
- scalar_decode_10(&out->v[i], in + i * 10 * DEGREE / 8);
+ scalar_decode_10(&out->v[i], in + i * 10 * kDegree / 8);
}
}
-static void vectork_encode_signed(uint8_t *out, const vectork *a, int bits,
- uint32_t max) {
- for (int i = 0; i < K; i++) {
- scalar_encode_signed(out + i * bits * DEGREE / 8, &a->v[i], bits, max);
- }
-}
-
-static int vectork_decode_signed(vectork *out, const uint8_t *in, int bits,
- uint32_t max) {
- for (int i = 0; i < K; i++) {
- if (!scalar_decode_signed(&out->v[i], in + i * bits * DEGREE / 8, bits,
- max)) {
- return 0;
- }
- }
- return 1;
-}
-
// FIPS 204, Algorithm 17 (`BitPack`).
//
-// Encodes an entire vector into 32*L*|bits| bytes. Note that since 256 (DEGREE)
-// is divisible by 8, the individual vector entries will always fill a whole
-// number of bytes, so we do not need to worry about bit packing here.
-static void vectorl_encode_signed(uint8_t *out, const vectorl *a, int bits,
- uint32_t max) {
- for (int i = 0; i < L; i++) {
- scalar_encode_signed(out + i * bits * DEGREE / 8, &a->v[i], bits, max);
+// Encodes an entire vector into 32*L*|bits| bytes. Note that since 256
+// (kDegree) is divisible by 8, the individual vector entries will always fill a
+// whole number of bytes, so we do not need to worry about bit packing here.
+template <int X>
+static void vector_encode_signed(uint8_t *out, const vector<X> *a, int bits,
+ uint32_t max) {
+ for (int i = 0; i < X; i++) {
+ scalar_encode_signed(out + i * bits * kDegree / 8, &a->v[i], bits, max);
}
}
-static int vectorl_decode_signed(vectorl *out, const uint8_t *in, int bits,
- uint32_t max) {
- for (int i = 0; i < L; i++) {
- if (!scalar_decode_signed(&out->v[i], in + i * bits * DEGREE / 8, bits,
+template <int X>
+static int vector_decode_signed(vector<X> *out, const uint8_t *in, int bits,
+ uint32_t max) {
+ for (int i = 0; i < X; i++) {
+ if (!scalar_decode_signed(&out->v[i], in + i * bits * kDegree / 8, bits,
max)) {
return 0;
}
@@ -1046,33 +1104,36 @@
}
// FIPS 204, Algorithm 28 (`w1Encode`).
-static void w1_encode(uint8_t out[128 * K], const vectork *w1) {
- vectork_encode(out, w1, 4);
+template <int K>
+static void w1_encode(uint8_t out[128 * K], const vector<K> *w1) {
+ vector_encode(out, w1, 4);
}
// FIPS 204, Algorithm 20 (`HintBitPack`).
-static void hint_bit_pack(uint8_t out[OMEGA + K], const vectork *h) {
- OPENSSL_memset(out, 0, OMEGA + K);
+template <int K>
+static void hint_bit_pack(uint8_t out[omega<K>() + K], const vector<K> *h) {
+ OPENSSL_memset(out, 0, omega<K>() + K);
int index = 0;
for (int i = 0; i < K; i++) {
- for (int j = 0; j < DEGREE; j++) {
+ for (int j = 0; j < kDegree; j++) {
if (h->v[i].c[j]) {
- // h must have at most OMEGA non-zero coefficients.
- BSSL_CHECK(index < OMEGA);
+ // h must have at most omega<K>() non-zero coefficients.
+ BSSL_CHECK(index < omega<K>());
out[index++] = j;
}
}
- out[OMEGA + i] = index;
+ out[omega<K>() + i] = index;
}
}
// FIPS 204, Algorithm 21 (`HintBitUnpack`).
-static int hint_bit_unpack(vectork *h, const uint8_t in[OMEGA + K]) {
- vectork_zero(h);
+template <int K>
+static int hint_bit_unpack(vector<K> *h, const uint8_t in[omega<K>() + K]) {
+ vector_zero(h);
int index = 0;
for (int i = 0; i < K; i++) {
- const int limit = in[OMEGA + i];
- if (limit < index || limit > OMEGA) {
+ const int limit = in[omega<K>() + i];
+ if (limit < index || limit > omega<K>()) {
return 0;
}
@@ -1083,12 +1144,12 @@
return 0;
}
last = byte;
- static_assert(DEGREE == 256,
- "DEGREE must be 256 for this write to be in bounds");
+ static_assert(kDegree == 256,
+ "kDegree must be 256 for this write to be in bounds");
h->v[i].c[byte] = 1;
}
}
- for (; index < OMEGA; index++) {
+ for (; index < omega<K>(); index++) {
if (in[index] != 0) {
return 0;
}
@@ -1096,30 +1157,34 @@
return 1;
}
+template <int K>
struct public_key {
- uint8_t rho[RHO_BYTES];
- vectork t1;
+ uint8_t rho[kRhoBytes];
+ vector<K> t1;
// Pre-cached value(s).
- uint8_t public_key_hash[TR_BYTES];
+ uint8_t public_key_hash[kTrBytes];
};
+template <int K, int L>
struct private_key {
- uint8_t rho[RHO_BYTES];
- uint8_t k[K_BYTES];
- uint8_t public_key_hash[TR_BYTES];
- vectorl s1;
- vectork s2;
- vectork t0;
+ uint8_t rho[kRhoBytes];
+ uint8_t k[kKBytes];
+ uint8_t public_key_hash[kTrBytes];
+ vector<L> s1;
+ vector<K> s2;
+ vector<K> t0;
};
+template <int K, int L>
struct signature {
- uint8_t c_tilde[2 * LAMBDA_BYTES];
- vectorl z;
- vectork h;
+ uint8_t c_tilde[2 * lambda_bytes<K>()];
+ vector<L> z;
+ vector<K> h;
};
// FIPS 204, Algorithm 22 (`pkEncode`).
-static int mldsa_marshal_public_key(CBB *out, const struct public_key *pub) {
+template <int K>
+static int mldsa_marshal_public_key(CBB *out, const struct public_key<K> *pub) {
if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
return 0;
}
@@ -1128,28 +1193,37 @@
if (!CBB_add_space(out, &vectork_output, 320 * K)) {
return 0;
}
- vectork_encode(vectork_output, &pub->t1, 10);
+ vector_encode(vectork_output, &pub->t1, 10);
return 1;
}
// FIPS 204, Algorithm 23 (`pkDecode`).
-static int mldsa_parse_public_key(struct public_key *pub, CBS *in) {
+template <int K>
+static int mldsa_parse_public_key(struct public_key<K> *pub, CBS *in) {
+ const CBS orig_in = *in;
+
if (!CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
return 0;
}
CBS t1_bytes;
- if (!CBS_get_bytes(in, &t1_bytes, 320 * K)) {
+ if (!CBS_get_bytes(in, &t1_bytes, 320 * K) || CBS_len(in) != 0) {
return 0;
}
- vectork_decode_10(&pub->t1, CBS_data(&t1_bytes));
+ vector_decode_10(&pub->t1, CBS_data(&t1_bytes));
+
+ // Compute pre-cached values.
+ BORINGSSL_keccak(pub->public_key_hash, sizeof(pub->public_key_hash),
+ CBS_data(&orig_in), CBS_len(&orig_in), boringssl_shake256);
return 1;
}
// FIPS 204, Algorithm 24 (`skEncode`).
-static int mldsa_marshal_private_key(CBB *out, const struct private_key *priv) {
+template <int K, int L>
+static int mldsa_marshal_private_key(CBB *out,
+ const struct private_key<K, L> *priv) {
if (!CBB_add_bytes(out, priv->rho, sizeof(priv->rho)) ||
!CBB_add_bytes(out, priv->k, sizeof(priv->k)) ||
!CBB_add_bytes(out, priv->public_key_hash,
@@ -1157,42 +1231,52 @@
return 0;
}
+ constexpr size_t scalar_bytes =
+ (kDegree * plus_minus_eta_bitlen<K>() + 7) / 8;
uint8_t *vectorl_output;
- if (!CBB_add_space(out, &vectorl_output, 128 * L)) {
+ if (!CBB_add_space(out, &vectorl_output, scalar_bytes * L)) {
return 0;
}
- vectorl_encode_signed(vectorl_output, &priv->s1, 4, ETA);
+ vector_encode_signed(vectorl_output, &priv->s1, plus_minus_eta_bitlen<K>(),
+ eta<K>());
- uint8_t *vectork_output;
- if (!CBB_add_space(out, &vectork_output, 128 * K)) {
+ uint8_t *s2_output;
+ if (!CBB_add_space(out, &s2_output, scalar_bytes * K)) {
return 0;
}
- vectork_encode_signed(vectork_output, &priv->s2, 4, ETA);
+ vector_encode_signed(s2_output, &priv->s2, plus_minus_eta_bitlen<K>(),
+ eta<K>());
- if (!CBB_add_space(out, &vectork_output, 416 * K)) {
+ uint8_t *t0_output;
+ if (!CBB_add_space(out, &t0_output, 416 * K)) {
return 0;
}
- vectork_encode_signed(vectork_output, &priv->t0, 13, 1 << 12);
+ vector_encode_signed(t0_output, &priv->t0, 13, 1 << 12);
return 1;
}
// FIPS 204, Algorithm 25 (`skDecode`).
-static int mldsa_parse_private_key(struct private_key *priv, CBS *in) {
+template <int K, int L>
+static int mldsa_parse_private_key(struct private_key<K, L> *priv, CBS *in) {
CBS s1_bytes;
CBS s2_bytes;
CBS t0_bytes;
+ constexpr size_t scalar_bytes =
+ (kDegree * plus_minus_eta_bitlen<K>() + 7) / 8;
if (!CBS_copy_bytes(in, priv->rho, sizeof(priv->rho)) ||
!CBS_copy_bytes(in, priv->k, sizeof(priv->k)) ||
!CBS_copy_bytes(in, priv->public_key_hash,
sizeof(priv->public_key_hash)) ||
- !CBS_get_bytes(in, &s1_bytes, 128 * L) ||
- !vectorl_decode_signed(&priv->s1, CBS_data(&s1_bytes), 4, ETA) ||
- !CBS_get_bytes(in, &s2_bytes, 128 * K) ||
- !vectork_decode_signed(&priv->s2, CBS_data(&s2_bytes), 4, ETA) ||
+ !CBS_get_bytes(in, &s1_bytes, scalar_bytes * L) ||
+ !vector_decode_signed(&priv->s1, CBS_data(&s1_bytes),
+ plus_minus_eta_bitlen<K>(), eta<K>()) ||
+ !CBS_get_bytes(in, &s2_bytes, scalar_bytes * K) ||
+ !vector_decode_signed(&priv->s2, CBS_data(&s2_bytes),
+ plus_minus_eta_bitlen<K>(), eta<K>()) ||
!CBS_get_bytes(in, &t0_bytes, 416 * K) ||
// Note: Decoding 13 bits into (-2^12, 2^12] cannot fail.
- !vectork_decode_signed(&priv->t0, CBS_data(&t0_bytes), 13, 1 << 12)) {
+ !vector_decode_signed(&priv->t0, CBS_data(&t0_bytes), 13, 1 << 12)) {
return 0;
}
@@ -1200,7 +1284,9 @@
}
// FIPS 204, Algorithm 26 (`sigEncode`).
-static int mldsa_marshal_signature(CBB *out, const struct signature *sign) {
+template <int K, int L>
+static int mldsa_marshal_signature(CBB *out,
+ const struct signature<K, L> *sign) {
if (!CBB_add_bytes(out, sign->c_tilde, sizeof(sign->c_tilde))) {
return 0;
}
@@ -1209,10 +1295,10 @@
if (!CBB_add_space(out, &vectorl_output, 640 * L)) {
return 0;
}
- vectorl_encode_signed(vectorl_output, &sign->z, 20, 1 << 19);
+ vector_encode_signed(vectorl_output, &sign->z, 20, 1 << 19);
uint8_t *hint_output;
- if (!CBB_add_space(out, &hint_output, OMEGA + K)) {
+ if (!CBB_add_space(out, &hint_output, omega<K>() + K)) {
return 0;
}
hint_bit_pack(hint_output, &sign->h);
@@ -1221,14 +1307,15 @@
}
// FIPS 204, Algorithm 27 (`sigDecode`).
-static int mldsa_parse_signature(struct signature *sign, CBS *in) {
+template <int K, int L>
+static int mldsa_parse_signature(struct signature<K, L> *sign, CBS *in) {
CBS z_bytes;
CBS hint_bytes;
if (!CBS_copy_bytes(in, sign->c_tilde, sizeof(sign->c_tilde)) ||
!CBS_get_bytes(in, &z_bytes, 640 * L) ||
// Note: Decoding 20 bits into (-2^19, 2^19] cannot fail.
- !vectorl_decode_signed(&sign->z, CBS_data(&z_bytes), 20, 1 << 19) ||
- !CBS_get_bytes(in, &hint_bytes, OMEGA + K) ||
+ !vector_decode_signed(&sign->z, CBS_data(&z_bytes), 20, 1 << 19) ||
+ !CBS_get_bytes(in, &hint_bytes, omega<K>() + K) ||
!hint_bit_unpack(&sign->h, CBS_data(&hint_bytes))) {
return 0;
};
@@ -1236,28 +1323,370 @@
return 1;
}
-static struct private_key *private_key_from_external(
+template <typename T>
+struct DeleterFree {
+ void operator()(T *ptr) { OPENSSL_free(ptr); }
+};
+
+// FIPS 204, Algorithm 6 (`ML-DSA.KeyGen_internal`). Returns 1 on success and 0
+// on failure.
+template <int K, int L>
+static int mldsa_generate_key_external_entropy(
+ uint8_t out_encoded_public_key[public_key_bytes<K>()],
+ struct private_key<K, L> *priv, const uint8_t entropy[MLDSA_SEED_BYTES]) {
+ // Intermediate values, allocated on the heap to allow use when there is a
+ // limited amount of stack.
+ struct values_st {
+ struct public_key<K> pub;
+ matrix<K, L> a_ntt;
+ vector<L> s1_ntt;
+ vector<K> t;
+ };
+ std::unique_ptr<values_st, DeleterFree<values_st>> values(
+ reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
+ if (values == NULL) {
+ return 0;
+ }
+
+ uint8_t augmented_entropy[MLDSA_SEED_BYTES + 2];
+ OPENSSL_memcpy(augmented_entropy, entropy, MLDSA_SEED_BYTES);
+ // The k and l parameters are appended to the seed.
+ augmented_entropy[MLDSA_SEED_BYTES] = K;
+ augmented_entropy[MLDSA_SEED_BYTES + 1] = L;
+ uint8_t expanded_seed[kRhoBytes + kSigmaBytes + kKBytes];
+ BORINGSSL_keccak(expanded_seed, sizeof(expanded_seed), augmented_entropy,
+ sizeof(augmented_entropy), boringssl_shake256);
+ const uint8_t *const rho = expanded_seed;
+ const uint8_t *const sigma = expanded_seed + kRhoBytes;
+ const uint8_t *const k = expanded_seed + kRhoBytes + kSigmaBytes;
+ // rho is public.
+ CONSTTIME_DECLASSIFY(rho, kRhoBytes);
+ OPENSSL_memcpy(values->pub.rho, rho, sizeof(values->pub.rho));
+ OPENSSL_memcpy(priv->rho, rho, sizeof(priv->rho));
+ OPENSSL_memcpy(priv->k, k, sizeof(priv->k));
+
+ matrix_expand(&values->a_ntt, rho);
+ vector_expand_short(&priv->s1, &priv->s2, sigma);
+
+ OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
+ vector_ntt(&values->s1_ntt);
+
+ matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
+ vector_inverse_ntt(&values->t);
+ vector_add(&values->t, &values->t, &priv->s2);
+
+ vector_power2_round(&values->pub.t1, &priv->t0, &values->t);
+ // t1 is public.
+ CONSTTIME_DECLASSIFY(&values->pub.t1, sizeof(values->pub.t1));
+
+ CBB cbb;
+ CBB_init_fixed(&cbb, out_encoded_public_key, public_key_bytes<K>());
+ if (!mldsa_marshal_public_key(&cbb, &values->pub)) {
+ return 0;
+ }
+ assert(CBB_len(&cbb) == public_key_bytes<K>());
+
+ BORINGSSL_keccak(priv->public_key_hash, sizeof(priv->public_key_hash),
+ out_encoded_public_key, public_key_bytes<K>(),
+ boringssl_shake256);
+
+ return 1;
+}
+
+template <int K, int L>
+static int mldsa_public_from_private(struct public_key<K> *pub,
+ const struct private_key<K, L> *priv) {
+ // Intermediate values, allocated on the heap to allow use when there is a
+ // limited amount of stack.
+ struct values_st {
+ matrix<K, L> a_ntt;
+ vector<L> s1_ntt;
+ vector<K> t;
+ vector<K> t0;
+ };
+ std::unique_ptr<values_st, DeleterFree<values_st>> values(
+ reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
+ if (values == NULL) {
+ return 0;
+ }
+
+
+ OPENSSL_memcpy(pub->rho, priv->rho, sizeof(pub->rho));
+ OPENSSL_memcpy(pub->public_key_hash, priv->public_key_hash,
+ sizeof(pub->public_key_hash));
+
+ matrix_expand(&values->a_ntt, priv->rho);
+
+ OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
+ vector_ntt(&values->s1_ntt);
+
+ matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
+ vector_inverse_ntt(&values->t);
+ vector_add(&values->t, &values->t, &priv->s2);
+
+ vector_power2_round(&pub->t1, &values->t0, &values->t);
+ return 1;
+}
+
+// FIPS 204, Algorithm 7 (`ML-DSA.Sign_internal`). Returns 1 on success and 0
+// on failure.
+template <int K, int L>
+static int mldsa_sign_internal(
+ uint8_t out_encoded_signature[signature_bytes<K>()],
+ const struct private_key<K, L> *priv, const uint8_t *msg, size_t msg_len,
+ const uint8_t *context_prefix, size_t context_prefix_len,
+ const uint8_t *context, size_t context_len,
+ const uint8_t randomizer[MLDSA_SIGNATURE_RANDOMIZER_BYTES]) {
+ uint8_t mu[kMuBytes];
+ struct BORINGSSL_keccak_st keccak_ctx;
+ BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+ BORINGSSL_keccak_absorb(&keccak_ctx, priv->public_key_hash,
+ sizeof(priv->public_key_hash));
+ BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
+ BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
+ BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
+ BORINGSSL_keccak_squeeze(&keccak_ctx, mu, kMuBytes);
+
+ uint8_t rho_prime[kRhoPrimeBytes];
+ BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+ BORINGSSL_keccak_absorb(&keccak_ctx, priv->k, sizeof(priv->k));
+ BORINGSSL_keccak_absorb(&keccak_ctx, randomizer,
+ MLDSA_SIGNATURE_RANDOMIZER_BYTES);
+ BORINGSSL_keccak_absorb(&keccak_ctx, mu, kMuBytes);
+ BORINGSSL_keccak_squeeze(&keccak_ctx, rho_prime, kRhoPrimeBytes);
+
+ // Intermediate values, allocated on the heap to allow use when there is a
+ // limited amount of stack.
+ struct values_st {
+ struct signature<K, L> sign;
+ vector<L> s1_ntt;
+ vector<K> s2_ntt;
+ vector<K> t0_ntt;
+ matrix<K, L> a_ntt;
+ vector<L> y;
+ vector<K> w;
+ vector<K> w1;
+ vector<L> cs1;
+ vector<K> cs2;
+ };
+ std::unique_ptr<values_st, DeleterFree<values_st>> values(
+ reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
+ if (values == NULL) {
+ return 0;
+ }
+ OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
+ vector_ntt(&values->s1_ntt);
+
+ OPENSSL_memcpy(&values->s2_ntt, &priv->s2, sizeof(values->s2_ntt));
+ vector_ntt(&values->s2_ntt);
+
+ OPENSSL_memcpy(&values->t0_ntt, &priv->t0, sizeof(values->t0_ntt));
+ vector_ntt(&values->t0_ntt);
+
+ matrix_expand(&values->a_ntt, priv->rho);
+
+ // kappa must not exceed 2**16/L = 13107. But the probability of it
+ // exceeding even 1000 iterations is vanishingly small.
+ for (size_t kappa = 0;; kappa += L) {
+ vector_expand_mask(&values->y, rho_prime, kappa);
+
+ vector<L> *y_ntt = &values->cs1;
+ OPENSSL_memcpy(y_ntt, &values->y, sizeof(*y_ntt));
+ vector_ntt(y_ntt);
+
+ matrix_mult(&values->w, &values->a_ntt, y_ntt);
+ vector_inverse_ntt(&values->w);
+
+ vector_high_bits(&values->w1, &values->w);
+ uint8_t w1_encoded[128 * K];
+ w1_encode(w1_encoded, &values->w1);
+
+ BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+ BORINGSSL_keccak_absorb(&keccak_ctx, mu, kMuBytes);
+ BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
+ BORINGSSL_keccak_squeeze(&keccak_ctx, values->sign.c_tilde,
+ 2 * lambda_bytes<K>());
+
+ scalar c_ntt;
+ scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde,
+ sizeof(values->sign.c_tilde), tau<K>());
+ scalar_ntt(&c_ntt);
+
+ vector_mult_scalar(&values->cs1, &values->s1_ntt, &c_ntt);
+ vector_inverse_ntt(&values->cs1);
+ vector_mult_scalar(&values->cs2, &values->s2_ntt, &c_ntt);
+ vector_inverse_ntt(&values->cs2);
+
+ vector_add(&values->sign.z, &values->y, &values->cs1);
+
+ vector<K> *r0 = &values->w1;
+ vector_sub(r0, &values->w, &values->cs2);
+ vector_low_bits(r0, r0);
+
+ // Leaking the fact that a signature was rejected is fine as the next
+ // attempt at a signature will be (indistinguishable from) independent of
+ // this one. Note, however, that we additionally leak which of the two
+ // branches rejected the signature. Section 5.5 of
+ // https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf
+ // describes this leak as OK. Note we leak less than what is described by
+ // the paper; we do not reveal which coefficient violated the bound, and
+ // we hide which of the |z_max| or |r0_max| bound failed. See also
+ // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/2bbab0fa_d241d35a/
+ uint32_t z_max = vector_max(&values->sign.z);
+ uint32_t r0_max = vector_max_signed(r0);
+ if (constant_time_declassify_w(
+ constant_time_ge_w(z_max, gamma1<K>() - beta<K>()) |
+ constant_time_ge_w(r0_max, kGamma2 - beta<K>()))) {
+ continue;
+ }
+
+ vector<K> *ct0 = &values->w1;
+ vector_mult_scalar(ct0, &values->t0_ntt, &c_ntt);
+ vector_inverse_ntt(ct0);
+ vector_make_hint(&values->sign.h, ct0, &values->cs2, &values->w);
+
+ // See above.
+ uint32_t ct0_max = vector_max(ct0);
+ size_t h_ones = vector_count_ones(&values->sign.h);
+ if (constant_time_declassify_w(constant_time_ge_w(ct0_max, kGamma2) |
+ constant_time_lt_w(omega<K>(), h_ones))) {
+ continue;
+ }
+
+ // Although computed with the private key, the signature is public.
+ CONSTTIME_DECLASSIFY(values->sign.c_tilde, sizeof(values->sign.c_tilde));
+ CONSTTIME_DECLASSIFY(&values->sign.z, sizeof(values->sign.z));
+ CONSTTIME_DECLASSIFY(&values->sign.h, sizeof(values->sign.h));
+
+ CBB cbb;
+ CBB_init_fixed(&cbb, out_encoded_signature, signature_bytes<K>());
+ if (!mldsa_marshal_signature(&cbb, &values->sign)) {
+ return 0;
+ }
+
+ BSSL_CHECK(CBB_len(&cbb) == signature_bytes<K>());
+ return 1;
+ }
+}
+
+// FIPS 204, Algorithm 8 (`ML-DSA.Verify_internal`).
+template <int K, int L>
+static int mldsa_verify_internal(
+ const struct public_key<K> *pub,
+ const uint8_t encoded_signature[signature_bytes<K>()], const uint8_t *msg,
+ size_t msg_len, const uint8_t *context_prefix, size_t context_prefix_len,
+ const uint8_t *context, size_t context_len) {
+ // Intermediate values, allocated on the heap to allow use when there is a
+ // limited amount of stack.
+ struct values_st {
+ struct signature<K, L> sign;
+ matrix<K, L> a_ntt;
+ vector<L> z_ntt;
+ vector<K> az_ntt;
+ vector<K> ct1_ntt;
+ };
+ std::unique_ptr<values_st, DeleterFree<values_st>> values(
+ reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
+ if (values == NULL) {
+ return 0;
+ }
+
+ CBS cbs;
+ CBS_init(&cbs, encoded_signature, signature_bytes<K>());
+ if (!mldsa_parse_signature(&values->sign, &cbs)) {
+ return 0;
+ }
+
+ matrix_expand(&values->a_ntt, pub->rho);
+
+ uint8_t mu[kMuBytes];
+ struct BORINGSSL_keccak_st keccak_ctx;
+ BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+ BORINGSSL_keccak_absorb(&keccak_ctx, pub->public_key_hash,
+ sizeof(pub->public_key_hash));
+ BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
+ BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
+ BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
+ BORINGSSL_keccak_squeeze(&keccak_ctx, mu, kMuBytes);
+
+ scalar c_ntt;
+ scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde,
+ sizeof(values->sign.c_tilde), tau<K>());
+ scalar_ntt(&c_ntt);
+
+ OPENSSL_memcpy(&values->z_ntt, &values->sign.z, sizeof(values->z_ntt));
+ vector_ntt(&values->z_ntt);
+
+ matrix_mult(&values->az_ntt, &values->a_ntt, &values->z_ntt);
+
+ vector_scale_power2_round(&values->ct1_ntt, &pub->t1);
+ vector_ntt(&values->ct1_ntt);
+
+ vector_mult_scalar(&values->ct1_ntt, &values->ct1_ntt, &c_ntt);
+
+ vector<K> *const w1 = &values->az_ntt;
+ vector_sub(w1, &values->az_ntt, &values->ct1_ntt);
+ vector_inverse_ntt(w1);
+
+ vector_use_hint_vartime(w1, &values->sign.h, w1);
+ uint8_t w1_encoded[128 * K];
+ w1_encode(w1_encoded, w1);
+
+ uint8_t c_tilde[2 * lambda_bytes<K>()];
+ BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
+ BORINGSSL_keccak_absorb(&keccak_ctx, mu, kMuBytes);
+ BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
+ BORINGSSL_keccak_squeeze(&keccak_ctx, c_tilde, 2 * lambda_bytes<K>());
+
+ uint32_t z_max = vector_max(&values->sign.z);
+ return z_max < static_cast<uint32_t>(gamma1<K>() - beta<K>()) &&
+ OPENSSL_memcmp(c_tilde, values->sign.c_tilde, 2 * lambda_bytes<K>()) ==
+ 0;
+}
+
+} // namespace
+
+// ML-DSA-65 specific wrappers.
+
+static struct private_key<6, 5> *mldsa65_private_key_from_external(
const struct MLDSA65_private_key *external) {
- static_assert(
- sizeof(struct MLDSA65_private_key) == sizeof(struct private_key),
- "Kyber private key size incorrect");
- static_assert(
- alignof(struct MLDSA65_private_key) == alignof(struct private_key),
- "Kyber private key align incorrect");
- return (struct private_key *)external;
+ static_assert(sizeof(struct MLDSA65_private_key) ==
+ sizeof(struct private_key<6, 5>),
+ "MLDSA65 private key size incorrect");
+ static_assert(alignof(struct MLDSA65_private_key) ==
+ alignof(struct private_key<6, 5>),
+ "MLDSA65 private key align incorrect");
+ return (struct private_key<6, 5> *)external;
}
-static struct public_key *public_key_from_external(
- const struct MLDSA65_public_key *external) {
- static_assert(sizeof(struct MLDSA65_public_key) == sizeof(struct public_key),
- "mldsa public key size incorrect");
- static_assert(
- alignof(struct MLDSA65_public_key) == alignof(struct public_key),
- "mldsa public key align incorrect");
- return (struct public_key *)external;
+static struct public_key<6> *
+mldsa65_public_key_from_external(const struct MLDSA65_public_key *external) {
+ static_assert(sizeof(struct MLDSA65_public_key) ==
+ sizeof(struct public_key<6>),
+ "MLDSA65 public key size incorrect");
+ static_assert(alignof(struct MLDSA65_public_key) ==
+ alignof(struct public_key<6>),
+ "MLDSA65 public key align incorrect");
+ return (struct public_key<6> *)external;
}
-/* API */
+int MLDSA65_parse_public_key(struct MLDSA65_public_key *public_key, CBS *in) {
+ return mldsa_parse_public_key(mldsa65_public_key_from_external(public_key),
+ in);
+}
+
+int MLDSA65_marshal_private_key(CBB *out,
+ const struct MLDSA65_private_key *private_key) {
+ return mldsa_marshal_private_key(
+ out, mldsa65_private_key_from_external(private_key));
+}
+
+int MLDSA65_parse_private_key(struct MLDSA65_private_key *private_key,
+ CBS *in) {
+ return mldsa_parse_private_key(mldsa65_private_key_from_external(private_key),
+ in) &&
+ CBS_len(in) == 0;
+}
// Calls |MLDSA_generate_key_external_entropy| with random bytes from
// |RAND_bytes|. Returns 1 on success and 0 on failure.
@@ -1280,257 +1709,35 @@
seed);
}
-template <typename T>
-struct DeleterFree {
- void operator()(T *ptr) { OPENSSL_free(ptr); }
-};
-
-// FIPS 204, Algorithm 6 (`ML-DSA.KeyGen_internal`). Returns 1 on success and 0
-// on failure.
int MLDSA65_generate_key_external_entropy(
uint8_t out_encoded_public_key[MLDSA65_PUBLIC_KEY_BYTES],
struct MLDSA65_private_key *out_private_key,
const uint8_t entropy[MLDSA_SEED_BYTES]) {
- // Intermediate values, allocated on the heap to allow use when there is a
- // limited amount of stack.
- struct values_st {
- struct public_key pub;
- matrix a_ntt;
- vectorl s1_ntt;
- vectork t;
- };
- std::unique_ptr<values_st, DeleterFree<values_st>> values(
- reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
- if (values == NULL) {
- return 0;
- }
-
- struct private_key *priv = private_key_from_external(out_private_key);
-
- uint8_t augmented_entropy[MLDSA_SEED_BYTES + 2];
- OPENSSL_memcpy(augmented_entropy, entropy, MLDSA_SEED_BYTES);
- // The k and l parameters are appended to the seed.
- augmented_entropy[MLDSA_SEED_BYTES] = K;
- augmented_entropy[MLDSA_SEED_BYTES + 1] = L;
- uint8_t expanded_seed[RHO_BYTES + SIGMA_BYTES + K_BYTES];
- BORINGSSL_keccak(expanded_seed, sizeof(expanded_seed), augmented_entropy,
- sizeof(augmented_entropy), boringssl_shake256);
- const uint8_t *const rho = expanded_seed;
- const uint8_t *const sigma = expanded_seed + RHO_BYTES;
- const uint8_t *const k = expanded_seed + RHO_BYTES + SIGMA_BYTES;
- // rho is public.
- CONSTTIME_DECLASSIFY(rho, RHO_BYTES);
- OPENSSL_memcpy(values->pub.rho, rho, sizeof(values->pub.rho));
- OPENSSL_memcpy(priv->rho, rho, sizeof(priv->rho));
- OPENSSL_memcpy(priv->k, k, sizeof(priv->k));
-
- matrix_expand(&values->a_ntt, rho);
- vector_expand_short(&priv->s1, &priv->s2, sigma);
-
- OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
- vectorl_ntt(&values->s1_ntt);
-
- matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
- vectork_inverse_ntt(&values->t);
- vectork_add(&values->t, &values->t, &priv->s2);
-
- vectork_power2_round(&values->pub.t1, &priv->t0, &values->t);
- // t1 is public.
- CONSTTIME_DECLASSIFY(&values->pub.t1, sizeof(values->pub.t1));
-
- CBB cbb;
- CBB_init_fixed(&cbb, out_encoded_public_key, MLDSA65_PUBLIC_KEY_BYTES);
- if (!mldsa_marshal_public_key(&cbb, &values->pub)) {
- return 0;
- }
- assert(CBB_len(&cbb) == MLDSA65_PUBLIC_KEY_BYTES);
-
- BORINGSSL_keccak(priv->public_key_hash, sizeof(priv->public_key_hash),
- out_encoded_public_key, MLDSA65_PUBLIC_KEY_BYTES,
- boringssl_shake256);
-
- return 1;
+ return mldsa_generate_key_external_entropy(
+ out_encoded_public_key,
+ mldsa65_private_key_from_external(out_private_key), entropy);
}
int MLDSA65_public_from_private(struct MLDSA65_public_key *out_public_key,
const struct MLDSA65_private_key *private_key) {
- // Intermediate values, allocated on the heap to allow use when there is a
- // limited amount of stack.
- struct values_st {
- matrix a_ntt;
- vectorl s1_ntt;
- vectork t;
- vectork t0;
- };
- std::unique_ptr<values_st, DeleterFree<values_st>> values(
- reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
- if (values == NULL) {
- return 0;
- }
-
- const struct private_key *priv = private_key_from_external(private_key);
- struct public_key *pub = public_key_from_external(out_public_key);
-
- OPENSSL_memcpy(pub->rho, priv->rho, sizeof(pub->rho));
- OPENSSL_memcpy(pub->public_key_hash, priv->public_key_hash,
- sizeof(pub->public_key_hash));
-
- matrix_expand(&values->a_ntt, priv->rho);
-
- OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
- vectorl_ntt(&values->s1_ntt);
-
- matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
- vectork_inverse_ntt(&values->t);
- vectork_add(&values->t, &values->t, &priv->s2);
-
- vectork_power2_round(&pub->t1, &values->t0, &values->t);
- return 1;
+ return mldsa_public_from_private(
+ mldsa65_public_key_from_external(out_public_key),
+ mldsa65_private_key_from_external(private_key));
}
-// FIPS 204, Algorithm 7 (`ML-DSA.Sign_internal`). Returns 1 on success and 0 on
-// failure.
int MLDSA65_sign_internal(
uint8_t out_encoded_signature[MLDSA65_SIGNATURE_BYTES],
const struct MLDSA65_private_key *private_key, const uint8_t *msg,
size_t msg_len, const uint8_t *context_prefix, size_t context_prefix_len,
const uint8_t *context, size_t context_len,
const uint8_t randomizer[MLDSA_SIGNATURE_RANDOMIZER_BYTES]) {
- const struct private_key *priv = private_key_from_external(private_key);
-
- uint8_t mu[MU_BYTES];
- struct BORINGSSL_keccak_st keccak_ctx;
- BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
- BORINGSSL_keccak_absorb(&keccak_ctx, priv->public_key_hash,
- sizeof(priv->public_key_hash));
- BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
- BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
- BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
- BORINGSSL_keccak_squeeze(&keccak_ctx, mu, MU_BYTES);
-
- uint8_t rho_prime[RHO_PRIME_BYTES];
- BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
- BORINGSSL_keccak_absorb(&keccak_ctx, priv->k, sizeof(priv->k));
- BORINGSSL_keccak_absorb(&keccak_ctx, randomizer,
- MLDSA_SIGNATURE_RANDOMIZER_BYTES);
- BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
- BORINGSSL_keccak_squeeze(&keccak_ctx, rho_prime, RHO_PRIME_BYTES);
-
- // Intermediate values, allocated on the heap to allow use when there is a
- // limited amount of stack.
- struct values_st {
- struct signature sign;
- vectorl s1_ntt;
- vectork s2_ntt;
- vectork t0_ntt;
- matrix a_ntt;
- vectorl y;
- vectork w;
- vectork w1;
- vectorl cs1;
- vectork cs2;
- };
- std::unique_ptr<values_st, DeleterFree<values_st>> values(
- reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
- if (values == NULL) {
- return 0;
- }
- OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
- vectorl_ntt(&values->s1_ntt);
-
- OPENSSL_memcpy(&values->s2_ntt, &priv->s2, sizeof(values->s2_ntt));
- vectork_ntt(&values->s2_ntt);
-
- OPENSSL_memcpy(&values->t0_ntt, &priv->t0, sizeof(values->t0_ntt));
- vectork_ntt(&values->t0_ntt);
-
- matrix_expand(&values->a_ntt, priv->rho);
-
- // kappa must not exceed 2**16/L = 13107. But the probability of it exceeding
- // even 1000 iterations is vanishingly small.
- for (size_t kappa = 0;; kappa += L) {
- vectorl_expand_mask(&values->y, rho_prime, kappa);
-
- vectorl *y_ntt = &values->cs1;
- OPENSSL_memcpy(y_ntt, &values->y, sizeof(*y_ntt));
- vectorl_ntt(y_ntt);
-
- matrix_mult(&values->w, &values->a_ntt, y_ntt);
- vectork_inverse_ntt(&values->w);
-
- vectork_high_bits(&values->w1, &values->w);
- uint8_t w1_encoded[128 * K];
- w1_encode(w1_encoded, &values->w1);
-
- BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
- BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
- BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
- BORINGSSL_keccak_squeeze(&keccak_ctx, values->sign.c_tilde,
- 2 * LAMBDA_BYTES);
-
- scalar c_ntt;
- scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde,
- sizeof(values->sign.c_tilde));
- scalar_ntt(&c_ntt);
-
- vectorl_mult_scalar(&values->cs1, &values->s1_ntt, &c_ntt);
- vectorl_inverse_ntt(&values->cs1);
- vectork_mult_scalar(&values->cs2, &values->s2_ntt, &c_ntt);
- vectork_inverse_ntt(&values->cs2);
-
- vectorl_add(&values->sign.z, &values->y, &values->cs1);
-
- vectork *r0 = &values->w1;
- vectork_sub(r0, &values->w, &values->cs2);
- vectork_low_bits(r0, r0);
-
- // Leaking the fact that a signature was rejected is fine as the next
- // attempt at a signature will be (indistinguishable from) independent of
- // this one. Note, however, that we additionally leak which of the two
- // branches rejected the signature. Section 5.5 of
- // https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf
- // describes this leak as OK. Note we leak less than what is described by
- // the paper; we do not reveal which coefficient violated the bound, and we
- // hide which of the |z_max| or |r0_max| bound failed. See also
- // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/2bbab0fa_d241d35a/
- uint32_t z_max = vectorl_max(&values->sign.z);
- uint32_t r0_max = vectork_max_signed(r0);
- if (constant_time_declassify_w(
- constant_time_ge_w(z_max, kGamma1 - BETA) |
- constant_time_ge_w(r0_max, kGamma2 - BETA))) {
- continue;
- }
-
- vectork *ct0 = &values->w1;
- vectork_mult_scalar(ct0, &values->t0_ntt, &c_ntt);
- vectork_inverse_ntt(ct0);
- vectork_make_hint(&values->sign.h, ct0, &values->cs2, &values->w);
-
- // See above.
- uint32_t ct0_max = vectork_max(ct0);
- size_t h_ones = vectork_count_ones(&values->sign.h);
- if (constant_time_declassify_w(constant_time_ge_w(ct0_max, kGamma2) |
- constant_time_lt_w(OMEGA, h_ones))) {
- continue;
- }
-
- // Although computed with the private key, the signature is public.
- CONSTTIME_DECLASSIFY(values->sign.c_tilde, sizeof(values->sign.c_tilde));
- CONSTTIME_DECLASSIFY(&values->sign.z, sizeof(values->sign.z));
- CONSTTIME_DECLASSIFY(&values->sign.h, sizeof(values->sign.h));
-
- CBB cbb;
- CBB_init_fixed(&cbb, out_encoded_signature, MLDSA65_SIGNATURE_BYTES);
- if (!mldsa_marshal_signature(&cbb, &values->sign)) {
- return 0;
- }
-
- BSSL_CHECK(CBB_len(&cbb) == MLDSA65_SIGNATURE_BYTES);
- return 1;
- }
+ return mldsa_sign_internal(out_encoded_signature,
+ mldsa65_private_key_from_external(private_key),
+ msg, msg_len, context_prefix, context_prefix_len,
+ context, context_len, randomizer);
}
-// mldsa signature in randomized mode, filling the random bytes with
+// ML-DSA signature in randomized mode, filling the random bytes with
// |RAND_bytes|. Returns 1 on success and 0 on failure.
int MLDSA65_sign(uint8_t out_encoded_signature[MLDSA65_SIGNATURE_BYTES],
const struct MLDSA65_private_key *private_key,
@@ -1564,108 +1771,18 @@
context, context_len);
}
-// FIPS 204, Algorithm 8 (`ML-DSA.Verify_internal`).
int MLDSA65_verify_internal(
const struct MLDSA65_public_key *public_key,
const uint8_t encoded_signature[MLDSA65_SIGNATURE_BYTES],
const uint8_t *msg, size_t msg_len, const uint8_t *context_prefix,
size_t context_prefix_len, const uint8_t *context, size_t context_len) {
- // Intermediate values, allocated on the heap to allow use when there is a
- // limited amount of stack.
- struct values_st {
- struct signature sign;
- matrix a_ntt;
- vectorl z_ntt;
- vectork az_ntt;
- vectork ct1_ntt;
- };
- std::unique_ptr<values_st, DeleterFree<values_st>> values(
- reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
- if (values == NULL) {
- return 0;
- }
-
- const struct public_key *pub = public_key_from_external(public_key);
-
- CBS cbs;
- CBS_init(&cbs, encoded_signature, MLDSA65_SIGNATURE_BYTES);
- if (!mldsa_parse_signature(&values->sign, &cbs)) {
- return 0;
- }
-
- matrix_expand(&values->a_ntt, pub->rho);
-
- uint8_t mu[MU_BYTES];
- struct BORINGSSL_keccak_st keccak_ctx;
- BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
- BORINGSSL_keccak_absorb(&keccak_ctx, pub->public_key_hash,
- sizeof(pub->public_key_hash));
- BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
- BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
- BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
- BORINGSSL_keccak_squeeze(&keccak_ctx, mu, MU_BYTES);
-
- scalar c_ntt;
- scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde,
- sizeof(values->sign.c_tilde));
- scalar_ntt(&c_ntt);
-
- OPENSSL_memcpy(&values->z_ntt, &values->sign.z, sizeof(values->z_ntt));
- vectorl_ntt(&values->z_ntt);
-
- matrix_mult(&values->az_ntt, &values->a_ntt, &values->z_ntt);
-
- vectork_scale_power2_round(&values->ct1_ntt, &pub->t1);
- vectork_ntt(&values->ct1_ntt);
-
- vectork_mult_scalar(&values->ct1_ntt, &values->ct1_ntt, &c_ntt);
-
- vectork *const w1 = &values->az_ntt;
- vectork_sub(w1, &values->az_ntt, &values->ct1_ntt);
- vectork_inverse_ntt(w1);
-
- vectork_use_hint_vartime(w1, &values->sign.h, w1);
- uint8_t w1_encoded[128 * K];
- w1_encode(w1_encoded, w1);
-
- uint8_t c_tilde[2 * LAMBDA_BYTES];
- BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
- BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
- BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
- BORINGSSL_keccak_squeeze(&keccak_ctx, c_tilde, 2 * LAMBDA_BYTES);
-
- uint32_t z_max = vectorl_max(&values->sign.z);
- return z_max < kGamma1 - BETA &&
- OPENSSL_memcmp(c_tilde, values->sign.c_tilde, 2 * LAMBDA_BYTES) == 0;
+ return mldsa_verify_internal<6, 5>(
+ mldsa65_public_key_from_external(public_key), encoded_signature, msg,
+ msg_len, context_prefix, context_prefix_len, context, context_len);
}
-/* Serialization of keys. */
-
int MLDSA65_marshal_public_key(CBB *out,
const struct MLDSA65_public_key *public_key) {
- return mldsa_marshal_public_key(out, public_key_from_external(public_key));
-}
-
-int MLDSA65_parse_public_key(struct MLDSA65_public_key *public_key, CBS *in) {
- struct public_key *pub = public_key_from_external(public_key);
- CBS orig_in = *in;
- if (!mldsa_parse_public_key(pub, in) || CBS_len(in) != 0) {
- return 0;
- }
-
- // Compute pre-cached values.
- BORINGSSL_keccak(pub->public_key_hash, sizeof(pub->public_key_hash),
- CBS_data(&orig_in), CBS_len(&orig_in), boringssl_shake256);
- return 1;
-}
-
-int MLDSA65_marshal_private_key(CBB *out,
- const struct MLDSA65_private_key *private_key) {
- return mldsa_marshal_private_key(out, private_key_from_external(private_key));
-}
-
-int MLDSA65_parse_private_key(struct MLDSA65_private_key *private_key,
- CBS *in) {
- struct private_key *priv = private_key_from_external(private_key);
- return mldsa_parse_private_key(priv, in) && CBS_len(in) == 0;
+ return mldsa_marshal_public_key(out,
+ mldsa65_public_key_from_external(public_key));
}
diff --git a/crypto/mldsa/mldsa_test.cc b/crypto/mldsa/mldsa_test.cc
index c9b0828..56035ef 100644
--- a/crypto/mldsa/mldsa_test.cc
+++ b/crypto/mldsa/mldsa_test.cc
@@ -264,6 +264,11 @@
FileTestGTest("crypto/mldsa/mldsa_nist_keygen_tests.txt", MLDSAKeyGenTest);
}
+template <typename PrivateKey, int (*ParsePrivateKey)(PrivateKey *, CBS *),
+ size_t SignatureBytes,
+ int (*SignInternal)(uint8_t *, const PrivateKey *, const uint8_t *,
+ size_t, const uint8_t *, size_t, const uint8_t *,
+ size_t, const uint8_t *)>
static void MLDSAWycheproofSignTest(FileTest *t) {
std::vector<uint8_t> private_key_bytes, msg, expected_signature, context;
ASSERT_TRUE(t->GetInstructionBytes(&private_key_bytes, "privateKey"));
@@ -278,8 +283,8 @@
CBS cbs;
CBS_init(&cbs, private_key_bytes.data(), private_key_bytes.size());
- auto priv = std::make_unique<MLDSA65_private_key>();
- const int priv_ok = MLDSA65_parse_private_key(priv.get(), &cbs);
+ auto priv = std::make_unique<PrivateKey>();
+ const int priv_ok = ParsePrivateKey(priv.get(), &cbs);
if (!priv_ok) {
ASSERT_TRUE(result != "valid");
@@ -295,21 +300,25 @@
}
const uint8_t zero_randomizer[MLDSA_SIGNATURE_RANDOMIZER_BYTES] = {0};
- std::vector<uint8_t> signature(MLDSA65_SIGNATURE_BYTES);
+ std::vector<uint8_t> signature(SignatureBytes);
const uint8_t context_prefix[2] = {0, static_cast<uint8_t>(context.size())};
- EXPECT_TRUE(MLDSA65_sign_internal(
- signature.data(), priv.get(), msg.data(), msg.size(), context_prefix,
- sizeof(context_prefix), context.data(), context.size(), zero_randomizer));
+ EXPECT_TRUE(SignInternal(signature.data(), priv.get(), msg.data(), msg.size(),
+ context_prefix, sizeof(context_prefix),
+ context.data(), context.size(), zero_randomizer));
EXPECT_EQ(Bytes(signature), Bytes(expected_signature));
}
-TEST(MLDSATest, WycheproofSignTests) {
+TEST(MLDSATest, WycheproofSignTests65) {
FileTestGTest(
"third_party/wycheproof_testvectors/mldsa_65_standard_sign_test.txt",
- MLDSAWycheproofSignTest);
+ MLDSAWycheproofSignTest<MLDSA65_private_key, MLDSA65_parse_private_key,
+ MLDSA65_SIGNATURE_BYTES, MLDSA65_sign_internal>);
}
+template <typename PublicKey, int (*ParsePublicKey)(PublicKey *, CBS *),
+ int (*Verify)(const PublicKey *, const uint8_t *, size_t,
+ const uint8_t *, size_t, const uint8_t *, size_t)>
static void MLDSAWycheproofVerifyTest(FileTest *t) {
std::vector<uint8_t> public_key_bytes, msg, signature, context;
ASSERT_TRUE(t->GetInstructionBytes(&public_key_bytes, "publicKey"));
@@ -324,8 +333,8 @@
CBS cbs;
CBS_init(&cbs, public_key_bytes.data(), public_key_bytes.size());
- auto pub = std::make_unique<MLDSA65_public_key>();
- const int pub_ok = MLDSA65_parse_public_key(pub.get(), &cbs);
+ auto pub = std::make_unique<PublicKey>();
+ const int pub_ok = ParsePublicKey(pub.get(), &cbs);
if (!pub_ok) {
EXPECT_EQ(flags, "IncorrectPublicKeyLength");
@@ -333,8 +342,8 @@
}
const int sig_ok =
- MLDSA65_verify(pub.get(), signature.data(), signature.size(), msg.data(),
- msg.size(), context.data(), context.size());
+ Verify(pub.get(), signature.data(), signature.size(), msg.data(),
+ msg.size(), context.data(), context.size());
if (!sig_ok) {
EXPECT_EQ(result, "invalid");
} else {
@@ -342,10 +351,12 @@
}
}
-TEST(MLDSATest, WycheproofVerifyTests) {
+
+TEST(MLDSATest, WycheproofVerifyTests65) {
FileTestGTest(
"third_party/wycheproof_testvectors/mldsa_65_standard_verify_test.txt",
- MLDSAWycheproofVerifyTest);
+ MLDSAWycheproofVerifyTest<MLDSA65_public_key, MLDSA65_parse_public_key,
+ MLDSA65_verify>);
}
} // namespace
diff --git a/include/openssl/mldsa.h b/include/openssl/mldsa.h
index a0a7560..3be22f3 100644
--- a/include/openssl/mldsa.h
+++ b/include/openssl/mldsa.h
@@ -22,12 +22,18 @@
#endif
-// ML-DSA-65.
+// ML-DSA.
//
// This implements the Module-Lattice-Based Digital Signature Standard from
// https://csrc.nist.gov/pubs/fips/204/final
+// MLDSA_SEED_BYTES is the number of bytes in an ML-DSA seed value.
+#define MLDSA_SEED_BYTES 32
+
+
+// ML-DSA-65.
+
// MLDSA65_private_key contains an ML-DSA-65 private key. The contents of this
// object should never leave the address space since the format is unstable.
struct MLDSA65_private_key {
@@ -58,9 +64,6 @@
// signature.
#define MLDSA65_SIGNATURE_BYTES 3309
-// MLDSA_SEED_BYTES is the number of bytes in an ML-DSA seed value.
-#define MLDSA_SEED_BYTES 32
-
// MLDSA65_generate_key generates a random public/private key pair, writes the
// encoded public key to |out_encoded_public_key|, writes the seed to
// |out_seed|, and sets |out_private_key| to the private key. Returns 1 on
@@ -106,9 +109,6 @@
size_t msg_len, const uint8_t *context,
size_t context_len);
-
-// Serialisation of keys.
-
// MLDSA65_marshal_public_key serializes |public_key| to |out| in the standard
// format for ML-DSA-65 public keys. It returns 1 on success or 0 on
// allocation error.