Add tests for ECDHE_PSK.

pskKeyAgreement is now a wrapper over a base key agreement.

Change-Id: Ic18862d3e98f7513476f878b8df5dcd8d36a0eac
Reviewed-on: https://boringssl-review.googlesource.com/2053
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/key_agreement.go b/ssl/test/runner/key_agreement.go
index 4f76cb1..af54a8f 100644
--- a/ssl/test/runner/key_agreement.go
+++ b/ssl/test/runner/key_agreement.go
@@ -187,8 +187,29 @@
 
 }
 
-// signedKeyAgreement implements helper functions for key agreement
-// methods that involve signed parameters in the ServerKeyExchange.
+// keyAgreementAuthentication is a helper interface that specifies how
+// to authenticate the ServerKeyExchange parameters.
+type keyAgreementAuthentication interface {
+	signParameters(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg, params []byte) (*serverKeyExchangeMsg, error)
+	verifyParameters(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, params []byte, sig []byte) error
+}
+
+// nilKeyAgreementAuthentication does not authenticate the key
+// agreement parameters.
+type nilKeyAgreementAuthentication struct{}
+
+func (ka *nilKeyAgreementAuthentication) signParameters(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg, params []byte) (*serverKeyExchangeMsg, error) {
+	skx := new(serverKeyExchangeMsg)
+	skx.key = params
+	return skx, nil
+}
+
+func (ka *nilKeyAgreementAuthentication) verifyParameters(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, params []byte, sig []byte) error {
+	return nil
+}
+
+// signedKeyAgreement signs the ServerKeyExchange parameters with the
+// server's private key.
 type signedKeyAgreement struct {
 	version uint16
 	sigType uint8
@@ -328,7 +349,7 @@
 // pre-master secret is then calculated using ECDH. The signature may
 // either be ECDSA or RSA.
 type ecdheKeyAgreement struct {
-	signedKeyAgreement
+	auth       keyAgreementAuthentication
 	privateKey []byte
 	curve      elliptic.Curve
 	x, y       *big.Int
@@ -394,7 +415,7 @@
 	serverECDHParams[3] = byte(len(ecdhePublic))
 	copy(serverECDHParams[4:], ecdhePublic)
 
-	return ka.signParameters(config, cert, clientHello, hello, serverECDHParams)
+	return ka.auth.signParameters(config, cert, clientHello, hello, serverECDHParams)
 }
 
 func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
@@ -438,7 +459,7 @@
 	serverECDHParams := skx.key[:4+publicLen]
 	sig := skx.key[4+publicLen:]
 
-	return ka.verifyParameters(config, clientHello, serverHello, cert, serverECDHParams, sig)
+	return ka.auth.verifyParameters(config, clientHello, serverHello, cert, serverECDHParams, sig)
 }
 
 func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
@@ -468,7 +489,7 @@
 // an ephemeral Diffie-Hellman public/private key pair and signs it. The
 // pre-master secret is then calculated using Diffie-Hellman.
 type dheKeyAgreement struct {
-	signedKeyAgreement
+	auth    keyAgreementAuthentication
 	p, g    *big.Int
 	yTheirs *big.Int
 	xOurs   *big.Int
@@ -500,7 +521,7 @@
 	serverDHParams = append(serverDHParams, byte(len(yBytes)>>8), byte(len(yBytes)))
 	serverDHParams = append(serverDHParams, yBytes...)
 
-	return ka.signParameters(config, cert, clientHello, hello, serverDHParams)
+	return ka.auth.signParameters(config, cert, clientHello, hello, serverDHParams)
 }
 
 func (ka *dheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
@@ -562,7 +583,7 @@
 	sig := k
 	serverDHParams := skx.key[:len(skx.key)-len(sig)]
 
-	return ka.verifyParameters(config, clientHello, serverHello, cert, serverDHParams, sig)
+	return ka.auth.verifyParameters(config, clientHello, serverHello, cert, serverDHParams, sig)
 }
 
 func (ka *dheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
@@ -587,8 +608,43 @@
 	return preMasterSecret, ckx, nil
 }
 
