Simplify BoGo's TLS 1.3 key derivation.

finishedHash should keep a running secret and incorporate entropy as is
available.

Change-Id: I2d245897e7520b2317bc0051fa4d821c32eeaa10
Reviewed-on: https://boringssl-review.googlesource.com/12586
Reviewed-by: Nick Harper <nharper@chromium.org>
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index e138f22..7fa7ea2 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -644,7 +644,6 @@
 	//
 	// TODO(davidben): This will need to be handled slightly earlier once
 	// 0-RTT is implemented.
-	var psk []byte
 	if hs.serverHello.hasPSKIdentity {
 		// We send at most one PSK identity.
 		if hs.session == nil || hs.serverHello.pskIdentity != 0 {
@@ -656,21 +655,18 @@
 			c.sendAlert(alertHandshakeFailure)
 			return errors.New("tls: server resumed an invalid session for the cipher suite")
 		}
-		psk = hs.session.masterSecret
+		hs.finishedHash.addEntropy(hs.session.masterSecret)
 		c.didResume = true
 	} else {
-		psk = zeroSecret
+		hs.finishedHash.addEntropy(zeroSecret)
 	}
 
-	earlySecret := hs.finishedHash.extractKey(zeroSecret, psk)
-
 	if !hs.serverHello.hasKeyShare {
 		c.sendAlert(alertUnsupportedExtension)
 		return errors.New("tls: server omitted KeyShare on resumption.")
 	}
 
 	// Resolve ECDHE and compute the handshake secret.
-	var ecdheSecret []byte
 	if !c.config.Bugs.MissingKeyShare && !c.config.Bugs.SecondClientHelloMissingKeyShare {
 		curve, ok := hs.keyShares[hs.serverHello.keyShare.group]
 		if !ok {
@@ -679,22 +675,19 @@
 		}
 		c.curveID = hs.serverHello.keyShare.group
 
-		var err error
-		ecdheSecret, err = curve.finish(hs.serverHello.keyShare.keyExchange)
+		ecdheSecret, err := curve.finish(hs.serverHello.keyShare.keyExchange)
 		if err != nil {
 			return err
 		}
+		hs.finishedHash.addEntropy(ecdheSecret)
 	} else {
-		ecdheSecret = zeroSecret
+		hs.finishedHash.addEntropy(zeroSecret)
 	}
 
-	// Compute the handshake secret.
-	handshakeSecret := hs.finishedHash.extractKey(earlySecret, ecdheSecret)
-
 	// Switch to handshake traffic keys.
-	clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, clientHandshakeTrafficLabel)
+	clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel)
 	c.out.useTrafficSecret(c.vers, hs.suite, clientHandshakeTrafficSecret, clientWrite)
-	serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, serverHandshakeTrafficLabel)
+	serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel)
 	c.in.useTrafficSecret(c.vers, hs.suite, serverHandshakeTrafficSecret, serverWrite)
 
 	msg, err := c.readHandshake()
@@ -822,9 +815,9 @@
 
 	// The various secrets do not incorporate the client's final leg, so
 	// derive them now before updating the handshake context.
-	masterSecret := hs.finishedHash.extractKey(handshakeSecret, zeroSecret)
-	clientTrafficSecret := hs.finishedHash.deriveSecret(masterSecret, clientApplicationTrafficLabel)
-	serverTrafficSecret := hs.finishedHash.deriveSecret(masterSecret, serverApplicationTrafficLabel)
+	hs.finishedHash.addEntropy(zeroSecret)
+	clientTrafficSecret := hs.finishedHash.deriveSecret(clientApplicationTrafficLabel)
+	serverTrafficSecret := hs.finishedHash.deriveSecret(serverApplicationTrafficLabel)
 
 	if certReq != nil && !c.config.Bugs.SkipClientCertificate {
 		certMsg := &certificateMsg{
@@ -911,8 +904,8 @@
 	c.out.useTrafficSecret(c.vers, hs.suite, clientTrafficSecret, clientWrite)
 	c.in.useTrafficSecret(c.vers, hs.suite, serverTrafficSecret, serverWrite)
 
-	c.exporterSecret = hs.finishedHash.deriveSecret(masterSecret, exporterLabel)
-	c.resumptionSecret = hs.finishedHash.deriveSecret(masterSecret, resumptionLabel)
+	c.exporterSecret = hs.finishedHash.deriveSecret(exporterLabel)
+	c.resumptionSecret = hs.finishedHash.deriveSecret(resumptionLabel)
 	return nil
 }
 
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 67950ba..57566c5 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -501,15 +501,12 @@
 	}
 
 	// Resolve PSK and compute the early secret.
-	var psk []byte
 	if hs.sessionState != nil {
-		psk = hs.sessionState.masterSecret
+		hs.finishedHash.addEntropy(hs.sessionState.masterSecret)
 	} else {
-		psk = hs.finishedHash.zeroSecret()
+		hs.finishedHash.addEntropy(hs.finishedHash.zeroSecret())
 	}
 
-	earlySecret := hs.finishedHash.extractKey(hs.finishedHash.zeroSecret(), psk)
-
 	hs.hello.hasKeyShare = true
 	if hs.sessionState != nil && config.Bugs.NegotiatePSKResumption {
 		hs.hello.hasKeyShare = false
@@ -647,7 +644,6 @@
 	}
 
 	// Resolve ECDHE and compute the handshake secret.
-	var ecdheSecret []byte
 	if hs.hello.hasKeyShare {
 		// Once a curve has been selected and a key share identified,
 		// the server needs to generate a public value and send it in
@@ -672,13 +668,12 @@
 			peerKey = selectedKeyShare.keyExchange
 		}
 
