Improve handling of DTLS 1.3 post-handshake messages.

Currently, if we receive a post-handshake message in DTLS 1.3, we treat
it as an error. This changes the behavior to ignore the messages and
provides an entry point for processing them.

Bug: 42290594
Change-Id: I3c4900e4b2d2d0b43033cb7b67f8568cfdfb80ff
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72047
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Nick Harper <nharper@chromium.org>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index 016e3a2..80da717 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -289,6 +289,66 @@
   return ssl->d1->incoming_messages[idx].get();
 }
 
+bool dtls1_process_handshake_fragments(SSL *ssl, uint8_t *out_alert,
+                                       Span<uint8_t> record) {
+  CBS cbs;
+  CBS_init(&cbs, record.data(), record.size());
+  while (CBS_len(&cbs) > 0) {
+    // Read a handshake fragment.
+    struct hm_header_st msg_hdr;
+    CBS body;
+    if (!dtls1_parse_fragment(&cbs, &msg_hdr, &body)) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD);
+      *out_alert = SSL_AD_DECODE_ERROR;
+      return false;
+    }
+
+    const size_t frag_off = msg_hdr.frag_off;
+    const size_t frag_len = msg_hdr.frag_len;
+    const size_t msg_len = msg_hdr.msg_len;
+    if (frag_off > msg_len || frag_off + frag_len < frag_off ||
+        frag_off + frag_len > msg_len ||
+        msg_len > ssl_max_handshake_message_len(ssl)) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
+      *out_alert = SSL_AD_ILLEGAL_PARAMETER;
+      return false;
+    }
+
+    // The encrypted epoch in DTLS has only one handshake message.
+    if (ssl->d1->r_epoch == 1 && msg_hdr.seq != ssl->d1->handshake_read_seq) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
+      *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
+      return false;
+    }
+
+    if (msg_hdr.seq < ssl->d1->handshake_read_seq ||
+        msg_hdr.seq >
+            (unsigned)ssl->d1->handshake_read_seq + SSL_MAX_HANDSHAKE_FLIGHT) {
+      // Ignore fragments from the past, or ones too far in the future.
+      continue;
+    }
+
+    hm_fragment *frag = dtls1_get_incoming_message(ssl, out_alert, &msg_hdr);
+    if (frag == NULL) {
+      return false;
+    }
+    assert(frag->msg_len == msg_len);
+
+    if (frag->reassembly == NULL) {
+      // The message is already assembled.
+      continue;
+    }
+    assert(msg_len > 0);
+
+    // Copy the body into the fragment.
+    OPENSSL_memcpy(frag->data + DTLS1_HM_HEADER_LENGTH + frag_off,
+                   CBS_data(&body), CBS_len(&body));
+    dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len);
+  }
+
+  return true;
+}
+
 ssl_open_record_t dtls1_open_handshake(SSL *ssl, size_t *out_consumed,
                                        uint8_t *out_alert, Span<uint8_t> in) {
   uint8_t type;
@@ -342,61 +402,9 @@
       return ssl_open_record_error;
   }
 
-  CBS cbs;
-  CBS_init(&cbs, record.data(), record.size());
-  while (CBS_len(&cbs) > 0) {
-    // Read a handshake fragment.
-    struct hm_header_st msg_hdr;
-    CBS body;
-    if (!dtls1_parse_fragment(&cbs, &msg_hdr, &body)) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD);
-      *out_alert = SSL_AD_DECODE_ERROR;
-      return ssl_open_record_error;
-    }
-
-    const size_t frag_off = msg_hdr.frag_off;
-    const size_t frag_len = msg_hdr.frag_len;
-    const size_t msg_len = msg_hdr.msg_len;
-    if (frag_off > msg_len || frag_off + frag_len < frag_off ||
-        frag_off + frag_len > msg_len ||
-        msg_len > ssl_max_handshake_message_len(ssl)) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
-      *out_alert = SSL_AD_ILLEGAL_PARAMETER;
-      return ssl_open_record_error;
-    }
-
-    // The encrypted epoch in DTLS has only one handshake message.
-    if (ssl->d1->r_epoch == 1 && msg_hdr.seq != ssl->d1->handshake_read_seq) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
-      *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
-      return ssl_open_record_error;
-    }
-
-    if (msg_hdr.seq < ssl->d1->handshake_read_seq ||
-        msg_hdr.seq >
-            (unsigned)ssl->d1->handshake_read_seq + SSL_MAX_HANDSHAKE_FLIGHT) {
-      // Ignore fragments from the past, or ones too far in the future.
-      continue;
-    }
-
-    hm_fragment *frag = dtls1_get_incoming_message(ssl, out_alert, &msg_hdr);
-    if (frag == NULL) {
-      return ssl_open_record_error;
-    }
-    assert(frag->msg_len == msg_len);
-
-    if (frag->reassembly == NULL) {
-      // The message is already assembled.
-      continue;
-    }
-    assert(msg_len > 0);
-
-    // Copy the body into the fragment.
-    OPENSSL_memcpy(frag->data + DTLS1_HM_HEADER_LENGTH + frag_off,
-                   CBS_data(&body), CBS_len(&body));
-    dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len);
+  if (!dtls1_process_handshake_fragments(ssl, out_alert, record)) {
+    return ssl_open_record_error;
   }
