Implement HelloRetryRequest in Go.

Change-Id: Ibde837040d2332bc8570589ba5be9b32e774bfcf
Reviewed-on: https://boringssl-review.googlesource.com/8811
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index ccb17c9..fba180c 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -364,6 +364,12 @@
 	// be used.
 	CurvePreferences []CurveID
 
+	// DefaultCurves contains the elliptic curves for which public values will
+	// be sent in the ClientHello's KeyShare extension. If this value is nil,
+	// all supported curves will have public values sent. This field is ignored
+	// on servers.
+	DefaultCurves []CurveID
+
 	// ChannelID contains the ECDSA key for the client to use as
 	// its TLS Channel ID.
 	ChannelID *ecdsa.PrivateKey
@@ -1041,6 +1047,18 @@
 	return c.CurvePreferences
 }
 
+func (c *Config) defaultCurves() map[CurveID]bool {
+	defaultCurves := make(map[CurveID]bool)
+	curves := c.DefaultCurves
+	if c == nil || c.DefaultCurves == nil {
+		curves = c.curvePreferences()
+	}
+	for _, curveID := range curves {
+		defaultCurves[curveID] = true
+	}
+	return defaultCurves
+}
+
 // mutualVersion returns the protocol version to use given the advertised
 // version of the peer.
 func (c *Config) mutualVersion(vers uint16, isDTLS bool) (uint16, bool) {
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index a9e9231..3789b28 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -1137,6 +1137,8 @@
 		m = &serverHelloMsg{
 			isDTLS: c.isDTLS,
 		}
+	case typeHelloRetryRequest:
+		m = new(helloRetryRequestMsg)
 	case typeNewSessionTicket:
 		m = new(newSessionTicketMsg)
 	case typeEncryptedExtensions:
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 52951d3..291fb75 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -105,13 +105,13 @@
 
 	var keyShares map[CurveID]ecdhCurve
 	if hello.vers >= VersionTLS13 {
-		// Offer every supported curve in the initial ClientHello.
-		//
-		// TODO(davidben): For real code, default to a more conservative
-		// set like P-256 and X25519. Make it configurable for tests to
-		// stress the HelloRetryRequest logic when implemented.
 		keyShares = make(map[CurveID]ecdhCurve)
+		hello.hasKeyShares = true
+		curvesToSend := c.config.defaultCurves()
 		for _, curveID := range hello.supportedCurves {
+			if !curvesToSend[curveID] {
+				continue
+			}
 			curve, ok := curveForCurveID(curveID)
 			if !ok {
 				continue
@@ -314,19 +314,78 @@
 		}
 	}
 
-	// TODO(davidben): Handle HelloRetryRequest.
+	var serverVersion uint16
+	switch m := msg.(type) {
+	case *helloRetryRequestMsg:
+		serverVersion = m.vers
+	case *serverHelloMsg:
+		serverVersion = m.vers
+	default:
+		c.sendAlert(alertUnexpectedMessage)
+		return fmt.Errorf("tls: received unexpected message of type %T when waiting for HelloRetryRequest or ServerHello", msg)
+	}
+
+	var ok bool
+	c.vers, ok = c.config.mutualVersion(serverVersion, c.isDTLS)
+	if !ok {
+		c.sendAlert(alertProtocolVersion)
+		return fmt.Errorf("tls: server selected unsupported protocol version %x", c.vers)
+	}
+	c.haveVers = true
+
+	helloRetryRequest, haveHelloRetryRequest := msg.(*helloRetryRequestMsg)
+	var secondHelloBytes []byte
+	if haveHelloRetryRequest {
+		var hrrCurveFound bool
+		group := helloRetryRequest.selectedGroup
+		for _, curveID := range hello.supportedCurves {
+			if group == curveID {
+				hrrCurveFound = true
+				break
+			}
+		}
+		if !hrrCurveFound || keyShares[group] != nil {
+			c.sendAlert(alertHandshakeFailure)
+			return errors.New("tls: received invalid HelloRetryRequest")
+		}
+		curve, ok := curveForCurveID(group)
+		if !ok {
+			return errors.New("tls: Unable to get curve requested in HelloRetryRequest")
+		}
+		publicKey, err := curve.offer(c.config.rand())
+		if err != nil {
+			return err
+		}
+		keyShares[group] = curve
+		hello.keyShares = append(hello.keyShares, keyShareEntry{
+			group:       group,
+			keyExchange: publicKey,
+		})
+
+		hello.hasEarlyData = false
+		hello.earlyDataContext = nil
+		hello.raw = nil
+
+		secondHelloBytes = hello.marshal()
+		c.writeRecord(recordTypeHandshake, secondHelloBytes)
+		c.flushHandshake()
+
+		msg, err = c.readHandshake()
+		if err != nil {
+			return err
+		}
+	}
+
 	serverHello, ok := msg.(*serverHelloMsg)
 	if !ok {
 		c.sendAlert(alertUnexpectedMessage)
 		return unexpectedMessageError(serverHello, msg)
 	}
 
-	c.vers, ok = c.config.mutualVersion(serverHello.vers, c.isDTLS)
-	if !ok {
+	if c.vers != serverHello.vers {
 		c.sendAlert(alertProtocolVersion)
-		return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers)
+		return fmt.Errorf("tls: server sent non-matching version %x vs %x", serverHello.vers, c.vers)
 	}
-	c.haveVers = true
 
 	// Check for downgrade signals in the server random, per
 	// draft-ietf-tls-tls13-14, section 6.3.1.2.
@@ -349,6 +408,11 @@
 		return fmt.Errorf("tls: server selected an unsupported cipher suite")
 	}
 
+	if haveHelloRetryRequest && (helloRetryRequest.cipherSuite != serverHello.cipherSuite || helloRetryRequest.selectedGroup != serverHello.keyShare.group) {
+		c.sendAlert(alertHandshakeFailure)
+		return errors.New("tls: ServerHello parameters did not match HelloRetryRequest")
+	}
+
 	hs := &clientHandshakeState{
 		c:            c,
 		serverHello:  serverHello,
@@ -360,6 +424,10 @@
 	}
 
 	hs.writeHash(helloBytes, hs.c.sendHandshakeSeq-1)
+	if haveHelloRetryRequest {
+		hs.writeServerHash(helloRetryRequest.marshal())
+		hs.writeClientHash(secondHelloBytes)
+	}
 	hs.writeServerHash(hs.serverHello.marshal())
 
 	if c.vers >= VersionTLS13 {
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index b856344..5ede674 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -126,6 +126,7 @@
 	ocspStapling            bool
 	supportedCurves         []CurveID
 	supportedPoints         []uint8
+	hasKeyShares            bool
 	keyShares               []keyShareEntry
 	pskIdentities           [][]uint8
 	hasEarlyData            bool
@@ -164,6 +165,7 @@
 		m.ocspStapling == m1.ocspStapling &&
 		eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
 		bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
+		m.hasKeyShares == m1.hasKeyShares &&
 		eqKeyShareEntryLists(m.keyShares, m1.keyShares) &&
 		eqByteSlices(m.pskIdentities, m1.pskIdentities) &&
 		m.hasEarlyData == m1.hasEarlyData &&
@@ -274,7 +276,7 @@
 			supportedPoints.addU8(pointFormat)
 		}
 	}
