Implement ALPN in runner.go.

Imported from upstream's https://codereview.appspot.com/108710046.

Change-Id: I66c879dcc9fd09446ac1a8380f796b1d68c89e4e
Reviewed-on: https://boringssl-review.googlesource.com/1751
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index 9af3063..acaa89c 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -76,6 +76,7 @@
 	extensionSupportedCurves     uint16 = 10
 	extensionSupportedPoints     uint16 = 11
 	extensionSignatureAlgorithms uint16 = 13
+	extensionALPN                uint16 = 16
 	extensionSessionTicket       uint16 = 35
 	extensionNextProtoNeg        uint16 = 13172 // not IANA assigned
 	extensionRenegotiationInfo   uint16 = 0xff01
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 708d282..d61c6d2 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -43,6 +43,18 @@
 	c.sendHandshakeSeq = 0
 	c.recvHandshakeSeq = 0
 
+	nextProtosLength := 0
+	for _, proto := range c.config.NextProtos {
+		if l := len(proto); l == 0 || l > 255 {
+			return errors.New("tls: invalid NextProtos value")
+		} else {
+			nextProtosLength += 1 + l
+		}
+	}
+	if nextProtosLength > 0xffff {
+		return errors.New("tls: NextProtos values too large")
+	}
+
 	hello := &clientHelloMsg{
 		isDTLS:              c.isDTLS,
 		vers:                c.config.maxVersion(),
@@ -54,6 +66,7 @@
 		supportedPoints:     []uint8{pointFormatUncompressed},
 		nextProtoNeg:        len(c.config.NextProtos) > 0,
 		secureRenegotiation: true,
+		alpnProtocols:       c.config.NextProtos,
 		duplicateExtension:  c.config.Bugs.DuplicateExtension,
 		channelIDSupported:  c.config.ChannelID != nil,
 	}
@@ -565,11 +578,31 @@
 		return false, errors.New("tls: server selected unsupported compression format")
 	}
 
