Move key_share computation out of ClientHello callbacks.

Like the early_data CL, this does shift a bit of logic that was
previously hidden away in the callbacks. For key_share, this is probably
a good move independent of ECH. The logic around HRR, etc., was a little
messy.

Bug: 275
Change-Id: Iafbcebdf66ce1f7957d798a98ee6b996fff24639
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/47986
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index aa015d4..7607d56 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -497,7 +497,8 @@
     hs->early_data_offered = true;
   }
 
-  if (!ssl_write_client_hello(hs)) {
+  if (!ssl_setup_key_shares(hs, /*override_group_id=*/0) ||
+      !ssl_write_client_hello(hs)) {
     return ssl_hs_error;
   }
 
diff --git a/ssl/internal.h b/ssl/internal.h
index d6801df..1893aa5 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1742,10 +1742,6 @@
     uint32_t received;
   } extensions;
 
-  // retry_group is the group ID selected by the server in HelloRetryRequest in
-  // TLS 1.3.
-  uint16_t retry_group = 0;
-
   // error, if |wait| is |ssl_hs_error|, is the error the handshake failed on.
   UniquePtr<ERR_SAVE_STATE> error;
 
@@ -1768,8 +1764,7 @@
   // reconstructed ClientHelloInner message.
   Array<uint8_t> ech_client_hello_buf;
 
-  // key_share_bytes is the value of the previously sent KeyShare extension by
-  // the client in TLS 1.3.
+  // key_share_bytes is the key_share extension that the client should send.
   Array<uint8_t> key_share_bytes;
 
   // ecdh_public_key, for servers, is the key share to be sent to the client in
@@ -2040,6 +2035,12 @@
 bssl::UniquePtr<SSL_SESSION> tls13_create_session_with_ticket(SSL *ssl,
                                                               CBS *body);
 
+// ssl_setup_key_shares computes client key shares and saves them in |hs|. It
+// returns true on success and false on failure. If |override_group_id| is zero,
+// it offers the default groups, including GREASE. If it is non-zero, it offers
+// a single key share of the specified group.
+bool ssl_setup_key_shares(SSL_HANDSHAKE *hs, uint16_t override_group_id);
+
 bool ssl_ext_key_share_parse_serverhello(SSL_HANDSHAKE *hs,
                                          Array<uint8_t> *out_secret,
                                          uint8_t *out_alert, CBS *contents);
diff --git a/ssl/t1_lib.cc b/ssl/t1_lib.cc
index da3400f..f48ed84 100644
--- a/ssl/t1_lib.cc
+++ b/ssl/t1_lib.cc
@@ -405,6 +405,11 @@
     return false;
   }
 
