Move init_buf and rwstate into SSL3_STATE.

This finally clears most of the SSL_clear special-cases.

Change-Id: I00fc240ccbf13f4290322845f585ca6f5786ad80
Reviewed-on: https://boringssl-review.googlesource.com/21947
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Steven Valdez <svaldez@google.com>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index 798deb0..93c2724 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -808,14 +808,14 @@
       // Retry this packet the next time around.
       ssl->d1->outgoing_written = old_written;
       ssl->d1->outgoing_offset = old_offset;
-      ssl->rwstate = SSL_WRITING;
+      ssl->s3->rwstate = SSL_WRITING;
       ret = bio_ret;
       goto err;
     }
   }
 
   if (BIO_flush(ssl->wbio) <= 0) {
-    ssl->rwstate = SSL_WRITING;
+    ssl->s3->rwstate = SSL_WRITING;
     goto err;
   }
 
diff --git a/ssl/handshake.cc b/ssl/handshake.cc
index df47ed1..ed11484 100644
--- a/ssl/handshake.cc
+++ b/ssl/handshake.cc
@@ -546,42 +546,42 @@
       }
 
       case ssl_hs_certificate_selection_pending:
-        ssl->rwstate = SSL_CERTIFICATE_SELECTION_PENDING;
+        ssl->s3->rwstate = SSL_CERTIFICATE_SELECTION_PENDING;
         hs->wait = ssl_hs_ok;
         return -1;
 
       case ssl_hs_x509_lookup:
-        ssl->rwstate = SSL_X509_LOOKUP;
+        ssl->s3->rwstate = SSL_X509_LOOKUP;
         hs->wait = ssl_hs_ok;
         return -1;
 
       case ssl_hs_channel_id_lookup:
-        ssl->rwstate = SSL_CHANNEL_ID_LOOKUP;
+        ssl->s3->rwstate = SSL_CHANNEL_ID_LOOKUP;
         hs->wait = ssl_hs_ok;
         return -1;
 
       case ssl_hs_private_key_operation:
-        ssl->rwstate = SSL_PRIVATE_KEY_OPERATION;
+        ssl->s3->rwstate = SSL_PRIVATE_KEY_OPERATION;
         hs->wait = ssl_hs_ok;
         return -1;
 
       case ssl_hs_pending_session:
-        ssl->rwstate = SSL_PENDING_SESSION;
+        ssl->s3->rwstate = SSL_PENDING_SESSION;
         hs->wait = ssl_hs_ok;
         return -1;
 
       case ssl_hs_pending_ticket:
-        ssl->rwstate = SSL_PENDING_TICKET;
+        ssl->s3->rwstate = SSL_PENDING_TICKET;
         hs->wait = ssl_hs_ok;
         return -1;
 
       case ssl_hs_certificate_verify:
-        ssl->rwstate = SSL_CERTIFICATE_VERIFY;
+        ssl->s3->rwstate = SSL_CERTIFICATE_VERIFY;
         hs->wait = ssl_hs_ok;
         return -1;
 
       case ssl_hs_early_data_rejected:
-        ssl->rwstate = SSL_EARLY_DATA_REJECTED;
+        ssl->s3->rwstate = SSL_EARLY_DATA_REJECTED;
         // Cause |SSL_write| to start failing immediately.
         hs->can_early_write = false;
         return -1;
diff --git a/ssl/internal.h b/ssl/internal.h
index e76be27..9744a1a 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2199,6 +2199,11 @@
 
   int total_renegotiations = 0;
 
+  // This holds a variable that indicates what we were doing when a 0 or -1 is
+  // returned.  This is needed for non-blocking IO so we know what request
+  // needs re-doing when in SSL_accept or SSL_connect
+  int rwstate = SSL_NOTHING;
+
   // early_data_skipped is the amount of early data that has been skipped by the
   // record layer.
   uint16_t early_data_skipped = 0;
@@ -2257,6 +2262,9 @@
 
   uint8_t send_alert[2] = {0};
 
+  // hs_buf is the buffer of handshake data to process.
+  UniquePtr<BUF_MEM> hs_buf;
+
   // pending_flight is the pending outgoing flight. This is used to flush each
   // handshake flight in a single write. |write_buffer| must be written out
   // before this data.
@@ -2479,8 +2487,6 @@
   // progress.
   enum ssl_hs_wait_t (*do_handshake)(SSL_HANDSHAKE *hs);
 