-	if !hs.hello.nextProtoNeg && hs.serverHello.nextProtoNeg {
+	clientDidNPN := hs.hello.nextProtoNeg
+	clientDidALPN := len(hs.hello.alpnProtocols) > 0
+	serverHasNPN := hs.serverHello.nextProtoNeg
+	serverHasALPN := len(hs.serverHello.alpnProtocol) > 0
+
+	if !clientDidNPN && serverHasNPN {
 		c.sendAlert(alertHandshakeFailure)
 		return false, errors.New("server advertised unrequested NPN extension")
 	}
 
+	if !clientDidALPN && serverHasALPN {
+		c.sendAlert(alertHandshakeFailure)
+		return false, errors.New("server advertised unrequested ALPN extension")
+	}
+
+	if serverHasNPN && serverHasALPN {
+		c.sendAlert(alertHandshakeFailure)
+		return false, errors.New("server advertised both NPN and ALPN extensions")
+	}
+
+	if serverHasALPN {
+		c.clientProtocol = hs.serverHello.alpnProtocol
+		c.clientProtocolFallback = false
+	}
+
 	if !hs.hello.channelIDSupported && hs.serverHello.channelIDRequested {
 		c.sendAlert(alertHandshakeFailure)
 		return false, errors.New("server advertised unrequested Channel ID extension")
@@ -750,20 +783,20 @@
 	return serverAddr.String()
 }
 
-// mutualProtocol finds the mutual Next Protocol Negotiation protocol given the
-// set of client and server supported protocols. The set of client supported
-// protocols must not be empty. It returns the resulting protocol and flag
+// mutualProtocol finds the mutual Next Protocol Negotiation or ALPN protocol
+// given list of possible protocols and a list of the preference order. The
+// first list must not be empty. It returns the resulting protocol and flag
 // indicating if the fallback case was reached.
-func mutualProtocol(clientProtos, serverProtos []string) (string, bool) {
-	for _, s := range serverProtos {
-		for _, c := range clientProtos {
+func mutualProtocol(protos, preferenceProtos []string) (string, bool) {
+	for _, s := range preferenceProtos {
+		for _, c := range protos {
 			if s == c {
 				return s, false
 			}
 		}
 	}
 
-	return clientProtos[0], true
+	return protos[0], true
 }
 
 // writeIntPadded writes x into b, padded up with leading zeros as
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 472aa87..f0a1493 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -24,6 +24,7 @@
 	sessionTicket       []uint8
 	signatureAndHashes  []signatureAndHash
 	secureRenegotiation bool
+	alpnProtocols       []string
 	duplicateExtension  bool
 	channelIDSupported  bool
 }
@@ -51,6 +52,7 @@
 		bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
 		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
 		m.secureRenegotiation == m1.secureRenegotiation &&
+		eqStrings(m.alpnProtocols, m1.alpnProtocols) &&
 		m.duplicateExtension == m1.duplicateExtension &&
 		m.channelIDSupported == m1.channelIDSupported
 }
@@ -103,6 +105,17 @@
 	if m.channelIDSupported {
 		numExtensions++
 	}
+	if len(m.alpnProtocols) > 0 {
+		extensionsLength += 2
+		for _, s := range m.alpnProtocols {
+			if l := len(s); l == 0 || l > 255 {
+				panic("invalid ALPN protocol")
+			}
+			extensionsLength++
+			extensionsLength += len(s)
+		}
+		numExtensions++
+	}
 	if numExtensions > 0 {
 		extensionsLength += 4 * numExtensions
 		length += 2 + extensionsLength
@@ -266,6 +279,27 @@
 		z[3] = 1
 		z = z[5:]
 	}
+	if len(m.alpnProtocols) > 0 {
+		z[0] = byte(extensionALPN >> 8)
+		z[1] = byte(extensionALPN & 0xff)
+		lengths := z[2:]
+		z = z[6:]
+
+		stringsLength := 0
+		for _, s := range m.alpnProtocols {
+			l := len(s)
+			z[0] = byte(l)
+			copy(z[1:], s)
+			z = z[1+l:]
+			stringsLength += 1 + l
+		}
+
+		lengths[2] = byte(stringsLength >> 8)
+		lengths[3] = byte(stringsLength)
+		stringsLength += 2
+		lengths[0] = byte(stringsLength >> 8)
+		lengths[1] = byte(stringsLength)
+	}
 	if m.channelIDSupported {
 		z[0] = byte(extensionChannelID >> 8)
 		z[1] = byte(extensionChannelID & 0xff)
@@ -342,6 +376,7 @@
 	m.ticketSupported = false
 	m.sessionTicket = nil
 	m.signatureAndHashes = nil
+	m.alpnProtocols = nil
 
 	if len(data) == 0 {
 		// ClientHello is optionally followed by extension data
@@ -451,6 +486,24 @@
 				return false
 			}
 			m.secureRenegotiation = true
+		case extensionALPN:
+			if length < 2 {
+				return false
+			}
+			l := int(data[0])<<8 | int(data[1])
+			if l != length-2 {
+				return false
+			}
+			d := data[2:length]
+			for len(d) != 0 {
+				stringLen := int(d[0])
+				d = d[1:]
+				if stringLen == 0 || stringLen > len(d) {
+					return false
+				}
+				m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
+				d = d[stringLen:]
+			}
 		case extensionChannelID:
 			if length > 0 {
 				return false
@@ -476,6 +529,7 @@
 	ocspStapling        bool
 	ticketSupported     bool
 	secureRenegotiation bool
+	alpnProtocol        string
 	duplicateExtension  bool
 	channelIDRequested  bool
 }
@@ -498,6 +552,7 @@
 		m.ocspStapling == m1.ocspStapling &&
 		m.ticketSupported == m1.ticketSupported &&
 		m.secureRenegotiation == m1.secureRenegotiation &&
+		m.alpnProtocol == m1.alpnProtocol &&
 		m.duplicateExtension == m1.duplicateExtension &&
 		m.channelIDRequested == m1.channelIDRequested
 }
@@ -536,6 +591,14 @@
 	if m.channelIDRequested {
 		numExtensions++
 	}
+	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
+		if alpnLen >= 256 {
+			panic("invalid ALPN protocol")
+		}
+		extensionsLength += 2 + 1 + alpnLen
+		numExtensions++
+	}
+
 	if numExtensions > 0 {
 		extensionsLength += 4 * numExtensions
 		length += 2 + extensionsLength
@@ -603,6 +666,20 @@
 		z[3] = 1
 		z = z[5:]
 	}
+	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
+		z[0] = byte(extensionALPN >> 8)
+		z[1] = byte(extensionALPN & 0xff)
+		l := 2 + 1 + alpnLen
+		z[2] = byte(l >> 8)
+		z[3] = byte(l)
+		l -= 2
+		z[4] = byte(l >> 8)
+		z[5] = byte(l)
+		l -= 1
+		z[6] = byte(l)
+		copy(z[7:], []byte(m.alpnProtocol))
+		z = z[7+alpnLen:]
+	}
 	if m.channelIDRequested {
 		z[0] = byte(extensionChannelID >> 8)
 		z[1] = byte(extensionChannelID & 0xff)
@@ -644,6 +721,7 @@
 	m.nextProtos = nil
 	m.ocspStapling = false
 	m.ticketSupported = false
+	m.alpnProtocol = ""
 
 	if len(data) == 0 {
 		// ServerHello is optionally followed by extension data
@@ -698,6 +776,22 @@
 				return false
 			}
 			m.secureRenegotiation = true
+		case extensionALPN:
+			d := data[:length]
+			if len(d) < 3 {
+				return false
+			}
+			l := int(d[0])<<8 | int(d[1])
+			if l != len(d)-2 {
+				return false
+			}
+			d = d[2:]
+			l = int(d[0])
+			if l != len(d)-1 {
+				return false
+			}
+			d = d[1:]
+			m.alpnProtocol = string(d)
 		case extensionChannelID:
 			if length > 0 {
 				return false
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index e456891..7f6b521 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -220,13 +220,21 @@
 	if len(hs.clientHello.serverName) > 0 {
 		c.serverName = hs.clientHello.serverName
 	}
-	// Although sending an empty NPN extension is reasonable, Firefox has
-	// had a bug around this. Best to send nothing at all if
-	// config.NextProtos is empty. See
-	// https://code.google.com/p/go/issues/detail?id=5445.
-	if hs.clientHello.nextProtoNeg && len(config.NextProtos) > 0 {
-		hs.hello.nextProtoNeg = true
-		hs.hello.nextProtos = config.NextProtos
+
+	if len(hs.clientHello.alpnProtocols) > 0 {
+		if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback {
+			hs.hello.alpnProtocol = selectedProto
+			c.clientProtocol = selectedProto
+		}
+	} else {
+		// Although sending an empty NPN extension is reasonable, Firefox has
+		// had a bug around this. Best to send nothing at all if
+		// config.NextProtos is empty. See
+		// https://code.google.com/p/go/issues/detail?id=5445.
+		if hs.clientHello.nextProtoNeg && len(config.NextProtos) > 0 {
+			hs.hello.nextProtoNeg = true
+			hs.hello.nextProtos = config.NextProtos
+		}
 	}
 
 	if len(config.Certificates) == 0 {