Add more aggressive DTLS replay tests.

The existing tests only went monotonic. Allow an arbitrary mapping
function. Also test by sending more app data. The handshake is fairly
resilient to replayed packets, whereas our test code intentionally
isn't.

Change-Id: I0fb74bbacc260c65ec5f6a1ca8f3cb23b4192855
Reviewed-on: https://boringssl-review.googlesource.com/5556
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index 07cb175..6c10992 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -589,11 +589,11 @@
 	// error if the server doesn't reply with the renegotiation extension.
 	RequireRenegotiationInfo bool
 
-	// SequenceNumberIncrement, if non-zero, causes outgoing sequence
-	// numbers in DTLS to increment by that value rather by 1. This is to
-	// stress the replay bitmap window by simulating extreme packet loss and
-	// retransmit at the record layer.
-	SequenceNumberIncrement uint64
+	// SequenceNumberMapping, if non-nil, is the mapping function to apply
+	// to the sequence number of outgoing packets. For both TLS and DTLS,
+	// the two most-significant bytes in the resulting sequence number are
+	// ignored so that the DTLS epoch cannot be changed.
+	SequenceNumberMapping func(uint64) uint64
 
 	// RSAEphemeralKey, if true, causes the server to send a
 	// ServerKeyExchange message containing an ephemeral key (as in
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index ea9f3bb..b755c46 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -12,6 +12,7 @@
 	"crypto/ecdsa"
 	"crypto/subtle"
 	"crypto/x509"
+	"encoding/binary"
 	"errors"
 	"fmt"
 	"io"
@@ -87,6 +88,8 @@
 	c.out.isDTLS = c.isDTLS
 	c.in.config = c.config
 	c.out.config = c.config
+
+	c.out.updateOutSeq()
 }
 
 // Access to net.Conn methods.
@@ -134,6 +137,7 @@
 	cipher  interface{} // cipher algorithm
 	mac     macFunction
 	seq     [8]byte // 64-bit sequence number
+	outSeq  [8]byte // Mapped sequence number
 	bfree   *block  // list of free blocks
 
 	nextCipher interface{} // next encryption state
@@ -189,10 +193,6 @@
 	if hc.isDTLS {
 		// Increment up to the epoch in DTLS.
 		limit = 2
-
-		if isOutgoing && hc.config.Bugs.SequenceNumberIncrement != 0 {
-			increment = hc.config.Bugs.SequenceNumberIncrement
-		}
 	}
 	for i := 7; i >= limit; i-- {
 		increment += uint64(hc.seq[i])
@@ -206,6 +206,8 @@
 	if increment != 0 {
 		panic("TLS: sequence number wraparound")
 	}
+
+	hc.updateOutSeq()
 }
 
 // incNextSeq increments the starting sequence number for the next epoch.
@@ -241,6 +243,22 @@
 			hc.seq[i] = 0
 		}
 	}
+
+	hc.updateOutSeq()
+}
+
+func (hc *halfConn) updateOutSeq() {
+	if hc.config.Bugs.SequenceNumberMapping != nil {
+		seqU64 := binary.BigEndian.Uint64(hc.seq[:])
+		seqU64 = hc.config.Bugs.SequenceNumberMapping(seqU64)
+		binary.BigEndian.PutUint64(hc.outSeq[:], seqU64)
+
+		// The DTLS epoch cannot be changed.
+		copy(hc.outSeq[:2], hc.seq[:2])
+		return
+	}
+
+	copy(hc.outSeq[:], hc.seq[:])
 }
 
 func (hc *halfConn) recordHeaderLen() int {
@@ -460,7 +478,7 @@
 
 	// mac
 	if hc.mac != nil {
-		mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:3], b.data[recordHeaderLen-2:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:])
+		mac := hc.mac.MAC(hc.outDigestBuf, hc.outSeq[0:], b.data[:3], b.data[recordHeaderLen-2:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:])
 
 		n := len(b.data)
 		b.resize(n + len(mac))
@@ -478,7 +496,7 @@
 		case *tlsAead:
 			payloadLen := len(b.data) - recordHeaderLen - explicitIVLen
 			b.resize(len(b.data) + c.Overhead())
-			nonce := hc.seq[:]
+			nonce := hc.outSeq[:]
 			if c.explicitNonce {
 				nonce = b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
 			}
@@ -486,7 +504,7 @@
 			payload = payload[:payloadLen]
 
 			var additionalData [13]byte
-			copy(additionalData[:], hc.seq[:])
+			copy(additionalData[:], hc.outSeq[:])
 			copy(additionalData[8:], b.data[:3])
 			additionalData[11] = byte(payloadLen >> 8)
 			additionalData[12] = byte(payloadLen)
diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go
index 538bf51..5c59dea 100644
--- a/ssl/test/runner/dtls.go
+++ b/ssl/test/runner/dtls.go
@@ -298,13 +298,13 @@
 	b.data[1] = byte(vers >> 8)
 	b.data[2] = byte(vers)
 	// DTLS records include an explicit sequence number.
