Rearrange key share and early data logic.

We currently determine whether we need HelloRetryRequest at the same
time as resolving key share machinery. That is a little too late for
early data negotiation, so we end up accepting early data and then
clearing it later on in the function. This works but is easy to mess up,
given the preceding CL. There's also some ALPS logic that got this
wrong, but I believe it didn't result in any incorrect behavior.

Instead, this pulls secret computation out of the key_share helper
function, which now just finds the matching key share. We then check
early whether we need HRR, before deciding on early data.

Change-Id: I108865da08addfefed4a7db73c60e11cf4335093
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/46765
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/internal.h b/ssl/internal.h
index 9c048bb..7a3979b 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2018,10 +2018,10 @@
                                          Array<uint8_t> *out_secret,
                                          uint8_t *out_alert, CBS *contents);
 bool ssl_ext_key_share_parse_clienthello(SSL_HANDSHAKE *hs, bool *out_found,
-                                         Array<uint8_t> *out_secret,
-                                         uint8_t *out_alert, CBS *contents);
-bool ssl_ext_key_share_add_serverhello(SSL_HANDSHAKE *hs, CBB *out,
-                                       bool dry_run);
+                                         Span<const uint8_t> *out_peer_key,
+                                         uint8_t *out_alert,
+                                         const SSL_CLIENT_HELLO *client_hello);
+bool ssl_ext_key_share_add_serverhello(SSL_HANDSHAKE *hs, CBB *out);
 
 bool ssl_ext_pre_shared_key_parse_serverhello(SSL_HANDSHAKE *hs,
                                               uint8_t *out_alert,
diff --git a/ssl/t1_lib.cc b/ssl/t1_lib.cc
index 4216fcd..20517c4 100644
--- a/ssl/t1_lib.cc
+++ b/ssl/t1_lib.cc
@@ -2444,25 +2444,29 @@
 }
 
 bool ssl_ext_key_share_parse_clienthello(SSL_HANDSHAKE *hs, bool *out_found,
-                                         Array<uint8_t> *out_secret,
-                                         uint8_t *out_alert, CBS *contents) {
-  uint16_t group_id;
-  CBS key_shares;
-  if (!tls1_get_shared_group(hs, &group_id)) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_NO_SHARED_GROUP);
-    *out_alert = SSL_AD_HANDSHAKE_FAILURE;
+                                         Span<const uint8_t> *out_peer_key,
+                                         uint8_t *out_alert,
+                                         const SSL_CLIENT_HELLO *client_hello) {
+  // We only support connections that include an ECDHE key exchange.
+  CBS contents;
+  if (!ssl_client_hello_get_extension(client_hello, &contents,
+                                      TLSEXT_TYPE_key_share)) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_MISSING_KEY_SHARE);
+    *out_alert = SSL_AD_MISSING_EXTENSION;
     return false;
   }
 
-  if (!CBS_get_u16_length_prefixed(contents, &key_shares) ||
-      CBS_len(contents) != 0) {
+  CBS key_shares;
+  if (!CBS_get_u16_length_prefixed(&contents, &key_shares) ||
+      CBS_len(&contents) != 0) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
     return false;
   }
 
   // Find the corresponding key share.
+  const uint16_t group_id = hs->new_session->group_id;
   CBS peer_key;
-  CBS_init(&peer_key, NULL, 0);
+  CBS_init(&peer_key, nullptr, 0);
   while (CBS_len(&key_shares) > 0) {
     uint16_t id;
     CBS peer_key_tmp;
@@ -2485,47 +2489,24 @@
     }
   }
 
-  if (CBS_len(&peer_key) == 0) {
-    *out_found = false;
-    out_secret->Reset();
-    return true;
+  if (out_peer_key != nullptr) {
+    *out_peer_key = peer_key;
   }
-
-  // Compute the DH secret.
-  Array<uint8_t> secret;
-  ScopedCBB public_key;
-  UniquePtr<SSLKeyShare> key_share = SSLKeyShare::Create(group_id);
-  if (!key_share ||
-      !CBB_init(public_key.get(), 32) ||
-      !key_share->Accept(public_key.get(), &secret, out_alert, peer_key) ||
-      !CBBFinishArray(public_key.get(), &hs->ecdh_public_key)) {
-    *out_alert = SSL_AD_ILLEGAL_PARAMETER;
-    return false;
-  }
-
-  *out_secret = std::move(secret);
-  *out_found = true;
+  *out_found = CBS_len(&peer_key) != 0;
   return true;
 }
 
