Factor out the buffering and low-level record code.
This begins decoupling the transport from the SSL state machine. The buffering
logic is hidden behind an opaque API. Fields like ssl->packet and
ssl->packet_length are gone.
ssl3_get_record and dtls1_get_record now call low-level tls_open_record and
dtls_open_record functions that unpack a single record independent of who owns
the buffer. Both may be called in-place. This removes ssl->rstate which was
redundant with the buffer length.
Future work will push the buffer up the stack until it is above the handshake.
Then we can expose SSL_open and SSL_seal APIs which act like *_open_record but
return a slightly larger enum due to other events being possible. Likewise the
handshake state machine will be detached from its buffer. The existing
SSL_read, SSL_write, etc., APIs will be implemented on top of SSL_open, etc.,
combined with ssl_read_buffer_* and ssl_write_buffer_*. (Which is why
ssl_read_buffer_extend still tries to abstract between TLS's and DTLS's fairly
different needs.)
The new buffering logic does not support read-ahead (removed previously) since
it lacks a memmove on ssl_read_buffer_discard for TLS, but this could be added
if desired. The old buffering logic wasn't quite right anyway; it tried to
avoid the memmove in some cases and could get stuck too far into the buffer and
not accept records. (The only time the memmove is optional is in DTLS or if
enough of the record header is available to know that the entire next record
would fit in the buffer.)
The new logic also now actually decrypts the ciphertext in-place again, rather
than almost in-place when there's an explicit nonce/IV. (That accidentally
switched in https://boringssl-review.googlesource.com/#/c/4792/; see
3d59e04bce96474099ba76786a2337e99ae14505.)
BUG=468889
Change-Id: I403c1626253c46897f47c7ae93aeab1064b767b2
Reviewed-on: https://boringssl-review.googlesource.com/5715
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/d1_pkt.c b/ssl/d1_pkt.c
index 308d695..6681490 100644
--- a/ssl/d1_pkt.c
+++ b/ssl/d1_pkt.c
@@ -122,221 +122,75 @@
#include "internal.h"
-/* to_u64_be treats |in| as a 8-byte big-endian integer and returns the value as
- * a |uint64_t|. */
-static uint64_t to_u64_be(const uint8_t in[8]) {
- uint64_t ret = 0;
- unsigned i;
- for (i = 0; i < 8; i++) {
- ret <<= 8;
- ret |= in[i];
- }
- return ret;
-}
-
-/* dtls1_bitmap_should_discard returns one if |seq_num| has been seen in |bitmap|
- * or is stale. Otherwise it returns zero. */
-static int dtls1_bitmap_should_discard(DTLS1_BITMAP *bitmap,
- const uint8_t seq_num[8]) {
- const unsigned kWindowSize = sizeof(bitmap->map) * 8;
-
- uint64_t seq_num_u = to_u64_be(seq_num);
- if (seq_num_u > bitmap->max_seq_num) {
- return 0;
- }
- uint64_t idx = bitmap->max_seq_num - seq_num_u;
- return idx >= kWindowSize || (bitmap->map & (((uint64_t)1) << idx));
-}
-
-/* dtls1_bitmap_record updates |bitmap| to record receipt of sequence number
- * |seq_num|. It slides the window forward if needed. It is an error to call
- * this function on a stale sequence number. */
-static void dtls1_bitmap_record(DTLS1_BITMAP *bitmap,
- const uint8_t seq_num[8]) {
- const unsigned kWindowSize = sizeof(bitmap->map) * 8;
-
- uint64_t seq_num_u = to_u64_be(seq_num);
- /* Shift the window if necessary. */
- if (seq_num_u > bitmap->max_seq_num) {
- uint64_t shift = seq_num_u - bitmap->max_seq_num;
- if (shift >= kWindowSize) {
- bitmap->map = 0;
- } else {
- bitmap->map <<= shift;
- }
- bitmap->max_seq_num = seq_num_u;
- }
-
- uint64_t idx = bitmap->max_seq_num - seq_num_u;
- if (idx < kWindowSize) {
- bitmap->map |= ((uint64_t)1) << idx;
- }
-}
-
static int do_dtls1_write(SSL *s, int type, const uint8_t *buf,
unsigned int len, enum dtls1_use_epoch_t use_epoch);
-/* Call this to get a new input record.
- * It will return <= 0 if more data is needed, normally due to an error
- * or non-blocking IO.
- * When it finishes, one packet has been decoded and can be found in
- * ssl->s3->rrec.type - is the type of record
- * ssl->s3->rrec.data, - data
- * ssl->s3->rrec.length, - number of bytes
- *
- * used only by dtls1_read_bytes */
-int dtls1_get_record(SSL *s) {
- uint8_t ssl_major, ssl_minor;
- int n;
- SSL3_RECORD *rr;
- uint8_t *p = NULL;
- uint16_t version;
-
- rr = &(s->s3->rrec);
-
- /* get something from the wire */
+/* dtls1_get_record reads a new input record. On success, it places it in
+ * |ssl->s3->rrec| and returns one. Otherwise it returns <= 0 on error or if
+ * more data is needed. */
+static int dtls1_get_record(SSL *ssl) {
again:
- /* check if we have the header */
- if ((s->rstate != SSL_ST_READ_BODY) ||
- (s->packet_length < DTLS1_RT_HEADER_LENGTH)) {
- n = ssl3_read_n(s, DTLS1_RT_HEADER_LENGTH, 0);
- /* read timeout is handled by dtls1_read_bytes */
- if (n <= 0) {
- return n; /* error or non-blocking */
+ /* Read a new packet if there is no unconsumed one. */
+ if (ssl_read_buffer_len(ssl) == 0) {
+ int ret = ssl_read_buffer_extend_to(ssl, 0 /* unused */);
+ if (ret <= 0) {
+ return ret;
}
-
- /* this packet contained a partial record, dump it */
- if (s->packet_length != DTLS1_RT_HEADER_LENGTH) {
- s->packet_length = 0;
- goto again;
- }
-
- s->rstate = SSL_ST_READ_BODY;
-
- p = s->packet;
-
- if (s->msg_callback) {
- s->msg_callback(0, 0, SSL3_RT_HEADER, p, DTLS1_RT_HEADER_LENGTH, s,
- s->msg_callback_arg);
- }
-
- /* Pull apart the header into the DTLS1_RECORD */
- rr->type = *(p++);
- ssl_major = *(p++);
- ssl_minor = *(p++);
- version = (((uint16_t)ssl_major) << 8) | ssl_minor;
-
- /* sequence number is 64 bits, with top 2 bytes = epoch */
- n2s(p, rr->epoch);
-
- memcpy(&(s->s3->read_sequence[2]), p, 6);
- p += 6;
-
- n2s(p, rr->length);
-
- /* Check the header. */
- if ((s->s3->have_version && version != s->version) ||
- (version & 0xff00) != (s->version & 0xff00) ||
- rr->length > SSL3_RT_MAX_ENCRYPTED_LENGTH) {
- /* The record's header is invalid, so silently drop it.
- *
- * TODO(davidben): This doesn't work. The DTLS record layer is not
- * packet-based, so the remainder of the packet isn't dropped and we
- * get a framing error. It's also unclear what it means to silently
- * drop a record in a packet containing two records. */
- rr->length = 0;
- s->packet_length = 0;
- goto again;
- }
-
- /* now s->rstate == SSL_ST_READ_BODY */
}
+ assert(ssl_read_buffer_len(ssl) > 0);
- /* s->rstate == SSL_ST_READ_BODY, get and decode the data */
-
- if (rr->length > s->packet_length - DTLS1_RT_HEADER_LENGTH) {
- /* now s->packet_length == DTLS1_RT_HEADER_LENGTH */
- n = ssl3_read_n(s, rr->length, 1);
- /* This packet contained a partial record, dump it. */
- if (n != rr->length) {
- rr->length = 0;
- s->packet_length = 0;
- goto again;
- }
-
- /* now n == rr->length,
- * and s->packet_length == DTLS1_RT_HEADER_LENGTH + rr->length */
- }
- s->rstate = SSL_ST_READ_HEADER; /* set state for later operations */
-
- if (rr->epoch != s->d1->r_epoch) {
- /* This record is from the wrong epoch. If it is the next epoch, it could be
- * buffered. For simplicity, drop it and expect retransmit to handle it
- * later; DTLS is supposed to handle packet loss. */
- rr->length = 0;
- s->packet_length = 0;
+ /* Ensure the packet is large enough to decrypt in-place. */
+ if (ssl_read_buffer_len(ssl) < ssl_record_prefix_len(ssl)) {
+ ssl_read_buffer_clear(ssl);
goto again;
}
- /* Check whether this is a repeat, or aged record. */
- if (dtls1_bitmap_should_discard(&s->d1->bitmap, s->s3->read_sequence)) {
- rr->length = 0;
- s->packet_length = 0; /* dump this record */
- goto again; /* get another record */
+ uint8_t *out = ssl_read_buffer(ssl) + ssl_record_prefix_len(ssl);
+ size_t max_out = ssl_read_buffer_len(ssl) - ssl_record_prefix_len(ssl);
+ uint8_t type, alert;
+ size_t len, consumed;
+ switch (dtls_open_record(ssl, &type, out, &len, &consumed, &alert, max_out,
+ ssl_read_buffer(ssl), ssl_read_buffer_len(ssl))) {
+ case ssl_open_record_success:
+ ssl_read_buffer_consume(ssl, consumed);
+
+ /* Discard empty records.
+ * TODO(davidben): This logic should be moved to a higher level. See
+ * https://crbug.com/521840.
+ * TODO(davidben): Limit the number of empty records as in TLS? This is
+ * useful if we also limit discarded packets. */
+ if (len == 0) {
+ goto again;
+ }
+
+ if (len > 0xffff) {
+ OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
+ return -1;
+ }
+
+ SSL3_RECORD *rr = &ssl->s3->rrec;
+ rr->type = type;
+ rr->length = (uint16_t)len;
+ rr->off = 0;
+ rr->data = out;
+ return 1;
+
+ case ssl_open_record_discard:
+ ssl_read_buffer_consume(ssl, consumed);
+ goto again;
+
+ case ssl_open_record_error:
+ ssl3_send_alert(ssl, SSL3_AL_FATAL, alert);
+ return -1;
+
+ case ssl_open_record_partial:
+ /* Impossible in DTLS. */
+ break;
}
- /* |rr->data| points to |rr->length| bytes of ciphertext in |s->packet|. */
- rr->data = &s->packet[DTLS1_RT_HEADER_LENGTH];
-
- uint8_t seq[8];
- seq[0] = rr->epoch >> 8;
- seq[1] = rr->epoch & 0xff;
- memcpy(&seq[2], &s->s3->read_sequence[2], 6);
-
- /* Decrypt the packet in-place. Note it is important that |ssl_aead_ctx_open|
- * not write beyond |rr->length|. There may be another record in the packet.
- *
- * TODO(davidben): This assumes |s->version| is the same as the record-layer
- * version which isn't always true, but it only differs with the NULL cipher
- * which ignores the parameter. */
- size_t plaintext_len;
- if (!SSL_AEAD_CTX_open(s->aead_read_ctx, rr->data, &plaintext_len, rr->length,
- rr->type, s->version, seq, rr->data, rr->length)) {
- /* Bad packets are silently dropped in DTLS. Clear the error queue of any
- * errors decryption may have added. */
- ERR_clear_error();
- rr->length = 0;
- s->packet_length = 0;
- goto again;
- }
-
- if (plaintext_len > SSL3_RT_MAX_PLAIN_LENGTH) {
- OPENSSL_PUT_ERROR(SSL, SSL_R_DATA_LENGTH_TOO_LONG);
- ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_RECORD_OVERFLOW);
- return -1;
- }
- assert(plaintext_len < (1u << 16));
- rr->length = plaintext_len;
-
- rr->off = 0;
- /* So at this point the following is true
- * ssl->s3->rrec.type is the type of record
- * ssl->s3->rrec.length == number of bytes in record
- * ssl->s3->rrec.off == offset to first valid byte
- * ssl->s3->rrec.data == the first byte of the record body. */
-
- /* we have pulled in a full packet so zero things */
- s->packet_length = 0;
-
- dtls1_bitmap_record(&s->d1->bitmap, s->s3->read_sequence);
-
- /* just read a 0 length packet
- * TODO(davidben): Reject excess 0-length packets? */
- if (rr->length == 0) {
- goto again;
- }
-
- return 1;
+ assert(0);
+ OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+ return -1;
}
int dtls1_read_app_data(SSL *ssl, uint8_t *buf, int len, int peek) {
@@ -415,7 +269,7 @@
}
/* get new packet if necessary */
- if (rr->length == 0 || s->rstate == SSL_ST_READ_BODY) {
+ if (rr->length == 0) {
ret = dtls1_get_record(s);
if (ret <= 0) {
ret = dtls1_read_failed(s, ret);
@@ -477,8 +331,9 @@
rr->length -= n;
rr->off += n;
if (rr->length == 0) {
- s->rstate = SSL_ST_READ_HEADER;
rr->off = 0;
+ /* The record has been consumed, so we may now clear the buffer. */
+ ssl_read_buffer_discard(s);
}
}
@@ -663,73 +518,11 @@
return i;
}
-/* dtls1_seal_record seals a new record of type |type| and plaintext |in| and
- * writes it to |out|. At most |max_out| bytes will be written. It returns one
- * on success and zero on error. On success, it updates the write sequence
- * number. */
-static int dtls1_seal_record(SSL *s, uint8_t *out, size_t *out_len,
- size_t max_out, uint8_t type, const uint8_t *in,
- size_t in_len, enum dtls1_use_epoch_t use_epoch) {
- if (max_out < DTLS1_RT_HEADER_LENGTH) {
- OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
- return 0;
- }
-
- /* Determine the parameters for the current epoch. */
- uint16_t epoch = s->d1->w_epoch;
- SSL_AEAD_CTX *aead = s->aead_write_ctx;
- uint8_t *seq = s->s3->write_sequence;
- if (use_epoch == dtls1_use_previous_epoch) {
- /* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1
- * (negotiated cipher) exist. */
- assert(s->d1->w_epoch == 1);
- epoch = s->d1->w_epoch - 1;
- aead = NULL;
- seq = s->d1->last_write_sequence;
- }
-
- out[0] = type;
-
- uint16_t wire_version = s->s3->have_version ? s->version : DTLS1_VERSION;
- out[1] = wire_version >> 8;
- out[2] = wire_version & 0xff;
-
- out[3] = epoch >> 8;
- out[4] = epoch & 0xff;
- memcpy(&out[5], &seq[2], 6);
-
- size_t ciphertext_len;
- if (!SSL_AEAD_CTX_seal(aead, out + DTLS1_RT_HEADER_LENGTH, &ciphertext_len,
- max_out - DTLS1_RT_HEADER_LENGTH, type, wire_version,
- &out[3] /* seq */, in, in_len) ||
- !ssl3_record_sequence_update(&seq[2], 6)) {
- return 0;
- }
-
- if (ciphertext_len >= 1 << 16) {
- OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
- return 0;
- }
- out[11] = ciphertext_len >> 8;
- out[12] = ciphertext_len & 0xff;
-
- *out_len = DTLS1_RT_HEADER_LENGTH + ciphertext_len;
-
- if (s->msg_callback) {
- s->msg_callback(1 /* write */, 0, SSL3_RT_HEADER, out,
- DTLS1_RT_HEADER_LENGTH, s, s->msg_callback_arg);
- }
-
- return 1;
-}
-
static int do_dtls1_write(SSL *s, int type, const uint8_t *buf,
unsigned int len, enum dtls1_use_epoch_t use_epoch) {
- SSL3_BUFFER *wb = &s->s3->wbuf;
-
/* ssl3_write_pending drops the write if |BIO_write| fails in DTLS, so there
* is never pending data. */
- assert(s->s3->wbuf.left == 0);
+ assert(!ssl_write_buffer_is_pending(s));
/* If we have an alert to send, lets send it */
if (s->s3->alert_dispatch) {
@@ -740,7 +533,8 @@
/* if it went, fall through and send more stuff */
}
- if (wb->buf == NULL && !ssl3_setup_write_buffer(s)) {
+ if (len > SSL3_RT_MAX_PLAIN_LENGTH) {
+ OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
return -1;
}
@@ -748,21 +542,15 @@
return 0;
}
- /* Align the output so the ciphertext is aligned to |SSL3_ALIGN_PAYLOAD|. */
- uintptr_t align = (uintptr_t)wb->buf + DTLS1_RT_HEADER_LENGTH;
- align = (0 - align) & (SSL3_ALIGN_PAYLOAD - 1);
- uint8_t *out = wb->buf + align;
- wb->offset = align;
- size_t max_out = wb->len - wb->offset;
-
+ size_t max_out = len + ssl_max_seal_overhead(s);
+ uint8_t *out;
size_t ciphertext_len;
- if (!dtls1_seal_record(s, out, &ciphertext_len, max_out, type, buf, len,
- use_epoch)) {
+ if (!ssl_write_buffer_init(s, &out, max_out) ||
+ !dtls_seal_record(s, out, &ciphertext_len, max_out, type, buf, len,
+ use_epoch)) {
return -1;
}
-
- /* now let's set up wb */
- wb->left = ciphertext_len;
+ ssl_write_buffer_set_len(s, ciphertext_len);
/* memorize arguments so that ssl3_write_pending can detect bad write retries
* later */