Refactor async logic in bssl_shim slightly.
Move the state to TestState rather than passing pointers to them everywhere.
Also move SSL_read and SSL_write retry loops into helper functions so they
aren't repeated everywhere. This also makes the SSL_write calls all
consistently account for partial writes.
Change-Id: I9bc083a03da6a77ab2fc03c29d4028435fc02620
Reviewed-on: https://boringssl-review.googlesource.com/4214
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index b22819e..eca7b6b 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -68,6 +68,13 @@
}
struct TestState {
+ // async_bio is async BIO which pauses reads and writes.
+ BIO *async_bio = nullptr;
+ // clock is the current time for the SSL connection.
+ OPENSSL_timeval clock = {0};
+ // clock_delta is how far the clock advanced in the most recent failed
+ // |BIO_read|.
+ OPENSSL_timeval clock_delta = {0};
ScopedEVP_PKEY channel_id;
bool cert_ready = false;
ScopedSSL_SESSION session;
@@ -82,7 +89,6 @@
}
static int g_config_index = 0;
-static int g_clock_index = 0;
static int g_state_index = 0;
static bool SetConfigPtr(SSL *ssl, const TestConfig *config) {
@@ -93,14 +99,6 @@
return (const TestConfig *)SSL_get_ex_data(ssl, g_config_index);
}
-static bool SetClockPtr(SSL *ssl, OPENSSL_timeval *clock) {
- return SSL_set_ex_data(ssl, g_clock_index, (void *)clock) == 1;
-}
-
-static OPENSSL_timeval *GetClockPtr(const SSL *ssl) {
- return (OPENSSL_timeval *)SSL_get_ex_data(ssl, g_clock_index);
-}
-
static bool SetTestState(SSL *ssl, std::unique_ptr<TestState> async) {
if (SSL_set_ex_data(ssl, g_state_index, (void *)async.get()) == 1) {
async.release();
@@ -280,7 +278,7 @@
}
static void CurrentTimeCallback(SSL *ssl, OPENSSL_timeval *out_clock) {
- *out_clock = *GetClockPtr(ssl);
+ *out_clock = GetTestState(ssl)->clock;
}
static void ChannelIdCallback(SSL *ssl, EVP_PKEY **out_pkey) {
@@ -455,23 +453,22 @@
// RetryAsync is called after a failed operation on |ssl| with return code
// |ret|. If the operation should be retried, it simulates one asynchronous
-// event and returns true. Otherwise it returns false. |async| and |clock_delta|
-// are the AsyncBio and simulated timeout for |ssl|, respectively.
-static bool RetryAsync(SSL *ssl, int ret, BIO *async,
- OPENSSL_timeval *clock_delta) {
+// event and returns true. Otherwise it returns false.
+static bool RetryAsync(SSL *ssl, int ret) {
// No error; don't retry.
if (ret >= 0) {
return false;
}
- if (clock_delta->tv_usec != 0 || clock_delta->tv_sec != 0) {
+ TestState *test_state = GetTestState(ssl);
+ if (test_state->clock_delta.tv_usec != 0 ||
+ test_state->clock_delta.tv_sec != 0) {
// Process the timeout and retry.
- OPENSSL_timeval *clock = GetClockPtr(ssl);
- clock->tv_usec += clock_delta->tv_usec;
- clock->tv_sec += clock->tv_usec / 1000000;
- clock->tv_usec %= 1000000;
- clock->tv_sec += clock_delta->tv_sec;
- memset(clock_delta, 0, sizeof(*clock_delta));
+ test_state->clock.tv_usec += test_state->clock_delta.tv_usec;
+ test_state->clock.tv_sec += test_state->clock.tv_usec / 1000000;
+ test_state->clock.tv_usec %= 1000000;
+ test_state->clock.tv_sec += test_state->clock_delta.tv_sec;
+ memset(&test_state->clock_delta, 0, sizeof(test_state->clock_delta));
if (DTLSv1_handle_timeout(ssl) < 0) {
printf("Error retransmitting.\n");
@@ -484,25 +481,24 @@
// the appropriate end to maximally stress the state machine.
switch (SSL_get_error(ssl, ret)) {
case SSL_ERROR_WANT_READ:
- AsyncBioAllowRead(async, 1);
+ AsyncBioAllowRead(test_state->async_bio, 1);
return true;
case SSL_ERROR_WANT_WRITE:
- AsyncBioAllowWrite(async, 1);
+ AsyncBioAllowWrite(test_state->async_bio, 1);
return true;
case SSL_ERROR_WANT_CHANNEL_ID_LOOKUP: {
ScopedEVP_PKEY pkey = LoadPrivateKey(GetConfigPtr(ssl)->send_channel_id);
if (!pkey) {
return false;
}
- GetTestState(ssl)->channel_id = std::move(pkey);
+ test_state->channel_id = std::move(pkey);
return true;
}
case SSL_ERROR_WANT_X509_LOOKUP:
- GetTestState(ssl)->cert_ready = true;
+ test_state->cert_ready = true;
return true;
case SSL_ERROR_PENDING_SESSION:
- GetTestState(ssl)->session =
- std::move(GetTestState(ssl)->pending_session);
+ test_state->session = std::move(test_state->pending_session);
return true;
case SSL_ERROR_PENDING_CERTIFICATE:
// The handshake will resume without a second call to the early callback.
@@ -512,6 +508,32 @@
}
}
+// DoRead reads from |ssl|, resolving any asynchronous operations. It returns
+// the result value of the final |SSL_read| call.
+static int DoRead(SSL *ssl, uint8_t *out, size_t max_out) {
+ const TestConfig *config = GetConfigPtr(ssl);
+ int ret;
+ do {
+ ret = SSL_read(ssl, out, max_out);
+ } while (config->async && RetryAsync(ssl, ret));
+ return ret;
+}
+
+// WriteAll writes |in_len| bytes from |in| to |ssl|, resolving any asynchronous
+// operations. It returns the result of the final |SSL_write| call.
+static int WriteAll(SSL *ssl, const uint8_t *in, size_t in_len) {
+ const TestConfig *config = GetConfigPtr(ssl);
+ int ret;
+ do {
+ ret = SSL_write(ssl, in, in_len);
+ if (ret > 0) {
+ in += ret;
+ in_len -= ret;
+ }
+ } while ((config->async && RetryAsync(ssl, ret)) || (ret > 0 && in_len > 0));
+ return ret;
+}
+
// DoExchange runs a test SSL exchange against the peer. On success, it returns
// true and sets |*out_session| to the negotiated SSL session. If the test is a
// resumption attempt, |is_resume| is true and |session| is the session from the
@@ -519,14 +541,12 @@
static bool DoExchange(ScopedSSL_SESSION *out_session, SSL_CTX *ssl_ctx,
const TestConfig *config, bool is_resume,
SSL_SESSION *session) {
- OPENSSL_timeval clock = {0}, clock_delta = {0};
ScopedSSL ssl(SSL_new(ssl_ctx));
if (!ssl) {
return false;
}
if (!SetConfigPtr(ssl.get(), config) ||
- !SetClockPtr(ssl.get(), &clock) |
!SetTestState(ssl.get(), std::unique_ptr<TestState>(new TestState))) {
return false;
}
@@ -647,16 +667,16 @@
return false;
}
if (config->is_dtls) {
- ScopedBIO packeted = PacketedBioCreate(&clock_delta);
+ ScopedBIO packeted =
+ PacketedBioCreate(&GetTestState(ssl.get())->clock_delta);
BIO_push(packeted.get(), bio.release());
bio = std::move(packeted);
}
- BIO *async = NULL;
if (config->async) {
ScopedBIO async_scoped =
config->is_dtls ? AsyncBioCreateDatagram() : AsyncBioCreate();
BIO_push(async_scoped.get(), bio.release());
- async = async_scoped.get();
+ GetTestState(ssl.get())->async_bio = async_scoped.get();
bio = std::move(async_scoped);
}
SSL_set_bio(ssl.get(), bio.get(), bio.get());
@@ -689,7 +709,7 @@
} else {
ret = SSL_connect(ssl.get());
}
- } while (config->async && RetryAsync(ssl.get(), ret, async, &clock_delta));
+ } while (config->async && RetryAsync(ssl.get(), ret));
if (ret != 1) {
return false;
}
@@ -843,40 +863,25 @@
0, 1, 255, 256, 257, 16383, 16384, 16385, 32767, 32768, 32769};
for (size_t i = 0; i < sizeof(kRecordSizes) / sizeof(kRecordSizes[0]);
i++) {
- int w;
const size_t len = kRecordSizes[i];
- size_t off = 0;
-
if (len > sizeof(buf)) {
fprintf(stderr, "Bad kRecordSizes value.\n");
return false;
}
-
- do {
- w = SSL_write(ssl.get(), buf + off, len - off);
- if (w > 0) {
- off += (size_t) w;
- }
- } while ((config->async && RetryAsync(ssl.get(), w, async, &clock_delta)) ||
- (w > 0 && off < len));
-
- if (w < 0 || off != len) {
+ if (WriteAll(ssl.get(), buf, len) < 0) {
return false;
}
}
} else {
if (config->shim_writes_first) {
- int w;
- do {
- w = SSL_write(ssl.get(), "hello", 5);
- } while (config->async && RetryAsync(ssl.get(), w, async, &clock_delta));
+ if (WriteAll(ssl.get(), reinterpret_cast<const uint8_t *>("hello"),
+ 5) < 0) {
+ return false;
+ }
}
for (;;) {
uint8_t buf[512];
- int n;
- do {
- n = SSL_read(ssl.get(), buf, sizeof(buf));
- } while (config->async && RetryAsync(ssl.get(), n, async, &clock_delta));
+ int n = DoRead(ssl.get(), buf, sizeof(buf));
int err = SSL_get_error(ssl.get(), n);
if (err == SSL_ERROR_ZERO_RETURN ||
(n == 0 && err == SSL_ERROR_SYSCALL)) {
@@ -910,11 +915,7 @@
for (int i = 0; i < n; i++) {
buf[i] ^= 0xff;
}
- int w;
- do {
- w = SSL_write(ssl.get(), buf, n);
- } while (config->async && RetryAsync(ssl.get(), w, async, &clock_delta));
- if (w != n) {
+ if (WriteAll(ssl.get(), buf, n) < 0) {
return false;
}
}
@@ -950,9 +951,8 @@
return 1;
}
g_config_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
- g_clock_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
g_state_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, TestStateExFree);
- if (g_config_index < 0 || g_clock_index < 0 || g_state_index < 0) {
+ if (g_config_index < 0 || g_state_index < 0) {
return 1;
}