| // Copyright 2014 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| // DTLS implementation. |
| // |
| // NOTE: This is a not even a remotely production-quality DTLS |
| // implementation. It is the bare minimum necessary to be able to |
| // achieve coverage on BoringSSL's implementation. Of note is that |
| // this implementation assumes the underlying net.PacketConn is not |
| // only reliable but also ordered. BoringSSL will be expected to deal |
| // with simulated loss, but there is no point in forcing the test |
| // driver to. |
| |
| package runner |
| |
| import ( |
| "bytes" |
| "cmp" |
| "encoding/binary" |
| "errors" |
| "fmt" |
| "math/rand" |
| "net" |
| "slices" |
| "time" |
| |
| "golang.org/x/crypto/cryptobyte" |
| ) |
| |
| // A DTLSMessage is a DTLS handshake message or ChangeCipherSpec, along with the |
| // epoch that it is to be sent under. |
| type DTLSMessage struct { |
| Epoch uint16 |
| IsChangeCipherSpec bool |
| // The following fields are only used if IsChangeCipherSpec is false. |
| Type uint8 |
| Sequence uint16 |
| Data []byte |
| } |
| |
| // Fragment returns a DTLSFragment for the message with the specified offset and |
| // length. |
| func (m *DTLSMessage) Fragment(offset, length int) DTLSFragment { |
| if m.IsChangeCipherSpec { |
| // Ignore the offset. ChangeCipherSpec cannot be fragmented. |
| return DTLSFragment{ |
| Epoch: m.Epoch, |
| IsChangeCipherSpec: m.IsChangeCipherSpec, |
| Data: m.Data, |
| } |
| } |
| |
| return DTLSFragment{ |
| Epoch: m.Epoch, |
| Sequence: m.Sequence, |
| Type: m.Type, |
| Data: m.Data[offset : offset+length], |
| Offset: offset, |
| TotalLength: len(m.Data), |
| } |
| } |
| |
| // Split returns two fragments for the message. |
| func (m *DTLSMessage) Split(offset int) (DTLSFragment, DTLSFragment) { |
| if m.IsChangeCipherSpec { |
| panic("tls: cannot split ChangeCipherSpec") |
| } |
| |
| return m.Fragment(0, offset), m.Fragment(offset, len(m.Data)-offset) |
| } |
| |
| // A DTLSFragment is a DTLS handshake fragment or ChangeCipherSpec, along with |
| // the epoch that it is to be sent under. |
| type DTLSFragment struct { |
| Epoch uint16 |
| IsChangeCipherSpec bool |
| // The following fields are only used if IsChangeCipherSpec is false. |
| Type uint8 |
| TotalLength int |
| Sequence uint16 |
| Offset int |
| Data []byte |
| // ShouldDiscard, if true, indicates the shim is expected to discard this |
| // fragment. A record with such a fragment must not be ACKed by the shim. |
| ShouldDiscard bool |
| } |
| |
| func (f *DTLSFragment) Bytes() []byte { |
| if f.IsChangeCipherSpec { |
| return f.Data |
| } |
| |
| bb := cryptobyte.NewBuilder(make([]byte, 0, 12+len(f.Data))) |
| bb.AddUint8(f.Type) |
| bb.AddUint24(uint32(f.TotalLength)) |
| bb.AddUint16(f.Sequence) |
| bb.AddUint24(uint32(f.Offset)) |
| addUint24LengthPrefixedBytes(bb, f.Data) |
| return bb.BytesOrPanic() |
| } |
| |
| func comparePair[T1 cmp.Ordered, T2 cmp.Ordered](a1 T1, a2 T2, b1 T1, b2 T2) int { |
| cmp1 := cmp.Compare(a1, b1) |
| if cmp1 != 0 { |
| return cmp1 |
| } |
| return cmp.Compare(a2, b2) |
| } |
| |
| type DTLSRecordNumber struct { |
| // Store the Epoch as a uint64, so that tests can send ACKs for epochs that |
| // the shim would never use. |
| Epoch uint64 |
| Sequence uint64 |
| } |
| |
| // A DTLSRecordNumberInfo contains information about a record received from the |
| // shim, which we may attempt to ACK. |
| type DTLSRecordNumberInfo struct { |
| DTLSRecordNumber |
| // The first byte covered by this record, inclusive. We only need to store |
| // one range because we require that the shim arrange fragments in order. |
| // Any gaps will have been previously-ACKed data, so there is no harm in |
| // double-ACKing. |
| MessageStartSequence uint16 |
| MessageStartOffset int |
| // The last byte covered by this record, exclusive. |
| MessageEndSequence uint16 |
| MessageEndOffset int |
| } |
| |
| func (r *DTLSRecordNumberInfo) HasACKInformation() bool { |
| return comparePair(r.MessageStartSequence, r.MessageStartOffset, r.MessageEndSequence, r.MessageEndOffset) < 0 |
| } |
| |
| func (c *Conn) readDTLS13RecordHeader(epoch *epochState, b []byte) (headerLen int, recordLen int, recTyp recordType, err error) { |
| // The DTLS 1.3 record header starts with the type byte containing |
| // 0b001CSLEE, where C, S, L, and EE are bits with the following |
| // meanings: |
| // |
| // C=1: Connection ID is present (C=0: CID is absent) |
| // S=1: the sequence number is 16 bits (S=0: it is 8 bits) |
| // L=1: 16-bit length field is present (L=0: record goes to end of packet) |
| // EE: low two bits of the epoch. |
| // |
| // A real DTLS implementation would parse these bits and take |
| // appropriate action based on them. However, this is a test |
| // implementation, and the code we are testing only ever sends C=0, S=1, |
| // L=1. This code expects those bits to be set and will error if |
| // anything else is set. This means we expect the type byte to look like |
| // 0b001011EE, or 0x2c-0x2f. |
| recordHeaderLen := 5 |
| if len(b) < recordHeaderLen { |
| return 0, 0, 0, errors.New("dtls: failed to read record header") |
| } |
| typ := b[0] |
| if typ&0xfc != 0x2c { |
| return 0, 0, 0, errors.New("dtls: DTLS 1.3 record header has bad type byte") |
| } |
| // For test purposes, require the epoch received be the same as the |
| // epoch we expect to receive. |
| epochBits := typ & 0x03 |
| if epochBits != byte(epoch.epoch&0x03) { |
| c.sendAlert(alertIllegalParameter) |
| return 0, 0, 0, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch")) |
| } |
| wireSeq := b[1:3] |
| if !c.config.Bugs.NullAllCiphers { |
| sample := b[recordHeaderLen:] |
| mask := epoch.recordNumberEncrypter.generateMask(sample) |
| xorSlice(wireSeq, mask) |
| } |
| decWireSeq := binary.BigEndian.Uint16(wireSeq) |
| // Reconstruct the sequence number from the low 16 bits on the wire. |
| // A real implementation would compute the full sequence number that is |
| // closest to the highest successfully decrypted record in the |
| // identified epoch. Since this test implementation errors on decryption |
| // failures instead of simply discarding packets, it reconstructs a |
| // sequence number that is not less than c.in.seq. (This matches the |
| // behavior of the check of the sequence number in the old record |
| // header format.) |
| seqInt := binary.BigEndian.Uint64(epoch.seq[:]) |
| // epoch.seq has the epoch in the upper two bytes - clear those. |
| seqInt = seqInt &^ (0xffff << 48) |
| newSeq := seqInt&^0xffff | uint64(decWireSeq) |
| if newSeq < seqInt { |
| newSeq += 0x10000 |
| } |
| |
| seq := make([]byte, 8) |
| binary.BigEndian.PutUint64(seq, newSeq) |
| copy(epoch.seq[2:], seq[2:]) |
| |
| recordLen = int(b[3])<<8 | int(b[4]) |
| return recordHeaderLen, recordLen, 0, nil |
| } |
| |
| // readDTLSRecordHeader reads the record header from the input. Based on the |
| // header it reads, it checks the header's validity and sets appropriate state |
| // as needed. This function returns the record header and the record type |
| // indicated in the header (if it contains the type). The connection's internal |
| // sequence number is updated to the value from the header. |
| func (c *Conn) readDTLSRecordHeader(epoch *epochState, b []byte) (headerLen int, recordLen int, typ recordType, err error) { |
| if epoch.cipher != nil && c.in.version >= VersionTLS13 { |
| return c.readDTLS13RecordHeader(epoch, b) |
| } |
| |
| recordHeaderLen := 13 |
| // Read out one record. |
| // |
| // A real DTLS implementation should be tolerant of errors, |
| // but this is test code. We should not be tolerant of our |
| // peer sending garbage. |
| if len(b) < recordHeaderLen { |
| return 0, 0, 0, errors.New("dtls: failed to read record header") |
| } |
| typ = recordType(b[0]) |
| vers := uint16(b[1])<<8 | uint16(b[2]) |
| // Alerts sent near version negotiation do not have a well-defined |
| // record-layer version prior to TLS 1.3. (In TLS 1.3, the record-layer |
| // version is irrelevant.) Additionally, if we're reading a retransmission, |
| // the peer may not know the version yet. |
| if typ != recordTypeAlert && !c.skipRecordVersionCheck { |
| if c.haveVers { |
| wireVersion := c.wireVersion |
| if c.vers >= VersionTLS13 { |
| wireVersion = VersionDTLS12 |
| } |
| if vers != wireVersion { |
| c.sendAlert(alertProtocolVersion) |
| return 0, 0, 0, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.wireVersion)) |
| } |
| } else { |
| if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect { |
| c.sendAlert(alertProtocolVersion) |
| return 0, 0, 0, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect)) |
| } |
| } |
| } |
| epochValue := binary.BigEndian.Uint16(b[3:5]) |
| seq := b[5:11] |
| // For test purposes, require the sequence number be monotonically |
| // increasing, so c.in includes the minimum next sequence number. Gaps |
| // may occur if packets failed to be sent out. A real implementation |
| // would maintain a replay window and such. |
| if epochValue != epoch.epoch { |
| c.sendAlert(alertIllegalParameter) |
| return 0, 0, 0, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch")) |
| } |
| if bytes.Compare(seq, epoch.seq[2:]) < 0 { |
| c.sendAlert(alertIllegalParameter) |
| return 0, 0, 0, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number")) |
| } |
| copy(epoch.seq[2:], seq) |
| recordLen = int(b[11])<<8 | int(b[12]) |
| return recordHeaderLen, recordLen, typ, nil |
| } |
| |
| func (c *Conn) dtlsDoReadRecord(epoch *epochState, want recordType) (recordType, []byte, error) { |
| // Read a new packet only if the current one is empty. |
| var newPacket bool |
| bytesAvailableInLastPacket := c.bytesAvailableInPacket |
| if c.rawInput.Len() == 0 { |
| // Pick some absurdly large buffer size. |
| c.rawInput.Grow(maxCiphertext + dtlsMaxRecordHeaderLen) |
| buf := c.rawInput.AvailableBuffer() |
| n, err := c.conn.Read(buf[:cap(buf)]) |
| if err != nil { |
| return 0, nil, err |
| } |
| if c.maxPacketLen != 0 { |
| if n > c.maxPacketLen { |
| return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length") |
| } |
| c.bytesAvailableInPacket = c.maxPacketLen - n |
| } else { |
| c.bytesAvailableInPacket = 0 |
| } |
| c.rawInput.Write(buf[:n]) |
| newPacket = true |
| } |
| |
| // Consume the next record from the buffer. |
| recordHeaderLen, n, typ, err := c.readDTLSRecordHeader(epoch, c.rawInput.Bytes()) |
| if err != nil { |
| return 0, nil, err |
| } |
| if n > maxCiphertext || c.rawInput.Len() < recordHeaderLen+n { |
| c.sendAlert(alertRecordOverflow) |
| return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n)) |
| } |
| b := c.rawInput.Next(recordHeaderLen + n) |
| |
| // Process message. |
| ok, encTyp, data, alertValue := c.in.decrypt(epoch, recordHeaderLen, b) |
| if !ok { |
| // A real DTLS implementation would silently ignore bad records, |
| // but we want to notice errors from the implementation under |
| // test. |
| return 0, nil, c.in.setErrorLocked(c.sendAlert(alertValue)) |
| } |
| |
| if typ == 0 { |
| // readDTLSRecordHeader sets typ=0 when decoding the DTLS 1.3 |
| // record header. When the new record header format is used, the |
| // type is returned by decrypt() in encTyp. |
| typ = encTyp |
| } |
| |
| if typ == recordTypeChangeCipherSpec || typ == recordTypeHandshake { |
| // If this is not the first record in the flight, check if it was packed |
| // efficiently. |
| if c.lastRecordInFlight != nil { |
| // 12-byte header + 1-byte fragment is the minimum to make progress. |
| const handshakeBytesNeeded = 13 |
| if typ == recordTypeHandshake && c.lastRecordInFlight.typ == recordTypeHandshake && epoch.epoch == c.lastRecordInFlight.epoch { |
| // The previous record was compatible with this one. The shim |
| // should have fit more in this record before making a new one. |
| // TODO(crbug.com/374991962): Enforce this for plaintext records |
| // too. |
| if c.lastRecordInFlight.bytesAvailable > handshakeBytesNeeded && epoch.epoch > 0 { |
| return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: previous handshake record had %d bytes available, but shim did not fit another fragment in it", c.lastRecordInFlight.bytesAvailable)) |
| } |
| } else if newPacket { |
| // The shim had to make a new record, but it did not need to |
| // make a new packet if this record fit in the previous. |
| bytesNeeded := 1 |
| if typ == recordTypeHandshake { |
| bytesNeeded = handshakeBytesNeeded |
| } |
| bytesNeeded += recordHeaderLen + c.in.maxEncryptOverhead(epoch, bytesNeeded) |
| if bytesNeeded < bytesAvailableInLastPacket { |
| return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: previous packet had %d bytes available, but shim did not fit record of type %d into it", bytesAvailableInLastPacket, typ)) |
| } |
| } |
| } |
| |
| // Save information about the current record, including how many more |
| // bytes the shim could have added. |
| recordBytesAvailable := c.bytesAvailableInPacket + c.rawInput.Len() |
| if cbc, ok := epoch.cipher.(*cbcMode); ok { |
| // It is possible that adding a byte would have added another block. |
| recordBytesAvailable = max(0, recordBytesAvailable-cbc.BlockSize()) |
| } |
| c.lastRecordInFlight = &dtlsRecordInfo{typ: typ, epoch: epoch.epoch, bytesAvailable: recordBytesAvailable} |
| } else { |
| c.lastRecordInFlight = nil |
| } |
| |
| return typ, data, nil |
| } |
| |
| func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) { |
| epoch := &c.out.epoch |
| |
| // Outgoing DTLS records are buffered in several stages, to test the various |
| // layers that data may be combined in. |
| // |
| // First, handshake and ChangeCipherSpec records are buffered in |
| // c.nextFlight, to be flushed by dtlsWriteFlight. |
| // |
| // dtlsWriteFlight, with the aid of a test-supplied callback, will divide |
| // them into handshake records containing fragments, possibly with some |
| // rounds of shim retransmit tests. Those records and any other |
| // non-handshake application data records are encrypted by dtlsPackRecord |
| // into c.pendingPacket, which may combine multiple records into one packet. |
| // |
| // Finally, dtlsFlushPacket writes the packet to the shim. |
| |
| if typ == recordTypeChangeCipherSpec { |
| // Don't send ChangeCipherSpec in DTLS 1.3. |
| // TODO(crbug.com/42290594): Add an option to send them anyway and test |
| // what our implementation does with unexpected ones. |
| if c.vers >= VersionTLS13 { |
| return |
| } |
| c.nextFlight = append(c.nextFlight, DTLSMessage{ |
| Epoch: epoch.epoch, |
| IsChangeCipherSpec: true, |
| Data: slices.Clone(data), |
| }) |
| err = c.out.changeCipherSpec() |
| if err != nil { |
| return n, c.sendAlertLocked(alertLevelError, err.(alert)) |
| } |
| return len(data), nil |
| } |
| if typ == recordTypeHandshake { |
| // Handshake messages have to be modified to include fragment |
| // offset and length and with the header replicated. Save the |
| // TLS header here. |
| header := data[:4] |
| body := data[4:] |
| |
| c.nextFlight = append(c.nextFlight, DTLSMessage{ |
| Epoch: epoch.epoch, |
| Sequence: c.sendHandshakeSeq, |
| Type: header[0], |
| Data: slices.Clone(body), |
| }) |
| c.sendHandshakeSeq++ |
| return len(data), nil |
| } |
| |
| // Flush any packets buffered from the handshake. |
| err = c.dtlsWriteFlight() |
| if err != nil { |
| return |
| } |
| |
| if typ == recordTypeApplicationData && len(data) > 1 && c.config.Bugs.SplitAndPackAppData { |
| _, _, err = c.dtlsPackRecord(epoch, typ, data[:len(data)/2], false) |
| if err != nil { |
| return |
| } |
| _, _, err = c.dtlsPackRecord(epoch, typ, data[len(data)/2:], true) |
| if err != nil { |
| return |
| } |
| n = len(data) |
| } else { |
| n, _, err = c.dtlsPackRecord(epoch, typ, data, false) |
| if err != nil { |
| return |
| } |
| } |
| |
| err = c.dtlsFlushPacket() |
| return |
| } |
| |
| // dtlsWriteFlight packs the pending handshake flight into the pending record. |
| // Callers should follow up with dtlsFlushPacket to write the packets. |
| func (c *Conn) dtlsWriteFlight() error { |
| if len(c.nextFlight) == 0 { |
| return nil |
| } |
| |
| // Avoid re-entrancy issues by updating the state immediately. The callback |
| // may try to write records. |
| prev, received, next, records := c.previousFlight, c.receivedFlight, c.nextFlight, c.receivedFlightRecords |
| c.previousFlight, c.receivedFlight, c.nextFlight, c.receivedFlightRecords = next, nil, nil, nil |
| |
| controller := newDTLSController(c, received) |
| if c.config.Bugs.WriteFlightDTLS != nil { |
| c.config.Bugs.WriteFlightDTLS(&controller, prev, received, next, records) |
| } else { |
| controller.WriteFlight(next) |
| } |
| if err := controller.Err(); err != nil { |
| return err |
| } |
| |
| if c.receivedFlight != nil || c.receivedFlightRecords != nil || c.nextFlight != nil { |
| panic("tls: flight state changed while writing flight") |
| } |
| if controller.mergeIntoNextFlight { |
| c.previousFlight, c.receivedFlight, c.nextFlight, c.receivedFlightRecords = prev, received, next, records |
| } |
| |
| // Flush any ACKs, etc., we may have written. |
| return c.dtlsFlushPacket() |
| } |
| |
| func (c *Conn) dtlsFlushHandshake() error { |
| if err := c.dtlsWriteFlight(); err != nil { |
| return err |
| } |
| if err := c.dtlsFlushPacket(); err != nil { |
| return err |
| } |
| |
| return nil |
| } |
| |
| func (c *Conn) dtlsACKHandshake() error { |
| if len(c.receivedFlight) == 0 { |
| return nil |
| } |
| |
| if len(c.nextFlight) != 0 { |
| panic("tls: not a final flight; more messages were queued up") |
| } |
| |
| // Avoid re-entrancy issues by updating the state immediately. The callback |
| // may try to write records. |
| prev, received, records := c.previousFlight, c.receivedFlight, c.receivedFlightRecords |
| c.previousFlight, c.receivedFlight, c.receivedFlightRecords = nil, nil, nil |
| |
| controller := newDTLSController(c, received) |
| if c.config.Bugs.ACKFlightDTLS != nil { |
| c.config.Bugs.ACKFlightDTLS(&controller, prev, received, records) |
| } else { |
| if c.vers >= VersionTLS13 { |
| controller.WriteACK(controller.OutEpoch(), records) |
| } |
| } |
| if err := controller.Err(); err != nil { |
| return err |
| } |
| |
| if c.previousFlight != nil || c.receivedFlight != nil || c.receivedFlightRecords != nil { |
| panic("tls: flight state changed while ACKing flight") |
| } |
| if controller.mergeIntoNextFlight { |
| c.previousFlight, c.receivedFlight, c.receivedFlightRecords = prev, received, records |
| } |
| |
| // Flush any ACKs, etc., we may have written. |
| return c.dtlsFlushPacket() |
| } |
| |
| // appendDTLS13RecordHeader appends to b the record header for a record of length |
| // recordLen. |
| func (c *Conn) appendDTLS13RecordHeader(b, seq []byte, recordLen int) []byte { |
| // Set the top 3 bits on the type byte to indicate the DTLS 1.3 record |
| // header format. |
| typ := byte(0x20) |
| // Set the Connection ID bit |
| if c.config.Bugs.DTLS13RecordHeaderSetCIDBit && c.handshakeComplete { |
| typ |= 0x10 |
| } |
| // Set the sequence number length bit |
| if !c.config.DTLSUseShortSeqNums { |
| typ |= 0x08 |
| } |
| // Set the length presence bit |
| if !c.config.DTLSRecordHeaderOmitLength { |
| typ |= 0x04 |
| } |
| // Set the epoch bits |
| typ |= seq[1] & 0x3 |
| b = append(b, typ) |
| if c.config.DTLSUseShortSeqNums { |
| b = append(b, seq[7]) |
| } else { |
| b = append(b, seq[6], seq[7]) |
| } |
| if !c.config.DTLSRecordHeaderOmitLength { |
| b = append(b, byte(recordLen>>8), byte(recordLen)) |
| } |
| return b |
| } |
| |
| // dtlsPackRecord packs a single record to the pending packet, flushing it |
| // if necessary. The caller should call dtlsFlushPacket to flush the current |
| // pending packet afterwards. |
| func (c *Conn) dtlsPackRecord(epoch *epochState, typ recordType, data []byte, mustPack bool) (n int, num DTLSRecordNumber, err error) { |
| maxLen := c.config.Bugs.MaxHandshakeRecordLength |
| if maxLen <= 0 { |
| maxLen = 1024 |
| } |
| |
| vers := c.wireVersion |
| if vers == 0 { |
| // Some TLS servers fail if the record version is greater than |
| // TLS 1.0 for the initial ClientHello. |
| if c.isDTLS { |
| vers = VersionDTLS10 |
| } else { |
| vers = VersionTLS10 |
| } |
| } |
| if c.vers >= VersionTLS13 || c.out.version >= VersionTLS13 { |
| vers = VersionDTLS12 |
| } |
| |
| useDTLS13RecordHeader := c.out.version >= VersionTLS13 && epoch.cipher != nil && !c.useDTLSPlaintextHeader() |
| headerHasLength := true |
| record := make([]byte, 0, dtlsMaxRecordHeaderLen+len(data)+c.out.maxEncryptOverhead(epoch, len(data))) |
| seq := c.out.sequenceNumberForOutput(epoch) |
| if useDTLS13RecordHeader { |
| record = c.appendDTLS13RecordHeader(record, seq, len(data)) |
| headerHasLength = !c.config.DTLSRecordHeaderOmitLength |
| } else { |
| record = append(record, byte(typ)) |
| record = append(record, byte(vers>>8)) |
| record = append(record, byte(vers)) |
| // DTLS records include an explicit sequence number. |
| record = append(record, seq...) |
| record = append(record, byte(len(data)>>8)) |
| record = append(record, byte(len(data))) |
| } |
| |
| recordHeaderLen := len(record) |
| record, err = c.out.encrypt(epoch, record, data, typ, recordHeaderLen, headerHasLength) |
| if err != nil { |
| return |
| } |
| num = c.out.lastRecordNumber(epoch, true /* isOut */) |
| |
| // Encrypt the sequence number. |
| if useDTLS13RecordHeader && !c.config.Bugs.NullAllCiphers { |
| sample := record[recordHeaderLen:] |
| mask := epoch.recordNumberEncrypter.generateMask(sample) |
| seqLen := 2 |
| if c.config.DTLSUseShortSeqNums { |
| seqLen = 1 |
| } |
| // The sequence number starts at index 1 in the record header. |
| xorSlice(record[1:1+seqLen], mask) |
| } |
| |
| // Flush the current pending packet if necessary. |
| if !mustPack && len(record)+len(c.pendingPacket) > c.config.Bugs.PackHandshakeRecords { |
| err = c.dtlsFlushPacket() |
| if err != nil { |
| return |
| } |
| } |
| |
| // Add the record to the pending packet. |
| c.pendingPacket = append(c.pendingPacket, record...) |
| if c.config.DTLSRecordHeaderOmitLength { |
| if c.config.Bugs.SplitAndPackAppData { |
| panic("incompatible config") |
| } |
| err = c.dtlsFlushPacket() |
| if err != nil { |
| return |
| } |
| } |
| n = len(data) |
| return |
| } |
| |
| func (c *Conn) dtlsFlushPacket() error { |
| if c.hand.Len() == 0 { |
| c.lastRecordInFlight = nil |
| } |
| if len(c.pendingPacket) == 0 { |
| return nil |
| } |
| _, err := c.conn.Write(c.pendingPacket) |
| c.pendingPacket = nil |
| return err |
| } |
| |
| func readDTLSFragment(s *cryptobyte.String) (DTLSFragment, error) { |
| var f DTLSFragment |
| var totLen, fragOffset uint32 |
| if !s.ReadUint8(&f.Type) || |
| !s.ReadUint24(&totLen) || |
| !s.ReadUint16(&f.Sequence) || |
| !s.ReadUint24(&fragOffset) || |
| !readUint24LengthPrefixedBytes(s, &f.Data) { |
| return DTLSFragment{}, errors.New("dtls: bad handshake record") |
| } |
| f.TotalLength = int(totLen) |
| f.Offset = int(fragOffset) |
| if f.Offset > f.TotalLength || len(f.Data) > f.TotalLength-f.Offset { |
| return DTLSFragment{}, errors.New("dtls: bad fragment offset") |
| } |
| // Although syntactically valid, the shim should never send empty fragments |
| // of non-empty messages. |
| if len(f.Data) == 0 && f.TotalLength != 0 { |
| return DTLSFragment{}, errors.New("dtls: fragment makes no progress") |
| } |
| return f, nil |
| } |
| |
| func (c *Conn) makeDTLSRecordNumberInfo(epoch *epochState, data []byte) (DTLSRecordNumberInfo, error) { |
| info := DTLSRecordNumberInfo{DTLSRecordNumber: c.in.lastRecordNumber(epoch, false /* isOut */)} |
| |
| s := cryptobyte.String(data) |
| first := true |
| for !s.Empty() { |
| f, err := readDTLSFragment(&s) |
| if err != nil { |
| return DTLSRecordNumberInfo{}, err |
| } |
| // This assumes the shim sent fragments in order. This isn't checked |
| // here, but the caller will check when processing the fragments. |
| if first { |
| info.MessageStartSequence = f.Sequence |
| info.MessageStartOffset = f.Offset |
| first = false |
| } |
| info.MessageEndSequence = f.Sequence |
| info.MessageEndOffset = f.Offset + len(f.Data) |
| } |
| |
| return info, nil |
| } |
| |
| func (c *Conn) dtlsDoReadHandshake() ([]byte, error) { |
| // Assemble a full handshake message. For test purposes, this |
| // implementation assumes fragments arrive in order. It may |
| // need to be cleverer if we ever test BoringSSL's retransmit |
| // behavior. |
| for len(c.handMsg) < 4+c.handMsgLen { |
| // Get a new handshake record if the previous has been |
| // exhausted. |
| if c.hand.Len() == 0 { |
| if err := c.in.err; err != nil { |
| return nil, err |
| } |
| if err := c.readRecord(recordTypeHandshake); err != nil { |
| return nil, err |
| } |
| } |
| |
| // Read the next fragment. It must fit entirely within |
| // the record. |
| s := cryptobyte.String(c.hand.Bytes()) |
| f, err := readDTLSFragment(&s) |
| if err != nil { |
| return nil, err |
| } |
| c.hand.Next(c.hand.Len() - len(s)) |
| |
| // Check it's a fragment for the right message. |
| if f.Sequence != c.recvHandshakeSeq { |
| return nil, errors.New("dtls: bad handshake sequence number") |
| } |
| |
| // Check that the length is consistent. |
| if c.handMsg == nil { |
| c.handMsgLen = f.TotalLength |
| if c.handMsgLen > maxHandshake { |
| return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError)) |
| } |
| // Start with the TLS handshake header, |
| // without the DTLS bits. |
| c.handMsg = []byte{f.Type, byte(f.TotalLength >> 16), byte(f.TotalLength >> 8), byte(f.TotalLength)} |
| } else if f.TotalLength != c.handMsgLen { |
| return nil, errors.New("dtls: bad handshake length") |
| } |
| |
| // Add the fragment to the pending message. |
| if 4+f.Offset != len(c.handMsg) { |
| return nil, errors.New("dtls: bad fragment offset") |
| } |
| // If the message isn't complete, check that the peer could not have |
| // fit more into the record. |
| c.handMsg = append(c.handMsg, f.Data...) |
| if len(c.handMsg) < 4+c.handMsgLen { |
| if c.hand.Len() != 0 { |
| return nil, errors.New("dtls: truncated handshake fragment was not last in the record") |
| } |
| if c.lastRecordInFlight.bytesAvailable > 0 { |
| return nil, fmt.Errorf("dtls: handshake fragment was truncated, but record could have fit %d more bytes", c.lastRecordInFlight.bytesAvailable) |
| } |
| } |
| |
| // Sending part of the next flight implicitly ACKs the previous flight. |
| // Having triggered this, the shim is expected to clear its ACK buffer. |
| c.expectedACK = nil |
| } |
| c.recvHandshakeSeq++ |
| ret := c.handMsg |
| c.handMsg, c.handMsgLen = nil, 0 |
| c.receivedFlight = append(c.receivedFlight, DTLSMessage{ |
| Epoch: c.in.epoch.epoch, |
| Type: ret[0], |
| Sequence: c.recvHandshakeSeq - 1, |
| Data: ret[4:], |
| }) |
| return ret, nil |
| } |
| |
| func (c *Conn) checkACK(data []byte) error { |
| s := cryptobyte.String(data) |
| var child cryptobyte.String |
| if !s.ReadUint16LengthPrefixed(&child) || !s.Empty() { |
| return fmt.Errorf("tls: could not parse ACK record") |
| } |
| |
| var acks []DTLSRecordNumber |
| for !child.Empty() { |
| var num DTLSRecordNumber |
| if !child.ReadUint64(&num.Epoch) || !child.ReadUint64(&num.Sequence) { |
| return fmt.Errorf("tls: could not parse ACK record") |
| } |
| acks = append(acks, num) |
| } |
| |
| // Determine the expected ACKs, if any. |
| expected := c.expectedACK |
| if len(expected) > shimConfig.MaxACKBuffer { |
| expected = expected[len(expected)-shimConfig.MaxACKBuffer:] |
| } |
| |
| // If we've configured a tighter MTU, the shim might have needed to truncate |
| // the list. Tolerate this as long as the shim sent the more recent records |
| // and still sent a plausible minimum number of ACKs. |
| if c.maxPacketLen != 0 && len(acks) > 10 && len(acks) < len(expected) { |
| expected = expected[len(expected)-len(acks):] |
| } |
| |
| // The shim is expected to sort the record numbers in the ACK. |
| expected = slices.Clone(expected) |
| slices.SortFunc(expected, func(a, b DTLSRecordNumber) int { |
| cmp1 := cmp.Compare(a.Epoch, b.Epoch) |
| if cmp1 != 0 { |
| return cmp1 |
| } |
| return cmp.Compare(a.Sequence, b.Sequence) |
| }) |
| |
| if !slices.Equal(acks, expected) { |
| return fmt.Errorf("tls: got ACKs %+v, but expected %+v", acks, expected) |
| } |
| |
| return nil |
| } |
| |
| // DTLSServer returns a new DTLS server side connection |
| // using conn as the underlying transport. |
| // The configuration config must be non-nil and must have |
| // at least one certificate. |
| func DTLSServer(conn net.Conn, config *Config) *Conn { |
| c := &Conn{config: config, isDTLS: true, conn: conn} |
| c.init() |
| return c |
| } |
| |
| // DTLSClient returns a new DTLS client side connection |
| // using conn as the underlying transport. |
| // The config cannot be nil: users must set either ServerHostname or |
| // InsecureSkipVerify in the config. |
| func DTLSClient(conn net.Conn, config *Config) *Conn { |
| c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn} |
| c.init() |
| return c |
| } |
| |
| type WriteFlightFunc = func(c *DTLSController, prev, received, next []DTLSMessage, records []DTLSRecordNumberInfo) |
| type ACKFlightFunc = func(c *DTLSController, prev, received []DTLSMessage, records []DTLSRecordNumberInfo) |
| |
| // A DTLSController is passed to a test callback and allows the callback to |
| // customize how an individual flight is sent. This is used to test DTLS's |
| // retransmission logic. |
| // |
| // Although DTLS runs over a lossy, reordered channel, runner assumes a |
| // reliable, ordered channel. When simulating packet loss, runner processes the |
| // shim's "lost" flight as usual. But, instead of responding, it calls a |
| // test-provided function type WriteFlightFunc. |
| // |
| // WriteFlight will be called next as the flight for the runner to send. prev is |
| // the previous flight sent by the runner, and received is the most recent |
| // flight received by the shim. prev and received may be nil if those flights do |
| // not exist. |
| // |
| // WriteFlight should send next to the shim, by calling methods on the |
| // DTLSController, and then return. The shim will then be expected to progress |
| // the connection. However, WriteFlight, may send fragments arbitrarily |
| // reordered or duplicated. It may also simulate packet loss with timeouts or |
| // retransmitted past fragments, and then test that the shim retransmits. |
| // |
| // WriteFlight must return as soon as the shim is expected to progress the |
| // connection. If WriteFlight expects the shim to send an alert, it must also |
| // return, at which point the logic to progress the connection will consume the |
| // alert and report it as a connection failure, to be captured in the test |
| // expectation. |
| // |
| // If unspecified, the default implementation of WriteFlight is: |
| // |
| // func WriteFlight(c *DTLSController, prev, received, next []DTLSMessage, records []DTLSRecordNumberInfo) { |
| // c.WriteFlight(next) |
| // } |
| // |
| // When the shim speaks last in a handshake or post-handshake transaction, there |
| // is no reply to implicitly acknowledge the flight. The runner will instead |
| // call a second callback type ACKFlightFunc. |
| // |
| // Like WriteFlight, ACKFlight may simulate packet loss with the DTLSController. |
| // It returns when it is ready to proceed. If not specified, it does nothing in |
| // DTLS 1.2 and ACKs the final flight in DTLS 1.3. |
| // |
| // This test design implicitly assumes the shim will never start a |
| // post-handshake transaction before the previous one is complete. Otherwise the |
| // retransmissions will get mixed up with the second transaction. |
| // |
| // For convenience, the DTLSController internally tracks whether it has |
| // encountered an error (e.g. an I/O error with the shim) and, if so, silently |
| // makes all methods do nothing. The Err method may be used to query if it is in |
| // this state, if it would otherwise cause an infinite loop. |
| type DTLSController struct { |
| conn *Conn |
| err error |
| // retransmitNeeded contains the list of fragments which the shim must |
| // retransmit. |
| retransmitNeeded []DTLSFragment |
| mergeIntoNextFlight bool |
| } |
| |
| func newDTLSController(conn *Conn, received []DTLSMessage) DTLSController { |
| var retransmitNeeded []DTLSFragment |
| for i := range received { |
| msg := &received[i] |
| retransmitNeeded = append(retransmitNeeded, msg.Fragment(0, len(msg.Data))) |
| } |
| |
| return DTLSController{conn: conn, retransmitNeeded: retransmitNeeded} |
| } |
| |
| func (c *DTLSController) getOutEpochOrPanic(epochValue uint16) *epochState { |
| epoch, ok := c.conn.out.getEpoch(epochValue) |
| if !ok { |
| panic(fmt.Sprintf("tls: could not find epoch %d", epochValue)) |
| } |
| return epoch |
| } |
| |
| func (c *DTLSController) getInEpochOrPanic(epochValue uint16) *epochState { |
| epoch, ok := c.conn.in.getEpoch(epochValue) |
| if !ok { |
| panic(fmt.Sprintf("tls: could not find epoch %d", epochValue)) |
| } |
| return epoch |
| } |
| |
| // Err returns whether the controller has stopped due to an error, or nil |
| // otherwise. If it returns non-nil, other methods will silently do nothing. |
| func (c *DTLSController) Err() error { return c.err } |
| |
| // OutEpoch returns the current outgoing epoch. |
| func (c *DTLSController) OutEpoch() uint16 { |
| return c.conn.out.epoch.epoch |
| } |
| |
| // InEpoch returns the current incoming epoch. |
| func (c *DTLSController) InEpoch() uint16 { |
| return c.conn.in.epoch.epoch |
| } |
| |
| // AdvanceClock advances the shim's clock by duration. It is a test failure if |
| // the shim sends anything before picking up the command. |
| func (c *DTLSController) AdvanceClock(duration time.Duration) { |
| if c.err != nil { |
| return |
| } |
| |
| c.err = c.conn.dtlsFlushPacket() |
| if c.err != nil { |
| return |
| } |
| |
| adaptor := c.conn.config.Bugs.PacketAdaptor |
| if adaptor == nil { |
| panic("tls: no PacketAdapter set") |
| } |
| |
| received, err := adaptor.SendReadTimeout(duration) |
| if err != nil { |
| c.err = err |
| } else if len(received) != 0 { |
| c.err = fmt.Errorf("tls: received %d unexpected packets while simulating a timeout", len(received)) |
| } |
| } |
| |
| // SetMTU updates the shim's MTU to mtu. |
| func (c *DTLSController) SetMTU(mtu int) { |
| if c.err != nil { |
| return |
| } |
| |
| adaptor := c.conn.config.Bugs.PacketAdaptor |
| if adaptor == nil { |
| panic("tls: no PacketAdapter set") |
| } |
| |
| c.conn.maxPacketLen = mtu |
| c.err = adaptor.SetPeerMTU(mtu) |
| } |
| |
| // WriteFlight writes msgs to the shim, using the default fragmenting logic. |
| // This may be used when the test is not concerned with fragmentation. |
| func (c *DTLSController) WriteFlight(msgs []DTLSMessage) { |
| config := c.conn.config |
| if c.err != nil { |
| return |
| } |
| |
| // Buffer up fragments to reorder them. |
| var fragments []DTLSFragment |
| |
| // TODO(davidben): All this could also have been implemented in the custom |
| // fallbacks. These options date to before we had the callback. Should some |
| // of them be moved out? |
| for _, msg := range msgs { |
| if msg.IsChangeCipherSpec { |
| fragments = append(fragments, msg.Fragment(0, len(msg.Data))) |
| continue |
| } |
| |
| maxLen := config.Bugs.MaxHandshakeRecordLength |
| if maxLen <= 0 { |
| maxLen = 1024 |
| } |
| |
| if config.Bugs.SendEmptyFragments { |
| fragments = append(fragments, msg.Fragment(0, 0)) |
| fragments = append(fragments, msg.Fragment(len(msg.Data), 0)) |
| } |
| |
| firstRun := true |
| fragOffset := 0 |
| for firstRun || fragOffset < len(msg.Data) { |
| firstRun = false |
| fragLen := min(len(msg.Data)-fragOffset, maxLen) |
| |
| fragment := msg.Fragment(fragOffset, fragLen) |
| fragments = append(fragments, fragment) |
| if config.Bugs.ReorderHandshakeFragments { |
| // Don't duplicate Finished to avoid the peer |
| // interpreting it as a retransmit request. |
| if msg.Type != typeFinished { |
| fragments = append(fragments, fragment) |
| } |
| |
| if fragLen > (maxLen+1)/2 { |
| // Overlap each fragment by half. |
| fragLen = (maxLen + 1) / 2 |
| } |
| } |
| fragOffset += fragLen |
| } |
| if config.Bugs.MixCompleteMessageWithFragments { |
| fragments = append(fragments, msg.Fragment(0, len(msg.Data))) |
| } |
| } |
| |
| // Reorder the fragments, but only within an epoch. |
| for start := 0; start < len(fragments); { |
| end := start + 1 |
| for end < len(fragments) && fragments[start].Epoch == fragments[end].Epoch { |
| end++ |
| } |
| chunk := fragments[start:end] |
| if config.Bugs.ReorderHandshakeFragments { |
| rand.Shuffle(len(chunk), func(i, j int) { chunk[i], chunk[j] = chunk[j], chunk[i] }) |
| } |
| start = end |
| } |
| |
| c.WriteFragments(fragments) |
| } |
| |
| // WriteFragments writes the specified handshake fragments to the shim. |
| func (c *DTLSController) WriteFragments(fragments []DTLSFragment) { |
| config := c.conn.config |
| if c.err != nil { |
| return |
| } |
| |
| maxRecordLen := config.Bugs.PackHandshakeFragments |
| packRecord := func(epoch *epochState, typ recordType, data []byte, anyDiscard bool) error { |
| _, num, err := c.conn.dtlsPackRecord(epoch, typ, data, false) |
| if err != nil { |
| return err |
| } |
| if !anyDiscard && typ == recordTypeHandshake { |
| c.conn.expectedACK = append(c.conn.expectedACK, num) |
| } |
| return nil |
| } |
| |
| // Pack handshake fragments into records. |
| var record []byte |
| var epoch *epochState |
| var anyDiscard bool |
| flush := func() error { |
| if len(record) > 0 { |
| if err := packRecord(epoch, recordTypeHandshake, record, anyDiscard); err != nil { |
| return err |
| } |
| } |
| record = nil |
| anyDiscard = false |
| return nil |
| } |
| |
| for i := range fragments { |
| f := &fragments[i] |
| if epoch != nil && (f.Epoch != epoch.epoch || f.IsChangeCipherSpec) { |
| c.err = flush() |
| if c.err != nil { |
| return |
| } |
| epoch = nil |
| } |
| |
| if epoch == nil { |
| epoch = c.getOutEpochOrPanic(f.Epoch) |
| } |
| |
| if f.IsChangeCipherSpec { |
| c.err = packRecord(epoch, recordTypeChangeCipherSpec, f.Bytes(), false) |
| if c.err != nil { |
| return |
| } |
| continue |
| } |
| |
| if f.ShouldDiscard { |
| anyDiscard = true |
| } |
| |
| fBytes := f.Bytes() |
| if n := config.Bugs.SplitFragments; n > 0 { |
| if len(fBytes) > n { |
| c.err = packRecord(epoch, recordTypeHandshake, fBytes[:n], f.ShouldDiscard) |
| if c.err != nil { |
| return |
| } |
| c.err = packRecord(epoch, recordTypeHandshake, fBytes[n:], f.ShouldDiscard) |
| if c.err != nil { |
| return |
| } |
| } else { |
| c.err = packRecord(epoch, recordTypeHandshake, fBytes, f.ShouldDiscard) |
| if c.err != nil { |
| return |
| } |
| } |
| } else { |
| if len(record)+len(fBytes) > maxRecordLen { |
| c.err = flush() |
| if c.err != nil { |
| return |
| } |
| } |
| record = append(record, fBytes...) |
| } |
| } |
| |
| c.err = flush() |
| } |
| |
| // WriteACK writes the specified record numbers in an ACK record to the shim, |
| // and updates shim expectations according to the specified byte ranges. To send |
| // an ACK which the shim is expected to ignore (e.g. because it should have |
| // forgotten a packet number), use a DTLSRecordNumberInfo with the |
| // MessageStartSequence, etc., fields all set to zero. |
| func (c *DTLSController) WriteACK(epoch uint16, records []DTLSRecordNumberInfo) { |
| if c.err != nil { |
| return |
| } |
| |
| // Send the ACK. |
| ack := cryptobyte.NewBuilder(make([]byte, 0, 2+8*len(records))) |
| ack.AddUint16LengthPrefixed(func(recordNumbers *cryptobyte.Builder) { |
| for _, r := range records { |
| recordNumbers.AddUint64(r.Epoch) |
| recordNumbers.AddUint64(r.Sequence) |
| } |
| }) |
| _, _, c.err = c.conn.dtlsPackRecord(c.getOutEpochOrPanic(epoch), recordTypeACK, ack.BytesOrPanic(), false) |
| if c.err != nil { |
| return |
| } |
| |
| // Update the list of expectations. This is inefficient, but is fine for |
| // test code. |
| for _, r := range records { |
| if !r.HasACKInformation() { |
| continue |
| } |
| var update []DTLSFragment |
| for _, f := range c.retransmitNeeded { |
| endOffset := f.Offset + len(f.Data) |
| // Compute two, possibly empty, intersections: the fragment with |
| // [0, ackStart) and the fragment with [ackStart, infinity). |
| |
| // First, the portion of the fragment that is before the ACK: |
| if comparePair(f.Sequence, f.Offset, r.MessageStartSequence, r.MessageStartOffset) < 0 { |
| // The fragment begins before the ACK. |
| if comparePair(f.Sequence, endOffset, r.MessageStartSequence, r.MessageStartOffset) <= 0 { |
| // The fragment ends before the ACK. |
| update = append(update, f) |
| } else { |
| // The ACK starts in the middle of the fragment. Retain a |
| // prefix of the fragment. |
| prefix := f |
| prefix.Data = f.Data[:r.MessageStartOffset-f.Offset] |
| update = append(update, prefix) |
| } |
| } |
| |
| // Next, the portion of the fragment that is after the ACK: |
| if comparePair(r.MessageEndSequence, r.MessageEndOffset, f.Sequence, endOffset) < 0 { |
| // The fragment ends after the ACK. |
| if comparePair(r.MessageEndSequence, r.MessageEndOffset, f.Sequence, f.Offset) <= 0 { |
| // The fragment begins after the ACK. |
| update = append(update, f) |
| } else { |
| // The ACK ends in the middle of the fragment. Retain a |
| // suffix of the fragment. |
| suffix := f |
| suffix.Offset = r.MessageEndOffset |
| suffix.Data = f.Data[r.MessageEndOffset-f.Offset:] |
| update = append(update, suffix) |
| } |
| } |
| } |
| c.retransmitNeeded = update |
| } |
| } |
| |
| // ReadRetransmit indicates the shim is expected to retransmit its current |
| // flight and consumes the retransmission. It returns the record numbers of the |
| // retransmission, for the test to ACK if it chooses. |
| func (c *DTLSController) ReadRetransmit() []DTLSRecordNumberInfo { |
| if c.err != nil { |
| return nil |
| } |
| |
| var ret []DTLSRecordNumberInfo |
| ret, c.err = c.doReadRetransmit() |
| return ret |
| } |
| |
| func (c *DTLSController) doReadRetransmit() ([]DTLSRecordNumberInfo, error) { |
| if err := c.conn.dtlsFlushPacket(); err != nil { |
| return nil, err |
| } |
| |
| var records []DTLSRecordNumberInfo |
| expected := slices.Clone(c.retransmitNeeded) |
| for len(expected) > 0 { |
| // Read a record from the expected epoch. The peer should retransmit in |
| // order. |
| wantTyp := recordTypeHandshake |
| if expected[0].IsChangeCipherSpec { |
| wantTyp = recordTypeChangeCipherSpec |
| } |
| epoch := c.getInEpochOrPanic(expected[0].Epoch) |
| // Retransmitted ClientHellos predate the shim learning the version. |
| // Ideally we would enforce the initial record-layer version, but |
| // post-HelloVerifyRequest ClientHellos and post-HelloRetryRequest |
| // ClientHellos look the same, but have different expectations. |
| c.conn.skipRecordVersionCheck = !expected[0].IsChangeCipherSpec && expected[0].Type == typeClientHello |
| typ, data, err := c.conn.dtlsDoReadRecord(epoch, wantTyp) |
| c.conn.skipRecordVersionCheck = false |
| if err != nil { |
| return nil, err |
| } |
| if typ != wantTyp { |
| return nil, fmt.Errorf("tls: got record of type %d in retransmit, but expected %d", typ, wantTyp) |
| } |
| if typ == recordTypeChangeCipherSpec { |
| if len(data) != 1 || data[0] != 1 { |
| return nil, errors.New("tls: got invalid ChangeCipherSpec") |
| } |
| expected = expected[1:] |
| continue |
| } |
| |
| // Consume all the handshake fragments and match them to what we expect. |
| s := cryptobyte.String(data) |
| if s.Empty() { |
| return nil, fmt.Errorf("tls: got empty record in retransmit") |
| } |
| for !s.Empty() { |
| if len(expected) == 0 || expected[0].Epoch != epoch.epoch || expected[0].IsChangeCipherSpec { |
| return nil, fmt.Errorf("tls: got excess data at epoch %d in retransmit", epoch.epoch) |
| } |
| |
| exp := &expected[0] |
| var f DTLSFragment |
| f, err = readDTLSFragment(&s) |
| if f.Type != exp.Type || f.TotalLength != exp.TotalLength || f.Sequence != exp.Sequence || f.Offset != exp.Offset { |
| return nil, fmt.Errorf("tls: got offset %d of message %d (type %d, length %d), expected offset %d of message %d (type %d, length %d)", f.Offset, f.Sequence, f.Type, f.TotalLength, exp.Offset, exp.Sequence, exp.Type, exp.TotalLength) |
| } |
| if len(f.Data) > len(exp.Data) { |
| return nil, fmt.Errorf("tls: got %d bytes at offset %d of message %d but only %d bytes were missing", len(f.Data), f.Offset, f.Sequence, len(exp.Data)) |
| } |
| if !bytes.Equal(f.Data, exp.Data[:len(f.Data)]) { |
| return nil, fmt.Errorf("tls: got %d bytes at offset %d of message %d but did not match original", len(f.Data), f.Offset, f.Sequence) |
| } |
| if len(f.Data) == len(exp.Data) { |
| expected = expected[1:] |
| } else { |
| // We only got part of the fragment we wanted. |
| exp.Offset += len(f.Data) |
| exp.Data = exp.Data[len(f.Data):] |
| // Check that the peer could not have fit more into the record. |
| if !s.Empty() { |
| return nil, errors.New("dtls: truncated handshake fragment was not last in the record") |
| } |
| if c.conn.lastRecordInFlight.bytesAvailable > 0 { |
| return nil, fmt.Errorf("dtls: handshake fragment was truncated, but record could have fit %d more bytes", c.conn.lastRecordInFlight.bytesAvailable) |
| } |
| } |
| } |
| |
| record, err := c.conn.makeDTLSRecordNumberInfo(epoch, data) |
| if err != nil { |
| return nil, err |
| } |
| records = append(records, record) |
| } |
| return records, nil |
| } |
| |
| // ReadACK indicates the shim is expected to send an ACK at the specified epoch. |
| // The contents of the ACK are checked against the connection's internal |
| // simulation of the shim's expected behavior. |
| func (c *DTLSController) ReadACK(epoch uint16) { |
| if c.err != nil { |
| return |
| } |
| |
| c.err = c.conn.dtlsFlushPacket() |
| if c.err != nil { |
| return |
| } |
| |
| typ, data, err := c.conn.dtlsDoReadRecord(c.getInEpochOrPanic(epoch), recordTypeACK) |
| if err != nil { |
| c.err = err |
| return |
| } |
| if typ != recordTypeACK { |
| c.err = fmt.Errorf("tls: got record of type %d, but expected ACK", typ) |
| return |
| } |
| |
| c.err = c.conn.checkACK(data) |
| } |
| |
| // WriteAppData writes an application data record to the shim. This may be used |
| // to test that post-handshake retransmits may interleave with application data. |
| func (c *DTLSController) WriteAppData(epoch uint16, data []byte) { |
| if c.err != nil { |
| return |
| } |
| |
| _, _, c.err = c.conn.dtlsPackRecord(c.getOutEpochOrPanic(epoch), recordTypeApplicationData, data, false) |
| } |
| |
| // ReadAppData indicates the shim is expected to send the specified application |
| // data record. This may be used to test that post-handshake retransmits may |
| // interleave with application data. |
| func (c *DTLSController) ReadAppData(epoch uint16, expected []byte) { |
| if c.err != nil { |
| return |
| } |
| |
| if err := c.conn.dtlsFlushPacket(); err != nil { |
| c.err = err |
| return |
| } |
| |
| typ, data, err := c.conn.dtlsDoReadRecord(c.getInEpochOrPanic(epoch), recordTypeApplicationData) |
| if err != nil { |
| c.err = err |
| return |
| } |
| if typ != recordTypeApplicationData { |
| c.err = fmt.Errorf("tls: got record of type %d, but expected application data", typ) |
| return |
| } |
| if !bytes.Equal(data, expected) { |
| c.err = fmt.Errorf("tls: got app data record containing %x, but expected %x", data, expected) |
| return |
| } |
| } |
| |
| // ExpectNextTimeout indicates the shim's next timeout should be d from now. |
| func (c *DTLSController) ExpectNextTimeout(d time.Duration) { |
| if c.err != nil { |
| return |
| } |
| if err := c.conn.dtlsFlushPacket(); err != nil { |
| c.err = err |
| return |
| } |
| c.err = c.conn.config.Bugs.PacketAdaptor.ExpectNextTimeout(d) |
| } |
| |
| // ExpectNoNext indicates the shim should not have a next timeout. |
| func (c *DTLSController) ExpectNoNextTimeout() { |
| if c.err != nil { |
| return |
| } |
| if err := c.conn.dtlsFlushPacket(); err != nil { |
| c.err = err |
| return |
| } |
| c.err = c.conn.config.Bugs.PacketAdaptor.ExpectNoNextTimeout() |
| } |
| |
| // MergeIntoNextFlight indicates the state from this flight should be merged |
| // into the next WriteFlight or ACKFlight call. This allows the test to control |
| // two independent post-handshake messages as a single unit. |
| func (c *DTLSController) MergeIntoNextFlight() { |
| c.mergeIntoNextFlight = true |
| } |