Add tests for PSK cipher suites.

Only the three plain PSK suites for now. ECDHE_PSK_WITH_AES_128_GCM_SHA256 will
be in a follow-up.

Change-Id: Iafc116a5b2798c61d90c139b461cf98897ae23b3
Reviewed-on: https://boringssl-review.googlesource.com/2051
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/cipher_suites.go b/ssl/test/runner/cipher_suites.go
index 6cd0de9..a0e6b94 100644
--- a/ssl/test/runner/cipher_suites.go
+++ b/ssl/test/runner/cipher_suites.go
@@ -57,6 +57,9 @@
 	// suiteNoDTLS indicates that the cipher suite cannot be used
 	// in DTLS.
 	suiteNoDTLS
+	// suitePSK indicates that the cipher suite authenticates with
+	// a pre-shared key rather than a server private key.
+	suitePSK
 )
 
 // A cipherSuite is a specific combination of key agreement, cipher and MAC
@@ -109,6 +112,9 @@
 	{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil},
 	{TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, dheRSAKA, 0, cipher3DES, macSHA1, nil},
 	{TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil},
+	{TLS_PSK_WITH_RC4_128_SHA, 16, 20, 0, pskKA, suiteNoDTLS | suitePSK, cipherRC4, macSHA1, nil},
+	{TLS_PSK_WITH_AES_128_CBC_SHA, 16, 20, 16, pskKA, suitePSK, cipherAES, macSHA1, nil},
+	{TLS_PSK_WITH_AES_256_CBC_SHA, 32, 20, 16, pskKA, suitePSK, cipherAES, macSHA1, nil},
 }
 
 func cipherRC4(key, iv []byte, isRead bool) interface{} {
@@ -312,6 +318,10 @@
 	}
 }
 
+func pskKA(version uint16) keyAgreement {
+	return &pskKeyAgreement{}
+}
+
 // mutualCipherSuite returns a cipherSuite given a list of supported
 // ciphersuites and the id requested by the peer.
 func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
@@ -343,6 +353,9 @@
 	TLS_RSA_WITH_AES_256_CBC_SHA256         uint16 = 0x003d
 	TLS_DHE_RSA_WITH_AES_128_CBC_SHA256     uint16 = 0x0067
 	TLS_DHE_RSA_WITH_AES_256_CBC_SHA256     uint16 = 0x006b
+	TLS_PSK_WITH_RC4_128_SHA                uint16 = 0x008a
+	TLS_PSK_WITH_AES_128_CBC_SHA            uint16 = 0x008c
+	TLS_PSK_WITH_AES_256_CBC_SHA            uint16 = 0x008d
 	TLS_RSA_WITH_AES_128_GCM_SHA256         uint16 = 0x009c
 	TLS_RSA_WITH_AES_256_GCM_SHA384         uint16 = 0x009d
 	TLS_DHE_RSA_WITH_AES_128_GCM_SHA256     uint16 = 0x009e
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index 935fd15..4aa21bb 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -325,6 +325,14 @@
 	// returned in the ConnectionState.
 	RequestChannelID bool
 
+	// PreSharedKey, if not nil, is the pre-shared key to use with
+	// the PSK cipher suites.
+	PreSharedKey []byte
+
+	// PreSharedKeyIdentity, if not empty, is the identity to use
+	// with the PSK cipher suites.
+	PreSharedKeyIdentity string
+
 	// Bugs specifies optional misbehaviour to be used for testing other
 	// implementations.
 	Bugs ProtocolBugs
