Add DTLS 1.3 sequence number encryption

Bug: 715
Change-Id: I87f8a08e9a2258dede21cffb1cfde5802608d30d
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/70667
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: Bob Beck <bbe@google.com>
diff --git a/ssl/dtls_record.cc b/ssl/dtls_record.cc
index 783e950..a83d6b1 100644
--- a/ssl/dtls_record.cc
+++ b/ssl/dtls_record.cc
@@ -195,7 +195,7 @@
   return seqnum;
 }
 
-static bool parse_dtls13_record_header(SSL *ssl, CBS *in, size_t packet_size,
+static bool parse_dtls13_record_header(SSL *ssl, CBS *in, Span<uint8_t> packet,
                                        uint8_t type, CBS *out_body,
                                        uint64_t *out_sequence,
                                        uint16_t *out_epoch,
@@ -206,29 +206,23 @@
     // Connection ID bit set, which we didn't negotiate.
     return false;
   }
+
   // TODO(crbug.com/boringssl/715): Add a runner test that performs many
   // key updates to verify epoch reconstruction works for epochs larger than
   // 3.
   *out_epoch = reconstruct_epoch(type, ssl->d1->r_epoch);
+  size_t seqlen = 1;
   if ((type & 0x08) == 0x08) {
-    // 16-bit sequence number.
-    uint16_t seq;
-    if (!CBS_get_u16(in, &seq)) {
-      // The record header was incomplete or malformed.
-      return false;
-    }
-    *out_sequence =
-        reconstruct_seqnum(seq, 0xffff, ssl->d1->bitmap.max_seq_num);
-  } else {
-    // 8-bit sequence number.
-    uint8_t seq;
-    if (!CBS_get_u8(in, &seq)) {
-      // The record header was incomplete or malformed.
-      return false;
-    }
-    *out_sequence = reconstruct_seqnum(seq, 0xff, ssl->d1->bitmap.max_seq_num);
+    // If this bit is set, the sequence number is 16 bits long, otherwise it is
+    // 8 bits. The seqlen variable tracks the length of the sequence number in
+    // bytes.
+    seqlen = 2;
   }
-  *out_header_len = packet_size - CBS_len(in);
+  if (!CBS_skip(in, seqlen)) {
+    // The record header was incomplete or malformed.
+    return false;
+  }
+  *out_header_len = packet.size() - CBS_len(in);
   if ((type & 0x04) == 0x04) {
     *out_header_len += 2;
     // 16-bit length present
@@ -244,6 +238,26 @@
       return false;
     }
   }
+
+  // Decrypt and reconstruct the sequence number:
+  uint8_t mask[AES_BLOCK_SIZE];
+  SSLAEADContext *aead = ssl->s3->aead_read_ctx.get();
+  if (!aead->GenerateRecordNumberMask(mask, *out_body)) {
+    // GenerateRecordNumberMask most likely failed because the record body was
+    // not long enough.
+    return false;
+  }
+  // Apply the mask to the sequence number as it exists in the header. The
+  // header (with the decrypted sequence number bytes) is used as the
+  // additional data for the AEAD function. Since we don't support Connection
+  // ID, the sequence number starts immediately after the type byte.
+  uint64_t seq = 0;
+  for (size_t i = 0; i < seqlen; i++) {
+    packet[i + 1] ^= mask[i];
+    seq = (seq << 8) | packet[i + 1];
+  }
+  *out_sequence = reconstruct_seqnum(seq, (1 << (seqlen * 8)) - 1,
+                                     ssl->d1->bitmap.max_seq_num);
   return true;
 }
 
