Add PKCS8_{decrypt,encrypt}_pbe.

The original functions do an ascii_to_ucs2 transformation on the password.
Deprecate them in favor of making that encoding the caller's problem.
ascii_to_ucs2 doesn't handle, say, UTF-8 anyway. And with the original OpenSSL
function, some ciphers would do the transformation, and some wouldn't making
the text-string/bytes-string confusion even messier.

BUG=399121

Change-Id: I7d1cea20a260f21eec2e8ffb7cd6be239fe92873
Reviewed-on: https://boringssl-review.googlesource.com/1347
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/pkcs8/pkcs8.c b/crypto/pkcs8/pkcs8.c
index 310af83..b39f10b 100644
--- a/crypto/pkcs8/pkcs8.c
+++ b/crypto/pkcs8/pkcs8.c
@@ -70,7 +70,7 @@
 #define PKCS12_IV_ID 2
 
 static int ascii_to_ucs2(const char *ascii, size_t ascii_len,
-                              uint8_t **out, size_t *out_len) {
+                         uint8_t **out, size_t *out_len) {
   uint8_t *unitmp;
   size_t ulen, i;
 
@@ -95,8 +95,9 @@
   return 1;
 }
 
-static int pkcs12_key_gen_uni(uint8_t *pass, size_t pass_len, uint8_t *salt,
-                              size_t salt_len, int id, int iterations,
+static int pkcs12_key_gen_raw(const uint8_t *pass_raw, size_t pass_raw_len,
+                              uint8_t *salt, size_t salt_len,
+                              int id, int iterations,
                               size_t out_len, uint8_t *out,
                               const EVP_MD *md_type) {
   uint8_t *B, *D, *I, *p, *Ai;
@@ -114,8 +115,8 @@
   Ai = OPENSSL_malloc(u);
   B = OPENSSL_malloc(v + 1);
   Slen = v * ((salt_len + v - 1) / v);
-  if (pass_len)
-    Plen = v * ((pass_len + v - 1) / v);
+  if (pass_raw_len)
+    Plen = v * ((pass_raw_len + v - 1) / v);
   else
     Plen = 0;
   Ilen = Slen + Plen;
@@ -130,7 +131,7 @@
   for (i = 0; i < Slen; i++)
     *p++ = salt[i % salt_len];
   for (i = 0; i < Plen; i++)
-    *p++ = pass[i % pass_len];
+    *p++ = pass_raw[i % pass_raw_len];
   for (;;) {
     if (!EVP_DigestInit_ex(&ctx, md_type, NULL) ||
         !EVP_DigestUpdate(&ctx, D, v) ||
@@ -184,7 +185,7 @@
   }
 
 err:
-  OPENSSL_PUT_ERROR(PKCS8, pkcs12_key_gen_uni, ERR_R_MALLOC_FAILURE);
+  OPENSSL_PUT_ERROR(PKCS8, pkcs12_key_gen_raw, ERR_R_MALLOC_FAILURE);
 
 end:
   OPENSSL_free(Ai);
@@ -198,31 +199,8 @@
   return ret;
 }
 
-static int pkcs12_key_gen_asc(const char *pass, size_t pass_len, uint8_t *salt,
-                              size_t salt_len, int id, int iterations,
-                              int out_len, uint8_t *out,
-                              const EVP_MD *md_type) {
-  int ret;
-  uint8_t *ucs2_pass = NULL;
-  size_t ucs2_pass_len = 0;
-
-  if (pass && !ascii_to_ucs2(pass, pass_len, &ucs2_pass, &ucs2_pass_len)) {
-    OPENSSL_PUT_ERROR(PKCS8, pkcs12_key_gen_asc, PKCS8_R_DECODE_ERROR);
-    return 0;
-  }
-  ret = pkcs12_key_gen_uni(ucs2_pass, ucs2_pass_len, salt, salt_len, id,
-                           iterations, out_len, out, md_type);
-
-  if (ucs2_pass) {
-    OPENSSL_cleanse(ucs2_pass, ucs2_pass_len);
-    OPENSSL_free(ucs2_pass);
-  }
-
-  return ret;
-}
-
-static int pkcs12_pbe_keyivgen(EVP_CIPHER_CTX *ctx, const char *pass,
-                               size_t pass_len, ASN1_TYPE *param,
+static int pkcs12_pbe_keyivgen(EVP_CIPHER_CTX *ctx, const uint8_t *pass_raw,
+                               size_t pass_raw_len, ASN1_TYPE *param,
                                const EVP_CIPHER *cipher, const EVP_MD *md,
                                int is_encrypt) {
   PBEPARAM *pbe;
@@ -252,13 +230,13 @@
   }
   salt = pbe->salt->data;
   salt_len = pbe->salt->length;
-  if (!pkcs12_key_gen_asc(pass, pass_len, salt, salt_len, PKCS12_KEY_ID,
+  if (!pkcs12_key_gen_raw(pass_raw, pass_raw_len, salt, salt_len, PKCS12_KEY_ID,
                           iterations, EVP_CIPHER_key_length(cipher), key, md)) {
     OPENSSL_PUT_ERROR(PKCS8, pkcs12_pbe_keyivgen, PKCS8_R_KEY_GEN_ERROR);
     PBEPARAM_free(pbe);
     return 0;
   }
-  if (!pkcs12_key_gen_asc(pass, pass_len, salt, salt_len, PKCS12_IV_ID,
+  if (!pkcs12_key_gen_raw(pass_raw, pass_raw_len, salt, salt_len, PKCS12_IV_ID,
                           iterations, EVP_CIPHER_iv_length(cipher), iv, md)) {
     OPENSSL_PUT_ERROR(PKCS8, pkcs12_pbe_keyivgen, PKCS8_R_KEY_GEN_ERROR);
     PBEPARAM_free(pbe);
@@ -271,8 +249,8 @@
   return ret;
 }
 
-typedef int (*keygen_func)(EVP_CIPHER_CTX *ctx, const char *pass,
-                           size_t pass_len, ASN1_TYPE *param,
+typedef int (*keygen_func)(EVP_CIPHER_CTX *ctx, const uint8_t *pass_raw,
+                           size_t pass_raw_len, ASN1_TYPE *param,
                            const EVP_CIPHER *cipher, const EVP_MD *md,
                            int is_encrypt);
 
@@ -293,8 +271,9 @@
     },
 };
 
-static int pbe_cipher_init(ASN1_OBJECT *pbe_obj, const char *pass,
-                           size_t pass_len, ASN1_TYPE *param,
+static int pbe_cipher_init(ASN1_OBJECT *pbe_obj,
+                           const uint8_t *pass_raw, size_t pass_raw_len,
+                           ASN1_TYPE *param,
                            EVP_CIPHER_CTX *ctx, int is_encrypt) {
   const EVP_CIPHER *cipher;
   const EVP_MD *md;
@@ -342,7 +321,8 @@
     }
   }
 
-  if (!suite->keygen(ctx, pass, pass_len, param, cipher, md, is_encrypt)) {
+  if (!suite->keygen(ctx, pass_raw, pass_raw_len, param, cipher, md,
+                     is_encrypt)) {
     OPENSSL_PUT_ERROR(PKCS8, pbe_cipher_init, PKCS8_R_KEYGEN_FAILURE);
     return 0;
   }
@@ -350,7 +330,8 @@
   return 1;
 }
 
-static int pbe_crypt(const X509_ALGOR *algor, const char *pass, size_t pass_len,
+static int pbe_crypt(const X509_ALGOR *algor,
+                     const uint8_t *pass_raw, size_t pass_raw_len,
                      uint8_t *in, size_t in_len, uint8_t **out, size_t *out_len,
                      int is_encrypt) {
   uint8_t *buf;
@@ -360,8 +341,8 @@
 
   EVP_CIPHER_CTX_init(&ctx);
 
-  if (!pbe_cipher_init(algor->algorithm, pass, pass_len, algor->parameter, &ctx,
-                       is_encrypt)) {
+  if (!pbe_cipher_init(algor->algorithm, pass_raw, pass_raw_len,
+                       algor->parameter, &ctx, is_encrypt)) {
     OPENSSL_PUT_ERROR(PKCS8, pbe_crypt, PKCS8_R_UNKNOWN_CIPHER_ALGORITHM);
     return 0;
   }
@@ -400,15 +381,16 @@
 }
 
 static void *pkcs12_item_decrypt_d2i(X509_ALGOR *algor, const ASN1_ITEM *it,
-                                     const char *pass, size_t pass_len,
+                                     const uint8_t *pass_raw,
+                                     size_t pass_raw_len,
                                      ASN1_OCTET_STRING *oct) {
   uint8_t *out;
   const uint8_t *p;
   void *ret;
   size_t out_len;
 
-  if (!pbe_crypt(algor, pass, pass_len, oct->data, oct->length, &out, &out_len,
-                 0 /* decrypt */)) {
+  if (!pbe_crypt(algor, pass_raw, pass_raw_len, oct->data, oct->length,
+                 &out, &out_len, 0 /* decrypt */)) {
     OPENSSL_PUT_ERROR(PKCS8, pkcs12_item_decrypt_d2i, PKCS8_R_CRYPT_ERROR);
     return NULL;
   }
@@ -424,18 +406,40 @@
 
 PKCS8_PRIV_KEY_INFO *PKCS8_decrypt(X509_SIG *pkcs8, const char *pass,
                                    int pass_len) {
-  if (pass && pass_len == -1) {
-    pass_len = strlen(pass);
+  uint8_t *pass_raw = NULL;
+  size_t pass_raw_len = 0;
+  PKCS8_PRIV_KEY_INFO *ret;
+
+  if (pass) {
+    if (pass_len == -1) {
+      pass_len = strlen(pass);
+    }
+    if (!ascii_to_ucs2(pass, pass_len, &pass_raw, &pass_raw_len)) {
+      OPENSSL_PUT_ERROR(PKCS8, pkcs12_key_gen_asc, PKCS8_R_DECODE_ERROR);
+      return NULL;
+    }
   }
+
+  ret = PKCS8_decrypt_pbe(pkcs8, pass_raw, pass_raw_len);
+
+  if (pass_raw) {
+    OPENSSL_cleanse(pass_raw, pass_raw_len);
+    OPENSSL_free(pass_raw);
+  }
+  return ret;
+}
+
+PKCS8_PRIV_KEY_INFO *PKCS8_decrypt_pbe(X509_SIG *pkcs8, const uint8_t *pass_raw,
+                                       size_t pass_raw_len) {
   return pkcs12_item_decrypt_d2i(pkcs8->algor,
-                                 ASN1_ITEM_rptr(PKCS8_PRIV_KEY_INFO), pass,
-                                 pass_len, pkcs8->digest);
+                                 ASN1_ITEM_rptr(PKCS8_PRIV_KEY_INFO), pass_raw,
+                                 pass_raw_len, pkcs8->digest);
 }
 
 static ASN1_OCTET_STRING *pkcs12_item_i2d_encrypt(X509_ALGOR *algor,
                                                   const ASN1_ITEM *it,
-                                                  const char *pass,
-                                                  size_t passlen, void *obj) {
+                                                  const uint8_t *pass_raw,
+                                                  size_t pass_raw_len, void *obj) {
   ASN1_OCTET_STRING *oct;
   uint8_t *in = NULL;
   int in_len;
@@ -451,7 +455,7 @@
     OPENSSL_PUT_ERROR(PKCS8, pkcs12_item_i2d_encrypt, PKCS8_R_ENCODE_ERROR);
     return NULL;
   }
-  if (!pbe_crypt(algor, pass, passlen, in, in_len, &oct->data, &crypt_len,
+  if (!pbe_crypt(algor, pass_raw, pass_raw_len, in, in_len, &oct->data, &crypt_len,
                  1 /* encrypt */)) {
     OPENSSL_PUT_ERROR(PKCS8, pkcs12_item_i2d_encrypt, PKCS8_R_ENCRYPT_ERROR);
     OPENSSL_free(in);
@@ -466,27 +470,46 @@
 X509_SIG *PKCS8_encrypt(int pbe_nid, const EVP_CIPHER *cipher, const char *pass,
                         int pass_len, uint8_t *salt, size_t salt_len,
                         int iterations, PKCS8_PRIV_KEY_INFO *p8inf) {
+  uint8_t *pass_raw = NULL;
+  size_t pass_raw_len = 0;
+  X509_SIG *ret;
+
+  if (pass) {
+    if (pass_len == -1) {
+      pass_len = strlen(pass);
+    }
+    if (!ascii_to_ucs2(pass, pass_len, &pass_raw, &pass_raw_len)) {
+      OPENSSL_PUT_ERROR(PKCS8, pkcs12_key_gen_asc, PKCS8_R_DECODE_ERROR);
+      return NULL;
+    }
+  }
+
+  ret = PKCS8_encrypt_pbe(pbe_nid, pass_raw, pass_raw_len,
+                          salt, salt_len, iterations, p8inf);
+
+  if (pass_raw) {
+    OPENSSL_cleanse(pass_raw, pass_raw_len);
+    OPENSSL_free(pass_raw);
+  }
+  return ret;
+}
+
+X509_SIG *PKCS8_encrypt_pbe(int pbe_nid,
+                            const uint8_t *pass_raw, size_t pass_raw_len,
+                            uint8_t *salt, size_t salt_len,
+                            int iterations, PKCS8_PRIV_KEY_INFO *p8inf) {
   X509_SIG *pkcs8 = NULL;
   X509_ALGOR *pbe;
 
-  if (pass && pass_len == -1) {
-    pass_len = strlen(pass);
-  }
-
   pkcs8 = X509_SIG_new();
   if (pkcs8 == NULL) {
-    OPENSSL_PUT_ERROR(PKCS8, PKCS8_encrypt, ERR_R_MALLOC_FAILURE);
+    OPENSSL_PUT_ERROR(PKCS8, PKCS8_encrypt_pbe, ERR_R_MALLOC_FAILURE);
     goto err;
   }
 
-  if (pbe_nid == -1) {
-    pbe = PKCS5_pbe2_set(cipher, iterations, salt, salt_len);
-  } else {
-    pbe = PKCS5_pbe_set(pbe_nid, iterations, salt, salt_len);
-  }
-
+  pbe = PKCS5_pbe_set(pbe_nid, iterations, salt, salt_len);
   if (!pbe) {
-    OPENSSL_PUT_ERROR(PKCS8, PKCS8_encrypt, ERR_R_ASN1_LIB);
+    OPENSSL_PUT_ERROR(PKCS8, PKCS8_encrypt_pbe, ERR_R_ASN1_LIB);
     goto err;
   }
 
@@ -494,9 +517,9 @@
   pkcs8->algor = pbe;
   M_ASN1_OCTET_STRING_free(pkcs8->digest);
   pkcs8->digest = pkcs12_item_i2d_encrypt(
-      pbe, ASN1_ITEM_rptr(PKCS8_PRIV_KEY_INFO), pass, pass_len, p8inf);
+      pbe, ASN1_ITEM_rptr(PKCS8_PRIV_KEY_INFO), pass_raw, pass_raw_len, p8inf);
   if (!pkcs8->digest) {
-    OPENSSL_PUT_ERROR(PKCS8, PKCS8_encrypt, PKCS8_R_ENCRYPT_ERROR);
+    OPENSSL_PUT_ERROR(PKCS8, PKCS8_encrypt_pbe, PKCS8_R_ENCRYPT_ERROR);
     goto err;
   }
 
diff --git a/crypto/pkcs8/pkcs8_error.c b/crypto/pkcs8/pkcs8_error.c
index 7536bd1..de1661b 100644
--- a/crypto/pkcs8/pkcs8_error.c
+++ b/crypto/pkcs8/pkcs8_error.c
@@ -23,13 +23,14 @@
   {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_PKCS5_pbe_set, 0), "PKCS5_pbe_set"},
   {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_PKCS5_pbe_set0_algor, 0), "PKCS5_pbe_set0_algor"},
   {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_PKCS5_pbkdf2_set, 0), "PKCS5_pbkdf2_set"},
+  {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_PKCS8_decrypt, 0), "PKCS8_decrypt"},
   {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_PKCS8_encrypt, 0), "PKCS8_encrypt"},
+  {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_PKCS8_encrypt_pbe, 0), "PKCS8_encrypt_pbe"},
   {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_pbe_cipher_init, 0), "pbe_cipher_init"},
   {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_pbe_crypt, 0), "pbe_crypt"},
   {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_pkcs12_item_decrypt_d2i, 0), "pkcs12_item_decrypt_d2i"},
   {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_pkcs12_item_i2d_encrypt, 0), "pkcs12_item_i2d_encrypt"},
-  {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_pkcs12_key_gen_asc, 0), "pkcs12_key_gen_asc"},
-  {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_pkcs12_key_gen_uni, 0), "pkcs12_key_gen_uni"},
+  {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_pkcs12_key_gen_raw, 0), "pkcs12_key_gen_raw"},
   {ERR_PACK(ERR_LIB_PKCS8, PKCS8_F_pkcs12_pbe_keyivgen, 0), "pkcs12_pbe_keyivgen"},
   {ERR_PACK(ERR_LIB_PKCS8, 0, PKCS8_R_CIPHER_HAS_NO_OBJECT_IDENTIFIER), "CIPHER_HAS_NO_OBJECT_IDENTIFIER"},
   {ERR_PACK(ERR_LIB_PKCS8, 0, PKCS8_R_CRYPT_ERROR), "CRYPT_ERROR"},
