Test-only DTLS implementation in runner.go.
Run against openssl s_client and openssl s_server. This seems to work for a
start, although it may need to become cleverer to stress more of BoringSSL's
implementation for test purposes.
In particular, it assumes a reliable, in-order channel. And it requires that
the peer send handshake fragments in order. Retransmit and whatnot are not
implemented. The peer under test will be expected to handle a lossy channel,
but all loss in the channel will be controlled. MAC errors, etc., are fatal.
Change-Id: I329233cfb0994938fd012667ddf7c6a791ac7164
Reviewed-on: https://boringssl-review.googlesource.com/1390
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/cipher_suites.go b/ssl/test/runner/cipher_suites.go
index 8a9df4c..ed26f09 100644
--- a/ssl/test/runner/cipher_suites.go
+++ b/ssl/test/runner/cipher_suites.go
@@ -52,6 +52,9 @@
// suiteSHA384 indicates that the cipher suite uses SHA384 as the
// handshake hash.
suiteSHA384
+ // suiteNoDTLS indicates that the cipher suite cannot be used
+ // in DTLS.
+ suiteNoDTLS
)
// A cipherSuite is a specific combination of key agreement, cipher and MAC
@@ -76,8 +79,8 @@
{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM},
{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteTLS12, nil, nil, aeadAESGCM},
{TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
- {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE, cipherRC4, macSHA1, nil},
- {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECDSA, cipherRC4, macSHA1, nil},
+ {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE | suiteNoDTLS, cipherRC4, macSHA1, nil},
+ {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteNoDTLS, cipherRC4, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
{TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECDSA, cipherAES, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil},
@@ -88,8 +91,8 @@
{TLS_DHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, dheRSAKA, 0, cipherAES, macSHA1, nil},
{TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM},
{TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM},
- {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, 0, cipherRC4, macSHA1, nil},
- {TLS_RSA_WITH_RC4_128_MD5, 16, 16, 0, rsaKA, 0, cipherRC4, macMD5, nil},
+ {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, suiteNoDTLS, cipherRC4, macSHA1, nil},
+ {TLS_RSA_WITH_RC4_128_MD5, 16, 16, 0, rsaKA, suiteNoDTLS, cipherRC4, macMD5, nil},
{TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
{TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil},
{TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil},
@@ -145,7 +148,7 @@
type macFunction interface {
Size() int
- MAC(digestBuf, seq, header, data []byte) []byte
+ MAC(digestBuf, seq, header, length, data []byte) []byte
}
// fixedNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to
@@ -203,7 +206,7 @@
var ssl30Pad2 = [48]byte{0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c}
-func (s ssl30MAC) MAC(digestBuf, seq, header, data []byte) []byte {
+func (s ssl30MAC) MAC(digestBuf, seq, header, length, data []byte) []byte {
padLength := 48
if s.h.Size() == 20 {
padLength = 40
@@ -214,7 +217,7 @@
s.h.Write(ssl30Pad1[:padLength])
s.h.Write(seq)
s.h.Write(header[:1])
- s.h.Write(header[3:5])
+ s.h.Write(length)
s.h.Write(data)
digestBuf = s.h.Sum(digestBuf[:0])
@@ -234,10 +237,11 @@
return s.h.Size()
}
-func (s tls10MAC) MAC(digestBuf, seq, header, data []byte) []byte {
+func (s tls10MAC) MAC(digestBuf, seq, header, length, data []byte) []byte {
s.h.Reset()
s.h.Write(seq)
s.h.Write(header)
+ s.h.Write(length)
s.h.Write(data)
return s.h.Sum(digestBuf[:0])
}
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index ed60a3b..f14f4e9 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -25,10 +25,11 @@
)
const (
- maxPlaintext = 16384 // maximum plaintext payload length
- maxCiphertext = 16384 + 2048 // maximum ciphertext payload length
- recordHeaderLen = 5 // record header length
- maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
+ maxPlaintext = 16384 // maximum plaintext payload length
+ maxCiphertext = 16384 + 2048 // maximum ciphertext payload length
+ tlsRecordHeaderLen = 5 // record header length
+ dtlsRecordHeaderLen = 13
+ maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
minVersion = VersionSSL30
maxVersion = VersionTLS12
@@ -48,6 +49,7 @@
const (
typeClientHello uint8 = 1
typeServerHello uint8 = 2
+ typeHelloVerifyRequest uint8 = 3
typeNewSessionTicket uint8 = 4
typeCertificate uint8 = 11
typeServerKeyExchange uint8 = 12
@@ -414,6 +416,10 @@
// SendClientVersion, if non-zero, causes the client to send a different
// TLS version in the ClientHello than the maximum supported version.
SendClientVersion uint16
+
+ // SkipHelloVerifyRequest causes a DTLS server to skip the
+ // HelloVerifyRequest message.
+ SkipHelloVerifyRequest bool
}
func (c *Config) serverInit() {
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index f3e2495..5371a64 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -24,6 +24,7 @@
type Conn struct {
// constant
conn net.Conn
+ isDTLS bool
isClient bool
// constant after handshake; protected by handshakeMutex
@@ -49,8 +50,14 @@
// input/output
in, out halfConn // in.Mutex < out.Mutex
rawInput *block // raw input, right off the wire
- input *block // application data waiting to be read
- hand bytes.Buffer // handshake data waiting to be read
+ input *block // application record waiting to be read
+ hand bytes.Buffer // handshake record waiting to be read
+
+ // DTLS state
+ sendHandshakeSeq uint16
+ recvHandshakeSeq uint16
+ handMsg []byte // pending assembled handshake message
+ handMsgLen int // handshake message length, not including the header
tmp [16]byte
}
@@ -94,8 +101,9 @@
type halfConn struct {
sync.Mutex
- err error // first permanent error
- version uint16 // protocol version
+ err error // first permanent error
+ version uint16 // protocol version
+ isDTLS bool
cipher interface{} // cipher algorithm
mac macFunction
seq [8]byte // 64-bit sequence number
@@ -141,15 +149,18 @@
hc.nextCipher = nil
hc.nextMac = nil
hc.config = config
- for i := range hc.seq {
- hc.seq[i] = 0
- }
+ hc.incEpoch()
return nil
}
// incSeq increments the sequence number.
func (hc *halfConn) incSeq() {
- for i := 7; i >= 0; i-- {
+ limit := 0
+ if hc.isDTLS {
+ // Increment up to the epoch in DTLS.
+ limit = 2
+ }
+ for i := 7; i >= limit; i-- {
hc.seq[i]++
if hc.seq[i] != 0 {
return
@@ -162,11 +173,33 @@
panic("TLS: sequence number wraparound")
}
-// resetSeq resets the sequence number to zero.
-func (hc *halfConn) resetSeq() {
- for i := range hc.seq {
- hc.seq[i] = 0
+// incEpoch resets the sequence number. In DTLS, it increments the
+// epoch half of the sequence number.
+func (hc *halfConn) incEpoch() {
+ limit := 0
+ if hc.isDTLS {
+ for i := 1; i >= 0; i-- {
+ hc.seq[i]++
+ if hc.seq[i] != 0 {
+ break
+ }
+ if i == 0 {
+ panic("TLS: epoch number wraparound")
+ }
+ }
+ limit = 2
}
+ seq := hc.seq[limit:]
+ for i := range seq {
+ seq[i] = 0
+ }
+}
+
+func (hc *halfConn) recordHeaderLen() int {
+ if hc.isDTLS {
+ return dtlsRecordHeaderLen
+ }
+ return tlsRecordHeaderLen
}
// removePadding returns an unpadded slice, in constant time, which is a prefix
@@ -237,6 +270,8 @@
// success boolean, the number of bytes to skip from the start of the record in
// order to get the application payload, and an optional alert value.
func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) {
+ recordHeaderLen := hc.recordHeaderLen()
+
// pull out payload
payload := b.data[recordHeaderLen:]
@@ -248,6 +283,12 @@
paddingGood := byte(255)
explicitIVLen := 0
+ seq := hc.seq[:]
+ if hc.isDTLS {
+ // DTLS sequence numbers are explicit.
+ seq = b.data[3:11]
+ }
+
// decrypt
if hc.cipher != nil {
switch c := hc.cipher.(type) {
@@ -262,7 +303,7 @@
payload = payload[8:]
var additionalData [13]byte
- copy(additionalData[:], hc.seq[:])
+ copy(additionalData[:], seq)
copy(additionalData[8:], b.data[:3])
n := len(payload) - c.Overhead()
additionalData[11] = byte(n >> 8)
@@ -275,7 +316,7 @@
b.resize(recordHeaderLen + explicitIVLen + len(payload))
case cbcMode:
blockSize := c.BlockSize()
- if hc.version >= VersionTLS11 {
+ if hc.version >= VersionTLS11 || hc.isDTLS {
explicitIVLen = blockSize
}
@@ -318,11 +359,11 @@
// strip mac off payload, b.data
n := len(payload) - macSize
- b.data[3] = byte(n >> 8)
- b.data[4] = byte(n)
+ b.data[recordHeaderLen-2] = byte(n >> 8)
+ b.data[recordHeaderLen-1] = byte(n)
b.resize(recordHeaderLen + explicitIVLen + n)
remoteMAC := payload[n:]
- localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n])
+ localMAC := hc.mac.MAC(hc.inDigestBuf, seq, b.data[:3], b.data[recordHeaderLen-2:recordHeaderLen], payload[:n])
if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
return false, 0, alertBadRecordMAC
@@ -364,9 +405,11 @@
// encrypt encrypts and macs the data in b.
func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) {
+ recordHeaderLen := hc.recordHeaderLen()
+
// mac
if hc.mac != nil {
- mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:])
+ mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:3], b.data[recordHeaderLen-2:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:])
n := len(b.data)
b.resize(n + len(mac))
@@ -412,8 +455,8 @@
// update length to include MAC and any block padding needed.
n := len(b.data) - recordHeaderLen
- b.data[3] = byte(n >> 8)
- b.data[4] = byte(n)
+ b.data[recordHeaderLen-2] = byte(n >> 8)
+ b.data[recordHeaderLen-1] = byte(n)
hc.incSeq()
return true, 0
@@ -517,6 +560,86 @@
return b, bb
}
+func (c *Conn) doReadRecord(want recordType) (recordType, *block, error) {
+ if c.isDTLS {
+ return c.dtlsDoReadRecord(want)
+ }
+
+ recordHeaderLen := tlsRecordHeaderLen
+
+ if c.rawInput == nil {
+ c.rawInput = c.in.newBlock()
+ }
+ b := c.rawInput
+
+ // Read header, payload.
+ if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
+ // RFC suggests that EOF without an alertCloseNotify is
+ // an error, but popular web sites seem to do this,
+ // so we can't make it an error.
+ // if err == io.EOF {
+ // err = io.ErrUnexpectedEOF
+ // }
+ if e, ok := err.(net.Error); !ok || !e.Temporary() {
+ c.in.setErrorLocked(err)
+ }
+ return 0, nil, err
+ }
+ typ := recordType(b.data[0])
+
+ // No valid TLS record has a type of 0x80, however SSLv2 handshakes
+ // start with a uint16 length where the MSB is set and the first record
+ // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
+ // an SSLv2 client.
+ if want == recordTypeHandshake && typ == 0x80 {
+ c.sendAlert(alertProtocolVersion)
+ return 0, nil, c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received"))
+ }
+
+ vers := uint16(b.data[1])<<8 | uint16(b.data[2])
+ n := int(b.data[3])<<8 | int(b.data[4])
+ if c.haveVers && vers != c.vers {
+ c.sendAlert(alertProtocolVersion)
+ return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers))
+ }
+ if n > maxCiphertext {
+ c.sendAlert(alertRecordOverflow)
+ return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n))
+ }
+ if !c.haveVers {
+ // First message, be extra suspicious:
+ // this might not be a TLS client.
+ // Bail out before reading a full 'body', if possible.
+ // The current max version is 3.1.
+ // If the version is >= 16.0, it's probably not real.
+ // Similarly, a clientHello message encodes in
+ // well under a kilobyte. If the length is >= 12 kB,
+ // it's probably not real.
+ if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
+ c.sendAlert(alertUnexpectedMessage)
+ return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake"))
+ }
+ }
+ if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ if e, ok := err.(net.Error); !ok || !e.Temporary() {
+ c.in.setErrorLocked(err)
+ }
+ return 0, nil, err
+ }
+
+ // Process message.
+ b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
+ ok, off, err := c.in.decrypt(b)
+ if !ok {
+ c.in.setErrorLocked(c.sendAlert(err))
+ }
+ b.off = off
+ return typ, b, nil
+}
+
// readRecord reads the next TLS record from the connection
// and updates the record layer state.
// c.in.Mutex <= L; c.input == nil.
@@ -541,76 +664,10 @@
}
Again:
- if c.rawInput == nil {
- c.rawInput = c.in.newBlock()
- }
- b := c.rawInput
-
- // Read header, payload.
- if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
- // RFC suggests that EOF without an alertCloseNotify is
- // an error, but popular web sites seem to do this,
- // so we can't make it an error.
- // if err == io.EOF {
- // err = io.ErrUnexpectedEOF
- // }
- if e, ok := err.(net.Error); !ok || !e.Temporary() {
- c.in.setErrorLocked(err)
- }
+ typ, b, err := c.doReadRecord(want)
+ if err != nil {
return err
}
- typ := recordType(b.data[0])
-
- // No valid TLS record has a type of 0x80, however SSLv2 handshakes
- // start with a uint16 length where the MSB is set and the first record
- // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
- // an SSLv2 client.
- if want == recordTypeHandshake && typ == 0x80 {
- c.sendAlert(alertProtocolVersion)
- return c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received"))
- }
-
- vers := uint16(b.data[1])<<8 | uint16(b.data[2])
- n := int(b.data[3])<<8 | int(b.data[4])
- if c.haveVers && vers != c.vers {
- c.sendAlert(alertProtocolVersion)
- return c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers))
- }
- if n > maxCiphertext {
- c.sendAlert(alertRecordOverflow)
- return c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n))
- }
- if !c.haveVers {
- // First message, be extra suspicious:
- // this might not be a TLS client.
- // Bail out before reading a full 'body', if possible.
- // The current max version is 3.1.
- // If the version is >= 16.0, it's probably not real.
- // Similarly, a clientHello message encodes in
- // well under a kilobyte. If the length is >= 12 kB,
- // it's probably not real.
- if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
- c.sendAlert(alertUnexpectedMessage)
- return c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake"))
- }
- }
- if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
- if err == io.EOF {
- err = io.ErrUnexpectedEOF
- }
- if e, ok := err.(net.Error); !ok || !e.Temporary() {
- c.in.setErrorLocked(err)
- }
- return err
- }
-
- // Process message.
- b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
- ok, off, err := c.in.decrypt(b)
- if !ok {
- c.in.setErrorLocked(c.sendAlert(err))
- }
- b.off = off
data := b.data[b.off:]
if len(data) > maxPlaintext {
err := c.sendAlert(alertRecordOverflow)
@@ -713,6 +770,11 @@
// to the connection and updates the record layer state.
// c.out.Mutex <= L.
func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
+ if c.isDTLS {
+ return c.dtlsWriteRecord(typ, data)
+ }
+
+ recordHeaderLen := tlsRecordHeaderLen
b := c.out.newBlock()
first := true
isClientHello := typ == recordTypeHandshake && len(data) > 0 && data[0] == typeClientHello
@@ -800,10 +862,11 @@
return
}
-// readHandshake reads the next handshake message from
-// the record layer.
-// c.in.Mutex < L; c.out.Mutex < L.
-func (c *Conn) readHandshake() (interface{}, error) {
+func (c *Conn) doReadHandshake() ([]byte, error) {
+ if c.isDTLS {
+ return c.dtlsDoReadHandshake()
+ }
+
for c.hand.Len() < 4 {
if err := c.in.err; err != nil {
return nil, err
@@ -826,13 +889,28 @@
return nil, err
}
}
- data = c.hand.Next(4 + n)
+ return c.hand.Next(4 + n), nil
+}
+
+// readHandshake reads the next handshake message from
+// the record layer.
+// c.in.Mutex < L; c.out.Mutex < L.
+func (c *Conn) readHandshake() (interface{}, error) {
+ data, err := c.doReadHandshake()
+ if err != nil {
+ return nil, err
+ }
+
var m handshakeMessage
switch data[0] {
case typeClientHello:
- m = new(clientHelloMsg)
+ m = &clientHelloMsg{
+ isDTLS: c.isDTLS,
+ }
case typeServerHello:
- m = new(serverHelloMsg)
+ m = &serverHelloMsg{
+ isDTLS: c.isDTLS,
+ }
case typeNewSessionTicket:
m = new(newSessionTicketMsg)
case typeCertificate:
@@ -857,6 +935,8 @@
m = new(nextProtoMsg)
case typeFinished:
m = new(finishedMsg)
+ case typeHelloVerifyRequest:
+ m = new(helloVerifyRequestMsg)
default:
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
@@ -899,7 +979,7 @@
// http://www.imperialviolet.org/2012/01/15/beastfollowup.html
var m int
- if len(b) > 1 && c.vers <= VersionTLS10 {
+ if len(b) > 1 && c.vers <= VersionTLS10 && !c.isDTLS {
if _, ok := c.out.cipher.(cipher.BlockMode); ok {
n, err := c.writeRecord(recordTypeApplicationData, b[:1])
if err != nil {
@@ -938,7 +1018,7 @@
}
n, err = c.input.Read(b)
- if c.input.off >= len(c.input.data) {
+ if c.input.off >= len(c.input.data) || c.isDTLS {
c.in.freeBlock(c.input)
c.input = nil
}
diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go
new file mode 100644
index 0000000..3b3f5a2
--- /dev/null
+++ b/ssl/test/runner/dtls.go
@@ -0,0 +1,326 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// DTLS implementation.
+//
+// NOTE: This is a not even a remotely production-quality DTLS
+// implementation. It is the bare minimum necessary to be able to
+// achieve coverage on BoringSSL's implementation. Of note is that
+// this implementation assumes the underlying net.PacketConn is not
+// only reliable but also ordered. BoringSSL will be expected to deal
+// with simulated loss, but there is no point in forcing the test
+// driver to.
+
+package main
+
+import (
+ "bytes"
+ "crypto/cipher"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+)
+
+func versionToWire(vers uint16, isDTLS bool) uint16 {
+ if isDTLS {
+ return ^(vers - 0x0201)
+ }
+ return vers
+}
+
+func wireToVersion(vers uint16, isDTLS bool) uint16 {
+ if isDTLS {
+ return ^vers + 0x0201
+ }
+ return vers
+}
+
+func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) {
+ recordHeaderLen := dtlsRecordHeaderLen
+
+ if c.rawInput == nil {
+ c.rawInput = c.in.newBlock()
+ }
+ b := c.rawInput
+
+ // Read a new packet only if the current one is empty.
+ if len(b.data) == 0 {
+ // Pick some absurdly large buffer size.
+ b.resize(maxCiphertext + recordHeaderLen)
+ n, err := c.conn.Read(c.rawInput.data)
+ if err != nil {
+ return 0, nil, err
+ }
+ c.rawInput.resize(n)
+ }
+
+ // Read out one record.
+ //
+ // A real DTLS implementation should be tolerant of errors,
+ // but this is test code. We should not be tolerant of our
+ // peer sending garbage.
+ if len(b.data) < recordHeaderLen {
+ return 0, nil, errors.New("dtls: failed to read record header")
+ }
+ typ := recordType(b.data[0])
+ vers := wireToVersion(uint16(b.data[1])<<8|uint16(b.data[2]), c.isDTLS)
+ if c.haveVers && vers != c.vers {
+ c.sendAlert(alertProtocolVersion)
+ return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.vers))
+ }
+ seq := b.data[3:11]
+ // For test purposes, we assume a reliable channel. Require
+ // that the explicit sequence number matches the incrementing
+ // one we maintain. A real implementation would maintain a
+ // replay window and such.
+ if !bytes.Equal(seq, c.in.seq[:]) {
+ c.sendAlert(alertIllegalParameter)
+ return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
+ }
+ n := int(b.data[11])<<8 | int(b.data[12])
+ if n > maxCiphertext || len(b.data) < recordHeaderLen+n {
+ c.sendAlert(alertRecordOverflow)
+ return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n))
+ }
+
+ // Process message.
+ b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
+ ok, off, err := c.in.decrypt(b)
+ if !ok {
+ c.in.setErrorLocked(c.sendAlert(err))
+ }
+ b.off = off
+ return typ, b, nil
+}
+
+func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) {
+ recordHeaderLen := dtlsRecordHeaderLen
+ maxLen := c.config.Bugs.MaxHandshakeRecordLength
+ if maxLen <= 0 {
+ maxLen = 1024
+ }
+
+ b := c.out.newBlock()
+
+ var header []byte
+ if typ == recordTypeHandshake {
+ // Handshake messages have to be modified to include
+ // fragment offset and length and with the header
+ // replicated. Save the header here.
+ //
+ // TODO(davidben): This assumes that data contains
+ // exactly one handshake message. This is incompatible
+ // with FragmentAcrossChangeCipherSpec. (Which is
+ // unfortunate because OpenSSL's DTLS implementation
+ // will probably accept such fragmentation and could
+ // do with a fix + tests.)
+ if len(data) < 4 {
+ // This should not happen.
+ panic(data)
+ }
+ header = data[:4]
+ data = data[4:]
+ }
+
+ firstRun := true
+ for firstRun || len(data) > 0 {
+ firstRun = false
+ m := len(data)
+ var fragment []byte
+ // Handshake messages get fragmented. Other records we
+ // pass-through as is. DTLS should be a packet
+ // interface.
+ if typ == recordTypeHandshake {
+ if m > maxLen {
+ m = maxLen
+ }
+
+ // Standard handshake header.
+ fragment = make([]byte, 0, 12+m)
+ fragment = append(fragment, header...)
+ // message_seq
+ fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq))
+ // fragment_offset
+ fragment = append(fragment, byte(n>>16), byte(n>>8), byte(n))
+ // fragment_length
+ fragment = append(fragment, byte(m>>16), byte(m>>8), byte(m))
+ fragment = append(fragment, data[:m]...)
+ } else {
+ fragment = data[:m]
+ }
+
+ // Send the fragment.
+ explicitIVLen := 0
+ explicitIVIsSeq := false
+
+ if cbc, ok := c.out.cipher.(cbcMode); ok {
+ // Block cipher modes have an explicit IV.
+ explicitIVLen = cbc.BlockSize()
+ } else if _, ok := c.out.cipher.(cipher.AEAD); ok {
+ explicitIVLen = 8
+ // The AES-GCM construction in TLS has an
+ // explicit nonce so that the nonce can be
+ // random. However, the nonce is only 8 bytes
+ // which is too small for a secure, random
+ // nonce. Therefore we use the sequence number
+ // as the nonce.
+ explicitIVIsSeq = true
+ } else if c.out.cipher != nil {
+ panic("Unknown cipher")
+ }
+ b.resize(recordHeaderLen + explicitIVLen + len(fragment))
+ b.data[0] = byte(typ)
+ vers := c.vers
+ if vers == 0 {
+ // Some TLS servers fail if the record version is
+ // greater than TLS 1.0 for the initial ClientHello.
+ vers = VersionTLS10
+ }
+ vers = versionToWire(vers, c.isDTLS)
+ b.data[1] = byte(vers >> 8)
+ b.data[2] = byte(vers)
+ // DTLS records include an explicit sequence number.
+ copy(b.data[3:11], c.out.seq[0:])
+ b.data[11] = byte(len(fragment) >> 8)
+ b.data[12] = byte(len(fragment))
+ if explicitIVLen > 0 {
+ explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
+ if explicitIVIsSeq {
+ copy(explicitIV, c.out.seq[:])
+ } else {
+ if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
+ break
+ }
+ }
+ }
+ copy(b.data[recordHeaderLen+explicitIVLen:], fragment)
+ c.out.encrypt(b, explicitIVLen)
+
+ // TODO(davidben): A real DTLS implementation needs to
+ // retransmit handshake messages. For testing
+ // purposes, we don't actually care.
+ _, err = c.conn.Write(b.data)
+ if err != nil {
+ break
+ }
+ n += m
+ data = data[m:]
+ }
+ c.out.freeBlock(b)
+
+ // Increment the handshake sequence number for the next
+ // handshake message.
+ if typ == recordTypeHandshake {
+ c.sendHandshakeSeq++
+ }
+
+ if typ == recordTypeChangeCipherSpec {
+ err = c.out.changeCipherSpec(c.config)
+ if err != nil {
+ // Cannot call sendAlert directly,
+ // because we already hold c.out.Mutex.
+ c.tmp[0] = alertLevelError
+ c.tmp[1] = byte(err.(alert))
+ c.writeRecord(recordTypeAlert, c.tmp[0:2])
+ return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
+ }
+ }
+ return
+}
+
+func (c *Conn) dtlsDoReadHandshake() ([]byte, error) {
+ // Assemble a full handshake message. For test purposes, this
+ // implementation assumes fragments arrive in order. It may
+ // need to be cleverer if we ever test BoringSSL's retransmit
+ // behavior.
+ for len(c.handMsg) < 4+c.handMsgLen {
+ // Get a new handshake record if the previous has been
+ // exhausted.
+ if c.hand.Len() == 0 {
+ if err := c.in.err; err != nil {
+ return nil, err
+ }
+ if err := c.readRecord(recordTypeHandshake); err != nil {
+ return nil, err
+ }
+ }
+
+ // Read the next fragment. It must fit entirely within
+ // the record.
+ if c.hand.Len() < 12 {
+ return nil, errors.New("dtls: bad handshake record")
+ }
+ header := c.hand.Next(12)
+ fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3])
+ fragSeq := uint16(header[4])<<8 | uint16(header[5])
+ fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8])
+ fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11])
+
+ if c.hand.Len() < fragLen {
+ return nil, errors.New("dtls: fragment length too long")
+ }
+ fragment := c.hand.Next(fragLen)
+
+ // Check it's a fragment for the right message.
+ if fragSeq != c.recvHandshakeSeq {
+ return nil, errors.New("dtls: bad handshake sequence number")
+ }
+
+ // Check that the length is consistent.
+ if c.handMsg == nil {
+ c.handMsgLen = fragN
+ if c.handMsgLen > maxHandshake {
+ return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
+ }
+ // Start with the TLS handshake header,
+ // without the DTLS bits.
+ c.handMsg = append([]byte{}, header[:4]...)
+ } else if fragN != c.handMsgLen {
+ return nil, errors.New("dtls: bad handshake length")
+ }
+
+ // Add the fragment to the pending message.
+ if 4+fragOff != len(c.handMsg) {
+ return nil, errors.New("dtls: bad fragment offset")
+ }
+ if fragOff+fragLen > c.handMsgLen {
+ return nil, errors.New("dtls: bad fragment length")
+ }
+ c.handMsg = append(c.handMsg, fragment...)
+ }
+ c.recvHandshakeSeq++
+ ret := c.handMsg
+ c.handMsg, c.handMsgLen = nil, 0
+ return ret, nil
+}
+
+// DTLSServer returns a new DTLS server side connection
+// using conn as the underlying transport.
+// The configuration config must be non-nil and must have
+// at least one certificate.
+func DTLSServer(conn net.Conn, config *Config) *Conn {
+ return &Conn{
+ config: config,
+ isDTLS: true,
+ conn: conn,
+ in: halfConn{isDTLS: true},
+ out: halfConn{isDTLS: true},
+ }
+}
+
+// DTLSClient returns a new DTLS client side connection
+// using conn as the underlying transport.
+// The config cannot be nil: users must set either ServerHostname or
+// InsecureSkipVerify in the config.
+func DTLSClient(conn net.Conn, config *Config) *Conn {
+ return &Conn{
+ config: config,
+ isClient: true,
+ isDTLS: true,
+ conn: conn,
+ in: halfConn{isDTLS: true},
+ out: halfConn{isDTLS: true},
+ }
+}
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index fa84074..f2cbbe4 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -39,7 +39,11 @@
return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
}
+ c.sendHandshakeSeq = 0
+ c.recvHandshakeSeq = 0
+
hello := &clientHelloMsg{
+ isDTLS: c.isDTLS,
vers: c.config.maxVersion(),
compressionMethods: []uint8{compressionNone},
random: make([]byte, 32),
@@ -70,6 +74,10 @@
if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 {
continue
}
+ // Don't advertise non-DTLS cipher suites on DTLS.
+ if c.isDTLS && suite.flags&suiteNoDTLS != 0 {
+ continue
+ }
hello.cipherSuites = append(hello.cipherSuites, suiteId)
continue NextCipherSuite
}
@@ -154,6 +162,22 @@
if err != nil {
return err
}
+
+ if c.isDTLS {
+ helloVerifyRequest, ok := msg.(*helloVerifyRequestMsg)
+ if ok {
+ hello.raw = nil
+ hello.cookie = helloVerifyRequest.cookie
+ helloBytes = hello.marshal()
+ c.writeRecord(recordTypeHandshake, helloBytes)
+
+ msg, err = c.readHandshake()
+ if err != nil {
+ return err
+ }
+ }
+ }
+
serverHello, ok := msg.(*serverHelloMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
@@ -184,8 +208,8 @@
session: session,
}
- hs.finishedHash.Write(helloBytes)
- hs.finishedHash.Write(hs.serverHello.marshal())
+ hs.writeHash(helloBytes, hs.c.sendHandshakeSeq-1)
+ hs.writeServerHash(hs.serverHello.marshal())
if c.config.Bugs.EarlyChangeCipherSpec > 0 {
hs.establishKeys()
@@ -252,7 +276,7 @@
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
- hs.finishedHash.Write(certMsg.marshal())
+ hs.writeServerHash(certMsg.marshal())
certs := make([]*x509.Certificate, len(certMsg.certificates))
for i, asn1Data := range certMsg.certificates {
@@ -305,7 +329,7 @@
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(cs, msg)
}
- hs.finishedHash.Write(cs.marshal())
+ hs.writeServerHash(cs.marshal())
if cs.statusType == statusTypeOCSP {
c.ocspResponse = cs.response
@@ -321,7 +345,7 @@
skx, ok := msg.(*serverKeyExchangeMsg)
if ok {
- hs.finishedHash.Write(skx.marshal())
+ hs.writeServerHash(skx.marshal())
err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, certs[0], skx)
if err != nil {
c.sendAlert(alertUnexpectedMessage)
@@ -351,7 +375,7 @@
// ClientCertificateType, unless there is some external
// arrangement to the contrary.
- hs.finishedHash.Write(certReq.marshal())
+ hs.writeServerHash(certReq.marshal())
var rsaAvail, ecdsaAvail bool
for _, certType := range certReq.certificateTypes {
@@ -417,7 +441,7 @@
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(shd, msg)
}
- hs.finishedHash.Write(shd.marshal())
+ hs.writeServerHash(shd.marshal())
// If the server requested a certificate then we have to send a
// Certificate message, even if it's empty because we don't have a
@@ -427,7 +451,7 @@
if chainToSend != nil {
certMsg.certificates = chainToSend.Certificate
}
- hs.finishedHash.Write(certMsg.marshal())
+ hs.writeClientHash(certMsg.marshal())
c.writeRecord(recordTypeHandshake, certMsg.marshal())
}
@@ -438,7 +462,7 @@
}
if ckx != nil {
if c.config.Bugs.EarlyChangeCipherSpec < 2 {
- hs.finishedHash.Write(ckx.marshal())
+ hs.writeClientHash(ckx.marshal())
}
c.writeRecord(recordTypeHandshake, ckx.marshal())
}
@@ -486,7 +510,7 @@
}
certVerify.signature = signed
- hs.finishedHash.Write(certVerify.marshal())
+ hs.writeClientHash(certVerify.marshal())
c.writeRecord(recordTypeHandshake, certVerify.marshal())
}
@@ -571,7 +595,7 @@
return errors.New("tls: server's Finished message was incorrect")
}
}
- hs.finishedHash.Write(serverFinished.marshal())
+ hs.writeServerHash(serverFinished.marshal())
return nil
}
@@ -590,7 +614,7 @@
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(sessionTicketMsg, msg)
}
- hs.finishedHash.Write(sessionTicketMsg.marshal())
+ hs.writeServerHash(sessionTicketMsg.marshal())
hs.session = &ClientSessionState{
sessionTicket: sessionTicketMsg.ticket,
@@ -607,6 +631,7 @@
c := hs.c
var postCCSBytes []byte
+ seqno := hs.c.sendHandshakeSeq
if hs.serverHello.nextProtoNeg {
nextProto := new(nextProtoMsg)
proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos)
@@ -615,7 +640,8 @@
c.clientProtocolFallback = fallback
nextProtoBytes := nextProto.marshal()
- hs.finishedHash.Write(nextProtoBytes)
+ hs.writeHash(nextProtoBytes, seqno)
+ seqno++
postCCSBytes = append(postCCSBytes, nextProtoBytes...)
}
@@ -626,7 +652,7 @@
finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
}
finishedBytes := finished.marshal()
- hs.finishedHash.Write(finishedBytes)
+ hs.writeHash(finishedBytes, seqno)
postCCSBytes = append(postCCSBytes, finishedBytes...)
if c.config.Bugs.FragmentAcrossChangeCipherSpec {
@@ -643,6 +669,32 @@
return nil
}
+func (hs *clientHandshakeState) writeClientHash(msg []byte) {
+ // writeClientHash is called before writeRecord.
+ hs.writeHash(msg, hs.c.sendHandshakeSeq)
+}
+
+func (hs *clientHandshakeState) writeServerHash(msg []byte) {
+ // writeServerHash is called after readHandshake.
+ hs.writeHash(msg, hs.c.recvHandshakeSeq-1)
+}
+
+func (hs *clientHandshakeState) writeHash(msg []byte, seqno uint16) {
+ if hs.c.isDTLS {
+ // This is somewhat hacky. DTLS hashes a slightly different format.
+ // First, the TLS header.
+ hs.finishedHash.Write(msg[:4])
+ // Then the sequence number and reassembled fragment offset (always 0).
+ hs.finishedHash.Write([]byte{byte(seqno >> 8), byte(seqno), 0, 0, 0})
+ // Then the reassembled fragment (always equal to the message length).
+ hs.finishedHash.Write(msg[1:4])
+ // And then the message body.
+ hs.finishedHash.Write(msg[4:])
+ } else {
+ hs.finishedHash.Write(msg)
+ }
+}
+
// clientSessionCacheKey returns a key used to cache sessionTickets that could
// be used to resume previously negotiated TLS sessions with a server.
func clientSessionCacheKey(serverAddr net.Addr, config *Config) string {
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 1c633bb..7fe8bf5 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -8,9 +8,11 @@
type clientHelloMsg struct {
raw []byte
+ isDTLS bool
vers uint16
random []byte
sessionId []byte
+ cookie []byte
cipherSuites []uint16
compressionMethods []uint8
nextProtoNeg bool
@@ -32,9 +34,11 @@
}
return bytes.Equal(m.raw, m1.raw) &&
+ m.isDTLS == m1.isDTLS &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
+ bytes.Equal(m.cookie, m1.cookie) &&
eqUint16s(m.cipherSuites, m1.cipherSuites) &&
bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
m.nextProtoNeg == m1.nextProtoNeg &&
@@ -54,6 +58,9 @@
}
length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
+ if m.isDTLS {
+ length += 1 + len(m.cookie)
+ }
numExtensions := 0
extensionsLength := 0
if m.nextProtoNeg {
@@ -100,12 +107,18 @@
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
- x[4] = uint8(m.vers >> 8)
- x[5] = uint8(m.vers)
+ vers := versionToWire(m.vers, m.isDTLS)
+ x[4] = uint8(vers >> 8)
+ x[5] = uint8(vers)
copy(x[6:38], m.random)
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
y := x[39+len(m.sessionId):]
+ if m.isDTLS {
+ y[0] = uint8(len(m.cookie))
+ copy(y[1:], m.cookie)
+ y = y[1+len(m.cookie):]
+ }
y[0] = uint8(len(m.cipherSuites) >> 7)
y[1] = uint8(len(m.cipherSuites) << 1)
for i, suite := range m.cipherSuites {
@@ -264,7 +277,7 @@
return false
}
m.raw = data
- m.vers = uint16(data[4])<<8 | uint16(data[5])
+ m.vers = wireToVersion(uint16(data[4])<<8|uint16(data[5]), m.isDTLS)
m.random = data[6:38]
sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
@@ -272,6 +285,17 @@
}
m.sessionId = data[39 : 39+sessionIdLen]
data = data[39+sessionIdLen:]
+ if m.isDTLS {
+ if len(data) < 1 {
+ return false
+ }
+ cookieLen := int(data[0])
+ if cookieLen > 32 || len(data) < 1+cookieLen {
+ return false
+ }
+ m.cookie = data[1 : 1+cookieLen]
+ data = data[1+cookieLen:]
+ }
if len(data) < 2 {
return false
}
@@ -425,6 +449,7 @@
type serverHelloMsg struct {
raw []byte
+ isDTLS bool
vers uint16
random []byte
sessionId []byte
@@ -445,6 +470,7 @@
}
return bytes.Equal(m.raw, m1.raw) &&
+ m.isDTLS == m1.isDTLS &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
bytes.Equal(m.sessionId, m1.sessionId) &&
@@ -498,8 +524,9 @@
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
- x[4] = uint8(m.vers >> 8)
- x[5] = uint8(m.vers)
+ vers := versionToWire(m.vers, m.isDTLS)
+ x[4] = uint8(vers >> 8)
+ x[5] = uint8(vers)
copy(x[6:38], m.random)
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
@@ -571,7 +598,7 @@
return false
}
m.raw = data
- m.vers = uint16(data[4])<<8 | uint16(data[5])
+ m.vers = wireToVersion(uint16(data[4])<<8|uint16(data[5]), m.isDTLS)
m.random = data[6:38]
sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
@@ -1368,6 +1395,58 @@
return x
}
+type helloVerifyRequestMsg struct {
+ raw []byte
+ vers uint16
+ cookie []byte
+}
+
+func (m *helloVerifyRequestMsg) equal(i interface{}) bool {
+ m1, ok := i.(*helloVerifyRequestMsg)
+ if !ok {
+ return false
+ }
+
+ return m.vers == m1.vers &&
+ bytes.Equal(m.cookie, m1.cookie)
+}
+
+func (m *helloVerifyRequestMsg) marshal() []byte {
+ if m.raw != nil {
+ return m.raw
+ }
+
+ length := 2 + 1 + len(m.cookie)
+
+ x := make([]byte, 4+length)
+ x[0] = typeHelloVerifyRequest
+ x[1] = uint8(length >> 16)
+ x[2] = uint8(length >> 8)
+ x[3] = uint8(length)
+ vers := versionToWire(m.vers, true)
+ x[4] = uint8(vers >> 8)
+ x[5] = uint8(vers)
+ x[6] = uint8(len(m.cookie))
+ copy(x[7:7+len(m.cookie)], m.cookie)
+
+ return x
+}
+
+func (m *helloVerifyRequestMsg) unmarshal(data []byte) bool {
+ if len(data) < 4+2+1 {
+ return false
+ }
+ m.raw = data
+ m.vers = wireToVersion(uint16(data[4])<<8|uint16(data[5]), true)
+ cookieLen := int(data[6])
+ if cookieLen > 32 || len(data) != 7+cookieLen {
+ return false
+ }
+ m.cookie = data[7 : 7+cookieLen]
+
+ return true
+}
+
func eqUint16s(x, y []uint16) bool {
if len(x) != len(y) {
return false
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 3b8ad6a..3a54eb2 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -5,6 +5,7 @@
package main
import (
+ "bytes"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
@@ -40,6 +41,9 @@
// encrypt the tickets with.
config.serverInitOnce.Do(config.serverInit)
+ c.sendHandshakeSeq = 0
+ c.recvHandshakeSeq = 0
+
hs := serverHandshakeState{
c: c,
}
@@ -114,9 +118,44 @@
c.sendAlert(alertProtocolVersion)
return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
}
+
+ if c.isDTLS && !config.Bugs.SkipHelloVerifyRequest {
+ helloVerifyRequest := &helloVerifyRequestMsg{
+ vers: c.vers,
+ cookie: make([]byte, 32),
+ }
+ if _, err := io.ReadFull(c.config.rand(), helloVerifyRequest.cookie); err != nil {
+ c.sendAlert(alertInternalError)
+ return false, errors.New("dtls: short read from Rand: " + err.Error())
+ }
+ c.writeRecord(recordTypeHandshake, helloVerifyRequest.marshal())
+
+ msg, err := c.readHandshake()
+ if err != nil {
+ return false, err
+ }
+ newClientHello, ok := msg.(*clientHelloMsg)
+ if !ok {
+ c.sendAlert(alertUnexpectedMessage)
+ return false, unexpectedMessageError(hs.clientHello, msg)
+ }
+ if !bytes.Equal(newClientHello.cookie, helloVerifyRequest.cookie) {
+ return false, errors.New("dtls: invalid cookie")
+ }
+ // Apart from the cookie, client hello must match.
+ hs.clientHello.cookie = newClientHello.cookie
+ if hs.clientHello.equal(newClientHello) {
+ return false, errors.New("dtls: retransmitted ClientHello does not match")
+ }
+ hs.clientHello = newClientHello
+ }
+
+ // Do not set c.haveVers until after HelloVerifyRequest; the
+ // retransmitted ClientHello may not have the final version.
c.haveVers = true
hs.hello = new(serverHelloMsg)
+ hs.hello.isDTLS = c.isDTLS
supportedCurve := false
preferredCurves := config.curvePreferences()
@@ -284,8 +323,8 @@
hs.hello.ticketSupported = c.config.Bugs.RenewTicketOnResume
hs.finishedHash = newFinishedHash(c.vers, hs.suite)
- hs.finishedHash.Write(hs.clientHello.marshal())
- hs.finishedHash.Write(hs.hello.marshal())
+ hs.writeClientHash(hs.clientHello.marshal())
+ hs.writeServerHash(hs.hello.marshal())
c.writeRecord(recordTypeHandshake, hs.hello.marshal())
@@ -312,15 +351,15 @@
hs.hello.cipherSuite = hs.suite.id
hs.finishedHash = newFinishedHash(c.vers, hs.suite)
- hs.finishedHash.Write(hs.clientHello.marshal())
- hs.finishedHash.Write(hs.hello.marshal())
+ hs.writeClientHash(hs.clientHello.marshal())
+ hs.writeServerHash(hs.hello.marshal())
c.writeRecord(recordTypeHandshake, hs.hello.marshal())
certMsg := new(certificateMsg)
certMsg.certificates = hs.cert.Certificate
if !config.Bugs.UnauthenticatedECDH {
- hs.finishedHash.Write(certMsg.marshal())
+ hs.writeServerHash(certMsg.marshal())
c.writeRecord(recordTypeHandshake, certMsg.marshal())
}
@@ -328,7 +367,7 @@
certStatus := new(certificateStatusMsg)
certStatus.statusType = statusTypeOCSP
certStatus.response = hs.cert.OCSPStaple
- hs.finishedHash.Write(certStatus.marshal())
+ hs.writeServerHash(certStatus.marshal())
c.writeRecord(recordTypeHandshake, certStatus.marshal())
}
@@ -339,7 +378,7 @@
return err
}
if skx != nil && !config.Bugs.SkipServerKeyExchange {
- hs.finishedHash.Write(skx.marshal())
+ hs.writeServerHash(skx.marshal())
c.writeRecord(recordTypeHandshake, skx.marshal())
}
@@ -367,12 +406,12 @@
if config.ClientCAs != nil {
certReq.certificateAuthorities = config.ClientCAs.Subjects()
}
- hs.finishedHash.Write(certReq.marshal())
+ hs.writeServerHash(certReq.marshal())
c.writeRecord(recordTypeHandshake, certReq.marshal())
}
helloDone := new(serverHelloDoneMsg)
- hs.finishedHash.Write(helloDone.marshal())
+ hs.writeServerHash(helloDone.marshal())
c.writeRecord(recordTypeHandshake, helloDone.marshal())
var pub crypto.PublicKey // public key for client auth, if any
@@ -390,7 +429,7 @@
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
- hs.finishedHash.Write(certMsg.marshal())
+ hs.writeClientHash(certMsg.marshal())
if len(certMsg.certificates) == 0 {
// The client didn't actually send a certificate
@@ -418,7 +457,7 @@
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(ckx, msg)
}
- hs.finishedHash.Write(ckx.marshal())
+ hs.writeClientHash(ckx.marshal())
// If we received a client cert in response to our certificate request message,
// the client will send us a certificateVerifyMsg immediately after the
@@ -494,7 +533,7 @@
return errors.New("could not validate signature of connection nonces: " + err.Error())
}
- hs.finishedHash.Write(certVerify.marshal())
+ hs.writeClientHash(certVerify.marshal())
}
preMasterSecret, err := keyAgreement.processClientKeyExchange(config, hs.cert, ckx, c.vers)
@@ -550,7 +589,7 @@
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(nextProto, msg)
}
- hs.finishedHash.Write(nextProto.marshal())
+ hs.writeClientHash(nextProto.marshal())
c.clientProtocol = nextProto.proto
}
@@ -571,7 +610,7 @@
return errors.New("tls: client's Finished message is incorrect")
}
- hs.finishedHash.Write(clientFinished.marshal())
+ hs.writeClientHash(clientFinished.marshal())
return nil
}
@@ -595,7 +634,7 @@
return err
}
- hs.finishedHash.Write(m.marshal())
+ hs.writeServerHash(m.marshal())
c.writeRecord(recordTypeHandshake, m.marshal())
return nil
@@ -607,7 +646,7 @@
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
postCCSBytes := finished.marshal()
- hs.finishedHash.Write(postCCSBytes)
+ hs.writeServerHash(postCCSBytes)
if c.config.Bugs.FragmentAcrossChangeCipherSpec {
c.writeRecord(recordTypeHandshake, postCCSBytes[:5])
@@ -690,6 +729,32 @@
return nil, nil
}
+func (hs *serverHandshakeState) writeServerHash(msg []byte) {
+ // writeServerHash is called before writeRecord.
+ hs.writeHash(msg, hs.c.sendHandshakeSeq)
+}
+
+func (hs *serverHandshakeState) writeClientHash(msg []byte) {
+ // writeClientHash is called after readHandshake.
+ hs.writeHash(msg, hs.c.recvHandshakeSeq-1)
+}
+
+func (hs *serverHandshakeState) writeHash(msg []byte, seqno uint16) {
+ if hs.c.isDTLS {
+ // This is somewhat hacky. DTLS hashes a slightly different format.
+ // First, the TLS header.
+ hs.finishedHash.Write(msg[:4])
+ // Then the sequence number and reassembled fragment offset (always 0).
+ hs.finishedHash.Write([]byte{byte(seqno >> 8), byte(seqno), 0, 0, 0})
+ // Then the reassembled fragment (always equal to the message length).
+ hs.finishedHash.Write(msg[1:4])
+ // And then the message body.
+ hs.finishedHash.Write(msg[4:])
+ } else {
+ hs.finishedHash.Write(msg)
+ }
+}
+
// tryCipherSuite returns a cipherSuite with the given id if that cipher suite
// is acceptable to use.
func (c *Conn) tryCipherSuite(id uint16, supportedCipherSuites []uint16, version uint16, ellipticOk, ecdsaOk bool) *cipherSuite {
@@ -717,6 +782,9 @@
if version < VersionTLS12 && candidate.flags&suiteTLS12 != 0 {
continue
}
+ if c.isDTLS && candidate.flags&suiteNoDTLS != 0 {
+ continue
+ }
return candidate
}
}
diff --git a/ssl/test/runner/key_agreement.go b/ssl/test/runner/key_agreement.go
index 929eb06..a678fee 100644
--- a/ssl/test/runner/key_agreement.go
+++ b/ssl/test/runner/key_agreement.go
@@ -74,6 +74,7 @@
if config.Bugs.RsaClientKeyExchangeVersion != 0 {
vers = config.Bugs.RsaClientKeyExchangeVersion
}
+ vers = versionToWire(vers, clientHello.isDTLS)
preMasterSecret[0] = byte(vers >> 8)
preMasterSecret[1] = byte(vers)
_, err := io.ReadFull(config.rand(), preMasterSecret[2:])