Revise version negotiation on the Go half.

This is in preparation for supporting multiple TLS 1.3 variants.

Change-Id: Ia2caf984f576f1b9e5915bdaf6ff952c8be10417
Reviewed-on: https://boringssl-review.googlesource.com/17526
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index bd490d3..9bd9c77 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -34,6 +34,19 @@
 // A draft version of TLS 1.3 that is sent over the wire for the current draft.
 const tls13DraftVersion = 0x7f12
 
+var allTLSWireVersions = []uint16{
+	tls13DraftVersion,
+	VersionTLS12,
+	VersionTLS11,
+	VersionTLS10,
+	VersionSSL30,
+}
+
+var allDTLSWireVersions = []uint16{
+	VersionDTLS12,
+	VersionDTLS10,
+}
+
 const (
 	maxPlaintext        = 16384        // maximum plaintext payload length
 	maxCiphertext       = 16384 + 2048 // maximum ciphertext payload length
@@ -630,12 +643,12 @@
 	SendSupportedVersions []uint16
 
 	// NegotiateVersion, if non-zero, causes the server to negotiate the
-	// specifed TLS version rather than the version supported by either
+	// specifed wire version rather than the version supported by either
 	// peer.
 	NegotiateVersion uint16
 
 	// NegotiateVersionOnRenego, if non-zero, causes the server to negotiate
-	// the specified TLS version on renegotiation rather than retaining it.
+	// the specified wire version on renegotiation rather than retaining it.
 	NegotiateVersionOnRenego uint16
 
 	// ExpectFalseStart causes the server to, on full handshakes,
@@ -1443,10 +1456,29 @@
 	return defaultCurves
 }
 
-// isSupportedVersion returns true if the specified protocol version is
-// acceptable.
-func (c *Config) isSupportedVersion(vers uint16, isDTLS bool) bool {
-	return c.minVersion(isDTLS) <= vers && vers <= c.maxVersion(isDTLS)
+// isSupportedVersion checks if the specified wire version is acceptable. If so,
+// it returns true and the corresponding protocol version. Otherwise, it returns
+// false.
+func (c *Config) isSupportedVersion(wireVers uint16, isDTLS bool) (uint16, bool) {
+	vers, ok := wireToVersion(wireVers, isDTLS)
+	if !ok || c.minVersion(isDTLS) > vers || vers > c.maxVersion(isDTLS) {
+		return 0, false
+	}
+	return vers, true
+}
+
+func (c *Config) supportedVersions(isDTLS bool) []uint16 {
+	versions := allTLSWireVersions
+	if isDTLS {
+		versions = allDTLSWireVersions
+	}
+	var ret []uint16
+	for _, vers := range versions {
+		if _, ok := c.isSupportedVersion(vers, isDTLS); ok {
+			ret = append(ret, vers)
+		}
+	}
+	return ret
 }
 
 // getCertificateForName returns the best certificate for the given name,
@@ -1722,3 +1754,12 @@
 	downgradeTLS13 = []byte{0x44, 0x4f, 0x57, 0x4e, 0x47, 0x52, 0x44, 0x01}
 	downgradeTLS12 = []byte{0x44, 0x4f, 0x57, 0x4e, 0x47, 0x52, 0x44, 0x00}
 )
+
+func containsGREASE(values []uint16) bool {
+	for _, v := range values {
+		if isGREASEValue(v) {
+			return true
+		}
+	}
+	return false
+}
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index fce0049..61fc9d3 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -35,6 +35,7 @@
 	// constant after handshake; protected by handshakeMutex
 	handshakeMutex       sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
 	handshakeErr         error      // error resulting from handshake
+	wireVersion          uint16     // TLS wire version
 	vers                 uint16     // TLS version
 	haveVers             bool       // version has been negotiated
 	config               *Config    // configuration passed to constructor
diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go
index e273bc7..d46b247 100644
--- a/ssl/test/runner/dtls.go
+++ b/ssl/test/runner/dtls.go
@@ -23,32 +23,12 @@
 	"net"
 )
 
