Add basic TLS Channel ID tests.

Change-Id: I7ccf2b8282dfa8f3985775e8b67edcf3c2949752
Reviewed-on: https://boringssl-review.googlesource.com/1606
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index c976e7c..8c1e75e 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -48,6 +48,20 @@
   return (const TestConfig *)SSL_get_ex_data(ssl, g_ex_data_index);
 }
 
+static EVP_PKEY *LoadPrivateKey(const std::string &file) {
+  BIO *bio = BIO_new(BIO_s_file());
+  if (bio == NULL) {
+    return NULL;
+  }
+  if (!BIO_read_filename(bio, file.c_str())) {
+    BIO_free(bio);
+    return NULL;
+  }
+  EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL);
+  BIO_free(bio);
+  return pkey;
+}
+
 static int early_callback_called = 0;
 
 static int select_certificate_callback(const struct ssl_early_callback_ctx *ctx) {
@@ -205,6 +219,8 @@
   SSL_CTX_set_cookie_generate_cb(ssl_ctx, cookie_generate_callback);
   SSL_CTX_set_cookie_verify_cb(ssl_ctx, cookie_verify_callback);
 
+  ssl_ctx->tlsext_channel_id_enabled_new = 1;
+
   DH_free(dh);
   return ssl_ctx;
 
@@ -300,6 +316,23 @@
   if (config->cookie_exchange) {
     SSL_set_options(ssl, SSL_OP_COOKIE_EXCHANGE);
   }
+  if (!config->expected_channel_id.empty()) {
+    SSL_enable_tls_channel_id(ssl);
+  }
+  if (!config->send_channel_id.empty()) {
+    EVP_PKEY *pkey = LoadPrivateKey(config->send_channel_id);
+    if (pkey == NULL) {
+      BIO_print_errors_fp(stdout);
+      return 1;
+    }
+    SSL_enable_tls_channel_id(ssl);
+    if (!SSL_set1_tls_channel_id(ssl, pkey)) {
+      EVP_PKEY_free(pkey);
+      BIO_print_errors_fp(stdout);
+      return 1;
+    }
+    EVP_PKEY_free(pkey);
+  }
 
   BIO *bio = BIO_new_fd(fd, 1 /* take ownership */);
   if (bio == NULL) {
@@ -363,7 +396,7 @@
   if (!config->expected_certificate_types.empty()) {
     uint8_t *certificate_types;
     int num_certificate_types =
-      SSL_get0_certificate_types(ssl, &certificate_types);
+        SSL_get0_certificate_types(ssl, &certificate_types);
     if (num_certificate_types !=
         (int)config->expected_certificate_types.size() ||
         memcmp(certificate_types,
@@ -386,6 +419,20 @@
     }
   }
 
+  if (!config->expected_channel_id.empty()) {
+    uint8_t channel_id[64];
+    if (!SSL_get_tls_channel_id(ssl, channel_id, sizeof(channel_id))) {
+      fprintf(stderr, "no channel id negotiated\n");
+      return 2;
+    }
+    if (config->expected_channel_id.size() != 64 ||
+        memcmp(config->expected_channel_id.data(),
+               channel_id, 64) != 0) {
+      fprintf(stderr, "channel id mismatch\n");
+      return 2;
+    }
+  }
+
   if (config->write_different_record_sizes) {
     if (config->is_dtls) {
       fprintf(stderr, "write_different_record_sizes not supported for DTLS\n");
diff --git a/ssl/test/runner/channel_id_key.pem b/ssl/test/runner/channel_id_key.pem
new file mode 100644
index 0000000..604752b
--- /dev/null
+++ b/ssl/test/runner/channel_id_key.pem
@@ -0,0 +1,5 @@
+-----BEGIN EC PRIVATE KEY-----
+MHcCAQEEIPwxu50c7LEhVNRYJFRWBUnoaz7JSos96T5hBp4rjyptoAoGCCqGSM49
+AwEHoUQDQgAEzFSVTE5guxJRQ0VbZ8dicPs5e/DT7xpW7Yc9hq0VOchv7cbXuI/T
+CwadDjGWX/oaz0ftFqrVmfkwZu+C58ioWg==
+-----END EC PRIVATE KEY-----
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 360de12..72e8cce 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -2,8 +2,11 @@
 
 import (
 	"bytes"
+	"crypto/ecdsa"
+	"crypto/elliptic"
 	"crypto/x509"
 	"encoding/base64"
+	"encoding/pem"
 	"flag"
 	"fmt"
 	"io"
@@ -26,11 +29,14 @@
 )
 
 const (
-	rsaKeyFile   = "key.pem"
-	ecdsaKeyFile = "ecdsa_key.pem"
+	rsaKeyFile       = "key.pem"
+	ecdsaKeyFile     = "ecdsa_key.pem"
+	channelIDKeyFile = "channel_id_key.pem"
 )
 
 var rsaCertificate, ecdsaCertificate Certificate
+var channelIDKey *ecdsa.PrivateKey
+var channelIDBytes []byte
 
 func initCertificates() {
 	var err error
@@ -43,6 +49,26 @@
 	if err != nil {
 		panic(err)
 	}
+
+	channelIDPEMBlock, err := ioutil.ReadFile(channelIDKeyFile)
+	if err != nil {
+		panic(err)
+	}
+	channelIDDERBlock, _ := pem.Decode(channelIDPEMBlock)
+	if channelIDDERBlock.Type != "EC PRIVATE KEY" {
+		panic("bad key type")
+	}
+	channelIDKey, err = x509.ParseECPrivateKey(channelIDDERBlock.Bytes)
+	if err != nil {
+		panic(err)
+	}
+	if channelIDKey.Curve != elliptic.P256() {
+		panic("bad curve")
+	}
+
+	channelIDBytes = make([]byte, 64)
+	writeIntPadded(channelIDBytes[:32], channelIDKey.X)
+	writeIntPadded(channelIDBytes[32:], channelIDKey.Y)
 }
 
 var certificateOnce sync.Once
@@ -84,6 +110,9 @@
 	// expectedVersion, if non-zero, specifies the TLS version that must be
 	// negotiated.
 	expectedVersion uint16
+	// expectChannelID controls whether the connection should have
+	// negotiated a Channel ID with channelIDKey.
+	expectChannelID bool
 	// messageLen is the length, in bytes, of the test message that will be
 	// sent.
 	messageLen int
@@ -488,6 +517,18 @@
 		return fmt.Errorf("got version %x, expected %x", vers, test.expectedVersion)
 	}
 
+	if test.expectChannelID {
+		channelID := tlsConn.ConnectionState().ChannelID
+		if channelID == nil {
+			return fmt.Errorf("no channel ID negotiated")
+		}
+		if channelID.Curve != channelIDKey.Curve ||
+			channelIDKey.X.Cmp(channelIDKey.X) != 0 ||
+			channelIDKey.Y.Cmp(channelIDKey.Y) != 0 {
+			return fmt.Errorf("incorrect channel ID")
+		}
+	}
+
 	if messageLen < 0 {
 		if test.protocol == dtls {
 			return fmt.Errorf("messageLen < 0 not supported for DTLS tests")
@@ -1141,7 +1182,7 @@
 			),
 		})
 
-		// Client sends a V2ClientHello.
+		// Server parses a V2ClientHello.
 		testCases = append(testCases, testCase{
 			protocol: protocol,
 			testType: serverTest,
@@ -1158,6 +1199,42 @@
 			},
 			flags: flags,
 		})
+
+		// Client sends a Channel ID.
+		testCases = append(testCases, testCase{
+			protocol: protocol,
+			name:     "ChannelID-Client" + suffix,
+			config: Config{
+				RequestChannelID: true,
+				Bugs: ProtocolBugs{
+					MaxHandshakeRecordLength: maxHandshakeRecordLength,
+				},
+			},
+			flags: append(flags,
+				"-send-channel-id", channelIDKeyFile,
+			),
+			resumeSession:   true,
+			expectChannelID: true,
+		})
+
+		// Server accepts a Channel ID.
+		testCases = append(testCases, testCase{
+			protocol: protocol,
+			testType: serverTest,
+			name:     "ChannelID-Server" + suffix,
+			config: Config{
+				ChannelID: channelIDKey,
+				Bugs: ProtocolBugs{
+					MaxHandshakeRecordLength: maxHandshakeRecordLength,
+				},
+			},
+			flags: append(flags,
+				"-expect-channel-id",
+				base64.StdEncoding.EncodeToString(channelIDBytes),
+			),
+			resumeSession:   true,
+			expectChannelID: true,
+		})
 	} else {
 		testCases = append(testCases, testCase{
 			protocol: protocol,
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index e69d570..cf0d0af 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -64,12 +64,14 @@
   { "-advertise-npn", &TestConfig::advertise_npn },
   { "-expect-next-proto", &TestConfig::expected_next_proto },
   { "-select-next-proto", &TestConfig::select_next_proto },
+  { "-send-channel-id", &TestConfig::send_channel_id },
 };
 
 const size_t kNumStringFlags = sizeof(kStringFlags) / sizeof(kStringFlags[0]);
 
 const StringFlag kBase64Flags[] = {
   { "-expect-certificate-types", &TestConfig::expected_certificate_types },
+  { "-expect-channel-id", &TestConfig::expected_channel_id },
 };
 
 const size_t kNumBase64Flags = sizeof(kBase64Flags) / sizeof(kBase64Flags[0]);
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index 34d720e..dc5e8d3 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -43,6 +43,8 @@
   bool no_tls1;
   bool no_ssl3;
   bool cookie_exchange;
+  std::string expected_channel_id;
+  std::string send_channel_id;
 };
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_config);