runner: Parse SH/HRR/EE with byteReader.

Bug: 212
Change-Id: I454db0bfd59bac3729338c6f8d9e51efde0735eb
Reviewed-on: https://boringssl-review.googlesource.com/23446
Reviewed-by: Steven Valdez <svaldez@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 766d515..bec97dd 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -537,11 +537,7 @@
 
 		srtpProtectionProfiles := useSrtpExt.addU16LengthPrefixed()
 		for _, p := range m.srtpProtectionProfiles {
-			// An SRTPProtectionProfile is defined as uint8[2],
-			// not uint16. For some reason, we're storing it
-			// as a uint16.
-			srtpProtectionProfiles.addU8(byte(p >> 8))
-			srtpProtectionProfiles.addU8(byte(p))
+			srtpProtectionProfiles.addU16(p)
 		}
 		srtpMki := useSrtpExt.addU8LengthPrefixed()
 		srtpMki.addBytes([]byte(m.srtpMasterKeyIdentifier))
@@ -967,76 +963,56 @@
 }
 
 func (m *serverHelloMsg) unmarshal(data []byte) bool {
-	if len(data) < 42 {
+	m.raw = data
+	reader := byteReader(data[4:])
+	if !reader.readU16(&m.vers) ||
+		!reader.readBytes(&m.random, 32) {
 		return false
 	}
-	m.raw = data
-	m.vers = uint16(data[4])<<8 | uint16(data[5])
 	vers, ok := wireToVersion(m.vers, m.isDTLS)
 	if !ok {
 		return false
 	}
-	m.random = data[6:38]
-	data = data[38:]
 	if vers < VersionTLS13 || isResumptionExperiment(m.vers) {
-		sessionIdLen := int(data[0])
-		if sessionIdLen > 32 || len(data) < 1+sessionIdLen {
+		if !reader.readU8LengthPrefixedBytes(&m.sessionId) {
 			return false
 		}
-		m.sessionId = data[1 : 1+sessionIdLen]
-		data = data[1+sessionIdLen:]
 	}
-	if len(data) < 2 {
+	if !reader.readU16(&m.cipherSuite) {
 		return false
 	}
-	m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
-	data = data[2:]
 	if vers < VersionTLS13 || isResumptionExperiment(m.vers) {
-		if len(data) < 1 {
+		if !reader.readU8(&m.compressionMethod) {
 			return false
 		}
-		m.compressionMethod = data[0]
-		data = data[1:]
 	}
 
-	if len(data) == 0 && m.vers < VersionTLS13 {
+	if len(reader) == 0 && m.vers < VersionTLS13 {
 		// Extension data is optional before TLS 1.3.
 		m.extensions = serverExtensions{}
 		m.omitExtensions = true
 		return true
 	}
-	if len(data) < 2 {
-		return false
-	}
 
-	extensionsLength := int(data[0])<<8 | int(data[1])
-	data = data[2:]
-	if len(data) != extensionsLength {
+	var extensions byteReader
+	if !reader.readU16LengthPrefixed(&extensions) || len(reader) != 0 {
 		return false
 	}
 
 	// Parse out the version from supported_versions if available.
 	if m.vers == VersionTLS12 {
-		vdata := data
-		for len(vdata) != 0 {
-			if len(vdata) < 4 {
+		extensionsCopy := extensions
+		for len(extensionsCopy) > 0 {
+			var extension uint16
+			var body byteReader
+			if !extensionsCopy.readU16(&extension) ||
+				!extensionsCopy.readU16LengthPrefixed(&body) {
 				return false
 			}
-			extension := uint16(vdata[0])<<8 | uint16(vdata[1])
-			length := int(vdata[2])<<8 | int(vdata[3])
-			vdata = vdata[4:]
-
-			if len(vdata) < length {
-				return false
-			}
-			d := vdata[:length]
-			vdata = vdata[length:]
-
 			if extension == extensionSupportedVersions {
-				if len(d) < 2 {
+				if !body.readU16(&m.vers) || len(body) != 0 {
 					return false
 				}
-				m.vers = uint16(d[0])<<8 | uint16(d[1])
 				vers, ok = wireToVersion(m.vers, m.isDTLS)
 				if !ok {
 					return false
@@ -1046,38 +1022,27 @@
 	}
 
 	if vers >= VersionTLS13 {
-		for len(data) != 0 {
-			if len(data) < 4 {
+		for len(extensions) > 0 {
+			var extension uint16
+			var body byteReader
+			if !extensions.readU16(&extension) ||
+				!extensions.readU16LengthPrefixed(&body) {
 				return false
 			}
-			extension := uint16(data[0])<<8 | uint16(data[1])
-			length := int(data[2])<<8 | int(data[3])
-			data = data[4:]
-
-			if len(data) < length {
-				return false
-			}
-			d := data[:length]
-			data = data[length:]
-
 			switch extension {
 			case extensionKeyShare:
 				m.hasKeyShare = true
-				if len(d) < 4 {
+				var group uint16
+				if !body.readU16(&group) ||
+					!body.readU16LengthPrefixedBytes(&m.keyShare.keyExchange) ||
+					len(body) != 0 {
 					return false
 				}
-				m.keyShare.group = CurveID(uint16(d[0])<<8 | uint16(d[1]))
-				keyExchLen := int(d[2])<<8 | int(d[3])
-				if keyExchLen != len(d)-4 {
-					return false
-				}
-				m.keyShare.keyExchange = make([]byte, keyExchLen)
-				copy(m.keyShare.keyExchange, d[4:])
+				m.keyShare.group = CurveID(group)
 			case extensionPreSharedKey:
-				if len(d) != 2 {
+				if !body.readU16(&m.pskIdentity) || len(body) != 0 {
 					return false
 				}
-				m.pskIdentity = uint16(d[0])<<8 | uint16(d[1])
 				m.hasPSKIdentity = true
 			case extensionSupportedVersions:
 				if !isResumptionExperiment(m.vers) {
@@ -1089,7 +1054,7 @@
 				return false
 			}
 		}
-	} else if !m.extensions.unmarshal(data, vers) {
+	} else if !m.extensions.unmarshal(extensions, vers) {
 		return false
 	}
 
@@ -1121,23 +1086,12 @@
 
 func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
 	m.raw = data
-	if len(data) < 6 {
+	reader := byteReader(data[4:])
+	var extensions byteReader
+	if !reader.readU16LengthPrefixed(&extensions) || len(reader) != 0 {
 		return false
 	}
-	if data[0] != typeEncryptedExtensions {
-		return false
-	}
-	msgLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
-	data = data[4:]
-	if len(data) != msgLen {
-		return false
-	}
-	extLen := int(data[0])<<8 | int(data[1])
-	data = data[2:]
-	if extLen != len(data) {
-		return false
-	}
-	return m.extensions.unmarshal(data, VersionTLS13)
+	return m.extensions.unmarshal(extensions, VersionTLS13)
 }
 
 type serverExtensions struct {
@@ -1223,8 +1177,7 @@
 		extension := extensions.addU16LengthPrefixed()
 
 		srtpProtectionProfiles := extension.addU16LengthPrefixed()
-		srtpProtectionProfiles.addU8(byte(m.srtpProtectionProfile >> 8))
-		srtpProtectionProfiles.addU8(byte(m.srtpProtectionProfile))
+		srtpProtectionProfiles.addU16(m.srtpProtectionProfile)
 		srtpMki := extension.addU8LengthPrefixed()
 		srtpMki.addBytes([]byte(m.srtpMasterKeyIdentifier))
 	}
@@ -1288,96 +1241,77 @@
 	}
 }
 
-func (m *serverExtensions) unmarshal(data []byte, version uint16) bool {
+func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool {
 	// Reset all fields.
 	*m = serverExtensions{}
 
-	for len(data) != 0 {
-		if len(data) < 4 {
+	for len(data) > 0 {
+		var extension uint16
+		var body byteReader
+		if !data.readU16(&extension) ||
+			!data.readU16LengthPrefixed(&body) {
 			return false
 		}
-		extension := uint16(data[0])<<8 | uint16(data[1])
-		length := int(data[2])<<8 | int(data[3])
-		data = data[4:]
-		if len(data) < length {
-			return false
-		}
-
 		switch extension {
 		case extensionNextProtoNeg:
 			m.nextProtoNeg = true
-			d := data[:length]
-			for len(d) > 0 {
-				l := int(d[0])
-				d = d[1:]
-				if l == 0 || l > len(d) {
+			for len(body) > 0 {
+				var protocol []byte
+				if !body.readU8LengthPrefixedBytes(&protocol) {
 					return false
 				}
-				m.nextProtos = append(m.nextProtos, string(d[:l]))
-				d = d[l:]
+				m.nextProtos = append(m.nextProtos, string(protocol))
 			}
 		case extensionStatusRequest:
-			if length > 0 {
+			if len(body) != 0 {
 				return false
 			}
 			m.ocspStapling = true
 		case extensionSessionTicket:
-			if length > 0 {
+			if len(body) != 0 {
 				return false
 			}
 			m.ticketSupported = true
 		case extensionRenegotiationInfo:
-			if length < 1 || length != int(data[0])+1 {
+			if !body.readU8LengthPrefixedBytes(&m.secureRenegotiation) || len(body) != 0 {
 				return false
 			}
-			m.secureRenegotiation = data[1:length]
 		case extensionALPN:
-			d := data[:length]
-			if len(d) < 3 {
+			var protocols, protocol byteReader
+			if !body.readU16LengthPrefixed(&protocols) ||
+				len(body) != 0 ||
+				!protocols.readU8LengthPrefixed(&protocol) ||
+				len(protocols) != 0 {
 				return false
 			}
-			l := int(d[0])<<8 | int(d[1])
-			if l != len(d)-2 {
-				return false
-			}
-			d = d[2:]
-			l = int(d[0])
-			if l != len(d)-1 {
-				return false
-			}
-			d = d[1:]
-			m.alpnProtocol = string(d)
-			m.alpnProtocolEmpty = len(d) == 0
+			m.alpnProtocol = string(protocol)
+			m.alpnProtocolEmpty = len(protocol) == 0
 		case extensionChannelID:
-			if length > 0 {
+			if len(body) != 0 {
 				return false
 			}
 			m.channelIDRequested = true
 		case extensionExtendedMasterSecret:
-			if length != 0 {
+			if len(body) != 0 {
 				return false
 			}
 			m.extendedMasterSecret = true
 		case extensionUseSRTP:
-			if length < 2+2+1 {
+			var profiles, mki byteReader
+			if !body.readU16LengthPrefixed(&profiles) ||
+				!profiles.readU16(&m.srtpProtectionProfile) ||
+				len(profiles) != 0 ||
+				!body.readU8LengthPrefixed(&mki) ||
+				len(body) != 0 {
 				return false
 			}
-			if data[0] != 0 || data[1] != 2 {
-				return false
-			}
-			m.srtpProtectionProfile = uint16(data[2])<<8 | uint16(data[3])
-			d := data[4:length]
-			l := int(d[0])
-			if l != len(d)-1 {
-				return false
-			}
-			m.srtpMasterKeyIdentifier = string(d[1:])
+			m.srtpMasterKeyIdentifier = string(mki)
 		case extensionSignedCertificateTimestamp:
-			m.sctList = data[:length]
+			m.sctList = []byte(body)
 		case extensionCustom:
-			m.customExtension = string(data[:length])
+			m.customExtension = string(body)
 		case extensionServerName:
-			if length != 0 {
+			if len(body) != 0 {
 				return false
 			}
 			m.serverNameAck = true
@@ -1387,21 +1321,16 @@
 				return false
 			}
 			// http://tools.ietf.org/html/rfc4492#section-5.5.2
-			if length < 1 {
+			if !body.readU8LengthPrefixedBytes(&m.supportedPoints) || len(body) != 0 {
 				return false
 			}
-			l := int(data[0])
-			if length != l+1 {
-				return false
-			}
-			m.supportedPoints = data[1 : 1+l]
 		case extensionSupportedCurves:
 			// The server can only send supported_curves in TLS 1.3.
 			if version < VersionTLS13 {
 				return false
 			}
 		case extensionEarlyData:
-			if version < VersionTLS13 || length != 0 {
+			if version < VersionTLS13 || len(body) != 0 {
 				return false
 			}
 			m.hasEarlyData = true
@@ -1409,7 +1338,6 @@
 			// Unknown extensions are illegal from the server.
 			return false
 		}
-		data = data[length:]
 	}
 
 	return true
@@ -1490,73 +1418,56 @@
 
 func (m *helloRetryRequestMsg) unmarshal(data []byte) bool {
 	m.raw = data
-	if len(data) < 8 {
+	reader := byteReader(data[4:])
+	if !reader.readU16(&m.vers) {
 		return false
 	}
-	m.vers = uint16(data[4])<<8 | uint16(data[5])
-	data = data[6:]
 	if m.isServerHello {
-		if len(data) < 33 {
+		var random []byte
+		var compressionMethod byte
+		if !reader.readBytes(&random, 32) ||
+			!reader.readU8LengthPrefixedBytes(&m.sessionId) ||
+			!reader.readU16(&m.cipherSuite) ||
+			!reader.readU8(&compressionMethod) ||
+			compressionMethod != 0 {
 			return false
 		}
-		data = data[32:] // Random
-		sessionIdLen := int(data[0])
-		if sessionIdLen > 32 || len(data) < 1+sessionIdLen+3 {
-			return false
-		}
-		m.sessionId = data[1 : 1+sessionIdLen]
-		data = data[1+sessionIdLen:]
-		m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
-		data = data[2:]
-		data = data[1:] // Compression Method
-	} else {
-		if isDraft21(m.vers) {
-			m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
-			data = data[2:]
-		}
-	}
-	extLen := int(data[0])<<8 | int(data[1])
-	data = data[2:]
-	if len(data) != extLen || len(data) == 0 {
+	} else if isDraft21(m.vers) && !reader.readU16(&m.cipherSuite) {
 		return false
 	}
-	for len(data) > 0 {
-		if len(data) < 4 {
+	var extensions byteReader
+	if !reader.readU16LengthPrefixed(&extensions) || len(reader) != 0 {
+		return false
+	}
+	for len(extensions) > 0 {
+		var extension uint16
+		var body byteReader
+		if !extensions.readU16(&extension) ||
+			!extensions.readU16LengthPrefixed(&body) {
 			return false
 		}
-		extension := uint16(data[0])<<8 | uint16(data[1])
-		length := int(data[2])<<8 | int(data[3])
-		data = data[4:]
-		if len(data) < length {
-			return false
-		}
-
 		switch extension {
 		case extensionSupportedVersions:
-			if length != 2 || !m.isServerHello {
+			if !m.isServerHello ||
+				!body.readU16(&m.vers) ||
+				len(body) != 0 {
 				return false
 			}
-			m.vers = uint16(data[0])<<8 | uint16(data[1])
 		case extensionKeyShare:
-			if length != 2 {
+			var v uint16
+			if !body.readU16(&v) || len(body) != 0 {
 				return false
 			}
 			m.hasSelectedGroup = true
-			m.selectedGroup = CurveID(data[0])<<8 | CurveID(data[1])
+			m.selectedGroup = CurveID(v)
 		case extensionCookie:
-			if length < 2 {
+			if !body.readU16LengthPrefixedBytes(&m.cookie) || len(body) != 0 {
 				return false
 			}
-			cookieLen := int(data[0])<<8 | int(data[1])
-			if 2+cookieLen != length {
-				return false
-			}
-			m.cookie = data[2 : 2+cookieLen]
 		default:
 			// Unknown extensions are illegal from the server.
 			return false
 		}
-		data = data[length:]
 	}
 	return true
 }