Fix handling of ServerHellos with omitted extensions.

Due to SSL 3.0 legacy, TLS 1.0 through 1.2 allow ClientHello and
ServerHello messages to omit the extensions field altogether, rather
than write an empty field. We broke this in
https://boringssl-review.googlesource.com/c/17704/ when we needed to a
second ServerHello parsing path.

Fix this and add some regression tests to explicitly test both the
omitted and empty extensions ClientHello and ServerHello cases.

Bug: chromium:743218
Change-Id: I8297ba608570238e19f12ea44a9fe2fe9d881d28
Reviewed-on: https://boringssl-review.googlesource.com/17904
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index 48fe052..dfb9c92 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -806,12 +806,80 @@
   return 1;
 }
 
+static int parse_server_version(SSL_HANDSHAKE *hs, uint16_t *out) {
+  SSL *const ssl = hs->ssl;
+  if (ssl->s3->tmp.message_type != SSL3_MT_SERVER_HELLO &&
+      ssl->s3->tmp.message_type != SSL3_MT_HELLO_RETRY_REQUEST) {
+    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
+    OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
+    return 0;
+  }
+
+  CBS server_hello;
+  CBS_init(&server_hello, ssl->init_msg, ssl->init_num);
+  if (!CBS_get_u16(&server_hello, out)) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
+    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
+    return 0;
+  }
+
+  /* The server version may also be in the supported_versions extension if
+   * applicable. */
+  if (ssl->s3->tmp.message_type != SSL3_MT_SERVER_HELLO ||
+      *out != TLS1_2_VERSION) {
+    return 1;
+  }
+
+  uint8_t sid_length;
+  if (!CBS_skip(&server_hello, SSL3_RANDOM_SIZE) ||
+      !CBS_get_u8(&server_hello, &sid_length) ||
+      !CBS_skip(&server_hello, sid_length + 2 /* cipher_suite */ +
+                1 /* compression_method */)) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
+    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
+    return 0;
+  }
+
+  /* The extensions block may not be present. */
+  if (CBS_len(&server_hello) == 0) {
+    return 1;
+  }
+
+  CBS extensions;
+  if (!CBS_get_u16_length_prefixed(&server_hello, &extensions) ||
+      CBS_len(&server_hello) != 0) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
+    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
+    return 0;
+  }
+
+  int have_supported_versions;
+  CBS supported_versions;
+  const SSL_EXTENSION_TYPE ext_types[] = {
+    {TLSEXT_TYPE_supported_versions, &have_supported_versions,
+     &supported_versions},
+  };
+
+  uint8_t alert = SSL_AD_DECODE_ERROR;
+  if (!ssl_parse_extensions(&extensions, &alert, ext_types,
+                            OPENSSL_ARRAY_SIZE(ext_types),
+                            1 /* ignore unknown */)) {
+    ssl3_send_alert(ssl, SSL3_AL_FATAL, alert);
+    return 0;
+  }
+
+  if (have_supported_versions &&
+      (!CBS_get_u16(&supported_versions, out) ||
+       CBS_len(&supported_versions) != 0)) {
+    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
+    return 0;
+  }
+
+  return 1;
+}
+
 static int ssl3_get_server_hello(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  CBS server_hello, server_random, session_id;
-  uint16_t server_version, cipher_suite;
-  uint8_t compression_method;
-
   int ret = ssl->method->ssl_get_message(ssl);
   if (ret <= 0) {
     uint32_t err = ERR_peek_error();
@@ -828,62 +896,11 @@
     return ret;
   }
 
-  if (ssl->s3->tmp.message_type != SSL3_MT_SERVER_HELLO &&
-      ssl->s3->tmp.message_type != SSL3_MT_HELLO_RETRY_REQUEST) {
-    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
-    OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
+  uint16_t server_version;
+  if (!parse_server_version(hs, &server_version)) {
     return -1;
   }
 
-  CBS_init(&server_hello, ssl->init_msg, ssl->init_num);
-
-  if (!CBS_get_u16(&server_hello, &server_version)) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
-    ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
-    return -1;
-  }
-
-  /* Parse out server version from supported_versions if available. */
-  if (ssl->s3->tmp.message_type == SSL3_MT_SERVER_HELLO &&
-      server_version == TLS1_2_VERSION) {
-    CBS copy = server_hello;
-    CBS extensions;
-    uint8_t sid_length;
-    if (!CBS_skip(&copy, SSL3_RANDOM_SIZE) ||
-        !CBS_get_u8(&copy, &sid_length) ||
-        !CBS_skip(&copy, sid_length + 2 /* cipher_suite */ +
-                             1 /* compression_method */) ||
-        !CBS_get_u16_length_prefixed(&copy, &extensions) ||
-        CBS_len(&copy) != 0) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
-      ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
-      return -1;
-    }
-
-    int have_supported_versions;
-    CBS supported_versions;
-    const SSL_EXTENSION_TYPE ext_types[] = {
-        {TLSEXT_TYPE_supported_versions, &have_supported_versions,
-         &supported_versions},
-    };
-
-    uint8_t alert = SSL_AD_DECODE_ERROR;
-    if (!ssl_parse_extensions(&extensions, &alert, ext_types,
-                              OPENSSL_ARRAY_SIZE(ext_types),
-                              1 /* ignore unknown */)) {
-      ssl3_send_alert(ssl, SSL3_AL_FATAL, alert);
-      return -1;
-    }
-
-    if (have_supported_versions) {
-      if (!CBS_get_u16(&supported_versions, &server_version) ||
-          CBS_len(&supported_versions) != 0) {
-        ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
-        return -1;
-      }
-    }
-  }
-
   if (!ssl_supports_version(hs, server_version)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNSUPPORTED_PROTOCOL);
     ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_PROTOCOL_VERSION);
