acvp: split ACVP modulewrapper for reuse by Trusty

Trusty requires its own trusted app to implement the ACVP modulewrapper
functionality for validation. Separate the frontend from the generic
functions that implement each algorithm.

Change-Id: I86802b66c627ce4f5b5ddd54555a386e8e993eed
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/45604
Reviewed-by: Adam Langley <agl@google.com>
Commit-Queue: Adam Langley <agl@google.com>
diff --git a/util/fipstools/acvp/modulewrapper/CMakeLists.txt b/util/fipstools/acvp/modulewrapper/CMakeLists.txt
index 8bee5cd..267f82c 100644
--- a/util/fipstools/acvp/modulewrapper/CMakeLists.txt
+++ b/util/fipstools/acvp/modulewrapper/CMakeLists.txt
@@ -4,6 +4,7 @@
   add_executable(
     modulewrapper
 
+    main.cc
     modulewrapper.cc
   )
 
diff --git a/util/fipstools/acvp/modulewrapper/main.cc b/util/fipstools/acvp/modulewrapper/main.cc
new file mode 100644
index 0000000..283c340
--- /dev/null
+++ b/util/fipstools/acvp/modulewrapper/main.cc
@@ -0,0 +1,49 @@
+/* Copyright (c) 2021, Google Inc.
+ *
+ * Permission to use, copy, modify, and/or distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+ * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
+ * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
+ * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
+
+#include <stdio.h>
+#include <string>
+#include <unistd.h>
+
+#include <openssl/span.h>
+
+#include "modulewrapper.h"
+
+
+int main() {
+  std::unique_ptr<bssl::acvp::RequestBuffer> buffer =
+      bssl::acvp::RequestBuffer::New();
+  const bssl::acvp::ReplyCallback write_reply = std::bind(
+      bssl::acvp::WriteReplyToFd, STDOUT_FILENO, std::placeholders::_1);
+
+  for (;;) {
+    const bssl::Span<const bssl::Span<const uint8_t>> args =
+        ParseArgsFromFd(STDIN_FILENO, buffer.get());
+    if (args.empty()) {
+      return 1;
+    }
+
+    const bssl::acvp::Handler handler = bssl::acvp::FindHandler(args);
+    if (!handler) {
+      return 2;
+    }
+
+    if (!handler(args.subspan(1).data(), write_reply)) {
+      const std::string name(reinterpret_cast<const char *>(args[0].data()),
+                             args[0].size());
+      fprintf(stderr, "\'%s\' operation failed.\n", name.c_str());
+      return 3;
+    }
+  }
+};
diff --git a/util/fipstools/acvp/modulewrapper/modulewrapper.cc b/util/fipstools/acvp/modulewrapper/modulewrapper.cc
index a08e670..ddf91c6 100644
--- a/util/fipstools/acvp/modulewrapper/modulewrapper.cc
+++ b/util/fipstools/acvp/modulewrapper/modulewrapper.cc
@@ -45,15 +45,35 @@
 #include "../../../../crypto/fipsmodule/ec/internal.h"
 #include "../../../../crypto/fipsmodule/rand/internal.h"
 #include "../../../../crypto/fipsmodule/tls/internal.h"
+#include "modulewrapper.h"
 
-static constexpr size_t kMaxArgs = 8;
-static constexpr size_t kMaxArgLength = (1 << 20);
-static constexpr size_t kMaxNameLength = 30;
 
-static_assert((kMaxArgs - 1 * kMaxArgLength) + kMaxNameLength > (1 << 30),
-              "Argument limits permit excessive messages");
+namespace bssl {
+namespace acvp {
 
-using namespace bssl;
+#if defined(OPENSSL_TRUSTY)
+#include <trusty_log.h>
+#define LOG_ERROR(...) TLOGE(__VA_ARGS__)
+#else
+#define LOG_ERROR(...) fprintf(stderr, __VA_ARGS__)
+#endif  // OPENSSL_TRUSTY
+
+constexpr size_t kMaxArgLength = (1 << 20);
+
+RequestBuffer::~RequestBuffer() = default;
+
+class RequestBufferImpl : public RequestBuffer {
+ public:
+  ~RequestBufferImpl() = default;
+
+  std::vector<uint8_t> buf;
+  Span<const uint8_t> args[kMaxArgs];
+};
+
+// static
+std::unique_ptr<RequestBuffer> RequestBuffer::New() {
+  return std::unique_ptr<RequestBuffer>(new RequestBufferImpl);
+}
 
 static bool ReadAll(int fd, void *in_data, size_t data_len) {
   uint8_t *data = reinterpret_cast<uint8_t *>(in_data);
@@ -75,9 +95,74 @@
   return true;
 }
 
-template <typename... Args>
-static bool WriteReply(int fd, Args... args) {
-  std::vector<Span<const uint8_t>> spans = {args...};
+Span<const Span<const uint8_t>> ParseArgsFromFd(int fd,
+                                                RequestBuffer *in_buffer) {
+  RequestBufferImpl *buffer = reinterpret_cast<RequestBufferImpl *>(in_buffer);
+  uint32_t nums[1 + kMaxArgs];
+  const Span<const Span<const uint8_t>> empty_span;
+
+  if (!ReadAll(fd, nums, sizeof(uint32_t) * 2)) {
+    return empty_span;
+  }
+
+  const size_t num_args = nums[0];
+  if (num_args == 0) {
+    LOG_ERROR("Invalid, zero-argument operation requested.\n");
+    return empty_span;
+  } else if (num_args > kMaxArgs) {
+    LOG_ERROR("Operation requested with %zu args, but %zu is the limit.\n",
+              num_args, kMaxArgs);
+    return empty_span;
+  }
+
+  if (num_args > 1 &&
+      !ReadAll(fd, &nums[2], sizeof(uint32_t) * (num_args - 1))) {
+    return empty_span;
+  }
+
+  size_t need = 0;
+  for (size_t i = 0; i < num_args; i++) {
+    const size_t arg_length = nums[i + 1];
+    if (i == 0 && arg_length > kMaxNameLength) {
+      LOG_ERROR("Operation with name of length %zu exceeded limit of %zu.\n",
+                arg_length, kMaxNameLength);
+      return empty_span;
+    } else if (arg_length > kMaxArgLength) {
+      LOG_ERROR(
+          "Operation with argument of length %zu exceeded limit of %zu.\n",
+          arg_length, kMaxArgLength);
+      return empty_span;
+    }
+
+    // This static_assert confirms that the following addition doesn't
+    // overflow.
+    static_assert((kMaxArgs - 1 * kMaxArgLength) + kMaxNameLength > (1 << 30),
+                  "Argument limits permit excessive messages");
+    need += arg_length;
+  }
+
+  if (need > buffer->buf.size()) {
+    size_t alloced = need + (need >> 1);
+    if (alloced < need) {
+      abort();
+    }
+    buffer->buf.resize(alloced);
+  }
+
+  if (!ReadAll(fd, buffer->buf.data(), need)) {
+    return empty_span;
+  }
+
+  size_t offset = 0;
+  for (size_t i = 0; i < num_args; i++) {
+    buffer->args[i] = Span<const uint8_t>(&buffer->buf[offset], nums[i + 1]);
+    offset += nums[i + 1];
+  }
+
+  return Span<const Span<const uint8_t>>(buffer->args, num_args);
+}
+
+bool WriteReplyToFd(int fd, const std::vector<Span<const uint8_t>> &spans) {
   if (spans.empty() || spans.size() > kMaxArgs) {
     abort();
   }
@@ -136,7 +221,7 @@
   return true;
 }
 
-static bool GetConfig(const Span<const uint8_t> args[]) {
+static bool GetConfig(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   static constexpr char kConfig[] =
       R"([
       {
@@ -741,35 +826,33 @@
         ]
       }
     ])";
-  return WriteReply(
-      STDOUT_FILENO,
-      Span<const uint8_t>(reinterpret_cast<const uint8_t *>(kConfig),
-                          sizeof(kConfig) - 1));
+  return write_reply({Span<const uint8_t>(
+      reinterpret_cast<const uint8_t *>(kConfig), sizeof(kConfig) - 1)});
 }
 
 template <uint8_t *(*OneShotHash)(const uint8_t *, size_t, uint8_t *),
           size_t DigestLength>
-static bool Hash(const Span<const uint8_t> args[]) {
+static bool Hash(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   uint8_t digest[DigestLength];
   OneShotHash(args[0].data(), args[0].size(), digest);
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(digest));
+  return write_reply({Span<const uint8_t>(digest)});
 }
 
 static uint32_t GetIterations(const Span<const uint8_t> iterations_bytes) {
   uint32_t iterations;
   if (iterations_bytes.size() != sizeof(iterations)) {
-    fprintf(stderr,
-            "Expected %u-byte input for number of iterations, but found %u "
-            "bytes.\n",
-            static_cast<unsigned>(sizeof(iterations)),
-            static_cast<unsigned>(iterations_bytes.size()));
+    LOG_ERROR(
+        "Expected %u-byte input for number of iterations, but found %u "
+        "bytes.\n",
+        static_cast<unsigned>(sizeof(iterations)),
+        static_cast<unsigned>(iterations_bytes.size()));
     abort();
   }
 
   memcpy(&iterations, iterations_bytes.data(), sizeof(iterations));
   if (iterations == 0 || iterations == UINT32_MAX) {
-    fprintf(stderr, "Invalid number of iterations: %x.\n",
-            static_cast<unsigned>(iterations));
+    LOG_ERROR("Invalid number of iterations: %x.\n",
+         static_cast<unsigned>(iterations));
     abort();
   }
 
@@ -778,7 +861,7 @@
 
 template <int (*SetKey)(const uint8_t *key, unsigned bits, AES_KEY *out),
           void (*Block)(const uint8_t *in, uint8_t *out, const AES_KEY *key)>
-static bool AES(const Span<const uint8_t> args[]) {
+static bool AES(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   AES_KEY key;
   if (SetKey(args[0].data(), args[0].size() * 8, &key) != 0) {
     return false;
@@ -800,13 +883,13 @@
     }
   }
 
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(result),
-                    Span<const uint8_t>(prev_result));
+  return write_reply(
+      {Span<const uint8_t>(result), Span<const uint8_t>(prev_result)});
 }
 
 template <int (*SetKey)(const uint8_t *key, unsigned bits, AES_KEY *out),
           int Direction>
-static bool AES_CBC(const Span<const uint8_t> args[]) {
+static bool AES_CBC(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   AES_KEY key;
   if (SetKey(args[0].data(), args[0].size() * 8, &key) != 0) {
     return false;
@@ -849,15 +932,15 @@
     }
   }
 
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(result),
-                    Span<const uint8_t>(prev_result));
+  return write_reply(
+      {Span<const uint8_t>(result), Span<const uint8_t>(prev_result)});
 }
 
-static bool AES_CTR(const Span<const uint8_t> args[]) {
+static bool AES_CTR(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   static const uint32_t kOneIteration = 1;
   if (args[3].size() != sizeof(kOneIteration) ||
       memcmp(args[3].data(), &kOneIteration, sizeof(kOneIteration))) {
-    fprintf(stderr, "Only a single iteration supported with AES-CTR\n");
+    LOG_ERROR("Only a single iteration supported with AES-CTR\n");
     return false;
   }
 
@@ -871,7 +954,7 @@
   uint8_t iv[AES_BLOCK_SIZE];
   memcpy(iv, args[2].data(), AES_BLOCK_SIZE);
   if (GetIterations(args[3]) != 1) {
-    fprintf(stderr, "Multiple iterations of AES-CTR is not supported.\n");
+    LOG_ERROR("Multiple iterations of AES-CTR is not supported.\n");
     return false;
   }
 
@@ -881,15 +964,15 @@
   uint8_t ecount_buf[AES_BLOCK_SIZE];
   AES_ctr128_encrypt(args[1].data(), out.data(), args[1].size(), &key, iv,
                      ecount_buf, &num);
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
+  return write_reply({Span<const uint8_t>(out)});
 }
 
 static bool AESGCMSetup(EVP_AEAD_CTX *ctx, Span<const uint8_t> tag_len_span,
                         Span<const uint8_t> key) {
   uint32_t tag_len_32;
   if (tag_len_span.size() != sizeof(tag_len_32)) {
-    fprintf(stderr, "Tag size value is %u bytes, not an uint32_t\n",
-            static_cast<unsigned>(tag_len_span.size()));
+    LOG_ERROR("Tag size value is %u bytes, not an uint32_t\n",
+              static_cast<unsigned>(tag_len_span.size()));
     return false;
   }
   memcpy(&tag_len_32, tag_len_span.data(), sizeof(tag_len_32));
@@ -906,15 +989,14 @@
       aead = EVP_aead_aes_256_gcm();
       break;
     default:
-      fprintf(stderr, "Bad AES-GCM key length %u\n",
-              static_cast<unsigned>(key.size()));
+      LOG_ERROR("Bad AES-GCM key length %u\n", static_cast<unsigned>(key.size()));
       return false;
   }
 
   if (!EVP_AEAD_CTX_init(ctx, aead, key.data(), key.size(), tag_len_32,
                          nullptr)) {
-    fprintf(stderr, "Failed to setup AES-GCM with tag length %u\n",
-            static_cast<unsigned>(tag_len_32));
+    LOG_ERROR("Failed to setup AES-GCM with tag length %u\n",
+              static_cast<unsigned>(tag_len_32));
     return false;
   }
 
@@ -925,28 +1007,27 @@
                         Span<const uint8_t> key) {
   uint32_t tag_len_32;
   if (tag_len_span.size() != sizeof(tag_len_32)) {
-    fprintf(stderr, "Tag size value is %u bytes, not an uint32_t\n",
-            static_cast<unsigned>(tag_len_span.size()));
+    LOG_ERROR("Tag size value is %u bytes, not an uint32_t\n",
+              static_cast<unsigned>(tag_len_span.size()));
     return false;
   }
   memcpy(&tag_len_32, tag_len_span.data(), sizeof(tag_len_32));
   if (tag_len_32 != 4) {
-    fprintf(stderr, "AES-CCM only supports 4-byte tags, but %u was requested\n",
-            static_cast<unsigned>(tag_len_32));
+    LOG_ERROR("AES-CCM only supports 4-byte tags, but %u was requested\n",
+              static_cast<unsigned>(tag_len_32));
     return false;
   }
 
   if (key.size() != 16) {
-    fprintf(stderr,
-            "AES-CCM only supports 128-bit keys, but %u bits were given\n",
-            static_cast<unsigned>(key.size() * 8));
+    LOG_ERROR("AES-CCM only supports 128-bit keys, but %u bits were given\n",
+              static_cast<unsigned>(key.size() * 8));
     return false;
   }
 
   if (!EVP_AEAD_CTX_init(ctx, EVP_aead_aes_128_ccm_bluetooth(), key.data(),
                          key.size(), tag_len_32, nullptr)) {
-    fprintf(stderr, "Failed to setup AES-CCM with tag length %u\n",
-            static_cast<unsigned>(tag_len_32));
+    LOG_ERROR("Failed to setup AES-CCM with tag length %u\n",
+              static_cast<unsigned>(tag_len_32));
     return false;
   }
 
@@ -955,7 +1036,7 @@
 
 template <bool (*SetupFunc)(EVP_AEAD_CTX *ctx, Span<const uint8_t> tag_len_span,
                             Span<const uint8_t> key)>
-static bool AEADSeal(const Span<const uint8_t> args[]) {
+static bool AEADSeal(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   Span<const uint8_t> tag_len_span = args[0];
   Span<const uint8_t> key = args[1];
   Span<const uint8_t> plaintext = args[2];
@@ -980,12 +1061,12 @@
   }
 
   out.resize(out_len);
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
+  return write_reply({Span<const uint8_t>(out)});
 }
 
 template <bool (*SetupFunc)(EVP_AEAD_CTX *ctx, Span<const uint8_t> tag_len_span,
                             Span<const uint8_t> key)>
-static bool AEADOpen(const Span<const uint8_t> args[]) {
+static bool AEADOpen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   Span<const uint8_t> tag_len_span = args[0];
   Span<const uint8_t> key = args[1];
   Span<const uint8_t> ciphertext = args[2];
@@ -1004,22 +1085,22 @@
   if (!EVP_AEAD_CTX_open(ctx.get(), out.data(), &out_len, out.size(),
                          nonce.data(), nonce.size(), ciphertext.data(),
                          ciphertext.size(), ad.data(), ad.size())) {
-    return WriteReply(STDOUT_FILENO, Span<const uint8_t>(success_flag),
-                      Span<const uint8_t>());
+    return write_reply(
+        {Span<const uint8_t>(success_flag), Span<const uint8_t>()});
   }
 
   out.resize(out_len);
   success_flag[0] = 1;
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(success_flag),
-                    Span<const uint8_t>(out));
+  return write_reply(
+      {Span<const uint8_t>(success_flag), Span<const uint8_t>(out)});
 }
 
 static bool AESPaddedKeyWrapSetup(AES_KEY *out, bool decrypt,
                                   Span<const uint8_t> key) {
   if ((decrypt ? AES_set_decrypt_key : AES_set_encrypt_key)(
           key.data(), key.size() * 8, out) != 0) {
-    fprintf(stderr, "Invalid AES key length for AES-KW(P): %u\n",
-            static_cast<unsigned>(key.size()));
+    LOG_ERROR("Invalid AES key length for AES-KW(P): %u\n",
+              static_cast<unsigned>(key.size()));
     return false;
   }
   return true;
@@ -1032,15 +1113,15 @@
   }
 
   if (input.size() % 8) {
-    fprintf(stderr, "Invalid AES-KW input length: %u\n",
-            static_cast<unsigned>(input.size()));
+    LOG_ERROR("Invalid AES-KW input length: %u\n",
+              static_cast<unsigned>(input.size()));
     return false;
   }
 
   return true;
 }
 
-static bool AESKeyWrapSeal(const Span<const uint8_t> args[]) {
+static bool AESKeyWrapSeal(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   Span<const uint8_t> key = args[1];
   Span<const uint8_t> plaintext = args[2];
 
@@ -1053,21 +1134,20 @@
   std::vector<uint8_t> out(plaintext.size() + 8);
   if (AES_wrap_key(&aes, /*iv=*/nullptr, out.data(), plaintext.data(),
                    plaintext.size()) != static_cast<int>(out.size())) {
-    fprintf(stderr, "AES-KW failed\n");
+    LOG_ERROR("AES-KW failed\n");
     return false;
   }
 
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
+  return write_reply({Span<const uint8_t>(out)});
 }
 