-func versionToWire(vers uint16, isDTLS bool) uint16 {
-	if isDTLS {
-		switch vers {
-		case VersionTLS12:
-			return 0xfefd
-		case VersionTLS10:
-			return 0xfeff
-		}
-	} else {
-		switch vers {
-		case VersionSSL30, VersionTLS10, VersionTLS11, VersionTLS12:
-			return vers
-		case VersionTLS13:
-			return tls13DraftVersion
-		}
-	}
-
-	panic("unknown version")
-}
-
 func wireToVersion(vers uint16, isDTLS bool) (uint16, bool) {
 	if isDTLS {
 		switch vers {
-		case 0xfefd:
+		case VersionDTLS12:
 			return VersionTLS12, true
-		case 0xfeff:
+		case VersionDTLS10:
 			return VersionTLS10, true
 		}
 	} else {
@@ -102,9 +82,9 @@
 	// version is irrelevant.)
 	if typ != recordTypeAlert {
 		if c.haveVers {
-			if wireVers := versionToWire(c.vers, c.isDTLS); vers != wireVers {
+			if vers != c.wireVersion {
 				c.sendAlert(alertProtocolVersion)
-				return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, wireVers))
+				return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.wireVersion))
 			}
 		} else {
 			// Pre-version-negotiation alerts may be sent with any version.
@@ -368,13 +348,16 @@
 	// TODO(nharper): DTLS 1.3 will likely need to set this to
 	// recordTypeApplicationData if c.out.cipher != nil.
 	b.data[0] = byte(typ)
-	vers := c.vers
+	vers := c.wireVersion
 	if vers == 0 {
 		// Some TLS servers fail if the record version is greater than
 		// TLS 1.0 for the initial ClientHello.
-		vers = VersionTLS10
+		if c.isDTLS {
+			vers = VersionDTLS10
+		} else {
+			vers = VersionTLS10
+		}
 	}
-	vers = versionToWire(vers, c.isDTLS)
 	b.data[1] = byte(vers >> 8)
 	b.data[2] = byte(vers)
 	// DTLS records include an explicit sequence number.
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index c531a28..f0bfca4 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -35,6 +35,21 @@
 	finishedBytes []byte
 }
 
