Add DTLS-SRTP tests.

Just the negotiation portion as everything else is external. This feature is
used in WebRTC.

Change-Id: Iccc3983ea99e7d054b59010182f9a56a8099e116
Reviewed-on: https://boringssl-review.googlesource.com/2310
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index 9f79778..a4bdef8 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -77,6 +77,7 @@
 	extensionSupportedCurves      uint16 = 10
 	extensionSupportedPoints      uint16 = 11
 	extensionSignatureAlgorithms  uint16 = 13
+	extensionUseSRTP              uint16 = 14
 	extensionALPN                 uint16 = 16
 	extensionExtendedMasterSecret uint16 = 23
 	extensionSessionTicket        uint16 = 35
@@ -161,6 +162,12 @@
 	{signatureECDSA, hashSHA256},
 }
 
+// SRTP protection profiles (See RFC 5764, section 4.1.2)
+const (
+	SRTP_AES128_CM_HMAC_SHA1_80 uint16 = 0x0001
+	SRTP_AES128_CM_HMAC_SHA1_32        = 0x0002
+)
+
 // ConnectionState records basic TLS details about the connection.
 type ConnectionState struct {
 	Version                    uint16                // TLS version used by the connection (e.g. VersionTLS12)
@@ -174,6 +181,7 @@
 	PeerCertificates           []*x509.Certificate   // certificate chain presented by remote peer
 	VerifiedChains             [][]*x509.Certificate // verified chains built from PeerCertificates
 	ChannelID                  *ecdsa.PublicKey      // the channel ID for this connection
+	SRTPProtectionProfile      uint16                // the negotiated DTLS-SRTP protection profile
 }
 
 // ClientAuthType declares the policy the server will follow for
@@ -334,6 +342,10 @@
 	// with the PSK cipher suites.
 	PreSharedKeyIdentity string
 
+	// SRTPProtectionProfiles, if not nil, is the list of SRTP
+	// protection profiles to offer in DTLS-SRTP.
+	SRTPProtectionProfiles []uint16
+
 	// Bugs specifies optional misbehaviour to be used for testing other
 	// implementations.
 	Bugs ProtocolBugs
@@ -520,6 +532,15 @@
 	// RSAServerKeyExchange, if true, causes the server to send a
 	// ServerKeyExchange message in the plain RSA key exchange.
 	RSAServerKeyExchange bool
+
+	// SRTPMasterKeyIdentifer, if not empty, is the SRTP MKI value that the
+	// client offers when negotiating SRTP. MKI support is still missing so
+	// the peer must still send none.
+	SRTPMasterKeyIdentifer string
+
+	// SendSRTPProtectionProfile, if non-zero, is the SRTP profile that the
+	// server sends in the ServerHello instead of the negotiated one.
+	SendSRTPProtectionProfile uint16
 }
 
 func (c *Config) serverInit() {
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index f4b4c36..94d7434 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -56,6 +56,8 @@
 
 	channelID *ecdsa.PublicKey
 
+	srtpProtectionProfile uint16
+
 	// input/output
 	in, out  halfConn     // in.Mutex < out.Mutex
 	rawInput *block       // raw input, right off the wire
@@ -1184,6 +1186,7 @@
 		state.VerifiedChains = c.verifiedChains
 		state.ServerName = c.serverName
 		state.ChannelID = c.channelID
+		state.SRTPProtectionProfile = c.srtpProtectionProfile
 	}
 
 	return state
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 702797b..71712a9 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -56,21 +56,23 @@
 	}
 
 	hello := &clientHelloMsg{
-		isDTLS:               c.isDTLS,
-		vers:                 c.config.maxVersion(),
-		compressionMethods:   []uint8{compressionNone},
-		random:               make([]byte, 32),
-		ocspStapling:         true,
-		serverName:           c.config.ServerName,
-		supportedCurves:      c.config.curvePreferences(),
-		supportedPoints:      []uint8{pointFormatUncompressed},
-		nextProtoNeg:         len(c.config.NextProtos) > 0,
-		secureRenegotiation:  []byte{},
-		alpnProtocols:        c.config.NextProtos,
-		duplicateExtension:   c.config.Bugs.DuplicateExtension,
-		channelIDSupported:   c.config.ChannelID != nil,
-		npnLast:              c.config.Bugs.SwapNPNAndALPN,
-		extendedMasterSecret: c.config.maxVersion() >= VersionTLS10,
+		isDTLS:                  c.isDTLS,
+		vers:                    c.config.maxVersion(),
+		compressionMethods:      []uint8{compressionNone},
+		random:                  make([]byte, 32),
+		ocspStapling:            true,
+		serverName:              c.config.ServerName,
+		supportedCurves:         c.config.curvePreferences(),
+		supportedPoints:         []uint8{pointFormatUncompressed},
+		nextProtoNeg:            len(c.config.NextProtos) > 0,
+		secureRenegotiation:     []byte{},
+		alpnProtocols:           c.config.NextProtos,
+		duplicateExtension:      c.config.Bugs.DuplicateExtension,
+		channelIDSupported:      c.config.ChannelID != nil,
+		npnLast:                 c.config.Bugs.SwapNPNAndALPN,
+		extendedMasterSecret:    c.config.maxVersion() >= VersionTLS10,
+		srtpProtectionProfiles:  c.config.SRTPProtectionProfiles,
+		srtpMasterKeyIdentifier: c.config.Bugs.SRTPMasterKeyIdentifer,
 	}
 
 	if c.config.Bugs.SendClientVersion != 0 {
@@ -666,6 +668,25 @@
 		return false, errors.New("server advertised unrequested Channel ID extension")
 	}
 