@@ -321,9 +335,8 @@
   // records use the old record header format.
   if ((type & 0xe0) == 0x20 && !aead->is_null_cipher() &&
       aead->ProtocolVersion() >= TLS1_3_VERSION) {
-    valid_record_header =
-        parse_dtls13_record_header(ssl, &cbs, in.size(), type, &body, &sequence,
-                                   &epoch, &record_header_len);
+    valid_record_header = parse_dtls13_record_header(
+        ssl, &cbs, in, type, &body, &sequence, &epoch, &record_header_len);
   } else {
     valid_record_header = parse_dtls_plaintext_record_header(
         ssl, &cbs, in.size(), type, &body, &sequence, &epoch,
@@ -539,8 +552,24 @@
     return false;
   }
 
-  // TODO(crbug.com/boringssl/715): Perform record number encryption (RFC 9147
-  // section 4.2.3).
+  // Perform record number encryption (RFC 9147 section 4.2.3).
+  if (dtls13_header) {
+    // Record number encryption uses bytes from the ciphertext as a sample to
+    // generate the mask used for encryption. For simplicity, pass in the whole
+    // ciphertext as the sample - GenerateRecordNumberMask will read only what
+    // it needs (and error if |sample| is too short).
+    Span<const uint8_t> sample =
+        MakeConstSpan(out + record_header_len, ciphertext_len);
+    // AES cipher suites require the mask be exactly AES_BLOCK_SIZE; ChaCha20
+    // cipher suites have no requirements on the mask size. We only need the
+    // first two bytes from the mask.
+    uint8_t mask[AES_BLOCK_SIZE];
+    if (!aead->GenerateRecordNumberMask(mask, sample)) {
+      return false;
+    }
+    out[1] ^= mask[0];
+    out[2] ^= mask[1];
+  }
 
   (*seq)++;
   *out_len = record_header_len + ciphertext_len;
diff --git a/ssl/internal.h b/ssl/internal.h
index febb676..e651828 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -155,6 +155,7 @@
 #include <utility>
 
 #include <openssl/aead.h>
+#include <openssl/aes.h>
 #include <openssl/curve25519.h>
 #include <openssl/err.h>
 #include <openssl/hpke.h>
@@ -811,6 +812,16 @@
 
 // Encryption layer.
 
+class RecordNumberEncrypter {
+ public:
+  virtual ~RecordNumberEncrypter() = default;
+  static constexpr bool kAllowUniquePtr = true;
+
+  virtual size_t KeySize() = 0;
+  virtual bool SetKey(Span<const uint8_t> key) = 0;
+  virtual bool GenerateMask(Span<uint8_t> out, Span<const uint8_t> sample) = 0;
+};
+
 // SSLAEADContext contains information about an AEAD that is being used to
 // encrypt an SSL connection.
 class SSLAEADContext {
@@ -916,6 +927,17 @@
 
   bool GetIV(const uint8_t **out_iv, size_t *out_iv_len) const;
 
+  RecordNumberEncrypter *GetRecordNumberEncrypter() {
+    return rn_encrypter_.get();
+  }
+
+  // GenerateRecordNumberMask computes the mask used for DTLS 1.3 record number
+  // encryption (RFC 9147 section 4.2.3), writing it to |out|. The |out| buffer
+  // must be sized to AES_BLOCK_SIZE. The |sample| buffer must be at least 16
+  // bytes, as required by the AES and ChaCha20 cipher suites in RFC 9147. Extra
+  // bytes in |sample| will be ignored.
+  bool GenerateRecordNumberMask(Span<uint8_t> out, Span<const uint8_t> sample);
+
  private:
   // GetAdditionalData returns the additional data, writing into |storage| if
   // necessary.
@@ -924,6 +946,8 @@
                                         uint64_t seqnum, size_t plaintext_len,
                                         Span<const uint8_t> header);
 
+  void CreateRecordNumberEncrypter();
+
   const SSL_CIPHER *cipher_;
   ScopedEVP_AEAD_CTX ctx_;
   // fixed_nonce_ contains any bytes of the nonce that are fixed for all
@@ -932,6 +956,7 @@
   uint8_t fixed_nonce_len_ = 0, variable_nonce_len_ = 0;
   // version_ is the wire version that should be used with this AEAD.
   uint16_t version_;
+  UniquePtr<RecordNumberEncrypter> rn_encrypter_;
   // is_dtls_ is whether DTLS is being used with this AEAD.
   bool is_dtls_;
   // variable_nonce_included_in_record_ is true if the variable nonce
@@ -951,6 +976,45 @@
   bool ad_is_header_ : 1;
 };
 
