Simplify ssl_private_key_* state machine points.

The original motivation behind the sign/complete split was to avoid
needlessly hashing the input on each pass through the state machine, but
we're payload-based now and, in all cases, the payload is either cheap
to compute or readily available. (Even the hashing worry was probably
unnecessary.)

Tweak ssl_private_key_{sign,decrypt} to automatically call
ssl_private_key_complete as needed and take advantage of this in the
handshake state machines:

- TLS 1.3 signing now computes the payload each pass. The payload is
  small and we're already allocating a comparable-sized buffer each
  iteration to hold the signature. This shouldn't be a big deal.

- TLS 1.2 decryption code still needs two states due to reading the
  message (fixed in new state machine style), but otherwise it just
  performs cheap idempotent tasks again. The PSK code is reshuffled to
  guarantee the callback is not called twice (though this was impossible
  anyway because we don't support RSA_PSK).

- TLS 1.2 CertificateVerify signing is easy as the transcript is readily
  available. The buffer is released very slightly later, but it
  shouldn't matter.

- TLS 1.2 ServerKeyExchange signing required some reshuffling.
  Assembling the ServerKeyExchange parameters is moved to the previous
  state. The signing payload has some randoms prepended. This is cheap
  enough, but a nuisance in C. Pre-prepend the randoms in
  hs->server_params.

With this change, we are *nearly* rid of the A/B => same function
pattern.

BUG=128

Change-Id: Iec4fe0be7cfc88a6de027ba2760fae70794ea810
Reviewed-on: https://boringssl-review.googlesource.com/17265
Commit-Queue: David Benjamin <davidben@google.com>
Commit-Queue: Steven Valdez <svaldez@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c
index 872347d..615302a 100644
--- a/ssl/handshake_server.c
+++ b/ssl/handshake_server.c
@@ -186,7 +186,6 @@
 
 int ssl3_accept(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  uint32_t alg_a;
   int ret = -1;
 
   assert(ssl->handshake_func == ssl3_accept);
@@ -257,12 +256,7 @@
         break;
 
       case SSL3_ST_SW_KEY_EXCH_A:
-      case SSL3_ST_SW_KEY_EXCH_B:
-        alg_a = hs->new_cipher->algorithm_auth;
-
-        /* PSK ciphers send ServerKeyExchange if there is an identity hint. */
-        if (ssl_cipher_requires_server_key_exchange(hs->new_cipher) ||
-            ((alg_a & SSL_aPSK) && ssl->psk_identity_hint)) {
+        if (hs->server_params_len > 0) {
           ret = ssl3_send_server_key_exchange(hs);
           if (ret <= 0) {
             goto end;
@@ -1030,50 +1024,48 @@
 
 static int ssl3_send_server_certificate(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  if (!ssl_cipher_uses_certificate_auth(hs->new_cipher)) {
-    return 1;
-  }
+  int ret = -1;
+  CBB cbb;
+  CBB_zero(&cbb);
 
-  if (!ssl_has_certificate(ssl)) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_NO_CERTIFICATE_SET);
-    return -1;
-  }
+  if (ssl_cipher_uses_certificate_auth(hs->new_cipher)) {
+    if (!ssl_has_certificate(ssl)) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_NO_CERTIFICATE_SET);
+      goto err;
+    }
 
-  if (!ssl3_output_cert_chain(ssl)) {
-    return -1;
-  }
+    if (!ssl3_output_cert_chain(ssl)) {
+      goto err;
+    }
 
-  if (hs->certificate_status_expected) {
-    CBB cbb, body, ocsp_response;
-    if (!ssl->method->init_message(ssl, &cbb, &body,
-                                   SSL3_MT_CERTIFICATE_STATUS) ||
-        !CBB_add_u8(&body, TLSEXT_STATUSTYPE_ocsp) ||
-        !CBB_add_u24_length_prefixed(&body, &ocsp_response) ||
-        !CBB_add_bytes(&ocsp_response,
-                       CRYPTO_BUFFER_data(ssl->cert->ocsp_response),
-                       CRYPTO_BUFFER_len(ssl->cert->ocsp_response)) ||
-        !ssl_add_message_cbb(ssl, &cbb)) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-      CBB_cleanup(&cbb);
-      return -1;
+    if (hs->certificate_status_expected) {
+      CBB body, ocsp_response;
+      if (!ssl->method->init_message(ssl, &cbb, &body,
+                                     SSL3_MT_CERTIFICATE_STATUS) ||
+          !CBB_add_u8(&body, TLSEXT_STATUSTYPE_ocsp) ||
+          !CBB_add_u24_length_prefixed(&body, &ocsp_response) ||
+          !CBB_add_bytes(&ocsp_response,
+                         CRYPTO_BUFFER_data(ssl->cert->ocsp_response),
+                         CRYPTO_BUFFER_len(ssl->cert->ocsp_response)) ||
+          !ssl_add_message_cbb(ssl, &cbb)) {
+        OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+        goto err;
+      }
     }
   }
 
-  return 1;
-}
+  /* Assemble ServerKeyExchange parameters if needed. */
+  uint32_t alg_k = hs->new_cipher->algorithm_mkey;
+  uint32_t alg_a = hs->new_cipher->algorithm_auth;
+  if (ssl_cipher_requires_server_key_exchange(hs->new_cipher) ||
+      ((alg_a & SSL_aPSK) && ssl->psk_identity_hint)) {
 
-static int ssl3_send_server_key_exchange(SSL_HANDSHAKE *hs) {
-  SSL *const ssl = hs->ssl;
-  CBB cbb, child;
-  CBB_zero(&cbb);
-
-  /* Put together the parameters. */
-  if (hs->state == SSL3_ST_SW_KEY_EXCH_A) {
-    uint32_t alg_k = hs->new_cipher->algorithm_mkey;
-    uint32_t alg_a = hs->new_cipher->algorithm_auth;
-
-    /* Pre-allocate enough room to comfortably fit an ECDHE public key. */
-    if (!CBB_init(&cbb, 128)) {
+    /* Pre-allocate enough room to comfortably fit an ECDHE public key. Prepend
+     * the client and server randoms for the signing transcript. */
+    CBB child;
+    if (!CBB_init(&cbb, SSL3_RANDOM_SIZE * 2 + 128) ||
+        !CBB_add_bytes(&cbb, ssl->s3->client_random, SSL3_RANDOM_SIZE) ||
+        !CBB_add_bytes(&cbb, ssl->s3->server_random, SSL3_RANDOM_SIZE)) {
       goto err;
     }
 
@@ -1115,11 +1107,22 @@
     }
   }
 
-  /* Assemble the message. */
-  CBB body;
+  ret = 1;
+
+err:
+  CBB_cleanup(&cbb);
+  return ret;
+}
+
+static int ssl3_send_server_key_exchange(SSL_HANDSHAKE *hs) {
+  SSL *const ssl = hs->ssl;
+  CBB cbb, body, child;
   if (!ssl->method->init_message(ssl, &cbb, &body,
                                  SSL3_MT_SERVER_KEY_EXCHANGE) ||
-      !CBB_add_bytes(&body, hs->server_params, hs->server_params_len)) {
+      /* |hs->server_params| contains a prefix for signing. */
+      hs->server_params_len < 2 * SSL3_RANDOM_SIZE ||
+      !CBB_add_bytes(&body, hs->server_params + 2 * SSL3_RANDOM_SIZE,
+                     hs->server_params_len - 2 * SSL3_RANDOM_SIZE)) {
     goto err;
   }
 
@@ -1152,36 +1155,9 @@
     }
 
     size_t sig_len;
-    enum ssl_private_key_result_t sign_result;
-    if (hs->state == SSL3_ST_SW_KEY_EXCH_A) {
-      CBB transcript;
-      uint8_t *transcript_data;
-      size_t transcript_len;
-      if (!CBB_init(&transcript,
-                    2 * SSL3_RANDOM_SIZE + hs->server_params_len) ||
-          !CBB_add_bytes(&transcript, ssl->s3->client_random,
-                         SSL3_RANDOM_SIZE) ||
-          !CBB_add_bytes(&transcript, ssl->s3->server_random,
-                         SSL3_RANDOM_SIZE) ||
-          !CBB_add_bytes(&transcript, hs->server_params,
-                         hs->server_params_len) ||
-          !CBB_finish(&transcript, &transcript_data, &transcript_len)) {
-        CBB_cleanup(&transcript);
-        OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-        ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
-        goto err;
-      }
-
-      sign_result = ssl_private_key_sign(ssl, ptr, &sig_len, max_sig_len,
-                                         signature_algorithm, transcript_data,
-                                         transcript_len);
-      OPENSSL_free(transcript_data);
-    } else {
-      assert(hs->state == SSL3_ST_SW_KEY_EXCH_B);
-      sign_result = ssl_private_key_complete(ssl, ptr, &sig_len, max_sig_len);
-    }
-
-    switch (sign_result) {
+    switch (ssl_private_key_sign(hs, ptr, &sig_len, max_sig_len,
+                                 signature_algorithm, hs->server_params,
+                                 hs->server_params_len)) {
       case ssl_private_key_success:
         if (!CBB_did_write(&child, sig_len)) {
           goto err;
@@ -1191,7 +1167,6 @@
         goto err;
       case ssl_private_key_retry:
         ssl->rwstate = SSL_PRIVATE_KEY_OPERATION;
-        hs->state = SSL3_ST_SW_KEY_EXCH_B;
         goto err;
     }
   }
@@ -1345,32 +1320,26 @@
 static int ssl3_get_client_key_exchange(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
   CBS client_key_exchange;
-  uint32_t alg_k;
-  uint32_t alg_a;
   uint8_t *premaster_secret = NULL;
   size_t premaster_secret_len = 0;
   uint8_t *decrypt_buf = NULL;
 
-  unsigned psk_len = 0;
-  uint8_t psk[PSK_MAX_PSK_LEN];
-
   if (hs->state == SSL3_ST_SR_KEY_EXCH_A) {
     int ret = ssl->method->ssl_get_message(ssl);
     if (ret <= 0) {
       return ret;
     }
+  }
 
-    if (!ssl_check_message_type(ssl, SSL3_MT_CLIENT_KEY_EXCHANGE) ||
-        !ssl_hash_current_message(hs)) {
-      return -1;
-    }
+  if (!ssl_check_message_type(ssl, SSL3_MT_CLIENT_KEY_EXCHANGE)) {
+    return -1;
   }
 
   CBS_init(&client_key_exchange, ssl->init_msg, ssl->init_num);
-  alg_k = hs->new_cipher->algorithm_mkey;
-  alg_a = hs->new_cipher->algorithm_auth;
+  uint32_t alg_k = hs->new_cipher->algorithm_mkey;
+  uint32_t alg_a = hs->new_cipher->algorithm_auth;
 
-  /* If using a PSK key exchange, prepare the pre-shared key. */
+  /* If using a PSK key exchange, parse the PSK identity. */
   if (alg_a & SSL_aPSK) {
     CBS psk_identity;
 
@@ -1383,12 +1352,6 @@
       goto err;
     }
 
-    if (ssl->psk_server_callback == NULL) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_PSK_NO_SERVER_CB);
-      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
-      goto err;
-    }
-
     if (CBS_len(&psk_identity) > PSK_MAX_IDENTITY_LEN ||
         CBS_contains_zero_byte(&psk_identity)) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_DATA_LENGTH_TOO_LONG);
@@ -1401,25 +1364,24 @@
       ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
       goto err;
     }
-
-    /* Look up the key for the identity. */
-    psk_len = ssl->psk_server_callback(ssl, hs->new_session->psk_identity, psk,
-                                       sizeof(psk));
-    if (psk_len > PSK_MAX_PSK_LEN) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
-      goto err;
-    } else if (psk_len == 0) {
-      /* PSK related to the given identity not found */
-      OPENSSL_PUT_ERROR(SSL, SSL_R_PSK_IDENTITY_NOT_FOUND);
-      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNKNOWN_PSK_IDENTITY);
-      goto err;
-    }
   }
 
   /* Depending on the key exchange method, compute |premaster_secret| and
    * |premaster_secret_len|. */
   if (alg_k & SSL_kRSA) {
+    CBS encrypted_premaster_secret;
+    if (ssl->version > SSL3_VERSION) {
+      if (!CBS_get_u16_length_prefixed(&client_key_exchange,
+                                       &encrypted_premaster_secret) ||
+          CBS_len(&client_key_exchange) != 0) {
+        OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
+        ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
+        goto err;
+      }
+    } else {
+      encrypted_premaster_secret = client_key_exchange;
+    }
+
     /* Allocate a buffer large enough for an RSA decryption. */
     const size_t rsa_size = EVP_PKEY_size(hs->local_pubkey);
     decrypt_buf = OPENSSL_malloc(rsa_size);
@@ -1428,43 +1390,12 @@
       goto err;
     }
 
