acvptool: add support for ECDSA

Change-Id: I0c643de16d5215a20bb21e8523efccd5555098eb
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/38764
Reviewed-by: Gurleen Grewal <gurleengrewal@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/util/fipstools/acvp/acvptool/subprocess/ecdsa.go b/util/fipstools/acvp/acvptool/subprocess/ecdsa.go
new file mode 100644
index 0000000..5b5b1d1
--- /dev/null
+++ b/util/fipstools/acvp/acvptool/subprocess/ecdsa.go
@@ -0,0 +1,228 @@
+// Copyright (c) 2019, Google Inc.
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
+// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
+// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+package subprocess
+
+import (
+	"bytes"
+	"encoding/hex"
+	"encoding/json"
+	"fmt"
+)
+
+// The following structures reflect the JSON of ACVP hash tests. See
+// https://usnistgov.github.io/ACVP/artifacts/acvp_sub_ecdsa.html#test_vectors
+
+type ecdsaTestVectorSet struct {
+	Groups []ecdsaTestGroup `json:"testGroups"`
+	Mode   string           `json:"mode"`
+}
+
+type ecdsaTestGroup struct {
+	ID                   uint64 `json:"tgId"`
+	Curve                string `json:"curve"`
+	SecretGenerationMode string `json:"secretGenerationMode,omitempty"`
+	HashAlgo             string `json:"hashAlg,omitEmpty"`
+	ComponentTest        bool   `json:"componentTest"`
+	Tests                []struct {
+		ID     uint64 `json:"tcId"`
+		QxHex  string `json:"qx,omitempty"`
+		QyHex  string `json:"qy,omitempty"`
+		RHex   string `json:"r,omitempty"`
+		SHex   string `json:"s,omitempty"`
+		MsgHex string `json:"message,omitempty"`
+	} `json:"tests"`
+}
+
+type ecdsaTestGroupResponse struct {
+	ID    uint64              `json:"tgId"`
+	Tests []ecdsaTestResponse `json:"tests"`
+	QxHex string              `json:"qx,omitempty"`
+	QyHex string              `json:"qy,omitempty"`
+}
+
+type ecdsaTestResponse struct {
+	ID     uint64 `json:"tcId"`
+	DHex   string `json:"d,omitempty"`
+	QxHex  string `json:"qx,omitempty"`
+	QyHex  string `json:"qy,omitempty"`
+	RHex   string `json:"r,omitempty"`
+	SHex   string `json:"s,omitempty"`
+	Passed *bool  `json:"testPassed,omitempty"` // using pointer so value is not omitted when it is false
+}
+
+// ecdsa implements an ACVP algorithm by making requests to the
+// subprocess to generate and verify ECDSA keys and signatures.
+type ecdsa struct {
+	// algo is the ACVP name for this algorithm and also the command name
+	// given to the subprocess to hash with this hash function.
+	algo   string
+	curves map[string]bool // supported curve names
+	m      *Subprocess
+}
+
+func (e *ecdsa) Process(vectorSet []byte) (interface{}, error) {
+	var parsed ecdsaTestVectorSet
+	if err := json.Unmarshal(vectorSet, &parsed); err != nil {
+		return nil, err
+	}
+
+	var ret []ecdsaTestGroupResponse
+	// See
+	// https://usnistgov.github.io/ACVP/artifacts/draft-celi-acvp-sha-00.html#rfc.section.3
+	// for details about the tests.
+	for _, group := range parsed.Groups {
+		if _, ok := e.curves[group.Curve]; !ok {
+			return nil, fmt.Errorf("curve %q in test group %d not supported", group.Curve, group.ID)
+		}
+
+		response := ecdsaTestGroupResponse{
+			ID: group.ID,
+		}
+		var sigGenPrivateKey []byte
+
+		for _, test := range group.Tests {
+			var testResp ecdsaTestResponse
+
+			switch parsed.Mode {
+			case "keyGen":
+				if group.SecretGenerationMode != "testing candidates" {
+					return nil, fmt.Errorf("invalid secret generation mode in test group %d: %q", group.ID, group.SecretGenerationMode)
+				}
+				result, err := e.m.transact(e.algo+"/"+"keyGen", 3, []byte(group.Curve))
+				if err != nil {
+					return nil, fmt.Errorf("key generation failed for test case %d/%d: %s", group.ID, test.ID, err)
+				}
+				testResp.DHex = hex.EncodeToString(result[0])
+				testResp.QxHex = hex.EncodeToString(result[1])
+				testResp.QyHex = hex.EncodeToString(result[2])
+
+			case "keyVer":
+				qx, err := hex.DecodeString(test.QxHex)
+				if err != nil {
+					return nil, fmt.Errorf("failed to decode qx in test case %d/%d: %s", group.ID, test.ID, err)
+				}
+				qy, err := hex.DecodeString(test.QyHex)
+				if err != nil {
+					return nil, fmt.Errorf("failed to decode qy in test case %d/%d: %s", group.ID, test.ID, err)
+				}
+				result, err := e.m.transact(e.algo+"/"+"keyVer", 1, []byte(group.Curve), qx, qy)
+				if err != nil {
+					return nil, fmt.Errorf("key verification failed for test case %d/%d: %s", group.ID, test.ID, err)
+				}
+				// result[0] should be a single byte: zero if false, one if true
+				switch {
+				case bytes.Equal(result[0], []byte{00}):
+					f := false
+					testResp.Passed = &f
+				case bytes.Equal(result[0], []byte{01}):
+					t := true
+					testResp.Passed = &t
+				default:
+					return nil, fmt.Errorf("key verification returned unexpected result: %q", result[0])
+				}
+
+			case "sigGen":
+				p := e.m.primitives[group.HashAlgo]
+				h, ok := p.(*hashPrimitive)
+				if !ok {
+					return nil, fmt.Errorf("unsupported hash algorithm %q in test group %d", group.HashAlgo, group.ID)
+				}
+
+				if len(sigGenPrivateKey) == 0 {
+					// Ask the subprocess to generate a key for this test group.
+					result, err := e.m.transact(e.algo+"/"+"keyGen", 3, []byte(group.Curve))
+					if err != nil {
+						return nil, fmt.Errorf("key generation failed for test case %d/%d: %s", group.ID, test.ID, err)
+					}
+
+					sigGenPrivateKey = result[0]
+					response.QxHex = hex.EncodeToString(result[1])
+					response.QyHex = hex.EncodeToString(result[2])
+				}
+
+				msg, err := hex.DecodeString(test.MsgHex)
+				if err != nil {
+					return nil, fmt.Errorf("failed to decode message hex in test case %d/%d: %s", group.ID, test.ID, err)
+				}
+				op := e.algo + "/" + "sigGen"
+				if group.ComponentTest {
+					if len(msg) != h.size {
+						return nil, fmt.Errorf("test case %d/%d contains message %q of length %d, but expected length %d", group.ID, test.ID, test.MsgHex, len(msg), h.size)
+					}
+					op += "/componentTest"
+				}
+				result, err := e.m.transact(op, 2, []byte(group.Curve), sigGenPrivateKey, []byte(group.HashAlgo), msg)
+				if err != nil {
+					return nil, fmt.Errorf("signature generation failed for test case %d/%d: %s", group.ID, test.ID, err)
+				}
+				testResp.RHex = hex.EncodeToString(result[0])
+				testResp.SHex = hex.EncodeToString(result[1])
+
+			case "sigVer":
+				p := e.m.primitives[group.HashAlgo]
+				_, ok := p.(*hashPrimitive)
+				if !ok {
+					return nil, fmt.Errorf("unsupported hash algorithm %q in test group %d", group.HashAlgo, group.ID)
+				}
+
+				msg, err := hex.DecodeString(test.MsgHex)
+				if err != nil {
+					return nil, fmt.Errorf("failed to decode message hex in test case %d/%d: %s", group.ID, test.ID, err)
+				}
+				qx, err := hex.DecodeString(test.QxHex)
+				if err != nil {
+					return nil, fmt.Errorf("failed to decode qx in test case %d/%d: %s", group.ID, test.ID, err)
+				}
+				qy, err := hex.DecodeString(test.QyHex)
+				if err != nil {
+					return nil, fmt.Errorf("failed to decode qy in test case %d/%d: %s", group.ID, test.ID, err)
+				}
+				r, err := hex.DecodeString(test.RHex)
+				if err != nil {
+					return nil, fmt.Errorf("failed to decode R in test case %d/%d: %s", group.ID, test.ID, err)
+				}
+				s, err := hex.DecodeString(test.SHex)
+				if err != nil {
+					return nil, fmt.Errorf("failed to decode S in test case %d/%d: %s", group.ID, test.ID, err)
+				}
+				result, err := e.m.transact(e.algo+"/"+"sigVer", 1, []byte(group.Curve), []byte(group.HashAlgo), msg, qx, qy, r, s)
+				if err != nil {
+					return nil, fmt.Errorf("signature verification failed for test case %d/%d: %s", group.ID, test.ID, err)
+				}
+				// result[0] should be a single byte: zero if false, one if true
+				switch {
+				case bytes.Equal(result[0], []byte{00}):
+					f := false
+					testResp.Passed = &f
+				case bytes.Equal(result[0], []byte{01}):
+					t := true
+					testResp.Passed = &t
+				default:
+					return nil, fmt.Errorf("signature verification returned unexpected result: %q", result[0])
+				}
+
+			default:
+				return nil, fmt.Errorf("invalid mode %q in ECDSA vector set", parsed.Mode)
+			}
+
+			testResp.ID = test.ID
+			response.Tests = append(response.Tests, testResp)
+		}
+
+		ret = append(ret, response)
+	}
+
+	return ret, nil
+}
diff --git a/util/fipstools/acvp/acvptool/subprocess/subprocess.go b/util/fipstools/acvp/acvptool/subprocess/subprocess.go
index 460041c..6f3e6ad 100644
--- a/util/fipstools/acvp/acvptool/subprocess/subprocess.go
+++ b/util/fipstools/acvp/acvptool/subprocess/subprocess.go
@@ -77,6 +77,7 @@
 		"HMAC-SHA2-512": &hmacPrimitive{"HMAC-SHA2-512", 64, m},
 		"ctrDRBG":       &drbg{"ctrDRBG", map[string]bool{"AES-128": true, "AES-192": true, "AES-256": true}, m},
 		"hmacDRBG":      &drbg{"hmacDRBG", map[string]bool{"SHA-1": true, "SHA2-224": true, "SHA2-256": true, "SHA2-384": true, "SHA2-512": true}, m},
