Keep the encryption state and encryption level in sync.

This is a little bit of internal cleanup. The original intent was so
QUIC could install secrets in set_(read|write)_state, but that was
somewhat annoying, so I've left it just before the call for now.

There is one TLS 1.3 state transition which doesn't carry an encryption
level: switching from 0-RTT keys back to unencrypted on an HRR-based
0-RTT reject. The TCP code doesn't care about write_level and the QUIC
code is currently fine because we never "install" the 0-RTT keys. But we
should get this correct.

This also opens the door for DTLS 1.3, if we ever implement it, because
DTLS 1.3 will need to know which level it is to handle 0-RTT keys funny.
(Clients sending 0-RTT will briefly have handshake and 0-RTT write keys
active simultaneously.)

QUIC has the same property, but we can fudge it because only the caller
is aware of this.

Change-Id: Ia76d787e1b96a058d9818948b6d9a051e8592207
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/40124
Reviewed-by: Steven Valdez <svaldez@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index 620a2e1..ae26de7 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -77,7 +77,8 @@
   }
 }
 
-static bool dtls1_set_read_state(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx) {
+static bool dtls1_set_read_state(SSL *ssl, ssl_encryption_level_t level,
+                                 UniquePtr<SSLAEADContext> aead_ctx) {
   // Cipher changes are forbidden if the current epoch has leftover data.
   if (dtls_has_unprocessed_handshake_data(ssl)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESS_HANDSHAKE_DATA);
@@ -90,11 +91,12 @@
   OPENSSL_memset(ssl->s3->read_sequence, 0, sizeof(ssl->s3->read_sequence));
 
   ssl->s3->aead_read_ctx = std::move(aead_ctx);
+  ssl->s3->read_level = level;
   ssl->d1->has_change_cipher_spec = 0;
   return true;
 }
 
-static bool dtls1_set_write_state(SSL *ssl,
+static bool dtls1_set_write_state(SSL *ssl, ssl_encryption_level_t level,
                                   UniquePtr<SSLAEADContext> aead_ctx) {
   ssl->d1->w_epoch++;
   OPENSSL_memcpy(ssl->d1->last_write_sequence, ssl->s3->write_sequence,
@@ -103,6 +105,7 @@
 
   ssl->d1->last_aead_write_ctx = std::move(ssl->s3->aead_write_ctx);
   ssl->s3->aead_write_ctx = std::move(aead_ctx);
+  ssl->s3->write_level = level;
   return true;
 }
 
diff --git a/ssl/internal.h b/ssl/internal.h
index a6b58c5..78c56b6 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2147,14 +2147,16 @@
   int (*flush_flight)(SSL *ssl);
   // on_handshake_complete is called when the handshake is complete.
   void (*on_handshake_complete)(SSL *ssl);
-  // set_read_state sets |ssl|'s read cipher state to |aead_ctx|. It returns
-  // true on success and false if changing the read state is forbidden at this
-  // point.
-  bool (*set_read_state)(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx);
-  // set_write_state sets |ssl|'s write cipher state to |aead_ctx|. It returns
-  // true on success and false if changing the write state is forbidden at this
-  // point.
-  bool (*set_write_state)(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx);
+  // set_read_state sets |ssl|'s read cipher state and level to |aead_ctx| and
+  // |level|. It returns true on success and false if changing the read state is
+  // forbidden at this point.
+  bool (*set_read_state)(SSL *ssl, ssl_encryption_level_t level,
+                         UniquePtr<SSLAEADContext> aead_ctx);
+  // set_write_state sets |ssl|'s write cipher state and level to |aead_ctx| and
+  // |level|. It returns true on success and false if changing the write state
+  // is forbidden at this point.
+  bool (*set_write_state)(SSL *ssl, ssl_encryption_level_t level,
+                          UniquePtr<SSLAEADContext> aead_ctx);
 };
 
 // The following wrappers call |open_*| but handle |read_shutdown| correctly.
diff --git a/ssl/t1_enc.cc b/ssl/t1_enc.cc
index 4c2fffb..8091021 100644
--- a/ssl/t1_enc.cc
+++ b/ssl/t1_enc.cc
@@ -236,10 +236,12 @@
   }
 
   if (direction == evp_aead_open) {
-    return ssl->method->set_read_state(ssl, std::move(aead_ctx));
+    return ssl->method->set_read_state(ssl, ssl_encryption_application,
+                                       std::move(aead_ctx));
   }
 
-  return ssl->method->set_write_state(ssl, std::move(aead_ctx));
+  return ssl->method->set_write_state(ssl, ssl_encryption_application,
+                                      std::move(aead_ctx));
 }
 
 int tls1_change_cipher_state(SSL_HANDSHAKE *hs,
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index e22c1e1..716c7b4 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -207,7 +207,8 @@
   bssl::UniquePtr<SSLAEADContext> null_ctx =
       SSLAEADContext::CreateNullCipher(SSL_is_dtls(ssl));
   if (!null_ctx ||
-      !ssl->method->set_write_state(ssl, std::move(null_ctx))) {
+      !ssl->method->set_write_state(ssl, ssl_encryption_initial,
+                                    std::move(null_ctx))) {
     return ssl_hs_error;
   }
 
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index d0c27b6..bd12f63 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -190,17 +190,6 @@
     return false;
   }
 