-		var publicKey []byte
-		var err error
-		publicKey, ecdheSecret, err = curve.accept(config.rand(), peerKey)
+		publicKey, ecdheSecret, err := curve.accept(config.rand(), peerKey)
 		if err != nil {
 			c.sendAlert(alertHandshakeFailure)
 			return err
 		}
+		hs.finishedHash.addEntropy(ecdheSecret)
 		hs.hello.hasKeyShare = true
 
 		curveID := selectedCurve
@@ -702,7 +697,7 @@
 			}
 		}
 	} else {
-		ecdheSecret = hs.finishedHash.zeroSecret()
+		hs.finishedHash.addEntropy(hs.finishedHash.zeroSecret())
 	}
 
 	// Send unencrypted ServerHello.
@@ -718,13 +713,10 @@
 	}
 	c.flushHandshake()
 
-	// Compute the handshake secret.
-	handshakeSecret := hs.finishedHash.extractKey(earlySecret, ecdheSecret)
-
 	// Switch to handshake traffic keys.
-	serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, serverHandshakeTrafficLabel)
+	serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel)
 	c.out.useTrafficSecret(c.vers, hs.suite, serverHandshakeTrafficSecret, serverWrite)
-	clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, clientHandshakeTrafficLabel)
+	clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel)
 	c.in.useTrafficSecret(c.vers, hs.suite, clientHandshakeTrafficSecret, clientWrite)
 
 	// Send EncryptedExtensions.
@@ -842,10 +834,10 @@
 
 	// The various secrets do not incorporate the client's final leg, so
 	// derive them now before updating the handshake context.
-	masterSecret := hs.finishedHash.extractKey(handshakeSecret, hs.finishedHash.zeroSecret())
-	clientTrafficSecret := hs.finishedHash.deriveSecret(masterSecret, clientApplicationTrafficLabel)
-	serverTrafficSecret := hs.finishedHash.deriveSecret(masterSecret, serverApplicationTrafficLabel)
-	c.exporterSecret = hs.finishedHash.deriveSecret(masterSecret, exporterLabel)
+	hs.finishedHash.addEntropy(hs.finishedHash.zeroSecret())
+	clientTrafficSecret := hs.finishedHash.deriveSecret(clientApplicationTrafficLabel)
+	serverTrafficSecret := hs.finishedHash.deriveSecret(serverApplicationTrafficLabel)
+	c.exporterSecret = hs.finishedHash.deriveSecret(exporterLabel)
 
 	// Switch to application data keys on write. In particular, any alerts
 	// from the client certificate are sent over these keys.
@@ -956,7 +948,7 @@
 	c.in.useTrafficSecret(c.vers, hs.suite, clientTrafficSecret, clientWrite)
 
 	c.cipherSuite = hs.suite
-	c.resumptionSecret = hs.finishedHash.deriveSecret(masterSecret, resumptionLabel)
+	c.resumptionSecret = hs.finishedHash.deriveSecret(resumptionLabel)
 
 	// TODO(davidben): Allow configuring the number of tickets sent for
 	// testing.
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go
index c311e99..50f37df 100644
--- a/ssl/test/runner/prf.go
+++ b/ssl/test/runner/prf.go
@@ -197,6 +197,8 @@
 
 		if version == VersionTLS12 {
 			ret.prf = prf12(ret.hash.New)
+		} else {
+			ret.secret = make([]byte, ret.hash.Size())
 		}
 	} else {
 		ret.hash = crypto.MD5SHA1
@@ -232,6 +234,9 @@
 
 	version uint16
 	prf     func(result, secret, label, seed []byte)
+
+	// secret, in TLS 1.3, is the running input secret.
+	secret []byte
 }
 
 func (h *finishedHash) Write(msg []byte) (n int, err error) {
@@ -370,10 +375,9 @@
 	return make([]byte, h.hash.Size())
 }
 
-// extractKey combines two secrets together with HKDF-Expand in the TLS 1.3 key
-// derivation schedule.
-func (h *finishedHash) extractKey(salt, ikm []byte) []byte {
-	return hkdfExtract(h.hash.New, salt, ikm)
+// addEntropy incorporates ikm into the running TLS 1.3 secret with HKDF-Expand.
+func (h *finishedHash) addEntropy(ikm []byte) {
+	h.secret = hkdfExtract(h.hash.New, h.secret, ikm)
 }
 
 // hkdfExpandLabel implements TLS 1.3's HKDF-Expand-Label function, as defined
@@ -420,8 +424,8 @@
 
 // deriveSecret implements TLS 1.3's Derive-Secret function, as defined in
 // section 7.1 of draft ietf-tls-tls13-16.
-func (h *finishedHash) deriveSecret(secret, label []byte) []byte {
-	return hkdfExpandLabel(h.hash, secret, label, h.appendContextHashes(nil), h.hash.Size())
+func (h *finishedHash) deriveSecret(label []byte) []byte {
+	return hkdfExpandLabel(h.hash, h.secret, label, h.appendContextHashes(nil), h.hash.Size())
 }
 
 // The following are context strings for CertificateVerify in TLS 1.3.
@@ -472,8 +476,8 @@
 
 func computePSKBinder(psk, label []byte, cipherSuite *cipherSuite, transcript, truncatedHello []byte) []byte {
 	finishedHash := newFinishedHash(VersionTLS13, cipherSuite)
-	earlySecret := finishedHash.extractKey(finishedHash.zeroSecret(), psk)
-	binderKey := finishedHash.deriveSecret(earlySecret, label)
+	finishedHash.addEntropy(psk)
+	binderKey := finishedHash.deriveSecret(label)
 	finishedHash.Write(transcript)
 	finishedHash.Write(truncatedHello)
 	return finishedHash.clientSum(binderKey)