Implement KeyUpdate in Go.

Implemented in preparation for testing the C implementation. Tested
against itself.

BUG=74

Change-Id: Iec1b9ad22e09711fa4e67c97cc3eb257585c3ae5
Reviewed-on: https://boringssl-review.googlesource.com/8873
Reviewed-by: Nick Harper <nharper@chromium.org>
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index fd95781..1014dea 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -68,6 +68,7 @@
 	typeClientKeyExchange   uint8 = 16
 	typeFinished            uint8 = 20
 	typeCertificateStatus   uint8 = 22
+	typeKeyUpdate           uint8 = 24  // draft-ietf-tls-tls13-13
 	typeNextProtocol        uint8 = 67  // Not IANA assigned
 	typeChannelID           uint8 = 203 // Not IANA assigned
 )
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 969c50a..cefdde3 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -159,6 +159,9 @@
 	// used to save allocating a new buffer for each MAC.
 	inDigestBuf, outDigestBuf []byte
 
+	trafficSecret       []byte
+	keyUpdateGeneration int
+
 	config *Config
 }
 
@@ -203,13 +206,23 @@
 	return nil
 }
 
-// updateKeys sets the current cipher state.
-func (hc *halfConn) updateKeys(cipher interface{}, version uint16) {
+// useTrafficSecret sets the current cipher state for TLS 1.3.
+func (hc *halfConn) useTrafficSecret(version uint16, suite *cipherSuite, secret, phase []byte, side trafficDirection) {
 	hc.version = version
-	hc.cipher = cipher
+	hc.cipher = deriveTrafficAEAD(version, suite, secret, phase, side)
+	hc.trafficSecret = secret
 	hc.incEpoch()
 }
 
+func (hc *halfConn) doKeyUpdate(c *Conn, isOutgoing bool) {
+	side := serverWrite
+	if c.isClient == isOutgoing {
+		side = clientWrite
+	}
+	hc.useTrafficSecret(hc.version, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), hc.trafficSecret), applicationPhase, side)
+	hc.keyUpdateGeneration++
+}
+
 // incSeq increments the sequence number.
 func (hc *halfConn) incSeq(isOutgoing bool) {
 	limit := 0
@@ -1175,6 +1188,8 @@
 		m = new(helloVerifyRequestMsg)
 	case typeChannelID:
 		m = new(channelIDMsg)
+	case typeKeyUpdate:
+		m = new(keyUpdateMsg)
 	default:
 		return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
 	}
@@ -1280,6 +1295,13 @@
 		return 0, alertInternalError
 	}
 
+	// Catch up with KeyUpdates from the peer.
+	for c.out.keyUpdateGeneration < c.in.keyUpdateGeneration {
+		if err := c.sendKeyUpdateLocked(); err != nil {
+			return 0, err
+		}
+	}
+
 	if c.config.Bugs.SendSpuriousAlert != 0 {
 		c.sendAlertLocked(alertLevelError, c.config.Bugs.SendSpuriousAlert)
 	}
@@ -1357,6 +1379,11 @@
 		}
 	}
 
+	if _, ok := msg.(*keyUpdateMsg); ok {
+		c.in.doKeyUpdate(c, true)
+		return nil
+	}
+
 	// TODO(davidben): Add support for KeyUpdate.
 	c.sendAlert(alertUnexpectedMessage)
 	return alertUnexpectedMessage
@@ -1648,3 +1675,21 @@
 	_, err := c.writeRecord(recordTypeHandshake, m.marshal())
 	return err
 }
+
+func (c *Conn) SendKeyUpdate() error {
+	c.out.Lock()
+	defer c.out.Unlock()
+	return c.sendKeyUpdateLocked()
+}
+
+func (c *Conn) sendKeyUpdateLocked() error {
+	m := new(keyUpdateMsg)
+	if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
+		return err
+	}
+	if err := c.flushHandshake(); err != nil {
+		return err
+	}
+	c.out.doKeyUpdate(c, false)
+	return nil
+}
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 7718447..b32be0e 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -594,8 +594,8 @@
 
 	// Switch to handshake traffic keys.
 	handshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, handshakeTrafficLabel)
-	c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite), c.vers)
-	c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite), c.vers)
+	c.out.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite)
+	c.in.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite)
 
 	msg, err := c.readHandshake()
 	if err != nil {
@@ -767,10 +767,9 @@
 	c.flushHandshake()
 
 	// Switch to application data keys.
-	c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite), c.vers)
-	c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite), c.vers)
+	c.out.useTrafficSecret(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite)
+	c.in.useTrafficSecret(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite)
 
-	// TODO(davidben): Save the traffic secret for KeyUpdate.
 	c.exporterSecret = hs.finishedHash.deriveSecret(masterSecret, exporterLabel)
 	c.resumptionSecret = hs.finishedHash.deriveSecret(masterSecret, resumptionLabel)
 	return nil
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 41a8fb2..8e73a3c 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -1924,6 +1924,17 @@
 	return len(data) == 4
 }
 
+type keyUpdateMsg struct {
+}
+
+func (*keyUpdateMsg) marshal() []byte {
+	return []byte{typeKeyUpdate, 0, 0, 0}
+}
+
+func (*keyUpdateMsg) unmarshal(data []byte) bool {
+	return len(data) == 4
+}
+
 func eqUint16s(x, y []uint16) bool {
 	if len(x) != len(y) {
 		return false
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index a660f72..012c836 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -490,8 +490,8 @@
 
 	// Switch to handshake traffic keys.
 	handshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, handshakeTrafficLabel)
-	c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite), c.vers)
-	c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite), c.vers)
+	c.out.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite)
+	c.in.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite)
 
 	if hs.suite.flags&suitePSK != 0 {
 		return errors.New("tls: PSK ciphers not implemented for TLS 1.3")
@@ -591,7 +591,7 @@
 
 	// Switch to application data keys on write. In particular, any alerts
 	// from the client certificate are sent over these keys.
-	c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite), c.vers)
+	c.out.useTrafficSecret(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite)
 
 	// If we requested a client certificate, then the client must send a
 	// certificate message, even if it's empty.
@@ -664,9 +664,8 @@
 	hs.writeClientHash(clientFinished.marshal())
 
 	// Switch to application data keys on read.
-	c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite), c.vers)
+	c.in.useTrafficSecret(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite)
 
-	// TODO(davidben): Save the traffic secret for KeyUpdate.
 	c.cipherSuite = hs.suite
 	c.exporterSecret = hs.finishedHash.deriveSecret(masterSecret, exporterLabel)
 	c.resumptionSecret = hs.finishedHash.deriveSecret(masterSecret, resumptionLabel)
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go
index abee6e6..220aa44 100644
--- a/ssl/test/runner/prf.go
+++ b/ssl/test/runner/prf.go
@@ -493,3 +493,7 @@
 
 	return suite.aead(version, key, iv)
 }
+
+func updateTrafficSecret(hash crypto.Hash, secret []byte) []byte {
+	return hkdfExpandLabel(hash, secret, applicationTrafficLabel, nil, hash.Size())
+}