EVP_AEAD: implement sealv/openv for ChaCha20-Poly1305.

Note that for now, the asm optimizations are only available when using
up to 2 plaintext segments or up to 1 ciphertext segment.

Performance note: this seems to cost 14 to 40 CPU cycles per call with
LTO, and 45 to 51 CPU cycles without; no measurable per-byte overhead
though.

Bug: 383343306
Change-Id: I8ebbba3280ac51c200e463b6f71e6e3ed028d450
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/84507
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Rudolf Polzer <rpolzer@google.com>
Auto-Submit: Rudolf Polzer <rpolzer@google.com>
diff --git a/crypto/cipher/e_chacha20poly1305.cc b/crypto/cipher/e_chacha20poly1305.cc
index 97c6b79..77793fe 100644
--- a/crypto/cipher/e_chacha20poly1305.cc
+++ b/crypto/cipher/e_chacha20poly1305.cc
@@ -22,11 +22,12 @@
 #include <openssl/err.h>
 #include <openssl/mem.h>
 #include <openssl/poly1305.h>
+#include <openssl/span.h>
 
-#include "internal.h"
 #include "../chacha/internal.h"
 #include "../fipsmodule/cipher/internal.h"
 #include "../internal.h"
+#include "internal.h"
 
 struct aead_chacha20_poly1305_ctx {
   uint8_t key[32];
@@ -76,46 +77,46 @@
   CRYPTO_poly1305_update(poly1305, length_bytes, sizeof(length_bytes));
 }
 
-// calc_tag fills |tag| with the authentication tag for the given inputs.
-static void calc_tag(uint8_t tag[POLY1305_TAG_LEN], const uint8_t *key,
-                     const uint8_t nonce[12], const uint8_t *ad, size_t ad_len,
-                     const uint8_t *ciphertext, size_t ciphertext_len,
-                     const uint8_t *ciphertext_extra,
-                     size_t ciphertext_extra_len) {
+// calc_tag_pre prepares filling |tag| with the authentication tag for the given
+// inputs.
+static size_t calc_tag_pre(poly1305_state *ctx, const uint8_t key[32],
+                           const uint8_t nonce[12],
+                           bssl::Span<const CRYPTO_IVEC> aadvecs) {
   alignas(16) uint8_t poly1305_key[32];
   OPENSSL_memset(poly1305_key, 0, sizeof(poly1305_key));
   CRYPTO_chacha_20(poly1305_key, poly1305_key, sizeof(poly1305_key), key, nonce,
                    0);
 
-  static const uint8_t padding[16] = { 0 };  // Padding is all zeros.
-  poly1305_state ctx;
-  CRYPTO_poly1305_init(&ctx, poly1305_key);
-  CRYPTO_poly1305_update(&ctx, ad, ad_len);
+  static const uint8_t padding[16] = {0};  // Padding is all zeros.
+  CRYPTO_poly1305_init(ctx, poly1305_key);
+  size_t ad_len = 0;
+  for (const CRYPTO_IVEC &aadvec : aadvecs) {
+    CRYPTO_poly1305_update(ctx, aadvec.in, aadvec.len);
+    ad_len += aadvec.len;
+  }
   if (ad_len % 16 != 0) {
-    CRYPTO_poly1305_update(&ctx, padding, sizeof(padding) - (ad_len % 16));
+    CRYPTO_poly1305_update(ctx, padding, sizeof(padding) - (ad_len % 16));
   }
-  CRYPTO_poly1305_update(&ctx, ciphertext, ciphertext_len);
-  CRYPTO_poly1305_update(&ctx, ciphertext_extra, ciphertext_extra_len);
-  const size_t ciphertext_total = ciphertext_len + ciphertext_extra_len;
-  if (ciphertext_total % 16 != 0) {
-    CRYPTO_poly1305_update(&ctx, padding,
-                           sizeof(padding) - (ciphertext_total % 16));
-  }
-  poly1305_update_length(&ctx, ad_len);
-  poly1305_update_length(&ctx, ciphertext_total);
-  CRYPTO_poly1305_finish(&ctx, tag);
+  return ad_len;
 }
 
-static int chacha20_poly1305_seal_scatter(
-    const uint8_t *key, uint8_t *out, uint8_t *out_tag,
-    size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
-    size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *extra_in,
-    size_t extra_in_len, const uint8_t *ad, size_t ad_len, size_t tag_len) {
-  if (extra_in_len + tag_len < tag_len) {
-    OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_TOO_LARGE);
-    return 0;
+static void calc_tag_post(poly1305_state *ctx, uint8_t tag[POLY1305_TAG_LEN],
+                          size_t ciphertext_total, size_t ad_len) {
+  static const uint8_t padding[16] = {0};  // Padding is all zeros.
+  if (ciphertext_total % 16 != 0) {
+    CRYPTO_poly1305_update(ctx, padding,
+                           sizeof(padding) - (ciphertext_total % 16));
   }
-  if (max_out_tag_len < tag_len + extra_in_len) {
+  poly1305_update_length(ctx, ad_len);
+  poly1305_update_length(ctx, ciphertext_total);
+  CRYPTO_poly1305_finish(ctx, tag);
+}
+
+static int chacha20_poly1305_sealv(
+    const uint8_t *key, bssl::Span<const CRYPTO_IOVEC> iovecs, uint8_t *out_tag,
+    size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
+    size_t nonce_len, bssl::Span<const CRYPTO_IVEC> aadvecs, size_t tag_len) {
+  if (max_out_tag_len < tag_len) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_BUFFER_TOO_SMALL);
     return 0;
   }
@@ -130,74 +131,107 @@
   // 32-bits and this produces a warning because it's always false.
   // Casting to uint64_t inside the conditional is not sufficient to stop
   // the warning.
-  const uint64_t in_len_64 = in_len;
+  const uint64_t in_len_64 = bssl::iovec::TotalLength(iovecs);
   if (in_len_64 >= (UINT64_C(1) << 32) * 64 - 64) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_TOO_LARGE);
     return 0;
   }
 