diff --git a/include/openssl/pkcs8.h b/include/openssl/pkcs8.h
index 917c7db..d871e57 100644
--- a/include/openssl/pkcs8.h
+++ b/include/openssl/pkcs8.h
@@ -65,12 +65,56 @@
 extern "C" {
 #endif
 
+
+/* PKCS8_encrypt_pbe serializes and encrypts a PKCS8_PRIV_KEY_INFO with PBES1 as
+ * defined in PKCS #5. Only pbeWithSHAAnd128BitRC4 and
+ * pbeWithSHAAnd3-KeyTripleDES-CBC, defined in PKCS #12, are supported. The
+ * |pass_raw_len| bytes pointed to by |pass_raw| are used as the password. Note
+ * that any conversions from the password as supplied in a text string (such as
+ * those specified in B.1 of PKCS #12) must be performed by the caller.
+ *
+ * If |salt| is NULL, a random salt of |salt_len| bytes is generated. If
+ * |salt_len| is zero, a default salt length is used instead.
+ *
+ * The resulting structure is stored in an X509_SIG which must be freed by the
+ * caller.
+ *
+ * TODO(davidben): Really? An X509_SIG? OpenSSL probably did that because it has
+ * the same structure as EncryptedPrivateKeyInfo. */
+OPENSSL_EXPORT X509_SIG *PKCS8_encrypt_pbe(int pbe_nid,
+                                           const uint8_t *pass_raw,
+                                           size_t pass_raw_len,
+                                           uint8_t *salt, size_t salt_len,
+                                           int iterations,
+                                           PKCS8_PRIV_KEY_INFO *p8inf);
+
+/* PKCS8_decrypt_pbe decrypts and decodes a PKCS8_PRIV_KEY_INFO with PBES1 as
+ * defined in PKCS #5. Only pbeWithSHAAnd128BitRC4 and
+ * pbeWithSHAAnd3-KeyTripleDES-CBC, defined in PKCS #12, are supported. The
+ * |pass_raw_len| bytes pointed to by |pass_raw| are used as the password. Note
+ * that any conversions from the password as supplied in a text string (such as
+ * those specified in B.1 of PKCS #12) must be performed by the caller.
+ *
+ * The resulting structure must be freed by the caller. */
+OPENSSL_EXPORT PKCS8_PRIV_KEY_INFO *PKCS8_decrypt_pbe(X509_SIG *pkcs8,
+                                                      const uint8_t *pass_raw,
+                                                      size_t pass_raw_len);
+
+
+/* Deprecated functions. */
+
+/* PKCS8_encrypt calls PKCS8_encrypt_pbe after treating |pass| as an ASCII
+ * string, appending U+0000, and converting to UCS-2. (So the empty password
+ * encodes as two NUL bytes.) The |cipher| argument is ignored. */
 OPENSSL_EXPORT X509_SIG *PKCS8_encrypt(int pbe_nid, const EVP_CIPHER *cipher,
                                        const char *pass, int pass_len,
                                        uint8_t *salt, size_t salt_len,
                                        int iterations,
                                        PKCS8_PRIV_KEY_INFO *p8inf);
 
+/* PKCS8_decrypt calls PKCS8_decrypt_pbe after treating |pass| as an ASCII
+ * string, appending U+0000, and converting to UCS-2. (So the empty password
+ * encodes as two NUL bytes.) */
 OPENSSL_EXPORT PKCS8_PRIV_KEY_INFO *PKCS8_decrypt(X509_SIG *pkcs8,
                                                   const char *pass,
                                                   int pass_len);
@@ -94,6 +138,9 @@
 #define PKCS8_F_pkcs12_item_i2d_encrypt 111
 #define PKCS8_F_PKCS5_pbe2_set_iv 112
 #define PKCS8_F_PKCS5_pbkdf2_set 113
+#define PKCS8_F_pkcs12_key_gen_raw 114
+#define PKCS8_F_PKCS8_decrypt 115
+#define PKCS8_F_PKCS8_encrypt_pbe 116
 #define PKCS8_R_ERROR_SETTING_CIPHER_PARAMS 100
 #define PKCS8_R_PRIVATE_KEY_ENCODE_ERROR 101
 #define PKCS8_R_UNKNOWN_ALGORITHM 102