-    enum ssl_private_key_result_t decrypt_result;
+    /* Decrypt with no padding. PKCS#1 padding will be removed as part of the
+     * timing-sensitive code below. */
     size_t decrypt_len;
-    if (hs->state == SSL3_ST_SR_KEY_EXCH_A) {
-      if (!ssl_has_private_key(ssl) ||
-          EVP_PKEY_id(hs->local_pubkey) != EVP_PKEY_RSA) {
-        OPENSSL_PUT_ERROR(SSL, SSL_R_MISSING_RSA_CERTIFICATE);
-        ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
-        goto err;
-      }
-      CBS encrypted_premaster_secret;
-      if (ssl->version > SSL3_VERSION) {
-        if (!CBS_get_u16_length_prefixed(&client_key_exchange,
-                                         &encrypted_premaster_secret) ||
-            CBS_len(&client_key_exchange) != 0) {
-          OPENSSL_PUT_ERROR(SSL,
-                            SSL_R_TLS_RSA_ENCRYPTED_VALUE_LENGTH_IS_WRONG);
-          ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
-          goto err;
-        }
-      } else {
-        encrypted_premaster_secret = client_key_exchange;
-      }
-
-      /* Decrypt with no padding. PKCS#1 padding will be removed as part of the
-       * timing-sensitive code below. */
-      decrypt_result = ssl_private_key_decrypt(
-          ssl, decrypt_buf, &decrypt_len, rsa_size,
-          CBS_data(&encrypted_premaster_secret),
-          CBS_len(&encrypted_premaster_secret));
-    } else {
-      assert(hs->state == SSL3_ST_SR_KEY_EXCH_B);
-      /* Complete async decrypt. */
-      decrypt_result =
-          ssl_private_key_complete(ssl, decrypt_buf, &decrypt_len, rsa_size);
-    }
-
-    switch (decrypt_result) {
+    switch (ssl_private_key_decrypt(hs, decrypt_buf, &decrypt_len, rsa_size,
+                                    CBS_data(&encrypted_premaster_secret),
+                                    CBS_len(&encrypted_premaster_secret))) {
       case ssl_private_key_success:
         break;
       case ssl_private_key_failure:
@@ -1547,17 +1478,7 @@
 
     /* The key exchange state may now be discarded. */
     SSL_ECDH_CTX_cleanup(&hs->ecdh_ctx);
-  } else if (alg_k & SSL_kPSK) {
-    /* For plain PSK, other_secret is a block of 0s with the same length as the
-     * pre-shared key. */
-    premaster_secret_len = psk_len;
-    premaster_secret = OPENSSL_malloc(premaster_secret_len);
-    if (premaster_secret == NULL) {
-      OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
-      goto err;
-    }
-    OPENSSL_memset(premaster_secret, 0, premaster_secret_len);
-  } else {
+  } else if (!(alg_k & SSL_kPSK)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_CIPHER_TYPE);
     ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
     goto err;
@@ -1566,10 +1487,42 @@
   /* For a PSK cipher suite, the actual pre-master secret is combined with the
    * pre-shared key. */
   if (alg_a & SSL_aPSK) {
+    if (ssl->psk_server_callback == NULL) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_PSK_NO_SERVER_CB);
+      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
+      goto err;
+    }
+
+    /* Look up the key for the identity. */
+    uint8_t psk[PSK_MAX_PSK_LEN];
+    unsigned psk_len = ssl->psk_server_callback(
+        ssl, hs->new_session->psk_identity, psk, sizeof(psk));
+    if (psk_len > PSK_MAX_PSK_LEN) {
+      OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
+      goto err;
+    } else if (psk_len == 0) {
+      /* PSK related to the given identity not found */
+      OPENSSL_PUT_ERROR(SSL, SSL_R_PSK_IDENTITY_NOT_FOUND);
+      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNKNOWN_PSK_IDENTITY);
+      goto err;
+    }
+
+    if (alg_k & SSL_kPSK) {
+      /* In plain PSK, other_secret is a block of 0s with the same length as the
+       * pre-shared key. */
+      premaster_secret_len = psk_len;
+      premaster_secret = OPENSSL_malloc(premaster_secret_len);
+      if (premaster_secret == NULL) {
+        OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+        goto err;
+      }
+      OPENSSL_memset(premaster_secret, 0, premaster_secret_len);
+    }
+
     CBB new_premaster, child;
     uint8_t *new_data;
     size_t new_len;
-
     CBB_zero(&new_premaster);
     if (!CBB_init(&new_premaster, 2 + psk_len + 2 + premaster_secret_len) ||
         !CBB_add_u16_length_prefixed(&new_premaster, &child) ||
@@ -1588,6 +1541,10 @@
     premaster_secret_len = new_len;
   }
 
+  if (!ssl_hash_current_message(hs)) {
+    goto err;
+  }
+
   /* Compute the master secret */
   hs->new_session->master_key_length = tls1_generate_master_secret(
       hs, hs->new_session->master_key, premaster_secret, premaster_secret_len);