acvp: add support for KAS

Change-Id: Ida3ec65e81398881a71828dc1d51cf80be41bdbb
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/44444
Commit-Queue: Adam Langley <agl@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/util/fipstools/acvp/acvptool/subprocess/kas.go b/util/fipstools/acvp/acvptool/subprocess/kas.go
new file mode 100644
index 0000000..b95e48a
--- /dev/null
+++ b/util/fipstools/acvp/acvptool/subprocess/kas.go
@@ -0,0 +1,162 @@
+// Copyright (c) 2020, Google Inc.
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
+// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
+// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+package subprocess
+
+import (
+	"bytes"
+	"encoding/hex"
+	"encoding/json"
+	"fmt"
+)
+
+type kasVectorSet struct {
+	Groups []kasTestGroup `json:"testGroups"`
+}
+
+type kasTestGroup struct {
+	ID     uint64    `json:"tgId"`
+	Type   string    `json:"testType"`
+	Curve  string    `json:"domainParameterGenerationMode"`
+	Role   string    `json:"kasRole"`
+	Scheme string    `json:"scheme"`
+	Tests  []kasTest `json:"tests"`
+}
+
+type kasTest struct {
+	ID            uint64 `json:"tcId"`
+	XHex          string `json:"ephemeralPublicServerX"`
+	YHex          string `json:"ephemeralPublicServerY"`
+	PrivateKeyHex string `json:"ephemeralPrivateIut"`
+	ResultHex     string `json:"z"`
+}
+
+type kasTestGroupResponse struct {
+	ID    uint64            `json:"tgId"`
+	Tests []kasTestResponse `json:"tests"`
+}
+
+type kasTestResponse struct {
+	ID        uint64 `json:"tcId"`
+	XHex      string `json:"ephemeralPublicIutX,omitempty"`
+	YHex      string `json:"ephemeralPublicIutY,omitempty"`
+	ResultHex string `json:"z,omitempty"`
+	Passed    *bool  `json:"testPassed,omitempty"`
+}
+
+type kas struct{}
+
+func (k *kas) Process(vectorSet []byte, m Transactable) (interface{}, error) {
+	var parsed kasVectorSet
+	if err := json.Unmarshal(vectorSet, &parsed); err != nil {
+		return nil, err
+	}
+
+	// See https://usnistgov.github.io/ACVP/draft-hammett-acvp-kas-ssc-ecc.html
+	var ret []kasTestGroupResponse
+	for _, group := range parsed.Groups {
+		response := kasTestGroupResponse{
+			ID: group.ID,
+		}
+
+		var privateKeyGiven bool
+		switch group.Type {
+		case "AFT":
+			privateKeyGiven = false
+		case "VAL":
+			privateKeyGiven = true
+		default:
+			return nil, fmt.Errorf("unknown test type %q", group.Type)
+		}
+
+		switch group.Curve {
+		case "P-224", "P-256", "P-384", "P-521":
+			break
+		default:
+			return nil, fmt.Errorf("unknown curve %q", group.Curve)
+		}
+
+		switch group.Role {
+		case "initiator", "responder":
+			break
+		default:
+			return nil, fmt.Errorf("unknown role %q", group.Role)
+		}
+
+		if group.Scheme != "ephemeralUnified" {
+			return nil, fmt.Errorf("unknown scheme %q", group.Scheme)
+		}
+
+		method := "ECDH/" + group.Curve
+
+		for _, test := range group.Tests {
+			if len(test.XHex) == 0 || len(test.YHex) == 0 {
+				return nil, fmt.Errorf("%d/%d is missing peer's point", group.ID, test.ID)
+			}
+
+			peerX, err := hex.DecodeString(test.XHex)
+			if err != nil {
+				return nil, err
+			}
+
+			peerY, err := hex.DecodeString(test.YHex)
+			if err != nil {
+				return nil, err
+			}
+
+			if (len(test.PrivateKeyHex) != 0) != privateKeyGiven {
+				return nil, fmt.Errorf("%d/%d incorrect private key presence", group.ID, test.ID)
+			}
+
+			if privateKeyGiven {
+				privateKey, err := hex.DecodeString(test.PrivateKeyHex)
+				if err != nil {
+					return nil, err
+				}
+
+				expectedOutput, err := hex.DecodeString(test.ResultHex)
+				if err != nil {
+					return nil, err
+				}
+
+				result, err := m.Transact(method, 3, peerX, peerY, privateKey)
+				if err != nil {
+					return nil, err
+				}
+
+				ok := bytes.Equal(result[2], expectedOutput)
+				response.Tests = append(response.Tests, kasTestResponse{
+					ID:     test.ID,
+					Passed: &ok,
+				})
+			} else {
+				result, err := m.Transact(method, 3, peerX, peerY, nil)
+				if err != nil {
+					return nil, err
+				}
+
+				response.Tests = append(response.Tests, kasTestResponse{
+					ID:        test.ID,
+					XHex:      hex.EncodeToString(result[0]),
+					YHex:      hex.EncodeToString(result[1]),
+					ResultHex: hex.EncodeToString(result[2]),
+				})
+			}
+		}
+
+		ret = append(ret, response)
+	}
+
+	return ret, nil
+}
diff --git a/util/fipstools/acvp/acvptool/subprocess/subprocess.go b/util/fipstools/acvp/acvptool/subprocess/subprocess.go
index e5c1d1a..e3f11cc 100644
--- a/util/fipstools/acvp/acvptool/subprocess/subprocess.go
+++ b/util/fipstools/acvp/acvptool/subprocess/subprocess.go
@@ -97,6 +97,7 @@
 		"CMAC-AES":       &keyedMACPrimitive{"CMAC-AES"},
 		"RSA":            &rsa{},
 		"kdf-components": &tlsKDF{},