-  BUF_MEM *init_buf;  // buffer used during init
-
   SSL3_STATE *s3;   // SSLv3 variables
   DTLS1_STATE *d1;  // DTLSv1 variables
 
@@ -2500,11 +2506,6 @@
   // This is used to hold the server certificate used
   CERT *cert;
 
-  // This holds a variable that indicates what we were doing when a 0 or -1 is
-  // returned.  This is needed for non-blocking IO so we know what request
-  // needs re-doing when in SSL_accept or SSL_connect
-  int rwstate;
-
   // initial_timeout_duration_ms is the default DTLS timeout duration in
   // milliseconds. It's used to initialize the timer any time it's restarted.
   unsigned initial_timeout_duration_ms;
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index e29d19f..ede4ba7 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -249,7 +249,7 @@
   if (!ssl->s3->write_buffer.empty()) {
     int ret = ssl_write_buffer_flush(ssl);
     if (ret <= 0) {
-      ssl->rwstate = SSL_WRITING;
+      ssl->s3->rwstate = SSL_WRITING;
       return ret;
     }
   }
@@ -261,7 +261,7 @@
         ssl->s3->pending_flight->data + ssl->s3->pending_flight_offset,
         ssl->s3->pending_flight->length - ssl->s3->pending_flight_offset);
     if (ret <= 0) {
-      ssl->rwstate = SSL_WRITING;
+      ssl->s3->rwstate = SSL_WRITING;
       return ret;
     }
 
@@ -269,7 +269,7 @@
   }
 
   if (BIO_flush(ssl->wbio) <= 0) {
-    ssl->rwstate = SSL_WRITING;
+    ssl->s3->rwstate = SSL_WRITING;
     return -1;
   }
 
@@ -351,9 +351,9 @@
                                1 /* compression length */ + 1 /* compression */;
   ScopedCBB client_hello;
   CBB hello_body, cipher_suites;