+	if hs.serverHello.srtpProtectionProfile != 0 {
+		if hs.serverHello.srtpMasterKeyIdentifier != "" {
+			return false, errors.New("tls: server selected SRTP MKI value")
+		}
+
+		found := false
+		for _, p := range c.config.SRTPProtectionProfiles {
+			if p == hs.serverHello.srtpProtectionProfile {
+				found = true
+				break
+			}
+		}
+		if !found {
+			return false, errors.New("tls: server advertised unsupported SRTP profile")
+		}
+
+		c.srtpProtectionProfile = hs.serverHello.srtpProtectionProfile
+	}
+
 	if hs.serverResumedSession() {
 		// Restore masterSecret and peerCerts from previous state
 		hs.masterSecret = hs.session.masterSecret
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 12a9f3d..cb3b5c4 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -7,28 +7,30 @@
 import "bytes"
 
 type clientHelloMsg struct {
-	raw                  []byte
-	isDTLS               bool
-	vers                 uint16
-	random               []byte
-	sessionId            []byte
-	cookie               []byte
-	cipherSuites         []uint16
-	compressionMethods   []uint8
-	nextProtoNeg         bool
-	serverName           string
-	ocspStapling         bool
-	supportedCurves      []CurveID
-	supportedPoints      []uint8
-	ticketSupported      bool
-	sessionTicket        []uint8
-	signatureAndHashes   []signatureAndHash
-	secureRenegotiation  []byte
-	alpnProtocols        []string
-	duplicateExtension   bool
-	channelIDSupported   bool
-	npnLast              bool
-	extendedMasterSecret bool
+	raw                     []byte
+	isDTLS                  bool
+	vers                    uint16
+	random                  []byte
+	sessionId               []byte
+	cookie                  []byte
+	cipherSuites            []uint16
+	compressionMethods      []uint8
+	nextProtoNeg            bool
+	serverName              string
+	ocspStapling            bool
+	supportedCurves         []CurveID
+	supportedPoints         []uint8
+	ticketSupported         bool
+	sessionTicket           []uint8
+	signatureAndHashes      []signatureAndHash
+	secureRenegotiation     []byte
+	alpnProtocols           []string
+	duplicateExtension      bool
+	channelIDSupported      bool
+	npnLast                 bool
+	extendedMasterSecret    bool
+	srtpProtectionProfiles  []uint16
+	srtpMasterKeyIdentifier string
 }
 
 func (m *clientHelloMsg) equal(i interface{}) bool {
@@ -59,7 +61,9 @@
 		m.duplicateExtension == m1.duplicateExtension &&
 		m.channelIDSupported == m1.channelIDSupported &&
 		m.npnLast == m1.npnLast &&
-		m.extendedMasterSecret == m1.extendedMasterSecret
+		m.extendedMasterSecret == m1.extendedMasterSecret &&
+		eqUint16s(m.srtpProtectionProfiles, m1.srtpProtectionProfiles) &&
+		m.srtpMasterKeyIdentifier == m1.srtpMasterKeyIdentifier
 }
 
 func (m *clientHelloMsg) marshal() []byte {
@@ -124,6 +128,11 @@
 	if m.extendedMasterSecret {
 		numExtensions++
 	}
+	if len(m.srtpProtectionProfiles) > 0 {
+		extensionsLength += 2 + 2*len(m.srtpProtectionProfiles)
+		extensionsLength += 1 + len(m.srtpMasterKeyIdentifier)
+		numExtensions++
+	}
 	if numExtensions > 0 {
 		extensionsLength += 4 * numExtensions
 		length += 2 + extensionsLength
@@ -334,6 +343,29 @@
 		z[1] = byte(extensionExtendedMasterSecret & 0xff)
 		z = z[4:]
 	}
+	if len(m.srtpProtectionProfiles) > 0 {
+		z[0] = byte(extensionUseSRTP >> 8)
+		z[1] = byte(extensionUseSRTP & 0xff)
+
+		profilesLen := 2 * len(m.srtpProtectionProfiles)
+		mkiLen := len(m.srtpMasterKeyIdentifier)
+		l := 2 + profilesLen + 1 + mkiLen
+		z[2] = byte(l >> 8)
+		z[3] = byte(l & 0xff)
+
+		z[4] = byte(profilesLen >> 8)
+		z[5] = byte(profilesLen & 0xff)
+		z = z[6:]
+		for _, p := range m.srtpProtectionProfiles {
+			z[0] = byte(p >> 8)
+			z[1] = byte(p & 0xff)
+			z = z[2:]
+		}
+
+		z[0] = byte(mkiLen)
+		copy(z[1:], []byte(m.srtpMasterKeyIdentifier))
+		z = z[1+mkiLen:]
+	}
 
 	m.raw = x
 
@@ -538,6 +570,25 @@
 				return false
 			}
 			m.extendedMasterSecret = true
+		case extensionUseSRTP:
+			if length < 2 {
+				return false
+			}
+			l := int(data[0])<<8 | int(data[1])
+			if l > length-2 || l%2 != 0 {
+				return false
+			}
+			n := l / 2
+			m.srtpProtectionProfiles = make([]uint16, n)
+			d := data[2:length]
+			for i := 0; i < n; i++ {
+				m.srtpProtectionProfiles[i] = uint16(d[0])<<8 | uint16(d[1])
+				d = d[2:]
+			}
+			if len(d) < 1 || int(d[0]) != len(d)-1 {
+				return false
+			}
+			m.srtpMasterKeyIdentifier = string(d[1:])
 		}
 		data = data[length:]
 	}
