Adding support for sending early data on the client.

BUG=76

Change-Id: If58a73da38e46549fd55f84a9104e2dfebfda43f
Reviewed-on: https://boringssl-review.googlesource.com/14164
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Steven Valdez <svaldez@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 55e53da..462e4ce 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -512,6 +512,12 @@
  * See also |SSL_CTX_set_ticket_aead_method|. */
 #define SSL_ERROR_PENDING_TICKET 14
 
+/* SSL_ERROR_EARLY_DATA_REJECTED indicates that early data was rejected. The
+ * caller should treat this as a connection failure and retry any operations
+ * associated with the rejected early data. |SSL_reset_early_data_reject| may be
+ * used to reuse the underlying connection for the retry. */
+#define SSL_ERROR_EARLY_DATA_REJECTED 15
+
 /* SSL_set_mtu sets the |ssl|'s MTU in DTLS to |mtu|. It returns one on success
  * and zero on failure. */
 OPENSSL_EXPORT int SSL_set_mtu(SSL *ssl, unsigned mtu);
@@ -2833,6 +2839,86 @@
 OPENSSL_EXPORT const char *SSL_get_psk_identity(const SSL *ssl);
 
 
+/* Early data.
+ *
+ * WARNING: 0-RTT support in BoringSSL is currently experimental and not fully
+ * implemented. It may cause interoperability or security failures when used.
+ *
+ * Early data, or 0-RTT, is a feature in TLS 1.3 which allows clients to send
+ * data on the first flight during a resumption handshake. This can save a
+ * round-trip in some application protocols.
+ *
+ * WARNING: A 0-RTT handshake has different security properties from normal
+ * handshake, so it is off by default unless opted in. In particular, early data
+ * is replayable by a network attacker. Callers must account for this when
+ * sending or processing data before the handshake is confirmed. See
+ * draft-ietf-tls-tls13-18 for more information.
+ *
+ * As a server, if early data is accepted, |SSL_do_handshake| will complete as
+ * soon as the ClientHello is processed and server flight sent. |SSL_write| may
+ * be used to send half-RTT data. |SSL_read| will consume early data and
+ * transition to 1-RTT data as appropriate. Prior to the transition,
+ * |SSL_in_init| will report the handshake is still in progress. Callers may use
+ * it or |SSL_in_early_data| to defer or reject requests as needed.
+ *
+ * Early data as a client is more complex. If the offered session (see
+ * |SSL_set_session|) is 0-RTT-capable, the handshake will return after sending
+ * the ClientHello. The predicted peer certificate and ALPN protocol will be
+ * available via the usual APIs. |SSL_write| will write early data, up to the
+ * session's limit. Writes past this limit and |SSL_read| will complete the
+ * handshake before continuing. Callers may also call |SSL_do_handshake| again
+ * to complete the handshake sooner.
+ *
+ * If the server accepts early data, the handshake will succeed. |SSL_read| and
+ * |SSL_write| will then act as in a 1-RTT handshake. The peer certificate and
+ * ALPN protocol will be as predicted and need not be re-queried.
+ *
+ * If the server rejects early data, |SSL_do_handshake| (and thus |SSL_read| and
+ * |SSL_write|) will then fail with |SSL_get_error| returning
+ * |SSL_ERROR_EARLY_DATA_REJECTED|. The caller should treat this as a connection
+ * error and most likely perform a high-level retry. Note the server may still
+ * have processed the early data due to attacker replays.
+ *
+ * To then continue the handshake on the original connection, use
+ * |SSL_reset_early_data_reject|. This allows a faster retry than making a fresh
+ * connection. |SSL_do_handshake| will the complete the full handshake as in a
+ * fresh connection. Once reset, the peer certificate, ALPN protocol, and other
+ * properties may change so the caller must query them again.
+ *
+ * Finally, to implement the fallback described in draft-ietf-tls-tls13-18
+ * appendix C.3, retry on a fresh connection without 0-RTT if the handshake
+ * fails with |SSL_R_WRONG_VERSION_ON_EARLY_DATA|. */
+
+/* SSL_CTX_set_early_data_enabled sets whether early data is allowed to be used
+ * with resumptions using |ctx|. */
+OPENSSL_EXPORT void SSL_CTX_set_early_data_enabled(SSL_CTX *ctx, int enabled);
+
+/* SSL_set_early_data_enabled sets whether early data is allowed to be used
+ * with resumptions using |ssl|. See |SSL_CTX_set_early_data_enabled| for more
+ * information. */
+OPENSSL_EXPORT void SSL_set_early_data_enabled(SSL *ssl, int enabled);
+
+/* SSL_in_early_data returns one if |ssl| has a pending handshake that has
+ * progressed enough to send or receive early data. Clients may call |SSL_write|
+ * to send early data, but |SSL_read| will complete the handshake before
+ * accepting application data. Servers may call |SSL_read| to read early data
+ * and |SSL_write| to send half-RTT data. */
+OPENSSL_EXPORT int SSL_in_early_data(const SSL *ssl);
+
+/* SSL_early_data_accepted returns whether early data was accepted on the
+ * handshake performed by |ssl|. */
+OPENSSL_EXPORT int SSL_early_data_accepted(const SSL *ssl);
+
+/* SSL_reset_early_data_reject resets |ssl| after an early data reject. All
+ * 0-RTT state is discarded, including any pending |SSL_write| calls. The caller
+ * should treat |ssl| as a logically fresh connection, usually by driving the
+ * handshake to completion using |SSL_do_handshake|.
+ *
+ * It is an error to call this function on an |SSL| object that is not signaling
+ * |SSL_ERROR_EARLY_DATA_REJECTED|. */
+OPENSSL_EXPORT void SSL_reset_early_data_reject(SSL *ssl);
+
+
 /* Alerts.
  *
  * TLS and SSL 3.0 use alerts to signal error conditions. Alerts have a type
@@ -3061,32 +3147,6 @@
  * performed by |ssl|. This includes the pending renegotiation, if any. */
 OPENSSL_EXPORT int SSL_total_renegotiations(const SSL *ssl);
 
-/* SSL_CTX_set_early_data_enabled sets whether early data is allowed to be used
- * with resumptions using |ctx|.
- *
- * As a server, if the client's early data is accepted, |SSL_do_handshake| will
- * complete as soon as the ClientHello is processed and server flight sent.
- * |SSL_write| may be used to send half-RTT data. |SSL_read| will consume early
- * data and transition to 1-RTT data as appropriate.
- *
- * Note early data is replayable by a network attacker. |SSL_in_init| and
- * |SSL_is_init_finished| will report the handshake is still in progress until
- * the client's Finished message is received. Callers may use these functions
- * to defer some processing if desired.
- *
- * WARNING: This is experimental and may cause interoperability failures until
- * fully implemented. */
-OPENSSL_EXPORT void SSL_CTX_set_early_data_enabled(SSL_CTX *ctx, int enabled);
-
-/* SSL_set_early_data_enabled sets whether early data is allowed to be used
- * with resumptions using |ssl|. See |SSL_CTX_set_early_data_enabled| for more
- * information. */
-OPENSSL_EXPORT void SSL_set_early_data_enabled(SSL *ssl, int enabled);
-
-/* SSL_early_data_accepted returns whether early data was accepted on the
- * handshake performed by |ssl|. */
-OPENSSL_EXPORT int SSL_early_data_accepted(const SSL *ssl);
-
 /* SSL_MAX_CERT_LIST_DEFAULT is the default maximum length, in bytes, of a peer
  * certificate chain. */
 #define SSL_MAX_CERT_LIST_DEFAULT (1024 * 100)
@@ -3676,6 +3736,7 @@
 #define SSL_CERTIFICATE_SELECTION_PENDING 8
 #define SSL_PRIVATE_KEY_OPERATION 9
 #define SSL_PENDING_TICKET 10
+#define SSL_EARLY_DATA_REJECTED 11
 
 /* SSL_want returns one of the above values to determine what the most recent
  * operation on |ssl| was blocked on. Use |SSL_get_error| instead. */
diff --git a/include/openssl/ssl3.h b/include/openssl/ssl3.h
index b5f5d6a..25c31c1 100644
--- a/include/openssl/ssl3.h
+++ b/include/openssl/ssl3.h
@@ -302,6 +302,7 @@
 #define SSL3_ST_FALSE_START (0x101 | SSL_ST_CONNECT)
 #define SSL3_ST_VERIFY_SERVER_CERT (0x102 | SSL_ST_CONNECT)
 #define SSL3_ST_FINISH_CLIENT_HANDSHAKE (0x103 | SSL_ST_CONNECT)
