Implement TLS Channel ID in runner.go

Change-Id: Ia349c7a7cdcfd49965cd0c4d6cf81a76fbffb696
Reviewed-on: https://boringssl-review.googlesource.com/1604
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index f14f4e9..daeeb5e 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -7,6 +7,7 @@
 import (
 	"container/list"
 	"crypto"
+	"crypto/ecdsa"
 	"crypto/rand"
 	"crypto/x509"
 	"fmt"
@@ -47,19 +48,20 @@
 
 // TLS handshake message types.
 const (
-	typeClientHello        uint8 = 1
-	typeServerHello        uint8 = 2
-	typeHelloVerifyRequest uint8 = 3
-	typeNewSessionTicket   uint8 = 4
-	typeCertificate        uint8 = 11
-	typeServerKeyExchange  uint8 = 12
-	typeCertificateRequest uint8 = 13
-	typeServerHelloDone    uint8 = 14
-	typeCertificateVerify  uint8 = 15
-	typeClientKeyExchange  uint8 = 16
-	typeFinished           uint8 = 20
-	typeCertificateStatus  uint8 = 22
-	typeNextProtocol       uint8 = 67 // Not IANA assigned
+	typeClientHello         uint8 = 1
+	typeServerHello         uint8 = 2
+	typeHelloVerifyRequest  uint8 = 3
+	typeNewSessionTicket    uint8 = 4
+	typeCertificate         uint8 = 11
+	typeServerKeyExchange   uint8 = 12
+	typeCertificateRequest  uint8 = 13
+	typeServerHelloDone     uint8 = 14
+	typeCertificateVerify   uint8 = 15
+	typeClientKeyExchange   uint8 = 16
+	typeFinished            uint8 = 20
+	typeCertificateStatus   uint8 = 22
+	typeNextProtocol        uint8 = 67  // Not IANA assigned
+	typeEncryptedExtensions uint8 = 203 // Not IANA assigned
 )
 
 // TLS compression types.
@@ -77,6 +79,7 @@
 	extensionSessionTicket       uint16 = 35
 	extensionNextProtoNeg        uint16 = 13172 // not IANA assigned
 	extensionRenegotiationInfo   uint16 = 0xff01
+	extensionChannelID           uint16 = 30032 // not IANA assigned
 )
 
 // TLS signaling cipher suite values
@@ -166,6 +169,7 @@
 	ServerName                 string                // server name requested by client, if any (server side only)
 	PeerCertificates           []*x509.Certificate   // certificate chain presented by remote peer
 	VerifiedChains             [][]*x509.Certificate // verified chains built from PeerCertificates
+	ChannelID                  *ecdsa.PublicKey      // the channel ID for this connection
 }
 
 // ClientAuthType declares the policy the server will follow for
@@ -187,6 +191,7 @@
 	vers               uint16              // SSL/TLS version negotiated for the session
 	cipherSuite        uint16              // Ciphersuite negotiated for the session
 	masterSecret       []byte              // MasterSecret generated by client on a full handshake
+	handshakeHash      []byte              // Handshake hash for Channel ID purposes.
 	serverCertificates []*x509.Certificate // Certificate chain presented by the server
 }
 
@@ -307,6 +312,15 @@
 	// be used.
 	CurvePreferences []CurveID
 
+	// ChannelID contains the ECDSA key for the client to use as
+	// its TLS Channel ID.
+	ChannelID *ecdsa.PrivateKey
+
+	// RequestChannelID controls whether the server requests a TLS
+	// Channel ID. If negotiated, the client's public key is
+	// returned in the ConnectionState.
+	RequestChannelID bool
+
 	// Bugs specifies optional misbehaviour to be used for testing other
 	// implementations.
 	Bugs ProtocolBugs
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 5371a64..2c5e920 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -9,6 +9,7 @@
 import (
 	"bytes"
 	"crypto/cipher"
+	"crypto/ecdsa"
 	"crypto/subtle"
 	"crypto/x509"
 	"errors"
@@ -47,6 +48,8 @@
 	clientProtocol         string
 	clientProtocolFallback bool
 
+	channelID *ecdsa.PublicKey
+
 	// input/output
 	in, out  halfConn     // in.Mutex < out.Mutex
 	rawInput *block       // raw input, right off the wire
@@ -937,6 +940,8 @@
 		m = new(finishedMsg)
 	case typeHelloVerifyRequest:
 		m = new(helloVerifyRequestMsg)
+	case typeEncryptedExtensions:
+		m = new(encryptedExtensionsMsg)
 	default:
 		return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
 	}
