Simplify PSK binder calculation

The PSK extension has a really annoying step. There is a list of
"binders" whose value depends on the hash of the entire ClientHello,
truncated to the binders block at the end. This means it depends on two
length prefixes that, using our CBB API, have not yet been closed out:

1. The length prefix on the extension list
2. The length prefix on the entire ClientHello message

As a result, we implemented this by filling in all zeros, and then
post-processing the message after the fact to fill them in. This got
messier with ECH, which constructs lots of different ClientHellos and
ClientHello-like structures. It will get even messier with external
PSKs, which now require us to compute potentially multiple binders.

Now that the PSK extension is treated fairly special anyway, we can
change its calling convention and fold the logic into the PSK extension.
This is effectively several steps, but the intermediate points are
awkward, so this is one CL.

First, we take care of the message header by having the binder
calculation compute the header on demand. This was already a bit hairy
on the DTLS side due to an unfortunate in-memory representation we need
to juggle in DTLS. Now that is moot.

That lets us push binder calculation into ssl_add_clienthello_tlsext, as
that is passed a CBB that will ultimately contain the ClientHello, minus
message header. But ssl_add_clienthello_tlsext bifurcates a bit between
ECH and non-ECH. Ideally we'd push it one layer deeper.

Next we change the ext_pre_shared_key_add_clienthello calling convention
to take two generations of CBB into the same function: the unfinished
ClientHello and the unfinished extensions block. This is a bit unusual
but lets it perform all three steps together:

1. Write out the PSK extension, with placeholder binder
2. Close out the extension block to make the ClientHello coherent
3. Replace the placeholder binder with the real binder, computed over
   the now coherent ClientHello.

This should make it a lot easier to add more complex PSK support as it
can all be encapsulated in that one function. We now don't need to
thread a needs_psk_binder output, though we do need to return the length
of the PSK extension to help ECH copy it in two places. (That could
probably be avoided with more math, but this seemed simpler. The main
issue is that, after ext_pre_shared_key_add_clienthello returns,
CBB_len(&extensions) is no longer usable.)

Bug: 369963041
Change-Id: I16d66567bd4eec84397b0e8e05df57bb257d3b7e
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/88409
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Lily Chen <chlily@google.com>
diff --git a/ssl/encrypted_client_hello.cc b/ssl/encrypted_client_hello.cc
index f88e236..44a3c03 100644
--- a/ssl/encrypted_client_hello.cc
+++ b/ssl/encrypted_client_hello.cc
@@ -793,7 +793,6 @@
   // draft-ietf-tls-esni-13, sections 5.1 and 6.1.
   ScopedCBB cbb, encoded_cbb;
   CBB body;
-  bool needs_psk_binder;
   Array<uint8_t> hello_inner;
   if (!ssl->method->init_message(ssl, cbb.get(), &body, SSL3_MT_CLIENT_HELLO) ||
       !CBB_init(encoded_cbb.get(), 256) ||
@@ -804,25 +803,12 @@
                                                  ssl_client_hello_inner,
                                                  /*empty_session_id=*/true) ||
       !ssl_add_clienthello_tlsext(hs, &body, encoded_cbb.get(),
-                                  &needs_psk_binder, ssl_client_hello_inner) ||
+                                  ssl_client_hello_inner) ||
       !ssl->method->finish_message(ssl, cbb.get(), &hello_inner)) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return false;
   }
 
-  if (needs_psk_binder) {
-    size_t binder_len;
-    if (!tls13_write_psk_binder(hs, hs->inner_transcript, Span(hello_inner),
-                                &binder_len)) {
-      return false;
-    }
-    // Also update the EncodedClientHelloInner.
-    auto encoded_binder = CBBAsSpan(encoded_cbb.get()).last(binder_len);
-    auto hello_inner_binder = Span(hello_inner).last(binder_len);
-    OPENSSL_memcpy(encoded_binder.data(), hello_inner_binder.data(),
-                   binder_len);
-  }
-
   ssl_do_msg_callback(ssl, /*is_write=*/1, SSL3_RT_CLIENT_HELLO_INNER,
                       hello_inner);
   if (!hs->inner_transcript.Update(hello_inner)) {
@@ -879,15 +865,11 @@
                                                  ssl_client_hello_outer,
                                                  /*empty_session_id=*/false) ||
       !ssl_add_clienthello_tlsext(hs, aad.get(), /*out_encoded=*/nullptr,
-                                  &needs_psk_binder, ssl_client_hello_outer)) {
+                                  ssl_client_hello_outer)) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return false;
   }
 