@@ -737,9 +745,10 @@
 }
 
 func initDefaultCipherSuites() {
-	varDefaultCipherSuites = make([]uint16, len(cipherSuites))
-	for i, suite := range cipherSuites {
-		varDefaultCipherSuites[i] = suite.id
+	for _, suite := range cipherSuites {
+		if suite.flags&suitePSK == 0 {
+			varDefaultCipherSuites = append(varDefaultCipherSuites, suite.id)
+		}
 	}
 }
 
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 2f9fe12..11a1ed3 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -308,60 +308,65 @@
 func (hs *clientHandshakeState) doFullHandshake() error {
 	c := hs.c
 
-	msg, err := c.readHandshake()
-	if err != nil {
-		return err
-	}
-	certMsg, ok := msg.(*certificateMsg)
-	if !ok || len(certMsg.certificates) == 0 {
-		c.sendAlert(alertUnexpectedMessage)
-		return unexpectedMessageError(certMsg, msg)
-	}
-	hs.writeServerHash(certMsg.marshal())
-
-	certs := make([]*x509.Certificate, len(certMsg.certificates))
-	for i, asn1Data := range certMsg.certificates {
-		cert, err := x509.ParseCertificate(asn1Data)
+	var leaf *x509.Certificate
+	if hs.suite.flags&suitePSK == 0 {
+		msg, err := c.readHandshake()
 		if err != nil {
-			c.sendAlert(alertBadCertificate)
-			return errors.New("tls: failed to parse certificate from server: " + err.Error())
-		}
-		certs[i] = cert
-	}
-
-	if !c.config.InsecureSkipVerify {
-		opts := x509.VerifyOptions{
-			Roots:         c.config.RootCAs,
-			CurrentTime:   c.config.time(),
-			DNSName:       c.config.ServerName,
-			Intermediates: x509.NewCertPool(),
-		}
-
-		for i, cert := range certs {
-			if i == 0 {
-				continue
-			}
-			opts.Intermediates.AddCert(cert)
-		}
-		c.verifiedChains, err = certs[0].Verify(opts)
-		if err != nil {
-			c.sendAlert(alertBadCertificate)
 			return err
 		}
-	}
 
-	switch certs[0].PublicKey.(type) {
-	case *rsa.PublicKey, *ecdsa.PublicKey:
-		break
-	default:
-		c.sendAlert(alertUnsupportedCertificate)
-		return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey)
-	}
+		certMsg, ok := msg.(*certificateMsg)
+		if !ok || len(certMsg.certificates) == 0 {
+			c.sendAlert(alertUnexpectedMessage)
+			return unexpectedMessageError(certMsg, msg)
+		}
+		hs.writeServerHash(certMsg.marshal())
 
-	c.peerCertificates = certs
+		certs := make([]*x509.Certificate, len(certMsg.certificates))
+		for i, asn1Data := range certMsg.certificates {
+			cert, err := x509.ParseCertificate(asn1Data)
+			if err != nil {
+				c.sendAlert(alertBadCertificate)
+				return errors.New("tls: failed to parse certificate from server: " + err.Error())
+			}
+			certs[i] = cert
+		}
+		leaf = certs[0]
+
+		if !c.config.InsecureSkipVerify {
+			opts := x509.VerifyOptions{
+				Roots:         c.config.RootCAs,
+				CurrentTime:   c.config.time(),
+				DNSName:       c.config.ServerName,
+				Intermediates: x509.NewCertPool(),
+			}
+
+			for i, cert := range certs {
+				if i == 0 {
+					continue
+				}
+				opts.Intermediates.AddCert(cert)
+			}
+			c.verifiedChains, err = leaf.Verify(opts)
+			if err != nil {
+				c.sendAlert(alertBadCertificate)
+				return err
+			}
+		}
+
+		switch leaf.PublicKey.(type) {
+		case *rsa.PublicKey, *ecdsa.PublicKey:
+			break
+		default:
+			c.sendAlert(alertUnsupportedCertificate)
+			return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", leaf.PublicKey)
+		}
+
+		c.peerCertificates = certs
+	}
 
 	if hs.serverHello.ocspStapling {
-		msg, err = c.readHandshake()
+		msg, err := c.readHandshake()
 		if err != nil {
 			return err
 		}
@@ -377,7 +382,7 @@
 		}
 	}
 
-	msg, err = c.readHandshake()
+	msg, err := c.readHandshake()
 	if err != nil {
 		return err
 	}
