Pack encrypted handshake messages together.

This does not affect TLS 1.2 (beyond Channel ID or NPN) but, in TLS 1.3,
we send several encrypted handshake messages in a row. For the server,
this means 66 wasted bytes in TLS 1.3. Since OpenSSL has otherwise used
one record per message since the beginning and unencrypted overhead is
less interesting, leave that behavior as-is for the time being. (This
isn't the most pressing use of the breakage budget.) But TLS 1.3 is new,
so get this tight from the start.

Change-Id: I64dbd590a62469d296e1f10673c14bcd0c62919a
Reviewed-on: https://boringssl-review.googlesource.com/22068
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Steven Valdez <svaldez@google.com>
diff --git a/ssl/internal.h b/ssl/internal.h
index 3953fbd..881f930 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1010,6 +1010,9 @@
 // |tls_has_unprocessed_handshake_data| for DTLS.
 bool dtls_has_unprocessed_handshake_data(const SSL *ssl);
 
+// tls_flush_pending_hs_data flushes any handshake plaintext data.
+bool tls_flush_pending_hs_data(SSL *ssl);
+
 struct DTLS_OUTGOING_MESSAGE {
   DTLS_OUTGOING_MESSAGE() {}
   DTLS_OUTGOING_MESSAGE(const DTLS_OUTGOING_MESSAGE &) = delete;
@@ -2273,6 +2276,11 @@
   // hs_buf is the buffer of handshake data to process.
   UniquePtr<BUF_MEM> hs_buf;
 
+  // pending_hs_data contains the pending handshake data that has not yet
+  // been encrypted to |pending_flight|. This allows packing the handshake into
+  // fewer records.
+  UniquePtr<BUF_MEM> pending_hs_data;
+
   // pending_flight is the pending outgoing flight. This is used to flush each
   // handshake flight in a single write. |write_buffer| must be written out
   // before this data.
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index ede4ba7..6b55538 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -134,6 +134,8 @@
 
 static bool add_record_to_flight(SSL *ssl, uint8_t type,
                                  Span<const uint8_t> in) {
+  // The caller should have flushed |pending_hs_data| first.
+  assert(!ssl->s3->pending_hs_data);
   // We'll never add a flight while in the process of writing it out.
   assert(ssl->s3->pending_flight_offset == 0);
 
@@ -182,17 +184,49 @@
 }
 
 bool ssl3_add_message(SSL *ssl, Array<uint8_t> msg) {
-  // Add the message to the current flight, splitting into several records if
-  // needed.
+  // Pack handshake data into the minimal number of records. This avoids
+  // unnecessary encryption overhead, notably in TLS 1.3 where we send several
+  // encrypted messages in a row. For now, we do not do this for the null
+  // cipher. The benefit is smaller and there is a risk of breaking buggy
+  // implementations.
+  //
+  // TODO(davidben): See if we can do this uniformly.
   Span<const uint8_t> rest = msg;
-  do {
-    Span<const uint8_t> chunk = rest.subspan(0, ssl->max_send_fragment);
-    rest = rest.subspan(chunk.size());
+  if (ssl->s3->aead_write_ctx->is_null_cipher()) {
+    while (!rest.empty()) {
+      Span<const uint8_t> chunk = rest.subspan(0, ssl->max_send_fragment);
+      rest = rest.subspan(chunk.size());
 
-    if (!add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, chunk)) {
-      return false;
+      if (!add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, chunk)) {
+        return false;
+      }
     }
-  } while (!rest.empty());
+  } else {
+    while (!rest.empty()) {
+      // Flush if |pending_hs_data| is full.
+      if (ssl->s3->pending_hs_data &&
+          ssl->s3->pending_hs_data->length >= ssl->max_send_fragment &&
+          !tls_flush_pending_hs_data(ssl)) {
+        return false;
+      }
+
+      size_t pending_len =
+          ssl->s3->pending_hs_data ? ssl->s3->pending_hs_data->length : 0;
+      Span<const uint8_t> chunk =
+          rest.subspan(0, ssl->max_send_fragment - pending_len);
+      assert(!chunk.empty());
+      rest = rest.subspan(chunk.size());
+
+      if (!ssl->s3->pending_hs_data) {
+        ssl->s3->pending_hs_data.reset(BUF_MEM_new());
+      }
+      if (!ssl->s3->pending_hs_data ||
+          !BUF_MEM_append(ssl->s3->pending_hs_data.get(), chunk.data(),
+                          chunk.size())) {
+        return false;
+      }
+    }
+  }
 
   ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HANDSHAKE, msg);
   // TODO(svaldez): Move this up a layer to fix abstraction for SSLTranscript on
@@ -204,10 +238,23 @@
   return true;
 }
 