-static bool AESKeyWrapOpen(const Span<const uint8_t> args[]) {
+static bool AESKeyWrapOpen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   Span<const uint8_t> key = args[1];
   Span<const uint8_t> ciphertext = args[2];
 
   AES_KEY aes;
   if (!AESKeyWrapSetup(&aes, /*decrypt=*/true, key, ciphertext) ||
-      ciphertext.size() < 8 ||
-      ciphertext.size() > INT_MAX) {
+      ciphertext.size() < 8 || ciphertext.size() > INT_MAX) {
     return false;
   }
 
@@ -1075,16 +1155,16 @@
   uint8_t success_flag[1] = {0};
   if (AES_unwrap_key(&aes, /*iv=*/nullptr, out.data(), ciphertext.data(),
                      ciphertext.size()) != static_cast<int>(out.size())) {
-    return WriteReply(STDOUT_FILENO, Span<const uint8_t>(success_flag),
-                      Span<const uint8_t>());
+    return write_reply(
+        {Span<const uint8_t>(success_flag), Span<const uint8_t>()});
   }
 
   success_flag[0] = 1;
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(success_flag),
-                    Span<const uint8_t>(out));
+  return write_reply(
+      {Span<const uint8_t>(success_flag), Span<const uint8_t>(out)});
 }
 
