Fix DTLS asynchronous write handling.

Although the DTLS transport layer logic drops failed writes on the floor, it is
actually set up to work correctly. If an SSL_write fails at the transport,
dropping the buffer is fine. Arguably it works better than in TLS because we
don't have the weird "half-committed to data" behavior. Likewise, the handshake
keeps track of how far its gotten and resumes the message at the right point.

This broke when the buffering logic was rewritten because I didn't understand
what the DTLS code was doing. The one thing that doesn't work as one might
expect is non-fatal write errors during rexmit are not recoverable. The next
timeout must fire before we try again.

This code is quite badly sprinkled in here, so add tests to guard it against
future turbulence. Because of the rexmit issues, the tests need some hacks
around calls which may trigger them. It also changes the Go DTLS implementation
from being completely strict about sequence numbers to only requiring they be
monotonic.

The tests also revealed another bug. This one seems to be upstream's fault, not
mine. The logic to reset the handshake hash on the second ClientHello (in the
HelloVerifyRequest case) was a little overenthusiastic and breaks if the
ClientHello took multiple tries to send.

Change-Id: I9b38b93fff7ae62faf8e36c4beaf848850b3f4b9
Reviewed-on: https://boringssl-review.googlesource.com/6417
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 2acf455..f9e0d85 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -317,18 +317,24 @@
  * returns the number of bytes read. Otherwise, it returns <= 0. The caller
  * should pass the value into |SSL_get_error| to determine how to proceed.
  *
- * A non-blocking |SSL_write| differs from non-blocking |write| in that a failed
- * |SSL_write| still commits to the data passed in. When retrying, the caller
- * must supply the original write buffer (or a larger one containing the
+ * In TLS, a non-blocking |SSL_write| differs from non-blocking |write| in that
+ * a failed |SSL_write| still commits to the data passed in. When retrying, the
+ * caller must supply the original write buffer (or a larger one containing the
  * original as a prefix). By default, retries will fail if they also do not
  * reuse the same |buf| pointer. This may be relaxed with
  * |SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER|, but the buffer contents still must be
  * unchanged.
  *
- * By default, |SSL_write| will not return success until all |num| bytes are
- * written. This may be relaxed with |SSL_MODE_ENABLE_PARTIAL_WRITE|. It allows
- * |SSL_write| to complete with a partial result when only part of the input was
- * written in a single record.
+ * By default, in TLS, |SSL_write| will not return success until all |num| bytes
+ * are written. This may be relaxed with |SSL_MODE_ENABLE_PARTIAL_WRITE|. It
+ * allows |SSL_write| to complete with a partial result when only part of the
+ * input was written in a single record.
+ *
+ * In DTLS, neither |SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER| and
+ * |SSL_MODE_ENABLE_PARTIAL_WRITE| do anything. The caller may retry with a
+ * different buffer freely. A single call to |SSL_write| only ever writes a
+ * single record in a single packet, so |num| must be at most
+ * |SSL3_RT_MAX_PLAIN_LENGTH|.
  *
  * TODO(davidben): Ensure 0 is only returned on transport EOF.
  * https://crbug.com/466303. */
@@ -482,11 +488,18 @@
  * flight of handshake messages and returns 1. If too many timeouts had expired
  * without progress or an error occurs, it returns -1.
  *
- * NOTE: The caller's external timer should be compatible with the one |ssl|
- * queries within some fudge factor. Otherwise, the call will be a no-op, but
+ * The caller's external timer should be compatible with the one |ssl| queries
+ * within some fudge factor. Otherwise, the call will be a no-op, but
  * |DTLSv1_get_timeout| will return an updated timeout.
  *
- * WARNING: This function breaks the usual return value convention. */
+ * If the function returns -1, checking if |SSL_get_error| returns
+ * |SSL_ERROR_WANT_WRITE| may be used to determine if the retransmit failed due
+ * to a non-fatal error at the write |BIO|. However, the operation may not be
+ * retried until the next timeout fires.
+ *
+ * WARNING: This function breaks the usual return value convention.
+ *
+ * TODO(davidben): This |SSL_ERROR_WANT_WRITE| behavior is kind of bizarre. */
 OPENSSL_EXPORT int DTLSv1_handle_timeout(SSL *ssl);
 
 
@@ -596,14 +609,16 @@
  *
  * Modes configure API behavior. */
 
-/* SSL_MODE_ENABLE_PARTIAL_WRITE allows |SSL_write| to complete with a partial
- * result when the only part of the input was written in a single record. */
+/* SSL_MODE_ENABLE_PARTIAL_WRITE, in TLS, allows |SSL_write| to complete with a
+ * partial result when the only part of the input was written in a single
+ * record. In DTLS, it does nothing. */
 #define SSL_MODE_ENABLE_PARTIAL_WRITE 0x00000001L
 
-/* SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER allows retrying an incomplete |SSL_write|
- * with a different buffer. However, |SSL_write| still assumes the buffer
- * contents are unchanged. This is not the default to avoid the misconception
- * that non-blocking |SSL_write| behaves like non-blocking |write|. */
+/* SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER, in TLS, allows retrying an incomplete
+ * |SSL_write| with a different buffer. However, |SSL_write| still assumes the
+ * buffer contents are unchanged. This is not the default to avoid the
+ * misconception that non-blocking |SSL_write| behaves like non-blocking
+ * |write|. In DTLS, it does nothing. */
 #define SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER 0x00000002L
 
 /* SSL_MODE_NO_AUTO_CHAIN disables automatically building a certificate chain
diff --git a/ssl/d1_clnt.c b/ssl/d1_clnt.c
index 3dd5f8c..7924598 100644
--- a/ssl/d1_clnt.c
+++ b/ssl/d1_clnt.c
@@ -190,13 +190,6 @@
       case SSL3_ST_CW_CLNT_HELLO_A:
       case SSL3_ST_CW_CLNT_HELLO_B:
         s->shutdown = 0;
-
-        if (!ssl3_init_handshake_buffer(s)) {
-          OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-          ret = -1;
-          goto end;
-        }
-
         dtls1_start_timer(s);
         ret = ssl3_send_client_hello(s);
         if (ret <= 0) {
diff --git a/ssl/s3_clnt.c b/ssl/s3_clnt.c
index 13bc0e8..12eb5e0 100644
--- a/ssl/s3_clnt.c
+++ b/ssl/s3_clnt.c
@@ -645,6 +645,13 @@
     return ssl_do_write(ssl);
   }
 
+  /* In DTLS, reset the handshake buffer each time a new ClientHello is
+   * assembled. We may send multiple if we receive HelloVerifyRequest. */
+  if (SSL_IS_DTLS(ssl) && !ssl3_init_handshake_buffer(ssl)) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+    return -1;
+  }
+
   CBB cbb;
   CBB_zero(&cbb);
 
