Implement DTLS 1.3 record header.

Bug: 715
Change-Id: I69c82eed41946da404fb13129aa790d61ec0fb78
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/69689
Auto-Submit: Nick Harper <nharper@chromium.org>
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: Bob Beck <bbe@google.com>
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index ac42e0b..6108f5c 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -95,13 +95,10 @@
     // reordering around KeyUpdate (i.e. accept records from both epochs), we'll
     // need a separate bitmap for each epoch.
     ssl->d1->r_epoch = level;
-    // |ssl->d1->bitmap| incorporates epochs into sequence numbers, so it
-    // doesn't need to be reset. Preserving it allows |SSL_get_read_sequence| to
-    // query the maximum sequence number received.
   } else {
     ssl->d1->r_epoch++;
-    ssl->d1->bitmap = DTLS1_BITMAP();
   }
+  ssl->d1->bitmap = DTLS1_BITMAP();
   ssl->s3->read_sequence = 0;
 
   ssl->s3->aead_read_ctx = std::move(aead_ctx);
diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc
index f7aeac2..1d32630 100644
--- a/ssl/dtls_record.cc
+++ b/ssl/dtls_record.cc
@@ -159,6 +159,132 @@
   }
 }
 
+// reconstruct_epoch finds the largest epoch that ends with the epoch bits from
+// |wire_epoch| that is less than or equal to |current_epoch|, to match the
+// epoch reconstruction algorithm described in RFC 9147 section 4.2.2.
+static uint16_t reconstruct_epoch(uint8_t wire_epoch, uint16_t current_epoch) {
+  uint16_t current_epoch_high = current_epoch & 0xfffc;
+  uint16_t epoch = (wire_epoch & 0x3) | current_epoch_high;
+  if (epoch > current_epoch && current_epoch_high > 0) {
+    epoch -= 0x4;
+  }
+  return epoch;
+}
+
+// reconstruct_seqnum returns the smallest sequence number that hasn't been seen
+// in |bitmap| and is still within |bitmap|'s window to handle as a reordered
+// record.
+//
+// Section 4.2.2 of RFC 9147 describes an algorithm for reconstructing sequence
+// numbers, which is implemented here. This algorithm finds the sequence number
+// that is numerically closest to one plus the largest sequence number seen in
+// this epoch.
+static uint64_t reconstruct_seqnum(uint16_t wire_seq, uint64_t seq_mask,
+                                   DTLS1_BITMAP *bitmap) {
+  uint64_t max_seqnum_plus_one = bitmap->max_seq_num + 1;
+  uint64_t diff = (wire_seq - max_seqnum_plus_one) & seq_mask;
+  uint64_t step = seq_mask + 1;
+  uint64_t seqnum = max_seqnum_plus_one + diff;
+  // diff is always non-negative, so seqnum is >= max_seqnum_plus_one. If the
+  // diff is larger than half the step size, then the numerically closest
+  // sequence number is less than max_seqnum_plus_one instead of greater.
+  if (diff > step / 2) {
+    seqnum -= step;
+  }
+  return seqnum;
+}
+
+static bool parse_dtls13_record_header(SSL *ssl, CBS *in, size_t packet_size,
+                                       uint8_t type, CBS *out_body,
+                                       uint64_t *out_sequence,
+                                       uint16_t *out_epoch,
+                                       size_t *out_header_len) {
+  // TODO(crbug.com/boringssl/715): Decrypt the sequence number before
+  // decoding it.
+  if ((type & 0x10) == 0x10) {
+    // Connection ID bit set, which we didn't negotiate.
+    return false;
+  }
+  // TODO(crbug.com/boringssl/715): Add a runner test that performs many
+  // key updates to verify epoch reconstruction works for epochs larger than
+  // 3.
+  *out_epoch = reconstruct_epoch(type, ssl->d1->r_epoch);
+  if ((type & 0x08) == 0x08) {
+    // 16-bit sequence number.
+    uint16_t seq;
+    if (!CBS_get_u16(in, &seq)) {
+      // The record header was incomplete or malformed.
+      return false;
+    }
+    *out_sequence = reconstruct_seqnum(seq, 0xffff, &ssl->d1->bitmap);
+  } else {
+    // 8-bit sequence number.
+    uint8_t seq;
+    if (!CBS_get_u8(in, &seq)) {
+      // The record header was incomplete or malformed.
+      return false;
+    }
+    *out_sequence = reconstruct_seqnum(seq, 0xff, &ssl->d1->bitmap);
+  }
+  *out_header_len = packet_size - CBS_len(in);
+  if ((type & 0x04) == 0x04) {
+    *out_header_len += 2;
+    // 16-bit length present
+    if (!CBS_get_u16_length_prefixed(in, out_body)) {
+      // The record header was incomplete or malformed.
+      return false;
+    }
+  } else {
+    // No length present - the remaining contents are the whole packet.
+    // CBS_get_bytes is used here to advance |in| to the end so that future
+    // code that computes the number of consumed bytes functions correctly.
+    if (!CBS_get_bytes(in, out_body, CBS_len(in))) {
+      return false;
+    }
+  }
+  return true;
+}
+
+static bool parse_dtls_plaintext_record_header(
+    SSL *ssl, CBS *in, size_t packet_size, uint8_t type, CBS *out_body,
+    uint64_t *out_sequence, uint16_t *out_epoch, size_t *out_header_len,
+    uint16_t *out_version) {
+  SSLAEADContext *aead = ssl->s3->aead_read_ctx.get();
+  uint8_t sequence_bytes[8];
+  if (!CBS_get_u16(in, out_version) ||
+      !CBS_copy_bytes(in, sequence_bytes, sizeof(sequence_bytes))) {
+    return false;
+  }
+  *out_header_len = packet_size - CBS_len(in) + 2;
+  if (!CBS_get_u16_length_prefixed(in, out_body) ||
+      CBS_len(out_body) > SSL3_RT_MAX_ENCRYPTED_LENGTH) {
+    return false;
+  }
+
+  bool version_ok;
+  if (aead->is_null_cipher()) {
+    // Only check the first byte. Enforcing beyond that can prevent decoding
+    // version negotiation failure alerts.
+    version_ok = (*out_version >> 8) == DTLS1_VERSION_MAJOR;
+  } else {
+    version_ok = *out_version == aead->RecordVersion();
+  }
+
+  if (!version_ok) {
+    return false;
+  }
+
+  *out_sequence = CRYPTO_load_u64_be(sequence_bytes);
+  *out_epoch = static_cast<uint16_t>(*out_sequence >> 48);
+
+  // Discard the packet if we're expecting an encrypted DTLS 1.3 record but we
+  // get the old record header format.
+  if (!aead->is_null_cipher() && aead->ProtocolVersion() >= TLS1_3_VERSION) {
+    return false;
+  }
+  return true;
+}
+
 enum ssl_open_record_t dtls_open_record(SSL *ssl, uint8_t *out_type,
                                         Span<uint8_t> *out,
                                         size_t *out_consumed,
@@ -174,42 +300,42 @@
 
   CBS cbs = CBS(in);
 
-  // Decode the record.
   uint8_t type;
-  uint16_t version;
-  uint8_t sequence_bytes[8];
+  size_t record_header_len;
+  if (!CBS_get_u8(&cbs, &type)) {
+    // The record header was incomplete or malformed. Drop the entire packet.
+    *out_consumed = in.size();
+    return ssl_open_record_discard;
+  }
+  SSLAEADContext *aead = ssl->s3->aead_read_ctx.get();
+  uint64_t sequence;
+  uint16_t epoch;
+  uint16_t version = 0;
   CBS body;
-  if (!CBS_get_u8(&cbs, &type) ||
-      !CBS_get_u16(&cbs, &version) ||
-      !CBS_copy_bytes(&cbs, sequence_bytes, sizeof(sequence_bytes)) ||
-      !CBS_get_u16_length_prefixed(&cbs, &body) ||
-      CBS_len(&body) > SSL3_RT_MAX_ENCRYPTED_LENGTH) {
-    // The record header was incomplete or malformed. Drop the entire packet.
-    *out_consumed = in.size();
-    return ssl_open_record_discard;
-  }
-
-  bool version_ok;
-  if (ssl->s3->aead_read_ctx->is_null_cipher()) {
-    // Only check the first byte. Enforcing beyond that can prevent decoding
-    // version negotiation failure alerts.
-    version_ok = (version >> 8) == DTLS1_VERSION_MAJOR;
+  bool valid_record_header;
+  // Decode the record header. If the 3 high bits of the type are 001, then the
+  // record header is the DTLS 1.3 format. The DTLS 1.3 format should only be
+  // used for encrypted records with DTLS 1.3. Plaintext records or DTLS 1.2
+  // records use the old record header format.
+  if ((type & 0xe0) == 0x20 && !aead->is_null_cipher() &&
+      aead->ProtocolVersion() >= TLS1_3_VERSION) {
+    valid_record_header =
+        parse_dtls13_record_header(ssl, &cbs, in.size(), type, &body, &sequence,
+                                   &epoch, &record_header_len);
   } else {
-    version_ok = version == ssl->s3->aead_read_ctx->RecordVersion();
+    valid_record_header = parse_dtls_plaintext_record_header(
+        ssl, &cbs, in.size(), type, &body, &sequence, &epoch,
+        &record_header_len, &version);
   }
-
-  if (!version_ok) {
+  if (!valid_record_header) {
     // The record header was incomplete or malformed. Drop the entire packet.
     *out_consumed = in.size();
     return ssl_open_record_discard;
   }
 
-  Span<const uint8_t> header =
-      in.subspan(0, dtls_record_header_write_len(ssl, ssl->d1->r_epoch));
+  Span<const uint8_t> header = in.subspan(0, record_header_len);
   ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HEADER, header);
 
-  uint64_t sequence = CRYPTO_load_u64_be(sequence_bytes);
-  uint16_t epoch = static_cast<uint16_t>(sequence >> 48);
   if (epoch != ssl->d1->r_epoch ||
       dtls1_bitmap_should_discard(&ssl->d1->bitmap, sequence)) {
     // Drop this record. It's from the wrong epoch or is a replay. Note that if
@@ -221,7 +347,7 @@
   }
 
   // discard the body in-place.
-  if (!ssl->s3->aead_read_ctx->Open(
+  if (!aead->Open(
           out, type, version, sequence, header,
           MakeSpan(const_cast<uint8_t *>(CBS_data(&body)), CBS_len(&body)))) {
     // Bad packets are silently dropped in DTLS. See section 4.2.1 of RFC 6347.
@@ -236,13 +362,29 @@
   }
   *out_consumed = in.size() - CBS_len(&cbs);
 
+  // DTLS 1.3 hides the record type inside the encrypted data.
+  bool has_padding =
+      !aead->is_null_cipher() && aead->ProtocolVersion() >= TLS1_3_VERSION;
   // Check the plaintext length.
-  if (out->size() > SSL3_RT_MAX_PLAIN_LENGTH) {
+  size_t plaintext_limit = SSL3_RT_MAX_PLAIN_LENGTH + (has_padding ? 1 : 0);
+  if (out->size() > plaintext_limit) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_DATA_LENGTH_TOO_LONG);
     *out_alert = SSL_AD_RECORD_OVERFLOW;
     return ssl_open_record_error;
   }
 
+  if (has_padding) {
+    do {
+      if (out->empty()) {
+        OPENSSL_PUT_ERROR(SSL, SSL_R_DECRYPTION_FAILED_OR_BAD_RECORD_MAC);
+        *out_alert = SSL_AD_DECRYPT_ERROR;
+        return ssl_open_record_error;
+      }
+      type = out->back();
+      *out = out->subspan(0, out->size() - 1);
+    } while (type == 0);
+  }
+
   dtls1_bitmap_record(&ssl->d1->bitmap, sequence);
 
   // TODO(davidben): Limit the number of empty records as in TLS? This is only
@@ -272,15 +414,34 @@
   return ssl->s3->aead_write_ctx.get();
 }
 
+static bool use_dtls13_record_header(const SSL *ssl, uint16_t epoch) {
+  // Plaintext records in DTLS 1.3 also use the DTLSPlaintext structure for
+  // backwards compatibility.
+  return ssl->s3->have_version && ssl_protocol_version(ssl) > TLS1_2_VERSION &&
+         epoch > 0;
+}
+
 size_t dtls_record_header_write_len(const SSL *ssl, uint16_t epoch) {
-  // 13 is the value of the former DTLS1_RT_HEADER_LENGTH constant.
-  return 13;
+  if (!use_dtls13_record_header(ssl, epoch)) {
+    return DTLS_PLAINTEXT_RECORD_HEADER_LENGTH;
+  }
+  // The DTLS 1.3 has a variable length record header. We never send Connection
+  // ID, we always send 16-bit sequence numbers, and we send a length. (Length
+  // can be omitted, but only for the last record of a packet. Since we send
+  // multiple records in one packet, it's easier to implement always sending the
+  // length.)
+  return DTLS1_3_RECORD_HEADER_WRITE_LENGTH;
 }
 
 size_t dtls_max_seal_overhead(const SSL *ssl,
                               uint16_t epoch) {
-  return dtls_record_header_write_len(ssl, epoch) +
-         get_write_aead(ssl, epoch)->MaxOverhead();
+  size_t ret = dtls_record_header_write_len(ssl, epoch) +
+               get_write_aead(ssl, epoch)->MaxOverhead();
+  if (use_dtls13_record_header(ssl, epoch)) {
+    // Add 1 byte for the encrypted record type.
+    ret++;
+  }
+  return ret;
 }
 
 size_t dtls_seal_prefix_len(const SSL *ssl, uint16_t epoch) {
@@ -308,16 +469,6 @@
   // of seq is probably wrong for a retransmission.
 
   const size_t record_header_len = dtls_record_header_write_len(ssl, epoch);
-  if (max_out < record_header_len) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
-    return false;
-  }
-
-  out[0] = type;
-
-  uint16_t record_version = ssl->s3->aead_write_ctx->RecordVersion();
-  out[1] = record_version >> 8;
-  out[2] = record_version & 0xff;
 
   // Ensure the sequence number update does not overflow.
   const uint64_t kMaxSequenceNumber = (uint64_t{1} << 48) - 1;
@@ -326,25 +477,68 @@
     return false;
   }
 
+  uint16_t record_version = ssl->s3->aead_write_ctx->RecordVersion();
   uint64_t seq_with_epoch = (uint64_t{epoch} << 48) | *seq;
-  CRYPTO_store_u64_be(&out[3], seq_with_epoch);
+
+  bool dtls13_header = use_dtls13_record_header(ssl, epoch);
+  uint8_t *extra_in = NULL;
+  size_t extra_in_len = 0;
+  if (dtls13_header) {
+    extra_in = &type;
+    extra_in_len = 1;
+  }
 
   size_t ciphertext_len;
-  if (!aead->CiphertextLen(&ciphertext_len, in_len, 0)) {
+  if (!aead->CiphertextLen(&ciphertext_len, in_len, extra_in_len)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_RECORD_TOO_LARGE);
     return false;
   }
-  out[11] = ciphertext_len >> 8;
-  out[12] = ciphertext_len & 0xff;
-  Span<const uint8_t> header = MakeConstSpan(out, record_header_len);
-
-  size_t len_copy;
-  if (!aead->Seal(out + record_header_len, &len_copy,
-                  max_out - record_header_len, type, record_version,
-                  seq_with_epoch, header, in, in_len)) {
+  if (max_out < record_header_len + ciphertext_len) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
     return false;
   }
-  assert(ciphertext_len == len_copy);
+
+  if (dtls13_header) {
+    // The first byte of the DTLS 1.3 record header has the following format:
+    // 0 1 2 3 4 5 6 7
+    // +-+-+-+-+-+-+-+-+
+    // |0|0|1|C|S|L|E E|
+    // +-+-+-+-+-+-+-+-+
+    //
+    // We set C=0 (no Connection ID), S=1 (16-bit sequence number), L=1 (length
+    // is present), which is a mask of 0x2c. The E E bits are the low-order two
+    // bits of the epoch.
+    //
+    // +-+-+-+-+-+-+-+-+
+    // |0|0|1|0|1|1|E E|
+    // +-+-+-+-+-+-+-+-+
+    out[0] = 0x2c | (epoch & 0x3);
+    out[1] = *seq >> 8;
+    out[2] = *seq & 0xff;
+    out[3] = ciphertext_len >> 8;
+    out[4] = ciphertext_len & 0xff;
+    // DTLS 1.3 uses the sequence number without the epoch for the AEAD.
+    seq_with_epoch = *seq;
+  } else {
+    out[0] = type;
+    out[1] = record_version >> 8;
+    out[2] = record_version & 0xff;
+    CRYPTO_store_u64_be(&out[3], seq_with_epoch);
+    out[11] = ciphertext_len >> 8;
+    out[12] = ciphertext_len & 0xff;
+  }
+  Span<const uint8_t> header = MakeConstSpan(out, record_header_len);
+
+
+  if (!aead->SealScatter(out + record_header_len, out + prefix,
+                         out + prefix + in_len, type, record_version,
+                         seq_with_epoch, header, in, in_len, extra_in,
+                         extra_in_len)) {
+    return false;
+  }
+
+  // TODO(crbug.com/boringssl/715): Perform record number encryption (RFC 9147
+  // section 4.2.3).
 
   (*seq)++;
   *out_len = record_header_len + ciphertext_len;
diff --git a/ssl/internal.h b/ssl/internal.h
index 34f377d..8cf0339 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2938,6 +2938,22 @@
 // lengths of messages
 #define DTLS1_RT_MAX_HEADER_LENGTH 13
 
+// DTLS_PLAINTEXT_RECORD_HEADER_LENGTH is the length of the DTLS record header
+// for plaintext records (in DTLS 1.3) or DTLS versions <= 1.2.
+#define DTLS_PLAINTEXT_RECORD_HEADER_LENGTH 13
+
+// DTLS1_3_RECORD_HEADER_LENGTH is the length of the DTLS 1.3 record header
+// sent by BoringSSL for encrypted records. Note that received encrypted DTLS
+// 1.3 records might have a different length header.
+#define DTLS1_3_RECORD_HEADER_WRITE_LENGTH 5
+
+static_assert(DTLS1_RT_MAX_HEADER_LENGTH >= DTLS_PLAINTEXT_RECORD_HEADER_LENGTH,
+              "DTLS1_RT_MAX_HEADER_LENGTH must not be smaller than defined "
+              "record header lengths");
+static_assert(DTLS1_RT_MAX_HEADER_LENGTH >= DTLS1_3_RECORD_HEADER_WRITE_LENGTH,
+              "DTLS1_RT_MAX_HEADER_LENGTH must not be smaller than defined "
+              "record header lengths");
+
 #define DTLS1_HM_HEADER_LENGTH 12
 
 #define DTLS1_CCS_HEADER_LENGTH 1
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 731116a..ba1a22b 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -4397,6 +4397,13 @@
 }
 
 TEST_P(SSLVersionTest, RecordCallback) {
+  if (version() == DTLS1_3_EXPERIMENTAL_VERSION) {
+    // The DTLS 1.3 record header is vastly different than the TLS or DTLS < 1.3
+    // header format. Instead of checking that the record header is formatted as
+    // expected here, the runner implementation in dtls.go is strict about what
+    // it accepts.
+    return;
+  }
   for (bool test_server : {true, false}) {
     SCOPED_TRACE(test_server);
     ASSERT_NO_FATAL_FAILURE(ResetContexts());
@@ -4430,9 +4437,6 @@
         uint16_t epoch;
         ASSERT_TRUE(CBS_get_u16(&cbs, &epoch));
         uint16_t max_epoch = 1;
-        if (version() == DTLS1_3_EXPERIMENTAL_VERSION) {
-          max_epoch = 3;
-        }
         EXPECT_LE(epoch, max_epoch) << "Invalid epoch: " << epoch;
         ASSERT_TRUE(CBS_skip(&cbs, 6));
       }
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index 89948b3..5907a35 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -617,6 +617,16 @@
 
 	CertCompressionAlgs map[uint16]CertCompressionAlg
 
+	// DTLSUseShortSeqNums specifies whether the DTLS 1.3 record header
+	// should use short (8-bit) or long (16-bit) sequence numbers. The
+	// default is to use long sequence numbers.
+	DTLSUseShortSeqNums bool
+
+	// DTLSRecordHeaderOmitLength specified whether the DTLS 1.3 record
+	// header includes a length field. The default is to include the length
+	// field.
+	DTLSRecordHeaderOmitLength bool
+
 	// Bugs specifies optional misbehaviour to be used for testing other
 	// implementations.
 	Bugs ProtocolBugs
@@ -1968,6 +1978,14 @@
 	// session ID in the ServerHello.
 	DTLS13EchoSessionID bool
 
+	// DTLSUsePlaintextRecord header, if true, has DTLS connections never
+	// use the DTLS 1.3 record header.
+	DTLSUsePlaintextRecordHeader bool
+
+	// DTLS13RecordHeaderSetCIDBit, if true, sets the Connection ID bit in
+	// the DTLS 1.3 record header.
+	DTLS13RecordHeaderSetCIDBit bool
+
 	// EncryptSessionTicketKey, if non-nil, is the ticket key to use when
 	// encrypting tickets.
 	EncryptSessionTicketKey *[32]byte
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 434ba1e..ce425a0 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -130,6 +130,8 @@
 	c.out.isDTLS = c.isDTLS
 	c.in.config = c.config
 	c.out.config = c.config
+	c.in.conn = c
+	c.out.conn = c
 
 	c.out.updateOutSeq()
 }