+func mapClientHelloVersion(vers uint16, isDTLS bool) uint16 {
+	if !isDTLS {
+		return vers
+	}
+
+	switch vers {
+	case VersionTLS12:
+		return VersionDTLS12
+	case VersionTLS10:
+		return VersionDTLS10
+	}
+
+	panic("Unknown ClientHello version.")
+}
+
 func (c *Conn) clientHandshake() error {
 	if c.config == nil {
 		c.config = defaultConfig()
@@ -63,7 +78,6 @@
 	maxVersion := c.config.maxVersion(c.isDTLS)
 	hello := &clientHelloMsg{
 		isDTLS:                  c.isDTLS,
-		vers:                    versionToWire(maxVersion, c.isDTLS),
 		compressionMethods:      []uint8{compressionNone},
 		random:                  make([]byte, 32),
 		ocspStapling:            !c.config.Bugs.NoOCSPStapling,
@@ -85,6 +99,23 @@
 		pskBinderFirst:          c.config.Bugs.PSKBinderFirst,
 	}
 
+	if maxVersion >= VersionTLS13 {
+		hello.vers = mapClientHelloVersion(VersionTLS12, c.isDTLS)
+		if !c.config.Bugs.OmitSupportedVersions {
+			hello.supportedVersions = c.config.supportedVersions(c.isDTLS)
+		}
+	} else {
+		hello.vers = mapClientHelloVersion(maxVersion, c.isDTLS)
+	}
+
+	if c.config.Bugs.SendClientVersion != 0 {
+		hello.vers = c.config.Bugs.SendClientVersion
+	}
+
+	if len(c.config.Bugs.SendSupportedVersions) > 0 {
+		hello.supportedVersions = c.config.Bugs.SendSupportedVersions
+	}
+
 	disableEMS := c.config.Bugs.NoExtendedMasterSecret
 	if c.cipherSuite != nil {
 		disableEMS = c.config.Bugs.NoExtendedMasterSecretOnRenegotiation
@@ -310,23 +341,6 @@
 		}
 	}
 
-	if maxVersion == VersionTLS13 && !c.config.Bugs.OmitSupportedVersions {
-		if hello.vers >= VersionTLS13 {
-			hello.vers = VersionTLS12
-		}
-		for version := maxVersion; version >= minVersion; version-- {
-			hello.supportedVersions = append(hello.supportedVersions, versionToWire(version, c.isDTLS))
-		}
-	}
-
-	if len(c.config.Bugs.SendSupportedVersions) > 0 {
-		hello.supportedVersions = c.config.Bugs.SendSupportedVersions
-	}
-
-	if c.config.Bugs.SendClientVersion != 0 {
-		hello.vers = c.config.Bugs.SendClientVersion
-	}
-
 	if c.config.Bugs.SendCipherSuites != nil {
 		hello.cipherSuites = c.config.Bugs.SendCipherSuites
 	}
@@ -409,7 +423,7 @@
 	if c.isDTLS {
 		helloVerifyRequest, ok := msg.(*helloVerifyRequestMsg)
 		if ok {
-			if helloVerifyRequest.vers != versionToWire(VersionTLS10, c.isDTLS) {
+			if helloVerifyRequest.vers != VersionDTLS10 {
 				// Per RFC 6347, the version field in
 				// HelloVerifyRequest SHOULD be always DTLS
 				// 1.0. Enforce this for testing purposes.
@@ -443,14 +457,12 @@
 		return fmt.Errorf("tls: received unexpected message of type %T when waiting for HelloRetryRequest or ServerHello", msg)
 	}
 
-	serverVersion, ok := wireToVersion(serverWireVersion, c.isDTLS)
-	if ok {
-		ok = c.config.isSupportedVersion(serverVersion, c.isDTLS)
-	}
+	serverVersion, ok := c.config.isSupportedVersion(serverWireVersion, c.isDTLS)
 	if !ok {
 		c.sendAlert(alertProtocolVersion)
 		return fmt.Errorf("tls: server selected unsupported protocol version %x", c.vers)
 	}
+	c.wireVersion = serverWireVersion
 	c.vers = serverVersion
 	c.haveVers = true
 
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 3a182ec..a29a812 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -161,7 +161,7 @@
 		// Per RFC 6347, the version field in HelloVerifyRequest SHOULD
 		// be always DTLS 1.0
 		helloVerifyRequest := &helloVerifyRequestMsg{
-			vers:   versionToWire(VersionTLS10, c.isDTLS),
+			vers:   VersionDTLS10,
 			cookie: make([]byte, 32),
 		}
 		if _, err := io.ReadFull(c.config.rand(), helloVerifyRequest.cookie); err != nil {
@@ -210,74 +210,69 @@
 
 	c.clientVersion = hs.clientHello.vers
 
-	// Convert the ClientHello wire version to a protocol version.
-	var clientVersion uint16
-	if c.isDTLS {
-		if hs.clientHello.vers <= 0xfefd {
-			clientVersion = VersionTLS12
-		} else if hs.clientHello.vers <= 0xfeff {
-			clientVersion = VersionTLS10
+	// Use the versions extension if supplied, otherwise use the legacy ClientHello version.
+	if len(hs.clientHello.supportedVersions) == 0 {
+		if c.isDTLS {
+			if hs.clientHello.vers <= VersionDTLS12 {
+				hs.clientHello.supportedVersions = append(hs.clientHello.supportedVersions, VersionDTLS12)
+			}
+			if hs.clientHello.vers <= VersionDTLS10 {
+				hs.clientHello.supportedVersions = append(hs.clientHello.supportedVersions, VersionDTLS10)
+			}
+		} else {
+			if hs.clientHello.vers >= VersionTLS12 {
+				hs.clientHello.supportedVersions = append(hs.clientHello.supportedVersions, VersionTLS12)
+			}
+			if hs.clientHello.vers >= VersionTLS11 {
+				hs.clientHello.supportedVersions = append(hs.clientHello.supportedVersions, VersionTLS11)
+			}
+			if hs.clientHello.vers >= VersionTLS10 {
+				hs.clientHello.supportedVersions = append(hs.clientHello.supportedVersions, VersionTLS10)
+			}
+			if hs.clientHello.vers >= VersionSSL30 {
+				hs.clientHello.supportedVersions = append(hs.clientHello.supportedVersions, VersionSSL30)
+			}
 		}
-	} else {
-		if hs.clientHello.vers >= VersionTLS12 {
-			clientVersion = VersionTLS12
-		} else if hs.clientHello.vers >= VersionTLS11 {
-			clientVersion = VersionTLS11
-		} else if hs.clientHello.vers >= VersionTLS10 {
-			clientVersion = VersionTLS10
-		} else if hs.clientHello.vers >= VersionSSL30 {
-			clientVersion = VersionSSL30
-		}
+	} else if config.Bugs.ExpectGREASE && !containsGREASE(hs.clientHello.supportedVersions) {
+		return errors.New("tls: no GREASE version value found")
 	}
 
-	if config.Bugs.NegotiateVersion != 0 {
-		c.vers = config.Bugs.NegotiateVersion
-	} else if c.haveVers && config.Bugs.NegotiateVersionOnRenego != 0 {
-		c.vers = config.Bugs.NegotiateVersionOnRenego
-	} else if len(hs.clientHello.supportedVersions) > 0 {
-		// Use the versions extension if supplied.
-		var foundVersion, foundGREASE bool
-		for _, extVersion := range hs.clientHello.supportedVersions {
-			if isGREASEValue(extVersion) {
-				foundGREASE = true
+	if !c.haveVers {
+		if config.Bugs.NegotiateVersion != 0 {
+			c.wireVersion = config.Bugs.NegotiateVersion
+		} else {
+			var found bool
+			for _, vers := range hs.clientHello.supportedVersions {
+				if _, ok := config.isSupportedVersion(vers, c.isDTLS); ok {
+					c.wireVersion = vers
+					found = true
+					break
+				}
 			}
-			extVersion, ok = wireToVersion(extVersion, c.isDTLS)
-			if !ok {
-				continue
-			}
-			if config.isSupportedVersion(extVersion, c.isDTLS) && !foundVersion {
-				c.vers = extVersion
-				foundVersion = true
-				break
+			if !found {
+				c.sendAlert(alertProtocolVersion)
+				return errors.New("tls: client did not offer any supported protocol versions")
 			}
 		}
-		if !foundVersion {
-			c.sendAlert(alertProtocolVersion)
-			return errors.New("tls: client did not offer any supported protocol versions")
-		}
-		if config.Bugs.ExpectGREASE && !foundGREASE {
-			return errors.New("tls: no GREASE version value found")
-		}
-	} else {
-		// Otherwise, use the legacy ClientHello version.
-		version := clientVersion
-		if maxVersion := config.maxVersion(c.isDTLS); version > maxVersion {
-			version = maxVersion
-		}
-		if version == 0 || !config.isSupportedVersion(version, c.isDTLS) {
-			return fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
-		}
-		c.vers = version
+	} else if config.Bugs.NegotiateVersionOnRenego != 0 {
+		c.wireVersion = config.Bugs.NegotiateVersionOnRenego
+	}
+
+	c.vers, ok = wireToVersion(c.wireVersion, c.isDTLS)
+	if !ok {
+		panic("Could not map wire version")
 	}
 	c.haveVers = true
 
+	clientProtocol, ok := wireToVersion(c.clientVersion, c.isDTLS)
+
 	// Reject < 1.2 ClientHellos with signature_algorithms.
-	if clientVersion < VersionTLS12 && len(hs.clientHello.signatureAlgorithms) > 0 {
+	if ok && clientProtocol < VersionTLS12 && len(hs.clientHello.signatureAlgorithms) > 0 {
 		return fmt.Errorf("tls: client included signature_algorithms before TLS 1.2")
 	}
 
 	// Check the client cipher list is consistent with the version.
-	if clientVersion < VersionTLS12 {
+	if ok && clientProtocol < VersionTLS12 {
 		for _, id := range hs.clientHello.cipherSuites {
 			if isTLS12Cipher(id) {
 				return fmt.Errorf("tls: client offered TLS 1.2 cipher before TLS 1.2")
@@ -298,13 +293,11 @@
 		return fmt.Errorf("tls: client offered unexpected PSK identities")
 	}
 
-	var scsvFound, greaseFound bool
+	var scsvFound bool
 	for _, cipherSuite := range hs.clientHello.cipherSuites {
 		if cipherSuite == fallbackSCSV {
 			scsvFound = true
-		}
-		if isGREASEValue(cipherSuite) {
-			greaseFound = true
+			break
 		}
 	}
 
@@ -314,11 +307,11 @@
 		return errors.New("tls: fallback SCSV found when not expected")
 	}
 
-	if !greaseFound && config.Bugs.ExpectGREASE {
+	if config.Bugs.ExpectGREASE && !containsGREASE(hs.clientHello.cipherSuites) {
 		return errors.New("tls: no GREASE cipher suite value found")
 	}
 
-	greaseFound = false
+	var greaseFound bool
 	for _, curve := range hs.clientHello.supportedCurves {
 		if isGREASEValue(uint16(curve)) {
 			greaseFound = true
@@ -367,7 +360,7 @@
 
 	hs.hello = &serverHelloMsg{
 		isDTLS:          c.isDTLS,
-		vers:            versionToWire(c.vers, c.isDTLS),
+		vers:            c.wireVersion,
 		versOverride:    config.Bugs.SendServerHelloVersion,
 		customExtension: config.Bugs.CustomUnencryptedExtension,
 		unencryptedALPN: config.Bugs.SendUnencryptedALPN,
@@ -526,7 +519,7 @@
 ResendHelloRetryRequest:
 	var sendHelloRetryRequest bool
 	helloRetryRequest := &helloRetryRequestMsg{
-		vers:                versionToWire(c.vers, c.isDTLS),
+		vers:                c.wireVersion,
 		duplicateExtensions: config.Bugs.DuplicateHelloRetryRequestExtensions,
 	}
 
@@ -1049,7 +1042,7 @@
 
 	hs.hello = &serverHelloMsg{
 		isDTLS:            c.isDTLS,
-		vers:              versionToWire(c.vers, c.isDTLS),
+		vers:              c.wireVersion,
 		versOverride:      config.Bugs.SendServerHelloVersion,
 		compressionMethod: compressionNone,
 	}
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index ba6cc54..7ba0c08 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -263,6 +263,31 @@
 	panic("Unknown test certificate")
 }
 
+// configVersionToWire maps a protocol version to the default wire version to
+// test at that protocol.
+//
+// TODO(davidben): Rather than mapping these, make tlsVersions contains a list
+// of wire versions and test all of them.
+func configVersionToWire(vers uint16, protocol protocol) uint16 {
+	if protocol == dtls {
+		switch vers {
+		case VersionTLS12:
+			return VersionDTLS12
+		case VersionTLS10:
+			return VersionDTLS10
+		}
+	} else {
+		switch vers {
+		case VersionSSL30, VersionTLS10, VersionTLS11, VersionTLS12:
+			return vers
+		case VersionTLS13:
+			return tls13DraftVersion
+		}
+	}
+
+	panic("unknown version")
+}
+
 // encodeDERValues encodes a series of bytestrings in comma-separated-hex form.
 func encodeDERValues(values [][]byte) string {
 	var ret string
@@ -4577,12 +4602,12 @@
 				if clientVers > VersionTLS10 {
 					clientVers = VersionTLS10
 				}
-				clientVers = versionToWire(clientVers, protocol == dtls)
+				clientVers = configVersionToWire(clientVers, protocol)
 				serverVers := expectedVersion
 				if expectedVersion >= VersionTLS13 {
 					serverVers = VersionTLS10
 				}
-				serverVers = versionToWire(serverVers, protocol == dtls)
+				serverVers = configVersionToWire(serverVers, protocol)
 
 				testCases = append(testCases, testCase{
 					protocol: protocol,
@@ -4653,7 +4678,7 @@
 				suffix += "-DTLS"
 			}
 
-			wireVersion := versionToWire(vers.version, protocol == dtls)
+			wireVersion := configVersionToWire(vers.version, protocol)
 			testCases = append(testCases, testCase{
 				protocol: protocol,
 				testType: serverTest,
@@ -4926,7 +4951,7 @@
 							// Ensure the server does not decline to
 							// select a version (versions extension) or
 							// cipher (some ciphers depend on versions).
-							NegotiateVersion:            runnerVers.version,
+							NegotiateVersion:            configVersionToWire(runnerVers.version, protocol),
 							IgnorePeerCipherPreferences: shouldFail,
 						},
 					},
@@ -4946,7 +4971,7 @@
 							// Ensure the server does not decline to
 							// select a version (versions extension) or
 							// cipher (some ciphers depend on versions).
-							NegotiateVersion:            runnerVers.version,
+							NegotiateVersion:            configVersionToWire(runnerVers.version, protocol),
 							IgnorePeerCipherPreferences: shouldFail,
 						},
 					},