Updating Key Schedule and KeyUpdate to draft 16.

This doesn't currently honor the required KeyUpdate response. That will
be done in a follow-up.

BUG=74

Change-Id: I750fc41278736cb24230303815e839c6f6967b6a
Reviewed-on: https://boringssl-review.googlesource.com/11412
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index 8fe61a4..c03cedb 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -206,6 +206,12 @@
 	pskSignAuthMode = 1
 )
 
+// KeyUpdateRequest values (see draft-ietf-tls-tls13-16, section 4.5.3)
+const (
+	keyUpdateNotRequested = 0
+	keyUpdateRequested    = 1
+)
+
 // ConnectionState records basic TLS details about the connection.
 type ConnectionState struct {
 	Version                    uint16                // TLS version used by the connection (e.g. VersionTLS12)
@@ -956,10 +962,6 @@
 	// message. This only makes sense for a server.
 	SendHelloRequestBeforeEveryHandshakeMessage bool
 
-	// SendKeyUpdateBeforeEveryAppDataRecord, if true, causes a KeyUpdate
-	// handshake message to be sent before each application data record.
-	SendKeyUpdateBeforeEveryAppDataRecord bool
-
 	// RequireDHPublicValueLen causes a fatal error if the length (in
 	// bytes) of the server's Diffie-Hellman public value is not equal to
 	// this.
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 24f0d60..f5014d4 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -92,6 +92,8 @@
 	handMsgLen       int      // handshake message length, not including the header
 	pendingFragments [][]byte // pending outgoing handshake fragments.
 
+	keyUpdateRequested bool
+
 	tmp [16]byte
 }
 
@@ -159,8 +161,7 @@
 	// used to save allocating a new buffer for each MAC.
 	inDigestBuf, outDigestBuf []byte
 
-	trafficSecret       []byte
-	keyUpdateGeneration int
+	trafficSecret []byte
 
 	config *Config
 }
@@ -223,7 +224,6 @@
 		side = clientWrite
 	}
 	hc.useTrafficSecret(hc.version, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), hc.trafficSecret), applicationPhase, side)
-	hc.keyUpdateGeneration++
 }
 
 // incSeq increments the sequence number.
@@ -1328,11 +1328,11 @@
 		return 0, alertInternalError
 	}
 
-	// Catch up with KeyUpdates from the peer.
-	for c.out.keyUpdateGeneration < c.in.keyUpdateGeneration {
-		if err := c.sendKeyUpdateLocked(); err != nil {
+	if c.keyUpdateRequested {
+		if err := c.sendKeyUpdateLocked(keyUpdateNotRequested); err != nil {
 			return 0, err
 		}
+		c.keyUpdateRequested = false
 	}
 
 	if c.config.Bugs.SendSpuriousAlert != 0 {
@@ -1344,12 +1344,6 @@
 		c.flushHandshake()
 	}
 
-	if c.config.Bugs.SendKeyUpdateBeforeEveryAppDataRecord {
-		if err := c.sendKeyUpdateLocked(); err != nil {
-			return 0, err
-		}
-	}
-
 	// SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
 	// attack when using block mode ciphers due to predictable IVs.
 	// This can be prevented by splitting each Application Data
@@ -1441,8 +1435,11 @@
 		}
 	}
 
-	if _, ok := msg.(*keyUpdateMsg); ok {
+	if keyUpdate, ok := msg.(*keyUpdateMsg); ok {
 		c.in.doKeyUpdate(c, false)
+		if keyUpdate.keyUpdateRequest == keyUpdateRequested {
+			c.keyUpdateRequested = true
+		}
 		return nil
 	}
 
@@ -1751,18 +1748,20 @@
 	return err
 }
 
-func (c *Conn) SendKeyUpdate() error {
+func (c *Conn) SendKeyUpdate(keyUpdateRequest byte) error {
 	c.out.Lock()
 	defer c.out.Unlock()
-	return c.sendKeyUpdateLocked()
+	return c.sendKeyUpdateLocked(keyUpdateRequest)
 }
 
-func (c *Conn) sendKeyUpdateLocked() error {
+func (c *Conn) sendKeyUpdateLocked(keyUpdateRequest byte) error {
 	if c.vers < VersionTLS13 {
 		return errors.New("tls: attempted to send KeyUpdate before TLS 1.3")
 	}
 
-	m := new(keyUpdateMsg)
+	m := keyUpdateMsg{
+		keyUpdateRequest: keyUpdateRequest,
+	}
 	if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
 		return err
 	}
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index c5be2b7..291a3b4 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -644,9 +644,10 @@
 	handshakeSecret := hs.finishedHash.extractKey(earlySecret, ecdheSecret)
 
 	// Switch to handshake traffic keys.
-	handshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, handshakeTrafficLabel)
-	c.out.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite)
-	c.in.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite)
+	clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, clientHandshakeTrafficLabel)
+	c.out.useTrafficSecret(c.vers, hs.suite, clientHandshakeTrafficSecret, handshakePhase, clientWrite)
+	serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, serverHandshakeTrafficLabel)
+	c.in.useTrafficSecret(c.vers, hs.suite, serverHandshakeTrafficSecret, handshakePhase, serverWrite)
 
 	msg, err := c.readHandshake()
 	if err != nil {
@@ -756,7 +757,7 @@
 		return unexpectedMessageError(serverFinished, msg)
 	}
 