@@ -546,22 +597,24 @@
 }
 
 type serverHelloMsg struct {
-	raw                  []byte
-	isDTLS               bool
-	vers                 uint16
-	random               []byte
-	sessionId            []byte
-	cipherSuite          uint16
-	compressionMethod    uint8
-	nextProtoNeg         bool
-	nextProtos           []string
-	ocspStapling         bool
-	ticketSupported      bool
-	secureRenegotiation  []byte
-	alpnProtocol         string
-	duplicateExtension   bool
-	channelIDRequested   bool
-	extendedMasterSecret bool
+	raw                     []byte
+	isDTLS                  bool
+	vers                    uint16
+	random                  []byte
+	sessionId               []byte
+	cipherSuite             uint16
+	compressionMethod       uint8
+	nextProtoNeg            bool
+	nextProtos              []string
+	ocspStapling            bool
+	ticketSupported         bool
+	secureRenegotiation     []byte
+	alpnProtocol            string
+	duplicateExtension      bool
+	channelIDRequested      bool
+	extendedMasterSecret    bool
+	srtpProtectionProfile   uint16
+	srtpMasterKeyIdentifier string
 }
 
 func (m *serverHelloMsg) equal(i interface{}) bool {
@@ -586,7 +639,9 @@
 		m.alpnProtocol == m1.alpnProtocol &&
 		m.duplicateExtension == m1.duplicateExtension &&
 		m.channelIDRequested == m1.channelIDRequested &&
-		m.extendedMasterSecret == m1.extendedMasterSecret
+		m.extendedMasterSecret == m1.extendedMasterSecret &&
+		m.srtpProtectionProfile == m1.srtpProtectionProfile &&
+		m.srtpMasterKeyIdentifier == m1.srtpMasterKeyIdentifier
 }
 
 func (m *serverHelloMsg) marshal() []byte {
@@ -633,6 +688,10 @@
 	if m.extendedMasterSecret {
 		numExtensions++
 	}
+	if m.srtpProtectionProfile != 0 {
+		extensionsLength += 2 + 2 + 1 + len(m.srtpMasterKeyIdentifier)
+		numExtensions++
+	}
 
 	if numExtensions > 0 {
 		extensionsLength += 4 * numExtensions
@@ -734,6 +793,21 @@
 		z[1] = byte(extensionExtendedMasterSecret & 0xff)
 		z = z[4:]
 	}
+	if m.srtpProtectionProfile != 0 {
+		z[0] = byte(extensionUseSRTP >> 8)
+		z[1] = byte(extensionUseSRTP & 0xff)
+		l := 2 + 2 + 1 + len(m.srtpMasterKeyIdentifier)
+		z[2] = byte(l >> 8)
+		z[3] = byte(l & 0xff)
+		z[4] = 0
+		z[5] = 2
+		z[6] = byte(m.srtpProtectionProfile >> 8)
+		z[7] = byte(m.srtpProtectionProfile & 0xff)
+		l = len(m.srtpMasterKeyIdentifier)
+		z[8] = byte(l)
+		copy(z[9:], []byte(m.srtpMasterKeyIdentifier))
+		z = z[9+l:]
+	}
 
 	m.raw = x
 
@@ -846,6 +920,20 @@
 				return false
 			}
 			m.extendedMasterSecret = true
