Move AES_KEY into GCM128_KEY, GCM128_KEY out of GCM128_CONTEXT

AES_KEY was separate from GCM128_KEY because crypto/modes was originally
written as if there were non-AES 128-bit block ciphers to worry about.
This has long since stopped being the case for BoringSSL. This avoids
some duplicate key setup logic in EVP_CIPHER and EVP_AEAD.

GCM128_KEY was embedded into GCM128_CONTEXT because OpenSSL assembly
once relied on the exact order of a bunch of the fields, some of which
we per-operation and some of which were per-key. See
https://boringssl-review.googlesource.com/c/boringssl/+/13122

This assumption has since been removed. See
https://boringssl-review.googlesource.com/c/boringssl/+/59526

Now that that is done, we can pull it out and instead pass it in as a
separate pointer, like how AES_KEY used to. (I made the key the first
argument because key than op seems more natural to me than op than key.
Also I didn't noticed I'd flipped them until I was done with the CL.) By
pulling it out, we avoid a pointless 500-ish byte memcpy before every
AES-GCM operation, which actually speeds up shorter inputs
non-trivially. See below. A nice bonus to cleaner code.

This doesn't do all the things in crbug.com/382503563, but gets us
partway towards it. It also gets the GCM128_* functions slightly closer
to being a self-contained abstraction.

As part of this, move aes_ctr_set_key into crypto/fipsmodule/aes instead
of crypto/fipsmodule/cipher. It was previously in cipher because it
depended on modes and modes already depends on aes, but now the GCM
dependency is moved out, I think aes makes more sense for it.

On an Intel(R) Xeon(R) Gold 6154 CPU @ 3.00GHz

Before:
Did 22847000 AES-128-GCM (16 bytes) seal operations in 2000070us (182.8 MB/sec)
Did 13704000 AES-128-GCM (256 bytes) seal operations in 2000025us (1754.1 MB/sec)
Did 5856000 AES-128-GCM (1350 bytes) seal operations in 2000107us (3952.6 MB/sec)
Did 1323000 AES-128-GCM (8192 bytes) seal operations in 2001392us (5415.2 MB/sec)
Did 683000 AES-128-GCM (16384 bytes) seal operations in 2000266us (5594.4 MB/sec)
Did 20866750 AES-256-GCM (16 bytes) seal operations in 2000005us (166.9 MB/sec)
Did 12114000 AES-256-GCM (256 bytes) seal operations in 2000109us (1550.5 MB/sec)
Did 4572000 AES-256-GCM (1350 bytes) seal operations in 2000142us (3085.9 MB/sec)
Did 972000 AES-256-GCM (8192 bytes) seal operations in 2000988us (3979.3 MB/sec)
Did 497000 AES-256-GCM (16384 bytes) seal operations in 2000832us (4069.7 MB/sec)

After:
Did 25786000 AES-128-GCM (16 bytes) seal operations in 2000050us (206.3 MB/sec) [+12.9%]
Did 14489000 AES-128-GCM (256 bytes) seal operations in 2000004us (1854.6 MB/sec) [+5.7%]
Did 5927000 AES-128-GCM (1350 bytes) seal operations in 2000248us (4000.2 MB/sec) [+1.2%]
Did 1316000 AES-128-GCM (8192 bytes) seal operations in 2000236us (5389.7 MB/sec) [-0.5%]
Did 679000 AES-128-GCM (16384 bytes) seal operations in 2001792us (5557.4 MB/sec) [-0.7%]
Did 23180500 AES-256-GCM (16 bytes) seal operations in 2000016us (185.4 MB/sec) [+11.1%]
Did 12703000 AES-256-GCM (256 bytes) seal operations in 2000070us (1625.9 MB/sec) [+4.9%]
Did 4668000 AES-256-GCM (1350 bytes) seal operations in 2000238us (3150.5 MB/sec) [+2.1%]
Did 976000 AES-256-GCM (8192 bytes) seal operations in 2000115us (3997.5 MB/sec) [+0.5%]
Did 500000 AES-256-GCM (16384 bytes) seal operations in 2001380us (4093.2 MB/sec) [+0.6%]

The difference is even more pronounced on GCC:

Before:
Did 19500000 AES-128-GCM (16 bytes) seal operations in 2000077us (156.0 MB/sec)
Did 12833000 AES-128-GCM (256 bytes) seal operations in 2000040us (1642.6 MB/sec)
Did 5544000 AES-128-GCM (1350 bytes) seal operations in 2000325us (3741.6 MB/sec)
Did 1305000 AES-128-GCM (8192 bytes) seal operations in 2000029us (5345.2 MB/sec)
Did 677000 AES-128-GCM (16384 bytes) seal operations in 2002554us (5538.9 MB/sec)
Did 18222000 AES-256-GCM (16 bytes) seal operations in 2000026us (145.8 MB/sec)
Did 11351750 AES-256-GCM (256 bytes) seal operations in 2000036us (1453.0 MB/sec)
Did 4431000 AES-256-GCM (1350 bytes) seal operations in 2000278us (2990.5 MB/sec)
Did 965000 AES-256-GCM (8192 bytes) seal operations in 2000617us (3951.4 MB/sec)
Did 497000 AES-256-GCM (16384 bytes) seal operations in 2003070us (4065.2 MB/sec)

After:
Did 25878000 AES-128-GCM (16 bytes) seal operations in 2000001us (207.0 MB/sec) [+32.7%]
Did 14510250 AES-128-GCM (256 bytes) seal operations in 2000034us (1857.3 MB/sec) [+13.1%]
Did 5936000 AES-128-GCM (1350 bytes) seal operations in 2000273us (4006.3 MB/sec) [+7.1%]
Did 1314000 AES-128-GCM (8192 bytes) seal operations in 2000033us (5382.1 MB/sec) [+0.7%]
Did 677000 AES-128-GCM (16384 bytes) seal operations in 2002827us (5538.2 MB/sec) [-0.0%]
Did 23281000 AES-256-GCM (16 bytes) seal operations in 2000048us (186.2 MB/sec) [+27.8%]
Did 12750000 AES-256-GCM (256 bytes) seal operations in 2000008us (1632.0 MB/sec) [+12.3%]
Did 4685000 AES-256-GCM (1350 bytes) seal operations in 2000205us (3162.1 MB/sec) [+5.7%]
Did 977000 AES-256-GCM (8192 bytes) seal operations in 2001842us (3998.1 MB/sec) [+1.2%]
Did 501000 AES-256-GCM (16384 bytes) seal operations in 2003419us (4097.2 MB/sec) [+0.8%]

Bug: 382503563, 42290602
Change-Id: I4690b79212242084cbcde49aa59979344012e5f6
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/74268
Reviewed-by: Adam Langley <agl@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/crypto/cipher_extra/e_aesctrhmac.cc b/crypto/cipher_extra/e_aesctrhmac.cc
index be94d73..f9845a4 100644
--- a/crypto/cipher_extra/e_aesctrhmac.cc
+++ b/crypto/cipher_extra/e_aesctrhmac.cc
@@ -21,6 +21,7 @@
 #include <openssl/err.h>
 #include <openssl/sha.h>
 
+#include "../fipsmodule/aes/internal.h"
 #include "../fipsmodule/cipher/internal.h"
 
 
