Move references to init_buf into SSL_PROTOCOL_METHOD.

Both DTLS and TLS still use it, but that will change in the following
commit. This also removes the handshake's knowledge of the
dtls_clear_incoming_messages function.

(It's possible we'll want to get rid of begin_handshake in favor of
allocating it lazily depending on how TLS 1.3 post-handshake messages
end up working out. But this should work for now.)

Change-Id: I0f512788bbc330ab2c947890939c73e0a1aca18b
Reviewed-on: https://boringssl-review.googlesource.com/8666
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/dtls_method.c b/ssl/dtls_method.c
index 00454dd..f6376bb 100644
--- a/ssl/dtls_method.c
+++ b/ssl/dtls_method.c
@@ -58,6 +58,8 @@
 
 #include <assert.h>
 
+#include <openssl/buf.h>
+
 #include "internal.h"
 
 
@@ -88,6 +90,32 @@
   return ~(version - 0x0201);
 }
 
+static int dtls1_begin_handshake(SSL *ssl) {
+  if (ssl->init_buf != NULL) {
+    return 1;
+  }
+
+  BUF_MEM *buf = BUF_MEM_new();
+  if (buf == NULL || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
+    BUF_MEM_free(buf);
+    return 0;
+  }
+
+  ssl->init_buf = buf;
+  ssl->init_num = 0;
+  return 1;
+}
+
+static void dtls1_finish_handshake(SSL *ssl) {
+  BUF_MEM_free(ssl->init_buf);
+  ssl->init_buf = NULL;
+  ssl->init_num = 0;
+
+  ssl->d1->handshake_read_seq = 0;
+  ssl->d1->handshake_write_seq = 0;
+  dtls_clear_incoming_messages(ssl);
+}
+
 static const SSL_PROTOCOL_METHOD kDTLSProtocolMethod = {
     1 /* is_dtls */,
     TLS1_1_VERSION,
@@ -96,6 +124,8 @@
     dtls1_version_to_wire,
     dtls1_new,
     dtls1_free,
+    dtls1_begin_handshake,
+    dtls1_finish_handshake,
     dtls1_get_message,
     dtls1_read_app_data,
     dtls1_read_change_cipher_spec,
diff --git a/ssl/handshake_client.c b/ssl/handshake_client.c
index 2ac7dee..3c90310 100644
--- a/ssl/handshake_client.c
+++ b/ssl/handshake_client.c
@@ -187,7 +187,6 @@
 static int ssl3_get_new_session_ticket(SSL *ssl);
 
 int ssl3_connect(SSL *ssl) {
-  BUF_MEM *buf = NULL;
   int ret = -1;
   int state, skip = 0;
 
@@ -201,18 +200,10 @@
       case SSL_ST_CONNECT:
         ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1);
 
-        if (ssl->init_buf == NULL) {
-          buf = BUF_MEM_new();
-          if (buf == NULL ||
-              !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
-            ret = -1;
-            goto end;
-          }
-
-          ssl->init_buf = buf;
-          buf = NULL;
+        if (!ssl->method->begin_handshake(ssl)) {
+          ret = -1;
+          goto end;
         }
-        ssl->init_num = 0;
 
         if (!ssl_init_wbio_buffer(ssl)) {
           ret = -1;
@@ -503,9 +494,7 @@
         /* clean a few things up */
         ssl3_cleanup_key_block(ssl);
 
-        BUF_MEM_free(ssl->init_buf);
-        ssl->init_buf = NULL;
-        ssl->init_num = 0;
+        ssl->method->finish_handshake(ssl);
 
         /* Remove write buffering now. */
         ssl_free_wbio_buffer(ssl);
@@ -520,11 +509,6 @@
           ssl_update_cache(ssl, SSL_SESS_CACHE_CLIENT);
         }
 
-        if (SSL_IS_DTLS(ssl)) {
-          ssl->d1->handshake_read_seq = 0;
-          ssl->d1->handshake_write_seq = 0;
-        }
-
         ret = 1;
         ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_DONE, 1);
         goto end;
@@ -545,7 +529,6 @@
   }
 
 end:
-  BUF_MEM_free(buf);
   ssl_do_info_callback(ssl, SSL_CB_CONNECT_EXIT, ret);
   return ret;
 }
diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c
index e067904..375f0e3 100644
--- a/ssl/handshake_server.c
+++ b/ssl/handshake_server.c
@@ -188,7 +188,6 @@
 static int ssl3_send_new_session_ticket(SSL *ssl);
 
 int ssl3_accept(SSL *ssl) {
-  BUF_MEM *buf = NULL;
   uint32_t alg_a;
   int ret = -1;
   int state, skip = 0;
@@ -203,16 +202,10 @@
       case SSL_ST_ACCEPT:
         ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1);
 
-        if (ssl->init_buf == NULL) {
-          buf = BUF_MEM_new();
-          if (!buf || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
-            ret = -1;
-            goto end;
-          }
-          ssl->init_buf = buf;
-          buf = NULL;
+        if (!ssl->method->begin_handshake(ssl)) {
+          ret = -1;
+          goto end;
         }
-        ssl->init_num = 0;
 
         /* Enable a write buffer. This groups handshake messages within a flight
          * into a single write. */
@@ -470,9 +463,7 @@
         /* clean a few things up */
         ssl3_cleanup_key_block(ssl);
 
-        BUF_MEM_free(ssl->init_buf);
-        ssl->init_buf = NULL;
-        ssl->init_num = 0;
+        ssl->method->finish_handshake(ssl);
 
         /* remove buffering on output */
         ssl_free_wbio_buffer(ssl);
@@ -486,12 +477,6 @@
           ssl->session->cert_chain = NULL;
         }
 
-        if (SSL_IS_DTLS(ssl)) {
-          ssl->d1->handshake_read_seq = 0;
-          ssl->d1->handshake_write_seq = 0;
-          dtls_clear_incoming_messages(ssl);
-        }
-
         ssl->s3->initial_handshake_complete = 1;
 
         ssl_update_cache(ssl, SSL_SESS_CACHE_SERVER);
@@ -517,7 +502,6 @@
   }
 
 end:
-  BUF_MEM_free(buf);
   ssl_do_info_callback(ssl, SSL_CB_ACCEPT_EXIT, ret);
   return ret;
 }
diff --git a/ssl/internal.h b/ssl/internal.h
index ab79dcc..2e4cb46 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -829,6 +829,11 @@
   uint16_t (*version_to_wire)(uint16_t version);
   int (*ssl_new)(SSL *ssl);
   void (*ssl_free)(SSL *ssl);
+  /* begin_handshake is called to start a new handshake. It returns one on
+   * success and zero on error. */
+  int (*begin_handshake)(SSL *ssl);
+  /* finish_handshake is called when a handshake completes. */
+  void (*finish_handshake)(SSL *ssl);
   long (*ssl_get_message)(SSL *ssl, int msg_type,
                           enum ssl_hash_message_t hash_message, int *ok);
   int (*read_app_data)(SSL *ssl, uint8_t *buf, int len, int peek);
diff --git a/ssl/tls_method.c b/ssl/tls_method.c
index e8cf1d6..dab5c47 100644
--- a/ssl/tls_method.c
+++ b/ssl/tls_method.c
@@ -56,6 +56,8 @@
 
 #include <openssl/ssl.h>
 
+#include <openssl/buf.h>
+
 #include "internal.h"
 
 
@@ -65,6 +67,28 @@
 
 static uint16_t ssl3_version_to_wire(uint16_t version) { return version; }
 
+static int ssl3_begin_handshake(SSL *ssl) {
+  if (ssl->init_buf != NULL) {
+    return 1;
+  }
+
+  BUF_MEM *buf = BUF_MEM_new();
+  if (buf == NULL || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
+    BUF_MEM_free(buf);
+    return 0;
+  }
+
+  ssl->init_buf = buf;
+  ssl->init_num = 0;
+  return 1;
+}
+
+static void ssl3_finish_handshake(SSL *ssl) {
+  BUF_MEM_free(ssl->init_buf);
+  ssl->init_buf = NULL;
+  ssl->init_num = 0;
+}
+
 static const SSL_PROTOCOL_METHOD kTLSProtocolMethod = {
     0 /* is_dtls */,
     SSL3_VERSION,
@@ -73,6 +97,8 @@
     ssl3_version_to_wire,
     ssl3_new,
     ssl3_free,
+    ssl3_begin_handshake,
+    ssl3_finish_handshake,
     ssl3_get_message,
     ssl3_read_app_data,
     ssl3_read_change_cipher_spec,