Add tests for bidirectional shutdown.

Now that it even works at all (type = 0 bug aside), add tests for it.
Test both close_notify being received before and after SSL_shutdown is
called. In the latter case, have the peer send some junk to be ignored
to test that works.

Also test that SSL_shutdown fails on unclean shutdown and that quiet
shutdowns ignore it.

BUG=526437

Change-Id: Iff13b08feb03e82f21ecab0c66d5f85aec256137
Reviewed-on: https://boringssl-review.googlesource.com/5769
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index d839f5f..dabe5ec 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -748,6 +748,17 @@
   return ret;
 }
 
+// DoShutdown calls |SSL_shutdown|, resolving any asynchronous operations. It
+// returns the result of the final |SSL_shutdown| call.
+static int DoShutdown(SSL *ssl) {
+  const TestConfig *config = GetConfigPtr(ssl);
+  int ret;
+  do {
+    ret = SSL_shutdown(ssl);
+  } while (config->async && RetryAsync(ssl, ret));
+  return ret;
+}
+
 // CheckHandshakeProperties checks, immediately after |ssl| completes its
 // initial handshake (or False Starts), whether all the properties are
 // consistent with the test configuration and invariants.
@@ -1021,6 +1032,9 @@
     /* Renegotiations are disabled by default. */
     SSL_set_reject_peer_renegotiations(ssl.get(), 0);
   }
+  if (!config->check_close_notify) {
+    SSL_set_quiet_shutdown(ssl.get(), 1);
+  }
 
   int sock = Connect(config->port);
   if (sock == -1) {
@@ -1157,44 +1171,45 @@
         return false;
       }
     }
