runner: Remove explicit seq parameter to encrypt and decrypt

It was always derived from hc.seq. The only subtleties are:

1. The DTLS code relies on record header parsing zooming the sequence
   number forward. (It was already doing this.)

2. Outgoing records need to accomodate the goofy SequenceNumberMapping
   feature.

3. The funny sequence number business in DTLS 1.2 vs 1.3 was previously
   handled at the header parser for incoming records and at encrypt()
   for outgoing records. Unify everything on doing it at
   encrypt/decrypt.

I added this parameter in
https://boringssl-review.googlesource.com/c/boringssl/+/71407, but I
think that was a mistake. We (mostly) always know the expected sequence
number, and this is one more field we can derive from the epoch.

Change-Id: I00124aee57618dfbde5e458d0f9572d16946c0bc
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72648
Reviewed-by: Nick Harper <nharper@chromium.org>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 9ac2323..839c2ee 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -448,7 +448,7 @@
 // success boolean, the application payload, the encrypted record type (or 0
 // if there is none), and an optional alert value. Decryption occurs in-place,
 // so the contents of record will be overwritten as part of this process.
-func (hc *halfConn) decrypt(seq []byte, recordHeaderLen int, record []byte) (ok bool, contentType recordType, data []byte, alertValue alert) {
+func (hc *halfConn) decrypt(recordHeaderLen int, record []byte) (ok bool, contentType recordType, data []byte, alertValue alert) {
 	// pull out payload
 	payload := record[recordHeaderLen:]
 
@@ -466,7 +466,15 @@
 		case cipher.Stream:
 			c.XORKeyStream(payload, payload)
 		case *tlsAead:
-			nonce := seq
+			nonce := hc.seq[:]
+			if hc.isDTLS && hc.version >= VersionTLS13 && !hc.conn.useDTLSPlaintextHeader() {
+				// Unlike DTLS 1.2, DTLS 1.3's nonce construction does not use
+				// the epoch number. We store the epoch and nonce numbers
+				// together, so make a copy without the epoch.
+				nonce = make([]byte, 8)
+				copy(nonce[2:], hc.seq[2:])
+			}
+
 			if explicitIVLen != 0 {
 				if len(payload) < explicitIVLen {
 					return false, 0, nil, alertBadRecordMAC
@@ -478,7 +486,7 @@
 			var additionalData []byte
 			if hc.version < VersionTLS13 {
 				additionalData = make([]byte, 13)
-				copy(additionalData, seq)
+				copy(additionalData, hc.seq[:])
 				copy(additionalData[8:], record[:3])
 				n := len(payload) - c.Overhead()
 				additionalData[11] = byte(n >> 8)
@@ -546,7 +554,7 @@
 		payload = payload[:n]
 		record[recordHeaderLen-2] = byte(n >> 8)
 		record[recordHeaderLen-1] = byte(n)
-		localMAC := hc.computeMAC(seq, record[:recordHeaderLen], payload)
+		localMAC := hc.computeMAC(hc.seq[:], record[:recordHeaderLen], payload)
 		if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
 			return false, 0, nil, alertBadRecordMAC
 		}
@@ -627,7 +635,8 @@
 // (which must be in the last two bytes of the header) should be computed for
 // the unencrypted, unpadded payload. It will be updated, potentially in-place,
 // with the final length.
-func (hc *halfConn) encrypt(record, payload []byte, typ recordType, headerLen int, headerHasLength bool, seq []byte) ([]byte, error) {
+func (hc *halfConn) encrypt(record, payload []byte, typ recordType, headerLen int, headerHasLength bool) ([]byte, error) {
+	seq := hc.sequenceNumberForOutput()
 	prefixLen := len(record)
 	header := record[prefixLen-headerLen:]
 	explicitIVLen := hc.explicitIVLen()
@@ -932,7 +941,7 @@
 
 	// Process message.
 	b := c.rawInput.Next(recordHeaderLen + n)
-	ok, encTyp, data, alertValue := c.in.decrypt(c.in.seq[:], recordHeaderLen, b)
+	ok, encTyp, data, alertValue := c.in.decrypt(recordHeaderLen, b)
 	if !ok {
 		// TLS 1.3 early data uses trial decryption.
 		if c.skipEarlyData {
@@ -1284,7 +1293,7 @@
 		record[3] = byte(m >> 8) // encrypt will update this
 		record[4] = byte(m)
 
-		record, err = c.out.encrypt(record, data[:m], typ, tlsRecordHeaderLen, true /* header has length */, c.out.seq[:])
+		record, err = c.out.encrypt(record, data[:m], typ, tlsRecordHeaderLen, true /* header has length */)
 		if err != nil {
 			return
 		}
diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go
index ad7d40f..b7a0dfa 100644
--- a/ssl/test/runner/dtls.go
+++ b/ssl/test/runner/dtls.go
@@ -26,7 +26,7 @@
 	"golang.org/x/crypto/cryptobyte"
 )
 
-func (c *Conn) readDTLS13RecordHeader(b []byte) (headerLen int, recordLen int, recTyp recordType, seq []byte, err error) {
+func (c *Conn) readDTLS13RecordHeader(b []byte) (headerLen int, recordLen int, recTyp recordType, err error) {
 	// The DTLS 1.3 record header starts with the type byte containing
 	// 0b001CSLEE, where C, S, L, and EE are bits with the following
 	// meanings:
@@ -44,18 +44,18 @@
 	// 0b001011EE, or 0x2c-0x2f.
 	recordHeaderLen := 5
 	if len(b) < recordHeaderLen {
-		return 0, 0, 0, nil, errors.New("dtls: failed to read record header")
+		return 0, 0, 0, errors.New("dtls: failed to read record header")
 	}
 	typ := b[0]
 	if typ&0xfc != 0x2c {
-		return 0, 0, 0, nil, errors.New("dtls: DTLS 1.3 record header has bad type byte")
+		return 0, 0, 0, errors.New("dtls: DTLS 1.3 record header has bad type byte")
 	}
 	// For test purposes, require the epoch received be the same as the
 	// epoch we expect to receive.
 	epoch := typ & 0x03
 	if epoch != c.in.seq[1]&0x03 {
 		c.sendAlert(alertIllegalParameter)
-		return 0, 0, 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
+		return 0, 0, 0, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
 	}
 	wireSeq := b[1:3]
 	if !c.config.Bugs.NullAllCiphers {
@@ -80,20 +80,20 @@
 		newSeq += 0x10000
 	}
 
-	seq = make([]byte, 8)
+	seq := make([]byte, 8)
 	binary.BigEndian.PutUint64(seq, newSeq)
 	copy(c.in.seq[2:], seq[2:])
 
 	recordLen = int(b[3])<<8 | int(b[4])
-	return recordHeaderLen, recordLen, 0, seq, nil
+	return recordHeaderLen, recordLen, 0, nil
 }
 
 // readDTLSRecordHeader reads the record header from the input. Based on the
 // header it reads, it checks the header's validity and sets appropriate state
-// as needed. This function returns the record header, the record type indicated
-// in the header (if it contains the type), and the sequence number to use for
-// record decryption.
-func (c *Conn) readDTLSRecordHeader(b []byte) (headerLen int, recordLen int, typ recordType, seq []byte, err error) {
+// as needed. This function returns the record header and the record type
+// indicated in the header (if it contains the type). The connection's internal
+// sequence number is updated to the value from the header.
+func (c *Conn) readDTLSRecordHeader(b []byte) (headerLen int, recordLen int, typ recordType, err error) {
 	if c.in.cipher != nil && c.in.version >= VersionTLS13 {
 		return c.readDTLS13RecordHeader(b)
 	}
@@ -105,7 +105,7 @@
 	// but this is test code. We should not be tolerant of our
 	// peer sending garbage.
 	if len(b) < recordHeaderLen {
-		return 0, 0, 0, nil, errors.New("dtls: failed to read record header")
+		return 0, 0, 0, errors.New("dtls: failed to read record header")
 	}
 	typ = recordType(b[0])
 	vers := uint16(b[1])<<8 | uint16(b[2])
@@ -120,33 +120,33 @@
 			}
 			if vers != wireVersion {
 				c.sendAlert(alertProtocolVersion)
-				return 0, 0, 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.wireVersion))
+				return 0, 0, 0, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.wireVersion))
 			}
 		} else {
 			// Pre-version-negotiation alerts may be sent with any version.
 			if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect {
 				c.sendAlert(alertProtocolVersion)
-				return 0, 0, 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
+				return 0, 0, 0, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
 			}
 		}
 	}
 	epoch := b[3:5]
-	seq = b[5:11]
+	seq := b[5:11]
 	// For test purposes, require the sequence number be monotonically
 	// increasing, so c.in includes the minimum next sequence number. Gaps
 	// may occur if packets failed to be sent out. A real implementation
 	// would maintain a replay window and such.
 	if !bytes.Equal(epoch, c.in.seq[:2]) {
 		c.sendAlert(alertIllegalParameter)
-		return 0, 0, 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
+		return 0, 0, 0, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
 	}
 	if bytes.Compare(seq, c.in.seq[2:]) < 0 {
 		c.sendAlert(alertIllegalParameter)
-		return 0, 0, 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
+		return 0, 0, 0, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
 	}
 	copy(c.in.seq[2:], seq)
 	recordLen = int(b[11])<<8 | int(b[12])
-	return recordHeaderLen, recordLen, typ, b[3:11], nil
+	return recordHeaderLen, recordLen, typ, nil
 }
 
 func (c *Conn) writeACKs(seqnums []uint64) {
@@ -180,7 +180,7 @@
 	}
 
 	// Consume the next record from the buffer.
-	recordHeaderLen, n, typ, seq, err := c.readDTLSRecordHeader(c.rawInput.Bytes())
+	recordHeaderLen, n, typ, err := c.readDTLSRecordHeader(c.rawInput.Bytes())
 	if err != nil {
 		return 0, nil, err
 	}
@@ -191,7 +191,8 @@
 	b := c.rawInput.Next(recordHeaderLen + n)
 
 	// Process message.
-	ok, encTyp, data, alertValue := c.in.decrypt(seq, recordHeaderLen, b)
+	seq := slices.Clone(c.in.seq[:])
+	ok, encTyp, data, alertValue := c.in.decrypt(recordHeaderLen, b)
 	if !ok {
 		// A real DTLS implementation would silently ignore bad records,
 		// but we want to notice errors from the implementation under
@@ -507,7 +508,7 @@
 	}
 
 	recordHeaderLen := len(record)
-	record, err = c.out.encrypt(record, data, typ, recordHeaderLen, headerHasLength, seq)
+	record, err = c.out.encrypt(record, data, typ, recordHeaderLen, headerHasLength)
 	if err != nil {
 		return
 	}