Implement basic TLS 1.3 client handshake in Go.

[Originally written by nharper and then revised by davidben.]

Most features are missing, but it works for a start. To avoid breaking
the fake TLS 1.3 tests while the C code is still not landed, all the
logic is gated on a global boolean. When the C code gets in, we'll
set it to true and remove this boolean.

Change-Id: I6b3a369890864c26203fc9cda37c8250024ce91b
Reviewed-on: https://boringssl-review.googlesource.com/8601
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go
index a91a319..0fd5762 100644
--- a/ssl/test/runner/prf.go
+++ b/ssl/test/runner/prf.go
@@ -5,11 +5,11 @@
 package runner
 
 import (
+	"crypto"
 	"crypto/hmac"
 	"crypto/md5"
 	"crypto/sha1"
 	"crypto/sha256"
-	"crypto/sha512"
 	"hash"
 )
 
@@ -133,13 +133,11 @@
 	// Once we no longer support Fake TLS 1.3, the VersionTLS13 should be
 	// removed from this case statement.
 	case VersionTLS12, VersionTLS13:
-		if suite.flags&suiteSHA384 != 0 {
-			return prf12(sha512.New384)
+		if version == VersionTLS12 || !enableTLS13Handshake {
+			return prf12(suite.hash().New)
 		}
-		return prf12(sha256.New)
-	default:
-		panic("unknown version")
 	}
+	panic("unknown version")
 }
 
 // masterFromPreMasterSecret generates the master secret from the pre-master
@@ -188,20 +186,38 @@
 }
 
 func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash {
-	if version >= VersionTLS12 {
-		newHash := sha256.New
-		if cipherSuite.flags&suiteSHA384 != 0 {
-			newHash = sha512.New384
-		}
+	var ret finishedHash
 
-		return finishedHash{newHash(), newHash(), nil, nil, []byte{}, version, prf12(newHash)}
+	if version >= VersionTLS12 {
+		ret.hash = cipherSuite.hash()
+
+		ret.client = ret.hash.New()
+		ret.server = ret.hash.New()
+
+		if version == VersionTLS12 || !enableTLS13Handshake {
+			ret.prf = prf12(ret.hash.New)
+		}
+	} else {
+		ret.hash = crypto.MD5SHA1
+
+		ret.client = sha1.New()
+		ret.server = sha1.New()
+		ret.clientMD5 = md5.New()
+		ret.serverMD5 = md5.New()
+
+		ret.prf = prf10
 	}
-	return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), []byte{}, version, prf10}
+
+	ret.buffer = []byte{}
+	ret.version = version
+	return ret
 }
 
 // A finishedHash calculates the hash of a set of handshake messages suitable
 // for including in a Finished message.
 type finishedHash struct {
+	hash crypto.Hash
+
 	client hash.Hash
 	server hash.Hash
 
@@ -213,6 +229,10 @@
 	// full buffer is required.
 	buffer []byte
 
+	// TLS 1.3 has a resumption context which is carried over on PSK
+	// resumption.
+	resumptionContextHash []byte
+
 	version uint16
 	prf     func(result, secret, label, seed []byte)
 }
@@ -280,26 +300,40 @@
 
 // clientSum returns the contents of the verify_data member of a client's
 // Finished message.