-  // ClientHelloOuter may not require a PSK binder. Otherwise, we have a
-  // circular dependency.
-  assert(!needs_psk_binder);
-
   // Replace the payload in |hs->ech_client_outer| with the encrypted value.
   auto payload_span = Span(hs->ech_client_outer).last(payload_len);
   if (CRYPTO_fuzzer_mode_enabled()) {
diff --git a/ssl/extensions.cc b/ssl/extensions.cc
index 36b9986..daf6435 100644
--- a/ssl/extensions.cc
+++ b/ssl/extensions.cc
@@ -1914,26 +1914,37 @@
   return 15 + ssl->session->ticket.size() + binder_len;
 }
 
+// ext_pre_shared_key_add_clienthello writes a pre_shared_key extension to
+// |out_extensions| and flushes |out_client_hello|, invalidating
+// |out_extensions|. |out_extensions| must be a child of |out_client_hello|.
+//
+// This function differs from other |CBB| functions because it needs to
+// accommodate PSK binders. It must write the PSK extension, flush the |CBB| to
+// write out a length prefix, and then finally sample the whole ClientHello.
 static bool ext_pre_shared_key_add_clienthello(const SSL_HANDSHAKE *hs,
-                                               CBB *out, bool *out_needs_binder,
+                                               CBB *out_client_hello,
+                                               CBB *out_extensions,
+                                               size_t *out_psk_len,
                                                ssl_client_hello_type_t type) {
   const SSL *const ssl = hs->ssl;
-  *out_needs_binder = false;
   if (!should_offer_psk(hs, type)) {
-    return true;
+    *out_psk_len = 0;
+    // Discard empty extensions blocks.
+    if (CBB_len(out_extensions) == 0) {
+      CBB_discard_child(out_client_hello);
+    }
+    return CBB_flush(out_client_hello);
   }
 
   OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
   uint32_t ticket_age = 1000 * (now.tv_sec - ssl->session->time);
   uint32_t obfuscated_ticket_age = ticket_age + ssl->session->ticket_age_add;
 
-  // Fill in a placeholder zero binder of the appropriate length. It will be
-  // computed and filled in later after length prefixes are computed.
   size_t binder_len = EVP_MD_size(ssl_session_get_digest(ssl->session.get()));
-
+  const size_t len_before = CBB_len(out_extensions);
   CBB contents, identity, ticket, binders, binder;
-  if (!CBB_add_u16(out, TLSEXT_TYPE_pre_shared_key) ||
-      !CBB_add_u16_length_prefixed(out, &contents) ||
+  if (!CBB_add_u16(out_extensions, TLSEXT_TYPE_pre_shared_key) ||
+      !CBB_add_u16_length_prefixed(out_extensions, &contents) ||
       !CBB_add_u16_length_prefixed(&contents, &identity) ||
       !CBB_add_u16_length_prefixed(&identity, &ticket) ||
       !CBB_add_bytes(&ticket, ssl->session->ticket.data(),
@@ -1941,12 +1952,24 @@
       !CBB_add_u32(&identity, obfuscated_ticket_age) ||
       !CBB_add_u16_length_prefixed(&contents, &binders) ||
       !CBB_add_u8_length_prefixed(&binders, &binder) ||
-      !CBB_add_zeros(&binder, binder_len)) {
+      // Fill in a placeholder zero binder of the appropriate length. It will be
+      // computed and filled in later after length prefixes are computed.
+      !CBB_add_zeros(&binder, binder_len) ||  //
+      !CBB_flush(out_extensions)) {
+    return false;
+  }
+  // Sample the length of the PSK extension.
+  *out_psk_len = CBB_len(out_extensions) - len_before;
+
+  // Close |out_extensions| and fill in the binder.
+  const auto &transcript =
+      type == ssl_client_hello_inner ? hs->inner_transcript : hs->transcript;
+  if (!CBB_flush(out_client_hello) ||
+      !tls13_write_psk_binder(hs, transcript, CBBAsSpan(out_client_hello))) {
     return false;
   }
 
-  *out_needs_binder = true;
-  return CBB_flush(out);
+  return true;
 }
 
 bool ssl_ext_pre_shared_key_parse_serverhello(SSL_HANDSHAKE *hs,
@@ -3810,8 +3833,7 @@
 }
 
 static bool ssl_add_clienthello_tlsext_inner(SSL_HANDSHAKE *hs, CBB *out,
-                                             CBB *out_encoded,
-                                             bool *out_needs_psk_binder) {
+                                             CBB *out_encoded) {
   // When writing ClientHelloInner, we construct the real and encoded
   // ClientHellos concurrently, to handle compression. Uncompressed extensions
   // are written to |extensions| and copied to |extensions_encoded|. Compressed
@@ -3901,15 +3923,16 @@
     }
   }
 
-  // The PSK extension must be last. It is never compressed. Note, if there is a
-  // binder, the caller will need to update both ClientHelloInner and
-  // EncodedClientHelloInner after computing it.
-  const size_t len_before = CBB_len(&extensions);
-  if (!ext_pre_shared_key_add_clienthello(hs, &extensions, out_needs_psk_binder,
-                                          ssl_client_hello_inner) ||
-      !CBB_add_bytes(&extensions_encoded, CBB_data(&extensions) + len_before,
-                     CBB_len(&extensions) - len_before) ||
-      !CBB_flush(out) ||  //
+  // The PSK extension must be last. It is never compressed.
+  size_t psk_len;
+  if (!ext_pre_shared_key_add_clienthello(hs, out, &extensions, &psk_len,
+                                          ssl_client_hello_inner)) {
+    return false;
+  }
+
+  // Copy the PSK extension to EncodedClientHelloInner.
+  auto psk = CBBAsSpan(out).last(psk_len);
+  if (!CBB_add_bytes(&extensions_encoded, psk.data(), psk.size()) ||
       !CBB_flush(out_encoded)) {
     return false;
   }
@@ -3918,18 +3941,14 @@
 }
 
 bool ssl_add_clienthello_tlsext(SSL_HANDSHAKE *hs, CBB *out, CBB *out_encoded,
-                                bool *out_needs_psk_binder,
                                 ssl_client_hello_type_t type) {
-  *out_needs_psk_binder = false;
-
   // |out| must contain the start of a ClientHello, which means it must begin
   // with a TLS or DTLS version.
   assert(CBB_len(out) != 0 && (CBB_data(out)[0] == SSL3_VERSION_MAJOR ||
                                CBB_data(out)[0] == DTLS1_VERSION_MAJOR));
 
   if (type == ssl_client_hello_inner) {
-    return ssl_add_clienthello_tlsext_inner(hs, out, out_encoded,
-                                            out_needs_psk_binder);
+    return ssl_add_clienthello_tlsext_inner(hs, out, out_encoded);
   }
 
   // Sample the length of the ClientHello thus far, including the message
@@ -3988,10 +4007,10 @@
   // In cleartext ClientHellos, we add the padding extension to work around
   // bugs. We also apply this padding to ClientHelloOuter, to keep the wire
   // images aligned.
-  size_t psk_extension_len = ext_pre_shared_key_clienthello_length(hs, type);
+  size_t psk_len = ext_pre_shared_key_clienthello_length(hs, type);
   if (!SSL_is_dtls(ssl) && !SSL_is_quic(ssl) &&
       !ssl->s3->used_hello_retry_request) {
-    msg_len += 2 /* length prefix */ + CBB_len(&extensions) + psk_extension_len;
+    msg_len += 2 /* length prefix */ + CBB_len(&extensions) + psk_len;
     // The length of the padding extension, excluding the four-byte extension
     // header.
     size_t padding_len = 0;
@@ -3999,7 +4018,7 @@
     // The final extension must be non-empty. WebSphere Application
     // Server 7.0 is intolerant to the last extension being zero-length. See
     // https://crbug.com/363583.
-    if (last_was_empty && psk_extension_len == 0) {
+    if (last_was_empty && psk_len == 0) {
       padding_len = 1;
       // The addition of the padding extension may push us into the F5 bug.
       msg_len += 4 + padding_len;
@@ -4034,21 +4053,13 @@
   }
 
   // The PSK extension must be last, including after the padding.
-  const size_t len_before = CBB_len(&extensions);
-  if (!ext_pre_shared_key_add_clienthello(hs, &extensions, out_needs_psk_binder,
+  size_t psk_len_actual;
+  if (!ext_pre_shared_key_add_clienthello(hs, out, &extensions, &psk_len_actual,
                                           type)) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return false;
   }
-  assert(psk_extension_len == CBB_len(&extensions) - len_before);
-  (void)len_before;  // |assert| is omitted in release builds.
-
-  // Discard empty extensions blocks.
-  if (CBB_len(&extensions) == 0) {
-    CBB_discard_child(out);
-  }
-
-  return CBB_flush(out);
+  assert(psk_len_actual == psk_len);
+  return true;
 }
 
 bool ssl_add_serverhello_tlsext(SSL_HANDSHAKE *hs, CBB *out) {
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index 638c75f..901040c 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -223,29 +223,15 @@
   ssl_client_hello_type_t type = hs->selected_ech_config
                                      ? ssl_client_hello_outer
                                      : ssl_client_hello_unencrypted;
-  bool needs_psk_binder;
   Array<uint8_t> msg;
   if (!ssl->method->init_message(ssl, cbb.get(), &body, SSL3_MT_CLIENT_HELLO) ||
       !ssl_write_client_hello_without_extensions(hs, &body, type,
                                                  /*empty_session_id=*/false) ||
-      !ssl_add_clienthello_tlsext(hs, &body, /*out_encoded=*/nullptr,
-                                  &needs_psk_binder, type) ||
+      !ssl_add_clienthello_tlsext(hs, &body, /*out_encoded=*/nullptr, type) ||
       !ssl->method->finish_message(ssl, cbb.get(), &msg)) {
     return false;
   }
 
-  // Now that the length prefixes have been computed, fill in the placeholder
-  // PSK binder.
-  if (needs_psk_binder) {
-    // ClientHelloOuter cannot have a PSK binder. Otherwise the
-    // ClientHellOuterAAD computation would break.
-    assert(type != ssl_client_hello_outer);
-    if (!tls13_write_psk_binder(hs, hs->transcript, Span(msg),
-                                /*out_binder_len=*/nullptr)) {
-      return false;
-    }
-  }
-
   return ssl->method->add_message(ssl, std::move(msg));
 }
 
diff --git a/ssl/internal.h b/ssl/internal.h
index 307729d..9053d88 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1256,11 +1256,10 @@
 
 // tls13_write_psk_binder calculates the PSK binder value over |transcript| and
 // |msg|, and replaces the last bytes of |msg| with the resulting value. It
-// returns true on success, and false on failure. If |out_binder_len| is
-// non-NULL, it sets |*out_binder_len| to the length of the value computed.
+// returns true on success, and false on failure. |msg| should contain the body
+// of a ClientHello, but not the message header.
 bool tls13_write_psk_binder(const SSL_HANDSHAKE *hs,
-                            const SSLTranscript &transcript, Span<uint8_t> msg,
-                            size_t *out_binder_len);
+                            const SSLTranscript &transcript, Span<uint8_t> msg);
 
 // tls13_verify_psk_binder verifies that the handshake transcript, truncated up
 // to the binders has a valid signature using the value of |session|'s
@@ -3626,18 +3625,12 @@
 // ssl_add_clienthello_tlsext writes ClientHello extensions to |out| for |type|.
 // It returns true on success and false on failure. |out| must currently contain
 // a ClientHello message, not including the message and record header. (Its
-// current length will be used to compute padding.)
+// contents will be used to compute padding and PSK binders.)
 //
 // If |type| is |ssl_client_hello_inner|, this function also writes the
 // compressed extensions to |out_encoded|. Otherwise, |out_encoded| should be
 // nullptr.
-//
-// On success, the function sets |*out_needs_psk_binder| to whether the last
-// ClientHello extension was the pre_shared_key extension and needs a PSK binder
-// filled in. The caller should then update |out| and, if applicable,
-// |out_encoded| with the binder after completing the whole message.
 bool ssl_add_clienthello_tlsext(SSL_HANDSHAKE *hs, CBB *out, CBB *out_encoded,
-                                bool *out_needs_psk_binder,
                                 ssl_client_hello_type_t type);
 
 bool ssl_add_serverhello_tlsext(SSL_HANDSHAKE *hs, CBB *out);
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index c5baa83..5c0099e 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -510,6 +510,10 @@
                              const SSLTranscript &transcript,
                              Span<const uint8_t> client_hello,
                              size_t binders_len, bool is_dtls) {
+  // |client_hello| should not include the message header and begin with a
+  // version number.
+  assert(!client_hello.empty() && (client_hello[0] == SSL3_VERSION_MAJOR ||
+                                   client_hello[0] == DTLS1_VERSION_MAJOR));
   const EVP_MD *digest = ssl_session_get_digest(session);
 
   // Compute the binder key.
@@ -534,38 +538,27 @@
     return false;
   }
 
-  // Hash the transcript and truncated ClientHello.
-  if (client_hello.size() < binders_len) {
+  // Hash the transcript and truncated ClientHello. As part of this, construct
+  // the expected ClientHello header.
+  if (client_hello.size() < binders_len || client_hello.size() > 0xffffff) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return false;
   }
+  uint8_t header[4] = {
+      SSL3_MT_CLIENT_HELLO,
+      static_cast<uint8_t>(client_hello.size() >> 16),
+      static_cast<uint8_t>(client_hello.size() >> 8),
+      static_cast<uint8_t>(client_hello.size()),
+  };
   auto truncated = client_hello.subspan(0, client_hello.size() - binders_len);
   uint8_t context[EVP_MAX_MD_SIZE];
   unsigned context_len;
   ScopedEVP_MD_CTX ctx;
-  if (!is_dtls) {
-    if (!transcript.CopyToHashContext(ctx.get(), digest) ||
-        !EVP_DigestUpdate(ctx.get(), truncated.data(), truncated.size()) ||
-        !EVP_DigestFinal_ex(ctx.get(), context, &context_len)) {
-      return false;
-    }
-  } else {
-    // In DTLS 1.3, the transcript hash is computed over only the TLS 1.3
-    // handshake messages (i.e. only type and length in the header), not the
-    // full DTLSHandshake messages that are in |truncated|. This code pulls
-    // the header and body out of the truncated ClientHello and writes those
-    // to the hash context so the correct binder value is computed.
-    if (truncated.size() < DTLS1_HM_HEADER_LENGTH) {
-      return false;
-    }
-    auto header = truncated.first<4>();
-    auto body = truncated.subspan<12>();
-    if (!transcript.CopyToHashContext(ctx.get(), digest) ||
-        !EVP_DigestUpdate(ctx.get(), header.data(), header.size()) ||
-        !EVP_DigestUpdate(ctx.get(), body.data(), body.size()) ||
-        !EVP_DigestFinal_ex(ctx.get(), context, &context_len)) {
-      return false;
-    }
+  if (!transcript.CopyToHashContext(ctx.get(), digest) ||
+      !EVP_DigestUpdate(ctx.get(), header, sizeof(header)) ||
+      !EVP_DigestUpdate(ctx.get(), truncated.data(), truncated.size()) ||
+      !EVP_DigestFinal_ex(ctx.get(), context, &context_len)) {
+    return false;
   }
 
   if (!tls13_verify_data(out, out_len, digest, session->ssl_version, binder_key,
@@ -578,8 +571,8 @@
 }
 
 bool tls13_write_psk_binder(const SSL_HANDSHAKE *hs,
-                            const SSLTranscript &transcript, Span<uint8_t> msg,
-                            size_t *out_binder_len) {
+                            const SSLTranscript &transcript,
+                            Span<uint8_t> msg) {
   const SSL *const ssl = hs->ssl;
   const EVP_MD *digest = ssl_session_get_digest(ssl->session.get());
   const size_t hash_len = EVP_MD_size(digest);
@@ -598,9 +591,6 @@
 
   auto msg_binder = msg.last(verify_data_len);
   OPENSSL_memcpy(msg_binder.data(), verify_data, verify_data_len);
-  if (out_binder_len != nullptr) {
-    *out_binder_len = verify_data_len;
-  }
   return true;
 }
 
@@ -614,7 +604,7 @@
   // prefix removed. The caller is assumed to have parsed |msg|, extracted
   // |binders|, and verified the PSK extension is last.
   if (!tls13_psk_binder(verify_data, &verify_data_len, session, hs->transcript,
-                        msg.raw, 2 + CBS_len(binders), SSL_is_dtls(hs->ssl)) ||
+                        msg.body, 2 + CBS_len(binders), SSL_is_dtls(hs->ssl)) ||
       // We only consider the first PSK, so compare against the first binder.
       !CBS_get_u8_length_prefixed(binders, &binder)) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);