+		"ECDSA":         &ecdsa{"ECDSA", map[string]bool{"P-224": true, "P-256": true, "P-384": true, "P-521": true}, m},
 	}
 
 	return m
diff --git a/util/fipstools/acvp/modulewrapper/modulewrapper.cc b/util/fipstools/acvp/modulewrapper/modulewrapper.cc
index f373b9e..109e3b0 100644
--- a/util/fipstools/acvp/modulewrapper/modulewrapper.cc
+++ b/util/fipstools/acvp/modulewrapper/modulewrapper.cc
@@ -23,8 +23,13 @@
 #include <cstdarg>
 
 #include <openssl/aes.h>
+#include <openssl/bn.h>
 #include <openssl/digest.h>
+#include <openssl/ec.h>
+#include <openssl/ec_key.h>
+#include <openssl/ecdsa.h>
 #include <openssl/hmac.h>
+#include <openssl/obj.h>
 #include <openssl/sha.h>
 #include <openssl/span.h>
 
@@ -231,8 +236,71 @@
           ],
           "returnedBitsLen": 2048
         }]
+      },
+      {
+        "algorithm": "ECDSA",
+        "mode": "keyGen",
+        "revision": "1.0",
+        "curve": [
+          "P-224",
+          "P-256",
+          "P-384",
+          "P-521"
+        ],
+        "secretGenerationMode": [
+          "testing candidates"
+        ]
+      },
+      {
+        "algorithm": "ECDSA",
+        "mode": "keyVer",
+        "revision": "1.0",
+        "curve": [
+          "P-224",
+          "P-256",
+          "P-384",
+          "P-521"
+        ]
+      },
+      {
+        "algorithm": "ECDSA",
+        "mode": "sigGen",
+        "revision": "1.0",
+        "capabilities": [{
+          "curve": [
+            "P-224",
+            "P-256",
+            "P-384",
+            "P-521"
+          ],
+          "hashAlg": [
+            "SHA2-224",
+            "SHA2-256",
+            "SHA2-384",
+            "SHA2-512"
+          ]
+        }]
+      },
+      {
+        "algorithm": "ECDSA",
+        "mode": "sigVer",
+        "revision": "1.0",
+        "capabilities": [{
+          "curve": [
+            "P-224",
+            "P-256",
+            "P-384",
+            "P-521"
+          ],
+          "hashAlg": [
+            "SHA2-224",
+            "SHA2-256",
+            "SHA2-384",
+            "SHA2-512"
+          ]
+        }]
       }
