Add min_version and max_version APIs.

Amend the version negotiation tests to test this new spelling of max_version.
min_version will be tested in a follow-up.

Change-Id: Ic4bfcd43bc4e5f951140966f64bb5fd3e2472b01
Reviewed-on: https://boringssl-review.googlesource.com/2583
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index b4cda2b..1bc9604 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -470,6 +470,8 @@
 /* SSL_OP_TLS_ROLLBACK_BUG does nothing. */
 #define SSL_OP_TLS_ROLLBACK_BUG				0x00800000L
 
+/* Deprecated: Use SSL_CTX_set_min_version and SSL_CTX_set_max_version
+ * instead. */
 #define SSL_OP_NO_SSLv2					0x01000000L
 #define SSL_OP_NO_SSLv3					0x02000000L
 #define SSL_OP_NO_TLSv1					0x04000000L
@@ -605,6 +607,22 @@
 #define SSL_clear_cert_flags(s,op) \
 	SSL_ctrl((s),SSL_CTRL_CLEAR_CERT_FLAGS,(op),NULL)
 
+/* SSL_CTX_set_min_version sets the minimum protocol version for |ctx| to
+ * |version|. */
+void SSL_CTX_set_min_version(SSL_CTX *ctx, uint16_t version);
+
+/* SSL_CTX_set_max_version sets the maximum protocol version for |ctx| to
+ * |version|. */
+void SSL_CTX_set_max_version(SSL_CTX *ctx, uint16_t version);
+
+/* SSL_set_min_version sets the minimum protocol version for |ssl| to
+ * |version|. */
+void SSL_set_min_version(SSL *ssl, uint16_t version);
+
+/* SSL_set_max_version sets the maximum protocol version for |ssl| to
+ * |version|. */
+void SSL_set_max_version(SSL *ssl, uint16_t version);
+
 OPENSSL_EXPORT void SSL_CTX_set_msg_callback(SSL_CTX *ctx, void (*cb)(int write_p, int version, int content_type, const void *buf, size_t len, SSL *ssl, void *arg));
 OPENSSL_EXPORT void SSL_set_msg_callback(SSL *ssl, void (*cb)(int write_p, int version, int content_type, const void *buf, size_t len, SSL *ssl, void *arg));
 #define SSL_CTX_set_msg_callback_arg(ctx, arg) SSL_CTX_ctrl((ctx), SSL_CTRL_SET_MSG_CALLBACK_ARG, 0, (arg))
@@ -723,6 +741,16 @@
 	{
 	const SSL_METHOD *method;
 
+	/* max_version is the maximum acceptable protocol version. If
+	 * zero, the maximum supported version, currently (D)TLS 1.2,
+	 * is used. */
+	uint16_t max_version;
+
+	/* min_version is the minimum acceptable protocl version. If
+	 * zero, the minimum supported version, currently SSL 3.0 and
+	 * DTLS 1.0, is used */
+	uint16_t min_version;
+
 	struct ssl_cipher_preference_list_st *cipher_list;
 	/* same as above but sorted for lookup */
 	STACK_OF(SSL_CIPHER) *cipher_list_by_id;
@@ -1178,6 +1206,14 @@
 	 * version. */
 	const SSL3_ENC_METHOD *enc_method;
 
+	/* max_version is the maximum acceptable protocol version. If zero, the
+	 * maximum supported version, currently (D)TLS 1.2, is used. */
+	uint16_t max_version;
+
+	/* min_version is the minimum acceptable protocl version. If zero, the
+	 * minimum supported version, currently SSL 3.0 and DTLS 1.0, is used */
+	uint16_t min_version;
+
 	/* There are 2 BIO's even though they are normally both the
 	 * same.  This is so data can be read and written to different
 	 * handlers */
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index 241bcc3..574b85f 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -246,6 +246,9 @@
 	if (s == NULL) goto err;
 	memset(s,0,sizeof(SSL));
 
+	s->min_version = ctx->min_version;
+	s->max_version = ctx->max_version;
+
 	s->options=ctx->options;
 	s->mode=ctx->mode;
 	s->max_cert_list=ctx->max_cert_list;
@@ -2878,6 +2881,26 @@
 	ctx->psk_server_callback = cb;
 	}
 