-    for (;;) {
-      uint8_t buf[512];
-      int n = DoRead(ssl.get(), buf, sizeof(buf));
-      int err = SSL_get_error(ssl.get(), n);
-      if (err == SSL_ERROR_ZERO_RETURN ||
-          (n == 0 && err == SSL_ERROR_SYSCALL)) {
-        if (n != 0) {
+    if (!config->shim_shuts_down) {
+      for (;;) {
+        uint8_t buf[512];
+        int n = DoRead(ssl.get(), buf, sizeof(buf));
+        int err = SSL_get_error(ssl.get(), n);
+        if (err == SSL_ERROR_ZERO_RETURN ||
+            (n == 0 && err == SSL_ERROR_SYSCALL)) {
+          if (n != 0) {
+            fprintf(stderr, "Invalid SSL_get_error output\n");
+            return false;
+          }
+          // Stop on either clean or unclean shutdown.
+          break;
+        } else if (err != SSL_ERROR_NONE) {
+          if (n > 0) {
+            fprintf(stderr, "Invalid SSL_get_error output\n");
+            return false;
+          }
+          return false;
+        }
+        // Successfully read data.
+        if (n <= 0) {
           fprintf(stderr, "Invalid SSL_get_error output\n");
           return false;
         }
-        // Accept shutdowns with or without close_notify.
-        // TODO(davidben): Write tests which distinguish these two cases.
-        break;
-      } else if (err != SSL_ERROR_NONE) {
-        if (n > 0) {
-          fprintf(stderr, "Invalid SSL_get_error output\n");
+
+        // After a successful read, with or without False Start, the handshake
+        // must be complete.
+        if (!GetTestState(ssl.get())->handshake_done) {
+          fprintf(stderr, "handshake was not completed after SSL_read\n");
           return false;
         }
-        return false;
-      }
-      // Successfully read data.
-      if (n <= 0) {
-        fprintf(stderr, "Invalid SSL_get_error output\n");
-        return false;
-      }
 
-      // After a successful read, with or without False Start, the handshake
-      // must be complete.
-      if (!GetTestState(ssl.get())->handshake_done) {
-        fprintf(stderr, "handshake was not completed after SSL_read\n");
-        return false;
-      }
-
-      for (int i = 0; i < n; i++) {
-        buf[i] ^= 0xff;
-      }
-      if (WriteAll(ssl.get(), buf, n) < 0) {
-        return false;
+        for (int i = 0; i < n; i++) {
+          buf[i] ^= 0xff;
+        }
+        if (WriteAll(ssl.get(), buf, n) < 0) {
+          return false;
+        }
       }
     }
   }
@@ -1210,7 +1225,24 @@
     out_session->reset(SSL_get1_session(ssl.get()));
   }
 
-  SSL_shutdown(ssl.get());
+  ret = DoShutdown(ssl.get());
+
+  if (config->shim_shuts_down && config->check_close_notify) {
+    // We initiate shutdown, so |SSL_shutdown| will return in two stages. First
+    // it returns zero when our close_notify is sent, then one when the peer's
+    // is received.
+    if (ret != 0) {
+      fprintf(stderr, "Unexpected SSL_shutdown result: %d != 0\n", ret);
+      return false;
+    }
+    ret = DoShutdown(ssl.get());
+  }
+
+  if (ret != 1) {
+    fprintf(stderr, "Unexpected SSL_shutdown result: %d != 1\n", ret);
+    return false;
+  }
+
   return true;
 }
 
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index f841293..2b7e29b 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -753,6 +753,15 @@
 	// ExpectedCustomExtension, if not nil, contains the expected contents
 	// of a custom extension.
 	ExpectedCustomExtension *string
+
+	// NoCloseNotify, if true, causes the close_notify alert to be skipped
+	// on connection shutdown.
+	NoCloseNotify bool
+
+	// ExpectCloseNotify, if true, requires a close_notify from the peer on
+	// shutdown. Records from the peer received after close_notify is sent
+	// are not discard.
+	ExpectCloseNotify bool
 }
 
 func (c *Config) serverInit() {
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index f09cb7c..42bc840 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -648,10 +648,10 @@
 	if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
 		// RFC suggests that EOF without an alertCloseNotify is
 		// an error, but popular web sites seem to do this,
-		// so we can't make it an error.
-		// if err == io.EOF {
-		// 	err = io.ErrUnexpectedEOF
-		// }
+		// so we can't make it an error, outside of tests.
+		if err == io.EOF && c.config.Bugs.ExpectCloseNotify {
+			err = io.ErrUnexpectedEOF
+		}
 		if e, ok := err.(net.Error); !ok || !e.Temporary() {
 			c.in.setErrorLocked(err)
 		}
@@ -740,6 +740,10 @@
 			c.sendAlert(alertInternalError)
 			return c.in.setErrorLocked(errors.New("tls: application data record requested before handshake complete"))
 		}
+	case recordTypeAlert:
+		// Looking for a close_notify. Note: unlike a real
+		// implementation, this is not tolerant of additional records.
+		// See the documentation for ExpectCloseNotify.
 	}
 
 Again:
@@ -802,7 +806,7 @@
 			// A client might need to process a HelloRequest from
 			// the server, thus receiving a handshake message when
 			// application data is expected is ok.
-			if !c.isClient {
+			if !c.isClient || want != recordTypeApplicationData {
 				return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
 			}
 		}
@@ -1260,10 +1264,22 @@
 
 	c.handshakeMutex.Lock()
 	defer c.handshakeMutex.Unlock()
-	if c.handshakeComplete {
+	if c.handshakeComplete && !c.config.Bugs.NoCloseNotify {
 		alertErr = c.sendAlert(alertCloseNotify)
 	}
 
+	// Consume a close_notify from the peer if one hasn't been received
+	// already. This avoids the peer from failing |SSL_shutdown| due to a
+	// write failing.
+	if c.handshakeComplete && alertErr == nil && c.config.Bugs.ExpectCloseNotify {
+		for c.in.error() == nil {
+			c.readRecord(recordTypeAlert)
+		}
+		if c.in.error() != io.EOF {
+			alertErr = c.in.error()
+		}
+	}
+
 	if err := c.conn.Close(); err != nil {
 		return err
 	}
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 950c02a..7ada5f1 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -191,6 +191,10 @@
 	// shimWritesFirst controls whether the shim sends an initial "hello"
 	// message before doing a roundtrip with the runner.
 	shimWritesFirst bool
+	// shimShutsDown, if true, runs a test where the shim shuts down the
+	// connection immediately after the handshake rather than echoing
+	// messages from the runner.
+	shimShutsDown bool
 	// renegotiate indicates the the connection should be renegotiated
 	// during the exchange.
 	renegotiate bool
@@ -270,6 +274,7 @@
 			tlsConn = Client(conn, config)
 		}
 	}