-	verify := hs.finishedHash.serverSum(handshakeTrafficSecret)
+	verify := hs.finishedHash.serverSum(serverHandshakeTrafficSecret)
 	if len(verify) != len(serverFinished.verifyData) ||
 		subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
 		c.sendAlert(alertHandshakeFailure)
@@ -768,7 +769,8 @@
 	// The various secrets do not incorporate the client's final leg, so
 	// derive them now before updating the handshake context.
 	masterSecret := hs.finishedHash.extractKey(handshakeSecret, zeroSecret)
-	trafficSecret := hs.finishedHash.deriveSecret(masterSecret, applicationTrafficLabel)
+	clientTrafficSecret := hs.finishedHash.deriveSecret(masterSecret, clientApplicationTrafficLabel)
+	serverTrafficSecret := hs.finishedHash.deriveSecret(masterSecret, serverApplicationTrafficLabel)
 
 	if certReq != nil && !c.config.Bugs.SkipClientCertificate {
 		certMsg := &certificateMsg{
@@ -813,7 +815,7 @@
 
 	// Send a client Finished message.
 	finished := new(finishedMsg)
-	finished.verifyData = hs.finishedHash.clientSum(handshakeTrafficSecret)
+	finished.verifyData = hs.finishedHash.clientSum(clientHandshakeTrafficSecret)
 	if c.config.Bugs.BadFinished {
 		finished.verifyData[0]++
 	}
@@ -830,8 +832,8 @@
 	c.flushHandshake()
 
 	// Switch to application data keys.
-	c.out.useTrafficSecret(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite)
-	c.in.useTrafficSecret(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite)
+	c.out.useTrafficSecret(c.vers, hs.suite, clientTrafficSecret, applicationPhase, clientWrite)
+	c.in.useTrafficSecret(c.vers, hs.suite, serverTrafficSecret, applicationPhase, serverWrite)
 
 	c.exporterSecret = hs.finishedHash.deriveSecret(masterSecret, exporterLabel)
 	c.resumptionSecret = hs.finishedHash.deriveSecret(masterSecret, resumptionLabel)
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 1e91ac3..b92735e 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -2094,14 +2094,32 @@
 }
 
 type keyUpdateMsg struct {
+	raw              []byte
+	keyUpdateRequest byte
 }
 
-func (*keyUpdateMsg) marshal() []byte {
-	return []byte{typeKeyUpdate, 0, 0, 0}
+func (m *keyUpdateMsg) marshal() []byte {
+	if m.raw != nil {
+		return m.raw
+	}
+
+	return []byte{typeKeyUpdate, 0, 0, 1, m.keyUpdateRequest}
 }
 
-func (*keyUpdateMsg) unmarshal(data []byte) bool {
-	return len(data) == 4
+func (m *keyUpdateMsg) unmarshal(data []byte) bool {
+	m.raw = data
+
+	if len(data) != 5 {
+		return false
+	}
+
+	length := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
+	if len(data)-4 != length {
+		return false
+	}
+
+	m.keyUpdateRequest = data[4]
+	return m.keyUpdateRequest == keyUpdateNotRequested || m.keyUpdateRequest == keyUpdateRequested
 }
 
 func eqUint16s(x, y []uint16) bool {
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 59b34fa..7b26cb6 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -693,9 +693,10 @@
 	handshakeSecret := hs.finishedHash.extractKey(earlySecret, ecdheSecret)
 
 	// Switch to handshake traffic keys.
-	handshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, handshakeTrafficLabel)
-	c.out.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite)
-	c.in.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite)
+	serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, serverHandshakeTrafficLabel)
+	c.out.useTrafficSecret(c.vers, hs.suite, serverHandshakeTrafficSecret, handshakePhase, serverWrite)
+	clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, clientHandshakeTrafficLabel)
+	c.in.useTrafficSecret(c.vers, hs.suite, clientHandshakeTrafficSecret, handshakePhase, clientWrite)
 
 	if hs.hello.useCertAuth {
 		if hs.clientHello.ocspStapling {
@@ -793,7 +794,7 @@
 	}
 
 	finished := new(finishedMsg)
-	finished.verifyData = hs.finishedHash.serverSum(handshakeTrafficSecret)
+	finished.verifyData = hs.finishedHash.serverSum(serverHandshakeTrafficSecret)
 	if config.Bugs.BadFinished {
 		finished.verifyData[0]++
 	}
@@ -807,11 +808,12 @@
 	// The various secrets do not incorporate the client's final leg, so
 	// derive them now before updating the handshake context.
 	masterSecret := hs.finishedHash.extractKey(handshakeSecret, hs.finishedHash.zeroSecret())
-	trafficSecret := hs.finishedHash.deriveSecret(masterSecret, applicationTrafficLabel)
+	clientTrafficSecret := hs.finishedHash.deriveSecret(masterSecret, clientApplicationTrafficLabel)
+	serverTrafficSecret := hs.finishedHash.deriveSecret(masterSecret, serverApplicationTrafficLabel)
 
 	// Switch to application data keys on write. In particular, any alerts
 	// from the client certificate are sent over these keys.
-	c.out.useTrafficSecret(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite)
+	c.out.useTrafficSecret(c.vers, hs.suite, serverTrafficSecret, applicationPhase, serverWrite)
 
 	// If we requested a client certificate, then the client must send a
 	// certificate message, even if it's empty.
@@ -875,7 +877,7 @@
 		return unexpectedMessageError(clientFinished, msg)
 	}
 