+void SSL_CTX_set_min_version(SSL_CTX *ctx, uint16_t version)
+	{
+	ctx->min_version = version;
+	}
+
+void SSL_CTX_set_max_version(SSL_CTX *ctx, uint16_t version)
+	{
+	ctx->max_version = version;
+	}
+
+void SSL_set_min_version(SSL *ssl, uint16_t version)
+	{
+	ssl->min_version = version;
+	}
+
+void SSL_set_max_version(SSL *ssl, uint16_t version)
+	{
+	ssl->max_version = version;
+	}
+
 void SSL_CTX_set_msg_callback(SSL_CTX *ctx, void (*cb)(int write_p, int version, int content_type, const void *buf, size_t len, SSL *ssl, void *arg))
 	{
 	SSL_CTX_callback_ctrl(ctx, SSL_CTRL_SET_MSG_CALLBACK, (void (*)(void))cb);
@@ -3094,47 +3117,69 @@
 
 uint16_t ssl3_get_max_server_version(const SSL *s)
 	{
+	uint16_t max_version;
+
 	if (SSL_IS_DTLS(s))
 		{
-		if (!(s->options & SSL_OP_NO_DTLSv1_2))
+		max_version = (s->max_version != 0) ? s->max_version : DTLS1_2_VERSION;
+		if (!(s->options & SSL_OP_NO_DTLSv1_2) && DTLS1_2_VERSION >= max_version)
 			return DTLS1_2_VERSION;
-		if (!(s->options & SSL_OP_NO_DTLSv1))
+		if (!(s->options & SSL_OP_NO_DTLSv1) && DTLS1_VERSION >= max_version)
 			return DTLS1_VERSION;
 		return 0;
 		}
 
-	if (!(s->options & SSL_OP_NO_TLSv1_2))
+	max_version = (s->max_version != 0) ? s->max_version : TLS1_2_VERSION;
+	if (!(s->options & SSL_OP_NO_TLSv1_2) && TLS1_2_VERSION <= max_version)
 		return TLS1_2_VERSION;
-	if (!(s->options & SSL_OP_NO_TLSv1_1))
+	if (!(s->options & SSL_OP_NO_TLSv1_1) && TLS1_1_VERSION <= max_version)
 		return TLS1_1_VERSION;
-	if (!(s->options & SSL_OP_NO_TLSv1))
+	if (!(s->options & SSL_OP_NO_TLSv1) && TLS1_VERSION <= max_version)
 		return TLS1_VERSION;
-	if (!(s->options & SSL_OP_NO_SSLv3))
+	if (!(s->options & SSL_OP_NO_SSLv3) && SSL3_VERSION <= max_version)
 		return SSL3_VERSION;
 	return 0;
 	}
 
 uint16_t ssl3_get_mutual_version(SSL *s, uint16_t client_version)
 	{
+	uint16_t version = 0;
+
 	if (SSL_IS_DTLS(s))
 		{
+		/* Clamp client_version to max_version. */
+		if (s->max_version != 0 && client_version < s->max_version)
+			client_version = s->max_version;
+
 		if (client_version <= DTLS1_2_VERSION && !(s->options & SSL_OP_NO_DTLSv1_2))
-			return DTLS1_2_VERSION;
-		if (client_version <= DTLS1_VERSION && !(s->options & SSL_OP_NO_DTLSv1))
-			return DTLS1_VERSION;
-		return 0;
+			version = DTLS1_2_VERSION;
+		else if (client_version <= DTLS1_VERSION && !(s->options & SSL_OP_NO_DTLSv1))
+			version = DTLS1_VERSION;
+
+		/* Check against min_version. */
+		if (version != 0 && s->min_version != 0 && version > s->min_version)
+			return 0;
+		return version;
 		}
 	else
 		{
+		/* Clamp client_version to max_version. */
+		if (s->max_version != 0 && client_version > s->max_version)
+			client_version = s->max_version;
+
 		if (client_version >= TLS1_2_VERSION && !(s->options & SSL_OP_NO_TLSv1_2))
-			return TLS1_2_VERSION;
-		if (client_version >= TLS1_1_VERSION && !(s->options & SSL_OP_NO_TLSv1_1))
-			return TLS1_1_VERSION;
-		if (client_version >= TLS1_VERSION && !(s->options & SSL_OP_NO_TLSv1))
-			return TLS1_VERSION;
-		if (client_version >= SSL3_VERSION && !(s->options & SSL_OP_NO_SSLv3))
-			return SSL3_VERSION;
-		return 0;
+			version =  TLS1_2_VERSION;
+		else if (client_version >= TLS1_1_VERSION && !(s->options & SSL_OP_NO_TLSv1_1))
+			version = TLS1_1_VERSION;
+		else if (client_version >= TLS1_VERSION && !(s->options & SSL_OP_NO_TLSv1))
+			version = TLS1_VERSION;
+		else if (client_version >= SSL3_VERSION && !(s->options & SSL_OP_NO_SSLv3))
+			version = SSL3_VERSION;
+
+		/* Check against min_version. */
+		if (version != 0 && s->min_version != 0 && version < s->min_version)
+			return 0;
+		return version;
 		}
 	}
 
@@ -3155,16 +3200,15 @@
 	 * set a maximum version of TLS 1.2 in a future-proof way.
 	 *
 	 * By this scheme, the maximum version is the lowest version V such that
-	 * V is enabled and V+1 is disabled or unimplemented.
-	 *
-	 * TODO(davidben): Deprecate this API in favor of more sensible
-	 * min_version/max_version settings. */
+	 * V is enabled and V+1 is disabled or unimplemented. */
 	if (SSL_IS_DTLS(s))
 		{
 		if (!(options & SSL_OP_NO_DTLSv1_2))
 			version = DTLS1_2_VERSION;
 		if (!(options & SSL_OP_NO_DTLSv1) && (options & SSL_OP_NO_DTLSv1_2))
 			version = DTLS1_VERSION;
+		if (s->max_version != 0 && version < s->max_version)
+			version = s->max_version;
 		}
 	else
 		{
@@ -3176,6 +3220,8 @@
 			version = TLS1_VERSION;
 		if (!(options & SSL_OP_NO_SSLv3) && (options & SSL_OP_NO_TLSv1))
 			version = SSL3_VERSION;
+		if (s->max_version != 0 && version > s->max_version)
+			version = s->max_version;
 		}
 
 	return version;
@@ -3185,6 +3231,10 @@
 	{
 	if (SSL_IS_DTLS(s))
 		{
+		if (s->max_version != 0 && version < s->max_version)
+			return 0;
+		if (s->min_version != 0 && version > s->min_version)
+			return 0;
 		switch (version)
 			{
 		case DTLS1_VERSION:
@@ -3197,6 +3247,10 @@
 		}
 	else
 		{
+		if (s->max_version != 0 && version > s->max_version)
+			return 0;
+		if (s->min_version != 0 && version < s->min_version)
+			return 0;
 		switch (version)
 			{
 		case SSL3_VERSION:
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index e04e44b..3d78c1c 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -430,6 +430,12 @@
     return 1;
   }
   SSL_enable_fastradio_padding(ssl, config->fastradio_padding);
+  if (config->min_version != 0) {
+    SSL_set_min_version(ssl, (uint16_t)config->min_version);
+  }
+  if (config->max_version != 0) {
+    SSL_set_max_version(ssl, (uint16_t)config->max_version);
+  }
 
   BIO *bio = BIO_new_fd(fd, 1 /* take ownership */);
   if (bio == NULL) {
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 5605966..2b91f43 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -1639,6 +1639,8 @@
 					suffix += "-DTLS"
 				}
 
+				shimVersFlag := strconv.Itoa(int(versionToWire(shimVers.version, protocol == dtls)))
+
 				clientVers := shimVers.version
 				if clientVers > VersionTLS10 {
 					clientVers = VersionTLS10
@@ -1656,6 +1658,19 @@
 					flags:           flags,
 					expectedVersion: expectedVersion,
 				})
+				testCases = append(testCases, testCase{
+					protocol: protocol,
+					testType: clientTest,
+					name:     "VersionNegotiation-Client2-" + suffix,
+					config: Config{
+						MaxVersion: runnerVers.version,
+						Bugs: ProtocolBugs{
+							ExpectInitialRecordVersion: clientVers,
+						},
+					},
+					flags:           []string{"-max-version", shimVersFlag},
+					expectedVersion: expectedVersion,
+				})
 
 				testCases = append(testCases, testCase{
 					protocol: protocol,
@@ -1670,6 +1685,19 @@
 					flags:           flags,
 					expectedVersion: expectedVersion,
 				})
+				testCases = append(testCases, testCase{
+					protocol: protocol,
+					testType: serverTest,
+					name:     "VersionNegotiation-Server2-" + suffix,
+					config: Config{
+						MaxVersion: runnerVers.version,
+						Bugs: ProtocolBugs{
+							ExpectInitialRecordVersion: expectedVersion,
+						},
+					},
+					flags:           []string{"-max-version", shimVersFlag},
+					expectedVersion: expectedVersion,
+				})
 			}
 		}
 	}
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index 59874ef..a678a69 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -15,6 +15,7 @@
 #include "test_config.h"
 
 #include <stdio.h>