-// makePSKPremaster formats a PSK pre-master secret based on
-// otherSecret from the base key exchange and psk.
+// nilKeyAgreement is a fake key agreement used to implement the plain PSK key
+// exchange.
+type nilKeyAgreement struct{}
+
+func (ka *nilKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
+	return nil, nil
+}
+
+func (ka *nilKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
+	if len(ckx.ciphertext) != 0 {
+		return nil, errClientKeyExchange
+	}
+
+	// Although in plain PSK, otherSecret is all zeros, the base key
+	// agreement does not access to the length of the pre-shared
+	// key. pskKeyAgreement instead interprets nil to mean to use all zeros
+	// of the appropriate length.
+	return nil, nil
+}
+
+func (ka *nilKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
+	if len(skx.key) != 0 {
+		return errServerKeyExchange
+	}
+	return nil
+}
+
+func (ka *nilKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
+	// Although in plain PSK, otherSecret is all zeros, the base key
+	// agreement does not access to the length of the pre-shared
+	// key. pskKeyAgreement instead interprets nil to mean to use all zeros
+	// of the appropriate length.
+	return nil, &clientKeyExchangeMsg{}, nil
+}
+
+// makePSKPremaster formats a PSK pre-master secret based on otherSecret from
+// the base key exchange and psk.
 func makePSKPremaster(otherSecret, psk []byte) []byte {
 	out := make([]byte, 0, 2+len(otherSecret)+2+len(psk))
 	out = append(out, byte(len(otherSecret)>>8), byte(len(otherSecret)))
@@ -600,33 +656,47 @@
 
 // pskKeyAgreement implements the PSK key agreement.
 type pskKeyAgreement struct {
+	base         keyAgreement
 	identityHint string
 }
 
 func (ka *pskKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
-	// ServerKeyExchange is optional if the identity hint is empty.
-	if config.PreSharedKeyIdentity == "" {
-		return nil, nil
-	}
+	// Assemble the identity hint.
 	bytes := make([]byte, 2+len(config.PreSharedKeyIdentity))
 	bytes[0] = byte(len(config.PreSharedKeyIdentity) >> 8)
 	bytes[1] = byte(len(config.PreSharedKeyIdentity))
 	copy(bytes[2:], []byte(config.PreSharedKeyIdentity))
 
+	// If there is one, append the base key agreement's
+	// ServerKeyExchange.
+	baseSkx, err := ka.base.generateServerKeyExchange(config, cert, clientHello, hello)
+	if err != nil {
+		return nil, err
+	}
+
+	if baseSkx != nil {
+		bytes = append(bytes, baseSkx.key...)
+	} else if config.PreSharedKeyIdentity == "" {
+		// ServerKeyExchange is optional if the identity hint is empty
+		// and there would otherwise be no ServerKeyExchange.
+		return nil, nil
+	}
+
 	skx := new(serverKeyExchangeMsg)
 	skx.key = bytes
 	return skx, nil
 }
 
 func (ka *pskKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
+	// First, process the PSK identity.
 	if len(ckx.ciphertext) < 2 {
 		return nil, errClientKeyExchange
 	}
 	identityLen := (int(ckx.ciphertext[0]) << 8) | int(ckx.ciphertext[1])
-	if 2+identityLen != len(ckx.ciphertext) {
+	if 2+identityLen > len(ckx.ciphertext) {
 		return nil, errClientKeyExchange
 	}
-	identity := string(ckx.ciphertext[2:])
+	identity := string(ckx.ciphertext[2 : 2+identityLen])
 
 	if identity != config.PreSharedKeyIdentity {
 		return nil, errors.New("tls: unexpected identity")
@@ -635,7 +705,20 @@
 	if config.PreSharedKey == nil {
 		return nil, errors.New("tls: pre-shared key not configured")
 	}
-	otherSecret := make([]byte, len(config.PreSharedKey))
+
+	// Process the remainder of the ClientKeyExchange to compute the base
+	// pre-master secret.
+	newCkx := new(clientKeyExchangeMsg)
+	newCkx.ciphertext = ckx.ciphertext[2+identityLen:]
+	otherSecret, err := ka.base.processClientKeyExchange(config, cert, newCkx, version)
+	if err != nil {
+		return nil, err
+	}
+
+	if otherSecret == nil {
+		// Special-case for the plain PSK key exchanges.
+		otherSecret = make([]byte, len(config.PreSharedKey))
+	}
 	return makePSKPremaster(otherSecret, config.PreSharedKey), nil
 }
 
@@ -644,11 +727,15 @@
 		return errServerKeyExchange
 	}
 	identityLen := (int(skx.key[0]) << 8) | int(skx.key[1])
-	if 2+identityLen != len(skx.key) {
+	if 2+identityLen > len(skx.key) {
 		return errServerKeyExchange
 	}
-	ka.identityHint = string(skx.key[2:])
-	return nil
+	ka.identityHint = string(skx.key[2 : 2+identityLen])
+
+	// Process the remainder of the ServerKeyExchange.
+	newSkx := new(serverKeyExchangeMsg)
+	newSkx.key = skx.key[2+identityLen:]
+	return ka.base.processServerKeyExchange(config, clientHello, serverHello, cert, newSkx)
 }
 
 func (ka *pskKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
@@ -659,17 +746,25 @@
 		return nil, nil, errors.New("tls: unexpected identity")
 	}
 
+	// Serialize the identity.
 	bytes := make([]byte, 2+len(config.PreSharedKeyIdentity))
 	bytes[0] = byte(len(config.PreSharedKeyIdentity) >> 8)
 	bytes[1] = byte(len(config.PreSharedKeyIdentity))
 	copy(bytes[2:], []byte(config.PreSharedKeyIdentity))
 
+	// Append the base key exchange's ClientKeyExchange.
+	otherSecret, baseCkx, err := ka.base.generateClientKeyExchange(config, clientHello, cert)
+	if err != nil {
+		return nil, nil, err
+	}
 	ckx := new(clientKeyExchangeMsg)
-	ckx.ciphertext = bytes
+	ckx.ciphertext = append(bytes, baseCkx.ciphertext...)
 
 	if config.PreSharedKey == nil {
 		return nil, nil, errors.New("tls: pre-shared key not configured")
 	}
-	otherSecret := make([]byte, len(config.PreSharedKey))
+	if otherSecret == nil {
+		otherSecret = make([]byte, len(config.PreSharedKey))
+	}
 	return makePSKPremaster(otherSecret, config.PreSharedKey), ckx, nil
 }