runner: Add a helper to read and downcast a message

We can't always use this because sometimes (particular in TLS 1.2), you
have to account for optional messages, but often we know exactly which
message we're expecting.

Change-Id: I4f6f59111fbf3e5f8a8fefa35802def9b2029196
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72949
Reviewed-by: Nick Harper <nharper@chromium.org>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 28c4f72..e634449 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -1528,6 +1528,19 @@
 	return m, nil
 }
 
+func readHandshakeType[T any](c *Conn) (*T, error) {
+	m, err := c.readHandshake()
+	if err != nil {
+		return nil, err
+	}
+	mType, ok := m.(*T)
+	if !ok {
+		c.sendAlert(alertUnexpectedMessage)
+		return nil, unexpectedMessageError(mType, m)
+	}
+	return mType, nil
+}
+
 // skipPacket processes all the DTLS records in packet. It updates
 // sequence number expectations but otherwise ignores them.
 func (c *Conn) skipPacket(packet []byte) error {
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index b7f1817..c71c295 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -1152,16 +1152,10 @@
 		return err
 	}
 
-	msg, err := c.readHandshake()
+	encryptedExtensions, err := readHandshakeType[encryptedExtensionsMsg](c)
 	if err != nil {
 		return err
 	}
-
-	encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
-	if !ok {
-		c.sendAlert(alertUnexpectedMessage)
-		return unexpectedMessageError(encryptedExtensions, msg)
-	}
 	hs.writeServerHash(encryptedExtensions.marshal())
 
 	if !bytes.Equal(encryptedExtensions.extensions.echRetryConfigs, c.config.Bugs.ExpectECHRetryConfigs) {
@@ -1277,15 +1271,10 @@
 		c.ocspResponse = certMsg.certificates[0].ocspResponse
 		c.sctList = certMsg.certificates[0].sctList
 
-		msg, err = c.readHandshake()
+		certVerifyMsg, err := readHandshakeType[certificateVerifyMsg](c)
 		if err != nil {
 			return err
 		}
-		certVerifyMsg, ok := msg.(*certificateVerifyMsg)
-		if !ok {
-			c.sendAlert(alertUnexpectedMessage)
-			return unexpectedMessageError(certVerifyMsg, msg)
-		}
 
 		c.peerSignatureAlgorithm = certVerifyMsg.signatureAlgorithm
 		input := hs.finishedHash.certificateVerifyInput(serverCertificateVerifyContextTLS13)
@@ -1301,16 +1290,10 @@
 		hs.writeServerHash(certVerifyMsg.marshal())
 	}
 
-	msg, err = c.readHandshake()
+	serverFinished, err := readHandshakeType[finishedMsg](c)
 	if err != nil {
 		return err
 	}
-	serverFinished, ok := msg.(*finishedMsg)
-	if !ok {
-		c.sendAlert(alertUnexpectedMessage)
-		return unexpectedMessageError(serverFinished, msg)
-	}
-
 	verify := hs.finishedHash.serverSum(serverHandshakeTrafficSecret)
 	if len(verify) != len(serverFinished.verifyData) ||
 		subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
@@ -1341,14 +1324,10 @@
 		// BoringSSL will always send two tickets half-RTT when
 		// negotiating 0-RTT.
 		for i := 0; i < shimConfig.HalfRTTTickets; i++ {
-			msg, err := c.readHandshake()
+			newSessionTicket, err := readHandshakeType[newSessionTicketMsg](c)
 			if err != nil {
 				return fmt.Errorf("tls: error reading half-RTT ticket: %s", err)
 			}
-			newSessionTicket, ok := msg.(*newSessionTicketMsg)
-			if !ok {
-				return errors.New("tls: expected half-RTT ticket")
-			}
 			// Defer processing until the resumption secret is computed.
 			deferredTickets = append(deferredTickets, newSessionTicket)
 		}
