tool: make speed use EVP_AEAD_CTX_seal_scatter

Change-Id: I41854e61d87d365b923349a5ec8e71d73a0141bb
Reviewed-on: https://boringssl-review.googlesource.com/18844
Reviewed-by: Adam Langley <agl@google.com>
Commit-Queue: Adam Langley <agl@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/tool/speed.cc b/tool/speed.cc
index 6652298..cf7e70e 100644
--- a/tool/speed.cc
+++ b/tool/speed.cc
@@ -215,15 +215,24 @@
   std::unique_ptr<uint8_t[]> nonce(new uint8_t[nonce_len]);
   OPENSSL_memset(nonce.get(), 0, nonce_len);
   std::unique_ptr<uint8_t[]> in_storage(new uint8_t[chunk_len + kAlignment]);
-  std::unique_ptr<uint8_t[]> out_storage(new uint8_t[chunk_len + overhead_len + kAlignment]);
+  // N.B. for EVP_AEAD_CTX_seal_scatter the input and output buffers may be the
+  // same size. However, in the direction == evp_aead_open case we still use
+  // non-scattering seal, hence we add overhead_len to the size of this buffer.
+  std::unique_ptr<uint8_t[]> out_storage(
+      new uint8_t[chunk_len + overhead_len + kAlignment]);
   std::unique_ptr<uint8_t[]> in2_storage(new uint8_t[chunk_len + kAlignment]);
   std::unique_ptr<uint8_t[]> ad(new uint8_t[ad_len]);
   OPENSSL_memset(ad.get(), 0, ad_len);
+  std::unique_ptr<uint8_t[]> tag_storage(
+      new uint8_t[overhead_len + kAlignment]);
+
 
   uint8_t *const in = align(in_storage.get(), kAlignment);
   OPENSSL_memset(in, 0, chunk_len);
   uint8_t *const out = align(out_storage.get(), kAlignment);
   OPENSSL_memset(out, 0, chunk_len + overhead_len);
+  uint8_t *const tag = align(tag_storage.get(), kAlignment);
+  OPENSSL_memset(tag, 0, overhead_len);
   uint8_t *const in2 = align(in2_storage.get(), kAlignment);
 
   if (!EVP_AEAD_CTX_init_with_direction(ctx.get(), aead, key.get(), key_len,
@@ -236,13 +245,15 @@
 
   TimeResults results;
   if (direction == evp_aead_seal) {
-    if (!TimeFunction(&results, [chunk_len, overhead_len, nonce_len, ad_len, in,
-                                 out, &ctx, &nonce, &ad]() -> bool {
-          size_t out_len;
-          return EVP_AEAD_CTX_seal(ctx.get(), out, &out_len,
-                                   chunk_len + overhead_len, nonce.get(),
-                                   nonce_len, in, chunk_len, ad.get(), ad_len);
-        })) {
+    if (!TimeFunction(&results,
+                      [chunk_len, nonce_len, ad_len, overhead_len, in, out, tag,
+                       &ctx, &nonce, &ad]() -> bool {
+                        size_t tag_len;
+                        return EVP_AEAD_CTX_seal_scatter(
+                            ctx.get(), out, tag, &tag_len, overhead_len,
+                            nonce.get(), nonce_len, in, chunk_len, nullptr, 0,
+                            ad.get(), ad_len);
+                      })) {
       fprintf(stderr, "EVP_AEAD_CTX_seal failed.\n");
       ERR_print_errors_fp(stderr);
       return false;
@@ -252,13 +263,16 @@
     EVP_AEAD_CTX_seal(ctx.get(), out, &out_len, chunk_len + overhead_len,
                       nonce.get(), nonce_len, in, chunk_len, ad.get(), ad_len);
 
-    if (!TimeFunction(&results, [chunk_len, nonce_len, ad_len, in2, out, &ctx,
-                                 &nonce, &ad, out_len]() -> bool {
-          size_t in2_len;
-          return EVP_AEAD_CTX_open(ctx.get(), in2, &in2_len, chunk_len,
-                                   nonce.get(), nonce_len, out, out_len,
-                                   ad.get(), ad_len);
-        })) {
+    if (!TimeFunction(&results,
+                      [chunk_len, nonce_len, ad_len, in2, out, out_len, &ctx,
+                       &nonce, &ad]() -> bool {
+                        size_t in2_len;
+                        // N.B. EVP_AEAD_CTX_open_gather is not implemented for
+                        // all AEADs.
+                        return EVP_AEAD_CTX_open(
+                            ctx.get(), in2, &in2_len, chunk_len, nonce.get(),
+                            nonce_len, out, out_len, ad.get(), ad_len);
+                      })) {
       fprintf(stderr, "EVP_AEAD_CTX_open failed.\n");
       ERR_print_errors_fp(stderr);
       return false;