Fix record header callback on writes. These broke at some point. Add a test for them. Change-Id: Ie45869e07d9615ae33aae4613f6d9b996af39528 Reviewed-on: https://boringssl-review.googlesource.com/17330 Commit-Queue: Adam Langley <agl@google.com> Reviewed-by: Adam Langley <agl@google.com> CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc index 31899a5..84b7496 100644 --- a/ssl/ssl_test.cc +++ b/ssl/ssl_test.cc
@@ -3046,6 +3046,81 @@ return true; } +static bool TestRecordCallback(bool is_dtls, const SSL_METHOD *method, + uint16_t version) { + bssl::UniquePtr<X509> cert = GetChainTestCertificate(); + bssl::UniquePtr<X509> intermediate = GetChainTestIntermediate(); + bssl::UniquePtr<EVP_PKEY> key = GetChainTestKey(); + if (!cert || !intermediate || !key) { + return false; + } + + bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(method)); + if (!ctx || + !SSL_CTX_use_certificate(ctx.get(), cert.get()) || + !SSL_CTX_use_PrivateKey(ctx.get(), key.get()) || + !SSL_CTX_set_min_proto_version(ctx.get(), version) || + !SSL_CTX_set_max_proto_version(ctx.get(), version)) { + return false; + } + + bool read_seen = false; + bool write_seen = false; + auto cb = [&](int is_write, int cb_version, int cb_type, const void *buf, + size_t len, SSL *ssl) { + if (cb_type != SSL3_RT_HEADER) { + return; + } + + // The callback does not report a version for records. + EXPECT_EQ(0, cb_version); + + if (is_write) { + write_seen = true; + } else { + read_seen = true; + } + + // Sanity-check that the record header is plausible. + CBS cbs; + CBS_init(&cbs, reinterpret_cast<const uint8_t *>(buf), len); + uint8_t type; + uint16_t record_version, length; + ASSERT_TRUE(CBS_get_u8(&cbs, &type)); + ASSERT_TRUE(CBS_get_u16(&cbs, &record_version)); + EXPECT_TRUE(record_version == version || + record_version == (is_dtls ? DTLS1_VERSION : TLS1_VERSION)) + << "Invalid record version: " << record_version; + if (is_dtls) { + uint16_t epoch; + ASSERT_TRUE(CBS_get_u16(&cbs, &epoch)); + EXPECT_TRUE(epoch == 0 || epoch == 1) << "Invalid epoch: " << epoch; + ASSERT_TRUE(CBS_skip(&cbs, 6)); + } + ASSERT_TRUE(CBS_get_u16(&cbs, &length)); + EXPECT_EQ(0u, CBS_len(&cbs)); + }; + using CallbackType = decltype(cb); + SSL_CTX_set_msg_callback( + ctx.get(), [](int is_write, int cb_version, int cb_type, const void *buf, + size_t len, SSL *ssl, void *arg) { + CallbackType *cb_ptr = reinterpret_cast<CallbackType *>(arg); + (*cb_ptr)(is_write, cb_version, cb_type, buf, len, ssl); + }); + SSL_CTX_set_msg_callback_arg(ctx.get(), &cb); + + bssl::UniquePtr<SSL> client, server; + if (!ConnectClientAndServer(&client, &server, ctx.get(), ctx.get(), + nullptr /* no session */)) { + return false; + } + + EXPECT_TRUE(read_seen); + EXPECT_TRUE(write_seen); + return true; +} + + static bool ForEachVersion(bool (*test_func)(bool is_dtls, const SSL_METHOD *method, uint16_t version)) { @@ -3539,7 +3614,8 @@ !ForEachVersion(TestALPNCipherAvailable) || !ForEachVersion(TestSSLClearSessionResumption) || !ForEachVersion(TestAutoChain) || - !ForEachVersion(TestSSLWriteRetry)) { + !ForEachVersion(TestSSLWriteRetry) || + !ForEachVersion(TestRecordCallback)) { ADD_FAILURE() << "Tests failed"; } }
diff --git a/ssl/tls_record.c b/ssl/tls_record.c index e67e0b4..a5bbe93 100644 --- a/ssl/tls_record.c +++ b/ssl/tls_record.c
@@ -398,14 +398,12 @@ out[0] = type; out[1] = wire_version >> 8; out[2] = wire_version & 0xff; - out += 3; - max_out -= 3; /* Write the ciphertext, leaving two bytes for the length. */ size_t ciphertext_len; - if (!SSL_AEAD_CTX_seal(ssl->s3->aead_write_ctx, out + 2, &ciphertext_len, - max_out - 2, type, wire_version, - ssl->s3->write_sequence, in, in_len) || + if (!SSL_AEAD_CTX_seal(ssl->s3->aead_write_ctx, out + SSL3_RT_HEADER_LENGTH, + &ciphertext_len, max_out - SSL3_RT_HEADER_LENGTH, type, + wire_version, ssl->s3->write_sequence, in, in_len) || !ssl_record_sequence_update(ssl->s3->write_sequence, 8)) { return 0; } @@ -415,8 +413,8 @@ OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW); return 0; } - out[0] = ciphertext_len >> 8; - out[1] = ciphertext_len & 0xff; + out[3] = ciphertext_len >> 8; + out[4] = ciphertext_len & 0xff; *out_len = SSL3_RT_HEADER_LENGTH + ciphertext_len;