Factor SSL_AEAD_CTX into a dedicated type.

tls1_enc is now SSL_AEAD_CTX_{open,seal}. This starts tidying up a bit
of the record-layer logic. This removes rr->input, as encrypting and
decrypting records no longer refers to various globals. It also removes
wrec altogether. SSL3_RECORD is now only used to maintain state about
the current incoming record. Outgoing records go straight to the write
buffer.

This also removes the outgoing alignment memcpy and simply calls
SSL_AEAD_CTX_seal with the parameters as appropriate. From bssl speed
tests, this seems to be faster on non-ARM and a bit of a wash on ARM.

Later it may be worth recasting these open/seal functions to write into
a CBB (tweaked so it can be malloc-averse), but for now they take an
out/out_len/max_out trio like their EVP_AEAD counterparts.

BUG=468889

Change-Id: Ie9266a818cc053f695d35ef611fd74c5d4def6c3
Reviewed-on: https://boringssl-review.googlesource.com/4792
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/d1_pkt.c b/ssl/d1_pkt.c
index 9e056ac..e80e773 100644
--- a/ssl/d1_pkt.c
+++ b/ssl/d1_pkt.c
@@ -189,21 +189,7 @@
 
 static int dtls1_process_record(SSL *s) {
   int al;
-  SSL3_RECORD *rr;
-
-  rr = &(s->s3->rrec);
-
-  /* At this point, s->packet_length == SSL3_RT_HEADER_LNGTH + rr->length, and
-   * we have that many bytes in s->packet. */
-  rr->input = &(s->packet[DTLS1_RT_HEADER_LENGTH]);
-
-  /* ok, we can now read from 's->packet' data into 'rr' rr->input points at
-   * rr->length bytes, which need to be copied into rr->data by either the
-   * decryption or by the decompression When the data is 'copied' into the
-   * rr->data buffer, rr->input will be pointed at the new buffer */
-
-  /* We now have - encrypted [ MAC [ compressed [ plain ] ] ] rr->length bytes
-   * of encrypted compressed stuff. */
+  SSL3_RECORD *rr = &s->s3->rrec;
 
   /* check is not needed I believe */
   if (rr->length > SSL3_RT_MAX_ENCRYPTED_LENGTH) {
@@ -213,10 +199,23 @@
     goto f_err;
   }
 
-  /* decrypt in place in 'rr->input' */
-  rr->data = rr->input;
+  /* |rr->data| points to |rr->length| bytes of ciphertext in |s->packet|. */
+  rr->data = &s->packet[DTLS1_RT_HEADER_LENGTH];
 
-  if (!s->enc_method->enc(s, 0)) {
+  uint8_t seq[8];
+  seq[0] = rr->epoch >> 8;
+  seq[1] = rr->epoch & 0xff;
+  memcpy(&seq[2], &rr->seq_num[2], 6);
+
+  /* Decrypt the packet in-place. Note it is important that |SSL_AEAD_CTX_open|
+   * not write beyond |rr->length|. There may be another record in the packet.
+   *
+   * TODO(davidben): This assumes |s->version| is the same as the record-layer
+   * version which isn't always true, but it only differs with the NULL cipher
+   * which ignores the parameter. */
+  size_t plaintext_len;
+  if (!SSL_AEAD_CTX_open(s->aead_read_ctx, rr->data, &plaintext_len, rr->length,
+                         rr->type, s->version, seq, rr->data, rr->length)) {
     /* Bad packets are silently dropped in DTLS. Clear the error queue of any
      * errors decryption may have added. */
     ERR_clear_error();
@@ -225,19 +224,20 @@
     goto err;
   }
 
-  if (rr->length > SSL3_RT_MAX_PLAIN_LENGTH) {
+  if (plaintext_len > SSL3_RT_MAX_PLAIN_LENGTH) {
     al = SSL_AD_RECORD_OVERFLOW;
     OPENSSL_PUT_ERROR(SSL, dtls1_process_record, SSL_R_DATA_LENGTH_TOO_LONG);
     goto f_err;
   }
+  assert(plaintext_len < (1u << 16));
+  rr->length = plaintext_len;
 
   rr->off = 0;
   /* So at this point the following is true
    * ssl->s3->rrec.type 	is the type of record
    * ssl->s3->rrec.length	== number of bytes in record
    * ssl->s3->rrec.off	== offset to first valid byte
-   * ssl->s3->rrec.data	== where to take bytes from, increment
-   *			   after use :-). */
+   * ssl->s3->rrec.data	== the first byte of the record body. */
 
   /* we have pulled in a full packet so zero things */
   s->packet_length = 0;
@@ -260,11 +260,11 @@
  *
  * used only by dtls1_read_bytes */
 int dtls1_get_record(SSL *s) {
-  int ssl_major, ssl_minor;
+  uint8_t ssl_major, ssl_minor;
   int i, n;
   SSL3_RECORD *rr;
-  unsigned char *p = NULL;
-  unsigned short version;
+  uint8_t *p = NULL;
+  uint16_t version;
 
   rr = &(s->s3->rrec);
 
@@ -298,7 +298,7 @@
     rr->type = *(p++);
     ssl_major = *(p++);
     ssl_minor = *(p++);
-    version = (ssl_major << 8) | ssl_minor;
+    version = (((uint16_t)ssl_major) << 8) | ssl_minor;
 
     /* sequence number is 64 bits, with top 2 bytes = epoch */
     n2s(p, rr->epoch);
@@ -710,14 +710,56 @@
   return i;
 }
 
+/* dtls1_seal_record seals a new record of type |type| and plaintext |in| and
+ * writes it to |out|. At most |max_out| bytes will be written. It returns one
+ * on success and zero on error. On success, it updates the write sequence
+ * number. */
+static int dtls1_seal_record(SSL *s, uint8_t *out, size_t *out_len,
+                             size_t max_out, uint8_t type, const uint8_t *in,
+                             size_t in_len) {
+  if (max_out < DTLS1_RT_HEADER_LENGTH) {
+    OPENSSL_PUT_ERROR(SSL, dtls1_seal_record, SSL_R_BUFFER_TOO_SMALL);
+    return 0;
+  }
+
+  out[0] = type;
+
+  uint16_t wire_version = s->s3->have_version ? s->version : DTLS1_VERSION;
+  out[1] = wire_version >> 8;
+  out[2] = wire_version & 0xff;
+
+  out[3] = s->d1->w_epoch >> 8;
+  out[4] = s->d1->w_epoch & 0xff;
+  memcpy(&out[5], &s->s3->write_sequence[2], 6);
+
+  size_t ciphertext_len;
+  if (!SSL_AEAD_CTX_seal(s->aead_write_ctx, out + DTLS1_RT_HEADER_LENGTH,
+                         &ciphertext_len, max_out - DTLS1_RT_HEADER_LENGTH,
+                         type, wire_version, &out[3] /* seq */, in, in_len) ||
+      !ssl3_record_sequence_update(&s->s3->write_sequence[2], 6)) {
+    return 0;
+  }
+
+  if (ciphertext_len >= 1 << 16) {
+    OPENSSL_PUT_ERROR(SSL, dtls1_seal_record, ERR_R_OVERFLOW);
+    return 0;
+  }
+  out[11] = ciphertext_len >> 8;
+  out[12] = ciphertext_len & 0xff;
+
+  *out_len = DTLS1_RT_HEADER_LENGTH + ciphertext_len;
+
+  if (s->msg_callback) {
+    s->msg_callback(1 /* write */, 0, SSL3_RT_HEADER, out,
+                    DTLS1_RT_HEADER_LENGTH, s, s->msg_callback_arg);
+  }
+
+  return 1;
+}
+
 static int do_dtls1_write(SSL *s, int type, const uint8_t *buf,
                           unsigned int len) {
-  uint8_t *p, *pseq;
-  int i;
-  int prefix_len = 0;
-  int eivlen = 0;
-  SSL3_RECORD *wr;
-  SSL3_BUFFER *wb;
+  SSL3_BUFFER *wb = &s->s3->wbuf;
 
   /* ssl3_write_pending drops the write if |BIO_write| fails in DTLS, so there
    * is never pending data. */
@@ -725,88 +767,35 @@
 
   /* If we have an alert to send, lets send it */
   if (s->s3->alert_dispatch) {
-    i = s->method->ssl_dispatch_alert(s);
-    if (i <= 0) {
-      return i;
+    int ret = s->method->ssl_dispatch_alert(s);
+    if (ret <= 0) {
+      return ret;
     }
     /* if it went, fall through and send more stuff */
   }
 
+  if (wb->buf == NULL && !ssl3_setup_write_buffer(s)) {
+    return -1;
+  }
+
   if (len == 0) {
     return 0;
   }
 
-  wr = &(s->s3->wrec);
-  wb = &(s->s3->wbuf);
+  /* Align the output so the ciphertext is aligned to |SSL3_ALIGN_PAYLOAD|. */
+  uintptr_t align = (uintptr_t)wb->buf + DTLS1_RT_HEADER_LENGTH;
+  align = (-align) & (SSL3_ALIGN_PAYLOAD - 1);
+  uint8_t *out = wb->buf + align;
+  wb->offset = align;
+  size_t max_out = wb->len - wb->offset;
 
-  if (wb->buf == NULL && !ssl3_setup_write_buffer(s)) {
+  size_t ciphertext_len;
+  if (!dtls1_seal_record(s, out, &ciphertext_len, max_out, type, buf, len)) {
     return -1;
   }
-  p = wb->buf + prefix_len;
-
-  /* write the header */
-
-  *(p++) = type & 0xff;
-  wr->type = type;
-  /* Special case: for hello verify request, client version 1.0 and
-   * we haven't decided which version to use yet send back using
-   * version 1.0 header: otherwise some clients will ignore it.
-   */
-  if (!s->s3->have_version) {
-    *(p++) = DTLS1_VERSION >> 8;
-    *(p++) = DTLS1_VERSION & 0xff;
-  } else {
-    *(p++) = s->version >> 8;
-    *(p++) = s->version & 0xff;
-  }
-
-  /* field where we are to write out packet epoch, seq num and len */
-  pseq = p;
-  p += 10;
-
-  /* Leave room for the variable nonce for AEADs which specify it explicitly. */
-  if (s->aead_write_ctx != NULL &&
-      s->aead_write_ctx->variable_nonce_included_in_record) {
-    eivlen = s->aead_write_ctx->variable_nonce_len;
-  }
-
-  /* Assemble the input for |s->enc_method->enc|. The input is the plaintext
-   * with |eivlen| bytes of space prepended for the explicit nonce. */
-  wr->input = p;
-  wr->length = eivlen + len;
-  memcpy(p + eivlen, buf, len);
-
-  /* Encrypt in-place, so the output also goes into |p|. */
-  wr->data = p;
-
-  if (!s->enc_method->enc(s, 1)) {
-    goto err;
-  }
-
-  /* there's only one epoch between handshake and app data */
-  s2n(s->d1->w_epoch, pseq);
-
-  memcpy(pseq, &(s->s3->write_sequence[2]), 6);
-  pseq += 6;
-  s2n(wr->length, pseq);
-
-  if (s->msg_callback) {
-    s->msg_callback(1, 0, SSL3_RT_HEADER, pseq - DTLS1_RT_HEADER_LENGTH,
-                    DTLS1_RT_HEADER_LENGTH, s, s->msg_callback_arg);
-  }
-
-  /* we should now have wr->data pointing to the encrypted data, which is
-   * wr->length long */
-  wr->type = type; /* not needed but helps for debugging */
-  wr->length += DTLS1_RT_HEADER_LENGTH;
-
-  if (!ssl3_record_sequence_update(&s->s3->write_sequence[2], 6)) {
-    goto err;
-  }
 
   /* now let's set up wb */
-  wb->left = prefix_len + wr->length;
-  wb->offset = 0;
+  wb->left = ciphertext_len;
 
   /* memorize arguments so that ssl3_write_pending can detect bad write retries
    * later */
@@ -817,9 +806,6 @@
 
   /* we now just need to write the buffer */
   return ssl3_write_pending(s, type, buf, len);
-
-err:
-  return -1;
 }
 
 static int dtls1_record_replay_check(SSL *s, DTLS1_BITMAP *bitmap) {