Convert more of the SSL write path to size_t and Spans.
We still have our <= 0 return values because anything with BIOs tries to
preserve BIO_write's error returns. (Maybe we can stop doing this?
BIO_read's error return is a little subtle with EOF vs error, but
BIO_write's is uninteresting.) But the rest of the logic is size_t-clean
and hopefully a little clearer. We still have to support SSL_write's
rather goofy calling convention, however.
I haven't pushed Spans down into the low-level record construction logic
yet. We should probably do that, but there are enough offsets tossed
around there that they warrant their own CL.
Bug: 507
Change-Id: Ia0c702d1a2d3713e71b0bbfa8d65649d3b20da9b
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/47544
Commit-Queue: Bob Beck <bbe@google.com>
Reviewed-by: Bob Beck <bbe@google.com>
diff --git a/ssl/d1_pkt.cc b/ssl/d1_pkt.cc
index b9b0ef9..b866156 100644
--- a/ssl/d1_pkt.cc
+++ b/ssl/d1_pkt.cc
@@ -186,8 +186,8 @@
return ssl_open_record_success;
}
-int dtls1_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in,
- int len) {
+int dtls1_write_app_data(SSL *ssl, bool *out_needs_handshake,
+ size_t *out_bytes_written, Span<const uint8_t> in) {
assert(!SSL_in_init(ssl));
*out_needs_handshake = false;
@@ -196,47 +196,46 @@
return -1;
}
- if (len > SSL3_RT_MAX_PLAIN_LENGTH) {
+ // DTLS does not split the input across records.
+ if (in.size() > SSL3_RT_MAX_PLAIN_LENGTH) {
OPENSSL_PUT_ERROR(SSL, SSL_R_DTLS_MESSAGE_TOO_BIG);
return -1;
}
- if (len < 0) {
- OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_LENGTH);
- return -1;
+ if (in.empty()) {
+ *out_bytes_written = 0;
+ return 1;
}
- if (len == 0) {
- return 0;
- }
-
- int ret = dtls1_write_record(ssl, SSL3_RT_APPLICATION_DATA, in, (size_t)len,
+ int ret = dtls1_write_record(ssl, SSL3_RT_APPLICATION_DATA, in,
dtls1_use_current_epoch);
if (ret <= 0) {
return ret;
}
- return len;
+ *out_bytes_written = in.size();
+ return 1;
}
-int dtls1_write_record(SSL *ssl, int type, const uint8_t *in, size_t len,
+int dtls1_write_record(SSL *ssl, int type, Span<const uint8_t> in,
enum dtls1_use_epoch_t use_epoch) {
SSLBuffer *buf = &ssl->s3->write_buffer;
- assert(len <= SSL3_RT_MAX_PLAIN_LENGTH);
+ assert(in.size() <= SSL3_RT_MAX_PLAIN_LENGTH);
// There should never be a pending write buffer in DTLS. One can't write half
// a datagram, so the write buffer is always dropped in
// |ssl_write_buffer_flush|.
assert(buf->empty());
- if (len > SSL3_RT_MAX_PLAIN_LENGTH) {
+ if (in.size() > SSL3_RT_MAX_PLAIN_LENGTH) {
OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
return -1;
}
size_t ciphertext_len;
if (!buf->EnsureCap(ssl_seal_align_prefix_len(ssl),
- len + SSL_max_seal_overhead(ssl)) ||
+ in.size() + SSL_max_seal_overhead(ssl)) ||
!dtls_seal_record(ssl, buf->remaining().data(), &ciphertext_len,
- buf->remaining().size(), type, in, len, use_epoch)) {
+ buf->remaining().size(), type, in.data(), in.size(),
+ use_epoch)) {
buf->Clear();
return -1;
}
@@ -250,7 +249,7 @@
}
int dtls1_dispatch_alert(SSL *ssl) {
- int ret = dtls1_write_record(ssl, SSL3_RT_ALERT, &ssl->s3->send_alert[0], 2,
+ int ret = dtls1_write_record(ssl, SSL3_RT_ALERT, ssl->s3->send_alert,
dtls1_use_current_epoch);
if (ret <= 0) {
return ret;
diff --git a/ssl/internal.h b/ssl/internal.h
index 1a78f63..0a15ace 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2456,8 +2456,13 @@
ssl_open_record_t (*open_app_data)(SSL *ssl, Span<uint8_t> *out,
size_t *out_consumed, uint8_t *out_alert,
Span<uint8_t> in);
- int (*write_app_data)(SSL *ssl, bool *out_needs_handshake, const uint8_t *buf,
- int len);
+ // write_app_data encrypts and writes |in| as application data. On success, it
+ // returns one and sets |*out_bytes_written| to the number of bytes of |in|
+ // written. Otherwise, it returns <= 0 and sets |*out_needs_handshake| to
+ // whether the operation failed because the caller needs to drive the
+ // handshake.
+ int (*write_app_data)(SSL *ssl, bool *out_needs_handshake,
+ size_t *out_bytes_written, Span<const uint8_t> in);
int (*dispatch_alert)(SSL *ssl);
// init_message begins a new handshake message of type |type|. |cbb| is the
// root CBB to be passed into |finish_message|. |*body| is set to a child CBB
@@ -2646,11 +2651,23 @@
// |read_buffer|.
Span<uint8_t> pending_app_data;
- // partial write - check the numbers match
- unsigned int wnum = 0; // number of bytes sent so far
- int wpend_tot = 0; // number bytes written
- int wpend_type = 0;
- const uint8_t *wpend_buf = nullptr;
+ // unreported_bytes_written is the number of bytes successfully written to the
+ // transport, but not yet reported to the caller. The next |SSL_write| will
+ // skip this many bytes from the input. This is used if
+ // |SSL_MODE_ENABLE_PARTIAL_WRITE| is disabled, in which case |SSL_write| only
+ // reports bytes written when the full caller input is written.
+ size_t unreported_bytes_written = 0;
+
+ // pending_write, if |has_pending_write| is true, is the caller-supplied data
+ // corresponding to the current pending write. This is used to check the
+ // caller retried with a compatible buffer.
+ Span<const uint8_t> pending_write;
+
+ // pending_write_type, if |has_pending_write| is true, is the record type
+ // for the current pending write.
+ //
+ // TODO(davidben): Remove this when alerts are moved out of this write path.
+ uint8_t pending_write_type = 0;
// read_shutdown is the shutdown state for the read half of the connection.
enum ssl_shutdown_t read_shutdown = ssl_shutdown_none;
@@ -3214,8 +3231,8 @@
ssl_open_record_t tls_open_change_cipher_spec(SSL *ssl, size_t *out_consumed,
uint8_t *out_alert,
Span<uint8_t> in);
-int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *buf,
- int len);
+int tls_write_app_data(SSL *ssl, bool *out_needs_handshake,
+ size_t *out_bytes_written, Span<const uint8_t> in);
bool tls_new(SSL *ssl);
void tls_free(SSL *ssl);
@@ -3248,11 +3265,11 @@
Span<uint8_t> in);
int dtls1_write_app_data(SSL *ssl, bool *out_needs_handshake,
- const uint8_t *buf, int len);
+ size_t *out_bytes_written, Span<const uint8_t> in);
// dtls1_write_record sends a record. It returns one on success and <= 0 on
// error.
-int dtls1_write_record(SSL *ssl, int type, const uint8_t *buf, size_t len,
+int dtls1_write_record(SSL *ssl, int type, Span<const uint8_t> in,
enum dtls1_use_epoch_t use_epoch);
int dtls1_retransmit_outgoing_messages(SSL *ssl);
diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc
index efe5905..bc0d13d 100644
--- a/ssl/s3_pkt.cc
+++ b/ssl/s3_pkt.cc
@@ -126,10 +126,11 @@
BSSL_NAMESPACE_BEGIN
-static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len);
+static int do_tls_write(SSL *ssl, size_t *out_bytes_written, uint8_t type,
+ Span<const uint8_t> in);
-int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in,
- int len) {
+int tls_write_app_data(SSL *ssl, bool *out_needs_handshake,
+ size_t *out_bytes_written, Span<const uint8_t> in) {
assert(ssl_can_write(ssl));
assert(!ssl->s3->aead_write_ctx->is_null_cipher());
@@ -140,32 +141,28 @@
return -1;
}
- // TODO(davidben): Switch this logic to |size_t| and |bssl::Span|.
- assert(ssl->s3->wnum <= INT_MAX);
- unsigned tot = ssl->s3->wnum;
-
- // Ensure that if we end up with a smaller value of data to write out than
- // the the original len from a write which didn't complete for non-blocking
- // I/O and also somehow ended up avoiding the check for this in
- // do_tls_write/SSL_R_BAD_WRITE_RETRY as it must never be possible to end up
- // with (len-tot) as a large number that will then promptly send beyond the
- // end of the users buffer ... so we trap and report the error in a way the
- // user will notice.
- if (len < 0 || (size_t)len < tot) {
+ size_t total_bytes_written = ssl->s3->unreported_bytes_written;
+ if (in.size() < total_bytes_written) {
+ // This can happen if the caller disables |SSL_MODE_ENABLE_PARTIAL_WRITE|,
+ // asks us to write some input of length N, we successfully encrypt M bytes
+ // and write it, but fail to write the rest. We will report
+ // |SSL_ERROR_WANT_WRITE|. If the caller then retries with fewer than M
+ // bytes, we cannot satisfy that request. The caller is required to always
+ // retry with at least as many bytes as the previous attempt.
OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_LENGTH);
return -1;
}
- const int is_early_data_write =
- !ssl->server && SSL_in_early_data(ssl) && ssl->s3->hs->can_early_write;
+ in = in.subspan(total_bytes_written);
- unsigned n = len - tot;
+ const bool is_early_data_write =
+ !ssl->server && SSL_in_early_data(ssl) && ssl->s3->hs->can_early_write;
for (;;) {
size_t max_send_fragment = ssl->max_send_fragment;
if (is_early_data_write) {
SSL_HANDSHAKE *hs = ssl->s3->hs.get();
if (hs->early_data_written >= hs->early_session->ticket_max_early_data) {
- ssl->s3->wnum = tot;
+ ssl->s3->unreported_bytes_written = total_bytes_written;
hs->can_early_write = false;
*out_needs_handshake = true;
return -1;
@@ -175,35 +172,43 @@
hs->early_data_written});
}
- const size_t nw = std::min(max_send_fragment, size_t{n});
- int ret = do_tls_write(ssl, SSL3_RT_APPLICATION_DATA, &in[tot], nw);
+ const size_t to_write = std::min(max_send_fragment, in.size());
+ size_t bytes_written;
+ int ret = do_tls_write(ssl, &bytes_written, SSL3_RT_APPLICATION_DATA,
+ in.subspan(0, to_write));
if (ret <= 0) {
- ssl->s3->wnum = tot;
+ ssl->s3->unreported_bytes_written = total_bytes_written;
return ret;
}
+ // Note |bytes_written| may be less than |to_write| if there was a pending
+ // record from a smaller write attempt.
+ assert(bytes_written <= to_write);
+ total_bytes_written += bytes_written;
+ in = in.subspan(bytes_written);
if (is_early_data_write) {
- ssl->s3->hs->early_data_written += ret;
+ ssl->s3->hs->early_data_written += bytes_written;
}
- if (ret == (int)n || (ssl->mode & SSL_MODE_ENABLE_PARTIAL_WRITE)) {
- ssl->s3->wnum = 0;
- return tot + ret;
+ if (in.empty() || (ssl->mode & SSL_MODE_ENABLE_PARTIAL_WRITE)) {
+ ssl->s3->unreported_bytes_written = 0;
+ *out_bytes_written = total_bytes_written;
+ return 1;
}
-
- n -= ret;
- tot += ret;
}
}
-// do_tls_write writes an SSL record of the given type.
-static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len) {
+// do_tls_write writes an SSL record of the given type. On success, it sets
+// |*out_bytes_written| to number of bytes successfully written and returns one.
+// On error, it returns a value <= 0 from the underlying |BIO|.
+static int do_tls_write(SSL *ssl, size_t *out_bytes_written, uint8_t type,
+ Span<const uint8_t> in) {
// If there is a pending write, the retry must be consistent.
- if (ssl->s3->wpend_tot > 0 &&
- (ssl->s3->wpend_tot > (int)len ||
+ if (!ssl->s3->pending_write.empty() &&
+ (ssl->s3->pending_write.size() > in.size() ||
(!(ssl->mode & SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER) &&
- ssl->s3->wpend_buf != in) ||
- ssl->s3->wpend_type != type)) {
+ ssl->s3->pending_write.data() != in.data()) ||
+ ssl->s3->pending_write_type != type)) {
OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_WRITE_RETRY);
return -1;
}
@@ -216,15 +221,14 @@
}
// If there is a pending write, we just completed it. Report it to the caller.
- if (ssl->s3->wpend_tot > 0) {
- ret = ssl->s3->wpend_tot;
- ssl->s3->wpend_buf = nullptr;
- ssl->s3->wpend_tot = 0;
- return ret;
+ if (!ssl->s3->pending_write.empty()) {
+ *out_bytes_written = ssl->s3->pending_write.size();
+ ssl->s3->pending_write = {};
+ return 1;
}
SSLBuffer *buf = &ssl->s3->write_buffer;
- if (len > SSL3_RT_MAX_PLAIN_LENGTH || buf->size() > 0) {
+ if (in.size() > SSL3_RT_MAX_PLAIN_LENGTH || buf->size() > 0) {
OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
return -1;
}
@@ -233,16 +237,22 @@
return -1;
}
- size_t flight_len = 0;
+ // We may have unflushed handshake data that must be written before |in|. This
+ // may be a KeyUpdate acknowledgment, 0-RTT key change messages, or a
+ // NewSessionTicket.
+ Span<const uint8_t> pending_flight;
if (ssl->s3->pending_flight != nullptr) {
- flight_len =
- ssl->s3->pending_flight->length - ssl->s3->pending_flight_offset;
+ pending_flight = MakeConstSpan(
+ reinterpret_cast<const uint8_t *>(ssl->s3->pending_flight->data),
+ ssl->s3->pending_flight->length);
+ pending_flight = pending_flight.subspan(ssl->s3->pending_flight_offset);
}
- size_t max_out = flight_len;
- if (len > 0) {
- const size_t max_ciphertext_len = len + SSL_max_seal_overhead(ssl);
- if (max_ciphertext_len < len || max_out + max_ciphertext_len < max_out) {
+ size_t max_out = pending_flight.size();
+ if (!in.empty()) {
+ const size_t max_ciphertext_len = in.size() + SSL_max_seal_overhead(ssl);
+ if (max_ciphertext_len < in.size() ||
+ max_out + max_ciphertext_len < max_out) {
OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
return -1;
}
@@ -250,31 +260,29 @@
}
if (max_out == 0) {
- return 0;
+ // Nothing to write.
+ *out_bytes_written = 0;
+ return 1;
}
- if (!buf->EnsureCap(flight_len + ssl_seal_align_prefix_len(ssl), max_out)) {
+ if (!buf->EnsureCap(pending_flight.size() + ssl_seal_align_prefix_len(ssl),
+ max_out)) {
return -1;
}
- // Add any unflushed handshake data as a prefix. This may be a KeyUpdate
- // acknowledgment or 0-RTT key change messages. |pending_flight| must be clear
- // when data is added to |write_buffer| or it will be written in the wrong
- // order.
- if (ssl->s3->pending_flight != nullptr) {
- OPENSSL_memcpy(
- buf->remaining().data(),
- ssl->s3->pending_flight->data + ssl->s3->pending_flight_offset,
- flight_len);
+ // Copy |pending_flight| to the output.
+ if (!pending_flight.empty()) {
+ OPENSSL_memcpy(buf->remaining().data(), pending_flight.data(),
+ pending_flight.size());
ssl->s3->pending_flight.reset();
ssl->s3->pending_flight_offset = 0;
- buf->DidWrite(flight_len);
+ buf->DidWrite(pending_flight.size());
}
- if (len > 0) {
+ if (!in.empty()) {
size_t ciphertext_len;
if (!tls_seal_record(ssl, buf->remaining().data(), &ciphertext_len,
- buf->remaining().size(), type, in, len)) {
+ buf->remaining().size(), type, in.data(), in.size())) {
return -1;
}
buf->DidWrite(ciphertext_len);
@@ -288,15 +296,15 @@
ret = ssl_write_buffer_flush(ssl);
if (ret <= 0) {
// Track the unfinished write.
- if (len > 0) {
- ssl->s3->wpend_tot = len;
- ssl->s3->wpend_buf = in;
- ssl->s3->wpend_type = type;
+ if (!in.empty()) {
+ ssl->s3->pending_write = in;
+ ssl->s3->pending_write_type = type;
}
return ret;
}
- return len;
+ *out_bytes_written = in.size();
+ return 1;
}
ssl_open_record_t tls_open_app_data(SSL *ssl, Span<uint8_t> *out,
@@ -434,10 +442,13 @@
return 0;
}
} else {
- int ret = do_tls_write(ssl, SSL3_RT_ALERT, &ssl->s3->send_alert[0], 2);
+ size_t bytes_written;
+ int ret =
+ do_tls_write(ssl, &bytes_written, SSL3_RT_ALERT, ssl->s3->send_alert);
if (ret <= 0) {
return ret;
}
+ assert(bytes_written == 2);
}
ssl->s3->alert_dispatch = false;
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 7035748..4d3ad44 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -1058,6 +1058,7 @@
}
int ret = 0;
+ size_t bytes_written = 0;
bool needs_handshake = false;
do {
// If necessary, complete the handshake implicitly.
@@ -1072,10 +1073,16 @@
}
}
- ret = ssl->method->write_app_data(ssl, &needs_handshake,
- (const uint8_t *)buf, num);
+ if (num < 0) {
+ OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_LENGTH);
+ return -1;
+ }
+ ret = ssl->method->write_app_data(
+ ssl, &needs_handshake, &bytes_written,
+ MakeConstSpan(static_cast<const uint8_t *>(buf),
+ static_cast<size_t>(num)));
} while (needs_handshake);
- return ret;
+ return ret <= 0 ? ret : static_cast<int>(bytes_written);
}
int SSL_key_update(SSL *ssl, int request_type) {
@@ -1239,8 +1246,7 @@
// Discard any unfinished writes from the perspective of |SSL_write|'s
// retry. The handshake will transparently flush out the pending record
// (discarded by the server) to keep the framing correct.
- ssl->s3->wpend_buf = nullptr;
- ssl->s3->wpend_tot = 0;
+ ssl->s3->pending_write = {};
}
enum ssl_early_data_reason_t SSL_get_early_data_reason(const SSL *ssl) {