Make ssl_parse_extensions a little easier to use.

std::initializer_list appears to work by instantiating a T[N] at the
call site (which is what we were doing anyway), so I don't believe there
is a runtime dependency.

This also adds a way for individual entries to turn themselves off,
which means we don't need to manually check for some unsolicited
extensions.

Change-Id: I40f79b6a0e9c005fc621f4a798fe201bfbf08411
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/48910
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/handshake.cc b/ssl/handshake.cc
index db4ee71..3608888 100644
--- a/ssl/handshake.cc
+++ b/ssl/handshake.cc
@@ -268,12 +268,15 @@
 }
 
 bool ssl_parse_extensions(const CBS *cbs, uint8_t *out_alert,
-                          Span<const SSL_EXTENSION_TYPE> ext_types,
+                          std::initializer_list<SSLExtension *> extensions,
                           bool ignore_unknown) {
   // Reset everything.
-  for (const SSL_EXTENSION_TYPE &ext_type : ext_types) {
-    *ext_type.out_present = false;
-    CBS_init(ext_type.out_data, nullptr, 0);
+  for (SSLExtension *ext : extensions) {
+    ext->present = false;
+    CBS_init(&ext->data, nullptr, 0);
+    if (!ext->allowed) {
+      assert(!ignore_unknown);
+    }
   }
 
   CBS copy = *cbs;
@@ -287,10 +290,10 @@
       return false;
     }
 
-    const SSL_EXTENSION_TYPE *found = nullptr;
-    for (const SSL_EXTENSION_TYPE &ext_type : ext_types) {
-      if (type == ext_type.type) {
-        found = &ext_type;
+    SSLExtension *found = nullptr;
+    for (SSLExtension *ext : extensions) {
+      if (type == ext->type && ext->allowed) {
+        found = ext;
         break;
       }
     }
@@ -305,14 +308,14 @@
     }
 
     // Duplicate ext_types are forbidden.
-    if (*found->out_present) {
+    if (found->present) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_DUPLICATE_EXTENSION);
       *out_alert = SSL_AD_ILLEGAL_PARAMETER;
       return false;
     }
 
-    *found->out_present = 1;
-    *found->out_data = data;
+    found->present = true;
+    found->data = data;
   }
 
   return true;
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index b5199ba..ee9045e 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -364,25 +364,20 @@
     return true;
   }
 
-  bool have_supported_versions;
-  CBS supported_versions;
-  const SSL_EXTENSION_TYPE ext_types[] = {
-      {TLSEXT_TYPE_supported_versions, &have_supported_versions,
-       &supported_versions},
-  };
+  SSLExtension supported_versions(TLSEXT_TYPE_supported_versions);
   CBS extensions = server_hello.extensions;
