Even more fun with Span.

Change-Id: If9f9fdc209b97f955b1ef3dea052393412865e59
Reviewed-on: https://boringssl-review.googlesource.com/22464
Commit-Queue: Steven Valdez <svaldez@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/include/openssl/tls1.h b/include/openssl/tls1.h
index 8eafe4c..0237e32 100644
--- a/include/openssl/tls1.h
+++ b/include/openssl/tls1.h
@@ -600,22 +600,6 @@
 #define TLS_CT_ECDSA_FIXED_ECDH 66
 
 #define TLS_MD_MAX_CONST_SIZE 20
-#define TLS_MD_CLIENT_FINISH_CONST "client finished"
-#define TLS_MD_CLIENT_FINISH_CONST_SIZE 15
-#define TLS_MD_SERVER_FINISH_CONST "server finished"
-#define TLS_MD_SERVER_FINISH_CONST_SIZE 15
-#define TLS_MD_KEY_EXPANSION_CONST "key expansion"
-#define TLS_MD_KEY_EXPANSION_CONST_SIZE 13
-#define TLS_MD_CLIENT_WRITE_KEY_CONST "client write key"
-#define TLS_MD_CLIENT_WRITE_KEY_CONST_SIZE 16
-#define TLS_MD_SERVER_WRITE_KEY_CONST "server write key"
-#define TLS_MD_SERVER_WRITE_KEY_CONST_SIZE 16
-#define TLS_MD_IV_BLOCK_CONST "IV block"
-#define TLS_MD_IV_BLOCK_CONST_SIZE 8
-#define TLS_MD_MASTER_SECRET_CONST "master secret"
-#define TLS_MD_MASTER_SECRET_CONST_SIZE 13
-#define TLS_MD_EXTENDED_MASTER_SECRET_CONST "extended master secret"
-#define TLS_MD_EXTENDED_MASTER_SECRET_CONST_SIZE 22
 
 
 #ifdef  __cplusplus
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index ff8ebd8..48466fb 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -1360,8 +1360,8 @@
     return ssl_hs_error;
   }
 
-  hs->new_session->master_key_length = tls1_generate_master_secret(
-      hs, hs->new_session->master_key, pms.data(), pms.size());
+  hs->new_session->master_key_length =
+      tls1_generate_master_secret(hs, hs->new_session->master_key, pms);
   if (hs->new_session->master_key_length == 0) {
     return ssl_hs_error;
   }
diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc
index d346875..bb565e9 100644
--- a/ssl/handshake_server.cc
+++ b/ssl/handshake_server.cc
@@ -1243,8 +1243,7 @@
 
   // Compute the master secret.
   hs->new_session->master_key_length = tls1_generate_master_secret(
-      hs, hs->new_session->master_key, premaster_secret.data(),
-      premaster_secret.size());
+      hs, hs->new_session->master_key, premaster_secret);
   if (hs->new_session->master_key_length == 0) {
     return ssl_hs_error;
   }
diff --git a/ssl/internal.h b/ssl/internal.h
index 7e1801a..5844b11 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -596,14 +596,12 @@
   ScopedEVP_MD_CTX md5_;
 };
 
-// tls1_prf computes the PRF function for |ssl|. It writes |out_len| bytes to
-// |out|, using |secret| as the secret and |label| as the label. |seed1| and
-// |seed2| are concatenated to form the seed parameter. It returns one on
-// success and zero on failure.
-int tls1_prf(const EVP_MD *digest, uint8_t *out, size_t out_len,
-             const uint8_t *secret, size_t secret_len, const char *label,
-             size_t label_len, const uint8_t *seed1, size_t seed1_len,
-             const uint8_t *seed2, size_t seed2_len);
+// tls1_prf computes the PRF function for |ssl|. It fills |out|, using |secret|
+// as the secret and |label| as the label. |seed1| and |seed2| are concatenated
+// to form the seed parameter. It returns true on success and false on failure.
+bool tls1_prf(const EVP_MD *digest, Span<uint8_t> out,
+              Span<const uint8_t> secret, Span<const char> label,
+              Span<const uint8_t> seed1, Span<const uint8_t> seed2);
 
 
 // Encryption layer.