+bool tls_flush_pending_hs_data(SSL *ssl) {
+  if (!ssl->s3->pending_hs_data || ssl->s3->pending_hs_data->length == 0) {
+    return true;
+  }
+
+  UniquePtr<BUF_MEM> pending_hs_data = std::move(ssl->s3->pending_hs_data);
+  return add_record_to_flight(
+      ssl, SSL3_RT_HANDSHAKE,
+      MakeConstSpan(reinterpret_cast<const uint8_t *>(pending_hs_data->data),
+                    pending_hs_data->length));
+}
+
 bool ssl3_add_change_cipher_spec(SSL *ssl) {
   static const uint8_t kChangeCipherSpec[1] = {SSL3_MT_CCS};
 
-  if (!add_record_to_flight(ssl, SSL3_RT_CHANGE_CIPHER_SPEC,
+  if (!tls_flush_pending_hs_data(ssl) ||
+      !add_record_to_flight(ssl, SSL3_RT_CHANGE_CIPHER_SPEC,
                             kChangeCipherSpec)) {
     return false;
   }
@@ -219,7 +266,8 @@
 
 bool ssl3_add_alert(SSL *ssl, uint8_t level, uint8_t desc) {
   uint8_t alert[2] = {level, desc};
-  if (!add_record_to_flight(ssl, SSL3_RT_ALERT, alert)) {
+  if (!tls_flush_pending_hs_data(ssl) ||
+      !add_record_to_flight(ssl, SSL3_RT_ALERT, alert)) {
     return false;
   }
 
@@ -229,6 +277,10 @@
 }
 
 int ssl3_flush_flight(SSL *ssl) {
+  if (!tls_flush_pending_hs_data(ssl)) {
+    return -1;
+  }
+
   if (ssl->s3->pending_flight == nullptr) {
     return 1;
   }
diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc
index 285abb3..4e9e89a 100644
--- a/ssl/s3_pkt.cc
+++ b/ssl/s3_pkt.cc
@@ -234,6 +234,9 @@
     return 0;
   }
 
+  if (!tls_flush_pending_hs_data(ssl)) {
+    return -1;
+  }
   size_t flight_len = 0;
   if (ssl->s3->pending_flight != nullptr) {
     flight_len =
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index eee1337..169ce63 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -1416,6 +1416,11 @@
 	// length accepted from the peer.
 	MaxReceivePlaintext int
 
+	// ExpectPackedEncryptedHandshake, if non-zero, requires that the peer maximally
+	// pack their encrypted handshake messages, fitting at most the
+	// specified number of plaintext bytes per record.
+	ExpectPackedEncryptedHandshake int
+
 	// SendTicketLifetime, if non-zero, is the ticket lifetime to send in
 	// NewSessionTicket messages.
 	SendTicketLifetime time.Duration
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 0edbe5c..24edea9 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -101,6 +101,11 @@
 	keyUpdateRequested bool
 	seenOneByteRecord  bool
 
+	// seenHandshakePackEnd is whether the most recent handshake record was
+	// not full for ExpectPackedEncryptedHandshake. If true, no more
+	// handshake data may be received until the next flight or epoch change.
+	seenHandshakePackEnd bool
+
 	tmp [16]byte
 }
 
@@ -244,14 +249,6 @@
 	hc.incEpoch()
 }
 
-func (hc *halfConn) doKeyUpdate(c *Conn, isOutgoing bool) {
-	side := serverWrite
-	if c.isClient == isOutgoing {
-		side = clientWrite
-	}
-	hc.useTrafficSecret(hc.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), hc.trafficSecret), side)
-}
-
 // incSeq increments the sequence number.
 func (hc *halfConn) incSeq(isOutgoing bool) {
 	limit := 0
@@ -737,6 +734,23 @@
 	return b, bb
 }
 
+func (c *Conn) useInTrafficSecret(version uint16, suite *cipherSuite, secret []byte) {
+	side := serverWrite
+	if !c.isClient {
+		side = clientWrite
+	}
+	c.in.useTrafficSecret(version, suite, secret, side)
+	c.seenHandshakePackEnd = false
+}
+
+func (c *Conn) useOutTrafficSecret(version uint16, suite *cipherSuite, secret []byte) {
+	side := serverWrite
+	if c.isClient {
+		side = clientWrite
+	}
+	c.out.useTrafficSecret(version, suite, secret, side)
+}
+
 func (c *Conn) doReadRecord(want recordType) (recordType, *block, error) {
 RestartReadRecord:
 	if c.isDTLS {
@@ -901,6 +915,13 @@
 		return c.in.setErrorLocked(err)
 	}
 
+	if typ != recordTypeHandshake {
+		c.seenHandshakePackEnd = false
+	} else if c.seenHandshakePackEnd {
+		c.in.freeBlock(b)
+		return c.in.setErrorLocked(errors.New("tls: peer violated ExpectPackedEncryptedHandshake"))
+	}
+
 	switch typ {
 	default:
 		c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
@@ -962,6 +983,9 @@
 			return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
 		}
 		c.hand.Write(data)
+		if pack := c.config.Bugs.ExpectPackedEncryptedHandshake; pack > 0 && len(data) < pack && c.out.cipher != nil {
+			c.seenHandshakePackEnd = true
+		}
 	}
 
 	if b != nil {
@@ -1020,6 +1044,7 @@
 // to the connection and updates the record layer state.
 // c.out.Mutex <= L.
 func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
+	c.seenHandshakePackEnd = false
 	if typ == recordTypeHandshake {
 		msgType := data[0]
 		if c.config.Bugs.SendWrongMessageType != 0 && msgType == c.config.Bugs.SendWrongMessageType {
@@ -1522,7 +1547,7 @@
 		if c.config.Bugs.RejectUnsolicitedKeyUpdate {
 			return errors.New("tls: unexpected KeyUpdate message")
 		}
-		c.in.doKeyUpdate(c, false)
+		c.useInTrafficSecret(c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.in.trafficSecret))
 		if keyUpdate.keyUpdateRequest == keyUpdateRequested {
 			c.keyUpdateRequested = true
 		}
@@ -1554,7 +1579,7 @@
 		return errors.New("tls: received invalid KeyUpdate message")
 	}
 
-	c.in.doKeyUpdate(c, false)
+	c.useInTrafficSecret(c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.in.trafficSecret))
 	return nil
 }
 