-	verify := hs.finishedHash.clientSum(handshakeTrafficSecret)
+	verify := hs.finishedHash.clientSum(clientHandshakeTrafficSecret)
 	if len(verify) != len(clientFinished.verifyData) ||
 		subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
 		c.sendAlert(alertHandshakeFailure)
@@ -884,7 +886,7 @@
 	hs.writeClientHash(clientFinished.marshal())
 
 	// Switch to application data keys on read.
-	c.in.useTrafficSecret(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite)
+	c.in.useTrafficSecret(c.vers, hs.suite, clientTrafficSecret, applicationPhase, clientWrite)
 
 	c.cipherSuite = hs.suite
 	c.exporterSecret = hs.finishedHash.deriveSecret(masterSecret, exporterLabel)
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go
index 5c7b3ab..99ef64f 100644
--- a/ssl/test/runner/prf.go
+++ b/ssl/test/runner/prf.go
@@ -119,6 +119,7 @@
 var keyExpansionLabel = []byte("key expansion")
 var clientFinishedLabel = []byte("client finished")
 var serverFinishedLabel = []byte("server finished")
+var finishedLabel = []byte("finished")
 var channelIDLabel = []byte("TLS Channel ID signature\x00")
 var channelIDResumeLabel = []byte("Resumption\x00")
 
@@ -311,7 +312,7 @@
 		return out
 	}
 
-	clientFinishedKey := hkdfExpandLabel(h.hash, baseKey, clientFinishedLabel, nil, h.hash.Size())
+	clientFinishedKey := hkdfExpandLabel(h.hash, baseKey, finishedLabel, nil, h.hash.Size())
 	finishedHMAC := hmac.New(h.hash.New, clientFinishedKey)
 	finishedHMAC.Write(h.appendContextHashes(nil))
 	return finishedHMAC.Sum(nil)
@@ -330,7 +331,7 @@
 		return out
 	}
 
-	serverFinishedKey := hkdfExpandLabel(h.hash, baseKey, serverFinishedLabel, nil, h.hash.Size())
+	serverFinishedKey := hkdfExpandLabel(h.hash, baseKey, finishedLabel, nil, h.hash.Size())
 	finishedHMAC := hmac.New(h.hash.New, serverFinishedKey)
 	finishedHMAC.Write(h.appendContextHashes(nil))
 	return finishedHMAC.Sum(nil)
