Implement basic TLS 1.3 client handshake in Go.

[Originally written by nharper and then revised by davidben.]

Most features are missing, but it works for a start. To avoid breaking
the fake TLS 1.3 tests while the C code is still not landed, all the
logic is gated on a global boolean. When the C code gets in, we'll
set it to true and remove this boolean.

Change-Id: I6b3a369890864c26203fc9cda37c8250024ce91b
Reviewed-on: https://boringssl-review.googlesource.com/8601
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/test/runner/alert.go b/ssl/test/runner/alert.go
index 8144b50..89c907f 100644
--- a/ssl/test/runner/alert.go
+++ b/ssl/test/runner/alert.go
@@ -40,6 +40,7 @@
 	alertUserCanceled           alert = 90
 	alertNoRenegotiation        alert = 100
 	alertMissingExtension       alert = 109
+	alertUnsupportedExtension   alert = 110
 )
 
 var alertText = map[alert]string{
@@ -67,6 +68,7 @@
 	alertUserCanceled:           "user canceled",
 	alertNoRenegotiation:        "no renegotiation",
 	alertMissingExtension:       "missing extension",
+	alertUnsupportedExtension:   "unsupported extension",
 }
 
 func (e alert) String() string {
diff --git a/ssl/test/runner/cipher_suites.go b/ssl/test/runner/cipher_suites.go
index ad2a638..ab22905 100644
--- a/ssl/test/runner/cipher_suites.go
+++ b/ssl/test/runner/cipher_suites.go
@@ -5,6 +5,7 @@
 package runner
 
 import (
+	"crypto"
 	"crypto/aes"
 	"crypto/cipher"
 	"crypto/des"
@@ -93,6 +94,13 @@
 	aead   func(version uint16, key, fixedNonce []byte) *tlsAead
 }
 
+func (cs cipherSuite) hash() crypto.Hash {
+	if cs.flags&suiteSHA384 != 0 {
+		return crypto.SHA384
+	}
+	return crypto.SHA256
+}
+
 var cipherSuites = []*cipherSuite{
 	// Ciphersuite order is chosen so that ECDHE comes before plain RSA
 	// and RC4 comes before AES (because of the Lucky13 attack).
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index eabd0c6..db3f270 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -18,6 +18,9 @@
 	"time"
 )
 
+// TODO(davidben): Flip this to true when the C code lands.
+const enableTLS13Handshake = false
+
 const (
 	VersionSSL30 = 0x0300
 	VersionTLS10 = 0x0301
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index c33ac0c..601c731 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -199,6 +199,13 @@
 	return nil
 }
 
+// updateKeys sets the current cipher state.
+func (hc *halfConn) updateKeys(cipher interface{}, version uint16) {
+	hc.version = version
+	hc.cipher = cipher
+	hc.incEpoch()
+}
+
 // incSeq increments the sequence number.
 func (hc *halfConn) incSeq(isOutgoing bool) {
 	limit := 0
@@ -1114,11 +1121,16 @@
 		}
 	case typeNewSessionTicket:
 		m = new(newSessionTicketMsg)
+	case typeEncryptedExtensions:
+		m = new(encryptedExtensionsMsg)
 	case typeCertificate:
-		m = new(certificateMsg)
+		m = &certificateMsg{
+			hasRequestContext: c.vers >= VersionTLS13 && enableTLS13Handshake,
+		}
 	case typeCertificateRequest:
 		m = &certificateRequestMsg{
 			hasSignatureAlgorithm: c.vers >= VersionTLS12,
+			hasRequestContext:     c.vers >= VersionTLS13 && enableTLS13Handshake,
 		}
 	case typeCertificateStatus:
 		m = new(certificateStatusMsg)
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index a2f6f65..56c49d9 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -26,6 +26,7 @@
 	hello         *clientHelloMsg
 	suite         *cipherSuite
 	finishedHash  finishedHash
+	keyShares     map[CurveID]ecdhCurve
 	masterSecret  []byte
 	session       *ClientSessionState
 	finishedBytes []byte
@@ -102,6 +103,31 @@
 		hello.secureRenegotiation = nil
 	}
 
+	var keyShares map[CurveID]ecdhCurve
+	if hello.vers >= VersionTLS13 && enableTLS13Handshake {
+		// Offer every supported curve in the initial ClientHello.
+		//
+		// TODO(davidben): For real code, default to a more conservative
+		// set like P-256 and X25519. Make it configurable for tests to
+		// stress the HelloRetryRequest logic when implemented.
+		keyShares = make(map[CurveID]ecdhCurve)
+		for _, curveID := range hello.supportedCurves {
+			curve, ok := curveForCurveID(curveID)
+			if !ok {
+				continue
+			}
+			publicKey, err := curve.offer(c.config.rand())
+			if err != nil {
+				return err
+			}
+			hello.keyShares = append(hello.keyShares, keyShareEntry{
+				group:       curveID,
+				keyExchange: publicKey,
+			})
+			keyShares[curveID] = curve
+		}
+	}
+
 	possibleCipherSuites := c.config.cipherSuites()
 	hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites))
 
@@ -261,6 +287,7 @@
 		}
 	}
 
+	// TODO(davidben): Handle HelloRetryRequest.
 	serverHello, ok := msg.(*serverHelloMsg)
 	if !ok {
 		c.sendAlert(alertUnexpectedMessage)
@@ -286,82 +313,90 @@
 		hello:        hello,
 		suite:        suite,
 		finishedHash: newFinishedHash(c.vers, suite),
+		keyShares:    keyShares,
 		session:      session,
 	}
 
 	hs.writeHash(helloBytes, hs.c.sendHandshakeSeq-1)
 	hs.writeServerHash(hs.serverHello.marshal())
 
