Fix and test EVP_PKEY_CTX copying.

The RSA-PSS salt length was not being copied, and copying an Ed25519
EVP_MD_CTX did not work.

This is rather pointless (an EVP_PKEY_CTX is just a bundle of
parameters), and it's unlikely anyone ever will use this. But since
OpenSSL's EVP_PKEY signing API reuses EVP_MD_CTX and EVP_MD_CTX_copy_ex
is plausible in that scenario, we're stuck making EVP_MD_CTX_copy_ex
reachable for EVP_PKEY too. That then implies EVP_PKEY_dup should exist,
and if it exists we should be testing it.

Change-Id: I189435d0c716a83f58e1d8ac4abc2c409ecfea64
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/35626
Commit-Queue: David Benjamin <davidben@google.com>
Commit-Queue: Adam Langley <agl@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/evp/evp_test.cc b/crypto/evp/evp_test.cc
index 6571c30..4d74292 100644
--- a/crypto/evp/evp_test.cc
+++ b/crypto/evp/evp_test.cc
@@ -297,16 +297,19 @@
   }
 
   if (md_op_init) {
-    bssl::ScopedEVP_MD_CTX ctx;
+    bssl::ScopedEVP_MD_CTX ctx, copy;
     EVP_PKEY_CTX *pctx;
     if (!md_op_init(ctx.get(), &pctx, digest, nullptr, key) ||
-        !SetupContext(t, pctx)) {
+        !SetupContext(t, pctx) ||
+        !EVP_MD_CTX_copy_ex(copy.get(), ctx.get())) {
       return false;
     }
 
     if (is_verify) {
-      return !!EVP_DigestVerify(ctx.get(), output.data(), output.size(),
-                                input.data(), input.size());
+      return EVP_DigestVerify(ctx.get(), output.data(), output.size(),
+                              input.data(), input.size()) &&
+             EVP_DigestVerify(copy.get(), output.data(), output.size(),
+                              input.data(), input.size());
     }
 
     size_t len;
@@ -321,6 +324,21 @@
     }
     actual.resize(len);
     EXPECT_EQ(Bytes(output), Bytes(actual));
+
+    // Repeat the test with |copy|, to check |EVP_MD_CTX_copy_ex| duplicated
+    // everything.
+    if (!EVP_DigestSign(copy.get(), nullptr, &len, input.data(),
+                        input.size())) {
+      return false;
+    }
+    actual.resize(len);
+    if (!EVP_DigestSign(copy.get(), actual.data(), &len, input.data(),
+                        input.size()) ||
+        !t->GetBytes(&output, "Output")) {
+      return false;
+    }
+    actual.resize(len);
+    EXPECT_EQ(Bytes(output), Bytes(actual));
     return true;
   }
 
@@ -333,72 +351,78 @@
     return false;
   }
 
+  bssl::UniquePtr<EVP_PKEY_CTX> copy(EVP_PKEY_CTX_dup(ctx.get()));
+  if (!copy) {
+    return false;
+  }
+
   if (is_verify) {
-    return !!EVP_PKEY_verify(ctx.get(), output.data(), output.size(),
-                             input.data(), input.size());
+    return EVP_PKEY_verify(ctx.get(), output.data(), output.size(),
+                           input.data(), input.size()) &&
+           EVP_PKEY_verify(copy.get(), output.data(), output.size(),
+                           input.data(), input.size());
   }
 