+#define SSL3_ST_WRITE_EARLY_DATA (0x104 | SSL_ST_CONNECT)
 /* write to server */
 #define SSL3_ST_CW_CLNT_HELLO_A (0x110 | SSL_ST_CONNECT)
 /* read from server */
diff --git a/ssl/d1_pkt.c b/ssl/d1_pkt.c
index 3444825..e2c7315 100644
--- a/ssl/d1_pkt.c
+++ b/ssl/d1_pkt.c
@@ -335,8 +335,10 @@
   }
 }
 
-int dtls1_write_app_data(SSL *ssl, const uint8_t *buf, int len) {
+int dtls1_write_app_data(SSL *ssl, int *out_needs_handshake, const uint8_t *buf,
+                         int len) {
   assert(!SSL_in_init(ssl));
+  *out_needs_handshake = 0;
 
   if (len > SSL3_RT_MAX_PLAIN_LENGTH) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_DTLS_MESSAGE_TOO_BIG);
diff --git a/ssl/handshake_client.c b/ssl/handshake_client.c
index f204286..857cdf3 100644
--- a/ssl/handshake_client.c
+++ b/ssl/handshake_client.c
@@ -218,8 +218,10 @@
               ret = -1;
               goto end;
             }
+            hs->next_state = SSL3_ST_WRITE_EARLY_DATA;
+          } else {
+            hs->next_state = SSL3_ST_CR_SRVR_HELLO_A;
           }
-          hs->next_state = SSL3_ST_CR_SRVR_HELLO_A;
         } else {
           hs->next_state = DTLS1_ST_CR_HELLO_VERIFY_REQUEST_A;
         }
@@ -240,6 +242,18 @@
         }
         break;
 
+      case SSL3_ST_WRITE_EARLY_DATA:
+        /* Stash the early data session, so connection properties may be queried
+         * out of it. */
+        hs->in_early_data = 1;
+        hs->early_session = ssl->session;
+        SSL_SESSION_up_ref(ssl->session);
+
+        hs->state = SSL3_ST_CR_SRVR_HELLO_A;
+        hs->can_early_write = 1;
+        ret = 1;
+        goto end;
+
       case SSL3_ST_CR_SRVR_HELLO_A:
         ret = ssl3_get_server_hello(hs);
         if (hs->state == SSL_ST_TLS13) {
diff --git a/ssl/internal.h b/ssl/internal.h
index bf0ef02..450b812 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -925,6 +925,7 @@
   ssl_hs_channel_id_lookup,
   ssl_hs_private_key_operation,
   ssl_hs_pending_ticket,
+  ssl_hs_early_data_rejected,
   ssl_hs_read_end_of_early_data,
 };
 
@@ -1057,6 +1058,10 @@
    * handshake. It should not be cached. */
   SSL_SESSION *new_session;
 
+  /* early_session is the session corresponding to the current 0-RTT state on
+   * the client if |in_early_data| is true. */
+  SSL_SESSION *early_session;
+
   /* new_cipher is the cipher being negotiated in this handshake. */
   const SSL_CIPHER *new_cipher;
 
@@ -1097,6 +1102,10 @@
    * Start. The client may write data at this point. */
   unsigned in_false_start:1;
 
+  /* in_early_data is one if there is a pending handshake that has progressed
+   * enough to send and receive early data. */
+  unsigned in_early_data:1;
+
   /* early_data_offered is one if the client sent the early_data extension. */
   unsigned early_data_offered:1;
 
@@ -1128,6 +1137,10 @@
   /* early_data_read is the amount of early data that has been read by the
    * record layer. */
   uint16_t early_data_read;
+
+  /* early_data_written is the amount of early data that has been written by the
+   * record layer. */
+  uint16_t early_data_written;
 } /* SSL_HANDSHAKE */;
 
 SSL_HANDSHAKE *ssl_handshake_new(SSL *ssl);
@@ -1421,7 +1434,8 @@
                        int peek);
   int (*read_change_cipher_spec)(SSL *ssl);
   void (*read_close_notify)(SSL *ssl);
-  int (*write_app_data)(SSL *ssl, const uint8_t *buf, int len);
+  int (*write_app_data)(SSL *ssl, int *out_needs_handshake, const uint8_t *buf,
+                        int len);
   int (*dispatch_alert)(SSL *ssl);
   /* supports_cipher returns one if |cipher| is supported by this protocol and
    * zero otherwise. */
@@ -1632,6 +1646,9 @@
    * outstanding. */
   unsigned key_update_pending:1;
 
+  /* wpend_pending is one if we have a pending write outstanding. */
+  unsigned wpend_pending:1;
+
   uint8_t send_alert[2];
 
   /* pending_flight is the pending outgoing flight. This is used to flush each
@@ -2088,7 +2105,8 @@
 int ssl3_read_change_cipher_spec(SSL *ssl);
 void ssl3_read_close_notify(SSL *ssl);
 int ssl3_read_handshake_bytes(SSL *ssl, uint8_t *buf, int len);
-int ssl3_write_app_data(SSL *ssl, const uint8_t *buf, int len);
+int ssl3_write_app_data(SSL *ssl, int *out_needs_handshake, const uint8_t *buf,
+                        int len);
 int ssl3_output_cert_chain(SSL *ssl);
 
 int ssl3_new(SSL *ssl);
@@ -2129,7 +2147,8 @@
 int dtls1_read_change_cipher_spec(SSL *ssl);
 void dtls1_read_close_notify(SSL *ssl);
 
-int dtls1_write_app_data(SSL *ssl, const uint8_t *buf, int len);
+int dtls1_write_app_data(SSL *ssl, int *out_needs_handshake, const uint8_t *buf,
+                         int len);
 
 /* dtls1_write_record sends a record. It returns one on success and <= 0 on
  * error. */
diff --git a/ssl/s3_both.c b/ssl/s3_both.c
index e05a16e..1b6c3c7 100644
--- a/ssl/s3_both.c
+++ b/ssl/s3_both.c
@@ -168,6 +168,7 @@
   OPENSSL_free(hs->key_share_bytes);
   OPENSSL_free(hs->ecdh_public_key);
   SSL_SESSION_free(hs->new_session);
+  SSL_SESSION_free(hs->early_session);
   OPENSSL_free(hs->peer_sigalgs);
   OPENSSL_free(hs->peer_supported_group_list);
   OPENSSL_free(hs->peer_key);
@@ -331,12 +332,14 @@
     return -1;
   }
 
-  /* The handshake flight buffer is mutually exclusive with application data.
-   *
-   * TODO(davidben): This will not be true when closure alerts use this. */
+  /* If there is pending data in the write buffer, it must be flushed out before
+   * any new data in pending_flight. */
   if (ssl_write_buffer_is_pending(ssl)) {
-    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-    return -1;
+    int ret = ssl_write_buffer_flush(ssl);
+    if (ret <= 0) {
+      ssl->rwstate = SSL_WRITING;
+      return ret;
+    }
   }
 
   /* Write the pending flight. */
diff --git a/ssl/s3_pkt.c b/ssl/s3_pkt.c
index 23b39f2..b624135 100644
--- a/ssl/s3_pkt.c
+++ b/ssl/s3_pkt.c
@@ -188,10 +188,13 @@
   return -1;
 }
 
