fipsoracle: Add MCT test for AES.

Change-Id: I5e48e78f0cc9962bc0302fd9642789016c84945c
Reviewed-on: https://boringssl-review.googlesource.com/15646
Reviewed-by: Adam Langley <agl@google.com>
Commit-Queue: Adam Langley <agl@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/crypto/fipsoracle/cavp_aes_test.cc b/crypto/fipsoracle/cavp_aes_test.cc
index 83dc0fe..21e49fb 100644
--- a/crypto/fipsoracle/cavp_aes_test.cc
+++ b/crypto/fipsoracle/cavp_aes_test.cc
@@ -30,8 +30,120 @@
   const EVP_CIPHER *cipher;
   std::unique_ptr<FileTest> response_sample;
   bool has_iv;
+  enum Mode {
+    kKAT,  // Known Answer Test
+    kMCT,  // Monte Carlo Test
+  };
+  Mode mode;
 };
 
+static bool MonteCarlo(const TestCtx *ctx, FileTest *t,
+                       const EVP_CIPHER *cipher, std::vector<uint8_t> *out,
+                       bool encrypt, std::vector<uint8_t> key,
+                       std::vector<uint8_t> iv, std::vector<uint8_t> in) {
+  const std::string in_label = encrypt ? "PLAINTEXT" : "CIPHERTEXT",
+                    result_label = encrypt ? "CIPHERTEXT" : "PLAINTEXT";
+  std::vector<uint8_t> prev_result, result, prev_in;
+  for (int i = 0; i < 100; i++) {
+    printf("COUNT = %d\r\nKEY = %s\r\n", i,
+           EncodeHex(key.data(), key.size()).c_str());
+    if (ctx->has_iv) {
+      printf("IV = %s\r\n", EncodeHex(iv.data(), iv.size()).c_str());
+    }
+    printf("%s = %s\r\n", in_label.c_str(),
+           EncodeHex(in.data(), in.size()).c_str());
+
+    if (!ctx->has_iv) {  // ECB mode
+      for (int j = 0; j < 1000; j++) {
+        prev_result = result;
+        if (!CipherOperation(cipher, &result, encrypt, key, iv, in)) {
+          return false;
+        }
+        in = result;
+      }
+    } else {
+      for (int j = 0; j < 1000; j++) {
+        prev_result = result;
+        if (j > 0) {
+          if (encrypt) {
+            iv = result;
+          } else {
+            iv = prev_in;
+          }
+        }
+
+        if (!CipherOperation(cipher, &result, encrypt, key, iv, in)) {
+          return false;
+        }
+
+        prev_in = in;
+
+        if (j == 0) {
+          in = iv;
+        } else {
+          in = prev_result;
+        }
+      }
+    }
+
+    printf("%s = %s\r\n\r\n", result_label.c_str(),
+           EncodeHex(result.data(), result.size()).c_str());
+
+    // Check if sample response file matches.
+    if (ctx->response_sample) {
+      if (ctx->response_sample->ReadNext() != FileTest::kReadSuccess) {
+        t->PrintLine("invalid sample file");
+        return false;
+      }
+      std::string expected_count;
+      std::vector<uint8_t> expected_key, expected_result;
+      if (!ctx->response_sample->GetBytes(&expected_key, "KEY") ||
+          !t->ExpectBytesEqual(expected_key.data(), expected_key.size(),
+                               key.data(), key.size()) ||
+          !ctx->response_sample->GetBytes(&expected_result, result_label) ||
+          !t->ExpectBytesEqual(expected_result.data(), expected_result.size(),
+                               result.data(), result.size())) {
+        t->PrintLine("result doesn't match");
+        return false;
+      }
+    }
+
+    const size_t key_len = key.size() * 8;
+    if (key_len == 128) {
+      for (size_t k = 0; k < key.size(); k++) {
+        key[k] ^= result[k];
+      }
+    } else if (key_len == 192) {
+      for (size_t k = 0; k < key.size(); k++) {
+        // Key[i+1] = Key[i] xor (last 64-bits of CT[j-1] || CT[j])
+        if (k < 8) {
+          key[k] ^= prev_result[prev_result.size() - 8 + k];
+        } else {
+          key[k] ^= result[k - 8];
+        }
+      }
+    } else {  // key_len == 256
+      for (size_t k = 0; k < key.size(); k++) {
+        // Key[i+1] = Key[i] xor (CT[j-1] || CT[j])
+        if (k < 16) {
+          key[k] ^= prev_result[k];
+        } else {
+          key[k] ^= result[k - 16];
+        }
+      }
+    }
+
+    if (ctx->has_iv) {
+      iv = result;
+      in = prev_result;
+    } else {
+      in = result;
+    }
+  }
+
+  return true;
+}
+
 static bool TestCipher(FileTest *t, void *arg) {
   TestCtx *ctx = reinterpret_cast<TestCtx *>(arg);
 
@@ -54,71 +166,104 @@
 
   const EVP_CIPHER *cipher = ctx->cipher;
   if (operation == kEncrypt) {
-    if (!t->GetBytes(&in, "PLAINTEXT") ||
-        !CipherOperation(cipher, &result, true /* encrypt */, key, iv, in)) {
+    if (!t->GetBytes(&in, "PLAINTEXT")) {
       return false;
     }
-    printf("%sCIPHERTEXT = %s\r\n\r\n", t->CurrentTestToString().c_str(),
-           EncodeHex(result.data(), result.size()).c_str());
-  } else {
-    if (!t->GetBytes(&in, "CIPHERTEXT") ||
-        !CipherOperation(cipher, &result, false /* decrypt */, key, iv, in)) {
+  } else {  // operation == kDecrypt
+    if (!t->GetBytes(&in, "CIPHERTEXT")) {
       return false;
     }
-    printf("%sPLAINTEXT = %s\r\n\r\n", t->CurrentTestToString().c_str(),
-           EncodeHex(result.data(), result.size()).c_str());
   }
 
-  // Check if sample response file matches.
-  if (ctx->response_sample) {
-    if (ctx->response_sample->ReadNext() != FileTest::kReadSuccess) {
-      t->PrintLine("invalid sample file");
+  if (ctx->mode == TestCtx::kKAT) {
+    if (!CipherOperation(cipher, &result, operation == kEncrypt, key, iv, in)) {
       return false;
     }
-    std::string expected_count;
-    std::vector<uint8_t> expected_result;
-    if (!ctx->response_sample->GetAttribute(&expected_count, "COUNT") ||
-        count != expected_count ||
-        (operation == kEncrypt &&
-         (!ctx->response_sample->GetBytes(&expected_result, "CIPHERTEXT") ||
-          !t->ExpectBytesEqual(expected_result.data(), expected_result.size(),
-                               result.data(), result.size()))) ||
-        (operation == kDecrypt &&
-         (!ctx->response_sample->GetBytes(&expected_result, "PLAINTEXT") ||
-          !t->ExpectBytesEqual(expected_result.data(), expected_result.size(),
-                               result.data(), result.size())))) {
-      t->PrintLine("result doesn't match");
+    const std::string label =
+        operation == kEncrypt ? "CIPHERTEXT" : "PLAINTEXT";
+    printf("%s%s = %s\r\n\r\n", t->CurrentTestToString().c_str(), label.c_str(),
+           EncodeHex(result.data(), result.size()).c_str());
+
+    // Check if sample response file matches.
+    if (ctx->response_sample) {
+      if (ctx->response_sample->ReadNext() != FileTest::kReadSuccess) {
+        t->PrintLine("invalid sample file");
+        return false;
+      }
+      std::string expected_count;
+      std::vector<uint8_t> expected_result;
+      if (!ctx->response_sample->GetAttribute(&expected_count, "COUNT") ||
+          count != expected_count ||
+          (operation == kEncrypt &&
+           (!ctx->response_sample->GetBytes(&expected_result, "CIPHERTEXT") ||
+            !t->ExpectBytesEqual(expected_result.data(), expected_result.size(),
+                                 result.data(), result.size()))) ||
+          (operation == kDecrypt &&
+           (!ctx->response_sample->GetBytes(&expected_result, "PLAINTEXT") ||
+            !t->ExpectBytesEqual(expected_result.data(), expected_result.size(),
+                                 result.data(), result.size())))) {
+        t->PrintLine("result doesn't match");
+        return false;
+      }
+    }
+
+  } else {  // ctx->mode == kMCT
+    const std::string op_label =
+        operation == kEncrypt ? "[ENCRYPT]" : "[DECRYPT]";
+    printf("%s\r\n\r\n", op_label.c_str());
+    if (!MonteCarlo(ctx, t, cipher, &result, operation == kEncrypt, key, iv,
+                    in)) {
       return false;
     }
+    if (operation == kEncrypt) {
+      // MCT tests contain a stray blank line after the ENCRYPT section.
+      printf("\r\n");
+    }
   }
 
   return true;
 }
 
+static int usage(char *arg) {
+  fprintf(stderr,
+          "usage: %s (kat|mct) <cipher> <test file> [<sample response file>]\n",
+          arg);
+  return 1;
+}
+
 int main(int argc, char **argv) {
   CRYPTO_library_init();
 
-  if (argc < 3 || argc > 4) {
-    fprintf(stderr, "usage: %s <cipher> <test file> [<sample response file>]\n",
-            argv[0]);
-    return 1;
+  if (argc < 4 || argc > 5) {
+    return usage(argv[0]);
   }
 
-  const EVP_CIPHER *cipher = GetCipher(argv[1]);
+  const std::string tm(argv[1]);
+  enum TestCtx::Mode test_mode;
+  if (tm == "kat") {
+    test_mode = TestCtx::kKAT;
+  } else if (tm == "mct") {
+    test_mode = TestCtx::kMCT;
+  } else {
+    fprintf(stderr, "invalid test_mode: %s\n", tm.c_str());
+    return usage(argv[0]);
+  }
+
+  const std::string cipher_name(argv[2]);
+  const EVP_CIPHER *cipher = GetCipher(argv[2]);
   if (cipher == nullptr) {
-    fprintf(stderr, "invalid cipher: %s\n", argv[1]);
+    fprintf(stderr, "invalid cipher: %s\n", argv[2]);
     return 1;
   }
-  const std::string cipher_name(argv[1]);
   const bool has_iv =
       (cipher_name != "aes-128-ecb" &&
        cipher_name != "aes-192-ecb" &&
        cipher_name != "aes-256-ecb");
 
-  TestCtx ctx = {cipher, nullptr, has_iv};
+  TestCtx ctx = {cipher, nullptr, has_iv, test_mode};
 
-  if (argc == 4) {
-    ctx.response_sample.reset(new FileTest(argv[3]));
+  if (argc == 5) {
+    ctx.response_sample.reset(new FileTest(argv[4]));
     if (!ctx.response_sample->is_open()) {
       return 1;
     }
@@ -131,5 +276,5 @@
   }
   printf("\r\n\r\n");
 
-  return FileTestMainSilent(TestCipher, &ctx, argv[2]);
+  return FileTestMainSilent(TestCipher, &ctx, argv[3]);
 }
diff --git a/crypto/fipsoracle/run_cavp.go b/crypto/fipsoracle/run_cavp.go
index 21c9c0a..bfd59cd 100644
--- a/crypto/fipsoracle/run_cavp.go
+++ b/crypto/fipsoracle/run_cavp.go
@@ -14,6 +14,7 @@
 
 var (
 	binaryDir = flag.String("bin-dir", "", "Directory containing fipsoracle binaries")
+	suiteDir  = flag.String("suite-dir", "", "Base directory containing the CAVP test suite")
 )
 
 // test describes a single request file.
@@ -39,6 +40,10 @@
 	tests  []test
 }
 
+func (t *testSuite) getDirectory() string {
+	return filepath.Join(*suiteDir, t.directory)
+}
+
 var aesGCMTests = testSuite{
 	"AES_GCM",
 	"cavp_aes_gcm_test",
@@ -54,47 +59,46 @@
 	"AES",
 	"cavp_aes_test",
 	[]test{
-		{"CBCGFSbox128", []string{"aes-128-cbc"}, false},
-		{"CBCGFSbox192", []string{"aes-192-cbc"}, false},
-		{"CBCGFSbox256", []string{"aes-256-cbc"}, false},
-		{"CBCKeySbox128", []string{"aes-128-cbc"}, false},
-		{"CBCKeySbox192", []string{"aes-192-cbc"}, false},
-		{"CBCKeySbox256", []string{"aes-256-cbc"}, false},
-		{"CBCMMT128", []string{"aes-128-cbc"}, false},
-		{"CBCMMT192", []string{"aes-192-cbc"}, false},
-		{"CBCMMT256", []string{"aes-256-cbc"}, false},
-		{"CBCVarKey128", []string{"aes-128-cbc"}, false},
-		{"CBCVarKey192", []string{"aes-192-cbc"}, false},
-		{"CBCVarKey256", []string{"aes-256-cbc"}, false},
-		{"CBCVarTxt128", []string{"aes-128-cbc"}, false},
-		{"CBCVarTxt192", []string{"aes-192-cbc"}, false},
-		{"CBCVarTxt256", []string{"aes-256-cbc"}, false},
-		{"ECBGFSbox128", []string{"aes-128-ecb"}, false},
-		{"ECBGFSbox192", []string{"aes-192-ecb"}, false},
-		{"ECBGFSbox256", []string{"aes-256-ecb"}, false},
-		{"ECBKeySbox128", []string{"aes-128-ecb"}, false},
-		{"ECBKeySbox192", []string{"aes-192-ecb"}, false},
-		{"ECBKeySbox256", []string{"aes-256-ecb"}, false},
-		{"ECBMMT128", []string{"aes-128-ecb"}, false},
-		{"ECBMMT192", []string{"aes-192-ecb"}, false},
-		{"ECBMMT256", []string{"aes-256-ecb"}, false},
-		{"ECBVarKey128", []string{"aes-128-ecb"}, false},
-		{"ECBVarKey192", []string{"aes-192-ecb"}, false},
-		{"ECBVarKey256", []string{"aes-256-ecb"}, false},
-		{"ECBVarTxt128", []string{"aes-128-ecb"}, false},
-		{"ECBVarTxt192", []string{"aes-192-ecb"}, false},
-		{"ECBVarTxt256", []string{"aes-256-ecb"}, false},
+		{"CBCGFSbox128", []string{"kat", "aes-128-cbc"}, false},
+		{"CBCGFSbox192", []string{"kat", "aes-192-cbc"}, false},
+		{"CBCGFSbox256", []string{"kat", "aes-256-cbc"}, false},
+		{"CBCKeySbox128", []string{"kat", "aes-128-cbc"}, false},
+		{"CBCKeySbox192", []string{"kat", "aes-192-cbc"}, false},
+		{"CBCKeySbox256", []string{"kat", "aes-256-cbc"}, false},
+		{"CBCMMT128", []string{"kat", "aes-128-cbc"}, false},
+		{"CBCMMT192", []string{"kat", "aes-192-cbc"}, false},
+		{"CBCMMT256", []string{"kat", "aes-256-cbc"}, false},
+		{"CBCVarKey128", []string{"kat", "aes-128-cbc"}, false},
+		{"CBCVarKey192", []string{"kat", "aes-192-cbc"}, false},
+		{"CBCVarKey256", []string{"kat", "aes-256-cbc"}, false},
+		{"CBCVarTxt128", []string{"kat", "aes-128-cbc"}, false},
+		{"CBCVarTxt192", []string{"kat", "aes-192-cbc"}, false},
+		{"CBCVarTxt256", []string{"kat", "aes-256-cbc"}, false},
+		{"ECBGFSbox128", []string{"kat", "aes-128-ecb"}, false},
+		{"ECBGFSbox192", []string{"kat", "aes-192-ecb"}, false},
+		{"ECBGFSbox256", []string{"kat", "aes-256-ecb"}, false},
+		{"ECBKeySbox128", []string{"kat", "aes-128-ecb"}, false},
+		{"ECBKeySbox192", []string{"kat", "aes-192-ecb"}, false},
+		{"ECBKeySbox256", []string{"kat", "aes-256-ecb"}, false},
+		{"ECBMMT128", []string{"kat", "aes-128-ecb"}, false},
+		{"ECBMMT192", []string{"kat", "aes-192-ecb"}, false},
+		{"ECBMMT256", []string{"kat", "aes-256-ecb"}, false},
+		{"ECBVarKey128", []string{"kat", "aes-128-ecb"}, false},
+		{"ECBVarKey192", []string{"kat", "aes-192-ecb"}, false},
+		{"ECBVarKey256", []string{"kat", "aes-256-ecb"}, false},
+		{"ECBVarTxt128", []string{"kat", "aes-128-ecb"}, false},
+		{"ECBVarTxt192", []string{"kat", "aes-192-ecb"}, false},
+		{"ECBVarTxt256", []string{"kat", "aes-256-ecb"}, false},
+		// AES Monte-Carlo tests
+		{"ECBMCT128", []string{"mct", "aes-128-ecb"}, false},
+		{"ECBMCT192", []string{"mct", "aes-192-ecb"}, false},
+		{"ECBMCT256", []string{"mct", "aes-256-ecb"}, false},
+		{"CBCMCT128", []string{"mct", "aes-128-cbc"}, false},
+		{"CBCMCT192", []string{"mct", "aes-192-cbc"}, false},
+		{"CBCMCT256", []string{"mct", "aes-256-cbc"}, false},
 	},
 }
 
-// AES Monte-Carlo tests need a different binary.
-//{"ECBMCT128", []string{"aes-128-ecb"}, false},
-//{"ECBMCT192", []string{"aes-192-ecb"}, false},
-//{"ECBMCT256", []string{"aes-256-ecb"}, false},
-//{"CBCMCT128", []string{"aes-128-cbc"}, false},
-//{"CBCMCT192", []string{"aes-192-cbc"}, false},
-//{"CBCMCT256", []string{"aes-256-cbc"}, false},
-
 var ecdsa2PKVTests = testSuite{
 	"ECDSA2",
 	"cavp_ecdsa2_pkv_test",
@@ -144,12 +148,12 @@
 
 	var args []string
 	args = append(args, test.args...)
-	args = append(args, filepath.Join(suite.directory, "req", test.inFile+".req"))
+	args = append(args, filepath.Join(suite.getDirectory(), "req", test.inFile+".req"))
 
-	outPath := filepath.Join(suite.directory, "resp", test.inFile+".resp")
+	outPath := filepath.Join(suite.getDirectory(), "resp", test.inFile+".resp")
 	outFile, err := os.OpenFile(outPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
 	if err != nil {
-		return fmt.Errorf("cannot open output file for %q %q: %s", suite.directory, test.inFile, err)
+		return fmt.Errorf("cannot open output file for %q %q: %s", suite.getDirectory(), test.inFile, err)
 	}
 	defer outFile.Close()
 
@@ -158,7 +162,7 @@
 	cmd.Stderr = os.Stderr
 
 	if err := cmd.Run(); err != nil {
-		return fmt.Errorf("cannot run command for %q %q: %s", suite.directory, test.inFile, err)
+		return fmt.Errorf("cannot run command for %q %q: %s", suite.getDirectory(), test.inFile, err)
 	}
 
 	return nil
@@ -175,17 +179,17 @@
 }
 
 func compareFAX(suite *testSuite, test test) error {
-	respPath := filepath.Join(suite.directory, "resp", test.inFile+".resp")
+	respPath := filepath.Join(suite.getDirectory(), "resp", test.inFile+".resp")
 	respFile, err := os.Open(respPath)
 	if err != nil {
-		return fmt.Errorf("cannot read output of %q %q: %s", suite.directory, test.inFile, err)
+		return fmt.Errorf("cannot read output of %q %q: %s", suite.getDirectory(), test.inFile, err)
 	}
 	defer respFile.Close()
 
-	faxPath := filepath.Join(suite.directory, "fax", test.inFile+".fax")
+	faxPath := filepath.Join(suite.getDirectory(), "fax", test.inFile+".fax")
 	faxFile, err := os.Open(faxPath)
 	if err != nil {
-		return fmt.Errorf("cannot open fax file for %q %q: %s", suite.directory, test.inFile, err)
+		return fmt.Errorf("cannot open fax file for %q %q: %s", suite.getDirectory(), test.inFile, err)
 	}
 	defer faxFile.Close()
 
@@ -225,7 +229,7 @@
 			}
 
 			if !haveFaxLine {
-				return fmt.Errorf("resp file is longer than fax for %q %q", suite.directory, test.inFile)
+				return fmt.Errorf("resp file is longer than fax for %q %q", suite.getDirectory(), test.inFile)
 			}
 
 			if strings.HasPrefix(faxLine, " (Reason: ") {
@@ -239,11 +243,11 @@
 			continue
 		}
 
-		return fmt.Errorf("resp and fax differ at line %d for %q %q: %q vs %q", lineNo, suite.directory, test.inFile, respLine, faxLine)
+		return fmt.Errorf("resp and fax differ at line %d for %q %q: %q vs %q", lineNo, suite.getDirectory(), test.inFile, respLine, faxLine)
 	}
 
 	if faxScanner.Scan() {
-		return fmt.Errorf("fax file is longer than resp for %q %q", suite.directory, test.inFile)
+		return fmt.Errorf("fax file is longer than resp for %q %q", suite.getDirectory(), test.inFile)
 	}
 
 	return nil