Add support for TLS 1.3 PSK resumption in Go.

Change-Id: I998f69269cdf813da19ccccc208b476f3501c8c4
Reviewed-on: https://boringssl-review.googlesource.com/8991
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/test/runner/alert.go b/ssl/test/runner/alert.go
index 89c907f..363a770 100644
--- a/ssl/test/runner/alert.go
+++ b/ssl/test/runner/alert.go
@@ -41,6 +41,7 @@
 	alertNoRenegotiation        alert = 100
 	alertMissingExtension       alert = 109
 	alertUnsupportedExtension   alert = 110
+	alertUnknownPSKIdentity     alert = 115
 )
 
 var alertText = map[alert]string{
@@ -69,6 +70,7 @@
 	alertNoRenegotiation:        "no renegotiation",
 	alertMissingExtension:       "missing extension",
 	alertUnsupportedExtension:   "unsupported extension",
+	alertUnknownPSKIdentity:     "unknown PSK identity",
 }
 
 func (e alert) String() string {
diff --git a/ssl/test/runner/cipher_suites.go b/ssl/test/runner/cipher_suites.go
index 495ec34..4ce4629 100644
--- a/ssl/test/runner/cipher_suites.go
+++ b/ssl/test/runner/cipher_suites.go
@@ -101,6 +101,27 @@
 	return crypto.SHA256
 }
 
+// TODO(nharper): Remove this function when TLS 1.3 cipher suites get
+// refactored to break out the AEAD/PRF from everything else. Once that's
+// done, this won't be necessary anymore.
+func ecdhePSKSuite(id uint16) uint16 {
+	switch id {
+	case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
+		TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
+		TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256:
+		return TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256
+	case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+		TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+		TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256:
+		return TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256
+	case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+		TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+		TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384:
+		return TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384
+	}
+	return 0
+}
+
 var cipherSuites = []*cipherSuite{
 	// Ciphersuite order is chosen so that ECDHE comes before plain RSA
 	// and RC4 comes before AES (because of the Lucky13 attack).
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index e9a3fdb..aca308c 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -28,7 +28,7 @@
 
 // The draft version of TLS 1.3 that is implemented here and sent in the draft
 // indicator extension.
-const tls13DraftVersion = 13
+const tls13DraftVersion = 14
 
 const (
 	maxPlaintext        = 16384        // maximum plaintext payload length
@@ -242,6 +242,10 @@
 	extendedMasterSecret bool                // Whether an extended master secret was used to generate the session
 	sctList              []byte
 	ocspResponse         []byte
+	ticketCreationTime   time.Time
+	ticketExpiration     time.Time
+	ticketFlags          uint32
+	ticketAgeAdd         uint32
 }
 
 // ClientSessionCache is a cache of ClientSessionState objects that can be used
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 703908a..77543e6 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -1389,6 +1389,10 @@
 				serverCertificates: c.peerCertificates,
 				sctList:            c.sctList,
 				ocspResponse:       c.ocspResponse,
+				ticketCreationTime: c.config.time(),
+				ticketExpiration:   c.config.time().Add(time.Duration(newSessionTicket.ticketLifetime) * time.Second),
+				ticketFlags:        newSessionTicket.ticketFlags,
+				ticketAgeAdd:       newSessionTicket.ticketAgeAdd,
 			}
 
 			cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
@@ -1667,11 +1671,10 @@
 	for _, cert := range c.peerCertificates {
 		peerCertificatesRaw = append(peerCertificatesRaw, cert.Raw)
 	}
-	state := sessionState{
-		vers:         c.vers,
-		cipherSuite:  c.cipherSuite.id,
-		masterSecret: c.resumptionSecret,
-		certificates: peerCertificatesRaw,
+
+	var ageAdd uint32
+	if err := binary.Read(c.config.rand(), binary.LittleEndian, &ageAdd); err != nil {
+		return err
 	}
 
 	// TODO(davidben): Allow configuring these values.
@@ -1679,7 +1682,20 @@
 		version:        c.vers,
 		ticketLifetime: uint32(24 * time.Hour / time.Second),
 		ticketFlags:    ticketAllowDHEResumption | ticketAllowPSKResumption,
+		ticketAgeAdd:   ageAdd,
 	}
