Split off DTLS record header parsing in test runner.
The DTLS 1.3 record header is formatted differently than the old record
header, but the code to read/process a DTLS record mixes record header
parsing with other record processing code. This change provides a clear
delineation between processing the record header and processing the
record, which will assist in adding support for the DTLS 1.3 record
header.
Bug: 715
Change-Id: I13a0bb5c184e79b88f064e9ac8ecbc82eb56750a
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/69950
Commit-Queue: Bob Beck <bbe@google.com>
Reviewed-by: Bob Beck <bbe@google.com>
diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go
index 95e1a9a..1893afe 100644
--- a/ssl/test/runner/dtls.go
+++ b/ssl/test/runner/dtls.go
@@ -23,29 +23,13 @@
"net"
)
-func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) {
- recordHeaderLen := dtlsMaxRecordHeaderLen
-
- if c.rawInput == nil {
- c.rawInput = c.in.newBlock()
- }
- b := c.rawInput
-
- // Read a new packet only if the current one is empty.
- var newPacket bool
- if len(b.data) == 0 {
- // Pick some absurdly large buffer size.
- b.resize(maxCiphertext + recordHeaderLen)
- n, err := c.conn.Read(c.rawInput.data)
- if err != nil {
- return 0, nil, err
- }
- if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength {
- return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length")
- }
- c.rawInput.resize(n)
- newPacket = true
- }
+// readDTLSRecordHeader reads the record header from the block. Based on the
+// header it reads, it checks the header's validity and sets appropriate state
+// as needed. This function returns the record header, the record type indicated
+// in the header (if it contains the type), and the sequence number to use for
+// record decryption.
+func (c *Conn) readDTLSRecordHeader(b *block) (headerLen int, recordLen int, typ recordType, seq []byte, err error) {
+ recordHeaderLen := 13
// Read out one record.
//
@@ -53,9 +37,9 @@
// but this is test code. We should not be tolerant of our
// peer sending garbage.
if len(b.data) < recordHeaderLen {
- return 0, nil, errors.New("dtls: failed to read record header")
+ return 0, 0, 0, nil, errors.New("dtls: failed to read record header")
}
- typ := recordType(b.data[0])
+ typ = recordType(b.data[0])
vers := uint16(b.data[1])<<8 | uint16(b.data[2])
// Alerts sent near version negotiation do not have a well-defined
// record-layer version prior to TLS 1.3. (In TLS 1.3, the record-layer
@@ -68,32 +52,61 @@
}
if vers != wireVersion {
c.sendAlert(alertProtocolVersion)
- return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.wireVersion))
+ return 0, 0, 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.wireVersion))
}
} else {
// Pre-version-negotiation alerts may be sent with any version.
if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect {
c.sendAlert(alertProtocolVersion)
- return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
+ return 0, 0, 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
}
}
}
epoch := b.data[3:5]
- seq := b.data[5:11]
+ seq = b.data[5:11]
// For test purposes, require the sequence number be monotonically
// increasing, so c.in includes the minimum next sequence number. Gaps
// may occur if packets failed to be sent out. A real implementation
// would maintain a replay window and such.
if !bytes.Equal(epoch, c.in.seq[:2]) {
c.sendAlert(alertIllegalParameter)
- return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch, want %x, got %x", c.in.seq[:2], epoch))
+ return 0, 0, 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
}
if bytes.Compare(seq, c.in.seq[2:]) < 0 {
c.sendAlert(alertIllegalParameter)
- return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
+ return 0, 0, 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
}
copy(c.in.seq[2:], seq)
- n := int(b.data[11])<<8 | int(b.data[12])
+ recordLen = int(b.data[11])<<8 | int(b.data[12])
+ return recordHeaderLen, recordLen, typ, b.data[3:11], nil
+}
+
+func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) {
+ if c.rawInput == nil {
+ c.rawInput = c.in.newBlock()
+ }
+ b := c.rawInput
+
+ // Read a new packet only if the current one is empty.
+ var newPacket bool
+ if len(b.data) == 0 {
+ // Pick some absurdly large buffer size.
+ b.resize(maxCiphertext + dtlsMaxRecordHeaderLen)
+ n, err := c.conn.Read(c.rawInput.data)
+ if err != nil {
+ return 0, nil, err
+ }
+ if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength {
+ return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length")
+ }
+ c.rawInput.resize(n)
+ newPacket = true
+ }
+
+ recordHeaderLen, n, typ, seq, err := c.readDTLSRecordHeader(b)
+ if err != nil {
+ return 0, nil, err
+ }
if n > maxCiphertext || len(b.data) < recordHeaderLen+n {
c.sendAlert(alertRecordOverflow)
return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n))
@@ -101,7 +114,7 @@
b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
// Process message.
- ok, off, _, alertValue := c.in.decrypt(b.data[3:11], recordHeaderLen, b)
+ ok, off, _, alertValue := c.in.decrypt(seq, recordHeaderLen, b)
if !ok {
// A real DTLS implementation would silently ignore bad records,
// but we want to notice errors from the implementation under