Add support for TLS 1.3 PSK resumption in Go.
Change-Id: I998f69269cdf813da19ccccc208b476f3501c8c4
Reviewed-on: https://boringssl-review.googlesource.com/8991
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/test/runner/alert.go b/ssl/test/runner/alert.go
index 89c907f..363a770 100644
--- a/ssl/test/runner/alert.go
+++ b/ssl/test/runner/alert.go
@@ -41,6 +41,7 @@
alertNoRenegotiation alert = 100
alertMissingExtension alert = 109
alertUnsupportedExtension alert = 110
+ alertUnknownPSKIdentity alert = 115
)
var alertText = map[alert]string{
@@ -69,6 +70,7 @@
alertNoRenegotiation: "no renegotiation",
alertMissingExtension: "missing extension",
alertUnsupportedExtension: "unsupported extension",
+ alertUnknownPSKIdentity: "unknown PSK identity",
}
func (e alert) String() string {
diff --git a/ssl/test/runner/cipher_suites.go b/ssl/test/runner/cipher_suites.go
index 495ec34..4ce4629 100644
--- a/ssl/test/runner/cipher_suites.go
+++ b/ssl/test/runner/cipher_suites.go
@@ -101,6 +101,27 @@
return crypto.SHA256
}
+// TODO(nharper): Remove this function when TLS 1.3 cipher suites get
+// refactored to break out the AEAD/PRF from everything else. Once that's
+// done, this won't be necessary anymore.
+func ecdhePSKSuite(id uint16) uint16 {
+ switch id {
+ case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
+ TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
+ TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256:
+ return TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256
+ case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+ TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256:
+ return TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256
+ case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384:
+ return TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384
+ }
+ return 0
+}
+
var cipherSuites = []*cipherSuite{
// Ciphersuite order is chosen so that ECDHE comes before plain RSA
// and RC4 comes before AES (because of the Lucky13 attack).
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index e9a3fdb..aca308c 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -28,7 +28,7 @@
// The draft version of TLS 1.3 that is implemented here and sent in the draft
// indicator extension.
-const tls13DraftVersion = 13
+const tls13DraftVersion = 14
const (
maxPlaintext = 16384 // maximum plaintext payload length
@@ -242,6 +242,10 @@
extendedMasterSecret bool // Whether an extended master secret was used to generate the session
sctList []byte
ocspResponse []byte
+ ticketCreationTime time.Time
+ ticketExpiration time.Time
+ ticketFlags uint32
+ ticketAgeAdd uint32
}
// ClientSessionCache is a cache of ClientSessionState objects that can be used
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 703908a..77543e6 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -1389,6 +1389,10 @@
serverCertificates: c.peerCertificates,
sctList: c.sctList,
ocspResponse: c.ocspResponse,
+ ticketCreationTime: c.config.time(),
+ ticketExpiration: c.config.time().Add(time.Duration(newSessionTicket.ticketLifetime) * time.Second),
+ ticketFlags: newSessionTicket.ticketFlags,
+ ticketAgeAdd: newSessionTicket.ticketAgeAdd,
}
cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
@@ -1667,11 +1671,10 @@
for _, cert := range c.peerCertificates {
peerCertificatesRaw = append(peerCertificatesRaw, cert.Raw)
}
- state := sessionState{
- vers: c.vers,
- cipherSuite: c.cipherSuite.id,
- masterSecret: c.resumptionSecret,
- certificates: peerCertificatesRaw,
+
+ var ageAdd uint32
+ if err := binary.Read(c.config.rand(), binary.LittleEndian, &ageAdd); err != nil {
+ return err
}
// TODO(davidben): Allow configuring these values.
@@ -1679,7 +1682,20 @@
version: c.vers,
ticketLifetime: uint32(24 * time.Hour / time.Second),
ticketFlags: ticketAllowDHEResumption | ticketAllowPSKResumption,
+ ticketAgeAdd: ageAdd,
}
+
+ state := sessionState{
+ vers: c.vers,
+ cipherSuite: c.cipherSuite.id,
+ masterSecret: c.resumptionSecret,
+ certificates: peerCertificatesRaw,
+ ticketCreationTime: c.config.time(),
+ ticketExpiration: c.config.time().Add(time.Duration(m.ticketLifetime) * time.Second),
+ ticketFlags: m.ticketFlags,
+ ticketAgeAdd: ageAdd,
+ }
+
if !c.config.Bugs.SendEmptySessionTicket {
var err error
m.ticket, err = c.encryptTicket(&state)
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 36bd7e4..46b4732 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -18,6 +18,7 @@
"math/big"
"net"
"strconv"
+ "time"
)
type clientHandshakeState struct {
@@ -197,6 +198,8 @@
// Try to resume a previously negotiated TLS session, if
// available.
cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
+ // TODO(nharper): Support storing more than one session
+ // ticket for TLS 1.3.
candidateSession, ok := sessionCache.Get(cacheKey)
if ok {
ticketOk := !c.config.SessionTicketsDisabled || candidateSession.sessionTicket == nil
@@ -219,7 +222,7 @@
}
}
- if session != nil {
+ if session != nil && c.config.time().Before(session.ticketExpiration) {
ticket := session.sessionTicket
if c.config.Bugs.CorruptTicket && len(ticket) > 0 {
ticket = make([]byte, len(session.sessionTicket))
@@ -232,7 +235,21 @@
}
if session.vers >= VersionTLS13 {
- // TODO(davidben): Offer TLS 1.3 tickets.
+ // TODO(nharper): Support sending more
+ // than one PSK identity.
+ if session.ticketFlags&ticketAllowDHEResumption != 0 {
+ var found bool
+ for _, id := range hello.cipherSuites {
+ if id == session.cipherSuite {
+ found = true
+ break
+ }
+ }
+ if found {
+ hello.pskIdentities = [][]uint8{ticket}
+ hello.cipherSuites = append(hello.cipherSuites, ecdhePSKSuite(session.cipherSuite))
+ }
+ }
} else if ticket != nil {
hello.sessionTicket = ticket
// A random session ID is used to detect when the
@@ -411,7 +428,7 @@
}
}
- suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
+ suite := mutualCipherSuite(hello.cipherSuites, serverHello.cipherSuite)
if suite == nil {
c.sendAlert(alertHandshakeFailure)
return fmt.Errorf("tls: server selected an unsupported cipher suite")
@@ -546,9 +563,18 @@
return errors.New("tls: server omitted the PSK identity extension")
}
- // TODO(davidben): Support PSK ciphers and PSK resumption. Set
- // the resumption context appropriately if resuming.
- return errors.New("tls: PSK ciphers not implemented for TLS 1.3")
+ // We send at most one PSK identity.
+ if hs.session == nil || hs.serverHello.pskIdentity != 0 {
+ c.sendAlert(alertUnknownPSKIdentity)
+ return errors.New("tls: server sent unknown PSK identity")
+ }
+ if ecdhePSKSuite(hs.session.cipherSuite) != hs.suite.id {
+ c.sendAlert(alertHandshakeFailure)
+ return errors.New("tls: server sent invalid cipher suite for PSK")
+ }
+ psk = deriveResumptionPSK(hs.suite, hs.session.masterSecret)
+ hs.finishedHash.setResumptionContext(deriveResumptionContext(hs.suite, hs.session.masterSecret))
+ c.didResume = true
} else {
if hs.serverHello.hasPSKIdentity {
c.sendAlert(alertUnsupportedExtension)
@@ -626,6 +652,11 @@
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent SCT list without a certificate")
}
+
+ // Copy over authentication from the session.
+ c.peerCertificates = hs.session.serverCertificates
+ c.sctList = hs.session.sctList
+ c.ocspResponse = hs.session.ocspResponse
} else {
c.ocspResponse = encryptedExtensions.extensions.ocspResponse
c.sctList = encryptedExtensions.extensions.sctList
@@ -1223,6 +1254,7 @@
serverCertificates: c.peerCertificates,
sctList: c.sctList,
ocspResponse: c.ocspResponse,
+ ticketExpiration: c.config.time().Add(time.Duration(7 * 24 * time.Hour)),
}
if !hs.serverHello.extensions.ticketSupported {
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 8e73a3c..8f87881 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -4,7 +4,10 @@
package runner
-import "bytes"
+import (
+ "bytes"
+ "encoding/binary"
+)
func writeLen(buf []byte, v, size int) {
for i := 0; i < size; i++ {
@@ -67,6 +70,13 @@
*bb.buf = append(*bb.buf, byte(u>>24), byte(u>>16), byte(u>>8), byte(u))
}
+func (bb *byteBuilder) addU64(u uint64) {
+ bb.flush()
+ var b [8]byte
+ binary.BigEndian.PutUint64(b[:], u)
+ *bb.buf = append(*bb.buf, b[:]...)
+}
+
func (bb *byteBuilder) addU8LengthPrefixed() *byteBuilder {
return bb.createChild(1)
}
@@ -79,6 +89,10 @@
return bb.createChild(3)
}
+func (bb *byteBuilder) addU32LengthPrefixed() *byteBuilder {
+ return bb.createChild(4)
+}
+
func (bb *byteBuilder) addBytes(b []byte) {
bb.flush()
*bb.buf = append(*bb.buf, b...)
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 4d8b5c1..c2b28f2 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -305,22 +305,54 @@
_, ecdsaOk := hs.cert.PrivateKey.(*ecdsa.PrivateKey)
- // TODO(davidben): Implement PSK support.
- pskOk := false
-
- // Select the cipher suite.
- var preferenceList, supportedList []uint16
- if config.PreferServerCipherSuites {
- preferenceList = config.cipherSuites()
- supportedList = hs.clientHello.cipherSuites
- } else {
- preferenceList = hs.clientHello.cipherSuites
- supportedList = config.cipherSuites()
+ for i, pskIdentity := range hs.clientHello.pskIdentities {
+ sessionState, ok := c.decryptTicket(pskIdentity)
+ if !ok {
+ continue
+ }
+ if sessionState.vers != c.vers {
+ continue
+ }
+ if sessionState.ticketFlags&ticketAllowDHEResumption == 0 {
+ continue
+ }
+ if sessionState.ticketExpiration.Before(c.config.time()) {
+ continue
+ }
+ suiteId := ecdhePSKSuite(sessionState.cipherSuite)
+ suite := mutualCipherSuite(hs.clientHello.cipherSuites, suiteId)
+ var found bool
+ for _, id := range config.cipherSuites() {
+ if id == sessionState.cipherSuite {
+ found = true
+ break
+ }
+ }
+ if suite != nil && found {
+ hs.sessionState = sessionState
+ hs.suite = suite
+ hs.hello.hasPSKIdentity = true
+ hs.hello.pskIdentity = uint16(i)
+ c.didResume = true
+ break
+ }
}
- for _, id := range preferenceList {
- if hs.suite = c.tryCipherSuite(id, supportedList, c.vers, supportedCurve, ecdsaOk, pskOk); hs.suite != nil {
- break
+ // If not resuming, select the cipher suite.
+ if hs.suite == nil {
+ var preferenceList, supportedList []uint16
+ if config.PreferServerCipherSuites {
+ preferenceList = config.cipherSuites()
+ supportedList = hs.clientHello.cipherSuites
+ } else {
+ preferenceList = hs.clientHello.cipherSuites
+ supportedList = config.cipherSuites()
+ }
+
+ for _, id := range preferenceList {
+ if hs.suite = c.tryCipherSuite(id, supportedList, c.vers, supportedCurve, ecdsaOk, false); hs.suite != nil {
+ break
+ }
}
}
@@ -339,9 +371,19 @@
hs.writeClientHash(hs.clientHello.marshal())
// Resolve PSK and compute the early secret.
- // TODO(davidben): Implement PSK in TLS 1.3.
- psk := hs.finishedHash.zeroSecret()
- hs.finishedHash.setResumptionContext(hs.finishedHash.zeroSecret())
+ var psk []byte
+ // The only way for hs.suite to be a PSK suite yet for there to be
+ // no sessionState is if config.Bugs.EnableAllCiphers is true and
+ // the test runner forced us to negotiated a PSK suite. It doesn't
+ // really matter what we do here so long as we continue the
+ // handshake and let the client error out.
+ if hs.suite.flags&suitePSK != 0 && hs.sessionState != nil {
+ psk = deriveResumptionPSK(hs.suite, hs.sessionState.masterSecret)
+ hs.finishedHash.setResumptionContext(deriveResumptionContext(hs.suite, hs.sessionState.masterSecret))
+ } else {
+ psk = hs.finishedHash.zeroSecret()
+ hs.finishedHash.setResumptionContext(hs.finishedHash.zeroSecret())
+ }
earlySecret := hs.finishedHash.extractKey(hs.finishedHash.zeroSecret(), psk)
@@ -494,9 +536,7 @@
c.out.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite)
c.in.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite)
- if hs.suite.flags&suitePSK != 0 {
- return errors.New("tls: PSK ciphers not implemented for TLS 1.3")
- } else {
+ if hs.suite.flags&suitePSK == 0 {
if hs.clientHello.ocspStapling {
encryptedExtensions.extensions.ocspResponse = hs.cert.OCSPStaple
}
@@ -574,6 +614,15 @@
hs.writeServerHash(certVerify.marshal())
c.writeRecord(recordTypeHandshake, certVerify.marshal())
+ } else {
+ // Pick up certificates from the session instead.
+ // hs.sessionState may be nil if config.Bugs.EnableAllCiphers is
+ // true.
+ if hs.sessionState != nil && len(hs.sessionState.certificates) > 0 {
+ if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil {
+ return err
+ }
+ }
}
finished := new(finishedMsg)
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go
index 220aa44..33ad75a 100644
--- a/ssl/test/runner/prf.go
+++ b/ssl/test/runner/prf.go
@@ -497,3 +497,11 @@
func updateTrafficSecret(hash crypto.Hash, secret []byte) []byte {
return hkdfExpandLabel(hash, secret, applicationTrafficLabel, nil, hash.Size())
}
+
+func deriveResumptionPSK(suite *cipherSuite, resumptionSecret []byte) []byte {
+ return hkdfExpandLabel(suite.hash(), resumptionSecret, []byte("resumption psk"), nil, suite.hash().Size())
+}
+
+func deriveResumptionContext(suite *cipherSuite, resumptionSecret []byte) []byte {
+ return hkdfExpandLabel(suite.hash(), resumptionSecret, []byte("resumption context"), nil, suite.hash().Size())
+}
diff --git a/ssl/test/runner/ticket.go b/ssl/test/runner/ticket.go
index e121c05..4a4540c 100644
--- a/ssl/test/runner/ticket.go
+++ b/ssl/test/runner/ticket.go
@@ -5,14 +5,15 @@
package runner
import (
- "bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
+ "encoding/binary"
"errors"
"io"
+ "time"
)
// sessionState contains the information that is serialized into a session
@@ -24,79 +25,40 @@
handshakeHash []byte
certificates [][]byte
extendedMasterSecret bool
-}
-
-func (s *sessionState) equal(i interface{}) bool {
- s1, ok := i.(*sessionState)
- if !ok {
- return false
- }
-
- if s.vers != s1.vers ||
- s.cipherSuite != s1.cipherSuite ||
- !bytes.Equal(s.masterSecret, s1.masterSecret) ||
- !bytes.Equal(s.handshakeHash, s1.handshakeHash) ||
- s.extendedMasterSecret != s1.extendedMasterSecret {
- return false
- }
-
- if len(s.certificates) != len(s1.certificates) {
- return false
- }
-
- for i := range s.certificates {
- if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
- return false
- }
- }
-
- return true
+ ticketCreationTime time.Time
+ ticketExpiration time.Time
+ ticketFlags uint32
+ ticketAgeAdd uint32
}
func (s *sessionState) marshal() []byte {
- length := 2 + 2 + 2 + len(s.masterSecret) + 2 + len(s.handshakeHash) + 2
+ msg := newByteBuilder()
+ msg.addU16(s.vers)
+ msg.addU16(s.cipherSuite)
+ masterSecret := msg.addU16LengthPrefixed()
+ masterSecret.addBytes(s.masterSecret)
+ handshakeHash := msg.addU16LengthPrefixed()
+ handshakeHash.addBytes(s.handshakeHash)
+ msg.addU16(uint16(len(s.certificates)))
for _, cert := range s.certificates {
- length += 4 + len(cert)
- }
- length++
-
- ret := make([]byte, length)
- x := ret
- x[0] = byte(s.vers >> 8)
- x[1] = byte(s.vers)
- x[2] = byte(s.cipherSuite >> 8)
- x[3] = byte(s.cipherSuite)
- x[4] = byte(len(s.masterSecret) >> 8)
- x[5] = byte(len(s.masterSecret))
- x = x[6:]
- copy(x, s.masterSecret)
- x = x[len(s.masterSecret):]
-
- x[0] = byte(len(s.handshakeHash) >> 8)
- x[1] = byte(len(s.handshakeHash))
- x = x[2:]
- copy(x, s.handshakeHash)
- x = x[len(s.handshakeHash):]
-
- x[0] = byte(len(s.certificates) >> 8)
- x[1] = byte(len(s.certificates))
- x = x[2:]
-
- for _, cert := range s.certificates {
- x[0] = byte(len(cert) >> 24)
- x[1] = byte(len(cert) >> 16)
- x[2] = byte(len(cert) >> 8)
- x[3] = byte(len(cert))
- copy(x[4:], cert)
- x = x[4+len(cert):]
+ certMsg := msg.addU32LengthPrefixed()
+ certMsg.addBytes(cert)
}
if s.extendedMasterSecret {
- x[0] = 1
+ msg.addU8(1)
+ } else {
+ msg.addU8(0)
}
- x = x[1:]
- return ret
+ if s.vers >= VersionTLS13 {
+ msg.addU64(uint64(s.ticketCreationTime.UnixNano()))
+ msg.addU64(uint64(s.ticketExpiration.UnixNano()))
+ msg.addU32(s.ticketFlags)
+ msg.addU32(s.ticketAgeAdd)
+ }
+
+ return msg.finish()
}
func (s *sessionState) unmarshal(data []byte) bool {
@@ -162,6 +124,20 @@
}
data = data[1:]
+ if s.vers >= VersionTLS13 {
+ if len(data) < 24 {
+ return false
+ }
+ s.ticketCreationTime = time.Unix(0, int64(binary.BigEndian.Uint64(data)))
+ data = data[8:]
+ s.ticketExpiration = time.Unix(0, int64(binary.BigEndian.Uint64(data)))
+ data = data[8:]
+ s.ticketFlags = binary.BigEndian.Uint32(data)
+ data = data[4:]
+ s.ticketAgeAdd = binary.BigEndian.Uint32(data)
+ data = data[4:]
+ }
+
if len(data) > 0 {
return false
}