-static bool AESPaddedKeyWrapSeal(const Span<const uint8_t> args[]) {
+static bool AESPaddedKeyWrapSeal(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   Span<const uint8_t> key = args[1];
   Span<const uint8_t> plaintext = args[2];
 
@@ -1098,15 +1178,15 @@
   size_t out_len;
   if (!AES_wrap_key_padded(&aes, out.data(), &out_len, out.size(),
                            plaintext.data(), plaintext.size())) {
-    fprintf(stderr, "AES-KWP failed\n");
+    LOG_ERROR("AES-KWP failed\n");
     return false;
   }
 
   out.resize(out_len);
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
+  return write_reply({Span<const uint8_t>(out)});
 }
 
-static bool AESPaddedKeyWrapOpen(const Span<const uint8_t> args[]) {
+static bool AESPaddedKeyWrapOpen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   Span<const uint8_t> key = args[1];
   Span<const uint8_t> ciphertext = args[2];
 
@@ -1121,23 +1201,23 @@
   uint8_t success_flag[1] = {0};
   if (!AES_unwrap_key_padded(&aes, out.data(), &out_len, out.size(),
                              ciphertext.data(), ciphertext.size())) {
-    return WriteReply(STDOUT_FILENO, Span<const uint8_t>(success_flag),
-                      Span<const uint8_t>());
+    return write_reply(
+        {Span<const uint8_t>(success_flag), Span<const uint8_t>()});
   }
 
   success_flag[0] = 1;
   out.resize(out_len);
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(success_flag),
-                    Span<const uint8_t>(out));
+  return write_reply(
+      {Span<const uint8_t>(success_flag), Span<const uint8_t>(out)});
 }
 
 template <bool Encrypt>