@@ -1885,7 +1910,7 @@
 	if err := c.flushHandshake(); err != nil {
 		return err
 	}
-	c.out.doKeyUpdate(c, true)
+	c.useOutTrafficSecret(c.out.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.out.trafficSecret))
 	return nil
 }
 
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index ef6a464..3656c2b 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -411,7 +411,7 @@
 		finishedHash.addEntropy(session.masterSecret)
 		finishedHash.Write(helloBytes)
 		earlyTrafficSecret := finishedHash.deriveSecret(earlyTrafficLabel)
-		c.out.useTrafficSecret(session.wireVersion, pskCipherSuite, earlyTrafficSecret, clientWrite)
+		c.useOutTrafficSecret(session.wireVersion, pskCipherSuite, earlyTrafficSecret)
 		for _, earlyData := range c.config.Bugs.SendEarlyData {
 			if _, err := c.writeRecord(recordTypeApplicationData, earlyData); err != nil {
 				return err
@@ -754,7 +754,7 @@
 	// traffic key.
 	clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel)
 	serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel)
-	c.in.useTrafficSecret(c.wireVersion, hs.suite, serverHandshakeTrafficSecret, serverWrite)
+	c.useInTrafficSecret(c.wireVersion, hs.suite, serverHandshakeTrafficSecret)
 
 	msg, err := c.readHandshake()
 	if err != nil {
@@ -888,7 +888,7 @@
 
 	// Switch to application data keys on read. In particular, any alerts
 	// from the client certificate are read over these keys.
-	c.in.useTrafficSecret(c.wireVersion, hs.suite, serverTrafficSecret, serverWrite)
+	c.useInTrafficSecret(c.wireVersion, hs.suite, serverTrafficSecret)
 
 	// If we're expecting 0.5-RTT messages from the server, read them
 	// now.
@@ -934,7 +934,7 @@
 		c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
 	}
 
