Reimplement serverHelloMsg with byteBuilder in Go.

[Originally written by nharper and tweaked by davidben.]

This will end up being split in two with most of the ServerHello
extensions being serializable in both ServerHello and
EncryptedExtensions depending on version.

Change-Id: Ida5876d55fbafb982bc2e5fdaf82872e733d6536
Reviewed-on: https://boringssl-review.googlesource.com/8580
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 2dff77e..b4f0c5b 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -604,211 +604,109 @@
 		return m.raw
 	}
 
-	length := 38 + len(m.sessionId)
-	numExtensions := 0
-	extensionsLength := 0
-
-	nextProtoLen := 0
-	if m.nextProtoNeg {
-		numExtensions++
-		for _, v := range m.nextProtos {
-			nextProtoLen += len(v)
-		}
-		nextProtoLen += len(m.nextProtos)
-		extensionsLength += nextProtoLen
-	}
-	if m.ocspStapling {
-		numExtensions++
-	}
-	if m.ticketSupported {
-		numExtensions++
-	}
-	if m.secureRenegotiation != nil {
-		extensionsLength += 1 + len(m.secureRenegotiation)
-		numExtensions++
-	}
-	if m.duplicateExtension {
-		numExtensions += 2
-	}
-	if m.channelIDRequested {
-		numExtensions++
-	}
-	if alpnLen := len(m.alpnProtocol); alpnLen > 0 || m.alpnProtocolEmpty {
-		if alpnLen >= 256 {
-			panic("invalid ALPN protocol")
-		}
-		extensionsLength += 2 + 1 + alpnLen
-		numExtensions++
-	}
-	if m.extendedMasterSecret {
-		numExtensions++
-	}
-	if m.srtpProtectionProfile != 0 {
-		extensionsLength += 2 + 2 + 1 + len(m.srtpMasterKeyIdentifier)
-		numExtensions++
-	}
-	if m.sctList != nil {
-		extensionsLength += len(m.sctList)
-		numExtensions++
-	}
-	if l := len(m.customExtension); l > 0 {
-		extensionsLength += l
-		numExtensions++
-	}
-
-	if numExtensions > 0 {
-		extensionsLength += 4 * numExtensions
-		length += 2 + extensionsLength
-	}
-
-	x := make([]byte, 4+length)
-	x[0] = typeServerHello
-	x[1] = uint8(length >> 16)
-	x[2] = uint8(length >> 8)
-	x[3] = uint8(length)
+	handshakeMsg := newByteBuilder()
+	handshakeMsg.addU8(typeServerHello)
+	hello := handshakeMsg.addU24LengthPrefixed()
 	vers := versionToWire(m.vers, m.isDTLS)
-	x[4] = uint8(vers >> 8)
-	x[5] = uint8(vers)
-	copy(x[6:38], m.random)
-	x[38] = uint8(len(m.sessionId))
-	copy(x[39:39+len(m.sessionId)], m.sessionId)
-	z := x[39+len(m.sessionId):]
-	z[0] = uint8(m.cipherSuite >> 8)
-	z[1] = uint8(m.cipherSuite)
-	z[2] = uint8(m.compressionMethod)
+	hello.addU16(vers)
+	hello.addBytes(m.random)
+	sessionId := hello.addU8LengthPrefixed()
+	sessionId.addBytes(m.sessionId)
+	hello.addU16(m.cipherSuite)
+	hello.addU8(m.compressionMethod)
 
