Add SSL_SESSION_{get,set}_protocol_version.
SSL_SESSION_set_protocol_version is useful when unit-testing a session
cache.
Change-Id: I4b04e31d61ce40739323248e3e5fdae498c4645e
Reviewed-on: https://boringssl-review.googlesource.com/21044
Commit-Queue: Steven Valdez <svaldez@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/internal.h b/ssl/internal.h
index 9a62b46..4247425 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2229,12 +2229,13 @@
int ssl_session_is_resumable(const SSL_HANDSHAKE *hs,
const SSL_SESSION *session);
-// SSL_SESSION_protocol_version returns the protocol version associated with
-// |session|.
-uint16_t SSL_SESSION_protocol_version(const SSL_SESSION *session);
+// ssl_session_protocol_version returns the protocol version associated with
+// |session|. Note that despite the name, this is not the same as
+// |SSL_SESSION_get_protocol_version|. The latter is based on upstream's name.
+uint16_t ssl_session_protocol_version(const SSL_SESSION *session);
-// SSL_SESSION_get_digest returns the digest used in |session|.
-const EVP_MD *SSL_SESSION_get_digest(const SSL_SESSION *session);
+// ssl_session_get_digest returns the digest used in |session|.
+const EVP_MD *ssl_session_get_digest(const SSL_SESSION *session);
void ssl_set_session(SSL *ssl, SSL_SESSION *session);
diff --git a/ssl/ssl_session.cc b/ssl/ssl_session.cc
index ed4c779..aff9c0f 100644
--- a/ssl/ssl_session.cc
+++ b/ssl/ssl_session.cc
@@ -340,7 +340,7 @@
}
}
-uint16_t SSL_SESSION_protocol_version(const SSL_SESSION *session) {
+uint16_t ssl_session_protocol_version(const SSL_SESSION *session) {
uint16_t ret;
if (!ssl_protocol_version_from_wire(&ret, session->ssl_version)) {
// An |SSL_SESSION| will never have an invalid version. This is enforced by
@@ -352,8 +352,8 @@
return ret;
}
-const EVP_MD *SSL_SESSION_get_digest(const SSL_SESSION *session) {
- return ssl_get_handshake_digest(SSL_SESSION_protocol_version(session),
+const EVP_MD *ssl_session_get_digest(const SSL_SESSION *session) {
+ return ssl_get_handshake_digest(ssl_session_protocol_version(session),
session->cipher);
}
@@ -967,7 +967,7 @@
}
int SSL_SESSION_should_be_single_use(const SSL_SESSION *session) {
- return SSL_SESSION_protocol_version(session) >= TLS1_3_VERSION;
+ return ssl_session_protocol_version(session) >= TLS1_3_VERSION;
}
int SSL_SESSION_is_resumable(const SSL_SESSION *session) {
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 4653f73..f9cf83c 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -3740,6 +3740,26 @@
EXPECT_EQ(SSL_R_NO_CIPHERS_AVAILABLE, ERR_GET_REASON(err));
}
+TEST_P(SSLVersionTest, SessionVersion) {
+ SSL_CTX_set_session_cache_mode(client_ctx_.get(), SSL_SESS_CACHE_BOTH);
+ SSL_CTX_set_session_cache_mode(server_ctx_.get(), SSL_SESS_CACHE_BOTH);
+
+ bssl::UniquePtr<SSL_SESSION> session =
+ CreateClientSession(client_ctx_.get(), server_ctx_.get());
+ ASSERT_TRUE(session);
+ EXPECT_EQ(version(), SSL_SESSION_get_protocol_version(session.get()));
+
+ // Sessions in TLS 1.3 and later should be single-use.
+ EXPECT_EQ(version() == TLS1_3_VERSION,
+ !!SSL_SESSION_should_be_single_use(session.get()));
+
+ // Making fake sessions for testing works.
+ session.reset(SSL_SESSION_new(client_ctx_.get()));
+ ASSERT_TRUE(session);
+ ASSERT_TRUE(SSL_SESSION_set_protocol_version(session.get(), version()));
+ EXPECT_EQ(version(), SSL_SESSION_get_protocol_version(session.get()));
+}
+
// TODO(davidben): Convert this file to GTest properly.
TEST(SSLTest, AllTests) {
if (!TestSSL_SESSIONEncoding(kOpenSSLSession) ||
diff --git a/ssl/ssl_versions.cc b/ssl/ssl_versions.cc
index c06c5ab..56653b1 100644
--- a/ssl/ssl_versions.cc
+++ b/ssl/ssl_versions.cc
@@ -98,23 +98,83 @@
return false;
}
-static bool set_version_bound(const SSL_PROTOCOL_METHOD *method, uint16_t *out,
- uint16_t version) {
- // The public API uses wire versions, except we use |TLS1_3_VERSION|
- // everywhere to refer to any draft TLS 1.3 versions. In this direction, we
- // map it to some representative TLS 1.3 draft version.
+// The following functions map between API versions and wire versions. The
+// public API works on wire versions, except that TLS 1.3 draft versions all
+// appear as TLS 1.3. This will get collapsed back down when TLS 1.3 is
+// finalized.
+
+static const char *ssl_version_to_string(uint16_t version) {
+ switch (version) {
+ case TLS1_3_DRAFT_VERSION:
+ case TLS1_3_EXPERIMENT_VERSION:
+ case TLS1_3_EXPERIMENT2_VERSION:
+ case TLS1_3_EXPERIMENT3_VERSION:
+ return "TLSv1.3";
+
+ case TLS1_2_VERSION:
+ return "TLSv1.2";
+
+ case TLS1_1_VERSION:
+ return "TLSv1.1";
+
+ case TLS1_VERSION:
+ return "TLSv1";
+
+ case SSL3_VERSION:
+ return "SSLv3";
+
+ case DTLS1_VERSION:
+ return "DTLSv1";
+
+ case DTLS1_2_VERSION:
+ return "DTLSv1.2";
+
+ default:
+ return "unknown";
+ }
+}
+
+static uint16_t wire_version_to_api(uint16_t version) {
+ switch (version) {
+ // Report TLS 1.3 draft versions as TLS 1.3 in the public API.
+ case TLS1_3_DRAFT_VERSION:
+ case TLS1_3_EXPERIMENT_VERSION:
+ case TLS1_3_EXPERIMENT2_VERSION:
+ case TLS1_3_EXPERIMENT3_VERSION:
+ return TLS1_3_VERSION;
+ default:
+ return version;
+ }
+}
+
+// api_version_to_wire maps |version| to some representative wire version. In
+// particular, it picks an arbitrary TLS 1.3 representative. This should only be
+// used in context where that does not matter.
+static bool api_version_to_wire(uint16_t *out, uint16_t version) {
if (version == TLS1_3_DRAFT_VERSION ||
version == TLS1_3_EXPERIMENT_VERSION ||
version == TLS1_3_EXPERIMENT2_VERSION ||
version == TLS1_3_EXPERIMENT3_VERSION) {
- OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_SSL_VERSION);
return false;
}
if (version == TLS1_3_VERSION) {
version = TLS1_3_DRAFT_VERSION;
}
- if (!method_supports_version(method, version) ||
+ // Check it is a real protocol version.
+ uint16_t unused;
+ if (!ssl_protocol_version_from_wire(&unused, version)) {
+ return false;
+ }
+
+ *out = version;
+ return true;
+}
+
+static bool set_version_bound(const SSL_PROTOCOL_METHOD *method, uint16_t *out,
+ uint16_t version) {
+ if (!api_version_to_wire(&version, version) ||
+ !method_supports_version(method, version) ||
!ssl_protocol_version_from_wire(out, version)) {
OPENSSL_PUT_ERROR(SSL, SSL_R_UNKNOWN_SSL_VERSION);
return false;
@@ -227,38 +287,6 @@
return ssl->version;
}
-static const char *ssl_version_to_string(uint16_t version) {
- switch (version) {
- // Report TLS 1.3 draft version as TLS 1.3 in the public API.
- case TLS1_3_DRAFT_VERSION:
- case TLS1_3_EXPERIMENT_VERSION:
- case TLS1_3_EXPERIMENT2_VERSION:
- case TLS1_3_EXPERIMENT3_VERSION:
- return "TLSv1.3";
-
- case TLS1_2_VERSION:
- return "TLSv1.2";
-
- case TLS1_1_VERSION:
- return "TLSv1.1";
-
- case TLS1_VERSION:
- return "TLSv1";
-
- case SSL3_VERSION:
- return "SSLv3";
-
- case DTLS1_VERSION:
- return "DTLSv1";
-
- case DTLS1_2_VERSION:
- return "DTLSv1.2";
-
- default:
- return "unknown";
- }
-}
-
uint16_t ssl3_protocol_version(const SSL *ssl) {
assert(ssl->s3->have_version);
uint16_t version;
@@ -389,15 +417,7 @@
}
int SSL_version(const SSL *ssl) {
- uint16_t ret = ssl_version(ssl);
- // Report TLS 1.3 draft version as TLS 1.3 in the public API.
- if (ret == TLS1_3_DRAFT_VERSION ||
- ret == TLS1_3_EXPERIMENT_VERSION ||
- ret == TLS1_3_EXPERIMENT2_VERSION ||
- ret == TLS1_3_EXPERIMENT3_VERSION) {
- return TLS1_3_VERSION;
- }
- return ret;
+ return wire_version_to_api(ssl_version(ssl));
}
const char *SSL_get_version(const SSL *ssl) {
@@ -407,3 +427,13 @@
const char *SSL_SESSION_get_version(const SSL_SESSION *session) {
return ssl_version_to_string(session->ssl_version);
}
+
+uint16_t SSL_SESSION_get_protocol_version(const SSL_SESSION *session) {
+ return wire_version_to_api(session->ssl_version);
+}
+
+int SSL_SESSION_set_protocol_version(SSL_SESSION *session, uint16_t version) {
+ // This picks a representative TLS 1.3 version, but this API should only be
+ // used on unit test sessions anyway.
+ return api_version_to_wire(&session->ssl_version, version);
+}
diff --git a/ssl/t1_enc.cc b/ssl/t1_enc.cc
index 8f8d328..85c368c 100644
--- a/ssl/t1_enc.cc
+++ b/ssl/t1_enc.cc
@@ -462,7 +462,7 @@
SSL3_RANDOM_SIZE, ssl->s3->client_random, SSL3_RANDOM_SIZE);
}
- const EVP_MD *digest = SSL_SESSION_get_digest(session);
+ const EVP_MD *digest = ssl_session_get_digest(session);
return tls1_prf(digest, out, out_len, session->master_key,
session->master_key_length, TLS_MD_KEY_EXPANSION_CONST,
TLS_MD_KEY_EXPANSION_CONST_SIZE, ssl->s3->server_random,
@@ -511,7 +511,7 @@
}
const SSL_SESSION *session = SSL_get_session(ssl);
- const EVP_MD *digest = SSL_SESSION_get_digest(session);
+ const EVP_MD *digest = ssl_session_get_digest(session);
int ret = tls1_prf(digest, out, out_len, session->master_key,
session->master_key_length, label, label_len, seed,
seed_len, NULL, 0);
diff --git a/ssl/t1_lib.cc b/ssl/t1_lib.cc
index 3b0a335..9c4231d 100644
--- a/ssl/t1_lib.cc
+++ b/ssl/t1_lib.cc
@@ -930,7 +930,7 @@
ssl->session != NULL &&
ssl->session->tlsext_tick != NULL &&
// Don't send TLS 1.3 session tickets in the ticket extension.
- SSL_SESSION_protocol_version(ssl->session) < TLS1_3_VERSION) {
+ ssl_session_protocol_version(ssl->session) < TLS1_3_VERSION) {
ticket_data = ssl->session->tlsext_tick;
ticket_len = ssl->session->tlsext_ticklen;
}
@@ -1808,18 +1808,18 @@
static size_t ext_pre_shared_key_clienthello_length(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
if (hs->max_version < TLS1_3_VERSION || ssl->session == NULL ||
- SSL_SESSION_protocol_version(ssl->session) < TLS1_3_VERSION) {
+ ssl_session_protocol_version(ssl->session) < TLS1_3_VERSION) {
return 0;
}
- size_t binder_len = EVP_MD_size(SSL_SESSION_get_digest(ssl->session));
+ size_t binder_len = EVP_MD_size(ssl_session_get_digest(ssl->session));
return 15 + ssl->session->tlsext_ticklen + binder_len;
}
static int ext_pre_shared_key_add_clienthello(SSL_HANDSHAKE *hs, CBB *out) {
SSL *const ssl = hs->ssl;
if (hs->max_version < TLS1_3_VERSION || ssl->session == NULL ||
- SSL_SESSION_protocol_version(ssl->session) < TLS1_3_VERSION) {
+ ssl_session_protocol_version(ssl->session) < TLS1_3_VERSION) {
return 1;
}
@@ -1831,7 +1831,7 @@
// Fill in a placeholder zero binder of the appropriate length. It will be
// computed and filled in later after length prefixes are computed.
uint8_t zero_binder[EVP_MAX_MD_SIZE] = {0};
- size_t binder_len = EVP_MD_size(SSL_SESSION_get_digest(ssl->session));
+ size_t binder_len = EVP_MD_size(ssl_session_get_digest(ssl->session));
CBB contents, identity, ticket, binders, binder;
if (!CBB_add_u16(out, TLSEXT_TYPE_pre_shared_key) ||
@@ -1997,7 +1997,7 @@
static int ext_early_data_add_clienthello(SSL_HANDSHAKE *hs, CBB *out) {
SSL *const ssl = hs->ssl;
if (ssl->session == NULL ||
- SSL_SESSION_protocol_version(ssl->session) < TLS1_3_VERSION ||
+ ssl_session_protocol_version(ssl->session) < TLS1_3_VERSION ||
ssl->session->ticket_max_early_data == 0 ||
hs->received_hello_retry_request ||
!ssl->cert->enable_early_data) {
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index b68a39e..854fae0 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -57,7 +57,7 @@
int tls13_init_early_key_schedule(SSL_HANDSHAKE *hs) {
SSL *const ssl = hs->ssl;
- return init_key_schedule(hs, SSL_SESSION_protocol_version(ssl->session),
+ return init_key_schedule(hs, ssl_session_protocol_version(ssl->session),
ssl->session->cipher);
}
@@ -116,7 +116,7 @@
const uint8_t *traffic_secret,
size_t traffic_secret_len) {
const SSL_SESSION *session = SSL_get_session(ssl);
- uint16_t version = SSL_SESSION_protocol_version(session);
+ uint16_t version = ssl_session_protocol_version(session);
if (traffic_secret_len > 0xff) {
OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
@@ -131,7 +131,7 @@
return 0;
}
- const EVP_MD *digest = SSL_SESSION_get_digest(session);
+ const EVP_MD *digest = ssl_session_get_digest(session);
// Derive the key.
size_t key_len = EVP_AEAD_key_length(aead);
@@ -253,7 +253,7 @@
secret_len = ssl->s3->write_traffic_secret_len;
}
- const EVP_MD *digest = SSL_SESSION_get_digest(SSL_get_session(ssl));
+ const EVP_MD *digest = ssl_session_get_digest(SSL_get_session(ssl));
if (!hkdf_expand_label(secret, digest, secret, secret_len,
(const uint8_t *)kTLS13LabelApplicationTraffic,
strlen(kTLS13LabelApplicationTraffic), NULL, 0,
@@ -328,7 +328,7 @@
hash_len = context_len;
}
- const EVP_MD *digest = SSL_SESSION_get_digest(SSL_get_session(ssl));
+ const EVP_MD *digest = ssl_session_get_digest(SSL_get_session(ssl));
return hkdf_expand_label(out, digest, ssl->s3->exporter_secret,
ssl->s3->exporter_secret_len, (const uint8_t *)label,
label_len, hash, hash_len, out_len);
@@ -368,7 +368,7 @@
int tls13_write_psk_binder(SSL_HANDSHAKE *hs, uint8_t *msg, size_t len) {
SSL *const ssl = hs->ssl;
- const EVP_MD *digest = SSL_SESSION_get_digest(ssl->session);
+ const EVP_MD *digest = ssl_session_get_digest(ssl->session);
size_t hash_len = EVP_MD_size(digest);
if (len < hash_len + 3) {