-bool ssl_ext_key_share_add_serverhello(SSL_HANDSHAKE *hs, CBB *out,
-                                       bool dry_run) {
-  uint16_t group_id;
+bool ssl_ext_key_share_add_serverhello(SSL_HANDSHAKE *hs, CBB *out) {
   CBB kse_bytes, public_key;
-  if (!tls1_get_shared_group(hs, &group_id) ||
-      !CBB_add_u16(out, TLSEXT_TYPE_key_share) ||
+  if (!CBB_add_u16(out, TLSEXT_TYPE_key_share) ||
       !CBB_add_u16_length_prefixed(out, &kse_bytes) ||
-      !CBB_add_u16(&kse_bytes, group_id) ||
+      !CBB_add_u16(&kse_bytes, hs->new_session->group_id) ||
       !CBB_add_u16_length_prefixed(&kse_bytes, &public_key) ||
       !CBB_add_bytes(&public_key, hs->ecdh_public_key.data(),
                      hs->ecdh_public_key.size()) ||
       !CBB_flush(out)) {
     return false;
   }
-  if (!dry_run) {
-    hs->ecdh_public_key.Reset();
-    hs->new_session->group_id = group_id;
-  }
   return true;
 }
 
diff --git a/ssl/tls13_server.cc b/ssl/tls13_server.cc
index adc9d39..b4f5297 100644
--- a/ssl/tls13_server.cc
+++ b/ssl/tls13_server.cc
@@ -42,35 +42,38 @@
 // See RFC 8446, section 8.3.
 static const int32_t kMaxTicketAgeSkewSeconds = 60;
 
-static int resolve_ecdhe_secret(SSL_HANDSHAKE *hs, bool *out_need_retry,
-                                SSL_CLIENT_HELLO *client_hello) {
+static bool resolve_ecdhe_secret(SSL_HANDSHAKE *hs,
+                                 const SSL_CLIENT_HELLO *client_hello) {
   SSL *const ssl = hs->ssl;
-  *out_need_retry = false;
-
-  // We only support connections that include an ECDHE key exchange.
-  CBS key_share;
-  if (!ssl_client_hello_get_extension(client_hello, &key_share,
-                                      TLSEXT_TYPE_key_share)) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_MISSING_KEY_SHARE);
-    ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_MISSING_EXTENSION);
-    return 0;
-  }
+  const uint16_t group_id = hs->new_session->group_id;
 
   bool found_key_share;
-  Array<uint8_t> dhe_secret;
+  Span<const uint8_t> peer_key;
   uint8_t alert = SSL_AD_DECODE_ERROR;
-  if (!ssl_ext_key_share_parse_clienthello(hs, &found_key_share, &dhe_secret,
-                                           &alert, &key_share)) {
+  if (!ssl_ext_key_share_parse_clienthello(hs, &found_key_share, &peer_key,
+                                           &alert, client_hello)) {
     ssl_send_alert(ssl, SSL3_AL_FATAL, alert);
-    return 0;
+    return false;
   }
 
   if (!found_key_share) {
-    *out_need_retry = true;
-    return 0;
+    ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
+    OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_CURVE);
+    return false;
   }
 
-  return tls13_advance_key_schedule(hs, dhe_secret);
+  Array<uint8_t> secret;
+  ScopedCBB public_key;
+  UniquePtr<SSLKeyShare> key_share = SSLKeyShare::Create(group_id);
+  if (!key_share ||
+      !CBB_init(public_key.get(), 32) ||
+      !key_share->Accept(public_key.get(), &secret, &alert, peer_key) ||
+      !CBBFinishArray(public_key.get(), &hs->ecdh_public_key)) {
+    ssl_send_alert(ssl, SSL3_AL_FATAL, alert);
+    return false;
+  }
+
+  return tls13_advance_key_schedule(hs, secret);
 }
 
 static int ssl_ext_supported_versions_add_serverhello(SSL_HANDSHAKE *hs,
@@ -394,6 +397,23 @@
     return ssl_hs_error;
   }
 