diff --git a/ssl/ssl_buffer.c b/ssl/ssl_buffer.c
index 63dcd80..f1abc53 100644
--- a/ssl/ssl_buffer.c
+++ b/ssl/ssl_buffer.c
@@ -292,11 +292,18 @@
     return 1;
   }
 
+  ssl->rwstate = SSL_WRITING;
   int ret = BIO_write(ssl->wbio, buf->buf + buf->offset, buf->len);
-  /* Drop the write buffer whether or not the write succeeded synchronously.
-   * TODO(davidben): How does this interact with the retry flag? */
+  if (ret <= 0) {
+    /* If the write failed, drop the write buffer anyway. Datagram transports
+     * can't write half a packet, so the caller is expected to retry from the
+     * top. */
+    ssl_write_buffer_clear(ssl);
+    return ret;
+  }
+  ssl->rwstate = SSL_NOTHING;
   ssl_write_buffer_clear(ssl);
-  return (ret <= 0) ? ret : 1;
+  return 1;
 }
 
 int ssl_write_buffer_flush(SSL *ssl) {
diff --git a/ssl/test/async_bio.cc b/ssl/test/async_bio.cc
index 0534845..7a5737b 100644
--- a/ssl/test/async_bio.cc
+++ b/ssl/test/async_bio.cc
@@ -26,6 +26,7 @@
 
 struct AsyncBio {
   bool datagram;
+  bool enforce_write_quota;
   size_t read_quota;
   size_t write_quota;
 };
@@ -43,9 +44,7 @@
     return 0;
   }
 
-  if (a->datagram) {
-    // Perform writes synchronously; the DTLS implementation drops any packets
-    // that failed to send.
+  if (!a->enforce_write_quota) {
     return BIO_write(bio->next_bio, in, inl);
   }
 
@@ -111,6 +110,7 @@
     return 0;
   }
   memset(a, 0, sizeof(*a));
+  a->enforce_write_quota = true;
   bio->init = 1;
   bio->ptr = (char *)a;
   return 1;
@@ -178,3 +178,11 @@
   }
   a->write_quota += count;
 }
+
+void AsyncBioEnforceWriteQuota(BIO *bio, bool enforce) {
+  AsyncBio *a = GetData(bio);
+  if (a == NULL) {
+    return;
+  }
+  a->enforce_write_quota = enforce;
+}
diff --git a/ssl/test/async_bio.h b/ssl/test/async_bio.h
index 1ccdf9b..fbc4016 100644
--- a/ssl/test/async_bio.h
+++ b/ssl/test/async_bio.h
@@ -38,5 +38,8 @@
 // AsyncBioAllowWrite increments |bio|'s write quota by |count|.
 void AsyncBioAllowWrite(BIO *bio, size_t count);
 
