Replace the incoming message buffer with a ring buffer.
It has size 7. There's no need for a priority queue structure, especially one
that's O(N^2) anyway.
Change-Id: I7609794aac1925c9bbf3015744cae266dcb79bff
Reviewed-on: https://boringssl-review.googlesource.com/8437
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/d1_both.c b/ssl/d1_both.c
index 9adc2aa..f021d0a 100644
--- a/ssl/d1_both.c
+++ b/ssl/d1_both.c
@@ -139,6 +139,15 @@
* the underlying BIO supplies one. */
static const unsigned int kDefaultMTU = 1500 - 28;
+static void dtls1_hm_fragment_free(hm_fragment *frag) {
+ if (frag == NULL) {
+ return;
+ }
+ OPENSSL_free(frag->fragment);
+ OPENSSL_free(frag->reassembly);
+ OPENSSL_free(frag);
+}
+
static hm_fragment *dtls1_hm_fragment_new(size_t frag_len) {
hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment));
if (frag == NULL) {
@@ -177,15 +186,6 @@
return NULL;
}
-void dtls1_hm_fragment_free(hm_fragment *frag) {
- if (frag == NULL) {
- return;
- }
- OPENSSL_free(frag->fragment);
- OPENSSL_free(frag->reassembly);
- OPENSSL_free(frag);
-}
-
#if !defined(inline)
#define inline __inline
#endif
@@ -407,16 +407,9 @@
/* dtls1_is_next_message_complete returns one if the next handshake message is
* complete and zero otherwise. */
static int dtls1_is_next_message_complete(SSL *ssl) {
- pitem *item = pqueue_peek(ssl->d1->buffered_messages);
- if (item == NULL) {
- return 0;
- }
-
- hm_fragment *frag = (hm_fragment *)item->data;
- assert(ssl->d1->handshake_read_seq <= frag->msg_header.seq);
-
- return ssl->d1->handshake_read_seq == frag->msg_header.seq &&
- frag->reassembly == NULL;
+ hm_fragment *frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
+ SSL_MAX_HANDSHAKE_FLIGHT];
+ return frag != NULL && frag->reassembly == NULL;
}
/* dtls1_get_buffered_message returns the buffered message corresponding to
@@ -425,41 +418,33 @@
* returns NULL on failure. The caller does not take ownership of the result. */
static hm_fragment *dtls1_get_buffered_message(
SSL *ssl, const struct hm_header_st *msg_hdr) {
- uint8_t seq64be[8];
- memset(seq64be, 0, sizeof(seq64be));
- seq64be[6] = (uint8_t)(msg_hdr->seq >> 8);
- seq64be[7] = (uint8_t)msg_hdr->seq;
- pitem *item = pqueue_find(ssl->d1->buffered_messages, seq64be);
+ if (msg_hdr->seq < ssl->d1->handshake_read_seq ||
+ msg_hdr->seq - ssl->d1->handshake_read_seq >= SSL_MAX_HANDSHAKE_FLIGHT) {
+ return NULL;
+ }
- hm_fragment *frag;
- if (item == NULL) {
- /* This is the first fragment from this message. */
- frag = dtls1_hm_fragment_new(msg_hdr->msg_len);
- if (frag == NULL) {
- return NULL;
- }
- memcpy(&frag->msg_header, msg_hdr, sizeof(*msg_hdr));
- item = pitem_new(seq64be, frag);
- if (item == NULL) {
- dtls1_hm_fragment_free(frag);
- return NULL;
- }
- item = pqueue_insert(ssl->d1->buffered_messages, item);
- /* |pqueue_insert| fails iff a duplicate item is inserted, but |item| cannot
- * be a duplicate. */
- assert(item != NULL);
- } else {
- frag = item->data;
+ size_t idx = msg_hdr->seq % SSL_MAX_HANDSHAKE_FLIGHT;
+ hm_fragment *frag = ssl->d1->incoming_messages[idx];
+ if (frag != NULL) {
assert(frag->msg_header.seq == msg_hdr->seq);
+ /* The new fragment must be compatible with the previous fragments from this
+ * message. */
if (frag->msg_header.type != msg_hdr->type ||
frag->msg_header.msg_len != msg_hdr->msg_len) {
- /* The new fragment must be compatible with the previous fragments from
- * this message. */
OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH);
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
return NULL;
}
+ return frag;
}
+
+ /* This is the first fragment from this message. */
+ frag = dtls1_hm_fragment_new(msg_hdr->msg_len);
+ if (frag == NULL) {
+ return NULL;
+ }
+ memcpy(&frag->msg_header, msg_hdr, sizeof(*msg_hdr));
+ ssl->d1->incoming_messages[idx] = frag;
return frag;
}
@@ -557,7 +542,6 @@
* arrive in fragments. */
long dtls1_get_message(SSL *ssl, int msg_type,
enum ssl_hash_message_t hash_message, int *ok) {
- pitem *item = NULL;
hm_fragment *frag = NULL;
int al;
@@ -590,12 +574,17 @@
}
}
- /* Read out the next complete handshake message. */
- item = pqueue_pop(ssl->d1->buffered_messages);
- assert(item != NULL);
- frag = (hm_fragment *)item->data;
- assert(ssl->d1->handshake_read_seq == frag->msg_header.seq);
+ /* Pop an entry from the ring buffer. */
+ frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
+ SSL_MAX_HANDSHAKE_FLIGHT];
+ ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
+ SSL_MAX_HANDSHAKE_FLIGHT] = NULL;
+
+ assert(frag != NULL);
assert(frag->reassembly == NULL);
+ assert(ssl->d1->handshake_read_seq == frag->msg_header.seq);
+
+ ssl->d1->handshake_read_seq++;
/* Reconstruct the assembled message. */
CBB cbb;
@@ -618,8 +607,6 @@
assert(ssl->init_buf->length ==
(size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH);
- ssl->d1->handshake_read_seq++;
-
/* TODO(davidben): This function has a lot of implicit outputs. Simplify the
* |ssl_get_message| API. */
ssl->s3->tmp.message_type = frag->msg_header.type;
@@ -639,7 +626,6 @@
ssl->init_buf->data,
ssl->init_num + DTLS1_HM_HEADER_LENGTH);
- pitem_free(item);
dtls1_hm_fragment_free(frag);
*ok = 1;
@@ -648,7 +634,6 @@
f_err:
ssl3_send_alert(ssl, SSL3_AL_FATAL, al);
err:
- pitem_free(item);
dtls1_hm_fragment_free(frag);
*ok = 0;
return -1;
@@ -765,6 +750,14 @@
ssl->d1->outgoing_messages_len = 0;
}
+void dtls_clear_incoming_messages(SSL *ssl) {
+ size_t i;
+ for (i = 0; i < SSL_MAX_HANDSHAKE_FLIGHT; i++) {
+ dtls1_hm_fragment_free(ssl->d1->incoming_messages[i]);
+ ssl->d1->incoming_messages[i] = NULL;
+ }
+}
+
unsigned int dtls1_min_mtu(void) {
return kMinMTU;
}
diff --git a/ssl/d1_lib.c b/ssl/d1_lib.c
index 0c47dc6..d738c57 100644
--- a/ssl/d1_lib.c
+++ b/ssl/d1_lib.c
@@ -97,13 +97,6 @@
}
memset(d1, 0, sizeof *d1);
- d1->buffered_messages = pqueue_new();
- if (d1->buffered_messages == NULL) {
- OPENSSL_free(d1);
- ssl3_free(ssl);
- return 0;
- }
-
ssl->d1 = d1;
/* Set the version to the highest supported version.
@@ -115,17 +108,6 @@
return 1;
}
-static void dtls1_clear_queues(SSL *ssl) {
- pitem *item = NULL;
- hm_fragment *frag = NULL;
-
- while ((item = pqueue_pop(ssl->d1->buffered_messages)) != NULL) {
- frag = (hm_fragment *)item->data;
- dtls1_hm_fragment_free(frag);
- pitem_free(item);
- }
-}
-
void dtls1_free(SSL *ssl) {
ssl3_free(ssl);
@@ -133,9 +115,7 @@
return;
}
- dtls1_clear_queues(ssl);
- pqueue_free(ssl->d1->buffered_messages);
-
+ dtls_clear_incoming_messages(ssl);
dtls_clear_outgoing_messages(ssl);
OPENSSL_free(ssl->d1);
diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c
index b0e4c56..7f7f3b8 100644
--- a/ssl/handshake_server.c
+++ b/ssl/handshake_server.c
@@ -516,6 +516,7 @@
if (SSL_IS_DTLS(ssl)) {
ssl->d1->handshake_read_seq = 0;
ssl->d1->handshake_write_seq = 0;
+ dtls_clear_incoming_messages(ssl);
}
ssl->s3->initial_handshake_complete = 1;
diff --git a/ssl/internal.h b/ssl/internal.h
index 64a63d8..369bdcc 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -145,7 +145,6 @@
#include <openssl/base.h>
#include <openssl/aead.h>
-#include <openssl/pqueue.h>
#include <openssl/ssl.h>
#include <openssl/stack.h>
@@ -645,6 +644,9 @@
* in a handshake message for |ssl|. */
size_t ssl_max_handshake_message_len(const SSL *ssl);
+/* dtls_clear_incoming_messages releases all buffered incoming messages. */
+void dtls_clear_incoming_messages(SSL *ssl);
+
typedef struct dtls_outgoing_message_st {
uint8_t *data;
uint32_t len;
@@ -922,12 +924,11 @@
/* save last sequence number for retransmissions */
uint8_t last_write_sequence[8];
- /* buffered_messages is a priority queue of incoming handshake messages that
- * have yet to be processed.
- *
- * TODO(davidben): This data structure may as well be a ring buffer of fixed
- * size. */
- pqueue buffered_messages;
+ /* incoming_messages is a ring buffer of incoming handshake messages that have
+ * yet to be processed. The front of the ring buffer is message number
+ * |handshake_read_seq|, at position |handshake_read_seq| %
+ * |SSL_MAX_HANDSHAKE_FLIGHT|. */
+ hm_fragment *incoming_messages[SSL_MAX_HANDSHAKE_FLIGHT];
/* outgoing_messages is the queue of outgoing messages from the last handshake
* flight. */
@@ -1095,7 +1096,6 @@
int dtls1_is_timer_expired(SSL *ssl);
void dtls1_double_timeout(SSL *ssl);
unsigned int dtls1_min_mtu(void);
-void dtls1_hm_fragment_free(hm_fragment *frag);
int dtls1_new(SSL *ssl);
int dtls1_accept(SSL *ssl);