blob: d2625070b080996521e84048260e12a7de06f3ef [file] [log] [blame]
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// TLS low level connection and record layer
package runner
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/subtle"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"slices"
"sync"
"time"
"golang.org/x/crypto/chacha20"
"golang.org/x/crypto/cryptobyte"
)
type dtlsRecordInfo struct {
typ recordType
epoch uint16
// bytesAvailable is the number of additional bytes of plaintext that could
// have been added to this record without exceeding the packet limit.
bytesAvailable int
}
// A Conn represents a secured connection.
// It implements the net.Conn interface.
type Conn struct {
// constant
conn net.Conn
isDTLS bool
isClient bool
// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
handshakeErr error // error resulting from handshake
wireVersion uint16 // TLS wire version
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
handshakeComplete bool
skipEarlyData bool // On a server, indicates that the client is sending early data that must be skipped over.
didResume bool // whether this connection was a session resumption
extendedMasterSecret bool // whether this session used an extended master secret
cipherSuite *cipherSuite
ocspResponse []byte // stapled OCSP response
sctList []byte // signed certificate timestamp list
peerCertificates []*x509.Certificate
peerDelegatedCredential []byte
// verifiedChains contains the certificate chains that we built, as
// opposed to the ones presented by the server.
verifiedChains [][]*x509.Certificate
// serverName contains the server name indicated by the client, if any.
serverName string
// firstFinished contains the first Finished hash sent during the
// handshake. This is the "tls-unique" channel binding value.
firstFinished [12]byte
// peerSignatureAlgorithm contains the signature algorithm that was used
// by the peer in the handshake, or zero if not applicable.
peerSignatureAlgorithm signatureAlgorithm
// curveID contains the curve that was used in the handshake, or zero if
// not applicable.
curveID CurveID
// quicTransportParams contains the QUIC transport params received
// by the peer using codepoint 57.
quicTransportParams []byte
// quicTransportParams contains the QUIC transport params received
// by the peer using legacy codepoint 0xffa5.
quicTransportParamsLegacy []byte
clientRandom, serverRandom [32]byte
earlyExporterSecret []byte
exporterSecret []byte
resumptionSecret []byte
clientProtocol string
clientProtocolFallback bool
usedALPN bool
localApplicationSettings, peerApplicationSettings []byte
hasApplicationSettings bool
localApplicationSettingsOld, peerApplicationSettingsOld []byte
hasApplicationSettingsOld bool
// verify_data values for the renegotiation extension.
clientVerify []byte
serverVerify []byte
channelID *ecdsa.PublicKey
srtpProtectionProfile uint16
clientVersion uint16
// input/output
in, out halfConn // in.Mutex < out.Mutex
rawInput bytes.Buffer // raw input, right off the wire
input bytes.Buffer // application record waiting to be read
hand bytes.Buffer // handshake record waiting to be read
// pendingFlight, if PackHandshakeFlight is enabled, is the buffer of
// handshake data to be split into records at the end of the flight.
pendingFlight bytes.Buffer
// DTLS state
sendHandshakeSeq uint16
recvHandshakeSeq uint16
handMsg []byte // pending assembled handshake message
handMsgLen int // handshake message length, not including the header
pendingPacket []byte // pending outgoing packet.
maxPacketLen int
previousFlight []DTLSMessage
receivedFlight []DTLSMessage
receivedFlightRecords []DTLSRecordNumberInfo
nextFlight []DTLSMessage
expectedACK []DTLSRecordNumber
keyUpdateSeen bool
keyUpdateRequested bool
seenOneByteRecord bool
expectTLS13ChangeCipherSpec bool
// seenHandshakePackEnd is whether the most recent handshake record was
// not full for ExpectPackedEncryptedHandshake. If true, no more
// handshake data may be received until the next flight or epoch change.
seenHandshakePackEnd bool
// lastRecordInFlight contains information about the previous handshake or
// ChangeCipherSpec record from the current flight, or nil if we are not in
// the middle of reading a flight from the peer.
lastRecordInFlight *dtlsRecordInfo
// bytesAvailableInPacket is the number of bytes that were still available
// in the current DTLS packet, up to a budget of maxPacketLen.
bytesAvailableInPacket int
// skipRecordVersionCheck, if true, causes the DTLS record layer to skip the
// record version check, even if the version is known. This is used when
// simulating retransmits.
skipRecordVersionCheck bool
// echAccepted indicates whether ECH was accepted for this connection.
echAccepted bool
tmp [16]byte
}
func (c *Conn) init() {
c.in.isDTLS = c.isDTLS
c.out.isDTLS = c.isDTLS
c.in.config = c.config
c.out.config = c.config
c.in.conn = c
c.out.conn = c
c.maxPacketLen = c.config.Bugs.MaxPacketLength
}
// Access to net.Conn methods.
// Cannot just embed net.Conn because that would
// export the struct field too.
// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline sets the read and write deadlines associated with the connection.
// A zero value for t means Read and Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
// SetReadDeadline sets the read deadline on the underlying connection.
// A zero value for t means Read will not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline sets the write deadline on the underlying conneciton.
// A zero value for t means Write will not time out.
// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
func (c *Conn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
// Arbitrarily cap the number of past epochs to 4. This is far more than is
// necessary. We set a limit only so tests can freely trigger unboundedly many
// KeyUpdates.
const maxEpochs = 4
type epochState struct {
epoch uint16
cipher any // cipher algorithm
recordNumberEncrypter recordNumberEncrypter
mac macFunction
seq [8]byte
}
// A halfConn represents one direction of the record layer
// connection, either sending or receiving.
type halfConn struct {
sync.Mutex
err error // first permanent error
version uint16 // protocol version
wireVersion uint16 // wire version
isDTLS bool
epoch epochState
pastEpochs []epochState
nextEpoch epochState
// used to save allocating a new buffer for each MAC.
macBuf []byte
trafficSecret []byte
config *Config
conn *Conn
}
func (hc *halfConn) setErrorLocked(err error) error {
hc.err = err
return err
}
func (hc *halfConn) error() error {
// This should be locked, but I've removed it for the renegotiation
// tests since we don't concurrently read and write the same tls.Conn
// in any case during testing.
err := hc.err
return err
}
func (hc *halfConn) getEpoch(epochValue uint16) (*epochState, bool) {
if hc.epoch.epoch == epochValue {
return &hc.epoch, true
}
for i := range hc.pastEpochs {
if hc.pastEpochs[i].epoch == epochValue {
return &hc.pastEpochs[i], true
}
}
return nil, false
}
func (hc *halfConn) changeEpoch(epoch epochState) {
if len(hc.pastEpochs) < maxEpochs {
hc.pastEpochs = append(hc.pastEpochs, hc.epoch)
} else {
for i := 1; i < len(hc.pastEpochs); i++ {
hc.pastEpochs[i-1] = hc.pastEpochs[i]
}
hc.pastEpochs[len(hc.pastEpochs)-1] = hc.epoch
}
hc.epoch = epoch
}
func (hc *halfConn) newEpochState(epoch uint16, cipher any, mac macFunction) epochState {
ret := epochState{epoch: epoch, cipher: cipher, mac: mac}
if hc.isDTLS {
binary.BigEndian.PutUint16(ret.seq[:2], epoch)
}
return ret
}
// prepareCipherSpec sets the encryption and MAC states
// that a subsequent changeCipherSpec will use.
func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac macFunction) {
hc.wireVersion = version
protocolVersion, ok := wireToVersion(version, hc.isDTLS)
if !ok {
panic("TLS: unknown version")
}
hc.version = protocolVersion
epoch := hc.epoch.epoch + 1
if epoch == 0 {
panic("TLS: epoch overflow")
}
hc.nextEpoch = hc.newEpochState(epoch, cipher, mac)
}
// changeCipherSpec changes the encryption and MAC states
// to the ones previously passed to prepareCipherSpec.
func (hc *halfConn) changeCipherSpec() error {
if hc.nextEpoch.cipher == nil {
return alertInternalError
}
hc.changeEpoch(hc.nextEpoch)
hc.nextEpoch = epochState{}
if hc.config.Bugs.NullAllCiphers {
hc.epoch.cipher = nullCipher{}
hc.epoch.mac = nil
}
return nil
}
// useTrafficSecret sets the current cipher state for TLS 1.3.
func (hc *halfConn) useTrafficSecret(version uint16, suite *cipherSuite, secret []byte, side trafficDirection, epoch uint16) {
hc.wireVersion = version
protocolVersion, ok := wireToVersion(version, hc.isDTLS)
if !ok {
panic("TLS: unknown version")
}
hc.version = protocolVersion
newEpoch := hc.newEpochState(epoch, deriveTrafficAEAD(version, suite, secret, side, hc.isDTLS), nil)
if hc.isDTLS && !hc.config.Bugs.NullAllCiphers {
sn_key := hkdfExpandLabel(suite.hash(), secret, []byte("sn"), nil, suite.keyLen, hc.isDTLS)
switch suite.id {
case TLS_CHACHA20_POLY1305_SHA256:
newEpoch.recordNumberEncrypter = newChachaRecordNumberEncrypter(sn_key)
case TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384:
newEpoch.recordNumberEncrypter = newAESRecordNumberEncrypter(sn_key)
default:
panic("Cipher suite does not support TLS 1.3")
}
}
if hc.config.Bugs.NullAllCiphers {
newEpoch.cipher = nullCipher{}
}
hc.trafficSecret = secret
hc.changeEpoch(newEpoch)
}
// resetCipher resets the cipher state back to no encryption to be able
// to send an unencrypted ClientHello in response to HelloRetryRequest
// after 0-RTT data was rejected.
func (hc *halfConn) resetCipher() {
initialEpoch, ok := hc.getEpoch(0)
if !ok {
panic("tls: could not find initial epoch")
}
hc.epoch = *initialEpoch
hc.pastEpochs = nil
}
// incSeq increments the sequence number.
func (hc *halfConn) incSeq(epoch *epochState) {
limit := 0
increment := uint64(1)
if hc.isDTLS {
// Increment up to the epoch in DTLS.
limit = 2
}
for i := 7; i >= limit; i-- {
increment += uint64(epoch.seq[i])
epoch.seq[i] = byte(increment)
increment >>= 8
}
// Not allowed to let sequence number wrap.
// Instead, must renegotiate before it does.
// Not likely enough to bother.
if increment != 0 {
panic("TLS: sequence number wraparound")
}
}
// lastRecordNumber returns the most recent record number decrypted or encrypted
// on a halfConn.
//
// TODO(crbug.com/376641666): This function is a bit hacky. It needs to rewind
// the state back to what the last call actually used. Fix the TLS/DTLS
// abstractions so we can return this value out directly.
func (hc *halfConn) lastRecordNumber(epoch *epochState, isOut bool) DTLSRecordNumber {
seq := binary.BigEndian.Uint64(epoch.seq[:])
// We maintain the next record number, so undo the increment.
if seq&(1<<48-1) == 0 {
panic("tls: epoch has never been used")
}
seq--
if hc.isDTLS {
if isOut && hc.config.Bugs.SequenceNumberMapping != nil {
seq = hc.config.Bugs.SequenceNumberMapping(seq)
}
// Remove the embedded epoch number.
seq &= 1<<48 - 1
}
return DTLSRecordNumber{Epoch: uint64(epoch.epoch), Sequence: seq}
}
func (hc *halfConn) sequenceNumberForOutput(epoch *epochState) []byte {
if !hc.isDTLS || hc.config.Bugs.SequenceNumberMapping == nil {
return epoch.seq[:]
}
var seq [8]byte
seqU64 := binary.BigEndian.Uint64(epoch.seq[:])
seqU64 = hc.config.Bugs.SequenceNumberMapping(seqU64)
binary.BigEndian.PutUint64(seq[:], seqU64)
// The DTLS epoch cannot be changed.
copy(seq[:2], epoch.seq[:2])
return seq[:]
}
func (hc *halfConn) explicitIVLen(epoch *epochState) int {
if epoch.cipher == nil {
return 0
}
switch c := epoch.cipher.(type) {
case cipher.Stream:
return 0
case *tlsAead:
if c.explicitNonce {
return 8
}
return 0
case *cbcMode:
if hc.version >= VersionTLS11 || hc.isDTLS {
return c.BlockSize()
}
return 0
case nullCipher:
return 0
default:
panic("unknown cipher type")
}
}
func (hc *halfConn) computeMAC(epoch *epochState, seq, header, data []byte) []byte {
hc.macBuf = epoch.mac.MAC(hc.macBuf[:0], seq, header[:3], header[len(header)-2:], data)
return hc.macBuf
}
// removePadding returns an unpadded slice, in constant time, which is a prefix
// of the input. It also returns a byte which is equal to 255 if the padding
// was valid and 0 otherwise. See RFC 2246, section 6.2.3.2
func removePadding(payload []byte) ([]byte, byte) {
if len(payload) < 1 {
return payload, 0
}
paddingLen := payload[len(payload)-1]
t := uint(len(payload)-1) - uint(paddingLen)
// if len(payload) >= (paddingLen - 1) then the MSB of t is zero
good := byte(int32(^t) >> 31)
toCheck := 255 // the maximum possible padding length
// The length of the padded data is public, so we can use an if here
if toCheck+1 > len(payload) {
toCheck = len(payload) - 1
}
for i := 0; i < toCheck; i++ {
t := uint(paddingLen) - uint(i)
// if i <= paddingLen then the MSB of t is zero
mask := byte(int32(^t) >> 31)
b := payload[len(payload)-1-i]
good &^= mask&paddingLen ^ mask&b
}
// We AND together the bits of good and replicate the result across
// all the bits.
good &= good << 4
good &= good << 2
good &= good << 1
good = uint8(int8(good) >> 7)
toRemove := good&paddingLen + 1
return payload[:len(payload)-int(toRemove)], good
}
func roundUp(a, b int) int {
return a + (b-a%b)%b
}
// decrypt checks and strips the mac and decrypts the data in record. Returns a
// success boolean, the application payload, the encrypted record type (or 0
// if there is none), and an optional alert value. Decryption occurs in-place,
// so the contents of record will be overwritten as part of this process.
func (hc *halfConn) decrypt(epoch *epochState, recordHeaderLen int, record []byte) (ok bool, contentType recordType, data []byte, alertValue alert) {
// pull out payload
payload := record[recordHeaderLen:]
macSize := 0
if epoch.mac != nil {
macSize = epoch.mac.Size()
}
paddingGood := byte(255)
explicitIVLen := hc.explicitIVLen(epoch)
// decrypt
if epoch.cipher != nil {
switch c := epoch.cipher.(type) {
case cipher.Stream:
c.XORKeyStream(payload, payload)
case *tlsAead:
nonce := epoch.seq[:]
if hc.isDTLS && hc.version >= VersionTLS13 && !hc.conn.useDTLSPlaintextHeader() {
// Unlike DTLS 1.2, DTLS 1.3's nonce construction does not use
// the epoch number. We store the epoch and nonce numbers
// together, so make a copy without the epoch.
nonce = make([]byte, 8)
copy(nonce[2:], epoch.seq[2:])
}
if explicitIVLen != 0 {
if len(payload) < explicitIVLen {
return false, 0, nil, alertBadRecordMAC
}
nonce = payload[:explicitIVLen]
payload = payload[explicitIVLen:]
}
var additionalData []byte
if hc.version < VersionTLS13 {
additionalData = make([]byte, 13)
copy(additionalData, epoch.seq[:])
copy(additionalData[8:], record[:3])
n := len(payload) - c.Overhead()
additionalData[11] = byte(n >> 8)
additionalData[12] = byte(n)
} else {
additionalData = record[:recordHeaderLen]
}
var err error
payload, err = c.Open(payload[:0], nonce, payload, additionalData)
if err != nil {
return false, 0, nil, alertBadRecordMAC
}
case *cbcMode:
blockSize := c.BlockSize()
if len(payload)%blockSize != 0 || len(payload) < roundUp(explicitIVLen+macSize+1, blockSize) {
return false, 0, nil, alertBadRecordMAC
}
if explicitIVLen > 0 {
c.SetIV(payload[:explicitIVLen])
payload = payload[explicitIVLen:]
}
c.CryptBlocks(payload, payload)
payload, paddingGood = removePadding(payload)
// note that we still have a timing side-channel in the
// MAC check, below. An attacker can align the record
// so that a correct padding will cause one less hash
// block to be calculated. Then they can iteratively
// decrypt a record by breaking each byte. See
// "Password Interception in a SSL/TLS Channel", Brice
// Canvel et al.
//
// However, our behavior matches OpenSSL, so we leak
// only as much as they do.
case nullCipher:
break
default:
panic("unknown cipher type")
}
if hc.version >= VersionTLS13 {
i := len(payload)
for i > 0 && payload[i-1] == 0 {
i--
}
payload = payload[:i]
if len(payload) == 0 {
return false, 0, nil, alertUnexpectedMessage
}
contentType = recordType(payload[len(payload)-1])
payload = payload[:len(payload)-1]
}
}
// check, strip mac
if epoch.mac != nil {
if len(payload) < macSize {
return false, 0, nil, alertBadRecordMAC
}
// strip mac off payload
n := len(payload) - macSize
remoteMAC := payload[n:]
payload = payload[:n]
record[recordHeaderLen-2] = byte(n >> 8)
record[recordHeaderLen-1] = byte(n)
localMAC := hc.computeMAC(epoch, epoch.seq[:], record[:recordHeaderLen], payload)
if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
return false, 0, nil, alertBadRecordMAC
}
}
hc.incSeq(epoch)
return true, contentType, payload, 0
}
// extendSlice updates *data to contain n more bytes and returns a slice
// containing the bytes that were added.
func extendSlice(data *[]byte, n int) []byte {
// Reallocate the slice if needed.
*data = slices.Grow(*data, n)
// Extend data into the capacity and return the newly added slice.
oldLen := len(*data)
newLen := oldLen + n
*data = (*data)[:newLen]
return (*data)[oldLen:newLen]
}
// computingCBCPaddingLength returns the number of bytes of CBC padding to use
// for a payload (plaintext + MAC) of length payloadLen.
func computingCBCPaddingLength(payloadLen, blockSize int, config *Config) int {
paddingLen := blockSize - payloadLen%blockSize
if config.Bugs.MaxPadding {
for paddingLen+blockSize <= 256 {
paddingLen += blockSize
}
}
return paddingLen
}
// appendCBCPadding computes paddingLen bytes of padding data, appends it to b,
// and returns the result.
func appendCBCPadding(b []byte, paddingLen int, config *Config) []byte {
padding := extendSlice(&b, paddingLen)
for i := range padding {
padding[i] = byte(paddingLen - 1)
}
if config.Bugs.PaddingFirstByteBad || config.Bugs.PaddingFirstByteBadIf255 && paddingLen == 256 {
padding[0] ^= 0xff
}
return b
}
func (hc *halfConn) maxEncryptOverhead(epoch *epochState, payloadLen int) int {
var macSize int
if epoch.mac != nil {
macSize = epoch.mac.Size()
}
overhead := macSize + hc.explicitIVLen(epoch)
if hc.version >= VersionTLS13 {
overhead += 1 + hc.config.Bugs.RecordPadding // type + padding
}
if epoch.cipher != nil {
switch c := epoch.cipher.(type) {
case cipher.Stream, *nullCipher:
case *tlsAead:
overhead += c.Overhead()
case *cbcMode:
overhead += computingCBCPaddingLength(payloadLen+macSize, c.BlockSize(), hc.config)
case nullCipher:
break
default:
panic("unknown cipher type")
}
}
return overhead
}
func (c *Conn) useDTLSPlaintextHeader() bool {
return c.config.Bugs.DTLSUsePlaintextRecordHeader && c.handshakeComplete
}
// encrypt encrypts and MACs the data in payload, appending it record. On
// entry, the last headerLen bytes of record must be the header. The length
// (which must be in the last two bytes of the header) should be computed for
// the unencrypted, unpadded payload. It will be updated, potentially in-place,
// with the final length.
func (hc *halfConn) encrypt(epoch *epochState, record, payload []byte, typ recordType, headerLen int, headerHasLength bool) ([]byte, error) {
seq := hc.sequenceNumberForOutput(epoch)
prefixLen := len(record)
header := record[prefixLen-headerLen:]
explicitIVLen := hc.explicitIVLen(epoch)
// Reserve some space for the explicit IV. The slice may get reallocated
// after this, so don't use the return value.
extendSlice(&record, explicitIVLen)
// Stage the plaintext, TLS 1.3 padding, and TLS 1.2 MAC in the record, to
// be encrypted in-place.
record = append(record, payload...)
if hc.version >= VersionTLS13 && epoch.cipher != nil {
if hc.config.Bugs.OmitRecordContents {
record = record[:len(record)-len(payload)]
} else {
record = append(record, byte(typ))
}
padding := extendSlice(&record, hc.config.Bugs.RecordPadding)
clear(padding)
}
if epoch.mac != nil {
record = append(record, hc.computeMAC(epoch, seq, header, payload)...)
}
explicitIV := record[prefixLen : prefixLen+explicitIVLen]
if epoch.cipher != nil {
switch c := epoch.cipher.(type) {
case cipher.Stream:
if explicitIVLen != 0 {
panic("tls: unexpected explicit IV length")
}
c.XORKeyStream(record[prefixLen:], record[prefixLen:])
case *tlsAead:
nonce := seq
if hc.isDTLS && hc.version >= VersionTLS13 && !hc.conn.useDTLSPlaintextHeader() {
// Unlike DTLS 1.2, DTLS 1.3's nonce construction does not use
// the epoch number. We store the epoch and nonce numbers
// together, so make a copy without the epoch.
nonce = make([]byte, 8)
copy(nonce[2:], seq[2:])
}
// Save the explicit IV, if not empty.
if len(explicitIV) != 0 {
if explicitIVLen != len(nonce) {
panic("tls: unexpected explicit IV length")
}
copy(explicitIV, nonce)
}
var additionalData []byte
if hc.version < VersionTLS13 {
// (D)TLS 1.2's AD is seq_num || type || version || plaintext length
additionalData = make([]byte, 13)
copy(additionalData, seq)
copy(additionalData[8:], header[:3])
additionalData[11] = byte(len(payload) >> 8)
additionalData[12] = byte(len(payload))
} else {
// (D)TLS 1.3's AD is the ciphertext record header, so update the
// length now.
if headerHasLength {
n := len(record) - prefixLen + c.Overhead()
record[prefixLen-2] = byte(n >> 8)
record[prefixLen-1] = byte(n)
}
additionalData = record[prefixLen-headerLen : prefixLen]
}
record = c.Seal(record[:prefixLen+explicitIVLen], nonce, record[prefixLen+explicitIVLen:], additionalData)
case *cbcMode:
if explicitIVLen > 0 {
if _, err := io.ReadFull(hc.config.rand(), explicitIV); err != nil {
return nil, err
}
c.SetIV(explicitIV)
}
blockSize := c.BlockSize()
paddingLen := computingCBCPaddingLength(len(record)-prefixLen, blockSize, hc.config)
record = appendCBCPadding(record, paddingLen, hc.config)
c.CryptBlocks(record[prefixLen:], record[prefixLen:])
case nullCipher:
break
default:
panic("unknown cipher type")
}
}
// Update the record header to include the encryption overhead.
if headerHasLength {
n := len(record) - prefixLen
record[prefixLen-2] = byte(n >> 8)
record[prefixLen-1] = byte(n)
}
hc.incSeq(epoch)
return record, nil
}
type recordNumberEncrypter interface {
// GenerateMask takes a sample of the encrypted record and returns the
// mask used to encrypt and decrypt record numbers.
generateMask(sample []byte) []byte
}
type aesRecordNumberEncrypter struct {
aesCipher cipher.Block
}
func newAESRecordNumberEncrypter(key []byte) *aesRecordNumberEncrypter {
aesCipher, err := aes.NewCipher(key)
if err != nil {
panic("Incorrect usage of newAESRecordNumberEncrypter")
}
return &aesRecordNumberEncrypter{
aesCipher: aesCipher,
}
}
func (a *aesRecordNumberEncrypter) generateMask(sample []byte) []byte {
out := make([]byte, len(sample))
a.aesCipher.Encrypt(out, sample)
return out
}
type chachaRecordNumberEncrypter struct {
key []byte
}
func newChachaRecordNumberEncrypter(key []byte) *chachaRecordNumberEncrypter {
out := &chachaRecordNumberEncrypter{
key: key,
}
return out
}
func (c *chachaRecordNumberEncrypter) generateMask(sample []byte) []byte {
var counter, nonce []byte
sampleReader := cryptobyte.String(sample)
if !sampleReader.ReadBytes(&counter, 4) || !sampleReader.ReadBytes(&nonce, 12) {
panic("chachaRecordNumberEncrypter.GenerateMask called with wrong size sample")
}
cipher, err := chacha20.NewUnauthenticatedCipher(c.key, nonce)
if err != nil {
panic("Failed to create chacha20 cipher for record number encryption")
}
cipher.SetCounter(binary.LittleEndian.Uint32(counter))
out := make([]byte, 2)
cipher.XORKeyStream(out, out)
return out
}
func (c *Conn) useInTrafficSecret(epoch uint16, version uint16, suite *cipherSuite, secret []byte) error {
if c.hand.Len() != 0 {
return c.in.setErrorLocked(errors.New("tls: buffered handshake messages on cipher change"))
}
side := serverWrite
if !c.isClient {
side = clientWrite
}
if c.config.Bugs.MockQUICTransport != nil {
if epoch > uint16(encryptionApplication) {
panic("tls: KeyUpdate processed in QUIC")
}
c.config.Bugs.MockQUICTransport.readLevel = encryptionLevel(epoch)
c.config.Bugs.MockQUICTransport.readSecret = secret
c.config.Bugs.MockQUICTransport.readCipherSuite = suite.id
}
c.in.useTrafficSecret(version, suite, secret, side, epoch)
c.seenHandshakePackEnd = false
return nil
}
func (c *Conn) useOutTrafficSecret(epoch uint16, version uint16, suite *cipherSuite, secret []byte) {
if !c.isDTLS {
// The TLS logic relies on flushHandshake to write out packed handshake
// data on key changes. The DTLS logic handles key changes directly.
c.flushHandshake()
}
side := serverWrite
if c.isClient {
side = clientWrite
}
if c.config.Bugs.MockQUICTransport != nil {
if epoch > uint16(encryptionApplication) {
panic("tls: KeyUpdate processed in QUIC")
}
c.config.Bugs.MockQUICTransport.writeLevel = encryptionLevel(epoch)
c.config.Bugs.MockQUICTransport.writeSecret = secret
c.config.Bugs.MockQUICTransport.writeCipherSuite = suite.id
}
c.out.useTrafficSecret(version, suite, secret, side, epoch)
}
func (c *Conn) setSkipEarlyData() {
if c.config.Bugs.MockQUICTransport != nil {
c.config.Bugs.MockQUICTransport.skipEarlyData = true
} else {
c.skipEarlyData = true
}
}
func (c *Conn) shouldSkipEarlyData() bool {
if c.config.Bugs.MockQUICTransport != nil {
return c.config.Bugs.MockQUICTransport.skipEarlyData
}
return c.skipEarlyData
}
func (c *Conn) readRawInputUntil(n int) error {
if c.rawInput.Len() >= n {
return nil
}
n -= c.rawInput.Len()
c.rawInput.Grow(n)
buf := c.rawInput.AvailableBuffer()
nread, err := io.ReadAtLeast(c.conn, buf[:cap(buf)], n)
c.rawInput.Write(buf[:nread])
return err
}
func (c *Conn) doReadRecord(want recordType) (recordType, []byte, error) {
RestartReadRecord:
if c.isDTLS {
return c.dtlsDoReadRecord(&c.in.epoch, want)
}
recordHeaderLen := tlsRecordHeaderLen
// Read header, payload.
if err := c.readRawInputUntil(recordHeaderLen); err != nil {
// RFC suggests that EOF without an alertCloseNotify is
// an error, but popular web sites seem to do this,
// so we can't make it an error, outside of tests.
if err == io.EOF && c.config.Bugs.ExpectCloseNotify {
err = io.ErrUnexpectedEOF
}
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.in.setErrorLocked(err)
}
return 0, nil, err
}
header := c.rawInput.Bytes()[:recordHeaderLen]
typ := recordType(header[0])
// No valid TLS record has a type of 0x80, however SSLv2 handshakes
// start with a uint16 length where the MSB is set and the first record
// is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
// an SSLv2 client.
if want == recordTypeHandshake && typ == 0x80 {
c.sendAlert(alertProtocolVersion)
return 0, nil, c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received"))
}
vers := uint16(header[1])<<8 | uint16(header[2])
n := int(header[3])<<8 | int(header[4])
// Alerts sent near version negotiation do not have a well-defined
// record-layer version prior to TLS 1.3. (In TLS 1.3, the record-layer
// version is irrelevant.)
if typ != recordTypeAlert {
var expect uint16
if c.haveVers {
expect = c.vers
if c.vers >= VersionTLS13 {
expect = VersionTLS12
}
} else {
expect = c.config.Bugs.ExpectInitialRecordVersion
}
if expect != 0 && vers != expect {
c.sendAlert(alertProtocolVersion)
return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, expect))
}
}
if n > maxCiphertext {
c.sendAlert(alertRecordOverflow)
return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n))
}
if !c.haveVers {
// First message, be extra suspicious:
// this might not be a TLS client.
// Bail out before reading a full 'body', if possible.
// The current max version is 3.1.
// If the version is >= 16.0, it's probably not real.
// Similarly, a clientHello message encodes in
// well under a kilobyte. If the length is >= 12 kB,
// it's probably not real.
if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
c.sendAlert(alertUnexpectedMessage)
return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake"))
}
}
if err := c.readRawInputUntil(recordHeaderLen + n); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.in.setErrorLocked(err)
}
return 0, nil, err
}
// Process message.
b := c.rawInput.Next(recordHeaderLen + n)
epoch := &c.in.epoch
ok, encTyp, data, alertValue := c.in.decrypt(epoch, recordHeaderLen, b)
if !ok {
// TLS 1.3 early data uses trial decryption.
if c.skipEarlyData {
goto RestartReadRecord
}
return 0, nil, c.in.setErrorLocked(c.sendAlert(alertValue))
}
// If the server is expecting a second ClientHello (in response to
// a HelloRetryRequest) and the client sends early data, there
// won't be a decryption failure (we will interpret the ciphertext
// as plaintext application data) but it still needs to be skipped.
if epoch.cipher == nil && typ == recordTypeApplicationData && c.skipEarlyData {
goto RestartReadRecord
}
c.skipEarlyData = false
if c.vers >= VersionTLS13 && epoch.cipher != nil {
if typ != recordTypeApplicationData {
return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: outer record type is not application data"))
}
typ = encTyp
}
if c.config.Bugs.ExpectRecordSplitting && typ == recordTypeApplicationData && len(data) != 1 && !c.seenOneByteRecord {
return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: application data records were not split"))
}
c.seenOneByteRecord = typ == recordTypeApplicationData && len(data) == 1
return typ, data, nil
}
func (c *Conn) readTLS13ChangeCipherSpec() error {
if c.config.Bugs.MockQUICTransport != nil {
return nil
}
if c.isDTLS {
// ChangeCipherSpec in DTLS 1.3 is handled within dtlsDoReadRecord.
return nil
}
if !c.expectTLS13ChangeCipherSpec {
panic("c.expectTLS13ChangeCipherSpec not set")
}
// Read the ChangeCipherSpec.
if err := c.readRawInputUntil(6); err != nil {
return c.in.setErrorLocked(fmt.Errorf("tls: error reading TLS 1.3 ChangeCipherSpec: %s", err))
}
if recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
// If the client is sending an alert, allow the ChangeCipherSpec
// to be skipped. It may be rejecting a sufficiently malformed
// ServerHello that it can't parse out the version.
c.expectTLS13ChangeCipherSpec = false
return nil
}
// Check they match that we expect.
expected := [6]byte{byte(recordTypeChangeCipherSpec), 3, 1, 0, 1, 1}
if c.vers >= VersionTLS13 {
expected[2] = 3
}
if data := c.rawInput.Bytes()[:6]; !bytes.Equal(data, expected[:]) {
return c.in.setErrorLocked(fmt.Errorf("tls: error invalid TLS 1.3 ChangeCipherSpec: %x", data))
}
// Discard the data.
c.rawInput.Next(6)
c.expectTLS13ChangeCipherSpec = false
return nil
}
// readRecord reads the next TLS record from the connection
// and updates the record layer state.
// c.in.Mutex <= L; c.input == nil.
func (c *Conn) readRecord(want recordType) error {
// Caller must be in sync with connection:
// handshake data if handshake not yet completed,
// else application data.
switch want {
default:
c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
case recordTypeChangeCipherSpec:
if c.handshakeComplete {
c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: ChangeCipherSpec requested after handshake complete"))
}
case recordTypeApplicationData, recordTypeAlert, recordTypeHandshake, recordTypeACK:
break
}
if c.expectTLS13ChangeCipherSpec {
if err := c.readTLS13ChangeCipherSpec(); err != nil {
return err
}
}
Again:
doReadRecord := c.doReadRecord
if c.config.Bugs.MockQUICTransport != nil {
doReadRecord = c.config.Bugs.MockQUICTransport.readRecord
}
typ, data, err := doReadRecord(want)
if err != nil {
return err
}
max := maxPlaintext
if c.config.Bugs.MaxReceivePlaintext != 0 {
max = c.config.Bugs.MaxReceivePlaintext
}
if len(data) > max {
err := c.sendAlert(alertRecordOverflow)
return c.in.setErrorLocked(err)
}
if typ != recordTypeHandshake {
c.seenHandshakePackEnd = false
} else if c.seenHandshakePackEnd {
return c.in.setErrorLocked(errors.New("tls: peer violated ExpectPackedEncryptedHandshake"))
}
switch typ {
default:
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
case recordTypeAlert:
if len(data) != 2 {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
if alert(data[1]) == alertCloseNotify {
c.in.setErrorLocked(io.EOF)
break
}
switch data[0] {
case alertLevelWarning:
// drop on the floor
goto Again
case alertLevelError:
c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
default:
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
case recordTypeChangeCipherSpec:
if typ != want || len(data) != 1 || data[0] != 1 {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
if c.hand.Len() != 0 {
c.in.setErrorLocked(errors.New("tls: buffered handshake messages on cipher change"))
break
}
if c.isDTLS {
// Track the ChangeCipherSpec record in the current flight.
c.receivedFlight = append(c.receivedFlight, DTLSMessage{
Epoch: c.in.epoch.epoch,
IsChangeCipherSpec: true,
Data: slices.Clone(data),
})
}
if err := c.in.changeCipherSpec(); err != nil {
c.in.setErrorLocked(c.sendAlert(err.(alert)))
}
case recordTypeApplicationData:
if typ != want {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
c.input.Write(data)
case recordTypeHandshake:
// Allow handshake data while reading application data to
// trigger post-handshake messages.
// TODO(rsc): Should at least pick off connection close.
if typ != want && want != recordTypeApplicationData {
return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
}
c.hand.Write(data)
if pack := c.config.Bugs.ExpectPackedEncryptedHandshake; pack > 0 && len(data) < pack && c.out.epoch.cipher != nil {
c.seenHandshakePackEnd = true
}
if c.isDTLS {
record, err := c.makeDTLSRecordNumberInfo(&c.in.epoch, c.hand.Bytes())
if err != nil {
return err
}
c.receivedFlightRecords = append(c.receivedFlightRecords, record)
}
case recordTypeACK:
if typ != want || !c.isDTLS {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
if err := c.checkACK(data); err != nil {
c.in.setErrorLocked(err)
break
}
}
return c.in.err
}
// sendAlert sends a TLS alert message.
// c.out.Mutex <= L.
func (c *Conn) sendAlertLocked(level byte, err alert) error {
c.tmp[0] = level
c.tmp[1] = byte(err)
if c.config.Bugs.FragmentAlert {
c.writeRecord(recordTypeAlert, c.tmp[0:1])
c.writeRecord(recordTypeAlert, c.tmp[1:2])
} else if c.config.Bugs.DoubleAlert {
copy(c.tmp[2:4], c.tmp[0:2])
c.writeRecord(recordTypeAlert, c.tmp[0:4])
} else {
c.writeRecord(recordTypeAlert, c.tmp[0:2])
}
// Error alerts are fatal to the connection.
if level == alertLevelError {
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
return nil
}
// sendAlert sends a TLS alert message.
// L < c.out.Mutex.
func (c *Conn) sendAlert(err alert) error {
level := byte(alertLevelError)
if err == alertNoRenegotiation || err == alertCloseNotify {
level = alertLevelWarning
}
return c.SendAlert(level, err)
}
func (c *Conn) SendAlert(level byte, err alert) error {
c.out.Lock()
defer c.out.Unlock()
return c.sendAlertLocked(level, err)
}
// writeV2Record writes a record for a V2ClientHello.
func (c *Conn) writeV2Record(data []byte) (n int, err error) {
record := make([]byte, 2+len(data))
record[0] = uint8(len(data)>>8) | 0x80
record[1] = uint8(len(data))
copy(record[2:], data)
return c.conn.Write(record)
}
// writeRecord writes a TLS record with the given type and payload
// to the connection and updates the record layer state.
// c.out.Mutex <= L.
func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
c.seenHandshakePackEnd = false
if c.hand.Len() == 0 {
c.lastRecordInFlight = nil
}
if typ == recordTypeHandshake {
msgType := data[0]
if c.config.Bugs.SendWrongMessageType != 0 && msgType == c.config.Bugs.SendWrongMessageType {
msgType += 42
}
if msgType != data[0] {
data = append([]byte{msgType}, data[1:]...)
}
if c.config.Bugs.SendTrailingMessageData != 0 && msgType == c.config.Bugs.SendTrailingMessageData {
// Add a 0 to the body.
newData := make([]byte, len(data)+1)
copy(newData, data)
// Fix the header.
newLen := len(newData) - 4
newData[1] = byte(newLen >> 16)
newData[2] = byte(newLen >> 8)
newData[3] = byte(newLen)
data = newData
}
if c.config.Bugs.TrailingDataWithFinished && msgType == typeFinished {
// Add a 0 to the record. Note unused bytes in |data| may be owned by the
// caller, so we force a new allocation.
data = append(data[:len(data):len(data)], 0)
}
}
if c.isDTLS {
return c.dtlsWriteRecord(typ, data)
}
if c.config.Bugs.MockQUICTransport != nil {
return c.config.Bugs.MockQUICTransport.writeRecord(typ, data)
}
if typ == recordTypeHandshake {
if c.config.Bugs.SendHelloRequestBeforeEveryHandshakeMessage {
newData := make([]byte, 0, 4+len(data))
newData = append(newData, typeHelloRequest, 0, 0, 0)
newData = append(newData, data...)
data = newData
}
if c.config.Bugs.PackHandshakeFlight {
c.pendingFlight.Write(data)
return len(data), nil
}
}
// Flush buffered data before writing anything.
if err := c.flushHandshake(); err != nil {
return 0, err
}
if typ == recordTypeApplicationData && c.config.Bugs.SendPostHandshakeChangeCipherSpec {
if _, err := c.doWriteRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
return 0, err
}
}
return c.doWriteRecord(typ, data)
}
func (c *Conn) doWriteRecord(typ recordType, data []byte) (n int, err error) {
first := true
for len(data) > 0 || first {
m := len(data)
if m > maxPlaintext && !c.config.Bugs.SendLargeRecords {
m = maxPlaintext
}
if typ == recordTypeHandshake && c.config.Bugs.MaxHandshakeRecordLength > 0 && m > c.config.Bugs.MaxHandshakeRecordLength {
m = c.config.Bugs.MaxHandshakeRecordLength
}
first = false
// Determine record version.
vers := c.vers
if vers == 0 {
// Some TLS servers fail if the record version is
// greater than TLS 1.0 for the initial ClientHello.
//
// TLS 1.3 fixes the version number in the record
// layer to {3, 1}.
vers = VersionTLS10
}
if c.vers >= VersionTLS13 || c.out.version >= VersionTLS13 {
vers = VersionTLS12
}
if c.config.Bugs.SendRecordVersion != 0 {
vers = c.config.Bugs.SendRecordVersion
}
if c.vers == 0 && c.config.Bugs.SendInitialRecordVersion != 0 {
vers = c.config.Bugs.SendInitialRecordVersion
}
// Assemble the record header.
epoch := &c.out.epoch
record := make([]byte, tlsRecordHeaderLen, tlsRecordHeaderLen+m+c.out.maxEncryptOverhead(epoch, m))
record[0] = byte(typ)
if c.vers >= VersionTLS13 && epoch.cipher != nil {
record[0] = byte(recordTypeApplicationData)
if outerType := c.config.Bugs.OuterRecordType; outerType != 0 {
record[0] = byte(outerType)
}
}
record[1] = byte(vers >> 8)
record[2] = byte(vers)
record[3] = byte(m >> 8) // encrypt will update this
record[4] = byte(m)
record, err = c.out.encrypt(epoch, record, data[:m], typ, tlsRecordHeaderLen, true /* header has length */)
if err != nil {
return
}
_, err = c.conn.Write(record)
if err != nil {
break
}
n += m
data = data[m:]
}
if typ == recordTypeChangeCipherSpec && c.vers < VersionTLS13 {
err = c.out.changeCipherSpec()
if err != nil {
return n, c.sendAlertLocked(alertLevelError, err.(alert))
}
}
return
}
func (c *Conn) flushHandshake() error {
if c.isDTLS {
return c.dtlsFlushHandshake()
}
for c.pendingFlight.Len() > 0 {
var buf [maxPlaintext]byte
n, _ := c.pendingFlight.Read(buf[:])
if _, err := c.doWriteRecord(recordTypeHandshake, buf[:n]); err != nil {
return err
}
}
c.pendingFlight.Reset()
return nil
}
func (c *Conn) ackHandshake() error {
if c.isDTLS {
return c.dtlsACKHandshake()
}
return nil
}
func (c *Conn) doReadHandshake() ([]byte, error) {
if c.isDTLS {
return c.dtlsDoReadHandshake()
}
for c.hand.Len() < 4 {
if err := c.in.err; err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
return nil, err
}
}
data := c.hand.Bytes()
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
}
for c.hand.Len() < 4+n {
if err := c.in.err; err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
return nil, err
}
}
return c.hand.Next(4 + n), nil
}
// readHandshake reads the next handshake message from
// the record layer.
// c.in.Mutex < L; c.out.Mutex < L.
func (c *Conn) readHandshake() (any, error) {
data, err := c.doReadHandshake()
if err != nil {
return nil, err
}
var m handshakeMessage
switch data[0] {
case typeHelloRequest:
m = new(helloRequestMsg)
case typeClientHello:
m = &clientHelloMsg{
isDTLS: c.isDTLS,
}
case typeServerHello:
m = &serverHelloMsg{
isDTLS: c.isDTLS,
}
case typeNewSessionTicket:
m = &newSessionTicketMsg{
vers: c.wireVersion,
isDTLS: c.isDTLS,
}
case typeEncryptedExtensions:
if c.isClient {
m = new(encryptedExtensionsMsg)
} else {
m = new(clientEncryptedExtensionsMsg)
}
case typeCertificate:
m = &certificateMsg{
hasRequestContext: c.vers >= VersionTLS13,
}
case typeCompressedCertificate:
m = new(compressedCertificateMsg)
case typeCertificateRequest:
m = &certificateRequestMsg{
vers: c.wireVersion,
hasSignatureAlgorithm: c.vers >= VersionTLS12,
hasRequestContext: c.vers >= VersionTLS13,
}
case typeCertificateStatus:
m = new(certificateStatusMsg)
case typeServerKeyExchange:
m = new(serverKeyExchangeMsg)
case typeServerHelloDone:
m = new(serverHelloDoneMsg)
case typeClientKeyExchange:
m = new(clientKeyExchangeMsg)
case typeCertificateVerify:
m = &certificateVerifyMsg{
hasSignatureAlgorithm: c.vers >= VersionTLS12,
}
case typeNextProtocol:
m = new(nextProtoMsg)
case typeFinished:
m = new(finishedMsg)
case typeHelloVerifyRequest:
m = new(helloVerifyRequestMsg)
case typeChannelID:
m = new(channelIDMsg)
case typeKeyUpdate:
m = new(keyUpdateMsg)
case typeEndOfEarlyData:
m = new(endOfEarlyDataMsg)
default:
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// The handshake message unmarshallers
// expect to be able to keep references to data,
// so pass in a fresh copy that won't be overwritten.
data = slices.Clone(data)
if data[0] == typeServerHello && len(data) >= 38 {
vers := uint16(data[4])<<8 | uint16(data[5])
if vers == VersionTLS12 && bytes.Equal(data[6:38], tls13HelloRetryRequest) {
m = new(helloRetryRequestMsg)
}
}
if !m.unmarshal(data) {
return nil, c.in.setErrorLocked(c.sendAlert(alertDecodeError))
}
return m, nil
}
func readHandshakeType[T any](c *Conn) (*T, error) {
m, err := c.readHandshake()
if err != nil {
return nil, err
}
mType, ok := m.(*T)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return nil, unexpectedMessageError(mType, m)
}
return mType, nil
}
func (c *Conn) SendHalfHelloRequest() error {
if err := c.Handshake(); err != nil {
return err
}
c.out.Lock()
defer c.out.Unlock()
if _, err := c.writeRecord(recordTypeHandshake, []byte{typeHelloRequest, 0}); err != nil {
return err
}
return c.flushHandshake()
}
// Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) {
if err := c.Handshake(); err != nil {
return 0, err
}
c.out.Lock()
defer c.out.Unlock()
if err := c.out.err; err != nil {
return 0, err
}
if !c.handshakeComplete {
return 0, alertInternalError
}
if c.keyUpdateRequested {
if err := c.sendKeyUpdateLocked(keyUpdateNotRequested); err != nil {
return 0, err
}
c.keyUpdateRequested = false
}
if c.config.Bugs.SendSpuriousAlert != 0 {
c.sendAlertLocked(alertLevelError, c.config.Bugs.SendSpuriousAlert)
}
if c.config.Bugs.SendHelloRequestBeforeEveryAppDataRecord {
c.writeRecord(recordTypeHandshake, []byte{typeHelloRequest, 0, 0, 0})
c.flushHandshake()
}
// SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
// attack when using block mode ciphers due to predictable IVs.
// This can be prevented by splitting each Application Data
// record into two records, effectively randomizing the IV.
//
// http://www.openssl.org/~bodo/tls-cbc.txt
// https://bugzilla.mozilla.org/show_bug.cgi?id=665814
// http://www.imperialviolet.org/2012/01/15/beastfollowup.html
var m int
if len(b) > 1 && c.vers <= VersionTLS10 && !c.isDTLS {
if _, ok := c.out.epoch.cipher.(*cbcMode); ok {
n, err := c.writeRecord(recordTypeApplicationData, b[:1])
if err != nil {
return n, c.out.setErrorLocked(err)
}
m, b = 1, b[1:]
}
}
n, err := c.writeRecord(recordTypeApplicationData, b)
return n + m, c.out.setErrorLocked(err)
}
func (c *Conn) processTLS13NewSessionTicket(newSessionTicket *newSessionTicketMsg, cipherSuite *cipherSuite) error {
if c.config.Bugs.ExpectGREASE && !newSessionTicket.hasGREASEExtension {
return errors.New("tls: no GREASE ticket extension found")
}
if c.config.Bugs.ExpectTicketEarlyData && newSessionTicket.maxEarlyDataSize == 0 {
return errors.New("tls: no early_data ticket extension found")
}
if c.config.Bugs.ExpectNoNewSessionTicket || c.config.Bugs.ExpectNoNonEmptyNewSessionTicket {
return errors.New("tls: received unexpected NewSessionTicket")
}
if c.config.ClientSessionCache == nil || newSessionTicket.ticketLifetime == 0 {
return nil
}
session := &ClientSessionState{
sessionTicket: newSessionTicket.ticket,
vers: c.vers,
wireVersion: c.wireVersion,
cipherSuite: cipherSuite,
secret: deriveSessionPSK(cipherSuite, c.wireVersion, c.resumptionSecret, newSessionTicket.ticketNonce, c.isDTLS),
serverCertificates: c.peerCertificates,
sctList: c.sctList,
ocspResponse: c.ocspResponse,
ticketCreationTime: c.config.time(),
ticketExpiration: c.config.time().Add(time.Duration(newSessionTicket.ticketLifetime) * time.Second),
ticketAgeAdd: newSessionTicket.ticketAgeAdd,
maxEarlyDataSize: newSessionTicket.maxEarlyDataSize,
earlyALPN: c.clientProtocol,
hasApplicationSettings: c.hasApplicationSettings,
localApplicationSettings: c.localApplicationSettings,
peerApplicationSettings: c.peerApplicationSettings,
hasApplicationSettingsOld: c.hasApplicationSettingsOld,
localApplicationSettingsOld: c.localApplicationSettingsOld,
peerApplicationSettingsOld: c.peerApplicationSettingsOld,
}
cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
_, ok := c.config.ClientSessionCache.Get(cacheKey)
if !ok || !c.config.Bugs.UseFirstSessionTicket {
c.config.ClientSessionCache.Put(cacheKey, session)
}
return c.ackHandshake()
}
func (c *Conn) processKeyUpdate(keyUpdate *keyUpdateMsg) error {
epoch := c.in.epoch.epoch + 1
if epoch == 0 && !c.config.Bugs.AllowEpochOverflow {
return errors.New("tls: too many KeyUpdates")
}
if err := c.useInTrafficSecret(epoch, c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret, c.isDTLS)); err != nil {
return err
}
if keyUpdate.keyUpdateRequest == keyUpdateRequested {
c.keyUpdateRequested = true
}
return c.ackHandshake()
}
func (c *Conn) handlePostHandshakeMessage() error {
msg, err := c.readHandshake()
if err != nil {
return err
}
if c.vers < VersionTLS13 {
if !c.isClient {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: unexpected post-handshake message")
}
_, ok := msg.(*helloRequestMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return alertUnexpectedMessage
}
c.handshakeComplete = false
return c.Handshake()
}
if c.isClient {
if newSessionTicket, ok := msg.(*newSessionTicketMsg); ok {
return c.processTLS13NewSessionTicket(newSessionTicket, c.cipherSuite)
}
}
if keyUpdate, ok := msg.(*keyUpdateMsg); ok {
c.keyUpdateSeen = true
if c.config.Bugs.RejectUnsolicitedKeyUpdate {
return errors.New("tls: unexpected KeyUpdate message")
}
return c.processKeyUpdate(keyUpdate)
}
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: unexpected post-handshake message")
}
// Reads a KeyUpdate from the peer, with type key_update_not_requested. There
// may not be any application data records before the message.
func (c *Conn) ReadKeyUpdate() error {
c.in.Lock()
defer c.in.Unlock()
keyUpdate, err := readHandshakeType[keyUpdateMsg](c)
if err != nil {
return err
}
if keyUpdate.keyUpdateRequest != keyUpdateNotRequested {
return errors.New("tls: received invalid KeyUpdate message")
}
return c.processKeyUpdate(keyUpdate)
}
func (c *Conn) Renegotiate() error {
if !c.isClient {
helloReq := new(helloRequestMsg).marshal()
if c.config.Bugs.BadHelloRequest != nil {
helloReq = c.config.Bugs.BadHelloRequest
}
c.writeRecord(recordTypeHandshake, helloReq)
c.flushHandshake()
}
c.handshakeComplete = false
return c.Handshake()
}
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
func (c *Conn) Read(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
return
}
c.in.Lock()
defer c.in.Unlock()
// Some OpenSSL servers send empty records in order to randomize the
// CBC IV. So this loop ignores a limited number of empty records.
const maxConsecutiveEmptyRecords = 100
for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ {
for c.input.Len() == 0 && c.in.err == nil {
if err := c.readRecord(recordTypeApplicationData); err != nil {
// Soft error, like EAGAIN
return 0, err
}
for c.hand.Len() > 0 {
// We received handshake bytes, indicating a
// post-handshake message.
if err := c.handlePostHandshakeMessage(); err != nil {
return 0, err
}
}
}
if err := c.in.err; err != nil {
return 0, err
}
n, err = c.input.Read(b)
if c.input.Len() == 0 || c.isDTLS {
c.input.Reset()
}
// If a close-notify alert is waiting, read it so that
// we can return (n, EOF) instead of (n, nil), to signal
// to the HTTP response reading goroutine that the
// connection is now closed. This eliminates a race
// where the HTTP response reading goroutine would
// otherwise not observe the EOF until its next read,
// by which time a client goroutine might have already
// tried to reuse the HTTP connection for a new
// request.
// See https://codereview.appspot.com/76400046
// and http://golang.org/issue/3514
if ri := c.rawInput.Bytes(); !c.isDTLS && n != 0 && err == nil &&
c.input.Len() == 0 && len(ri) > 0 && recordType(ri[0]) == recordTypeAlert {
if recErr := c.readRecord(recordTypeApplicationData); recErr != nil {
err = recErr // will be io.EOF on closeNotify
}
}
if n != 0 || err != nil {
return n, err
}
}
return 0, io.ErrNoProgress
}
// Close closes the connection.
func (c *Conn) Close() error {
var alertErr error
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if c.handshakeComplete && !c.config.Bugs.NoCloseNotify {
alert := alertCloseNotify
if c.config.Bugs.SendAlertOnShutdown != 0 {
alert = c.config.Bugs.SendAlertOnShutdown
}
alertErr = c.sendAlert(alert)
// Clear local alerts when sending alerts so we continue to wait
// for the peer rather than closing the socket early.
if opErr, ok := alertErr.(*net.OpError); ok && opErr.Op == "local error" {
alertErr = nil
}
}
// Consume a close_notify from the peer if one hasn't been received
// already. This avoids the peer from failing |SSL_shutdown| due to a
// write failing.
if c.handshakeComplete && alertErr == nil && c.config.Bugs.ExpectCloseNotify {
for c.in.error() == nil {
c.readRecord(recordTypeAlert)
}
if c.in.error() != io.EOF {
alertErr = c.in.error()
}
}
if err := c.conn.Close(); err != nil {
return err
}
return alertErr
}
// Handshake runs the client or server handshake
// protocol if it has not yet been run.
// Most uses of this package need not call Handshake
// explicitly: the first Read or Write will call it automatically.
func (c *Conn) Handshake() error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if err := c.handshakeErr; err != nil {
return err
}
if c.handshakeComplete {
return nil
}
if c.isDTLS && c.config.Bugs.SendSplitAlert {
c.conn.Write([]byte{
byte(recordTypeAlert), // type
0xfe, 0xff, // version
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, // sequence
0x0, 0x2, // length
})
c.conn.Write([]byte{alertLevelError, byte(alertInternalError)})
}
if data := c.config.Bugs.AppDataBeforeHandshake; data != nil {
c.writeRecord(recordTypeApplicationData, data)
}
if c.isClient {
c.handshakeErr = c.clientHandshake()
} else {
c.handshakeErr = c.serverHandshake()
}
if c.handshakeErr == nil && c.config.Bugs.SendInvalidRecordType {
c.writeRecord(recordType(42), []byte("invalid record"))
}
return c.handshakeErr
}
// ConnectionState returns basic TLS details about the connection.
func (c *Conn) ConnectionState() ConnectionState {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
var state ConnectionState
state.HandshakeComplete = c.handshakeComplete
if c.handshakeComplete {
state.Version = c.vers
state.NegotiatedProtocol = c.clientProtocol
state.DidResume = c.didResume
state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
state.NegotiatedProtocolFromALPN = c.usedALPN
state.CipherSuite = c.cipherSuite.id
state.PeerCertificates = c.peerCertificates
state.PeerDelegatedCredential = c.peerDelegatedCredential
state.VerifiedChains = c.verifiedChains
state.OCSPResponse = c.ocspResponse
state.ServerName = c.serverName
state.ChannelID = c.channelID
state.SRTPProtectionProfile = c.srtpProtectionProfile
state.TLSUnique = c.firstFinished[:]
state.SCTList = c.sctList
state.PeerSignatureAlgorithm = c.peerSignatureAlgorithm
state.CurveID = c.curveID
state.QUICTransportParams = c.quicTransportParams
state.QUICTransportParamsLegacy = c.quicTransportParamsLegacy
state.HasApplicationSettings = c.hasApplicationSettings
state.PeerApplicationSettings = c.peerApplicationSettings
state.HasApplicationSettingsOld = c.hasApplicationSettingsOld
state.PeerApplicationSettingsOld = c.peerApplicationSettingsOld
state.ECHAccepted = c.echAccepted
}
return state
}
// VerifyHostname checks that the peer certificate chain is valid for
// connecting to host. If so, it returns nil; if not, it returns an error
// describing the problem.
func (c *Conn) VerifyHostname(host string) error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if !c.isClient {
return errors.New("tls: VerifyHostname called on TLS server connection")
}
if !c.handshakeComplete {
return errors.New("tls: handshake has not yet been performed")
}
return c.peerCertificates[0].VerifyHostname(host)
}
func (c *Conn) exportKeyingMaterialTLS13(length int, secret, label, context []byte) []byte {
hash := c.cipherSuite.hash()
exporterKeyingLabel := []byte("exporter")
contextHash := hash.New()
contextHash.Write(context)
exporterContext := hash.New().Sum(nil)
derivedSecret := hkdfExpandLabel(c.cipherSuite.hash(), secret, label, exporterContext, hash.Size(), c.isDTLS)
return hkdfExpandLabel(c.cipherSuite.hash(), derivedSecret, exporterKeyingLabel, contextHash.Sum(nil), length, c.isDTLS)
}
// ExportKeyingMaterial exports keying material from the current connection
// state, as per RFC 5705.
func (c *Conn) ExportKeyingMaterial(length int, label, context []byte, useContext bool) ([]byte, error) {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if !c.handshakeComplete {
return nil, errors.New("tls: handshake has not yet been performed")
}
if c.vers >= VersionTLS13 {
return c.exportKeyingMaterialTLS13(length, c.exporterSecret, label, context), nil
}
seedLen := len(c.clientRandom) + len(c.serverRandom)
if useContext {
seedLen += 2 + len(context)
}
seed := make([]byte, 0, seedLen)
seed = append(seed, c.clientRandom[:]...)
seed = append(seed, c.serverRandom[:]...)
if useContext {
seed = append(seed, byte(len(context)>>8), byte(len(context)))
seed = append(seed, context...)
}
result := make([]byte, length)
prfForVersion(c.vers, c.cipherSuite)(result, c.exporterSecret, label, seed)
return result, nil
}
func (c *Conn) ExportEarlyKeyingMaterial(length int, label, context []byte) ([]byte, error) {
if c.vers < VersionTLS13 {
return nil, errors.New("tls: early exporters not defined before TLS 1.3")
}
if c.earlyExporterSecret == nil {
return nil, errors.New("tls: no early exporter secret")
}
return c.exportKeyingMaterialTLS13(length, c.earlyExporterSecret, label, context), nil
}
// noRenegotiationInfo returns true if the renegotiation info extension
// should be supported in the current handshake.
func (c *Conn) noRenegotiationInfo() bool {
if c.config.Bugs.NoRenegotiationInfo {
return true
}
if c.cipherSuite == nil && c.config.Bugs.NoRenegotiationInfoInInitial {
return true
}
if c.cipherSuite != nil && c.config.Bugs.NoRenegotiationInfoAfterInitial {
return true
}
return false
}
func (c *Conn) SendNewSessionTicket(nonce []byte) error {
if c.isClient || c.vers < VersionTLS13 {
return errors.New("tls: cannot send post-handshake NewSessionTicket")
}
var peerCertificatesRaw [][]byte
for _, cert := range c.peerCertificates {
peerCertificatesRaw = append(peerCertificatesRaw, cert.Raw)
}
addBuffer := make([]byte, 4)
_, err := io.ReadFull(c.config.rand(), addBuffer)
if err != nil {
c.sendAlert(alertInternalError)
return errors.New("tls: short read from Rand: " + err.Error())
}
ticketAgeAdd := uint32(addBuffer[3])<<24 | uint32(addBuffer[2])<<16 | uint32(addBuffer[1])<<8 | uint32(addBuffer[0])
// TODO(davidben): Allow configuring these values.
m := &newSessionTicketMsg{
vers: c.wireVersion,
isDTLS: c.isDTLS,
ticketLifetime: uint32(24 * time.Hour / time.Second),
duplicateEarlyDataExtension: c.config.Bugs.DuplicateTicketEarlyData,
customExtension: c.config.Bugs.CustomTicketExtension,
ticketAgeAdd: ticketAgeAdd,
ticketNonce: nonce,
maxEarlyDataSize: c.config.MaxEarlyDataSize,
}
if c.config.Bugs.MockQUICTransport != nil && m.maxEarlyDataSize > 0 {
m.maxEarlyDataSize = 0xffffffff
}
if c.config.Bugs.SendTicketLifetime != 0 {
m.ticketLifetime = uint32(c.config.Bugs.SendTicketLifetime / time.Second)
}
state := sessionState{
vers: c.vers,
cipherSuite: c.cipherSuite.id,
secret: deriveSessionPSK(c.cipherSuite, c.wireVersion, c.resumptionSecret, nonce, c.isDTLS),
certificates: peerCertificatesRaw,
ticketCreationTime: c.config.time(),
ticketExpiration: c.config.time().Add(time.Duration(m.ticketLifetime) * time.Second),
ticketAgeAdd: uint32(addBuffer[3])<<24 | uint32(addBuffer[2])<<16 | uint32(addBuffer[1])<<8 | uint32(addBuffer[0]),
earlyALPN: []byte(c.clientProtocol),
hasApplicationSettings: c.hasApplicationSettings,
localApplicationSettings: c.localApplicationSettings,
peerApplicationSettings: c.peerApplicationSettings,
hasApplicationSettingsOld: c.hasApplicationSettingsOld,
localApplicationSettingsOld: c.localApplicationSettingsOld,
peerApplicationSettingsOld: c.peerApplicationSettingsOld,
}
if !c.config.Bugs.SendEmptySessionTicket {
var err error
m.ticket, err = c.encryptTicket(&state)
if err != nil {
return err
}
}
c.out.Lock()
defer c.out.Unlock()
_, err = c.writeRecord(recordTypeHandshake, m.marshal())
return err
}
func (c *Conn) SendKeyUpdate(keyUpdateRequest byte) error {
c.out.Lock()
defer c.out.Unlock()
return c.sendKeyUpdateLocked(keyUpdateRequest)
}
func (c *Conn) sendKeyUpdateLocked(keyUpdateRequest byte) error {
if c.vers < VersionTLS13 {
return errors.New("tls: attempted to send KeyUpdate before TLS 1.3")
}
epoch := c.out.epoch.epoch + 1
if epoch == 0 && !c.config.Bugs.AllowEpochOverflow {
return errors.New("tls: too many KeyUpdates")
}
m := keyUpdateMsg{
keyUpdateRequest: keyUpdateRequest,
}
if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
return err
}
// In DTLS 1.3, a real implementation would not install the new epoch until
// receiving an ACK. Our test transport is ordered and reliable, so this is
// not necessary. ACK effects will be simulated in tests by the WriteFlight
// callback.
c.useOutTrafficSecret(epoch, c.out.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.out.trafficSecret, c.isDTLS))
return c.flushHandshake()
}
func (c *Conn) sendFakeEarlyData(len int) error {
// Assemble a fake early data record. This does not use writeRecord
// because the record layer may be using different keys at this point.
payload := make([]byte, 5+len)
payload[0] = byte(recordTypeApplicationData)
payload[1] = 3
payload[2] = 3
payload[3] = byte(len >> 8)
payload[4] = byte(len)
_, err := c.conn.Write(payload)
return err
}
func (c *Conn) usesEndOfEarlyData() bool {
if c.isClient && c.config.Bugs.SendEndOfEarlyDataInQUICAndDTLS {
return true
}
return c.config.Bugs.MockQUICTransport == nil && !c.isDTLS
}