Unify AEAD and EVP code paths for AES-GCM.

This change makes the AEAD and EVP code paths use the same code for
AES-GCM. When AVX instructions are enabled in the assembly this will
allow them to use the stitched AES-GCM implementation.

Note that the stitched implementations are no-ops for small inputs
(smaller than 288 bytes for encryption; smaller than 96 bytes for
decryption). This means that only a handful of test cases with longish
inputs actually test the stitched code.

Change-Id: Iece8003d90448dcac9e0bde1f42ff102ebe1a1c9
Reviewed-on: https://boringssl-review.googlesource.com/7173
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/cipher/e_aes.c b/crypto/cipher/e_aes.c
index f7b6fa3..a0087d8 100644
--- a/crypto/cipher/e_aes.c
+++ b/crypto/cipher/e_aes.c
@@ -252,22 +252,6 @@
 void aesni_cbc_encrypt(const uint8_t *in, uint8_t *out, size_t length,
                        const AES_KEY *key, uint8_t *ivec, int enc);
 
-void aesni_ctr32_encrypt_blocks(const uint8_t *in, uint8_t *out, size_t blocks,
-                                const void *key, const uint8_t *ivec);
-
-#if defined(OPENSSL_X86_64)
-size_t aesni_gcm_encrypt(const uint8_t *in, uint8_t *out, size_t len,
-                         const void *key, uint8_t ivec[16], uint64_t *Xi);
-#define AES_gcm_encrypt aesni_gcm_encrypt
-size_t aesni_gcm_decrypt(const uint8_t *in, uint8_t *out, size_t len,
-                         const void *key, uint8_t ivec[16], uint64_t *Xi);
-#define AES_gcm_decrypt aesni_gcm_decrypt
-void gcm_ghash_avx(uint64_t Xi[2], const u128 Htable[16], const uint8_t *in,
-                   size_t len);
-#define AES_GCM_ASM(gctx) \
-  (gctx->ctr == aesni_ctr32_encrypt_blocks && gctx->gcm.ghash == gcm_ghash_avx)
-#endif  /* OPENSSL_X86_64 */
-
 #else
 
 /* On other platforms, aesni_capable() will always return false and so the
@@ -651,57 +635,23 @@
       }
     } else if (ctx->encrypt) {
       if (gctx->ctr) {
-        size_t bulk = 0;
-#if defined(AES_GCM_ASM)
-        if (len >= 32 && AES_GCM_ASM(gctx)) {
-          size_t res = (16 - gctx->gcm.mres) % 16;
-
-          if (!CRYPTO_gcm128_encrypt(&gctx->gcm, &gctx->ks.ks, in, out, res)) {
-            return -1;
-          }
-
-          bulk = AES_gcm_encrypt(in + res, out + res, len - res, &gctx->ks.ks,
-                                 gctx->gcm.Yi.c, gctx->gcm.Xi.u);
-          gctx->gcm.len.u[1] += bulk;
-          bulk += res;
-        }
-#endif
-        if (!CRYPTO_gcm128_encrypt_ctr32(&gctx->gcm, &gctx->ks.ks, in + bulk,
-                                         out + bulk, len - bulk, gctx->ctr)) {
+        if (!CRYPTO_gcm128_encrypt_ctr32(&gctx->gcm, &gctx->ks.ks, in, out, len,
+                                         gctx->ctr)) {
           return -1;
         }
       } else {
-        size_t bulk = 0;
-        if (!CRYPTO_gcm128_encrypt(&gctx->gcm, &gctx->ks.ks, in + bulk,
-                                   out + bulk, len - bulk)) {
+        if (!CRYPTO_gcm128_encrypt(&gctx->gcm, &gctx->ks.ks, in, out, len)) {
           return -1;
         }
       }
     } else {
       if (gctx->ctr) {
-        size_t bulk = 0;
-#if defined(AES_GCM_ASM)
-        if (len >= 16 && AES_GCM_ASM(gctx)) {
-          size_t res = (16 - gctx->gcm.mres) % 16;
-
-          if (!CRYPTO_gcm128_decrypt(&gctx->gcm, &gctx->ks.ks, in, out, res)) {
-            return -1;
-          }
-
-          bulk = AES_gcm_decrypt(in + res, out + res, len - res, &gctx->ks.ks,
-                                 gctx->gcm.Yi.c, gctx->gcm.Xi.u);
-          gctx->gcm.len.u[1] += bulk;
-          bulk += res;
-        }
-#endif
-        if (!CRYPTO_gcm128_decrypt_ctr32(&gctx->gcm, &gctx->ks.ks, in + bulk,
-                                         out + bulk, len - bulk, gctx->ctr)) {
+        if (!CRYPTO_gcm128_decrypt_ctr32(&gctx->gcm, &gctx->ks.ks, in, out, len,
+                                         gctx->ctr)) {
           return -1;
         }
       } else {
-        size_t bulk = 0;
-        if (!CRYPTO_gcm128_decrypt(&gctx->gcm, &gctx->ks.ks, in + bulk,
-                                   out + bulk, len - bulk)) {
+        if (!CRYPTO_gcm128_decrypt(&gctx->gcm, &gctx->ks.ks, in, out, len)) {
           return -1;
         }
       }
diff --git a/crypto/modes/asm/aesni-gcm-x86_64.pl b/crypto/modes/asm/aesni-gcm-x86_64.pl
index 26135e6..c7402f5 100644
--- a/crypto/modes/asm/aesni-gcm-x86_64.pl
+++ b/crypto/modes/asm/aesni-gcm-x86_64.pl
@@ -41,11 +41,12 @@
 ( $xlate="${dir}../../perlasm/x86_64-xlate.pl" and -f $xlate) or
 die "can't locate x86_64-xlate.pl";
 
+# This must be kept in sync with |$avx| in ghash-x86_64.pl; otherwise tags will
+# be computed incorrectly.
+#
 # In upstream, this is controlled by shelling out to the compiler to check
 # versions, but BoringSSL is intended to be used with pre-generated perlasm
 # output, so this isn't useful anyway.
-#
-# TODO(davidben): Enable this after testing. $avx goes up to 2.
 $avx = 0;
 
 open OUT,"| \"$^X\" $xlate $flavour $output";
diff --git a/crypto/modes/asm/ghash-x86_64.pl b/crypto/modes/asm/ghash-x86_64.pl
index e42ca32..5a11fb9 100644
--- a/crypto/modes/asm/ghash-x86_64.pl
+++ b/crypto/modes/asm/ghash-x86_64.pl
@@ -90,11 +90,12 @@
 ( $xlate="${dir}../../perlasm/x86_64-xlate.pl" and -f $xlate) or
 die "can't locate x86_64-xlate.pl";
 
+# This must be kept in sync with |$avx| in aesni-gcm-x86_64.pl; otherwise tags
+# will be computed incorrectly.
+#
 # In upstream, this is controlled by shelling out to the compiler to check
 # versions, but BoringSSL is intended to be used with pre-generated perlasm
 # output, so this isn't useful anyway.
-#
-# TODO(davidben): Enable this after testing. $avx goes up to 2.
 $avx = 0;
 
 open OUT,"| \"$^X\" $xlate $flavour $output";
diff --git a/crypto/modes/gcm.c b/crypto/modes/gcm.c
index 8cc138d..153ff17 100644
--- a/crypto/modes/gcm.c
+++ b/crypto/modes/gcm.c
@@ -55,6 +55,7 @@
 #include <openssl/cpu.h>
 
 #include "internal.h"
+#include "../internal.h"
 
 
 #if !defined(OPENSSL_NO_ASM) &&                         \
@@ -337,7 +338,18 @@
 #else
 void gcm_init_avx(u128 Htable[16], const uint64_t Xi[2]);
 void gcm_gmult_avx(uint64_t Xi[2], const u128 Htable[16]);
-void gcm_ghash_avx(uint64_t Xi[2], const u128 Htable[16], const uint8_t *inp, size_t len);
+void gcm_ghash_avx(uint64_t Xi[2], const u128 Htable[16], const uint8_t *in,
+                   size_t len);
+#define AESNI_GCM
+static int aesni_gcm_enabled(GCM128_CONTEXT *ctx, ctr128_f stream) {
+  return stream == aesni_ctr32_encrypt_blocks &&
+         ctx->ghash == gcm_ghash_avx;
+}
+
+size_t aesni_gcm_encrypt(const uint8_t *in, uint8_t *out, size_t len,
+                         const void *key, uint8_t ivec[16], uint64_t *Xi);
+size_t aesni_gcm_decrypt(const uint8_t *in, uint8_t *out, size_t len,
+                         const void *key, uint8_t ivec[16], uint64_t *Xi);
 #endif
 
 #if defined(OPENSSL_X86)
@@ -1011,6 +1023,18 @@
       return 1;
     }
   }
+
+#if defined(AESNI_GCM)
+  if (aesni_gcm_enabled(ctx, stream)) {
+    /* |aesni_gcm_encrypt| may not process all the input given to it. It may
+     * not process *any* of its input if it is deemed too small. */
+    size_t bulk = aesni_gcm_encrypt(in, out, len, key, ctx->Yi.c, ctx->Xi.u);
+    in += bulk;
+    out += bulk;
+    len -= bulk;
+  }
+#endif
+
 #if defined(GHASH)
   while (len >= GHASH_CHUNK) {
     (*stream)(in, out, GHASH_CHUNK / 16, key, ctx->Yi.c);
@@ -1100,12 +1124,6 @@
     ctx->ares = 0;
   }
 
