Reimplement serverHelloMsg with byteBuilder in Go.
[Originally written by nharper and tweaked by davidben.]
This will end up being split in two with most of the ServerHello
extensions being serializable in both ServerHello and
EncryptedExtensions depending on version.
Change-Id: Ida5876d55fbafb982bc2e5fdaf82872e733d6536
Reviewed-on: https://boringssl-review.googlesource.com/8580
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 2dff77e..b4f0c5b 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -604,211 +604,109 @@
return m.raw
}
- length := 38 + len(m.sessionId)
- numExtensions := 0
- extensionsLength := 0
-
- nextProtoLen := 0
- if m.nextProtoNeg {
- numExtensions++
- for _, v := range m.nextProtos {
- nextProtoLen += len(v)
- }
- nextProtoLen += len(m.nextProtos)
- extensionsLength += nextProtoLen
- }
- if m.ocspStapling {
- numExtensions++
- }
- if m.ticketSupported {
- numExtensions++
- }
- if m.secureRenegotiation != nil {
- extensionsLength += 1 + len(m.secureRenegotiation)
- numExtensions++
- }
- if m.duplicateExtension {
- numExtensions += 2
- }
- if m.channelIDRequested {
- numExtensions++
- }
- if alpnLen := len(m.alpnProtocol); alpnLen > 0 || m.alpnProtocolEmpty {
- if alpnLen >= 256 {
- panic("invalid ALPN protocol")
- }
- extensionsLength += 2 + 1 + alpnLen
- numExtensions++
- }
- if m.extendedMasterSecret {
- numExtensions++
- }
- if m.srtpProtectionProfile != 0 {
- extensionsLength += 2 + 2 + 1 + len(m.srtpMasterKeyIdentifier)
- numExtensions++
- }
- if m.sctList != nil {
- extensionsLength += len(m.sctList)
- numExtensions++
- }
- if l := len(m.customExtension); l > 0 {
- extensionsLength += l
- numExtensions++
- }
-
- if numExtensions > 0 {
- extensionsLength += 4 * numExtensions
- length += 2 + extensionsLength
- }
-
- x := make([]byte, 4+length)
- x[0] = typeServerHello
- x[1] = uint8(length >> 16)
- x[2] = uint8(length >> 8)
- x[3] = uint8(length)
+ handshakeMsg := newByteBuilder()
+ handshakeMsg.addU8(typeServerHello)
+ hello := handshakeMsg.addU24LengthPrefixed()
vers := versionToWire(m.vers, m.isDTLS)
- x[4] = uint8(vers >> 8)
- x[5] = uint8(vers)
- copy(x[6:38], m.random)
- x[38] = uint8(len(m.sessionId))
- copy(x[39:39+len(m.sessionId)], m.sessionId)
- z := x[39+len(m.sessionId):]
- z[0] = uint8(m.cipherSuite >> 8)
- z[1] = uint8(m.cipherSuite)
- z[2] = uint8(m.compressionMethod)
+ hello.addU16(vers)
+ hello.addBytes(m.random)
+ sessionId := hello.addU8LengthPrefixed()
+ sessionId.addBytes(m.sessionId)
+ hello.addU16(m.cipherSuite)
+ hello.addU8(m.compressionMethod)
- z = z[3:]
- if numExtensions > 0 {
- z[0] = byte(extensionsLength >> 8)
- z[1] = byte(extensionsLength)
- z = z[2:]
- }
+ extensions := hello.addU16LengthPrefixed()
if m.duplicateExtension {
// Add a duplicate bogus extension at the beginning and end.
- z[0] = 0xff
- z[1] = 0xff
- z = z[4:]
+ extensions.addU16(0xffff)
+ extensions.addU16(0) // length = 0 for empty extension
}
if m.nextProtoNeg && !m.npnLast {
- z[0] = byte(extensionNextProtoNeg >> 8)
- z[1] = byte(extensionNextProtoNeg & 0xff)
- z[2] = byte(nextProtoLen >> 8)
- z[3] = byte(nextProtoLen)
- z = z[4:]
+ extensions.addU16(extensionNextProtoNeg)
+ extension := extensions.addU16LengthPrefixed()
for _, v := range m.nextProtos {
- l := len(v)
- if l > 255 {
- l = 255
+ if len(v) > 255 {
+ v = v[:255]
}
- z[0] = byte(l)
- copy(z[1:], []byte(v[0:l]))
- z = z[1+l:]
+ npn := extension.addU8LengthPrefixed()
+ npn.addBytes([]byte(v))
}
}
if m.ocspStapling {
- z[0] = byte(extensionStatusRequest >> 8)
- z[1] = byte(extensionStatusRequest)
- z = z[4:]
+ extensions.addU16(extensionStatusRequest)
+ extensions.addU16(0)
}
if m.ticketSupported {
- z[0] = byte(extensionSessionTicket >> 8)
- z[1] = byte(extensionSessionTicket)
- z = z[4:]
+ extensions.addU16(extensionSessionTicket)
+ extensions.addU16(0)
}
if m.secureRenegotiation != nil {
- z[0] = byte(extensionRenegotiationInfo >> 8)
- z[1] = byte(extensionRenegotiationInfo & 0xff)
- z[2] = 0
- z[3] = byte(1 + len(m.secureRenegotiation))
- z[4] = byte(len(m.secureRenegotiation))
- z = z[5:]
- copy(z, m.secureRenegotiation)
- z = z[len(m.secureRenegotiation):]
+ extensions.addU16(extensionRenegotiationInfo)
+ extension := extensions.addU16LengthPrefixed()
+ secureRenego := extension.addU8LengthPrefixed()
+ secureRenego.addBytes(m.secureRenegotiation)
}
- if alpnLen := len(m.alpnProtocol); alpnLen > 0 || m.alpnProtocolEmpty {
- z[0] = byte(extensionALPN >> 8)
- z[1] = byte(extensionALPN & 0xff)
- l := 2 + 1 + alpnLen
- z[2] = byte(l >> 8)
- z[3] = byte(l)
- l -= 2
- z[4] = byte(l >> 8)
- z[5] = byte(l)
- l -= 1
- z[6] = byte(l)
- copy(z[7:], []byte(m.alpnProtocol))
- z = z[7+alpnLen:]
+ if len(m.alpnProtocol) > 0 || m.alpnProtocolEmpty {
+ extensions.addU16(extensionALPN)
+ extension := extensions.addU16LengthPrefixed()
+
+ protocolNameList := extension.addU16LengthPrefixed()
+ protocolName := protocolNameList.addU8LengthPrefixed()
+ protocolName.addBytes([]byte(m.alpnProtocol))
}
if m.channelIDRequested {
- z[0] = byte(extensionChannelID >> 8)
- z[1] = byte(extensionChannelID & 0xff)
- z = z[4:]
+ extensions.addU16(extensionChannelID)
+ extensions.addU16(0)
}
if m.duplicateExtension {
// Add a duplicate bogus extension at the beginning and end.
- z[0] = 0xff
- z[1] = 0xff
- z = z[4:]
+ extensions.addU16(0xffff)
+ extensions.addU16(0)
}
if m.extendedMasterSecret {
- z[0] = byte(extensionExtendedMasterSecret >> 8)
- z[1] = byte(extensionExtendedMasterSecret & 0xff)
- z = z[4:]
+ extensions.addU16(extensionExtendedMasterSecret)
+ extensions.addU16(0)
}
if m.srtpProtectionProfile != 0 {
- z[0] = byte(extensionUseSRTP >> 8)
- z[1] = byte(extensionUseSRTP & 0xff)
- l := 2 + 2 + 1 + len(m.srtpMasterKeyIdentifier)
- z[2] = byte(l >> 8)
- z[3] = byte(l & 0xff)
- z[4] = 0
- z[5] = 2
- z[6] = byte(m.srtpProtectionProfile >> 8)
- z[7] = byte(m.srtpProtectionProfile & 0xff)
- l = len(m.srtpMasterKeyIdentifier)
- z[8] = byte(l)
- copy(z[9:], []byte(m.srtpMasterKeyIdentifier))
- z = z[9+l:]
+ extensions.addU16(extensionUseSRTP)
+ extension := extensions.addU16LengthPrefixed()
+
+ srtpProtectionProfiles := extension.addU16LengthPrefixed()
+ srtpProtectionProfiles.addU8(byte(m.srtpProtectionProfile >> 8))
+ srtpProtectionProfiles.addU8(byte(m.srtpProtectionProfile))
+ srtpMki := extension.addU8LengthPrefixed()
+ srtpMki.addBytes([]byte(m.srtpMasterKeyIdentifier))
}
if m.sctList != nil {
- z[0] = byte(extensionSignedCertificateTimestamp >> 8)
- z[1] = byte(extensionSignedCertificateTimestamp & 0xff)
- l := len(m.sctList)
- z[2] = byte(l >> 8)
- z[3] = byte(l & 0xff)
- copy(z[4:], m.sctList)
- z = z[4+l:]
+ extensions.addU16(extensionSignedCertificateTimestamp)
+ extension := extensions.addU16LengthPrefixed()
+ extension.addBytes(m.sctList)
}
if l := len(m.customExtension); l > 0 {
- z[0] = byte(extensionCustom >> 8)
- z[1] = byte(extensionCustom & 0xff)
- z[2] = byte(l >> 8)
- z[3] = byte(l & 0xff)
- copy(z[4:], []byte(m.customExtension))
- z = z[4+l:]
+ extensions.addU16(extensionCustom)
+ customExt := extensions.addU16LengthPrefixed()
+ customExt.addBytes([]byte(m.customExtension))
}
if m.nextProtoNeg && m.npnLast {
- z[0] = byte(extensionNextProtoNeg >> 8)
- z[1] = byte(extensionNextProtoNeg & 0xff)
- z[2] = byte(nextProtoLen >> 8)
- z[3] = byte(nextProtoLen)
- z = z[4:]
+ extensions.addU16(extensionNextProtoNeg)
+ extension := extensions.addU16LengthPrefixed()
for _, v := range m.nextProtos {
- l := len(v)
- if l > 255 {
- l = 255
+ if len(v) > 255 {
+ v = v[0:255]
}
- z[0] = byte(l)
- copy(z[1:], []byte(v[0:l]))
- z = z[1+l:]
+ npn := extension.addU8LengthPrefixed()
+ npn.addBytes([]byte(v))
}
}
- m.raw = x
+ if extensions.len() == 0 {
+ hello.discardChild()
+ }
- return x
+ m.raw = handshakeMsg.finish()
+ return m.raw
}
func (m *serverHelloMsg) unmarshal(data []byte) bool {