Add additional features to bssl client.

This exposes the features needed to mimic Chrome's ClientHello, which is useful
in testing. Also use bssl_shim's scopers for SSL objects.

Change-Id: Icb88bb00c0a05c27610134d618f466a24f7f757a
Reviewed-on: https://boringssl-review.googlesource.com/4113
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/tool/args.cc b/tool/args.cc
index 52856d4..a164476 100644
--- a/tool/args.cc
+++ b/tool/args.cc
@@ -41,22 +41,26 @@
       return false;
     }
 
-    if (i + 1 >= args.size()) {
-      fprintf(stderr, "Missing argument for option: %s\n", arg.c_str());
-      return false;
-    }
-
     if (out_args->find(arg) != out_args->end()) {
-      fprintf(stderr, "Duplicate value given for: %s\n", arg.c_str());
+      fprintf(stderr, "Duplicate argument: %s\n", arg.c_str());
       return false;
     }
 
-    (*out_args)[arg] = args[++i];
+    if (templ->type == kBooleanArgument) {
+      (*out_args)[arg] = "";
+    } else {
+      if (i + 1 >= args.size()) {
+        fprintf(stderr, "Missing argument for option: %s\n", arg.c_str());
+        return false;
+      }
+      (*out_args)[arg] = args[++i];
+    }
   }
 
   for (size_t j = 0; templates[j].name[0] != 0; j++) {
     const struct argument *templ = &templates[j];
-    if (templ->required && out_args->find(templ->name) == out_args->end()) {
+    if (templ->type == kRequiredArgument &&
+        out_args->find(templ->name) == out_args->end()) {
       fprintf(stderr, "Missing value for required argument: %s\n", templ->name);
       return false;
     }
diff --git a/tool/client.cc b/tool/client.cc
index 59c5fe3..15592c4 100644
--- a/tool/client.cc
+++ b/tool/client.cc
@@ -22,26 +22,99 @@
 #include <sys/types.h>
 
 #include <openssl/err.h>
+#include <openssl/pem.h>
 #include <openssl/ssl.h>
 
+#include "../ssl/test/scoped_types.h"
 #include "internal.h"
 #include "transport_common.h"
 
 
 static const struct argument kArguments[] = {
     {
-     "-connect", true,
+     "-connect", kRequiredArgument,
      "The hostname and port of the server to connect to, e.g. foo.com:443",
     },
     {
-     "-cipher", false,
+     "-cipher", kOptionalArgument,
      "An OpenSSL-style cipher suite string that configures the offered ciphers",
     },
     {
-     "", false, "",
+     "-max-version", kOptionalArgument,
+     "The maximum acceptable protocol version",
+    },
+    {
+     "-min-version", kOptionalArgument,
+     "The minimum acceptable protocol version",
+    },
+    {
+     "-server-name", kOptionalArgument,
+     "The server name to advertise",
+    },
+    {
+     "-select-next-proto", kOptionalArgument,
+     "An NPN protocol to select if the server supports NPN",
+    },
+    {
+     "-alpn-protos", kOptionalArgument,
+     "A comma-separated list of ALPN protocols to advertise",
+    },
+    {
+     "-fallback-scsv", kBooleanArgument,
+     "Enable FALLBACK_SCSV",
+    },
+    {
+     "-ocsp-stapling", kBooleanArgument,
+     "Advertise support for OCSP stabling",
+    },
+    {
+     "-signed-certificate-timestamps", kBooleanArgument,
+     "Advertise support for signed certificate timestamps",
+    },
+    {
+     "-channel-id-key", kOptionalArgument,
+     "The key to use for signing a channel ID",
+    },
+    {
+     "", kOptionalArgument, "",
     },
 };
 
+static ScopedEVP_PKEY LoadPrivateKey(const std::string &file) {
+  ScopedBIO bio(BIO_new(BIO_s_file()));
+  if (!bio || !BIO_read_filename(bio.get(), file.c_str())) {
+    return nullptr;
+  }
+  ScopedEVP_PKEY pkey(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr,
+                                              nullptr));
+  return pkey;
+}
+
+static bool VersionFromString(uint16_t *out_version,
+                              const std::string& version) {
+  if (version == "ssl3") {
+    *out_version = SSL3_VERSION;
+    return true;
+  } else if (version == "tls1" || version == "tls1.0") {
+    *out_version = TLS1_VERSION;
+    return true;
+  } else if (version == "tls1.1") {
+    *out_version = TLS1_1_VERSION;
+    return true;
+  } else if (version == "tls1.2") {
+    *out_version = TLS1_2_VERSION;
+    return true;
+  }
+  return false;
+}
+
+static int NextProtoSelectCallback(SSL* ssl, uint8_t** out, uint8_t* outlen,
+                                   const uint8_t* in, unsigned inlen, void* arg) {
+  *out = reinterpret_cast<uint8_t *>(arg);
+  *outlen = strlen(reinterpret_cast<const char *>(arg));
+  return SSL_TLSEXT_ERR_OK;
+}
+
 bool Client(const std::vector<std::string> &args) {
   if (!InitSocketLibrary()) {
     return false;
@@ -54,7 +127,7 @@
     return false;
   }
 
-  SSL_CTX *ctx = SSL_CTX_new(SSLv23_client_method());
+  ScopedSSL_CTX ctx(SSL_CTX_new(SSLv23_client_method()));
 
   const char *keylog_file = getenv("SSLKEYLOGFILE");
   if (keylog_file) {
@@ -63,38 +136,117 @@
       ERR_print_errors_cb(PrintErrorCallback, stderr);
       return false;
     }
-    SSL_CTX_set_keylog_bio(ctx, keylog_bio);
+    SSL_CTX_set_keylog_bio(ctx.get(), keylog_bio);
   }
 
   if (args_map.count("-cipher") != 0 &&
-      !SSL_CTX_set_cipher_list(ctx, args_map["-cipher"].c_str())) {
+      !SSL_CTX_set_cipher_list(ctx.get(), args_map["-cipher"].c_str())) {
     fprintf(stderr, "Failed setting cipher list\n");
     return false;
   }
 
+  if (args_map.count("-max-version") != 0) {
+    uint16_t version;
+    if (!VersionFromString(&version, args_map["-max-version"])) {
+      fprintf(stderr, "Unknown protocol version: '%s'\n",
+              args_map["-max-version"].c_str());
+      return false;
+    }
+    SSL_CTX_set_max_version(ctx.get(), version);
+  }
+
+  if (args_map.count("-min-version") != 0) {
+    uint16_t version;
+    if (!VersionFromString(&version, args_map["-min-version"])) {
+      fprintf(stderr, "Unknown protocol version: '%s'\n",
+              args_map["-min-version"].c_str());
+      return false;
+    }
+    SSL_CTX_set_min_version(ctx.get(), version);
+  }
+
+  if (args_map.count("-select-next-proto") != 0) {
+    const std::string &proto = args_map["-select-next-proto"];
+    if (proto.size() > 255) {
+      fprintf(stderr, "Bad NPN protocol: '%s'\n", proto.c_str());
+      return false;
+    }
+    // |SSL_CTX_set_next_proto_select_cb| is not const-correct.
+    SSL_CTX_set_next_proto_select_cb(ctx.get(), NextProtoSelectCallback,
+                                     const_cast<char *>(proto.c_str()));
+  }
+
+  if (args_map.count("-alpn-protos") != 0) {
+    const std::string &alpn_protos = args_map["-alpn-protos"];
+    std::vector<uint8_t> wire;
+    size_t i = 0;
+    while (i <= alpn_protos.size()) {
+      size_t j = alpn_protos.find(',', i);
+      if (j == std::string::npos) {
+        j = alpn_protos.size();
+      }
+      size_t len = j - i;
+      if (len > 255) {
+        fprintf(stderr, "Invalid ALPN protocols: '%s'\n", alpn_protos.c_str());
+        return false;
+      }
+      wire.push_back(static_cast<uint8_t>(len));
+      wire.resize(wire.size() + len);
+      memcpy(wire.data() + wire.size() - len, alpn_protos.data() + i, len);
+      i = j + 1;
+    }
+    if (SSL_CTX_set_alpn_protos(ctx.get(), wire.data(), wire.size()) != 0) {
+      return false;
+    }
+  }
+
+  if (args_map.count("-fallback-scsv") != 0) {
+    SSL_CTX_set_mode(ctx.get(), SSL_MODE_SEND_FALLBACK_SCSV);
+  }
+
+  if (args_map.count("-ocsp-stapling") != 0) {
+    SSL_CTX_enable_ocsp_stapling(ctx.get());
+  }
+
+  if (args_map.count("-signed-certificate-timestamps") != 0) {
+    SSL_CTX_enable_signed_cert_timestamps(ctx.get());
+  }
+
+  if (args_map.count("-channel-id-key") != 0) {
+    ScopedEVP_PKEY pkey = LoadPrivateKey(args_map["-channel-id-key"]);
+    if (!pkey || !SSL_CTX_set1_tls_channel_id(ctx.get(), pkey.get())) {
+      return false;
+    }
+    ctx->tlsext_channel_id_enabled_new = 1;
+  }
+
   int sock = -1;
   if (!Connect(&sock, args_map["-connect"])) {
     return false;
   }
 
-  BIO *bio = BIO_new_socket(sock, BIO_CLOSE);
-  SSL *ssl = SSL_new(ctx);
-  SSL_set_bio(ssl, bio, bio);
+  ScopedBIO bio(BIO_new_socket(sock, BIO_CLOSE));
+  ScopedSSL ssl(SSL_new(ctx.get()));
 
-  int ret = SSL_connect(ssl);
+  if (args_map.count("-server-name") != 0) {
+    SSL_set_tlsext_host_name(ssl.get(), args_map["-server-name"].c_str());
+  }
+
+  SSL_set_bio(ssl.get(), bio.get(), bio.get());
+  bio.release();
+
+  int ret = SSL_connect(ssl.get());
   if (ret != 1) {
-    int ssl_err = SSL_get_error(ssl, ret);
+    int ssl_err = SSL_get_error(ssl.get(), ret);
     fprintf(stderr, "Error while connecting: %d\n", ssl_err);
     ERR_print_errors_cb(PrintErrorCallback, stderr);
     return false;
   }
 
   fprintf(stderr, "Connected.\n");
-  PrintConnectionInfo(ssl);
+  PrintConnectionInfo(ssl.get());
 
-  bool ok = TransferData(ssl, sock);
+  bool ok = TransferData(ssl.get(), sock);
 
-  SSL_free(ssl);
-  SSL_CTX_free(ctx);
   return ok;
 }
