Adding support for receiving early data on the server.

BUG=76

Change-Id: Ie894ea5d327f88e66b234767de437dbe5c67c41d
Reviewed-on: https://boringssl-review.googlesource.com/12960
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/dtls_method.c b/ssl/dtls_method.c
index 806212d..6084789 100644
--- a/ssl/dtls_method.c
+++ b/ssl/dtls_method.c
@@ -145,7 +145,6 @@
     dtls1_release_current_message,
     dtls1_read_app_data,
     dtls1_read_change_cipher_spec,
-    NULL,
     dtls1_read_close_notify,
     dtls1_write_app_data,
     dtls1_dispatch_alert,
diff --git a/ssl/handshake_client.c b/ssl/handshake_client.c
index 3898c1b..7eddd35 100644
--- a/ssl/handshake_client.c
+++ b/ssl/handshake_client.c
@@ -406,6 +406,7 @@
       case SSL3_ST_FALSE_START:
         hs->state = SSL3_ST_CR_SESSION_TICKET_A;
         hs->in_false_start = 1;
+        hs->can_early_write = 1;
         ret = 1;
         goto end;
 
@@ -457,13 +458,21 @@
         }
         break;
 
-      case SSL_ST_TLS13:
-        ret = tls13_handshake(hs);
+      case SSL_ST_TLS13: {
+        int early_return = 0;
+        ret = tls13_handshake(hs, &early_return);
         if (ret <= 0) {
           goto end;
         }
+
+        if (early_return) {
+          ret = 1;
+          goto end;
+        }
+
         hs->state = SSL3_ST_FINISH_CLIENT_HANDSHAKE;
         break;
+      }
 
       case SSL3_ST_FINISH_CLIENT_HANDSHAKE:
         ssl->method->release_current_message(ssl, 1 /* free_buffer */);
diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c
index fd6c8e9..a1341d6 100644
--- a/ssl/handshake_server.c
+++ b/ssl/handshake_server.c
@@ -448,13 +448,21 @@
         }
         break;
 
-      case SSL_ST_TLS13:
-        ret = tls13_handshake(hs);
+      case SSL_ST_TLS13: {
+        int early_return = 0;
+        ret = tls13_handshake(hs, &early_return);
         if (ret <= 0) {
           goto end;
         }
+
+        if (early_return) {
+          ret = 1;
+          goto end;
+        }
+
         hs->state = SSL_ST_OK;
         break;
+      }
 
       case SSL_ST_OK:
         ssl->method->release_current_message(ssl, 1 /* free_buffer */);
diff --git a/ssl/internal.h b/ssl/internal.h
index 5ef4094..b405fb6 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1111,6 +1111,14 @@
   /* early_data_offered is one if the client sent the early_data extension. */
   unsigned early_data_offered:1;
 
+  /* can_early_read is one if application data may be read at this point in the
+   * handshake. */
+  unsigned can_early_read:1;
+
+  /* can_early_write is one if application data may be written at this point in
+   * the handshake. */
+  unsigned can_early_write:1;
+
   /* next_proto_neg_seen is one of NPN was negotiated. */
   unsigned next_proto_neg_seen:1;
 
@@ -1139,8 +1147,9 @@
 int ssl_check_message_type(SSL *ssl, int type);
 
 /* tls13_handshake runs the TLS 1.3 handshake. It returns one on success and <=
- * 0 on error. */
-int tls13_handshake(SSL_HANDSHAKE *hs);
+ * 0 on error. It sets |out_early_return| to one if we've completed the
+ * handshake early. */
+int tls13_handshake(SSL_HANDSHAKE *hs, int *out_early_return);
 
 /* The following are implementations of |do_tls13_handshake| for the client and
  * server. */
@@ -1413,7 +1422,6 @@
   int (*read_app_data)(SSL *ssl, int *out_got_handshake, uint8_t *buf, int len,
                        int peek);
   int (*read_change_cipher_spec)(SSL *ssl);
-  int (*read_end_of_early_data)(SSL *ssl);
   void (*read_close_notify)(SSL *ssl);
   int (*write_app_data)(SSL *ssl, const uint8_t *buf, int len);
   int (*dispatch_alert)(SSL *ssl);
@@ -2231,6 +2239,12 @@
  * otherwise. */
 int ssl3_can_false_start(const SSL *ssl);
 