@@ -1622,16 +1601,10 @@
 
 	var leaf *x509.Certificate
 	if hs.suite.flags&suitePSK == 0 {
-		msg, err := c.readHandshake()
+		certMsg, err := readHandshakeType[certificateMsg](c)
 		if err != nil {
 			return err
 		}
-
-		certMsg, ok := msg.(*certificateMsg)
-		if !ok {
-			c.sendAlert(alertUnexpectedMessage)
-			return unexpectedMessageError(certMsg, msg)
-		}
 		hs.writeServerHash(certMsg.marshal())
 
 		if err := hs.verifyCertificates(certMsg); err != nil {
@@ -1641,15 +1614,10 @@
 	}
 
 	if hs.serverHello.extensions.ocspStapling {
-		msg, err := c.readHandshake()
+		cs, err := readHandshakeType[certificateStatusMsg](c)
 		if err != nil {
 			return err
 		}
-		cs, ok := msg.(*certificateStatusMsg)
-		if !ok {
-			c.sendAlert(alertUnexpectedMessage)
-			return unexpectedMessageError(cs, msg)
-		}
 		hs.writeServerHash(cs.marshal())
 
 		if cs.statusType == statusTypeOCSP {
@@ -2176,15 +2144,10 @@
 		return err
 	}
 
-	msg, err := c.readHandshake()
+	serverFinished, err := readHandshakeType[finishedMsg](c)
 	if err != nil {
 		return err
 	}
-	serverFinished, ok := msg.(*finishedMsg)
-	if !ok {
-		c.sendAlert(alertUnexpectedMessage)
-		return unexpectedMessageError(serverFinished, msg)
-	}
 
 	if c.config.Bugs.EarlyChangeCipherSpec == 0 {
 		verify := hs.finishedHash.serverSum(hs.masterSecret)
@@ -2233,15 +2196,10 @@
 		return errors.New("tls: received unexpected NewSessionTicket")
 	}
 
-	msg, err := c.readHandshake()
+	sessionTicketMsg, err := readHandshakeType[newSessionTicketMsg](c)
 	if err != nil {
 		return err
 	}
-	sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
-	if !ok {
-		c.sendAlert(alertUnexpectedMessage)
-		return unexpectedMessageError(sessionTicketMsg, msg)
-	}
 
 	if c.config.Bugs.ExpectNoNonEmptyNewSessionTicket && len(sessionTicketMsg.ticket) != 0 {
 		return errors.New("tls: received unexpected non-empty NewSessionTicket")
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 96490cb..9dddd13 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -166,16 +166,11 @@
 	config := hs.c.config
 	c := hs.c
 
-	msg, err := c.readHandshake()
+	var err error
+	hs.clientHello, err = readHandshakeType[clientHelloMsg](c)
 	if err != nil {
 		return err
 	}
-	var ok bool
-	hs.clientHello, ok = msg.(*clientHelloMsg)
-	if !ok {
-		c.sendAlert(alertUnexpectedMessage)
-		return unexpectedMessageError(hs.clientHello, msg)
-	}
 	if size := config.Bugs.RequireClientHelloSize; size != 0 && len(hs.clientHello.raw) != size {
 		return fmt.Errorf("tls: ClientHello record size is %d, but expected %d", len(hs.clientHello.raw), size)
 	}
@@ -290,6 +285,7 @@
 		c.wireVersion = config.Bugs.NegotiateVersionOnRenego
 	}
 
+	var ok bool
 	c.vers, ok = wireToVersion(c.wireVersion, c.isDTLS)
 	if !ok {
 		panic("Could not map wire version")
@@ -320,15 +316,10 @@
 			return err
 		}
 
-		msg, err := c.readHandshake()
+		newClientHello, err := readHandshakeType[clientHelloMsg](c)
 		if err != nil {
 			return err
 		}
-		newClientHello, ok := msg.(*clientHelloMsg)
-		if !ok {
-			c.sendAlert(alertUnexpectedMessage)
-			return unexpectedMessageError(hs.clientHello, msg)
-		}
 		if !bytes.Equal(newClientHello.cookie, helloVerifyRequest.cookie) {
 			return errors.New("dtls: invalid cookie")
 		}
@@ -799,15 +790,10 @@
 		}
 
 		// Read new ClientHello.
-		newMsg, err := c.readHandshake()
+		newClientHello, err := readHandshakeType[clientHelloMsg](c)
 		if err != nil {
 			return err
 		}
-		newClientHello, ok := newMsg.(*clientHelloMsg)
-		if !ok {
-			c.sendAlert(alertUnexpectedMessage)
-			return unexpectedMessageError(newClientHello, newMsg)
-		}
 
 		if expected := config.Bugs.ExpectOuterServerName; len(expected) != 0 && expected != newClientHello.serverName {
 			return fmt.Errorf("tls: unexpected ClientHelloOuter server name: wanted %q, got %q", expected, newClientHello.serverName)
@@ -1252,15 +1238,10 @@
 			c.input.Reset()
 		}
 		if c.usesEndOfEarlyData() {
-			msg, err := c.readHandshake()
+			endOfEarlyData, err := readHandshakeType[endOfEarlyDataMsg](c)
 			if err != nil {
 				return err
 			}
-			endOfEarlyData, ok := msg.(*endOfEarlyDataMsg)
-			if !ok {
-				c.sendAlert(alertUnexpectedMessage)
-				return unexpectedMessageError(endOfEarlyData, msg)
-			}
 			hs.writeClientHash(endOfEarlyData.marshal())
 		}
 	}
