Implement basic TLS 1.3 client handshake in Go.
[Originally written by nharper and then revised by davidben.]
Most features are missing, but it works for a start. To avoid breaking
the fake TLS 1.3 tests while the C code is still not landed, all the
logic is gated on a global boolean. When the C code gets in, we'll
set it to true and remove this boolean.
Change-Id: I6b3a369890864c26203fc9cda37c8250024ce91b
Reviewed-on: https://boringssl-review.googlesource.com/8601
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/test/runner/alert.go b/ssl/test/runner/alert.go
index 8144b50..89c907f 100644
--- a/ssl/test/runner/alert.go
+++ b/ssl/test/runner/alert.go
@@ -40,6 +40,7 @@
alertUserCanceled alert = 90
alertNoRenegotiation alert = 100
alertMissingExtension alert = 109
+ alertUnsupportedExtension alert = 110
)
var alertText = map[alert]string{
@@ -67,6 +68,7 @@
alertUserCanceled: "user canceled",
alertNoRenegotiation: "no renegotiation",
alertMissingExtension: "missing extension",
+ alertUnsupportedExtension: "unsupported extension",
}
func (e alert) String() string {
diff --git a/ssl/test/runner/cipher_suites.go b/ssl/test/runner/cipher_suites.go
index ad2a638..ab22905 100644
--- a/ssl/test/runner/cipher_suites.go
+++ b/ssl/test/runner/cipher_suites.go
@@ -5,6 +5,7 @@
package runner
import (
+ "crypto"
"crypto/aes"
"crypto/cipher"
"crypto/des"
@@ -93,6 +94,13 @@
aead func(version uint16, key, fixedNonce []byte) *tlsAead
}
+func (cs cipherSuite) hash() crypto.Hash {
+ if cs.flags&suiteSHA384 != 0 {
+ return crypto.SHA384
+ }
+ return crypto.SHA256
+}
+
var cipherSuites = []*cipherSuite{
// Ciphersuite order is chosen so that ECDHE comes before plain RSA
// and RC4 comes before AES (because of the Lucky13 attack).
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index eabd0c6..db3f270 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -18,6 +18,9 @@
"time"
)
+// TODO(davidben): Flip this to true when the C code lands.
+const enableTLS13Handshake = false
+
const (
VersionSSL30 = 0x0300
VersionTLS10 = 0x0301
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index c33ac0c..601c731 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -199,6 +199,13 @@
return nil
}
+// updateKeys sets the current cipher state.
+func (hc *halfConn) updateKeys(cipher interface{}, version uint16) {
+ hc.version = version
+ hc.cipher = cipher
+ hc.incEpoch()
+}
+
// incSeq increments the sequence number.
func (hc *halfConn) incSeq(isOutgoing bool) {
limit := 0
@@ -1114,11 +1121,16 @@
}
case typeNewSessionTicket:
m = new(newSessionTicketMsg)
+ case typeEncryptedExtensions:
+ m = new(encryptedExtensionsMsg)
case typeCertificate:
- m = new(certificateMsg)
+ m = &certificateMsg{
+ hasRequestContext: c.vers >= VersionTLS13 && enableTLS13Handshake,
+ }
case typeCertificateRequest:
m = &certificateRequestMsg{
hasSignatureAlgorithm: c.vers >= VersionTLS12,
+ hasRequestContext: c.vers >= VersionTLS13 && enableTLS13Handshake,
}
case typeCertificateStatus:
m = new(certificateStatusMsg)
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index a2f6f65..56c49d9 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -26,6 +26,7 @@
hello *clientHelloMsg
suite *cipherSuite
finishedHash finishedHash
+ keyShares map[CurveID]ecdhCurve
masterSecret []byte
session *ClientSessionState
finishedBytes []byte
@@ -102,6 +103,31 @@
hello.secureRenegotiation = nil
}
+ var keyShares map[CurveID]ecdhCurve
+ if hello.vers >= VersionTLS13 && enableTLS13Handshake {
+ // Offer every supported curve in the initial ClientHello.
+ //
+ // TODO(davidben): For real code, default to a more conservative
+ // set like P-256 and X25519. Make it configurable for tests to
+ // stress the HelloRetryRequest logic when implemented.
+ keyShares = make(map[CurveID]ecdhCurve)
+ for _, curveID := range hello.supportedCurves {
+ curve, ok := curveForCurveID(curveID)
+ if !ok {
+ continue
+ }
+ publicKey, err := curve.offer(c.config.rand())
+ if err != nil {
+ return err
+ }
+ hello.keyShares = append(hello.keyShares, keyShareEntry{
+ group: curveID,
+ keyExchange: publicKey,
+ })
+ keyShares[curveID] = curve
+ }
+ }
+
possibleCipherSuites := c.config.cipherSuites()
hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites))
@@ -261,6 +287,7 @@
}
}
+ // TODO(davidben): Handle HelloRetryRequest.
serverHello, ok := msg.(*serverHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
@@ -286,82 +313,90 @@
hello: hello,
suite: suite,
finishedHash: newFinishedHash(c.vers, suite),
+ keyShares: keyShares,
session: session,
}
hs.writeHash(helloBytes, hs.c.sendHandshakeSeq-1)
hs.writeServerHash(hs.serverHello.marshal())
- if c.config.Bugs.EarlyChangeCipherSpec > 0 {
- hs.establishKeys()
- c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
- }
-
- if hs.serverHello.compressionMethod != compressionNone {
- c.sendAlert(alertUnexpectedMessage)
- return errors.New("tls: server selected unsupported compression format")
- }
-
- err = hs.processServerExtensions(&serverHello.extensions)
- if err != nil {
- return err
- }
-
- isResume, err := hs.processServerHello()
- if err != nil {
- return err
- }
-
- if isResume {
- if c.config.Bugs.EarlyChangeCipherSpec == 0 {
- if err := hs.establishKeys(); err != nil {
- return err
- }
- }
- if err := hs.readSessionTicket(); err != nil {
- return err
- }
- if err := hs.readFinished(c.firstFinished[:]); err != nil {
- return err
- }
- if err := hs.sendFinished(nil, isResume); err != nil {
+ if c.vers >= VersionTLS13 && enableTLS13Handshake {
+ if err := hs.doTLS13Handshake(); err != nil {
return err
}
} else {
- if err := hs.doFullHandshake(); err != nil {
+ if c.config.Bugs.EarlyChangeCipherSpec > 0 {
+ hs.establishKeys()
+ c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
+ }
+
+ if hs.serverHello.compressionMethod != compressionNone {
+ c.sendAlert(alertUnexpectedMessage)
+ return errors.New("tls: server selected unsupported compression format")
+ }
+
+ err = hs.processServerExtensions(&serverHello.extensions)
+ if err != nil {
return err
}
- if err := hs.establishKeys(); err != nil {
+
+ isResume, err := hs.processServerHello()
+ if err != nil {
return err
}
- if err := hs.sendFinished(c.firstFinished[:], isResume); err != nil {
- return err
+
+ if isResume {
+ if c.config.Bugs.EarlyChangeCipherSpec == 0 {
+ if err := hs.establishKeys(); err != nil {
+ return err
+ }
+ }
+ if err := hs.readSessionTicket(); err != nil {
+ return err
+ }
+ if err := hs.readFinished(c.firstFinished[:]); err != nil {
+ return err
+ }
+ if err := hs.sendFinished(nil, isResume); err != nil {
+ return err
+ }
+ } else {
+ if err := hs.doFullHandshake(); err != nil {
+ return err
+ }
+ if err := hs.establishKeys(); err != nil {
+ return err
+ }
+ if err := hs.sendFinished(c.firstFinished[:], isResume); err != nil {
+ return err
+ }
+ // Most retransmits are triggered by a timeout, but the final
+ // leg of the handshake is retransmited upon re-receiving a
+ // Finished.
+ if err := c.simulatePacketLoss(func() {
+ c.writeRecord(recordTypeHandshake, hs.finishedBytes)
+ c.flushHandshake()
+ }); err != nil {
+ return err
+ }
+ if err := hs.readSessionTicket(); err != nil {
+ return err
+ }
+ if err := hs.readFinished(nil); err != nil {
+ return err
+ }
}
- // Most retransmits are triggered by a timeout, but the final
- // leg of the handshake is retransmited upon re-receiving a
- // Finished.
- if err := c.simulatePacketLoss(func() {
- c.writeRecord(recordTypeHandshake, hs.finishedBytes)
- c.flushHandshake()
- }); err != nil {
- return err
+
+ if sessionCache != nil && hs.session != nil && session != hs.session {
+ if c.config.Bugs.RequireSessionTickets && len(hs.session.sessionTicket) == 0 {
+ return errors.New("tls: new session used session IDs instead of tickets")
+ }
+ sessionCache.Put(cacheKey, hs.session)
}
- if err := hs.readSessionTicket(); err != nil {
- return err
- }
- if err := hs.readFinished(nil); err != nil {
- return err
- }
+
+ c.didResume = isResume
}
- if sessionCache != nil && hs.session != nil && session != hs.session {
- if c.config.Bugs.RequireSessionTickets && len(hs.session.sessionTicket) == 0 {
- return errors.New("tls: new session used session IDs instead of tickets")
- }
- sessionCache.Put(cacheKey, hs.session)
- }
-
- c.didResume = isResume
c.handshakeComplete = true
c.cipherSuite = suite
copy(c.clientRandom[:], hs.hello.random)
@@ -371,6 +406,202 @@
return nil
}
+func (hs *clientHandshakeState) doTLS13Handshake() error {
+ c := hs.c
+
+ // Once the PRF hash is known, TLS 1.3 does not require a handshake
+ // buffer.
+ hs.finishedHash.discardHandshakeBuffer()
+
+ zeroSecret := hs.finishedHash.zeroSecret()
+
+ // Resolve PSK and compute the early secret.
+ //
+ // TODO(davidben): This will need to be handled slightly earlier once
+ // 0-RTT is implemented.
+ var psk []byte
+ if hs.suite.flags&suitePSK != 0 {
+ if !hs.serverHello.hasPSKIdentity {
+ c.sendAlert(alertMissingExtension)
+ return errors.New("tls: server omitted the PSK identity extension")
+ }
+
+ // TODO(davidben): Support PSK ciphers and PSK resumption. Set
+ // the resumption context appropriately if resuming.
+ return errors.New("tls: PSK ciphers not implemented for TLS 1.3")
+ } else {
+ if hs.serverHello.hasPSKIdentity {
+ c.sendAlert(alertUnsupportedExtension)
+ return errors.New("tls: server sent unexpected PSK identity")
+ }
+
+ psk = zeroSecret
+ hs.finishedHash.setResumptionContext(zeroSecret)
+ }
+
+ earlySecret := hs.finishedHash.extractKey(zeroSecret, psk)
+
+ // Resolve ECDHE and compute the handshake secret.
+ var ecdheSecret []byte
+ if hs.suite.flags&suiteECDHE != 0 {
+ if !hs.serverHello.hasKeyShare {
+ c.sendAlert(alertMissingExtension)
+ return errors.New("tls: server omitted the key share extension")
+ }
+
+ curve, ok := hs.keyShares[hs.serverHello.keyShare.group]
+ if !ok {
+ c.sendAlert(alertHandshakeFailure)
+ return errors.New("tls: server selected an unsupported group")
+ }
+
+ var err error
+ ecdheSecret, err = curve.finish(hs.serverHello.keyShare.keyExchange)
+ if err != nil {
+ return err
+ }
+ } else {
+ if hs.serverHello.hasKeyShare {
+ c.sendAlert(alertUnsupportedExtension)
+ return errors.New("tls: server sent unexpected key share extension")
+ }
+
+ ecdheSecret = zeroSecret
+ }
+
+ // Compute the handshake secret.
+ handshakeSecret := hs.finishedHash.extractKey(earlySecret, ecdheSecret)
+
+ // Switch to handshake traffic keys.
+ handshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, handshakeTrafficLabel)
+ c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite), c.vers)
+ c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite), c.vers)
+
+ msg, err := c.readHandshake()
+ if err != nil {
+ return err
+ }
+
+ encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
+ if !ok {
+ c.sendAlert(alertUnexpectedMessage)
+ return unexpectedMessageError(encryptedExtensions, msg)
+ }
+ hs.writeServerHash(encryptedExtensions.marshal())
+
+ err = hs.processServerExtensions(&encryptedExtensions.extensions)
+ if err != nil {
+ return err
+ }
+
+ var chainToSend *Certificate
+ var certRequested bool
+ var certRequestContext []byte
+ if hs.suite.flags&suitePSK == 0 {
+ // TODO(davidben): Save OCSP response and SCT list. Forbid them
+ // if not negotiating a certificate-based extension.
+
+ msg, err := c.readHandshake()
+ if err != nil {
+ return err
+ }
+
+ certReq, ok := msg.(*certificateRequestMsg)
+ if ok {
+ hs.writeServerHash(certReq.marshal())
+ certRequested = true
+ certRequestContext = certReq.requestContext
+
+ chainToSend, err = selectClientCertificate(c, certReq)
+ if err != nil {
+ return err
+ }
+
+ msg, err = c.readHandshake()
+ if err != nil {
+ return err
+ }
+ }
+
+ certMsg, ok := msg.(*certificateMsg)
+ if !ok {
+ c.sendAlert(alertUnexpectedMessage)
+ return unexpectedMessageError(certMsg, msg)
+ }
+ hs.writeServerHash(certMsg.marshal())
+
+ if err := hs.verifyCertificates(certMsg); err != nil {
+ return err
+ }
+ leaf := c.peerCertificates[0]
+
+ msg, err = c.readHandshake()
+ if err != nil {
+ return err
+ }
+ certVerifyMsg, ok := msg.(*certificateVerifyMsg)
+ if !ok {
+ c.sendAlert(alertUnexpectedMessage)
+ return unexpectedMessageError(certVerifyMsg, msg)
+ }
+
+ input := hs.finishedHash.certificateVerifyInput(serverCertificateVerifyContextTLS13)
+ err = verifyMessage(c.vers, leaf.PublicKey, certVerifyMsg.signatureAlgorithm, input, certVerifyMsg.signature)
+ if err != nil {
+ return err
+ }
+
+ hs.writeServerHash(certVerifyMsg.marshal())
+ }
+
+ msg, err = c.readHandshake()
+ if err != nil {
+ return err
+ }
+ serverFinished, ok := msg.(*finishedMsg)
+ if !ok {
+ c.sendAlert(alertUnexpectedMessage)
+ return unexpectedMessageError(serverFinished, msg)
+ }
+
+ verify := hs.finishedHash.serverSum(handshakeTrafficSecret)
+ if len(verify) != len(serverFinished.verifyData) ||
+ subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
+ c.sendAlert(alertHandshakeFailure)
+ return errors.New("tls: server's Finished message was incorrect")
+ }
+
+ hs.writeServerHash(serverFinished.marshal())
+
+ // The various secrets do not incorporate the client's final leg, so
+ // derive them now before updating the handshake context.
+ masterSecret := hs.finishedHash.extractKey(handshakeSecret, zeroSecret)
+ trafficSecret := hs.finishedHash.deriveSecret(masterSecret, applicationTrafficLabel)
+
+ if certRequested {
+ _ = chainToSend
+ _ = certRequestContext
+ return errors.New("tls: client auth not implemented.")
+ }
+
+ // Send a client Finished message.
+ finished := new(finishedMsg)
+ finished.verifyData = hs.finishedHash.clientSum(handshakeTrafficSecret)
+ if c.config.Bugs.BadFinished {
+ finished.verifyData[0]++
+ }
+ c.writeRecord(recordTypeHandshake, finished.marshal())
+
+ // Switch to application data keys.
+ c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite), c.vers)
+ c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite), c.vers)
+
+ // TODO(davidben): Derive and save the exporter master secret for key exporters. Swap out the masterSecret field.
+ // TODO(davidben): Derive and save the resumption master secret for receiving tickets.
+ // TODO(davidben): Save the traffic secret for KeyUpdate.
+ return nil
+}
+
func (hs *clientHandshakeState) doFullHandshake() error {
c := hs.c
@@ -633,17 +864,19 @@
func (hs *clientHandshakeState) processServerExtensions(serverExtensions *serverExtensions) error {
c := hs.c
- if c.config.Bugs.RequireRenegotiationInfo && serverExtensions.secureRenegotiation == nil {
- return errors.New("tls: renegotiation extension missing")
- }
+ if c.vers < VersionTLS13 || !enableTLS13Handshake {
+ if c.config.Bugs.RequireRenegotiationInfo && serverExtensions.secureRenegotiation == nil {
+ return errors.New("tls: renegotiation extension missing")
+ }
- if len(c.clientVerify) > 0 && !c.noRenegotiationInfo() {
- var expectedRenegInfo []byte
- expectedRenegInfo = append(expectedRenegInfo, c.clientVerify...)
- expectedRenegInfo = append(expectedRenegInfo, c.serverVerify...)
- if !bytes.Equal(serverExtensions.secureRenegotiation, expectedRenegInfo) {
- c.sendAlert(alertHandshakeFailure)
- return fmt.Errorf("tls: renegotiation mismatch")
+ if len(c.clientVerify) > 0 && !c.noRenegotiationInfo() {
+ var expectedRenegInfo []byte
+ expectedRenegInfo = append(expectedRenegInfo, c.clientVerify...)
+ expectedRenegInfo = append(expectedRenegInfo, c.serverVerify...)
+ if !bytes.Equal(serverExtensions.secureRenegotiation, expectedRenegInfo) {
+ c.sendAlert(alertHandshakeFailure)
+ return fmt.Errorf("tls: renegotiation mismatch")
+ }
}
}
@@ -679,11 +912,21 @@
c.usedALPN = true
}
+ if serverHasNPN && c.vers >= VersionTLS13 && enableTLS13Handshake {
+ c.sendAlert(alertHandshakeFailure)
+ return errors.New("server advertised NPN over TLS 1.3")
+ }
+
if !hs.hello.channelIDSupported && serverExtensions.channelIDRequested {
c.sendAlert(alertHandshakeFailure)
return errors.New("server advertised unrequested Channel ID extension")
}
+ if serverExtensions.channelIDRequested && c.vers >= VersionTLS13 && enableTLS13Handshake {
+ c.sendAlert(alertHandshakeFailure)
+ return errors.New("server advertised Channel ID over TLS 1.3")
+ }
+
if serverExtensions.srtpProtectionProfile != 0 {
if serverExtensions.srtpMasterKeyIdentifier != "" {
return errors.New("tls: server selected SRTP MKI value")
@@ -960,12 +1203,14 @@
// the contrary.
var rsaAvail, ecdsaAvail bool
- for _, certType := range certReq.certificateTypes {
- switch certType {
- case CertTypeRSASign:
- rsaAvail = true
- case CertTypeECDSASign:
- ecdsaAvail = true
+ if !certReq.hasRequestContext {
+ for _, certType := range certReq.certificateTypes {
+ switch certType {
+ case CertTypeRSASign:
+ rsaAvail = true
+ case CertTypeECDSASign:
+ ecdsaAvail = true
+ }
}
}
@@ -974,7 +1219,7 @@
// certReq.certificateAuthorities
findCert:
for i, chain := range c.config.Certificates {
- if !rsaAvail && !ecdsaAvail {
+ if !certReq.hasRequestContext && !rsaAvail && !ecdsaAvail {
continue
}
@@ -998,11 +1243,13 @@
}
}
- switch {
- case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA:
- case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA:
- default:
- continue findCert
+ if !certReq.hasRequestContext {
+ switch {
+ case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA:
+ case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA:
+ default:
+ continue findCert
+ }
}
if len(certReq.certificateAuthorities) == 0 {
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 17fb5cb..d349ae7 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -111,11 +111,13 @@
}
type clientHelloMsg struct {
- raw []byte
- isDTLS bool
- vers uint16
- random []byte
- sessionId []byte
+ raw []byte
+ isDTLS bool
+ vers uint16
+ random []byte
+ sessionId []byte
+ // TODO(davidben): Add support for TLS 1.3 cookies which are larger and
+ // use an extension.
cookie []byte
cipherSuites []uint16
compressionMethods []uint8
@@ -681,14 +683,19 @@
}
type serverHelloMsg struct {
- raw []byte
- isDTLS bool
- vers uint16
- random []byte
- sessionId []byte
- cipherSuite uint16
- compressionMethod uint8
- extensions serverExtensions
+ raw []byte
+ isDTLS bool
+ vers uint16
+ random []byte
+ sessionId []byte
+ cipherSuite uint16
+ hasKeyShare bool
+ keyShare keyShareEntry
+ hasPSKIdentity bool
+ pskIdentity uint16
+ earlyDataIndication bool
+ compressionMethod uint8
+ extensions serverExtensions
}
func (m *serverHelloMsg) marshal() []byte {
@@ -702,17 +709,39 @@
vers := versionToWire(m.vers, m.isDTLS)
hello.addU16(vers)
hello.addBytes(m.random)
- sessionId := hello.addU8LengthPrefixed()
- sessionId.addBytes(m.sessionId)
+ if m.vers < VersionTLS13 || !enableTLS13Handshake {
+ sessionId := hello.addU8LengthPrefixed()
+ sessionId.addBytes(m.sessionId)
+ }
hello.addU16(m.cipherSuite)
- hello.addU8(m.compressionMethod)
+ if m.vers < VersionTLS13 || !enableTLS13Handshake {
+ hello.addU8(m.compressionMethod)
+ }
extensions := hello.addU16LengthPrefixed()
- m.extensions.marshal(extensions)
-
- if extensions.len() == 0 {
- hello.discardChild()
+ if m.vers >= VersionTLS13 && enableTLS13Handshake {
+ if m.hasKeyShare {
+ extensions.addU16(extensionKeyShare)
+ keyShare := extensions.addU16LengthPrefixed()
+ keyShare.addU16(uint16(m.keyShare.group))
+ keyExchange := keyShare.addU16LengthPrefixed()
+ keyExchange.addBytes(m.keyShare.keyExchange)
+ }
+ if m.hasPSKIdentity {
+ extensions.addU16(extensionPreSharedKey)
+ extensions.addU16(2) // Length
+ extensions.addU16(m.pskIdentity)
+ }
+ if m.earlyDataIndication {
+ extensions.addU16(extensionEarlyData)
+ extensions.addU16(0) // Length
+ }
+ } else {
+ m.extensions.marshal(extensions)
+ if extensions.len() == 0 {
+ hello.discardChild()
+ }
}
m.raw = handshakeMsg.finish()
@@ -726,21 +755,30 @@
m.raw = data
m.vers = wireToVersion(uint16(data[4])<<8|uint16(data[5]), m.isDTLS)
m.random = data[6:38]
- sessionIdLen := int(data[38])
- if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
- return false
+ data = data[38:]
+ if m.vers < VersionTLS13 || !enableTLS13Handshake {
+ sessionIdLen := int(data[0])
+ if sessionIdLen > 32 || len(data) < 1+sessionIdLen {
+ return false
+ }
+ m.sessionId = data[1 : 1+sessionIdLen]
+ data = data[1+sessionIdLen:]
}
- m.sessionId = data[39 : 39+sessionIdLen]
- data = data[39+sessionIdLen:]
- if len(data) < 3 {
+ if len(data) < 2 {
return false
}
m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
- m.compressionMethod = data[2]
- data = data[3:]
+ data = data[2:]
+ if m.vers < VersionTLS13 || !enableTLS13Handshake {
+ if len(data) < 1 {
+ return false
+ }
+ m.compressionMethod = data[0]
+ data = data[1:]
+ }
- if len(data) == 0 {
- // ServerHello is optionally followed by extension data
+ if len(data) == 0 && (m.vers < VersionTLS13 || enableTLS13Handshake) {
+ // Extension data is optional before TLS 1.3.
m.extensions = serverExtensions{}
return true
}
@@ -754,13 +792,98 @@
return false
}
- if !m.extensions.unmarshal(data) {
+ if m.vers >= VersionTLS13 && enableTLS13Handshake {
+ for len(data) != 0 {
+ if len(data) < 4 {
+ return false
+ }
+ extension := uint16(data[0])<<8 | uint16(data[1])
+ length := int(data[2])<<8 | int(data[3])
+ data = data[4:]
+
+ if len(data) < length {
+ return false
+ }
+ d := data[:length]
+ data = data[length:]
+
+ switch extension {
+ case extensionKeyShare:
+ m.hasKeyShare = true
+ if len(d) < 4 {
+ return false
+ }
+ m.keyShare.group = CurveID(uint16(d[0])<<8 | uint16(d[1]))
+ keyExchLen := int(d[2])<<8 | int(d[3])
+ if keyExchLen != len(d)-4 {
+ return false
+ }
+ m.keyShare.keyExchange = make([]byte, keyExchLen)
+ copy(m.keyShare.keyExchange, d[4:])
+ case extensionPreSharedKey:
+ if len(d) != 2 {
+ return false
+ }
+ m.pskIdentity = uint16(d[0])<<8 | uint16(d[1])
+ m.hasPSKIdentity = true
+ case extensionEarlyData:
+ if len(d) != 0 {
+ return false
+ }
+ m.earlyDataIndication = true
+ default:
+ // Only allow the 3 extensions that are sent in
+ // the clear in TLS 1.3.
+ return false
+ }
+ }
+ } else if !m.extensions.unmarshal(data) {
return false
}
return true
}
+type encryptedExtensionsMsg struct {
+ raw []byte
+ extensions serverExtensions
+}
+
+func (m *encryptedExtensionsMsg) marshal() []byte {
+ if m.raw != nil {
+ return m.raw
+ }
+
+ encryptedExtensionsMsg := newByteBuilder()
+ encryptedExtensionsMsg.addU8(typeEncryptedExtensions)
+ encryptedExtensions := encryptedExtensionsMsg.addU24LengthPrefixed()
+ extensions := encryptedExtensions.addU16LengthPrefixed()
+ m.extensions.marshal(extensions)
+
+ m.raw = encryptedExtensionsMsg.finish()
+ return m.raw
+}
+
+func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
+ if len(data) < 6 {
+ return false
+ }
+ if data[0] != typeEncryptedExtensions {
+ return false
+ }
+ msgLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
+ data = data[4:]
+ if len(data) != msgLen {
+ return false
+ }
+ extLen := int(data[0])<<8 | int(data[1])
+ data = data[2:]
+ if extLen != len(data) {
+ return false
+ }
+ return m.extensions.unmarshal(data)
+}
+
type serverExtensions struct {
nextProtoNeg bool
nextProtos []string
@@ -962,8 +1085,10 @@
}
type certificateMsg struct {
- raw []byte
- certificates [][]byte
+ raw []byte
+ hasRequestContext bool
+ requestContext []byte
+ certificates [][]byte
}
func (m *certificateMsg) marshal() (x []byte) {
@@ -974,6 +1099,10 @@
certMsg := newByteBuilder()
certMsg.addU8(typeCertificate)
certificate := certMsg.addU24LengthPrefixed()
+ if m.hasRequestContext {
+ context := certificate.addU8LengthPrefixed()
+ context.addBytes(m.requestContext)
+ }
certificateList := certificate.addU24LengthPrefixed()
for _, cert := range m.certificates {
certEntry := certificateList.addU24LengthPrefixed()
@@ -985,24 +1114,43 @@
}
func (m *certificateMsg) unmarshal(data []byte) bool {
- if len(data) < 7 {
+ if len(data) < 4 {
return false
}
m.raw = data
- certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
- if uint32(len(data)) != certsLen+7 {
+ data = data[4:]
+
+ if m.hasRequestContext {
+ if len(data) == 0 {
+ return false
+ }
+ contextLen := int(data[0])
+ if len(data) < 1+contextLen {
+ return false
+ }
+ m.requestContext = make([]byte, contextLen)
+ copy(m.requestContext, data[1:])
+ data = data[1+contextLen:]
+ }
+
+ if len(data) < 3 {
+ return false
+ }
+ certsLen := int(data[0])<<16 | int(data[1])<<8 | int(data[2])
+ data = data[3:]
+ if len(data) != certsLen {
return false
}
numCerts := 0
- d := data[7:]
+ d := data
for certsLen > 0 {
if len(d) < 4 {
return false
}
- certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
- if uint32(len(d)) < 3+certLen {
+ certLen := int(d[0])<<16 | int(d[1])<<8 | int(d[2])
+ if len(d) < 3+certLen {
return false
}
d = d[3+certLen:]
@@ -1011,7 +1159,7 @@
}
m.certificates = make([][]byte, numCerts)
- d = data[7:]
+ d = data
for i := 0; i < numCerts; i++ {
certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
m.certificates[i] = d[3 : 3+certLen]
@@ -1245,8 +1393,13 @@
// of signature and hash functions. This change was introduced with TLS
// 1.2.
hasSignatureAlgorithm bool
+ // hasRequestContext indicates whether this message includes a context
+ // field instead of certificateTypes. This change was introduced with
+ // TLS 1.3.
+ hasRequestContext bool
certificateTypes []byte
+ requestContext []byte
signatureAlgorithms []signatureAlgorithm
certificateAuthorities [][]byte
}
@@ -1261,8 +1414,13 @@
builder.addU8(typeCertificateRequest)
body := builder.addU24LengthPrefixed()
- certificateTypes := body.addU8LengthPrefixed()
- certificateTypes.addBytes(m.certificateTypes)
+ if m.hasRequestContext {
+ requestContext := body.addU8LengthPrefixed()
+ requestContext.addBytes(m.requestContext)
+ } else {
+ certificateTypes := body.addU8LengthPrefixed()
+ certificateTypes.addBytes(m.certificateTypes)
+ }
if m.hasSignatureAlgorithm {
signatureAlgorithms := body.addU16LengthPrefixed()
@@ -1287,25 +1445,26 @@
if len(data) < 5 {
return false
}
+ data = data[4:]
- length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
- if uint32(len(data))-4 != length {
- return false
+ if m.hasRequestContext {
+ contextLen := int(data[0])
+ if len(data) < 1+contextLen {
+ return false
+ }
+ m.requestContext = make([]byte, contextLen)
+ copy(m.requestContext, data[1:])
+ data = data[1+contextLen:]
+ } else {
+ numCertTypes := int(data[0])
+ if len(data) < 1+numCertTypes {
+ return false
+ }
+ m.certificateTypes = make([]byte, numCertTypes)
+ copy(m.certificateTypes, data[1:])
+ data = data[1+numCertTypes:]
}
- numCertTypes := int(data[4])
- data = data[5:]
- if numCertTypes == 0 || len(data) <= numCertTypes {
- return false
- }
-
- m.certificateTypes = make([]byte, numCertTypes)
- if copy(m.certificateTypes, data) != numCertTypes {
- return false
- }
-
- data = data[numCertTypes:]
-
if m.hasSignatureAlgorithm {
if len(data) < 2 {
return false
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go
index a91a319..0fd5762 100644
--- a/ssl/test/runner/prf.go
+++ b/ssl/test/runner/prf.go
@@ -5,11 +5,11 @@
package runner
import (
+ "crypto"
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
- "crypto/sha512"
"hash"
)
@@ -133,13 +133,11 @@
// Once we no longer support Fake TLS 1.3, the VersionTLS13 should be
// removed from this case statement.
case VersionTLS12, VersionTLS13:
- if suite.flags&suiteSHA384 != 0 {
- return prf12(sha512.New384)
+ if version == VersionTLS12 || !enableTLS13Handshake {
+ return prf12(suite.hash().New)
}
- return prf12(sha256.New)
- default:
- panic("unknown version")
}
+ panic("unknown version")
}
// masterFromPreMasterSecret generates the master secret from the pre-master
@@ -188,20 +186,38 @@
}
func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash {
- if version >= VersionTLS12 {
- newHash := sha256.New
- if cipherSuite.flags&suiteSHA384 != 0 {
- newHash = sha512.New384
- }
+ var ret finishedHash
- return finishedHash{newHash(), newHash(), nil, nil, []byte{}, version, prf12(newHash)}
+ if version >= VersionTLS12 {
+ ret.hash = cipherSuite.hash()
+
+ ret.client = ret.hash.New()
+ ret.server = ret.hash.New()
+
+ if version == VersionTLS12 || !enableTLS13Handshake {
+ ret.prf = prf12(ret.hash.New)
+ }
+ } else {
+ ret.hash = crypto.MD5SHA1
+
+ ret.client = sha1.New()
+ ret.server = sha1.New()
+ ret.clientMD5 = md5.New()
+ ret.serverMD5 = md5.New()
+
+ ret.prf = prf10
}
- return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), []byte{}, version, prf10}
+
+ ret.buffer = []byte{}
+ ret.version = version
+ return ret
}
// A finishedHash calculates the hash of a set of handshake messages suitable
// for including in a Finished message.
type finishedHash struct {
+ hash crypto.Hash
+
client hash.Hash
server hash.Hash
@@ -213,6 +229,10 @@
// full buffer is required.
buffer []byte
+ // TLS 1.3 has a resumption context which is carried over on PSK
+ // resumption.
+ resumptionContextHash []byte
+
version uint16
prf func(result, secret, label, seed []byte)
}
@@ -280,26 +300,40 @@
// clientSum returns the contents of the verify_data member of a client's
// Finished message.
-func (h finishedHash) clientSum(masterSecret []byte) []byte {
+func (h finishedHash) clientSum(baseKey []byte) []byte {
if h.version == VersionSSL30 {
- return finishedSum30(h.clientMD5, h.client, masterSecret, ssl3ClientFinishedMagic[:])
+ return finishedSum30(h.clientMD5, h.client, baseKey, ssl3ClientFinishedMagic[:])
}
- out := make([]byte, finishedVerifyLength)
- h.prf(out, masterSecret, clientFinishedLabel, h.Sum())
- return out
+ if h.version < VersionTLS13 || !enableTLS13Handshake {
+ out := make([]byte, finishedVerifyLength)
+ h.prf(out, baseKey, clientFinishedLabel, h.Sum())
+ return out
+ }
+
+ clientFinishedKey := hkdfExpandLabel(h.hash, baseKey, clientFinishedLabel, nil, h.hash.Size())
+ finishedHMAC := hmac.New(h.hash.New, clientFinishedKey)
+ finishedHMAC.Write(h.appendContextHashes(nil))
+ return finishedHMAC.Sum(nil)
}
// serverSum returns the contents of the verify_data member of a server's
// Finished message.
-func (h finishedHash) serverSum(masterSecret []byte) []byte {
+func (h finishedHash) serverSum(baseKey []byte) []byte {
if h.version == VersionSSL30 {
- return finishedSum30(h.serverMD5, h.server, masterSecret, ssl3ServerFinishedMagic[:])
+ return finishedSum30(h.serverMD5, h.server, baseKey, ssl3ServerFinishedMagic[:])
}
- out := make([]byte, finishedVerifyLength)
- h.prf(out, masterSecret, serverFinishedLabel, h.Sum())
- return out
+ if h.version < VersionTLS13 || !enableTLS13Handshake {
+ out := make([]byte, finishedVerifyLength)
+ h.prf(out, baseKey, serverFinishedLabel, h.Sum())
+ return out
+ }
+
+ serverFinishedKey := hkdfExpandLabel(h.hash, baseKey, serverFinishedLabel, nil, h.hash.Size())
+ finishedHMAC := hmac.New(h.hash.New, serverFinishedKey)
+ finishedHMAC.Write(h.appendContextHashes(nil))
+ return finishedHMAC.Sum(nil)
}
// hashForClientCertificateSSL3 returns the hash to be signed for client
@@ -331,3 +365,125 @@
func (h *finishedHash) discardHandshakeBuffer() {
h.buffer = nil
}
+
+// zeroSecretTLS13 returns the default all zeros secret for TLS 1.3, used when a
+// given secret is not available in the handshake. See draft-ietf-tls-tls13-13,
+// section 7.1.
+func (h *finishedHash) zeroSecret() []byte {
+ return make([]byte, h.hash.Size())
+}
+
+// setResumptionContext sets the TLS 1.3 resumption context.
+func (h *finishedHash) setResumptionContext(resumptionContext []byte) {
+ hash := h.hash.New()
+ hash.Write(resumptionContext)
+ h.resumptionContextHash = hash.Sum(nil)
+}
+
+// extractKey combines two secrets together with HKDF-Expand in the TLS 1.3 key
+// derivation schedule.
+func (h *finishedHash) extractKey(salt, ikm []byte) []byte {
+ return hkdfExtract(h.hash.New, salt, ikm)
+}
+
+// hkdfExpandLabel implements TLS 1.3's HKDF-Expand-Label function, as defined
+// in section 7.1 of draft-ietf-tls-tls13-13.
+func hkdfExpandLabel(hash crypto.Hash, secret, label, hashValue []byte, length int) []byte {
+ if len(label) > 255 || len(hashValue) > 255 {
+ panic("hkdfExpandLabel: label or hashValue too long")
+ }
+ hkdfLabel := make([]byte, 3+9+len(label)+1+len(hashValue))
+ x := hkdfLabel
+ x[0] = byte(length >> 8)
+ x[1] = byte(length)
+ x[2] = byte(9 + len(label))
+ x = x[3:]
+ copy(x, []byte("TLS 1.3, "))
+ x = x[9:]
+ copy(x, label)
+ x = x[len(label):]
+ x[0] = byte(len(hashValue))
+ copy(x[1:], hashValue)
+ return hkdfExpand(hash.New, secret, hkdfLabel, length)
+}
+
+// appendContextHashes returns the concatenation of the handshake hash and the
+// resumption context hash, as used in TLS 1.3.
+func (h *finishedHash) appendContextHashes(b []byte) []byte {
+ b = h.client.Sum(b)
+ b = append(b, h.resumptionContextHash...)
+ return b
+}
+
+// The following are labels for traffic secret derivation in TLS 1.3.
+var (
+ earlyTrafficLabel = []byte("early traffic secret")
+ handshakeTrafficLabel = []byte("handshake traffic secret")
+ applicationTrafficLabel = []byte("application traffic secret")
+ exporterLabel = []byte("exporter master secret")
+ resumptionLabel = []byte("resumption master secret")
+)
+
+// deriveSecret implements TLS 1.3's Derive-Secret function, as defined in
+// section 7.1 of draft ietf-tls-tls13-13.
+func (h *finishedHash) deriveSecret(secret, label []byte) []byte {
+ if h.resumptionContextHash == nil {
+ panic("Resumption context not set.")
+ }
+
+ return hkdfExpandLabel(h.hash, secret, label, h.appendContextHashes(nil), h.hash.Size())
+}
+
+// The following are context strings for CertificateVerify in TLS 1.3.
+var (
+ clientCertificateVerifyContextTLS13 = []byte("TLS 1.3, client CertificateVerify")
+ serverCertificateVerifyContextTLS13 = []byte("TLS 1.3, server CertificateVerify")
+)
+
+// certificateVerifyMessage returns the input to be signed for CertificateVerify
+// in TLS 1.3.
+func (h *finishedHash) certificateVerifyInput(context []byte) []byte {
+ const paddingLen = 64
+ b := make([]byte, paddingLen, paddingLen+len(context)+1+2*h.hash.Size())
+ for i := 0; i < paddingLen; i++ {
+ b[i] = 32
+ }
+ b = append(b, context...)
+ b = append(b, 0)
+ b = h.appendContextHashes(b)
+ return b
+}
+
+// The following are phase values for traffic key derivation in TLS 1.3.
+var (
+ earlyHandshakePhase = []byte("early handshake key expansion")
+ earlyApplicationPhase = []byte("early application data key expansion")
+ handshakePhase = []byte("handshake key expansion")
+ applicationPhase = []byte("application data key expansion")
+)
+
+type trafficDirection int
+
+const (
+ clientWrite trafficDirection = iota
+ serverWrite
+)
+
+// deriveTrafficAEAD derives traffic keys and constructs an AEAD given a traffic
+// secret.
+func deriveTrafficAEAD(version uint16, suite *cipherSuite, secret, phase []byte, side trafficDirection) *tlsAead {
+ label := make([]byte, 0, len(phase)+2+16)
+ label = append(label, phase...)
+ if side == clientWrite {
+ label = append(label, []byte(", client write key")...)
+ } else {
+ label = append(label, []byte(", server write key")...)
+ }
+ key := hkdfExpandLabel(suite.hash(), secret, label, nil, suite.keyLen)
+
+ label = label[:len(label)-3] // Remove "key" from the end.
+ label = append(label, []byte("iv")...)
+ iv := hkdfExpandLabel(suite.hash(), secret, label, nil, suite.ivLen(version))
+
+ return suite.aead(version, key, iv)
+}