Fix transcript hash for DTLS 1.3.

The DTLS 1.3 handshake transcript uses the same format as TLS 1.3, i.e.
without the message_seq, fragment_offset, and fragment_length values
from the DTLSHandshake struct.

Change-Id: Ic46ee5519d92d15b194a47149e4497c18be52876
Bug: 42290594
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/71947
Reviewed-by: Adam Langley <agl@google.com>
Commit-Queue: Nick Harper <nharper@chromium.org>
Auto-Submit: Nick Harper <nharper@chromium.org>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/handshake.cc b/ssl/handshake.cc
index 7195d66..0c5895f 100644
--- a/ssl/handshake.cc
+++ b/ssl/handshake.cc
@@ -126,6 +126,8 @@
 
 SSL_HANDSHAKE::SSL_HANDSHAKE(SSL *ssl_arg)
     : ssl(ssl_arg),
+      transcript(SSL_is_dtls(ssl_arg)),
+      inner_transcript(SSL_is_dtls(ssl_arg)),
       ech_is_inner(false),
       ech_authenticated_reject(false),
       scts_requested(false),
diff --git a/ssl/internal.h b/ssl/internal.h
index a3b1ce0..dc539a9 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -937,7 +937,7 @@
 // buffer and running hash.
 class SSLTranscript {
  public:
-  SSLTranscript();
+  explicit SSLTranscript(bool is_dtls);
   ~SSLTranscript();
 
   SSLTranscript(SSLTranscript &&other) = default;
@@ -1000,10 +1000,23 @@
                       bool from_server) const;
 
  private:
+  // HashBuffer initializes |ctx| to use |digest| and writes the contents of
+  // |buffer_| to |ctx|. If this SSLTranscript is for DTLS 1.3, the appropriate
+  // bytes in |buffer_| will be skipped when hashing the buffer.
+  bool HashBuffer(EVP_MD_CTX *ctx, const EVP_MD *digest) const;
+
+  // AddToBufferOrHash directly adds the contents of |in| to |buffer_| and/or
+  // |hash_|.
+  bool AddToBufferOrHash(Span<const uint8_t> in);
+
   // buffer_, if non-null, contains the handshake transcript.
   UniquePtr<BUF_MEM> buffer_;
   // hash, if initialized with an |EVP_MD|, maintains the handshake hash.
   ScopedEVP_MD_CTX hash_;
+  // is_dtls_ indicates whether this is a transcript for a DTLS connection.
+  bool is_dtls_ : 1;
+  // version_ contains the version for the connection (if known).
+  uint16_t version_ = 0;
 };
 
 // tls1_prf computes the PRF function for |ssl|. It fills |out|, using |secret|
diff --git a/ssl/ssl_transcript.cc b/ssl/ssl_transcript.cc
index 239363d..4d813df 100644
--- a/ssl/ssl_transcript.cc
+++ b/ssl/ssl_transcript.cc
@@ -144,7 +144,7 @@
 
 BSSL_NAMESPACE_BEGIN
 
-SSLTranscript::SSLTranscript() {}
+SSLTranscript::SSLTranscript(bool is_dtls) : is_dtls_(is_dtls) {}
 
 SSLTranscript::~SSLTranscript() {}
 