+/* ssl_can_write returns one if |ssl| is allowed to write and zero otherwise. */
+int ssl_can_write(const SSL *ssl);
+
+/* ssl_can_read returns one if |ssl| is allowed to read and zero otherwise. */
+int ssl_can_read(const SSL *ssl);
+
 /* ssl_get_version_range sets |*out_min_version| and |*out_max_version| to the
  * minimum and maximum enabled protocol versions, respectively. */
 int ssl_get_version_range(const SSL *ssl, uint16_t *out_min_version,
diff --git a/ssl/s3_pkt.c b/ssl/s3_pkt.c
index 69696ed..42fffb1 100644
--- a/ssl/s3_pkt.c
+++ b/ssl/s3_pkt.c
@@ -189,7 +189,7 @@
 }
 
 int ssl3_write_app_data(SSL *ssl, const uint8_t *buf, int len) {
-  assert(!SSL_in_init(ssl) || SSL_in_false_start(ssl));
+  assert(ssl_can_write(ssl));
 
   unsigned tot, n, nw;
 
@@ -325,10 +325,11 @@
 
 int ssl3_read_app_data(SSL *ssl, int *out_got_handshake, uint8_t *buf, int len,
                        int peek) {
-  assert(!SSL_in_init(ssl));
-  assert(ssl->s3->initial_handshake_complete);
+  assert(ssl_can_read(ssl));
   *out_got_handshake = 0;
 
+  ssl->method->release_current_message(ssl, 0 /* don't free buffer */);
+
   SSL3_RECORD *rr = &ssl->s3->rrec;
 
   for (;;) {
@@ -345,6 +346,14 @@
     }
 
     if (has_hs_data || rr->type == SSL3_RT_HANDSHAKE) {
+      /* If reading 0-RTT data, reject handshake data. 0-RTT data is terminated
+       * by an alert. */
+      if (SSL_in_init(ssl)) {
+        OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
+        ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
+        return -1;
+      }
+
       /* Post-handshake data prior to TLS 1.3 is always renegotiation, which we
        * never accept as a server. Otherwise |ssl3_get_message| will send
        * |SSL_R_EXCESSIVE_MESSAGE_SIZE|. */
@@ -363,6 +372,20 @@
       return -1;
     }
 
+    if (rr->type == SSL3_RT_ALERT &&
+        ssl->server &&
+        ssl->s3->hs != NULL &&
+        ssl->s3->hs->can_early_read &&
+        ssl3_protocol_version(ssl) >= TLS1_3_VERSION) {
+      int ret = ssl3_read_end_of_early_data(ssl);
+      if (ret <= 0) {
+        return ret;
+      }
+      ssl->s3->hs->can_early_read = 0;
+      *out_got_handshake = 1;
+      return -1;
+    }
+
     if (rr->type != SSL3_RT_APPLICATION_DATA) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
       ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index 2405306..d01f6a2 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -613,6 +613,14 @@
   return SSL_do_handshake(ssl);
 }
 
+int ssl_can_write(const SSL *ssl) {
+  return !SSL_in_init(ssl) || ssl->s3->hs->can_early_write;
+}
+
+int ssl_can_read(const SSL *ssl) {
+  return !SSL_in_init(ssl) || ssl->s3->hs->can_early_read;
+}
+
 static int ssl_do_renegotiate(SSL *ssl) {
   /* We do not accept renegotiations as a server or SSL 3.0. SSL 3.0 will be
    * removed entirely in the future and requires retaining more data for
@@ -693,7 +701,7 @@
     /* Complete the current handshake, if any. False Start will cause
      * |SSL_do_handshake| to return mid-handshake, so this may require multiple
      * iterations. */
-    while (SSL_in_init(ssl)) {
+    while (!ssl_can_read(ssl)) {
       int ret = SSL_do_handshake(ssl);
       if (ret < 0) {
         return ret;
@@ -711,6 +719,12 @@
       return ret;
     }
 
+    /* If we received an interrupt in early read (the end_of_early_data alert),
+     * loop again for the handshake to process it. */
+    if (SSL_in_init(ssl)) {
+      continue;
+    }
+
     /* Handle the post-handshake message and try again. */
     if (!ssl_do_post_handshake(ssl)) {
       return -1;
@@ -741,7 +755,7 @@
   }
 
   /* If necessary, complete the handshake implicitly. */
-  if (SSL_in_init(ssl) && !SSL_in_false_start(ssl)) {
+  if (!ssl_can_write(ssl)) {
     int ret = SSL_do_handshake(ssl);
     if (ret < 0) {
       return ret;
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index f04ed1f..21497fd 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -1368,7 +1368,9 @@
     return false;
   }
 
-  bool expect_handshake_done = is_resume || !config->false_start;
+  bool expect_handshake_done =
+      (is_resume || !config->false_start) &&
+      !(config->is_server && SSL_early_data_accepted(ssl));
   if (expect_handshake_done != GetTestState(ssl)->handshake_done) {
     fprintf(stderr, "handshake was%s completed\n",
             GetTestState(ssl)->handshake_done ? "" : " not");
@@ -2005,8 +2007,9 @@
         }
 
         // After a successful read, with or without False Start, the handshake
-        // must be complete.
-        if (!GetTestState(ssl.get())->handshake_done) {
+        // must be complete unless we are doing early data.
+        if (!GetTestState(ssl.get())->handshake_done &&
+            !SSL_early_data_accepted(ssl.get())) {
           fprintf(stderr, "handshake was not completed after SSL_read\n");
           return false;
         }
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index db4bcf6..95dcbd0 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -1137,6 +1137,10 @@
 	// send after the ClientHello.
 	SendFakeEarlyDataLength int
 
+	// SendStrayEarlyHandshake, if non-zero, causes the client to send a stray
+	// handshake record before sending end of early data.
+	SendStrayEarlyHandshake bool
+
 	// OmitEarlyDataExtension, if true, causes the early data extension to
 	// be omitted in the ClientHello.
 	OmitEarlyDataExtension bool
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index d3ae110..d73722c 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -392,7 +392,6 @@
 		finishedHash.Write(helloBytes)
 		earlyTrafficSecret := finishedHash.deriveSecret(earlyTrafficLabel)
 		c.out.useTrafficSecret(session.vers, pskCipherSuite, earlyTrafficSecret, clientWrite)
-
 		for _, earlyData := range c.config.Bugs.SendEarlyData {
 			if _, err := c.writeRecord(recordTypeApplicationData, earlyData); err != nil {
 				return err
@@ -892,6 +891,10 @@
 	// Send EndOfEarlyData and then switch write key to handshake
 	// traffic key.
 	if c.out.cipher != nil && !c.config.Bugs.SkipEndOfEarlyData {
+		if c.config.Bugs.SendStrayEarlyHandshake {
+			helloRequest := new(helloRequestMsg)
+			c.writeRecord(recordTypeHandshake, helloRequest.marshal())
+		}
 		c.sendAlert(alertEndOfEarlyData)
 	}
 	c.out.useTrafficSecret(c.vers, hs.suite, clientHandshakeTrafficSecret, clientWrite)
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index e19df1a..3aa2c46 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -3584,13 +3584,16 @@
 				MaxVersion: VersionTLS13,
 				MinVersion: VersionTLS13,
 				Bugs: ProtocolBugs{
-					SendEarlyData:           [][]byte{},
+					SendEarlyData:           [][]byte{{1, 2, 3, 4}},
 					ExpectEarlyDataAccepted: true,
+					ExpectHalfRTTData:       [][]byte{{254, 253, 252, 251}},
 				},
 			},
+			messageCount:  2,
 			resumeSession: true,
 			flags: []string{
 				"-enable-early-data",
+				"-expect-accept-early-data",
 			},
 		})
 	}
@@ -9982,39 +9985,6 @@
 		},
 	})
 
-	// Test that we accept data-less early data.
-	testCases = append(testCases, testCase{
-		testType: serverTest,
-		name:     "TLS13-DataLessEarlyData-Server",
-		config: Config{
-			MaxVersion: VersionTLS13,
-			Bugs: ProtocolBugs{
-				SendEarlyData:           [][]byte{},
-				ExpectEarlyDataAccepted: true,
-			},
-		},
-		resumeSession: true,
-		flags: []string{
-			"-enable-early-data",
-			"-expect-accept-early-data",
-		},
-	})
-
-	testCases = append(testCases, testCase{
-		testType: clientTest,
-		name:     "TLS13-DataLessEarlyData-Client",
-		config: Config{
-			MaxVersion:       VersionTLS13,
-			MaxEarlyDataSize: 16384,
-		},
-		resumeSession: true,
-		flags: []string{
-			"-enable-early-data",
-			"-expect-early-data-info",
-			"-expect-accept-early-data",
-		},
-	})
-
 	testCases = append(testCases, testCase{
 		testType: clientTest,
 		name:     "TLS13-DataLessEarlyData-Reject-Client",
@@ -10254,7 +10224,7 @@
 		resumeConfig: &Config{
 			MaxVersion: VersionTLS13,
 			Bugs: ProtocolBugs{
-				SendEarlyData:           [][]byte{{}},
+				SendEarlyData:           [][]byte{{1, 2, 3, 4}},
 				ExpectEarlyDataAccepted: false,
 			},
 		},
@@ -10278,7 +10248,7 @@
 			MaxVersion: VersionTLS13,
 			NextProtos: []string{"foo"},
 			Bugs: ProtocolBugs{
-				SendEarlyData:           [][]byte{{}},
+				SendEarlyData:           [][]byte{{1, 2, 3, 4}},
 				ExpectEarlyDataAccepted: false,
 			},
 		},
@@ -10303,7 +10273,7 @@
 			MaxVersion: VersionTLS13,
 			NextProtos: []string{},
 			Bugs: ProtocolBugs{
-				SendEarlyData:           [][]byte{{}},
+				SendEarlyData:           [][]byte{{1, 2, 3, 4}},
 				ExpectEarlyDataAccepted: false,
 			},
 		},
@@ -10327,7 +10297,7 @@
 			MaxVersion: VersionTLS13,
 			NextProtos: []string{"bar"},
 			Bugs: ProtocolBugs{
-				SendEarlyData:           [][]byte{{}},
+				SendEarlyData:           [][]byte{{1, 2, 3, 4}},
 				ExpectEarlyDataAccepted: false,
 			},
 		},
@@ -10388,7 +10358,7 @@
 		config: Config{
 			MaxVersion: VersionTLS13,
 			Bugs: ProtocolBugs{
-				SendEarlyData:           [][]byte{},
+				SendEarlyData:           [][]byte{{1, 2, 3, 4}},
 				ExpectEarlyDataAccepted: true,
 				SkipEndOfEarlyData:      true,
 			},
@@ -10399,6 +10369,28 @@
 		expectedLocalError: "remote error: bad record MAC",
 		expectedError:      ":BAD_DECRYPT:",
 	})