-  if (is_endian.little) {
-    ctr = GETU32(ctx->Yi.c + 12);
-  } else {
-    ctr = ctx->Yi.d[3];
-  }
-
   n = ctx->mres;
   if (n) {
     while (n && len) {
@@ -1122,6 +1140,24 @@
       return 1;
     }
   }
+
+#if defined(AESNI_GCM)
+  if (aesni_gcm_enabled(ctx, stream)) {
+    /* |aesni_gcm_decrypt| may not process all the input given to it. It may
+     * not process *any* of its input if it is deemed too small. */
+    size_t bulk = aesni_gcm_decrypt(in, out, len, key, ctx->Yi.c, ctx->Xi.u);
+    in += bulk;
+    out += bulk;
+    len -= bulk;
+  }
+#endif
+
+  if (is_endian.little) {
+    ctr = GETU32(ctx->Yi.c + 12);
+  } else {
+    ctr = ctx->Yi.d[3];
+  }
+
 #if defined(GHASH)
   while (len >= GHASH_CHUNK) {
     GHASH(ctx, in, GHASH_CHUNK);
diff --git a/crypto/modes/internal.h b/crypto/modes/internal.h
index 7255a7c..51b6431 100644
--- a/crypto/modes/internal.h
+++ b/crypto/modes/internal.h
@@ -363,6 +363,12 @@
                                    block128_f block);
 
 
+#if !defined(OPENSSL_NO_ASM) && \
+    (defined(OPENSSL_X86) || defined(OPENSSL_X86_64))
+void aesni_ctr32_encrypt_blocks(const uint8_t *in, uint8_t *out, size_t blocks,
+                                const void *key, const uint8_t *ivec);
+#endif
+
 #if defined(__cplusplus)
 } /* extern C */
 #endif