Add SSL_AEAD_CTX_seal_scatter.

This plumbs EVP_AEAD_CTX_seal_scatter all the way through to
tls_record.c, so we can add a new zero-copy record sealing method on top
of the existing code.

Change-Id: I01fdd88abef5442dc16605ea31b29b4b1231c073
Reviewed-on: https://boringssl-review.googlesource.com/17684
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/internal.h b/ssl/internal.h
index 7a7c9ef..6b88070 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -447,6 +447,13 @@
  * |SSL_AEAD_CTX_seal|. |ctx| may be NULL to denote the null cipher. */
 size_t SSL_AEAD_CTX_max_overhead(const SSL_AEAD_CTX *ctx);
 
+/* SSL_AEAD_CTX_max_suffix_len returns the maximum suffix length written by
+ * |SSL_AEAD_CTX_seal_scatter|. |ctx| may be NULL to denote the null cipher.
+ * |extra_in_len| should equal the argument of the same name passed to
+ * |SSL_AEAD_CTX_seal_scatter|. */
+size_t SSL_AEAD_CTX_max_suffix_len(const SSL_AEAD_CTX *ctx,
+                                   size_t extra_in_len);
+
 /* SSL_AEAD_CTX_open authenticates and decrypts |in_len| bytes from |in|
  * in-place. On success, it sets |*out| to the plaintext in |in| and returns
  * one. Otherwise, it returns zero. |ctx| may be NULL to denote the null cipher.
@@ -465,6 +472,31 @@
                       const uint8_t seqnum[8], const uint8_t *in,
                       size_t in_len);
 
+/* SSL_AEAD_CTX_seal_scatter encrypts and authenticates |in_len| bytes from |in|
+ * and splits the result between |out_prefix|, |out| and |out_suffix|. It
+ * returns one on success and zero on error. |ctx| may be NULL to denote the
+ * null cipher.
+ *
+ * On successful return, exactly |SSL_AEAD_CTX_explicit_nonce_len| bytes are
+ * written to |out_prefix|, |in_len| bytes to |out|, and up to
+ * |SSL_AEAD_CTX_max_suffix_len| bytes to |out_suffix|. |*out_suffix_len| is set
+ * to the actual number of bytes written to |out_suffix|.
+ *
+ * |extra_in| may point to an additional plaintext buffer. If present,
+ * |extra_in_len| additional bytes are encrypted and authenticated, and the
+ * ciphertext is written to the beginning of |out_suffix|.
+ * |SSL_AEAD_CTX_max_suffix_len| may be used to size |out_suffix| accordingly.
+ *
+ * If |in| and |out| alias then |out| must be == |in|. Other arguments may not
+ * alias anything. */
+int SSL_AEAD_CTX_seal_scatter(SSL_AEAD_CTX *aead, uint8_t *out_prefix,
+                              uint8_t *out, uint8_t *out_suffix,
+                              size_t *out_suffix_len, size_t max_out_suffix_len,
+                              uint8_t type, uint16_t wire_version,
+                              const uint8_t seqnum[8], const uint8_t *in,
+                              size_t in_len, const uint8_t *extra_in,
+                              size_t extra_in_len);
+
 
 /* DTLS replay bitmap. */
 
diff --git a/ssl/ssl_aead_ctx.cc b/ssl/ssl_aead_ctx.cc
index 0cdf717..b78b06b 100644
--- a/ssl/ssl_aead_ctx.cc
+++ b/ssl/ssl_aead_ctx.cc
@@ -140,16 +140,19 @@
   return 0;
 }
 