+#include <stdlib.h>
 #include <string.h>
 
 #include <memory>
@@ -23,20 +24,26 @@
 
 namespace {
 
-typedef bool TestConfig::*BoolMember;
-typedef std::string TestConfig::*StringMember;
-
-struct BoolFlag {
+template <typename T>
+struct Flag {
   const char *flag;
-  BoolMember member;
+  T TestConfig::*member;
 };
 
-struct StringFlag {
-  const char *flag;
-  StringMember member;
-};
+// FindField looks for the flag in |flags| that matches |flag|. If one is found,
+// it returns a pointer to the corresponding field in |config|. Otherwise, it
+// returns NULL.
+template<typename T, size_t N>
+T *FindField(TestConfig *config, const Flag<T> (&flags)[N], const char *flag) {
+  for (size_t i = 0; i < N; i++) {
+    if (strcmp(flag, flags[i].flag) == 0) {
+      return &(config->*(flags[i].member));
+    }
+  }
+  return NULL;
+}
 
-const BoolFlag kBoolFlags[] = {
+const Flag<bool> kBoolFlags[] = {
   { "-server", &TestConfig::is_server },
   { "-dtls", &TestConfig::is_dtls },
   { "-resume", &TestConfig::resume },
@@ -68,9 +75,7 @@
   { "-fastradio-padding", &TestConfig::fastradio_padding },
 };
 
-const size_t kNumBoolFlags = sizeof(kBoolFlags) / sizeof(kBoolFlags[0]);
-
-const StringFlag kStringFlags[] = {
+const Flag<std::string> kStringFlags[] = {
   { "-key-file", &TestConfig::key_file },
   { "-cert-file", &TestConfig::cert_file },
   { "-expect-server-name", &TestConfig::expected_server_name },
@@ -88,9 +93,7 @@
   { "-srtp-profiles", &TestConfig::srtp_profiles },
 };
 
-const size_t kNumStringFlags = sizeof(kStringFlags) / sizeof(kStringFlags[0]);
-
-const StringFlag kBase64Flags[] = {
+const Flag<std::string> kBase64Flags[] = {
   { "-expect-certificate-types", &TestConfig::expected_certificate_types },
   { "-expect-channel-id", &TestConfig::expected_channel_id },
   { "-expect-ocsp-response", &TestConfig::expected_ocsp_response },
@@ -98,7 +101,10 @@
     &TestConfig::expected_signed_cert_timestamps },
 };
 
-const size_t kNumBase64Flags = sizeof(kBase64Flags) / sizeof(kBase64Flags[0]);
+const Flag<int> kIntFlags[] = {
+  { "-min-version", &TestConfig::min_version },
+  { "-max-version", &TestConfig::max_version },
+};
 
 }  // namespace
 
@@ -126,43 +132,32 @@
       allow_unsafe_legacy_renegotiation(false),
       enable_ocsp_stapling(false),
       enable_signed_cert_timestamps(false),
-      fastradio_padding(false) {
+      fastradio_padding(false),
+      min_version(0),
+      max_version(0) {
 }
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_config) {
   for (int i = 0; i < argc; i++) {
-    size_t j;
-    for (j = 0; j < kNumBoolFlags; j++) {
-      if (strcmp(argv[i], kBoolFlags[j].flag) == 0) {
-        break;
-      }
-    }
-    if (j < kNumBoolFlags) {
-      out_config->*(kBoolFlags[j].member) = true;
+    bool *bool_field = FindField(out_config, kBoolFlags, argv[i]);
+    if (bool_field != NULL) {
+      *bool_field = true;
       continue;
     }
 
-    for (j = 0; j < kNumStringFlags; j++) {
-      if (strcmp(argv[i], kStringFlags[j].flag) == 0) {
-        break;
-      }
-    }
-    if (j < kNumStringFlags) {
+    std::string *string_field = FindField(out_config, kStringFlags, argv[i]);
+    if (string_field != NULL) {
       i++;
       if (i >= argc) {
         fprintf(stderr, "Missing parameter\n");
         return false;
       }
-      out_config->*(kStringFlags[j].member) = argv[i];
+      string_field->assign(argv[i]);
       continue;
     }
 
-    for (j = 0; j < kNumBase64Flags; j++) {
-      if (strcmp(argv[i], kBase64Flags[j].flag) == 0) {
-        break;
-      }
-    }
-    if (j < kNumBase64Flags) {
+    std::string *base64_field = FindField(out_config, kBase64Flags, argv[i]);
+    if (base64_field != NULL) {
       i++;
       if (i >= argc) {
         fprintf(stderr, "Missing parameter\n");
@@ -178,8 +173,18 @@
                             strlen(argv[i]))) {
         fprintf(stderr, "Invalid base64: %s\n", argv[i]);
       }
-      out_config->*(kBase64Flags[j].member) = std::string(
-          reinterpret_cast<const char *>(decoded.get()), len);
+      base64_field->assign(reinterpret_cast<const char *>(decoded.get()), len);
+      continue;
+    }
+
+    int *int_field = FindField(out_config, kIntFlags, argv[i]);
+    if (int_field) {
+      i++;
+      if (i >= argc) {
+        fprintf(stderr, "Missing parameter\n");
+        return false;
+      }
+      *int_field = atoi(argv[i]);
       continue;
     }
 
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index f778c28..c7fd136 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -64,6 +64,8 @@
   bool enable_signed_cert_timestamps;
   std::string expected_signed_cert_timestamps;
   bool fastradio_padding;
+  int min_version;
+  int max_version;
 };
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_config);