@@ -1104,6 +1109,7 @@
 		state.PeerCertificates = c.peerCertificates
 		state.VerifiedChains = c.verifiedChains
 		state.ServerName = c.serverName
+		state.ChannelID = c.channelID
 	}
 
 	return state
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index ecc2bed..c683913 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -8,6 +8,7 @@
 	"bytes"
 	"crypto"
 	"crypto/ecdsa"
+	"crypto/elliptic"
 	"crypto/rsa"
 	"crypto/subtle"
 	"crypto/x509"
@@ -54,6 +55,7 @@
 		nextProtoNeg:        len(c.config.NextProtos) > 0,
 		secureRenegotiation: true,
 		duplicateExtension:  c.config.Bugs.DuplicateExtension,
+		channelIDSupported:  c.config.ChannelID != nil,
 	}
 
 	if c.config.Bugs.SendClientVersion != 0 {
@@ -238,7 +240,7 @@
 		if err := hs.readFinished(); err != nil {
 			return err
 		}
-		if err := hs.sendFinished(); err != nil {
+		if err := hs.sendFinished(isResume); err != nil {
 			return err
 		}
 	} else {
@@ -248,7 +250,7 @@
 		if err := hs.establishKeys(); err != nil {
 			return err
 		}
-		if err := hs.sendFinished(); err != nil {
+		if err := hs.sendFinished(isResume); err != nil {
 			return err
 		}
 		if err := hs.readSessionTicket(); err != nil {
@@ -565,6 +567,11 @@
 		return false, errors.New("server advertised unrequested NPN extension")
 	}
 
+	if !hs.hello.channelIDSupported && hs.serverHello.channelIDRequested {
+		c.sendAlert(alertHandshakeFailure)
+		return false, errors.New("server advertised unrequested Channel ID extension")
+	}
+
 	if hs.serverResumedSession() {
 		// Restore masterSecret and peerCerts from previous state
 		hs.masterSecret = hs.session.masterSecret
@@ -619,20 +626,22 @@
 		c.sendAlert(alertUnexpectedMessage)
 		return unexpectedMessageError(sessionTicketMsg, msg)
 	}
-	hs.writeServerHash(sessionTicketMsg.marshal())
 
 	hs.session = &ClientSessionState{
 		sessionTicket:      sessionTicketMsg.ticket,
 		vers:               c.vers,
 		cipherSuite:        hs.suite.id,
 		masterSecret:       hs.masterSecret,
+		handshakeHash:      hs.finishedHash.server.Sum(nil),
 		serverCertificates: c.peerCertificates,
 	}
 
+	hs.writeServerHash(sessionTicketMsg.marshal())
+
 	return nil
 }
 
-func (hs *clientHandshakeState) sendFinished() error {
+func (hs *clientHandshakeState) sendFinished(isResume bool) error {
 	c := hs.c
 
 	var postCCSBytes []byte
@@ -650,6 +659,34 @@
 		postCCSBytes = append(postCCSBytes, nextProtoBytes...)
 	}
 
+	if hs.serverHello.channelIDRequested {
+		encryptedExtensions := new(encryptedExtensionsMsg)
+		if c.config.ChannelID.Curve != elliptic.P256() {
+			return fmt.Errorf("tls: Channel ID is not on P-256.")
+		}
+		var resumeHash []byte
+		if isResume {
+			resumeHash = hs.session.handshakeHash
+		}
+		r, s, err := ecdsa.Sign(c.config.rand(), c.config.ChannelID, hs.finishedHash.hashForChannelID(resumeHash))
+		if err != nil {
+			return err
+		}
+		channelID := make([]byte, 128)
+		writeIntPadded(channelID[0:32], c.config.ChannelID.X)
+		writeIntPadded(channelID[32:64], c.config.ChannelID.Y)
+		writeIntPadded(channelID[64:96], r)
+		writeIntPadded(channelID[96:128], s)
+		encryptedExtensions.channelID = channelID
+
+		c.channelID = &c.config.ChannelID.PublicKey
+
+		encryptedExtensionsBytes := encryptedExtensions.marshal()
+		hs.writeHash(encryptedExtensionsBytes, seqno)
+		seqno++
+		postCCSBytes = append(postCCSBytes, encryptedExtensionsBytes...)
+	}
+
 	finished := new(finishedMsg)
 	if c.config.Bugs.EarlyChangeCipherSpec == 2 {
 		finished.verifyData = hs.finishedHash.clientSum(nil)
@@ -724,3 +761,13 @@
 
 	return clientProtos[0], true
 }
+
+// writeIntPadded writes x into b, padded up with leading zeros as
+// needed.
+func writeIntPadded(b []byte, x *big.Int) {
+	for i := range b {
+		b[i] = 0
+	}
+	xb := x.Bytes()
+	copy(b[len(b)-len(xb):], xb)
+}
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 7fe8bf5..472aa87 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -25,6 +25,7 @@
 	signatureAndHashes  []signatureAndHash
 	secureRenegotiation bool
 	duplicateExtension  bool
+	channelIDSupported  bool
 }
 
 func (m *clientHelloMsg) equal(i interface{}) bool {
@@ -49,7 +50,9 @@
 		m.ticketSupported == m1.ticketSupported &&
 		bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
 		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
-		m.secureRenegotiation == m1.secureRenegotiation
+		m.secureRenegotiation == m1.secureRenegotiation &&
+		m.duplicateExtension == m1.duplicateExtension &&
+		m.channelIDSupported == m1.channelIDSupported
 }
 
 func (m *clientHelloMsg) marshal() []byte {
@@ -97,6 +100,9 @@
 	if m.duplicateExtension {
 		numExtensions += 2
 	}
+	if m.channelIDSupported {
+		numExtensions++
+	}
 	if numExtensions > 0 {
 		extensionsLength += 4 * numExtensions
 		length += 2 + extensionsLength
@@ -260,6 +266,11 @@
 		z[3] = 1
 		z = z[5:]
 	}
+	if m.channelIDSupported {
+		z[0] = byte(extensionChannelID >> 8)
+		z[1] = byte(extensionChannelID & 0xff)
+		z = z[4:]
+	}
 	if m.duplicateExtension {
 		// Add a duplicate bogus extension at the beginning and end.
 		z[0] = 0xff
@@ -440,6 +451,11 @@
 				return false
 			}
 			m.secureRenegotiation = true
+		case extensionChannelID:
+			if length > 0 {
+				return false
+			}
+			m.channelIDSupported = true
 		}
 		data = data[length:]
 	}
@@ -461,6 +477,7 @@
 	ticketSupported     bool
 	secureRenegotiation bool
 	duplicateExtension  bool
+	channelIDRequested  bool
 }
 
 func (m *serverHelloMsg) equal(i interface{}) bool {
@@ -480,7 +497,9 @@
 		eqStrings(m.nextProtos, m1.nextProtos) &&
 		m.ocspStapling == m1.ocspStapling &&
 		m.ticketSupported == m1.ticketSupported &&
-		m.secureRenegotiation == m1.secureRenegotiation
+		m.secureRenegotiation == m1.secureRenegotiation &&
+		m.duplicateExtension == m1.duplicateExtension &&
+		m.channelIDRequested == m1.channelIDRequested
 }
 
 func (m *serverHelloMsg) marshal() []byte {
@@ -514,6 +533,9 @@
 	if m.duplicateExtension {
 		numExtensions += 2
 	}
+	if m.channelIDRequested {
+		numExtensions++
+	}
 	if numExtensions > 0 {
 		extensionsLength += 4 * numExtensions
 		length += 2 + extensionsLength
@@ -581,6 +603,11 @@
 		z[3] = 1
 		z = z[5:]
 	}
+	if m.channelIDRequested {
+		z[0] = byte(extensionChannelID >> 8)
+		z[1] = byte(extensionChannelID & 0xff)
+		z = z[4:]
+	}
 	if m.duplicateExtension {
 		// Add a duplicate bogus extension at the beginning and end.
 		z[0] = 0xff
@@ -671,6 +698,11 @@
 				return false
 			}
 			m.secureRenegotiation = true
+		case extensionChannelID:
+			if length > 0 {
+				return false
+			}
+			m.channelIDRequested = true
 		}
 		data = data[length:]
 	}