-  if (!ssl_parse_extensions(&extensions, out_alert, ext_types,
+  if (!ssl_parse_extensions(&extensions, out_alert, {&supported_versions},
                             /*ignore_unknown=*/true)) {
     return false;
   }
 
-  if (!have_supported_versions) {
+  if (!supported_versions.present) {
     *out_version = server_hello.legacy_version;
     return true;
   }
 
-  if (!CBS_get_u16(&supported_versions, out_version) ||
-       CBS_len(&supported_versions) != 0) {
+  if (!CBS_get_u16(&supported_versions.data, out_version) ||
+       CBS_len(&supported_versions.data) != 0) {
     *out_alert = SSL_AD_DECODE_ERROR;
     return false;
   }
diff --git a/ssl/internal.h b/ssl/internal.h
index 0c91724..f9cee53 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -146,6 +146,7 @@
 
 #include <stdlib.h>
 
+#include <initializer_list>
 #include <limits>
 #include <new>
 #include <type_traits>
@@ -2219,19 +2220,25 @@
 bool ssl_negotiate_alps(SSL_HANDSHAKE *hs, uint8_t *out_alert,
                         const SSL_CLIENT_HELLO *client_hello);
 
-struct SSL_EXTENSION_TYPE {
+struct SSLExtension {
+  SSLExtension(uint16_t type_arg, bool allowed_arg = true)
+      : type(type_arg), allowed(allowed_arg), present(false) {
+    CBS_init(&data, nullptr, 0);
+  }
+
   uint16_t type;
-  bool *out_present;
-  CBS *out_data;
+  bool allowed;
+  bool present;
+  CBS data;
 };
 
 // ssl_parse_extensions parses a TLS extensions block out of |cbs| and advances
-// it. It writes the parsed extensions to pointers denoted by |ext_types|. On
-// success, it fills in the |out_present| and |out_data| fields and returns
-// true. Otherwise, it sets |*out_alert| to an alert to send and returns false.
-// Unknown extensions are rejected unless |ignore_unknown| is true.
+// it. It writes the parsed extensions to pointers in |extensions|. On success,
+// it fills in the |present| and |data| fields and returns true. Otherwise, it
+// sets |*out_alert| to an alert to send and returns false. Unknown extensions
+// are rejected unless |ignore_unknown| is true.
 bool ssl_parse_extensions(const CBS *cbs, uint8_t *out_alert,
-                          Span<const SSL_EXTENSION_TYPE> ext_types,
+                          std::initializer_list<SSLExtension *> extensions,
                           bool ignore_unknown);
 
 // ssl_verify_peer_cert verifies the peer certificate for |hs|.
diff --git a/ssl/tls13_both.cc b/ssl/tls13_both.cc
index 0354f39..226c67b 100644
--- a/ssl/tls13_both.cc
+++ b/ssl/tls13_both.cc
@@ -235,15 +235,14 @@
     }
 
     // Parse out the extensions.
-    bool have_status_request = false, have_sct = false;
-    CBS status_request, sct;
-    const SSL_EXTENSION_TYPE ext_types[] = {
-        {TLSEXT_TYPE_status_request, &have_status_request, &status_request},
-        {TLSEXT_TYPE_certificate_timestamp, &have_sct, &sct},
-    };
-
+    SSLExtension status_request(
+        TLSEXT_TYPE_status_request,
+        !ssl->server && hs->config->ocsp_stapling_enabled);
+    SSLExtension sct(
+        TLSEXT_TYPE_certificate_timestamp,
+        !ssl->server && hs->config->signed_cert_timestamps_enabled);
     uint8_t alert = SSL_AD_DECODE_ERROR;
-    if (!ssl_parse_extensions(&extensions, &alert, ext_types,
+    if (!ssl_parse_extensions(&extensions, &alert, {&status_request, &sct},
                               /*ignore_unknown=*/false)) {
       ssl_send_alert(ssl, SSL3_AL_FATAL, alert);
       return false;
@@ -251,20 +250,14 @@
 
     // All Certificate extensions are parsed, but only the leaf extensions are
     // stored.
-    if (have_status_request) {
-      if (ssl->server || !hs->config->ocsp_stapling_enabled) {
-        OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_EXTENSION);
-        ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNSUPPORTED_EXTENSION);
-        return false;
-      }
-
+    if (status_request.present) {
       uint8_t status_type;
       CBS ocsp_response;
-      if (!CBS_get_u8(&status_request, &status_type) ||
+      if (!CBS_get_u8(&status_request.data, &status_type) ||
           status_type != TLSEXT_STATUSTYPE_ocsp ||
-          !CBS_get_u24_length_prefixed(&status_request, &ocsp_response) ||
+          !CBS_get_u24_length_prefixed(&status_request.data, &ocsp_response) ||
           CBS_len(&ocsp_response) == 0 ||
-          CBS_len(&status_request) != 0) {
+          CBS_len(&status_request.data) != 0) {
         ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
         return false;
       }
@@ -279,14 +272,8 @@
       }
     }
 
-    if (have_sct) {
-      if (ssl->server || !hs->config->signed_cert_timestamps_enabled) {
-        OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_EXTENSION);
-        ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNSUPPORTED_EXTENSION);
-        return false;
-      }
-
-      if (!ssl_is_sct_list_valid(&sct)) {
+    if (sct.present) {
+      if (!ssl_is_sct_list_valid(&sct.data)) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_ERROR_PARSING_EXTENSION);
         ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
         return false;
@@ -294,7 +281,7 @@
 
       if (sk_CRYPTO_BUFFER_num(certs.get()) == 1) {
         hs->new_session->signed_cert_timestamp_list.reset(
-            CRYPTO_BUFFER_new_from_CBS(&sct, ssl->ctx->pool));
+            CRYPTO_BUFFER_new_from_CBS(&sct.data, ssl->ctx->pool));
         if (hs->new_session->signed_cert_timestamp_list == nullptr) {
           ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
           return false;
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index 7a19b2a..bd7e63f 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -162,31 +162,25 @@
     return ssl_hs_ok;
   }
 
-  bool have_cookie, have_key_share, have_supported_versions;
-  CBS cookie, key_share, supported_versions;
-  SSL_EXTENSION_TYPE ext_types[] = {
-      {TLSEXT_TYPE_key_share, &have_key_share, &key_share},
-      {TLSEXT_TYPE_cookie, &have_cookie, &cookie},
-      {TLSEXT_TYPE_supported_versions, &have_supported_versions,
-       &supported_versions},
-  };
-
-  if (!ssl_parse_extensions(&server_hello.extensions, &alert, ext_types,
+  SSLExtension cookie(TLSEXT_TYPE_cookie), key_share(TLSEXT_TYPE_key_share),
+      supported_versions(TLSEXT_TYPE_supported_versions);
+  if (!ssl_parse_extensions(&server_hello.extensions, &alert,
+                            {&cookie, &key_share, &supported_versions},
                             /*ignore_unknown=*/false)) {
     ssl_send_alert(ssl, SSL3_AL_FATAL, alert);
     return ssl_hs_error;
   }
 
-  if (!have_cookie && !have_key_share) {
+  if (!cookie.present && !key_share.present) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_EMPTY_HELLO_RETRY_REQUEST);
     ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
     return ssl_hs_error;
   }
-  if (have_cookie) {
+  if (cookie.present) {
     CBS cookie_value;
-    if (!CBS_get_u16_length_prefixed(&cookie, &cookie_value) ||
+    if (!CBS_get_u16_length_prefixed(&cookie.data, &cookie_value) ||
         CBS_len(&cookie_value) == 0 ||
-        CBS_len(&cookie) != 0) {
+        CBS_len(&cookie.data) != 0) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
       ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
       return ssl_hs_error;
@@ -197,9 +191,10 @@
     }
   }
 
-  if (have_key_share) {
+  if (key_share.present) {
     uint16_t group_id;
-    if (!CBS_get_u16(&key_share, &group_id) || CBS_len(&key_share) != 0) {
+    if (!CBS_get_u16(&key_share.data, &group_id) ||
+        CBS_len(&key_share.data) != 0) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
       ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
       return ssl_hs_error;
@@ -316,18 +311,11 @@
   OPENSSL_memcpy(ssl->s3->server_random, CBS_data(&server_hello.random),
                  SSL3_RANDOM_SIZE);
 
-  // Parse out the extensions.
-  bool have_key_share = false, have_pre_shared_key = false,
-       have_supported_versions = false;
-  CBS key_share, pre_shared_key, supported_versions;
-  SSL_EXTENSION_TYPE ext_types[] = {
-      {TLSEXT_TYPE_key_share, &have_key_share, &key_share},
-      {TLSEXT_TYPE_pre_shared_key, &have_pre_shared_key, &pre_shared_key},
-      {TLSEXT_TYPE_supported_versions, &have_supported_versions,
-       &supported_versions},
-  };
-
-  if (!ssl_parse_extensions(&server_hello.extensions, &alert, ext_types,
+  SSLExtension key_share(TLSEXT_TYPE_key_share),
+      pre_shared_key(TLSEXT_TYPE_pre_shared_key, ssl->session != nullptr),
+      supported_versions(TLSEXT_TYPE_supported_versions);
+  if (!ssl_parse_extensions(&server_hello.extensions, &alert,
+                            {&key_share, &pre_shared_key, &supported_versions},
                             /*ignore_unknown=*/false)) {
     ssl_send_alert(ssl, SSL3_AL_FATAL, alert);
     return ssl_hs_error;
@@ -335,8 +323,9 @@
 
   // Recheck supported_versions, in case this is after HelloRetryRequest.
   uint16_t version;
-  if (!have_supported_versions ||
-      !CBS_get_u16(&supported_versions, &version) ||
+  if (!supported_versions.present ||
+      !CBS_get_u16(&supported_versions.data, &version) ||
+      CBS_len(&supported_versions.data) != 0 ||
       version != ssl->version) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_SECOND_SERVERHELLO_VERSION_MISMATCH);
     ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
@@ -344,15 +333,9 @@
   }
 
   alert = SSL_AD_DECODE_ERROR;
-  if (have_pre_shared_key) {
-    if (ssl->session == NULL) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_EXTENSION);
-      ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNSUPPORTED_EXTENSION);
-      return ssl_hs_error;
-    }
-
+  if (pre_shared_key.present) {
     if (!ssl_ext_pre_shared_key_parse_serverhello(hs, &alert,
-                                                  &pre_shared_key)) {
+                                                  &pre_shared_key.data)) {
       ssl_send_alert(ssl, SSL3_AL_FATAL, alert);
       return ssl_hs_error;
     }
@@ -409,7 +392,7 @@
     return ssl_hs_error;
   }
 
-  if (!have_key_share) {
+  if (!key_share.present) {
     // We do not support psk_ke and thus always require a key share.
     OPENSSL_PUT_ERROR(SSL, SSL_R_MISSING_KEY_SHARE);
     ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_MISSING_EXTENSION);
@@ -420,7 +403,7 @@
   Array<uint8_t> dhe_secret;
   alert = SSL_AD_DECODE_ERROR;
   if (!ssl_ext_key_share_parse_serverhello(hs, &dhe_secret, &alert,
-                                           &key_share)) {
+                                           &key_share.data)) {
     ssl_send_alert(ssl, SSL3_AL_FATAL, alert);
     return ssl_hs_error;
   }
@@ -456,7 +439,7 @@
                      SSL3_RANDOM_SIZE);
     } else {
       // Resuming against the ClientHelloOuter was an unsolicited extension.
-      if (have_pre_shared_key) {
+      if (pre_shared_key.present) {
         OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_EXTENSION);
         ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNSUPPORTED_EXTENSION);
         return ssl_hs_error;
@@ -600,25 +583,19 @@
   }
 
 
-  bool have_sigalgs = false, have_ca = false;
-  CBS sigalgs, ca;
-  const SSL_EXTENSION_TYPE ext_types[] = {
-    {TLSEXT_TYPE_signature_algorithms, &have_sigalgs, &sigalgs},
-    {TLSEXT_TYPE_certificate_authorities, &have_ca, &ca},
-  };
-
+  SSLExtension sigalgs(TLSEXT_TYPE_signature_algorithms),
+      ca(TLSEXT_TYPE_certificate_authorities);
   CBS body = msg.body, context, extensions, supported_signature_algorithms;
   uint8_t alert = SSL_AD_DECODE_ERROR;
   if (!CBS_get_u8_length_prefixed(&body, &context) ||
       // The request context is always empty during the handshake.
       CBS_len(&context) != 0 ||
-      !CBS_get_u16_length_prefixed(&body, &extensions) ||
+      !CBS_get_u16_length_prefixed(&body, &extensions) ||  //
       CBS_len(&body) != 0 ||
-      !ssl_parse_extensions(&extensions, &alert, ext_types,
+      !ssl_parse_extensions(&extensions, &alert, {&sigalgs, &ca},
                             /*ignore_unknown=*/true) ||
-      (have_ca && CBS_len(&ca) == 0) ||
-      !have_sigalgs ||
-      !CBS_get_u16_length_prefixed(&sigalgs,
+      !sigalgs.present ||
+      !CBS_get_u16_length_prefixed(&sigalgs.data,
                                    &supported_signature_algorithms) ||
       !tls1_parse_peer_sigalgs(hs, &supported_signature_algorithms)) {
     ssl_send_alert(ssl, SSL3_AL_FATAL, alert);
@@ -626,8 +603,8 @@
     return ssl_hs_error;
   }
 
-  if (have_ca) {
-    hs->ca_names = ssl_parse_client_CA_list(ssl, &alert, &ca);
+  if (ca.present) {
+    hs->ca_names = ssl_parse_client_CA_list(ssl, &alert, &ca.data);
     if (!hs->ca_names) {
       ssl_send_alert(ssl, SSL3_AL_FATAL, alert);
       return ssl_hs_error;
@@ -1050,23 +1027,17 @@
     return nullptr;
   }
 
-  // Parse out the extensions.
-  bool have_early_data = false;
-  CBS early_data;
-  const SSL_EXTENSION_TYPE ext_types[] = {
-      {TLSEXT_TYPE_early_data, &have_early_data, &early_data},
-  };
-
+  SSLExtension early_data(TLSEXT_TYPE_early_data);
   uint8_t alert = SSL_AD_DECODE_ERROR;
-  if (!ssl_parse_extensions(&extensions, &alert, ext_types,
+  if (!ssl_parse_extensions(&extensions, &alert, {&early_data},
                             /*ignore_unknown=*/true)) {
     ssl_send_alert(ssl, SSL3_AL_FATAL, alert);
     return nullptr;
   }
 
-  if (have_early_data) {
-    if (!CBS_get_u32(&early_data, &session->ticket_max_early_data) ||
-        CBS_len(&early_data) != 0) {
+  if (early_data.present) {
+    if (!CBS_get_u32(&early_data.data, &session->ticket_max_early_data) ||
+        CBS_len(&early_data.data) != 0) {
       ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
       OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
       return nullptr;
diff --git a/ssl/tls13_server.cc b/ssl/tls13_server.cc
index f1a62b2..6509cc4 100644
--- a/ssl/tls13_server.cc
+++ b/ssl/tls13_server.cc
@@ -1040,20 +1040,15 @@
       return ssl_hs_error;
     }
 
-    // Parse out the extensions.
-    bool have_application_settings = false;
-    CBS application_settings;
-    SSL_EXTENSION_TYPE ext_types[] = {{TLSEXT_TYPE_application_settings,
-                                       &have_application_settings,
-                                       &application_settings}};
+    SSLExtension application_settings(TLSEXT_TYPE_application_settings);
     uint8_t alert = SSL_AD_DECODE_ERROR;
-    if (!ssl_parse_extensions(&extensions, &alert, ext_types,
+    if (!ssl_parse_extensions(&extensions, &alert, {&application_settings},
                               /*ignore_unknown=*/false)) {
       ssl_send_alert(ssl, SSL3_AL_FATAL, alert);
       return ssl_hs_error;
     }
 
-    if (!have_application_settings) {
+    if (!application_settings.present) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_MISSING_EXTENSION);
       ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_MISSING_EXTENSION);
       return ssl_hs_error;
@@ -1062,7 +1057,7 @@
     // Note that, if 0-RTT was accepted, these values will already have been
     // initialized earlier.
     if (!hs->new_session->peer_application_settings.CopyFrom(
-            application_settings) ||
+            application_settings.data) ||
         !ssl_hash_message(hs, msg)) {
       ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
       return ssl_hs_error;