-size_t SSL_AEAD_CTX_max_overhead(const SSL_AEAD_CTX *aead) {
+size_t SSL_AEAD_CTX_max_suffix_len(const SSL_AEAD_CTX *aead,
+                                   size_t extra_in_len) {
 #if defined(BORINGSSL_UNSAFE_FUZZER_MODE)
   aead = NULL;
 #endif
 
-  if (aead == NULL) {
-    return 0;
-  }
-  return EVP_AEAD_max_overhead(aead->ctx.aead) +
-         SSL_AEAD_CTX_explicit_nonce_len(aead);
+  return extra_in_len +
+         (aead == NULL ? 0 : EVP_AEAD_max_overhead(aead->ctx.aead));
+}
+
+size_t SSL_AEAD_CTX_max_overhead(const SSL_AEAD_CTX *aead) {
+  return SSL_AEAD_CTX_explicit_nonce_len(aead) +
+         SSL_AEAD_CTX_max_suffix_len(aead, 0);
 }
 
 /* ssl_aead_ctx_get_ad writes the additional data for |aead| into |out| and
@@ -252,22 +255,32 @@
   return 1;
 }
 
-int SSL_AEAD_CTX_seal(SSL_AEAD_CTX *aead, uint8_t *out, size_t *out_len,
-                      size_t max_out, uint8_t type, uint16_t wire_version,
-                      const uint8_t seqnum[8], const uint8_t *in,
-                      size_t in_len) {
+int SSL_AEAD_CTX_seal_scatter(SSL_AEAD_CTX *aead, uint8_t *out_prefix,
+                              uint8_t *out, uint8_t *out_suffix,
+                              size_t *out_suffix_len, size_t max_out_suffix_len,
+                              uint8_t type, uint16_t wire_version,
+                              const uint8_t seqnum[8], const uint8_t *in,
+                              size_t in_len, const uint8_t *extra_in,
+                              size_t extra_in_len) {
 #if defined(BORINGSSL_UNSAFE_FUZZER_MODE)
   aead = NULL;
 #endif
 
+  if ((in != out && buffers_alias(in, in_len, out, in_len)) ||
+      buffers_alias(in, in_len, out_suffix, max_out_suffix_len)) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_OUTPUT_ALIASES_INPUT);
+    return 0;
+  }
+  if (extra_in_len > max_out_suffix_len) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
+    return 0;
+  }
+
   if (aead == NULL) {
     /* Handle the initial NULL cipher. */
-    if (in_len > max_out) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
-      return 0;
-    }
     OPENSSL_memmove(out, in, in_len);
-    *out_len = in_len;
+    OPENSSL_memmove(out_suffix, extra_in, extra_in_len);
+    *out_suffix_len = extra_in_len;
     return 1;
   }
 
@@ -303,22 +316,14 @@
   nonce_len += aead->variable_nonce_len;
 
   /* Emit the variable nonce if included in the record. */
-  size_t extra_len = 0;
   if (aead->variable_nonce_included_in_record) {
     assert(!aead->xor_fixed_nonce);
-    if (max_out < aead->variable_nonce_len) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
-      return 0;
-    }
-    if (out < in + in_len && in < out + aead->variable_nonce_len) {
+    if (buffers_alias(in, in_len, out_prefix, aead->variable_nonce_len)) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_OUTPUT_ALIASES_INPUT);
       return 0;
     }
-    OPENSSL_memcpy(out, nonce + aead->fixed_nonce_len,
+    OPENSSL_memcpy(out_prefix, nonce + aead->fixed_nonce_len,
                    aead->variable_nonce_len);
-    extra_len = aead->variable_nonce_len;
-    out += aead->variable_nonce_len;
-    max_out -= aead->variable_nonce_len;
   }
 
   /* XOR the fixed nonce, if necessary. */
@@ -329,10 +334,33 @@
     }
   }
 
