Test server-side renegotiation.

This change adds support to the Go code for renegotiation as a client,
meaning that we can test BoringSSL's renegotiation as a server.

Change-Id: Iaa9fb1a6022c51023bce36c47d4ef7abee74344b
Reviewed-on: https://boringssl-review.googlesource.com/2082
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index 1cf81a7..ce2a3da 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -542,6 +542,30 @@
     }
   }
 
+  if (config->renegotiate) {
+    if (config->async) {
+      fprintf(stderr, "--renegotiate is not supported with --async.\n");
+      return 2;
+    }
+
+    SSL_renegotiate(ssl);
+
+    ret = SSL_do_handshake(ssl);
+    if (ret != 1) {
+      SSL_free(ssl);
+      BIO_print_errors_fp(stdout);
+      return 2;
+    }
+
+    SSL_set_state(ssl, SSL_ST_ACCEPT);
+    ret = SSL_do_handshake(ssl);
+    if (ret != 1) {
+      SSL_free(ssl);
+      BIO_print_errors_fp(stdout);
+      return 2;
+    }
+  }
+
   if (config->write_different_record_sizes) {
     if (config->is_dtls) {
       fprintf(stderr, "write_different_record_sizes not supported for DTLS\n");
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index 4aa21bb..6f146af 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -48,6 +48,7 @@
 
 // TLS handshake message types.
 const (
+	typeHelloRequest        uint8 = 0
 	typeClientHello         uint8 = 1
 	typeServerHello         uint8 = 2
 	typeHelloVerifyRequest  uint8 = 3
@@ -490,6 +491,14 @@
 	// NoExtendedMasterSecret causes the client and server to behave is if
 	// they didn't support an extended master secret.
 	NoExtendedMasterSecret bool
+
+	// EmptyRenegotiationInfo causes the renegotiation extension to be
+	// empty in a renegotiation handshake.
+	EmptyRenegotiationInfo bool
+
+	// BadRenegotiationInfo causes the renegotiation extension value in a
+	// renegotiation handshake to be incorrect.
+	BadRenegotiationInfo bool
 }
 
 func (c *Config) serverInit() {
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 3ce6c76..897175e 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -50,6 +50,10 @@
 	clientProtocolFallback bool
 	usedALPN               bool
 
+	// verify_data values for the renegotiation extension.
+	clientVerify []byte
+	serverVerify []byte
+
 	channelID *ecdsa.PublicKey
 
 	// input/output
@@ -129,9 +133,10 @@
 }
 
 func (hc *halfConn) error() error {
-	hc.Lock()
+	// This should be locked, but I've removed it for the renegotiation
+	// tests since we don't concurrently read and write the same tls.Conn
+	// in any case during testing.
 	err := hc.err
-	hc.Unlock()
 	return err
 }
 
@@ -651,7 +656,7 @@
 func (c *Conn) readRecord(want recordType) error {
 	// Caller must be in sync with connection:
 	// handshake data if handshake not yet completed,
-	// else application data.  (We don't support renegotiation.)
+	// else application data.
 	switch want {
 	default:
 		c.sendAlert(alertInternalError)
@@ -725,7 +730,12 @@
 	case recordTypeHandshake:
 		// TODO(rsc): Should at least pick off connection close.
 		if typ != want {
-			return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
+			// A client might need to process a HelloRequest from
+			// the server, thus receiving a handshake message when
+			// application data is expected is ok.
+			if !c.isClient {
+				return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
+			}
 		}
 		c.hand.Write(data)
 	}
@@ -908,6 +918,8 @@
 
 	var m handshakeMessage
 	switch data[0] {
+	case typeHelloRequest:
+		m = new(helloRequestMsg)
 	case typeClientHello:
 		m = &clientHelloMsg{
 			isDTLS: c.isDTLS,
@@ -1000,6 +1012,25 @@
 	return n + m, c.out.setErrorLocked(err)
 }
 
+func (c *Conn) handleRenegotiation() error {
+	c.handshakeComplete = false
+	if !c.isClient {
+		panic("renegotiation should only happen for a client")
+	}
+
+	msg, err := c.readHandshake()
+	if err != nil {
+		return err
+	}
+	_, ok := msg.(*helloRequestMsg)
+	if !ok {
+		c.sendAlert(alertUnexpectedMessage)
+		return alertUnexpectedMessage
+	}
+
+	return c.Handshake()
+}
+
 // Read can be made to time out and return a net.Error with Timeout() == true
 // after a fixed time limit; see SetDeadline and SetReadDeadline.
 func (c *Conn) Read(b []byte) (n int, err error) {
@@ -1019,6 +1050,14 @@
 				// Soft error, like EAGAIN
 				return 0, err
 			}
+			if c.hand.Len() > 0 {
+				// We received handshake bytes, indicating the
+				// start of a renegotiation.
+				if err := c.handleRenegotiation(); err != nil {
+					return 0, err
+				}
+				continue
+			}
 		}
 		if err := c.in.err; err != nil {
 			return 0, err
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 11a1ed3..0c5192f 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -65,7 +65,7 @@
 		supportedCurves:      c.config.curvePreferences(),
 		supportedPoints:      []uint8{pointFormatUncompressed},
 		nextProtoNeg:         len(c.config.NextProtos) > 0,
-		secureRenegotiation:  true,
+		secureRenegotiation:  []byte{},
 		alpnProtocols:        c.config.NextProtos,
 		duplicateExtension:   c.config.Bugs.DuplicateExtension,
 		channelIDSupported:   c.config.ChannelID != nil,
@@ -81,6 +81,15 @@
 		hello.extendedMasterSecret = false
 	}
 
+	if len(c.clientVerify) > 0 && !c.config.Bugs.EmptyRenegotiationInfo {
+		if c.config.Bugs.BadRenegotiationInfo {
+			hello.secureRenegotiation = append(hello.secureRenegotiation, c.clientVerify...)
+			hello.secureRenegotiation[0] ^= 0x80
+		} else {
+			hello.secureRenegotiation = c.clientVerify
+		}
+	}
+
 	possibleCipherSuites := c.config.cipherSuites()
 	hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites))
 
@@ -240,6 +249,16 @@
 		return fmt.Errorf("tls: server selected an unsupported cipher suite")
 	}
 