@@ -2803,7 +2801,7 @@
 
 int tls1_change_cipher_state(SSL_HANDSHAKE *hs, evp_aead_direction_t direction);
 int tls1_generate_master_secret(SSL_HANDSHAKE *hs, uint8_t *out,
-                                const uint8_t *premaster, size_t premaster_len);
+                                Span<const uint8_t> premaster);
 
 // tls1_get_grouplist returns the locally-configured group preference list.
 Span<const uint16_t> tls1_get_grouplist(const SSL *ssl);
diff --git a/ssl/ssl_transcript.cc b/ssl/ssl_transcript.cc
index ab9822f..8486e5f 100644
--- a/ssl/ssl_transcript.cc
+++ b/ssl/ssl_transcript.cc
@@ -348,12 +348,11 @@
   // its own.
   assert(!buffer_);
 
-  const char *label = TLS_MD_CLIENT_FINISH_CONST;
-  size_t label_len = TLS_MD_SERVER_FINISH_CONST_SIZE;
-  if (from_server) {
-    label = TLS_MD_SERVER_FINISH_CONST;
-    label_len = TLS_MD_SERVER_FINISH_CONST_SIZE;
-  }
+  static const char kClientLabel[] = "client finished";
+  static const char kServerLabel[] = "server finished";
+  auto label = from_server
+                   ? MakeConstSpan(kServerLabel, sizeof(kServerLabel) - 1)
+                   : MakeConstSpan(kClientLabel, sizeof(kClientLabel) - 1);
 
   uint8_t digests[EVP_MAX_MD_SIZE];
   size_t digests_len;
@@ -362,9 +361,9 @@
   }
 
   static const size_t kFinishedLen = 12;