-  if (!EVP_AEAD_CTX_seal(&aead->ctx, out, out_len, max_out, nonce, nonce_len,
-                         in, in_len, ad, ad_len)) {
+  return EVP_AEAD_CTX_seal_scatter(&aead->ctx, out, out_suffix, out_suffix_len,
+                                   max_out_suffix_len, nonce, nonce_len, in,
+                                   in_len, extra_in, extra_in_len, ad, ad_len);
+}
+
+int SSL_AEAD_CTX_seal(SSL_AEAD_CTX *aead, uint8_t *out, size_t *out_len,
+                      size_t max_out_len, uint8_t type, uint16_t wire_version,
+                      const uint8_t seqnum[8], const uint8_t *in,
+                      size_t in_len) {
+  size_t prefix_len = SSL_AEAD_CTX_explicit_nonce_len(aead);
+  if (in_len + prefix_len < in_len) {
+    OPENSSL_PUT_ERROR(CIPHER, SSL_R_RECORD_TOO_LARGE);
     return 0;
   }
-  *out_len += extra_len;
+  if (in_len + prefix_len > max_out_len) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
+    return 0;
+  }
+
+  size_t suffix_len;
+  if (!SSL_AEAD_CTX_seal_scatter(aead, out, out + prefix_len,
+                                 out + prefix_len + in_len, &suffix_len,
+                                 max_out_len - prefix_len - in_len, type,
+                                 wire_version, seqnum, in, in_len, 0, 0)) {
+    return 0;
+  }
+  assert(suffix_len <= SSL_AEAD_CTX_max_suffix_len(aead, 0));
+  *out_len = prefix_len + in_len + suffix_len;
   return 1;
 }
diff --git a/ssl/tls_record.cc b/ssl/tls_record.cc
index 3bc0b29..28ffb4e 100644
--- a/ssl/tls_record.cc
+++ b/ssl/tls_record.cc
@@ -358,30 +358,24 @@
   return ssl_open_record_discard;
 }
 
-static int do_seal_record(SSL *ssl, uint8_t *out, size_t *out_len,
-                          size_t max_out, uint8_t type, const uint8_t *in,
-                          size_t in_len) {
-  assert(!buffers_alias(in, in_len, out, max_out));
+static int do_seal_record(SSL *ssl, uint8_t *out_prefix, uint8_t *out,
+                          uint8_t *out_suffix, size_t *out_suffix_len,
+                          const size_t max_out_suffix_len, uint8_t type,
+                          const uint8_t *in, const size_t in_len) {
+  assert(in == out || !buffers_alias(in, in_len, out, in_len));
+  assert(!buffers_alias(in, in_len, out_prefix, ssl_record_prefix_len(ssl)));
+  assert(!buffers_alias(in, in_len, out_suffix, max_out_suffix_len));
 
   /* TLS 1.3 hides the actual record type inside the encrypted data. */
+  uint8_t *extra_in = NULL;
+  size_t extra_in_len = 0;
   if (ssl->s3->aead_write_ctx != NULL &&
       ssl->s3->aead_write_ctx->version >= TLS1_3_VERSION) {
-    if (in_len > in_len + SSL3_RT_HEADER_LENGTH + 1 ||
-        max_out < in_len + SSL3_RT_HEADER_LENGTH + 1) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
-      return 0;
-    }
-
-    OPENSSL_memcpy(out + SSL3_RT_HEADER_LENGTH, in, in_len);
-    out[SSL3_RT_HEADER_LENGTH + in_len] = type;
-    in = out + SSL3_RT_HEADER_LENGTH;
-    type = SSL3_RT_APPLICATION_DATA;
-    in_len++;
-  }
-
-  if (max_out < SSL3_RT_HEADER_LENGTH) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
-    return 0;
+    extra_in = &type;
+    extra_in_len = 1;
+    out_prefix[0] = SSL3_RT_APPLICATION_DATA;
+  } else {
+    out_prefix[0] = type;
   }
 
   /* The TLS record-layer version number is meaningless and, starting in
@@ -395,65 +389,144 @@
   if (ssl->s3->have_version && ssl3_protocol_version(ssl) < TLS1_3_VERSION) {
     wire_version = ssl->version;
   }
-
-  /* Write the non-length portions of the header. */
-  out[0] = type;
-  out[1] = wire_version >> 8;
-  out[2] = wire_version & 0xff;
+  out_prefix[1] = wire_version >> 8;
+  out_prefix[2] = wire_version & 0xff;
 
   /* Write the ciphertext, leaving two bytes for the length. */