-int ssl3_write_app_data(SSL *ssl, const uint8_t *buf, int len) {
+int ssl3_write_app_data(SSL *ssl, int *out_needs_handshake, const uint8_t *buf,
+                        int len) {
   assert(ssl_can_write(ssl));
   assert(ssl->s3->aead_write_ctx != NULL);
 
+  *out_needs_handshake = 0;
+
   unsigned tot, n, nw;
 
   assert(ssl->s3->wnum <= INT_MAX);
@@ -210,11 +213,25 @@
     return -1;
   }
 
+  const int is_early_data_write =
+      !ssl->server && SSL_in_early_data(ssl) && ssl->s3->hs->can_early_write;
+
   n = len - tot;
   for (;;) {
     /* max contains the maximum number of bytes that we can put into a
      * record. */
     unsigned max = ssl->max_send_fragment;
+    if (is_early_data_write && max > ssl->session->ticket_max_early_data -
+                                         ssl->s3->hs->early_data_written) {
+      max = ssl->session->ticket_max_early_data - ssl->s3->hs->early_data_written;
+      if (max == 0) {
+        ssl->s3->wnum = tot;
+        ssl->s3->hs->can_early_write = 0;
+        *out_needs_handshake = 1;
+        return -1;
+      }
+    }
+
     if (n > max) {
       nw = max;
     } else {
@@ -227,6 +244,10 @@
       return ret;
     }
 
+    if (is_early_data_write) {
+      ssl->s3->hs->early_data_written += ret;
+    }
+
     if (ret == (int)n || (ssl->mode & SSL_MODE_ENABLE_PARTIAL_WRITE)) {
       return tot + ret;
     }
@@ -250,13 +271,14 @@
   if (ret <= 0) {
     return ret;
   }
+  ssl->s3->wpend_pending = 0;
   return ssl->s3->wpend_ret;
 }
 
 /* do_ssl3_write writes an SSL record of the given type. */
 static int do_ssl3_write(SSL *ssl, int type, const uint8_t *buf, unsigned len) {
   /* If there is still data from the previous record, flush it. */
-  if (ssl_write_buffer_is_pending(ssl)) {
+  if (ssl->s3->wpend_pending) {
     return ssl3_write_pending(ssl, type, buf, len);
   }
 
@@ -317,6 +339,7 @@
   ssl->s3->wpend_buf = buf;
   ssl->s3->wpend_type = type;
   ssl->s3->wpend_ret = len;
+  ssl->s3->wpend_pending = 1;
 
   /* we now just need to write the buffer */
   return ssl3_write_pending(ssl, type, buf, len);
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index b3e397d..d936633 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -757,19 +757,23 @@
     return -1;
   }
 
-  /* If necessary, complete the handshake implicitly. */
-  if (!ssl_can_write(ssl)) {
-    int ret = SSL_do_handshake(ssl);
-    if (ret < 0) {
-      return ret;
+  int ret = 0, needs_handshake = 0;
+  do {
+    /* If necessary, complete the handshake implicitly. */
+    if (!ssl_can_write(ssl)) {
+      ret = SSL_do_handshake(ssl);
+      if (ret < 0) {
+        return ret;
+      }
+      if (ret == 0) {
+        OPENSSL_PUT_ERROR(SSL, SSL_R_SSL_HANDSHAKE_FAILURE);
+        return -1;
+      }
     }
-    if (ret == 0) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_SSL_HANDSHAKE_FAILURE);
-      return -1;
-    }
-  }
 
-  return ssl->method->write_app_data(ssl, buf, num);
+    ret = ssl->method->write_app_data(ssl, &needs_handshake, buf, num);
+  } while (needs_handshake);
+  return ret;
 }
 
 int SSL_shutdown(SSL *ssl) {
@@ -842,10 +846,35 @@
   ssl->cert->enable_early_data = !!enabled;
 }
 
+int SSL_in_early_data(const SSL *ssl) {
+  if (ssl->s3->hs == NULL) {
+    return 0;
+  }
+  return ssl->s3->hs->in_early_data;
+}
+
 int SSL_early_data_accepted(const SSL *ssl) {
   return ssl->early_data_accepted;
 }
 