+		case extensionUseSRTP:
+			if length < 2+2+1 {
+				return false
+			}
+			if data[0] != 0 || data[1] != 2 {
+				return false
+			}
+			m.srtpProtectionProfile = uint16(data[2])<<8 | uint16(data[3])
+			d := data[4:length]
+			l := int(d[0])
+			if l != len(d)-1 {
+				return false
+			}
+			m.srtpMasterKeyIdentifier = string(d[1:])
 		}
 		data = data[length:]
 	}
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 89c7b8d..bd6f702 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -270,6 +270,23 @@
 		hs.hello.channelIDRequested = true
 	}
 
+	if hs.clientHello.srtpProtectionProfiles != nil {
+	SRTPLoop:
+		for _, p1 := range c.config.SRTPProtectionProfiles {
+			for _, p2 := range hs.clientHello.srtpProtectionProfiles {
+				if p1 == p2 {
+					hs.hello.srtpProtectionProfile = p1
+					c.srtpProtectionProfile = p1
+					break SRTPLoop
+				}
+			}
+		}
+	}
+
+	if c.config.Bugs.SendSRTPProtectionProfile != 0 {
+		hs.hello.srtpProtectionProfile = c.config.Bugs.SendSRTPProtectionProfile
+	}
+
 	_, hs.ecdsaOk = hs.cert.PrivateKey.(*ecdsa.PrivateKey)
 
 	if hs.checkForResumption() {
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index a302687..8c661a6 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -129,6 +129,9 @@
 	// expectedNextProtoType, if non-zero, is the expected next
 	// protocol negotiation mechanism.
 	expectedNextProtoType int
+	// expectedSRTPProtectionProfile is the DTLS-SRTP profile that
+	// should be negotiated. If zero, none should be negotiated.
+	expectedSRTPProtectionProfile uint16
 	// messageLen is the length, in bytes, of the test message that will be
 	// sent.
 	messageLen int
@@ -357,7 +360,7 @@
 		name:     "FragmentAlert",
 		config: Config{
 			Bugs: ProtocolBugs{
-				FragmentAlert: true,
+				FragmentAlert:     true,
 				SendSpuriousAlert: true,
 			},
 		},
@@ -589,6 +592,10 @@
 		}
 	}
 
