Move references to init_buf into SSL_PROTOCOL_METHOD.
Both DTLS and TLS still use it, but that will change in the following
commit. This also removes the handshake's knowledge of the
dtls_clear_incoming_messages function.
(It's possible we'll want to get rid of begin_handshake in favor of
allocating it lazily depending on how TLS 1.3 post-handshake messages
end up working out. But this should work for now.)
Change-Id: I0f512788bbc330ab2c947890939c73e0a1aca18b
Reviewed-on: https://boringssl-review.googlesource.com/8666
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/dtls_method.c b/ssl/dtls_method.c
index 00454dd..f6376bb 100644
--- a/ssl/dtls_method.c
+++ b/ssl/dtls_method.c
@@ -58,6 +58,8 @@
#include <assert.h>
+#include <openssl/buf.h>
+
#include "internal.h"
@@ -88,6 +90,32 @@
return ~(version - 0x0201);
}
+static int dtls1_begin_handshake(SSL *ssl) {
+ if (ssl->init_buf != NULL) {
+ return 1;
+ }
+
+ BUF_MEM *buf = BUF_MEM_new();
+ if (buf == NULL || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
+ BUF_MEM_free(buf);
+ return 0;
+ }
+
+ ssl->init_buf = buf;
+ ssl->init_num = 0;
+ return 1;
+}
+
+static void dtls1_finish_handshake(SSL *ssl) {
+ BUF_MEM_free(ssl->init_buf);
+ ssl->init_buf = NULL;
+ ssl->init_num = 0;
+
+ ssl->d1->handshake_read_seq = 0;
+ ssl->d1->handshake_write_seq = 0;
+ dtls_clear_incoming_messages(ssl);
+}
+
static const SSL_PROTOCOL_METHOD kDTLSProtocolMethod = {
1 /* is_dtls */,
TLS1_1_VERSION,
@@ -96,6 +124,8 @@
dtls1_version_to_wire,
dtls1_new,
dtls1_free,
+ dtls1_begin_handshake,
+ dtls1_finish_handshake,
dtls1_get_message,
dtls1_read_app_data,
dtls1_read_change_cipher_spec,
diff --git a/ssl/handshake_client.c b/ssl/handshake_client.c
index 2ac7dee..3c90310 100644
--- a/ssl/handshake_client.c
+++ b/ssl/handshake_client.c
@@ -187,7 +187,6 @@
static int ssl3_get_new_session_ticket(SSL *ssl);
int ssl3_connect(SSL *ssl) {
- BUF_MEM *buf = NULL;
int ret = -1;
int state, skip = 0;
@@ -201,18 +200,10 @@
case SSL_ST_CONNECT:
ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1);
- if (ssl->init_buf == NULL) {
- buf = BUF_MEM_new();
- if (buf == NULL ||
- !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
- ret = -1;
- goto end;
- }
-
- ssl->init_buf = buf;
- buf = NULL;
+ if (!ssl->method->begin_handshake(ssl)) {
+ ret = -1;
+ goto end;
}
- ssl->init_num = 0;
if (!ssl_init_wbio_buffer(ssl)) {
ret = -1;
@@ -503,9 +494,7 @@
/* clean a few things up */
ssl3_cleanup_key_block(ssl);
- BUF_MEM_free(ssl->init_buf);
- ssl->init_buf = NULL;
- ssl->init_num = 0;
+ ssl->method->finish_handshake(ssl);
/* Remove write buffering now. */
ssl_free_wbio_buffer(ssl);
@@ -520,11 +509,6 @@
ssl_update_cache(ssl, SSL_SESS_CACHE_CLIENT);
}
- if (SSL_IS_DTLS(ssl)) {
- ssl->d1->handshake_read_seq = 0;
- ssl->d1->handshake_write_seq = 0;
- }
-
ret = 1;
ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_DONE, 1);
goto end;
@@ -545,7 +529,6 @@
}
end:
- BUF_MEM_free(buf);
ssl_do_info_callback(ssl, SSL_CB_CONNECT_EXIT, ret);
return ret;
}
diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c
index e067904..375f0e3 100644
--- a/ssl/handshake_server.c
+++ b/ssl/handshake_server.c
@@ -188,7 +188,6 @@
static int ssl3_send_new_session_ticket(SSL *ssl);
int ssl3_accept(SSL *ssl) {
- BUF_MEM *buf = NULL;
uint32_t alg_a;
int ret = -1;
int state, skip = 0;
@@ -203,16 +202,10 @@
case SSL_ST_ACCEPT:
ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1);
- if (ssl->init_buf == NULL) {
- buf = BUF_MEM_new();
- if (!buf || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
- ret = -1;
- goto end;
- }
- ssl->init_buf = buf;
- buf = NULL;
+ if (!ssl->method->begin_handshake(ssl)) {
+ ret = -1;
+ goto end;
}
- ssl->init_num = 0;
/* Enable a write buffer. This groups handshake messages within a flight
* into a single write. */
@@ -470,9 +463,7 @@
/* clean a few things up */
ssl3_cleanup_key_block(ssl);
- BUF_MEM_free(ssl->init_buf);
- ssl->init_buf = NULL;
- ssl->init_num = 0;
+ ssl->method->finish_handshake(ssl);
/* remove buffering on output */
ssl_free_wbio_buffer(ssl);
@@ -486,12 +477,6 @@
ssl->session->cert_chain = NULL;
}
- if (SSL_IS_DTLS(ssl)) {
- ssl->d1->handshake_read_seq = 0;
- ssl->d1->handshake_write_seq = 0;
- dtls_clear_incoming_messages(ssl);
- }
-
ssl->s3->initial_handshake_complete = 1;
ssl_update_cache(ssl, SSL_SESS_CACHE_SERVER);
@@ -517,7 +502,6 @@
}
end:
- BUF_MEM_free(buf);
ssl_do_info_callback(ssl, SSL_CB_ACCEPT_EXIT, ret);
return ret;
}
diff --git a/ssl/internal.h b/ssl/internal.h
index ab79dcc..2e4cb46 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -829,6 +829,11 @@
uint16_t (*version_to_wire)(uint16_t version);
int (*ssl_new)(SSL *ssl);
void (*ssl_free)(SSL *ssl);
+ /* begin_handshake is called to start a new handshake. It returns one on
+ * success and zero on error. */
+ int (*begin_handshake)(SSL *ssl);
+ /* finish_handshake is called when a handshake completes. */
+ void (*finish_handshake)(SSL *ssl);
long (*ssl_get_message)(SSL *ssl, int msg_type,
enum ssl_hash_message_t hash_message, int *ok);
int (*read_app_data)(SSL *ssl, uint8_t *buf, int len, int peek);
diff --git a/ssl/tls_method.c b/ssl/tls_method.c
index e8cf1d6..dab5c47 100644
--- a/ssl/tls_method.c
+++ b/ssl/tls_method.c
@@ -56,6 +56,8 @@
#include <openssl/ssl.h>
+#include <openssl/buf.h>
+
#include "internal.h"
@@ -65,6 +67,28 @@
static uint16_t ssl3_version_to_wire(uint16_t version) { return version; }
+static int ssl3_begin_handshake(SSL *ssl) {
+ if (ssl->init_buf != NULL) {
+ return 1;
+ }
+
+ BUF_MEM *buf = BUF_MEM_new();
+ if (buf == NULL || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
+ BUF_MEM_free(buf);
+ return 0;
+ }
+
+ ssl->init_buf = buf;
+ ssl->init_num = 0;
+ return 1;
+}
+
+static void ssl3_finish_handshake(SSL *ssl) {
+ BUF_MEM_free(ssl->init_buf);
+ ssl->init_buf = NULL;
+ ssl->init_num = 0;
+}
+
static const SSL_PROTOCOL_METHOD kTLSProtocolMethod = {
0 /* is_dtls */,
SSL3_VERSION,
@@ -73,6 +97,8 @@
ssl3_version_to_wire,
ssl3_new,
ssl3_free,
+ ssl3_begin_handshake,
+ ssl3_finish_handshake,
ssl3_get_message,
ssl3_read_app_data,
ssl3_read_change_cipher_spec,