Add get0 getters for EVP_PKEY.

Right now your options are:
- Bounce on a reference and deal with cleanup needlessly.
- Manually check the type tag and peek into the union.

We probably have no hope of opaquifying this struct, but for new code, let's
recommend using this function rather than the more error-prone thing.

Change-Id: I9b39ff95fe4264a3f7d1e0d2894db337aa968f6c
Reviewed-on: https://boringssl-review.googlesource.com/6551
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/evp/evp.c b/crypto/evp/evp.c
index 188a368..097a31c 100644
--- a/crypto/evp/evp.c
+++ b/crypto/evp/evp.c
@@ -231,15 +231,22 @@
   return EVP_PKEY_assign(pkey, EVP_PKEY_RSA, key);
 }
 
-RSA *EVP_PKEY_get1_RSA(EVP_PKEY *pkey) {
+RSA *EVP_PKEY_get0_RSA(EVP_PKEY *pkey) {
   if (pkey->type != EVP_PKEY_RSA) {
     OPENSSL_PUT_ERROR(EVP, EVP_R_EXPECTING_AN_RSA_KEY);
     return NULL;
   }
-  RSA_up_ref(pkey->pkey.rsa);
   return pkey->pkey.rsa;
 }
 
+RSA *EVP_PKEY_get1_RSA(EVP_PKEY *pkey) {
+  RSA *rsa = EVP_PKEY_get0_RSA(pkey);
+  if (rsa != NULL) {
+    RSA_up_ref(rsa);
+  }
+  return rsa;
+}
+
 int EVP_PKEY_set1_DSA(EVP_PKEY *pkey, DSA *key) {
   if (EVP_PKEY_assign_DSA(pkey, key)) {
     DSA_up_ref(key);
@@ -252,15 +259,22 @@
   return EVP_PKEY_assign(pkey, EVP_PKEY_DSA, key);
 }
 
-DSA *EVP_PKEY_get1_DSA(EVP_PKEY *pkey) {
+DSA *EVP_PKEY_get0_DSA(EVP_PKEY *pkey) {
   if (pkey->type != EVP_PKEY_DSA) {
     OPENSSL_PUT_ERROR(EVP, EVP_R_EXPECTING_A_DSA_KEY);
     return NULL;
   }
-  DSA_up_ref(pkey->pkey.dsa);
   return pkey->pkey.dsa;
 }
 
+DSA *EVP_PKEY_get1_DSA(EVP_PKEY *pkey) {
+  DSA *dsa = EVP_PKEY_get0_DSA(pkey);
+  if (dsa != NULL) {
+    DSA_up_ref(dsa);
+  }
+  return dsa;
+}
+
 int EVP_PKEY_set1_EC_KEY(EVP_PKEY *pkey, EC_KEY *key) {
   if (EVP_PKEY_assign_EC_KEY(pkey, key)) {
     EC_KEY_up_ref(key);
@@ -273,15 +287,22 @@
   return EVP_PKEY_assign(pkey, EVP_PKEY_EC, key);
 }
 
-EC_KEY *EVP_PKEY_get1_EC_KEY(EVP_PKEY *pkey) {
+EC_KEY *EVP_PKEY_get0_EC_KEY(EVP_PKEY *pkey) {
   if (pkey->type != EVP_PKEY_EC) {
     OPENSSL_PUT_ERROR(EVP, EVP_R_EXPECTING_AN_EC_KEY_KEY);
     return NULL;
   }
-  EC_KEY_up_ref(pkey->pkey.ec);
   return pkey->pkey.ec;
 }
 
+EC_KEY *EVP_PKEY_get1_EC_KEY(EVP_PKEY *pkey) {
+  EC_KEY *ec_key = EVP_PKEY_get0_EC_KEY(pkey);
+  if (ec_key != NULL) {
+    EC_KEY_up_ref(ec_key);
+  }
+  return ec_key;
+}
+
 int EVP_PKEY_set1_DH(EVP_PKEY *pkey, DH *key) {
   if (EVP_PKEY_assign_DH(pkey, key)) {
     DH_up_ref(key);
diff --git a/include/openssl/evp.h b/include/openssl/evp.h
index 6f594e5..e479e5e 100644
--- a/include/openssl/evp.h
+++ b/include/openssl/evp.h
@@ -143,19 +143,24 @@
  * The following functions get and set the underlying public key in an
  * |EVP_PKEY| object. The |set1| functions take an additional reference to the
  * underlying key and return one on success or zero on error. The |assign|
- * functions adopt the caller's reference. The getters return a fresh reference
- * to the underlying object. */
+ * functions adopt the caller's reference. The |get1| functions return a fresh
+ * reference to the underlying object or NULL if |pkey| is not of the correct
+ * type. The |get0| functions behave the same but return a non-owning
+ * pointer. */
 
 OPENSSL_EXPORT int EVP_PKEY_set1_RSA(EVP_PKEY *pkey, RSA *key);
 OPENSSL_EXPORT int EVP_PKEY_assign_RSA(EVP_PKEY *pkey, RSA *key);
+OPENSSL_EXPORT RSA *EVP_PKEY_get0_RSA(EVP_PKEY *pkey);
 OPENSSL_EXPORT RSA *EVP_PKEY_get1_RSA(EVP_PKEY *pkey);
 
 OPENSSL_EXPORT int EVP_PKEY_set1_DSA(EVP_PKEY *pkey, DSA *key);
 OPENSSL_EXPORT int EVP_PKEY_assign_DSA(EVP_PKEY *pkey, DSA *key);
+OPENSSL_EXPORT DSA *EVP_PKEY_get0_DSA(EVP_PKEY *pkey);
 OPENSSL_EXPORT DSA *EVP_PKEY_get1_DSA(EVP_PKEY *pkey);
 
 OPENSSL_EXPORT int EVP_PKEY_set1_EC_KEY(EVP_PKEY *pkey, EC_KEY *key);
 OPENSSL_EXPORT int EVP_PKEY_assign_EC_KEY(EVP_PKEY *pkey, EC_KEY *key);
+OPENSSL_EXPORT EC_KEY *EVP_PKEY_get0_EC_KEY(EVP_PKEY *pkey);
 OPENSSL_EXPORT EC_KEY *EVP_PKEY_get1_EC_KEY(EVP_PKEY *pkey);
 
 OPENSSL_EXPORT int EVP_PKEY_set1_DH(EVP_PKEY *pkey, DH *key);
diff --git a/ssl/s3_clnt.c b/ssl/s3_clnt.c
index b881966..f5af366 100644
--- a/ssl/s3_clnt.c
+++ b/ssl/s3_clnt.c
@@ -1664,7 +1664,6 @@
 
     /* Depending on the key exchange method, compute |pms| and |pms_len|. */
     if (alg_k & SSL_kRSA) {
-      RSA *rsa;
       size_t enc_pms_len;
 
       pms_len = SSL_MAX_MASTER_KEY_LENGTH;
@@ -1675,16 +1674,18 @@
       }
 
       pkey = X509_get_pubkey(s->session->peer);
-      if (pkey == NULL ||
-          pkey->type != EVP_PKEY_RSA ||
-          pkey->pkey.rsa == NULL) {
+      if (pkey == NULL) {
+        goto err;
+      }
+
+      RSA *rsa = EVP_PKEY_get0_RSA(pkey);
+      if (rsa == NULL) {
         OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
         EVP_PKEY_free(pkey);
         goto err;
       }
 
       s->session->key_exchange_info = EVP_PKEY_bits(pkey);
-      rsa = pkey->pkey.rsa;
       EVP_PKEY_free(pkey);
 
       pms[0] = s->client_version >> 8;
@@ -2161,13 +2162,13 @@
   }
   ssl->rwstate = SSL_NOTHING;
 
-  if (EVP_PKEY_id(ssl->tlsext_channel_id_private) != EVP_PKEY_EC) {
+  EC_KEY *ec_key = EVP_PKEY_get0_EC_KEY(ssl->tlsext_channel_id_private);
+  if (ec_key == NULL) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return -1;
   }
 
   int ret = -1;
-  EC_KEY *ec_key = ssl->tlsext_channel_id_private->pkey.ec;
   BIGNUM *x = BN_new();
   BIGNUM *y = BN_new();
   ECDSA_SIG *sig = NULL;
diff --git a/ssl/s3_lib.c b/ssl/s3_lib.c
index 7bf223d..f6d400a 100644
--- a/ssl/s3_lib.c
+++ b/ssl/s3_lib.c
@@ -335,8 +335,9 @@
 }
 
 int SSL_set1_tls_channel_id(SSL *ssl, EVP_PKEY *private_key) {
-  if (EVP_PKEY_id(private_key) != EVP_PKEY_EC ||
-      EC_GROUP_get_curve_name(EC_KEY_get0_group(private_key->pkey.ec)) !=
+  EC_KEY *ec_key = EVP_PKEY_get0_EC_KEY(private_key);
+  if (ec_key == NULL ||
+      EC_GROUP_get_curve_name(EC_KEY_get0_group(ec_key)) !=
           NID_X9_62_prime256v1) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_CHANNEL_ID_NOT_P256);
     return 0;
diff --git a/ssl/ssl_rsa.c b/ssl/ssl_rsa.c
index b6ae370..990979b 100644
--- a/ssl/ssl_rsa.c
+++ b/ssl/ssl_rsa.c
@@ -401,20 +401,19 @@
                                           in_len);
   }
 
-  if (ssl_private_key_type(ssl) != EVP_PKEY_RSA) {
+  RSA *rsa = EVP_PKEY_get0_RSA(ssl->cert->privatekey);
+  if (rsa == NULL) {
     /* Decrypt operations are only supported for RSA keys. */
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return ssl_private_key_failure;
   }
 
-  enum ssl_private_key_result_t ret = ssl_private_key_failure;
-  RSA *rsa = ssl->cert->privatekey->pkey.rsa;
   /* Decrypt with no padding. PKCS#1 padding will be removed as part
    * of the timing-sensitive code by the caller. */
-  if (RSA_decrypt(rsa, out_len, out, max_out, in, in_len, RSA_NO_PADDING)) {
-    ret = ssl_private_key_success;
+  if (!RSA_decrypt(rsa, out_len, out, max_out, in, in_len, RSA_NO_PADDING)) {
+    return ssl_private_key_failure;
   }
-  return ret;
+  return ssl_private_key_success;
 }
 
 enum ssl_private_key_result_t ssl_private_key_decrypt_complete(
diff --git a/ssl/t1_lib.c b/ssl/t1_lib.c
index 2a3ba7f..c7e52a9 100644
--- a/ssl/t1_lib.c
+++ b/ssl/t1_lib.c
@@ -584,9 +584,12 @@
   uint16_t curve_id;
   uint8_t comp_id;
 
-  if (!pkey ||
-      pkey->type != EVP_PKEY_EC ||
-      !tls1_curve_params_from_ec_key(&curve_id, &comp_id, pkey->pkey.ec) ||
+  if (!pkey) {
+    goto done;
+  }
+  EC_KEY *ec_key = EVP_PKEY_get0_EC_KEY(pkey);
+  if (ec_key == NULL ||
+      !tls1_curve_params_from_ec_key(&curve_id, &comp_id, ec_key) ||
       !tls1_check_curve_id(s, curve_id) ||
       comp_id != TLSEXT_ECPOINTFORMAT_uncompressed) {
     goto done;
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index 22ce889..1321b2a 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -222,13 +222,12 @@
     abort();
   }
 
-  EVP_PKEY *pkey = test_state->private_key.get();
-  if (pkey->type != EVP_PKEY_RSA || pkey->pkey.rsa == NULL) {
+  RSA *rsa = EVP_PKEY_get0_RSA(test_state->private_key.get());
+  if (rsa == NULL) {
     fprintf(stderr,
             "AsyncPrivateKeyDecrypt called with incorrect key type.\n");
     abort();
   }
-  RSA *rsa = pkey->pkey.rsa;
   test_state->private_key_result.resize(RSA_size(rsa));
   if (!RSA_decrypt(rsa, out_len, test_state->private_key_result.data(),
                    RSA_size(rsa), in, in_len, RSA_NO_PADDING)) {