@@ -387,7 +392,7 @@
 	skx, ok := msg.(*serverKeyExchangeMsg)
 	if ok {
 		hs.writeServerHash(skx.marshal())
-		err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, certs[0], skx)
+		err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, leaf, skx)
 		if err != nil {
 			c.sendAlert(alertUnexpectedMessage)
 			return err
@@ -488,7 +493,7 @@
 	// Certificate message, even if it's empty because we don't have a
 	// certificate to send.
 	if certRequested {
-		certMsg = new(certificateMsg)
+		certMsg := new(certificateMsg)
 		if chainToSend != nil {
 			certMsg.certificates = chainToSend.Certificate
 		}
@@ -496,7 +501,7 @@
 		c.writeRecord(recordTypeHandshake, certMsg.marshal())
 	}
 
-	preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0])
+	preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, leaf)
 	if err != nil {
 		c.sendAlert(alertInternalError)
 		return err
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 645a67c..4bf8f1c 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -383,7 +383,8 @@
 	config := hs.c.config
 	c := hs.c
 
-	if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
+	isPSK := hs.suite.flags&suitePSK != 0
+	if !isPSK && hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
 		hs.hello.ocspStapling = true
 	}
 
@@ -397,11 +398,13 @@
 
 	c.writeRecord(recordTypeHandshake, hs.hello.marshal())
 
