EVP_AEAD: implement sealv/openv for AES-EAX.

Performance note: this seems to cost 22 to 85 CPU cycles per call with
LTO, and 31 to 55 CPU cycles without; no measurable per-byte overhead
though.

Bug: 383343306
Change-Id: Ib220fdfb8926da941f0ddaf994a14dee2b2951a3
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/84409
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
Auto-Submit: Rudolf Polzer <rpolzer@google.com>
diff --git a/crypto/cipher/e_aeseax.cc b/crypto/cipher/e_aeseax.cc
index 4f6a774..56026c7 100644
--- a/crypto/cipher/e_aeseax.cc
+++ b/crypto/cipher/e_aeseax.cc
@@ -25,6 +25,7 @@
 #include <openssl/crypto.h>
 #include <openssl/err.h>
 #include <openssl/mem.h>
+#include <openssl/span.h>
 
 #include "../fipsmodule/cipher/internal.h"
 #include "../internal.h"
@@ -114,9 +115,13 @@
   CRYPTO_xor16(out, aes_ctx->p, out);
 }
 
-static void omac(const struct aead_aes_eax_ctx *aes_ctx,
-                 uint8_t out[AES_BLOCK_SIZE], const uint8_t *in,
-                 size_t in_len) {
+template <typename IVEC>
+static void omac_with_tag(const struct aead_aes_eax_ctx *aes_ctx,
+                          uint8_t out[AES_BLOCK_SIZE],
+                          bssl::Span<const IVEC> ivecs, uint8_t tag) {
+  OPENSSL_memset(out, 0, AES_BLOCK_SIZE);
+  out[AES_BLOCK_SIZE - 1] = tag;
+  size_t in_len = bssl::iovec::TotalLength(ivecs);
   if (in_len == 0) {
     // CBK(pad(M;B,P)) = CBK(B). Avoiding padding to skip a copy.
     cbk_block(aes_ctx, aes_ctx->b, out);
@@ -124,46 +129,98 @@
   }
   // CBK(M1) = Ek(M1 ^ 0^n)
   AES_encrypt(out, out, &aes_ctx->ks.ks);
-  while (in_len > AES_BLOCK_SIZE) {
-    // Full blocks, no padding needed.
-    cbk_block(aes_ctx, in, out);
-    in += AES_BLOCK_SIZE;
-    in_len -= AES_BLOCK_SIZE;
-  }
-  // Last block to be padded.
-  uint8_t padded_block[AES_BLOCK_SIZE];
-  pad(aes_ctx, padded_block, in, in_len);
-  cbk_block(aes_ctx, padded_block, out);
+  bssl::iovec::ForEachBlockRange<AES_BLOCK_SIZE>(
+      ivecs,
+      [&](const uint8_t *in, size_t len) {
+        while (len >= AES_BLOCK_SIZE) {
+          // Full blocks, no padding needed.
+          cbk_block(aes_ctx, in, out);
+          in += AES_BLOCK_SIZE;
+          len -= AES_BLOCK_SIZE;
+        }
+        BSSL_CHECK(len == 0);
+        return true;
+      },
+      [&](const uint8_t *in, size_t len) {
+        // Remaining blocks.
+        while (len > AES_BLOCK_SIZE) {
+          // Full blocks, no padding needed.
+          cbk_block(aes_ctx, in, out);
+          in += AES_BLOCK_SIZE;
+          len -= AES_BLOCK_SIZE;
+        }
+        // Last partial block.
+        uint8_t padded_block[AES_BLOCK_SIZE];
+        pad(aes_ctx, padded_block, in, len);
+        cbk_block(aes_ctx, padded_block, out);
+        return true;
+      });
 }
 
-static void omac_with_tag(const struct aead_aes_eax_ctx *aes_ctx,
-                          uint8_t out[AES_BLOCK_SIZE], const uint8_t *in,
-                          size_t in_len, int tag) {
+static void omac_with_tag_iovec_out(const struct aead_aes_eax_ctx *aes_ctx,
+                                    uint8_t out[AES_BLOCK_SIZE],
+                                    bssl::Span<const CRYPTO_IOVEC> iovecs,
+                                    uint8_t tag) {
   OPENSSL_memset(out, 0, AES_BLOCK_SIZE);
   out[AES_BLOCK_SIZE - 1] = tag;
-  omac(aes_ctx, out, in, in_len);
+  size_t in_len = bssl::iovec::TotalLength(iovecs);
+  if (in_len == 0) {
+    // CBK(pad(M;B,P)) = CBK(B). Avoiding padding to skip a copy.
+    cbk_block(aes_ctx, aes_ctx->b, out);
+    return;
+  }
+  // CBK(M1) = Ek(M1 ^ 0^n)
+  AES_encrypt(out, out, &aes_ctx->ks.ks);
+  bssl::iovec::ForEachOutBlockRange<AES_BLOCK_SIZE>(
+      iovecs,
+      [&](const uint8_t *in, size_t len) {
+        while (len >= AES_BLOCK_SIZE) {
+          // Full blocks, no padding needed.
+          cbk_block(aes_ctx, in, out);
+          in += AES_BLOCK_SIZE;
+          len -= AES_BLOCK_SIZE;
+        }
+        BSSL_CHECK(len == 0);
+        return true;
+      },
+      [&](const uint8_t *in, size_t len) {
+        while (len > AES_BLOCK_SIZE) {
+          // Full blocks, no padding needed.
+          cbk_block(aes_ctx, in, out);
+          in += AES_BLOCK_SIZE;
+          len -= AES_BLOCK_SIZE;
+        }
+        // Last partial block.
+        uint8_t padded_block[AES_BLOCK_SIZE];
+        pad(aes_ctx, padded_block, in, len);
+        cbk_block(aes_ctx, padded_block, out);
+        return true;
+      });
 }
 
 // Encrypts/decrypts |in_len| bytes from |in| to |out| using AES-CTR with |n| as
 // the IV.
-static void aes_ctr(const struct aead_aes_eax_ctx *aes_ctx, uint8_t *out,
-                    const uint8_t n[AES_BLOCK_SIZE], const uint8_t *in,
-                    size_t in_len) {
+static void aes_ctr(const struct aead_aes_eax_ctx *aes_ctx,
+                    bssl::Span<const CRYPTO_IOVEC> iovecs,
+                    const uint8_t n[AES_BLOCK_SIZE]) {
   uint8_t ivec[AES_BLOCK_SIZE];
   OPENSSL_memcpy(ivec, n, AES_BLOCK_SIZE);
 
-  uint8_t unused_ecount_buf[AES_BLOCK_SIZE];
-  unsigned int unused_num = 0;
-  AES_ctr128_encrypt(in, out, in_len, &aes_ctx->ks.ks, ivec, unused_ecount_buf,
-                     &unused_num);
+  uint8_t ecount_buf[AES_BLOCK_SIZE];
+  unsigned int num = 0;
+
+  for (const CRYPTO_IOVEC &iovec : iovecs) {
+    AES_ctr128_encrypt(iovec.in, iovec.out, iovec.len, &aes_ctx->ks.ks, ivec,
+                       ecount_buf, &num);
+  }
 }
 
-static int aead_aes_eax_seal_scatter(
-    const EVP_AEAD_CTX *ctx, uint8_t *out, uint8_t *out_tag,
-    size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
-    size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *extra_in,
-    size_t extra_in_len, const uint8_t *ad, size_t ad_len) {
-  assert(extra_in_len == 0);
+static int aead_aes_eax_sealv(const EVP_AEAD_CTX *ctx,
+                              bssl::Span<const CRYPTO_IOVEC> iovecs,
+                              uint8_t *out_tag, size_t *out_tag_len,
+                              size_t max_out_tag_len, const uint8_t *nonce,
+                              size_t nonce_len,
+                              bssl::Span<const CRYPTO_IVEC> aadvecs) {
   // We use the full 128 bits of the nonce as counter, so no need to check the
   // plaintext size.
 
@@ -182,16 +239,19 @@
 
   // N <- OMAC(0 || nonce)
   uint8_t n[AES_BLOCK_SIZE];
-  omac_with_tag(aes_ctx, n, nonce, nonce_len, /*tag=*/0);
+  CRYPTO_IVEC noncevec[1];
+  noncevec[0].in = nonce;
+  noncevec[0].len = nonce_len;
+  omac_with_tag(aes_ctx, n, bssl::Span<const CRYPTO_IVEC>(noncevec), /*tag=*/0);
   // H <- OMAC(1 || ad)
   uint8_t h[AES_BLOCK_SIZE];
-  omac_with_tag(aes_ctx, h, ad, ad_len, /*tag=*/1);
+  omac_with_tag(aes_ctx, h, aadvecs, /*tag=*/1);
 
   // C <- CTR^{N}_{K}(M)
-  aes_ctr(aes_ctx, out, n, in, in_len);
+  aes_ctr(aes_ctx, iovecs, n);
 
   // MAC <- OMAC(2 || C)
-  omac_with_tag(aes_ctx, out_tag, out, in_len, /*tag=*/2);
+  omac_with_tag_iovec_out(aes_ctx, out_tag, iovecs, /*tag=*/2);
   // MAC <- N ^ C ^ H
   CRYPTO_xor16(out_tag, n, out_tag);
   CRYPTO_xor16(out_tag, h, out_tag);
@@ -200,18 +260,18 @@
   return 1;
 }
 
-static int aead_aes_eax_open_gather(const EVP_AEAD_CTX *ctx, uint8_t *out,
-                                    const uint8_t *nonce, size_t nonce_len,
-                                    const uint8_t *in, size_t in_len,
-                                    const uint8_t *in_tag, size_t in_tag_len,
-                                    const uint8_t *ad, size_t ad_len) {
-  const uint64_t ad_len_64 = ad_len;
+static int aead_aes_eax_openv_detached(const EVP_AEAD_CTX *ctx,
+                                       bssl::Span<const CRYPTO_IOVEC> iovecs,
+                                       const uint8_t *nonce, size_t nonce_len,
+                                       const uint8_t *in_tag, size_t in_tag_len,
+                                       bssl::Span<const CRYPTO_IVEC> aadvecs) {
+  const uint64_t ad_len_64 = bssl::iovec::TotalLength(aadvecs);
   if (ad_len_64 >= (UINT64_C(1) << 61)) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_TOO_LARGE);
     return 0;
   }
 
-  const uint64_t in_len_64 = in_len;
+  const uint64_t in_len_64 = bssl::iovec::TotalLength(iovecs);
   if (in_tag_len != EVP_AEAD_AES_EAX_TAG_LEN ||
       in_len_64 > (UINT64_C(1) << 36) + AES_BLOCK_SIZE) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_BAD_DECRYPT);
@@ -228,14 +288,18 @@
 
   // N <- OMAC(0 || nonce)
   uint8_t n[AES_BLOCK_SIZE];
-  omac_with_tag(aes_ctx, n, nonce, nonce_len, /*tag=*/0);
+  CRYPTO_IVEC noncevec[1];
+  noncevec[0].in = nonce;
+  noncevec[0].len = nonce_len;
+  omac_with_tag(aes_ctx, n, bssl::Span<const CRYPTO_IVEC>(noncevec),
+                /*tag=*/0);
   // H <- OMAC(1 || ad)
   uint8_t h[AES_BLOCK_SIZE];
-  omac_with_tag(aes_ctx, h, ad, ad_len, /*tag=*/1);
+  omac_with_tag(aes_ctx, h, aadvecs, /*tag=*/1);
 
   // MAC <- OMAC(2 || C)
   uint8_t mac[AES_BLOCK_SIZE];
-  omac_with_tag(aes_ctx, mac, in, in_len, /*tag=*/2);
+  omac_with_tag(aes_ctx, mac, iovecs, /*tag=*/2);
   // MAC <- N ^ C ^ H
   CRYPTO_xor16(mac, n, mac);
   CRYPTO_xor16(mac, h, mac);
@@ -246,7 +310,7 @@
   }
 
   // M <- CTR^{N}_{K}(C)
-  aes_ctr(aes_ctx, out, n, in, in_len);
+  aes_ctr(aes_ctx, iovecs, n);
   return 1;
 }
 
@@ -261,11 +325,11 @@
     nullptr,  // init_with_direction
     aead_aes_eax_cleanup,
     nullptr,  // open
-    aead_aes_eax_seal_scatter,
-    aead_aes_eax_open_gather,
+    nullptr,  // seal_scatter
+    nullptr,  // open_gather
     nullptr,  // openv
-    nullptr,  // sealv
-    nullptr,  // openv_detached
+    aead_aes_eax_sealv,
+    aead_aes_eax_openv_detached,
     nullptr,  // get_iv
     nullptr,  // tag_len
 };
@@ -281,11 +345,11 @@
     nullptr,  // init_with_direction
     aead_aes_eax_cleanup,
     nullptr,  // open
-    aead_aes_eax_seal_scatter,
-    aead_aes_eax_open_gather,
+    nullptr,  // seal_scatter
+    nullptr,  // open_gather
     nullptr,  // openv
-    nullptr,  // sealv
-    nullptr,  // openv_detached
+    aead_aes_eax_sealv,
+    aead_aes_eax_openv_detached,
     nullptr,  // get_iv
     nullptr,  // tag_len
 };