@@ -193,6 +195,7 @@
 	trafficSecret []byte
 
 	config *Config
+	conn   *Conn
 }
 
 func (hc *halfConn) setErrorLocked(err error) error {
@@ -362,10 +365,25 @@
 // that can depend on the bytes read.
 func (hc *halfConn) writeRecordHeaderLen() int {
 	if hc.isDTLS {
-		// TODO(nharper): Change this to be the actual record header
-		// length that will be written. This will depend on version and
-		// write cipher, as well as configuration or protocol bugs to
-		// exercise all options of the DTLS 1.3 record header.
+		usePlaintextHeader := hc.config.Bugs.DTLSUsePlaintextRecordHeader && hc.conn.handshakeComplete
+		if hc.version >= VersionTLS13 && hc.cipher != nil && !usePlaintextHeader {
+			// The DTLS 1.3 record header consists of a
+			// demultiplexing/type byte, some number of connection
+			// ID bytes, 1 or 2 sequence number bytes, and 0 or 2
+			// length bytes. Configuration options or protocol bugs
+			// will change these values to test all options of the
+			// DTLS 1.3 record header.
+			cidSize := 0
+			seqSize := 2
+			if hc.config.DTLSUseShortSeqNums {
+				seqSize = 1
+			}
+			lenSize := 2
+			if hc.config.DTLSRecordHeaderOmitLength {
+				lenSize = 0
+			}
+			return 1 + cidSize + seqSize + lenSize
+		}
 		return dtlsMaxRecordHeaderLen
 	}
 	return tlsRecordHeaderLen
@@ -502,7 +520,7 @@
 			panic("unknown cipher type")
 		}
 
-		if hc.version >= VersionTLS13 && !hc.isDTLS {
+		if hc.version >= VersionTLS13 {
 			i := len(payload)
 			for i > 0 && payload[i-1] == 0 {
 				i--
@@ -597,6 +615,11 @@
 			if c.explicitNonce {
 				nonce = b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
 			}
+			usePlaintextHeader := hc.config.Bugs.DTLSUsePlaintextRecordHeader && hc.conn.handshakeComplete
+			if hc.isDTLS && hc.version >= VersionTLS13 && !usePlaintextHeader {
+				nonce = make([]byte, 8)
+				copy(nonce[2:], hc.outSeq[2:])
+			}
 			payload := b.data[recordHeaderLen+explicitIVLen:]
 			payload = payload[:payloadLen]
 
@@ -631,9 +654,11 @@
 	}
 
 	// update length to include MAC and any block padding needed.
-	n := len(b.data) - recordHeaderLen
-	b.data[recordHeaderLen-2] = byte(n >> 8)
-	b.data[recordHeaderLen-1] = byte(n)
+	if !hc.config.DTLSRecordHeaderOmitLength {
+		n := len(b.data) - recordHeaderLen
+		b.data[recordHeaderLen-2] = byte(n >> 8)
+		b.data[recordHeaderLen-1] = byte(n)
+	}
 	hc.incSeq(true)
 
 	return true, 0
@@ -1193,21 +1218,17 @@
 	if c.out.version < VersionTLS13 || c.out.cipher == nil {
 		return recordLen
 	}
-	// TODO(nharper): DTLS 1.3 should be adding padding, but the currently
-	// implemented DTLS 1.25 doesn't include padding.
-	if !c.isDTLS {
-		paddingLen := c.config.Bugs.RecordPadding
-		if c.config.Bugs.OmitRecordContents {
-			recordLen = paddingLen
-			b.resize(recordHeaderLen + paddingLen)
-		} else {
-			recordLen += 1 + paddingLen
-			b.resize(len(b.data) + 1 + paddingLen)
-			b.data[len(b.data)-paddingLen-1] = byte(typ)
-		}
-		for i := 0; i < paddingLen; i++ {
-			b.data[len(b.data)-paddingLen+i] = 0
-		}
+	paddingLen := c.config.Bugs.RecordPadding
+	if c.config.Bugs.OmitRecordContents {
+		recordLen = paddingLen
+		b.resize(recordHeaderLen + paddingLen)
+	} else {
+		recordLen += 1 + paddingLen
+		b.resize(len(b.data) + 1 + paddingLen)
+		b.data[len(b.data)-paddingLen-1] = byte(typ)
+	}
+	for i := 0; i < paddingLen; i++ {
+		b.data[len(b.data)-paddingLen+i] = 0
 	}
 	if c, ok := c.out.cipher.(*tlsAead); ok {
 		recordLen += c.Overhead()
diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go
index 1893afe..f4921d4 100644
--- a/ssl/test/runner/dtls.go
+++ b/ssl/test/runner/dtls.go
@@ -16,6 +16,7 @@
 
 import (
 	"bytes"
+	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
@@ -23,14 +24,73 @@
 	"net"
 )
 
+func (c *Conn) readDTLS13RecordHeader(b *block) (headerLen int, recordLen int, recTyp recordType, seq []byte, err error) {
+	// The DTLS 1.3 record header starts with the type byte containing
+	// 0b001CSLEE, where C, S, L, and EE are bits with the following
+	// meanings:
+	//
+	// C=1: Connection ID is present (C=0: CID is absent)
+	// S=1: the sequence number is 16 bits (S=0: it is 8 bits)
+	// L=1: 16-bit length field is present (L=0: record goes to end of packet)
+	// EE: low two bits of the epoch.
+	//
+	// A real DTLS implementation would parse these bits and take
+	// appropriate action based on them. However, this is a test
+	// implementation, and the code we are testing only ever sends C=0, S=1,
+	// L=1. This code expects those bits to be set and will error if
+	// anything else is set. This means we expect the type byte to look like
+	// 0b001011EE, or 0x2c-0x2f.
+	recordHeaderLen := 5
+	if len(b.data) < recordHeaderLen {
+		return 0, 0, 0, nil, errors.New("dtls: failed to read record header")
+	}
+	typ := b.data[0]
+	if typ&0xfc != 0x2c {
+		return 0, 0, 0, nil, errors.New("dtls: DTLS 1.3 record header has bad type byte")
+	}
+	// For test purposes, require the epoch received be the same as the
+	// epoch we expect to receive.
+	epoch := typ & 0x03
+	if epoch != c.in.seq[1]&0x03 {
+		c.sendAlert(alertIllegalParameter)
+		return 0, 0, 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
+	}
+	wireSeq := binary.BigEndian.Uint16(b.data[1:3])
+	// Reconstruct the sequence number from the low 16 bits on the wire.
+	// A real implementation would compute the full sequence number that is
+	// closest to the highest successfully decrypted record in the
+	// identified epoch. Since this test implementation errors on decryption
+	// failures instead of simply discarding packets, it reconstructs a
+	// sequence number that is not less than c.in.seq. (This matches the
+	// behavior of the check of the sequence number in the old record
+	// header format.)
+	seqInt := binary.BigEndian.Uint64(c.in.seq[:])
+	// c.in.seq has the epoch in the upper two bytes - clear those.
+	seqInt = seqInt &^ (0xffff << 48)
+	newSeq := seqInt&^0xffff | uint64(wireSeq)
+	if newSeq < seqInt {
+		newSeq += 0x10000
+	}
+
+	seq = make([]byte, 8)
+	binary.BigEndian.PutUint64(seq, newSeq)
+	copy(c.in.seq[2:], seq[2:])
+
+	recordLen = int(b.data[3])<<8 | int(b.data[4])
+	return recordHeaderLen, recordLen, 0, seq, nil
+}
+
 // readDTLSRecordHeader reads the record header from the block. Based on the
 // header it reads, it checks the header's validity and sets appropriate state
 // as needed. This function returns the record header, the record type indicated
 // in the header (if it contains the type), and the sequence number to use for
 // record decryption.
 func (c *Conn) readDTLSRecordHeader(b *block) (headerLen int, recordLen int, typ recordType, seq []byte, err error) {
-	recordHeaderLen := 13
+	if c.in.cipher != nil && c.in.version >= VersionTLS13 {
+		return c.readDTLS13RecordHeader(b)
+	}
 
+	recordHeaderLen := 13
 	// Read out one record.
 	//
 	// A real DTLS implementation should be tolerant of errors,
@@ -114,7 +174,7 @@
 	b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
 
 	// Process message.
-	ok, off, _, alertValue := c.in.decrypt(seq, recordHeaderLen, b)
+	ok, off, encTyp, alertValue := c.in.decrypt(seq, recordHeaderLen, b)
 	if !ok {
 		// A real DTLS implementation would silently ignore bad records,
 		// but we want to notice errors from the implementation under
@@ -123,8 +183,12 @@
 	}
 	b.off = off
 
-	// TODO(nharper): Once DTLS 1.3 is defined, handle the extra
-	// parameter from decrypt.
+	if typ == 0 {
+		// readDTLSRecordHeader sets typ=0 when decoding the DTLS 1.3
+		// record header. When the new record header format is used, the
+		// type is returned by decrypt() in encTyp.
+		typ = encTyp
+	}
 
 	// Require that ChangeCipherSpec always share a packet with either the
 	// previous or next handshake message.
@@ -344,6 +408,42 @@
 	return nil
 }
 
+// writeDTLS13RecordHeader writes to b the record header for a record of length
+// recordLen.
+func (c *Conn) writeDTLS13RecordHeader(b *block, recordLen int) {
+	// Set the top 3 bits on the type byte to indicate the DTLS 1.3 record
+	// header format.
+	typ := byte(0x20)
+
+	if c.config.Bugs.DTLS13RecordHeaderSetCIDBit && c.handshakeComplete {
+		// Set the Connection ID bit
+		typ |= 0x10
+	}
+
+	// Set the sequence number length bit
+	if !c.config.DTLSUseShortSeqNums {
+		typ |= 0x08
+	}
+	// Set the length presence bit
+	if !c.config.DTLSRecordHeaderOmitLength {
+		typ |= 0x04
+	}
+	// Set the epoch bits
+	typ |= c.out.outSeq[1] & 0x3
+	b.data[0] = typ
+	lenOffset := 3
+	if c.config.DTLSUseShortSeqNums {
+		b.data[1] = c.out.outSeq[7]
+		lenOffset = 2
+	} else {
+		copy(b.data[1:3], c.out.outSeq[6:8])
+	}
+	if !c.config.DTLSRecordHeaderOmitLength {
+		b.data[lenOffset] = byte(recordLen >> 8)
+		b.data[lenOffset+1] = byte(recordLen)
+	}
+}
+
 // dtlsPackRecord packs a single record to the pending packet, flushing it
 // if necessary. The caller should call dtlsFlushPacket to flush the current
 // pending packet afterwards.
@@ -375,9 +475,6 @@
 		panic("Unknown cipher")
 	}
 	b.resize(recordHeaderLen + explicitIVLen + len(data))
-	// TODO(nharper): DTLS 1.3 will likely need to set this to
-	// recordTypeApplicationData if c.out.cipher != nil.
-	b.data[0] = byte(typ)
 	vers := c.wireVersion
 	if vers == 0 {
 		// Some TLS servers fail if the record version is greater than
@@ -391,10 +488,6 @@
 	if c.vers >= VersionTLS13 || c.out.version >= VersionTLS13 {
 		vers = VersionDTLS12
 	}
-	b.data[1] = byte(vers >> 8)
-	b.data[2] = byte(vers)
-	// DTLS records include an explicit sequence number.
-	copy(b.data[3:11], c.out.outSeq[0:])
 	if explicitIVLen > 0 {
 		explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
 		if explicitIVIsSeq {
@@ -407,8 +500,17 @@
 	}
 	copy(b.data[recordHeaderLen+explicitIVLen:], data)
 	recordLen := c.addTLS13Padding(b, recordHeaderLen, len(data), typ)
-	b.data[11] = byte(recordLen >> 8)
-	b.data[12] = byte(recordLen)
+	if c.out.version < VersionTLS13 || c.out.cipher == nil || (c.config.Bugs.DTLSUsePlaintextRecordHeader && c.handshakeComplete) {
+		b.data[0] = byte(typ)
+		b.data[1] = byte(vers >> 8)
+		b.data[2] = byte(vers)
+		// DTLS records include an explicit sequence number.
+		copy(b.data[3:11], c.out.outSeq[0:])
+		b.data[11] = byte(recordLen >> 8)
+		b.data[12] = byte(recordLen)
+	} else {
+		c.writeDTLS13RecordHeader(b, recordLen)
+	}
 	c.out.encrypt(b, explicitIVLen, typ)
 
 	// Flush the current pending packet if necessary.
@@ -423,6 +525,15 @@
 	// Add the record to the pending packet.
 	c.pendingPacket = append(c.pendingPacket, b.data...)
 	c.out.freeBlock(b)
+	if c.config.DTLSRecordHeaderOmitLength {
+		if c.config.Bugs.SplitAndPackAppData {
+			panic("incompatible config")
+		}
+		err = c.dtlsFlushPacket()
+		if err != nil {
+			return
+		}
+	}
 	n = len(data)
 	return
 }
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index dae8a2a..16cebb2 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -3729,6 +3729,67 @@
 		shouldFail:    true,
 		expectedError: ":DECODE_ERROR:",
 	})
+
+	// DTLS 1.3 should work with record headers that don't set the
+	// length bit or that use the short sequence number format.
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		protocol: dtls,
+		name:     "DTLS13RecordHeader-NoLength-Client",
+		config: Config{
+			MinVersion:                 VersionTLS13,
+			DTLSRecordHeaderOmitLength: true,
+		},
+	})
+	testCases = append(testCases, testCase{
+		testType: serverTest,
+		protocol: dtls,
+		name:     "DTLS13RecordHeader-NoLength-Server",
+		config: Config{
+			MinVersion:                 VersionTLS13,
+			DTLSRecordHeaderOmitLength: true,
+		},
+	})
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		protocol: dtls,
+		name:     "DTLS13RecordHeader-ShortSeqNums-Client",
+		config: Config{
+			MinVersion:          VersionTLS13,
+			DTLSUseShortSeqNums: true,
+		},
+	})
+	testCases = append(testCases, testCase{
+		testType: serverTest,
+		protocol: dtls,
+		name:     "DTLS13RecordHeader-ShortSeqNums-Server",
+		config: Config{
+			MinVersion:          VersionTLS13,
+			DTLSUseShortSeqNums: true,
+		},
+	})
+	testCases = append(testCases, testCase{
+		protocol: dtls,
+		name:     "DTLS13RecordHeader-OldHeader",
+		config: Config{
+			MinVersion: VersionTLS13,
+			Bugs: ProtocolBugs{
+				DTLSUsePlaintextRecordHeader: true,
+			},
+		},
+		expectMessageDropped: true,
+	})
+	testCases = append(testCases, testCase{
+		protocol: dtls,
+		name:     "DTLS13RecordHeader-CIDBit",
+		config: Config{
+			MinVersion: VersionTLS13,
+			Bugs: ProtocolBugs{
+				DTLS13RecordHeaderSetCIDBit: true,
+			},
+		},
+		expectMessageDropped: true,
+	})
 }
 
 func addTestForCipherSuite(suite testCipherSuite, ver tlsVersion, protocol protocol) {