-  size_t len;
-  if (!key_op(ctx.get(), nullptr, &len, input.data(), input.size())) {
-    return false;
-  }
-  actual.resize(len);
-  if (!key_op(ctx.get(), actual.data(), &len, input.data(), input.size())) {
-    return false;
-  }
+  for (EVP_PKEY_CTX *pctx : {ctx.get(), copy.get()}) {
+    size_t len;
+    if (!key_op(pctx, nullptr, &len, input.data(), input.size())) {
+      return false;
+    }
+    actual.resize(len);
+    if (!key_op(pctx, actual.data(), &len, input.data(), input.size())) {
+      return false;
+    }
 
-  // Encryption is non-deterministic, so we check by decrypting.
-  if (t->HasAttribute("CheckDecrypt")) {
-    size_t plaintext_len;
-    ctx.reset(EVP_PKEY_CTX_new(key, nullptr));
-    if (!ctx ||
-        !EVP_PKEY_decrypt_init(ctx.get()) ||
-        (digest != nullptr &&
-         !EVP_PKEY_CTX_set_signature_md(ctx.get(), digest)) ||
-        !SetupContext(t, ctx.get()) ||
-        !EVP_PKEY_decrypt(ctx.get(), nullptr, &plaintext_len, actual.data(),
-                          actual.size())) {
-      return false;
+    if (t->HasAttribute("CheckDecrypt")) {
+      // Encryption is non-deterministic, so we check by decrypting.
+      size_t plaintext_len;
+      bssl::UniquePtr<EVP_PKEY_CTX> decrypt_ctx(EVP_PKEY_CTX_new(key, nullptr));
+      if (!decrypt_ctx ||
+          !EVP_PKEY_decrypt_init(decrypt_ctx.get()) ||
+          (digest != nullptr &&
+           !EVP_PKEY_CTX_set_signature_md(decrypt_ctx.get(), digest)) ||
+          !SetupContext(t, decrypt_ctx.get()) ||
+          !EVP_PKEY_decrypt(decrypt_ctx.get(), nullptr, &plaintext_len,
+                            actual.data(), actual.size())) {
+        return false;
+      }
+      output.resize(plaintext_len);
+      if (!EVP_PKEY_decrypt(decrypt_ctx.get(), output.data(), &plaintext_len,
+                            actual.data(), actual.size())) {
+        ADD_FAILURE() << "Could not decrypt result.";
+        return false;
+      }
+      output.resize(plaintext_len);
+      EXPECT_EQ(Bytes(input), Bytes(output)) << "Decrypted result mismatch.";
+    } else if (t->HasAttribute("CheckVerify")) {
+      // Some signature schemes are non-deterministic, so we check by verifying.
+      bssl::UniquePtr<EVP_PKEY_CTX> verify_ctx(EVP_PKEY_CTX_new(key, nullptr));
+      if (!verify_ctx ||
+          !EVP_PKEY_verify_init(verify_ctx.get()) ||
+          (digest != nullptr &&
+           !EVP_PKEY_CTX_set_signature_md(verify_ctx.get(), digest)) ||
+          !SetupContext(t, verify_ctx.get())) {
+        return false;
+      }
+      if (t->HasAttribute("VerifyPSSSaltLength")) {
+        if (!EVP_PKEY_CTX_set_rsa_pss_saltlen(
+                verify_ctx.get(),
+                atoi(t->GetAttributeOrDie("VerifyPSSSaltLength").c_str()))) {
+          return false;
+        }
+      }
+      EXPECT_TRUE(EVP_PKEY_verify(verify_ctx.get(), actual.data(),
+                                  actual.size(), input.data(), input.size()))
+          << "Could not verify result.";
+    } else {
+      // By default, check by comparing the result against Output.
+      if (!t->GetBytes(&output, "Output")) {
+        return false;
+      }
+      actual.resize(len);
+      EXPECT_EQ(Bytes(output), Bytes(actual));
     }
-    output.resize(plaintext_len);
-    if (!EVP_PKEY_decrypt(ctx.get(), output.data(), &plaintext_len,
-                          actual.data(), actual.size())) {
-      ADD_FAILURE() << "Could not decrypt result.";
-      return false;
-    }
-    output.resize(plaintext_len);
-    EXPECT_EQ(Bytes(input), Bytes(output)) << "Decrypted result mismatch.";
-    return true;
   }
-
-  // Some signature schemes are non-deterministic, so we check by verifying.
-  if (t->HasAttribute("CheckVerify")) {
-    ctx.reset(EVP_PKEY_CTX_new(key, nullptr));
-    if (!ctx ||
-        !EVP_PKEY_verify_init(ctx.get()) ||
-        (digest != nullptr &&
-         !EVP_PKEY_CTX_set_signature_md(ctx.get(), digest)) ||
-        !SetupContext(t, ctx.get())) {
-      return false;
-    }
-    if (t->HasAttribute("VerifyPSSSaltLength") &&
-        !EVP_PKEY_CTX_set_rsa_pss_saltlen(
-            ctx.get(),
-            atoi(t->GetAttributeOrDie("VerifyPSSSaltLength").c_str()))) {
-      return false;
-    }
-    EXPECT_TRUE(EVP_PKEY_verify(ctx.get(), actual.data(), actual.size(),
-                                input.data(), input.size()))
-        << "Could not verify result.";
-    return true;
-  }
-
-  // By default, check by comparing the result against Output.
-  if (!t->GetBytes(&output, "Output")) {
-    return false;
-  }
-  actual.resize(len);
-  EXPECT_EQ(Bytes(output), Bytes(actual));
   return true;
 }
 
diff --git a/crypto/evp/evp_tests.txt b/crypto/evp/evp_tests.txt
index ff08ee7..9dbe1cb 100644
--- a/crypto/evp/evp_tests.txt
+++ b/crypto/evp/evp_tests.txt
@@ -274,6 +274,17 @@
 Input = "0123456789ABCDEF0123456789ABCDEF"
 CheckVerify
 
+# Check a salt length with a non-standard digest length, to verify things are
+# not just working due to defaults. (The current default is a maximum salt
+# length, but the ecosystem has converged on matching the digest length, so we
+# may change this in the future.)
+Sign = RSA-2048
+RSAPadding = PSS
+PSSSaltLength = 42
+Digest = SHA256
+Input = "0123456789ABCDEF0123456789ABCDEF"
+CheckVerify
+
 # Auto-detected salt length
 Verify = RSA-2048-SPKI
 RSAPadding = PSS
diff --git a/crypto/evp/p_rsa.c b/crypto/evp/p_rsa.c
index eb59901..865b36a 100644
--- a/crypto/evp/p_rsa.c
+++ b/crypto/evp/p_rsa.c
@@ -132,6 +132,7 @@
   dctx->pad_mode = sctx->pad_mode;
   dctx->md = sctx->md;
   dctx->mgf1md = sctx->mgf1md;
+  dctx->saltlen = sctx->saltlen;
   if (sctx->oaep_label) {
     OPENSSL_free(dctx->oaep_label);
     dctx->oaep_label = BUF_memdup(sctx->oaep_label, sctx->oaep_labellen);
diff --git a/crypto/fipsmodule/digest/digest.c b/crypto/fipsmodule/digest/digest.c
index e49d552..6705867 100644
--- a/crypto/fipsmodule/digest/digest.c
+++ b/crypto/fipsmodule/digest/digest.c
@@ -116,7 +116,9 @@
 void EVP_MD_CTX_destroy(EVP_MD_CTX *ctx) { EVP_MD_CTX_free(ctx); }
 
 int EVP_MD_CTX_copy_ex(EVP_MD_CTX *out, const EVP_MD_CTX *in) {
-  if (in == NULL || in->digest == NULL) {
+  // |in->digest| may be NULL if this is a signing |EVP_MD_CTX| for, e.g.,
+  // Ed25519 which does not hash with |EVP_MD_CTX|.
+  if (in == NULL || (in->pctx == NULL && in->digest == NULL)) {
     OPENSSL_PUT_ERROR(DIGEST, DIGEST_R_INPUT_NOT_INITIALIZED);
     return 0;
   }
@@ -131,29 +133,34 @@
     }
   }
 
-  uint8_t *tmp_buf;
-  if (out->digest != in->digest) {
-    assert(in->digest->ctx_size != 0);
-    tmp_buf = OPENSSL_malloc(in->digest->ctx_size);
-    if (tmp_buf == NULL) {
-      if (pctx) {
-        in->pctx_ops->free(pctx);
+  uint8_t *tmp_buf = NULL;
+  if (in->digest != NULL) {
+    if (out->digest != in->digest) {
+      assert(in->digest->ctx_size != 0);
+      tmp_buf = OPENSSL_malloc(in->digest->ctx_size);
+      if (tmp_buf == NULL) {
+        if (pctx) {
+          in->pctx_ops->free(pctx);
+        }
+        OPENSSL_PUT_ERROR(DIGEST, ERR_R_MALLOC_FAILURE);
+        return 0;
       }
-      OPENSSL_PUT_ERROR(DIGEST, ERR_R_MALLOC_FAILURE);
-      return 0;
+    } else {
+      // |md_data| will be the correct size in this case. It's removed from
+      // |out| so that |EVP_MD_CTX_cleanup| doesn't free it, and then it's
+      // reused.
+      tmp_buf = out->md_data;
+      out->md_data = NULL;
     }
-  } else {
-    // |md_data| will be the correct size in this case. It's removed from |out|
-    // so that |EVP_MD_CTX_cleanup| doesn't free it, and then it's reused.
-    tmp_buf = out->md_data;
-    out->md_data = NULL;
   }
 
   EVP_MD_CTX_cleanup(out);
 
   out->digest = in->digest;
   out->md_data = tmp_buf;
-  OPENSSL_memcpy(out->md_data, in->md_data, in->digest->ctx_size);
+  if (in->digest != NULL) {
+    OPENSSL_memcpy(out->md_data, in->md_data, in->digest->ctx_size);
+  }
   out->pctx = pctx;
   out->pctx_ops = in->pctx_ops;
   assert(out->pctx == NULL || out->pctx_ops != NULL);