-static bool TDES(const Span<const uint8_t> args[]) {
+static bool TDES(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   const EVP_CIPHER *cipher = EVP_des_ede3();
 
   if (args[0].size() != 24) {
-    fprintf(stderr, "Bad key length %u for 3DES.\n",
-            static_cast<unsigned>(args[0].size()));
+    LOG_ERROR("Bad key length %u for 3DES.\n",
+              static_cast<unsigned>(args[0].size()));
     return false;
   }
   bssl::ScopedEVP_CIPHER_CTX ctx;
@@ -1148,8 +1228,8 @@
   }
 
   if (args[1].size() % 8) {
-    fprintf(stderr, "Bad input length %u for 3DES.\n",
-            static_cast<unsigned>(args[1].size()));
+    LOG_ERROR("Bad input length %u for 3DES.\n",
+              static_cast<unsigned>(args[1].size()));
     return false;
   }
   std::vector<uint8_t> result(args[1].begin(), args[1].end());
@@ -1172,31 +1252,31 @@
     }
   }
 
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(result),
-                    Span<const uint8_t>(prev_result),
-                    Span<const uint8_t>(prev_prev_result));
+  return write_reply({Span<const uint8_t>(result),
+                      Span<const uint8_t>(prev_result),
+                      Span<const uint8_t>(prev_prev_result)});
 }
 
 template <bool Encrypt>