-
   return ssl_open_record_success;
 }
 
diff --git a/ssl/d1_pkt.cc b/ssl/d1_pkt.cc
index 9fead5c..13da69a 100644
--- a/ssl/d1_pkt.cc
+++ b/ssl/d1_pkt.cc
@@ -140,6 +140,14 @@
   }
 
   if (type == SSL3_RT_HANDSHAKE) {
+    // Process handshake fragments for DTLS 1.3 post-handshake messages.
+    if (ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
+      if (!dtls1_process_handshake_fragments(ssl, out_alert, record)) {
+        return ssl_open_record_error;
+      }
+      return ssl_open_record_discard;
+    }
+
     // Parse the first fragment header to determine if this is a pre-CCS or
     // post-CCS handshake record. DTLS resets handshake message numbers on each
     // handshake, so renegotiations and retransmissions are ambiguous.
diff --git a/ssl/internal.h b/ssl/internal.h
index 092b298..b3b1f67 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -3691,6 +3691,8 @@
 bool dtls1_new(SSL *ssl);
 void dtls1_free(SSL *ssl);
 
+bool dtls1_process_handshake_fragments(SSL *ssl, uint8_t *out_alert,
+                                       Span<uint8_t> record);
 bool dtls1_get_message(const SSL *ssl, SSLMessage *out);
 ssl_open_record_t dtls1_open_handshake(SSL *ssl, size_t *out_consumed,
                                        uint8_t *out_alert, Span<uint8_t> in);
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index c019645..8de4b17 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -2050,7 +2050,12 @@
 	if err := c.flushHandshake(); err != nil {
 		return err
 	}
-	c.useOutTrafficSecret(encryptionApplication, c.out.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.out.trafficSecret, c.isDTLS))
+	if !c.isDTLS {
+		// TODO(crbug.com/42290594): Properly implement KeyUpdate. Right
+		// now we only support sending KeyUpdate to test that we drop
+		// post-HS messages on the floor (instead of erroring).
+		c.useOutTrafficSecret(encryptionApplication, c.out.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.out.trafficSecret, c.isDTLS))
+	}
 	return nil
 }
 
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index e598ac7..72937af 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -3322,6 +3322,25 @@
 			keyUpdateRequest: keyUpdateNotRequested,
 		},
 		{
+			protocol: dtls,
+			name:     "KeyUpdate-ToClient-DTLS",
+			config: Config{
+				MaxVersion: VersionTLS13,
+			},
+			sendKeyUpdates:   1,
+			keyUpdateRequest: keyUpdateNotRequested,
+		},
+		{
+			protocol: dtls,
+			testType: serverTest,
+			name:     "KeyUpdate-ToServerDTLS",
+			config: Config{
+				MaxVersion: VersionTLS13,
+			},
+			sendKeyUpdates:   1,
+			keyUpdateRequest: keyUpdateNotRequested,
+		},
+		{
 			name: "KeyUpdate-FromClient",
 			config: Config{
 				MaxVersion: VersionTLS13,
diff --git a/ssl/tls13_both.cc b/ssl/tls13_both.cc
index 4386a6f..67c7b42 100644
--- a/ssl/tls13_both.cc
+++ b/ssl/tls13_both.cc
@@ -664,6 +664,10 @@
 }
 
 bool tls13_post_handshake(SSL *ssl, const SSLMessage &msg) {
+  if (SSL_is_dtls(ssl)) {
+    // TODO(crbug.com/42290594): Process post-handshake messages in DTLS 1.3.
+    return true;
+  }
   if (msg.type == SSL3_MT_KEY_UPDATE) {
     ssl->s3->key_update_count++;
     if (ssl->quic_method != nullptr ||