Make tls_open_record always in-place.

The business with ssl_record_prefix_len is rather a hassle. Instead, have
tls_open_record always decrypt in-place and give back a CBS to where the body
is.

This way the caller doesn't need to do an extra check all to avoid creating an
invalid pointer and underflow in subtraction.

Change-Id: I4e12b25a760870d8f8a503673ab00a2d774fc9ee
Reviewed-on: https://boringssl-review.googlesource.com/8173
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/crypto/bytestring/bytestring_test.cc b/crypto/bytestring/bytestring_test.cc
index 84ecffc..e1d16f4 100644
--- a/crypto/bytestring/bytestring_test.cc
+++ b/crypto/bytestring/bytestring_test.cc
@@ -43,7 +43,7 @@
 }
 
 static bool TestGetUint() {
-  static const uint8_t kData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+  static const uint8_t kData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
   uint8_t u8;
   uint16_t u16;
   uint32_t u32;
@@ -58,7 +58,10 @@
     u32 == 0x40506 &&
     CBS_get_u32(&data, &u32) &&
     u32 == 0x708090a &&
-    !CBS_get_u8(&data, &u8);
+    CBS_get_last_u8(&data, &u8) &&
+    u8 == 0xb &&
+    !CBS_get_u8(&data, &u8) &&
+    !CBS_get_last_u8(&data, &u8);
 }
 
 static bool TestGetPrefixed() {
diff --git a/crypto/bytestring/cbs.c b/crypto/bytestring/cbs.c
index ed54b49..c86afbd 100644
--- a/crypto/bytestring/cbs.c
+++ b/crypto/bytestring/cbs.c
@@ -128,6 +128,15 @@
   return cbs_get_u(cbs, out, 4);
 }
 
