Fix a bug in and test the message callback.

reuse_message and V2ClientHellos each caused messages to be
double-reported.

Change-Id: I8722a3761ede272408ac9cf8e1b2ce383911cc6f
Reviewed-on: https://boringssl-review.googlesource.com/18764
Reviewed-by: Steven Valdez <svaldez@google.com>
Commit-Queue: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index f019bae..3ac21c3 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -442,14 +442,16 @@
   assert(frag->reassembly == NULL);
   assert(ssl->d1->handshake_read_seq == frag->seq);
 
+  if (ssl->init_msg == NULL) {
+    ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, frag->data,
+                        frag->msg_len + DTLS1_HM_HEADER_LENGTH);
+  }
+
   /* TODO(davidben): This function has a lot of implicit outputs. Simplify the
    * |ssl_get_message| API. */
   ssl->s3->tmp.message_type = frag->type;
   ssl->init_msg = frag->data + DTLS1_HM_HEADER_LENGTH;
   ssl->init_num = frag->msg_len;
-
-  ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, frag->data,
-                      ssl->init_num + DTLS1_HM_HEADER_LENGTH);
   return 1;
 }
 
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index b9aa3bc..4d53d53 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -731,8 +731,10 @@
   }
 
   /* We have now received a complete message. */
-  ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, ssl->init_buf->data,
-                      ssl->init_buf->length);
+  if (ssl->init_msg == NULL && !ssl->s3->is_v2_hello) {
+    ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE,
+                        ssl->init_buf->data, ssl->init_buf->length);
+  }
 
   ssl->s3->tmp.message_type = ((const uint8_t *)ssl->init_buf->data)[0];
   ssl->init_msg = (uint8_t*)ssl->init_buf->data + SSL3_HM_HEADER_LENGTH;
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index a056be0..8f4126e 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -113,6 +113,8 @@
   bool is_resume = false;
   bool early_callback_ready = false;
   bool custom_verify_ready = false;
+  std::string msg_callback_text;
+  bool msg_callback_ok = true;
 };
 
 static void TestStateExFree(void *parent, void *ptr, CRYPTO_EX_DATA *ad,
@@ -993,6 +995,84 @@
   return SSL_TLSEXT_ERR_OK;
 }
 
+static void MessageCallback(int is_write, int version, int content_type,
+                            const void *buf, size_t len, SSL *ssl, void *arg) {
+  const uint8_t *buf_u8 = reinterpret_cast<const uint8_t *>(buf);
+  const TestConfig *config = GetTestConfig(ssl);
+  TestState *state = GetTestState(ssl);
+  if (!state->msg_callback_ok) {
+    return;
+  }
+
+  if (content_type == SSL3_RT_HEADER) {
+    if (len !=
+        (config->is_dtls ? DTLS1_RT_HEADER_LENGTH : SSL3_RT_HEADER_LENGTH)) {
+      fprintf(stderr, "Incorrect length for record header: %zu\n", len);
+      state->msg_callback_ok = false;
+    }
+    return;
+  }
+
+  state->msg_callback_text += is_write ? "write " : "read ";
+  switch (content_type) {
+    case 0:
+      if (version != SSL2_VERSION) {
+        fprintf(stderr, "Incorrect version for V2ClientHello: %x\n", version);
+        state->msg_callback_ok = false;
+        return;
+      }
+      state->msg_callback_text += "v2clienthello\n";
+      return;
+
+    case SSL3_RT_HANDSHAKE: {
+      CBS cbs;
+      CBS_init(&cbs, buf_u8, len);
+      uint8_t type;
+      uint32_t msg_len;
+      if (!CBS_get_u8(&cbs, &type) ||
+          /* TODO(davidben): Reporting on entire messages would be more
+           * consistent than fragments. */
+          (config->is_dtls &&
+           !CBS_skip(&cbs, 3 /* total */ + 2 /* seq */ + 3 /* frag_off */)) ||
+          !CBS_get_u24(&cbs, &msg_len) ||
+          !CBS_skip(&cbs, msg_len) ||
+          CBS_len(&cbs) != 0) {
+        fprintf(stderr, "Could not parse handshake message.\n");
+        state->msg_callback_ok = false;
+        return;
+      }
+      char text[16];
+      snprintf(text, sizeof(text), "hs %d\n", type);
+      state->msg_callback_text += text;
+      return;
+    }
+
+    case SSL3_RT_CHANGE_CIPHER_SPEC:
+      if (len != 1 || buf_u8[0] != 1) {
+        fprintf(stderr, "Invalid ChangeCipherSpec.\n");
+        state->msg_callback_ok = false;
+        return;
+      }
+      state->msg_callback_text += "ccs\n";
+      return;
+
+    case SSL3_RT_ALERT:
+      if (len != 2) {
+        fprintf(stderr, "Invalid alert.\n");
+        state->msg_callback_ok = false;
+        return;
+      }
+      char text[16];
+      snprintf(text, sizeof(text), "alert %d %d\n", buf_u8[0], buf_u8[1]);
+      state->msg_callback_text += text;
+      return;
+
+    default:
+      fprintf(stderr, "Invalid content_type: %d\n", content_type);
+      state->msg_callback_ok = false;
+  }
+}
+
 // Connect returns a new socket connected to localhost on |port| or -1 on
 // error.
 static int Connect(uint16_t port) {
@@ -1224,6 +1304,8 @@
     }
   }
 
