Test we can round-trip PKCS8_{encrypt,decrypt}.

This is a very basic test, but it's something.

Change-Id: Ic044297e97ce5719673869113ce581de4621ebbd
Reviewed-on: https://boringssl-review.googlesource.com/13061
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <alangley@gmail.com>
diff --git a/crypto/pkcs8/pkcs8_test.cc b/crypto/pkcs8/pkcs8_test.cc
index af6dede..877884d 100644
--- a/crypto/pkcs8/pkcs8_test.cc
+++ b/crypto/pkcs8/pkcs8_test.cc
@@ -21,6 +21,8 @@
 #include <openssl/pkcs8.h>
 #include <openssl/x509.h>
 
+#include "../internal.h"
+
 
 /* kDER is a PKCS#8 encrypted private key. It was generated with:
  *
@@ -141,7 +143,8 @@
     0xed,
 };
 
-static bool test(const uint8_t *der, size_t der_len, const char *password) {
+static bool TestDecrypt(const uint8_t *der, size_t der_len,
+                        const char *password) {
   const uint8_t *data = der;
   bssl::UniquePtr<X509_SIG> sig(d2i_X509_SIG(NULL, &data, der_len));
   if (sig.get() == NULL || data != der + der_len) {
@@ -160,11 +163,90 @@
   return true;
 }
 
+static bool TestRoundTrip(int pbe_nid, const EVP_CIPHER *cipher,
+                          const char *password, const uint8_t *salt,
+                          size_t salt_len, int iterations) {
+  static const uint8_t kSampleKey[] = {
+      0x30, 0x81, 0x87, 0x02, 0x01, 0x00, 0x30, 0x13, 0x06, 0x07, 0x2a, 0x86,
+      0x48, 0xce, 0x3d, 0x02, 0x01, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d,
+      0x03, 0x01, 0x07, 0x04, 0x6d, 0x30, 0x6b, 0x02, 0x01, 0x01, 0x04, 0x20,
+      0x8a, 0x87, 0x2f, 0xb6, 0x28, 0x93, 0xc4, 0xd1, 0xff, 0xc5, 0xb9, 0xf0,
+      0xf9, 0x17, 0x58, 0x06, 0x9f, 0x83, 0x52, 0xe0, 0x8f, 0xa0, 0x5a, 0x49,
+      0xf8, 0xdb, 0x92, 0x6c, 0xb5, 0x72, 0x87, 0x25, 0xa1, 0x44, 0x03, 0x42,
+      0x00, 0x04, 0x2c, 0x15, 0x0f, 0x42, 0x9c, 0xe7, 0x0f, 0x21, 0x6c, 0x25,
+      0x2c, 0xf5, 0xe0, 0x62, 0xce, 0x1f, 0x63, 0x9c, 0xd5, 0xd1, 0x65, 0xc7,
+      0xf8, 0x94, 0x24, 0x07, 0x2c, 0x27, 0x19, 0x7d, 0x78, 0xb3, 0x3b, 0x92,
+      0x0e, 0x95, 0xcd, 0xb6, 0x64, 0xe9, 0x90, 0xdc, 0xf0, 0xcf, 0xea, 0x0d,
+      0x94, 0xe2, 0xa8, 0xe6, 0xaf, 0x9d, 0x0e, 0x58, 0x05, 0x6e, 0x65, 0x31,
+      0x04, 0x92, 0x5b, 0x9f, 0xe6, 0xc9,
+  };
+
+  const uint8_t *ptr = kSampleKey;
+  bssl::UniquePtr<PKCS8_PRIV_KEY_INFO> key(
+      d2i_PKCS8_PRIV_KEY_INFO(nullptr, &ptr, sizeof(kSampleKey)));
+  if (!key || ptr != kSampleKey + sizeof(kSampleKey)) {
+    return false;
+  }
+
+  bssl::UniquePtr<X509_SIG> encrypted(PKCS8_encrypt(
+      pbe_nid, cipher, password, -1, salt, salt_len, iterations, key.get()));
+  if (!encrypted) {
+    fprintf(stderr, "Failed to encrypt private key.\n");
+    return false;
+  }
+
+  bssl::UniquePtr<PKCS8_PRIV_KEY_INFO> key2(
+      PKCS8_decrypt(encrypted.get(), password, -1));
+  if (!key2) {
+    fprintf(stderr, "Failed to decrypt private key.\n");
+    return false;
+  }
+
+  uint8_t *encoded = nullptr;
+  int len = i2d_PKCS8_PRIV_KEY_INFO(key2.get(), &encoded);
+  bssl::UniquePtr<uint8_t> free_encoded(encoded);
+  if (len < 0 ||
+      static_cast<size_t>(len) != sizeof(kSampleKey) ||
+      OPENSSL_memcmp(encoded, kSampleKey, sizeof(kSampleKey)) != 0) {
+    fprintf(stderr, "Decrypted private key did not round-trip.");
+    return false;
+  }
+
+  return true;
+}
+
 int main(int argc, char **argv) {
-  if (!test(kDER, sizeof(kDER), "testing") ||
-      !test(kNullPassword, sizeof(kNullPassword), NULL) ||
-      !test(kNullPasswordNSS, sizeof(kNullPasswordNSS), NULL) ||
-      !test(kEmptyPasswordOpenSSL, sizeof(kEmptyPasswordOpenSSL), "")) {
+  if (!TestDecrypt(kDER, sizeof(kDER), "testing") ||
+      !TestDecrypt(kNullPassword, sizeof(kNullPassword), NULL) ||
+      !TestDecrypt(kNullPasswordNSS, sizeof(kNullPasswordNSS), NULL) ||
+      !TestDecrypt(kEmptyPasswordOpenSSL, sizeof(kEmptyPasswordOpenSSL), "") ||
+      !TestRoundTrip(NID_pbe_WithSHA1And3_Key_TripleDES_CBC, nullptr,
+                     "password", nullptr, 0, 10) ||
+      // Vary the salt
+      !TestRoundTrip(NID_pbe_WithSHA1And3_Key_TripleDES_CBC, nullptr,
+                     "password", nullptr, 4, 10) ||
+      !TestRoundTrip(NID_pbe_WithSHA1And3_Key_TripleDES_CBC, nullptr,
+                     "password", (const uint8_t *)"salt", 4, 10) ||
+      // Vary the iteration count.
+      !TestRoundTrip(NID_pbe_WithSHA1And3_Key_TripleDES_CBC, nullptr,
+                     "password", nullptr, 0, 1) ||
+      // Vary the password.
+      !TestRoundTrip(NID_pbe_WithSHA1And3_Key_TripleDES_CBC, nullptr, "",
+                     nullptr, 0, 1) ||
+      !TestRoundTrip(NID_pbe_WithSHA1And3_Key_TripleDES_CBC, nullptr, nullptr,
+                     nullptr, 0, 1) ||
+      // Vary the PBE suite.
+      !TestRoundTrip(NID_pbe_WithSHA1And40BitRC2_CBC, nullptr, "password",
+                     nullptr, 0, 10) ||
+      !TestRoundTrip(NID_pbe_WithSHA1And128BitRC4, nullptr, "password", nullptr,
+                     0, 10) ||
+      // Test PBES2.
+      !TestRoundTrip(-1, EVP_aes_128_cbc(), "password", nullptr, 0, 10) ||
+      !TestRoundTrip(-1, EVP_aes_128_cbc(), "password", nullptr, 4, 10) ||
+      !TestRoundTrip(-1, EVP_aes_128_cbc(), "password", (const uint8_t *)"salt",
+                     4, 10) ||
+      !TestRoundTrip(-1, EVP_aes_128_cbc(), "password", nullptr, 0, 1) ||
+      !TestRoundTrip(-1, EVP_rc2_cbc(), "password", nullptr, 0, 10)) {
     return 1;
   }