Implement SSL_CTX_set1_curves_list()

This function is used by NGINX to enable specific curves for ECDH from a
configuration file. However when building with BoringSSL, since it's not
implmeneted, it falls back to using EC_KEY_new_by_curve_name() wich doesn't
support X25519.

Change-Id: I533df4ef302592c1a9f9fc8880bd85f796ce0ef3
Reviewed-on: https://boringssl-review.googlesource.com/11382
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index b0242be..806b937 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -1934,6 +1934,18 @@
 OPENSSL_EXPORT int SSL_set1_curves(SSL *ssl, const int *curves,
                                    size_t curves_len);
 
+/* SSL_CTX_set1_curves_list sets the preferred curves for |ctx| to be the
+ * colon-separated list |curves|. Each element of |curves| should be a curve
+ * name (e.g. P-256, X25519, ...). It returns one on success and zero on
+ * failure. */
+OPENSSL_EXPORT int SSL_CTX_set1_curves_list(SSL_CTX *ctx, const char *curves);
+
+/* SSL_set1_curves_list sets the preferred curves for |ssl| to be the
+ * colon-separated list |curves|. Each element of |curves| should be a curve
+ * name (e.g. P-256, X25519, ...). It returns one on success and zero on
+ * failure. */
+OPENSSL_EXPORT int SSL_set1_curves_list(SSL *ssl, const char *curves);
+
 /* SSL_CURVE_* define TLS curve IDs. */
 #define SSL_CURVE_SECP256R1 23
 #define SSL_CURVE_SECP384R1 24
@@ -4514,6 +4526,7 @@
 #define SSL_CTRL_SESS_NUMBER doesnt_exist
 #define SSL_CTRL_SET_CHANNEL_ID doesnt_exist
 #define SSL_CTRL_SET_CURVES doesnt_exist
+#define SSL_CTRL_SET_CURVES_LIST doesnt_exist
 #define SSL_CTRL_SET_MAX_CERT_LIST doesnt_exist
 #define SSL_CTRL_SET_MAX_SEND_FRAGMENT doesnt_exist
 #define SSL_CTRL_SET_MSG_CALLBACK doesnt_exist
diff --git a/ssl/internal.h b/ssl/internal.h
index 3745592..9a78a46 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -609,6 +609,11 @@
  * zero. */
 int ssl_nid_to_group_id(uint16_t *out_group_id, int nid);
 
+/* ssl_name_to_group_id looks up the group corresponding to the |name| string
+ * of length |len|. On success, it sets |*out_group_id| to the group ID and
+ * returns one. Otherwise, it returns zero. */
+int ssl_name_to_group_id(uint16_t *out_group_id, const char *name, size_t len);
+
 /* SSL_ECDH_CTX_init sets up |ctx| for use with curve |group_id|. It returns one
  * on success and zero on error. */
 int SSL_ECDH_CTX_init(SSL_ECDH_CTX *ctx, uint16_t group_id);
@@ -1472,6 +1477,13 @@
 int tls1_set_curves(uint16_t **out_group_ids, size_t *out_group_ids_len,
                     const int *curves, size_t ncurves);
 
+/* tls1_set_curves_list converts the string of curves pointed to by |curves|
+ * into a newly allocated array of TLS group IDs. On success, the function
+ * returns one and writes the array to |*out_group_ids| and its size to
+ * |*out_group_ids_len|. Otherwise, it returns zero. */
+int tls1_set_curves_list(uint16_t **out_group_ids, size_t *out_group_ids_len,
+                         const char *curves);
+
 /* tls1_check_ec_cert returns one if |x| is an ECC certificate with curve and
  * point format compatible with the client's preferences. Otherwise it returns
  * zero. */
diff --git a/ssl/ssl_ecdh.c b/ssl/ssl_ecdh.c
index 16599e4..bcb3af4 100644
--- a/ssl/ssl_ecdh.c
+++ b/ssl/ssl_ecdh.c
@@ -521,6 +521,16 @@
   return NULL;
 }
 
+static const SSL_ECDH_METHOD *method_from_name(const char *name, size_t len) {
+  for (size_t i = 0; i < OPENSSL_ARRAY_SIZE(kMethods); i++) {
+    if (len == strlen(kMethods[i].name) &&
+        !strncmp(kMethods[i].name, name, len)) {
+      return &kMethods[i];
+    }
+  }
+  return NULL;
+}
+
 const char* SSL_get_curve_name(uint16_t group_id) {
   const SSL_ECDH_METHOD *method = method_from_group_id(group_id);
   if (method == NULL) {
@@ -538,6 +548,15 @@
   return 1;
 }
 
+int ssl_name_to_group_id(uint16_t *out_group_id, const char *name, size_t len) {
+  const SSL_ECDH_METHOD *method = method_from_name(name, len);
+  if (method == NULL) {
+    return 0;
+  }
+  *out_group_id = method->group_id;
+  return 1;
+}
+
 int SSL_ECDH_CTX_init(SSL_ECDH_CTX *ctx, uint16_t group_id) {
   SSL_ECDH_CTX_cleanup(ctx);
 
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index 63f72ca..b580d95 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -1499,6 +1499,16 @@
                          curves_len);
 }
 
