runner: Use some new Go conveniences
I was trying to figure out why "slices" didn't have function to zero
things, but it's because it's a new built-in, clear().
Also replace a ton of loops with slices.Contains.
Change-Id: I0e3777b17217eef7cf166f1e3c9f5dab77f1ba34
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/71968
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/cipher_suites.go b/ssl/test/runner/cipher_suites.go
index 4702df6..c43fc09 100644
--- a/ssl/test/runner/cipher_suites.go
+++ b/ssl/test/runner/cipher_suites.go
@@ -16,6 +16,7 @@
"crypto/sha512"
"crypto/x509"
"hash"
+ "slices"
"golang.org/x/crypto/chacha20poly1305"
)
@@ -338,11 +339,9 @@
// mutualCipherSuite returns a cipherSuite given a list of supported
// ciphersuites and the id requested by the peer.
-func mutualCipherSuite(have []uint16, want uint16) *cipherSuite {
- for _, id := range have {
- if id == want {
- return cipherSuiteFromID(id)
- }
+func mutualCipherSuite(have []uint16, id uint16) *cipherSuite {
+ if slices.Contains(have, id) {
+ return cipherSuiteFromID(id)
}
return nil
}
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 9229302..c019645 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -337,13 +337,9 @@
}
}
copy(hc.seq[2:], hc.nextSeq[:])
- for i := range hc.nextSeq {
- hc.nextSeq[i] = 0
- }
+ clear(hc.nextSeq[:])
} else {
- for i := range hc.seq {
- hc.seq[i] = 0
- }
+ clear(hc.seq[:])
}
}
@@ -354,9 +350,7 @@
hc.seq[0] = byte(epoch >> 8)
hc.seq[1] = byte(epoch)
copy(hc.seq[2:], hc.nextSeq[:])
- for i := range hc.nextSeq {
- hc.nextSeq[i] = 0
- }
+ clear(hc.nextSeq[:])
}
func (hc *halfConn) sequenceNumberForOutput() []byte {
@@ -653,9 +647,7 @@
record = append(record, byte(typ))
}
padding := extendSlice(&record, hc.config.Bugs.RecordPadding)
- for i := range padding {
- padding[i] = 0
- }
+ clear(padding)
}
if hc.mac != nil {
@@ -1430,7 +1422,7 @@
// The handshake message unmarshallers
// expect to be able to keep references to data,
// so pass in a fresh copy that won't be overwritten.
- data = append([]byte(nil), data...)
+ data = slices.Clone(data)
if data[0] == typeServerHello && len(data) >= 38 {
vers := uint16(data[4])<<8 | uint16(data[5])
diff --git a/ssl/test/runner/deterministic.go b/ssl/test/runner/deterministic.go
index 50ae47f..6ce9b51 100644
--- a/ssl/test/runner/deterministic.go
+++ b/ssl/test/runner/deterministic.go
@@ -28,9 +28,7 @@
}
func (d *deterministicRand) Read(buf []byte) (int, error) {
- for i := range buf {
- buf[i] = 0
- }
+ clear(buf)
var nonce [12]byte
binary.LittleEndian.PutUint64(nonce[:8], d.numCalls)
cipher, err := chacha20.NewUnauthenticatedCipher(deterministicRandKey, nonce[:])
diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go
index 5eef430..63bc864 100644
--- a/ssl/test/runner/dtls.go
+++ b/ssl/test/runner/dtls.go
@@ -21,6 +21,7 @@
"fmt"
"math/rand"
"net"
+ "slices"
)
func (c *Conn) readDTLS13RecordHeader(b []byte) (headerLen int, recordLen int, recTyp recordType, seq []byte, err error) {
@@ -383,7 +384,7 @@
records[i] = append(records[i], fragment...)
} else {
// The fragment will be appended to, so copy it.
- records = append(records, append([]byte{}, fragment...))
+ records = append(records, slices.Clone(fragment))
}
}
@@ -577,7 +578,7 @@
}
// Start with the TLS handshake header,
// without the DTLS bits.
- c.handMsg = append([]byte{}, header[:4]...)
+ c.handMsg = slices.Clone(header[:4])
} else if fragN != c.handMsgLen {
return nil, errors.New("dtls: bad handshake length")
}
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index f968d09..827fa18 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -18,6 +18,7 @@
"io"
"math/big"
"net"
+ "slices"
"time"
"boringssl.googlesource.com/boringssl/ssl/test/runner/hpke"
@@ -63,7 +64,7 @@
// this function does not update internal handshake state, so the test must be
// configured compatibly with |in|.
func replaceClientHello(hello *clientHelloMsg, in []byte) (*clientHelloMsg, error) {
- copied := append([]byte{}, in...)
+ copied := slices.Clone(in)
newHello := new(clientHelloMsg)
if !newHello.unmarshal(copied) {
return nil, errors.New("tls: invalid ClientHello")
@@ -197,9 +198,7 @@
}
if challengeLength <= len(hs.hello.random) {
skip := len(hs.hello.random) - challengeLength
- for i := 0; i < skip; i++ {
- hs.hello.random[i] = 0
- }
+ clear(hs.hello.random[:skip])
hs.hello.v2Challenge = hs.hello.random[skip:]
} else {
hs.hello.v2Challenge = make([]byte, challengeLength)
@@ -451,14 +450,9 @@
return HPKECipherSuite{}, false
}
- for _, wantSuite := range config.echCipherSuitePreferences() {
- if config.Bugs.IgnoreECHConfigCipherPreferences {
- return wantSuite, true
- }
- for _, cipherSuite := range echConfig.CipherSuites {
- if cipherSuite == wantSuite {
- return cipherSuite, true
- }
+ for _, suite := range config.echCipherSuitePreferences() {
+ if config.Bugs.IgnoreECHConfigCipherPreferences || slices.Contains(echConfig.CipherSuites, suite) {
+ return suite, true
}
}
return HPKECipherSuite{}, false
@@ -691,20 +685,17 @@
possibleCipherSuites := c.config.cipherSuites()
hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites))
-NextCipherSuite:
for _, suiteID := range possibleCipherSuites {
- for _, suite := range cipherSuites {
- if suite.id != suiteID {
- continue
- }
- // Don't advertise TLS 1.2-only cipher suites unless
- // we're attempting TLS 1.2.
- if maxVersion < VersionTLS12 && suite.flags&suiteTLS12 != 0 {
- continue
- }
- hello.cipherSuites = append(hello.cipherSuites, suiteID)
- continue NextCipherSuite
+ suite := cipherSuiteFromID(suiteID)
+ if suite == nil {
+ continue
}
+ // Don't advertise TLS 1.2-only cipher suites unless
+ // we're attempting TLS 1.2.
+ if maxVersion < VersionTLS12 && suite.flags&suiteTLS12 != 0 {
+ continue
+ }
+ hello.cipherSuites = append(hello.cipherSuites, suiteID)
}
if c.config.Bugs.AdvertiseAllConfiguredCiphers {
@@ -953,10 +944,8 @@
offset = 4 + 2 + 32 - echAcceptConfirmationLength
}
- withZeros := append(make([]byte, 0, len(raw)), raw...)
- for i := 0; i < echAcceptConfirmationLength; i++ {
- withZeros[i+offset] = 0
- }
+ withZeros := slices.Clone(raw)
+ clear(withZeros[offset : offset+echAcceptConfirmationLength])
confirmation := finishedHash.echAcceptConfirmation(hello.random, label, withZeros)
return bytes.Equal(confirmation, raw[offset:offset+echAcceptConfirmationLength])
@@ -1542,15 +1531,8 @@
helloRetryRequest.selectedGroup = c.config.Bugs.MisinterpretHelloRetryRequestCurve
}
if helloRetryRequest.hasSelectedGroup {
- var hrrCurveFound bool
group := helloRetryRequest.selectedGroup
- for _, curveID := range hello.supportedCurves {
- if group == curveID {
- hrrCurveFound = true
- break
- }
- }
- if !hrrCurveFound || hs.keyShares[group] != nil {
+ if !slices.Contains(hello.supportedCurves, group) || hs.keyShares[group] != nil {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: received invalid HelloRetryRequest")
}
@@ -2408,10 +2390,8 @@
// indicating if the fallback case was reached.
func mutualProtocol(protos, preferenceProtos []string) (string, bool) {
for _, s := range preferenceProtos {
- for _, c := range protos {
- if s == c {
- return s, false
- }
+ if slices.Contains(protos, s) {
+ return s, false
}
}
@@ -2421,9 +2401,7 @@
// writeIntPadded writes x into b, padded up with leading zeros as
// needed.
func writeIntPadded(b []byte, x *big.Int) {
- for i := range b {
- b[i] = 0
- }
+ clear(b)
xb := x.Bytes()
copy(b[len(b)-len(xb):], xb)
}
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index b3c1f98..2c210eb 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -17,6 +17,7 @@
"fmt"
"io"
"math/big"
+ "slices"
"time"
"boringssl.googlesource.com/boringssl/ssl/test/runner/hpke"
@@ -302,7 +303,7 @@
panic("Could not map wire version")
}
- clientProtocol, ok := wireToVersion(c.clientVersion, c.isDTLS)
+ clientProtocol, clientProtocolOK := wireToVersion(c.clientVersion, c.isDTLS)
if c.shouldSendHelloVerifyRequest() {
// Per RFC 6347, the version field in HelloVerifyRequest SHOULD
@@ -367,12 +368,8 @@
}
}
- if config.Bugs.FailIfPostQuantumOffered {
- for _, offeredCurve := range hs.clientHello.supportedCurves {
- if isPqGroup(offeredCurve) {
- return errors.New("tls: post-quantum group was offered")
- }
- }
+ if config.Bugs.FailIfPostQuantumOffered && slices.ContainsFunc(hs.clientHello.supportedCurves, isPqGroup) {
+ return errors.New("tls: post-quantum group was offered")
}
if expected := config.Bugs.ExpectedKeyShares; expected != nil {
@@ -388,17 +385,13 @@
}
// Reject < 1.2 ClientHellos with signature_algorithms.
- if ok && clientProtocol < VersionTLS12 && len(hs.clientHello.signatureAlgorithms) > 0 {
+ if clientProtocolOK && clientProtocol < VersionTLS12 && len(hs.clientHello.signatureAlgorithms) > 0 {
return fmt.Errorf("tls: client included signature_algorithms before TLS 1.2")
}
// Check the client cipher list is consistent with the version.
- if ok && clientProtocol < VersionTLS12 {
- for _, id := range hs.clientHello.cipherSuites {
- if isTLS12Cipher(id) {
- return fmt.Errorf("tls: client offered TLS 1.2 cipher before TLS 1.2")
- }
- }
+ if clientProtocolOK && clientProtocol < VersionTLS12 && slices.ContainsFunc(hs.clientHello.cipherSuites, isTLS12Cipher) {
+ return fmt.Errorf("tls: client offered TLS 1.2 cipher before TLS 1.2")
}
if config.Bugs.MockQUICTransport != nil && len(hs.clientHello.sessionID) > 0 {
@@ -422,14 +415,7 @@
return fmt.Errorf("tls: client offered unexpected PSK identities")
}
- var scsvFound bool
- for _, cipherSuite := range hs.clientHello.cipherSuites {
- if cipherSuite == fallbackSCSV {
- scsvFound = true
- break
- }
- }
-
+ scsvFound := slices.Contains(hs.clientHello.cipherSuites, fallbackSCSV)
if !scsvFound && config.Bugs.FailIfNotFallbackSCSV {
return errors.New("tls: no fallback SCSV found when expected")
} else if scsvFound && !config.Bugs.FailIfNotFallbackSCSV {
@@ -617,14 +603,11 @@
supportedCurve := false
var selectedCurve CurveID
preferredCurves := config.curvePreferences()
-Curves:
for _, curve := range hs.clientHello.supportedCurves {
- for _, supported := range preferredCurves {
- if supported == curve {
- supportedCurve = true
- selectedCurve = curve
- break Curves
- }
+ if slices.Contains(preferredCurves, curve) {
+ supportedCurve = true
+ selectedCurve = curve
+ break
}
}
@@ -1047,9 +1030,7 @@
// Emit the ECH confirmation signal when requested.
if hs.clientHello.echInner && !config.Bugs.OmitServerHelloECHConfirmation {
randomSuffix := hs.hello.random[len(hs.hello.random)-echAcceptConfirmationLength:]
- for i := range randomSuffix {
- randomSuffix[i] = 0
- }
+ clear(randomSuffix)
copy(randomSuffix, hs.finishedHash.echAcceptConfirmation(hs.clientHello.random, echAcceptConfirmationLabel, hs.hello.marshal()))
hs.hello.raw = nil
}
@@ -1148,38 +1129,35 @@
certMsgBytes := certMsg.marshal()
sentCompressedCertMsg := false
- FindCertCompressionAlg:
- for candidate, alg := range c.config.CertCompressionAlgs {
- for _, id := range hs.clientHello.compressedCertAlgs {
- if id == candidate {
- if expected := config.Bugs.ExpectedCompressedCert; expected != 0 && expected != id {
- return fmt.Errorf("tls: expected to send compressed cert with alg %d, but picked %d", expected, id)
- }
- if config.Bugs.ExpectUncompressedCert {
- return errors.New("tls: expected to send uncompressed cert")
- }
-
- if override := config.Bugs.SendCertCompressionAlgID; override != 0 {
- id = override
- }
-
- uncompressed := certMsgBytes[4:]
- uncompressedLen := uint32(len(uncompressed))
- if override := config.Bugs.SendCertUncompressedLength; override != 0 {
- uncompressedLen = override
- }
-
- compressedCertMsgBytes := (&compressedCertificateMsg{
- algID: id,
- uncompressedLength: uncompressedLen,
- compressed: alg.Compress(uncompressed),
- }).marshal()
-
- hs.writeServerHash(compressedCertMsgBytes)
- c.writeRecord(recordTypeHandshake, compressedCertMsgBytes)
- sentCompressedCertMsg = true
- break FindCertCompressionAlg
+ for id, alg := range c.config.CertCompressionAlgs {
+ if slices.Contains(hs.clientHello.compressedCertAlgs, id) {
+ if expected := config.Bugs.ExpectedCompressedCert; expected != 0 && expected != id {
+ return fmt.Errorf("tls: expected to send compressed cert with alg %d, but picked %d", expected, id)
}
+ if config.Bugs.ExpectUncompressedCert {
+ return errors.New("tls: expected to send uncompressed cert")
+ }
+
+ if override := config.Bugs.SendCertCompressionAlgID; override != 0 {
+ id = override
+ }
+
+ uncompressed := certMsgBytes[4:]
+ uncompressedLen := uint32(len(uncompressed))
+ if override := config.Bugs.SendCertUncompressedLength; override != 0 {
+ uncompressedLen = override
+ }
+
+ compressedCertMsgBytes := (&compressedCertificateMsg{
+ algID: id,
+ uncompressedLength: uncompressedLen,
+ compressed: alg.Compress(uncompressed),
+ }).marshal()
+
+ hs.writeServerHash(compressedCertMsgBytes)
+ c.writeRecord(recordTypeHandshake, compressedCertMsgBytes)
+ sentCompressedCertMsg = true
+ break
}
}
@@ -1497,16 +1475,8 @@
copy(hs.hello.random[len(hs.hello.random)-8:], downgradeJDK11)
}
- foundCompression := false
// We only support null compression, so check that the client offered it.
- for _, compression := range hs.clientHello.compressionMethods {
- if compression == compressionNone {
- foundCompression = true
- break
- }
- }
-
- if !foundCompression {
+ if !slices.Contains(hs.clientHello.compressionMethods, compressionNone) {
c.sendAlert(alertHandshakeFailure)
return false, errors.New("tls: client does not support uncompressed connections")
}
@@ -1517,28 +1487,19 @@
supportedCurve := false
preferredCurves := config.curvePreferences()
-Curves:
for _, curve := range hs.clientHello.supportedCurves {
if isPqGroup(curve) && c.vers < VersionTLS13 {
// Post-quantum is TLS 1.3 only.
continue
}
- for _, supported := range preferredCurves {
- if supported == curve {
- supportedCurve = true
- break Curves
- }
- }
- }
-
- supportedPointFormat := false
- for _, pointFormat := range hs.clientHello.supportedPoints {
- if pointFormat == pointFormatUncompressed {
- supportedPointFormat = true
+ if slices.Contains(preferredCurves, curve) {
+ supportedCurve = true
break
}
}
+
+ supportedPointFormat := slices.Contains(hs.clientHello.supportedPoints, pointFormatUncompressed)
hs.ellipticOk = supportedCurve && supportedPointFormat
_, hs.ecdsaOk = hs.cert.PrivateKey.(*ecdsa.PrivateKey)
@@ -1646,18 +1607,8 @@
var alpsAllowed, alpsAllowedOld bool
if c.vers >= VersionTLS13 {
- for _, proto := range hs.clientHello.alpsProtocols {
- if proto == c.clientProtocol {
- alpsAllowed = true
- break
- }
- }
- for _, proto := range hs.clientHello.alpsProtocolsOld {
- if proto == c.clientProtocol {
- alpsAllowedOld = true
- break
- }
- }
+ alpsAllowed = slices.Contains(hs.clientHello.alpsProtocols, c.clientProtocol)
+ alpsAllowedOld = slices.Contains(hs.clientHello.alpsProtocolsOld, c.clientProtocol)
}
if c.config.Bugs.AlwaysNegotiateApplicationSettingsBoth {
@@ -1727,14 +1678,11 @@
}
if hs.clientHello.srtpProtectionProfiles != nil {
- SRTPLoop:
- for _, p1 := range c.config.SRTPProtectionProfiles {
- for _, p2 := range hs.clientHello.srtpProtectionProfiles {
- if p1 == p2 {
- serverExtensions.srtpProtectionProfile = p1
- c.srtpProtectionProfile = p1
- break SRTPLoop
- }
+ for _, p := range c.config.SRTPProtectionProfiles {
+ if slices.Contains(hs.clientHello.srtpProtectionProfiles, p) {
+ serverExtensions.srtpProtectionProfile = p
+ c.srtpProtectionProfile = p
+ break
}
}
}
@@ -2422,53 +2370,34 @@
// tryCipherSuite returns a cipherSuite with the given id if that cipher suite
// is acceptable to use.
func (c *Conn) tryCipherSuite(id uint16, supportedCipherSuites []uint16, version uint16, ellipticOk, ecdsaOk bool) *cipherSuite {
- for _, supported := range supportedCipherSuites {
- if id == supported {
- var candidate *cipherSuite
-
- for _, s := range cipherSuites {
- if s.id == id {
- candidate = s
- break
- }
- }
- if candidate == nil {
- continue
- }
-
- // Don't select a ciphersuite which we can't
- // support for this client.
- if version >= VersionTLS13 || candidate.flags&suiteTLS13 != 0 {
- if version < VersionTLS13 || candidate.flags&suiteTLS13 == 0 {
- continue
- }
- return candidate
- }
- if (candidate.flags&suiteECDHE != 0) && !ellipticOk {
- continue
- }
- if (candidate.flags&suiteECDSA != 0) != ecdsaOk {
- continue
- }
- if version < VersionTLS12 && candidate.flags&suiteTLS12 != 0 {
- continue
- }
- return candidate
- }
+ candidate := mutualCipherSuite(supportedCipherSuites, id)
+ if candidate == nil {
+ return nil
}
- return nil
+ // Don't select a ciphersuite which we can't
+ // support for this client.
+ if version >= VersionTLS13 || candidate.flags&suiteTLS13 != 0 {
+ if version < VersionTLS13 || candidate.flags&suiteTLS13 == 0 {
+ return nil
+ }
+ return candidate
+ }
+ if (candidate.flags&suiteECDHE != 0) && !ellipticOk {
+ return nil
+ }
+ if (candidate.flags&suiteECDSA != 0) != ecdsaOk {
+ return nil
+ }
+ if version < VersionTLS12 && candidate.flags&suiteTLS12 != 0 {
+ return nil
+ }
+ return candidate
}
func isTLS12Cipher(id uint16) bool {
- for _, cipher := range cipherSuites {
- if cipher.id != id {
- continue
- }
- return cipher.flags&suiteTLS12 != 0
- }
- // Unknown cipher.
- return false
+ cipher := cipherSuiteFromID(id)
+ return cipher != nil && cipher.flags&suiteTLS12 != 0
}
func isGREASEValue(val uint16) bool {
diff --git a/ssl/test/runner/key_agreement.go b/ssl/test/runner/key_agreement.go
index acea236..5294702 100644
--- a/ssl/test/runner/key_agreement.go
+++ b/ssl/test/runner/key_agreement.go
@@ -728,19 +728,15 @@
func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Credential, clientHello *clientHelloMsg, hello *serverHelloMsg, version uint16) (*serverKeyExchangeMsg, error) {
var curveID CurveID
preferredCurves := config.curvePreferences()
-
-NextCandidate:
for _, candidate := range preferredCurves {
if isPqGroup(candidate) && version < VersionTLS13 {
// Post-quantum "groups" require TLS 1.3.
continue
}
- for _, c := range clientHello.supportedCurves {
- if candidate == c {
- curveID = c
- break NextCandidate
- }
+ if slices.Contains(clientHello.supportedCurves, candidate) {
+ curveID = candidate
+ break
}
}
diff --git a/ssl/test/runner/kyber/kyber.go b/ssl/test/runner/kyber/kyber.go
index dd113f0..82c9dbd 100644
--- a/ssl/test/runner/kyber/kyber.go
+++ b/ssl/test/runner/kyber/kyber.go
@@ -18,14 +18,15 @@
import (
"crypto/subtle"
- "golang.org/x/crypto/sha3"
"io"
+
+ "golang.org/x/crypto/sha3"
)
-const(
- CiphertextSize = 1088
- PublicKeySize = 1184
- PrivateKeySize = 2400
+const (
+ CiphertextSize = 1088
+ PublicKeySize = 1184
+ PrivateKeySize = 2400
)
const (
@@ -104,9 +105,7 @@
type scalar [degree]uint16
func (s *scalar) zero() {
- for i := range s {
- s[i] = 0
- }
+ clear(s[:])
}
// This bit of Python will be referenced in some of the following comments:
diff --git a/ssl/test/runner/packet_adapter.go b/ssl/test/runner/packet_adapter.go
index a8da311..42684fb 100644
--- a/ssl/test/runner/packet_adapter.go
+++ b/ssl/test/runner/packet_adapter.go
@@ -9,6 +9,7 @@
"fmt"
"io"
"net"
+ "slices"
"time"
)
@@ -172,7 +173,7 @@
func (d *damageAdaptor) Write(b []byte) (int, error) {
if d.damage && len(b) > 0 {
- b = append([]byte{}, b...)
+ b = slices.Clone(b)
b[len(b)-1]++
}
return d.Conn.Write(b)