+int CBS_get_last_u8(CBS *cbs, uint8_t *out) {
+  if (cbs->len == 0) {
+    return 0;
+  }
+  *out = cbs->data[cbs->len - 1];
+  cbs->len--;
+  return 1;
+}
+
 int CBS_get_bytes(CBS *cbs, CBS *out, size_t len) {
   const uint8_t *v;
   if (!cbs_get(cbs, &v, len)) {
diff --git a/include/openssl/bytestring.h b/include/openssl/bytestring.h
index 68ede2d..3a8d4e5 100644
--- a/include/openssl/bytestring.h
+++ b/include/openssl/bytestring.h
@@ -95,6 +95,10 @@
  * and advances |cbs|. It returns one on success and zero on error. */
 OPENSSL_EXPORT int CBS_get_u32(CBS *cbs, uint32_t *out);
 
+/* CBS_get_last_u8 sets |*out| to the last uint8_t from |cbs| and shortens
+ * |cbs|. It returns one on success and zero on error. */
+OPENSSL_EXPORT int CBS_get_last_u8(CBS *cbs, uint8_t *out);
+
 /* CBS_get_bytes sets |*out| to the next |len| bytes from |cbs| and advances
  * |cbs|. It returns one on success and zero on error. */
 OPENSSL_EXPORT int CBS_get_bytes(CBS *cbs, CBS *out, size_t len);
diff --git a/ssl/d1_pkt.c b/ssl/d1_pkt.c
index 68e6a4d..4f05f0f 100644
--- a/ssl/d1_pkt.c
+++ b/ssl/d1_pkt.c
@@ -159,18 +159,11 @@
   }
   assert(ssl_read_buffer_len(ssl) > 0);
 
-  /* Ensure the packet is large enough to decrypt in-place. */
-  if (ssl_read_buffer_len(ssl) < ssl_record_prefix_len(ssl)) {
-    ssl_read_buffer_clear(ssl);
-    goto again;
-  }
-
-  uint8_t *out = ssl_read_buffer(ssl) + ssl_record_prefix_len(ssl);
-  size_t max_out = ssl_read_buffer_len(ssl) - ssl_record_prefix_len(ssl);
+  CBS body;
   uint8_t type, alert;
-  size_t len, consumed;
+  size_t consumed;
   enum ssl_open_record_t open_ret =
-      dtls_open_record(ssl, &type, out, &len, &consumed, &alert, max_out,
+      dtls_open_record(ssl, &type, &body, &consumed, &alert,
                        ssl_read_buffer(ssl), ssl_read_buffer_len(ssl));
   ssl_read_buffer_consume(ssl, consumed);
   switch (open_ret) {
@@ -179,15 +172,15 @@
       break;
 
     case ssl_open_record_success:
-      if (len > 0xffff) {
+      if (CBS_len(&body) > 0xffff) {
         OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
         return -1;
       }
 
       SSL3_RECORD *rr = &ssl->s3->rrec;
       rr->type = type;
-      rr->length = (uint16_t)len;
-      rr->data = out;
+      rr->length = (uint16_t)CBS_len(&body);
+      rr->data = (uint8_t *)CBS_data(&body);
       return 1;
 
     case ssl_open_record_discard:
diff --git a/ssl/dtls_record.c b/ssl/dtls_record.c
index 94dfb28..76ae8e5 100644
--- a/ssl/dtls_record.c
+++ b/ssl/dtls_record.c
@@ -171,10 +171,10 @@
   }
 }
 
-enum ssl_open_record_t dtls_open_record(
-    SSL *ssl, uint8_t *out_type, uint8_t *out, size_t *out_len,
-    size_t *out_consumed, uint8_t *out_alert, size_t max_out, const uint8_t *in,
-    size_t in_len) {
+enum ssl_open_record_t dtls_open_record(SSL *ssl, uint8_t *out_type, CBS *out,
+                                        size_t *out_consumed,
+                                        uint8_t *out_alert, uint8_t *in,
+                                        size_t in_len) {
   *out_consumed = 0;
 
   CBS cbs;
@@ -211,11 +211,9 @@
     return ssl_open_record_discard;
   }
 
-  /* Decrypt the body. */
-  size_t plaintext_len;
-  if (!SSL_AEAD_CTX_open(ssl->s3->aead_read_ctx, out, &plaintext_len, max_out,
-                         type, version, sequence, CBS_data(&body),
-                         CBS_len(&body))) {
+  /* Decrypt the body in-place. */
+  if (!SSL_AEAD_CTX_open(ssl->s3->aead_read_ctx, out, type, version, sequence,
+                         (uint8_t *)CBS_data(&body), CBS_len(&body))) {
     /* Bad packets are silently dropped in DTLS. See section 4.2.1 of RFC 6347.
      * Clear the error queue of any errors decryption may have added. Drop the
      * entire packet as it must not have come from the peer.
@@ -229,7 +227,7 @@
   *out_consumed = in_len - CBS_len(&cbs);
 
   /* Check the plaintext length. */
-  if (plaintext_len > SSL3_RT_MAX_PLAIN_LENGTH) {
+  if (CBS_len(out) > SSL3_RT_MAX_PLAIN_LENGTH) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_DATA_LENGTH_TOO_LONG);
     *out_alert = SSL_AD_RECORD_OVERFLOW;
     return ssl_open_record_error;
@@ -241,13 +239,12 @@
    * useful if we also limit discarded packets. */
 
   if (type == SSL3_RT_ALERT) {
-    return ssl_process_alert(ssl, out_alert, out, plaintext_len);
+    return ssl_process_alert(ssl, out_alert, CBS_data(out), CBS_len(out));
   }
 
   ssl->s3->warning_alert_count = 0;
 
   *out_type = type;
-  *out_len = plaintext_len;
   return ssl_open_record_success;
 }
 
diff --git a/ssl/internal.h b/ssl/internal.h
index 13e7935..4856969 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -313,15 +313,13 @@
  * |SSL_AEAD_CTX_seal|. |ctx| may be NULL to denote the null cipher. */
 size_t SSL_AEAD_CTX_max_overhead(SSL_AEAD_CTX *ctx);
 
-/* SSL_AEAD_CTX_open authenticates and decrypts |in_len| bytes from |in| and
- * writes the result to |out|. It returns one on success and zero on
- * error. |ctx| may be NULL to denote the null cipher.
- *
- * If |in| and |out| alias then |out| must be <= |in| + |explicit_nonce_len|. */
-int SSL_AEAD_CTX_open(SSL_AEAD_CTX *ctx, uint8_t *out, size_t *out_len,
-                      size_t max_out, uint8_t type, uint16_t wire_version,
-                      const uint8_t seqnum[8], const uint8_t *in,
-                      size_t in_len);
+/* SSL_AEAD_CTX_open authenticates and decrypts |in_len| bytes from |in|
+ * in-place. On success, it sets |*out| to the plaintext in |in| and returns
+ * one. Otherwise, it returns zero. |ctx| may be NULL to denote the null cipher.
+ * The output will always be |explicit_nonce_len| bytes ahead of |in|. */
+int SSL_AEAD_CTX_open(SSL_AEAD_CTX *ctx, CBS *out, uint8_t type,
+                      uint16_t wire_version, const uint8_t seqnum[8],
+                      uint8_t *in, size_t in_len);
 
 /* SSL_AEAD_CTX_seal encrypts and authenticates |in_len| bytes from |in| and
  * writes the result to |out|. It returns one on success and zero on
@@ -370,7 +368,7 @@
   ssl_open_record_error,
 };
 
-/* tls_open_record decrypts a record from |in|.
+/* tls_open_record decrypts a record from |in| in-place.
  *
  * If the input did not contain a complete record, it returns
  * |ssl_open_record_partial|. It sets |*out_consumed| to the total number of
@@ -382,8 +380,8 @@
  * decrypted.
  *
  * On success, it returns |ssl_open_record_success|. It sets |*out_type| to the
- * record type, |*out_len| to the plaintext length, and writes the record body
- * to |out|. Note that |*out_len| may be zero.
+ * record type and |*out| to the record body in |in|. Note that |*out| may be
+ * empty.
  *
  * If a record was successfully processed but should be discarded, it returns
  * |ssl_open_record_discard|.
@@ -392,20 +390,17 @@
  * it returns |ssl_open_record_close_notify| or |ssl_open_record_fatal_alert|.
  *
  * On failure, it returns |ssl_open_record_error| and sets |*out_alert| to an
- * alert to emit.
- *
- * If |in| and |out| alias, |out| must be <= |in| + |ssl_record_prefix_len|. */
-enum ssl_open_record_t tls_open_record(
-    SSL *ssl, uint8_t *out_type, uint8_t *out, size_t *out_len,
-    size_t *out_consumed, uint8_t *out_alert, size_t max_out, const uint8_t *in,
-    size_t in_len);
+ * alert to emit. */
+enum ssl_open_record_t tls_open_record(SSL *ssl, uint8_t *out_type, CBS *out,
+                                       size_t *out_consumed, uint8_t *out_alert,
+                                       uint8_t *in, size_t in_len);
 
 /* dtls_open_record implements |tls_open_record| for DTLS. It never returns
  * |ssl_open_record_partial| but otherwise behaves analogously. */
-enum ssl_open_record_t dtls_open_record(
-    SSL *ssl, uint8_t *out_type, uint8_t *out, size_t *out_len,
-    size_t *out_consumed, uint8_t *out_alert, size_t max_out, const uint8_t *in,
-    size_t in_len);
+enum ssl_open_record_t dtls_open_record(SSL *ssl, uint8_t *out_type, CBS *out,
+                                        size_t *out_consumed,
+                                        uint8_t *out_alert, uint8_t *in,
+                                        size_t in_len);
 
 /* ssl_seal_prefix_len returns the length of the prefix before the ciphertext
  * when sealing a record with |ssl|. Note that this value may differ from
diff --git a/ssl/s3_pkt.c b/ssl/s3_pkt.c
index 04d41be..cd6de5d 100644
--- a/ssl/s3_pkt.c
+++ b/ssl/s3_pkt.c
@@ -138,41 +138,34 @@
       return 0;
   }
 
-  /* Ensure the buffer is large enough to decrypt in-place. */
-  int read_ret = ssl_read_buffer_extend_to(ssl, ssl_record_prefix_len(ssl));
-  if (read_ret <= 0) {
-    return read_ret;
-  }
-  assert(ssl_read_buffer_len(ssl) >= ssl_record_prefix_len(ssl));
-
-  uint8_t *out = ssl_read_buffer(ssl) + ssl_record_prefix_len(ssl);
-  size_t max_out = ssl_read_buffer_len(ssl) - ssl_record_prefix_len(ssl);
+  CBS body;
   uint8_t type, alert;
-  size_t len, consumed;
+  size_t consumed;
   enum ssl_open_record_t open_ret =
-      tls_open_record(ssl, &type, out, &len, &consumed, &alert, max_out,
+      tls_open_record(ssl, &type, &body, &consumed, &alert,
                       ssl_read_buffer(ssl), ssl_read_buffer_len(ssl));
   if (open_ret != ssl_open_record_partial) {
     ssl_read_buffer_consume(ssl, consumed);
   }
   switch (open_ret) {
-    case ssl_open_record_partial:
-      read_ret = ssl_read_buffer_extend_to(ssl, consumed);
+    case ssl_open_record_partial: {
+      int read_ret = ssl_read_buffer_extend_to(ssl, consumed);
       if (read_ret <= 0) {
         return read_ret;
       }
       goto again;
+    }
 
     case ssl_open_record_success:
-      if (len > 0xffff) {
+      if (CBS_len(&body) > 0xffff) {
         OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
         return -1;
       }
 
       SSL3_RECORD *rr = &ssl->s3->rrec;
       rr->type = type;
-      rr->length = (uint16_t)len;
-      rr->data = out;
+      rr->length = (uint16_t)CBS_len(&body);
+      rr->data = (uint8_t *)CBS_data(&body);
       return 1;
 
     case ssl_open_record_discard:
diff --git a/ssl/ssl_aead_ctx.c b/ssl/ssl_aead_ctx.c
index 1e549ea..88daddd 100644
--- a/ssl/ssl_aead_ctx.c
+++ b/ssl/ssl_aead_ctx.c
@@ -166,22 +166,16 @@
   return len;
 }
 
-int SSL_AEAD_CTX_open(SSL_AEAD_CTX *aead, uint8_t *out, size_t *out_len,
-                      size_t max_out, uint8_t type, uint16_t wire_version,
-                      const uint8_t seqnum[8], const uint8_t *in,
-                      size_t in_len) {
+int SSL_AEAD_CTX_open(SSL_AEAD_CTX *aead, CBS *out, uint8_t type,
+                      uint16_t wire_version, const uint8_t seqnum[8],
+                      uint8_t *in, size_t in_len) {
 #if defined(BORINGSSL_UNSAFE_FUZZER_MODE)
   aead = NULL;
 #endif
 
   if (aead == NULL) {
     /* Handle the initial NULL cipher. */
-    if (in_len > max_out) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
-      return 0;
-    }
-    memmove(out, in, in_len);
-    *out_len = in_len;
+    CBS_init(out, in, in_len);
     return 1;
   }
 
@@ -239,8 +233,14 @@
     }
   }
 
-  return EVP_AEAD_CTX_open(&aead->ctx, out, out_len, max_out, nonce, nonce_len,
-                           in, in_len, ad, ad_len);
+  /* Decrypt in-place. */
+  size_t len;
+  if (!EVP_AEAD_CTX_open(&aead->ctx, in, &len, in_len, nonce, nonce_len,
+                         in, in_len, ad, ad_len)) {
+    return 0;
+  }
+  CBS_init(out, in, len);
+  return 1;
 }
 
 int SSL_AEAD_CTX_seal(SSL_AEAD_CTX *aead, uint8_t *out, size_t *out_len,
diff --git a/ssl/tls_record.c b/ssl/tls_record.c
index 869831c..24dfb21 100644
--- a/ssl/tls_record.c
+++ b/ssl/tls_record.c
@@ -192,10 +192,9 @@
   return ret;
 }
 
-enum ssl_open_record_t tls_open_record(
-    SSL *ssl, uint8_t *out_type, uint8_t *out, size_t *out_len,
-    size_t *out_consumed, uint8_t *out_alert, size_t max_out, const uint8_t *in,
-    size_t in_len) {
+enum ssl_open_record_t tls_open_record(SSL *ssl, uint8_t *out_type, CBS *out,
+                                       size_t *out_consumed, uint8_t *out_alert,
+                                       uint8_t *in, size_t in_len) {
   *out_consumed = 0;
 
   CBS cbs;
@@ -236,10 +235,9 @@
   ssl_do_msg_callback(ssl, 0 /* read */, 0, SSL3_RT_HEADER, in,
                       SSL3_RT_HEADER_LENGTH);
 
-  /* Decrypt the body. */
-  size_t plaintext_len;
-  if (!SSL_AEAD_CTX_open(ssl->s3->aead_read_ctx, out, &plaintext_len, max_out,
-                         type, version, ssl->s3->read_sequence, CBS_data(&body),
+  /* Decrypt the body in-place. */
+  if (!SSL_AEAD_CTX_open(ssl->s3->aead_read_ctx, out, type, version,
+                         ssl->s3->read_sequence, (uint8_t *)CBS_data(&body),
                          CBS_len(&body))) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_DECRYPTION_FAILED_OR_BAD_RECORD_MAC);
     *out_alert = SSL_AD_BAD_RECORD_MAC;
@@ -256,28 +254,24 @@
   if (ssl->s3->have_version &&
       ssl3_protocol_version(ssl) >= TLS1_3_VERSION &&
       ssl->s3->aead_read_ctx != NULL) {
-    while (plaintext_len != 0 && out[plaintext_len - 1] == 0) {
-      plaintext_len--;
-    }
-
-    if (plaintext_len == 0) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_DECRYPTION_FAILED_OR_BAD_RECORD_MAC);
-      *out_alert = SSL_AD_DECRYPT_ERROR;
-      return ssl_open_record_error;
-    }
-    type = out[plaintext_len - 1];
-    plaintext_len--;
+    do {
+      if (!CBS_get_last_u8(out, &type)) {
+        OPENSSL_PUT_ERROR(SSL, SSL_R_DECRYPTION_FAILED_OR_BAD_RECORD_MAC);
+        *out_alert = SSL_AD_DECRYPT_ERROR;
+        return ssl_open_record_error;
+      }
+    } while (type == 0);
   }
 
   /* Check the plaintext length. */
-  if (plaintext_len > SSL3_RT_MAX_PLAIN_LENGTH) {
+  if (CBS_len(out) > SSL3_RT_MAX_PLAIN_LENGTH) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_DATA_LENGTH_TOO_LONG);
     *out_alert = SSL_AD_RECORD_OVERFLOW;
     return ssl_open_record_error;
   }
 
   /* Limit the number of consecutive empty records. */
-  if (plaintext_len == 0) {
+  if (CBS_len(out) == 0) {
     ssl->s3->empty_record_count++;
     if (ssl->s3->empty_record_count > kMaxEmptyRecords) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_TOO_MANY_EMPTY_FRAGMENTS);
@@ -291,13 +285,12 @@
   }
 
   if (type == SSL3_RT_ALERT) {
-    return ssl_process_alert(ssl, out_alert, out, plaintext_len);
+    return ssl_process_alert(ssl, out_alert, CBS_data(out), CBS_len(out));
   }
 
   ssl->s3->warning_alert_count = 0;
 
   *out_type = type;
-  *out_len = plaintext_len;
   return ssl_open_record_success;
 }