diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index 620a2e1..ae26de7 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -77,7 +77,8 @@
   }
 }
 
-static bool dtls1_set_read_state(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx) {
+static bool dtls1_set_read_state(SSL *ssl, ssl_encryption_level_t level,
+                                 UniquePtr<SSLAEADContext> aead_ctx) {
   // Cipher changes are forbidden if the current epoch has leftover data.
   if (dtls_has_unprocessed_handshake_data(ssl)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESS_HANDSHAKE_DATA);
@@ -90,11 +91,12 @@
   OPENSSL_memset(ssl->s3->read_sequence, 0, sizeof(ssl->s3->read_sequence));
 
   ssl->s3->aead_read_ctx = std::move(aead_ctx);
+  ssl->s3->read_level = level;
   ssl->d1->has_change_cipher_spec = 0;
   return true;
 }
 
-static bool dtls1_set_write_state(SSL *ssl,
+static bool dtls1_set_write_state(SSL *ssl, ssl_encryption_level_t level,
                                   UniquePtr<SSLAEADContext> aead_ctx) {
   ssl->d1->w_epoch++;
   OPENSSL_memcpy(ssl->d1->last_write_sequence, ssl->s3->write_sequence,
@@ -103,6 +105,7 @@
 
   ssl->d1->last_aead_write_ctx = std::move(ssl->s3->aead_write_ctx);
   ssl->s3->aead_write_ctx = std::move(aead_ctx);
+  ssl->s3->write_level = level;
   return true;
 }
 
diff --git a/ssl/internal.h b/ssl/internal.h
index a6b58c5..78c56b6 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2147,14 +2147,16 @@
   int (*flush_flight)(SSL *ssl);
   // on_handshake_complete is called when the handshake is complete.
   void (*on_handshake_complete)(SSL *ssl);
-  // set_read_state sets |ssl|'s read cipher state to |aead_ctx|. It returns
-  // true on success and false if changing the read state is forbidden at this
-  // point.
-  bool (*set_read_state)(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx);
-  // set_write_state sets |ssl|'s write cipher state to |aead_ctx|. It returns
-  // true on success and false if changing the write state is forbidden at this
-  // point.
-  bool (*set_write_state)(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx);
+  // set_read_state sets |ssl|'s read cipher state and level to |aead_ctx| and
+  // |level|. It returns true on success and false if changing the read state is
+  // forbidden at this point.
+  bool (*set_read_state)(SSL *ssl, ssl_encryption_level_t level,
+                         UniquePtr<SSLAEADContext> aead_ctx);
+  // set_write_state sets |ssl|'s write cipher state and level to |aead_ctx| and
+  // |level|. It returns true on success and false if changing the write state
+  // is forbidden at this point.
+  bool (*set_write_state)(SSL *ssl, ssl_encryption_level_t level,
+                          UniquePtr<SSLAEADContext> aead_ctx);
 };
 
 // The following wrappers call |open_*| but handle |read_shutdown| correctly.
diff --git a/ssl/t1_enc.cc b/ssl/t1_enc.cc
index 4c2fffb..8091021 100644
--- a/ssl/t1_enc.cc
+++ b/ssl/t1_enc.cc
@@ -236,10 +236,12 @@
   }
 
   if (direction == evp_aead_open) {
-    return ssl->method->set_read_state(ssl, std::move(aead_ctx));
+    return ssl->method->set_read_state(ssl, ssl_encryption_application,
+                                       std::move(aead_ctx));
   }
 
-  return ssl->method->set_write_state(ssl, std::move(aead_ctx));
+  return ssl->method->set_write_state(ssl, ssl_encryption_application,
+                                      std::move(aead_ctx));
 }
 
 int tls1_change_cipher_state(SSL_HANDSHAKE *hs,
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index e22c1e1..716c7b4 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -207,7 +207,8 @@
   bssl::UniquePtr<SSLAEADContext> null_ctx =
       SSLAEADContext::CreateNullCipher(SSL_is_dtls(ssl));
   if (!null_ctx ||
-      !ssl->method->set_write_state(ssl, std::move(null_ctx))) {
+      !ssl->method->set_write_state(ssl, ssl_encryption_initial,
+                                    std::move(null_ctx))) {
     return ssl_hs_error;
   }
 
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index d0c27b6..bd12f63 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -190,17 +190,6 @@
     return false;
   }
 
-  if (direction == evp_aead_open) {
-    if (!ssl->method->set_read_state(ssl, std::move(traffic_aead))) {
-      return false;
-    }
-  } else {
-    if (!ssl->method->set_write_state(ssl, std::move(traffic_aead))) {
-      return false;
-    }
-  }
-
-  // Save the traffic secret.
   if (traffic_secret.size() >
           OPENSSL_ARRAY_SIZE(ssl->s3->read_traffic_secret) ||
       traffic_secret.size() >
@@ -208,16 +197,21 @@
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return false;
   }
+
   if (direction == evp_aead_open) {
+    if (!ssl->method->set_read_state(ssl, level, std::move(traffic_aead))) {
+      return false;
+    }
     OPENSSL_memmove(ssl->s3->read_traffic_secret, traffic_secret.data(),
                     traffic_secret.size());
     ssl->s3->read_traffic_secret_len = traffic_secret.size();
-    ssl->s3->read_level = level;
   } else {
+    if (!ssl->method->set_write_state(ssl, level, std::move(traffic_aead))) {
+      return false;
+    }
     OPENSSL_memmove(ssl->s3->write_traffic_secret, traffic_secret.data(),
                     traffic_secret.size());
     ssl->s3->write_traffic_secret_len = traffic_secret.size();
-    ssl->s3->write_level = level;
   }
 
   return true;
diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc
index 241a3fd..3868852 100644
--- a/ssl/tls_method.cc
+++ b/ssl/tls_method.cc
@@ -82,7 +82,8 @@
   }
 }
 
-static bool tls_set_read_state(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx) {
+static bool tls_set_read_state(SSL *ssl, ssl_encryption_level_t level,
+                               UniquePtr<SSLAEADContext> aead_ctx) {
   // Cipher changes are forbidden if the current epoch has leftover data.
   if (tls_has_unprocessed_handshake_data(ssl)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESS_HANDSHAKE_DATA);
@@ -92,16 +93,19 @@
 
   OPENSSL_memset(ssl->s3->read_sequence, 0, sizeof(ssl->s3->read_sequence));
   ssl->s3->aead_read_ctx = std::move(aead_ctx);
+  ssl->s3->read_level = level;
   return true;
 }
 
-static bool tls_set_write_state(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx) {
+static bool tls_set_write_state(SSL *ssl, ssl_encryption_level_t level,
+                                UniquePtr<SSLAEADContext> aead_ctx) {
   if (!tls_flush_pending_hs_data(ssl)) {
     return false;
   }
 
   OPENSSL_memset(ssl->s3->write_sequence, 0, sizeof(ssl->s3->write_sequence));
   ssl->s3->aead_write_ctx = std::move(aead_ctx);
+  ssl->s3->write_level = level;
   return true;
 }
 
