Test async channel ID callback.

Start exercising the various async callbacks, starting with channel ID. These
will run under the existing state machine coverage tests; -async will also
enable every asynchronous callback we can.

Change-Id: I173148d93d3a9c575b3abc3e2aceb77968b88f0e
Reviewed-on: https://boringssl-review.googlesource.com/3342
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index eddbefc..908b020 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -40,23 +40,45 @@
   return 1;
 }
 
-static int g_ex_data_index = 0;
-static int g_ex_data_clock_index = 0;
+struct AsyncState {
+  ScopedEVP_PKEY channel_id;
+};
+
+static void AsyncExFree(void *parent, void *ptr, CRYPTO_EX_DATA *ad, int index,
+                        long argl, void *argp) {
+  delete ((AsyncState *)ptr);
+}
+
+static int g_config_index = 0;
+static int g_clock_index = 0;
+static int g_async_index = 0;
 
 static bool SetConfigPtr(SSL *ssl, const TestConfig *config) {
-  return SSL_set_ex_data(ssl, g_ex_data_index, (void *)config) == 1;
+  return SSL_set_ex_data(ssl, g_config_index, (void *)config) == 1;
 }
 
 static const TestConfig *GetConfigPtr(SSL *ssl) {
-  return (const TestConfig *)SSL_get_ex_data(ssl, g_ex_data_index);
+  return (const TestConfig *)SSL_get_ex_data(ssl, g_config_index);
 }
 
 static bool SetClockPtr(SSL *ssl, OPENSSL_timeval *clock) {
-  return SSL_set_ex_data(ssl, g_ex_data_clock_index, (void *)clock) == 1;
+  return SSL_set_ex_data(ssl, g_clock_index, (void *)clock) == 1;
 }
 
 static OPENSSL_timeval *GetClockPtr(SSL *ssl) {
-  return (OPENSSL_timeval *)SSL_get_ex_data(ssl, g_ex_data_clock_index);
+  return (OPENSSL_timeval *)SSL_get_ex_data(ssl, g_clock_index);
+}
+
+static bool SetAsyncState(SSL *ssl, std::unique_ptr<AsyncState> async) {
+  if (SSL_set_ex_data(ssl, g_async_index, (void *)async.get()) == 1) {
+    async.release();
+    return true;
+  }
+  return false;
+}
+
+static AsyncState *GetAsyncState(SSL *ssl) {
+  return (AsyncState *)SSL_get_ex_data(ssl, g_async_index);
 }
 
 static ScopedEVP_PKEY LoadPrivateKey(const std::string &file) {
@@ -236,6 +258,10 @@
   *out_clock = *GetClockPtr(ssl);
 }
 
+static void channel_id_callback(SSL *ssl, EVP_PKEY **out_pkey) {
+  *out_pkey = GetAsyncState(ssl)->channel_id.release();
+}
+
 static ScopedSSL_CTX setup_ctx(const TestConfig *config) {
   ScopedSSL_CTX ssl_ctx(SSL_CTX_new(
       config->is_dtls ? DTLS_method() : TLS_method()));
@@ -283,6 +309,7 @@
   SSL_CTX_set_cookie_verify_cb(ssl_ctx.get(), cookie_verify_callback);
 
   ssl_ctx->tlsext_channel_id_enabled_new = 1;
+  SSL_CTX_set_channel_id_cb(ssl_ctx.get(), channel_id_callback);
 
   ssl_ctx->current_time_cb = current_time_cb;
 
@@ -314,15 +341,20 @@
 
   // See if we needed to read or write more. If so, allow one byte through on
   // the appropriate end to maximally stress the state machine.
-  int err = SSL_get_error(ssl, ret);
-  if (err == SSL_ERROR_WANT_READ) {
-    async_bio_allow_read(async, 1);
-    return 1;
-  } else if (err == SSL_ERROR_WANT_WRITE) {
-    async_bio_allow_write(async, 1);
-    return 1;
+  switch (SSL_get_error(ssl, ret)) {
+    case SSL_ERROR_WANT_READ:
+      async_bio_allow_read(async, 1);
+      return 1;
+    case SSL_ERROR_WANT_WRITE:
+      async_bio_allow_write(async, 1);
+      return 1;
+    case SSL_ERROR_WANT_CHANNEL_ID_LOOKUP:
+      GetAsyncState(ssl)->channel_id =
+          LoadPrivateKey(GetConfigPtr(ssl)->send_channel_id);
+      return 1;
+    default:
+      return 0;
   }
-  return 0;
 }
 
 static int do_exchange(ScopedSSL_SESSION *out_session,
@@ -341,7 +373,8 @@
   }
 
   if (!SetConfigPtr(ssl.get(), config) ||
-      !SetClockPtr(ssl.get(), &clock)) {
+      !SetClockPtr(ssl.get(), &clock) |
+      !SetAsyncState(ssl.get(), std::unique_ptr<AsyncState>(new AsyncState))) {
     BIO_print_errors_fp(stdout);
     return 1;
   }
@@ -405,10 +438,13 @@
   }
   if (!config->send_channel_id.empty()) {
     SSL_enable_tls_channel_id(ssl.get());
-    ScopedEVP_PKEY pkey = LoadPrivateKey(config->send_channel_id);
-    if (!pkey || !SSL_set1_tls_channel_id(ssl.get(), pkey.get())) {
-      BIO_print_errors_fp(stdout);
-      return 1;
+    if (!config->async) {
+      // The async case will be supplied by |channel_id_callback|.
+      ScopedEVP_PKEY pkey = LoadPrivateKey(config->send_channel_id);
+      if (!pkey || !SSL_set1_tls_channel_id(ssl.get(), pkey.get())) {
+        BIO_print_errors_fp(stdout);
+        return 1;
+      }
     }
   }
   if (!config->host_name.empty()) {
@@ -734,9 +770,10 @@
   if (!SSL_library_init()) {
     return 1;
   }
-  g_ex_data_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
-  g_ex_data_clock_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
-  if (g_ex_data_index < 0 || g_ex_data_clock_index < 0) {
+  g_config_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
+  g_clock_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
+  g_async_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, AsyncExFree);
+  if (g_config_index < 0 || g_clock_index < 0 || g_async_index < 0) {
     return 1;
   }