Tidy up shutdown state.

The existing logic gets confused in a number of cases around close_notify vs.
fatal alert. SSL_shutdown, while still pushing to the error queue, will fail to
notice alerts. We also get confused if we try to send a fatal alert when we've
already sent something else.

Change-Id: I9b1d217fbf1ee8a9c59efbebba60165b7de9689e
Reviewed-on: https://boringssl-review.googlesource.com/7952
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_pkt.c b/ssl/d1_pkt.c
index 96418df..ed64676 100644
--- a/ssl/d1_pkt.c
+++ b/ssl/d1_pkt.c
@@ -131,12 +131,14 @@
  * more data is needed. */
 static int dtls1_get_record(SSL *ssl) {
 again:
-  if (ssl->shutdown & SSL_RECEIVED_SHUTDOWN) {
-    if (ssl->s3->clean_shutdown) {
+  switch (ssl->s3->recv_shutdown) {
+    case ssl_shutdown_none:
+      break;
+    case ssl_shutdown_fatal_alert:
+      OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
+      return -1;
+    case ssl_shutdown_close_notify:
       return 0;
-    }
-    OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
-    return -1;
   }
 
   /* Read a new packet if there is no unconsumed one. */
@@ -225,7 +227,9 @@
    * alerts also aren't delivered reliably, so we may even time out because the
    * peer never received our close_notify. Report to the caller that the channel
    * has fully shut down. */
-  ssl->shutdown |= SSL_RECEIVED_SHUTDOWN;
+  if (ssl->s3->recv_shutdown == ssl_shutdown_none) {
+    ssl->s3->recv_shutdown = ssl_shutdown_close_notify;
+  }
 }
 
 /* Return up to 'len' payload bytes received in 'type' records.
@@ -354,8 +358,7 @@
 
     if (alert_level == SSL3_AL_WARNING) {
       if (alert_descr == SSL_AD_CLOSE_NOTIFY) {
-        ssl->s3->clean_shutdown = 1;
-        ssl->shutdown |= SSL_RECEIVED_SHUTDOWN;
+        ssl->s3->recv_shutdown = ssl_shutdown_close_notify;
         return 0;
       }
     } else if (alert_level == SSL3_AL_FATAL) {
@@ -364,7 +367,7 @@
       OPENSSL_PUT_ERROR(SSL, SSL_AD_REASON_OFFSET + alert_descr);
       BIO_snprintf(tmp, sizeof tmp, "%d", alert_descr);
       ERR_add_error_data(2, "SSL alert number ", tmp);
-      ssl->shutdown |= SSL_RECEIVED_SHUTDOWN;
+      ssl->s3->recv_shutdown = ssl_shutdown_fatal_alert;
       SSL_CTX_remove_session(ssl->ctx, ssl->session);
       return 0;
     } else {
diff --git a/ssl/s3_pkt.c b/ssl/s3_pkt.c
index a034c29..bacbfe6 100644
--- a/ssl/s3_pkt.c
+++ b/ssl/s3_pkt.c
@@ -133,12 +133,14 @@
 static int ssl3_get_record(SSL *ssl) {
   int ret;
 again:
-  if (ssl->shutdown & SSL_RECEIVED_SHUTDOWN) {
-    if (ssl->s3->clean_shutdown) {
+  switch (ssl->s3->recv_shutdown) {
+    case ssl_shutdown_none:
+      break;
+    case ssl_shutdown_fatal_alert:
+      OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
+      return -1;
+    case ssl_shutdown_close_notify:
       return 0;
-    }
-    OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
-    return -1;
   }
 
   /* Ensure the buffer is large enough to decrypt in-place. */
@@ -547,8 +549,7 @@
 
     if (alert_level == SSL3_AL_WARNING) {
       if (alert_descr == SSL_AD_CLOSE_NOTIFY) {
-        ssl->s3->clean_shutdown = 1;
-        ssl->shutdown |= SSL_RECEIVED_SHUTDOWN;
+        ssl->s3->recv_shutdown = ssl_shutdown_close_notify;
         return 0;
       }
 
@@ -564,7 +565,7 @@
       OPENSSL_PUT_ERROR(SSL, SSL_AD_REASON_OFFSET + alert_descr);
       BIO_snprintf(tmp, sizeof(tmp), "%d", alert_descr);
       ERR_add_error_data(2, "SSL alert number ", tmp);
-      ssl->shutdown |= SSL_RECEIVED_SHUTDOWN;
+      ssl->s3->recv_shutdown = ssl_shutdown_fatal_alert;
       SSL_CTX_remove_session(ssl->ctx, ssl->session);
       return 0;
     } else {
@@ -576,7 +577,7 @@
     goto start;
   }
 
-  if (ssl->shutdown & SSL_SENT_SHUTDOWN) {
+  if (ssl->s3->send_shutdown == ssl_shutdown_close_notify) {
     /* close_notify has been sent, so discard all records other than alerts. */
     rr->length = 0;
     goto start;
@@ -592,9 +593,19 @@
 }
 
 int ssl3_send_alert(SSL *ssl, int level, int desc) {
-  /* If a fatal one, remove from cache */
-  if (level == 2 && ssl->session != NULL) {
-    SSL_CTX_remove_session(ssl->ctx, ssl->session);
+  /* It is illegal to send an alert when we've already sent a closing one. */
+  if (ssl->s3->send_shutdown != ssl_shutdown_none) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
+    return -1;
+  }
+
+  if (level == SSL3_AL_FATAL) {
+    if (ssl->session != NULL) {
+      SSL_CTX_remove_session(ssl->ctx, ssl->session);
+    }
+    ssl->s3->send_shutdown = ssl_shutdown_fatal_alert;
+  } else if (level == SSL3_AL_WARNING && desc == SSL_AD_CLOSE_NOTIFY) {
+    ssl->s3->send_shutdown = ssl_shutdown_close_notify;
   }
 
   ssl->s3->alert_dispatch = 1;
@@ -606,8 +617,7 @@
     return ssl->method->ssl_dispatch_alert(ssl);
   }
 
-  /* else data is still being written out, we will get written some time in the
-   * future */
+  /* The alert will be dispatched later. */
   return -1;
 }
 
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index ade7120..25e1349 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -516,14 +516,12 @@
 
 void SSL_set_connect_state(SSL *ssl) {
   ssl->server = 0;
-  ssl->shutdown = 0;
   ssl->state = SSL_ST_CONNECT;
   ssl->handshake_func = ssl->method->ssl_connect;
 }
 
 void SSL_set_accept_state(SSL *ssl) {
   ssl->server = 1;
-  ssl->shutdown = 0;
   ssl->state = SSL_ST_ACCEPT;
   ssl->handshake_func = ssl->method->ssl_accept;
 }
@@ -637,7 +635,7 @@
     return -1;
   }
 