+  // We internally assume zero is never allocated as a group ID.
+  if (group_id == 0) {
+    return false;
+  }
+
   for (uint16_t supported : tls1_get_grouplist(hs)) {
     if (supported == group_id) {
       return true;
@@ -2242,42 +2247,33 @@
 //
 // https://tools.ietf.org/html/rfc8446#section-4.2.8
 
-static bool ext_key_share_add_clienthello(SSL_HANDSHAKE *hs, CBB *out) {
+bool ssl_setup_key_shares(SSL_HANDSHAKE *hs, uint16_t override_group_id) {
   SSL *const ssl = hs->ssl;
+  hs->key_shares[0].reset();
+  hs->key_shares[1].reset();
+  hs->key_share_bytes.Reset();
+
   if (hs->max_version < TLS1_3_VERSION) {
     return true;
   }
 
-  CBB contents, kse_bytes;
-  if (!CBB_add_u16(out, TLSEXT_TYPE_key_share) ||
-      !CBB_add_u16_length_prefixed(out, &contents) ||
-      !CBB_add_u16_length_prefixed(&contents, &kse_bytes)) {
+  bssl::ScopedCBB cbb;
+  if (!CBB_init(cbb.get(), 64)) {
     return false;
   }
 
-  uint16_t group_id = hs->retry_group;
-  uint16_t second_group_id = 0;
-  if (ssl->s3 && ssl->s3->used_hello_retry_request) {
-    // We received a HelloRetryRequest without a new curve, so there is no new
-    // share to append. Leave |hs->key_share| as-is.
-    if (group_id == 0 &&
-        !CBB_add_bytes(&kse_bytes, hs->key_share_bytes.data(),
-                       hs->key_share_bytes.size())) {
-      return false;
-    }
-    if (group_id == 0) {
-      return CBB_flush(out);
-    }
-  } else {
+  if (override_group_id == 0 && ssl->ctx->grease_enabled) {
     // Add a fake group. See RFC 8701.
-    if (ssl->ctx->grease_enabled &&
-        (!CBB_add_u16(&kse_bytes,
-                      ssl_get_grease_value(hs, ssl_grease_group)) ||
-         !CBB_add_u16(&kse_bytes, 1 /* length */) ||
-         !CBB_add_u8(&kse_bytes, 0 /* one byte key share */))) {
+    if (!CBB_add_u16(cbb.get(), ssl_get_grease_value(hs, ssl_grease_group)) ||
+        !CBB_add_u16(cbb.get(), 1 /* length */) ||
+        !CBB_add_u8(cbb.get(), 0 /* one byte key share */)) {
       return false;
     }
+  }
 
+  uint16_t group_id = override_group_id;
+  uint16_t second_group_id = 0;
+  if (override_group_id == 0) {
     // Predict the most preferred group.
     Span<const uint16_t> groups = tls1_get_grouplist(hs);
     if (groups.empty()) {
@@ -2297,34 +2293,43 @@
 
   CBB key_exchange;
   hs->key_shares[0] = SSLKeyShare::Create(group_id);
-  if (!hs->key_shares[0] ||
-      !CBB_add_u16(&kse_bytes, group_id) ||
-      !CBB_add_u16_length_prefixed(&kse_bytes, &key_exchange) ||
-      !hs->key_shares[0]->Offer(&key_exchange) ||
-      !CBB_flush(&kse_bytes)) {
+  if (!hs->key_shares[0] ||  //
+      !CBB_add_u16(cbb.get(), group_id) ||
+      !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) ||
+      !hs->key_shares[0]->Offer(&key_exchange)) {
     return false;
   }
 
   if (second_group_id != 0) {
     hs->key_shares[1] = SSLKeyShare::Create(second_group_id);
-    if (!hs->key_shares[1] ||
-        !CBB_add_u16(&kse_bytes, second_group_id) ||
-        !CBB_add_u16_length_prefixed(&kse_bytes, &key_exchange) ||
-        !hs->key_shares[1]->Offer(&key_exchange) ||
-        !CBB_flush(&kse_bytes)) {
+    if (!hs->key_shares[1] ||  //
+        !CBB_add_u16(cbb.get(), second_group_id) ||
+        !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) ||
+        !hs->key_shares[1]->Offer(&key_exchange)) {
       return false;
     }
   }
 
-  // Save the contents of the extension to repeat it in the second
-  // ClientHello.
-  if (ssl->s3 && !ssl->s3->used_hello_retry_request &&
-      !hs->key_share_bytes.CopyFrom(
-          MakeConstSpan(CBB_data(&kse_bytes), CBB_len(&kse_bytes)))) {
+  return CBBFinishArray(cbb.get(), &hs->key_share_bytes);
+}
+
+static bool ext_key_share_add_clienthello(SSL_HANDSHAKE *hs, CBB *out) {
+  if (hs->max_version < TLS1_3_VERSION) {
+    return true;
+  }
+
+  assert(!hs->key_share_bytes.empty());
+  CBB contents, kse_bytes;
+  if (!CBB_add_u16(out, TLSEXT_TYPE_key_share) ||
+      !CBB_add_u16_length_prefixed(out, &contents) ||
+      !CBB_add_u16_length_prefixed(&contents, &kse_bytes) ||
+      !CBB_add_bytes(&kse_bytes, hs->key_share_bytes.data(),
+                     hs->key_share_bytes.size()) ||
+      !CBB_flush(out)) {
     return false;
   }
 
-  return CBB_flush(out);
+  return true;
 }
 
 bool ssl_ext_key_share_parse_serverhello(SSL_HANDSHAKE *hs,
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index c7f45f6..92ccf62 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -222,9 +222,9 @@
       return ssl_hs_error;
     }
 
-    hs->key_shares[0].reset();
-    hs->key_shares[1].reset();
-    hs->retry_group = group_id;
+    if (!ssl_setup_key_shares(hs, group_id)) {
+      return ssl_hs_error;
+    }
   }
 
   if (!ssl_hash_message(hs, msg)) {