@@ -417,11 +418,14 @@
 
 // The following are labels for traffic secret derivation in TLS 1.3.
 var (
-	earlyTrafficLabel       = []byte("early traffic secret")
-	handshakeTrafficLabel   = []byte("handshake traffic secret")
-	applicationTrafficLabel = []byte("application traffic secret")
-	exporterLabel           = []byte("exporter master secret")
-	resumptionLabel         = []byte("resumption master secret")
+	earlyTrafficLabel             = []byte("client early traffic secret")
+	clientHandshakeTrafficLabel   = []byte("client handshake traffic secret")
+	serverHandshakeTrafficLabel   = []byte("server handshake traffic secret")
+	clientApplicationTrafficLabel = []byte("client application traffic secret")
+	serverApplicationTrafficLabel = []byte("server application traffic secret")
+	applicationTrafficLabel       = []byte("application traffic secret")
+	exporterLabel                 = []byte("exporter master secret")
+	resumptionLabel               = []byte("resumption master secret")
 )
 
 // deriveSecret implements TLS 1.3's Derive-Secret function, as defined in
@@ -474,11 +478,7 @@
 func deriveTrafficAEAD(version uint16, suite *cipherSuite, secret, phase []byte, side trafficDirection) interface{} {
 	label := make([]byte, 0, len(phase)+2+16)
 	label = append(label, phase...)
-	if side == clientWrite {
-		label = append(label, []byte(", client write key")...)
-	} else {
-		label = append(label, []byte(", server write key")...)
-	}
+	label = append(label, []byte(", key")...)
 	key := hkdfExpandLabel(suite.hash(), secret, label, nil, suite.keyLen)
 
 	label = label[:len(label)-3] // Remove "key" from the end.
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index fe2cf84..2256346 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -359,6 +359,8 @@
 	// sendKeyUpdates is the number of consecutive key updates to send
 	// before and after the test message.
 	sendKeyUpdates int
+	// keyUpdateRequest is the KeyUpdateRequest value to send in KeyUpdate messages.
+	keyUpdateRequest byte
 	// expectMessageDropped, if true, means the test message is expected to
 	// be dropped by the client rather than echoed back.
 	expectMessageDropped bool
@@ -616,7 +618,7 @@
 	}
 
 	for i := 0; i < test.sendKeyUpdates; i++ {
-		if err := tlsConn.SendKeyUpdate(); err != nil {
+		if err := tlsConn.SendKeyUpdate(test.keyUpdateRequest); err != nil {
 			return err
 		}
 	}
@@ -678,7 +680,7 @@
 		tlsConn.Write(testMessage)
 
 		for i := 0; i < test.sendKeyUpdates; i++ {
-			tlsConn.SendKeyUpdate()
+			tlsConn.SendKeyUpdate(test.keyUpdateRequest)
 		}
 
 		for i := 0; i < test.sendEmptyRecords; i++ {
@@ -1981,13 +1983,14 @@
 			expectedError:     ":TOO_MANY_WARNING_ALERTS:",
 		},
 		{
-			name: "SendKeyUpdates",
+			name: "TooManyKeyUpdates",
 			config: Config{
 				MaxVersion: VersionTLS13,
 			},
-			sendKeyUpdates: 33,
-			shouldFail:     true,
-			expectedError:  ":TOO_MANY_KEY_UPDATES:",
+			sendKeyUpdates:   33,
+			keyUpdateRequest: keyUpdateNotRequested,
+			shouldFail:       true,
+			expectedError:    ":TOO_MANY_KEY_UPDATES:",
 		},
 		{
 			name: "EmptySessionID",
@@ -2195,14 +2198,22 @@
 			expectedError: ":WRONG_VERSION_NUMBER:",
 		},
 		{
-			testType: clientTest,
-			name:     "KeyUpdate",
+			name: "KeyUpdate",
 			config: Config{
 				MaxVersion: VersionTLS13,
-				Bugs: ProtocolBugs{
-					SendKeyUpdateBeforeEveryAppDataRecord: true,
-				},
 			},
+			sendKeyUpdates:   1,
+			keyUpdateRequest: keyUpdateNotRequested,
+		},
+		{
+			name: "KeyUpdate-InvalidRequestMode",
+			config: Config{
+				MaxVersion: VersionTLS13,
+			},
+			sendKeyUpdates:   1,
+			keyUpdateRequest: 42,
+			shouldFail:       true,
+			expectedError:    ":DECODE_ERROR:",
 		},
 		{
 			name: "SendSNIWarningAlert",
@@ -8723,11 +8734,10 @@
 		name: "Peek-KeyUpdate",
 		config: Config{
 			MaxVersion: VersionTLS13,
-			Bugs: ProtocolBugs{
-				SendKeyUpdateBeforeEveryAppDataRecord: true,
-			},
 		},
-		flags: []string{"-peek-then-read"},
+		sendKeyUpdates:   1,
+		keyUpdateRequest: keyUpdateNotRequested,
+		flags:            []string{"-peek-then-read"},
 	})
 }