-      ])";
+    ])";
   return WriteReply(
       STDOUT_FILENO,
       Span<const uint8_t>(reinterpret_cast<const uint8_t *>(kConfig),
@@ -333,6 +401,171 @@
   return WriteReply(STDOUT_FILENO, Span<const uint8_t>(out));
 }
 
+static bool StringEq(Span<const uint8_t> a, const char *b) {
+  const size_t len = strlen(b);
+  return a.size() == len && memcmp(a.data(), b, len) == 0;
+}
+
+static bssl::UniquePtr<EC_KEY> ECKeyFromName(Span<const uint8_t> name) {
+  int nid;
+  if (StringEq(name, "P-224")) {
+    nid = NID_secp224r1;
+  } else if (StringEq(name, "P-256")) {
+    nid = NID_X9_62_prime256v1;
+  } else if (StringEq(name, "P-384")) {
+    nid = NID_secp384r1;
+  } else if (StringEq(name, "P-521")) {
+    nid = NID_secp521r1;
+  } else {
+    return nullptr;
+  }
+
+  return bssl::UniquePtr<EC_KEY>(EC_KEY_new_by_curve_name(nid));
+}
+
+static std::vector<uint8_t> BIGNUMBytes(const BIGNUM *bn) {
+  const size_t len = BN_num_bytes(bn);
+  std::vector<uint8_t> ret(len);
+  BN_bn2bin(bn, ret.data());
+  return ret;
+}
+
+static std::pair<std::vector<uint8_t>, std::vector<uint8_t>> GetPublicKeyBytes(
+    const EC_KEY *key) {
+  bssl::UniquePtr<BIGNUM> x(BN_new());
+  bssl::UniquePtr<BIGNUM> y(BN_new());
+  if (!EC_POINT_get_affine_coordinates_GFp(EC_KEY_get0_group(key),
+                                           EC_KEY_get0_public_key(key), x.get(),
+                                           y.get(), /*ctx=*/nullptr)) {
+    abort();
+  }
+
+  std::vector<uint8_t> x_bytes = BIGNUMBytes(x.get());
+  std::vector<uint8_t> y_bytes = BIGNUMBytes(y.get());
+
+  return std::make_pair(std::move(x_bytes), std::move(y_bytes));
+}
+
+static bool ECDSAKeyGen(const Span<const uint8_t> args[]) {
+  bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
+  if (!key || !EC_KEY_generate_key_fips(key.get())) {
+    return false;
+  }
+
+  const auto pub_key = GetPublicKeyBytes(key.get());
+  std::vector<uint8_t> d_bytes =
+      BIGNUMBytes(EC_KEY_get0_private_key(key.get()));
+
+  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(d_bytes),
+                    Span<const uint8_t>(pub_key.first),
+                    Span<const uint8_t>(pub_key.second));
+}
+
+static bssl::UniquePtr<BIGNUM> BytesToBIGNUM(Span<const uint8_t> bytes) {
+  bssl::UniquePtr<BIGNUM> bn(BN_new());
+  BN_bin2bn(bytes.data(), bytes.size(), bn.get());
+  return bn;
+}
+
+static bool ECDSAKeyVer(const Span<const uint8_t> args[]) {
+  bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
+  if (!key) {
+    return false;
+  }
+
+  bssl::UniquePtr<BIGNUM> x(BytesToBIGNUM(args[1]));
+  bssl::UniquePtr<BIGNUM> y(BytesToBIGNUM(args[2]));
+
+  bssl::UniquePtr<EC_POINT> point(EC_POINT_new(EC_KEY_get0_group(key.get())));
+  uint8_t reply[1];
+  if (!EC_POINT_set_affine_coordinates_GFp(EC_KEY_get0_group(key.get()),
+                                           point.get(), x.get(), y.get(),
+                                           /*ctx=*/nullptr) ||
+      !EC_KEY_set_public_key(key.get(), point.get()) ||
+      !EC_KEY_check_fips(key.get())) {
+    reply[0] = 0;
+  } else {
+    reply[0] = 1;
+  }
+
+  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(reply));
+}
+
+static const EVP_MD *HashFromName(Span<const uint8_t> name) {
+  if (StringEq(name, "SHA2-224")) {
+    return EVP_sha224();
+  } else if (StringEq(name, "SHA2-256")) {
+    return EVP_sha256();
+  } else if (StringEq(name, "SHA2-384")) {
+    return EVP_sha384();
+  } else if (StringEq(name, "SHA2-512")) {
+    return EVP_sha512();
+  } else {
+    return nullptr;
+  }
+}
+
+static bool ECDSASigGen(const Span<const uint8_t> args[]) {
+  bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
+  bssl::UniquePtr<BIGNUM> d = BytesToBIGNUM(args[1]);
+  const EVP_MD *hash = HashFromName(args[2]);
+  uint8_t digest[EVP_MAX_MD_SIZE];
+  unsigned digest_len;
+  if (!key || !hash ||
+      !EVP_Digest(args[3].data(), args[3].size(), digest, &digest_len, hash,
+                  /*impl=*/nullptr) ||
+      !EC_KEY_set_private_key(key.get(), d.get())) {
+    return false;
+  }
+
+  bssl::UniquePtr<ECDSA_SIG> sig(ECDSA_do_sign(digest, digest_len, key.get()));
+  if (!sig) {
+    return false;
+  }
+
+  std::vector<uint8_t> r_bytes(BIGNUMBytes(sig->r));
+  std::vector<uint8_t> s_bytes(BIGNUMBytes(sig->s));
+
+  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(r_bytes),
+                    Span<const uint8_t>(s_bytes));
+}
+
+static bool ECDSASigVer(const Span<const uint8_t> args[]) {
+  bssl::UniquePtr<EC_KEY> key = ECKeyFromName(args[0]);
+  const EVP_MD *hash = HashFromName(args[1]);
+  auto msg = args[2];
+  bssl::UniquePtr<BIGNUM> x(BytesToBIGNUM(args[3]));
+  bssl::UniquePtr<BIGNUM> y(BytesToBIGNUM(args[4]));
+  bssl::UniquePtr<BIGNUM> r(BytesToBIGNUM(args[5]));
+  bssl::UniquePtr<BIGNUM> s(BytesToBIGNUM(args[6]));
+  ECDSA_SIG sig;
+  sig.r = r.get();
+  sig.s = s.get();
+
+  uint8_t digest[EVP_MAX_MD_SIZE];
+  unsigned digest_len;
+  if (!key || !hash ||
+      !EVP_Digest(msg.data(), msg.size(), digest, &digest_len, hash,
+                  /*impl=*/nullptr)) {
+    return false;
+  }
+
+  bssl::UniquePtr<EC_POINT> point(EC_POINT_new(EC_KEY_get0_group(key.get())));
+  uint8_t reply[1];
+  if (!EC_POINT_set_affine_coordinates_GFp(EC_KEY_get0_group(key.get()),
+                                           point.get(), x.get(), y.get(),
+                                           /*ctx=*/nullptr) ||
+      !EC_KEY_set_public_key(key.get(), point.get()) ||
+      !EC_KEY_check_fips(key.get()) ||
+      !ECDSA_do_verify(digest, digest_len, &sig, key.get())) {
+    reply[0] = 0;
+  } else {
+    reply[0] = 1;
+  }
+
+  return WriteReply(STDOUT_FILENO, Span<const uint8_t>(reply));
+}
+
 static constexpr struct {
   const char name[kMaxNameLength + 1];
   uint8_t expected_args;
@@ -354,6 +587,10 @@
     {"HMAC-SHA2-384", 2, HMAC<EVP_sha384>},
     {"HMAC-SHA2-512", 2, HMAC<EVP_sha512>},
     {"ctrDRBG/AES-256", 6, DRBG},
+    {"ECDSA/keyGen", 1, ECDSAKeyGen},
+    {"ECDSA/keyVer", 3, ECDSAKeyVer},
+    {"ECDSA/sigGen", 4, ECDSASigGen},
+    {"ECDSA/sigVer", 7, ECDSASigVer},
 };
 
 int main() {