+class AESRecordNumberEncrypter : public RecordNumberEncrypter {
+ public:
+  bool SetKey(Span<const uint8_t> key) override;
+  bool GenerateMask(Span<uint8_t> out, Span<const uint8_t> sample) override;
+
+ private:
+  AES_KEY key_;
+};
+
+class AES128RecordNumberEncrypter : public AESRecordNumberEncrypter {
+ public:
+  size_t KeySize() override;
+};
+
+class AES256RecordNumberEncrypter : public AESRecordNumberEncrypter {
+ public:
+  size_t KeySize() override;
+};
+
+class ChaChaRecordNumberEncrypter : public RecordNumberEncrypter {
+ public:
+  size_t KeySize() override;
+  bool SetKey(Span<const uint8_t> key) override;
+  bool GenerateMask(Span<uint8_t> out, Span<const uint8_t> sample) override;
+
+ private:
+  static const size_t kKeySize = 32;
+  uint8_t key_[kKeySize];
+};
+
+#if defined(BORINGSSL_UNSAFE_FUZZER_MODE)
+class NullRecordNumberEncrypter : public RecordNumberEncrypter {
+ public:
+  size_t KeySize() override;
+  bool SetKey(Span<const uint8_t> key) override;
+  bool GenerateMask(Span<uint8_t> out, Span<const uint8_t> sample) override;
+};
+#endif  // BORINGSSL_UNSAFE_FUZZER_MODE
+
 
 // DTLS replay bitmap.
 
diff --git a/ssl/ssl_aead_ctx.cc b/ssl/ssl_aead_ctx.cc
index 85617a4..4f532e9 100644
--- a/ssl/ssl_aead_ctx.cc
+++ b/ssl/ssl_aead_ctx.cc
@@ -18,6 +18,7 @@
 #include <string.h>
 
 #include <openssl/aead.h>
+#include <openssl/chacha.h>
 #include <openssl/err.h>
 #include <openssl/rand.h>
 
@@ -44,6 +45,7 @@
       omit_length_in_ad_(false),
       ad_is_header_(false) {
   OPENSSL_memset(fixed_nonce_, 0, sizeof(fixed_nonce_));
+  CreateRecordNumberEncrypter();
 }
 
 SSLAEADContext::~SSLAEADContext() {}
@@ -145,6 +147,23 @@
   return aead_ctx;
 }
 
+void SSLAEADContext::CreateRecordNumberEncrypter() {
+  if (!cipher_) {
+    return;
+  }
+#if defined(BORINGSSL_UNSAFE_FUZZER_MODE)
+  rn_encrypter_ = MakeUnique<NullRecordNumberEncrypter>();
+#else
+  if (cipher_->algorithm_enc == SSL_AES128GCM) {
+    rn_encrypter_ = MakeUnique<AES128RecordNumberEncrypter>();
+  } else if (cipher_->algorithm_enc == SSL_AES256GCM) {
+    rn_encrypter_ = MakeUnique<AES256RecordNumberEncrypter>();
+  } else if (cipher_->algorithm_enc == SSL_CHACHA20POLY1305) {
+    rn_encrypter_ = MakeUnique<ChaChaRecordNumberEncrypter>();
+  }
+#endif  // BORINGSSL_UNSAFE_FUZZER_MODE
+}
+
 UniquePtr<SSLAEADContext> SSLAEADContext::CreatePlaceholderForQUIC(
     uint16_t version, const SSL_CIPHER *cipher) {
   return MakeUnique<SSLAEADContext>(version, false, cipher);
@@ -427,4 +446,70 @@
          EVP_AEAD_CTX_get_iv(ctx_.get(), out_iv, out_iv_len);
 }
 