+	if len(c.clientVerify) > 0 {
+		var expectedRenegInfo []byte
+		expectedRenegInfo = append(expectedRenegInfo, c.clientVerify...)
+		expectedRenegInfo = append(expectedRenegInfo, c.serverVerify...)
+		if !bytes.Equal(serverHello.secureRenegotiation, expectedRenegInfo) {
+			c.sendAlert(alertHandshakeFailure)
+			return fmt.Errorf("tls: renegotiation mismatch")
+		}
+	}
+
 	hs := &clientHandshakeState{
 		c:            c,
 		serverHello:  serverHello,
@@ -680,6 +699,7 @@
 			return errors.New("tls: server's Finished message was incorrect")
 		}
 	}
+	c.serverVerify = append(c.serverVerify[:0], serverFinished.verifyData...)
 	hs.writeServerHash(serverFinished.marshal())
 	return nil
 }
@@ -766,6 +786,7 @@
 	} else {
 		finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
 	}
+	c.clientVerify = append(c.clientVerify[:0], finished.verifyData...)
 	finishedBytes := finished.marshal()
 	hs.writeHash(finishedBytes, seqno)
 	postCCSBytes = append(postCCSBytes, finishedBytes...)
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go
index 1114a6f..12a9f3d 100644
--- a/ssl/test/runner/handshake_messages.go
+++ b/ssl/test/runner/handshake_messages.go
@@ -23,7 +23,7 @@
 	ticketSupported      bool
 	sessionTicket        []uint8
 	signatureAndHashes   []signatureAndHash
-	secureRenegotiation  bool
+	secureRenegotiation  []byte
 	alpnProtocols        []string
 	duplicateExtension   bool
 	channelIDSupported   bool
@@ -53,7 +53,8 @@
 		m.ticketSupported == m1.ticketSupported &&
 		bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
 		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
-		m.secureRenegotiation == m1.secureRenegotiation &&
+		bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
+		(m.secureRenegotiation == nil) == (m1.secureRenegotiation == nil) &&
 		eqStrings(m.alpnProtocols, m1.alpnProtocols) &&
 		m.duplicateExtension == m1.duplicateExtension &&
 		m.channelIDSupported == m1.channelIDSupported &&
@@ -99,8 +100,8 @@
 		extensionsLength += 2 + 2*len(m.signatureAndHashes)
 		numExtensions++
 	}
