Replace the incoming message buffer with a ring buffer.

It has size 7. There's no need for a priority queue structure, especially one
that's O(N^2) anyway.

Change-Id: I7609794aac1925c9bbf3015744cae266dcb79bff
Reviewed-on: https://boringssl-review.googlesource.com/8437
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/d1_both.c b/ssl/d1_both.c
index 9adc2aa..f021d0a 100644
--- a/ssl/d1_both.c
+++ b/ssl/d1_both.c
@@ -139,6 +139,15 @@
  * the underlying BIO supplies one. */
 static const unsigned int kDefaultMTU = 1500 - 28;
 
+static void dtls1_hm_fragment_free(hm_fragment *frag) {
+  if (frag == NULL) {
+    return;
+  }
+  OPENSSL_free(frag->fragment);
+  OPENSSL_free(frag->reassembly);
+  OPENSSL_free(frag);
+}
+
 static hm_fragment *dtls1_hm_fragment_new(size_t frag_len) {
   hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment));
   if (frag == NULL) {
@@ -177,15 +186,6 @@
   return NULL;
 }
 
-void dtls1_hm_fragment_free(hm_fragment *frag) {
-  if (frag == NULL) {
-    return;
-  }
-  OPENSSL_free(frag->fragment);
-  OPENSSL_free(frag->reassembly);
-  OPENSSL_free(frag);
-}
-
 #if !defined(inline)
 #define inline __inline
 #endif
@@ -407,16 +407,9 @@
 /* dtls1_is_next_message_complete returns one if the next handshake message is
  * complete and zero otherwise. */
 static int dtls1_is_next_message_complete(SSL *ssl) {
-  pitem *item = pqueue_peek(ssl->d1->buffered_messages);
-  if (item == NULL) {
-    return 0;
-  }
-
-  hm_fragment *frag = (hm_fragment *)item->data;
-  assert(ssl->d1->handshake_read_seq <= frag->msg_header.seq);
-
-  return ssl->d1->handshake_read_seq == frag->msg_header.seq &&
-         frag->reassembly == NULL;
+  hm_fragment *frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
+                                                 SSL_MAX_HANDSHAKE_FLIGHT];
+  return frag != NULL && frag->reassembly == NULL;
 }
 
 /* dtls1_get_buffered_message returns the buffered message corresponding to
@@ -425,41 +418,33 @@
  * returns NULL on failure. The caller does not take ownership of the result. */
 static hm_fragment *dtls1_get_buffered_message(
     SSL *ssl, const struct hm_header_st *msg_hdr) {
-  uint8_t seq64be[8];
-  memset(seq64be, 0, sizeof(seq64be));
-  seq64be[6] = (uint8_t)(msg_hdr->seq >> 8);
-  seq64be[7] = (uint8_t)msg_hdr->seq;
-  pitem *item = pqueue_find(ssl->d1->buffered_messages, seq64be);
+  if (msg_hdr->seq < ssl->d1->handshake_read_seq ||
+      msg_hdr->seq - ssl->d1->handshake_read_seq >= SSL_MAX_HANDSHAKE_FLIGHT) {
+    return NULL;
+  }
 
-  hm_fragment *frag;
-  if (item == NULL) {
-    /* This is the first fragment from this message. */
-    frag = dtls1_hm_fragment_new(msg_hdr->msg_len);
-    if (frag == NULL) {
-      return NULL;
-    }
-    memcpy(&frag->msg_header, msg_hdr, sizeof(*msg_hdr));
-    item = pitem_new(seq64be, frag);
-    if (item == NULL) {
-      dtls1_hm_fragment_free(frag);
-      return NULL;
-    }
-    item = pqueue_insert(ssl->d1->buffered_messages, item);
-    /* |pqueue_insert| fails iff a duplicate item is inserted, but |item| cannot
-     * be a duplicate. */
-    assert(item != NULL);
-  } else {
-    frag = item->data;
+  size_t idx = msg_hdr->seq % SSL_MAX_HANDSHAKE_FLIGHT;
+  hm_fragment *frag = ssl->d1->incoming_messages[idx];
+  if (frag != NULL) {
     assert(frag->msg_header.seq == msg_hdr->seq);
+    /* The new fragment must be compatible with the previous fragments from this
+     * message. */
     if (frag->msg_header.type != msg_hdr->type ||
         frag->msg_header.msg_len != msg_hdr->msg_len) {
-      /* The new fragment must be compatible with the previous fragments from
-       * this message. */
       OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH);
       ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
       return NULL;
     }
+    return frag;
   }
+
+  /* This is the first fragment from this message. */
+  frag = dtls1_hm_fragment_new(msg_hdr->msg_len);
+  if (frag == NULL) {
+    return NULL;
+  }
+  memcpy(&frag->msg_header, msg_hdr, sizeof(*msg_hdr));
+  ssl->d1->incoming_messages[idx] = frag;
   return frag;
 }
 
@@ -557,7 +542,6 @@
  * arrive in fragments. */
 long dtls1_get_message(SSL *ssl, int msg_type,
                        enum ssl_hash_message_t hash_message, int *ok) {
-  pitem *item = NULL;
   hm_fragment *frag = NULL;
   int al;
 
@@ -590,12 +574,17 @@
     }
   }
 
-  /* Read out the next complete handshake message. */
-  item = pqueue_pop(ssl->d1->buffered_messages);
-  assert(item != NULL);
-  frag = (hm_fragment *)item->data;
-  assert(ssl->d1->handshake_read_seq == frag->msg_header.seq);
+  /* Pop an entry from the ring buffer. */
+  frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
+                                    SSL_MAX_HANDSHAKE_FLIGHT];
+  ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
+                             SSL_MAX_HANDSHAKE_FLIGHT] = NULL;
+
+  assert(frag != NULL);
   assert(frag->reassembly == NULL);
