Handle blocked writes in bssl client/server.

On Windows, just switching the socket to blocking doesn't work. Instead,
switch the stdin half of the waiter to waiting for either socket write
or stdin read, depending on whether we're in the middle of trying to
write a buffer.

Change-Id: I81414898f0491e78e6ab5b28c12148a3909ec1e0
Reviewed-on: https://boringssl-review.googlesource.com/28167
Commit-Queue: Steven Valdez <svaldez@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/tool/transport_common.cc b/tool/transport_common.cc
index 1e9b72f..f1a944a 100644
--- a/tool/transport_common.cc
+++ b/tool/transport_common.cc
@@ -365,6 +365,11 @@
   return ok;
 }
 
+enum class StdinWait {
+  kStdinRead,
+  kSocketWrite,
+};
+
 #if !defined(OPENSSL_WINDOWS)
 
 // SocketWaiter abstracts waiting for either the socket or stdin to be readable
@@ -380,23 +385,28 @@
 
   // Wait waits for at least on of the socket or stdin or be ready. On success,
   // it sets |*socket_ready| and |*stdin_ready| to whether the respective
-  // objects are readable and returns true. On error, it returns false.
-  bool Wait(bool *socket_ready, bool *stdin_ready) {
+  // objects are readable and returns true. On error, it returns false. stdin's
+  // readiness may either be the socket being writable or stdin being readable,
+  // depending on |stdin_wait|.
+  bool Wait(StdinWait stdin_wait, bool *socket_ready, bool *stdin_ready) {
     *socket_ready = true;
     *stdin_ready = false;
 
-    fd_set read_fds;
+    fd_set read_fds, write_fds;
     FD_ZERO(&read_fds);
-    if (stdin_open_) {
+    FD_ZERO(&write_fds);
+    if (stdin_wait == StdinWait::kSocketWrite) {
+      FD_SET(sock_, &write_fds);
+    } else if (stdin_open_) {
       FD_SET(STDIN_FILENO, &read_fds);
     }
     FD_SET(sock_, &read_fds);
-    if (select(sock_ + 1, &read_fds, NULL, NULL, NULL) <= 0) {
+    if (select(sock_ + 1, &read_fds, &write_fds, NULL, NULL) <= 0) {
       perror("select");
       return false;
     }
 
-    if (FD_ISSET(STDIN_FILENO, &read_fds)) {
+    if (FD_ISSET(STDIN_FILENO, &read_fds) || FD_ISSET(sock_, &write_fds)) {
       *stdin_ready = true;
     }
     if (FD_ISSET(sock_, &read_fds)) {
@@ -522,20 +532,30 @@
     return true;
   }
 
-  bool Wait(bool *socket_ready, bool *stdin_ready) {
+  bool Wait(StdinWait stdin_wait, bool *socket_ready, bool *stdin_ready) {
     *socket_ready = true;
     *stdin_ready = false;
 
-    ScopedWSAEVENT sock_event(WSACreateEvent());
-    if (!sock_event ||
-        WSAEventSelect(sock_, sock_event.get(), FD_READ | FD_CLOSE) != 0) {
-      PrintSocketError("Error waiting for socket");
+    ScopedWSAEVENT sock_read_event(WSACreateEvent());
+    if (!sock_read_event ||
+        WSAEventSelect(sock_, sock_read_event.get(), FD_READ | FD_CLOSE) != 0) {
+      PrintSocketError("Error waiting for socket read");
       return false;
     }
 
     DWORD count = 1;
-    WSAEVENT events[2] = {sock_event.get(), WSA_INVALID_EVENT};
-    if (listen_stdin_) {
+    WSAEVENT events[3] = {sock_read_event.get(), WSA_INVALID_EVENT};
+    ScopedWSAEVENT sock_write_event;
+    if (stdin_wait == StdinWait::kSocketWrite) {
+      sock_write_event.reset(WSACreateEvent());
+      if (!sock_write_event || WSAEventSelect(sock_, sock_write_event.get(),
+                                              FD_WRITE | FD_CLOSE) != 0) {
+        PrintSocketError("Error waiting for socket write");
+        return false;
+      }
+      events[1] = sock_write_event.get();
+      count++;
+    } else if (listen_stdin_) {
       events[1] = stdin_->event.get();
       count++;
     }
@@ -651,51 +671,49 @@
   if (!waiter.Init()) {
     return false;
   }
+
+  uint8_t pending_write[512];
+  size_t pending_write_len = 0;
   for (;;) {
     bool socket_ready = false;
     bool stdin_ready = false;
-    if (!waiter.Wait(&socket_ready, &stdin_ready)) {
+    if (!waiter.Wait(pending_write_len == 0 ? StdinWait::kStdinRead
+                                            : StdinWait::kSocketWrite,
+                     &socket_ready, &stdin_ready)) {
       return false;
     }
 
     if (stdin_ready) {
-      uint8_t buffer[512];
-      size_t n;
-      if (!waiter.ReadStdin(buffer, &n, sizeof(buffer))) {
-        return false;
-      }
-      if (n == 0) {
-#if !defined(OPENSSL_WINDOWS)
-        shutdown(sock, SHUT_WR);
-#else
-        shutdown(sock, SD_SEND);
-#endif
-        continue;
+      if (pending_write_len == 0) {
+        if (!waiter.ReadStdin(pending_write, &pending_write_len,
+                              sizeof(pending_write))) {
+          return false;
+        }
+        if (pending_write_len == 0) {
+  #if !defined(OPENSSL_WINDOWS)
+          shutdown(sock, SHUT_WR);
+  #else
+          shutdown(sock, SD_SEND);
+  #endif
+          continue;
+        }
       }
 
-      // TODO(davidben): On Windows, |WSAEventSocket| sets |sock| to non-
-      // blocking and forbids setting it back to blocking. Rather than toggle
-      // the blocking state, loop waiting for the socket to be writable.
-#if !defined(OPENSSL_WINDOWS)
-      if (!SocketSetNonBlocking(sock, false)) {
-        return false;
-      }
-#endif
-      int ssl_ret = SSL_write(ssl, buffer, static_cast<int>(n));
+      int ssl_ret =
+          SSL_write(ssl, pending_write, static_cast<int>(pending_write_len));
       if (ssl_ret <= 0) {
         int ssl_err = SSL_get_error(ssl, ssl_ret);
+        if (ssl_err == SSL_ERROR_WANT_WRITE) {
+          continue;
+        }
         PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret);
         return false;
-      } else if (ssl_ret != static_cast<int>(n)) {
+      }
+      if (ssl_ret != static_cast<int>(pending_write_len)) {
         fprintf(stderr, "Short write from SSL_write.\n");
         return false;
       }
-
-      // Note we handle errors before restoring the non-blocking state. On
-      // Windows, |SocketSetNonBlocking| internally clears the last error.
-      if (!SocketSetNonBlocking(sock, true)) {
-        return false;
-      }
+      pending_write_len = 0;
     }
 
     if (socket_ready) {