@@ -1407,7 +1439,8 @@
 		return false
 	}
 
-	return m.vers == m1.vers &&
+	return bytes.Equal(m.raw, m1.raw) &&
+		m.vers == m1.vers &&
 		bytes.Equal(m.cookie, m1.cookie)
 }
 
@@ -1447,6 +1480,58 @@
 	return true
 }
 
+type encryptedExtensionsMsg struct {
+	raw       []byte
+	channelID []byte
+}
+
+func (m *encryptedExtensionsMsg) equal(i interface{}) bool {
+	m1, ok := i.(*encryptedExtensionsMsg)
+	if !ok {
+		return false
+	}
+
+	return bytes.Equal(m.raw, m1.raw) &&
+		bytes.Equal(m.channelID, m1.channelID)
+}
+
+func (m *encryptedExtensionsMsg) marshal() []byte {
+	if m.raw != nil {
+		return m.raw
+	}
+
+	length := 2 + 2 + len(m.channelID)
+
+	x := make([]byte, 4+length)
+	x[0] = typeEncryptedExtensions
+	x[1] = uint8(length >> 16)
+	x[2] = uint8(length >> 8)
+	x[3] = uint8(length)
+	x[4] = uint8(extensionChannelID >> 8)
+	x[5] = uint8(extensionChannelID & 0xff)
+	x[6] = uint8(len(m.channelID) >> 8)
+	x[7] = uint8(len(m.channelID) & 0xff)
+	copy(x[8:], m.channelID)
+
+	return x
+}
+
+func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
+	if len(data) != 4+2+2+128 {
+		return false
+	}
+	m.raw = data
+	if (uint16(data[4])<<8)|uint16(data[5]) != extensionChannelID {
+		return false
+	}
+	if int(data[6])<<8|int(data[7]) != 128 {
+		return false
+	}
+	m.channelID = data[4+2+2:]
+
+	return true
+}
+
 func eqUint16s(x, y []uint16) bool {
 	if len(x) != len(y) {
 		return false
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index d5a660a..40e1b88 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -8,6 +8,7 @@
 	"bytes"
 	"crypto"
 	"crypto/ecdsa"
+	"crypto/elliptic"
 	"crypto/rsa"
 	"crypto/subtle"
 	"crypto/x509"
@@ -15,6 +16,7 @@
 	"errors"
 	"fmt"
 	"io"
+	"math/big"
 )
 
 // serverHandshakeState contains details of a server handshake in progress.
@@ -69,7 +71,7 @@
 		if err := hs.sendFinished(); err != nil {
 			return err
 		}
-		if err := hs.readFinished(); err != nil {
+		if err := hs.readFinished(isResume); err != nil {
 			return err
 		}
 		c.didResume = true
@@ -82,7 +84,7 @@
 		if err := hs.establishKeys(); err != nil {
 			return err
 		}
-		if err := hs.readFinished(); err != nil {
+		if err := hs.readFinished(isResume); err != nil {
 			return err
 		}
 		if err := hs.sendSessionTicket(); err != nil {
@@ -231,6 +233,10 @@
 		hs.cert = config.getCertificateForName(hs.clientHello.serverName)
 	}
 
+	if hs.clientHello.channelIDSupported && config.RequestChannelID {
+		hs.hello.channelIDRequested = true
+	}
+
 	_, hs.ecdsaOk = hs.cert.PrivateKey.(*ecdsa.PrivateKey)
 
 	if hs.checkForResumption() {
@@ -579,7 +585,7 @@
 	return nil
 }
 
-func (hs *serverHandshakeState) readFinished() error {
+func (hs *serverHandshakeState) readFinished(isResume bool) error {
 	c := hs.c
 
 	c.readRecord(recordTypeChangeCipherSpec)
@@ -601,6 +607,36 @@
 		c.clientProtocol = nextProto.proto
 	}
 
+	if hs.hello.channelIDRequested {
+		msg, err := c.readHandshake()
+		if err != nil {
+			return err
+		}
+		encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
+		if !ok {
+			c.sendAlert(alertUnexpectedMessage)
+			return unexpectedMessageError(encryptedExtensions, msg)
+		}
+		x := new(big.Int).SetBytes(encryptedExtensions.channelID[0:32])
+		y := new(big.Int).SetBytes(encryptedExtensions.channelID[32:64])
+		r := new(big.Int).SetBytes(encryptedExtensions.channelID[64:96])
+		s := new(big.Int).SetBytes(encryptedExtensions.channelID[96:128])
+		if !elliptic.P256().IsOnCurve(x, y) {
+			return errors.New("tls: invalid channel ID public key")
+		}
+		channelID := &ecdsa.PublicKey{elliptic.P256(), x, y}
+		var resumeHash []byte
+		if isResume {
+			resumeHash = hs.sessionState.handshakeHash
+		}
+		if !ecdsa.Verify(channelID, hs.finishedHash.hashForChannelID(resumeHash), r, s) {
+			return errors.New("tls: invalid channel ID signature")
+		}
+		c.channelID = channelID
+
+		hs.writeClientHash(encryptedExtensions.marshal())
+	}
+
 	msg, err := c.readHandshake()
 	if err != nil {
 		return err
@@ -632,10 +668,11 @@
 
 	var err error
 	state := sessionState{
-		vers:         c.vers,
-		cipherSuite:  hs.suite.id,
-		masterSecret: hs.masterSecret,
-		certificates: hs.certsFromClient,
+		vers:          c.vers,
+		cipherSuite:   hs.suite.id,
+		masterSecret:  hs.masterSecret,
+		certificates:  hs.certsFromClient,
+		handshakeHash: hs.finishedHash.server.Sum(nil),
 	}
 	m.ticket, err = c.encryptTicket(&state)
 	if err != nil {
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go
index 991196f..55a3614 100644
--- a/ssl/test/runner/prf.go
+++ b/ssl/test/runner/prf.go
@@ -120,6 +120,8 @@
 var keyExpansionLabel = []byte("key expansion")
 var clientFinishedLabel = []byte("client finished")
 var serverFinishedLabel = []byte("server finished")
+var channelIDLabel = []byte("TLS Channel ID signature\x00")
+var channelIDResumeLabel = []byte("Resumption\x00")
 
 func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) {
 	switch version {
@@ -321,3 +323,17 @@
 	digest = h.server.Sum(digest)
 	return digest, crypto.MD5SHA1, nil
 }
+
+// hashForChannelID returns the hash to be signed for TLS Channel
+// ID. If a resumption, resumeHash has the previous handshake
+// hash. Otherwise, it is nil.
+func (h finishedHash) hashForChannelID(resumeHash []byte) []byte {
+	hash := sha256.New()
+	hash.Write(channelIDLabel)
+	if resumeHash != nil {
+		hash.Write(channelIDResumeLabel)
+		hash.Write(resumeHash)
+	}
+	hash.Write(h.server.Sum(nil))
+	return hash.Sum(nil)
+}
diff --git a/ssl/test/runner/ticket.go b/ssl/test/runner/ticket.go
index 519543b..74791d6 100644
--- a/ssl/test/runner/ticket.go
+++ b/ssl/test/runner/ticket.go
@@ -18,10 +18,11 @@
 // sessionState contains the information that is serialized into a session
 // ticket in order to later resume a connection.
 type sessionState struct {
-	vers         uint16
-	cipherSuite  uint16
-	masterSecret []byte
-	certificates [][]byte
+	vers          uint16
+	cipherSuite   uint16
+	masterSecret  []byte
+	handshakeHash []byte
+	certificates  [][]byte
 }
 
 func (s *sessionState) equal(i interface{}) bool {
@@ -32,7 +33,8 @@
 
 	if s.vers != s1.vers ||
 		s.cipherSuite != s1.cipherSuite ||
-		!bytes.Equal(s.masterSecret, s1.masterSecret) {
+		!bytes.Equal(s.masterSecret, s1.masterSecret) ||
+		!bytes.Equal(s.handshakeHash, s1.handshakeHash) {
 		return false
 	}
 
@@ -50,7 +52,7 @@
 }
 
 func (s *sessionState) marshal() []byte {
-	length := 2 + 2 + 2 + len(s.masterSecret) + 2
+	length := 2 + 2 + 2 + len(s.masterSecret) + 2 + len(s.handshakeHash) + 2
 	for _, cert := range s.certificates {
 		length += 4 + len(cert)
 	}
@@ -67,6 +69,12 @@
 	copy(x, s.masterSecret)
 	x = x[len(s.masterSecret):]
 
+	x[0] = byte(len(s.handshakeHash) >> 8)
+	x[1] = byte(len(s.handshakeHash))
+	x = x[2:]
+	copy(x, s.handshakeHash)
+	x = x[len(s.handshakeHash):]
+
 	x[0] = byte(len(s.certificates) >> 8)
 	x[1] = byte(len(s.certificates))
 	x = x[2:]
@@ -103,6 +111,19 @@
 		return false
 	}
 
+	handshakeHashLen := int(data[0])<<8 | int(data[1])
+	data = data[2:]
+	if len(data) < handshakeHashLen {
+		return false
+	}
+
+	s.handshakeHash = data[:handshakeHashLen]
+	data = data[handshakeHashLen:]
+
+	if len(data) < 2 {
+		return false
+	}
+
 	numCerts := int(data[0])<<8 | int(data[1])
 	data = data[2:]