Add SSL_SESSION_{get,set}_protocol_version.

SSL_SESSION_set_protocol_version is useful when unit-testing a session
cache.

Change-Id: I4b04e31d61ce40739323248e3e5fdae498c4645e
Reviewed-on: https://boringssl-review.googlesource.com/21044
Commit-Queue: Steven Valdez <svaldez@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/internal.h b/ssl/internal.h
index 9a62b46..4247425 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2229,12 +2229,13 @@
 int ssl_session_is_resumable(const SSL_HANDSHAKE *hs,
                              const SSL_SESSION *session);
 
-// SSL_SESSION_protocol_version returns the protocol version associated with
-// |session|.
-uint16_t SSL_SESSION_protocol_version(const SSL_SESSION *session);
+// ssl_session_protocol_version returns the protocol version associated with
+// |session|. Note that despite the name, this is not the same as
+// |SSL_SESSION_get_protocol_version|. The latter is based on upstream's name.
+uint16_t ssl_session_protocol_version(const SSL_SESSION *session);
 
-// SSL_SESSION_get_digest returns the digest used in |session|.
-const EVP_MD *SSL_SESSION_get_digest(const SSL_SESSION *session);
+// ssl_session_get_digest returns the digest used in |session|.
+const EVP_MD *ssl_session_get_digest(const SSL_SESSION *session);
 
 void ssl_set_session(SSL *ssl, SSL_SESSION *session);
 
diff --git a/ssl/ssl_session.cc b/ssl/ssl_session.cc
index ed4c779..aff9c0f 100644
--- a/ssl/ssl_session.cc
+++ b/ssl/ssl_session.cc
@@ -340,7 +340,7 @@
   }
 }
 
