runner: Add a helper to read and downcast a message
We can't always use this because sometimes (particular in TLS 1.2), you
have to account for optional messages, but often we know exactly which
message we're expecting.
Change-Id: I4f6f59111fbf3e5f8a8fefa35802def9b2029196
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72949
Reviewed-by: Nick Harper <nharper@chromium.org>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 28c4f72..e634449 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -1528,6 +1528,19 @@
return m, nil
}
+func readHandshakeType[T any](c *Conn) (*T, error) {
+ m, err := c.readHandshake()
+ if err != nil {
+ return nil, err
+ }
+ mType, ok := m.(*T)
+ if !ok {
+ c.sendAlert(alertUnexpectedMessage)
+ return nil, unexpectedMessageError(mType, m)
+ }
+ return mType, nil
+}
+
// skipPacket processes all the DTLS records in packet. It updates
// sequence number expectations but otherwise ignores them.
func (c *Conn) skipPacket(packet []byte) error {
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index b7f1817..c71c295 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -1152,16 +1152,10 @@
return err
}
- msg, err := c.readHandshake()
+ encryptedExtensions, err := readHandshakeType[encryptedExtensionsMsg](c)
if err != nil {
return err
}
-
- encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(encryptedExtensions, msg)
- }
hs.writeServerHash(encryptedExtensions.marshal())
if !bytes.Equal(encryptedExtensions.extensions.echRetryConfigs, c.config.Bugs.ExpectECHRetryConfigs) {
@@ -1277,15 +1271,10 @@
c.ocspResponse = certMsg.certificates[0].ocspResponse
c.sctList = certMsg.certificates[0].sctList
- msg, err = c.readHandshake()
+ certVerifyMsg, err := readHandshakeType[certificateVerifyMsg](c)
if err != nil {
return err
}
- certVerifyMsg, ok := msg.(*certificateVerifyMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(certVerifyMsg, msg)
- }
c.peerSignatureAlgorithm = certVerifyMsg.signatureAlgorithm
input := hs.finishedHash.certificateVerifyInput(serverCertificateVerifyContextTLS13)
@@ -1301,16 +1290,10 @@
hs.writeServerHash(certVerifyMsg.marshal())
}
- msg, err = c.readHandshake()
+ serverFinished, err := readHandshakeType[finishedMsg](c)
if err != nil {
return err
}
- serverFinished, ok := msg.(*finishedMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(serverFinished, msg)
- }
-
verify := hs.finishedHash.serverSum(serverHandshakeTrafficSecret)
if len(verify) != len(serverFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
@@ -1341,14 +1324,10 @@
// BoringSSL will always send two tickets half-RTT when
// negotiating 0-RTT.
for i := 0; i < shimConfig.HalfRTTTickets; i++ {
- msg, err := c.readHandshake()
+ newSessionTicket, err := readHandshakeType[newSessionTicketMsg](c)
if err != nil {
return fmt.Errorf("tls: error reading half-RTT ticket: %s", err)
}
- newSessionTicket, ok := msg.(*newSessionTicketMsg)
- if !ok {
- return errors.New("tls: expected half-RTT ticket")
- }
// Defer processing until the resumption secret is computed.
deferredTickets = append(deferredTickets, newSessionTicket)
}
@@ -1622,16 +1601,10 @@
var leaf *x509.Certificate
if hs.suite.flags&suitePSK == 0 {
- msg, err := c.readHandshake()
+ certMsg, err := readHandshakeType[certificateMsg](c)
if err != nil {
return err
}
-
- certMsg, ok := msg.(*certificateMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(certMsg, msg)
- }
hs.writeServerHash(certMsg.marshal())
if err := hs.verifyCertificates(certMsg); err != nil {
@@ -1641,15 +1614,10 @@
}
if hs.serverHello.extensions.ocspStapling {
- msg, err := c.readHandshake()
+ cs, err := readHandshakeType[certificateStatusMsg](c)
if err != nil {
return err
}
- cs, ok := msg.(*certificateStatusMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(cs, msg)
- }
hs.writeServerHash(cs.marshal())
if cs.statusType == statusTypeOCSP {
@@ -2176,15 +2144,10 @@
return err
}
- msg, err := c.readHandshake()
+ serverFinished, err := readHandshakeType[finishedMsg](c)
if err != nil {
return err
}
- serverFinished, ok := msg.(*finishedMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(serverFinished, msg)
- }
if c.config.Bugs.EarlyChangeCipherSpec == 0 {
verify := hs.finishedHash.serverSum(hs.masterSecret)
@@ -2233,15 +2196,10 @@
return errors.New("tls: received unexpected NewSessionTicket")
}
- msg, err := c.readHandshake()
+ sessionTicketMsg, err := readHandshakeType[newSessionTicketMsg](c)
if err != nil {
return err
}
- sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(sessionTicketMsg, msg)
- }
if c.config.Bugs.ExpectNoNonEmptyNewSessionTicket && len(sessionTicketMsg.ticket) != 0 {
return errors.New("tls: received unexpected non-empty NewSessionTicket")
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 96490cb..9dddd13 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -166,16 +166,11 @@
config := hs.c.config
c := hs.c
- msg, err := c.readHandshake()
+ var err error
+ hs.clientHello, err = readHandshakeType[clientHelloMsg](c)
if err != nil {
return err
}
- var ok bool
- hs.clientHello, ok = msg.(*clientHelloMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(hs.clientHello, msg)
- }
if size := config.Bugs.RequireClientHelloSize; size != 0 && len(hs.clientHello.raw) != size {
return fmt.Errorf("tls: ClientHello record size is %d, but expected %d", len(hs.clientHello.raw), size)
}
@@ -290,6 +285,7 @@
c.wireVersion = config.Bugs.NegotiateVersionOnRenego
}
+ var ok bool
c.vers, ok = wireToVersion(c.wireVersion, c.isDTLS)
if !ok {
panic("Could not map wire version")
@@ -320,15 +316,10 @@
return err
}
- msg, err := c.readHandshake()
+ newClientHello, err := readHandshakeType[clientHelloMsg](c)
if err != nil {
return err
}
- newClientHello, ok := msg.(*clientHelloMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(hs.clientHello, msg)
- }
if !bytes.Equal(newClientHello.cookie, helloVerifyRequest.cookie) {
return errors.New("dtls: invalid cookie")
}
@@ -799,15 +790,10 @@
}
// Read new ClientHello.
- newMsg, err := c.readHandshake()
+ newClientHello, err := readHandshakeType[clientHelloMsg](c)
if err != nil {
return err
}
- newClientHello, ok := newMsg.(*clientHelloMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(newClientHello, newMsg)
- }
if expected := config.Bugs.ExpectOuterServerName; len(expected) != 0 && expected != newClientHello.serverName {
return fmt.Errorf("tls: unexpected ClientHelloOuter server name: wanted %q, got %q", expected, newClientHello.serverName)
@@ -1252,15 +1238,10 @@
c.input.Reset()
}
if c.usesEndOfEarlyData() {
- msg, err := c.readHandshake()
+ endOfEarlyData, err := readHandshakeType[endOfEarlyDataMsg](c)
if err != nil {
return err
}
- endOfEarlyData, ok := msg.(*endOfEarlyDataMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(endOfEarlyData, msg)
- }
hs.writeClientHash(endOfEarlyData.marshal())
}
}
@@ -1272,15 +1253,10 @@
// If we sent an ALPS extension, the client must respond with a single EncryptedExtensions.
if encryptedExtensions.extensions.hasApplicationSettings || encryptedExtensions.extensions.hasApplicationSettingsOld {
- msg, err := c.readHandshake()
+ clientEncryptedExtensions, err := readHandshakeType[clientEncryptedExtensionsMsg](c)
if err != nil {
return err
}
- clientEncryptedExtensions, ok := msg.(*clientEncryptedExtensionsMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(clientEncryptedExtensions, msg)
- }
hs.writeClientHash(clientEncryptedExtensions.marshal())
// Expect client send new application settings not old.
@@ -1317,16 +1293,10 @@
// If we requested a client certificate, then the client must send a
// certificate message, even if it's empty.
if config.ClientAuth >= RequestClientCert {
- msg, err := c.readHandshake()
+ certMsg, err := readHandshakeType[certificateMsg](c)
if err != nil {
return err
}
-
- certMsg, ok := msg.(*certificateMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(certMsg, msg)
- }
hs.writeClientHash(certMsg.marshal())
if len(certMsg.certificates) == 0 {
@@ -1354,17 +1324,10 @@
}
if len(c.peerCertificates) > 0 {
- msg, err = c.readHandshake()
+ certVerify, err := readHandshakeType[certificateVerifyMsg](c)
if err != nil {
return err
}
-
- certVerify, ok := msg.(*certificateVerifyMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(certVerify, msg)
- }
-
c.peerSignatureAlgorithm = certVerify.signatureAlgorithm
input := hs.finishedHash.certificateVerifyInput(clientCertificateVerifyContextTLS13)
if err := verifyMessage(c.isClient, c.vers, pub, config, certVerify.signatureAlgorithm, input, certVerify.signature); err != nil {
@@ -1376,15 +1339,10 @@
}
if encryptedExtensions.extensions.channelIDRequested {
- msg, err := c.readHandshake()
+ channelIDMsg, err := readHandshakeType[channelIDMsg](c)
if err != nil {
return err
}
- channelIDMsg, ok := msg.(*channelIDMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(channelIDMsg, msg)
- }
channelIDHash := crypto.SHA256.New()
channelIDHash.Write(hs.finishedHash.certificateVerifyInput(channelIDContextTLS13))
channelID, err := verifyChannelIDMessage(channelIDMsg, channelIDHash.Sum(nil))
@@ -1397,16 +1355,10 @@
}
// Read the client Finished message.
- msg, err := c.readHandshake()
+ clientFinished, err := readHandshakeType[finishedMsg](c)
if err != nil {
return err
}
- clientFinished, ok := msg.(*finishedMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(clientFinished, msg)
- }
-
verify := hs.finishedHash.clientSum(clientHandshakeTrafficSecret)
if len(verify) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
@@ -1985,18 +1937,12 @@
var pub crypto.PublicKey // public key for client auth, if any
- msg, err := c.readHandshake()
- if err != nil {
- return err
- }
-
// If we requested a client certificate, then the client must send a
// certificate message, even if it's empty.
if config.ClientAuth >= RequestClientCert {
- certMsg, ok := msg.(*certificateMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(certMsg, msg)
+ certMsg, err := readHandshakeType[certificateMsg](c)
+ if err != nil {
+ return err
}
hs.writeClientHash(certMsg.marshal())
@@ -2018,18 +1964,12 @@
if err != nil {
return err
}
-
- msg, err = c.readHandshake()
- if err != nil {
- return err
- }
}
// Get client key exchange
- ckx, ok := msg.(*clientKeyExchangeMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(ckx, msg)
+ ckx, err := readHandshakeType[clientKeyExchangeMsg](c)
+ if err != nil {
+ return err
}
hs.writeClientHash(ckx.marshal())
@@ -2054,15 +1994,10 @@
// to the client's certificate. This allows us to verify that the client is in
// possession of the private key of the certificate.
if len(c.peerCertificates) > 0 {
- msg, err = c.readHandshake()
+ certVerify, err := readHandshakeType[certificateVerifyMsg](c)
if err != nil {
return err
}
- certVerify, ok := msg.(*certificateVerifyMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(certVerify, msg)
- }
// Determine the signature type.
var sigAlg signatureAlgorithm
@@ -2118,29 +2053,19 @@
}
if hs.hello.extensions.nextProtoNeg {
- msg, err := c.readHandshake()
+ nextProto, err := readHandshakeType[nextProtoMsg](c)
if err != nil {
return err
}
- nextProto, ok := msg.(*nextProtoMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(nextProto, msg)
- }
hs.writeClientHash(nextProto.marshal())
c.clientProtocol = nextProto.proto
}
if hs.hello.extensions.channelIDRequested {
- msg, err := c.readHandshake()
+ channelIDMsg, err := readHandshakeType[channelIDMsg](c)
if err != nil {
return err
}
- channelIDMsg, ok := msg.(*channelIDMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(channelIDMsg, msg)
- }
var resumeHash []byte
if isResume {
resumeHash = hs.sessionState.handshakeHash
@@ -2154,15 +2079,10 @@
hs.writeClientHash(channelIDMsg.marshal())
}
- msg, err := c.readHandshake()
+ clientFinished, err := readHandshakeType[finishedMsg](c)
if err != nil {
return err
}
- clientFinished, ok := msg.(*finishedMsg)
- if !ok {
- c.sendAlert(alertUnexpectedMessage)
- return unexpectedMessageError(clientFinished, msg)
- }
verify := hs.finishedHash.clientSum(hs.masterSecret)
if len(verify) != len(clientFinished.verifyData) ||