-	if c.config.Bugs.EarlyChangeCipherSpec > 0 {
-		hs.establishKeys()
-		c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
-	}
-
-	if hs.serverHello.compressionMethod != compressionNone {
-		c.sendAlert(alertUnexpectedMessage)
-		return errors.New("tls: server selected unsupported compression format")
-	}
-
-	err = hs.processServerExtensions(&serverHello.extensions)
-	if err != nil {
-		return err
-	}
-
-	isResume, err := hs.processServerHello()
-	if err != nil {
-		return err
-	}
-
-	if isResume {
-		if c.config.Bugs.EarlyChangeCipherSpec == 0 {
-			if err := hs.establishKeys(); err != nil {
-				return err
-			}
-		}
-		if err := hs.readSessionTicket(); err != nil {
-			return err
-		}
-		if err := hs.readFinished(c.firstFinished[:]); err != nil {
-			return err
-		}
-		if err := hs.sendFinished(nil, isResume); err != nil {
+	if c.vers >= VersionTLS13 && enableTLS13Handshake {
+		if err := hs.doTLS13Handshake(); err != nil {
 			return err
 		}
 	} else {
-		if err := hs.doFullHandshake(); err != nil {
+		if c.config.Bugs.EarlyChangeCipherSpec > 0 {
+			hs.establishKeys()
+			c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
+		}
+
+		if hs.serverHello.compressionMethod != compressionNone {
+			c.sendAlert(alertUnexpectedMessage)
+			return errors.New("tls: server selected unsupported compression format")
+		}
+
+		err = hs.processServerExtensions(&serverHello.extensions)
+		if err != nil {
 			return err
 		}
-		if err := hs.establishKeys(); err != nil {
+
+		isResume, err := hs.processServerHello()
+		if err != nil {
 			return err
 		}
-		if err := hs.sendFinished(c.firstFinished[:], isResume); err != nil {
-			return err
+
+		if isResume {
+			if c.config.Bugs.EarlyChangeCipherSpec == 0 {
+				if err := hs.establishKeys(); err != nil {
+					return err
+				}
+			}
+			if err := hs.readSessionTicket(); err != nil {
+				return err
+			}
+			if err := hs.readFinished(c.firstFinished[:]); err != nil {
+				return err
+			}
+			if err := hs.sendFinished(nil, isResume); err != nil {
+				return err
+			}
+		} else {
+			if err := hs.doFullHandshake(); err != nil {
+				return err
+			}
+			if err := hs.establishKeys(); err != nil {
+				return err
+			}
+			if err := hs.sendFinished(c.firstFinished[:], isResume); err != nil {
+				return err
+			}
+			// Most retransmits are triggered by a timeout, but the final
+			// leg of the handshake is retransmited upon re-receiving a
+			// Finished.
+			if err := c.simulatePacketLoss(func() {
+				c.writeRecord(recordTypeHandshake, hs.finishedBytes)
+				c.flushHandshake()
+			}); err != nil {
+				return err
+			}
+			if err := hs.readSessionTicket(); err != nil {
+				return err
+			}
+			if err := hs.readFinished(nil); err != nil {
+				return err
+			}
 		}
-		// Most retransmits are triggered by a timeout, but the final
-		// leg of the handshake is retransmited upon re-receiving a
-		// Finished.
-		if err := c.simulatePacketLoss(func() {
-			c.writeRecord(recordTypeHandshake, hs.finishedBytes)
-			c.flushHandshake()
-		}); err != nil {
-			return err
+
+		if sessionCache != nil && hs.session != nil && session != hs.session {
+			if c.config.Bugs.RequireSessionTickets && len(hs.session.sessionTicket) == 0 {
+				return errors.New("tls: new session used session IDs instead of tickets")
+			}
+			sessionCache.Put(cacheKey, hs.session)
 		}
-		if err := hs.readSessionTicket(); err != nil {
-			return err
-		}
-		if err := hs.readFinished(nil); err != nil {
-			return err
-		}
+
+		c.didResume = isResume
 	}
 
-	if sessionCache != nil && hs.session != nil && session != hs.session {
-		if c.config.Bugs.RequireSessionTickets && len(hs.session.sessionTicket) == 0 {
-			return errors.New("tls: new session used session IDs instead of tickets")
-		}
-		sessionCache.Put(cacheKey, hs.session)
-	}
-
-	c.didResume = isResume
 	c.handshakeComplete = true
 	c.cipherSuite = suite
 	copy(c.clientRandom[:], hs.hello.random)
@@ -371,6 +406,202 @@
 	return nil
 }
 
+func (hs *clientHandshakeState) doTLS13Handshake() error {
+	c := hs.c
+
+	// Once the PRF hash is known, TLS 1.3 does not require a handshake
+	// buffer.
+	hs.finishedHash.discardHandshakeBuffer()
+
+	zeroSecret := hs.finishedHash.zeroSecret()
+
+	// Resolve PSK and compute the early secret.
+	//
+	// TODO(davidben): This will need to be handled slightly earlier once
+	// 0-RTT is implemented.
+	var psk []byte
+	if hs.suite.flags&suitePSK != 0 {
+		if !hs.serverHello.hasPSKIdentity {
+			c.sendAlert(alertMissingExtension)
+			return errors.New("tls: server omitted the PSK identity extension")
+		}
+
+		// TODO(davidben): Support PSK ciphers and PSK resumption. Set
+		// the resumption context appropriately if resuming.
+		return errors.New("tls: PSK ciphers not implemented for TLS 1.3")
+	} else {
+		if hs.serverHello.hasPSKIdentity {
+			c.sendAlert(alertUnsupportedExtension)
+			return errors.New("tls: server sent unexpected PSK identity")
+		}
+
+		psk = zeroSecret
+		hs.finishedHash.setResumptionContext(zeroSecret)
+	}
+
+	earlySecret := hs.finishedHash.extractKey(zeroSecret, psk)
+
+	// Resolve ECDHE and compute the handshake secret.
+	var ecdheSecret []byte
+	if hs.suite.flags&suiteECDHE != 0 {
+		if !hs.serverHello.hasKeyShare {
+			c.sendAlert(alertMissingExtension)
+			return errors.New("tls: server omitted the key share extension")
+		}
+
+		curve, ok := hs.keyShares[hs.serverHello.keyShare.group]
+		if !ok {
+			c.sendAlert(alertHandshakeFailure)
+			return errors.New("tls: server selected an unsupported group")
+		}
+
+		var err error
+		ecdheSecret, err = curve.finish(hs.serverHello.keyShare.keyExchange)
+		if err != nil {
+			return err
+		}
+	} else {
+		if hs.serverHello.hasKeyShare {
+			c.sendAlert(alertUnsupportedExtension)
+			return errors.New("tls: server sent unexpected key share extension")
+		}
+
+		ecdheSecret = zeroSecret
+	}
+
+	// Compute the handshake secret.
+	handshakeSecret := hs.finishedHash.extractKey(earlySecret, ecdheSecret)
+
+	// Switch to handshake traffic keys.
+	handshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, handshakeTrafficLabel)
+	c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite), c.vers)
+	c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite), c.vers)
+
+	msg, err := c.readHandshake()
+	if err != nil {
+		return err
+	}
+
+	encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
+	if !ok {
+		c.sendAlert(alertUnexpectedMessage)
+		return unexpectedMessageError(encryptedExtensions, msg)
+	}
+	hs.writeServerHash(encryptedExtensions.marshal())
+
+	err = hs.processServerExtensions(&encryptedExtensions.extensions)
+	if err != nil {
+		return err
+	}
+
+	var chainToSend *Certificate
+	var certRequested bool
+	var certRequestContext []byte
+	if hs.suite.flags&suitePSK == 0 {
+		// TODO(davidben): Save OCSP response and SCT list. Forbid them
+		// if not negotiating a certificate-based extension.
+
+		msg, err := c.readHandshake()
+		if err != nil {
+			return err
+		}
+
+		certReq, ok := msg.(*certificateRequestMsg)
+		if ok {
+			hs.writeServerHash(certReq.marshal())
+			certRequested = true
+			certRequestContext = certReq.requestContext
+
+			chainToSend, err = selectClientCertificate(c, certReq)
+			if err != nil {
+				return err
+			}
+
+			msg, err = c.readHandshake()
+			if err != nil {
+				return err
+			}
+		}
+
+		certMsg, ok := msg.(*certificateMsg)
+		if !ok {
+			c.sendAlert(alertUnexpectedMessage)
+			return unexpectedMessageError(certMsg, msg)
+		}
+		hs.writeServerHash(certMsg.marshal())
+
+		if err := hs.verifyCertificates(certMsg); err != nil {
+			return err
+		}
+		leaf := c.peerCertificates[0]
+
+		msg, err = c.readHandshake()
+		if err != nil {
+			return err
+		}
+		certVerifyMsg, ok := msg.(*certificateVerifyMsg)
+		if !ok {
+			c.sendAlert(alertUnexpectedMessage)
+			return unexpectedMessageError(certVerifyMsg, msg)
+		}
+
+		input := hs.finishedHash.certificateVerifyInput(serverCertificateVerifyContextTLS13)
+		err = verifyMessage(c.vers, leaf.PublicKey, certVerifyMsg.signatureAlgorithm, input, certVerifyMsg.signature)
+		if err != nil {
+			return err
+		}
+
+		hs.writeServerHash(certVerifyMsg.marshal())
+	}
+
+	msg, err = c.readHandshake()
+	if err != nil {
+		return err
+	}
+	serverFinished, ok := msg.(*finishedMsg)
+	if !ok {
+		c.sendAlert(alertUnexpectedMessage)
+		return unexpectedMessageError(serverFinished, msg)
+	}
+
+	verify := hs.finishedHash.serverSum(handshakeTrafficSecret)
+	if len(verify) != len(serverFinished.verifyData) ||
+		subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
+		c.sendAlert(alertHandshakeFailure)
+		return errors.New("tls: server's Finished message was incorrect")
+	}
+
+	hs.writeServerHash(serverFinished.marshal())
+
+	// The various secrets do not incorporate the client's final leg, so
+	// derive them now before updating the handshake context.
+	masterSecret := hs.finishedHash.extractKey(handshakeSecret, zeroSecret)
+	trafficSecret := hs.finishedHash.deriveSecret(masterSecret, applicationTrafficLabel)
+
+	if certRequested {
+		_ = chainToSend
+		_ = certRequestContext
+		return errors.New("tls: client auth not implemented.")
+	}
+
+	// Send a client Finished message.
+	finished := new(finishedMsg)
+	finished.verifyData = hs.finishedHash.clientSum(handshakeTrafficSecret)
+	if c.config.Bugs.BadFinished {
+		finished.verifyData[0]++
+	}
+	c.writeRecord(recordTypeHandshake, finished.marshal())
+
+	// Switch to application data keys.
+	c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite), c.vers)
+	c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite), c.vers)
+
+	// TODO(davidben): Derive and save the exporter master secret for key exporters. Swap out the masterSecret field.
+	// TODO(davidben): Derive and save the resumption master secret for receiving tickets.
+	// TODO(davidben): Save the traffic secret for KeyUpdate.
+	return nil
+}
+
 func (hs *clientHandshakeState) doFullHandshake() error {
 	c := hs.c
 
@@ -633,17 +864,19 @@
 func (hs *clientHandshakeState) processServerExtensions(serverExtensions *serverExtensions) error {
 	c := hs.c
 
-	if c.config.Bugs.RequireRenegotiationInfo && serverExtensions.secureRenegotiation == nil {
-		return errors.New("tls: renegotiation extension missing")
-	}
+	if c.vers < VersionTLS13 || !enableTLS13Handshake {
+		if c.config.Bugs.RequireRenegotiationInfo && serverExtensions.secureRenegotiation == nil {
+			return errors.New("tls: renegotiation extension missing")
+		}
 
-	if len(c.clientVerify) > 0 && !c.noRenegotiationInfo() {
-		var expectedRenegInfo []byte
-		expectedRenegInfo = append(expectedRenegInfo, c.clientVerify...)
-		expectedRenegInfo = append(expectedRenegInfo, c.serverVerify...)
-		if !bytes.Equal(serverExtensions.secureRenegotiation, expectedRenegInfo) {
-			c.sendAlert(alertHandshakeFailure)
-			return fmt.Errorf("tls: renegotiation mismatch")
+		if len(c.clientVerify) > 0 && !c.noRenegotiationInfo() {
+			var expectedRenegInfo []byte
+			expectedRenegInfo = append(expectedRenegInfo, c.clientVerify...)
+			expectedRenegInfo = append(expectedRenegInfo, c.serverVerify...)
+			if !bytes.Equal(serverExtensions.secureRenegotiation, expectedRenegInfo) {
+				c.sendAlert(alertHandshakeFailure)
+				return fmt.Errorf("tls: renegotiation mismatch")
+			}
 		}
 	}
 
@@ -679,11 +912,21 @@
 		c.usedALPN = true
 	}
 
+	if serverHasNPN && c.vers >= VersionTLS13 && enableTLS13Handshake {
+		c.sendAlert(alertHandshakeFailure)
+		return errors.New("server advertised NPN over TLS 1.3")
+	}
+
 	if !hs.hello.channelIDSupported && serverExtensions.channelIDRequested {
 		c.sendAlert(alertHandshakeFailure)
 		return errors.New("server advertised unrequested Channel ID extension")
 	}
 
+	if serverExtensions.channelIDRequested && c.vers >= VersionTLS13 && enableTLS13Handshake {
+		c.sendAlert(alertHandshakeFailure)
+		return errors.New("server advertised Channel ID over TLS 1.3")
+	}
+
 	if serverExtensions.srtpProtectionProfile != 0 {
 		if serverExtensions.srtpMasterKeyIdentifier != "" {
 			return errors.New("tls: server selected SRTP MKI value")
@@ -960,12 +1203,14 @@
 	// the contrary.
 
 	var rsaAvail, ecdsaAvail bool
-	for _, certType := range certReq.certificateTypes {
-		switch certType {
-		case CertTypeRSASign:
-			rsaAvail = true
-		case CertTypeECDSASign:
-			ecdsaAvail = true
+	if !certReq.hasRequestContext {
+		for _, certType := range certReq.certificateTypes {
+			switch certType {
+			case CertTypeRSASign:
+				rsaAvail = true
+			case CertTypeECDSASign:
+				ecdsaAvail = true
+			}
 		}
 	}
 
@@ -974,7 +1219,7 @@
 	// certReq.certificateAuthorities
 findCert:
 	for i, chain := range c.config.Certificates {
-		if !rsaAvail && !ecdsaAvail {
+		if !certReq.hasRequestContext && !rsaAvail && !ecdsaAvail {
 			continue
 		}
 
@@ -998,11 +1243,13 @@
 				}
 			}
 
-			switch {
-			case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA:
-			case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA:
-			default:
-				continue findCert
+			if !certReq.hasRequestContext {
+				switch {
+				case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA:
+				case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA:
+				default:
+					continue findCert
+				}
 			}
 
 			if len(certReq.certificateAuthorities) == 0 {
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 17fb5cb..d349ae7 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -111,11 +111,13 @@
 }
 
 type clientHelloMsg struct {
-	raw                     []byte
-	isDTLS                  bool
-	vers                    uint16
-	random                  []byte
-	sessionId               []byte
+	raw       []byte
+	isDTLS    bool
+	vers      uint16
+	random    []byte
+	sessionId []byte
+	// TODO(davidben): Add support for TLS 1.3 cookies which are larger and
+	// use an extension.
 	cookie                  []byte
 	cipherSuites            []uint16
 	compressionMethods      []uint8
@@ -681,14 +683,19 @@
 }
 
 type serverHelloMsg struct {
-	raw               []byte
-	isDTLS            bool
-	vers              uint16
-	random            []byte
-	sessionId         []byte
-	cipherSuite       uint16
-	compressionMethod uint8
-	extensions        serverExtensions
+	raw                 []byte
+	isDTLS              bool
+	vers                uint16
+	random              []byte
+	sessionId           []byte
+	cipherSuite         uint16
+	hasKeyShare         bool
+	keyShare            keyShareEntry
+	hasPSKIdentity      bool
+	pskIdentity         uint16
+	earlyDataIndication bool
+	compressionMethod   uint8
+	extensions          serverExtensions
 }
 
 func (m *serverHelloMsg) marshal() []byte {
@@ -702,17 +709,39 @@
 	vers := versionToWire(m.vers, m.isDTLS)
 	hello.addU16(vers)
 	hello.addBytes(m.random)
-	sessionId := hello.addU8LengthPrefixed()
-	sessionId.addBytes(m.sessionId)
+	if m.vers < VersionTLS13 || !enableTLS13Handshake {
+		sessionId := hello.addU8LengthPrefixed()
+		sessionId.addBytes(m.sessionId)
+	}
 	hello.addU16(m.cipherSuite)
-	hello.addU8(m.compressionMethod)
+	if m.vers < VersionTLS13 || !enableTLS13Handshake {
+		hello.addU8(m.compressionMethod)
+	}
 
 	extensions := hello.addU16LengthPrefixed()
 
-	m.extensions.marshal(extensions)
-
-	if extensions.len() == 0 {
-		hello.discardChild()
+	if m.vers >= VersionTLS13 && enableTLS13Handshake {
+		if m.hasKeyShare {
+			extensions.addU16(extensionKeyShare)
+			keyShare := extensions.addU16LengthPrefixed()
+			keyShare.addU16(uint16(m.keyShare.group))
+			keyExchange := keyShare.addU16LengthPrefixed()
+			keyExchange.addBytes(m.keyShare.keyExchange)
+		}
+		if m.hasPSKIdentity {
+			extensions.addU16(extensionPreSharedKey)
+			extensions.addU16(2) // Length
+			extensions.addU16(m.pskIdentity)
+		}
+		if m.earlyDataIndication {
+			extensions.addU16(extensionEarlyData)
+			extensions.addU16(0) // Length
+		}
+	} else {
+		m.extensions.marshal(extensions)
+		if extensions.len() == 0 {
+			hello.discardChild()
+		}
 	}
 
 	m.raw = handshakeMsg.finish()
@@ -726,21 +755,30 @@
 	m.raw = data
 	m.vers = wireToVersion(uint16(data[4])<<8|uint16(data[5]), m.isDTLS)
 	m.random = data[6:38]
-	sessionIdLen := int(data[38])
-	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
-		return false
+	data = data[38:]
+	if m.vers < VersionTLS13 || !enableTLS13Handshake {
+		sessionIdLen := int(data[0])
+		if sessionIdLen > 32 || len(data) < 1+sessionIdLen {
+			return false
+		}
+		m.sessionId = data[1 : 1+sessionIdLen]
+		data = data[1+sessionIdLen:]
 	}
-	m.sessionId = data[39 : 39+sessionIdLen]
-	data = data[39+sessionIdLen:]
-	if len(data) < 3 {
+	if len(data) < 2 {
 		return false
 	}
 	m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
-	m.compressionMethod = data[2]
-	data = data[3:]
+	data = data[2:]
+	if m.vers < VersionTLS13 || !enableTLS13Handshake {
+		if len(data) < 1 {
+			return false
+		}
+		m.compressionMethod = data[0]
+		data = data[1:]
+	}
 
-	if len(data) == 0 {
-		// ServerHello is optionally followed by extension data
+	if len(data) == 0 && (m.vers < VersionTLS13 || enableTLS13Handshake) {
+		// Extension data is optional before TLS 1.3.
 		m.extensions = serverExtensions{}
 		return true
 	}
@@ -754,13 +792,98 @@
 		return false
 	}
 
-	if !m.extensions.unmarshal(data) {
+	if m.vers >= VersionTLS13 && enableTLS13Handshake {
+		for len(data) != 0 {
+			if len(data) < 4 {
+				return false
+			}
+			extension := uint16(data[0])<<8 | uint16(data[1])
+			length := int(data[2])<<8 | int(data[3])
+			data = data[4:]
+
+			if len(data) < length {
+				return false
+			}
+			d := data[:length]
+			data = data[length:]
+
+			switch extension {
+			case extensionKeyShare:
+				m.hasKeyShare = true
+				if len(d) < 4 {
+					return false
+				}
+				m.keyShare.group = CurveID(uint16(d[0])<<8 | uint16(d[1]))
+				keyExchLen := int(d[2])<<8 | int(d[3])
+				if keyExchLen != len(d)-4 {
+					return false
+				}
+				m.keyShare.keyExchange = make([]byte, keyExchLen)
+				copy(m.keyShare.keyExchange, d[4:])
+			case extensionPreSharedKey:
+				if len(d) != 2 {
+					return false
+				}
+				m.pskIdentity = uint16(d[0])<<8 | uint16(d[1])
+				m.hasPSKIdentity = true
+			case extensionEarlyData:
+				if len(d) != 0 {
+					return false
+				}
+				m.earlyDataIndication = true
+			default:
+				// Only allow the 3 extensions that are sent in
+				// the clear in TLS 1.3.
+				return false
+			}
+		}
+	} else if !m.extensions.unmarshal(data) {
 		return false
 	}
 
 	return true
 }
 
+type encryptedExtensionsMsg struct {
+	raw        []byte
+	extensions serverExtensions
+}
+
+func (m *encryptedExtensionsMsg) marshal() []byte {
+	if m.raw != nil {
+		return m.raw
+	}
+
+	encryptedExtensionsMsg := newByteBuilder()
+	encryptedExtensionsMsg.addU8(typeEncryptedExtensions)
+	encryptedExtensions := encryptedExtensionsMsg.addU24LengthPrefixed()
+	extensions := encryptedExtensions.addU16LengthPrefixed()
+	m.extensions.marshal(extensions)
+
+	m.raw = encryptedExtensionsMsg.finish()
+	return m.raw
+}
+
+func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
+	if len(data) < 6 {
+		return false
+	}
+	if data[0] != typeEncryptedExtensions {
+		return false
+	}
+	msgLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
+	data = data[4:]
+	if len(data) != msgLen {
+		return false
+	}
+	extLen := int(data[0])<<8 | int(data[1])
+	data = data[2:]
+	if extLen != len(data) {
+		return false
+	}
+	return m.extensions.unmarshal(data)
+}
+
 type serverExtensions struct {
 	nextProtoNeg            bool
 	nextProtos              []string
@@ -962,8 +1085,10 @@
 }
 
 type certificateMsg struct {
-	raw          []byte
-	certificates [][]byte
+	raw               []byte
+	hasRequestContext bool
+	requestContext    []byte
+	certificates      [][]byte
 }
 
 func (m *certificateMsg) marshal() (x []byte) {
@@ -974,6 +1099,10 @@
 	certMsg := newByteBuilder()
 	certMsg.addU8(typeCertificate)
 	certificate := certMsg.addU24LengthPrefixed()
+	if m.hasRequestContext {
+		context := certificate.addU8LengthPrefixed()
+		context.addBytes(m.requestContext)
+	}
 	certificateList := certificate.addU24LengthPrefixed()
 	for _, cert := range m.certificates {
 		certEntry := certificateList.addU24LengthPrefixed()
@@ -985,24 +1114,43 @@
 }
 
 func (m *certificateMsg) unmarshal(data []byte) bool {
-	if len(data) < 7 {
+	if len(data) < 4 {
 		return false
 	}
 
 	m.raw = data
-	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
-	if uint32(len(data)) != certsLen+7 {
+	data = data[4:]
+
+	if m.hasRequestContext {
+		if len(data) == 0 {
+			return false
+		}
+		contextLen := int(data[0])
+		if len(data) < 1+contextLen {
+			return false
+		}
+		m.requestContext = make([]byte, contextLen)
+		copy(m.requestContext, data[1:])
+		data = data[1+contextLen:]
+	}
+
+	if len(data) < 3 {
+		return false
+	}
+	certsLen := int(data[0])<<16 | int(data[1])<<8 | int(data[2])
+	data = data[3:]
+	if len(data) != certsLen {
 		return false
 	}
 
 	numCerts := 0
-	d := data[7:]
+	d := data
 	for certsLen > 0 {
 		if len(d) < 4 {
 			return false
 		}
-		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
-		if uint32(len(d)) < 3+certLen {
+		certLen := int(d[0])<<16 | int(d[1])<<8 | int(d[2])
+		if len(d) < 3+certLen {
 			return false
 		}
 		d = d[3+certLen:]
@@ -1011,7 +1159,7 @@
 	}
 
 	m.certificates = make([][]byte, numCerts)
-	d = data[7:]
+	d = data
 	for i := 0; i < numCerts; i++ {
 		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
 		m.certificates[i] = d[3 : 3+certLen]
@@ -1245,8 +1393,13 @@
 	// of signature and hash functions. This change was introduced with TLS
 	// 1.2.
 	hasSignatureAlgorithm bool
+	// hasRequestContext indicates whether this message includes a context
+	// field instead of certificateTypes. This change was introduced with
+	// TLS 1.3.
+	hasRequestContext bool
 
 	certificateTypes       []byte
+	requestContext         []byte
 	signatureAlgorithms    []signatureAlgorithm
 	certificateAuthorities [][]byte
 }
@@ -1261,8 +1414,13 @@
 	builder.addU8(typeCertificateRequest)
 	body := builder.addU24LengthPrefixed()
 
-	certificateTypes := body.addU8LengthPrefixed()
-	certificateTypes.addBytes(m.certificateTypes)
+	if m.hasRequestContext {
+		requestContext := body.addU8LengthPrefixed()
+		requestContext.addBytes(m.requestContext)
+	} else {
+		certificateTypes := body.addU8LengthPrefixed()
+		certificateTypes.addBytes(m.certificateTypes)
+	}
 
 	if m.hasSignatureAlgorithm {
 		signatureAlgorithms := body.addU16LengthPrefixed()
@@ -1287,25 +1445,26 @@
 	if len(data) < 5 {
 		return false
 	}
+	data = data[4:]
 
-	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
-	if uint32(len(data))-4 != length {
-		return false
+	if m.hasRequestContext {
+		contextLen := int(data[0])
+		if len(data) < 1+contextLen {
+			return false
+		}
+		m.requestContext = make([]byte, contextLen)
+		copy(m.requestContext, data[1:])
+		data = data[1+contextLen:]
+	} else {
+		numCertTypes := int(data[0])
+		if len(data) < 1+numCertTypes {
+			return false
+		}
+		m.certificateTypes = make([]byte, numCertTypes)
+		copy(m.certificateTypes, data[1:])
+		data = data[1+numCertTypes:]
 	}
 
-	numCertTypes := int(data[4])
-	data = data[5:]
-	if numCertTypes == 0 || len(data) <= numCertTypes {
-		return false
-	}
-
-	m.certificateTypes = make([]byte, numCertTypes)
-	if copy(m.certificateTypes, data) != numCertTypes {
-		return false
-	}
-
-	data = data[numCertTypes:]
-
 	if m.hasSignatureAlgorithm {
 		if len(data) < 2 {
 			return false
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go
index a91a319..0fd5762 100644
--- a/ssl/test/runner/prf.go
+++ b/ssl/test/runner/prf.go
@@ -5,11 +5,11 @@
 package runner
 
 import (
+	"crypto"
 	"crypto/hmac"
 	"crypto/md5"
 	"crypto/sha1"
 	"crypto/sha256"
-	"crypto/sha512"
 	"hash"
 )
 
@@ -133,13 +133,11 @@
 	// Once we no longer support Fake TLS 1.3, the VersionTLS13 should be
 	// removed from this case statement.
 	case VersionTLS12, VersionTLS13:
-		if suite.flags&suiteSHA384 != 0 {
-			return prf12(sha512.New384)
+		if version == VersionTLS12 || !enableTLS13Handshake {
+			return prf12(suite.hash().New)
 		}
-		return prf12(sha256.New)
-	default:
-		panic("unknown version")
 	}
+	panic("unknown version")
 }
 
 // masterFromPreMasterSecret generates the master secret from the pre-master
@@ -188,20 +186,38 @@
 }
 
 func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash {
-	if version >= VersionTLS12 {
-		newHash := sha256.New
-		if cipherSuite.flags&suiteSHA384 != 0 {
-			newHash = sha512.New384
-		}
+	var ret finishedHash
 
-		return finishedHash{newHash(), newHash(), nil, nil, []byte{}, version, prf12(newHash)}
+	if version >= VersionTLS12 {
+		ret.hash = cipherSuite.hash()
+
+		ret.client = ret.hash.New()
+		ret.server = ret.hash.New()
+
+		if version == VersionTLS12 || !enableTLS13Handshake {
+			ret.prf = prf12(ret.hash.New)
+		}
+	} else {
+		ret.hash = crypto.MD5SHA1
+
+		ret.client = sha1.New()
+		ret.server = sha1.New()
+		ret.clientMD5 = md5.New()
+		ret.serverMD5 = md5.New()
+
+		ret.prf = prf10
 	}
-	return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), []byte{}, version, prf10}
+
+	ret.buffer = []byte{}
+	ret.version = version
+	return ret
 }
 
 // A finishedHash calculates the hash of a set of handshake messages suitable
 // for including in a Finished message.
 type finishedHash struct {
+	hash crypto.Hash
+
 	client hash.Hash
 	server hash.Hash
 
@@ -213,6 +229,10 @@
 	// full buffer is required.
 	buffer []byte
 
+	// TLS 1.3 has a resumption context which is carried over on PSK
+	// resumption.
+	resumptionContextHash []byte
+
 	version uint16
 	prf     func(result, secret, label, seed []byte)
 }
@@ -280,26 +300,40 @@
 
 // clientSum returns the contents of the verify_data member of a client's
 // Finished message.
-func (h finishedHash) clientSum(masterSecret []byte) []byte {
+func (h finishedHash) clientSum(baseKey []byte) []byte {
 	if h.version == VersionSSL30 {
-		return finishedSum30(h.clientMD5, h.client, masterSecret, ssl3ClientFinishedMagic[:])
+		return finishedSum30(h.clientMD5, h.client, baseKey, ssl3ClientFinishedMagic[:])
 	}
 
-	out := make([]byte, finishedVerifyLength)
-	h.prf(out, masterSecret, clientFinishedLabel, h.Sum())
-	return out
+	if h.version < VersionTLS13 || !enableTLS13Handshake {
+		out := make([]byte, finishedVerifyLength)
+		h.prf(out, baseKey, clientFinishedLabel, h.Sum())
+		return out
+	}
+
+	clientFinishedKey := hkdfExpandLabel(h.hash, baseKey, clientFinishedLabel, nil, h.hash.Size())
+	finishedHMAC := hmac.New(h.hash.New, clientFinishedKey)
+	finishedHMAC.Write(h.appendContextHashes(nil))
+	return finishedHMAC.Sum(nil)
 }
 
 // serverSum returns the contents of the verify_data member of a server's
 // Finished message.
-func (h finishedHash) serverSum(masterSecret []byte) []byte {
+func (h finishedHash) serverSum(baseKey []byte) []byte {
 	if h.version == VersionSSL30 {
-		return finishedSum30(h.serverMD5, h.server, masterSecret, ssl3ServerFinishedMagic[:])
+		return finishedSum30(h.serverMD5, h.server, baseKey, ssl3ServerFinishedMagic[:])
 	}
 
-	out := make([]byte, finishedVerifyLength)
-	h.prf(out, masterSecret, serverFinishedLabel, h.Sum())
-	return out
+	if h.version < VersionTLS13 || !enableTLS13Handshake {
+		out := make([]byte, finishedVerifyLength)
+		h.prf(out, baseKey, serverFinishedLabel, h.Sum())
+		return out
+	}
+
+	serverFinishedKey := hkdfExpandLabel(h.hash, baseKey, serverFinishedLabel, nil, h.hash.Size())
+	finishedHMAC := hmac.New(h.hash.New, serverFinishedKey)
+	finishedHMAC.Write(h.appendContextHashes(nil))
+	return finishedHMAC.Sum(nil)
 }
 
 // hashForClientCertificateSSL3 returns the hash to be signed for client
@@ -331,3 +365,125 @@
 func (h *finishedHash) discardHandshakeBuffer() {
 	h.buffer = nil
 }
+
+// zeroSecretTLS13 returns the default all zeros secret for TLS 1.3, used when a
+// given secret is not available in the handshake. See draft-ietf-tls-tls13-13,
+// section 7.1.
+func (h *finishedHash) zeroSecret() []byte {
+	return make([]byte, h.hash.Size())
+}
+
+// setResumptionContext sets the TLS 1.3 resumption context.
+func (h *finishedHash) setResumptionContext(resumptionContext []byte) {
+	hash := h.hash.New()
+	hash.Write(resumptionContext)
+	h.resumptionContextHash = hash.Sum(nil)
+}
+
+// extractKey combines two secrets together with HKDF-Expand in the TLS 1.3 key
+// derivation schedule.
+func (h *finishedHash) extractKey(salt, ikm []byte) []byte {
+	return hkdfExtract(h.hash.New, salt, ikm)
+}
+
+// hkdfExpandLabel implements TLS 1.3's HKDF-Expand-Label function, as defined
+// in section 7.1 of draft-ietf-tls-tls13-13.
+func hkdfExpandLabel(hash crypto.Hash, secret, label, hashValue []byte, length int) []byte {
+	if len(label) > 255 || len(hashValue) > 255 {
+		panic("hkdfExpandLabel: label or hashValue too long")
+	}
+	hkdfLabel := make([]byte, 3+9+len(label)+1+len(hashValue))
+	x := hkdfLabel
+	x[0] = byte(length >> 8)
+	x[1] = byte(length)
+	x[2] = byte(9 + len(label))
+	x = x[3:]
+	copy(x, []byte("TLS 1.3, "))
+	x = x[9:]
+	copy(x, label)
+	x = x[len(label):]
+	x[0] = byte(len(hashValue))
+	copy(x[1:], hashValue)
+	return hkdfExpand(hash.New, secret, hkdfLabel, length)
+}
+
+// appendContextHashes returns the concatenation of the handshake hash and the
+// resumption context hash, as used in TLS 1.3.
+func (h *finishedHash) appendContextHashes(b []byte) []byte {
+	b = h.client.Sum(b)
+	b = append(b, h.resumptionContextHash...)
+	return b
+}
+
+// The following are labels for traffic secret derivation in TLS 1.3.
+var (
+	earlyTrafficLabel       = []byte("early traffic secret")
+	handshakeTrafficLabel   = []byte("handshake traffic secret")
+	applicationTrafficLabel = []byte("application traffic secret")
+	exporterLabel           = []byte("exporter master secret")
+	resumptionLabel         = []byte("resumption master secret")
+)
+
+// deriveSecret implements TLS 1.3's Derive-Secret function, as defined in
+// section 7.1 of draft ietf-tls-tls13-13.
+func (h *finishedHash) deriveSecret(secret, label []byte) []byte {
+	if h.resumptionContextHash == nil {
+		panic("Resumption context not set.")
+	}
+
+	return hkdfExpandLabel(h.hash, secret, label, h.appendContextHashes(nil), h.hash.Size())
+}
+
+// The following are context strings for CertificateVerify in TLS 1.3.
+var (
+	clientCertificateVerifyContextTLS13 = []byte("TLS 1.3, client CertificateVerify")
+	serverCertificateVerifyContextTLS13 = []byte("TLS 1.3, server CertificateVerify")
+)
+
+// certificateVerifyMessage returns the input to be signed for CertificateVerify
+// in TLS 1.3.
+func (h *finishedHash) certificateVerifyInput(context []byte) []byte {
+	const paddingLen = 64
+	b := make([]byte, paddingLen, paddingLen+len(context)+1+2*h.hash.Size())
+	for i := 0; i < paddingLen; i++ {
+		b[i] = 32
+	}
+	b = append(b, context...)
+	b = append(b, 0)
+	b = h.appendContextHashes(b)
+	return b
+}
+
+// The following are phase values for traffic key derivation in TLS 1.3.
+var (
+	earlyHandshakePhase   = []byte("early handshake key expansion")
+	earlyApplicationPhase = []byte("early application data key expansion")
+	handshakePhase        = []byte("handshake key expansion")
+	applicationPhase      = []byte("application data key expansion")
+)
+
+type trafficDirection int
+
+const (
+	clientWrite trafficDirection = iota
+	serverWrite
+)
+
+// deriveTrafficAEAD derives traffic keys and constructs an AEAD given a traffic
+// secret.
+func deriveTrafficAEAD(version uint16, suite *cipherSuite, secret, phase []byte, side trafficDirection) *tlsAead {
+	label := make([]byte, 0, len(phase)+2+16)
+	label = append(label, phase...)
+	if side == clientWrite {
+		label = append(label, []byte(", client write key")...)
+	} else {
+		label = append(label, []byte(", server write key")...)
+	}
+	key := hkdfExpandLabel(suite.hash(), secret, label, nil, suite.keyLen)
+
+	label = label[:len(label)-3] // Remove "key" from the end.
+	label = append(label, []byte("iv")...)
+	iv := hkdfExpandLabel(suite.hash(), secret, label, nil, suite.ivLen(version))
+
+	return suite.aead(version, key, iv)
+}