-  if (max_out_tag_len < tag_len) {
-    OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_BUFFER_TOO_SMALL);
-    return 0;
-  }
-
-  // The the extra input is given, it is expected to be very short and so is
-  // encrypted byte-by-byte first.
-  if (extra_in_len) {
-    static const size_t kChaChaBlockSize = 64;
-    uint32_t block_counter = (uint32_t)(1 + (in_len / kChaChaBlockSize));
-    size_t offset = in_len % kChaChaBlockSize;
-    uint8_t block[64 /* kChaChaBlockSize */];
-
-    for (size_t done = 0; done < extra_in_len; block_counter++) {
-      memset(block, 0, sizeof(block));
-      CRYPTO_chacha_20(block, block, sizeof(block), key, nonce,
-                       block_counter);
-      for (size_t i = offset; i < sizeof(block) && done < extra_in_len;
-           i++, done++) {
-        out_tag[done] = extra_in[done] ^ block[i];
-      }
-      offset = 0;
-    }
-  }
-
   union chacha20_poly1305_seal_data data;
-  if (chacha20_poly1305_asm_capable()) {
+  if (chacha20_poly1305_asm_capable() && iovecs.size() <= 2 &&
+      aadvecs.size() <= 1) {
     OPENSSL_memcpy(data.in.key, key, 32);
     data.in.counter = 0;
     OPENSSL_memcpy(data.in.nonce, nonce, 12);
-    data.in.extra_ciphertext = out_tag;
-    data.in.extra_ciphertext_len = extra_in_len;
-    chacha20_poly1305_seal(out, in, in_len, ad, ad_len, &data);
+    if (iovecs.size() >= 2) {
+      // |chacha20_poly1305_seal| only supports one extra input and expects it
+      // to have been encrypted ahead of time. (Historically it was only used
+      // for very short inputs.)
+      constexpr size_t kChaChaBlockSize = 64;
+      uint32_t block_counter =
+          (uint32_t)(1 + (iovecs[0].len / kChaChaBlockSize));
+      size_t offset = iovecs[0].len % kChaChaBlockSize;
+      size_t done = 0;
+      if (offset != 0) {
+        uint8_t block[kChaChaBlockSize];
+        memset(block, 0, sizeof(block));
+        CRYPTO_chacha_20(block, block, sizeof(block), key, nonce,
+                         block_counter);
+        for (size_t i = offset; i < sizeof(block) && done < iovecs[1].len;
+             i++, done++) {
+          iovecs[1].out[done] = iovecs[1].in[done] ^ block[i];
+        }
+        ++block_counter;
+      }
+      if (done < iovecs[1].len) {
+        CRYPTO_chacha_20(iovecs[1].out + done, iovecs[1].in + done,
+                         iovecs[1].len - done, key, nonce, block_counter);
+      }
+      // TODO(crbug.com/383343306): Support more than 1 extra ciphertext.
+      data.in.extra_ciphertext = iovecs[1].out;
+      data.in.extra_ciphertext_len = iovecs[1].len;
+    } else {
+      data.in.extra_ciphertext = nullptr;
+      data.in.extra_ciphertext_len = 0;
+    }
+    chacha20_poly1305_seal(iovecs.size() >= 1 ? iovecs[0].out : nullptr,
+                           iovecs.size() >= 1 ? iovecs[0].in : nullptr,
+                           iovecs.size() >= 1 ? iovecs[0].len : 0,
+                           aadvecs.size() >= 1 ? aadvecs[0].in : nullptr,
+                           aadvecs.size() >= 1 ? aadvecs[0].len : 0, &data);
   } else {
-    CRYPTO_chacha_20(out, in, in_len, key, nonce, 1);
-    calc_tag(data.out.tag, key, nonce, ad, ad_len, out, in_len, out_tag,
-             extra_in_len);
+    poly1305_state ctx;
+    size_t ad_len = calc_tag_pre(&ctx, key, nonce, aadvecs);
+
+    size_t ciphertext_total = 0;
+    size_t block = 1;
+    bssl::iovec::ForEachBlockRange<64, /*WriteOut=*/true>(
+        iovecs,
+        [&](const uint8_t *in, uint8_t *out, size_t len) {
+          // TODO(crbug.com/383343306): Maybe just provide asm version of this?
+          // Here, len is always a multiple of 64.
+          CRYPTO_chacha_20(out, in, len, key, nonce, block);
+          CRYPTO_poly1305_update(&ctx, out, len);
+          ciphertext_total += len;
+          block += len / 64;
+          return true;
+        },
+        [&](const uint8_t *in, uint8_t *out, size_t len) {
+          // Here, len may be anything. If an asm version can't handle that,
+          // it will be worth splitting off multiples of 64 here.
+          CRYPTO_chacha_20(out, in, len, key, nonce, block);
+          CRYPTO_poly1305_update(&ctx, out, len);
+          ciphertext_total += len;
+          return true;
+        });
+
+    calc_tag_post(&ctx, data.out.tag, ciphertext_total, ad_len);
   }
 
-  OPENSSL_memcpy(out_tag + extra_in_len, data.out.tag, tag_len);
-  *out_tag_len = extra_in_len + tag_len;
+  OPENSSL_memcpy(out_tag, data.out.tag, tag_len);
+  *out_tag_len = tag_len;
   return 1;
 }
 