-func (h finishedHash) clientSum(masterSecret []byte) []byte {
+func (h finishedHash) clientSum(baseKey []byte) []byte {
 	if h.version == VersionSSL30 {
-		return finishedSum30(h.clientMD5, h.client, masterSecret, ssl3ClientFinishedMagic[:])
+		return finishedSum30(h.clientMD5, h.client, baseKey, ssl3ClientFinishedMagic[:])
 	}
 
-	out := make([]byte, finishedVerifyLength)
-	h.prf(out, masterSecret, clientFinishedLabel, h.Sum())
-	return out
+	if h.version < VersionTLS13 || !enableTLS13Handshake {
+		out := make([]byte, finishedVerifyLength)
+		h.prf(out, baseKey, clientFinishedLabel, h.Sum())
+		return out
+	}
+
+	clientFinishedKey := hkdfExpandLabel(h.hash, baseKey, clientFinishedLabel, nil, h.hash.Size())
+	finishedHMAC := hmac.New(h.hash.New, clientFinishedKey)
+	finishedHMAC.Write(h.appendContextHashes(nil))
+	return finishedHMAC.Sum(nil)
 }
 
 // serverSum returns the contents of the verify_data member of a server's
 // Finished message.
-func (h finishedHash) serverSum(masterSecret []byte) []byte {
+func (h finishedHash) serverSum(baseKey []byte) []byte {
 	if h.version == VersionSSL30 {
-		return finishedSum30(h.serverMD5, h.server, masterSecret, ssl3ServerFinishedMagic[:])
+		return finishedSum30(h.serverMD5, h.server, baseKey, ssl3ServerFinishedMagic[:])
 	}
 
-	out := make([]byte, finishedVerifyLength)
-	h.prf(out, masterSecret, serverFinishedLabel, h.Sum())
-	return out
+	if h.version < VersionTLS13 || !enableTLS13Handshake {
+		out := make([]byte, finishedVerifyLength)
+		h.prf(out, baseKey, serverFinishedLabel, h.Sum())
+		return out
+	}
+
+	serverFinishedKey := hkdfExpandLabel(h.hash, baseKey, serverFinishedLabel, nil, h.hash.Size())
+	finishedHMAC := hmac.New(h.hash.New, serverFinishedKey)
+	finishedHMAC.Write(h.appendContextHashes(nil))
+	return finishedHMAC.Sum(nil)
 }
 
 // hashForClientCertificateSSL3 returns the hash to be signed for client
@@ -331,3 +365,125 @@
 func (h *finishedHash) discardHandshakeBuffer() {
 	h.buffer = nil
 }
+
+// zeroSecretTLS13 returns the default all zeros secret for TLS 1.3, used when a
+// given secret is not available in the handshake. See draft-ietf-tls-tls13-13,
+// section 7.1.
+func (h *finishedHash) zeroSecret() []byte {
+	return make([]byte, h.hash.Size())
+}
+
+// setResumptionContext sets the TLS 1.3 resumption context.
+func (h *finishedHash) setResumptionContext(resumptionContext []byte) {
+	hash := h.hash.New()
+	hash.Write(resumptionContext)
+	h.resumptionContextHash = hash.Sum(nil)
+}
+
+// extractKey combines two secrets together with HKDF-Expand in the TLS 1.3 key
+// derivation schedule.
+func (h *finishedHash) extractKey(salt, ikm []byte) []byte {
+	return hkdfExtract(h.hash.New, salt, ikm)
+}
+
+// hkdfExpandLabel implements TLS 1.3's HKDF-Expand-Label function, as defined
+// in section 7.1 of draft-ietf-tls-tls13-13.
+func hkdfExpandLabel(hash crypto.Hash, secret, label, hashValue []byte, length int) []byte {
+	if len(label) > 255 || len(hashValue) > 255 {
+		panic("hkdfExpandLabel: label or hashValue too long")
+	}
+	hkdfLabel := make([]byte, 3+9+len(label)+1+len(hashValue))
+	x := hkdfLabel
+	x[0] = byte(length >> 8)
+	x[1] = byte(length)
+	x[2] = byte(9 + len(label))
+	x = x[3:]
+	copy(x, []byte("TLS 1.3, "))
+	x = x[9:]
+	copy(x, label)
+	x = x[len(label):]
+	x[0] = byte(len(hashValue))
+	copy(x[1:], hashValue)
+	return hkdfExpand(hash.New, secret, hkdfLabel, length)
+}
+
+// appendContextHashes returns the concatenation of the handshake hash and the
+// resumption context hash, as used in TLS 1.3.
+func (h *finishedHash) appendContextHashes(b []byte) []byte {
+	b = h.client.Sum(b)
+	b = append(b, h.resumptionContextHash...)
+	return b
+}
+
+// The following are labels for traffic secret derivation in TLS 1.3.
+var (
+	earlyTrafficLabel       = []byte("early traffic secret")
+	handshakeTrafficLabel   = []byte("handshake traffic secret")
+	applicationTrafficLabel = []byte("application traffic secret")
+	exporterLabel           = []byte("exporter master secret")
+	resumptionLabel         = []byte("resumption master secret")
+)
+
+// deriveSecret implements TLS 1.3's Derive-Secret function, as defined in
+// section 7.1 of draft ietf-tls-tls13-13.
+func (h *finishedHash) deriveSecret(secret, label []byte) []byte {
+	if h.resumptionContextHash == nil {
+		panic("Resumption context not set.")
+	}
+
+	return hkdfExpandLabel(h.hash, secret, label, h.appendContextHashes(nil), h.hash.Size())
+}
+
+// The following are context strings for CertificateVerify in TLS 1.3.
+var (
+	clientCertificateVerifyContextTLS13 = []byte("TLS 1.3, client CertificateVerify")
+	serverCertificateVerifyContextTLS13 = []byte("TLS 1.3, server CertificateVerify")
+)
+
+// certificateVerifyMessage returns the input to be signed for CertificateVerify
+// in TLS 1.3.
+func (h *finishedHash) certificateVerifyInput(context []byte) []byte {
+	const paddingLen = 64
+	b := make([]byte, paddingLen, paddingLen+len(context)+1+2*h.hash.Size())
+	for i := 0; i < paddingLen; i++ {
+		b[i] = 32
+	}
+	b = append(b, context...)
+	b = append(b, 0)
+	b = h.appendContextHashes(b)
+	return b
+}
+
+// The following are phase values for traffic key derivation in TLS 1.3.
+var (
+	earlyHandshakePhase   = []byte("early handshake key expansion")
+	earlyApplicationPhase = []byte("early application data key expansion")
+	handshakePhase        = []byte("handshake key expansion")
+	applicationPhase      = []byte("application data key expansion")
+)
+
+type trafficDirection int
+
+const (
+	clientWrite trafficDirection = iota
+	serverWrite
+)
+
+// deriveTrafficAEAD derives traffic keys and constructs an AEAD given a traffic
+// secret.
+func deriveTrafficAEAD(version uint16, suite *cipherSuite, secret, phase []byte, side trafficDirection) *tlsAead {
+	label := make([]byte, 0, len(phase)+2+16)
+	label = append(label, phase...)
+	if side == clientWrite {
+		label = append(label, []byte(", client write key")...)
+	} else {
+		label = append(label, []byte(", server write key")...)
+	}
+	key := hkdfExpandLabel(suite.hash(), secret, label, nil, suite.keyLen)
+
+	label = label[:len(label)-3] // Remove "key" from the end.
+	label = append(label, []byte("iv")...)
+	iv := hkdfExpandLabel(suite.hash(), secret, label, nil, suite.ivLen(version))
+
+	return suite.aead(version, key, iv)
+}