-uint16_t SSL_SESSION_protocol_version(const SSL_SESSION *session) {
+uint16_t ssl_session_protocol_version(const SSL_SESSION *session) {
   uint16_t ret;
   if (!ssl_protocol_version_from_wire(&ret, session->ssl_version)) {
     // An |SSL_SESSION| will never have an invalid version. This is enforced by
@@ -352,8 +352,8 @@
   return ret;
 }
 
-const EVP_MD *SSL_SESSION_get_digest(const SSL_SESSION *session) {
-  return ssl_get_handshake_digest(SSL_SESSION_protocol_version(session),
+const EVP_MD *ssl_session_get_digest(const SSL_SESSION *session) {
+  return ssl_get_handshake_digest(ssl_session_protocol_version(session),
                                   session->cipher);
 }
 
@@ -967,7 +967,7 @@
 }
 
 int SSL_SESSION_should_be_single_use(const SSL_SESSION *session) {
-  return SSL_SESSION_protocol_version(session) >= TLS1_3_VERSION;
+  return ssl_session_protocol_version(session) >= TLS1_3_VERSION;
 }
 
 int SSL_SESSION_is_resumable(const SSL_SESSION *session) {
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 4653f73..f9cf83c 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -3740,6 +3740,26 @@
   EXPECT_EQ(SSL_R_NO_CIPHERS_AVAILABLE, ERR_GET_REASON(err));
 }
 
+TEST_P(SSLVersionTest, SessionVersion) {
+  SSL_CTX_set_session_cache_mode(client_ctx_.get(), SSL_SESS_CACHE_BOTH);
+  SSL_CTX_set_session_cache_mode(server_ctx_.get(), SSL_SESS_CACHE_BOTH);
+
+  bssl::UniquePtr<SSL_SESSION> session =
+      CreateClientSession(client_ctx_.get(), server_ctx_.get());
+  ASSERT_TRUE(session);
+  EXPECT_EQ(version(), SSL_SESSION_get_protocol_version(session.get()));
+
+  // Sessions in TLS 1.3 and later should be single-use.
+  EXPECT_EQ(version() == TLS1_3_VERSION,
+            !!SSL_SESSION_should_be_single_use(session.get()));
+
+  // Making fake sessions for testing works.
+  session.reset(SSL_SESSION_new(client_ctx_.get()));
+  ASSERT_TRUE(session);
+  ASSERT_TRUE(SSL_SESSION_set_protocol_version(session.get(), version()));
+  EXPECT_EQ(version(), SSL_SESSION_get_protocol_version(session.get()));
+}
+
 // TODO(davidben): Convert this file to GTest properly.
 TEST(SSLTest, AllTests) {
   if (!TestSSL_SESSIONEncoding(kOpenSSLSession) ||
diff --git a/ssl/ssl_versions.cc b/ssl/ssl_versions.cc
index c06c5ab..56653b1 100644
--- a/ssl/ssl_versions.cc
+++ b/ssl/ssl_versions.cc
@@ -98,23 +98,83 @@
   return false;
 }
 
-static bool set_version_bound(const SSL_PROTOCOL_METHOD *method, uint16_t *out,
-                              uint16_t version) {
-  // The public API uses wire versions, except we use |TLS1_3_VERSION|
-  // everywhere to refer to any draft TLS 1.3 versions. In this direction, we
-  // map it to some representative TLS 1.3 draft version.
+// The following functions map between API versions and wire versions. The
+// public API works on wire versions, except that TLS 1.3 draft versions all
+// appear as TLS 1.3. This will get collapsed back down when TLS 1.3 is
+// finalized.
+
+static const char *ssl_version_to_string(uint16_t version) {
+  switch (version) {
+    case TLS1_3_DRAFT_VERSION:
+    case TLS1_3_EXPERIMENT_VERSION:
+    case TLS1_3_EXPERIMENT2_VERSION:
+    case TLS1_3_EXPERIMENT3_VERSION:
+      return "TLSv1.3";
+
+    case TLS1_2_VERSION:
+      return "TLSv1.2";
+
+    case TLS1_1_VERSION:
+      return "TLSv1.1";
+
+    case TLS1_VERSION:
+      return "TLSv1";
+
+    case SSL3_VERSION:
+      return "SSLv3";
+
+    case DTLS1_VERSION:
+      return "DTLSv1";
+
+    case DTLS1_2_VERSION:
+      return "DTLSv1.2";
+
+    default:
+      return "unknown";
+  }
+}
+
+static uint16_t wire_version_to_api(uint16_t version) {
+  switch (version) {
+    // Report TLS 1.3 draft versions as TLS 1.3 in the public API.
+    case TLS1_3_DRAFT_VERSION:
+    case TLS1_3_EXPERIMENT_VERSION:
+    case TLS1_3_EXPERIMENT2_VERSION:
+    case TLS1_3_EXPERIMENT3_VERSION:
+      return TLS1_3_VERSION;
+    default:
+      return version;
+  }
+}
+
+// api_version_to_wire maps |version| to some representative wire version. In
+// particular, it picks an arbitrary TLS 1.3 representative. This should only be
+// used in context where that does not matter.
+static bool api_version_to_wire(uint16_t *out, uint16_t version) {
   if (version == TLS1_3_DRAFT_VERSION ||
       version == TLS1_3_EXPERIMENT_VERSION ||
       version == TLS1_3_EXPERIMENT2_VERSION ||
       version == TLS1_3_EXPERIMENT3_VERSION) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_SSL_VERSION);
     return false;
   }
   if (version == TLS1_3_VERSION) {
     version = TLS1_3_DRAFT_VERSION;
   }
 
-  if (!method_supports_version(method, version) ||
+  // Check it is a real protocol version.
+  uint16_t unused;
+  if (!ssl_protocol_version_from_wire(&unused, version)) {
+    return false;
+  }
+
+  *out = version;
+  return true;
+}
+
+static bool set_version_bound(const SSL_PROTOCOL_METHOD *method, uint16_t *out,
+                              uint16_t version) {
+  if (!api_version_to_wire(&version, version) ||
+      !method_supports_version(method, version) ||
       !ssl_protocol_version_from_wire(out, version)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_SSL_VERSION);
     return false;
@@ -227,38 +287,6 @@
   return ssl->version;
 }
 
-static const char *ssl_version_to_string(uint16_t version) {
-  switch (version) {
-    // Report TLS 1.3 draft version as TLS 1.3 in the public API.
-    case TLS1_3_DRAFT_VERSION:
-    case TLS1_3_EXPERIMENT_VERSION:
-    case TLS1_3_EXPERIMENT2_VERSION:
-    case TLS1_3_EXPERIMENT3_VERSION:
-      return "TLSv1.3";
-
-    case TLS1_2_VERSION:
-      return "TLSv1.2";
-
-    case TLS1_1_VERSION:
-      return "TLSv1.1";
-
-    case TLS1_VERSION:
-      return "TLSv1";
-
-    case SSL3_VERSION:
-      return "SSLv3";
-
-    case DTLS1_VERSION:
-      return "DTLSv1";
-
-    case DTLS1_2_VERSION:
-      return "DTLSv1.2";
-
-    default:
-      return "unknown";
-  }
-}
-
 uint16_t ssl3_protocol_version(const SSL *ssl) {
   assert(ssl->s3->have_version);
   uint16_t version;
@@ -389,15 +417,7 @@
 }
 
 int SSL_version(const SSL *ssl) {
-  uint16_t ret = ssl_version(ssl);
-  // Report TLS 1.3 draft version as TLS 1.3 in the public API.
-  if (ret == TLS1_3_DRAFT_VERSION ||
-      ret == TLS1_3_EXPERIMENT_VERSION ||
-      ret == TLS1_3_EXPERIMENT2_VERSION ||
-      ret == TLS1_3_EXPERIMENT3_VERSION) {
-    return TLS1_3_VERSION;
-  }
-  return ret;
+  return wire_version_to_api(ssl_version(ssl));
 }
 
 const char *SSL_get_version(const SSL *ssl) {
@@ -407,3 +427,13 @@
 const char *SSL_SESSION_get_version(const SSL_SESSION *session) {
   return ssl_version_to_string(session->ssl_version);
 }
+
+uint16_t SSL_SESSION_get_protocol_version(const SSL_SESSION *session) {
+  return wire_version_to_api(session->ssl_version);
+}
+
+int SSL_SESSION_set_protocol_version(SSL_SESSION *session, uint16_t version) {
+  // This picks a representative TLS 1.3 version, but this API should only be
+  // used on unit test sessions anyway.
+  return api_version_to_wire(&session->ssl_version, version);
+}
diff --git a/ssl/t1_enc.cc b/ssl/t1_enc.cc
index 8f8d328..85c368c 100644
--- a/ssl/t1_enc.cc
+++ b/ssl/t1_enc.cc
@@ -462,7 +462,7 @@
                     SSL3_RANDOM_SIZE, ssl->s3->client_random, SSL3_RANDOM_SIZE);
   }
 
-  const EVP_MD *digest = SSL_SESSION_get_digest(session);
+  const EVP_MD *digest = ssl_session_get_digest(session);
   return tls1_prf(digest, out, out_len, session->master_key,
                   session->master_key_length, TLS_MD_KEY_EXPANSION_CONST,
                   TLS_MD_KEY_EXPANSION_CONST_SIZE, ssl->s3->server_random,
@@ -511,7 +511,7 @@
   }
 
   const SSL_SESSION *session = SSL_get_session(ssl);
-  const EVP_MD *digest = SSL_SESSION_get_digest(session);
+  const EVP_MD *digest = ssl_session_get_digest(session);
   int ret = tls1_prf(digest, out, out_len, session->master_key,
                      session->master_key_length, label, label_len, seed,
                      seed_len, NULL, 0);
diff --git a/ssl/t1_lib.cc b/ssl/t1_lib.cc
index 3b0a335..9c4231d 100644
--- a/ssl/t1_lib.cc
+++ b/ssl/t1_lib.cc
@@ -930,7 +930,7 @@
       ssl->session != NULL &&
       ssl->session->tlsext_tick != NULL &&
       // Don't send TLS 1.3 session tickets in the ticket extension.
-      SSL_SESSION_protocol_version(ssl->session) < TLS1_3_VERSION) {
+      ssl_session_protocol_version(ssl->session) < TLS1_3_VERSION) {
     ticket_data = ssl->session->tlsext_tick;
     ticket_len = ssl->session->tlsext_ticklen;
   }
@@ -1808,18 +1808,18 @@
 static size_t ext_pre_shared_key_clienthello_length(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
   if (hs->max_version < TLS1_3_VERSION || ssl->session == NULL ||
-      SSL_SESSION_protocol_version(ssl->session) < TLS1_3_VERSION) {
+      ssl_session_protocol_version(ssl->session) < TLS1_3_VERSION) {
     return 0;
   }
 
-  size_t binder_len = EVP_MD_size(SSL_SESSION_get_digest(ssl->session));
+  size_t binder_len = EVP_MD_size(ssl_session_get_digest(ssl->session));
   return 15 + ssl->session->tlsext_ticklen + binder_len;
 }
 
 static int ext_pre_shared_key_add_clienthello(SSL_HANDSHAKE *hs, CBB *out) {
   SSL *const ssl = hs->ssl;
   if (hs->max_version < TLS1_3_VERSION || ssl->session == NULL ||
-      SSL_SESSION_protocol_version(ssl->session) < TLS1_3_VERSION) {
+      ssl_session_protocol_version(ssl->session) < TLS1_3_VERSION) {
     return 1;
   }
 
@@ -1831,7 +1831,7 @@
   // Fill in a placeholder zero binder of the appropriate length. It will be
   // computed and filled in later after length prefixes are computed.
   uint8_t zero_binder[EVP_MAX_MD_SIZE] = {0};
-  size_t binder_len = EVP_MD_size(SSL_SESSION_get_digest(ssl->session));
+  size_t binder_len = EVP_MD_size(ssl_session_get_digest(ssl->session));
 
   CBB contents, identity, ticket, binders, binder;
   if (!CBB_add_u16(out, TLSEXT_TYPE_pre_shared_key) ||
@@ -1997,7 +1997,7 @@
 static int ext_early_data_add_clienthello(SSL_HANDSHAKE *hs, CBB *out) {
   SSL *const ssl = hs->ssl;
   if (ssl->session == NULL ||
-      SSL_SESSION_protocol_version(ssl->session) < TLS1_3_VERSION ||
+      ssl_session_protocol_version(ssl->session) < TLS1_3_VERSION ||
       ssl->session->ticket_max_early_data == 0 ||
       hs->received_hello_retry_request ||
       !ssl->cert->enable_early_data) {
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index b68a39e..854fae0 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -57,7 +57,7 @@
 
 int tls13_init_early_key_schedule(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  return init_key_schedule(hs, SSL_SESSION_protocol_version(ssl->session),
+  return init_key_schedule(hs, ssl_session_protocol_version(ssl->session),
                            ssl->session->cipher);
 }
 
@@ -116,7 +116,7 @@
                           const uint8_t *traffic_secret,
                           size_t traffic_secret_len) {
   const SSL_SESSION *session = SSL_get_session(ssl);
-  uint16_t version = SSL_SESSION_protocol_version(session);
+  uint16_t version = ssl_session_protocol_version(session);
 
   if (traffic_secret_len > 0xff) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
@@ -131,7 +131,7 @@
     return 0;
   }
 
-  const EVP_MD *digest = SSL_SESSION_get_digest(session);
+  const EVP_MD *digest = ssl_session_get_digest(session);
 
   // Derive the key.
   size_t key_len = EVP_AEAD_key_length(aead);
@@ -253,7 +253,7 @@
     secret_len = ssl->s3->write_traffic_secret_len;
   }
 
-  const EVP_MD *digest = SSL_SESSION_get_digest(SSL_get_session(ssl));
+  const EVP_MD *digest = ssl_session_get_digest(SSL_get_session(ssl));
   if (!hkdf_expand_label(secret, digest, secret, secret_len,
                          (const uint8_t *)kTLS13LabelApplicationTraffic,
                          strlen(kTLS13LabelApplicationTraffic), NULL, 0,
@@ -328,7 +328,7 @@
     hash_len = context_len;
   }
 
-  const EVP_MD *digest = SSL_SESSION_get_digest(SSL_get_session(ssl));
+  const EVP_MD *digest = ssl_session_get_digest(SSL_get_session(ssl));
   return hkdf_expand_label(out, digest, ssl->s3->exporter_secret,
                            ssl->s3->exporter_secret_len, (const uint8_t *)label,
                            label_len, hash, hash_len, out_len);
@@ -368,7 +368,7 @@
 
 int tls13_write_psk_binder(SSL_HANDSHAKE *hs, uint8_t *msg, size_t len) {
   SSL *const ssl = hs->ssl;
-  const EVP_MD *digest = SSL_SESSION_get_digest(ssl->session);
+  const EVP_MD *digest = ssl_session_get_digest(ssl->session);
   size_t hash_len = EVP_MD_size(digest);
 
   if (len < hash_len + 3) {