-	c.out.useTrafficSecret(c.wireVersion, hs.suite, clientHandshakeTrafficSecret, clientWrite)
+	c.useOutTrafficSecret(c.wireVersion, hs.suite, clientHandshakeTrafficSecret)
 
 	if certReq != nil && !c.config.Bugs.SkipClientCertificate {
 		certMsg := &certificateMsg{
@@ -1020,7 +1020,7 @@
 	c.flushHandshake()
 
 	// Switch to application data keys.
-	c.out.useTrafficSecret(c.wireVersion, hs.suite, clientTrafficSecret, clientWrite)
+	c.useOutTrafficSecret(c.wireVersion, hs.suite, clientTrafficSecret)
 
 	c.resumptionSecret = hs.finishedHash.deriveSecret(resumptionLabel)
 	return nil
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 0ffb72c..251e91f 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -664,7 +664,7 @@
 		}
 		if encryptedExtensions.extensions.hasEarlyData {
 			earlyTrafficSecret := hs.finishedHash.deriveSecret(earlyTrafficLabel)
-			c.in.useTrafficSecret(c.wireVersion, hs.suite, earlyTrafficSecret, clientWrite)
+			c.useInTrafficSecret(c.wireVersion, hs.suite, earlyTrafficSecret)
 
 			for _, expectedMsg := range config.Bugs.ExpectEarlyData {
 				if err := c.readRecord(recordTypeApplicationData); err != nil {
@@ -761,7 +761,7 @@
 
 	// Switch to handshake traffic keys.
 	serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel)
-	c.out.useTrafficSecret(c.wireVersion, hs.suite, serverHandshakeTrafficSecret, serverWrite)
+	c.useOutTrafficSecret(c.wireVersion, hs.suite, serverHandshakeTrafficSecret)
 	// Derive handshake traffic read key, but don't switch yet.
 	clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel)
 
@@ -902,7 +902,7 @@
 
 	// Switch to application data keys on write. In particular, any alerts
 	// from the client certificate are sent over these keys.
-	c.out.useTrafficSecret(c.wireVersion, hs.suite, serverTrafficSecret, serverWrite)
+	c.useOutTrafficSecret(c.wireVersion, hs.suite, serverTrafficSecret)
 
 	// Send 0.5-RTT messages.
 	for _, halfRTTMsg := range config.Bugs.SendHalfRTTData {
@@ -928,7 +928,7 @@
 	}
 
 	// Switch input stream to handshake traffic keys.
-	c.in.useTrafficSecret(c.wireVersion, hs.suite, clientHandshakeTrafficSecret, clientWrite)
+	c.useInTrafficSecret(c.wireVersion, hs.suite, clientHandshakeTrafficSecret)
 
 	// If we requested a client certificate, then the client must send a
 	// certificate message, even if it's empty.
@@ -1032,7 +1032,7 @@
 	hs.writeClientHash(clientFinished.marshal())
 
 	// Switch to application data keys on read.
-	c.in.useTrafficSecret(c.wireVersion, hs.suite, clientTrafficSecret, clientWrite)
+	c.useInTrafficSecret(c.wireVersion, hs.suite, clientTrafficSecret)
 
 	c.cipherSuite = hs.suite
 	c.resumptionSecret = hs.finishedHash.deriveSecret(resumptionLabel)
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index fe9afee..12d85a1 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -2761,11 +2761,12 @@
 			},
 		},
 		{
-			// Test the server so there is a large certificate as
-			// well as application data.
+			// Test the TLS 1.2 server so there is a large
+			// unencrypted certificate as well as application data.
 			testType: serverTest,
-			name:     "MaxSendFragment",
+			name:     "MaxSendFragment-TLS12",
 			config: Config{
+				MaxVersion: VersionTLS12,
 				Bugs: ProtocolBugs{
 					MaxReceivePlaintext: 512,
 				},
@@ -2777,11 +2778,12 @@
 			},
 		},
 		{
-			// Test the server so there is a large certificate as
-			// well as application data.
+			// Test the TLS 1.2 server so there is a large
+			// unencrypted certificate as well as application data.
 			testType: serverTest,
-			name:     "MaxSendFragment-TooLarge",
+			name:     "MaxSendFragment-TLS12-TooLarge",
 			config: Config{
+				MaxVersion: VersionTLS12,
 				Bugs: ProtocolBugs{
 					// Ensure that some of the records are
 					// 512.
@@ -2797,6 +2799,57 @@
 			expectedLocalError: "local error: record overflow",
 		},
 		{
+			// Test the TLS 1.3 server so there is a large encrypted
+			// certificate as well as application data.
+			testType: serverTest,
+			name:     "MaxSendFragment-TLS13",
+			config: Config{
+				MaxVersion: VersionTLS13,
+				Bugs: ProtocolBugs{
+					MaxReceivePlaintext:            512,
+					ExpectPackedEncryptedHandshake: 512,
+				},
+			},
+			messageLen: 1024,
+			flags: []string{
+				"-max-send-fragment", "512",
+				"-read-size", "1024",
+			},
+		},
+		{
+			// Test the TLS 1.3 server so there is a large encrypted
+			// certificate as well as application data.
+			testType: serverTest,
+			name:     "MaxSendFragment-TLS13-TooLarge",
+			config: Config{
+				MaxVersion: VersionTLS13,
+				Bugs: ProtocolBugs{
+					// Ensure that some of the records are
+					// 512.
+					MaxReceivePlaintext: 511,
+				},
+			},
+			messageLen: 1024,
+			flags: []string{
+				"-max-send-fragment", "512",
+				"-read-size", "1024",
+			},
+			shouldFail:         true,
+			expectedLocalError: "local error: record overflow",
+		},
+		{
+			// Test that, by default, handshake data is tightly
+			// packed in TLS 1.3.
+			testType: serverTest,
+			name:     "PackedEncryptedHandshake-TLS13",
+			config: Config{
+				MaxVersion: VersionTLS13,
+				Bugs: ProtocolBugs{
+					ExpectPackedEncryptedHandshake: 16384,
+				},
+			},
+		},
+		{
 			// Test that DTLS can handle multiple application data
 			// records in a single packet.
 			protocol: dtls,
diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc
index 8aeb489..2dd27fc 100644
--- a/ssl/tls_method.cc
+++ b/ssl/tls_method.cc
@@ -92,14 +92,16 @@
   }
 
   OPENSSL_memset(ssl->s3->read_sequence, 0, sizeof(ssl->s3->read_sequence));
-
   ssl->s3->aead_read_ctx = std::move(aead_ctx);
   return true;
 }
 
 static bool ssl3_set_write_state(SSL *ssl, UniquePtr<SSLAEADContext> aead_ctx) {
-  OPENSSL_memset(ssl->s3->write_sequence, 0, sizeof(ssl->s3->write_sequence));
+  if (!tls_flush_pending_hs_data(ssl)) {
+    return false;
+  }
 
+  OPENSSL_memset(ssl->s3->write_sequence, 0, sizeof(ssl->s3->write_sequence));
   ssl->s3->aead_write_ctx = std::move(aead_ctx);
   return true;
 }