Add a callback for DDoS protection.

This callback receives information about the ClientHello and can decide
whether or not to allow the handshake to continue.

Change-Id: I21be28335fa74fedb5b73a310ee24310670fc923
Reviewed-on: https://boringssl-review.googlesource.com/3721
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 19ed07b..82d632a 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -809,6 +809,11 @@
    * will not call the callback a second time. */
   int (*select_certificate_cb)(const struct ssl_early_callback_ctx *);
 
+  /* dos_protection_cb is called once the resumption decision for a ClientHello
+   * has been made. It returns one to continue the handshake or zero to
+   * abort. */
+  int (*dos_protection_cb) (const struct ssl_early_callback_ctx *);
+
   /* quiet_shutdown is true if the connection should not send a close_notify on
    * shutdown. */
   int quiet_shutdown;
@@ -2228,6 +2233,12 @@
 OPENSSL_EXPORT int SSL_cache_hit(SSL *s);
 OPENSSL_EXPORT int SSL_is_server(SSL *s);
 
+/* SSL_CTX_set_dos_protection_cb sets a callback that is called once the
+ * resumption decision for a ClientHello has been made. It can return 1 to
+ * allow the handshake to continue or zero to cause the handshake to abort. */
+void SSL_CTX_set_dos_protection_cb(
+    SSL_CTX *ctx, int (*cb)(const struct ssl_early_callback_ctx *));
+
 /* SSL_get_structure_sizes returns the sizes of the SSL, SSL_CTX and
  * SSL_SESSION structures so that a test can ensure that outside code agrees on
  * these values. */
diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c
index ddff5c5..535c1df 100644
--- a/ssl/s3_srvr.c
+++ b/ssl/s3_srvr.c
@@ -1086,6 +1086,13 @@
     }
   }
 
+  if (s->ctx->dos_protection_cb != NULL && s->ctx->dos_protection_cb(&early_ctx) == 0) {
+    /* Connection rejected for DOS reasons. */
+    al = SSL_AD_ACCESS_DENIED;
+    OPENSSL_PUT_ERROR(SSL, ssl3_get_client_hello, SSL_R_CONNECTION_REJECTED);
+    goto f_err;
+  }
+
   if (!CBS_get_u16_length_prefixed(&client_hello, &cipher_suites) ||
       !CBS_get_u8_length_prefixed(&client_hello, &compression_methods) ||
       CBS_len(&compression_methods) == 0) {
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index 68f4ccb..46b9cb6 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -3128,6 +3128,11 @@
 
 int SSL_is_server(SSL *s) { return s->server; }
 
+void SSL_CTX_set_dos_protection_cb(
+    SSL_CTX *ctx, int (*cb)(const struct ssl_early_callback_ctx *)) {
+  ctx->dos_protection_cb = cb;
+}
+
 void SSL_enable_fastradio_padding(SSL *s, char on_off) {
   s->fastradio_padding = on_off;
 }
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index ad4fa1a..3e712fc 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -308,6 +308,18 @@
   }
 }
 
+static int DDoSCallback(const struct ssl_early_callback_ctx *early_context) {
+  const TestConfig *config = GetConfigPtr(early_context->ssl);
+  static int callback_num = 0;
+
+  callback_num++;
+  if (config->fail_ddos_callback ||
+      (config->fail_second_ddos_callback && callback_num == 2)) {
+    return 0;
+  }
+  return 1;
+}
+
 // Connect returns a new socket connected to localhost on |port| or -1 on
 // error.
 static int Connect(uint16_t port) {
@@ -600,6 +612,9 @@
     SSL_set_options(ssl.get(), SSL_OP_NO_QUERY_MTU);
     SSL_set_mtu(ssl.get(), config->mtu);
   }
+  if (config->install_ddos_callback) {
+    SSL_CTX_set_dos_protection_cb(ssl_ctx, DDoSCallback);
+  }
 
   int sock = Connect(config->port);
   if (sock == -1) {
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index d1dbe87..3f26786 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -2037,6 +2037,37 @@
 	}
 }
 
+func addDDoSCallbackTests() {
+	// DDoS callback.
+
+	for _, resume := range []bool{false, true} {
+		suffix := "Resume"
+		if resume {
+			suffix = "No" + suffix
+		}
+
+		testCases = append(testCases, testCase{
+			testType:      serverTest,
+			name:          "Server-DDoS-OK-" + suffix,
+			flags:         []string{"-install-ddos-callback"},
+			resumeSession: resume,
+		})
+
+		failFlag := "-fail-ddos-callback"
+		if resume {
+			failFlag = "-fail-second-ddos-callback"
+		}
+		testCases = append(testCases, testCase{
+			testType:      serverTest,
+			name:          "Server-DDoS-Reject-" + suffix,
+			flags:         []string{"-install-ddos-callback", failFlag},
+			resumeSession: resume,
+			shouldFail:    true,
+			expectedError: ":CONNECTION_REJECTED:",
+		})
+	}
+}
+
 func addVersionNegotiationTests() {
 	for i, shimVers := range tlsVersions {
 		// Assemble flags to disable all newer versions on the shim.
@@ -3029,6 +3060,7 @@
 	addCBCPaddingTests()
 	addCBCSplittingTests()
 	addClientAuthTests()
+	addDDoSCallbackTests()
 	addVersionNegotiationTests()
 	addMinimumVersionTests()
 	addD5BugTests()
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index 4db72b4..cbfc10f 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -75,6 +75,9 @@
   { "-implicit-handshake", &TestConfig::implicit_handshake },
   { "-use-early-callback", &TestConfig::use_early_callback },
   { "-fail-early-callback", &TestConfig::fail_early_callback },
+  { "-install-ddos-callback", &TestConfig::install_ddos_callback },
+  { "-fail-ddos-callback", &TestConfig::fail_ddos_callback },
+  { "-fail-second-ddos-callback", &TestConfig::fail_second_ddos_callback },
 };
 
 const Flag<std::string> kStringFlags[] = {
@@ -142,7 +145,9 @@
       mtu(0),
       implicit_handshake(false),
       use_early_callback(false),
-      fail_early_callback(false) {
+      fail_early_callback(false),
+      install_ddos_callback(false),
+      fail_ddos_callback(false) {
 }
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_config) {
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index a54fb23..380e845 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -70,6 +70,9 @@
   bool implicit_handshake;
   bool use_early_callback;
   bool fail_early_callback;
+  bool install_ddos_callback;
+  bool fail_ddos_callback;
+  bool fail_second_ddos_callback;
 };
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_config);