@@ -920,7 +937,12 @@
     return -1;
   }
 
-  if (!CBS_get_bytes(&server_hello, &server_random, SSL3_RANDOM_SIZE) ||
+  CBS server_hello, server_random, session_id;
+  uint16_t cipher_suite;
+  uint8_t compression_method;
+  CBS_init(&server_hello, ssl->init_msg, ssl->init_num);
+  if (!CBS_skip(&server_hello, 2 /* version */) ||
+      !CBS_get_bytes(&server_hello, &server_random, SSL3_RANDOM_SIZE) ||
       !CBS_get_u8_length_prefixed(&server_hello, &session_id) ||
       CBS_len(&session_id) > SSL3_SESSION_ID_SIZE ||
       !CBS_get_u16(&server_hello, &cipher_suite) ||
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index a3c744c..fd9fb3d 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -1384,6 +1384,14 @@
 	// RejectUnsolicitedKeyUpdate, if true, causes all unsolicited
 	// KeyUpdates from the peer to be rejected.
 	RejectUnsolicitedKeyUpdate bool
+
+	// OmitExtensions, if true, causes the extensions field in ClientHello
+	// and ServerHello messages to be omitted.
+	OmitExtensions bool
+
+	// EmptyExtensions, if true, causese the extensions field in ClientHello
+	// and ServerHello messages to be present, but empty.
+	EmptyExtensions bool
 }
 
 func (c *Config) serverInit() {
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index cac1ebf..05e7311 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -84,7 +84,6 @@
 		sctListSupported:        !c.config.Bugs.NoSignedCertificateTimestamps,
 		serverName:              c.config.ServerName,
 		supportedCurves:         c.config.curvePreferences(),
-		pskKEModes:              []byte{pskDHEKEMode},
 		supportedPoints:         []uint8{pointFormatUncompressed},
 		nextProtoNeg:            len(c.config.NextProtos) > 0,
 		secureRenegotiation:     []byte{},
@@ -97,6 +96,8 @@
 		srtpMasterKeyIdentifier: c.config.Bugs.SRTPMasterKeyIdentifer,
 		customExtension:         c.config.Bugs.CustomExtension,
 		pskBinderFirst:          c.config.Bugs.PSKBinderFirst,
+		omitExtensions:          c.config.Bugs.OmitExtensions,
+		emptyExtensions:         c.config.Bugs.EmptyExtensions,
 	}
 
 	if maxVersion >= VersionTLS13 {
@@ -104,6 +105,7 @@
 		if !c.config.Bugs.OmitSupportedVersions {
 			hello.supportedVersions = c.config.supportedVersions(c.isDTLS)
 		}
+		hello.pskKEModes = []byte{pskDHEKEMode}
 	} else {
 		hello.vers = mapClientHelloVersion(maxVersion, c.isDTLS)
 	}
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 1e36f08..4be873d 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -7,6 +7,7 @@
 import (
 	"bytes"
 	"encoding/binary"
+	"fmt"
 )
 
 func writeLen(buf []byte, v, size int) {
@@ -35,6 +36,11 @@
 	return len(*bb.buf) - bb.start - bb.prefixLen
 }
 
+func (bb *byteBuilder) data() []byte {
+	bb.flush()
+	return (*bb.buf)[bb.start+bb.prefixLen:]
+}
+
 func (bb *byteBuilder) flush() {
 	if bb.child == nil {
 		return
@@ -112,11 +118,11 @@
 }
 
 func (bb *byteBuilder) discardChild() {
-	if bb.child != nil {
+	if bb.child == nil {
 		return
 	}
+	*bb.buf = (*bb.buf)[:bb.child.start]
 	bb.child = nil
-	*bb.buf = (*bb.buf)[:bb.start]
 }
 
 type keyShareEntry struct {
@@ -167,6 +173,8 @@
 	customExtension         string
 	hasGREASEExtension      bool
 	pskBinderFirst          bool
+	omitExtensions          bool
+	emptyExtensions         bool
 }
 
 func (m *clientHelloMsg) equal(i interface{}) bool {
@@ -212,7 +220,9 @@
 		m.sctListSupported == m1.sctListSupported &&
 		m.customExtension == m1.customExtension &&
 		m.hasGREASEExtension == m1.hasGREASEExtension &&
-		m.pskBinderFirst == m1.pskBinderFirst
+		m.pskBinderFirst == m1.pskBinderFirst &&
+		m.omitExtensions == m1.omitExtensions &&
+		m.emptyExtensions == m1.emptyExtensions
 }
 
 func (m *clientHelloMsg) marshal() []byte {
@@ -444,8 +454,12 @@
 		}
 	}
 
-	if extensions.len() == 0 {
+	if m.omitExtensions || m.emptyExtensions {
+		// Silently erase any extensions which were sent.
 		hello.discardChild()
+		if m.emptyExtensions {
+			hello.addU16(0)
+		}
 	}
 
 	m.raw = handshakeMsg.finish()
@@ -828,6 +842,8 @@
 	compressionMethod     uint8
 	customExtension       string
 	unencryptedALPN       string
+	omitExtensions        bool
+	emptyExtensions       bool
 	extensions            serverExtensions
 }
 
@@ -904,8 +920,17 @@
 		}
 	} else {
 		m.extensions.marshal(extensions)
-		if extensions.len() == 0 {
+		if m.omitExtensions || m.emptyExtensions {
+			// Silently erasing server extensions will break the handshake. Instead,
+			// assert that tests which use this field also disable all features which
+			// would write an extension.
+			if extensions.len() != 0 {
+				panic(fmt.Sprintf("ServerHello unexpectedly contained extensions: %x, %+v", extensions.data(), m))
+			}
 			hello.discardChild()
+			if m.emptyExtensions {
+				hello.addU16(0)
+			}
 		}
 	}
 
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 35d005e..b31a562 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -1068,6 +1068,8 @@
 		extensions: serverExtensions{
 			supportedVersion: config.Bugs.SendServerSupportedExtensionVersion,
 		},
+		omitExtensions:  config.Bugs.OmitExtensions,
+		emptyExtensions: config.Bugs.EmptyExtensions,
 	}
 
 	hs.hello.random = make([]byte, 32)
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 5c0906a..29747db 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -12074,6 +12074,75 @@
 	})
 }
 