-  if (ssl->shutdown & SSL_SENT_SHUTDOWN) {
+  if (ssl->s3->send_shutdown != ssl_shutdown_none) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
     return -1;
   }
@@ -662,11 +660,6 @@
   /* Functions which use SSL_get_error must clear the error queue on entry. */
   ERR_clear_error();
 
-  /* Note that this function behaves differently from what one might expect.
-   * Return values are 0 for no success (yet), 1 for success; but calling it
-   * once is usually not enough, even if blocking I/O is used (see
-   * ssl3_shutdown). */
-
   if (ssl->handshake_func == NULL) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_UNINITIALIZED);
     return -1;
@@ -678,44 +671,37 @@
     return -1;
   }
 
-  /* Do nothing if configured not to send a close_notify. */
   if (ssl->quiet_shutdown) {
-    ssl->shutdown = SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN;
+    /* Do nothing if configured not to send a close_notify. */
+    ssl->s3->send_shutdown = ssl_shutdown_close_notify;
+    ssl->s3->recv_shutdown = ssl_shutdown_close_notify;
     return 1;
   }
 
-  if (!(ssl->shutdown & SSL_SENT_SHUTDOWN)) {
-    ssl->shutdown |= SSL_SENT_SHUTDOWN;
-    ssl3_send_alert(ssl, SSL3_AL_WARNING, SSL_AD_CLOSE_NOTIFY);
+  /* This function completes in two stages. It sends a close_notify and then it
+   * waits for a close_notify to come in. Perform exactly one action and return
+   * whether or not it succeeds. */
 
-    /* our shutdown alert has been sent now, and if it still needs to be
-     * written, ssl->s3->alert_dispatch will be true */
-    if (ssl->s3->alert_dispatch) {
-      return -1; /* return WANT_WRITE */
+  if (ssl->s3->send_shutdown != ssl_shutdown_close_notify) {
+    /* Send a close_notify. */
+    if (ssl3_send_alert(ssl, SSL3_AL_WARNING, SSL_AD_CLOSE_NOTIFY) <= 0) {
+      return -1;
     }
   } else if (ssl->s3->alert_dispatch) {
-    /* resend it if not sent */
-    int ret = ssl->method->ssl_dispatch_alert(ssl);
-    if (ret == -1) {
-      /* we only get to return -1 here the 2nd/Nth invocation, we must  have
-       * already signalled return 0 upon a previous invoation, return
-       * WANT_WRITE */
-      return ret;
+    /* Finish sending the close_notify. */
+    if (ssl->method->ssl_dispatch_alert(ssl) <= 0) {
+      return -1;
     }
-  } else if (!(ssl->shutdown & SSL_RECEIVED_SHUTDOWN)) {
-    /* If we are waiting for a close from our peer, we are closed */
+  } else if (ssl->s3->recv_shutdown != ssl_shutdown_close_notify) {
+    /* Wait for the peer's close_notify. */
     ssl->method->ssl_read_close_notify(ssl);
-    if (!(ssl->shutdown & SSL_RECEIVED_SHUTDOWN)) {
-      return -1; /* return WANT_READ */
+    if (ssl->s3->recv_shutdown != ssl_shutdown_close_notify) {
+      return -1;
     }
   }
 
-  if (ssl->shutdown == (SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN) &&
-      !ssl->s3->alert_dispatch) {
-    return 1;
-  } else {
-    return 0;
-  }
+  /* Return 0 for unidirectional shutdown and 1 for bidirectional shutdown. */
+  return ssl->s3->recv_shutdown == ssl_shutdown_close_notify;
 }
 
 int SSL_get_error(const SSL *ssl, int ret_code) {
@@ -738,8 +724,7 @@
   }
 
   if (ret_code == 0) {
-    if ((ssl->shutdown & SSL_RECEIVED_SHUTDOWN) && ssl->s3->clean_shutdown) {
-      /* The socket was cleanly shut down with a close_notify. */
+    if (ssl->s3->recv_shutdown == ssl_shutdown_close_notify) {
       return SSL_ERROR_ZERO_RETURN;
     }
     /* An EOF was observed which violates the protocol, and the underlying
@@ -1929,12 +1914,32 @@
 void SSL_set_shutdown(SSL *ssl, int mode) {
   /* It is an error to clear any bits that have already been set. (We can't try
    * to get a second close_notify or send two.) */
-  assert((ssl->shutdown & mode) == ssl->shutdown);
+  assert((SSL_get_shutdown(ssl) & mode) == SSL_get_shutdown(ssl));
 
-  ssl->shutdown |= mode;
+  if (mode & SSL_RECEIVED_SHUTDOWN &&
+      ssl->s3->recv_shutdown == ssl_shutdown_none) {
+    ssl->s3->recv_shutdown = ssl_shutdown_close_notify;
+  }
+
+  if (mode & SSL_SENT_SHUTDOWN &&
+      ssl->s3->send_shutdown == ssl_shutdown_none) {
+    ssl->s3->send_shutdown = ssl_shutdown_close_notify;
+  }
 }
 
-int SSL_get_shutdown(const SSL *ssl) { return ssl->shutdown; }
+int SSL_get_shutdown(const SSL *ssl) {
+  int ret = 0;
+  if (ssl->s3->recv_shutdown != ssl_shutdown_none) {
+    /* Historically, OpenSSL set |SSL_RECEIVED_SHUTDOWN| on both close_notify
+     * and fatal alert. */
+    ret |= SSL_RECEIVED_SHUTDOWN;
+  }
+  if (ssl->s3->send_shutdown == ssl_shutdown_close_notify) {
+    /* Historically, OpenSSL set |SSL_SENT_SHUTDOWN| on only close_notify. */
+    ret |= SSL_SENT_SHUTDOWN;
+  }
+  return ret;
+}
 
 int SSL_version(const SSL *ssl) { return ssl->version; }
 
@@ -2642,7 +2647,6 @@
   }
 
   ssl->hit = 0;
-  ssl->shutdown = 0;
 
   /* SSL_clear may be called before or after the |ssl| is initialized in either
    * accept or connect state. In the latter case, SSL_clear should preserve the
diff --git a/ssl/ssl_session.c b/ssl/ssl_session.c
index 12d065e..009693b 100644
--- a/ssl/ssl_session.c
+++ b/ssl/ssl_session.c
@@ -658,7 +658,8 @@
 }
 
 int ssl_clear_bad_session(SSL *ssl) {
-  if (ssl->session != NULL && !(ssl->shutdown & SSL_SENT_SHUTDOWN) &&
+  if (ssl->session != NULL &&
+      ssl->s3->send_shutdown != ssl_shutdown_close_notify &&
       !SSL_in_init(ssl)) {
     SSL_CTX_remove_session(ssl->ctx, ssl->session);
     return 1;
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index 679969d..57b7b29 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -787,6 +787,10 @@
 	// on connection shutdown.
 	NoCloseNotify bool
 
+	// SendAlertOnShutdown, if non-zero, is the alert to send instead of
+	// close_notify on shutdown.
+	SendAlertOnShutdown alert
+
 	// ExpectCloseNotify, if true, requires a close_notify from the peer on
 	// shutdown. Records from the peer received after close_notify is sent
 	// are not discard.
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index 43548e8..36ca202 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -1292,7 +1292,11 @@
 	c.handshakeMutex.Lock()
 	defer c.handshakeMutex.Unlock()
 	if c.handshakeComplete && !c.config.Bugs.NoCloseNotify {
-		alertErr = c.sendAlert(alertCloseNotify)
+		alert := alertCloseNotify
+		if c.config.Bugs.SendAlertOnShutdown != 0 {
+			alert = c.config.Bugs.SendAlertOnShutdown
+		}
+		alertErr = c.sendAlert(alert)
 	}
 
 	// Consume a close_notify from the peer if one hasn't been received
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index f9cf3d6..eb1efdf 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -2061,6 +2061,19 @@
 			shimShutsDown: true,
 		},
 		{
+			name: "Unclean-Shutdown-Alert",
+			config: Config{
+				Bugs: ProtocolBugs{
+					SendAlertOnShutdown: alertDecompressionFailure,
+					ExpectCloseNotify:   true,
+				},
+			},
+			shimShutsDown: true,
+			flags:         []string{"-check-close-notify"},
+			shouldFail:    true,
+			expectedError: ":SSLV3_ALERT_DECOMPRESSION_FAILURE:",
+		},
+		{
 			name: "LargePlaintext",
 			config: Config{
 				Bugs: ProtocolBugs{