+  // Record connection properties in the new session.
+  hs->new_session->cipher = hs->new_cipher;
+  if (!tls1_get_shared_group(hs, &hs->new_session->group_id)) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_NO_SHARED_GROUP);
+    ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
+    return ssl_hs_error;
+  }
+
+  // Determine if we need HelloRetryRequest.
+  bool found_key_share;
+  if (!ssl_ext_key_share_parse_clienthello(hs, &found_key_share,
+                                           /*out_key_share=*/nullptr, &alert,
+                                           &client_hello)) {
+    ssl_send_alert(ssl, SSL3_AL_FATAL, alert);
+    return ssl_hs_error;
+  }
+
   // Determine if we're negotiating 0-RTT.
   if (!ssl->enable_early_data) {
     ssl->s3->early_data_reason = ssl_early_data_disabled;
@@ -424,6 +444,8 @@
     ssl->s3->early_data_reason = ssl_early_data_ticket_age_skew;
   } else if (!quic_ticket_compatible(session.get(), hs->config)) {
     ssl->s3->early_data_reason = ssl_early_data_quic_parameter_mismatch;
+  } else if (!found_key_share) {
+    ssl->s3->early_data_reason = ssl_early_data_hello_retry_request;
   } else {
     // |ssl_session_is_resumable| forbids cross-cipher resumptions even if the
     // PRF hashes match.
@@ -433,9 +455,6 @@
     ssl->s3->early_data_accepted = true;
   }
 
-  // Record connection properties in the new session.
-  hs->new_session->cipher = hs->new_cipher;
-
   // Store the ALPN and ALPS values in the session for 0-RTT. Note the peer
   // applications settings are not generally known until client
   // EncryptedExtensions.
@@ -498,24 +517,16 @@
     ssl->s3->skip_early_data = true;
   }
 
-  // Resolve ECDHE and incorporate it into the secret.
-  bool need_retry;
-  if (!resolve_ecdhe_secret(hs, &need_retry, &client_hello)) {
-    if (need_retry) {
-      if (ssl->s3->early_data_accepted) {
-        ssl->s3->early_data_reason = ssl_early_data_hello_retry_request;
-        ssl->s3->early_data_accepted = false;
-      }
-      if (hs->early_data_offered) {
-        ssl->s3->skip_early_data = true;
-      }
-      ssl->method->next_message(ssl);
-      if (!hs->transcript.UpdateForHelloRetryRequest()) {
-        return ssl_hs_error;
-      }
-      hs->tls13_state = state13_send_hello_retry_request;
-      return ssl_hs_ok;
+  if (!found_key_share) {
+    ssl->method->next_message(ssl);
+    if (!hs->transcript.UpdateForHelloRetryRequest()) {
+      return ssl_hs_error;
     }
+    hs->tls13_state = state13_send_hello_retry_request;
+    return ssl_hs_ok;
+  }
+
+  if (!resolve_ecdhe_secret(hs, &client_hello)) {
     return ssl_hs_error;
   }
 
@@ -676,13 +687,7 @@
     }
   }
 
-  bool need_retry;
-  if (!resolve_ecdhe_secret(hs, &need_retry, &client_hello)) {
-    if (need_retry) {
-      // Only send one HelloRetryRequest.
-      ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
-      OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_CURVE);
-    }
+  if (!resolve_ecdhe_secret(hs, &client_hello)) {
     return ssl_hs_error;
   }
 
@@ -729,7 +734,7 @@
         !CBB_add_u8(&body, 0) ||
         !CBB_add_u16_length_prefixed(&body, &extensions) ||
         !ssl_ext_pre_shared_key_add_serverhello(hs, &extensions) ||
-        !ssl_ext_key_share_add_serverhello(hs, &extensions, /*dry_run=*/true) ||
+        !ssl_ext_key_share_add_serverhello(hs, &extensions) ||
         !ssl_ext_supported_versions_add_serverhello(hs, &extensions) ||
         !CBB_flush(cbb.get())) {
       return ssl_hs_error;
@@ -756,12 +761,13 @@
       !CBB_add_u8(&body, 0) ||
       !CBB_add_u16_length_prefixed(&body, &extensions) ||
       !ssl_ext_pre_shared_key_add_serverhello(hs, &extensions) ||
-      !ssl_ext_key_share_add_serverhello(hs, &extensions, /*dry_run=*/false) ||
+      !ssl_ext_key_share_add_serverhello(hs, &extensions) ||
       !ssl_ext_supported_versions_add_serverhello(hs, &extensions) ||
       !ssl_add_message_cbb(ssl, cbb.get())) {
     return ssl_hs_error;
   }
 
+  hs->ecdh_public_key.Reset();  // No longer needed.
   if (!ssl->s3->used_hello_retry_request &&
       !ssl->method->add_change_cipher_spec(ssl)) {
     return ssl_hs_error;