@@ -1272,15 +1253,10 @@
 
 	// If we sent an ALPS extension, the client must respond with a single EncryptedExtensions.
 	if encryptedExtensions.extensions.hasApplicationSettings || encryptedExtensions.extensions.hasApplicationSettingsOld {
-		msg, err := c.readHandshake()
+		clientEncryptedExtensions, err := readHandshakeType[clientEncryptedExtensionsMsg](c)
 		if err != nil {
 			return err
 		}
-		clientEncryptedExtensions, ok := msg.(*clientEncryptedExtensionsMsg)
-		if !ok {
-			c.sendAlert(alertUnexpectedMessage)
-			return unexpectedMessageError(clientEncryptedExtensions, msg)
-		}
 		hs.writeClientHash(clientEncryptedExtensions.marshal())
 
 		// Expect client send new application settings not old.
@@ -1317,16 +1293,10 @@
 	// If we requested a client certificate, then the client must send a
 	// certificate message, even if it's empty.
 	if config.ClientAuth >= RequestClientCert {
-		msg, err := c.readHandshake()
+		certMsg, err := readHandshakeType[certificateMsg](c)
 		if err != nil {
 			return err
 		}
-
-		certMsg, ok := msg.(*certificateMsg)
-		if !ok {
-			c.sendAlert(alertUnexpectedMessage)
-			return unexpectedMessageError(certMsg, msg)
-		}
 		hs.writeClientHash(certMsg.marshal())
 
 		if len(certMsg.certificates) == 0 {
@@ -1354,17 +1324,10 @@
 		}
 
 		if len(c.peerCertificates) > 0 {
-			msg, err = c.readHandshake()
+			certVerify, err := readHandshakeType[certificateVerifyMsg](c)
 			if err != nil {
 				return err
 			}
-
-			certVerify, ok := msg.(*certificateVerifyMsg)
-			if !ok {
-				c.sendAlert(alertUnexpectedMessage)
-				return unexpectedMessageError(certVerify, msg)
-			}
-
 			c.peerSignatureAlgorithm = certVerify.signatureAlgorithm
 			input := hs.finishedHash.certificateVerifyInput(clientCertificateVerifyContextTLS13)
 			if err := verifyMessage(c.isClient, c.vers, pub, config, certVerify.signatureAlgorithm, input, certVerify.signature); err != nil {
@@ -1376,15 +1339,10 @@
 	}
 
 	if encryptedExtensions.extensions.channelIDRequested {
-		msg, err := c.readHandshake()
+		channelIDMsg, err := readHandshakeType[channelIDMsg](c)
 		if err != nil {
 			return err
 		}
-		channelIDMsg, ok := msg.(*channelIDMsg)
-		if !ok {
-			c.sendAlert(alertUnexpectedMessage)
-			return unexpectedMessageError(channelIDMsg, msg)
-		}
 		channelIDHash := crypto.SHA256.New()
 		channelIDHash.Write(hs.finishedHash.certificateVerifyInput(channelIDContextTLS13))
 		channelID, err := verifyChannelIDMessage(channelIDMsg, channelIDHash.Sum(nil))
@@ -1397,16 +1355,10 @@
 	}
 
 	// Read the client Finished message.
-	msg, err := c.readHandshake()
+	clientFinished, err := readHandshakeType[finishedMsg](c)
 	if err != nil {
 		return err
 	}
-	clientFinished, ok := msg.(*finishedMsg)
-	if !ok {
-		c.sendAlert(alertUnexpectedMessage)
-		return unexpectedMessageError(clientFinished, msg)
-	}
-
 	verify := hs.finishedHash.clientSum(clientHandshakeTrafficSecret)
 	if len(verify) != len(clientFinished.verifyData) ||
 		subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
@@ -1985,18 +1937,12 @@
 
 	var pub crypto.PublicKey // public key for client auth, if any
 
-	msg, err := c.readHandshake()
-	if err != nil {
-		return err
-	}
-
 	// If we requested a client certificate, then the client must send a
 	// certificate message, even if it's empty.
 	if config.ClientAuth >= RequestClientCert {
-		certMsg, ok := msg.(*certificateMsg)
-		if !ok {
-			c.sendAlert(alertUnexpectedMessage)
-			return unexpectedMessageError(certMsg, msg)
+		certMsg, err := readHandshakeType[certificateMsg](c)
+		if err != nil {
+			return err
 		}
 		hs.writeClientHash(certMsg.marshal())
 
@@ -2018,18 +1964,12 @@
 		if err != nil {
 			return err
 		}
-
-		msg, err = c.readHandshake()
-		if err != nil {
-			return err
-		}
 	}
 
 	// Get client key exchange
-	ckx, ok := msg.(*clientKeyExchangeMsg)
-	if !ok {
-		c.sendAlert(alertUnexpectedMessage)
-		return unexpectedMessageError(ckx, msg)
+	ckx, err := readHandshakeType[clientKeyExchangeMsg](c)
+	if err != nil {
+		return err
 	}
 	hs.writeClientHash(ckx.marshal())
 
@@ -2054,15 +1994,10 @@
 	// to the client's certificate. This allows us to verify that the client is in
 	// possession of the private key of the certificate.
 	if len(c.peerCertificates) > 0 {
-		msg, err = c.readHandshake()
+		certVerify, err := readHandshakeType[certificateVerifyMsg](c)
 		if err != nil {
 			return err
 		}
-		certVerify, ok := msg.(*certificateVerifyMsg)
-		if !ok {
-			c.sendAlert(alertUnexpectedMessage)
-			return unexpectedMessageError(certVerify, msg)
-		}
 
 		// Determine the signature type.
 		var sigAlg signatureAlgorithm
@@ -2118,29 +2053,19 @@
 	}
 
 	if hs.hello.extensions.nextProtoNeg {
-		msg, err := c.readHandshake()
+		nextProto, err := readHandshakeType[nextProtoMsg](c)
 		if err != nil {
 			return err
 		}
-		nextProto, ok := msg.(*nextProtoMsg)
-		if !ok {
-			c.sendAlert(alertUnexpectedMessage)
-			return unexpectedMessageError(nextProto, msg)
-		}
 		hs.writeClientHash(nextProto.marshal())
 		c.clientProtocol = nextProto.proto
 	}
 
 	if hs.hello.extensions.channelIDRequested {
-		msg, err := c.readHandshake()
+		channelIDMsg, err := readHandshakeType[channelIDMsg](c)
 		if err != nil {
 			return err
 		}
-		channelIDMsg, ok := msg.(*channelIDMsg)
-		if !ok {
-			c.sendAlert(alertUnexpectedMessage)
-			return unexpectedMessageError(channelIDMsg, msg)
-		}
 		var resumeHash []byte
 		if isResume {
 			resumeHash = hs.sessionState.handshakeHash
@@ -2154,15 +2079,10 @@
 		hs.writeClientHash(channelIDMsg.marshal())
 	}
 
-	msg, err := c.readHandshake()
+	clientFinished, err := readHandshakeType[finishedMsg](c)
 	if err != nil {
 		return err
 	}
-	clientFinished, ok := msg.(*finishedMsg)
-	if !ok {
-		c.sendAlert(alertUnexpectedMessage)
-		return unexpectedMessageError(clientFinished, msg)
-	}
 
 	verify := hs.finishedHash.clientSum(hs.masterSecret)
 	if len(verify) != len(clientFinished.verifyData) ||