+void SSL_reset_early_data_reject(SSL *ssl) {
+  SSL_HANDSHAKE *hs = ssl->s3->hs;
+  if (hs == NULL ||
+      hs->wait != ssl_hs_early_data_rejected) {
+    abort();
+  }
+
+  hs->wait = ssl_hs_ok;
+  hs->in_early_data = 0;
+  SSL_SESSION_free(hs->early_session);
+  hs->early_session = NULL;
+
+  /* Discard any unfinished writes from the perspective of |SSL_write|'s
+   * retry. The handshake will transparently flush out the pending record
+   * (discarded by the server) to keep the framing correct. */
+  ssl->s3->wpend_pending = 0;
+}
+
 static int bio_retry_reason_to_error(int reason) {
   switch (reason) {
     case BIO_RR_CONNECT:
@@ -938,6 +967,9 @@
 
     case SSL_PENDING_TICKET:
       return SSL_ERROR_PENDING_TICKET;
+
+    case SSL_EARLY_DATA_REJECTED:
+      return SSL_ERROR_EARLY_DATA_REJECTED;
   }
 
   return SSL_ERROR_SYSCALL;
@@ -1737,13 +1769,11 @@
 
 void SSL_get0_alpn_selected(const SSL *ssl, const uint8_t **out_data,
                             unsigned *out_len) {
-  *out_data = NULL;
-  if (ssl->s3) {
-    *out_data = ssl->s3->alpn_selected;
-  }
-  if (*out_data == NULL) {
-    *out_len = 0;
+  if (SSL_in_early_data(ssl) && !ssl->server) {
+    *out_data = ssl->s3->hs->early_session->early_alpn;
+    *out_len = ssl->s3->hs->early_session->early_alpn_len;
   } else {
+    *out_data = ssl->s3->alpn_selected;
     *out_len = ssl->s3->alpn_selected_len;
   }
 }
@@ -1934,7 +1964,7 @@
 }
 
 int SSL_session_reused(const SSL *ssl) {
-  return ssl->s3->session_reused;
+  return ssl->s3->session_reused || SSL_in_early_data(ssl);
 }
 
 const COMP_METHOD *SSL_get_current_compression(SSL *ssl) { return NULL; }
diff --git a/ssl/ssl_session.c b/ssl/ssl_session.c
index b105cd0..f025364 100644
--- a/ssl/ssl_session.c
+++ b/ssl/ssl_session.c
@@ -467,8 +467,12 @@
   if (!SSL_in_init(ssl)) {
     return ssl->s3->established_session;
   }
-  if (ssl->s3->hs->new_session != NULL) {
-    return ssl->s3->hs->new_session;
+  SSL_HANDSHAKE *hs = ssl->s3->hs;
+  if (hs->early_session != NULL) {
+    return hs->early_session;
+  }
+  if (hs->new_session != NULL) {
+    return hs->new_session;
   }
   return ssl->session;
 }
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 80465ce..2bf6eeb 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -1945,7 +1945,7 @@
 }
 
 static bssl::UniquePtr<SSL_SESSION> CreateClientSession(SSL_CTX *client_ctx,
-                                             SSL_CTX *server_ctx) {
+                                                        SSL_CTX *server_ctx) {
   g_last_session = nullptr;
   SSL_CTX_sess_set_new_cb(client_ctx, SaveLastSession);
 
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index f9b6553..e088aa9 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -1361,24 +1361,23 @@
 // CheckHandshakeProperties checks, immediately after |ssl| completes its
 // initial handshake (or False Starts), whether all the properties are
 // consistent with the test configuration and invariants.
-static bool CheckHandshakeProperties(SSL *ssl, bool is_resume) {
-  const TestConfig *config = GetTestConfig(ssl);
-
+static bool CheckHandshakeProperties(SSL *ssl, bool is_resume,
+                                     const TestConfig *config) {
   if (SSL_get_current_cipher(ssl) == nullptr) {
     fprintf(stderr, "null cipher after handshake\n");
     return false;
   }
 
-  if (is_resume &&
-      (!!SSL_session_reused(ssl) == config->expect_session_miss)) {
-    fprintf(stderr, "session was%s reused\n",
+  bool expect_resume =
+      is_resume && (!config->expect_session_miss || SSL_in_early_data(ssl));
+  if (!!SSL_session_reused(ssl) != expect_resume) {
+    fprintf(stderr, "session unexpectedly was%s reused\n",
             SSL_session_reused(ssl) ? "" : " not");
     return false;
   }
 
   bool expect_handshake_done =
-      (is_resume || !config->false_start) &&
-      !(config->is_server && SSL_early_data_accepted(ssl));
+      (is_resume || !config->false_start) && !SSL_in_early_data(ssl);
   if (expect_handshake_done != GetTestState(ssl)->handshake_done) {
     fprintf(stderr, "handshake was%s completed\n",
             GetTestState(ssl)->handshake_done ? "" : " not");
@@ -1552,7 +1551,7 @@
     return false;
   }
 
-  if (is_resume) {
+  if (is_resume && !SSL_in_early_data(ssl)) {
     if ((config->expect_accept_early_data && !SSL_early_data_accepted(ssl)) ||
         (config->expect_reject_early_data && SSL_early_data_accepted(ssl))) {
       fprintf(stderr,
@@ -1643,13 +1642,17 @@
   return true;
 }
 
-// DoExchange runs a test SSL exchange against the peer. On success, it returns
+static bool DoExchange(bssl::UniquePtr<SSL_SESSION> *out_session, SSL *ssl,
+                       const TestConfig *config, bool is_resume, bool is_retry);
+
+// DoConnection tests an SSL connection against the peer. On success, it returns
 // true and sets |*out_session| to the negotiated SSL session. If the test is a
 // resumption attempt, |is_resume| is true and |session| is the session from the
 // previous exchange.
-static bool DoExchange(bssl::UniquePtr<SSL_SESSION> *out_session,
-                       SSL_CTX *ssl_ctx, const TestConfig *config,
-                       bool is_resume, SSL_SESSION *session) {
+static bool DoConnection(bssl::UniquePtr<SSL_SESSION> *out_session,
+                         SSL_CTX *ssl_ctx, const TestConfig *config,
+                         const TestConfig *retry_config, bool is_resume,
+                         SSL_SESSION *session) {
   bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx));
   if (!ssl) {
     return false;
@@ -1865,20 +1868,62 @@
     SSL_set_connect_state(ssl.get());
   }
 
+  bool ret = DoExchange(out_session, ssl.get(), config, is_resume, false);
+  if (!config->is_server && is_resume && config->expect_reject_early_data) {
+    // We must have failed due to an early data rejection.
+    if (ret) {
+      fprintf(stderr, "0-RTT exchange unexpected succeeded.\n");
+      return false;
+    }
+    if (SSL_get_error(ssl.get(), -1) != SSL_ERROR_EARLY_DATA_REJECTED) {
+      fprintf(stderr,
+              "SSL_get_error did not signal SSL_ERROR_EARLY_DATA_REJECTED.\n");
+      return false;
+    }
+
+    // Before reseting, early state should still be available.
+    if (!SSL_in_early_data(ssl.get()) ||
+        !CheckHandshakeProperties(ssl.get(), is_resume, config)) {
+      fprintf(stderr, "SSL_in_early_data returned false before reset.\n");
+      return false;
+    }
+
+    // Reset the connection and try again at 1-RTT.
+    SSL_reset_early_data_reject(ssl.get());
+
+    // After reseting, the socket should report it is no longer in an early data
+    // state.
+    if (SSL_in_early_data(ssl.get())) {
+      fprintf(stderr, "SSL_in_early_data returned true after reset.\n");
+      return false;
+    }
+
+    if (!SetTestConfig(ssl.get(), retry_config)) {
+      return false;
+    }
+
+    ret = DoExchange(out_session, ssl.get(), retry_config, is_resume, true);
+  }
+  return ret;
+}
+
+static bool DoExchange(bssl::UniquePtr<SSL_SESSION> *out_session, SSL *ssl,
+                       const TestConfig *config, bool is_resume,
+                       bool is_retry) {
   int ret;
   if (!config->implicit_handshake) {
     do {
-      ret = SSL_do_handshake(ssl.get());
-    } while (config->async && RetryAsync(ssl.get(), ret));
+      ret = SSL_do_handshake(ssl);
+    } while (config->async && RetryAsync(ssl, ret));
     if (ret != 1 ||
-        !CheckHandshakeProperties(ssl.get(), is_resume)) {
+        !CheckHandshakeProperties(ssl, is_resume, config)) {
       return false;
     }
 
     if (config->handshake_twice) {
       do {
-        ret = SSL_do_handshake(ssl.get());
-      } while (config->async && RetryAsync(ssl.get(), ret));
+        ret = SSL_do_handshake(ssl);
+      } while (config->async && RetryAsync(ssl, ret));
       if (ret != 1) {
         return false;
       }
@@ -1886,28 +1931,28 @@
 
     // Skip the |config->async| logic as this should be a no-op.
     if (config->no_op_extra_handshake &&
-        SSL_do_handshake(ssl.get()) != 1) {
+        SSL_do_handshake(ssl) != 1) {
       fprintf(stderr, "Extra SSL_do_handshake was not a no-op.\n");
       return false;
     }
 
     // Reset the state to assert later that the callback isn't called in
     // renegotations.
-    GetTestState(ssl.get())->got_new_session = false;
+    GetTestState(ssl)->got_new_session = false;
   }
 
   if (config->export_keying_material > 0) {
     std::vector<uint8_t> result(
         static_cast<size_t>(config->export_keying_material));
     if (!SSL_export_keying_material(
-            ssl.get(), result.data(), result.size(),
-            config->export_label.data(), config->export_label.size(),
-            reinterpret_cast<const uint8_t*>(config->export_context.data()),
+            ssl, result.data(), result.size(), config->export_label.data(),
+            config->export_label.size(),
+            reinterpret_cast<const uint8_t *>(config->export_context.data()),
             config->export_context.size(), config->use_export_context)) {
       fprintf(stderr, "failed to export keying material\n");
       return false;
     }
-    if (WriteAll(ssl.get(), result.data(), result.size()) < 0) {
+    if (WriteAll(ssl, result.data(), result.size()) < 0) {
       return false;
     }
   }
@@ -1915,7 +1960,7 @@
   if (config->tls_unique) {
     uint8_t tls_unique[16];
     size_t tls_unique_len;
-    if (!SSL_get_tls_unique(ssl.get(), tls_unique, &tls_unique_len,
+    if (!SSL_get_tls_unique(ssl, tls_unique, &tls_unique_len,
                             sizeof(tls_unique))) {
       fprintf(stderr, "failed to get tls-unique\n");
       return false;
@@ -1927,13 +1972,13 @@
       return false;
     }
 
-    if (WriteAll(ssl.get(), tls_unique, tls_unique_len) < 0) {
+    if (WriteAll(ssl, tls_unique, tls_unique_len) < 0) {
       return false;
     }
   }
 
   if (config->send_alert) {
-    if (DoSendFatalAlert(ssl.get(), SSL_AD_DECOMPRESSION_FAILURE) < 0) {
+    if (DoSendFatalAlert(ssl, SSL_AD_DECOMPRESSION_FAILURE) < 0) {
       return false;
     }
     return true;
@@ -1957,7 +2002,7 @@
         fprintf(stderr, "Bad kRecordSizes value.\n");
         return false;
       }
-      if (WriteAll(ssl.get(), buf.get(), len) < 0) {
+      if (WriteAll(ssl, buf.get(), len) < 0) {
         return false;
       }
     }
@@ -1970,15 +2015,17 @@
         return false;
       }
 
+      // Let only one byte of the record through.
+      AsyncBioAllowWrite(GetTestState(ssl)->async_bio, 1);
       int write_ret =
-          SSL_write(ssl.get(), kInitialWrite, strlen(kInitialWrite));
-      if (SSL_get_error(ssl.get(), write_ret) != SSL_ERROR_WANT_WRITE) {
+          SSL_write(ssl, kInitialWrite, strlen(kInitialWrite));
+      if (SSL_get_error(ssl, write_ret) != SSL_ERROR_WANT_WRITE) {
         fprintf(stderr, "Failed to leave unfinished write.\n");
         return false;
       }
       pending_initial_write = true;
     } else if (config->shim_writes_first) {
-      if (WriteAll(ssl.get(), kInitialWrite, strlen(kInitialWrite)) < 0) {
+      if (WriteAll(ssl, kInitialWrite, strlen(kInitialWrite)) < 0) {
         return false;
       }
     }
@@ -1992,8 +2039,8 @@
         }
         std::unique_ptr<uint8_t[]> buf(new uint8_t[read_size]);
 
-        int n = DoRead(ssl.get(), buf.get(), read_size);
-        int err = SSL_get_error(ssl.get(), n);
+        int n = DoRead(ssl, buf.get(), read_size);
+        int err = SSL_get_error(ssl, n);
         if (err == SSL_ERROR_ZERO_RETURN ||
             (n == 0 && err == SSL_ERROR_SYSCALL)) {
           if (n != 0) {
@@ -2015,17 +2062,24 @@
           return false;
         }
 
+        if (!config->is_server && is_resume && !is_retry &&
+            config->expect_reject_early_data) {
+          fprintf(stderr,
+                  "Unexpectedly received data instead of 0-RTT reject.\n");
+          return false;
+        }
+
         // After a successful read, with or without False Start, the handshake
         // must be complete unless we are doing early data.
-        if (!GetTestState(ssl.get())->handshake_done &&
-            !SSL_early_data_accepted(ssl.get())) {
+        if (!GetTestState(ssl)->handshake_done &&
+            !SSL_early_data_accepted(ssl)) {
           fprintf(stderr, "handshake was not completed after SSL_read\n");
           return false;
         }
 
         // Clear the initial write, if unfinished.
         if (pending_initial_write) {
-          if (WriteAll(ssl.get(), kInitialWrite, strlen(kInitialWrite)) < 0) {
+          if (WriteAll(ssl, kInitialWrite, strlen(kInitialWrite)) < 0) {
             return false;
           }
           pending_initial_write = false;
@@ -2034,7 +2088,7 @@
         for (int i = 0; i < n; i++) {
           buf[i] ^= 0xff;
         }
-        if (WriteAll(ssl.get(), buf.get(), n) < 0) {
+        if (WriteAll(ssl, buf.get(), n) < 0) {
           return false;
         }
       }
@@ -2044,25 +2098,25 @@
   if (!config->is_server && !config->false_start &&
       !config->implicit_handshake &&
       // Session tickets are sent post-handshake in TLS 1.3.
-      GetProtocolVersion(ssl.get()) < TLS1_3_VERSION &&
-      GetTestState(ssl.get())->got_new_session) {
+      GetProtocolVersion(ssl) < TLS1_3_VERSION &&
+      GetTestState(ssl)->got_new_session) {
     fprintf(stderr, "new session was established after the handshake\n");
     return false;
   }
 
-  if (GetProtocolVersion(ssl.get()) >= TLS1_3_VERSION && !config->is_server) {
+  if (GetProtocolVersion(ssl) >= TLS1_3_VERSION && !config->is_server) {
     bool expect_new_session =
         !config->expect_no_session && !config->shim_shuts_down;
-    if (expect_new_session != GetTestState(ssl.get())->got_new_session) {
+    if (expect_new_session != GetTestState(ssl)->got_new_session) {
       fprintf(stderr,
               "new session was%s cached, but we expected the opposite\n",
-              GetTestState(ssl.get())->got_new_session ? "" : " not");
+              GetTestState(ssl)->got_new_session ? "" : " not");
       return false;
     }
 
     if (expect_new_session) {
       bool got_early_data_info =
-          GetTestState(ssl.get())->new_session->ticket_max_early_data != 0;
+          GetTestState(ssl)->new_session->ticket_max_early_data != 0;
       if (config->expect_early_data_info != got_early_data_info) {
         fprintf(
             stderr,
@@ -2075,10 +2129,10 @@
   }
 
   if (out_session) {
-    *out_session = std::move(GetTestState(ssl.get())->new_session);
+    *out_session = std::move(GetTestState(ssl)->new_session);
   }
 
-  ret = DoShutdown(ssl.get());
+  ret = DoShutdown(ssl);
 
   if (config->shim_shuts_down && config->check_close_notify) {
     // We initiate shutdown, so |SSL_shutdown| will return in two stages. First
@@ -2088,7 +2142,7 @@
       fprintf(stderr, "Unexpected SSL_shutdown result: %d != 0\n", ret);
       return false;
     }
-    ret = DoShutdown(ssl.get());
+    ret = DoShutdown(ssl);
   }
 
   if (ret != 1) {
@@ -2096,11 +2150,9 @@
     return false;
   }
 
-  if (SSL_total_renegotiations(ssl.get()) !=
-      config->expect_total_renegotiations) {
+  if (SSL_total_renegotiations(ssl) != config->expect_total_renegotiations) {
     fprintf(stderr, "Expected %d renegotiations, got %d\n",
-            config->expect_total_renegotiations,
-            SSL_total_renegotiations(ssl.get()));
+            config->expect_total_renegotiations, SSL_total_renegotiations(ssl));
     return false;
   }
 
@@ -2141,9 +2193,9 @@
     return 1;
   }
 
-  TestConfig initial_config, resume_config;
-  if (!ParseConfig(argc - 1, argv + 1, false, &initial_config) ||
-      !ParseConfig(argc - 1, argv + 1, true, &resume_config)) {
+  TestConfig initial_config, resume_config, retry_config;
+  if (!ParseConfig(argc - 1, argv + 1, &initial_config, &resume_config,
+                   &retry_config)) {
     return Usage(argv[0]);
   }
 
@@ -2172,8 +2224,8 @@
     }
 
     bssl::UniquePtr<SSL_SESSION> offer_session = std::move(session);
-    if (!DoExchange(&session, ssl_ctx.get(), config, is_resume,
-                    offer_session.get())) {
+    if (!DoConnection(&session, ssl_ctx.get(), config, &retry_config, is_resume,
+                      offer_session.get())) {
       fprintf(stderr, "Connection %d failed.\n", i + 1);
       ERR_print_errors_fp(stderr);
       return 1;
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index 052f879..0a6648f 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -1175,6 +1175,14 @@
 	// the number of records or their content do not match.
 	ExpectEarlyData [][]byte
 
+	// ExpectLateEarlyData causes a TLS 1.3 server to read application
+	// data after the ServerFinished (assuming the server is able to
+	// derive the key under which the data is encrypted) before it
+	// sends the ClientFinished. It checks that the application data it
+	// reads matches what is provided in ExpectLateEarlyData and errors if
+	// the number of records or their content do not match.
+	ExpectLateEarlyData [][]byte
+
 	// SendHalfRTTData causes a TLS 1.3 server to send the provided
 	// data in application data records before reading the client's
 	// Finished message.
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 0eb64e7..fce0049 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -857,17 +857,8 @@
 			c.sendAlert(alertInternalError)
 			return c.in.setErrorLocked(errors.New("tls: ChangeCipherSpec requested after handshake complete"))
 		}
-	case recordTypeApplicationData:
-		if !c.handshakeComplete && !c.config.Bugs.ExpectFalseStart && len(c.config.Bugs.ExpectHalfRTTData) == 0 && len(c.config.Bugs.ExpectEarlyData) == 0 {
-			c.sendAlert(alertInternalError)
-			return c.in.setErrorLocked(errors.New("tls: application data record requested before handshake complete"))
-		}
-	case recordTypeAlert, recordTypeHandshake:
-		// Looking for a close_notify or handshake message. Note: unlike
-		// a real implementation, this is not tolerant of additional
-		// records. See the documentation for ExpectCloseNotify.
-		// Post-handshake requests for handshake messages are allowed if
-		// the caller used ReadKeyUpdateACK.
+	case recordTypeApplicationData, recordTypeAlert, recordTypeHandshake:
+		break
 	}
 
 Again:
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 8dc0446..3a182ec 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -680,7 +680,6 @@
 				}
 				c.in.freeBlock(c.input)
 				c.input = nil
-
 			}
 		} else {
 			c.skipEarlyData = true
@@ -880,6 +879,19 @@
 	}
 	c.flushHandshake()
 
+	if encryptedExtensions.extensions.hasEarlyData && !c.skipEarlyData {
+		for _, expectedMsg := range config.Bugs.ExpectLateEarlyData {
+			if err := c.readRecord(recordTypeApplicationData); err != nil {
+				return err
+			}
+			if !bytes.Equal(c.input.data[c.input.off:], expectedMsg) {
+				return errors.New("ExpectLateEarlyData: did not get expected message")
+			}
+			c.in.freeBlock(c.input)
+			c.input = nil
+		}
+	}
+
 	// The various secrets do not incorporate the client's final leg, so
 	// derive them now before updating the handshake context.
 	hs.finishedHash.addEntropy(hs.finishedHash.zeroSecret())
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index f714c77..62e0056 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -431,6 +431,11 @@
 	// expectPeerCertificate, if not nil, is the certificate chain the peer
 	// is expected to send.
 	expectPeerCertificate *Certificate
+	// shimPrefix is the prefix that the shim will send to the server.
+	shimPrefix string
+	// resumeShimPrefix is the prefix that the shim will send to the server on a
+	// resumption.
+	resumeShimPrefix string
 }
 
 var testCases []testCase
@@ -688,20 +693,26 @@
 		tlsConn.SendHalfHelloRequest()
 	}
 
-	shimPrefixPending := test.shimWritesFirst || test.readWithUnfinishedWrite
+	shimPrefix := test.shimPrefix
+	if isResume {
+		shimPrefix = test.resumeShimPrefix
+	}
+	if test.shimWritesFirst || test.readWithUnfinishedWrite {
+		shimPrefix = "hello"
+	}
 	if test.renegotiate > 0 {
 		// If readWithUnfinishedWrite is set, the shim prefix will be
 		// available later.
-		if shimPrefixPending && !test.readWithUnfinishedWrite {
-			var buf [5]byte
-			_, err := io.ReadFull(tlsConn, buf[:])
+		if shimPrefix != "" && !test.readWithUnfinishedWrite {
+			var buf = make([]byte, len(shimPrefix))
+			_, err := io.ReadFull(tlsConn, buf)
 			if err != nil {
 				return err
 			}
-			if string(buf[:]) != "hello" {
-				return fmt.Errorf("bad initial message")
+			if string(buf) != shimPrefix {
+				return fmt.Errorf("bad initial message %v vs %v", string(buf), shimPrefix)
 			}
-			shimPrefixPending = false
+			shimPrefix = ""
 		}
 
 		if test.renegotiateCiphers != nil {
@@ -764,16 +775,16 @@
 		tlsConn.Write(testMessage)
 
 		// Consume the shim prefix if needed.
-		if shimPrefixPending {
-			var buf [5]byte
-			_, err := io.ReadFull(tlsConn, buf[:])
+		if shimPrefix != "" {
+			var buf = make([]byte, len(shimPrefix))
+			_, err := io.ReadFull(tlsConn, buf)
 			if err != nil {
 				return err
 			}
-			if string(buf[:]) != "hello" {
-				return fmt.Errorf("bad initial message")
+			if string(buf) != shimPrefix {
+				return fmt.Errorf("bad initial message %v vs %v", string(buf), shimPrefix)
 			}
-			shimPrefixPending = false
+			shimPrefix = ""
 		}
 
 		if test.shimShutsDown || test.expectMessageDropped {
@@ -3639,7 +3650,6 @@
 			resumeSession: true,
 		})
 
-		// TODO(svaldez): Send data on early data once implemented.
 		tests = append(tests, testCase{
 			testType: clientTest,
 			name:     "TLS13-EarlyData-Client",
@@ -3648,15 +3658,110 @@
 				MinVersion:       VersionTLS13,
 				MaxEarlyDataSize: 16384,
 			},
+			resumeConfig: &Config{
+				MaxVersion:       VersionTLS13,
+				MinVersion:       VersionTLS13,
+				MaxEarlyDataSize: 16384,
+				Bugs: ProtocolBugs{
+					ExpectEarlyData: [][]byte{{'h', 'e', 'l', 'l', 'o'}},
+				},
+			},
 			resumeSession: true,
 			flags: []string{
 				"-enable-early-data",
 				"-expect-early-data-info",
 				"-expect-accept-early-data",
+				"-on-resume-shim-writes-first",
 			},
 		})
 
 		tests = append(tests, testCase{
+			testType: clientTest,
+			name:     "TLS13-EarlyData-TooMuchData-Client",
+			config: Config{
+				MaxVersion:       VersionTLS13,
+				MinVersion:       VersionTLS13,
+				MaxEarlyDataSize: 2,
+			},
+			resumeConfig: &Config{
+				MaxVersion:       VersionTLS13,
+				MinVersion:       VersionTLS13,
+				MaxEarlyDataSize: 2,
+				Bugs: ProtocolBugs{
+					ExpectEarlyData: [][]byte{{'h', 'e'}},
+				},
+			},
+			resumeShimPrefix: "llo",
+			resumeSession:    true,
+			flags: []string{
+				"-enable-early-data",
+				"-expect-early-data-info",
+				"-expect-accept-early-data",
+				"-on-resume-shim-writes-first",
+			},
+		})
+
+		// Unfinished writes can only be tested when operations are async. EarlyData
+		// can't be tested as part of an ImplicitHandshake in this case since
+		// otherwise the early data will be sent as normal data.
+		if config.async && !config.implicitHandshake {
+			tests = append(tests, testCase{
+				testType: clientTest,
+				name:     "TLS13-EarlyData-UnfinishedWrite-Client",
+				config: Config{
+					MaxVersion:       VersionTLS13,
+					MinVersion:       VersionTLS13,
+					MaxEarlyDataSize: 16384,
+				},
+				resumeConfig: &Config{
+					MaxVersion:       VersionTLS13,
+					MinVersion:       VersionTLS13,
+					MaxEarlyDataSize: 16384,
+					Bugs: ProtocolBugs{
+						ExpectLateEarlyData: [][]byte{{'h', 'e', 'l', 'l', 'o'}},
+					},
+				},
+				resumeSession: true,
+				flags: []string{
+					"-enable-early-data",
+					"-expect-early-data-info",
+					"-expect-accept-early-data",
+					"-on-resume-read-with-unfinished-write",
+					"-on-resume-shim-writes-first",
+				},
+			})
+
+			// Rejected unfinished writes are discarded (from the
+			// perspective of the calling application) on 0-RTT
+			// reject.
+			tests = append(tests, testCase{
+				testType: clientTest,
+				name:     "TLS13-EarlyData-RejectUnfinishedWrite-Client",
+				config: Config{
+					MaxVersion:       VersionTLS13,
+					MinVersion:       VersionTLS13,
+					MaxEarlyDataSize: 16384,
+				},
+				resumeConfig: &Config{
+					MaxVersion:       VersionTLS13,
+					MinVersion:       VersionTLS13,
+					MaxEarlyDataSize: 16384,
+					Bugs: ProtocolBugs{
+						AlwaysRejectEarlyData: true,
+					},
+				},
+				resumeSession: true,
+				flags: []string{
+					"-enable-early-data",
+					"-expect-early-data-info",
+					"-expect-reject-early-data",
+					"-on-resume-read-with-unfinished-write",
+					"-on-resume-shim-writes-first",
+				},
+			})
+		}
+
+		tests = append(tests, testCase{
 			testType: serverTest,
 			name:     "TLS13-EarlyData-Server",
 			config: Config{
@@ -10286,7 +10391,7 @@
 
 	testCases = append(testCases, testCase{
 		testType: clientTest,
-		name:     "TLS13-DataLessEarlyData-Reject-Client",
+		name:     "TLS13-EarlyData-Reject-Client",
 		config: Config{
 			MaxVersion:       VersionTLS13,
 			MaxEarlyDataSize: 16384,
@@ -10303,12 +10408,42 @@
 			"-enable-early-data",
 			"-expect-early-data-info",
 			"-expect-reject-early-data",
+			"-on-resume-shim-writes-first",
 		},
 	})
 
 	testCases = append(testCases, testCase{
 		testType: clientTest,
-		name:     "TLS13-DataLessEarlyData-HRR-Client",
+		name:     "TLS13-EarlyData-RejectTicket-Client",
+		config: Config{
+			MaxVersion:       VersionTLS13,
+			MaxEarlyDataSize: 16384,
+			Certificates:     []Certificate{rsaCertificate},
+		},
+		resumeConfig: &Config{
+			MaxVersion:             VersionTLS13,
+			MaxEarlyDataSize:       16384,
+			Certificates:           []Certificate{ecdsaP256Certificate},
+			SessionTicketsDisabled: true,
+		},
+		resumeSession:        true,
+		expectResumeRejected: true,
+		flags: []string{
+			"-enable-early-data",
+			"-expect-early-data-info",
+			"-expect-reject-early-data",
+			"-on-resume-shim-writes-first",
+			"-on-initial-expect-peer-cert-file", path.Join(*resourceDir, rsaCertificateFile),
+			"-on-resume-expect-peer-cert-file", path.Join(*resourceDir, rsaCertificateFile),
+			"-on-retry-expect-peer-cert-file", path.Join(*resourceDir, ecdsaP256CertificateFile),
+			// Session tickets are disabled, so the runner will not send a ticket.
+			"-on-retry-expect-no-session",
+		},
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "TLS13-EarlyData-HRR-Client",
 		config: Config{
 			MaxVersion:       VersionTLS13,
 			MaxEarlyDataSize: 16384,
@@ -10395,6 +10530,7 @@
 		flags: []string{
 			"-enable-early-data",
 			"-expect-early-data-info",
+			"-expect-reject-early-data",
 		},
 		shouldFail:    true,
 		expectedError: ":UNEXPECTED_EXTENSION:",
@@ -10407,7 +10543,7 @@
 	// that changed it.
 	testCases = append(testCases, testCase{
 		testType: clientTest,
-		name:     "TLS13-DataLessEarlyData-ALPNMismatch-Client",
+		name:     "TLS13-EarlyData-ALPNMismatch-Client",
 		config: Config{
 			MaxVersion:       VersionTLS13,
 			MaxEarlyDataSize: 16384,
@@ -10429,7 +10565,8 @@
 			"-expect-early-data-info",
 			"-expect-reject-early-data",
 			"-on-initial-expect-alpn", "foo",
-			"-on-resume-expect-alpn", "bar",
+			"-on-resume-expect-alpn", "foo",
+			"-on-retry-expect-alpn", "bar",
 		},
 	})
 
@@ -10437,7 +10574,7 @@
 	// ALPN was omitted from the first connection.
 	testCases = append(testCases, testCase{
 		testType: clientTest,
-		name:     "TLS13-DataLessEarlyData-ALPNOmitted1-Client",
+		name:     "TLS13-EarlyData-ALPNOmitted1-Client",
 		config: Config{
 			MaxVersion:       VersionTLS13,
 			MaxEarlyDataSize: 16384,
@@ -10454,7 +10591,9 @@
 			"-expect-early-data-info",
 			"-expect-reject-early-data",
 			"-on-initial-expect-alpn", "",
-			"-on-resume-expect-alpn", "foo",
+			"-on-resume-expect-alpn", "",
+			"-on-retry-expect-alpn", "foo",
+			"-on-resume-shim-writes-first",
 		},
 	})
 
@@ -10462,7 +10601,7 @@
 	// ALPN was omitted from the second connection.
 	testCases = append(testCases, testCase{
 		testType: clientTest,
-		name:     "TLS13-DataLessEarlyData-ALPNOmitted2-Client",
+		name:     "TLS13-EarlyData-ALPNOmitted2-Client",
 		config: Config{
 			MaxVersion:       VersionTLS13,
 			MaxEarlyDataSize: 16384,
@@ -10479,14 +10618,16 @@
 			"-expect-early-data-info",
 			"-expect-reject-early-data",
 			"-on-initial-expect-alpn", "foo",
-			"-on-resume-expect-alpn", "",
+			"-on-resume-expect-alpn", "foo",
+			"-on-retry-expect-alpn", "",
+			"-on-resume-shim-writes-first",
 		},
 	})
 
 	// Test that the client enforces ALPN match on 0-RTT accept.
 	testCases = append(testCases, testCase{
 		testType: clientTest,
-		name:     "TLS13-DataLessEarlyData-BadALPNMismatch-Client",
+		name:     "TLS13-EarlyData-BadALPNMismatch-Client",
 		config: Config{
 			MaxVersion:       VersionTLS13,
 			MaxEarlyDataSize: 16384,
@@ -10508,7 +10649,8 @@
 			"-enable-early-data",
 			"-expect-early-data-info",
 			"-on-initial-expect-alpn", "foo",
-			"-on-resume-expect-alpn", "bar",
+			"-on-resume-expect-alpn", "foo",
+			"-on-retry-expect-alpn", "bar",
 		},
 		shouldFail:    true,
 		expectedError: ":ALPN_MISMATCH_ON_EARLY_DATA:",
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index d7c3239..960240e 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -142,6 +142,7 @@
   { "-host-name", &TestConfig::host_name },
   { "-advertise-alpn", &TestConfig::advertise_alpn },
   { "-expect-alpn", &TestConfig::expected_alpn },
+  { "-expect-late-alpn", &TestConfig::expected_late_alpn },
   { "-expect-advertised-alpn", &TestConfig::expected_advertised_alpn },
   { "-select-alpn", &TestConfig::select_alpn },
   { "-psk", &TestConfig::psk },
@@ -192,110 +193,123 @@
   { "-verify-prefs", &TestConfig::verify_prefs },
 };
 
+bool ParseFlag(char *flag, int argc, char **argv, int *i,
+               bool skip, TestConfig *out_config) {
+  bool *bool_field = FindField(out_config, kBoolFlags, flag);
+  if (bool_field != NULL) {
+    if (!skip) {
+      *bool_field = true;
+    }
+    return true;
+  }
+
+  std::string *string_field = FindField(out_config, kStringFlags, flag);
+  if (string_field != NULL) {
+    *i = *i + 1;
+    if (*i >= argc) {
+      fprintf(stderr, "Missing parameter\n");
+      return false;
+    }
+    if (!skip) {
+      string_field->assign(argv[*i]);
+    }
+    return true;
+  }
+
+  std::string *base64_field = FindField(out_config, kBase64Flags, flag);
+  if (base64_field != NULL) {
+    *i = *i + 1;
+    if (*i >= argc) {
+      fprintf(stderr, "Missing parameter\n");
+      return false;
+    }
+    size_t len;
+    if (!EVP_DecodedLength(&len, strlen(argv[*i]))) {
+      fprintf(stderr, "Invalid base64: %s\n", argv[*i]);
+      return false;
+    }
+    std::unique_ptr<uint8_t[]> decoded(new uint8_t[len]);
+    if (!EVP_DecodeBase64(decoded.get(), &len, len,
+                          reinterpret_cast<const uint8_t *>(argv[*i]),
+                          strlen(argv[*i]))) {
+      fprintf(stderr, "Invalid base64: %s\n", argv[*i]);
+      return false;
+    }
+    if (!skip) {
+      base64_field->assign(reinterpret_cast<const char *>(decoded.get()),
+                           len);
+    }
+    return true;
+  }
+
+  int *int_field = FindField(out_config, kIntFlags, flag);
+  if (int_field) {
+    *i = *i + 1;
+    if (*i >= argc) {
+      fprintf(stderr, "Missing parameter\n");
+      return false;
+    }
+    if (!skip) {
+      *int_field = atoi(argv[*i]);
+    }
+    return true;
+  }
+
+  std::vector<int> *int_vector_field =
+      FindField(out_config, kIntVectorFlags, flag);
+  if (int_vector_field) {
+    *i = *i + 1;
+    if (*i >= argc) {
+      fprintf(stderr, "Missing parameter\n");
+      return false;
+    }
+
+    // Each instance of the flag adds to the list.
+    if (!skip) {
+      int_vector_field->push_back(atoi(argv[*i]));
+    }
+    return true;
+  }
+
+  fprintf(stderr, "Unknown argument: %s\n", flag);
+  return false;
+}
+
 const char kInit[] = "-on-initial";
 const char kResume[] = "-on-resume";
+const char kRetry[] = "-on-retry";
 
 }  // namespace
 
-bool ParseConfig(int argc, char **argv, bool is_resume,
-                 TestConfig *out_config) {
+bool ParseConfig(int argc, char **argv,
+                 TestConfig *out_initial,
+                 TestConfig *out_resume,
+                 TestConfig *out_retry) {
   for (int i = 0; i < argc; i++) {
     bool skip = false;
     char *flag = argv[i];
-    const char *prefix = is_resume ? kResume : kInit;
-    const char *opposite = is_resume ? kInit : kResume;
-    if (strncmp(flag, prefix, strlen(prefix)) == 0) {
-      flag = flag + strlen(prefix);
-      for (int j = 0; j < argc; j++) {
-        if (strcmp(argv[j], flag) == 0) {
-          fprintf(stderr, "Can't use default and prefixed arguments: %s\n",
-                  flag);
-          return false;
-        }
-      }
-    } else if (strncmp(flag, opposite, strlen(opposite)) == 0) {
-      flag = flag + strlen(opposite);
-      skip = true;
-    }
-
-    bool *bool_field = FindField(out_config, kBoolFlags, flag);
-    if (bool_field != NULL) {
-      if (!skip) {
-        *bool_field = true;
-      }
-      continue;
-    }
-
-    std::string *string_field = FindField(out_config, kStringFlags, flag);
-    if (string_field != NULL) {
-      i++;
-      if (i >= argc) {
-        fprintf(stderr, "Missing parameter\n");
+    if (strncmp(flag, kInit, strlen(kInit)) == 0) {
+      if (!ParseFlag(flag + strlen(kInit), argc, argv, &i, skip, out_initial)) {
         return false;
       }
-      if (!skip) {
-        string_field->assign(argv[i]);
+    } else if (strncmp(flag, kResume, strlen(kResume)) == 0) {
+      if (!ParseFlag(flag + strlen(kResume), argc, argv, &i, skip,
+                     out_resume)) {
+        return false;
       }
-      continue;
+    } else if (strncmp(flag, kRetry, strlen(kRetry)) == 0) {
+      if (!ParseFlag(flag + strlen(kRetry), argc, argv, &i, skip, out_retry)) {
+        return false;
+      }
+    } else {
+      int i_init = i;
+      int i_resume = i;
+      if (!ParseFlag(flag, argc, argv, &i_init, skip, out_initial) ||
+          !ParseFlag(flag, argc, argv, &i_resume, skip, out_resume) ||
+          !ParseFlag(flag, argc, argv, &i, skip, out_retry)) {
+        return false;
+      }
     }
-
-    std::string *base64_field = FindField(out_config, kBase64Flags, flag);
-    if (base64_field != NULL) {
-      i++;
-      if (i >= argc) {
-        fprintf(stderr, "Missing parameter\n");
-        return false;
-      }
-      size_t len;
-      if (!EVP_DecodedLength(&len, strlen(argv[i]))) {
-        fprintf(stderr, "Invalid base64: %s\n", argv[i]);
-        return false;
-      }
-      std::unique_ptr<uint8_t[]> decoded(new uint8_t[len]);
-      if (!EVP_DecodeBase64(decoded.get(), &len, len,
-                            reinterpret_cast<const uint8_t *>(argv[i]),
-                            strlen(argv[i]))) {
-        fprintf(stderr, "Invalid base64: %s\n", argv[i]);
-        return false;
-      }
-      if (!skip) {
-        base64_field->assign(reinterpret_cast<const char *>(decoded.get()),
-                             len);
-      }
-      continue;
-    }
-
-    int *int_field = FindField(out_config, kIntFlags, flag);
-    if (int_field) {
-      i++;
-      if (i >= argc) {
-        fprintf(stderr, "Missing parameter\n");
-        return false;
-      }
-      if (!skip) {
-        *int_field = atoi(argv[i]);
-      }
-      continue;
-    }
-
-    std::vector<int> *int_vector_field =
-        FindField(out_config, kIntVectorFlags, flag);
-    if (int_vector_field) {
-      i++;
-      if (i >= argc) {
-        fprintf(stderr, "Missing parameter\n");
-        return false;
-      }
-
-      // Each instance of the flag adds to the list.
-      if (!skip) {
-        int_vector_field->push_back(atoi(argv[i]));
-      }
-      continue;
-    }
-
-    fprintf(stderr, "Unknown argument: %s\n", flag);
-    return false;
   }
 
   return true;
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index c63a1cb..8bc8892 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -53,6 +53,7 @@
   std::string host_name;
   std::string advertise_alpn;
   std::string expected_alpn;
+  std::string expected_late_alpn;
   std::string expected_advertised_alpn;
   std::string select_alpn;
   bool decline_alpn = false;
@@ -141,7 +142,8 @@
   bool enable_ed25519 = false;
 };
 
-bool ParseConfig(int argc, char **argv, bool is_resume, TestConfig *out_config);
+bool ParseConfig(int argc, char **argv, TestConfig *out_initial,
+                 TestConfig *out_resume, TestConfig *out_retry);
 
 
 #endif  // HEADER_TEST_CONFIG
diff --git a/ssl/tls13_both.c b/ssl/tls13_both.c
index f44933f..7a375fe 100644
--- a/ssl/tls13_both.c
+++ b/ssl/tls13_both.c
@@ -94,6 +94,12 @@
         hs->wait = ssl_hs_ok;
         return -1;
 
+      case ssl_hs_early_data_rejected:
+        ssl->rwstate = SSL_EARLY_DATA_REJECTED;
+        /* Cause |SSL_write| to start failing immediately. */
+        hs->can_early_write = 0;
+        return -1;
+
       case ssl_hs_ok:
         break;
     }
diff --git a/ssl/tls13_client.c b/ssl/tls13_client.c
index 6de51e5..6b39371 100644
--- a/ssl/tls13_client.c
+++ b/ssl/tls13_client.c
@@ -33,6 +33,7 @@
   state_send_second_client_hello,
   state_process_server_hello,
   state_process_encrypted_extensions,
+  state_continue_second_server_flight,
   state_process_certificate_request,
   state_process_server_certificate,
   state_process_server_certificate_verify,
@@ -141,13 +142,15 @@
 
   hs->received_hello_retry_request = 1;
   hs->tls13_state = state_send_second_client_hello;
+  /* 0-RTT is rejected if we receive a HelloRetryRequest. */
+  if (hs->in_early_data) {
+    return ssl_hs_early_data_rejected;
+  }
   return ssl_hs_ok;
 }
 
 static enum ssl_hs_wait_t do_send_second_client_hello(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  /* TODO(svaldez): Ensure that we set can_early_write to false since 0-RTT is
-   * rejected if we receive a HelloRetryRequest. */
   if (!ssl->method->set_write_state(ssl, NULL) ||
       !ssl_write_client_hello(hs)) {
     return ssl_hs_error;
@@ -259,6 +262,7 @@
       ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
       return ssl_hs_error;
     }
+    ssl_set_session(ssl, NULL);
 
     /* Resumption incorporates fresh key material, so refresh the timeout. */
     ssl_session_renew_timeout(ssl, hs->new_session,
@@ -358,9 +362,9 @@
   }
 
   if (ssl->early_data_accepted) {
-    if (ssl->session->cipher != hs->new_session->cipher ||
-        ssl->session->early_alpn_len != ssl->s3->alpn_selected_len ||
-        OPENSSL_memcmp(ssl->session->early_alpn, ssl->s3->alpn_selected,
+    if (hs->early_session->cipher != hs->new_session->cipher ||
+        hs->early_session->early_alpn_len != ssl->s3->alpn_selected_len ||
+        OPENSSL_memcmp(hs->early_session->early_alpn, ssl->s3->alpn_selected,
                        ssl->s3->alpn_selected_len) != 0) {
       OPENSSL_PUT_ERROR(SSL, SSL_R_ALPN_MISMATCH_ON_EARLY_DATA);
       return ssl_hs_error;
@@ -371,15 +375,18 @@
     }
   }
 
-  /* Release offered session now that it is no longer needed. */
-  if (ssl->s3->session_reused) {
-    ssl_set_session(ssl, NULL);
-  }
-
   if (!ssl_hash_current_message(hs)) {
     return ssl_hs_error;
   }
 
+  hs->tls13_state = state_continue_second_server_flight;
+  if (hs->in_early_data && !ssl->early_data_accepted) {
+    return ssl_hs_early_data_rejected;
+  }
+  return ssl_hs_ok;
+}
+
+static enum ssl_hs_wait_t do_continue_second_server_flight(SSL_HANDSHAKE *hs) {
   hs->tls13_state = state_process_certificate_request;
   return ssl_hs_read_message;
 }
@@ -485,11 +492,13 @@
 
 static enum ssl_hs_wait_t do_send_end_of_early_data(SSL_HANDSHAKE *hs) {
   SSL *const ssl = hs->ssl;
-  /* TODO(svaldez): Stop sending early data. */
-  if (ssl->early_data_accepted &&
-      !ssl->method->add_alert(ssl, SSL3_AL_WARNING,
-                              TLS1_AD_END_OF_EARLY_DATA)) {
-    return ssl_hs_error;
+
+  if (ssl->early_data_accepted) {
+    hs->can_early_write = 0;
+    if (!ssl->method->add_alert(ssl, SSL3_AL_WARNING,
+                                TLS1_AD_END_OF_EARLY_DATA)) {
+      return ssl_hs_error;
+    }
   }
 
   if (hs->early_data_offered &&
@@ -618,6 +627,9 @@
       case state_process_encrypted_extensions:
         ret = do_process_encrypted_extensions(hs);
         break;
+      case state_continue_second_server_flight:
+        ret = do_continue_second_server_flight(hs);
+        break;
       case state_process_certificate_request:
         ret = do_process_certificate_request(hs);
         break;
diff --git a/ssl/tls13_server.c b/ssl/tls13_server.c
index 9e8513c..162da22 100644
--- a/ssl/tls13_server.c
+++ b/ssl/tls13_server.c
@@ -657,6 +657,7 @@
     }
     hs->can_early_write = 1;
     hs->can_early_read = 1;
+    hs->in_early_data = 1;
     hs->tls13_state = state_process_end_of_early_data;
     return ssl_hs_read_end_of_early_data;
   }
diff --git a/tool/client.cc b/tool/client.cc
index 005afa8..c4a071d 100644
--- a/tool/client.cc
+++ b/tool/client.cc
@@ -121,7 +121,7 @@
         "verification is required.",
     },
     {
-        "-early-data", kBooleanArgument, "Allow early data",
+        "-early-data", kOptionalArgument, "Allow early data",
     },
     {
         "-ed25519", kBooleanArgument, "Advertise Ed25519 support",
@@ -264,6 +264,20 @@
     return false;
   }
 
+  if (args_map.count("-early-data") != 0 && SSL_in_early_data(ssl.get())) {
+    int ed_size = args_map["-early-data"].size();
+    int ssl_ret = SSL_write(ssl.get(), args_map["-early-data"].data(), ed_size);
+    if (ssl_ret <= 0) {
+      int ssl_err = SSL_get_error(ssl.get(), ssl_ret);
+      fprintf(stderr, "Error while writing: %d\n", ssl_err);
+      ERR_print_errors_cb(PrintErrorCallback, stderr);
+      return false;
+    } else if (ssl_ret != ed_size) {
+      fprintf(stderr, "Short write from SSL_write.\n");
+      return false;
+    }
+  }
+
   fprintf(stderr, "Connected.\n");
   PrintConnectionInfo(ssl.get());