@@ -159,13 +159,73 @@
 }
 
 bool SSLTranscript::InitHash(uint16_t version, const SSL_CIPHER *cipher) {
+  version_ = version;
   const EVP_MD *md = ssl_get_handshake_digest(version, cipher);
   if (Digest() == md) {
     // No need to re-hash the buffer.
     return true;
   }
-  return EVP_DigestInit_ex(hash_.get(), md, nullptr) &&
-         EVP_DigestUpdate(hash_.get(), buffer_->data, buffer_->length);
+  if (!HashBuffer(hash_.get(), md)) {
+    return false;
+  }
+  if (is_dtls_ && version_ >= TLS1_3_VERSION) {
+    // In DTLS 1.3, prior to the call to InitHash, the message (if present) in
+    // the buffer has the DTLS 1.2 header. After the call to InitHash, the TLS
+    // 1.3 header is written by SSLTranscript::Update. If the buffer isn't freed
+    // here, it would have a mix of different header formats and using it would
+    // yield wrong results. However, there's no need for the buffer once the
+    // version and the digest for the cipher suite are known, so the buffer is
+    // freed here to avoid potential misuse of the SSLTranscript object.
+    FreeBuffer();
+  }
+  return true;
+}
+
+bool SSLTranscript::HashBuffer(EVP_MD_CTX *ctx, const EVP_MD *digest) const {
+  if (!EVP_DigestInit_ex(ctx, digest, nullptr)) {
+    return false;
+  }
+  if (!is_dtls_ || version_ < TLS1_3_VERSION) {
+    return EVP_DigestUpdate(ctx, buffer_->data, buffer_->length);
+  }
+
+  // If the version is DTLS 1.3 and we still have a buffer, then there should be
+  // at most a single DTLSHandshake message in the buffer, for the ClientHello.
+  // On the server side, the version (DTLS 1.3) and cipher suite are chosen in
+  // response to the first ClientHello, and InitHash is called before that
+  // ClientHello is added to the SSLTranscript, so the buffer is empty if this
+  // SSLTranscript is on the server.
+  if (buffer_->length == 0) {
+    return true;
+  }
+
+  // On the client side, we can receive either a ServerHello or
+  // HelloRetryRequest in response to the ClientHello. Regardless of which
+  // message we receive, the client code calls InitHash before updating the
+  // transcript with that message, so the ClientHello is the only message in the
+  // buffer. In DTLS 1.3, we need to skip the message_seq, fragment_offset, and
+  // fragment_length fields from the DTLSHandshake message in the buffer. The
+  // structure of a DTLSHandshake message is as follows (RFC 9147, section 5.2):
+  //
+  //   struct {
+  //       HandshakeType msg_type;    /* handshake type */
+  //       uint24 length;             /* bytes in message */
+  //       uint16 message_seq;        /* DTLS-required field */
+  //       uint24 fragment_offset;    /* DTLS-required field */
+  //       uint24 fragment_length;    /* DTLS-required field */
+  //       select (msg_type) {
+  //         /* omitted for brevity */
+  //       } body;
+  //   } DTLSHandshake;
+  CBS buf, header;
+  CBS_init(&buf, reinterpret_cast<uint8_t *>(buffer_->data), buffer_->length);
+  if (!CBS_get_bytes(&buf, &header, 4) ||                             //
+      !CBS_skip(&buf, 8) ||                                           //
+      !EVP_DigestUpdate(ctx, CBS_data(&header), CBS_len(&header)) ||  //
+      !EVP_DigestUpdate(ctx, CBS_data(&buf), CBS_len(&buf))) {
+    return false;
+  }
+  return true;
 }
 
 void SSLTranscript::FreeBuffer() {
@@ -193,8 +253,8 @@
   const uint8_t header[4] = {SSL3_MT_MESSAGE_HASH, 0, 0,
                              static_cast<uint8_t>(hash_len)};
   if (!EVP_DigestInit_ex(hash_.get(), Digest(), nullptr) ||
-      !Update(header) ||
-      !Update(MakeConstSpan(old_hash, hash_len))) {
+      !AddToBufferOrHash(header) ||
+      !AddToBufferOrHash(MakeConstSpan(old_hash, hash_len))) {
     return false;
   }
   return true;
@@ -209,8 +269,7 @@
   }
 
   if (buffer_) {
-    return EVP_DigestInit_ex(ctx, digest, nullptr) &&
-           EVP_DigestUpdate(ctx, buffer_->data, buffer_->length);
+    return HashBuffer(ctx, digest);
   }
 
   OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
@@ -218,6 +277,27 @@
 }
 
 bool SSLTranscript::Update(Span<const uint8_t> in) {
+  if (!is_dtls_ || version_ < TLS1_3_VERSION) {
+    return AddToBufferOrHash(in);
+  }
+  if (in.size() < DTLS1_HM_HEADER_LENGTH) {
+    return false;
+  }
+  // The message passed into Update is the whole Handshake or DTLSHandshake
+  // message, including the msg_type and length. In DTLS, the DTLSHandshake
+  // message also has message_seq, fragment_offset, and fragment_length
+  // fields. In DTLS 1.3, those fields are omitted so that the same
+  // transcript format as TLS 1.3 is used. This means we write the 1-byte
+  // msg_type, 3-byte length, then skip 2+3+3 bytes for the DTLS-specific
+  // fields that get omitted.
+  if (!AddToBufferOrHash(in.subspan(0, 4)) ||
+      !AddToBufferOrHash(in.subspan(12))) {
+    return false;
+  }
+  return true;
+}
+
+bool SSLTranscript::AddToBufferOrHash(Span<const uint8_t> in) {
   // Depending on the state of the handshake, either the handshake buffer may be
   // active, the rolling hash, or both.
   if (buffer_ &&
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go
index 3ef23df..0324c87 100644
--- a/ssl/test/runner/prf.go
+++ b/ssl/test/runner/prf.go
@@ -258,8 +258,8 @@
 // handshake message with a TLS header. In DTLS, the header is rewritten to a
 // DTLS header with |seqno| as the sequence number.
 func (h *finishedHash) WriteHandshake(msg []byte, seqno uint16) {
-	if h.isDTLS {
-		// This is somewhat hacky. DTLS hashes a slightly different format.
+	if h.isDTLS && h.version <= VersionTLS12 {
+		// This is somewhat hacky. DTLS <= 1.2 hashes a slightly different format. (DTLS 1.3 uses the same format as TLS.)
 		// First, the TLS header.
 		h.Write(msg[:4])
 		// Then the sequence number and reassembled fragment offset (always 0).