Add support for TLS 1.3 PSK resumption in Go. Change-Id: I998f69269cdf813da19ccccc208b476f3501c8c4 Reviewed-on: https://boringssl-review.googlesource.com/8991 Reviewed-by: Steven Valdez <svaldez@google.com> Reviewed-by: David Benjamin <davidben@google.com> Commit-Queue: David Benjamin <davidben@google.com> CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/test/runner/alert.go b/ssl/test/runner/alert.go index 89c907f..363a770 100644 --- a/ssl/test/runner/alert.go +++ b/ssl/test/runner/alert.go
@@ -41,6 +41,7 @@ alertNoRenegotiation alert = 100 alertMissingExtension alert = 109 alertUnsupportedExtension alert = 110 + alertUnknownPSKIdentity alert = 115 ) var alertText = map[alert]string{ @@ -69,6 +70,7 @@ alertNoRenegotiation: "no renegotiation", alertMissingExtension: "missing extension", alertUnsupportedExtension: "unsupported extension", + alertUnknownPSKIdentity: "unknown PSK identity", } func (e alert) String() string {
diff --git a/ssl/test/runner/cipher_suites.go b/ssl/test/runner/cipher_suites.go index 495ec34..4ce4629 100644 --- a/ssl/test/runner/cipher_suites.go +++ b/ssl/test/runner/cipher_suites.go
@@ -101,6 +101,27 @@ return crypto.SHA256 } +// TODO(nharper): Remove this function when TLS 1.3 cipher suites get +// refactored to break out the AEAD/PRF from everything else. Once that's +// done, this won't be necessary anymore. +func ecdhePSKSuite(id uint16) uint16 { + switch id { + case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256: + return TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256 + case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256: + return TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256 + case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384: + return TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384 + } + return 0 +} + var cipherSuites = []*cipherSuite{ // Ciphersuite order is chosen so that ECDHE comes before plain RSA // and RC4 comes before AES (because of the Lucky13 attack).
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go index e9a3fdb..aca308c 100644 --- a/ssl/test/runner/common.go +++ b/ssl/test/runner/common.go
@@ -28,7 +28,7 @@ // The draft version of TLS 1.3 that is implemented here and sent in the draft // indicator extension. -const tls13DraftVersion = 13 +const tls13DraftVersion = 14 const ( maxPlaintext = 16384 // maximum plaintext payload length @@ -242,6 +242,10 @@ extendedMasterSecret bool // Whether an extended master secret was used to generate the session sctList []byte ocspResponse []byte + ticketCreationTime time.Time + ticketExpiration time.Time + ticketFlags uint32 + ticketAgeAdd uint32 } // ClientSessionCache is a cache of ClientSessionState objects that can be used
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go index 703908a..77543e6 100644 --- a/ssl/test/runner/conn.go +++ b/ssl/test/runner/conn.go
@@ -1389,6 +1389,10 @@ serverCertificates: c.peerCertificates, sctList: c.sctList, ocspResponse: c.ocspResponse, + ticketCreationTime: c.config.time(), + ticketExpiration: c.config.time().Add(time.Duration(newSessionTicket.ticketLifetime) * time.Second), + ticketFlags: newSessionTicket.ticketFlags, + ticketAgeAdd: newSessionTicket.ticketAgeAdd, } cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) @@ -1667,11 +1671,10 @@ for _, cert := range c.peerCertificates { peerCertificatesRaw = append(peerCertificatesRaw, cert.Raw) } - state := sessionState{ - vers: c.vers, - cipherSuite: c.cipherSuite.id, - masterSecret: c.resumptionSecret, - certificates: peerCertificatesRaw, + + var ageAdd uint32 + if err := binary.Read(c.config.rand(), binary.LittleEndian, &ageAdd); err != nil { + return err } // TODO(davidben): Allow configuring these values. @@ -1679,7 +1682,20 @@ version: c.vers, ticketLifetime: uint32(24 * time.Hour / time.Second), ticketFlags: ticketAllowDHEResumption | ticketAllowPSKResumption, + ticketAgeAdd: ageAdd, } + + state := sessionState{ + vers: c.vers, + cipherSuite: c.cipherSuite.id, + masterSecret: c.resumptionSecret, + certificates: peerCertificatesRaw, + ticketCreationTime: c.config.time(), + ticketExpiration: c.config.time().Add(time.Duration(m.ticketLifetime) * time.Second), + ticketFlags: m.ticketFlags, + ticketAgeAdd: ageAdd, + } + if !c.config.Bugs.SendEmptySessionTicket { var err error m.ticket, err = c.encryptTicket(&state)
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go index 36bd7e4..46b4732 100644 --- a/ssl/test/runner/handshake_client.go +++ b/ssl/test/runner/handshake_client.go
@@ -18,6 +18,7 @@ "math/big" "net" "strconv" + "time" ) type clientHandshakeState struct { @@ -197,6 +198,8 @@ // Try to resume a previously negotiated TLS session, if // available. cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + // TODO(nharper): Support storing more than one session + // ticket for TLS 1.3. candidateSession, ok := sessionCache.Get(cacheKey) if ok { ticketOk := !c.config.SessionTicketsDisabled || candidateSession.sessionTicket == nil @@ -219,7 +222,7 @@ } } - if session != nil { + if session != nil && c.config.time().Before(session.ticketExpiration) { ticket := session.sessionTicket if c.config.Bugs.CorruptTicket && len(ticket) > 0 { ticket = make([]byte, len(session.sessionTicket)) @@ -232,7 +235,21 @@ } if session.vers >= VersionTLS13 { - // TODO(davidben): Offer TLS 1.3 tickets. + // TODO(nharper): Support sending more + // than one PSK identity. + if session.ticketFlags&ticketAllowDHEResumption != 0 { + var found bool + for _, id := range hello.cipherSuites { + if id == session.cipherSuite { + found = true + break + } + } + if found { + hello.pskIdentities = [][]uint8{ticket} + hello.cipherSuites = append(hello.cipherSuites, ecdhePSKSuite(session.cipherSuite)) + } + } } else if ticket != nil { hello.sessionTicket = ticket // A random session ID is used to detect when the @@ -411,7 +428,7 @@ } } - suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite) + suite := mutualCipherSuite(hello.cipherSuites, serverHello.cipherSuite) if suite == nil { c.sendAlert(alertHandshakeFailure) return fmt.Errorf("tls: server selected an unsupported cipher suite") @@ -546,9 +563,18 @@ return errors.New("tls: server omitted the PSK identity extension") } - // TODO(davidben): Support PSK ciphers and PSK resumption. Set - // the resumption context appropriately if resuming. - return errors.New("tls: PSK ciphers not implemented for TLS 1.3") + // We send at most one PSK identity. + if hs.session == nil || hs.serverHello.pskIdentity != 0 { + c.sendAlert(alertUnknownPSKIdentity) + return errors.New("tls: server sent unknown PSK identity") + } + if ecdhePSKSuite(hs.session.cipherSuite) != hs.suite.id { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server sent invalid cipher suite for PSK") + } + psk = deriveResumptionPSK(hs.suite, hs.session.masterSecret) + hs.finishedHash.setResumptionContext(deriveResumptionContext(hs.suite, hs.session.masterSecret)) + c.didResume = true } else { if hs.serverHello.hasPSKIdentity { c.sendAlert(alertUnsupportedExtension) @@ -626,6 +652,11 @@ c.sendAlert(alertUnsupportedExtension) return errors.New("tls: server sent SCT list without a certificate") } + + // Copy over authentication from the session. + c.peerCertificates = hs.session.serverCertificates + c.sctList = hs.session.sctList + c.ocspResponse = hs.session.ocspResponse } else { c.ocspResponse = encryptedExtensions.extensions.ocspResponse c.sctList = encryptedExtensions.extensions.sctList @@ -1223,6 +1254,7 @@ serverCertificates: c.peerCertificates, sctList: c.sctList, ocspResponse: c.ocspResponse, + ticketExpiration: c.config.time().Add(time.Duration(7 * 24 * time.Hour)), } if !hs.serverHello.extensions.ticketSupported {
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go index 8e73a3c..8f87881 100644 --- a/ssl/test/runner/handshake_messages.go +++ b/ssl/test/runner/handshake_messages.go
@@ -4,7 +4,10 @@ package runner -import "bytes" +import ( + "bytes" + "encoding/binary" +) func writeLen(buf []byte, v, size int) { for i := 0; i < size; i++ { @@ -67,6 +70,13 @@ *bb.buf = append(*bb.buf, byte(u>>24), byte(u>>16), byte(u>>8), byte(u)) } +func (bb *byteBuilder) addU64(u uint64) { + bb.flush() + var b [8]byte + binary.BigEndian.PutUint64(b[:], u) + *bb.buf = append(*bb.buf, b[:]...) +} + func (bb *byteBuilder) addU8LengthPrefixed() *byteBuilder { return bb.createChild(1) } @@ -79,6 +89,10 @@ return bb.createChild(3) } +func (bb *byteBuilder) addU32LengthPrefixed() *byteBuilder { + return bb.createChild(4) +} + func (bb *byteBuilder) addBytes(b []byte) { bb.flush() *bb.buf = append(*bb.buf, b...)
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go index 4d8b5c1..c2b28f2 100644 --- a/ssl/test/runner/handshake_server.go +++ b/ssl/test/runner/handshake_server.go
@@ -305,22 +305,54 @@ _, ecdsaOk := hs.cert.PrivateKey.(*ecdsa.PrivateKey) - // TODO(davidben): Implement PSK support. - pskOk := false - - // Select the cipher suite. - var preferenceList, supportedList []uint16 - if config.PreferServerCipherSuites { - preferenceList = config.cipherSuites() - supportedList = hs.clientHello.cipherSuites - } else { - preferenceList = hs.clientHello.cipherSuites - supportedList = config.cipherSuites() + for i, pskIdentity := range hs.clientHello.pskIdentities { + sessionState, ok := c.decryptTicket(pskIdentity) + if !ok { + continue + } + if sessionState.vers != c.vers { + continue + } + if sessionState.ticketFlags&ticketAllowDHEResumption == 0 { + continue + } + if sessionState.ticketExpiration.Before(c.config.time()) { + continue + } + suiteId := ecdhePSKSuite(sessionState.cipherSuite) + suite := mutualCipherSuite(hs.clientHello.cipherSuites, suiteId) + var found bool + for _, id := range config.cipherSuites() { + if id == sessionState.cipherSuite { + found = true + break + } + } + if suite != nil && found { + hs.sessionState = sessionState + hs.suite = suite + hs.hello.hasPSKIdentity = true + hs.hello.pskIdentity = uint16(i) + c.didResume = true + break + } } - for _, id := range preferenceList { - if hs.suite = c.tryCipherSuite(id, supportedList, c.vers, supportedCurve, ecdsaOk, pskOk); hs.suite != nil { - break + // If not resuming, select the cipher suite. + if hs.suite == nil { + var preferenceList, supportedList []uint16 + if config.PreferServerCipherSuites { + preferenceList = config.cipherSuites() + supportedList = hs.clientHello.cipherSuites + } else { + preferenceList = hs.clientHello.cipherSuites + supportedList = config.cipherSuites() + } + + for _, id := range preferenceList { + if hs.suite = c.tryCipherSuite(id, supportedList, c.vers, supportedCurve, ecdsaOk, false); hs.suite != nil { + break + } } } @@ -339,9 +371,19 @@ hs.writeClientHash(hs.clientHello.marshal()) // Resolve PSK and compute the early secret. - // TODO(davidben): Implement PSK in TLS 1.3. - psk := hs.finishedHash.zeroSecret() - hs.finishedHash.setResumptionContext(hs.finishedHash.zeroSecret()) + var psk []byte + // The only way for hs.suite to be a PSK suite yet for there to be + // no sessionState is if config.Bugs.EnableAllCiphers is true and + // the test runner forced us to negotiated a PSK suite. It doesn't + // really matter what we do here so long as we continue the + // handshake and let the client error out. + if hs.suite.flags&suitePSK != 0 && hs.sessionState != nil { + psk = deriveResumptionPSK(hs.suite, hs.sessionState.masterSecret) + hs.finishedHash.setResumptionContext(deriveResumptionContext(hs.suite, hs.sessionState.masterSecret)) + } else { + psk = hs.finishedHash.zeroSecret() + hs.finishedHash.setResumptionContext(hs.finishedHash.zeroSecret()) + } earlySecret := hs.finishedHash.extractKey(hs.finishedHash.zeroSecret(), psk) @@ -494,9 +536,7 @@ c.out.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite) c.in.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite) - if hs.suite.flags&suitePSK != 0 { - return errors.New("tls: PSK ciphers not implemented for TLS 1.3") - } else { + if hs.suite.flags&suitePSK == 0 { if hs.clientHello.ocspStapling { encryptedExtensions.extensions.ocspResponse = hs.cert.OCSPStaple } @@ -574,6 +614,15 @@ hs.writeServerHash(certVerify.marshal()) c.writeRecord(recordTypeHandshake, certVerify.marshal()) + } else { + // Pick up certificates from the session instead. + // hs.sessionState may be nil if config.Bugs.EnableAllCiphers is + // true. + if hs.sessionState != nil && len(hs.sessionState.certificates) > 0 { + if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil { + return err + } + } } finished := new(finishedMsg)
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go index 220aa44..33ad75a 100644 --- a/ssl/test/runner/prf.go +++ b/ssl/test/runner/prf.go
@@ -497,3 +497,11 @@ func updateTrafficSecret(hash crypto.Hash, secret []byte) []byte { return hkdfExpandLabel(hash, secret, applicationTrafficLabel, nil, hash.Size()) } + +func deriveResumptionPSK(suite *cipherSuite, resumptionSecret []byte) []byte { + return hkdfExpandLabel(suite.hash(), resumptionSecret, []byte("resumption psk"), nil, suite.hash().Size()) +} + +func deriveResumptionContext(suite *cipherSuite, resumptionSecret []byte) []byte { + return hkdfExpandLabel(suite.hash(), resumptionSecret, []byte("resumption context"), nil, suite.hash().Size()) +}
diff --git a/ssl/test/runner/ticket.go b/ssl/test/runner/ticket.go index e121c05..4a4540c 100644 --- a/ssl/test/runner/ticket.go +++ b/ssl/test/runner/ticket.go
@@ -5,14 +5,15 @@ package runner import ( - "bytes" "crypto/aes" "crypto/cipher" "crypto/hmac" "crypto/sha256" "crypto/subtle" + "encoding/binary" "errors" "io" + "time" ) // sessionState contains the information that is serialized into a session @@ -24,79 +25,40 @@ handshakeHash []byte certificates [][]byte extendedMasterSecret bool -} - -func (s *sessionState) equal(i interface{}) bool { - s1, ok := i.(*sessionState) - if !ok { - return false - } - - if s.vers != s1.vers || - s.cipherSuite != s1.cipherSuite || - !bytes.Equal(s.masterSecret, s1.masterSecret) || - !bytes.Equal(s.handshakeHash, s1.handshakeHash) || - s.extendedMasterSecret != s1.extendedMasterSecret { - return false - } - - if len(s.certificates) != len(s1.certificates) { - return false - } - - for i := range s.certificates { - if !bytes.Equal(s.certificates[i], s1.certificates[i]) { - return false - } - } - - return true + ticketCreationTime time.Time + ticketExpiration time.Time + ticketFlags uint32 + ticketAgeAdd uint32 } func (s *sessionState) marshal() []byte { - length := 2 + 2 + 2 + len(s.masterSecret) + 2 + len(s.handshakeHash) + 2 + msg := newByteBuilder() + msg.addU16(s.vers) + msg.addU16(s.cipherSuite) + masterSecret := msg.addU16LengthPrefixed() + masterSecret.addBytes(s.masterSecret) + handshakeHash := msg.addU16LengthPrefixed() + handshakeHash.addBytes(s.handshakeHash) + msg.addU16(uint16(len(s.certificates))) for _, cert := range s.certificates { - length += 4 + len(cert) - } - length++ - - ret := make([]byte, length) - x := ret - x[0] = byte(s.vers >> 8) - x[1] = byte(s.vers) - x[2] = byte(s.cipherSuite >> 8) - x[3] = byte(s.cipherSuite) - x[4] = byte(len(s.masterSecret) >> 8) - x[5] = byte(len(s.masterSecret)) - x = x[6:] - copy(x, s.masterSecret) - x = x[len(s.masterSecret):] - - x[0] = byte(len(s.handshakeHash) >> 8) - x[1] = byte(len(s.handshakeHash)) - x = x[2:] - copy(x, s.handshakeHash) - x = x[len(s.handshakeHash):] - - x[0] = byte(len(s.certificates) >> 8) - x[1] = byte(len(s.certificates)) - x = x[2:] - - for _, cert := range s.certificates { - x[0] = byte(len(cert) >> 24) - x[1] = byte(len(cert) >> 16) - x[2] = byte(len(cert) >> 8) - x[3] = byte(len(cert)) - copy(x[4:], cert) - x = x[4+len(cert):] + certMsg := msg.addU32LengthPrefixed() + certMsg.addBytes(cert) } if s.extendedMasterSecret { - x[0] = 1 + msg.addU8(1) + } else { + msg.addU8(0) } - x = x[1:] - return ret + if s.vers >= VersionTLS13 { + msg.addU64(uint64(s.ticketCreationTime.UnixNano())) + msg.addU64(uint64(s.ticketExpiration.UnixNano())) + msg.addU32(s.ticketFlags) + msg.addU32(s.ticketAgeAdd) + } + + return msg.finish() } func (s *sessionState) unmarshal(data []byte) bool { @@ -162,6 +124,20 @@ } data = data[1:] + if s.vers >= VersionTLS13 { + if len(data) < 24 { + return false + } + s.ticketCreationTime = time.Unix(0, int64(binary.BigEndian.Uint64(data))) + data = data[8:] + s.ticketExpiration = time.Unix(0, int64(binary.BigEndian.Uint64(data))) + data = data[8:] + s.ticketFlags = binary.BigEndian.Uint32(data) + data = data[4:] + s.ticketAgeAdd = binary.BigEndian.Uint32(data) + data = data[4:] + } + if len(data) > 0 { return false }