+bool SSLAEADContext::GenerateRecordNumberMask(Span<uint8_t> out,
+                                              Span<const uint8_t> sample) {
+  if (!rn_encrypter_) {
+    return false;
+  }
+  return rn_encrypter_->GenerateMask(out, sample);
+}
+
+size_t AES128RecordNumberEncrypter::KeySize() { return 16; }
+
+size_t AES256RecordNumberEncrypter::KeySize() { return 32; }
+
+bool AESRecordNumberEncrypter::SetKey(Span<const uint8_t> key) {
+  return AES_set_encrypt_key(key.data(), key.size() * 8, &key_) == 0;
+}
+
+bool AESRecordNumberEncrypter::GenerateMask(Span<uint8_t> out,
+                                            Span<const uint8_t> sample) {
+  if (sample.size() < AES_BLOCK_SIZE || out.size() != AES_BLOCK_SIZE) {
+    return false;
+  }
+  AES_encrypt(sample.data(), out.data(), &key_);
+  return true;
+}
+
+size_t ChaChaRecordNumberEncrypter::KeySize() { return kKeySize; }
+
+bool ChaChaRecordNumberEncrypter::SetKey(Span<const uint8_t> key) {
+  if (key.size() != kKeySize) {
+    return false;
+  }
+  OPENSSL_memcpy(key_, key.data(), key.size());
+  return true;
+}
+
+bool ChaChaRecordNumberEncrypter::GenerateMask(Span<uint8_t> out,
+                                               Span<const uint8_t> sample) {
+  Array<uint8_t> zeroes;
+  if (!zeroes.Init(out.size())) {
+    return false;
+  }
+  OPENSSL_memset(zeroes.data(), 0, zeroes.size());
+  // RFC 9147 section 4.2.3 uses the first 4 bytes of the sample as the counter
+  // and the next 12 bytes as the nonce. If we have less than 4+12=16 bytes in
+  // the sample, then we'll read past the end of the |sample| buffer.
+  if (sample.size() < 16) {
+    return false;
+  }
+  uint32_t counter = CRYPTO_load_u32_be(sample.data());
+  Span<const uint8_t> nonce = sample.subspan(4);
+  CRYPTO_chacha_20(out.data(), zeroes.data(), zeroes.size(), key_, nonce.data(),
+                   counter);
+  return true;
+}
+
+#if defined(BORINGSSL_UNSAFE_FUZZER_MODE)
+size_t NullRecordNumberEncrypter::KeySize() { return 0; }
+bool NullRecordNumberEncrypter::SetKey(Span<const uint8_t> key) { return true; }
+
+bool NullRecordNumberEncrypter::GenerateMask(Span<uint8_t> out,
+                                             Span<const uint8_t> sample) {
+  OPENSSL_memset(out.data(), 0, out.size());
+  return true;
+}
+#endif  // BORINGSSL_UNSAFE_FUZZER_MODE
+
 BSSL_NAMESPACE_END
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index ce425a0..1988ab2 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -8,6 +8,7 @@
 
 import (
 	"bytes"
+	"crypto/aes"
 	"crypto/cipher"
 	"crypto/ecdsa"
 	"crypto/subtle"
@@ -19,6 +20,9 @@
 	"net"
 	"sync"
 	"time"
+
+	"golang.org/x/crypto/chacha20"
+	"golang.org/x/crypto/cryptobyte"
 )
 
 // A Conn represents a secured connection.
@@ -175,15 +179,16 @@
 type halfConn struct {
 	sync.Mutex
 
-	err         error  // first permanent error
-	version     uint16 // protocol version
-	wireVersion uint16 // wire version
-	isDTLS      bool
-	cipher      any // cipher algorithm
-	mac         macFunction
-	seq         [8]byte // 64-bit sequence number
-	outSeq      [8]byte // Mapped sequence number
-	bfree       *block  // list of free blocks
+	err                   error  // first permanent error
+	version               uint16 // protocol version
+	wireVersion           uint16 // wire version
+	isDTLS                bool
+	cipher                any // cipher algorithm
+	recordNumberEncrypter recordNumberEncrypter
+	mac                   macFunction
+	seq                   [8]byte // 64-bit sequence number
+	outSeq                [8]byte // Mapped sequence number
+	bfree                 *block  // list of free blocks
 
 	nextCipher any         // next encryption state
 	nextMac    macFunction // next MAC algorithm
@@ -253,6 +258,17 @@
 	}
 	hc.version = protocolVersion
 	hc.cipher = deriveTrafficAEAD(version, suite, secret, side, hc.isDTLS)
+	if hc.isDTLS && !hc.config.Bugs.NullAllCiphers {
+		sn_key := hkdfExpandLabel(suite.hash(), secret, []byte("sn"), nil, suite.keyLen, hc.isDTLS)
+		switch suite.id {
+		case TLS_CHACHA20_POLY1305_SHA256:
+			hc.recordNumberEncrypter = newChachaRecordNumberEncrypter(sn_key)
+		case TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384:
+			hc.recordNumberEncrypter = newAESRecordNumberEncrypter(sn_key)
+		default:
+			panic("Cipher suite does not support TLS 1.3")
+		}
+	}
 	if hc.config.Bugs.NullAllCiphers {
 		hc.cipher = nullCipher{}
 	}
@@ -762,6 +778,63 @@
 	return b, bb
 }
 
