Fix record header callback on writes.

These broke at some point. Add a test for them.

Change-Id: Ie45869e07d9615ae33aae4613f6d9b996af39528
Reviewed-on: https://boringssl-review.googlesource.com/17330
Commit-Queue: Adam Langley <agl@google.com>
Reviewed-by: Adam Langley <agl@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 31899a5..84b7496 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -3046,6 +3046,81 @@
   return true;
 }
 
+static bool TestRecordCallback(bool is_dtls, const SSL_METHOD *method,
+                               uint16_t version) {
+  bssl::UniquePtr<X509> cert = GetChainTestCertificate();
+  bssl::UniquePtr<X509> intermediate = GetChainTestIntermediate();
+  bssl::UniquePtr<EVP_PKEY> key = GetChainTestKey();
+  if (!cert || !intermediate || !key) {
+    return false;
+  }
+
+  bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(method));
+  if (!ctx ||
+      !SSL_CTX_use_certificate(ctx.get(), cert.get()) ||
+      !SSL_CTX_use_PrivateKey(ctx.get(), key.get()) ||
+      !SSL_CTX_set_min_proto_version(ctx.get(), version) ||
+      !SSL_CTX_set_max_proto_version(ctx.get(), version)) {
+    return false;
+  }
+
+  bool read_seen = false;
+  bool write_seen = false;
+  auto cb = [&](int is_write, int cb_version, int cb_type, const void *buf,
+                size_t len, SSL *ssl) {
+    if (cb_type != SSL3_RT_HEADER) {
+      return;
+    }
+
+    // The callback does not report a version for records.
+    EXPECT_EQ(0, cb_version);
+
+    if (is_write) {
+      write_seen = true;
+    } else {
+      read_seen = true;
+    }
+
+    // Sanity-check that the record header is plausible.
+    CBS cbs;
+    CBS_init(&cbs, reinterpret_cast<const uint8_t *>(buf), len);
+    uint8_t type;
+    uint16_t record_version, length;
+    ASSERT_TRUE(CBS_get_u8(&cbs, &type));
+    ASSERT_TRUE(CBS_get_u16(&cbs, &record_version));
+    EXPECT_TRUE(record_version == version ||
+                record_version == (is_dtls ? DTLS1_VERSION : TLS1_VERSION))
+        << "Invalid record version: " << record_version;
+    if (is_dtls) {
+      uint16_t epoch;
+      ASSERT_TRUE(CBS_get_u16(&cbs, &epoch));
+      EXPECT_TRUE(epoch == 0 || epoch == 1) << "Invalid epoch: " << epoch;
+      ASSERT_TRUE(CBS_skip(&cbs, 6));
+    }
+    ASSERT_TRUE(CBS_get_u16(&cbs, &length));
+    EXPECT_EQ(0u, CBS_len(&cbs));
+  };
+  using CallbackType = decltype(cb);
+  SSL_CTX_set_msg_callback(
+      ctx.get(), [](int is_write, int cb_version, int cb_type, const void *buf,
+                    size_t len, SSL *ssl, void *arg) {
+        CallbackType *cb_ptr = reinterpret_cast<CallbackType *>(arg);
+        (*cb_ptr)(is_write, cb_version, cb_type, buf, len, ssl);
+      });
+  SSL_CTX_set_msg_callback_arg(ctx.get(), &cb);
+
+  bssl::UniquePtr<SSL> client, server;
+  if (!ConnectClientAndServer(&client, &server, ctx.get(), ctx.get(),
+                              nullptr /* no session */)) {
+    return false;
+  }
+
+  EXPECT_TRUE(read_seen);
+  EXPECT_TRUE(write_seen);
+  return true;
+}
+
+
 static bool ForEachVersion(bool (*test_func)(bool is_dtls,
                                              const SSL_METHOD *method,
                                              uint16_t version)) {
@@ -3539,7 +3614,8 @@
       !ForEachVersion(TestALPNCipherAvailable) ||
       !ForEachVersion(TestSSLClearSessionResumption) ||
       !ForEachVersion(TestAutoChain) ||
-      !ForEachVersion(TestSSLWriteRetry)) {
+      !ForEachVersion(TestSSLWriteRetry) ||
+      !ForEachVersion(TestRecordCallback)) {
     ADD_FAILURE() << "Tests failed";
   }
 }
diff --git a/ssl/tls_record.c b/ssl/tls_record.c
index e67e0b4..a5bbe93 100644
--- a/ssl/tls_record.c
+++ b/ssl/tls_record.c
@@ -398,14 +398,12 @@
   out[0] = type;
   out[1] = wire_version >> 8;
   out[2] = wire_version & 0xff;
-  out += 3;
-  max_out -= 3;
 
   /* Write the ciphertext, leaving two bytes for the length. */
   size_t ciphertext_len;
-  if (!SSL_AEAD_CTX_seal(ssl->s3->aead_write_ctx, out + 2, &ciphertext_len,
-                         max_out - 2, type, wire_version,
-                         ssl->s3->write_sequence, in, in_len) ||
+  if (!SSL_AEAD_CTX_seal(ssl->s3->aead_write_ctx, out + SSL3_RT_HEADER_LENGTH,
+                         &ciphertext_len, max_out - SSL3_RT_HEADER_LENGTH, type,
+                         wire_version, ssl->s3->write_sequence, in, in_len) ||
       !ssl_record_sequence_update(ssl->s3->write_sequence, 8)) {
     return 0;
   }
@@ -415,8 +413,8 @@
     OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
     return 0;
   }
-  out[0] = ciphertext_len >> 8;
-  out[1] = ciphertext_len & 0xff;
+  out[3] = ciphertext_len >> 8;
+  out[4] = ciphertext_len & 0xff;
 
   *out_len = SSL3_RT_HEADER_LENGTH + ciphertext_len;