+// Test that omitted and empty extensions blocks are tolerated.
+func addOmitExtensionsTests() {
+	for _, ver := range tlsVersions {
+		if ver.version > VersionTLS12 {
+			continue
+		}
+
+		testCases = append(testCases, testCase{
+			testType: serverTest,
+			name:     "OmitExtensions-ClientHello-" + ver.name,
+			config: Config{
+				MinVersion:             ver.version,
+				MaxVersion:             ver.version,
+				SessionTicketsDisabled: true,
+				Bugs: ProtocolBugs{
+					OmitExtensions: true,
+				},
+			},
+		})
+
+		testCases = append(testCases, testCase{
+			testType: serverTest,
+			name:     "EmptyExtensions-ClientHello-" + ver.name,
+			config: Config{
+				MinVersion:             ver.version,
+				MaxVersion:             ver.version,
+				SessionTicketsDisabled: true,
+				Bugs: ProtocolBugs{
+					EmptyExtensions: true,
+				},
+			},
+		})
+
+		testCases = append(testCases, testCase{
+			testType: clientTest,
+			name:     "OmitExtensions-ServerHello-" + ver.name,
+			config: Config{
+				MinVersion:             ver.version,
+				MaxVersion:             ver.version,
+				SessionTicketsDisabled: true,
+				Bugs: ProtocolBugs{
+					OmitExtensions: true,
+					// Disable all ServerHello extensions so
+					// OmitExtensions works.
+					NoExtendedMasterSecret: true,
+					NoRenegotiationInfo:    true,
+				},
+			},
+		})
+
+		testCases = append(testCases, testCase{
+			testType: clientTest,
+			name:     "EmptyExtensions-ServerHello-" + ver.name,
+			config: Config{
+				MinVersion:             ver.version,
+				MaxVersion:             ver.version,
+				SessionTicketsDisabled: true,
+				Bugs: ProtocolBugs{
+					EmptyExtensions: true,
+					// Disable all ServerHello extensions so
+					// EmptyExtensions works.
+					NoExtendedMasterSecret: true,
+					NoRenegotiationInfo:    true,
+				},
+			},
+		})
+	}
+}
+
 func worker(statusChan chan statusMsg, c chan *testCase, shimPath string, wg *sync.WaitGroup) {
 	defer wg.Done()
 
@@ -12201,6 +12270,7 @@
 	addRetainOnlySHA256ClientCertTests()
 	addECDSAKeyUsageTests()
 	addExtraHandshakeTests()
+	addOmitExtensionsTests()
 
 	var wg sync.WaitGroup