-  if (!BUF_MEM_reserve(ssl->init_buf, max_v3_client_hello) ||
-      !CBB_init_fixed(client_hello.get(), (uint8_t *)ssl->init_buf->data,
-                      ssl->init_buf->max) ||
+  if (!BUF_MEM_reserve(ssl->s3->hs_buf.get(), max_v3_client_hello) ||
+      !CBB_init_fixed(client_hello.get(), (uint8_t *)ssl->s3->hs_buf->data,
+                      ssl->s3->hs_buf->max) ||
       !CBB_add_u8(client_hello.get(), SSL3_MT_CLIENT_HELLO) ||
       !CBB_add_u24_length_prefixed(client_hello.get(), &hello_body) ||
       !CBB_add_u16(&hello_body, version) ||
@@ -386,7 +386,7 @@
   // Add the null compression scheme and finish.
   if (!CBB_add_u8(&hello_body, 1) ||
       !CBB_add_u8(&hello_body, 0) ||
-      !CBB_finish(client_hello.get(), NULL, &ssl->init_buf->length)) {
+      !CBB_finish(client_hello.get(), NULL, &ssl->s3->hs_buf->length)) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return ssl_open_record_error;
   }
@@ -398,15 +398,15 @@
 
 static bool parse_message(const SSL *ssl, SSLMessage *out,
                           size_t *out_bytes_needed) {
-  if (ssl->init_buf == NULL) {
+  if (!ssl->s3->hs_buf) {
     *out_bytes_needed = 4;
     return false;
   }
 
   CBS cbs;
   uint32_t len;
-  CBS_init(&cbs, reinterpret_cast<const uint8_t *>(ssl->init_buf->data),
-           ssl->init_buf->length);
+  CBS_init(&cbs, reinterpret_cast<const uint8_t *>(ssl->s3->hs_buf->data),
+           ssl->s3->hs_buf->length);
   if (!CBS_get_u8(&cbs, &out->type) ||
       !CBS_get_u24(&cbs, &len)) {
     *out_bytes_needed = 4;
@@ -418,7 +418,7 @@
     return false;
   }
 
-  CBS_init(&out->raw, reinterpret_cast<const uint8_t *>(ssl->init_buf->data),
+  CBS_init(&out->raw, reinterpret_cast<const uint8_t *>(ssl->s3->hs_buf->data),
            4 + len);
   out->is_v2_hello = ssl->s3->is_v2_hello;
   return true;
@@ -468,16 +468,16 @@
     }
   }
 
-  return ssl->init_buf != NULL && ssl->init_buf->length > msg_len;
+  return ssl->s3->hs_buf && ssl->s3->hs_buf->length > msg_len;
 }
 
 ssl_open_record_t ssl3_open_handshake(SSL *ssl, size_t *out_consumed,
                                       uint8_t *out_alert, Span<uint8_t> in) {
   *out_consumed = 0;
   // Re-create the handshake buffer if needed.
-  if (ssl->init_buf == NULL) {
-    ssl->init_buf = BUF_MEM_new();
-    if (ssl->init_buf == NULL) {
+  if (!ssl->s3->hs_buf) {
+    ssl->s3->hs_buf.reset(BUF_MEM_new());
+    if (!ssl->s3->hs_buf) {
       *out_alert = SSL_AD_INTERNAL_ERROR;
       return ssl_open_record_error;
     }
@@ -551,7 +551,7 @@
   }
 
   // Append the entire handshake record to the buffer.
-  if (!BUF_MEM_append(ssl->init_buf, body.data(), body.size())) {
+  if (!BUF_MEM_append(ssl->s3->hs_buf.get(), body.data(), body.size())) {
     *out_alert = SSL_AD_INTERNAL_ERROR;
     return ssl_open_record_error;
   }
@@ -562,23 +562,23 @@
 void ssl3_next_message(SSL *ssl) {
   SSLMessage msg;
   if (!ssl3_get_message(ssl, &msg) ||
-      ssl->init_buf == NULL ||
-      ssl->init_buf->length < CBS_len(&msg.raw)) {
+      !ssl->s3->hs_buf ||
+      ssl->s3->hs_buf->length < CBS_len(&msg.raw)) {
     assert(0);
     return;
   }
 
-  OPENSSL_memmove(ssl->init_buf->data, ssl->init_buf->data + CBS_len(&msg.raw),
-                  ssl->init_buf->length - CBS_len(&msg.raw));
-  ssl->init_buf->length -= CBS_len(&msg.raw);
+  OPENSSL_memmove(ssl->s3->hs_buf->data,
+                  ssl->s3->hs_buf->data + CBS_len(&msg.raw),
+                  ssl->s3->hs_buf->length - CBS_len(&msg.raw));
+  ssl->s3->hs_buf->length -= CBS_len(&msg.raw);
   ssl->s3->is_v2_hello = false;
   ssl->s3->has_message = false;
 
   // Post-handshake messages are rare, so release the buffer after every
   // message. During the handshake, |on_handshake_complete| will release it.
-  if (!SSL_in_init(ssl) && ssl->init_buf->length == 0) {
-    BUF_MEM_free(ssl->init_buf);
-    ssl->init_buf = NULL;
+  if (!SSL_in_init(ssl) && ssl->s3->hs_buf->length == 0) {
+    ssl->s3->hs_buf.reset();
   }
 }
 
diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc
index c4ccecc..285abb3 100644
--- a/ssl/s3_pkt.cc
+++ b/ssl/s3_pkt.cc
@@ -321,11 +321,11 @@
       return ssl_open_record_error;
     }
 
-    if (ssl->init_buf == NULL) {
-      ssl->init_buf = BUF_MEM_new();
+    if (!ssl->s3->hs_buf) {
+      ssl->s3->hs_buf.reset(BUF_MEM_new());
     }
-    if (ssl->init_buf == NULL ||
-        !BUF_MEM_append(ssl->init_buf, body.data(), body.size())) {
+    if (!ssl->s3->hs_buf ||
+        !BUF_MEM_append(ssl->s3->hs_buf.get(), body.data(), body.size())) {
       *out_alert = SSL_AD_INTERNAL_ERROR;
       return ssl_open_record_error;
     }
diff --git a/ssl/ssl_buffer.cc b/ssl/ssl_buffer.cc
index a942054..da1de93 100644
--- a/ssl/ssl_buffer.cc
+++ b/ssl/ssl_buffer.cc
@@ -115,7 +115,7 @@
   // Read a single packet from |ssl->rbio|. |buf->cap()| must fit in an int.
   int ret = BIO_read(ssl->rbio, buf->data(), static_cast<int>(buf->cap()));
   if (ret <= 0) {
-    ssl->rwstate = SSL_READING;
+    ssl->s3->rwstate = SSL_READING;
     return ret;
   }
   buf->DidWrite(static_cast<size_t>(ret));
@@ -137,7 +137,7 @@
     int ret = BIO_read(ssl->rbio, buf->data() + buf->size(),
                        static_cast<int>(len - buf->size()));
     if (ret <= 0) {
-      ssl->rwstate = SSL_READING;
+      ssl->s3->rwstate = SSL_READING;
       return ret;
     }
     buf->DidWrite(static_cast<size_t>(ret));
@@ -242,7 +242,7 @@
   while (!buf->empty()) {
     int ret = BIO_write(ssl->wbio, buf->data(), buf->size());
     if (ret <= 0) {
-      ssl->rwstate = SSL_WRITING;
+      ssl->s3->rwstate = SSL_WRITING;
       return ret;
     }
     buf->Consume(static_cast<size_t>(ret));
@@ -259,7 +259,7 @@
 
   int ret = BIO_write(ssl->wbio, buf->data(), buf->size());
   if (ret <= 0) {
-    ssl->rwstate = SSL_WRITING;
+    ssl->s3->rwstate = SSL_WRITING;
     // If the write failed, drop the write buffer anyway. Datagram transports
     // can't write half a packet, so the caller is expected to retry from the
     // top.
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index c8e6c94..2fc6ffd 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -201,7 +201,7 @@
 void ssl_reset_error_state(SSL *ssl) {
   // Functions which use |SSL_get_error| must reset I/O and error state on
   // entry.
-  ssl->rwstate = SSL_NOTHING;
+  ssl->s3->rwstate = SSL_NOTHING;
   ERR_clear_error();
   ERR_clear_system_error();
 }
@@ -716,8 +716,6 @@
     goto err;
   }
 
-  ssl->rwstate = SSL_NOTHING;
-
   CRYPTO_new_ex_data(&ssl->ex_data);
 
   ssl->psk_identity_hint = NULL;
@@ -762,8 +760,6 @@
   BIO_free_all(ssl->rbio);
   BIO_free_all(ssl->wbio);
 
-  BUF_MEM_free(ssl->init_buf);
-
   // add extra stuff
   ssl_cipher_preference_list_free(ssl->cipher_list);
 
@@ -1240,7 +1236,7 @@
     return SSL_ERROR_SYSCALL;
   }
 
-  switch (ssl->rwstate) {
+  switch (ssl->s3->rwstate) {
     case SSL_PENDING_SESSION:
       return SSL_ERROR_PENDING_SESSION;
 
@@ -2294,7 +2290,7 @@
   return CRYPTO_get_ex_data(&ctx->ex_data, idx);
 }
 
-int SSL_want(const SSL *ssl) { return ssl->rwstate; }
+int SSL_want(const SSL *ssl) { return ssl->s3->rwstate; }
 
 void SSL_CTX_set_tmp_rsa_callback(SSL_CTX *ctx,
                                   RSA *(*cb)(SSL *ssl, int is_export,
@@ -2578,17 +2574,6 @@
     SSL_SESSION_up_ref(session.get());
   }
 
-  // TODO(davidben): Some state on |ssl| is reset both in |SSL_new| and
-  // |SSL_clear| because it is per-connection state rather than configuration
-  // state. Per-connection state should be on |ssl->s3| and |ssl->d1| so it is
-  // naturally reset at the right points between |SSL_new|, |SSL_clear|, and
-  // |ssl3_new|.
-
-  ssl->rwstate = SSL_NOTHING;
-
-  BUF_MEM_free(ssl->init_buf);
-  ssl->init_buf = NULL;
-
   // The ssl->d1->mtu is simultaneously configuration (preserved across
   // clear) and connection-specific state (gets reset).
   //
diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc
index 157cff4..8aeb489 100644
--- a/ssl/tls_method.cc
+++ b/ssl/tls_method.cc
@@ -73,14 +73,13 @@
   // The handshake should have released its final message.
   assert(!ssl->s3->has_message);
 
-  // During the handshake, |init_buf| is retained. Release if it there is no
+  // During the handshake, |hs_buf| is retained. Release if it there is no
   // excess in it. There may be excess left if there server sent Finished and
   // HelloRequest in the same record.
   //
   // TODO(davidben): SChannel does not support this. Reject this case.
-  if (ssl->init_buf != NULL && ssl->init_buf->length == 0) {
-    BUF_MEM_free(ssl->init_buf);
-    ssl->init_buf = NULL;
+  if (ssl->s3->hs_buf && ssl->s3->hs_buf->length == 0) {
+    ssl->s3->hs_buf.reset();
   }
 }