+		"KAS-ECC-SSC":    &kas{},
 	}
 	m.primitives["ECDSA"] = &ecdsa{"ECDSA", map[string]bool{"P-224": true, "P-256": true, "P-384": true, "P-521": true}, m.primitives}
 
diff --git a/util/fipstools/acvp/modulewrapper/modulewrapper.cc b/util/fipstools/acvp/modulewrapper/modulewrapper.cc
index 9d915ab..37db281 100644
--- a/util/fipstools/acvp/modulewrapper/modulewrapper.cc
+++ b/util/fipstools/acvp/modulewrapper/modulewrapper.cc
@@ -32,6 +32,7 @@
 #include <openssl/digest.h>
 #include <openssl/ec.h>
 #include <openssl/ec_key.h>
+#include <openssl/ecdh.h>
 #include <openssl/ecdsa.h>
 #include <openssl/err.h>
 #include <openssl/hmac.h>
@@ -40,6 +41,7 @@
 #include <openssl/sha.h>
 #include <openssl/span.h>
 
+#include "../../../../crypto/fipsmodule/ec/internal.h"
 #include "../../../../crypto/fipsmodule/rand/internal.h"
 #include "../../../../crypto/fipsmodule/tls/internal.h"
 
@@ -717,6 +719,24 @@
           "SHA2-384",
           "SHA2-512"
         ]
+      },
+      {
+        "algorithm": "KAS-ECC-SSC",
+        "revision": "Sp800-56Ar3",
+        "scheme": {
+          "ephemeralUnified": {
+            "kasRole": [
+              "initiator",
+              "responder"
+            ]
+          }
+        },
+        "domainParameterGenerationMethods": [
+          "P-224",
+          "P-256",
+          "P-384",
+          "P-521"
+        ]
       }
     ])";
   return WriteReply(
@@ -1594,6 +1614,70 @@
   return WriteReply(STDOUT_FILENO, out);
 }
 