+
+	state := sessionState{
+		vers:               c.vers,
+		cipherSuite:        c.cipherSuite.id,
+		masterSecret:       c.resumptionSecret,
+		certificates:       peerCertificatesRaw,
+		ticketCreationTime: c.config.time(),
+		ticketExpiration:   c.config.time().Add(time.Duration(m.ticketLifetime) * time.Second),
+		ticketFlags:        m.ticketFlags,
+		ticketAgeAdd:       ageAdd,
+	}
+
 	if !c.config.Bugs.SendEmptySessionTicket {
 		var err error
 		m.ticket, err = c.encryptTicket(&state)
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 36bd7e4..46b4732 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -18,6 +18,7 @@
 	"math/big"
 	"net"
 	"strconv"
+	"time"
 )
 
 type clientHandshakeState struct {
@@ -197,6 +198,8 @@
 		// Try to resume a previously negotiated TLS session, if
 		// available.
 		cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
+		// TODO(nharper): Support storing more than one session
+		// ticket for TLS 1.3.
 		candidateSession, ok := sessionCache.Get(cacheKey)
 		if ok {
 			ticketOk := !c.config.SessionTicketsDisabled || candidateSession.sessionTicket == nil
@@ -219,7 +222,7 @@
 		}
 	}
 
-	if session != nil {
+	if session != nil && c.config.time().Before(session.ticketExpiration) {
 		ticket := session.sessionTicket
 		if c.config.Bugs.CorruptTicket && len(ticket) > 0 {
 			ticket = make([]byte, len(session.sessionTicket))
@@ -232,7 +235,21 @@
 		}
 
 		if session.vers >= VersionTLS13 {
-			// TODO(davidben): Offer TLS 1.3 tickets.
+			// TODO(nharper): Support sending more
+			// than one PSK identity.
+			if session.ticketFlags&ticketAllowDHEResumption != 0 {
+				var found bool
+				for _, id := range hello.cipherSuites {
+					if id == session.cipherSuite {
+						found = true
+						break
+					}
+				}
+				if found {
+					hello.pskIdentities = [][]uint8{ticket}
+					hello.cipherSuites = append(hello.cipherSuites, ecdhePSKSuite(session.cipherSuite))
+				}
+			}
 		} else if ticket != nil {
 			hello.sessionTicket = ticket
 			// A random session ID is used to detect when the
@@ -411,7 +428,7 @@
 		}
 	}
 
-	suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
+	suite := mutualCipherSuite(hello.cipherSuites, serverHello.cipherSuite)
 	if suite == nil {
 		c.sendAlert(alertHandshakeFailure)
 		return fmt.Errorf("tls: server selected an unsupported cipher suite")
@@ -546,9 +563,18 @@
 			return errors.New("tls: server omitted the PSK identity extension")
 		}
 
-		// TODO(davidben): Support PSK ciphers and PSK resumption. Set
-		// the resumption context appropriately if resuming.
-		return errors.New("tls: PSK ciphers not implemented for TLS 1.3")
+		// We send at most one PSK identity.
+		if hs.session == nil || hs.serverHello.pskIdentity != 0 {
+			c.sendAlert(alertUnknownPSKIdentity)
+			return errors.New("tls: server sent unknown PSK identity")
+		}
+		if ecdhePSKSuite(hs.session.cipherSuite) != hs.suite.id {
+			c.sendAlert(alertHandshakeFailure)
+			return errors.New("tls: server sent invalid cipher suite for PSK")
+		}
+		psk = deriveResumptionPSK(hs.suite, hs.session.masterSecret)
+		hs.finishedHash.setResumptionContext(deriveResumptionContext(hs.suite, hs.session.masterSecret))
+		c.didResume = true
 	} else {
 		if hs.serverHello.hasPSKIdentity {
 			c.sendAlert(alertUnsupportedExtension)
@@ -626,6 +652,11 @@
 			c.sendAlert(alertUnsupportedExtension)
 			return errors.New("tls: server sent SCT list without a certificate")
 		}
+
+		// Copy over authentication from the session.
+		c.peerCertificates = hs.session.serverCertificates
+		c.sctList = hs.session.sctList
+		c.ocspResponse = hs.session.ocspResponse
 	} else {
 		c.ocspResponse = encryptedExtensions.extensions.ocspResponse
 		c.sctList = encryptedExtensions.extensions.sctList
@@ -1223,6 +1254,7 @@
 		serverCertificates: c.peerCertificates,
 		sctList:            c.sctList,
 		ocspResponse:       c.ocspResponse,
+		ticketExpiration:   c.config.time().Add(time.Duration(7 * 24 * time.Hour)),
 	}
 
 	if !hs.serverHello.extensions.ticketSupported {
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 8e73a3c..8f87881 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -4,7 +4,10 @@
 
 package runner
 
-import "bytes"
+import (
+	"bytes"
+	"encoding/binary"
+)
 
 func writeLen(buf []byte, v, size int) {
 	for i := 0; i < size; i++ {
@@ -67,6 +70,13 @@
 	*bb.buf = append(*bb.buf, byte(u>>24), byte(u>>16), byte(u>>8), byte(u))
 }
 
+func (bb *byteBuilder) addU64(u uint64) {
+	bb.flush()
+	var b [8]byte
+	binary.BigEndian.PutUint64(b[:], u)
+	*bb.buf = append(*bb.buf, b[:]...)
+}
+
 func (bb *byteBuilder) addU8LengthPrefixed() *byteBuilder {
 	return bb.createChild(1)
 }
@@ -79,6 +89,10 @@
 	return bb.createChild(3)
 }
 
+func (bb *byteBuilder) addU32LengthPrefixed() *byteBuilder {
+	return bb.createChild(4)
+}
+
 func (bb *byteBuilder) addBytes(b []byte) {
 	bb.flush()
 	*bb.buf = append(*bb.buf, b...)
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 4d8b5c1..c2b28f2 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -305,22 +305,54 @@
 
 	_, ecdsaOk := hs.cert.PrivateKey.(*ecdsa.PrivateKey)
 
-	// TODO(davidben): Implement PSK support.
-	pskOk := false
-
-	// Select the cipher suite.
-	var preferenceList, supportedList []uint16
-	if config.PreferServerCipherSuites {
-		preferenceList = config.cipherSuites()
-		supportedList = hs.clientHello.cipherSuites
-	} else {
-		preferenceList = hs.clientHello.cipherSuites
-		supportedList = config.cipherSuites()
+	for i, pskIdentity := range hs.clientHello.pskIdentities {
+		sessionState, ok := c.decryptTicket(pskIdentity)
+		if !ok {
+			continue
+		}
+		if sessionState.vers != c.vers {
+			continue
+		}
+		if sessionState.ticketFlags&ticketAllowDHEResumption == 0 {
+			continue
+		}
+		if sessionState.ticketExpiration.Before(c.config.time()) {
+			continue
+		}
+		suiteId := ecdhePSKSuite(sessionState.cipherSuite)
+		suite := mutualCipherSuite(hs.clientHello.cipherSuites, suiteId)
+		var found bool
+		for _, id := range config.cipherSuites() {
+			if id == sessionState.cipherSuite {
+				found = true
+				break
+			}
+		}
+		if suite != nil && found {
+			hs.sessionState = sessionState
+			hs.suite = suite
+			hs.hello.hasPSKIdentity = true
+			hs.hello.pskIdentity = uint16(i)
+			c.didResume = true
+			break
+		}
 	}
 
-	for _, id := range preferenceList {
-		if hs.suite = c.tryCipherSuite(id, supportedList, c.vers, supportedCurve, ecdsaOk, pskOk); hs.suite != nil {
-			break
+	// If not resuming, select the cipher suite.
+	if hs.suite == nil {
+		var preferenceList, supportedList []uint16
+		if config.PreferServerCipherSuites {
+			preferenceList = config.cipherSuites()
+			supportedList = hs.clientHello.cipherSuites
+		} else {
+			preferenceList = hs.clientHello.cipherSuites
+			supportedList = config.cipherSuites()
+		}
+
+		for _, id := range preferenceList {
+			if hs.suite = c.tryCipherSuite(id, supportedList, c.vers, supportedCurve, ecdsaOk, false); hs.suite != nil {
+				break
+			}
 		}
 	}
 
@@ -339,9 +371,19 @@
 	hs.writeClientHash(hs.clientHello.marshal())
 
 	// Resolve PSK and compute the early secret.
-	// TODO(davidben): Implement PSK in TLS 1.3.
-	psk := hs.finishedHash.zeroSecret()
-	hs.finishedHash.setResumptionContext(hs.finishedHash.zeroSecret())
+	var psk []byte
+	// The only way for hs.suite to be a PSK suite yet for there to be
+	// no sessionState is if config.Bugs.EnableAllCiphers is true and
+	// the test runner forced us to negotiated a PSK suite. It doesn't
+	// really matter what we do here so long as we continue the
+	// handshake and let the client error out.
+	if hs.suite.flags&suitePSK != 0 && hs.sessionState != nil {
+		psk = deriveResumptionPSK(hs.suite, hs.sessionState.masterSecret)
+		hs.finishedHash.setResumptionContext(deriveResumptionContext(hs.suite, hs.sessionState.masterSecret))
+	} else {
+		psk = hs.finishedHash.zeroSecret()
+		hs.finishedHash.setResumptionContext(hs.finishedHash.zeroSecret())
+	}
 
 	earlySecret := hs.finishedHash.extractKey(hs.finishedHash.zeroSecret(), psk)
 
@@ -494,9 +536,7 @@
 	c.out.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite)
 	c.in.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite)
 
-	if hs.suite.flags&suitePSK != 0 {
-		return errors.New("tls: PSK ciphers not implemented for TLS 1.3")
-	} else {
+	if hs.suite.flags&suitePSK == 0 {
 		if hs.clientHello.ocspStapling {
 			encryptedExtensions.extensions.ocspResponse = hs.cert.OCSPStaple
 		}
@@ -574,6 +614,15 @@
 
 		hs.writeServerHash(certVerify.marshal())
 		c.writeRecord(recordTypeHandshake, certVerify.marshal())
+	} else {
+		// Pick up certificates from the session instead.
+		// hs.sessionState may be nil if config.Bugs.EnableAllCiphers is
+		// true.
+		if hs.sessionState != nil && len(hs.sessionState.certificates) > 0 {
+			if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil {
+				return err
+			}
+		}
 	}
 
 	finished := new(finishedMsg)
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go
index 220aa44..33ad75a 100644
--- a/ssl/test/runner/prf.go
+++ b/ssl/test/runner/prf.go
@@ -497,3 +497,11 @@
 func updateTrafficSecret(hash crypto.Hash, secret []byte) []byte {
 	return hkdfExpandLabel(hash, secret, applicationTrafficLabel, nil, hash.Size())
 }
+
+func deriveResumptionPSK(suite *cipherSuite, resumptionSecret []byte) []byte {
+	return hkdfExpandLabel(suite.hash(), resumptionSecret, []byte("resumption psk"), nil, suite.hash().Size())
+}
+
+func deriveResumptionContext(suite *cipherSuite, resumptionSecret []byte) []byte {
+	return hkdfExpandLabel(suite.hash(), resumptionSecret, []byte("resumption context"), nil, suite.hash().Size())
+}
diff --git a/ssl/test/runner/ticket.go b/ssl/test/runner/ticket.go
index e121c05..4a4540c 100644
--- a/ssl/test/runner/ticket.go
+++ b/ssl/test/runner/ticket.go
@@ -5,14 +5,15 @@
 package runner
 
 import (
-	"bytes"
 	"crypto/aes"
 	"crypto/cipher"
 	"crypto/hmac"
 	"crypto/sha256"
 	"crypto/subtle"
+	"encoding/binary"
 	"errors"
 	"io"
+	"time"
 )
 
 // sessionState contains the information that is serialized into a session
@@ -24,79 +25,40 @@
 	handshakeHash        []byte
 	certificates         [][]byte
 	extendedMasterSecret bool
-}
-
-func (s *sessionState) equal(i interface{}) bool {
-	s1, ok := i.(*sessionState)
-	if !ok {
-		return false
-	}
-
-	if s.vers != s1.vers ||
-		s.cipherSuite != s1.cipherSuite ||
-		!bytes.Equal(s.masterSecret, s1.masterSecret) ||
-		!bytes.Equal(s.handshakeHash, s1.handshakeHash) ||
-		s.extendedMasterSecret != s1.extendedMasterSecret {
-		return false
-	}
-
-	if len(s.certificates) != len(s1.certificates) {
-		return false
-	}
-
-	for i := range s.certificates {
-		if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
-			return false
-		}
-	}
-
-	return true
+	ticketCreationTime   time.Time
+	ticketExpiration     time.Time
+	ticketFlags          uint32
+	ticketAgeAdd         uint32
 }
 
 func (s *sessionState) marshal() []byte {
-	length := 2 + 2 + 2 + len(s.masterSecret) + 2 + len(s.handshakeHash) + 2
+	msg := newByteBuilder()
+	msg.addU16(s.vers)
+	msg.addU16(s.cipherSuite)
+	masterSecret := msg.addU16LengthPrefixed()
+	masterSecret.addBytes(s.masterSecret)
+	handshakeHash := msg.addU16LengthPrefixed()
+	handshakeHash.addBytes(s.handshakeHash)
+	msg.addU16(uint16(len(s.certificates)))
 	for _, cert := range s.certificates {
-		length += 4 + len(cert)
-	}
-	length++
-
-	ret := make([]byte, length)
-	x := ret
-	x[0] = byte(s.vers >> 8)
-	x[1] = byte(s.vers)
-	x[2] = byte(s.cipherSuite >> 8)
-	x[3] = byte(s.cipherSuite)
-	x[4] = byte(len(s.masterSecret) >> 8)
-	x[5] = byte(len(s.masterSecret))
-	x = x[6:]
-	copy(x, s.masterSecret)
-	x = x[len(s.masterSecret):]
-
-	x[0] = byte(len(s.handshakeHash) >> 8)
-	x[1] = byte(len(s.handshakeHash))
-	x = x[2:]
-	copy(x, s.handshakeHash)
-	x = x[len(s.handshakeHash):]
-
-	x[0] = byte(len(s.certificates) >> 8)
-	x[1] = byte(len(s.certificates))
-	x = x[2:]
-
-	for _, cert := range s.certificates {
-		x[0] = byte(len(cert) >> 24)
-		x[1] = byte(len(cert) >> 16)
-		x[2] = byte(len(cert) >> 8)
-		x[3] = byte(len(cert))
-		copy(x[4:], cert)
-		x = x[4+len(cert):]
+		certMsg := msg.addU32LengthPrefixed()
+		certMsg.addBytes(cert)
 	}
 
 	if s.extendedMasterSecret {
-		x[0] = 1
+		msg.addU8(1)
+	} else {
+		msg.addU8(0)
 	}
-	x = x[1:]
 
-	return ret
+	if s.vers >= VersionTLS13 {
+		msg.addU64(uint64(s.ticketCreationTime.UnixNano()))
+		msg.addU64(uint64(s.ticketExpiration.UnixNano()))
+		msg.addU32(s.ticketFlags)
+		msg.addU32(s.ticketAgeAdd)
+	}
+
+	return msg.finish()
 }
 
 func (s *sessionState) unmarshal(data []byte) bool {
@@ -162,6 +124,20 @@
 	}
 	data = data[1:]
 
+	if s.vers >= VersionTLS13 {
+		if len(data) < 24 {
+			return false
+		}
+		s.ticketCreationTime = time.Unix(0, int64(binary.BigEndian.Uint64(data)))
+		data = data[8:]
+		s.ticketExpiration = time.Unix(0, int64(binary.BigEndian.Uint64(data)))
+		data = data[8:]
+		s.ticketFlags = binary.BigEndian.Uint32(data)
+		data = data[4:]
+		s.ticketAgeAdd = binary.BigEndian.Uint32(data)
+		data = data[4:]
+	}
+
 	if len(data) > 0 {
 		return false
 	}