Refactor async logic in bssl_shim slightly.

Move the state to TestState rather than passing pointers to them everywhere.
Also move SSL_read and SSL_write retry loops into helper functions so they
aren't repeated everywhere. This also makes the SSL_write calls all
consistently account for partial writes.

Change-Id: I9bc083a03da6a77ab2fc03c29d4028435fc02620
Reviewed-on: https://boringssl-review.googlesource.com/4214
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index b22819e..eca7b6b 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -68,6 +68,13 @@
 }
 
 struct TestState {
+  // async_bio is async BIO which pauses reads and writes.
+  BIO *async_bio = nullptr;
+  // clock is the current time for the SSL connection.
+  OPENSSL_timeval clock = {0};
+  // clock_delta is how far the clock advanced in the most recent failed
+  // |BIO_read|.
+  OPENSSL_timeval clock_delta = {0};
   ScopedEVP_PKEY channel_id;
   bool cert_ready = false;
   ScopedSSL_SESSION session;
@@ -82,7 +89,6 @@
 }
 
 static int g_config_index = 0;
-static int g_clock_index = 0;
 static int g_state_index = 0;
 
 static bool SetConfigPtr(SSL *ssl, const TestConfig *config) {
@@ -93,14 +99,6 @@
   return (const TestConfig *)SSL_get_ex_data(ssl, g_config_index);
 }
 