-static bool TDES_CBC(const Span<const uint8_t> args[]) {
+static bool TDES_CBC(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   const EVP_CIPHER *cipher = EVP_des_ede3_cbc();
 
   if (args[0].size() != 24) {
-    fprintf(stderr, "Bad key length %u for 3DES.\n",
-            static_cast<unsigned>(args[0].size()));
+    LOG_ERROR("Bad key length %u for 3DES.\n",
+              static_cast<unsigned>(args[0].size()));
     return false;
   }
 
   if (args[1].size() % 8 || args[1].size() == 0) {
-    fprintf(stderr, "Bad input length %u for 3DES.\n",
-            static_cast<unsigned>(args[1].size()));
+    LOG_ERROR("Bad input length %u for 3DES.\n",
+              static_cast<unsigned>(args[1].size()));
     return false;
   }
   std::vector<uint8_t> input(args[1].begin(), args[1].end());
 
   if (args[2].size() != EVP_CIPHER_iv_length(cipher)) {
-    fprintf(stderr, "Bad IV length %u for 3DES.\n",
-            static_cast<unsigned>(args[2].size()));
+    LOG_ERROR("Bad IV length %u for 3DES.\n",
+              static_cast<unsigned>(args[2].size()));
     return false;
   }
   std::vector<uint8_t> iv(args[2].begin(), args[2].end());
@@ -1238,13 +1318,13 @@
     }
   }
 
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(result),
-                    Span<const uint8_t>(prev_result),
-                    Span<const uint8_t>(prev_prev_result));
+  return write_reply({Span<const uint8_t>(result),
+                     Span<const uint8_t>(prev_result),
+                     Span<const uint8_t>(prev_prev_result)});
 }
 
 template <const EVP_MD *HashFunc()>
-static bool HMAC(const Span<const uint8_t> args[]) {
+static bool HMAC(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   const EVP_MD *const md = HashFunc();
   uint8_t digest[EVP_MAX_MD_SIZE];
   unsigned digest_len;
@@ -1252,10 +1332,10 @@
              digest, &digest_len) == nullptr) {
     return false;
   }
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(digest, digest_len));
+  return write_reply({Span<const uint8_t>(digest, digest_len)});
 }
 
-static bool DRBG(const Span<const uint8_t> args[]) {
+static bool DRBG(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   const auto out_len_bytes = args[0];
   const auto entropy = args[1];
   const auto personalisation = args[2];
@@ -1286,7 +1366,7 @@
     return false;
   }
 
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
+  return write_reply({Span<const uint8_t>(out)});
 }
 
 static bool StringEq(Span<const uint8_t> a, const char *b) {
@@ -1334,7 +1414,7 @@
   return std::make_pair(std::move(x_bytes), std::move(y_bytes));
 }
 
-static bool ECDSAKeyGen(const Span<const uint8_t> args[]) {
+static bool ECDSAKeyGen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
   if (!key || !EC_KEY_generate_key_fips(key.get())) {
     return false;
@@ -1344,9 +1424,9 @@
   std::vector<uint8_t> d_bytes =
       BIGNUMBytes(EC_KEY_get0_private_key(key.get()));
 
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(d_bytes),
-                    Span<const uint8_t>(pub_key.first),
-                    Span<const uint8_t>(pub_key.second));
+  return write_reply({Span<const uint8_t>(d_bytes),
+                      Span<const uint8_t>(pub_key.first),
+                      Span<const uint8_t>(pub_key.second)});
 }
 
 static bssl::UniquePtr<BIGNUM> BytesToBIGNUM(Span<const uint8_t> bytes) {
@@ -1355,7 +1435,7 @@
   return bn;
 }
 
-static bool ECDSAKeyVer(const Span<const uint8_t> args[]) {
+static bool ECDSAKeyVer(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
   if (!key) {
     return false;
@@ -1376,7 +1456,7 @@
     reply[0] = 1;
   }
 
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(reply));
+  return write_reply({Span<const uint8_t>(reply)});
 }
 
 static const EVP_MD *HashFromName(Span<const uint8_t> name) {
@@ -1393,7 +1473,7 @@
   }
 }
 
-static bool ECDSASigGen(const Span<const uint8_t> args[]) {
+static bool ECDSASigGen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
   bssl::UniquePtr<BIGNUM> d = BytesToBIGNUM(args[1]);
   const EVP_MD *hash = HashFromName(args[2]);
@@ -1414,11 +1494,11 @@
   std::vector<uint8_t> r_bytes(BIGNUMBytes(sig->r));
   std::vector<uint8_t> s_bytes(BIGNUMBytes(sig->s));
 
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(r_bytes),
-                    Span<const uint8_t>(s_bytes));
+  return write_reply(
+      {Span<const uint8_t>(r_bytes), Span<const uint8_t>(s_bytes)});
 }
 
-static bool ECDSASigVer(const Span<const uint8_t> args[]) {
+static bool ECDSASigVer(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
   const EVP_MD *hash = HashFromName(args[1]);
   auto msg = args[2];
@@ -1451,10 +1531,10 @@
     reply[0] = 1;
   }
 
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(reply));
+  return write_reply({Span<const uint8_t>(reply)});
 }
 