+int SSL_CTX_set1_curves_list(SSL_CTX *ctx, const char *curves) {
+  return tls1_set_curves_list(&ctx->supported_group_list,
+                              &ctx->supported_group_list_len, curves);
+}
+
+int SSL_set1_curves_list(SSL *ssl, const char *curves) {
+  return tls1_set_curves_list(&ssl->supported_group_list,
+                              &ssl->supported_group_list_len, curves);
+}
+
 uint16_t SSL_get_curve_id(const SSL *ssl) {
   /* TODO(davidben): This checks the wrong session if there is a renegotiation in
    * progress. */
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 4ad513e..9455117 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -56,6 +56,13 @@
   std::vector<ExpectedCipher> expected;
 };
 
+struct CurveTest {
+  // The rule string to apply.
+  const char *rule;
+  // The list of expected curves, in order.
+  std::vector<uint16_t> expected;
+};
+
 static const CipherTest kCipherTests[] = {
     // Selecting individual ciphers should work.
     {
@@ -286,6 +293,33 @@
   "CHACHA20",
 };
 
+static const CurveTest kCurveTests[] = {
+  {
+    "P-256",
+    { SSL_CURVE_SECP256R1 },
+  },
+  {
+    "P-256:P-384:P-521:X25519",
+    {
+      SSL_CURVE_SECP256R1,
+      SSL_CURVE_SECP384R1,
+      SSL_CURVE_SECP521R1,
+      SSL_CURVE_X25519,
+    },
+  },
+};
+
+static const char *kBadCurvesLists[] = {
+  "",
+  ":",
+  "::",
+  "P-256::X25519",
+  "RSA:P-256",
+  "P-256:RSA",
+  "X25519:P-256:",
+  ":X25519:P-256",
+};
+
 static void PrintCipherPreferenceList(ssl_cipher_preference_list_st *list) {
   bool in_group = false;
   for (size_t i = 0; i < sk_SSL_CIPHER_num(list->ciphers); i++) {
@@ -408,6 +442,55 @@
   return true;
 }
 
+static bool TestCurveRule(const CurveTest &t) {
+  bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+  if (!ctx) {
+    return false;
+  }
+
+  if (!SSL_CTX_set1_curves_list(ctx.get(), t.rule)) {
+    fprintf(stderr, "Error testing curves list '%s'\n", t.rule);
+    return false;
+  }
+
+  // Compare the two lists.
+  if (ctx->supported_group_list_len != t.expected.size()) {
+    fprintf(stderr, "Error testing curves list '%s': length\n", t.rule);
+    return false;
+  }
+
+  for (size_t i = 0; i < t.expected.size(); i++) {
+    if (t.expected[i] != ctx->supported_group_list[i]) {
+      fprintf(stderr, "Error testing curves list '%s': mismatch\n", t.rule);
+      return false;
+    }
+  }
+
+  return true;
+}
+
+static bool TestCurveRules() {
+  for (const CurveTest &test : kCurveTests) {
+    if (!TestCurveRule(test)) {
+      return false;
+    }
+  }
+
+  for (const char *rule : kBadCurvesLists) {
+    bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(SSLv23_server_method()));
+    if (!ctx) {
+      return false;
+    }
+    if (SSL_CTX_set1_curves_list(ctx.get(), rule)) {
+      fprintf(stderr, "Curves list '%s' unexpectedly succeeded\n", rule);
+      return false;
+    }
+    ERR_clear_error();
+  }
+
+  return true;
+}
+
 // kOpenSSLSession is a serialized SSL_SESSION generated from openssl
 // s_client -sess_out.
 static const char kOpenSSLSession[] =
@@ -2213,6 +2296,7 @@
   CRYPTO_library_init();
 
   if (!TestCipherRules() ||
+      !TestCurveRules() ||
       !TestSSL_SESSIONEncoding(kOpenSSLSession) ||
       !TestSSL_SESSIONEncoding(kCustomSession) ||
       !TestSSL_SESSIONEncoding(kBoringSSLSession) ||
diff --git a/ssl/t1_lib.c b/ssl/t1_lib.c
index da446e0..f6eaeb7 100644
--- a/ssl/t1_lib.c
+++ b/ssl/t1_lib.c
@@ -402,6 +402,49 @@
   return 1;
 }
 
+int tls1_set_curves_list(uint16_t **out_group_ids, size_t *out_group_ids_len,
+                         const char *curves) {
+  uint16_t *group_ids = NULL;
+  size_t ncurves = 0;
+
+  const char *col;
+  const char *ptr = curves;
+
+  do {
+    col = strchr(ptr, ':');
+
+    uint16_t group_id;
+    if (!ssl_name_to_group_id(&group_id, ptr,
+                              col ? (size_t)(col - ptr) : strlen(ptr))) {
+      goto err;
+    }
+
+    uint16_t *new_group_ids = OPENSSL_realloc(group_ids,
+                                              (ncurves + 1) * sizeof(uint16_t));
+    if (new_group_ids == NULL) {
+      goto err;
+    }
+    group_ids = new_group_ids;
+
+    group_ids[ncurves] = group_id;
+    ncurves++;
+
+    if (col) {
+      ptr = col + 1;
+    }
+  } while (col);
+
+  OPENSSL_free(*out_group_ids);
+  *out_group_ids = group_ids;
+  *out_group_ids_len = ncurves;
+
+  return 1;
+
+err:
+  OPENSSL_free(group_ids);
+  return 0;
+}
+
 /* tls1_curve_params_from_ec_key sets |*out_group_id| and |*out_comp_id| to the
  * TLS group ID and point format, respectively, for |ec|. It returns one on
  * success and zero on failure. */