+
+	testCases = append(testCases, testCase{
+		testType: serverTest,
+		name:     "TLS13-EarlyData-UnexpectedHandshake-Server",
+		config: Config{
+			MaxVersion: VersionTLS13,
+		},
+		resumeConfig: &Config{
+			MaxVersion: VersionTLS13,
+			Bugs: ProtocolBugs{
+				SendEarlyData:           [][]byte{{1, 2, 3, 4}},
+				SendStrayEarlyHandshake: true,
+				ExpectEarlyDataAccepted: true},
+		},
+		resumeSession:      true,
+		shouldFail:         true,
+		expectedError:      ":UNEXPECTED_RECORD:",
+		expectedLocalError: "remote error: unexpected message",
+		flags: []string{
+			"-enable-early-data",
+		},
+	})
 }
 
 func addTLS13CipherPreferenceTests() {
diff --git a/ssl/tls13_both.c b/ssl/tls13_both.c
index e334a6c..ec67cdc 100644
--- a/ssl/tls13_both.c
+++ b/ssl/tls13_both.c
@@ -33,7 +33,7 @@
  * without being able to return application data. */
 static const uint8_t kMaxKeyUpdates = 32;
 
-int tls13_handshake(SSL_HANDSHAKE *hs) {
+int tls13_handshake(SSL_HANDSHAKE *hs, int *out_early_return) {
   SSL *const ssl = hs->ssl;
   for (;;) {
     /* Resolve the operation the handshake was waiting on. */
@@ -65,10 +65,12 @@
       }
 
       case ssl_hs_read_end_of_early_data: {
-        int ret = ssl->method->read_end_of_early_data(ssl);
-        if (ret <= 0) {
-          return ret;
+        if (ssl->s3->hs->can_early_read) {
+          /* While we are processing early data, the handshake returns early. */
+          *out_early_return = 1;
+          return 1;
         }
+        hs->wait = ssl_hs_ok;
         break;
       }
 
diff --git a/ssl/tls13_server.c b/ssl/tls13_server.c
index dbb44d2..35ee4f7 100644
--- a/ssl/tls13_server.c
+++ b/ssl/tls13_server.c
@@ -673,6 +673,8 @@
                                hs->hash_len)) {
       return ssl_hs_error;
     }
+    hs->can_early_write = 1;
+    hs->can_early_read = 1;
     hs->tls13_state = state_process_end_of_early_data;
     return ssl_hs_read_end_of_early_data;
   }
diff --git a/ssl/tls_method.c b/ssl/tls_method.c
index 2af4f2c..6144f86 100644
--- a/ssl/tls_method.c
+++ b/ssl/tls_method.c
@@ -141,7 +141,6 @@
     ssl3_release_current_message,
     ssl3_read_app_data,
     ssl3_read_change_cipher_spec,
-    ssl3_read_end_of_early_data,
     ssl3_read_close_notify,
     ssl3_write_app_data,
     ssl3_dispatch_alert,