+	defer tlsConn.Close()
 
 	if err := tlsConn.Handshake(); err != nil {
 		return err
@@ -420,6 +425,11 @@
 			tlsConn.SendAlert(alertLevelWarning, alertUnexpectedMessage)
 		}
 
+		if test.shimShutsDown {
+			// The shim will not respond.
+			continue
+		}
+
 		buf := make([]byte, len(testMessage))
 		if test.protocol == dtls {
 			bufTmp := make([]byte, len(buf)+1)
@@ -547,6 +557,10 @@
 		flags = append(flags, "-shim-writes-first")
 	}
 
+	if test.shimShutsDown {
+		flags = append(flags, "-shim-shuts-down")
+	}
+
 	if test.exportKeyingMaterial > 0 {
 		flags = append(flags, "-export-keying-material", strconv.Itoa(test.exportKeyingMaterial))
 		flags = append(flags, "-export-label", test.exportLabel)
@@ -1847,8 +1861,29 @@
 			noSessionCache: true,
 			flags:          []string{"-expect-no-session"},
 		},
+		{
+			name: "Unclean-Shutdown",
+			config: Config{
+				Bugs: ProtocolBugs{
+					NoCloseNotify:     true,
+					ExpectCloseNotify: true,
+				},
+			},
+			shimShutsDown: true,
+			flags:         []string{"-check-close-notify"},
+			shouldFail:    true,
+			expectedError: "Unexpected SSL_shutdown result: -1 != 1",
+		},
+		{
+			name: "Unclean-Shutdown-Ignored",
+			config: Config{
+				Bugs: ProtocolBugs{
+					NoCloseNotify: true,
+				},
+			},
+			shimShutsDown: true,
+		},
 	}
-
 	testCases = append(testCases, basicTests...)
 }
 
@@ -2561,6 +2596,33 @@
 			resumeSession:   true,
 			expectChannelID: true,
 		})
+
+		// Bidirectional shutdown with the runner initiating.
+		tests = append(tests, testCase{
+			name: "Shutdown-Runner",
+			config: Config{
+				Bugs: ProtocolBugs{
+					ExpectCloseNotify: true,
+				},
+			},
+			flags: []string{"-check-close-notify"},
+		})
+
+		// Bidirectional shutdown with the shim initiating. The runner,
+		// in the meantime, sends garbage before the close_notify which
+		// the shim must ignore.
+		tests = append(tests, testCase{
+			name: "Shutdown-Shim",
+			config: Config{
+				Bugs: ProtocolBugs{
+					ExpectCloseNotify: true,
+				},
+			},
+			shimShutsDown:     true,
+			sendEmptyRecords:  1,
+			sendWarningAlerts: 1,
+			flags:             []string{"-check-close-notify"},
+		})
 	} else {
 		tests = append(tests, testCase{
 			name: "SkipHelloVerifyRequest",
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index 8c4b420..4191b2b 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -90,6 +90,8 @@
     &TestConfig::enable_server_custom_extension },
   { "-custom-extension-skip", &TestConfig::custom_extension_skip },
   { "-custom-extension-fail-add", &TestConfig::custom_extension_fail_add },
+  { "-check-close-notify", &TestConfig::check_close_notify },
+  { "-shim-shuts-down", &TestConfig::shim_shuts_down },
 };
 
 const Flag<std::string> kStringFlags[] = {
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index 4418ed3..c7bdab3 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -87,6 +87,8 @@
   bool custom_extension_skip = false;
   bool custom_extension_fail_add = false;
   std::string ocsp_response;
+  bool check_close_notify = false;
+  bool shim_shuts_down = false;
 };
 
 bool ParseConfig(int argc, char **argv, TestConfig *out_config);