-  if (direction == evp_aead_open) {
-    if (!ssl->method->set_read_state(ssl, std::move(traffic_aead))) {
-      return false;
-    }
-  } else {
-    if (!ssl->method->set_write_state(ssl, std::move(traffic_aead))) {
-      return false;
-    }
-  }
-
-  // Save the traffic secret.
   if (traffic_secret.size() >
           OPENSSL_ARRAY_SIZE(ssl->s3->read_traffic_secret) ||
       traffic_secret.size() >
@@ -208,16 +197,21 @@
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return false;
   }
+
   if (direction == evp_aead_open) {
+    if (!ssl->method->set_read_state(ssl, level, std::move(traffic_aead))) {
+      return false;
+    }
     OPENSSL_memmove(ssl->s3->read_traffic_secret, traffic_secret.data(),
                     traffic_secret.size());
     ssl->s3->read_traffic_secret_len = traffic_secret.size();
-    ssl->s3->read_level = level;
   } else {
+    if (!ssl->method->set_write_state(ssl, level, std::move(traffic_aead))) {
+      return false;
+    }
     OPENSSL_memmove(ssl->s3->write_traffic_secret, traffic_secret.data(),
                     traffic_secret.size());
     ssl->s3->write_traffic_secret_len = traffic_secret.size();
-    ssl->s3->write_level = level;
   }
 
   return true;
diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc
index 241a3fd..3868852 100644
--- a/ssl/tls_method.cc
+++ b/ssl/tls_method.cc
@@ -82,7 +82,8 @@
   }
 }
 
-static bool tls_set_read_state(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx) {
+static bool tls_set_read_state(SSL *ssl, ssl_encryption_level_t level,
+                               UniquePtr<SSLAEADContext> aead_ctx) {
   // Cipher changes are forbidden if the current epoch has leftover data.
   if (tls_has_unprocessed_handshake_data(ssl)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESS_HANDSHAKE_DATA);
@@ -92,16 +93,19 @@
 
   OPENSSL_memset(ssl->s3->read_sequence, 0, sizeof(ssl->s3->read_sequence));
   ssl->s3->aead_read_ctx = std::move(aead_ctx);
+  ssl->s3->read_level = level;
   return true;
 }
 
-static bool tls_set_write_state(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx) {
+static bool tls_set_write_state(SSL *ssl, ssl_encryption_level_t level,
+                                UniquePtr<SSLAEADContext> aead_ctx) {
   if (!tls_flush_pending_hs_data(ssl)) {
     return false;
   }
 
   OPENSSL_memset(ssl->s3->write_sequence, 0, sizeof(ssl->s3->write_sequence));
   ssl->s3->aead_write_ctx = std::move(aead_ctx);
+  ssl->s3->write_level = level;
   return true;
 }