+template <int Nid>
+static bool ECDH(const Span<const uint8_t> args[]) {
+  bssl::UniquePtr<BIGNUM> their_x(BytesToBIGNUM(args[0]));
+  bssl::UniquePtr<BIGNUM> their_y(BytesToBIGNUM(args[1]));
+  const Span<const uint8_t> private_key = args[2];
+
+  bssl::UniquePtr<EC_KEY> ec_key(EC_KEY_new_by_curve_name(Nid));
+  bssl::UniquePtr<BN_CTX> ctx(BN_CTX_new());
+
+  const EC_GROUP *const group = EC_KEY_get0_group(ec_key.get());
+  bssl::UniquePtr<EC_POINT> their_point(EC_POINT_new(group));
+  if (!EC_POINT_set_affine_coordinates_GFp(
+          group, their_point.get(), their_x.get(), their_y.get(), ctx.get())) {
+    fprintf(stderr, "Invalid peer point for ECDH.\n");
+    return false;
+  }
+
+  if (!private_key.empty()) {
+    bssl::UniquePtr<BIGNUM> our_k(BytesToBIGNUM(private_key));
+    if (!EC_KEY_set_private_key(ec_key.get(), our_k.get())) {
+      fprintf(stderr, "EC_KEY_set_private_key failed.\n");
+      return false;
+    }
+
+    bssl::UniquePtr<EC_POINT> our_pub(EC_POINT_new(group));
+    if (!EC_POINT_mul(group, our_pub.get(), our_k.get(), nullptr, nullptr,
+                      ctx.get()) ||
+        !EC_KEY_set_public_key(ec_key.get(), our_pub.get())) {
+      fprintf(stderr, "Calculating public key failed.\n");
+      return false;
+    }
+  } else if (!EC_KEY_generate_key_fips(ec_key.get())) {
+    fprintf(stderr, "EC_KEY_generate_key_fips failed.\n");
+    return false;
+  }
+
+  // The output buffer is one larger than |EC_MAX_BYTES| so that truncation
+  // can be detected.
+  std::vector<uint8_t> output(EC_MAX_BYTES + 1);
+  const int out_len =
+      ECDH_compute_key(output.data(), output.size(), their_point.get(),
+                       ec_key.get(), /*kdf=*/nullptr);
+  if (out_len < 0) {
+    fprintf(stderr, "ECDH_compute_key failed.\n");
+    return false;
+  } else if (static_cast<size_t>(out_len) == output.size()) {
+    fprintf(stderr, "ECDH_compute_key output may have been truncated.\n");
+    return false;
+  }
+  output.resize(static_cast<size_t>(out_len));
+
+  const EC_POINT *pub = EC_KEY_get0_public_key(ec_key.get());
+  bssl::UniquePtr<BIGNUM> x(BN_new());
+  bssl::UniquePtr<BIGNUM> y(BN_new());
+  if (!EC_POINT_get_affine_coordinates_GFp(group, pub, x.get(), y.get(),
+                                           ctx.get())) {
+    fprintf(stderr, "EC_POINT_get_affine_coordinates_GFp failed.\n");
+    return false;
+  }
+
+  return WriteReply(STDOUT_FILENO, BIGNUMBytes(x.get()), BIGNUMBytes(y.get()),
+                    output);
+}
+
 static constexpr struct {
   const char name[kMaxNameLength + 1];
   uint8_t expected_args;
@@ -1660,6 +1744,10 @@
     {"TLSKDF/1.2/SHA2-256", 5, TLSKDF<EVP_sha256>},
     {"TLSKDF/1.2/SHA2-384", 5, TLSKDF<EVP_sha384>},
     {"TLSKDF/1.2/SHA2-512", 5, TLSKDF<EVP_sha512>},
+    {"ECDH/P-224", 3, ECDH<NID_secp224r1>},
+    {"ECDH/P-256", 3, ECDH<NID_X9_62_prime256v1>},
+    {"ECDH/P-384", 3, ECDH<NID_secp384r1>},
+    {"ECDH/P-521", 3, ECDH<NID_secp521r1>},
 };
 
 int main() {