Replace init_msg/init_num with a get_message hook. Rather than init_msg/init_num, there is a get_message function which either returns success or try again. This function does not advance the current message (see the previous preparatory change). It only completes the current one if necessary. Being idempotent means it may be freely placed at the top of states which otherwise have other asychronous operations. It also eases converting the TLS 1.2 state machine. See https://docs.google.com/a/google.com/document/d/11n7LHsT3GwE34LAJIe3EFs4165TI4UR_3CqiM9LJVpI/edit?usp=sharing for details. The read_message hook (later to be replaced by something which doesn't depend on BIO) intentionally does not finish the handshake, only "makes progress". A follow-up change will align both TLS and DTLS on consuming one handshake record and always consuming the entire record (so init_buf may contain trailing data). In a few places I've gone ahead and accounted for that case because it was more natural to do so. This change also removes a couple pointers of redundant state from every socket. Bug: 128 Change-Id: I89d8f3622d3b53147d69ee3ac34bb654ed044a71 Reviewed-on: https://boringssl-review.googlesource.com/18806 Reviewed-by: David Benjamin <davidben@google.com> Commit-Queue: David Benjamin <davidben@google.com> CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc index 70af6c1..2538d28 100644 --- a/ssl/d1_both.cc +++ b/ssl/d1_both.cc
@@ -298,9 +298,7 @@ return frag; } -/* dtls1_process_handshake_record reads a record for the handshake and processes - * it. It returns one on success and 0 or -1 on error. */ -static int dtls1_process_handshake_record(SSL *ssl) { +int dtls1_read_message(SSL *ssl) { SSL3_RECORD *rr = &ssl->s3->rrec; if (rr->length == 0) { int ret = dtls1_get_record(ssl); @@ -419,52 +417,33 @@ return 1; } -int dtls1_get_message(SSL *ssl) { - /* Process handshake records until the current message is ready. */ - while (!dtls1_is_current_message_complete(ssl)) { - int ret = dtls1_process_handshake_record(ssl); - if (ret <= 0) { - return ret; - } +bool dtls1_get_message(SSL *ssl, SSLMessage *out) { + if (!dtls1_is_current_message_complete(ssl)) { + return false; } hm_fragment *frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT]; - assert(frag != NULL); - assert(frag->reassembly == NULL); - assert(ssl->d1->handshake_read_seq == frag->seq); - - if (ssl->init_msg == NULL) { + out->type = frag->type; + CBS_init(&out->body, frag->data + DTLS1_HM_HEADER_LENGTH, frag->msg_len); + CBS_init(&out->raw, frag->data, DTLS1_HM_HEADER_LENGTH + frag->msg_len); + out->is_v2_hello = false; + if (!ssl->s3->has_message) { ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, frag->data, frag->msg_len + DTLS1_HM_HEADER_LENGTH); + ssl->s3->has_message = 1; } - - /* TODO(davidben): This function has a lot of implicit outputs. Simplify the - * |ssl_get_message| API. */ - ssl->s3->tmp.message_type = frag->type; - ssl->init_msg = frag->data + DTLS1_HM_HEADER_LENGTH; - ssl->init_num = frag->msg_len; - return 1; -} - -void dtls1_get_current_message(const SSL *ssl, CBS *out) { - assert(dtls1_is_current_message_complete(ssl)); - - hm_fragment *frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq % - SSL_MAX_HANDSHAKE_FLIGHT]; - CBS_init(out, frag->data, DTLS1_HM_HEADER_LENGTH + frag->msg_len); + return true; } void dtls1_next_message(SSL *ssl) { - assert(ssl->init_msg != NULL); + assert(ssl->s3->has_message); assert(dtls1_is_current_message_complete(ssl)); size_t index = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT; dtls1_hm_fragment_free(ssl->d1->incoming_messages[index]); ssl->d1->incoming_messages[index] = NULL; ssl->d1->handshake_read_seq++; - - ssl->init_msg = NULL; - ssl->init_num = 0; + ssl->s3->has_message = 0; } void dtls_clear_incoming_messages(SSL *ssl) { @@ -478,7 +457,7 @@ size_t current = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT; for (size_t i = 0; i < SSL_MAX_HANDSHAKE_FLIGHT; i++) { /* Skip the current message. */ - if (ssl->init_msg != NULL && i == current) { + if (ssl->s3->has_message && i == current) { assert(dtls1_is_current_message_complete(ssl)); continue; } @@ -508,7 +487,7 @@ int dtls1_read_change_cipher_spec(SSL *ssl) { /* Process handshake records until there is a ChangeCipherSpec. */ while (!ssl->d1->has_change_cipher_spec) { - int ret = dtls1_process_handshake_record(ssl); + int ret = dtls1_read_message(ssl); if (ret <= 0) { return ret; }