Tidy up the PSK binder logic.
Computing the binders on ClientHelloInner is a little interesting. While
I'm in the area, tidy this up a bit. The exploded parameters may as well
be an SSL_SESSION, and hash_transcript_and_truncated_client_hello can
just get folded in.
Change-Id: I9d3a7e0ae9f391d6b9a23b51b5d7198e15569b11
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/47997
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/internal.h b/ssl/internal.h
index 72e1fba..6a1a650 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -700,9 +700,9 @@
// the transcript. It returns true on success and false on failure. If the
// handshake buffer is still present, |digest| may be any supported digest.
// Otherwise, |digest| must match the transcript hash.
- bool CopyToHashContext(EVP_MD_CTX *ctx, const EVP_MD *digest);
+ bool CopyToHashContext(EVP_MD_CTX *ctx, const EVP_MD *digest) const;
- Span<const uint8_t> buffer() {
+ Span<const uint8_t> buffer() const {
return MakeConstSpan(reinterpret_cast<const uint8_t *>(buffer_->data),
buffer_->length);
}
@@ -725,14 +725,14 @@
// GetHash writes the handshake hash to |out| which must have room for at
// least |DigestLen| bytes. On success, it returns true and sets |*out_len| to
// the number of bytes written. Otherwise, it returns false.
- bool GetHash(uint8_t *out, size_t *out_len);
+ bool GetHash(uint8_t *out, size_t *out_len) const;
// GetFinishedMAC computes the MAC for the Finished message into the bytes
// pointed by |out| and writes the number of bytes to |*out_len|. |out| must
// have room for |EVP_MAX_MD_SIZE| bytes. It returns true on success and false
// on failure.
bool GetFinishedMAC(uint8_t *out, size_t *out_len, const SSL_SESSION *session,
- bool from_server);
+ bool from_server) const;
private:
// buffer_, if non-null, contains the handshake transcript.
@@ -1418,13 +1418,14 @@
// tls13_write_psk_binder calculates the PSK binder value and replaces the last
// bytes of |msg| with the resulting value. It returns true on success, and
// false on failure.
-bool tls13_write_psk_binder(SSL_HANDSHAKE *hs, Span<uint8_t> msg);
+bool tls13_write_psk_binder(const SSL_HANDSHAKE *hs, Span<uint8_t> msg);
// tls13_verify_psk_binder verifies that the handshake transcript, truncated up
// to the binders has a valid signature using the value of |session|'s
// resumption secret. It returns true on success, and false on failure.
-bool tls13_verify_psk_binder(SSL_HANDSHAKE *hs, SSL_SESSION *session,
- const SSLMessage &msg, CBS *binders);
+bool tls13_verify_psk_binder(const SSL_HANDSHAKE *hs,
+ const SSL_SESSION *session, const SSLMessage &msg,
+ CBS *binders);
// Encrypted ClientHello.
diff --git a/ssl/ssl_transcript.cc b/ssl/ssl_transcript.cc
index 0bc13b9..1599c80 100644
--- a/ssl/ssl_transcript.cc
+++ b/ssl/ssl_transcript.cc
@@ -206,7 +206,8 @@
return true;
}
-bool SSLTranscript::CopyToHashContext(EVP_MD_CTX *ctx, const EVP_MD *digest) {
+bool SSLTranscript::CopyToHashContext(EVP_MD_CTX *ctx,
+ const EVP_MD *digest) const {
const EVP_MD *transcript_digest = Digest();
if (transcript_digest != nullptr &&
EVP_MD_type(transcript_digest) == EVP_MD_type(digest)) {
@@ -237,7 +238,7 @@
return true;
}
-bool SSLTranscript::GetHash(uint8_t *out, size_t *out_len) {
+bool SSLTranscript::GetHash(uint8_t *out, size_t *out_len) const {
ScopedEVP_MD_CTX ctx;
unsigned len;
if (!EVP_MD_CTX_copy_ex(ctx.get(), hash_.get()) ||
@@ -250,7 +251,7 @@
bool SSLTranscript::GetFinishedMAC(uint8_t *out, size_t *out_len,
const SSL_SESSION *session,
- bool from_server) {
+ bool from_server) const {
static const char kClientLabel[] = "client finished";
static const char kServerLabel[] = "server finished";
auto label = from_server
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index 9c3063f..2ac9985 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -395,29 +395,52 @@
static const char kTLS13LabelPSKBinder[] = "res binder";
-static bool tls13_psk_binder(uint8_t *out, size_t *out_len, uint16_t version,
- const EVP_MD *digest, Span<const uint8_t> psk,
- Span<const uint8_t> context) {
+static bool tls13_psk_binder(uint8_t *out, size_t *out_len,
+ const SSL_SESSION *session,
+ const SSLTranscript &transcript,
+ Span<const uint8_t> client_hello,
+ size_t binders_len) {
+ const EVP_MD *digest = ssl_session_get_digest(session);
+
+ // Compute the binder key.
+ //
+ // TODO(davidben): Ideally we wouldn't recompute early secret and the binder
+ // key each time.
uint8_t binder_context[EVP_MAX_MD_SIZE];
unsigned binder_context_len;
- if (!EVP_Digest(NULL, 0, binder_context, &binder_context_len, digest, NULL)) {
- return false;
- }
-
uint8_t early_secret[EVP_MAX_MD_SIZE] = {0};
size_t early_secret_len;
- if (!HKDF_extract(early_secret, &early_secret_len, digest, psk.data(),
- psk.size(), NULL, 0)) {
+ uint8_t binder_key_buf[EVP_MAX_MD_SIZE] = {0};
+ auto binder_key = MakeSpan(binder_key_buf, EVP_MD_size(digest));
+ if (!EVP_Digest(nullptr, 0, binder_context, &binder_context_len, digest,
+ nullptr) ||
+ !HKDF_extract(early_secret, &early_secret_len, digest, session->secret,
+ session->secret_length, nullptr, 0) ||
+ !hkdf_expand_label(binder_key, digest,
+ MakeConstSpan(early_secret, early_secret_len),
+ label_to_span(kTLS13LabelPSKBinder),
+ MakeConstSpan(binder_context, binder_context_len))) {
return false;
}
- uint8_t binder_key_buf[EVP_MAX_MD_SIZE] = {0};
- auto binder_key = MakeSpan(binder_key_buf, EVP_MD_size(digest));
- if (!hkdf_expand_label(binder_key, digest,
- MakeConstSpan(early_secret, early_secret_len),
- label_to_span(kTLS13LabelPSKBinder),
- MakeConstSpan(binder_context, binder_context_len)) ||
- !tls13_verify_data(out, out_len, digest, version, binder_key, context)) {
+ // Hash the transcript and truncated ClientHello.
+ if (client_hello.size() < binders_len) {
+ OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+ return false;
+ }
+ auto truncated = client_hello.subspan(0, client_hello.size() - binders_len);
+ uint8_t context[EVP_MAX_MD_SIZE];
+ unsigned context_len;
+ ScopedEVP_MD_CTX ctx;
+ if (!transcript.CopyToHashContext(ctx.get(), digest) ||
+ !EVP_DigestUpdate(ctx.get(), truncated.data(),
+ truncated.size()) ||
+ !EVP_DigestFinal_ex(ctx.get(), context, &context_len)) {
+ return false;
+ }
+
+ if (!tls13_verify_data(out, out_len, digest, session->ssl_version, binder_key,
+ MakeConstSpan(context, context_len))) {
return false;
}
@@ -425,44 +448,19 @@
return true;
}
-static bool hash_transcript_and_truncated_client_hello(
- SSL_HANDSHAKE *hs, uint8_t *out, size_t *out_len, const EVP_MD *digest,
- Span<const uint8_t> client_hello, size_t binders_len) {
- // Truncate the ClientHello.
- if (binders_len + 2 < binders_len || client_hello.size() < binders_len + 2) {
- return false;
- }
- client_hello = client_hello.subspan(0, client_hello.size() - binders_len - 2);
-
- ScopedEVP_MD_CTX ctx;
- unsigned len;
- if (!hs->transcript.CopyToHashContext(ctx.get(), digest) ||
- !EVP_DigestUpdate(ctx.get(), client_hello.data(), client_hello.size()) ||
- !EVP_DigestFinal_ex(ctx.get(), out, &len)) {
- return false;
- }
-
- *out_len = len;
- return true;
-}
-
-bool tls13_write_psk_binder(SSL_HANDSHAKE *hs, Span<uint8_t> msg) {
- SSL *const ssl = hs->ssl;
+bool tls13_write_psk_binder(const SSL_HANDSHAKE *hs,
+ Span<uint8_t> msg) {
+ const SSL *const ssl = hs->ssl;
const EVP_MD *digest = ssl_session_get_digest(ssl->session.get());
- size_t hash_len = EVP_MD_size(digest);
-
- ScopedEVP_MD_CTX ctx;
- uint8_t context[EVP_MAX_MD_SIZE];
- size_t context_len;
+ const size_t hash_len = EVP_MD_size(digest);
+ // We only offer one PSK, so the binders are a u16 and u8 length
+ // prefix, followed by the binder. The caller is assumed to have constructed
+ // |msg| with placeholder binders.
+ const size_t binders_len = 3 + hash_len;
uint8_t verify_data[EVP_MAX_MD_SIZE];
size_t verify_data_len;
- if (!hash_transcript_and_truncated_client_hello(
- hs, context, &context_len, digest, msg,
- 1 /* length prefix */ + hash_len) ||
- !tls13_psk_binder(
- verify_data, &verify_data_len, ssl->session->ssl_version, digest,
- MakeConstSpan(ssl->session->secret, ssl->session->secret_length),
- MakeConstSpan(context, context_len)) ||
+ if (!tls13_psk_binder(verify_data, &verify_data_len, ssl->session.get(),
+ hs->transcript, msg, binders_len) ||
verify_data_len != hash_len) {
OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
return false;
@@ -473,20 +471,17 @@
return true;
}
-bool tls13_verify_psk_binder(SSL_HANDSHAKE *hs, SSL_SESSION *session,
- const SSLMessage &msg, CBS *binders) {
- uint8_t context[EVP_MAX_MD_SIZE];
- size_t context_len;
+bool tls13_verify_psk_binder(const SSL_HANDSHAKE *hs,
+ const SSL_SESSION *session, const SSLMessage &msg,
+ CBS *binders) {
uint8_t verify_data[EVP_MAX_MD_SIZE];
size_t verify_data_len;
CBS binder;
- if (!hash_transcript_and_truncated_client_hello(hs, context, &context_len,
- hs->transcript.Digest(),
- msg.raw, CBS_len(binders)) ||
- !tls13_psk_binder(verify_data, &verify_data_len, hs->ssl->version,
- hs->transcript.Digest(),
- MakeConstSpan(session->secret, session->secret_length),
- MakeConstSpan(context, context_len)) ||
+ // The binders are computed over |msg| with |binders| and its u16 length
+ // prefix removed. The caller is assumed to have parsed |msg|, extracted
+ // |binders|, and verified the PSK extension is last.
+ if (!tls13_psk_binder(verify_data, &verify_data_len, session, hs->transcript,
+ msg.raw, 2 + CBS_len(binders)) ||
// We only consider the first PSK, so compare against the first binder.
!CBS_get_u8_length_prefixed(binders, &binder)) {
OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);