+  SSL_CTX_set_msg_callback(ssl_ctx.get(), MessageCallback);
+
   if (old_ctx) {
     uint8_t keys[48];
     if (!SSL_CTX_get_tlsext_ticket_keys(old_ctx, &keys, sizeof(keys)) ||
@@ -2026,7 +2108,25 @@
 
     ret = DoExchange(out_session, ssl.get(), retry_config, is_resume, true);
   }
-  return ret;
+
+  if (!ret) {
+    return false;
+  }
+
+  if (!GetTestState(ssl.get())->msg_callback_ok) {
+    return false;
+  }
+
+  if (!config->expect_msg_callback.empty() &&
+      GetTestState(ssl.get())->msg_callback_text !=
+          config->expect_msg_callback) {
+    fprintf(stderr, "Bad message callback trace. Wanted:\n%s\nGot:\n%s\n",
+            config->expect_msg_callback.c_str(),
+            GetTestState(ssl.get())->msg_callback_text.c_str());
+    return false;
+  }
+
+  return true;
 }
 
 static bool DoExchange(bssl::UniquePtr<SSL_SESSION> *out_session, SSL *ssl,
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index d53c041..9898101 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -1448,6 +1448,54 @@
 			},
 			flags: []string{
 				"-enable-ocsp-stapling",
+				// This test involves an optional message. Test the message callback
+				// trace to ensure we do not miss or double-report any.
+				"-expect-msg-callback",
+				`write hs 1
+read hs 2
+read hs 11
+read hs 12
+read hs 14
+write hs 16
+write ccs
+write hs 20
+read hs 4
+read ccs
+read hs 20
+read alert 1 0
+`,
+			},
+		},
+		{
+			protocol: dtls,
+			name:     "SkipCertificateStatus-DTLS",
+			config: Config{
+				MaxVersion:   VersionTLS12,
+				CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
+				Bugs: ProtocolBugs{
+					SkipCertificateStatus: true,
+				},
+			},
+			flags: []string{
+				"-enable-ocsp-stapling",
+				// This test involves an optional message. Test the message callback
+				// trace to ensure we do not miss or double-report any.
+				"-expect-msg-callback",
+				`write hs 1
+read hs 3
+write hs 1
+read hs 2
+read hs 11
+read hs 12
+read hs 14
+write hs 16
+write ccs
+write hs 20
+read hs 4
+read ccs
+read hs 20
+read alert 1 0
+`,
 			},
 		},
 		{
@@ -4759,6 +4807,20 @@
 					SendV2ClientHello: true,
 				},
 			},
+			flags: []string{
+				"-expect-msg-callback",
+				`read v2clienthello
+write hs 2
+write hs 11
+write hs 14
+read hs 16
+read ccs
+read hs 20
+write ccs
+write hs 20
+read alert 1 0
+`,
+			},
 		})
 
 		// Test Channel ID
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index fa7dfe1..8b2f7f2 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -157,6 +157,7 @@
   { "-expect-peer-cert-file", &TestConfig::expect_peer_cert_file },
   { "-use-client-ca-list", &TestConfig::use_client_ca_list },
   { "-expect-client-ca-list", &TestConfig::expected_client_ca_list },
+  { "-expect-msg-callback", &TestConfig::expect_msg_callback },
 };
 
 const Flag<std::string> kBase64Flags[] = {
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index 1e5912e..af75548 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -145,6 +145,7 @@
   bool allow_unknown_alpn_protos = false;
   bool enable_ed25519 = false;
   bool use_custom_verify_callback = false;
+  std::string expect_msg_callback;
 };
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_initial,