-	z = z[3:]
-	if numExtensions > 0 {
-		z[0] = byte(extensionsLength >> 8)
-		z[1] = byte(extensionsLength)
-		z = z[2:]
-	}
+	extensions := hello.addU16LengthPrefixed()
 	if m.duplicateExtension {
 		// Add a duplicate bogus extension at the beginning and end.
-		z[0] = 0xff
-		z[1] = 0xff
-		z = z[4:]
+		extensions.addU16(0xffff)
+		extensions.addU16(0) // length = 0 for empty extension
 	}
 	if m.nextProtoNeg && !m.npnLast {
-		z[0] = byte(extensionNextProtoNeg >> 8)
-		z[1] = byte(extensionNextProtoNeg & 0xff)
-		z[2] = byte(nextProtoLen >> 8)
-		z[3] = byte(nextProtoLen)
-		z = z[4:]
+		extensions.addU16(extensionNextProtoNeg)
+		extension := extensions.addU16LengthPrefixed()
 
 		for _, v := range m.nextProtos {
-			l := len(v)
-			if l > 255 {
-				l = 255
+			if len(v) > 255 {
+				v = v[:255]
 			}
-			z[0] = byte(l)
-			copy(z[1:], []byte(v[0:l]))
-			z = z[1+l:]
+			npn := extension.addU8LengthPrefixed()
+			npn.addBytes([]byte(v))
 		}
 	}
 	if m.ocspStapling {
-		z[0] = byte(extensionStatusRequest >> 8)
-		z[1] = byte(extensionStatusRequest)
-		z = z[4:]
+		extensions.addU16(extensionStatusRequest)
+		extensions.addU16(0)
 	}
 	if m.ticketSupported {
-		z[0] = byte(extensionSessionTicket >> 8)
-		z[1] = byte(extensionSessionTicket)
-		z = z[4:]
+		extensions.addU16(extensionSessionTicket)
+		extensions.addU16(0)
 	}
 	if m.secureRenegotiation != nil {
-		z[0] = byte(extensionRenegotiationInfo >> 8)
-		z[1] = byte(extensionRenegotiationInfo & 0xff)
-		z[2] = 0
-		z[3] = byte(1 + len(m.secureRenegotiation))
-		z[4] = byte(len(m.secureRenegotiation))
-		z = z[5:]
-		copy(z, m.secureRenegotiation)
-		z = z[len(m.secureRenegotiation):]
+		extensions.addU16(extensionRenegotiationInfo)
+		extension := extensions.addU16LengthPrefixed()
+		secureRenego := extension.addU8LengthPrefixed()
+		secureRenego.addBytes(m.secureRenegotiation)
 	}
-	if alpnLen := len(m.alpnProtocol); alpnLen > 0 || m.alpnProtocolEmpty {
-		z[0] = byte(extensionALPN >> 8)
-		z[1] = byte(extensionALPN & 0xff)
-		l := 2 + 1 + alpnLen
-		z[2] = byte(l >> 8)
-		z[3] = byte(l)
-		l -= 2
-		z[4] = byte(l >> 8)
-		z[5] = byte(l)
-		l -= 1
-		z[6] = byte(l)
-		copy(z[7:], []byte(m.alpnProtocol))
-		z = z[7+alpnLen:]
+	if len(m.alpnProtocol) > 0 || m.alpnProtocolEmpty {
+		extensions.addU16(extensionALPN)
+		extension := extensions.addU16LengthPrefixed()
+
+		protocolNameList := extension.addU16LengthPrefixed()
+		protocolName := protocolNameList.addU8LengthPrefixed()
+		protocolName.addBytes([]byte(m.alpnProtocol))
 	}
 	if m.channelIDRequested {
-		z[0] = byte(extensionChannelID >> 8)
-		z[1] = byte(extensionChannelID & 0xff)
-		z = z[4:]
+		extensions.addU16(extensionChannelID)
+		extensions.addU16(0)
 	}
 	if m.duplicateExtension {
 		// Add a duplicate bogus extension at the beginning and end.
-		z[0] = 0xff
-		z[1] = 0xff
-		z = z[4:]
+		extensions.addU16(0xffff)
+		extensions.addU16(0)
 	}
 	if m.extendedMasterSecret {
-		z[0] = byte(extensionExtendedMasterSecret >> 8)
-		z[1] = byte(extensionExtendedMasterSecret & 0xff)
-		z = z[4:]
+		extensions.addU16(extensionExtendedMasterSecret)
+		extensions.addU16(0)
 	}
 	if m.srtpProtectionProfile != 0 {
-		z[0] = byte(extensionUseSRTP >> 8)
-		z[1] = byte(extensionUseSRTP & 0xff)
-		l := 2 + 2 + 1 + len(m.srtpMasterKeyIdentifier)
-		z[2] = byte(l >> 8)
-		z[3] = byte(l & 0xff)
-		z[4] = 0
-		z[5] = 2
-		z[6] = byte(m.srtpProtectionProfile >> 8)
-		z[7] = byte(m.srtpProtectionProfile & 0xff)
-		l = len(m.srtpMasterKeyIdentifier)
-		z[8] = byte(l)
-		copy(z[9:], []byte(m.srtpMasterKeyIdentifier))
-		z = z[9+l:]
+		extensions.addU16(extensionUseSRTP)
+		extension := extensions.addU16LengthPrefixed()
+
+		srtpProtectionProfiles := extension.addU16LengthPrefixed()
+		srtpProtectionProfiles.addU8(byte(m.srtpProtectionProfile >> 8))
+		srtpProtectionProfiles.addU8(byte(m.srtpProtectionProfile))
+		srtpMki := extension.addU8LengthPrefixed()
+		srtpMki.addBytes([]byte(m.srtpMasterKeyIdentifier))
 	}
 	if m.sctList != nil {
-		z[0] = byte(extensionSignedCertificateTimestamp >> 8)
-		z[1] = byte(extensionSignedCertificateTimestamp & 0xff)
-		l := len(m.sctList)
-		z[2] = byte(l >> 8)
-		z[3] = byte(l & 0xff)
-		copy(z[4:], m.sctList)
-		z = z[4+l:]
+		extensions.addU16(extensionSignedCertificateTimestamp)
+		extension := extensions.addU16LengthPrefixed()
+		extension.addBytes(m.sctList)
 	}
 	if l := len(m.customExtension); l > 0 {
-		z[0] = byte(extensionCustom >> 8)
-		z[1] = byte(extensionCustom & 0xff)
-		z[2] = byte(l >> 8)
-		z[3] = byte(l & 0xff)
-		copy(z[4:], []byte(m.customExtension))
-		z = z[4+l:]
+		extensions.addU16(extensionCustom)
+		customExt := extensions.addU16LengthPrefixed()
+		customExt.addBytes([]byte(m.customExtension))
 	}
 	if m.nextProtoNeg && m.npnLast {
-		z[0] = byte(extensionNextProtoNeg >> 8)
-		z[1] = byte(extensionNextProtoNeg & 0xff)
-		z[2] = byte(nextProtoLen >> 8)
-		z[3] = byte(nextProtoLen)
-		z = z[4:]
+		extensions.addU16(extensionNextProtoNeg)
+		extension := extensions.addU16LengthPrefixed()
 
 		for _, v := range m.nextProtos {
-			l := len(v)
-			if l > 255 {
-				l = 255
+			if len(v) > 255 {
+				v = v[0:255]
 			}
-			z[0] = byte(l)
-			copy(z[1:], []byte(v[0:l]))
-			z = z[1+l:]
+			npn := extension.addU8LengthPrefixed()
+			npn.addBytes([]byte(v))
 		}
 	}
 
-	m.raw = x
+	if extensions.len() == 0 {
+		hello.discardChild()
+	}
 
-	return x
+	m.raw = handshakeMsg.finish()
+	return m.raw
 }
 
 func (m *serverHelloMsg) unmarshal(data []byte) bool {