Fix a bug in and test the message callback.
reuse_message and V2ClientHellos each caused messages to be
double-reported.
Change-Id: I8722a3761ede272408ac9cf8e1b2ce383911cc6f
Reviewed-on: https://boringssl-review.googlesource.com/18764
Reviewed-by: Steven Valdez <svaldez@google.com>
Commit-Queue: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index f019bae..3ac21c3 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -442,14 +442,16 @@
assert(frag->reassembly == NULL);
assert(ssl->d1->handshake_read_seq == frag->seq);
+ if (ssl->init_msg == NULL) {
+ ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, frag->data,
+ frag->msg_len + DTLS1_HM_HEADER_LENGTH);
+ }
+
/* TODO(davidben): This function has a lot of implicit outputs. Simplify the
* |ssl_get_message| API. */
ssl->s3->tmp.message_type = frag->type;
ssl->init_msg = frag->data + DTLS1_HM_HEADER_LENGTH;
ssl->init_num = frag->msg_len;
-
- ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, frag->data,
- ssl->init_num + DTLS1_HM_HEADER_LENGTH);
return 1;
}
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index b9aa3bc..4d53d53 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -731,8 +731,10 @@
}
/* We have now received a complete message. */
- ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, ssl->init_buf->data,
- ssl->init_buf->length);
+ if (ssl->init_msg == NULL && !ssl->s3->is_v2_hello) {
+ ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE,
+ ssl->init_buf->data, ssl->init_buf->length);
+ }
ssl->s3->tmp.message_type = ((const uint8_t *)ssl->init_buf->data)[0];
ssl->init_msg = (uint8_t*)ssl->init_buf->data + SSL3_HM_HEADER_LENGTH;
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index a056be0..8f4126e 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -113,6 +113,8 @@
bool is_resume = false;
bool early_callback_ready = false;
bool custom_verify_ready = false;
+ std::string msg_callback_text;
+ bool msg_callback_ok = true;
};
static void TestStateExFree(void *parent, void *ptr, CRYPTO_EX_DATA *ad,
@@ -993,6 +995,84 @@
return SSL_TLSEXT_ERR_OK;
}
+static void MessageCallback(int is_write, int version, int content_type,
+ const void *buf, size_t len, SSL *ssl, void *arg) {
+ const uint8_t *buf_u8 = reinterpret_cast<const uint8_t *>(buf);
+ const TestConfig *config = GetTestConfig(ssl);
+ TestState *state = GetTestState(ssl);
+ if (!state->msg_callback_ok) {
+ return;
+ }
+
+ if (content_type == SSL3_RT_HEADER) {
+ if (len !=
+ (config->is_dtls ? DTLS1_RT_HEADER_LENGTH : SSL3_RT_HEADER_LENGTH)) {
+ fprintf(stderr, "Incorrect length for record header: %zu\n", len);
+ state->msg_callback_ok = false;
+ }
+ return;
+ }
+
+ state->msg_callback_text += is_write ? "write " : "read ";
+ switch (content_type) {
+ case 0:
+ if (version != SSL2_VERSION) {
+ fprintf(stderr, "Incorrect version for V2ClientHello: %x\n", version);
+ state->msg_callback_ok = false;
+ return;
+ }
+ state->msg_callback_text += "v2clienthello\n";
+ return;
+
+ case SSL3_RT_HANDSHAKE: {
+ CBS cbs;
+ CBS_init(&cbs, buf_u8, len);
+ uint8_t type;
+ uint32_t msg_len;
+ if (!CBS_get_u8(&cbs, &type) ||
+ /* TODO(davidben): Reporting on entire messages would be more
+ * consistent than fragments. */
+ (config->is_dtls &&
+ !CBS_skip(&cbs, 3 /* total */ + 2 /* seq */ + 3 /* frag_off */)) ||
+ !CBS_get_u24(&cbs, &msg_len) ||
+ !CBS_skip(&cbs, msg_len) ||
+ CBS_len(&cbs) != 0) {
+ fprintf(stderr, "Could not parse handshake message.\n");
+ state->msg_callback_ok = false;
+ return;
+ }
+ char text[16];
+ snprintf(text, sizeof(text), "hs %d\n", type);
+ state->msg_callback_text += text;
+ return;
+ }
+
+ case SSL3_RT_CHANGE_CIPHER_SPEC:
+ if (len != 1 || buf_u8[0] != 1) {
+ fprintf(stderr, "Invalid ChangeCipherSpec.\n");
+ state->msg_callback_ok = false;
+ return;
+ }
+ state->msg_callback_text += "ccs\n";
+ return;
+
+ case SSL3_RT_ALERT:
+ if (len != 2) {
+ fprintf(stderr, "Invalid alert.\n");
+ state->msg_callback_ok = false;
+ return;
+ }
+ char text[16];
+ snprintf(text, sizeof(text), "alert %d %d\n", buf_u8[0], buf_u8[1]);
+ state->msg_callback_text += text;
+ return;
+
+ default:
+ fprintf(stderr, "Invalid content_type: %d\n", content_type);
+ state->msg_callback_ok = false;
+ }
+}
+
// Connect returns a new socket connected to localhost on |port| or -1 on
// error.
static int Connect(uint16_t port) {
@@ -1224,6 +1304,8 @@
}
}
+ SSL_CTX_set_msg_callback(ssl_ctx.get(), MessageCallback);
+
if (old_ctx) {
uint8_t keys[48];
if (!SSL_CTX_get_tlsext_ticket_keys(old_ctx, &keys, sizeof(keys)) ||
@@ -2026,7 +2108,25 @@
ret = DoExchange(out_session, ssl.get(), retry_config, is_resume, true);
}
- return ret;
+
+ if (!ret) {
+ return false;
+ }
+
+ if (!GetTestState(ssl.get())->msg_callback_ok) {
+ return false;
+ }
+
+ if (!config->expect_msg_callback.empty() &&
+ GetTestState(ssl.get())->msg_callback_text !=
+ config->expect_msg_callback) {
+ fprintf(stderr, "Bad message callback trace. Wanted:\n%s\nGot:\n%s\n",
+ config->expect_msg_callback.c_str(),
+ GetTestState(ssl.get())->msg_callback_text.c_str());
+ return false;
+ }
+
+ return true;
}
static bool DoExchange(bssl::UniquePtr<SSL_SESSION> *out_session, SSL *ssl,
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index d53c041..9898101 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -1448,6 +1448,54 @@
},
flags: []string{
"-enable-ocsp-stapling",
+ // This test involves an optional message. Test the message callback
+ // trace to ensure we do not miss or double-report any.
+ "-expect-msg-callback",
+ `write hs 1
+read hs 2
+read hs 11
+read hs 12
+read hs 14
+write hs 16
+write ccs
+write hs 20
+read hs 4
+read ccs
+read hs 20
+read alert 1 0
+`,
+ },
+ },
+ {
+ protocol: dtls,
+ name: "SkipCertificateStatus-DTLS",
+ config: Config{
+ MaxVersion: VersionTLS12,
+ CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
+ Bugs: ProtocolBugs{
+ SkipCertificateStatus: true,
+ },
+ },
+ flags: []string{
+ "-enable-ocsp-stapling",
+ // This test involves an optional message. Test the message callback
+ // trace to ensure we do not miss or double-report any.
+ "-expect-msg-callback",
+ `write hs 1
+read hs 3
+write hs 1
+read hs 2
+read hs 11
+read hs 12
+read hs 14
+write hs 16
+write ccs
+write hs 20
+read hs 4
+read ccs
+read hs 20
+read alert 1 0
+`,
},
},
{
@@ -4759,6 +4807,20 @@
SendV2ClientHello: true,
},
},
+ flags: []string{
+ "-expect-msg-callback",
+ `read v2clienthello
+write hs 2
+write hs 11
+write hs 14
+read hs 16
+read ccs
+read hs 20
+write ccs
+write hs 20
+read alert 1 0
+`,
+ },
})
// Test Channel ID
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index fa7dfe1..8b2f7f2 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -157,6 +157,7 @@
{ "-expect-peer-cert-file", &TestConfig::expect_peer_cert_file },
{ "-use-client-ca-list", &TestConfig::use_client_ca_list },
{ "-expect-client-ca-list", &TestConfig::expected_client_ca_list },
+ { "-expect-msg-callback", &TestConfig::expect_msg_callback },
};
const Flag<std::string> kBase64Flags[] = {
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index 1e5912e..af75548 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -145,6 +145,7 @@
bool allow_unknown_alpn_protos = false;
bool enable_ed25519 = false;
bool use_custom_verify_callback = false;
+ std::string expect_msg_callback;
};
bool ParseConfig(int argc, char **argv, TestConfig *out_initial,