+  assert(ssl->d1->handshake_read_seq == frag->msg_header.seq);
+
+  ssl->d1->handshake_read_seq++;
 
   /* Reconstruct the assembled message. */
   CBB cbb;
@@ -618,8 +607,6 @@
   assert(ssl->init_buf->length ==
          (size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH);
 
-  ssl->d1->handshake_read_seq++;
-
   /* TODO(davidben): This function has a lot of implicit outputs. Simplify the
    * |ssl_get_message| API. */
   ssl->s3->tmp.message_type = frag->msg_header.type;
@@ -639,7 +626,6 @@
                       ssl->init_buf->data,
                       ssl->init_num + DTLS1_HM_HEADER_LENGTH);
 
-  pitem_free(item);
   dtls1_hm_fragment_free(frag);
 
   *ok = 1;
@@ -648,7 +634,6 @@
 f_err:
   ssl3_send_alert(ssl, SSL3_AL_FATAL, al);
 err:
-  pitem_free(item);
   dtls1_hm_fragment_free(frag);
   *ok = 0;
   return -1;
@@ -765,6 +750,14 @@
   ssl->d1->outgoing_messages_len = 0;
 }
 
+void dtls_clear_incoming_messages(SSL *ssl) {
+  size_t i;
+  for (i = 0; i < SSL_MAX_HANDSHAKE_FLIGHT; i++) {
+    dtls1_hm_fragment_free(ssl->d1->incoming_messages[i]);
+    ssl->d1->incoming_messages[i] = NULL;
+  }
+}
+
 unsigned int dtls1_min_mtu(void) {
   return kMinMTU;
 }
diff --git a/ssl/d1_lib.c b/ssl/d1_lib.c
index 0c47dc6..d738c57 100644
--- a/ssl/d1_lib.c
+++ b/ssl/d1_lib.c
@@ -97,13 +97,6 @@
   }
   memset(d1, 0, sizeof *d1);
 
-  d1->buffered_messages = pqueue_new();
-  if (d1->buffered_messages == NULL) {
-    OPENSSL_free(d1);
-    ssl3_free(ssl);
-    return 0;
-  }
-
   ssl->d1 = d1;
 
   /* Set the version to the highest supported version.
@@ -115,17 +108,6 @@
   return 1;
 }
 
-static void dtls1_clear_queues(SSL *ssl) {
-  pitem *item = NULL;
-  hm_fragment *frag = NULL;
-
-  while ((item = pqueue_pop(ssl->d1->buffered_messages)) != NULL) {
-    frag = (hm_fragment *)item->data;
-    dtls1_hm_fragment_free(frag);
-    pitem_free(item);
-  }
-}
-
 void dtls1_free(SSL *ssl) {
   ssl3_free(ssl);
 
@@ -133,9 +115,7 @@
     return;
   }
 
-  dtls1_clear_queues(ssl);
-  pqueue_free(ssl->d1->buffered_messages);
-
+  dtls_clear_incoming_messages(ssl);
   dtls_clear_outgoing_messages(ssl);
 
   OPENSSL_free(ssl->d1);
diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c
index b0e4c56..7f7f3b8 100644
--- a/ssl/handshake_server.c
+++ b/ssl/handshake_server.c
@@ -516,6 +516,7 @@
         if (SSL_IS_DTLS(ssl)) {
           ssl->d1->handshake_read_seq = 0;
           ssl->d1->handshake_write_seq = 0;
+          dtls_clear_incoming_messages(ssl);
         }
 
         ssl->s3->initial_handshake_complete = 1;
diff --git a/ssl/internal.h b/ssl/internal.h
index 64a63d8..369bdcc 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -145,7 +145,6 @@
 #include <openssl/base.h>
 
 #include <openssl/aead.h>
-#include <openssl/pqueue.h>
 #include <openssl/ssl.h>
 #include <openssl/stack.h>
 
@@ -645,6 +644,9 @@
  * in a handshake message for |ssl|. */
 size_t ssl_max_handshake_message_len(const SSL *ssl);
 
+/* dtls_clear_incoming_messages releases all buffered incoming messages. */
+void dtls_clear_incoming_messages(SSL *ssl);
+
 typedef struct dtls_outgoing_message_st {
   uint8_t *data;
   uint32_t len;
@@ -922,12 +924,11 @@
   /* save last sequence number for retransmissions */
   uint8_t last_write_sequence[8];
 
-  /* buffered_messages is a priority queue of incoming handshake messages that
-   * have yet to be processed.
-   *
-   * TODO(davidben): This data structure may as well be a ring buffer of fixed
-   * size. */
-  pqueue buffered_messages;
+  /* incoming_messages is a ring buffer of incoming handshake messages that have
+   * yet to be processed. The front of the ring buffer is message number
+   * |handshake_read_seq|, at position |handshake_read_seq| %
+   * |SSL_MAX_HANDSHAKE_FLIGHT|. */
+  hm_fragment *incoming_messages[SSL_MAX_HANDSHAKE_FLIGHT];
 
   /* outgoing_messages is the queue of outgoing messages from the last handshake
    * flight. */
@@ -1095,7 +1096,6 @@
 int dtls1_is_timer_expired(SSL *ssl);
 void dtls1_double_timeout(SSL *ssl);
 unsigned int dtls1_min_mtu(void);
-void dtls1_hm_fragment_free(hm_fragment *frag);
 
 int dtls1_new(SSL *ssl);
 int dtls1_accept(SSL *ssl);