diff --git a/tool/internal.h b/tool/internal.h
index bc87c51..277d099 100644
--- a/tool/internal.h
+++ b/tool/internal.h
@@ -32,9 +32,15 @@
 #pragma warning(pop)
 #endif
 
+enum ArgumentType {
+  kRequiredArgument,
+  kOptionalArgument,
+  kBooleanArgument,
+};
+
 struct argument {
-  const char name[15];
-  bool required;
+  const char *name;
+  ArgumentType type;
   const char *description;
 };
 
diff --git a/tool/pkcs12.cc b/tool/pkcs12.cc
index e0133e5..7ce2bd0 100644
--- a/tool/pkcs12.cc
+++ b/tool/pkcs12.cc
@@ -46,10 +46,11 @@
 
 static const struct argument kArguments[] = {
     {
-     "-dump", false, "Dump the key and contents of the given file to stdout",
+     "-dump", kOptionalArgument,
+     "Dump the key and contents of the given file to stdout",
     },
     {
-     "", false, "",
+     "", kOptionalArgument, "",
     },
 };
 
diff --git a/tool/server.cc b/tool/server.cc
index 120e450..2890b09 100644
--- a/tool/server.cc
+++ b/tool/server.cc
@@ -30,19 +30,19 @@
 
 static const struct argument kArguments[] = {
     {
-     "-accept", true,
+     "-accept", kRequiredArgument,
      "The port of the server to bind on; eg 45102",
     },
     {
-     "-cipher", false,
+     "-cipher", kOptionalArgument,
      "An OpenSSL-style cipher suite string that configures the offered ciphers",
     },
     {
-      "-key", false,
+      "-key", kOptionalArgument,
       "Private-key file to use (default is server.pem)",
     },
     {
-     "", false, "",
+     "", kOptionalArgument, "",
     },
 };
 
diff --git a/tool/transport_common.cc b/tool/transport_common.cc
index c05742e..ddb41b6 100644
--- a/tool/transport_common.cc
+++ b/tool/transport_common.cc
@@ -172,6 +172,17 @@
   fprintf(stderr, "  Cipher: %s\n", SSL_CIPHER_get_name(cipher));
   fprintf(stderr, "  Secure renegotiation: %s\n",
           SSL_get_secure_renegotiation_support(ssl) ? "yes" : "no");
+
+  const uint8_t *next_proto;
+  unsigned next_proto_len;
+  SSL_get0_next_proto_negotiated(ssl, &next_proto, &next_proto_len);
+  fprintf(stderr, "  Next protocol negotiated: %.*s\n", next_proto_len,
+          next_proto);
+
+  const uint8_t *alpn;
+  unsigned alpn_len;
+  SSL_get0_alpn_selected(ssl, &alpn, &alpn_len);
+  fprintf(stderr, "  ALPN protocol: %.*s\n", alpn_len, alpn);
 }
 
 bool SocketSetNonBlocking(int sock, bool is_non_blocking) {