-	certMsg := new(certificateMsg)
-	certMsg.certificates = hs.cert.Certificate
-	if !config.Bugs.UnauthenticatedECDH {
-		hs.writeServerHash(certMsg.marshal())
-		c.writeRecord(recordTypeHandshake, certMsg.marshal())
+	if !isPSK {
+		certMsg := new(certificateMsg)
+		certMsg.certificates = hs.cert.Certificate
+		if !config.Bugs.UnauthenticatedECDH {
+			hs.writeServerHash(certMsg.marshal())
+			c.writeRecord(recordTypeHandshake, certMsg.marshal())
+		}
 	}
 
 	if hs.hello.ocspStapling {
@@ -466,6 +469,7 @@
 	// If we requested a client certificate, then the client must send a
 	// certificate message, even if it's empty.
 	if config.ClientAuth >= RequestClientCert {
+		var certMsg *certificateMsg
 		if certMsg, ok = msg.(*certificateMsg); !ok {
 			c.sendAlert(alertUnexpectedMessage)
 			return unexpectedMessageError(certMsg, msg)
diff --git a/ssl/test/runner/key_agreement.go b/ssl/test/runner/key_agreement.go
index f8ba1f8..4f76cb1 100644
--- a/ssl/test/runner/key_agreement.go
+++ b/ssl/test/runner/key_agreement.go
@@ -586,3 +586,90 @@
 
 	return preMasterSecret, ckx, nil
 }
+
+// makePSKPremaster formats a PSK pre-master secret based on
+// otherSecret from the base key exchange and psk.
+func makePSKPremaster(otherSecret, psk []byte) []byte {
+	out := make([]byte, 0, 2+len(otherSecret)+2+len(psk))
+	out = append(out, byte(len(otherSecret)>>8), byte(len(otherSecret)))
+	out = append(out, otherSecret...)
+	out = append(out, byte(len(psk)>>8), byte(len(psk)))
+	out = append(out, psk...)
+	return out
+}
+
+// pskKeyAgreement implements the PSK key agreement.
+type pskKeyAgreement struct {
+	identityHint string
+}
+
+func (ka *pskKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
+	// ServerKeyExchange is optional if the identity hint is empty.
+	if config.PreSharedKeyIdentity == "" {
+		return nil, nil
+	}
+	bytes := make([]byte, 2+len(config.PreSharedKeyIdentity))
+	bytes[0] = byte(len(config.PreSharedKeyIdentity) >> 8)
+	bytes[1] = byte(len(config.PreSharedKeyIdentity))
+	copy(bytes[2:], []byte(config.PreSharedKeyIdentity))
+
+	skx := new(serverKeyExchangeMsg)
+	skx.key = bytes
+	return skx, nil
+}
+
+func (ka *pskKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
+	if len(ckx.ciphertext) < 2 {
+		return nil, errClientKeyExchange
+	}
+	identityLen := (int(ckx.ciphertext[0]) << 8) | int(ckx.ciphertext[1])
+	if 2+identityLen != len(ckx.ciphertext) {
+		return nil, errClientKeyExchange
+	}
+	identity := string(ckx.ciphertext[2:])
+
+	if identity != config.PreSharedKeyIdentity {
+		return nil, errors.New("tls: unexpected identity")
+	}
+
+	if config.PreSharedKey == nil {
+		return nil, errors.New("tls: pre-shared key not configured")
+	}
+	otherSecret := make([]byte, len(config.PreSharedKey))
+	return makePSKPremaster(otherSecret, config.PreSharedKey), nil
+}
+
+func (ka *pskKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
+	if len(skx.key) < 2 {
+		return errServerKeyExchange
+	}
+	identityLen := (int(skx.key[0]) << 8) | int(skx.key[1])
+	if 2+identityLen != len(skx.key) {
+		return errServerKeyExchange
+	}
+	ka.identityHint = string(skx.key[2:])
+	return nil
+}
+
+func (ka *pskKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
+	// The server only sends an identity hint but, for purposes of
+	// test code, the server always sends the hint and it is
+	// required to match.
+	if ka.identityHint != config.PreSharedKeyIdentity {
+		return nil, nil, errors.New("tls: unexpected identity")
+	}
+
+	bytes := make([]byte, 2+len(config.PreSharedKeyIdentity))
+	bytes[0] = byte(len(config.PreSharedKeyIdentity) >> 8)
+	bytes[1] = byte(len(config.PreSharedKeyIdentity))
+	copy(bytes[2:], []byte(config.PreSharedKeyIdentity))
+
+	ckx := new(clientKeyExchangeMsg)
+	ckx.ciphertext = bytes
+
+	if config.PreSharedKey == nil {
+		return nil, nil, errors.New("tls: pre-shared key not configured")
+	}
+	otherSecret := make([]byte, len(config.PreSharedKey))
+	return makePSKPremaster(otherSecret, config.PreSharedKey), ckx, nil
+}
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 10f86c9..6f4562a 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -835,6 +835,9 @@
 	{"ECDHE-RSA-AES256-SHA", TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
 	{"ECDHE-RSA-AES256-SHA384", TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384},
 	{"ECDHE-RSA-RC4-SHA", TLS_ECDHE_RSA_WITH_RC4_128_SHA},
+	{"PSK-AES128-CBC-SHA", TLS_PSK_WITH_AES_128_CBC_SHA},
+	{"PSK-AES256-CBC-SHA", TLS_PSK_WITH_AES_256_CBC_SHA},
+	{"PSK-RC4-SHA", TLS_PSK_WITH_RC4_128_SHA},
 	{"RC4-MD5", TLS_RSA_WITH_RC4_128_MD5},
 	{"RC4-SHA", TLS_RSA_WITH_RC4_128_SHA},
 }
@@ -847,6 +850,9 @@
 
 func addCipherSuiteTests() {
 	for _, suite := range testCipherSuites {
+		const psk = "12345"
+		const pskIdentity = "luggage combo"
+
 		var cert Certificate
 		var certFile string
 		var keyFile string
@@ -860,6 +866,13 @@
 			keyFile = rsaKeyFile
 		}
 
+		var flags []string
+		if strings.HasPrefix(suite.name, "PSK-") || strings.Contains(suite.name, "-PSK-") {
+			flags = append(flags,
+				"-psk", psk,
+				"-psk-identity", pskIdentity)
+		}
+
 		for _, ver := range tlsVersions {
 			if ver.version < VersionTLS12 && isTLS12Only(suite.name) {
 				continue
@@ -874,11 +887,14 @@
 				testType: clientTest,
 				name:     ver.name + "-" + suite.name + "-client",
 				config: Config{
-					MinVersion:   ver.version,
-					MaxVersion:   ver.version,
-					CipherSuites: []uint16{suite.id},
-					Certificates: []Certificate{cert},
+					MinVersion:           ver.version,
+					MaxVersion:           ver.version,
+					CipherSuites:         []uint16{suite.id},
+					Certificates:         []Certificate{cert},
+					PreSharedKey:         []byte(psk),
+					PreSharedKeyIdentity: pskIdentity,
 				},
+				flags:         flags,
 				resumeSession: resumeSession,
 			})
 
@@ -886,13 +902,16 @@
 				testType: serverTest,
 				name:     ver.name + "-" + suite.name + "-server",
 				config: Config{
-					MinVersion:   ver.version,
-					MaxVersion:   ver.version,
-					CipherSuites: []uint16{suite.id},
-					Certificates: []Certificate{cert},
+					MinVersion:           ver.version,
+					MaxVersion:           ver.version,
+					CipherSuites:         []uint16{suite.id},
+					Certificates:         []Certificate{cert},
+					PreSharedKey:         []byte(psk),
+					PreSharedKeyIdentity: pskIdentity,
 				},
 				certFile:      certFile,
 				keyFile:       keyFile,
+				flags:         flags,
 				resumeSession: resumeSession,
 			})
 
@@ -903,11 +922,14 @@
 					protocol: dtls,
 					name:     "D" + ver.name + "-" + suite.name + "-client",
 					config: Config{
-						MinVersion:   ver.version,
-						MaxVersion:   ver.version,
-						CipherSuites: []uint16{suite.id},
-						Certificates: []Certificate{cert},
+						MinVersion:           ver.version,
+						MaxVersion:           ver.version,
+						CipherSuites:         []uint16{suite.id},
+						Certificates:         []Certificate{cert},
+						PreSharedKey:         []byte(psk),
+						PreSharedKeyIdentity: pskIdentity,
 					},
+					flags:         flags,
 					resumeSession: resumeSession,
 				})
 				testCases = append(testCases, testCase{
@@ -915,13 +937,16 @@
 					protocol: dtls,
 					name:     "D" + ver.name + "-" + suite.name + "-server",
 					config: Config{
-						MinVersion:   ver.version,
-						MaxVersion:   ver.version,
-						CipherSuites: []uint16{suite.id},
-						Certificates: []Certificate{cert},
+						MinVersion:           ver.version,
+						MaxVersion:           ver.version,
+						CipherSuites:         []uint16{suite.id},
+						Certificates:         []Certificate{cert},
+						PreSharedKey:         []byte(psk),
+						PreSharedKeyIdentity: pskIdentity,
 					},
 					certFile:      certFile,
 					keyFile:       keyFile,
+					flags:         flags,
 					resumeSession: resumeSession,
 				})
 			}