-static bool CMAC_AES(const Span<const uint8_t> args[]) {
+static bool CMAC_AES(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   uint8_t mac[16];
   if (!AES_CMAC(mac, args[1].data(), args[1].size(), args[2].data(),
                 args[2].size())) {
@@ -1470,10 +1550,10 @@
     return false;
   }
 
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(mac, mac_len));
+  return write_reply({Span<const uint8_t>(mac, mac_len)});
 }
 
-static bool CMAC_AESVerify(const Span<const uint8_t> args[]) {
+static bool CMAC_AESVerify(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   // This function is just for testing since libcrypto doesn't do the
   // verification itself. The regcap doesn't advertise "ver" support.
   uint8_t mac[16];
@@ -1484,7 +1564,7 @@
   }
 
   const uint8_t ok = (OPENSSL_memcmp(mac, args[2].data(), args[2].size()) == 0);
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(&ok, sizeof(ok)));
+  return write_reply({Span<const uint8_t>(&ok, sizeof(ok))});
 }
 
 static std::map<unsigned, bssl::UniquePtr<RSA>>& CachedRSAKeys() {
@@ -1509,7 +1589,7 @@
   return ret;
 }
 
-static bool RSAKeyGen(const Span<const uint8_t> args[]) {
+static bool RSAKeyGen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   uint32_t bits;
   if (args[0].size() != sizeof(bits)) {
     return false;
@@ -1518,8 +1598,7 @@
 
   bssl::UniquePtr<RSA> key(RSA_new());
   if (!RSA_generate_key_fips(key.get(), bits, nullptr)) {
-    fprintf(stderr, "RSA_generate_key_fips failed for modulus length %u.\n",
-            bits);
+    LOG_ERROR("RSA_generate_key_fips failed for modulus length %u.\n", bits);
     return false;
   }
 
@@ -1527,8 +1606,8 @@
   RSA_get0_key(key.get(), &n, &e, &d);
   RSA_get0_factors(key.get(), &p, &q);
 
-  if (!WriteReply(STDOUT_FILENO, BIGNUMBytes(e), BIGNUMBytes(p), BIGNUMBytes(q),
-                  BIGNUMBytes(n), BIGNUMBytes(d))) {
+  if (!write_reply({BIGNUMBytes(e), BIGNUMBytes(p), BIGNUMBytes(q),
+                    BIGNUMBytes(n), BIGNUMBytes(d)})) {
     return false;
   }
 
@@ -1536,8 +1615,8 @@
   return true;
 }
 
-template<const EVP_MD *(MDFunc)(), bool UsePSS>
-static bool RSASigGen(const Span<const uint8_t> args[]) {
+template <const EVP_MD *(MDFunc)(), bool UsePSS>
+static bool RSASigGen(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   uint32_t bits;
   if (args[0].size() != sizeof(bits)) {
     return false;
@@ -1571,12 +1650,12 @@
 
   sig.resize(sig_len);
 
-  return WriteReply(STDOUT_FILENO, BIGNUMBytes(RSA_get0_n(key)),
-                    BIGNUMBytes(RSA_get0_e(key)), sig);
+  return write_reply(
+      {BIGNUMBytes(RSA_get0_n(key)), BIGNUMBytes(RSA_get0_e(key)), sig});
 }
 
-template<const EVP_MD *(MDFunc)(), bool UsePSS>
-static bool RSASigVer(const Span<const uint8_t> args[]) {
+template <const EVP_MD *(MDFunc)(), bool UsePSS>
+static bool RSASigVer(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   const Span<const uint8_t> n_bytes = args[0];
   const Span<const uint8_t> e_bytes = args[1];
   const Span<const uint8_t> msg = args[2];
@@ -1608,11 +1687,11 @@
   }
   ERR_clear_error();
 
-  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(&ok, 1));
+  return write_reply({Span<const uint8_t>(&ok, 1)});
 }
 
-template<const EVP_MD *(MDFunc)()>
-static bool TLSKDF(const Span<const uint8_t> args[]) {
+template <const EVP_MD *(MDFunc)()>
+static bool TLSKDF(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   const Span<const uint8_t> out_len_bytes = args[0];
   const Span<const uint8_t> secret = args[1];
   const Span<const uint8_t> label = args[2];
@@ -1634,11 +1713,11 @@
     return 0;
   }
 
-  return WriteReply(STDOUT_FILENO, out);
+  return write_reply({out});
 }
 
 template <int Nid>
-static bool ECDH(const Span<const uint8_t> args[]) {
+static bool ECDH(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   bssl::UniquePtr<BIGNUM> their_x(BytesToBIGNUM(args[0]));
   bssl::UniquePtr<BIGNUM> their_y(BytesToBIGNUM(args[1]));
   const Span<const uint8_t> private_key = args[2];
@@ -1650,14 +1729,14 @@
   bssl::UniquePtr<EC_POINT> their_point(EC_POINT_new(group));
   if (!EC_POINT_set_affine_coordinates_GFp(
           group, their_point.get(), their_x.get(), their_y.get(), ctx.get())) {
-    fprintf(stderr, "Invalid peer point for ECDH.\n");
+    LOG_ERROR("Invalid peer point for ECDH.\n");
     return false;
   }
 
   if (!private_key.empty()) {
     bssl::UniquePtr<BIGNUM> our_k(BytesToBIGNUM(private_key));
     if (!EC_KEY_set_private_key(ec_key.get(), our_k.get())) {
-      fprintf(stderr, "EC_KEY_set_private_key failed.\n");
+      LOG_ERROR("EC_KEY_set_private_key failed.\n");
       return false;
     }
 
@@ -1665,11 +1744,11 @@
     if (!EC_POINT_mul(group, our_pub.get(), our_k.get(), nullptr, nullptr,
                       ctx.get()) ||
         !EC_KEY_set_public_key(ec_key.get(), our_pub.get())) {
-      fprintf(stderr, "Calculating public key failed.\n");
+      LOG_ERROR("Calculating public key failed.\n");
       return false;
     }
   } else if (!EC_KEY_generate_key_fips(ec_key.get())) {
-    fprintf(stderr, "EC_KEY_generate_key_fips failed.\n");
+    LOG_ERROR("EC_KEY_generate_key_fips failed.\n");
     return false;
   }
 
@@ -1680,10 +1759,10 @@
       ECDH_compute_key(output.data(), output.size(), their_point.get(),
                        ec_key.get(), /*kdf=*/nullptr);
   if (out_len < 0) {
-    fprintf(stderr, "ECDH_compute_key failed.\n");
+    LOG_ERROR("ECDH_compute_key failed.\n");
     return false;
   } else if (static_cast<size_t>(out_len) == output.size()) {
-    fprintf(stderr, "ECDH_compute_key output may have been truncated.\n");
+    LOG_ERROR("ECDH_compute_key output may have been truncated.\n");
     return false;
   }
   output.resize(static_cast<size_t>(out_len));
@@ -1693,15 +1772,14 @@
   bssl::UniquePtr<BIGNUM> y(BN_new());
   if (!EC_POINT_get_affine_coordinates_GFp(group, pub, x.get(), y.get(),
                                            ctx.get())) {
-    fprintf(stderr, "EC_POINT_get_affine_coordinates_GFp failed.\n");
+    LOG_ERROR("EC_POINT_get_affine_coordinates_GFp failed.\n");
     return false;
   }
 
-  return WriteReply(STDOUT_FILENO, BIGNUMBytes(x.get()), BIGNUMBytes(y.get()),
-                    output);
+  return write_reply({BIGNUMBytes(x.get()), BIGNUMBytes(y.get()), output});
 }
 
-static bool FFDH(const Span<const uint8_t> args[]) {
+static bool FFDH(const Span<const uint8_t> args[], ReplyCallback write_reply) {
   bssl::UniquePtr<BIGNUM> p(BytesToBIGNUM(args[0]));
   bssl::UniquePtr<BIGNUM> q(BytesToBIGNUM(args[1]));
   bssl::UniquePtr<BIGNUM> g(BytesToBIGNUM(args[2]));
@@ -1711,7 +1789,7 @@
 
   bssl::UniquePtr<DH> dh(DH_new());
   if (!DH_set0_pqg(dh.get(), p.get(), q.get(), g.get())) {
-    fprintf(stderr, "DH_set0_pqg failed.\n");
+    LOG_ERROR("DH_set0_pqg failed.\n");
     return 0;
   }
 
@@ -1725,7 +1803,7 @@
     bssl::UniquePtr<BIGNUM> public_key(BytesToBIGNUM(public_key_span));
 
     if (!DH_set0_key(dh.get(), public_key.get(), private_key.get())) {
-      fprintf(stderr, "DH_set0_key failed.\n");
+      LOG_ERROR("DH_set0_key failed.\n");
       return 0;
     }
 
@@ -1733,24 +1811,24 @@
     public_key.release();
     private_key.release();
   } else if (!DH_generate_key(dh.get())) {
-    fprintf(stderr, "DH_generate_key failed.\n");
+    LOG_ERROR("DH_generate_key failed.\n");
     return false;
   }
 
   std::vector<uint8_t> z(DH_size(dh.get()));
   if (DH_compute_key_padded(z.data(), their_pub.get(), dh.get()) !=
       static_cast<int>(z.size())) {
-    fprintf(stderr, "DH_compute_key_hashed failed.\n");
+    LOG_ERROR("DH_compute_key_hashed failed.\n");
     return false;
   }
 
-  return WriteReply(STDOUT_FILENO, BIGNUMBytes(DH_get0_pub_key(dh.get())), z);
+  return write_reply({BIGNUMBytes(DH_get0_pub_key(dh.get())), z});
 }
 
 static constexpr struct {
-  const char name[kMaxNameLength + 1];
-  uint8_t expected_args;
-  bool (*handler)(const Span<const uint8_t>[]);
+  char name[kMaxNameLength + 1];
+  uint8_t num_expected_args;
+  bool (*handler)(const Span<const uint8_t> args[], ReplyCallback write_reply);
 } kFunctions[] = {
     {"getConfig", 0, GetConfig},
     {"SHA-1", 1, Hash<SHA1, SHA_DIGEST_LENGTH>},
@@ -1821,98 +1899,26 @@
     {"FFDH", 6, FFDH},
 };
 
-int main() {
-  uint32_t nums[1 + kMaxArgs];
-  std::unique_ptr<uint8_t[]> buf;
-  size_t buf_len = 0;
-  Span<const uint8_t> args[kMaxArgs];
-
-  for (;;) {
-    if (!ReadAll(STDIN_FILENO, nums, sizeof(uint32_t) * 2)) {
-      return 1;
-    }
-
-    const size_t num_args = nums[0];
-    if (num_args == 0) {
-      fprintf(stderr, "Invalid, zero-argument operation requested.\n");
-      return 2;
-    } else if (num_args > kMaxArgs) {
-      fprintf(stderr,
-              "Operation requested with %zu args, but %zu is the limit.\n",
-              num_args, kMaxArgs);
-      return 2;
-    }
-
-    if (num_args > 1 &&
-        !ReadAll(STDIN_FILENO, &nums[2], sizeof(uint32_t) * (num_args - 1))) {
-      return 1;
-    }
-
-    size_t need = 0;
-    for (size_t i = 0; i < num_args; i++) {
-      const size_t arg_length = nums[i + 1];
-      if (i == 0 && arg_length > kMaxNameLength) {
-        fprintf(stderr,
-                "Operation with name of length %zu exceeded limit of %zu.\n",
-                arg_length, kMaxNameLength);
-        return 2;
-      } else if (arg_length > kMaxArgLength) {
-        fprintf(
-            stderr,
-            "Operation with argument of length %zu exceeded limit of %zu.\n",
-            arg_length, kMaxArgLength);
-        return 2;
+Handler FindHandler(Span<const Span<const uint8_t>> args) {
+  const bssl::Span<const uint8_t> algorithm = args[0];
+  for (const auto &func : kFunctions) {
+    if (algorithm.size() == strlen(func.name) &&
+        memcmp(algorithm.data(), func.name, algorithm.size()) == 0) {
+      if (args.size() - 1 != func.num_expected_args) {
+        LOG_ERROR("\'%s\' operation received %zu arguments but expected %u.\n",
+                  func.name, args.size() - 1, func.num_expected_args);
+        return nullptr;
       }
 
-      // static_assert around kMaxArgs etc enforces that this doesn't overflow.
-      need += arg_length;
-    }
-
-    if (need > buf_len) {
-      size_t alloced = need + (need >> 1);
-      if (alloced < need) {
-        abort();
-      }
-      buf.reset(new uint8_t[alloced]);
-      buf_len = alloced;
-    }
-
-    if (!ReadAll(STDIN_FILENO, buf.get(), need)) {
-      return 1;
-    }
-
-    size_t offset = 0;
-    for (size_t i = 0; i < num_args; i++) {
-      args[i] = Span<const uint8_t>(&buf[offset], nums[i + 1]);
-      offset += nums[i + 1];
-    }
-
-    bool found = false;
-    for (const auto &func : kFunctions) {
-      if (args[0].size() == strlen(func.name) &&
-          memcmp(args[0].data(), func.name, args[0].size()) == 0) {
-        if (num_args - 1 != func.expected_args) {
-          fprintf(stderr,
-                  "\'%s\' operation received %zu arguments but expected %u.\n",
-                  func.name, num_args - 1, func.expected_args);
-          return 2;
-        }
-
-        if (!func.handler(&args[1])) {
-          fprintf(stderr, "\'%s\' operation failed.\n", func.name);
-          return 4;
-        }
-
-        found = true;
-        break;
-      }
-    }
-
-    if (!found) {
-      const std::string name(reinterpret_cast<const char *>(args[0].data()),
-                             args[0].size());
-      fprintf(stderr, "Unknown operation: %s\n", name.c_str());
-      return 3;
+      return func.handler;
     }
   }
+
+  const std::string name(reinterpret_cast<const char *>(algorithm.data()),
+                         algorithm.size());
+  LOG_ERROR("Unknown operation: %s\n", name.c_str());
+  return nullptr;
 }
+
+}  // namespace acvp
+}  // namespace bssl
diff --git a/util/fipstools/acvp/modulewrapper/modulewrapper.h b/util/fipstools/acvp/modulewrapper/modulewrapper.h
new file mode 100644
index 0000000..0472800
--- /dev/null
+++ b/util/fipstools/acvp/modulewrapper/modulewrapper.h
@@ -0,0 +1,69 @@
+/* Copyright (c) 2021, Google Inc.
+ *
+ * Permission to use, copy, modify, and/or distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+ * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
+ * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
+ * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
+
+#include <openssl/base.h>
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include <openssl/span.h>
+
+
+namespace bssl {
+namespace acvp {
+
+// kMaxArgs is the maximum number of arguments (including the function name)
+// that an ACVP request can contain.
+constexpr size_t kMaxArgs = 8;
+// kMaxNameLength is the maximum length of a function name in an ACVP request.
+constexpr size_t kMaxNameLength = 30;
+
+// RequestBuffer holds various buffers needed for parsing an ACVP request. It
+// can be reused between requests.
+class RequestBuffer {
+ public:
+  virtual ~RequestBuffer();
+
+  static std::unique_ptr<RequestBuffer> New();
+};
+
+// ParseArgsFromFd returns a span of arguments, the first of which is the name
+// of the requested function, from |fd|. The return values point into |buffer|
+// and so must not be used after |buffer| has been freed or reused for a
+// subsequent call. It returns an empty span on error, because std::optional
+// is still too new.
+Span<const Span<const uint8_t>> ParseArgsFromFd(int fd, RequestBuffer *buffer);
+
+// WriteReplyToFd writes a reply to the given file descriptor.
+bool WriteReplyToFd(int fd, const std::vector<Span<const uint8_t>> &spans);
+
+// ReplyCallback is the type of a callback that writes a reply to an ACVP
+// request.
+typedef std::function<bool(const std::vector<Span<const uint8_t>> &)>
+    ReplyCallback;
+
+// Handler is the type of a function that handles a specific ACVP request. If
+// successful it will call |write_reply| with the response arguments and return
+// |write_reply|'s return value. Otherwise it will return false. The given args
+// must not include the name at the beginning.
+typedef bool (*Handler)(const Span<const uint8_t> args[],
+                        ReplyCallback write_reply);
+
+// FindHandler returns a |Handler| that can process the given arguments, or logs
+// a reason and returns |nullptr| if none is found.
+Handler FindHandler(Span<const Span<const uint8_t>> args);
+
+}  // namespace acvp
+}  // namespace bssl