Factor SSLv3 key derivation steps into an ssl3_PRF.

Fix up the generate_master_secret parameter while we're here.

Change-Id: I1c80796d1f481be0c3eefcf3222f2d9fc1de4a51
Reviewed-on: https://boringssl-review.googlesource.com/2696
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 201f1e4..456481b 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -2435,6 +2435,7 @@
 #define SSL_F_ssl3_get_v2_client_hello 295
 #define SSL_F_ssl3_get_initial_bytes 296
 #define SSL_F_tls1_enc 297
+#define SSL_F_ssl3_PRF 298
 #define SSL_R_UNABLE_TO_FIND_ECDH_PARAMETERS 100
 #define SSL_R_DECRYPTION_FAILED_OR_BAD_RECORD_MAC 101
 #define SSL_R_INVALID_NULL_CMD_NAME 102
diff --git a/ssl/s3_enc.c b/ssl/s3_enc.c
index 9209278..68684b3 100644
--- a/ssl/s3_enc.c
+++ b/ssl/s3_enc.c
@@ -162,21 +162,24 @@
 static int ssl3_handshake_mac(SSL *s, int md_nid, const char *sender, int len,
                               uint8_t *p);
 
-static int ssl3_generate_key_block(SSL *s, uint8_t *km, int num) {
+static int ssl3_PRF(uint8_t *out, size_t out_len,
+                    const uint8_t *secret, size_t secret_len,
+                    const uint8_t *seed1, size_t seed1_len,
+                    const uint8_t *seed2, size_t seed2_len) {
   EVP_MD_CTX md5;
   EVP_MD_CTX sha1;
   uint8_t buf[16], smd[SHA_DIGEST_LENGTH];
   uint8_t c = 'A';
-  unsigned int i, j, k;
+  size_t i, j, k;
 
   k = 0;
   EVP_MD_CTX_init(&md5);
   EVP_MD_CTX_init(&sha1);
-  for (i = 0; (int)i < num; i += MD5_DIGEST_LENGTH) {
+  for (i = 0; i < out_len; i += MD5_DIGEST_LENGTH) {
     k++;
     if (k > sizeof(buf)) {
       /* bug: 'buf' is too small for this ciphersuite */
-      OPENSSL_PUT_ERROR(SSL, ssl3_generate_key_block, ERR_R_INTERNAL_ERROR);
+      OPENSSL_PUT_ERROR(SSL, ssl3_PRF, ERR_R_INTERNAL_ERROR);
       return 0;
     }
 
@@ -185,31 +188,33 @@
     }
     c++;
     if (!EVP_DigestInit_ex(&sha1, EVP_sha1(), NULL)) {
-      OPENSSL_PUT_ERROR(SSL, ssl3_generate_key_block, ERR_LIB_EVP);
+      OPENSSL_PUT_ERROR(SSL, ssl3_PRF, ERR_LIB_EVP);
       return 0;
     }
     EVP_DigestUpdate(&sha1, buf, k);
-    EVP_DigestUpdate(&sha1, s->session->master_key,
-                     s->session->master_key_length);
-    EVP_DigestUpdate(&sha1, s->s3->server_random, SSL3_RANDOM_SIZE);
-    EVP_DigestUpdate(&sha1, s->s3->client_random, SSL3_RANDOM_SIZE);
+    EVP_DigestUpdate(&sha1, secret, secret_len);
+    if (seed1_len) {
+      EVP_DigestUpdate(&sha1, seed1, seed1_len);
+    }
+    if (seed2_len) {
+      EVP_DigestUpdate(&sha1, seed2, seed2_len);
+    }
     EVP_DigestFinal_ex(&sha1, smd, NULL);
 
     if (!EVP_DigestInit_ex(&md5, EVP_md5(), NULL)) {
-      OPENSSL_PUT_ERROR(SSL, ssl3_generate_key_block, ERR_LIB_EVP);
+      OPENSSL_PUT_ERROR(SSL, ssl3_PRF, ERR_LIB_EVP);
       return 0;
     }
-    EVP_DigestUpdate(&md5, s->session->master_key,
-                     s->session->master_key_length);
+    EVP_DigestUpdate(&md5, secret, secret_len);
     EVP_DigestUpdate(&md5, smd, SHA_DIGEST_LENGTH);
-    if ((int)(i + MD5_DIGEST_LENGTH) > num) {
+    if (i + MD5_DIGEST_LENGTH > out_len) {
       EVP_DigestFinal_ex(&md5, smd, NULL);
-      memcpy(km, smd, (num - i));
+      memcpy(out, smd, out_len - i);
     } else {
-      EVP_DigestFinal_ex(&md5, km, NULL);
+      EVP_DigestFinal_ex(&md5, out, NULL);
     }
 
-    km += MD5_DIGEST_LENGTH;
+    out += MD5_DIGEST_LENGTH;
   }
 
   OPENSSL_cleanse(smd, SHA_DIGEST_LENGTH);
@@ -219,6 +224,12 @@
   return 1;
 }
 
+static int ssl3_generate_key_block(SSL *s, uint8_t *out, size_t out_len) {
+  return ssl3_PRF(out, out_len, s->session->master_key,
+                  s->session->master_key_length, s->s3->server_random,
+                  SSL3_RANDOM_SIZE, s->s3->client_random, SSL3_RANDOM_SIZE);
+}
+
 int ssl3_change_cipher_state(SSL *s, int which) {
   uint8_t *p, *mac_secret;
   uint8_t exp_key[EVP_MAX_KEY_LENGTH];
@@ -739,45 +750,15 @@
   }
 }
 
-int ssl3_generate_master_secret(SSL *s, uint8_t *out, uint8_t *p, int len) {
-  uint8_t buf[EVP_MAX_MD_SIZE];
-  EVP_MD_CTX ctx;
-  int i, ret = 0;
-  unsigned int n;
-
-  EVP_MD_CTX_init(&ctx);
-  for (i = 0; i < 3; i++) {
-    if (!EVP_DigestInit_ex(&ctx, EVP_sha1(), NULL)) {
-      ret = 0;
-      break;
-    }
-
-    if (i == 0) {
-      EVP_DigestUpdate(&ctx, (const uint8_t*) "A", 1);
-    } else if (i == 1) {
-      EVP_DigestUpdate(&ctx, (const uint8_t*) "BB", 2);
-    } else {
-      EVP_DigestUpdate(&ctx, (const uint8_t*) "CCC", 3);
-    }
-    EVP_DigestUpdate(&ctx, p, len);
-    EVP_DigestUpdate(&ctx, &s->s3->client_random[0], SSL3_RANDOM_SIZE);
-    EVP_DigestUpdate(&ctx, &s->s3->server_random[0], SSL3_RANDOM_SIZE);
-    EVP_DigestFinal_ex(&ctx, buf, &n);
-
-    if (!EVP_DigestInit_ex(&ctx, EVP_md5(), NULL)) {
-      ret = 0;
-      break;
-    }
-
-    EVP_DigestUpdate(&ctx, p, len);
-    EVP_DigestUpdate(&ctx, buf, n);
-    EVP_DigestFinal_ex(&ctx, out, &n);
-    out += n;
-    ret += n;
+int ssl3_generate_master_secret(SSL *s, uint8_t *out, const uint8_t *premaster,
+                                size_t premaster_len) {
+  if (!ssl3_PRF(out, SSL3_MASTER_SECRET_SIZE, premaster, premaster_len,
+                s->s3->client_random, SSL3_RANDOM_SIZE, s->s3->server_random,
+                SSL3_RANDOM_SIZE)) {
+    return 0;
   }
-  EVP_MD_CTX_cleanup(&ctx);
 
-  return ret;
+  return SSL3_MASTER_SECRET_SIZE;
 }
 
 int ssl3_alert_code(int code) {
diff --git a/ssl/ssl_error.c b/ssl/ssl_error.c
index a0a5ac7..691a30e 100644
--- a/ssl/ssl_error.c
+++ b/ssl/ssl_error.c
@@ -105,6 +105,7 @@
   {ERR_PACK(ERR_LIB_SSL, SSL_F_ssl23_peek, 0), "ssl23_peek"},
   {ERR_PACK(ERR_LIB_SSL, SSL_F_ssl23_read, 0), "ssl23_read"},
   {ERR_PACK(ERR_LIB_SSL, SSL_F_ssl23_write, 0), "ssl23_write"},
+  {ERR_PACK(ERR_LIB_SSL, SSL_F_ssl3_PRF, 0), "ssl3_PRF"},
   {ERR_PACK(ERR_LIB_SSL, SSL_F_ssl3_accept, 0), "ssl3_accept"},
   {ERR_PACK(ERR_LIB_SSL, SSL_F_ssl3_callback_ctrl, 0), "ssl3_callback_ctrl"},
   {ERR_PACK(ERR_LIB_SSL, SSL_F_ssl3_cert_verify_hash, 0), "ssl3_cert_verify_hash"},
diff --git a/ssl/ssl_locl.h b/ssl/ssl_locl.h
index 46683a5e..0264a87 100644
--- a/ssl/ssl_locl.h
+++ b/ssl/ssl_locl.h
@@ -579,7 +579,7 @@
   int (*enc)(SSL *, int);
   int (*mac)(SSL *, uint8_t *, int);
   int (*setup_key_block)(SSL *);
-  int (*generate_master_secret)(SSL *, uint8_t *, uint8_t *, int);
+  int (*generate_master_secret)(SSL *, uint8_t *, const uint8_t *, size_t);
   int (*change_cipher_state)(SSL *, int);
   int (*final_finish_mac)(SSL *, const char *, int, uint8_t *);
   int finish_mac_length;
@@ -741,7 +741,8 @@
 void ssl3_cleanup_key_block(SSL *s);
 int ssl3_do_write(SSL *s, int type);
 int ssl3_send_alert(SSL *s, int level, int desc);
-int ssl3_generate_master_secret(SSL *s, uint8_t *out, uint8_t *p, int len);
+int ssl3_generate_master_secret(SSL *s, uint8_t *out,
+                                const uint8_t *premaster, size_t premaster_len);
 int ssl3_get_req_cert_type(SSL *s, uint8_t *p);
 long ssl3_get_message(SSL *s, int header_state, int body_state, int msg_type,
                       long max, int hash_message, int *ok);
@@ -900,8 +901,8 @@
 int tls1_final_finish_mac(SSL *s, const char *str, int slen, uint8_t *p);
 int tls1_cert_verify_mac(SSL *s, int md_nid, uint8_t *p);
 int tls1_mac(SSL *ssl, uint8_t *md, int snd);
-int tls1_generate_master_secret(SSL *s, uint8_t *out, uint8_t *premaster,
-                                int premaster_len);
+int tls1_generate_master_secret(SSL *s, uint8_t *out, const uint8_t *premaster,
+                                size_t premaster_len);
 int tls1_export_keying_material(SSL *s, uint8_t *out, size_t olen,
                                 const char *label, size_t llen,
                                 const uint8_t *p, size_t plen, int use_context);
diff --git a/ssl/t1_enc.c b/ssl/t1_enc.c
index b24781a..7385e7f 100644
--- a/ssl/t1_enc.c
+++ b/ssl/t1_enc.c
@@ -1142,8 +1142,8 @@
   return md_size;
 }
 
-int tls1_generate_master_secret(SSL *s, uint8_t *out, uint8_t *premaster,
-                                int premaster_len) {
+int tls1_generate_master_secret(SSL *s, uint8_t *out, const uint8_t *premaster,
+                                size_t premaster_len) {
   if (s->s3->tmp.extended_master_secret) {
     uint8_t digests[2 * EVP_MAX_MD_SIZE];
     int digests_len;