@@ -1115,8 +1140,8 @@
 							RequireExtendedMasterSecret: with,
 						},
 					},
-					flags:              flags,
-					shouldFail:         ver.version == VersionSSL30 && with,
+					flags:      flags,
+					shouldFail: ver.version == VersionSSL30 && with,
 				}
 				if test.shouldFail {
 					test.expectedLocalError = "extended master secret required but not supported by peer"
@@ -1248,6 +1273,34 @@
 		flags: flags,
 	})
 
+	// Skip ServerKeyExchange in PSK key exchange if there's no
+	// identity hint.
+	testCases = append(testCases, testCase{
+		protocol: protocol,
+		name:     "EmptyPSKHint-Client" + suffix,
+		config: Config{
+			CipherSuites: []uint16{TLS_PSK_WITH_AES_128_CBC_SHA},
+			PreSharedKey: []byte("secret"),
+			Bugs: ProtocolBugs{
+				MaxHandshakeRecordLength: maxHandshakeRecordLength,
+			},
+		},
+		flags: append(flags, "-psk", "secret"),
+	})
+	testCases = append(testCases, testCase{
+		protocol: protocol,
+		testType: serverTest,
+		name:     "EmptyPSKHint-Server" + suffix,
+		config: Config{
+			CipherSuites: []uint16{TLS_PSK_WITH_AES_128_CBC_SHA},
+			PreSharedKey: []byte("secret"),
+			Bugs: ProtocolBugs{
+				MaxHandshakeRecordLength: maxHandshakeRecordLength,
+			},
+		},
+		flags: append(flags, "-psk", "secret"),
+	})
+
 	if protocol == tls {
 		// NPN on client and server; results in post-handshake message.
 		testCases = append(testCases, testCase{