-  if (!tls1_prf(Digest(), out, kFinishedLen, session->master_key,
-                session->master_key_length, label, label_len, digests,
-                digests_len, NULL, 0)) {
+  if (!tls1_prf(Digest(), MakeSpan(out, kFinishedLen),
+                MakeConstSpan(session->master_key, session->master_key_length),
+                label, MakeConstSpan(digests, digests_len), {})) {
     return false;
   }
 
diff --git a/ssl/t1_enc.cc b/ssl/t1_enc.cc
index 1298a10..2a09987 100644
--- a/ssl/t1_enc.cc
+++ b/ssl/t1_enc.cc
@@ -155,26 +155,26 @@
 namespace bssl {
 
 // tls1_P_hash computes the TLS P_<hash> function as described in RFC 5246,
-// section 5. It XORs |out_len| bytes to |out|, using |md| as the hash and
-// |secret| as the secret. |seed1| through |seed3| are concatenated to form the
-// seed parameter. It returns one on success and zero on failure.
-static int tls1_P_hash(uint8_t *out, size_t out_len, const EVP_MD *md,
-                       const uint8_t *secret, size_t secret_len,
-                       const uint8_t *seed1, size_t seed1_len,
-                       const uint8_t *seed2, size_t seed2_len,
-                       const uint8_t *seed3, size_t seed3_len) {
+// section 5. It XORs |out.size()| bytes to |out|, using |md| as the hash and
+// |secret| as the secret. |label|, |seed1|, and |seed2| are concatenated to
+// form the seed parameter. It returns true on success and false on failure.
+static bool tls1_P_hash(Span<uint8_t> out, const EVP_MD *md,
+                        Span<const uint8_t> secret, Span<const char> label,
+                        Span<const uint8_t> seed1, Span<const uint8_t> seed2) {
   ScopedHMAC_CTX ctx, ctx_tmp, ctx_init;
   uint8_t A1[EVP_MAX_MD_SIZE];
   unsigned A1_len;
-  int ret = 0;
+  bool ret = false;
 
   size_t chunk = EVP_MD_size(md);
 
-  if (!HMAC_Init_ex(ctx_init.get(), secret, secret_len, md, NULL) ||
+  if (!HMAC_Init_ex(ctx_init.get(), secret.data(), secret.size(), md,
+                    nullptr) ||
       !HMAC_CTX_copy_ex(ctx.get(), ctx_init.get()) ||
-      !HMAC_Update(ctx.get(), seed1, seed1_len) ||
-      !HMAC_Update(ctx.get(), seed2, seed2_len) ||
-      !HMAC_Update(ctx.get(), seed3, seed3_len) ||
+      !HMAC_Update(ctx.get(), reinterpret_cast<const uint8_t *>(label.data()),
+                   label.size()) ||
+      !HMAC_Update(ctx.get(), seed1.data(), seed1.size()) ||
+      !HMAC_Update(ctx.get(), seed2.data(), seed2.size()) ||
       !HMAC_Final(ctx.get(), A1, &A1_len)) {
     goto err;
   }
@@ -185,27 +185,26 @@
     if (!HMAC_CTX_copy_ex(ctx.get(), ctx_init.get()) ||
         !HMAC_Update(ctx.get(), A1, A1_len) ||
         // Save a copy of |ctx| to compute the next A1 value below.
-        (out_len > chunk && !HMAC_CTX_copy_ex(ctx_tmp.get(), ctx.get())) ||
-        !HMAC_Update(ctx.get(), seed1, seed1_len) ||
-        !HMAC_Update(ctx.get(), seed2, seed2_len) ||
-        !HMAC_Update(ctx.get(), seed3, seed3_len) ||
+        (out.size() > chunk && !HMAC_CTX_copy_ex(ctx_tmp.get(), ctx.get())) ||
+        !HMAC_Update(ctx.get(), reinterpret_cast<const uint8_t *>(label.data()),
+                     label.size()) ||
+        !HMAC_Update(ctx.get(), seed1.data(), seed1.size()) ||
+        !HMAC_Update(ctx.get(), seed2.data(), seed2.size()) ||
         !HMAC_Final(ctx.get(), hmac, &len)) {
       goto err;
     }
     assert(len == chunk);
 
     // XOR the result into |out|.
-    if (len > out_len) {
-      len = out_len;
+    if (len > out.size()) {
+      len = out.size();
     }
-    unsigned i;
-    for (i = 0; i < len; i++) {
+    for (unsigned i = 0; i < len; i++) {
       out[i] ^= hmac[i];
     }
-    out += len;
-    out_len -= len;
+    out = out.subspan(len);
 
-    if (out_len == 0) {
+    if (out.empty()) {
       break;
     }
 
@@ -215,105 +214,86 @@
     }
   }
 
-  ret = 1;
+  ret = true;
 
 err:
   OPENSSL_cleanse(A1, sizeof(A1));
   return ret;
 }
 
-int tls1_prf(const EVP_MD *digest, uint8_t *out, size_t out_len,
-             const uint8_t *secret, size_t secret_len, const char *label,
-             size_t label_len, const uint8_t *seed1, size_t seed1_len,
-             const uint8_t *seed2, size_t seed2_len) {
-  if (out_len == 0) {
-    return 1;
+bool tls1_prf(const EVP_MD *digest, Span<uint8_t> out,
+              Span<const uint8_t> secret, Span<const char> label,
+              Span<const uint8_t> seed1, Span<const uint8_t> seed2) {
+  if (out.empty()) {
+    return true;
   }
 
-  OPENSSL_memset(out, 0, out_len);
+  OPENSSL_memset(out.data(), 0, out.size());
 
   if (digest == EVP_md5_sha1()) {
-    // If using the MD5/SHA1 PRF, |secret| is partitioned between SHA-1 and
-    // MD5, MD5 first.
-    size_t secret_half = secret_len - (secret_len / 2);
-    if (!tls1_P_hash(out, out_len, EVP_md5(), secret, secret_half,
-                     (const uint8_t *)label, label_len, seed1, seed1_len, seed2,
-                     seed2_len)) {
-      return 0;
+    // If using the MD5/SHA1 PRF, |secret| is partitioned between MD5 and SHA-1.
+    size_t secret_half = secret.size() - (secret.size() / 2);
+    if (!tls1_P_hash(out, EVP_md5(), secret.subspan(0, secret_half), label,
+                     seed1, seed2)) {
+      return false;
     }
 
-    // Note that, if |secret_len| is odd, the two halves share a byte.
-    secret = secret + (secret_len - secret_half);
-    secret_len = secret_half;
-
+    // Note that, if |secret.size()| is odd, the two halves share a byte.
+    secret = secret.subspan(secret.size() - secret_half);
     digest = EVP_sha1();
   }
 
-  if (!tls1_P_hash(out, out_len, digest, secret, secret_len,
-                   (const uint8_t *)label, label_len, seed1, seed1_len, seed2,
-                   seed2_len)) {
-    return 0;
-  }
-
-  return 1;
+  return tls1_P_hash(out, digest, secret, label, seed1, seed2);
 }
 
-static int ssl3_prf(uint8_t *out, size_t out_len, const uint8_t *secret,
-                    size_t secret_len, const char *label, size_t label_len,
-                    const uint8_t *seed1, size_t seed1_len,
-                    const uint8_t *seed2, size_t seed2_len) {
+static bool ssl3_prf(Span<uint8_t> out, Span<const uint8_t> secret,
+                     Span<const char> label, Span<const uint8_t> seed1,
+                     Span<const uint8_t> seed2) {
   ScopedEVP_MD_CTX md5;
   ScopedEVP_MD_CTX sha1;
   uint8_t buf[16], smd[SHA_DIGEST_LENGTH];
   uint8_t c = 'A';
-  size_t i, j, k;
-
-  k = 0;
-  for (i = 0; i < out_len; i += MD5_DIGEST_LENGTH) {
+  size_t k = 0;
+  while (!out.empty()) {
     k++;
     if (k > sizeof(buf)) {
       // bug: 'buf' is too small for this ciphersuite
       OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-      return 0;
+      return false;
     }
 
-    for (j = 0; j < k; j++) {
+    for (size_t j = 0; j < k; j++) {
       buf[j] = c;
     }
     c++;
     if (!EVP_DigestInit_ex(sha1.get(), EVP_sha1(), NULL)) {
       OPENSSL_PUT_ERROR(SSL, ERR_LIB_EVP);
-      return 0;
+      return false;
     }
     EVP_DigestUpdate(sha1.get(), buf, k);
-    EVP_DigestUpdate(sha1.get(), secret, secret_len);
+    EVP_DigestUpdate(sha1.get(), secret.data(), secret.size());
     // |label| is ignored for SSLv3.
-    if (seed1_len) {
-      EVP_DigestUpdate(sha1.get(), seed1, seed1_len);
-    }
-    if (seed2_len) {
-      EVP_DigestUpdate(sha1.get(), seed2, seed2_len);
-    }
+    EVP_DigestUpdate(sha1.get(), seed1.data(), seed1.size());
+    EVP_DigestUpdate(sha1.get(), seed2.data(), seed2.size());
     EVP_DigestFinal_ex(sha1.get(), smd, NULL);
 
     if (!EVP_DigestInit_ex(md5.get(), EVP_md5(), NULL)) {
       OPENSSL_PUT_ERROR(SSL, ERR_LIB_EVP);
-      return 0;
+      return false;
     }
-    EVP_DigestUpdate(md5.get(), secret, secret_len);
+    EVP_DigestUpdate(md5.get(), secret.data(), secret.size());
     EVP_DigestUpdate(md5.get(), smd, SHA_DIGEST_LENGTH);
-    if (i + MD5_DIGEST_LENGTH > out_len) {
+    if (out.size() < MD5_DIGEST_LENGTH) {
       EVP_DigestFinal_ex(md5.get(), smd, NULL);
-      OPENSSL_memcpy(out, smd, out_len - i);
-    } else {
-      EVP_DigestFinal_ex(md5.get(), out, NULL);
+      OPENSSL_memcpy(out.data(), smd, out.size());
+      break;
     }
-
-    out += MD5_DIGEST_LENGTH;
+    EVP_DigestFinal_ex(md5.get(), out.data(), NULL);
+    out = out.subspan(MD5_DIGEST_LENGTH);
   }
 
   OPENSSL_cleanse(smd, SHA_DIGEST_LENGTH);
-  return 1;
+  return true;
 }
 
 static bool get_key_block_lengths(const SSL *ssl, size_t *out_mac_secret_len,
@@ -405,33 +385,33 @@
 }
 
 int tls1_generate_master_secret(SSL_HANDSHAKE *hs, uint8_t *out,
-                                const uint8_t *premaster,
-                                size_t premaster_len) {
+                                Span<const uint8_t> premaster) {
+  static const char kMasterSecretLabel[] = "master secret";
+  static const char kExtendedMasterSecretLabel[] = "extended master secret";
+
   const SSL *ssl = hs->ssl;
+  auto out_span = MakeSpan(out, SSL3_MASTER_SECRET_SIZE);
   if (hs->extended_master_secret) {
+    auto label = MakeConstSpan(kExtendedMasterSecretLabel,
+                               sizeof(kExtendedMasterSecretLabel) - 1);
     uint8_t digests[EVP_MAX_MD_SIZE];
     size_t digests_len;
     if (!hs->transcript.GetHash(digests, &digests_len) ||
-        !tls1_prf(hs->transcript.Digest(), out, SSL3_MASTER_SECRET_SIZE,
-                  premaster, premaster_len, TLS_MD_EXTENDED_MASTER_SECRET_CONST,
-                  TLS_MD_EXTENDED_MASTER_SECRET_CONST_SIZE, digests,
-                  digests_len, NULL, 0)) {
+        !tls1_prf(hs->transcript.Digest(), out_span, premaster, label,
+                  MakeConstSpan(digests, digests_len), {})) {
       return 0;
     }
   } else {
+    auto label =
+        MakeConstSpan(kMasterSecretLabel, sizeof(kMasterSecretLabel) - 1);
     if (ssl_protocol_version(ssl) == SSL3_VERSION) {
-      if (!ssl3_prf(out, SSL3_MASTER_SECRET_SIZE, premaster, premaster_len,
-                    TLS_MD_MASTER_SECRET_CONST, TLS_MD_MASTER_SECRET_CONST_SIZE,
-                    ssl->s3->client_random, SSL3_RANDOM_SIZE,
-                    ssl->s3->server_random, SSL3_RANDOM_SIZE)) {
+      if (!ssl3_prf(out_span, premaster, label, ssl->s3->client_random,
+                    ssl->s3->server_random)) {
         return 0;
       }
     } else {
-      if (!tls1_prf(hs->transcript.Digest(), out, SSL3_MASTER_SECRET_SIZE,
-                    premaster, premaster_len, TLS_MD_MASTER_SECRET_CONST,
-                    TLS_MD_MASTER_SECRET_CONST_SIZE, ssl->s3->client_random,
-                    SSL3_RANDOM_SIZE, ssl->s3->server_random,
-                    SSL3_RANDOM_SIZE)) {
+      if (!tls1_prf(hs->transcript.Digest(), out_span, premaster, label,
+                    ssl->s3->client_random, ssl->s3->server_random)) {
         return 0;
       }
     }
@@ -457,18 +437,20 @@
 
 int SSL_generate_key_block(const SSL *ssl, uint8_t *out, size_t out_len) {
   const SSL_SESSION *session = SSL_get_session(ssl);
+  auto out_span = MakeSpan(out, out_len);
+  auto master_key =
+      MakeConstSpan(session->master_key, session->master_key_length);
+  static const char kLabel[] = "key expansion";
+  auto label = MakeConstSpan(kLabel, sizeof(kLabel) - 1);
+
   if (ssl_protocol_version(ssl) == SSL3_VERSION) {
-    return ssl3_prf(out, out_len, session->master_key,
-                    session->master_key_length, TLS_MD_KEY_EXPANSION_CONST,
-                    TLS_MD_KEY_EXPANSION_CONST_SIZE, ssl->s3->server_random,
-                    SSL3_RANDOM_SIZE, ssl->s3->client_random, SSL3_RANDOM_SIZE);
+    return ssl3_prf(out_span, master_key, label, ssl->s3->server_random,
+                    ssl->s3->client_random);
   }
 
   const EVP_MD *digest = ssl_session_get_digest(session);
-  return tls1_prf(digest, out, out_len, session->master_key,
-                  session->master_key_length, TLS_MD_KEY_EXPANSION_CONST,
-                  TLS_MD_KEY_EXPANSION_CONST_SIZE, ssl->s3->server_random,
-                  SSL3_RANDOM_SIZE, ssl->s3->client_random, SSL3_RANDOM_SIZE);
+  return tls1_prf(digest, out_span, master_key, label, ssl->s3->server_random,
+                  ssl->s3->client_random);
 }
 
 int SSL_export_keying_material(SSL *ssl, uint8_t *out, size_t out_len,
@@ -497,26 +479,25 @@
     }
     seed_len += 2 + context_len;
   }
-  uint8_t *seed = (uint8_t *)OPENSSL_malloc(seed_len);
-  if (seed == NULL) {
+  Array<uint8_t> seed;
+  if (!seed.Init(seed_len)) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
     return 0;
   }
 
-  OPENSSL_memcpy(seed, ssl->s3->client_random, SSL3_RANDOM_SIZE);
-  OPENSSL_memcpy(seed + SSL3_RANDOM_SIZE, ssl->s3->server_random,
+  OPENSSL_memcpy(seed.data(), ssl->s3->client_random, SSL3_RANDOM_SIZE);
+  OPENSSL_memcpy(seed.data() + SSL3_RANDOM_SIZE, ssl->s3->server_random,
                  SSL3_RANDOM_SIZE);
   if (use_context) {
-    seed[2 * SSL3_RANDOM_SIZE] = (uint8_t)(context_len >> 8);
-    seed[2 * SSL3_RANDOM_SIZE + 1] = (uint8_t)context_len;
-    OPENSSL_memcpy(seed + 2 * SSL3_RANDOM_SIZE + 2, context, context_len);
+    seed[2 * SSL3_RANDOM_SIZE] = static_cast<uint8_t>(context_len >> 8);
+    seed[2 * SSL3_RANDOM_SIZE + 1] = static_cast<uint8_t>(context_len);
+    OPENSSL_memcpy(seed.data() + 2 * SSL3_RANDOM_SIZE + 2, context, context_len);
   }
 
   const SSL_SESSION *session = SSL_get_session(ssl);
   const EVP_MD *digest = ssl_session_get_digest(session);
-  int ret = tls1_prf(digest, out, out_len, session->master_key,
-                     session->master_key_length, label, label_len, seed,
-                     seed_len, NULL, 0);
-  OPENSSL_free(seed);
-  return ret;
+  return tls1_prf(
+      digest, MakeSpan(out, out_len),
+      MakeConstSpan(session->master_key, session->master_key_length),
+      MakeConstSpan(label, label_len), seed, {});
 }