-  size_t ciphertext_len;
-  if (!SSL_AEAD_CTX_seal(ssl->s3->aead_write_ctx, out + SSL3_RT_HEADER_LENGTH,
-                         &ciphertext_len, max_out - SSL3_RT_HEADER_LENGTH, type,
-                         wire_version, ssl->s3->write_sequence, in, in_len) ||
+  if (!SSL_AEAD_CTX_seal_scatter(
+          ssl->s3->aead_write_ctx, out_prefix + SSL3_RT_HEADER_LENGTH, out,
+          out_suffix, out_suffix_len, max_out_suffix_len, type, wire_version,
+          ssl->s3->write_sequence, in, in_len, extra_in, extra_in_len) ||
       !ssl_record_sequence_update(ssl->s3->write_sequence, 8)) {
     return 0;
   }
 
   /* Fill in the length. */
+  const size_t ciphertext_len =
+      SSL_AEAD_CTX_explicit_nonce_len(ssl->s3->aead_write_ctx) + in_len +
+      *out_suffix_len;
   if (ciphertext_len >= 1 << 15) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
     return 0;
   }
-  out[3] = ciphertext_len >> 8;
-  out[4] = ciphertext_len & 0xff;
+  out_prefix[3] = ciphertext_len >> 8;
+  out_prefix[4] = ciphertext_len & 0xff;
 
-  *out_len = SSL3_RT_HEADER_LENGTH + ciphertext_len;
-
-  ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HEADER, out,
+  ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HEADER, out_prefix,
                       SSL3_RT_HEADER_LENGTH);
   return 1;
 }
 