+type recordNumberEncrypter interface {
+	// GenerateMask takes a sample of the encrypted record and returns the
+	// mask used to encrypt and decrypt record numbers.
+	generateMask(sample []byte) []byte
+}
+
+type aesRecordNumberEncrypter struct {
+	aesCipher cipher.Block
+}
+
+func newAESRecordNumberEncrypter(key []byte) *aesRecordNumberEncrypter {
+	aesCipher, err := aes.NewCipher(key)
+	if err != nil {
+		panic("Incorrect usage of newAESRecordNumberEncrypter")
+	}
+	return &aesRecordNumberEncrypter{
+		aesCipher: aesCipher,
+	}
+}
+
+func (a *aesRecordNumberEncrypter) generateMask(sample []byte) []byte {
+	out := make([]byte, len(sample))
+	a.aesCipher.Encrypt(out, sample)
+	return out
+}
+
+type chachaRecordNumberEncrypter struct {
+	key []byte
+}
+
+func newChachaRecordNumberEncrypter(key []byte) *chachaRecordNumberEncrypter {
+	out := &chachaRecordNumberEncrypter{
+		key: key,
+	}
+	fmt.Printf("new RNE with key %x\n", key)
+	return out
+}
+
+func (c *chachaRecordNumberEncrypter) generateMask(sample []byte) []byte {
+	var counter uint32
+	nonce := make([]byte, 12)
+	sampleReader := cryptobyte.String(sample)
+	if !sampleReader.ReadUint32(&counter) || !sampleReader.CopyBytes(nonce) {
+		panic("chachaRecordNumberEncrypter.GenerateMask called with wrong size sample")
+	}
+	cipher, err := chacha20.NewUnauthenticatedCipher(c.key, nonce)
+	if err != nil {
+		panic("Failed to create chacha20 cipher for record number encryption")
+	}
+	cipher.SetCounter(counter)
+	zeroes := make([]byte, 2)
+	out := make([]byte, 2)
+	cipher.XORKeyStream(out, zeroes)
+	fmt.Printf("golang generateMask: sample: %x, key: %x, mask: %x\n", sample[:16], c.key, out)
+	return out
+}
+
 func (c *Conn) useInTrafficSecret(level encryptionLevel, version uint16, suite *cipherSuite, secret []byte) error {
 	if c.hand.Len() != 0 {
 		return c.in.setErrorLocked(errors.New("tls: buffered handshake messages on cipher change"))
diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go
index f4921d4..8c723f2 100644
--- a/ssl/test/runner/dtls.go
+++ b/ssl/test/runner/dtls.go
@@ -55,7 +55,13 @@
 		c.sendAlert(alertIllegalParameter)
 		return 0, 0, 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
 	}
-	wireSeq := binary.BigEndian.Uint16(b.data[1:3])
+	wireSeq := b.data[1:3]
+	if !c.config.Bugs.NullAllCiphers {
+		sample := b.data[recordHeaderLen:]
+		mask := c.in.recordNumberEncrypter.generateMask(sample)
+		xorSlice(wireSeq, mask)
+	}
+	decWireSeq := binary.BigEndian.Uint16(wireSeq)
 	// Reconstruct the sequence number from the low 16 bits on the wire.
 	// A real implementation would compute the full sequence number that is
 	// closest to the highest successfully decrypted record in the
@@ -67,7 +73,7 @@
 	seqInt := binary.BigEndian.Uint64(c.in.seq[:])
 	// c.in.seq has the epoch in the upper two bytes - clear those.
 	seqInt = seqInt &^ (0xffff << 48)
-	newSeq := seqInt&^0xffff | uint64(wireSeq)
+	newSeq := seqInt&^0xffff | uint64(decWireSeq)
 	if newSeq < seqInt {
 		newSeq += 0x10000
 	}
@@ -500,7 +506,10 @@
 	}
 	copy(b.data[recordHeaderLen+explicitIVLen:], data)
 	recordLen := c.addTLS13Padding(b, recordHeaderLen, len(data), typ)