-	if len(m.keyShares) > 0 {
+	if m.hasKeyShares {
 		extensions.addU16(extensionKeyShare)
 		keyShareList := extensions.addU16LengthPrefixed()
 
@@ -549,6 +551,7 @@
 				return false
 			}
 			d := data[2:length]
+			m.hasKeyShares = true
 			for len(d) > 0 {
 				// The next KeyShareEntry contains a NamedGroup (2 bytes) and a
 				// key_exchange (2-byte length prefix with at least 1 byte of content).
@@ -1142,6 +1145,47 @@
 	return true
 }
 
+type helloRetryRequestMsg struct {
+	raw           []byte
+	vers          uint16
+	cipherSuite   uint16
+	selectedGroup CurveID
+}
+
+func (m *helloRetryRequestMsg) marshal() []byte {
+	if m.raw != nil {
+		return m.raw
+	}
+
+	retryRequestMsg := newByteBuilder()
+	retryRequestMsg.addU8(typeHelloRetryRequest)
+	retryRequest := retryRequestMsg.addU24LengthPrefixed()
+	retryRequest.addU16(m.vers)
+	retryRequest.addU16(m.cipherSuite)
+	retryRequest.addU16(uint16(m.selectedGroup))
+	// Extensions field. We have none to send.
+	retryRequest.addU16(0)
+
+	m.raw = retryRequestMsg.finish()
+	return m.raw
+}
+
+func (m *helloRetryRequestMsg) unmarshal(data []byte) bool {
+	m.raw = data
+	if len(data) < 12 {
+		return false
+	}
+	m.vers = uint16(data[4])<<8 | uint16(data[5])
+	m.cipherSuite = uint16(data[6])<<8 | uint16(data[7])
+	m.selectedGroup = CurveID(data[8])<<8 | CurveID(data[9])
+	extLen := int(data[10])<<8 | int(data[11])
+	data = data[12:]
+	if len(data) != extLen {
+		return false
+	}
+	return true
+}
+
 type certificateMsg struct {
 	raw               []byte
 	hasRequestContext bool
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index e9f94b8..300ab50 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -353,8 +353,46 @@
 		}
 
 		if selectedKeyShare == nil {
-			// TODO(davidben,nharper): Implement HelloRetryRequest.
-			return errors.New("tls: HelloRetryRequest not implemented")
+			// Send HelloRetryRequest.
+			helloRetryRequestMsg := helloRetryRequestMsg{
+				vers:          c.vers,
+				cipherSuite:   hs.hello.cipherSuite,
+				selectedGroup: selectedCurve,
+			}
+			hs.writeServerHash(helloRetryRequestMsg.marshal())
+			c.writeRecord(recordTypeHandshake, helloRetryRequestMsg.marshal())
+
+			// Read new ClientHello.
+			newMsg, err := c.readHandshake()
+			if err != nil {
+				return err
+			}
+			newClientHello, ok := newMsg.(*clientHelloMsg)
+			if !ok {
+				c.sendAlert(alertUnexpectedMessage)
+				return unexpectedMessageError(newClientHello, newMsg)
+			}
+			hs.writeClientHash(newClientHello.marshal())
+
+			// Check that the new ClientHello matches the old ClientHello, except for
+			// the addition of the new KeyShareEntry at the end of the list, and
+			// removing the EarlyDataIndication extension (if present).
+			newKeyShares := newClientHello.keyShares
+			if len(newKeyShares) == 0 || newKeyShares[len(newKeyShares)-1].group != selectedCurve {
+				return errors.New("tls: KeyShare from HelloRetryRequest not present in new ClientHello")
+			}
+			oldClientHelloCopy := *hs.clientHello
+			oldClientHelloCopy.raw = nil
+			oldClientHelloCopy.hasEarlyData = false
+			oldClientHelloCopy.earlyDataContext = nil
+			newClientHelloCopy := *newClientHello
+			newClientHelloCopy.raw = nil
+			newClientHelloCopy.keyShares = newKeyShares[:len(newKeyShares)-1]
+			if !oldClientHelloCopy.equal(&newClientHelloCopy) {
+				return errors.New("tls: new ClientHello does not match")
+			}
+
+			selectedKeyShare = &newKeyShares[len(newKeyShares)-1]
 		}
 
 		// Once a curve has been selected and a key share identified,