+// AsyncBioEnforceWriteQuota configures where |bio| enforces its write quota.
+void AsyncBioEnforceWriteQuota(BIO *bio, bool enforce);
+
 
 #endif  // HEADER_ASYNC_BIO
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index 3f4e9cf..32c572e 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -808,6 +808,7 @@
     return false;
   }
 
+  const TestConfig *config = GetConfigPtr(ssl);
   TestState *test_state = GetTestState(ssl);
   if (test_state->clock_delta.tv_usec != 0 ||
       test_state->clock_delta.tv_sec != 0) {
@@ -818,7 +819,17 @@
     test_state->clock.tv_sec += test_state->clock_delta.tv_sec;
     memset(&test_state->clock_delta, 0, sizeof(test_state->clock_delta));
 
-    if (DTLSv1_handle_timeout(ssl) < 0) {
+    // The DTLS retransmit logic silently ignores write failures. So the test
+    // may progress, allow writes through synchronously.
+    if (config->async) {
+      AsyncBioEnforceWriteQuota(test_state->async_bio, false);
+    }
+    int timeout_ret = DTLSv1_handle_timeout(ssl);
+    if (config->async) {
+      AsyncBioEnforceWriteQuota(test_state->async_bio, true);
+    }
+
+    if (timeout_ret < 0) {
       fprintf(stderr, "Error retransmitting.\n");
       return false;
     }
@@ -863,9 +874,19 @@
 // the result value of the final |SSL_read| call.
 static int DoRead(SSL *ssl, uint8_t *out, size_t max_out) {
   const TestConfig *config = GetConfigPtr(ssl);
+  TestState *test_state = GetTestState(ssl);
   int ret;
   do {
+    if (config->async) {
+      // The DTLS retransmit logic silently ignores write failures. So the test
+      // may progress, allow writes through synchronously. |SSL_read| may
+      // trigger a retransmit, so disconnect the write quota.
+      AsyncBioEnforceWriteQuota(test_state->async_bio, false);
+    }
     ret = SSL_read(ssl, out, max_out);
+    if (config->async) {
+      AsyncBioEnforceWriteQuota(test_state->async_bio, true);
+    }
   } while (config->async && RetryAsync(ssl, ret));
   return ret;
 }
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 986e2b5..c911ad0 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -1078,14 +1078,16 @@
 		seq := packet[5:11]
 		length := uint16(packet[11])<<8 | uint16(packet[12])
 		if bytes.Equal(c.in.seq[:2], epoch) {
-			if !bytes.Equal(c.in.seq[2:], seq) {
+			if bytes.Compare(seq, c.in.seq[2:]) < 0 {
 				return errors.New("tls: sequence mismatch")
 			}
+			copy(c.in.seq[2:], seq)
 			c.in.incSeq(false)
 		} else {
-			if !bytes.Equal(c.in.nextSeq[:], seq) {
+			if bytes.Compare(seq, c.in.nextSeq[:]) < 0 {
 				return errors.New("tls: sequence mismatch")
 			}
+			copy(c.in.nextSeq[:], seq)
 			c.in.incNextSeq()
 		}
 		if len(packet) < 13+int(length) {
diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go
index fac035e..c3ee521 100644
--- a/ssl/test/runner/dtls.go
+++ b/ssl/test/runner/dtls.go
@@ -80,15 +80,21 @@
 			return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
 		}
 	}
-	seq := b.data[3:11]
-	// For test purposes, we assume a reliable channel. Require
-	// that the explicit sequence number matches the incrementing
-	// one we maintain. A real implementation would maintain a
-	// replay window and such.
-	if !bytes.Equal(seq, c.in.seq[:]) {
+	epoch := b.data[3:5]
+	seq := b.data[5:11]
+	// For test purposes, require the sequence number be monotonically
+	// increasing, so c.in includes the minimum next sequence number. Gaps
+	// may occur if packets failed to be sent out. A real implementation
+	// would maintain a replay window and such.
+	if !bytes.Equal(epoch, c.in.seq[:2]) {
+		c.sendAlert(alertIllegalParameter)
+		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
+	}
+	if bytes.Compare(seq, c.in.seq[2:]) < 0 {
 		c.sendAlert(alertIllegalParameter)
 		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
 	}
+	copy(c.in.seq[2:], seq)
 	n := int(b.data[11])<<8 | int(b.data[12])
 	if n > maxCiphertext || len(b.data) < recordHeaderLen+n {
 		c.sendAlert(alertRecordOverflow)