-static bool SetClockPtr(SSL *ssl, OPENSSL_timeval *clock) {
-  return SSL_set_ex_data(ssl, g_clock_index, (void *)clock) == 1;
-}
-
-static OPENSSL_timeval *GetClockPtr(const SSL *ssl) {
-  return (OPENSSL_timeval *)SSL_get_ex_data(ssl, g_clock_index);
-}
-
 static bool SetTestState(SSL *ssl, std::unique_ptr<TestState> async) {
   if (SSL_set_ex_data(ssl, g_state_index, (void *)async.get()) == 1) {
     async.release();
@@ -280,7 +278,7 @@
 }
 
 static void CurrentTimeCallback(SSL *ssl, OPENSSL_timeval *out_clock) {
-  *out_clock = *GetClockPtr(ssl);
+  *out_clock = GetTestState(ssl)->clock;
 }
 
 static void ChannelIdCallback(SSL *ssl, EVP_PKEY **out_pkey) {
@@ -455,23 +453,22 @@
 
 // RetryAsync is called after a failed operation on |ssl| with return code
 // |ret|. If the operation should be retried, it simulates one asynchronous
-// event and returns true. Otherwise it returns false. |async| and |clock_delta|
-// are the AsyncBio and simulated timeout for |ssl|, respectively.
-static bool RetryAsync(SSL *ssl, int ret, BIO *async,
-                       OPENSSL_timeval *clock_delta) {
+// event and returns true. Otherwise it returns false.
+static bool RetryAsync(SSL *ssl, int ret) {
   // No error; don't retry.
   if (ret >= 0) {
     return false;
   }
 
-  if (clock_delta->tv_usec != 0 || clock_delta->tv_sec != 0) {
+  TestState *test_state = GetTestState(ssl);
+  if (test_state->clock_delta.tv_usec != 0 ||
+      test_state->clock_delta.tv_sec != 0) {
     // Process the timeout and retry.
-    OPENSSL_timeval *clock = GetClockPtr(ssl);
-    clock->tv_usec += clock_delta->tv_usec;
-    clock->tv_sec += clock->tv_usec / 1000000;
-    clock->tv_usec %= 1000000;
-    clock->tv_sec += clock_delta->tv_sec;
-    memset(clock_delta, 0, sizeof(*clock_delta));
+    test_state->clock.tv_usec += test_state->clock_delta.tv_usec;
+    test_state->clock.tv_sec += test_state->clock.tv_usec / 1000000;
+    test_state->clock.tv_usec %= 1000000;
+    test_state->clock.tv_sec += test_state->clock_delta.tv_sec;
+    memset(&test_state->clock_delta, 0, sizeof(test_state->clock_delta));
 
     if (DTLSv1_handle_timeout(ssl) < 0) {
       printf("Error retransmitting.\n");
@@ -484,25 +481,24 @@
   // the appropriate end to maximally stress the state machine.
   switch (SSL_get_error(ssl, ret)) {
     case SSL_ERROR_WANT_READ:
-      AsyncBioAllowRead(async, 1);
+      AsyncBioAllowRead(test_state->async_bio, 1);
       return true;
     case SSL_ERROR_WANT_WRITE:
-      AsyncBioAllowWrite(async, 1);
+      AsyncBioAllowWrite(test_state->async_bio, 1);
       return true;
     case SSL_ERROR_WANT_CHANNEL_ID_LOOKUP: {
       ScopedEVP_PKEY pkey = LoadPrivateKey(GetConfigPtr(ssl)->send_channel_id);
       if (!pkey) {
         return false;
       }
-      GetTestState(ssl)->channel_id = std::move(pkey);
+      test_state->channel_id = std::move(pkey);
       return true;
     }
     case SSL_ERROR_WANT_X509_LOOKUP:
-      GetTestState(ssl)->cert_ready = true;
+      test_state->cert_ready = true;
       return true;
     case SSL_ERROR_PENDING_SESSION:
-      GetTestState(ssl)->session =
-          std::move(GetTestState(ssl)->pending_session);
+      test_state->session = std::move(test_state->pending_session);
       return true;
     case SSL_ERROR_PENDING_CERTIFICATE:
       // The handshake will resume without a second call to the early callback.
@@ -512,6 +508,32 @@
   }
 }
 
+// DoRead reads from |ssl|, resolving any asynchronous operations. It returns
+// the result value of the final |SSL_read| call.
+static int DoRead(SSL *ssl, uint8_t *out, size_t max_out) {
+  const TestConfig *config = GetConfigPtr(ssl);
+  int ret;
+  do {
+    ret = SSL_read(ssl, out, max_out);
+  } while (config->async && RetryAsync(ssl, ret));
+  return ret;
+}
+
+// WriteAll writes |in_len| bytes from |in| to |ssl|, resolving any asynchronous
+// operations. It returns the result of the final |SSL_write| call.
+static int WriteAll(SSL *ssl, const uint8_t *in, size_t in_len) {
+  const TestConfig *config = GetConfigPtr(ssl);
+  int ret;
+  do {
+    ret = SSL_write(ssl, in, in_len);
+    if (ret > 0) {
+      in += ret;
+      in_len -= ret;
+    }
+  } while ((config->async && RetryAsync(ssl, ret)) || (ret > 0 && in_len > 0));
+  return ret;
+}
+
 // DoExchange runs a test SSL exchange against the peer. On success, it returns
 // true and sets |*out_session| to the negotiated SSL session. If the test is a
 // resumption attempt, |is_resume| is true and |session| is the session from the
@@ -519,14 +541,12 @@
 static bool DoExchange(ScopedSSL_SESSION *out_session, SSL_CTX *ssl_ctx,
                        const TestConfig *config, bool is_resume,
                        SSL_SESSION *session) {
-  OPENSSL_timeval clock = {0}, clock_delta = {0};
   ScopedSSL ssl(SSL_new(ssl_ctx));
   if (!ssl) {
     return false;
   }
 
   if (!SetConfigPtr(ssl.get(), config) ||
-      !SetClockPtr(ssl.get(), &clock) |
       !SetTestState(ssl.get(), std::unique_ptr<TestState>(new TestState))) {
     return false;
   }
@@ -647,16 +667,16 @@
     return false;
   }
   if (config->is_dtls) {
-    ScopedBIO packeted = PacketedBioCreate(&clock_delta);
+    ScopedBIO packeted =
+        PacketedBioCreate(&GetTestState(ssl.get())->clock_delta);
     BIO_push(packeted.get(), bio.release());
     bio = std::move(packeted);
   }
-  BIO *async = NULL;
   if (config->async) {
     ScopedBIO async_scoped =
         config->is_dtls ? AsyncBioCreateDatagram() : AsyncBioCreate();
     BIO_push(async_scoped.get(), bio.release());
-    async = async_scoped.get();
+    GetTestState(ssl.get())->async_bio = async_scoped.get();
     bio = std::move(async_scoped);
   }
   SSL_set_bio(ssl.get(), bio.get(), bio.get());
@@ -689,7 +709,7 @@
       } else {
         ret = SSL_connect(ssl.get());
       }
-    } while (config->async && RetryAsync(ssl.get(), ret, async, &clock_delta));
+    } while (config->async && RetryAsync(ssl.get(), ret));
     if (ret != 1) {
       return false;
     }
@@ -843,40 +863,25 @@
         0, 1, 255, 256, 257, 16383, 16384, 16385, 32767, 32768, 32769};
     for (size_t i = 0; i < sizeof(kRecordSizes) / sizeof(kRecordSizes[0]);
          i++) {
-      int w;
       const size_t len = kRecordSizes[i];
-      size_t off = 0;
-
       if (len > sizeof(buf)) {
         fprintf(stderr, "Bad kRecordSizes value.\n");
         return false;
       }
-
-      do {
-        w = SSL_write(ssl.get(), buf + off, len - off);
-        if (w > 0) {
-          off += (size_t) w;
-        }
-      } while ((config->async && RetryAsync(ssl.get(), w, async, &clock_delta)) ||
-               (w > 0 && off < len));
-
-      if (w < 0 || off != len) {
+      if (WriteAll(ssl.get(), buf, len) < 0) {
         return false;
       }
     }
   } else {
     if (config->shim_writes_first) {
-      int w;
-      do {
-        w = SSL_write(ssl.get(), "hello", 5);
-      } while (config->async && RetryAsync(ssl.get(), w, async, &clock_delta));
+      if (WriteAll(ssl.get(), reinterpret_cast<const uint8_t *>("hello"),
+                   5) < 0) {
+        return false;
+      }
     }
     for (;;) {
       uint8_t buf[512];
-      int n;
-      do {
-        n = SSL_read(ssl.get(), buf, sizeof(buf));
-      } while (config->async && RetryAsync(ssl.get(), n, async, &clock_delta));
+      int n = DoRead(ssl.get(), buf, sizeof(buf));
       int err = SSL_get_error(ssl.get(), n);
       if (err == SSL_ERROR_ZERO_RETURN ||
           (n == 0 && err == SSL_ERROR_SYSCALL)) {
@@ -910,11 +915,7 @@
       for (int i = 0; i < n; i++) {
         buf[i] ^= 0xff;
       }
-      int w;
-      do {
-        w = SSL_write(ssl.get(), buf, n);
-      } while (config->async && RetryAsync(ssl.get(), w, async, &clock_delta));
-      if (w != n) {
+      if (WriteAll(ssl.get(), buf, n) < 0) {
         return false;
       }
     }
@@ -950,9 +951,8 @@
     return 1;
   }
   g_config_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
-  g_clock_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
   g_state_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, TestStateExFree);
-  if (g_config_index < 0 || g_clock_index < 0 || g_state_index < 0) {
+  if (g_config_index < 0 || g_state_index < 0) {
     return 1;
   }