-static int aead_chacha20_poly1305_seal_scatter(
-    const EVP_AEAD_CTX *ctx, uint8_t *out, uint8_t *out_tag,
-    size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
-    size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *extra_in,
-    size_t extra_in_len, const uint8_t *ad, size_t ad_len) {
+static int aead_chacha20_poly1305_sealv(const EVP_AEAD_CTX *ctx,
+                                        bssl::Span<const CRYPTO_IOVEC> iovecs,
+                                        uint8_t *out_tag, size_t *out_tag_len,
+                                        size_t max_out_tag_len,
+                                        const uint8_t *nonce, size_t nonce_len,
+                                        bssl::Span<const CRYPTO_IVEC> aadvecs) {
   const struct aead_chacha20_poly1305_ctx *c20_ctx =
       (struct aead_chacha20_poly1305_ctx *)&ctx->state;
 
-  return chacha20_poly1305_seal_scatter(
-      c20_ctx->key, out, out_tag, out_tag_len, max_out_tag_len, nonce,
-      nonce_len, in, in_len, extra_in, extra_in_len, ad, ad_len, ctx->tag_len);
+  return chacha20_poly1305_sealv(c20_ctx->key, iovecs, out_tag, out_tag_len,
+                                 max_out_tag_len, nonce, nonce_len, aadvecs,
+                                 ctx->tag_len);
 }
 
-static int aead_xchacha20_poly1305_seal_scatter(
-    const EVP_AEAD_CTX *ctx, uint8_t *out, uint8_t *out_tag,
-    size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
-    size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *extra_in,
-    size_t extra_in_len, const uint8_t *ad, size_t ad_len) {
+static int aead_xchacha20_poly1305_sealv(
+    const EVP_AEAD_CTX *ctx, bssl::Span<const CRYPTO_IOVEC> iovecs,
+    uint8_t *out_tag, size_t *out_tag_len, size_t max_out_tag_len,
+    const uint8_t *nonce, size_t nonce_len,
+    bssl::Span<const CRYPTO_IVEC> aadvecs) {
   const struct aead_chacha20_poly1305_ctx *c20_ctx =
       (struct aead_chacha20_poly1305_ctx *)&ctx->state;
 
@@ -212,16 +246,15 @@
   OPENSSL_memset(derived_nonce, 0, 4);
   OPENSSL_memcpy(&derived_nonce[4], &nonce[16], 8);
 
-  return chacha20_poly1305_seal_scatter(
-      derived_key, out, out_tag, out_tag_len, max_out_tag_len,
-      derived_nonce, sizeof(derived_nonce), in, in_len, extra_in, extra_in_len,
-      ad, ad_len, ctx->tag_len);
+  return chacha20_poly1305_sealv(derived_key, iovecs, out_tag, out_tag_len,
+                                 max_out_tag_len, derived_nonce,
+                                 sizeof(derived_nonce), aadvecs, ctx->tag_len);
 }
 
-static int chacha20_poly1305_open_gather(
-    const uint8_t *key, uint8_t *out, const uint8_t *nonce,
-    size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *in_tag,
-    size_t in_tag_len, const uint8_t *ad, size_t ad_len, size_t tag_len) {
+static int chacha20_poly1305_openv_detached(
+    const uint8_t *key, bssl::Span<const CRYPTO_IOVEC> iovecs,
+    const uint8_t *nonce, size_t nonce_len, const uint8_t *in_tag,
+    size_t in_tag_len, bssl::Span<const CRYPTO_IVEC> aadvecs, size_t tag_len) {
   if (nonce_len != 12) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_UNSUPPORTED_NONCE_SIZE);
     return 0;
@@ -238,21 +271,51 @@
   // 32-bits and this produces a warning because it's always false.
   // Casting to uint64_t inside the conditional is not sufficient to stop
   // the warning.
-  const uint64_t in_len_64 = in_len;
+  const uint64_t in_len_64 = bssl::iovec::TotalLength(iovecs);
   if (in_len_64 >= (UINT64_C(1) << 32) * 64 - 64) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_TOO_LARGE);
     return 0;
   }
 
   union chacha20_poly1305_open_data data;
-  if (chacha20_poly1305_asm_capable()) {
+  if (chacha20_poly1305_asm_capable() && iovecs.size() <= 1 &&
+      aadvecs.size() <= 1) {
+    // TODO(crbug.com/383343306): Support more than 1 ciphertext segment.
     OPENSSL_memcpy(data.in.key, key, 32);
     data.in.counter = 0;
     OPENSSL_memcpy(data.in.nonce, nonce, 12);
-    chacha20_poly1305_open(out, in, in_len, ad, ad_len, &data);
+    chacha20_poly1305_open(iovecs.size() >= 1 ? iovecs[0].out : nullptr,
+                           iovecs.size() >= 1 ? iovecs[0].in : nullptr,
+                           iovecs.size() >= 1 ? iovecs[0].len : 0,
+                           aadvecs.size() >= 1 ? aadvecs[0].in : nullptr,
+                           aadvecs.size() >= 1 ? aadvecs[0].len : 0, &data);
   } else {
-    calc_tag(data.out.tag, key, nonce, ad, ad_len, in, in_len, nullptr, 0);
-    CRYPTO_chacha_20(out, in, in_len, key, nonce, 1);
+    poly1305_state ctx;
+    size_t ad_len = calc_tag_pre(&ctx, key, nonce, aadvecs);
+
+    size_t ciphertext_total = 0;
+    size_t block = 1;
+    bssl::iovec::ForEachBlockRange<64, /*WriteOut=*/true>(
+        iovecs,
+        [&](const uint8_t *in, uint8_t *out, size_t len) {
+          // TODO(crbug.com/383343306): Maybe just provide asm version of this?
+          // Here, len is always a multiple of 64.
+          CRYPTO_poly1305_update(&ctx, in, len);
+          CRYPTO_chacha_20(out, in, len, key, nonce, block);
+          ciphertext_total += len;
+          block += len / 64;
+          return true;
+        },
+        [&](const uint8_t *in, uint8_t *out, size_t len) {
+          // Here, len may be anything. If an asm version can't handle that,
+          // it will be worth splitting off multiples of 64 here.
+          CRYPTO_poly1305_update(&ctx, in, len);
+          CRYPTO_chacha_20(out, in, len, key, nonce, block);
+          ciphertext_total += len;
+          return true;
+        });
+
+    calc_tag_post(&ctx, data.out.tag, ciphertext_total, ad_len);
   }
 
   if (CRYPTO_memcmp(data.out.tag, in_tag, tag_len) != 0) {
@@ -263,22 +326,22 @@
   return 1;
 }
 
-static int aead_chacha20_poly1305_open_gather(
-    const EVP_AEAD_CTX *ctx, uint8_t *out, const uint8_t *nonce,
-    size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *in_tag,
-    size_t in_tag_len, const uint8_t *ad, size_t ad_len) {
+static int aead_chacha20_poly1305_openv_detached(
+    const EVP_AEAD_CTX *ctx, bssl::Span<const CRYPTO_IOVEC> iovecs,
+    const uint8_t *nonce, size_t nonce_len, const uint8_t *in_tag,
+    size_t in_tag_len, bssl::Span<const CRYPTO_IVEC> aadvecs) {
   const struct aead_chacha20_poly1305_ctx *c20_ctx =
       (struct aead_chacha20_poly1305_ctx *)&ctx->state;
 
-  return chacha20_poly1305_open_gather(c20_ctx->key, out, nonce, nonce_len, in,
-                                       in_len, in_tag, in_tag_len, ad, ad_len,
-                                       ctx->tag_len);
+  return chacha20_poly1305_openv_detached(c20_ctx->key, iovecs, nonce,
+                                          nonce_len, in_tag, in_tag_len,
+                                          aadvecs, ctx->tag_len);
 }
 
-static int aead_xchacha20_poly1305_open_gather(
-    const EVP_AEAD_CTX *ctx, uint8_t *out, const uint8_t *nonce,
-    size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *in_tag,
-    size_t in_tag_len, const uint8_t *ad, size_t ad_len) {
+static int aead_xchacha20_poly1305_openv_detached(
+    const EVP_AEAD_CTX *ctx, bssl::Span<const CRYPTO_IOVEC> iovecs,
+    const uint8_t *nonce, size_t nonce_len, const uint8_t *in_tag,
+    size_t in_tag_len, bssl::Span<const CRYPTO_IVEC> aadvecs) {
   const struct aead_chacha20_poly1305_ctx *c20_ctx =
       (struct aead_chacha20_poly1305_ctx *)&ctx->state;
 
@@ -293,9 +356,9 @@
   OPENSSL_memset(derived_nonce, 0, 4);
   OPENSSL_memcpy(&derived_nonce[4], &nonce[16], 8);
 
-  return chacha20_poly1305_open_gather(
-      derived_key, out, derived_nonce, sizeof(derived_nonce), in, in_len,
-      in_tag, in_tag_len, ad, ad_len, ctx->tag_len);
+  return chacha20_poly1305_openv_detached(derived_key, iovecs, derived_nonce,
+                                          sizeof(derived_nonce), in_tag,
+                                          in_tag_len, aadvecs, ctx->tag_len);
 }
 
 static const EVP_AEAD aead_chacha20_poly1305 = {
@@ -308,12 +371,12 @@
     aead_chacha20_poly1305_init,
     nullptr,  // init_with_direction
     aead_chacha20_poly1305_cleanup,
-    nullptr /* open */,
-    aead_chacha20_poly1305_seal_scatter,
-    aead_chacha20_poly1305_open_gather,
+    nullptr,  // open
+    nullptr,  // seal_scatter
+    nullptr,  // open_gather
     nullptr,  // openv
-    nullptr,  // sealv
-    nullptr,  // openv_detached
+    aead_chacha20_poly1305_sealv,
+    aead_chacha20_poly1305_openv_detached,
     nullptr,  // get_iv
     nullptr,  // tag_len
 };
@@ -328,12 +391,12 @@
     aead_chacha20_poly1305_init,
     nullptr,  // init_with_direction
     aead_chacha20_poly1305_cleanup,
-    nullptr /* open */,
-    aead_xchacha20_poly1305_seal_scatter,
-    aead_xchacha20_poly1305_open_gather,
+    nullptr,  // open
+    nullptr,  // seal_scatter
+    nullptr,  // open_gather
     nullptr,  // openv
-    nullptr,  // sealv
-    nullptr,  // openv_detached
+    aead_xchacha20_poly1305_sealv,
+    aead_xchacha20_poly1305_openv_detached,
     nullptr,  // get_iv
     nullptr,  // tag_len
 };