-	copy(b.data[3:11], c.out.seq[0:])
+	copy(b.data[3:11], c.out.outSeq[0:])
 	b.data[11] = byte(len(data) >> 8)
 	b.data[12] = byte(len(data))
 	if explicitIVLen > 0 {
 		explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
 		if explicitIVIsSeq {
-			copy(explicitIV, c.out.seq[:])
+			copy(explicitIV, c.out.outSeq[:])
 		} else {
 			if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
 				return
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 9fa394f..ff43678 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -158,6 +158,8 @@
 	// messageLen is the length, in bytes, of the test message that will be
 	// sent.
 	messageLen int
+	// messageCount is the number of test messages that will be sent.
+	messageCount int
 	// certFile is the path to the certificate to use for the server.
 	certFile string
 	// keyFile is the path to the private key to use for the server.
@@ -221,7 +223,7 @@
 
 var testCases []testCase
 
-func doExchange(test *testCase, config *Config, conn net.Conn, messageLen int, isResume bool) error {
+func doExchange(test *testCase, config *Config, conn net.Conn, isResume bool) error {
 	var connDebug *recordingConn
 	var connDamage *damageAdaptor
 	if *flagDebug {
@@ -379,6 +381,7 @@
 		connDamage.setDamage(false)
 	}
 
+	messageLen := test.messageLen
 	if messageLen < 0 {
 		if test.protocol == dtls {
 			return fmt.Errorf("messageLen < 0 not supported for DTLS tests")
@@ -387,45 +390,52 @@
 		_, err := io.Copy(ioutil.Discard, tlsConn)
 		return err
 	}
-
 	if messageLen == 0 {
 		messageLen = 32
 	}
-	testMessage := make([]byte, messageLen)
-	for i := range testMessage {
-		testMessage[i] = 0x42
-	}
-	tlsConn.Write(testMessage)
 
-	for i := 0; i < test.sendEmptyRecords; i++ {
-		tlsConn.Write(nil)
+	messageCount := test.messageCount
+	if messageCount == 0 {
+		messageCount = 1
 	}
 
-	for i := 0; i < test.sendWarningAlerts; i++ {
-		tlsConn.SendAlert(alertLevelWarning, alertUnexpectedMessage)
-	}
-
-	buf := make([]byte, len(testMessage))
-	if test.protocol == dtls {
-		bufTmp := make([]byte, len(buf)+1)
-		n, err := tlsConn.Read(bufTmp)
-		if err != nil {
-			return err
+	for j := 0; j < messageCount; j++ {
+		testMessage := make([]byte, messageLen)
+		for i := range testMessage {
+			testMessage[i] = 0x42 ^ byte(j)
 		}
-		if n != len(buf) {
-			return fmt.Errorf("bad reply; length mismatch (%d vs %d)", n, len(buf))
-		}
-		copy(buf, bufTmp)
-	} else {
-		_, err := io.ReadFull(tlsConn, buf)
-		if err != nil {
-			return err
-		}
-	}
+		tlsConn.Write(testMessage)
 
-	for i, v := range buf {
-		if v != testMessage[i]^0xff {
-			return fmt.Errorf("bad reply contents at byte %d", i)
+		for i := 0; i < test.sendEmptyRecords; i++ {
+			tlsConn.Write(nil)
+		}
+
+		for i := 0; i < test.sendWarningAlerts; i++ {
+			tlsConn.SendAlert(alertLevelWarning, alertUnexpectedMessage)
+		}
+
+		buf := make([]byte, len(testMessage))
+		if test.protocol == dtls {
+			bufTmp := make([]byte, len(buf)+1)
+			n, err := tlsConn.Read(bufTmp)
+			if err != nil {
+				return err
+			}
+			if n != len(buf) {
+				return fmt.Errorf("bad reply; length mismatch (%d vs %d)", n, len(buf))
+			}
+			copy(buf, bufTmp)
+		} else {
+			_, err := io.ReadFull(tlsConn, buf)
+			if err != nil {
+				return err
+			}
+		}
+
+		for i, v := range buf {
+			if v != testMessage[i]^0xff {
+				return fmt.Errorf("bad reply contents at byte %d", i)
+			}
 		}
 	}
 
@@ -595,7 +605,7 @@
 
 	conn, err := acceptOrWait(listener, waitChan)
 	if err == nil {
-		err = doExchange(test, &config, conn, test.messageLen, false /* not a resumption */)
+		err = doExchange(test, &config, conn, false /* not a resumption */)
 		conn.Close()
 	}
 
@@ -625,7 +635,7 @@
 		var connResume net.Conn
 		connResume, err = acceptOrWait(listener, waitChan)
 		if err == nil {
-			err = doExchange(test, &resumeConfig, connResume, test.messageLen, true /* resumption */)
+			err = doExchange(test, &resumeConfig, connResume, true /* resumption */)
 			connResume.Close()
 		}
 	}
@@ -3303,19 +3313,38 @@
 	testCases = append(testCases, testCase{
 		protocol:     dtls,
 		name:         "DTLS-Replay",
+		messageCount: 200,
 		replayWrites: true,
 	})
 
-	// Test the outgoing sequence number skipping by values larger
+	// Test the incoming sequence number skipping by values larger
 	// than the retransmit window.
 	testCases = append(testCases, testCase{
 		protocol: dtls,
 		name:     "DTLS-Replay-LargeGaps",
 		config: Config{
 			Bugs: ProtocolBugs{
-				SequenceNumberIncrement: 127,
+				SequenceNumberMapping: func(in uint64) uint64 {
+					return in * 127
+				},
 			},
 		},
+		messageCount: 200,
+		replayWrites: true,
+	})
+
+	// Test the incoming sequence number changing non-monotonically.
+	testCases = append(testCases, testCase{
+		protocol: dtls,
+		name:     "DTLS-Replay-NonMonotonic",
+		config: Config{
+			Bugs: ProtocolBugs{
+				SequenceNumberMapping: func(in uint64) uint64 {
+					return in ^ 31
+				},
+			},
+		},
+		messageCount: 200,
 		replayWrites: true,
 	})
 }