acvp: add AES-KW support.

Change-Id: I8cfa1f525a029a015db2f41e4502e9b5332ba102
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/43164
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/util/fipstools/acvp/acvptool/subprocess/subprocess.go b/util/fipstools/acvp/acvptool/subprocess/subprocess.go
index a9fbdfc..a3b81ae 100644
--- a/util/fipstools/acvp/acvptool/subprocess/subprocess.go
+++ b/util/fipstools/acvp/acvptool/subprocess/subprocess.go
@@ -80,6 +80,7 @@
 		"ACVP-AES-CBC":  &blockCipher{"AES-CBC", 16, true, true},
 		"ACVP-AES-CTR":  &blockCipher{"AES-CTR", 16, false, true},
 		"ACVP-AES-GCM":  &aead{"AES-GCM"},
+		"ACVP-AES-KW":   &aead{"AES-KW"},
 		"HMAC-SHA-1":    &hmacPrimitive{"HMAC-SHA-1", 20},
 		"HMAC-SHA2-224": &hmacPrimitive{"HMAC-SHA2-224", 28},
 		"HMAC-SHA2-256": &hmacPrimitive{"HMAC-SHA2-256", 32},
diff --git a/util/fipstools/acvp/modulewrapper/modulewrapper.cc b/util/fipstools/acvp/modulewrapper/modulewrapper.cc
index 2e9598b..d4a2207 100644
--- a/util/fipstools/acvp/modulewrapper/modulewrapper.cc
+++ b/util/fipstools/acvp/modulewrapper/modulewrapper.cc
@@ -17,6 +17,7 @@
 
 #include <assert.h>
 #include <errno.h>
+#include <limits.h>
 #include <string.h>
 #include <sys/uio.h>
 #include <unistd.h>
@@ -205,6 +206,21 @@
         "ivGen": "external"
       },
       {
+        "algorithm": "ACVP-AES-KW",
+        "revision": "1.0",
+        "direction": [
+            "encrypt",
+            "decrypt"
+        ],
+        "kwCipher": [
+            "cipher"
+        ],
+        "keyLen": [
+            128, 192, 256
+        ],
+        "payloadLen": [{"min": 128, "max": 1024, "increment": 64}]
+      },
+      {
         "algorithm": "HMAC-SHA-1",
         "revision": "1.0",
         "keyLen": [{
@@ -506,18 +522,79 @@
 
   std::vector<uint8_t> out(ciphertext.size());
   size_t out_len;
-  uint8_t success[1] = {0};
+  uint8_t success_flag[1] = {0};
 
   if (!EVP_AEAD_CTX_open(ctx.get(), out.data(), &out_len, out.size(),
                          nonce.data(), nonce.size(), ciphertext.data(),
                          ciphertext.size(), ad.data(), ad.size())) {
-    return WriteReply(STDOUT_FILENO, Span<const uint8_t>(success),
+    return WriteReply(STDOUT_FILENO, Span<const uint8_t>(success_flag),
                       Span<const uint8_t>());
   }
 
   out.resize(out_len);
-  success[0] = 1;
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(success),
+  success_flag[0] = 1;
+  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(success_flag),
+                    Span<const uint8_t>(out));
+}
+
+static bool AESKeyWrapSetup(AES_KEY *out, bool decrypt, Span<const uint8_t> key,
+                            Span<const uint8_t> input) {
+  if ((decrypt ? AES_set_decrypt_key : AES_set_encrypt_key)(
+          key.data(), key.size() * 8, out) != 0) {
+    fprintf(stderr, "Invalid AES key length for AES-KW: %u\n",
+            static_cast<unsigned>(key.size()));
+    return false;
+  }
+  if (input.size() % 8) {
+    fprintf(stderr, "Invalid AES-KW input length: %u\n",
+            static_cast<unsigned>(input.size()));
+    return false;
+  }
+
+  return true;
+}
+
+static bool AESKeyWrapSeal(const Span<const uint8_t> args[]) {
+  Span<const uint8_t> key = args[1];
+  Span<const uint8_t> plaintext = args[2];
+
+  AES_KEY aes;
+  if (!AESKeyWrapSetup(&aes, /*decrypt=*/false, key, plaintext) ||
+      plaintext.size() > INT_MAX - 8) {
+    return false;
+  }
+
+  std::vector<uint8_t> out(plaintext.size() + 8);
+  if (AES_wrap_key(&aes, /*iv=*/nullptr, out.data(), plaintext.data(),
+                   plaintext.size()) != static_cast<int>(out.size())) {
+    fprintf(stderr, "AES-KW failed\n");
+    return false;
+  }
+
+  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
+}
+
+static bool AESKeyWrapOpen(const Span<const uint8_t> args[]) {
+  Span<const uint8_t> key = args[1];
+  Span<const uint8_t> ciphertext = args[2];
+
+  AES_KEY aes;
+  if (!AESKeyWrapSetup(&aes, /*decrypt=*/true, key, ciphertext) ||
+      ciphertext.size() < 8 ||
+      ciphertext.size() > INT_MAX) {
+    return false;
+  }
+
+  std::vector<uint8_t> out(ciphertext.size() - 8);
+  uint8_t success_flag[1] = {0};
+  if (AES_unwrap_key(&aes, /*iv=*/nullptr, out.data(), ciphertext.data(),
+                     ciphertext.size()) != static_cast<int>(out.size())) {
+    return WriteReply(STDOUT_FILENO, Span<const uint8_t>(success_flag),
+                      Span<const uint8_t>());
+  }
+
+  success_flag[0] = 1;
+  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(success_flag),
                     Span<const uint8_t>(out));
 }
 
@@ -770,6 +847,8 @@
     {"AES-CTR/decrypt", 3, AES_CTR},
     {"AES-GCM/seal", 5, AESGCMSeal},
     {"AES-GCM/open", 5, AESGCMOpen},
+    {"AES-KW/seal", 5, AESKeyWrapSeal},
+    {"AES-KW/open", 5, AESKeyWrapOpen},
     {"HMAC-SHA-1", 2, HMAC<EVP_sha1>},
     {"HMAC-SHA2-224", 2, HMAC<EVP_sha224>},
     {"HMAC-SHA2-256", 2, HMAC<EVP_sha256>},