Send one ACK immediately after the handshake in DTLS 1.3 servers
This does not yet implement sending ACKs in general, just the one
immediate ACK when the handshake completes. The general case will
require scheduling an ACK-send timer, but this one can be sent
immediately.
One interesting case to test is when the server would like to ACK
Finished, but cannot because the records were merged with fragments that
we had to discard.
Bug: 42290594
Change-Id: I64b3f8ecbef4ffee68d923f83ea89d2349847f8b
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72867
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Nick Harper <nharper@chromium.org>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index 2b53b85..381186f 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -350,8 +350,10 @@
}
bool dtls1_process_handshake_fragments(SSL *ssl, uint8_t *out_alert,
+ DTLSRecordNumber record_number,
Span<const uint8_t> record) {
bool implicit_ack = false;
+ bool skipped_fragments = false;
CBS cbs = record;
while (CBS_len(&cbs) > 0) {
// Read a handshake fragment.
@@ -381,6 +383,7 @@
continue;
}
+ assert(record_number.epoch() == ssl->d1->read_epoch.epoch);
if (ssl->d1->next_read_epoch != nullptr) {
// Any any time, we only expect new messages in one epoch. If
// |next_read_epoch| is set, we've started a new epoch but haven't
@@ -416,6 +419,7 @@
if (msg_hdr.seq - ssl->d1->handshake_read_seq > SSL_MAX_HANDSHAKE_FLIGHT) {
// Ignore fragments too far in the future.
+ skipped_fragments = true;
continue;
}
@@ -443,6 +447,10 @@
dtls_clear_outgoing_messages(ssl);
}
+ if (!skipped_fragments) {
+ ssl->d1->records_to_ack.PushBack(record_number);
+ }
+
return true;
}
@@ -498,7 +506,8 @@
return ssl_open_record_error;
}
- if (!dtls1_process_handshake_fragments(ssl, out_alert, record)) {
+ if (!dtls1_process_handshake_fragments(ssl, out_alert, record_number,
+ record)) {
return ssl_open_record_error;
}
return ssl_open_record_success;
@@ -986,13 +995,83 @@
return 1;
}
-int dtls1_flush_flight(SSL *ssl) {
+int dtls1_flush_flight(SSL *ssl, bool post_handshake) {
ssl->d1->outgoing_messages_complete = true;
+ if (!post_handshake) {
+ // Our new flight implicitly ACKs the previous flight, so there is no need
+ // to ACK previous records. This clears the ACK buffer slightly earlier than
+ // the specification suggests. See the discussion in
+ // https://mailarchive.ietf.org/arch/msg/tls/kjJnquJOVaWxu5hUCmNzB35eqY0/
+ //
+ // TODO(crbug.com/42290594): When we introduce the ACK timer, this should
+ // also stop the ACK timer.
+ ssl->d1->records_to_ack.Clear();
+ }
// Start the retransmission timer for the next flight (if any).
dtls1_start_timer(ssl);
return send_flight(ssl);
}
+int dtls1_send_ack(SSL *ssl) {
+ assert(ssl_protocol_version(ssl) >= TLS1_3_VERSION);
+ if (ssl->d1->records_to_ack.empty()) {
+ return 1;
+ }
+
+ // Ensure we don't send so many ACKs that we overflow the MTU. There is a
+ // 2-byte length prefix and each ACK is 16 bytes.
+ dtls1_update_mtu(ssl);
+ size_t max_plaintext =
+ dtls_seal_max_input_len(ssl, ssl->d1->write_epoch.epoch(), ssl->d1->mtu);
+ if (max_plaintext < 2 + 16) {
+ OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL); // No room for even one ACK.
+ return -1;
+ }
+ size_t num_acks =
+ std::min((max_plaintext - 2) / 16, ssl->d1->records_to_ack.size());
+
+ // Assemble the ACK. RFC 9147 says to sort ACKs numerically. It is unclear if
+ // other implementations do this, but go ahead and sort for now. See
+ // https://mailarchive.ietf.org/arch/msg/tls/kjJnquJOVaWxu5hUCmNzB35eqY0/.
+ // Remove this if rfc9147bis removes this requirement.
+ InplaceVector<DTLSRecordNumber, DTLS_MAX_ACK_BUFFER> sorted;
+ for (size_t i = ssl->d1->records_to_ack.size() - num_acks;
+ i < ssl->d1->records_to_ack.size(); i++) {
+ sorted.PushBack(ssl->d1->records_to_ack[i]);
+ }
+ std::sort(sorted.begin(), sorted.end());
+
+ uint8_t buf[2 + 16 * DTLS_MAX_ACK_BUFFER];
+ CBB cbb, child;
+ CBB_init_fixed(&cbb, buf, sizeof(buf));
+ BSSL_CHECK(CBB_add_u16_length_prefixed(&cbb, &child));
+ for (const auto &number : sorted) {
+ BSSL_CHECK(CBB_add_u64(&child, number.epoch()));
+ BSSL_CHECK(CBB_add_u64(&child, number.sequence()));
+ }
+ BSSL_CHECK(CBB_flush(&cbb));
+
+ // Encrypt it.
+ uint8_t record[DTLS1_3_RECORD_HEADER_WRITE_LENGTH + sizeof(buf) +
+ 1 /* record type */ + EVP_AEAD_MAX_OVERHEAD];
+ size_t record_len;
+ DTLSRecordNumber record_number;
+ if (!dtls_seal_record(ssl, &record_number, record, &record_len,
+ sizeof(record), SSL3_RT_ACK, CBB_data(&cbb),
+ CBB_len(&cbb), ssl->d1->write_epoch.epoch())) {
+ return -1;
+ }
+
+ int bio_ret =
+ BIO_write(ssl->wbio.get(), record, static_cast<int>(record_len));
+ if (bio_ret <= 0) {
+ ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
+ return bio_ret;
+ }
+
+ return 1;
+}
+
int dtls1_retransmit_outgoing_messages(SSL *ssl) {
// Rewind to the start of the flight and write it again.
//
diff --git a/ssl/d1_pkt.cc b/ssl/d1_pkt.cc
index 6069676..168558d 100644
--- a/ssl/d1_pkt.cc
+++ b/ssl/d1_pkt.cc
@@ -256,7 +256,8 @@
if (type == SSL3_RT_HANDSHAKE) {
// Process handshake fragments for DTLS 1.3 post-handshake messages.
if (ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
- if (!dtls1_process_handshake_fragments(ssl, out_alert, record)) {
+ if (!dtls1_process_handshake_fragments(ssl, out_alert, record_number,
+ record)) {
return ssl_open_record_error;
}
return ssl_open_record_discard;
diff --git a/ssl/dtls_method.cc b/ssl/dtls_method.cc
index e2e525f..df0f7ed 100644
--- a/ssl/dtls_method.cc
+++ b/ssl/dtls_method.cc
@@ -166,6 +166,7 @@
dtls1_add_message,
dtls1_add_change_cipher_spec,
dtls1_flush_flight,
+ dtls1_send_ack,
dtls1_on_handshake_complete,
dtls1_set_read_state,
dtls1_set_write_state,
diff --git a/ssl/handshake.cc b/ssl/handshake.cc
index 0c5895f..51eeeea 100644
--- a/ssl/handshake.cc
+++ b/ssl/handshake.cc
@@ -598,8 +598,10 @@
ERR_restore_state(hs->error.get());
return -1;
+ case ssl_hs_flush_post_handshake:
case ssl_hs_flush: {
- int ret = ssl->method->flush_flight(ssl);
+ bool post_handshake = hs->wait == ssl_hs_flush_post_handshake;
+ int ret = ssl->method->flush_flight(ssl, post_handshake);
if (ret <= 0) {
return ret;
}
@@ -677,7 +679,7 @@
return -1;
case ssl_hs_handback: {
- int ret = ssl->method->flush_flight(ssl);
+ int ret = ssl->method->flush_flight(ssl, /*post_handshake=*/false);
if (ret <= 0) {
return ret;
}
@@ -730,6 +732,15 @@
ssl->s3->rwstate = SSL_ERROR_HANDSHAKE_HINTS_READY;
return -1;
+ case ssl_hs_ack:
+ if (ssl->method->send_ack != nullptr) {
+ int ret = ssl->method->send_ack(ssl);
+ if (ret <= 0) {
+ return ret;
+ }
+ }
+ break;
+
case ssl_hs_ok:
break;
}
diff --git a/ssl/internal.h b/ssl/internal.h
index 97c9ed1..e5cb15a 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1298,6 +1298,7 @@
return combined() == r.combined();
}
bool operator!=(DTLSRecordNumber r) const { return !((*this) == r); }
+ bool operator<(DTLSRecordNumber r) const { return combined() < r.combined(); }
uint64_t combined() const { return combined_; }
uint16_t epoch() const { return combined_ >> 48; }
@@ -2122,6 +2123,7 @@
ssl_hs_read_server_hello,
ssl_hs_read_message,
ssl_hs_flush,
+ ssl_hs_flush_post_handshake,
ssl_hs_certificate_selection_pending,
ssl_hs_handoff,
ssl_hs_handback,
@@ -2135,6 +2137,7 @@
ssl_hs_read_change_cipher_spec,
ssl_hs_certificate_verify,
ssl_hs_hints_ready,
+ ssl_hs_ack,
};
enum ssl_grease_index_t {
@@ -2954,8 +2957,12 @@
// flight. It returns true on success and false on error.
bool (*add_change_cipher_spec)(SSL *ssl);
// flush_flight flushes the pending flight to the transport. It returns one on
- // success and <= 0 on error.
- int (*flush_flight)(SSL *ssl);
+ // success and <= 0 on error. If |post_handshake| is true, the flight is a
+ // post-handshake flight.
+ int (*flush_flight)(SSL *ssl, bool post_handshake);
+ // send_ack, if not NULL, sends a DTLS ACK record to the peer. It returns one
+ // on success and <= 0 on error.
+ int (*send_ack)(SSL *ssl);
// on_handshake_complete is called when the handshake is complete.
void (*on_handshake_complete)(SSL *ssl);
// set_read_state sets |ssl|'s read cipher state and level to |aead_ctx| and
@@ -3538,6 +3545,12 @@
// when empty.
UniquePtr<MRUQueue<DTLSSentRecord, DTLS_MAX_ACK_BUFFER>> sent_records;
+ // records_to_ack is a queue of received records that we should ACK. This is
+ // not stored on the heap because, in the steady state, DTLS 1.3 does not
+ // necessarily empty this list. (We probably could drop records from here once
+ // they are sufficiently old.)
+ MRUQueue<DTLSRecordNumber, DTLS_MAX_ACK_BUFFER> records_to_ack;
+
// outgoing_written is the number of outgoing messages that have been
// written.
uint8_t outgoing_written = 0;
@@ -3879,13 +3892,14 @@
bool tls_finish_message(const SSL *ssl, CBB *cbb, Array<uint8_t> *out_msg);
bool tls_add_message(SSL *ssl, Array<uint8_t> msg);
bool tls_add_change_cipher_spec(SSL *ssl);
-int tls_flush_flight(SSL *ssl);
+int tls_flush_flight(SSL *ssl, bool post_handshake);
bool dtls1_init_message(const SSL *ssl, CBB *cbb, CBB *body, uint8_t type);
bool dtls1_finish_message(const SSL *ssl, CBB *cbb, Array<uint8_t> *out_msg);
bool dtls1_add_message(SSL *ssl, Array<uint8_t> msg);
bool dtls1_add_change_cipher_spec(SSL *ssl);
-int dtls1_flush_flight(SSL *ssl);
+int dtls1_flush_flight(SSL *ssl, bool post_handshake);
+int dtls1_send_ack(SSL *ssl);
// ssl_add_message_cbb finishes the handshake message in |cbb| and adds it to
// the pending flight. It returns true on success and false on error.
@@ -3927,6 +3941,7 @@
void dtls1_free(SSL *ssl);
bool dtls1_process_handshake_fragments(SSL *ssl, uint8_t *out_alert,
+ DTLSRecordNumber record_number,
Span<const uint8_t> record);
bool dtls1_get_message(const SSL *ssl, SSLMessage *out);
ssl_open_record_t dtls1_open_handshake(SSL *ssl, size_t *out_consumed,
diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc
index 4ffc1d8..6912dc4 100644
--- a/ssl/s3_both.cc
+++ b/ssl/s3_both.cc
@@ -281,7 +281,7 @@
return true;
}
-int tls_flush_flight(SSL *ssl) {
+int tls_flush_flight(SSL *ssl, bool post_handshake) {
if (!tls_flush_pending_hs_data(ssl)) {
return -1;
}
@@ -302,6 +302,15 @@
return 1;
}
+ if (post_handshake) {
+ // Don't flush post-handshake messages like NewSessionTicket until the
+ // server performs a write, to prevent a non-reading client from causing the
+ // server to hang in the case of a small server write buffer. Consumers
+ // which don't write data to the client will need to do a zero-byte write if
+ // they wish to flush the tickets.
+ return 1;
+ }
+
if (ssl->s3->write_shutdown != ssl_shutdown_none) {
OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
return -1;
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index c025548..28c4f72 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -125,6 +125,7 @@
receivedFlight []DTLSMessage
receivedFlightRecords []DTLSRecordNumberInfo
nextFlight []DTLSMessage
+ expectedACK []DTLSRecordNumber
keyUpdateSeen bool
keyUpdateRequested bool
@@ -375,6 +376,29 @@
}
}
+// lastRecordNumber returns the most recent record number decrypted or encrypted
+// on a halfConn.
+//
+// TODO(crbug.com/376641666): This function is a bit hacky. It needs to rewind
+// the state back to what the last call actually used. Fix the TLS/DTLS
+// abstractions so we can return this value out directly.
+func (hc *halfConn) lastRecordNumber(epoch *epochState, isOut bool) DTLSRecordNumber {
+ seq := binary.BigEndian.Uint64(epoch.seq[:])
+ // We maintain the next record number, so undo the increment.
+ if seq&(1<<48-1) == 0 {
+ panic("tls: epoch has never been used")
+ }
+ seq--
+ if hc.isDTLS {
+ if isOut && hc.config.Bugs.SequenceNumberMapping != nil {
+ seq = hc.config.Bugs.SequenceNumberMapping(seq)
+ }
+ // Remove the embedded epoch number.
+ seq &= 1<<48 - 1
+ }
+ return DTLSRecordNumber{Epoch: uint64(epoch.epoch), Sequence: seq}
+}
+
func (hc *halfConn) sequenceNumberForOutput(epoch *epochState) []byte {
if !hc.isDTLS || hc.config.Bugs.SequenceNumberMapping == nil {
return epoch.seq[:]
@@ -1054,7 +1078,7 @@
c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: ChangeCipherSpec requested after handshake complete"))
}
- case recordTypeApplicationData, recordTypeAlert, recordTypeHandshake:
+ case recordTypeApplicationData, recordTypeAlert, recordTypeHandshake, recordTypeACK:
break
}
@@ -1150,6 +1174,17 @@
if pack := c.config.Bugs.ExpectPackedEncryptedHandshake; pack > 0 && len(data) < pack && c.out.epoch.cipher != nil {
c.seenHandshakePackEnd = true
}
+
+ case recordTypeACK:
+ if typ != want || !c.isDTLS {
+ c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
+ break
+ }
+
+ if err := c.checkACK(data); err != nil {
+ c.in.setErrorLocked(err)
+ break
+ }
}
return c.in.err
diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go
index a8944c1..266b9c3 100644
--- a/ssl/test/runner/dtls.go
+++ b/ssl/test/runner/dtls.go
@@ -72,6 +72,9 @@
Sequence uint16
Offset int
Data []byte
+ // ShouldDiscard, if true, indicates the shim is expected to discard this
+ // fragment. A record with such a fragment must not be ACKed by the shim.
+ ShouldDiscard bool
}
func (f *DTLSFragment) Bytes() []byte {
@@ -96,13 +99,17 @@
return cmp.Compare(a2, b2)
}
-// A DTLSRecordNumberInfo contains information about a record received from the
-// shim, which we may attempt to ACK.
-type DTLSRecordNumberInfo struct {
+type DTLSRecordNumber struct {
// Store the Epoch as a uint64, so that tests can send ACKs for epochs that
// the shim would never use.
Epoch uint64
Sequence uint64
+}
+
+// A DTLSRecordNumberInfo contains information about a record received from the
+// shim, which we may attempt to ACK.
+type DTLSRecordNumberInfo struct {
+ DTLSRecordNumber
// The first byte covered by this record, inclusive. We only need to store
// one range because we require that the shim arrange fragments in order.
// Any gaps will have been previously-ACKed data, so there is no harm in
@@ -394,17 +401,17 @@
}
if typ == recordTypeApplicationData && len(data) > 1 && c.config.Bugs.SplitAndPackAppData {
- _, err = c.dtlsPackRecord(epoch, typ, data[:len(data)/2], false)
+ _, _, err = c.dtlsPackRecord(epoch, typ, data[:len(data)/2], false)
if err != nil {
return
}
- _, err = c.dtlsPackRecord(epoch, typ, data[len(data)/2:], true)
+ _, _, err = c.dtlsPackRecord(epoch, typ, data[len(data)/2:], true)
if err != nil {
return
}
n = len(data)
} else {
- n, err = c.dtlsPackRecord(epoch, typ, data, false)
+ n, _, err = c.dtlsPackRecord(epoch, typ, data, false)
if err != nil {
return
}
@@ -514,7 +521,7 @@
// dtlsPackRecord packs a single record to the pending packet, flushing it
// if necessary. The caller should call dtlsFlushPacket to flush the current
// pending packet afterwards.
-func (c *Conn) dtlsPackRecord(epoch *epochState, typ recordType, data []byte, mustPack bool) (n int, err error) {
+func (c *Conn) dtlsPackRecord(epoch *epochState, typ recordType, data []byte, mustPack bool) (n int, num DTLSRecordNumber, err error) {
maxLen := c.config.Bugs.MaxHandshakeRecordLength
if maxLen <= 0 {
maxLen = 1024
@@ -556,6 +563,7 @@
if err != nil {
return
}
+ num = c.out.lastRecordNumber(epoch, true /* isOut */)
// Encrypt the sequence number.
if useDTLS13RecordHeader && !c.config.Bugs.NullAllCiphers {
@@ -627,16 +635,8 @@
return f, nil
}
-func makeDTLSRecordNumberInfo(epoch *epochState, data []byte) (DTLSRecordNumberInfo, error) {
- info := DTLSRecordNumberInfo{
- Epoch: uint64(epoch.epoch),
- // Remove the embedded epoch number. The sequence number has also since
- // been incremented, so adjust it back down.
- //
- // TODO(crbug.com/376641666): The record abstractions should reliably
- // return the sequence number.
- Sequence: (binary.BigEndian.Uint64(epoch.seq[:]) & (1<<48 - 1)) - 1,
- }
+func (c *Conn) makeDTLSRecordNumberInfo(epoch *epochState, data []byte) (DTLSRecordNumberInfo, error) {
+ info := DTLSRecordNumberInfo{DTLSRecordNumber: c.in.lastRecordNumber(epoch, false /* isOut */)}
s := cryptobyte.String(data)
first := true
@@ -674,7 +674,7 @@
if err := c.readRecord(recordTypeHandshake); err != nil {
return nil, err
}
- record, err := makeDTLSRecordNumberInfo(&c.in.epoch, c.hand.Bytes())
+ record, err := c.makeDTLSRecordNumberInfo(&c.in.epoch, c.hand.Bytes())
if err != nil {
return nil, err
}
@@ -723,6 +723,10 @@
return nil, fmt.Errorf("dtls: handshake fragment was truncated, but record could have fit %d more bytes", c.lastRecordInFlight.bytesAvailable)
}
}
+
+ // Sending part of the next flight implicitly ACKs the previous flight.
+ // Having triggered this, the shim is expected to clear its ACK buffer.
+ c.expectedACK = nil
}
c.recvHandshakeSeq++
ret := c.handMsg
@@ -736,6 +740,52 @@
return ret, nil
}
+func (c *Conn) checkACK(data []byte) error {
+ s := cryptobyte.String(data)
+ var child cryptobyte.String
+ if !s.ReadUint16LengthPrefixed(&child) || !s.Empty() {
+ return fmt.Errorf("tls: could not parse ACK record")
+ }
+
+ var acks []DTLSRecordNumber
+ for !child.Empty() {
+ var num DTLSRecordNumber
+ if !child.ReadUint64(&num.Epoch) || !child.ReadUint64(&num.Sequence) {
+ return fmt.Errorf("tls: could not parse ACK record")
+ }
+ acks = append(acks, num)
+ }
+
+ // Determine the expected ACKs, if any.
+ expected := c.expectedACK
+ if len(expected) > shimConfig.MaxACKBuffer {
+ expected = expected[len(expected)-shimConfig.MaxACKBuffer:]
+ }
+
+ // If we've configured a tighter MTU, the shim might have needed to truncate
+ // the list. Tolerate this as long as the shim sent the more recent records
+ // and still sent a plausible minimum number of ACKs.
+ if c.maxPacketLen != 0 && len(acks) > 10 && len(acks) < len(expected) {
+ expected = expected[len(expected)-len(acks):]
+ }
+
+ // The shim is expected to sort the record numbers in the ACK.
+ expected = slices.Clone(expected)
+ slices.SortFunc(expected, func(a, b DTLSRecordNumber) int {
+ cmp1 := cmp.Compare(a.Epoch, b.Epoch)
+ if cmp1 != 0 {
+ return cmp1
+ }
+ return cmp.Compare(a.Sequence, b.Sequence)
+ })
+
+ if !slices.Equal(acks, expected) {
+ return fmt.Errorf("tls: got ACKs %+v, but expected %+v", acks, expected)
+ }
+
+ return nil
+}
+
// DTLSServer returns a new DTLS server side connection
// using conn as the underlying transport.
// The configuration config must be non-nil and must have
@@ -998,18 +1048,29 @@
}
maxRecordLen := config.Bugs.PackHandshakeFragments
+ packRecord := func(epoch *epochState, typ recordType, data []byte, anyDiscard bool) error {
+ _, num, err := c.conn.dtlsPackRecord(epoch, typ, data, false)
+ if err != nil {
+ return err
+ }
+ if !anyDiscard && typ == recordTypeHandshake {
+ c.conn.expectedACK = append(c.conn.expectedACK, num)
+ }
+ return nil
+ }
// Pack handshake fragments into records.
var record []byte
var epoch *epochState
+ var anyDiscard bool
flush := func() error {
if len(record) > 0 {
- _, err := c.conn.dtlsPackRecord(epoch, recordTypeHandshake, record, false)
- if err != nil {
+ if err := packRecord(epoch, recordTypeHandshake, record, anyDiscard); err != nil {
return err
}
}
record = nil
+ anyDiscard = false
return nil
}
@@ -1028,26 +1089,30 @@
}
if f.IsChangeCipherSpec {
- _, c.err = c.conn.dtlsPackRecord(epoch, recordTypeChangeCipherSpec, f.Bytes(), false)
+ c.err = packRecord(epoch, recordTypeChangeCipherSpec, f.Bytes(), false)
if c.err != nil {
return
}
continue
}
+ if f.ShouldDiscard {
+ anyDiscard = true
+ }
+
fBytes := f.Bytes()
if n := config.Bugs.SplitFragments; n > 0 {
if len(fBytes) > n {
- _, c.err = c.conn.dtlsPackRecord(epoch, recordTypeHandshake, fBytes[:n], false)
+ c.err = packRecord(epoch, recordTypeHandshake, fBytes[:n], f.ShouldDiscard)
if c.err != nil {
return
}
- _, c.err = c.conn.dtlsPackRecord(epoch, recordTypeHandshake, fBytes[n:], false)
+ c.err = packRecord(epoch, recordTypeHandshake, fBytes[n:], f.ShouldDiscard)
if c.err != nil {
return
}
} else {
- _, c.err = c.conn.dtlsPackRecord(epoch, recordTypeHandshake, fBytes, false)
+ c.err = packRecord(epoch, recordTypeHandshake, fBytes, f.ShouldDiscard)
if c.err != nil {
return
}
@@ -1078,13 +1143,13 @@
// Send the ACK.
ack := cryptobyte.NewBuilder(make([]byte, 0, 2+8*len(records)))
- ack.AddUint16LengthPrefixed(func(child *cryptobyte.Builder) {
+ ack.AddUint16LengthPrefixed(func(recordNumbers *cryptobyte.Builder) {
for _, r := range records {
- child.AddUint64(r.Epoch)
- child.AddUint64(r.Sequence)
+ recordNumbers.AddUint64(r.Epoch)
+ recordNumbers.AddUint64(r.Sequence)
}
})
- _, c.err = c.conn.dtlsPackRecord(c.getOutEpochOrPanic(epoch), recordTypeACK, ack.BytesOrPanic(), false)
+ _, _, c.err = c.conn.dtlsPackRecord(c.getOutEpochOrPanic(epoch), recordTypeACK, ack.BytesOrPanic(), false)
if c.err != nil {
return
}
@@ -1223,7 +1288,7 @@
}
}
- record, err := makeDTLSRecordNumberInfo(epoch, data)
+ record, err := c.conn.makeDTLSRecordNumberInfo(epoch, data)
if err != nil {
return nil, err
}
@@ -1232,6 +1297,32 @@
return records, nil
}
+// ReadACK indicates the shim is expected to send an ACK at the specified epoch.
+// The contents of the ACK are checked against the connection's internal
+// simulation of the shim's expected behavior.
+func (c *DTLSController) ReadACK(epoch uint16) {
+ if c.err != nil {
+ return
+ }
+
+ c.err = c.conn.dtlsFlushPacket()
+ if c.err != nil {
+ return
+ }
+
+ typ, data, err := c.conn.dtlsDoReadRecord(c.getInEpochOrPanic(epoch), recordTypeACK)
+ if err != nil {
+ c.err = err
+ return
+ }
+ if typ != recordTypeACK {
+ c.err = fmt.Errorf("tls: got record of type %d, but expected ACK", typ)
+ return
+ }
+
+ c.err = c.conn.checkACK(data)
+}
+
// WriteAppData writes an application data record to the shim. This may be used
// to test that post-handshake retransmits may interleave with application data.
func (c *DTLSController) WriteAppData(epoch uint16, data []byte) {
@@ -1239,7 +1330,7 @@
return
}
- _, c.err = c.conn.dtlsPackRecord(c.getOutEpochOrPanic(epoch), recordTypeApplicationData, data, false)
+ _, _, c.err = c.conn.dtlsPackRecord(c.getOutEpochOrPanic(epoch), recordTypeApplicationData, data, false)
}
// ReadAppData indicates the shim is expected to send the specified application
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 35ba616..b7f1817 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -1515,6 +1515,12 @@
}
}
+ if c.isDTLS && len(c.expectedACK) != 0 {
+ if err := c.readRecord(recordTypeACK); err != nil {
+ return err
+ }
+ }
+
return nil
}
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index fad3ac8..d4a82e6 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -104,11 +104,16 @@
// This is currently used to control tests that enable all curves but may
// automatically disable tests in the future.
AllCurves []int
+
+ // MaxACKBuffer is the maximum number of received records the shim is
+ // expected to retain when ACKing.
+ MaxACKBuffer int
}
// Setup shimConfig defaults aligning with BoringSSL.
var shimConfig ShimConfiguration = ShimConfiguration{
HalfRTTTickets: 2,
+ MaxACKBuffer: 32,
}
//go:embed rsa_2048_key.pem
@@ -11791,10 +11796,13 @@
name: "DTLS-Retransmit-Server-ACKEverything" + suffix,
config: Config{
MaxVersion: vers.version,
+ Credential: &rsaChainCertificate,
CurvePreferences: []CurveID{CurveX25519MLKEM768},
DefaultCurves: []CurveID{}, // Force HelloRetryRequest.
Bugs: ProtocolBugs{
- MaxPacketLength: 512,
+ // Send smaller packets to exercise more ACK cases.
+ MaxPacketLength: 512,
+ MaxHandshakeRecordLength: 512,
WriteFlightDTLS: func(c *DTLSController, prev, received, next []DTLSMessage, records []DTLSRecordNumberInfo) {
if len(received) > 0 {
c.WriteACK(c.OutEpoch(), records)
@@ -11806,10 +11814,20 @@
}
c.WriteFlight(next)
},
+ SequenceNumberMapping: func(in uint64) uint64 {
+ // Perturb sequence numbers to test that ACKs are sorted.
+ return in ^ 63
+ },
},
},
shimCertificate: &rsaChainCertificate,
- flags: slices.Concat(flags, []string{"-mtu", "512", "-curves", strconv.Itoa(int(CurveX25519MLKEM768))}),
+ flags: slices.Concat(flags, []string{
+ "-mtu", "512",
+ "-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+ // Request a client certificate so the client final flight is
+ // larger.
+ "-require-any-client-certificate",
+ }),
})
// ACK packets one by one, in reverse.
@@ -12061,7 +12079,7 @@
// ACK the first record the shim ever sent. It will have
// fallen off the queue by now, so it is expected to not
// impact the shim's retransmissions.
- c.WriteACK(c.OutEpoch(), []DTLSRecordNumberInfo{{Epoch: records[0].Epoch, Sequence: records[0].Sequence}})
+ c.WriteACK(c.OutEpoch(), []DTLSRecordNumberInfo{{DTLSRecordNumber: records[0].DTLSRecordNumber}})
c.AdvanceClock(useTimeouts[len(useTimeouts)-2])
c.ReadRetransmit()
}
@@ -12088,7 +12106,7 @@
// to the shim's ServerHello. ACK the shim's first
// record, which would have been part of
// HelloRetryRequest. This should not impact retransmit.
- c.WriteACK(c.OutEpoch(), []DTLSRecordNumberInfo{{Epoch: 0, Sequence: 0}})
+ c.WriteACK(c.OutEpoch(), []DTLSRecordNumberInfo{{DTLSRecordNumber: DTLSRecordNumber{Epoch: 0, Sequence: 0}}})
c.AdvanceClock(useTimeouts[0])
c.ReadRetransmit()
}
@@ -12099,6 +12117,41 @@
flags: flags,
})
+ // Records that contain a mix of discarded and processed fragments should
+ // not be ACKed.
+ testCases = append(testCases, testCase{
+ protocol: dtls,
+ testType: serverTest,
+ name: "DTLS-Retransmit-Server-DoNotACKDiscardedFragments" + suffix,
+ config: Config{
+ MaxVersion: vers.version,
+ DefaultCurves: []CurveID{}, // Force a HelloRetryRequest.
+ Bugs: ProtocolBugs{
+ PackHandshakeFragments: 4096,
+ WriteFlightDTLS: func(c *DTLSController, prev, received, next []DTLSMessage, records []DTLSRecordNumberInfo) {
+ // Send the flight, but combine every fragment with a far future
+ // fragment, which the shim will discard. During the handshake,
+ // the shim has enough information to reject this entirely, but
+ // that would require coordinating with the handshake state
+ // machine. Instead, BoringSSL discards the fragment and skips
+ // ACKing the packet.
+ //
+ // runner implicitly tests that the shim ACKs the Finished flight
+ // (or, in case, that it is does not), so this exercises the final
+ // ACK.
+ //
+ // TODO(crbug.com/42290594): Once we send partial ACKs, exercise
+ // those here.
+ for _, msg := range next {
+ shouldDiscard := DTLSFragment{Epoch: msg.Epoch, Sequence: 1000, ShouldDiscard: true}
+ c.WriteFragments([]DTLSFragment{shouldDiscard, msg.Fragment(0, len(msg.Data))})
+ }
+ },
+ },
+ },
+ flags: flags,
+ })
+
// As a client, the shim must tolerate ACKs in response to its
// initial ClientHello, but it will not process them because the
// version is not yet known. The second ClientHello, in response
diff --git a/ssl/tls13_server.cc b/ssl/tls13_server.cc
index 73eb3c5..989c2fd 100644
--- a/ssl/tls13_server.cc
+++ b/ssl/tls13_server.cc
@@ -1276,7 +1276,7 @@
}
ssl->method->next_message(ssl);
- return ssl_hs_ok;
+ return ssl_hs_ack;
}
static enum ssl_hs_wait_t do_send_new_session_ticket(SSL_HANDSHAKE *hs) {
@@ -1286,16 +1286,7 @@
}
hs->tls13_state = state13_done;
- // In TLS 1.3, the NewSessionTicket isn't flushed until the server performs a
- // write, to prevent a non-reading client from causing the server to hang in
- // the case of a small server write buffer. Consumers which don't write data
- // to the client will need to do a zero-byte write if they wish to flush the
- // tickets.
- if ((hs->ssl->quic_method != nullptr || SSL_is_dtls(hs->ssl)) &&
- sent_tickets) {
- return ssl_hs_flush;
- }
- return ssl_hs_ok;
+ return sent_tickets ? ssl_hs_flush_post_handshake : ssl_hs_ok;
}
enum ssl_hs_wait_t tls13_server_handshake(SSL_HANDSHAKE *hs) {
diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc
index 7b1c21f..f585734 100644
--- a/ssl/tls_method.cc
+++ b/ssl/tls_method.cc
@@ -160,6 +160,7 @@
tls_add_message,
tls_add_change_cipher_spec,
tls_flush_flight,
+ /*send_ack=*/nullptr,
tls_on_handshake_complete,
tls_set_read_state,
tls_set_write_state,