-	if c.out.version < VersionTLS13 || c.out.cipher == nil || (c.config.Bugs.DTLSUsePlaintextRecordHeader && c.handshakeComplete) {
+	useDTLS13RecordHeader := c.out.version >= VersionTLS13 && c.out.cipher != nil && !(c.config.Bugs.DTLSUsePlaintextRecordHeader && c.handshakeComplete)
+	if useDTLS13RecordHeader {
+		c.writeDTLS13RecordHeader(b, recordLen)
+	} else {
 		b.data[0] = byte(typ)
 		b.data[1] = byte(vers >> 8)
 		b.data[2] = byte(vers)
@@ -508,10 +517,26 @@
 		copy(b.data[3:11], c.out.outSeq[0:])
 		b.data[11] = byte(recordLen >> 8)
 		b.data[12] = byte(recordLen)
-	} else {
-		c.writeDTLS13RecordHeader(b, recordLen)
 	}
+	// encrypt will increment the sequence number. Copy it here to use when
+	// performing sequence number encryption.
+	seqBytes := make([]byte, 2)
+	copy(seqBytes, c.out.outSeq[6:8])
 	c.out.encrypt(b, explicitIVLen, typ)
+	if useDTLS13RecordHeader && !c.config.Bugs.NullAllCiphers {
+		recordHeaderLen := c.out.writeRecordHeaderLen()
+		sample := b.data[recordHeaderLen:]
+		mask := c.out.recordNumberEncrypter.generateMask(sample)
+		if c.config.DTLSUseShortSeqNums {
+			seqBytes = seqBytes[1:2]
+		}
+		xorSlice(seqBytes, mask)
+		for i := range seqBytes {
+			// The sequence number starts at index 1 in the record
+			// header.
+			b.data[1+i] = seqBytes[i]
+		}
+	}
 
 	// Flush the current pending packet if necessary.
 	if !mustPack && len(b.data)+len(c.pendingPacket) > c.config.Bugs.PackHandshakeRecords {
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index 1613a3a..7c193a3 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -184,6 +184,8 @@
                            const SSL_SESSION *session,
                            Span<const uint8_t> traffic_secret) {
   uint16_t version = ssl_session_protocol_version(session);
+  const EVP_MD *digest = ssl_session_get_digest(session);
+  bool is_dtls = SSL_is_dtls(ssl);
   UniquePtr<SSLAEADContext> traffic_aead;
   Span<const uint8_t> secret_for_quic;
   if (ssl->quic_method != nullptr) {
@@ -197,18 +199,16 @@
     const EVP_AEAD *aead;
     size_t discard;
     if (!ssl_cipher_get_evp_aead(&aead, &discard, &discard, session->cipher,
-                                 version, SSL_is_dtls(ssl))) {
+                                 version, is_dtls)) {
       return false;
     }
 
-    const EVP_MD *digest = ssl_session_get_digest(session);
-
     // Derive the key.
     size_t key_len = EVP_AEAD_key_length(aead);
     uint8_t key_buf[EVP_AEAD_MAX_KEY_LENGTH];
     auto key = MakeSpan(key_buf, key_len);
     if (!hkdf_expand_label(key, digest, traffic_secret, label_to_span("key"),
-                           {}, SSL_is_dtls(ssl))) {
+                           {}, is_dtls)) {
       return false;
     }
 
@@ -217,19 +217,34 @@
     uint8_t iv_buf[EVP_AEAD_MAX_NONCE_LENGTH];
     auto iv = MakeSpan(iv_buf, iv_len);
     if (!hkdf_expand_label(iv, digest, traffic_secret, label_to_span("iv"), {},
-                           SSL_is_dtls(ssl))) {
+                           is_dtls)) {
       return false;
     }
 
-    traffic_aead = SSLAEADContext::Create(direction, session->ssl_version,
-                                          SSL_is_dtls(ssl), session->cipher,
-                                          key, Span<const uint8_t>(), iv);
+    traffic_aead =
+        SSLAEADContext::Create(direction, session->ssl_version, is_dtls,
+                               session->cipher, key, Span<const uint8_t>(), iv);
   }
 
   if (!traffic_aead) {
     return false;
   }
 
+  if (is_dtls) {
+    RecordNumberEncrypter *rn_encrypter =
+        traffic_aead->GetRecordNumberEncrypter();
+    if (!rn_encrypter) {
+      return false;
+    }
+    Array<uint8_t> rne_key;
+    if (!rne_key.Init(rn_encrypter->KeySize()) ||
+        !hkdf_expand_label(MakeSpan(rne_key), digest, traffic_secret,
+                           label_to_span("sn"), {}, is_dtls) ||
+        !rn_encrypter->SetKey(MakeSpan(rne_key))) {
+      return false;
+    }
+  }
+
   if (traffic_secret.size() >
           OPENSSL_ARRAY_SIZE(ssl->s3->read_traffic_secret) ||
       traffic_secret.size() >