Tidy up the PSK binder logic.

Computing the binders on ClientHelloInner is a little interesting. While
I'm in the area, tidy this up a bit. The exploded parameters may as well
be an SSL_SESSION, and hash_transcript_and_truncated_client_hello can
just get folded in.

Change-Id: I9d3a7e0ae9f391d6b9a23b51b5d7198e15569b11
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/47997
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index 9c3063f..2ac9985 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -395,29 +395,52 @@
 
 static const char kTLS13LabelPSKBinder[] = "res binder";
 
-static bool tls13_psk_binder(uint8_t *out, size_t *out_len, uint16_t version,
-                             const EVP_MD *digest, Span<const uint8_t> psk,
-                             Span<const uint8_t> context) {
+static bool tls13_psk_binder(uint8_t *out, size_t *out_len,
+                             const SSL_SESSION *session,
+                             const SSLTranscript &transcript,
+                             Span<const uint8_t> client_hello,
+                             size_t binders_len) {
+  const EVP_MD *digest = ssl_session_get_digest(session);
+
+  // Compute the binder key.
+  //
+  // TODO(davidben): Ideally we wouldn't recompute early secret and the binder
+  // key each time.
   uint8_t binder_context[EVP_MAX_MD_SIZE];
   unsigned binder_context_len;
-  if (!EVP_Digest(NULL, 0, binder_context, &binder_context_len, digest, NULL)) {
-    return false;
-  }
-
   uint8_t early_secret[EVP_MAX_MD_SIZE] = {0};
   size_t early_secret_len;
-  if (!HKDF_extract(early_secret, &early_secret_len, digest, psk.data(),
-                    psk.size(), NULL, 0)) {
+  uint8_t binder_key_buf[EVP_MAX_MD_SIZE] = {0};
+  auto binder_key = MakeSpan(binder_key_buf, EVP_MD_size(digest));
+  if (!EVP_Digest(nullptr, 0, binder_context, &binder_context_len, digest,
+                  nullptr) ||
+      !HKDF_extract(early_secret, &early_secret_len, digest, session->secret,
+                    session->secret_length, nullptr, 0) ||
+      !hkdf_expand_label(binder_key, digest,
+                         MakeConstSpan(early_secret, early_secret_len),
+                         label_to_span(kTLS13LabelPSKBinder),
+                         MakeConstSpan(binder_context, binder_context_len))) {
     return false;
   }
 
-  uint8_t binder_key_buf[EVP_MAX_MD_SIZE] = {0};
-  auto binder_key = MakeSpan(binder_key_buf, EVP_MD_size(digest));
-  if (!hkdf_expand_label(binder_key, digest,
-                         MakeConstSpan(early_secret, early_secret_len),
-                         label_to_span(kTLS13LabelPSKBinder),
-                         MakeConstSpan(binder_context, binder_context_len)) ||
-      !tls13_verify_data(out, out_len, digest, version, binder_key, context)) {
+  // Hash the transcript and truncated ClientHello.
+  if (client_hello.size() < binders_len) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+    return false;
+  }
+  auto truncated = client_hello.subspan(0, client_hello.size() - binders_len);
+  uint8_t context[EVP_MAX_MD_SIZE];
+  unsigned context_len;
+  ScopedEVP_MD_CTX ctx;
+  if (!transcript.CopyToHashContext(ctx.get(), digest) ||
+      !EVP_DigestUpdate(ctx.get(), truncated.data(),
+                        truncated.size()) ||
+      !EVP_DigestFinal_ex(ctx.get(), context, &context_len)) {
+    return false;
+  }
+
+  if (!tls13_verify_data(out, out_len, digest, session->ssl_version, binder_key,
+                         MakeConstSpan(context, context_len))) {
     return false;
   }
 
@@ -425,44 +448,19 @@
   return true;
 }
 
-static bool hash_transcript_and_truncated_client_hello(
-    SSL_HANDSHAKE *hs, uint8_t *out, size_t *out_len, const EVP_MD *digest,
-    Span<const uint8_t> client_hello, size_t binders_len) {
-  // Truncate the ClientHello.
-  if (binders_len + 2 < binders_len || client_hello.size() < binders_len + 2) {
-    return false;
-  }
-  client_hello = client_hello.subspan(0, client_hello.size() - binders_len - 2);
-
-  ScopedEVP_MD_CTX ctx;
-  unsigned len;
-  if (!hs->transcript.CopyToHashContext(ctx.get(), digest) ||
-      !EVP_DigestUpdate(ctx.get(), client_hello.data(), client_hello.size()) ||
-      !EVP_DigestFinal_ex(ctx.get(), out, &len)) {
-    return false;
-  }
-
-  *out_len = len;
-  return true;
-}
-
-bool tls13_write_psk_binder(SSL_HANDSHAKE *hs, Span<uint8_t> msg) {
-  SSL *const ssl = hs->ssl;
+bool tls13_write_psk_binder(const SSL_HANDSHAKE *hs,
+                            Span<uint8_t> msg) {
+  const SSL *const ssl = hs->ssl;
   const EVP_MD *digest = ssl_session_get_digest(ssl->session.get());
-  size_t hash_len = EVP_MD_size(digest);
-
-  ScopedEVP_MD_CTX ctx;
-  uint8_t context[EVP_MAX_MD_SIZE];
-  size_t context_len;
+  const size_t hash_len = EVP_MD_size(digest);
+  // We only offer one PSK, so the binders are a u16 and u8 length
+  // prefix, followed by the binder. The caller is assumed to have constructed
+  // |msg| with placeholder binders.
+  const size_t binders_len = 3 + hash_len;
   uint8_t verify_data[EVP_MAX_MD_SIZE];
   size_t verify_data_len;
-  if (!hash_transcript_and_truncated_client_hello(
-          hs, context, &context_len, digest, msg,
-          1 /* length prefix */ + hash_len) ||
-      !tls13_psk_binder(
-          verify_data, &verify_data_len, ssl->session->ssl_version, digest,
-          MakeConstSpan(ssl->session->secret, ssl->session->secret_length),
-          MakeConstSpan(context, context_len)) ||
+  if (!tls13_psk_binder(verify_data, &verify_data_len, ssl->session.get(),
+                        hs->transcript, msg, binders_len) ||
       verify_data_len != hash_len) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return false;
@@ -473,20 +471,17 @@
   return true;
 }
 
-bool tls13_verify_psk_binder(SSL_HANDSHAKE *hs, SSL_SESSION *session,
-                             const SSLMessage &msg, CBS *binders) {
-  uint8_t context[EVP_MAX_MD_SIZE];
-  size_t context_len;
+bool tls13_verify_psk_binder(const SSL_HANDSHAKE *hs,
+                             const SSL_SESSION *session, const SSLMessage &msg,
+                             CBS *binders) {
   uint8_t verify_data[EVP_MAX_MD_SIZE];
   size_t verify_data_len;
   CBS binder;
-  if (!hash_transcript_and_truncated_client_hello(hs, context, &context_len,
-                                                  hs->transcript.Digest(),
-                                                  msg.raw, CBS_len(binders)) ||
-      !tls13_psk_binder(verify_data, &verify_data_len, hs->ssl->version,
-                        hs->transcript.Digest(),
-                        MakeConstSpan(session->secret, session->secret_length),
-                        MakeConstSpan(context, context_len)) ||
+  // The binders are computed over |msg| with |binders| and its u16 length
+  // prefix removed. The caller is assumed to have parsed |msg|, extracted
+  // |binders|, and verified the PSK extension is last.
+  if (!tls13_psk_binder(verify_data, &verify_data_len, session, hs->transcript,
+                        msg.raw, 2 + CBS_len(binders)) ||
       // We only consider the first PSK, so compare against the first binder.
       !CBS_get_u8_length_prefixed(binders, &binder)) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);