Implement HelloRetryRequest in Go.
Change-Id: Ibde837040d2332bc8570589ba5be9b32e774bfcf
Reviewed-on: https://boringssl-review.googlesource.com/8811
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index ccb17c9..fba180c 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -364,6 +364,12 @@
// be used.
CurvePreferences []CurveID
+ // DefaultCurves contains the elliptic curves for which public values will
+ // be sent in the ClientHello's KeyShare extension. If this value is nil,
+ // all supported curves will have public values sent. This field is ignored
+ // on servers.
+ DefaultCurves []CurveID
+
// ChannelID contains the ECDSA key for the client to use as
// its TLS Channel ID.
ChannelID *ecdsa.PrivateKey
@@ -1041,6 +1047,18 @@
return c.CurvePreferences
}
+func (c *Config) defaultCurves() map[CurveID]bool {
+ defaultCurves := make(map[CurveID]bool)
+ curves := c.DefaultCurves
+ if c == nil || c.DefaultCurves == nil {
+ curves = c.curvePreferences()
+ }
+ for _, curveID := range curves {
+ defaultCurves[curveID] = true
+ }
+ return defaultCurves
+}
+
// mutualVersion returns the protocol version to use given the advertised
// version of the peer.
func (c *Config) mutualVersion(vers uint16, isDTLS bool) (uint16, bool) {
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index a9e9231..3789b28 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -1137,6 +1137,8 @@
m = &serverHelloMsg{
isDTLS: c.isDTLS,
}
+ case typeHelloRetryRequest:
+ m = new(helloRetryRequestMsg)
case typeNewSessionTicket:
m = new(newSessionTicketMsg)
case typeEncryptedExtensions:
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 52951d3..291fb75 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -105,13 +105,13 @@
var keyShares map[CurveID]ecdhCurve
if hello.vers >= VersionTLS13 {
- // Offer every supported curve in the initial ClientHello.
- //
- // TODO(davidben): For real code, default to a more conservative
- // set like P-256 and X25519. Make it configurable for tests to
- // stress the HelloRetryRequest logic when implemented.
keyShares = make(map[CurveID]ecdhCurve)
+ hello.hasKeyShares = true
+ curvesToSend := c.config.defaultCurves()
for _, curveID := range hello.supportedCurves {
+ if !curvesToSend[curveID] {
+ continue
+ }
curve, ok := curveForCurveID(curveID)
if !ok {
continue
@@ -314,19 +314,78 @@
}
}
- // TODO(davidben): Handle HelloRetryRequest.
+ var serverVersion uint16
+ switch m := msg.(type) {
+ case *helloRetryRequestMsg:
+ serverVersion = m.vers
+ case *serverHelloMsg:
+ serverVersion = m.vers
+ default:
+ c.sendAlert(alertUnexpectedMessage)
+ return fmt.Errorf("tls: received unexpected message of type %T when waiting for HelloRetryRequest or ServerHello", msg)
+ }
+
+ var ok bool
+ c.vers, ok = c.config.mutualVersion(serverVersion, c.isDTLS)
+ if !ok {
+ c.sendAlert(alertProtocolVersion)
+ return fmt.Errorf("tls: server selected unsupported protocol version %x", c.vers)
+ }
+ c.haveVers = true
+
+ helloRetryRequest, haveHelloRetryRequest := msg.(*helloRetryRequestMsg)
+ var secondHelloBytes []byte
+ if haveHelloRetryRequest {
+ var hrrCurveFound bool
+ group := helloRetryRequest.selectedGroup
+ for _, curveID := range hello.supportedCurves {
+ if group == curveID {
+ hrrCurveFound = true
+ break
+ }
+ }
+ if !hrrCurveFound || keyShares[group] != nil {
+ c.sendAlert(alertHandshakeFailure)
+ return errors.New("tls: received invalid HelloRetryRequest")
+ }
+ curve, ok := curveForCurveID(group)
+ if !ok {
+ return errors.New("tls: Unable to get curve requested in HelloRetryRequest")
+ }
+ publicKey, err := curve.offer(c.config.rand())
+ if err != nil {
+ return err
+ }
+ keyShares[group] = curve
+ hello.keyShares = append(hello.keyShares, keyShareEntry{
+ group: group,
+ keyExchange: publicKey,
+ })
+
+ hello.hasEarlyData = false
+ hello.earlyDataContext = nil
+ hello.raw = nil
+
+ secondHelloBytes = hello.marshal()
+ c.writeRecord(recordTypeHandshake, secondHelloBytes)
+ c.flushHandshake()
+
+ msg, err = c.readHandshake()
+ if err != nil {
+ return err
+ }
+ }
+
serverHello, ok := msg.(*serverHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(serverHello, msg)
}
- c.vers, ok = c.config.mutualVersion(serverHello.vers, c.isDTLS)
- if !ok {
+ if c.vers != serverHello.vers {
c.sendAlert(alertProtocolVersion)
- return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers)
+ return fmt.Errorf("tls: server sent non-matching version %x vs %x", serverHello.vers, c.vers)
}
- c.haveVers = true
// Check for downgrade signals in the server random, per
// draft-ietf-tls-tls13-14, section 6.3.1.2.
@@ -349,6 +408,11 @@
return fmt.Errorf("tls: server selected an unsupported cipher suite")
}
+ if haveHelloRetryRequest && (helloRetryRequest.cipherSuite != serverHello.cipherSuite || helloRetryRequest.selectedGroup != serverHello.keyShare.group) {
+ c.sendAlert(alertHandshakeFailure)
+ return errors.New("tls: ServerHello parameters did not match HelloRetryRequest")
+ }
+
hs := &clientHandshakeState{
c: c,
serverHello: serverHello,
@@ -360,6 +424,10 @@
}
hs.writeHash(helloBytes, hs.c.sendHandshakeSeq-1)
+ if haveHelloRetryRequest {
+ hs.writeServerHash(helloRetryRequest.marshal())
+ hs.writeClientHash(secondHelloBytes)
+ }
hs.writeServerHash(hs.serverHello.marshal())
if c.vers >= VersionTLS13 {
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index b856344..5ede674 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -126,6 +126,7 @@
ocspStapling bool
supportedCurves []CurveID
supportedPoints []uint8
+ hasKeyShares bool
keyShares []keyShareEntry
pskIdentities [][]uint8
hasEarlyData bool
@@ -164,6 +165,7 @@
m.ocspStapling == m1.ocspStapling &&
eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
+ m.hasKeyShares == m1.hasKeyShares &&
eqKeyShareEntryLists(m.keyShares, m1.keyShares) &&
eqByteSlices(m.pskIdentities, m1.pskIdentities) &&
m.hasEarlyData == m1.hasEarlyData &&
@@ -274,7 +276,7 @@
supportedPoints.addU8(pointFormat)
}
}
- if len(m.keyShares) > 0 {
+ if m.hasKeyShares {
extensions.addU16(extensionKeyShare)
keyShareList := extensions.addU16LengthPrefixed()
@@ -549,6 +551,7 @@
return false
}
d := data[2:length]
+ m.hasKeyShares = true
for len(d) > 0 {
// The next KeyShareEntry contains a NamedGroup (2 bytes) and a
// key_exchange (2-byte length prefix with at least 1 byte of content).
@@ -1142,6 +1145,47 @@
return true
}
+type helloRetryRequestMsg struct {
+ raw []byte
+ vers uint16
+ cipherSuite uint16
+ selectedGroup CurveID
+}
+
+func (m *helloRetryRequestMsg) marshal() []byte {
+ if m.raw != nil {
+ return m.raw
+ }
+
+ retryRequestMsg := newByteBuilder()
+ retryRequestMsg.addU8(typeHelloRetryRequest)
+ retryRequest := retryRequestMsg.addU24LengthPrefixed()
+ retryRequest.addU16(m.vers)
+ retryRequest.addU16(m.cipherSuite)
+ retryRequest.addU16(uint16(m.selectedGroup))
+ // Extensions field. We have none to send.
+ retryRequest.addU16(0)
+
+ m.raw = retryRequestMsg.finish()
+ return m.raw
+}
+
+func (m *helloRetryRequestMsg) unmarshal(data []byte) bool {
+ m.raw = data
+ if len(data) < 12 {
+ return false
+ }
+ m.vers = uint16(data[4])<<8 | uint16(data[5])
+ m.cipherSuite = uint16(data[6])<<8 | uint16(data[7])
+ m.selectedGroup = CurveID(data[8])<<8 | CurveID(data[9])
+ extLen := int(data[10])<<8 | int(data[11])
+ data = data[12:]
+ if len(data) != extLen {
+ return false
+ }
+ return true
+}
+
type certificateMsg struct {
raw []byte
hasRequestContext bool
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index e9f94b8..300ab50 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -353,8 +353,46 @@
}
if selectedKeyShare == nil {
- // TODO(davidben,nharper): Implement HelloRetryRequest.
- return errors.New("tls: HelloRetryRequest not implemented")
+ // Send HelloRetryRequest.
+ helloRetryRequestMsg := helloRetryRequestMsg{
+ vers: c.vers,
+ cipherSuite: hs.hello.cipherSuite,
+ selectedGroup: selectedCurve,
+ }
+ hs.writeServerHash(helloRetryRequestMsg.marshal())
+ c.writeRecord(recordTypeHandshake, helloRetryRequestMsg.marshal())
+
+ // Read new ClientHello.
+ newMsg, err := c.readHandshake()
+ if err != nil {
+ return err
+ }
+ newClientHello, ok := newMsg.(*clientHelloMsg)
+ if !ok {
+ c.sendAlert(alertUnexpectedMessage)
+ return unexpectedMessageError(newClientHello, newMsg)
+ }
+ hs.writeClientHash(newClientHello.marshal())
+
+ // Check that the new ClientHello matches the old ClientHello, except for
+ // the addition of the new KeyShareEntry at the end of the list, and
+ // removing the EarlyDataIndication extension (if present).
+ newKeyShares := newClientHello.keyShares
+ if len(newKeyShares) == 0 || newKeyShares[len(newKeyShares)-1].group != selectedCurve {
+ return errors.New("tls: KeyShare from HelloRetryRequest not present in new ClientHello")
+ }
+ oldClientHelloCopy := *hs.clientHello
+ oldClientHelloCopy.raw = nil
+ oldClientHelloCopy.hasEarlyData = false
+ oldClientHelloCopy.earlyDataContext = nil
+ newClientHelloCopy := *newClientHello
+ newClientHelloCopy.raw = nil
+ newClientHelloCopy.keyShares = newKeyShares[:len(newKeyShares)-1]
+ if !oldClientHelloCopy.equal(&newClientHelloCopy) {
+ return errors.New("tls: new ClientHello does not match")
+ }
+
+ selectedKeyShare = &newKeyShares[len(newKeyShares)-1]
}
// Once a curve has been selected and a key share identified,