+	if p := tlsConn.ConnectionState().SRTPProtectionProfile; p != test.expectedSRTPProtectionProfile {
+		return fmt.Errorf("SRTP profile mismatch: got %d, wanted %d", p, test.expectedSRTPProtectionProfile)
+	}
+
 	if test.shimWritesFirst {
 		var buf [5]byte
 		_, err := io.ReadFull(tlsConn, buf[:])
@@ -1741,6 +1748,82 @@
 		shouldFail:    true,
 		expectedError: ":DECODE_ERROR:",
 	})
+	// Basic DTLS-SRTP tests. Include fake profiles to ensure they
+	// are ignored.
+	testCases = append(testCases, testCase{
+		protocol: dtls,
+		name:     "SRTP-Client",
+		config: Config{
+			SRTPProtectionProfiles: []uint16{40, SRTP_AES128_CM_HMAC_SHA1_80, 42},
+		},
+		flags: []string{
+			"-srtp-profiles",
+			"SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32",
+		},
+		expectedSRTPProtectionProfile: SRTP_AES128_CM_HMAC_SHA1_80,
+	})
+	testCases = append(testCases, testCase{
+		protocol: dtls,
+		testType: serverTest,
+		name:     "SRTP-Server",
+		config: Config{
+			SRTPProtectionProfiles: []uint16{40, SRTP_AES128_CM_HMAC_SHA1_80, 42},
+		},
+		flags: []string{
+			"-srtp-profiles",
+			"SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32",
+		},
+		expectedSRTPProtectionProfile: SRTP_AES128_CM_HMAC_SHA1_80,
+	})
+	// Test that the MKI is ignored.
+	testCases = append(testCases, testCase{
+		protocol: dtls,
+		testType: serverTest,
+		name:     "SRTP-Server-IgnoreMKI",
+		config: Config{
+			SRTPProtectionProfiles: []uint16{SRTP_AES128_CM_HMAC_SHA1_80},
+			Bugs: ProtocolBugs{
+				SRTPMasterKeyIdentifer: "bogus",
+			},
+		},
+		flags: []string{
+			"-srtp-profiles",
+			"SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32",
+		},
+		expectedSRTPProtectionProfile: SRTP_AES128_CM_HMAC_SHA1_80,
+	})
+	// Test that SRTP isn't negotiated on the server if there were
+	// no matching profiles.
+	testCases = append(testCases, testCase{
+		protocol: dtls,
+		testType: serverTest,
+		name:     "SRTP-Server-NoMatch",
+		config: Config{
+			SRTPProtectionProfiles: []uint16{100, 101, 102},
+		},
+		flags: []string{
+			"-srtp-profiles",
+			"SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32",
+		},
+		expectedSRTPProtectionProfile: 0,
+	})
+	// Test that the server returning an invalid SRTP profile is
+	// flagged as an error by the client.
+	testCases = append(testCases, testCase{
+		protocol: dtls,
+		name:     "SRTP-Client-NoMatch",
+		config: Config{
+			Bugs: ProtocolBugs{
+				SendSRTPProtectionProfile: SRTP_AES128_CM_HMAC_SHA1_32,
+			},
+		},
+		flags: []string{
+			"-srtp-profiles",
+			"SRTP_AES128_CM_SHA1_80",
+		},
+		shouldFail:    true,
+		expectedError: ":BAD_SRTP_PROTECTION_PROFILE_LIST:",
+	})
 }
 
 func addResumptionVersionTests() {