Add tests for installing the certificate on the early callback.

Test both asynchronous and synchronous versions. This callback is somewhat
different from others. It's NOT called a second time when the handshake is
resumed. This appears to be intentional and not a mismerge from the internal
patch. The caller is expected to set up any state before resuming the handshake
state machine.

Also test the early callback returning an error.

Change-Id: If5e6eddd7007ea5cdd7533b4238e456106b95cbd
Reviewed-on: https://boringssl-review.googlesource.com/3590
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index c5645a4..9cc8d3f 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -805,7 +805,8 @@
    * before the decision whether to resume a session is made. It may return one
    * to continue the handshake or zero to cause the handshake loop to return
    * with an error and cause SSL_get_error to return
-   * SSL_ERROR_PENDING_CERTIFICATE. */
+   * SSL_ERROR_PENDING_CERTIFICATE. Note: when the handshake loop is resumed, it
+   * will not call the callback a second time. */
   int (*select_certificate_cb)(const struct ssl_early_callback_ctx *);
 
   /* quiet_shutdown is true if the connection should not send a close_notify on
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index 28846ca..ad4fa1a 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -139,39 +139,51 @@
   const TestConfig *config = GetConfigPtr(ctx->ssl);
   GetTestState(ctx->ssl)->early_callback_called = true;
 
-  if (config->expected_server_name.empty()) {
-    return 1;
+  if (!config->expected_server_name.empty()) {
+    const uint8_t *extension_data;
+    size_t extension_len;
+    CBS extension, server_name_list, host_name;
+    uint8_t name_type;
+
+    if (!SSL_early_callback_ctx_extension_get(ctx, TLSEXT_TYPE_server_name,
+                                              &extension_data,
+                                              &extension_len)) {
+      fprintf(stderr, "Could not find server_name extension.\n");
+      return -1;
+    }
+
+    CBS_init(&extension, extension_data, extension_len);
+    if (!CBS_get_u16_length_prefixed(&extension, &server_name_list) ||
+        CBS_len(&extension) != 0 ||
+        !CBS_get_u8(&server_name_list, &name_type) ||
+        name_type != TLSEXT_NAMETYPE_host_name ||
+        !CBS_get_u16_length_prefixed(&server_name_list, &host_name) ||
+        CBS_len(&server_name_list) != 0) {
+      fprintf(stderr, "Could not decode server_name extension.\n");
+      return -1;
+    }
+
+    if (!CBS_mem_equal(&host_name,
+                       (const uint8_t*)config->expected_server_name.data(),
+                       config->expected_server_name.size())) {
+      fprintf(stderr, "Server name mismatch.\n");
+    }
   }
 
-  const uint8_t *extension_data;
-  size_t extension_len;
-  CBS extension, server_name_list, host_name;
-  uint8_t name_type;
-
-  if (!SSL_early_callback_ctx_extension_get(ctx, TLSEXT_TYPE_server_name,
-                                            &extension_data,
-                                            &extension_len)) {
-    fprintf(stderr, "Could not find server_name extension.\n");
+  if (config->fail_early_callback) {
     return -1;
   }
 
-  CBS_init(&extension, extension_data, extension_len);
-  if (!CBS_get_u16_length_prefixed(&extension, &server_name_list) ||
-      CBS_len(&extension) != 0 ||
-      !CBS_get_u8(&server_name_list, &name_type) ||
-      name_type != TLSEXT_NAMETYPE_host_name ||
-      !CBS_get_u16_length_prefixed(&server_name_list, &host_name) ||
-      CBS_len(&server_name_list) != 0) {
-    fprintf(stderr, "Could not decode server_name extension.\n");
-    return -1;
+  // Install the certificate in the early callback.
+  if (config->use_early_callback) {
+    if (config->async) {
+      // Install the certificate asynchronously.
+      return 0;
+    }
+    if (!InstallCertificate(ctx->ssl)) {
+      return -1;
+    }
   }
-
-  if (!CBS_mem_equal(&host_name,
-                     (const uint8_t*)config->expected_server_name.data(),
-                     config->expected_server_name.size())) {
-    fprintf(stderr, "Server name mismatch.\n");
-  }
-
   return 1;
 }
 
@@ -464,6 +476,9 @@
       GetTestState(ssl)->session =
           std::move(GetTestState(ssl)->pending_session);
       return true;
+    case SSL_ERROR_PENDING_CERTIFICATE:
+      // The handshake will resume without a second call to the early callback.
+      return InstallCertificate(ssl);
     default:
       return false;
   }
@@ -492,12 +507,13 @@
       !SSL_set_mode(ssl.get(), SSL_MODE_SEND_FALLBACK_SCSV)) {
     return false;
   }
-  if (config->async) {
-    // TODO(davidben): Also test |s->ctx->client_cert_cb| on the client and
-    // |s->ctx->select_certificate_cb| on the server.
-    SSL_set_cert_cb(ssl.get(), CertCallback, NULL);
-  } else if (!InstallCertificate(ssl.get())) {
-    return false;
+  if (!config->use_early_callback) {
+    if (config->async) {
+      // TODO(davidben): Also test |s->ctx->client_cert_cb| on the client.
+      SSL_set_cert_cb(ssl.get(), CertCallback, NULL);
+    } else if (!InstallCertificate(ssl.get())) {
+      return false;
+    }
   }
   if (config->require_any_client_certificate) {
     SSL_set_verify(ssl.get(), SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
@@ -650,6 +666,11 @@
       return false;
     }
 
+    if (config->is_server && !GetTestState(ssl.get())->early_callback_called) {
+      fprintf(stderr, "early callback not called\n");
+      return false;
+    }
+
     if (!config->expected_server_name.empty()) {
       const char *server_name =
         SSL_get_servername(ssl.get(), TLSEXT_NAMETYPE_host_name);
@@ -658,11 +679,6 @@
                 server_name, config->expected_server_name.c_str());
         return false;
       }
-
-      if (!GetTestState(ssl.get())->early_callback_called) {
-        fprintf(stderr, "early callback not called\n");
-        return false;
-      }
     }
 
     if (!config->expected_certificate_types.empty()) {
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 1c84440..4079863 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -780,6 +780,14 @@
 		shouldFail:    true,
 		expectedError: ":UNEXPECTED_RECORD:",
 	},
+	{
+		testType:           serverTest,
+		name:               "FailEarlyCallback",
+		flags:              []string{"-fail-early-callback"},
+		shouldFail:         true,
+		expectedError:      ":CONNECTION_REJECTED:",
+		expectedLocalError: "remote error: access denied",
+	},
 }
 
 func doExchange(test *testCase, config *Config, conn net.Conn, messageLen int, isResume bool) error {
@@ -1657,6 +1665,18 @@
 		flags:         append(flags, "-implicit-handshake"),
 		resumeSession: true,
 	})
+	testCases = append(testCases, testCase{
+		protocol: protocol,
+		testType: serverTest,
+		name:     "Basic-Server-EarlyCallback" + suffix,
+		config: Config{
+			Bugs: ProtocolBugs{
+				MaxHandshakeRecordLength: maxHandshakeRecordLength,
+			},
+		},
+		flags:         append(flags, "-use-early-callback"),
+		resumeSession: true,
+	})
 
 	// TLS client auth.
 	testCases = append(testCases, testCase{
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index afcb106..4db72b4 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -73,6 +73,8 @@
     &TestConfig::enable_signed_cert_timestamps },
   { "-fastradio-padding", &TestConfig::fastradio_padding },
   { "-implicit-handshake", &TestConfig::implicit_handshake },
+  { "-use-early-callback", &TestConfig::use_early_callback },
+  { "-fail-early-callback", &TestConfig::fail_early_callback },
 };
 
 const Flag<std::string> kStringFlags[] = {
@@ -138,7 +140,9 @@
       min_version(0),
       max_version(0),
       mtu(0),
-      implicit_handshake(false) {
+      implicit_handshake(false),
+      use_early_callback(false),
+      fail_early_callback(false) {
 }
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_config) {
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index 61b25c0..a54fb23 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -68,6 +68,8 @@
   int max_version;
   int mtu;
   bool implicit_handshake;
+  bool use_early_callback;
+  bool fail_early_callback;
 };
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_config);