diff --git a/crypto/cipher_extra/e_aesgcmsiv.cc b/crypto/cipher_extra/e_aesgcmsiv.cc
index 9574b3b..036c7a1 100644
--- a/crypto/cipher_extra/e_aesgcmsiv.cc
+++ b/crypto/cipher_extra/e_aesgcmsiv.cc
@@ -20,6 +20,7 @@
 #include <openssl/crypto.h>
 #include <openssl/err.h>
 
+#include "../fipsmodule/aes/internal.h"
 #include "../fipsmodule/cipher/internal.h"
 #include "../internal.h"
 
diff --git a/crypto/fipsmodule/aes/aes.cc.inc b/crypto/fipsmodule/aes/aes.cc.inc
index 05f7a2b..152f76a 100644
--- a/crypto/fipsmodule/aes/aes.cc.inc
+++ b/crypto/fipsmodule/aes/aes.cc.inc
@@ -144,3 +144,83 @@
   }
 }
 #endif
+
+#if defined(BSAES)
+void vpaes_ctr32_encrypt_blocks_with_bsaes(const uint8_t *in, uint8_t *out,
+                                           size_t blocks, const AES_KEY *key,
+                                           const uint8_t ivec[16]) {
+  // |bsaes_ctr32_encrypt_blocks| is faster than |vpaes_ctr32_encrypt_blocks|,
+  // but it takes at least one full 8-block batch to amortize the conversion.
+  if (blocks < 8) {
+    vpaes_ctr32_encrypt_blocks(in, out, blocks, key, ivec);
+    return;
+  }
+
+  size_t bsaes_blocks = blocks;
+  if (bsaes_blocks % 8 < 6) {
+    // |bsaes_ctr32_encrypt_blocks| internally works in 8-block batches. If the
+    // final batch is too small (under six blocks), it is faster to loop over
+    // |vpaes_encrypt|. Round |bsaes_blocks| down to a multiple of 8.
+    bsaes_blocks -= bsaes_blocks % 8;
+  }
+
+  AES_KEY bsaes;
+  vpaes_encrypt_key_to_bsaes(&bsaes, key);
+  bsaes_ctr32_encrypt_blocks(in, out, bsaes_blocks, &bsaes, ivec);
+  OPENSSL_cleanse(&bsaes, sizeof(bsaes));
+
+  in += 16 * bsaes_blocks;
+  out += 16 * bsaes_blocks;
+  blocks -= bsaes_blocks;
+
+  uint8_t new_ivec[16];
+  memcpy(new_ivec, ivec, 12);
+  uint32_t ctr = CRYPTO_load_u32_be(ivec + 12) + bsaes_blocks;
+  CRYPTO_store_u32_be(new_ivec + 12, ctr);
+
+  // Finish any remaining blocks with |vpaes_ctr32_encrypt_blocks|.
+  vpaes_ctr32_encrypt_blocks(in, out, blocks, key, new_ivec);
+}
+#endif  // BSAES
+
+ctr128_f aes_ctr_set_key(AES_KEY *aes_key, int *out_is_hwaes,
+                         block128_f *out_block, const uint8_t *key,
+                         size_t key_bytes) {
+  // This function assumes the key length was previously validated.
+  assert(key_bytes == 128 / 8 || key_bytes == 192 / 8 || key_bytes == 256 / 8);
+  if (hwaes_capable()) {
+    aes_hw_set_encrypt_key(key, (int)key_bytes * 8, aes_key);
+    if (out_is_hwaes) {
+      *out_is_hwaes = 1;
+    }
+    if (out_block) {
+      *out_block = aes_hw_encrypt;
+    }
+    return aes_hw_ctr32_encrypt_blocks;
+  }
+
+  if (vpaes_capable()) {
+    vpaes_set_encrypt_key(key, (int)key_bytes * 8, aes_key);
+    if (out_block) {
+      *out_block = vpaes_encrypt;
+    }
+    if (out_is_hwaes) {
+      *out_is_hwaes = 0;
+    }
+#if defined(BSAES)
+    assert(bsaes_capable());
+    return vpaes_ctr32_encrypt_blocks_with_bsaes;
+#else
+    return vpaes_ctr32_encrypt_blocks;
+#endif
+  }
+
+  aes_nohw_set_encrypt_key(key, (int)key_bytes * 8, aes_key);
+  if (out_is_hwaes) {
+    *out_is_hwaes = 0;
+  }
+  if (out_block) {
+    *out_block = aes_nohw_encrypt;
+  }
+  return aes_nohw_ctr32_encrypt_blocks;
+}
diff --git a/crypto/fipsmodule/aes/internal.h b/crypto/fipsmodule/aes/internal.h
index 7730b00..fa5ba87 100644
--- a/crypto/fipsmodule/aes/internal.h
+++ b/crypto/fipsmodule/aes/internal.h
@@ -24,6 +24,34 @@
 extern "C" {
 
 
+// block128_f is the type of an AES block cipher implementation.
+//
+// Unlike upstream OpenSSL, it and the other functions in this file hard-code
+// |AES_KEY|. It is undefined in C to call a function pointer with anything
+// other than the original type. Thus we either must match |block128_f| to the
+// type signature of |AES_encrypt| and friends or pass in |void*| wrapper
+// functions.
+//
+// These functions are called exclusively with AES, so we use the former.
+typedef void (*block128_f)(const uint8_t in[16], uint8_t out[16],
+                           const AES_KEY *key);
+
+// ctr128_f is the type of a function that performs CTR-mode encryption.
+typedef void (*ctr128_f)(const uint8_t *in, uint8_t *out, size_t blocks,
+                         const AES_KEY *key, const uint8_t ivec[16]);
+
+// aes_ctr_set_key initialises |*aes_key| using |key_bytes| bytes from |key|,
+// where |key_bytes| must either be 16, 24 or 32. If not NULL, |*out_block| is
+// set to a function that encrypts single blocks. If not NULL, |*out_is_hwaes|
+// is set to whether the hardware AES implementation was used. It returns a
+// function for optimised CTR-mode.
+ctr128_f aes_ctr_set_key(AES_KEY *aes_key, int *out_is_hwaes,
+                         block128_f *out_block, const uint8_t *key,
+                         size_t key_bytes);
+
+
+// AES implementations.
+
 #if !defined(OPENSSL_NO_ASM)
 
 #if defined(OPENSSL_X86) || defined(OPENSSL_X86_64)
@@ -152,6 +180,9 @@
 // VPAES to BSAES conversions are available on all BSAES platforms.
 void vpaes_encrypt_key_to_bsaes(AES_KEY *out_bsaes, const AES_KEY *vpaes);
 void vpaes_decrypt_key_to_bsaes(AES_KEY *out_bsaes, const AES_KEY *vpaes);
+void vpaes_ctr32_encrypt_blocks_with_bsaes(const uint8_t *in, uint8_t *out,
+                                           size_t blocks, const AES_KEY *key,
+                                           const uint8_t ivec[16]);
 #else
 OPENSSL_INLINE char bsaes_capable(void) { return 0; }
 
diff --git a/crypto/fipsmodule/cipher/e_aes.cc.inc b/crypto/fipsmodule/cipher/e_aes.cc.inc
index 30957f0..346cd8b 100644
--- a/crypto/fipsmodule/cipher/e_aes.cc.inc
+++ b/crypto/fipsmodule/cipher/e_aes.cc.inc
@@ -71,45 +71,6 @@
 
 #define AES_GCM_NONCE_LENGTH 12
 
-#if defined(BSAES)
-static void vpaes_ctr32_encrypt_blocks_with_bsaes(const uint8_t *in,
-                                                  uint8_t *out, size_t blocks,
-                                                  const AES_KEY *key,
-                                                  const uint8_t ivec[16]) {
-  // |bsaes_ctr32_encrypt_blocks| is faster than |vpaes_ctr32_encrypt_blocks|,
-  // but it takes at least one full 8-block batch to amortize the conversion.
-  if (blocks < 8) {
-    vpaes_ctr32_encrypt_blocks(in, out, blocks, key, ivec);
-    return;
-  }
-
-  size_t bsaes_blocks = blocks;
-  if (bsaes_blocks % 8 < 6) {
-    // |bsaes_ctr32_encrypt_blocks| internally works in 8-block batches. If the
-    // final batch is too small (under six blocks), it is faster to loop over
-    // |vpaes_encrypt|. Round |bsaes_blocks| down to a multiple of 8.
-    bsaes_blocks -= bsaes_blocks % 8;
-  }
-
-  AES_KEY bsaes;
-  vpaes_encrypt_key_to_bsaes(&bsaes, key);
-  bsaes_ctr32_encrypt_blocks(in, out, bsaes_blocks, &bsaes, ivec);
-  OPENSSL_cleanse(&bsaes, sizeof(bsaes));
-
-  in += 16 * bsaes_blocks;
-  out += 16 * bsaes_blocks;
-  blocks -= bsaes_blocks;
-
-  uint8_t new_ivec[16];
-  memcpy(new_ivec, ivec, 12);
-  uint32_t ctr = CRYPTO_load_u32_be(ivec + 12) + bsaes_blocks;
-  CRYPTO_store_u32_be(new_ivec + 12, ctr);
-
-  // Finish any remaining blocks with |vpaes_ctr32_encrypt_blocks|.
-  vpaes_ctr32_encrypt_blocks(in, out, blocks, key, new_ivec);
-}
-#endif  // BSAES
-
 typedef struct {
   union {
     double align;
@@ -123,11 +84,8 @@
 } EVP_AES_KEY;
 
 typedef struct {
+  GCM128_KEY key;
   GCM128_CONTEXT gcm;
-  union {
-    double align;
-    AES_KEY ks;
-  } ks;         // AES key schedule to use
   int key_set;  // Set if key initialised
   int iv_set;   // Set if an iv is set
   uint8_t *iv;  // Temporary IV store
@@ -283,48 +241,6 @@
   return 1;
 }
 
-ctr128_f aes_ctr_set_key(AES_KEY *aes_key, GCM128_KEY *gcm_key,
-                         block128_f *out_block, const uint8_t *key,
-                         size_t key_bytes) {
-  // This function assumes the key length was previously validated.
-  assert(key_bytes == 128 / 8 || key_bytes == 192 / 8 || key_bytes == 256 / 8);
-  if (hwaes_capable()) {
-    aes_hw_set_encrypt_key(key, (int)key_bytes * 8, aes_key);
-    if (gcm_key != NULL) {
-      CRYPTO_gcm128_init_key(gcm_key, aes_key, aes_hw_encrypt, 1);
-    }
-    if (out_block) {
-      *out_block = aes_hw_encrypt;
-    }
-    return aes_hw_ctr32_encrypt_blocks;
-  }
-
-  if (vpaes_capable()) {
-    vpaes_set_encrypt_key(key, (int)key_bytes * 8, aes_key);
-    if (out_block) {
-      *out_block = vpaes_encrypt;
-    }
-    if (gcm_key != NULL) {
-      CRYPTO_gcm128_init_key(gcm_key, aes_key, vpaes_encrypt, 0);
-    }
-#if defined(BSAES)
-    assert(bsaes_capable());
-    return vpaes_ctr32_encrypt_blocks_with_bsaes;
-#else
-    return vpaes_ctr32_encrypt_blocks;
-#endif
-  }
-
-  aes_nohw_set_encrypt_key(key, (int)key_bytes * 8, aes_key);
-  if (gcm_key != NULL) {
-    CRYPTO_gcm128_init_key(gcm_key, aes_key, aes_nohw_encrypt, 0);
-  }
-  if (out_block) {
-    *out_block = aes_nohw_encrypt;
-  }
-  return aes_nohw_ctr32_encrypt_blocks;
-}
-
 #if defined(OPENSSL_32_BIT)
 #define EVP_AES_GCM_CTX_PADDING (4 + 8)
 #else
@@ -368,24 +284,25 @@
       break;
   }
 
+  // We must configure first the key, then the IV, but the caller may pass both
+  // together, or separately in either order.
   if (key) {
     OPENSSL_memset(&gctx->gcm, 0, sizeof(gctx->gcm));
-    gctx->ctr = aes_ctr_set_key(&gctx->ks.ks, &gctx->gcm.gcm_key, NULL, key,
-                                ctx->key_len);
-    // If we have an iv can set it directly, otherwise use saved IV.
+    CRYPTO_gcm128_init_aes_key(&gctx->key, key, ctx->key_len);
+    // Use the IV if specified. Otherwise, use the saved IV, if any.
     if (iv == NULL && gctx->iv_set) {
       iv = gctx->iv;
     }
     if (iv) {
-      CRYPTO_gcm128_setiv(&gctx->gcm, &gctx->ks.ks, iv, gctx->ivlen);
+      CRYPTO_gcm128_init_ctx(&gctx->key, &gctx->gcm, iv, gctx->ivlen);
       gctx->iv_set = 1;
     }
     gctx->key_set = 1;
   } else {
-    // If key set use IV, otherwise copy
     if (gctx->key_set) {
-      CRYPTO_gcm128_setiv(&gctx->gcm, &gctx->ks.ks, iv, gctx->ivlen);
+      CRYPTO_gcm128_init_ctx(&gctx->key, &gctx->gcm, iv, gctx->ivlen);
     } else {
+      // The caller specified the IV before the key. Save the IV for later.
       OPENSSL_memcpy(gctx->iv, iv, gctx->ivlen);
     }
     gctx->iv_set = 1;
@@ -396,6 +313,7 @@
 
 static void aes_gcm_cleanup(EVP_CIPHER_CTX *c) {
   EVP_AES_GCM_CTX *gctx = aes_gcm_from_cipher_ctx(c);
+  OPENSSL_cleanse(&gctx->key, sizeof(gctx->key));
   OPENSSL_cleanse(&gctx->gcm, sizeof(gctx->gcm));
   if (gctx->iv != c->iv) {
     OPENSSL_free(gctx->iv);
@@ -479,7 +397,7 @@
       if (gctx->iv_gen == 0 || gctx->key_set == 0) {
         return 0;
       }
-      CRYPTO_gcm128_setiv(&gctx->gcm, &gctx->ks.ks, gctx->iv, gctx->ivlen);
+      CRYPTO_gcm128_init_ctx(&gctx->key, &gctx->gcm, gctx->iv, gctx->ivlen);
       if (arg <= 0 || arg > gctx->ivlen) {
         arg = gctx->ivlen;
       }
@@ -497,7 +415,7 @@
         return 0;
       }
       OPENSSL_memcpy(gctx->iv + gctx->ivlen - arg, ptr, arg);
-      CRYPTO_gcm128_setiv(&gctx->gcm, &gctx->ks.ks, gctx->iv, gctx->ivlen);
+      CRYPTO_gcm128_init_ctx(&gctx->key, &gctx->gcm, gctx->iv, gctx->ivlen);
       gctx->iv_set = 1;
       return 1;
 
@@ -546,31 +464,29 @@
 
   if (in) {
     if (out == NULL) {
-      if (!CRYPTO_gcm128_aad(&gctx->gcm, in, len)) {
+      if (!CRYPTO_gcm128_aad(&gctx->key, &gctx->gcm, in, len)) {
         return -1;
       }
     } else if (ctx->encrypt) {
-      if (!CRYPTO_gcm128_encrypt_ctr32(&gctx->gcm, &gctx->ks.ks, in, out, len,
-                                       gctx->ctr)) {
+      if (!CRYPTO_gcm128_encrypt(&gctx->key, &gctx->gcm, in, out, len)) {
         return -1;
       }
     } else {
-      if (!CRYPTO_gcm128_decrypt_ctr32(&gctx->gcm, &gctx->ks.ks, in, out, len,
-                                       gctx->ctr)) {
+      if (!CRYPTO_gcm128_decrypt(&gctx->key, &gctx->gcm, in, out, len)) {
         return -1;
       }
     }
     return (int)len;
   } else {
     if (!ctx->encrypt) {
-      if (gctx->taglen < 0 ||
-          !CRYPTO_gcm128_finish(&gctx->gcm, ctx->buf, gctx->taglen)) {
+      if (gctx->taglen < 0 || !CRYPTO_gcm128_finish(&gctx->key, &gctx->gcm,
+                                                    ctx->buf, gctx->taglen)) {
         return -1;
       }
       gctx->iv_set = 0;
       return 0;
     }
-    CRYPTO_gcm128_tag(&gctx->gcm, ctx->buf, 16);
+    CRYPTO_gcm128_tag(&gctx->key, &gctx->gcm, ctx->buf, 16);
     gctx->taglen = 16;
     // Don't reuse the IV
     gctx->iv_set = 0;
@@ -860,12 +776,7 @@
 #define EVP_AEAD_AES_GCM_TAG_LEN 16
 
 struct aead_aes_gcm_ctx {
-  union {
-    double align;
-    AES_KEY ks;
-  } ks;
-  GCM128_KEY gcm_key;
-  ctr128_f ctr;
+  GCM128_KEY key;
 };
 
 static int aead_aes_gcm_init_impl(struct aead_aes_gcm_ctx *gcm_ctx,
@@ -897,8 +808,7 @@
     return 0;
   }
 
-  gcm_ctx->ctr =
-      aes_ctr_set_key(&gcm_ctx->ks.ks, &gcm_ctx->gcm_key, NULL, key, key_len);
+  CRYPTO_gcm128_init_aes_key(&gcm_ctx->key, key, key_len);
   *out_tag_len = tag_len;
   return 1;
 }
@@ -944,28 +854,24 @@
     return 0;
   }
 
-  const AES_KEY *key = &gcm_ctx->ks.ks;
-
+  const GCM128_KEY *key = &gcm_ctx->key;
   GCM128_CONTEXT gcm;
-  OPENSSL_memset(&gcm, 0, sizeof(gcm));
-  OPENSSL_memcpy(&gcm.gcm_key, &gcm_ctx->gcm_key, sizeof(gcm.gcm_key));
-  CRYPTO_gcm128_setiv(&gcm, key, nonce, nonce_len);
+  CRYPTO_gcm128_init_ctx(key, &gcm, nonce, nonce_len);
 
-  if (ad_len > 0 && !CRYPTO_gcm128_aad(&gcm, ad, ad_len)) {
+  if (ad_len > 0 && !CRYPTO_gcm128_aad(key, &gcm, ad, ad_len)) {
     return 0;
   }
 
-  if (!CRYPTO_gcm128_encrypt_ctr32(&gcm, key, in, out, in_len, gcm_ctx->ctr)) {
+  if (!CRYPTO_gcm128_encrypt(key, &gcm, in, out, in_len)) {
     return 0;
   }
 
   if (extra_in_len > 0 &&
-      !CRYPTO_gcm128_encrypt_ctr32(&gcm, key, extra_in, out_tag, extra_in_len,
-                                   gcm_ctx->ctr)) {
+      !CRYPTO_gcm128_encrypt(key, &gcm, extra_in, out_tag, extra_in_len)) {
     return 0;
   }
 
-  CRYPTO_gcm128_tag(&gcm, out_tag + extra_in_len, tag_len);
+  CRYPTO_gcm128_tag(key, &gcm, out_tag + extra_in_len, tag_len);
   *out_tag_len = tag_len + extra_in_len;
 
   return 1;
@@ -1001,22 +907,19 @@
     return 0;
   }
 
-  const AES_KEY *key = &gcm_ctx->ks.ks;
-
+  const GCM128_KEY *key = &gcm_ctx->key;
   GCM128_CONTEXT gcm;
-  OPENSSL_memset(&gcm, 0, sizeof(gcm));
-  OPENSSL_memcpy(&gcm.gcm_key, &gcm_ctx->gcm_key, sizeof(gcm.gcm_key));
-  CRYPTO_gcm128_setiv(&gcm, key, nonce, nonce_len);
+  CRYPTO_gcm128_init_ctx(key, &gcm, nonce, nonce_len);
 
-  if (!CRYPTO_gcm128_aad(&gcm, ad, ad_len)) {
+  if (!CRYPTO_gcm128_aad(key, &gcm, ad, ad_len)) {
     return 0;
   }
 
-  if (!CRYPTO_gcm128_decrypt_ctr32(&gcm, key, in, out, in_len, gcm_ctx->ctr)) {
+  if (!CRYPTO_gcm128_decrypt(key, &gcm, in, out, in_len)) {
     return 0;
   }
 
-  CRYPTO_gcm128_tag(&gcm, tag, tag_len);
+  CRYPTO_gcm128_tag(key, &gcm, tag, tag_len);
   if (CRYPTO_memcmp(tag, in_tag, tag_len) != 0) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_BAD_DECRYPT);
     return 0;
diff --git a/crypto/fipsmodule/cipher/e_aesccm.cc.inc b/crypto/fipsmodule/cipher/e_aesccm.cc.inc
index e1682b6..ede2988 100644
--- a/crypto/fipsmodule/cipher/e_aesccm.cc.inc
+++ b/crypto/fipsmodule/cipher/e_aesccm.cc.inc
@@ -55,6 +55,7 @@
 #include <openssl/mem.h>
 
 #include "../delocate.h"
+#include "../aes/internal.h"
 #include "../modes/internal.h"
 #include "../service_indicator/internal.h"
 #include "internal.h"
diff --git a/crypto/fipsmodule/cipher/internal.h b/crypto/fipsmodule/cipher/internal.h
index e9709f0..6e9e73b 100644
--- a/crypto/fipsmodule/cipher/internal.h
+++ b/crypto/fipsmodule/cipher/internal.h
@@ -148,15 +148,6 @@
   int (*ctrl)(EVP_CIPHER_CTX *, int type, int arg, void *ptr);
 };
 
-// aes_ctr_set_key initialises |*aes_key| using |key_bytes| bytes from |key|,
-// where |key_bytes| must either be 16, 24 or 32. If not NULL, |*out_block| is
-// set to a function that encrypts single blocks. If not NULL, |*gcm_key| is
-// initialised to do GHASH with the given key. It returns a function for
-// optimised CTR-mode.
-ctr128_f aes_ctr_set_key(AES_KEY *aes_key, GCM128_KEY *gcm_key,
-                         block128_f *out_block, const uint8_t *key,
-                         size_t key_bytes);
-
 #if defined(__cplusplus)
 }  // extern C
 #endif
diff --git a/crypto/fipsmodule/modes/gcm.cc.inc b/crypto/fipsmodule/modes/gcm.cc.inc
index cbd1858..10ec1be 100644
--- a/crypto/fipsmodule/modes/gcm.cc.inc
+++ b/crypto/fipsmodule/modes/gcm.cc.inc
@@ -48,13 +48,13 @@
 
 #include <openssl/base.h>
 
-#include <assert.h>
 #include <string.h>
 
 #include <openssl/mem.h>
 
-#include "internal.h"
 #include "../../internal.h"
+#include "../aes/internal.h"
+#include "internal.h"
 
 
 // kSizeTWithoutLower4Bits is a mask that can be used to zero the lower four
@@ -62,9 +62,9 @@
 static const size_t kSizeTWithoutLower4Bits = (size_t) -16;
 
 
-#define GCM_MUL(ctx, Xi) gcm_gmult_nohw((ctx)->Xi, (ctx)->gcm_key.Htable)
-#define GHASH(ctx, in, len) \
-  gcm_ghash_nohw((ctx)->Xi, (ctx)->gcm_key.Htable, in, len)
+#define GCM_MUL(key, ctx, Xi) gcm_gmult_nohw((ctx)->Xi, (key)->Htable)
+#define GHASH(key, ctx, in, len) \
+  gcm_ghash_nohw((ctx)->Xi, (key)->Htable, in, len)
 // GHASH_CHUNK is "stride parameter" missioned to mitigate cache
 // trashing effect. In other words idea is to hash data while it's
 // still in L1 cache after encryption pass...
@@ -126,10 +126,10 @@
 
 #ifdef GCM_FUNCREF
 #undef GCM_MUL
-#define GCM_MUL(ctx, Xi) (*gcm_gmult_p)((ctx)->Xi, (ctx)->gcm_key.Htable)
+#define GCM_MUL(key, ctx, Xi) (*gcm_gmult_p)((ctx)->Xi, (key)->Htable)
 #undef GHASH
-#define GHASH(ctx, in, len) \
-  (*gcm_ghash_p)((ctx)->Xi, (ctx)->gcm_key.Htable, in, len)
+#define GHASH(key, ctx, in, len) \
+  (*gcm_ghash_p)((ctx)->Xi, (key)->Htable, in, len)
 #endif  // GCM_FUNCREF
 
 #if defined(HW_GCM) && defined(OPENSSL_X86_64)
@@ -272,14 +272,16 @@
   *out_hash = gcm_ghash_nohw;
 }
 
-void CRYPTO_gcm128_init_key(GCM128_KEY *gcm_key, const AES_KEY *aes_key,
-                            block128_f block, int block_is_hwaes) {
+void CRYPTO_gcm128_init_aes_key(GCM128_KEY *gcm_key, const uint8_t *key,
+                                size_t key_bytes) {
   OPENSSL_memset(gcm_key, 0, sizeof(*gcm_key));
-  gcm_key->block = block;
+  int is_hwaes;
+  gcm_key->ctr = aes_ctr_set_key(&gcm_key->aes, &is_hwaes, &gcm_key->block, key,
+                                 key_bytes);
 
   uint8_t ghash_key[16];
   OPENSSL_memset(ghash_key, 0, sizeof(ghash_key));
-  (*block)(ghash_key, ghash_key, aes_key);
+  gcm_key->block(ghash_key, ghash_key, &gcm_key->aes);
 
   CRYPTO_ghash_init(&gcm_key->gmult, &gcm_key->ghash, gcm_key->Htable,
                     ghash_key);
@@ -292,22 +294,21 @@
   } else if (gcm_key->ghash == gcm_ghash_vpclmulqdq_avx10_512 &&
              CRYPTO_is_VAES_capable()) {
     gcm_key->impl = gcm_x86_vaes_avx10_512;
-  } else if (gcm_key->ghash == gcm_ghash_avx && block_is_hwaes) {
+  } else if (gcm_key->ghash == gcm_ghash_avx && is_hwaes) {
     gcm_key->impl = gcm_x86_aesni;
   }
 #elif defined(OPENSSL_AARCH64)
-  if (gcm_pmull_capable() && block_is_hwaes) {
+  if (gcm_pmull_capable() && is_hwaes) {
     gcm_key->impl = gcm_arm64_aes;
   }
 #endif
 #endif
 }
 
-void CRYPTO_gcm128_setiv(GCM128_CONTEXT *ctx, const AES_KEY *key,
-                         const uint8_t *iv, size_t len) {
+void CRYPTO_gcm128_init_ctx(const GCM128_KEY *key, GCM128_CONTEXT *ctx,
+                            const uint8_t *iv, size_t iv_len) {
 #ifdef GCM_FUNCREF
-  void (*gcm_gmult_p)(uint8_t Xi[16], const u128 Htable[16]) =
-      ctx->gcm_key.gmult;
+  void (*gcm_gmult_p)(uint8_t Xi[16], const u128 Htable[16]) = key->gmult;
 #endif
 
   OPENSSL_memset(&ctx->Yi, 0, sizeof(ctx->Yi));
@@ -318,24 +319,24 @@
   ctx->mres = 0;
 
   uint32_t ctr;
-  if (len == 12) {
+  if (iv_len == 12) {
     OPENSSL_memcpy(ctx->Yi, iv, 12);
     ctx->Yi[15] = 1;
     ctr = 1;
   } else {
-    uint64_t len0 = len;
+    uint64_t len0 = iv_len;
 
-    while (len >= 16) {
+    while (iv_len >= 16) {
       CRYPTO_xor16(ctx->Yi, ctx->Yi, iv);
-      GCM_MUL(ctx, Yi);
+      GCM_MUL(key, ctx, Yi);
       iv += 16;
-      len -= 16;
+      iv_len -= 16;
     }
-    if (len) {
-      for (size_t i = 0; i < len; ++i) {
+    if (iv_len) {
+      for (size_t i = 0; i < iv_len; ++i) {
         ctx->Yi[i] ^= iv[i];
       }
-      GCM_MUL(ctx, Yi);
+      GCM_MUL(key, ctx, Yi);
     }
 
     uint8_t len_block[16];
@@ -343,21 +344,21 @@
     CRYPTO_store_u64_be(len_block + 8, len0 << 3);
     CRYPTO_xor16(ctx->Yi, ctx->Yi, len_block);
 
-    GCM_MUL(ctx, Yi);
+    GCM_MUL(key, ctx, Yi);
     ctr = CRYPTO_load_u32_be(ctx->Yi + 12);
   }
 
-  (*ctx->gcm_key.block)(ctx->Yi, ctx->EK0, key);
+  key->block(ctx->Yi, ctx->EK0, &key->aes);
   ++ctr;
   CRYPTO_store_u32_be(ctx->Yi + 12, ctr);
 }
 
-int CRYPTO_gcm128_aad(GCM128_CONTEXT *ctx, const uint8_t *aad, size_t len) {
+int CRYPTO_gcm128_aad(const GCM128_KEY *key, GCM128_CONTEXT *ctx,
+                      const uint8_t *aad, size_t aad_len) {
 #ifdef GCM_FUNCREF
-  void (*gcm_gmult_p)(uint8_t Xi[16], const u128 Htable[16]) =
-      ctx->gcm_key.gmult;
+  void (*gcm_gmult_p)(uint8_t Xi[16], const u128 Htable[16]) = key->gmult;
   void (*gcm_ghash_p)(uint8_t Xi[16], const u128 Htable[16], const uint8_t *inp,
-                      size_t len) = ctx->gcm_key.ghash;
+                      size_t len) = key->ghash;
 #endif
 
   if (ctx->len.msg != 0) {
@@ -365,21 +366,21 @@
     return 0;
   }
 
-  uint64_t alen = ctx->len.aad + len;
-  if (alen > (UINT64_C(1) << 61) || (sizeof(len) == 8 && alen < len)) {
+  uint64_t alen = ctx->len.aad + aad_len;
+  if (alen > (UINT64_C(1) << 61) || (sizeof(aad_len) == 8 && alen < aad_len)) {
     return 0;
   }
   ctx->len.aad = alen;
 
   unsigned n = ctx->ares;
   if (n) {
-    while (n && len) {
+    while (n && aad_len) {
       ctx->Xi[n] ^= *(aad++);
-      --len;
+      --aad_len;
       n = (n + 1) % 16;
     }
     if (n == 0) {
-      GCM_MUL(ctx, Xi);
+      GCM_MUL(key, ctx, Xi);
     } else {
       ctx->ares = n;
       return 1;
@@ -387,17 +388,17 @@
   }
 
   // Process a whole number of blocks.
-  size_t len_blocks = len & kSizeTWithoutLower4Bits;
+  size_t len_blocks = aad_len & kSizeTWithoutLower4Bits;
   if (len_blocks != 0) {
-    GHASH(ctx, aad, len_blocks);
+    GHASH(key, ctx, aad, len_blocks);
     aad += len_blocks;
-    len -= len_blocks;
+    aad_len -= len_blocks;
   }
 
   // Process the remainder.
-  if (len != 0) {
-    n = (unsigned int)len;
-    for (size_t i = 0; i < len; ++i) {
+  if (aad_len != 0) {
+    n = (unsigned int)aad_len;
+    for (size_t i = 0; i < aad_len; ++i) {
       ctx->Xi[i] ^= aad[i];
     }
   }
@@ -406,14 +407,12 @@
   return 1;
 }
 
-int CRYPTO_gcm128_encrypt_ctr32(GCM128_CONTEXT *ctx, const AES_KEY *key,
-                                const uint8_t *in, uint8_t *out, size_t len,
-                                ctr128_f stream) {
+int CRYPTO_gcm128_encrypt(const GCM128_KEY *key, GCM128_CONTEXT *ctx,
+                          const uint8_t *in, uint8_t *out, size_t len) {
 #ifdef GCM_FUNCREF
-  void (*gcm_gmult_p)(uint8_t Xi[16], const u128 Htable[16]) =
-      ctx->gcm_key.gmult;
+  void (*gcm_gmult_p)(uint8_t Xi[16], const u128 Htable[16]) = key->gmult;
   void (*gcm_ghash_p)(uint8_t Xi[16], const u128 Htable[16], const uint8_t *inp,
-                      size_t len) = ctx->gcm_key.ghash;
+                      size_t len) = key->ghash;
 #endif
 
   uint64_t mlen = ctx->len.msg + len;
@@ -425,7 +424,7 @@
 
   if (ctx->ares) {
     // First call to encrypt finalizes GHASH(AAD)
-    GCM_MUL(ctx, Xi);
+    GCM_MUL(key, ctx, Xi);
     ctx->ares = 0;
   }
 
@@ -437,7 +436,7 @@
       n = (n + 1) % 16;
     }
     if (n == 0) {
-      GCM_MUL(ctx, Xi);
+      GCM_MUL(key, ctx, Xi);
     } else {
       ctx->mres = n;
       return 1;
@@ -446,11 +445,11 @@
 
 #if defined(HW_GCM)
   // Check |len| to work around a C language bug. See https://crbug.com/1019588.
-  if (ctx->gcm_key.impl != gcm_separate && len > 0) {
+  if (key->impl != gcm_separate && len > 0) {
     // |hw_gcm_encrypt| may not process all the input given to it. It may
     // not process *any* of its input if it is deemed too small.
-    size_t bulk = hw_gcm_encrypt(in, out, len, key, ctx->Yi, ctx->Xi,
-                                 ctx->gcm_key.Htable, ctx->gcm_key.impl);
+    size_t bulk = hw_gcm_encrypt(in, out, len, &key->aes, ctx->Yi, ctx->Xi,
+                                 key->Htable, key->impl);
     in += bulk;
     out += bulk;
     len -= bulk;
@@ -458,29 +457,31 @@
 #endif
 
   uint32_t ctr = CRYPTO_load_u32_be(ctx->Yi + 12);
+  ctr128_f stream = key->ctr;
   while (len >= GHASH_CHUNK) {
-    (*stream)(in, out, GHASH_CHUNK / 16, key, ctx->Yi);
+    (*stream)(in, out, GHASH_CHUNK / 16, &key->aes, ctx->Yi);
     ctr += GHASH_CHUNK / 16;
     CRYPTO_store_u32_be(ctx->Yi + 12, ctr);
-    GHASH(ctx, out, GHASH_CHUNK);
+    GHASH(key, ctx, out, GHASH_CHUNK);
     out += GHASH_CHUNK;
     in += GHASH_CHUNK;
     len -= GHASH_CHUNK;
   }
+
   size_t len_blocks = len & kSizeTWithoutLower4Bits;
   if (len_blocks != 0) {
     size_t j = len_blocks / 16;
-
-    (*stream)(in, out, j, key, ctx->Yi);
-    ctr += (unsigned int)j;
+    (*stream)(in, out, j, &key->aes, ctx->Yi);
+    ctr += (uint32_t)j;
     CRYPTO_store_u32_be(ctx->Yi + 12, ctr);
     in += len_blocks;
     len -= len_blocks;
-    GHASH(ctx, out, len_blocks);
+    GHASH(key, ctx, out, len_blocks);
     out += len_blocks;
   }
+
   if (len) {
-    (*ctx->gcm_key.block)(ctx->Yi, ctx->EKi, key);
+    key->block(ctx->Yi, ctx->EKi, &key->aes);
     ++ctr;
     CRYPTO_store_u32_be(ctx->Yi + 12, ctr);
     while (len--) {
@@ -493,14 +494,12 @@
   return 1;
 }
 
-int CRYPTO_gcm128_decrypt_ctr32(GCM128_CONTEXT *ctx, const AES_KEY *key,
-                                const uint8_t *in, uint8_t *out, size_t len,
-                                ctr128_f stream) {
+int CRYPTO_gcm128_decrypt(const GCM128_KEY *key, GCM128_CONTEXT *ctx,
+                          const uint8_t *in, uint8_t *out, size_t len) {
 #ifdef GCM_FUNCREF
-  void (*gcm_gmult_p)(uint8_t Xi[16], const u128 Htable[16]) =
-      ctx->gcm_key.gmult;
+  void (*gcm_gmult_p)(uint8_t Xi[16], const u128 Htable[16]) = key->gmult;
   void (*gcm_ghash_p)(uint8_t Xi[16], const u128 Htable[16], const uint8_t *inp,
-                      size_t len) = ctx->gcm_key.ghash;
+                      size_t len) = key->ghash;
 #endif
 
   uint64_t mlen = ctx->len.msg + len;
@@ -512,7 +511,7 @@
 
   if (ctx->ares) {
     // First call to decrypt finalizes GHASH(AAD)
-    GCM_MUL(ctx, Xi);
+    GCM_MUL(key, ctx, Xi);
     ctx->ares = 0;
   }
 
@@ -526,7 +525,7 @@
       n = (n + 1) % 16;
     }
     if (n == 0) {
-      GCM_MUL(ctx, Xi);
+      GCM_MUL(key, ctx, Xi);
     } else {
       ctx->mres = n;
       return 1;
@@ -535,11 +534,11 @@
 
 #if defined(HW_GCM)
   // Check |len| to work around a C language bug. See https://crbug.com/1019588.
-  if (ctx->gcm_key.impl != gcm_separate && len > 0) {
+  if (key->impl != gcm_separate && len > 0) {
     // |hw_gcm_decrypt| may not process all the input given to it. It may
     // not process *any* of its input if it is deemed too small.
-    size_t bulk = hw_gcm_decrypt(in, out, len, key, ctx->Yi, ctx->Xi,
-                                 ctx->gcm_key.Htable, ctx->gcm_key.impl);
+    size_t bulk = hw_gcm_decrypt(in, out, len, &key->aes, ctx->Yi, ctx->Xi,
+                                 key->Htable, key->impl);
     in += bulk;
     out += bulk;
     len -= bulk;
@@ -547,29 +546,31 @@
 #endif
 
   uint32_t ctr = CRYPTO_load_u32_be(ctx->Yi + 12);
+  ctr128_f stream = key->ctr;
   while (len >= GHASH_CHUNK) {
-    GHASH(ctx, in, GHASH_CHUNK);
-    (*stream)(in, out, GHASH_CHUNK / 16, key, ctx->Yi);
+    GHASH(key, ctx, in, GHASH_CHUNK);
+    (*stream)(in, out, GHASH_CHUNK / 16, &key->aes, ctx->Yi);
     ctr += GHASH_CHUNK / 16;
     CRYPTO_store_u32_be(ctx->Yi + 12, ctr);
     out += GHASH_CHUNK;
     in += GHASH_CHUNK;
     len -= GHASH_CHUNK;
   }
+
   size_t len_blocks = len & kSizeTWithoutLower4Bits;
   if (len_blocks != 0) {
     size_t j = len_blocks / 16;
-
-    GHASH(ctx, in, len_blocks);
-    (*stream)(in, out, j, key, ctx->Yi);
-    ctr += (unsigned int)j;
+    GHASH(key, ctx, in, len_blocks);
+    (*stream)(in, out, j, &key->aes, ctx->Yi);
+    ctr += (uint32_t)j;
     CRYPTO_store_u32_be(ctx->Yi + 12, ctr);
     out += len_blocks;
     in += len_blocks;
     len -= len_blocks;
   }
+
   if (len) {
-    (*ctx->gcm_key.block)(ctx->Yi, ctx->EKi, key);
+    key->block(ctx->Yi, ctx->EKi, &key->aes);
     ++ctr;
     CRYPTO_store_u32_be(ctx->Yi + 12, ctr);
     while (len--) {
@@ -584,21 +585,21 @@
   return 1;
 }
 
-int CRYPTO_gcm128_finish(GCM128_CONTEXT *ctx, const uint8_t *tag, size_t len) {
+int CRYPTO_gcm128_finish(const GCM128_KEY *key, GCM128_CONTEXT *ctx,
+                         const uint8_t *tag, size_t len) {
 #ifdef GCM_FUNCREF
-  void (*gcm_gmult_p)(uint8_t Xi[16], const u128 Htable[16]) =
-      ctx->gcm_key.gmult;
+  void (*gcm_gmult_p)(uint8_t Xi[16], const u128 Htable[16]) = key->gmult;
 #endif
 
   if (ctx->mres || ctx->ares) {
-    GCM_MUL(ctx, Xi);
+    GCM_MUL(key, ctx, Xi);
   }
 
   uint8_t len_block[16];
   CRYPTO_store_u64_be(len_block, ctx->len.aad << 3);
   CRYPTO_store_u64_be(len_block + 8, ctx->len.msg << 3);
   CRYPTO_xor16(ctx->Xi, ctx->Xi, len_block);
-  GCM_MUL(ctx, Xi);
+  GCM_MUL(key, ctx, Xi);
   CRYPTO_xor16(ctx->Xi, ctx->Xi, ctx->EK0);
 
   if (tag && len <= sizeof(ctx->Xi)) {
@@ -608,8 +609,9 @@
   }
 }
 
-void CRYPTO_gcm128_tag(GCM128_CONTEXT *ctx, unsigned char *tag, size_t len) {
-  CRYPTO_gcm128_finish(ctx, NULL, 0);
+void CRYPTO_gcm128_tag(const GCM128_KEY *key, GCM128_CONTEXT *ctx, uint8_t *tag,
+                       size_t len) {
+  CRYPTO_gcm128_finish(key, ctx, NULL, 0);
   OPENSSL_memcpy(tag, ctx->Xi, len <= sizeof(ctx->Xi) ? len : sizeof(ctx->Xi));
 }
 
diff --git a/crypto/fipsmodule/modes/internal.h b/crypto/fipsmodule/modes/internal.h
index 263b314..1878121 100644
--- a/crypto/fipsmodule/modes/internal.h
+++ b/crypto/fipsmodule/modes/internal.h
@@ -58,24 +58,13 @@
 #include <string.h>
 
 #include "../../internal.h"
+#include "../aes/internal.h"
 
 #if defined(__cplusplus)
 extern "C" {
 #endif
 
 
-// block128_f is the type of an AES block cipher implementation.
-//
-// Unlike upstream OpenSSL, it and the other functions in this file hard-code
-// |AES_KEY|. It is undefined in C to call a function pointer with anything
-// other than the original type. Thus we either must match |block128_f| to the
-// type signature of |AES_encrypt| and friends or pass in |void*| wrapper
-// functions.
-//
-// These functions are called exclusively with AES, so we use the former.
-typedef void (*block128_f)(const uint8_t in[16], uint8_t out[16],
-                           const AES_KEY *key);
-
 OPENSSL_INLINE void CRYPTO_xor16(uint8_t out[16], const uint8_t a[16],
                                  const uint8_t b[16]) {
   // TODO(davidben): Ideally we'd leave this to the compiler, which could use
@@ -93,10 +82,6 @@
 
 // CTR.
 
-// ctr128_f is the type of a function that performs CTR-mode encryption.
-typedef void (*ctr128_f)(const uint8_t *in, uint8_t *out, size_t blocks,
-                         const AES_KEY *key, const uint8_t ivec[16]);
-
 // CRYPTO_ctr128_encrypt_ctr32 encrypts (or decrypts, it's the same in CTR mode)
 // |len| bytes from |in| to |out| using |block| in counter mode. There's no
 // requirement that |len| be a multiple of any value and any partial blocks are
@@ -148,7 +133,9 @@
   u128 Htable[16];
   gmult_func gmult;
   ghash_func ghash;
+  AES_KEY aes;
 
+  ctr128_f ctr;
   block128_f block;
   enum gcm_impl_t impl;
 } GCM128_KEY;
@@ -165,11 +152,6 @@
     uint64_t msg;
   } len;
   uint8_t Xi[16];
-
-  // |gcm_*_ssse3| require |Htable| to be 16-byte-aligned.
-  // TODO(crbug.com/boringssl/604): Revisit this.
-  alignas(16) GCM128_KEY gcm_key;
-
   unsigned mres, ares;
 } GCM128_CONTEXT;
 
@@ -185,46 +167,44 @@
 void CRYPTO_ghash_init(gmult_func *out_mult, ghash_func *out_hash,
                        u128 out_table[16], const uint8_t gcm_key[16]);
 
-// CRYPTO_gcm128_init_key initialises |gcm_key| to use |block| (typically AES)
-// with the given key. |block_is_hwaes| is one if |block| is |aes_hw_encrypt|.
-void CRYPTO_gcm128_init_key(GCM128_KEY *gcm_key, const AES_KEY *key,
-                            block128_f block, int block_is_hwaes);
+// CRYPTO_gcm128_init_aes_key initialises |gcm_key| to with AES key |key|.
+void CRYPTO_gcm128_init_aes_key(GCM128_KEY *gcm_key, const uint8_t *key,
+                                size_t key_bytes);
 
-// CRYPTO_gcm128_setiv sets the IV (nonce) for |ctx|. The |key| must be the
-// same key that was passed to |CRYPTO_gcm128_init|.
-void CRYPTO_gcm128_setiv(GCM128_CONTEXT *ctx, const AES_KEY *key,
-                         const uint8_t *iv, size_t iv_len);
+// CRYPTO_gcm128_init_ctx initializes |ctx| to encrypt with |key| and |iv|.
+void CRYPTO_gcm128_init_ctx(const GCM128_KEY *key, GCM128_CONTEXT *ctx,
+                            const uint8_t *iv, size_t iv_len);
 
-// CRYPTO_gcm128_aad sets the authenticated data for an instance of GCM.
-// This must be called before and data is encrypted. It returns one on success
+// CRYPTO_gcm128_aad adds to the authenticated data for an instance of GCM.
+// This must be called before and data is encrypted. |key| must be the same
+// value that was passed to |CRYPTO_gcm128_init_ctx|. It returns one on success
 // and zero otherwise.
-int CRYPTO_gcm128_aad(GCM128_CONTEXT *ctx, const uint8_t *aad, size_t len);
+int CRYPTO_gcm128_aad(const GCM128_KEY *key, GCM128_CONTEXT *ctx,
+                      const uint8_t *aad, size_t aad_len);
 
-// CRYPTO_gcm128_encrypt_ctr32 encrypts |len| bytes from |in| to |out| using
-// a CTR function that only handles the bottom 32 bits of the nonce, like
-// |CRYPTO_ctr128_encrypt_ctr32|. The |key| must be the same key that was
-// passed to |CRYPTO_gcm128_init|. It returns one on success and zero
-// otherwise.
-int CRYPTO_gcm128_encrypt_ctr32(GCM128_CONTEXT *ctx, const AES_KEY *key,
-                                const uint8_t *in, uint8_t *out, size_t len,
-                                ctr128_f stream);
+// CRYPTO_gcm128_encrypt encrypts |len| bytes from |in| to |out|. |key| must be
+// the same value that was passed to |CRYPTO_gcm128_init_ctx|. It returns one on
+// success and zero otherwise.
+int CRYPTO_gcm128_encrypt(const GCM128_KEY *key, GCM128_CONTEXT *ctx,
+                          const uint8_t *in, uint8_t *out, size_t len);
 
-// CRYPTO_gcm128_decrypt_ctr32 decrypts |len| bytes from |in| to |out| using
-// a CTR function that only handles the bottom 32 bits of the nonce, like
-// |CRYPTO_ctr128_encrypt_ctr32|. The |key| must be the same key that was
-// passed to |CRYPTO_gcm128_init|. It returns one on success and zero
-// otherwise.
-int CRYPTO_gcm128_decrypt_ctr32(GCM128_CONTEXT *ctx, const AES_KEY *key,
-                                const uint8_t *in, uint8_t *out, size_t len,
-                                ctr128_f stream);
+// CRYPTO_gcm128_decrypt decrypts |len| bytes from |in| to |out|. |key| must be
+// the same value that was passed to |CRYPTO_gcm128_init_ctx|. It returns one on
+// success and zero otherwise.
+int CRYPTO_gcm128_decrypt(const GCM128_KEY *key, GCM128_CONTEXT *ctx,
+                          const uint8_t *in, uint8_t *out, size_t len);
 
 // CRYPTO_gcm128_finish calculates the authenticator and compares it against
-// |len| bytes of |tag|. It returns one on success and zero otherwise.
-int CRYPTO_gcm128_finish(GCM128_CONTEXT *ctx, const uint8_t *tag, size_t len);
+// |len| bytes of |tag|. |key| must be the same value that was passed to
+// |CRYPTO_gcm128_init_ctx|. It returns one on success and zero otherwise.
+int CRYPTO_gcm128_finish(const GCM128_KEY *key, GCM128_CONTEXT *ctx,
+                         const uint8_t *tag, size_t len);
 
 // CRYPTO_gcm128_tag calculates the authenticator and copies it into |tag|.
-// The minimum of |len| and 16 bytes are copied into |tag|.
-void CRYPTO_gcm128_tag(GCM128_CONTEXT *ctx, uint8_t *tag, size_t len);
+// The minimum of |len| and 16 bytes are copied into |tag|. |key| must be the
+// same value that was passed to |CRYPTO_gcm128_init_ctx|.
+void CRYPTO_gcm128_tag(const GCM128_KEY *key, GCM128_CONTEXT *ctx, uint8_t *tag,
+                       size_t len);
 
 
 // GCM assembly.
diff --git a/crypto/fipsmodule/rand/ctrdrbg.cc.inc b/crypto/fipsmodule/rand/ctrdrbg.cc.inc
index c60eb8d..0f10a21 100644
--- a/crypto/fipsmodule/rand/ctrdrbg.cc.inc
+++ b/crypto/fipsmodule/rand/ctrdrbg.cc.inc
@@ -18,7 +18,7 @@
 
 #include <openssl/mem.h>
 
-#include "../cipher/internal.h"
+#include "../aes/internal.h"
 #include "../service_indicator/internal.h"
 #include "internal.h"