-	if m.secureRenegotiation {
-		extensionsLength += 1
+	if m.secureRenegotiation != nil {
+		extensionsLength += 1 + len(m.secureRenegotiation)
 		numExtensions++
 	}
 	if m.duplicateExtension {
@@ -279,12 +280,15 @@
 			z = z[2:]
 		}
 	}
-	if m.secureRenegotiation {
+	if m.secureRenegotiation != nil {
 		z[0] = byte(extensionRenegotiationInfo >> 8)
 		z[1] = byte(extensionRenegotiationInfo & 0xff)
 		z[2] = 0
-		z[3] = 1
+		z[3] = byte(1 + len(m.secureRenegotiation))
+		z[4] = byte(len(m.secureRenegotiation))
 		z = z[5:]
+		copy(z, m.secureRenegotiation)
+		z = z[len(m.secureRenegotiation):]
 	}
 	if len(m.alpnProtocols) > 0 {
 		z[0] = byte(extensionALPN >> 8)
@@ -374,7 +378,7 @@
 	for i := 0; i < numCipherSuites; i++ {
 		m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
 		if m.cipherSuites[i] == scsvRenegotiation {
-			m.secureRenegotiation = true
+			m.secureRenegotiation = []byte{}
 		}
 	}
 	data = data[2+cipherSuiteLen:]
@@ -501,11 +505,11 @@
 				m.signatureAndHashes[i].signature = d[1]
 				d = d[2:]
 			}
-		case extensionRenegotiationInfo + 1:
-			if length != 1 || data[0] != 0 {
+		case extensionRenegotiationInfo:
+			if length < 1 || length != int(data[0])+1 {
 				return false
 			}
-			m.secureRenegotiation = true
+			m.secureRenegotiation = data[1:length]
 		case extensionALPN:
 			if length < 2 {
 				return false
@@ -553,7 +557,7 @@
 	nextProtos           []string
 	ocspStapling         bool
 	ticketSupported      bool
-	secureRenegotiation  bool
+	secureRenegotiation  []byte
 	alpnProtocol         string
 	duplicateExtension   bool
 	channelIDRequested   bool
@@ -577,7 +581,8 @@
 		eqStrings(m.nextProtos, m1.nextProtos) &&
 		m.ocspStapling == m1.ocspStapling &&
 		m.ticketSupported == m1.ticketSupported &&
-		m.secureRenegotiation == m1.secureRenegotiation &&
+		bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
+		(m.secureRenegotiation == nil) == (m1.secureRenegotiation == nil) &&
 		m.alpnProtocol == m1.alpnProtocol &&
 		m.duplicateExtension == m1.duplicateExtension &&
 		m.channelIDRequested == m1.channelIDRequested &&
@@ -608,8 +613,8 @@
 	if m.ticketSupported {
 		numExtensions++
 	}
-	if m.secureRenegotiation {
-		extensionsLength += 1
+	if m.secureRenegotiation != nil {
+		extensionsLength += 1 + len(m.secureRenegotiation)
 		numExtensions++
 	}
 	if m.duplicateExtension {
@@ -689,12 +694,15 @@
 		z[1] = byte(extensionSessionTicket)
 		z = z[4:]
 	}
-	if m.secureRenegotiation {
+	if m.secureRenegotiation != nil {
 		z[0] = byte(extensionRenegotiationInfo >> 8)
 		z[1] = byte(extensionRenegotiationInfo & 0xff)
 		z[2] = 0
-		z[3] = 1
+		z[3] = byte(1 + len(m.secureRenegotiation))
+		z[4] = byte(len(m.secureRenegotiation))
 		z = z[5:]
+		copy(z, m.secureRenegotiation)
+		z = z[len(m.secureRenegotiation):]
 	}
 	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
 		z[0] = byte(extensionALPN >> 8)
@@ -808,10 +816,10 @@
 			}
 			m.ticketSupported = true
 		case extensionRenegotiationInfo:
-			if length != 1 || data[0] != 0 {
+			if length < 1 || length != int(data[0])+1 {
 				return false
 			}
-			m.secureRenegotiation = true
+			m.secureRenegotiation = data[1:length]
 		case extensionALPN:
 			d := data[:length]
 			if len(d) < 3 {
@@ -1667,6 +1675,17 @@
 	return true
 }
 
+type helloRequestMsg struct {
+}
+
+func (*helloRequestMsg) marshal() []byte {
+	return []byte{typeHelloRequest, 0, 0, 0}
+}
+
+func (*helloRequestMsg) unmarshal(data []byte) bool {
+	return len(data) == 4
+}
+
 func eqUint16s(x, y []uint16) bool {
 	if len(x) != len(y) {
 		return false
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 4bf8f1c..3288b0d 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -214,6 +214,11 @@
 		c.sendAlert(alertInternalError)
 		return false, err
 	}
+
+	if len(hs.clientHello.secureRenegotiation) > 1 {
+		c.sendAlert(alertHandshakeFailure)
+		return false, errors.New("tls: client is doing a renegotiation handshake")
+	}
 	hs.hello.secureRenegotiation = hs.clientHello.secureRenegotiation
 	hs.hello.compressionMethod = compressionNone
 	hs.hello.duplicateExtension = c.config.Bugs.DuplicateExtension
@@ -693,6 +698,7 @@
 		c.sendAlert(alertHandshakeFailure)
 		return errors.New("tls: client's Finished message is incorrect")
 	}
+	c.clientVerify = append(c.clientVerify[:0], clientFinished.verifyData...)
 
 	hs.writeClientHash(clientFinished.marshal())
 	return nil
@@ -730,6 +736,7 @@
 
 	finished := new(finishedMsg)
 	finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
+	c.serverVerify = append(c.serverVerify[:0], finished.verifyData...)
 	postCCSBytes := finished.marshal()
 	hs.writeServerHash(postCCSBytes)
 
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 1b461e2..ef72374 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -1759,6 +1759,42 @@
 	}
 }
 
+func addRenegotiationTests() {
+	testCases = append(testCases, testCase{
+		testType:        serverTest,
+		name:            "Renegotiate-Server",
+		flags:           []string{"-renegotiate"},
+		shimWritesFirst: true,
+	})
+	testCases = append(testCases, testCase{
+		testType: serverTest,
+		name:     "Renegotiate-Server-EmptyExt",
+		config: Config{
+			Bugs: ProtocolBugs{
+				EmptyRenegotiationInfo: true,
+			},
+		},
+		flags:           []string{"-renegotiate"},
+		shimWritesFirst: true,
+		shouldFail:      true,
+		expectedError:   ":RENEGOTIATION_MISMATCH:",
+	})
+	testCases = append(testCases, testCase{
+		testType: serverTest,
+		name:     "Renegotiate-Server-BadExt",
+		config: Config{
+			Bugs: ProtocolBugs{
+				BadRenegotiationInfo: true,
+			},
+		},
+		flags:           []string{"-renegotiate"},
+		shimWritesFirst: true,
+		shouldFail:      true,
+		expectedError:   ":RENEGOTIATION_MISMATCH:",
+	})
+	// TODO(agl): test the renegotiation info SCSV.
+}
+
 func worker(statusChan chan statusMsg, c chan *testCase, buildDir string, wg *sync.WaitGroup) {
 	defer wg.Done()
 
@@ -1815,6 +1851,7 @@
 	addExtensionTests()
 	addResumptionVersionTests()
 	addExtendedMasterSecretTests()
+	addRenegotiationTests()
 	for _, async := range []bool{false, true} {
 		for _, splitHandshake := range []bool{false, true} {
 			for _, protocol := range []protocol{tls, dtls} {
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index c50d9de..b717bd3 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -59,6 +59,7 @@
   { "-expect-session-miss", &TestConfig::expect_session_miss },
   { "-expect-extended-master-secret",
     &TestConfig::expect_extended_master_secret },
+  { "-renegotiate", &TestConfig::renegotiate },
 };
 
 const size_t kNumBoolFlags = sizeof(kBoolFlags) / sizeof(kBoolFlags[0]);
@@ -110,7 +111,8 @@
       shim_writes_first(false),
       tls_d5_bug(false),
       expect_session_miss(false),
-      expect_extended_master_secret(false) {
+      expect_extended_master_secret(false),
+      renegotiate(false) {
 }
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_config) {
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index e5ff8ad..2dc4dc1 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -56,6 +56,7 @@
   bool expect_extended_master_secret;
   std::string psk;
   std::string psk_identity;
+  bool renegotiate;
 };
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_config);