-int tls_seal_record(SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out,
-                    uint8_t type, const uint8_t *in, size_t in_len) {
-  if (buffers_alias(in, in_len, out, max_out)) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_OUTPUT_ALIASES_INPUT);
-    return 0;
-  }
-
-  size_t frag_len = 0;
+static size_t tls_seal_scatter_prefix_len(const SSL *ssl, uint8_t type,
+                                          size_t in_len) {
+  size_t ret = SSL3_RT_HEADER_LENGTH;
   if (type == SSL3_RT_APPLICATION_DATA && in_len > 1 &&
       ssl_needs_record_splitting(ssl)) {
-    if (!do_seal_record(ssl, out, &frag_len, max_out, type, in, 1)) {
+    /* In the case of record splitting, the 1-byte record (of the 1/n-1 split)
+     * will be placed in the prefix, as will four of the five bytes of the
+     * record header for the main record. The final byte will replace the first
+     * byte of the plaintext that was used in the small record. */
+    ret += ssl_cipher_get_record_split_len(ssl->s3->aead_write_ctx->cipher);
+    ret += SSL3_RT_HEADER_LENGTH - 1;
+  } else {
+    ret += SSL_AEAD_CTX_explicit_nonce_len(ssl->s3->aead_write_ctx);
+  }
+  return ret;
+}
+
+/* tls_seal_scatter_record seals a new record of type |type| and body |in| and
+ * splits it between |out_prefix|, |out|, and |out_suffix|. Exactly
+ * |tls_seal_scatter_prefix_len| bytes are written to |out_prefix|, |in_len|
+ * bytes to |out|, and up to 1 + |SSL_AEAD_CTX_max_overhead| bytes to
+ * |out_suffix|. |*out_suffix_len| is set to the actual number of bytes written
+ * to |out_suffix|. It returns one on success and zero on error. If enabled,
+ * |tls_seal_scatter_record| implements TLS 1.0 CBC 1/n-1 record splitting and
+ * may write two records concatenated. */
+static int tls_seal_scatter_record(SSL *ssl, uint8_t *out_prefix, uint8_t *out,
+                                   uint8_t *out_suffix, size_t *out_suffix_len,
+                                   size_t max_out_suffix_len, uint8_t type,
+                                   const uint8_t *in, size_t in_len) {
+  if (type == SSL3_RT_APPLICATION_DATA && in_len > 1 &&
+      ssl_needs_record_splitting(ssl)) {
+    assert(SSL_AEAD_CTX_explicit_nonce_len(ssl->s3->aead_write_ctx) == 0);
+    const size_t prefix_len = SSL3_RT_HEADER_LENGTH;
+
+    /* Write the 1-byte fragment into |out_prefix|. */
+    uint8_t *split_body = out_prefix + prefix_len;
+    uint8_t *split_suffix = split_body + 1;
+
+    /* TODO(martinkr): Make AEAD code not complain if max_suffix_len is lower
+     * than |EVP_AEAD_max_overhead| but still sufficiently large. */
+    size_t split_max_suffix_len =
+        SSL_AEAD_CTX_max_suffix_len(ssl->s3->aead_write_ctx, 0);
+    size_t split_suffix_len = 0;
+    if (!do_seal_record(ssl, out_prefix, split_body, split_suffix,
+                        &split_suffix_len, split_max_suffix_len, type, in, 1)) {
       return 0;
     }
-    in++;
-    in_len--;
-    out += frag_len;
-    max_out -= frag_len;
+
+    size_t split_record_len = prefix_len + 1 + split_suffix_len;
 
 #if !defined(BORINGSSL_UNSAFE_FUZZER_MODE)
     assert(SSL3_RT_HEADER_LENGTH + ssl_cipher_get_record_split_len(
                                        ssl->s3->aead_write_ctx->cipher) ==
-           frag_len);
+           split_record_len);
 #endif
+
+    /* Write the n-1-byte fragment. The header gets split between |out_prefix|
+     * (header[:-1]) and |out| (header[-1:]). */
+    uint8_t tmp_prefix[SSL3_RT_HEADER_LENGTH];
+    if (!do_seal_record(ssl, tmp_prefix, out + 1, out_suffix, out_suffix_len,
+                        max_out_suffix_len, type, in + 1, in_len - 1)) {
+      return 0;
+    }
+    assert(tls_seal_scatter_prefix_len(ssl, type, in_len) ==
+           split_record_len + SSL3_RT_HEADER_LENGTH - 1);
+    OPENSSL_memcpy(out_prefix + split_record_len, tmp_prefix,
+                   SSL3_RT_HEADER_LENGTH - 1);
+    OPENSSL_memcpy(out, tmp_prefix + SSL3_RT_HEADER_LENGTH - 1, 1);
+    return 1;
   }
 
-  if (!do_seal_record(ssl, out, out_len, max_out, type, in, in_len)) {
+  return do_seal_record(ssl, out_prefix, out, out_suffix, out_suffix_len,
+                        max_out_suffix_len, type, in, in_len);
+}
+
+int tls_seal_record(SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out_len,
+                    uint8_t type, const uint8_t *in, size_t in_len) {
+  if (buffers_alias(in, in_len, out, max_out_len)) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_OUTPUT_ALIASES_INPUT);
     return 0;
   }
-  *out_len += frag_len;
+
+  const size_t prefix_len = tls_seal_scatter_prefix_len(ssl, type, in_len);
+
+  if (in_len + prefix_len < in_len) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_RECORD_TOO_LARGE);
+    return 0;
+  }
+  if (max_out_len < in_len + prefix_len) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
+    return 0;
+  }
+
+  uint8_t *prefix = out;
+  uint8_t *body = out + prefix_len;
+  uint8_t *suffix = body + in_len;
+  size_t max_suffix_len = max_out_len - prefix_len - in_len;
+  size_t suffix_len = 0;
+
+  if (!tls_seal_scatter_record(ssl, prefix, body, suffix, &suffix_len,
+                               max_suffix_len, type, in, in_len)) {
+    return 0;
+  }
+
+  if (prefix_len + in_len + suffix_len < prefix_len + in_len) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_RECORD_TOO_LARGE);
+    return 0;
+  }
+
+  *out_len = prefix_len + in_len + suffix_len;
   return 1;
 }