Fix bssl select loop on Windows.

While |WaitForMultipleObjects| works for both sockets and stdin, the
latter is often a line-buffered console. The |HANDLE| is considered
readable if there are any console events available, but reading blocks
until a full line is available. (In POSIX, line buffering is implemented
in the kernel via termios, which is differently concerning, but does
mean |select| works as expected.)

So that |Wait| reflects final stdin read, we spawn a stdin reader thread
that writes to an in-memory buffer and signals a |WSAEVENT| to
coordinate with the socket. This is kind of silly, but it works.

I tried just writing it to a pipe, but it appears
|WaitForMultipleObjects| does not work on pipes!

Change-Id: I2bfa323fa91aad7d2035bb1fe86ee6f54b85d811
Reviewed-on: https://boringssl-review.googlesource.com/28165
Reviewed-by: Steven Valdez <svaldez@google.com>
Commit-Queue: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/tool/transport_common.cc b/tool/transport_common.cc
index dcb8e0d..f1b03ab 100644
--- a/tool/transport_common.cc
+++ b/tool/transport_common.cc
@@ -12,6 +12,12 @@
  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
 
+// Suppress MSVC's STL warnings. It flags |std::copy| calls with a raw output
+// pointer, on grounds that MSVC cannot check them. Unfortunately, there is no
+// way to suppress the warning just on one line. The warning is flagged inside
+// the STL itself, so suppressing at the |std::copy| call does not work.
+#define _SCL_SECURE_NO_WARNINGS
+
 #include <openssl/base.h>
 
 #include <string>
@@ -33,6 +39,13 @@
 #include <sys/socket.h>
 #include <unistd.h>
 #else
+#include <algorithm>
+#include <deque>
+#include <memory>
+#include <mutex>
+#include <thread>
+#include <utility>
+
 #include <io.h>
 OPENSSL_MSVC_PRAGMA(warning(push, 3))
 #include <winsock2.h>
@@ -347,64 +360,265 @@
   ok = 0 == fcntl(sock, F_SETFL, flags);
 #endif
   if (!ok) {
-    fprintf(stderr, "Failed to set socket non-blocking.\n");
+    PrintSocketError("Failed to set socket non-blocking");
   }
   return ok;
 }
 
-static bool SocketSelect(int sock, bool stdin_open, bool *socket_ready,
-                         bool *stdin_ready) {
 #if !defined(OPENSSL_WINDOWS)
-  fd_set read_fds;
-  FD_ZERO(&read_fds);
-  if (stdin_open) {
-    FD_SET(0, &read_fds);
-  }
-  FD_SET(sock, &read_fds);
-  if (select(sock + 1, &read_fds, NULL, NULL, NULL) <= 0) {
-    perror("select");
-    return false;
-  }
 
-  if (FD_ISSET(0, &read_fds)) {
-    *stdin_ready = true;
-  }
-  if (FD_ISSET(sock, &read_fds)) {
+// SocketWaiter abstracts waiting for either the socket or stdin to be readable
+// between Windows and POSIX.
+class SocketWaiter {
+ public:
+  explicit SocketWaiter(int sock) : sock_(sock) {}
+  SocketWaiter(const SocketWaiter &) = delete;
+  SocketWaiter &operator=(const SocketWaiter &) = delete;
+
+  // Init initializes the SocketWaiter. It returns whether it succeeded.
+  bool Init() { return true; }
+
+  // Wait waits for at least on of the socket or stdin or be ready. On success,
+  // it sets |*socket_ready| and |*stdin_ready| to whether the respective
+  // objects are readable and returns true. On error, it returns false.
+  bool Wait(bool *socket_ready, bool *stdin_ready) {
     *socket_ready = true;
-  }
+    *stdin_ready = false;
 
-  return true;
-#else
-  WSAEVENT socket_handle = WSACreateEvent();
-  if (socket_handle == WSA_INVALID_EVENT ||
-      WSAEventSelect(sock, socket_handle, FD_READ | FD_CLOSE) != 0) {
-    WSACloseEvent(socket_handle);
-    return false;
-  }
-
-  HANDLE read_fds[2];
-  read_fds[0] = socket_handle;
-  read_fds[1] = GetStdHandle(STD_INPUT_HANDLE);
-
-  switch (
-      WaitForMultipleObjects(stdin_open ? 2 : 1, read_fds, FALSE, INFINITE)) {
-    case WAIT_OBJECT_0 + 0:
-      *socket_ready = true;
-      break;
-    case WAIT_OBJECT_0 + 1:
-      *stdin_ready = true;
-      break;
-    case WAIT_TIMEOUT:
-      break;
-    default:
-      WSACloseEvent(socket_handle);
+    fd_set read_fds;
+    FD_ZERO(&read_fds);
+    if (stdin_open_) {
+      FD_SET(STDIN_FILENO, &read_fds);
+    }
+    FD_SET(sock_, &read_fds);
+    if (select(sock_ + 1, &read_fds, NULL, NULL, NULL) <= 0) {
+      perror("select");
       return false;
+    }
+
+    if (FD_ISSET(STDIN_FILENO, &read_fds)) {
+      *stdin_ready = true;
+    }
+    if (FD_ISSET(sock_, &read_fds)) {
+      *socket_ready = true;
+    }
+
+    return true;
   }
 
-  WSACloseEvent(socket_handle);
-  return true;
-#endif
-}
+  // ReadStdin reads at most |max_out| bytes from stdin. On success, it writes
+  // them to |out| and sets |*out_len| to the number of bytes written. On error,
+  // it returns false. This method may only be called after |Wait| returned
+  // stdin was ready.
+  bool ReadStdin(void *out, size_t *out_len, size_t max_out) {
+    ssize_t n;
+    do {
+      n = read(STDIN_FILENO, out, max_out);
+    } while (n == -1 && errno == EINTR);
+    if (n < 0) {
+      perror("read from stdin");
+      return false;
+    }
+    *out_len = static_cast<size_t>(n);
+    return true;
+  }
+
+ private:
+   bool stdin_open_ = true;
+   int sock_;
+};
+
+#else // OPENSSL_WINDOWs
+
+class ScopedWSAEVENT {
+ public:
+  ScopedWSAEVENT() = default;
+  ScopedWSAEVENT(WSAEVENT event) { reset(event); }
+  ScopedWSAEVENT(const ScopedWSAEVENT &) = delete;
+  ScopedWSAEVENT(ScopedWSAEVENT &&other) { *this = std::move(other); }
+
+  ~ScopedWSAEVENT() { reset(); }
+
+  ScopedWSAEVENT &operator=(const ScopedWSAEVENT &) = delete;
+  ScopedWSAEVENT &operator=(ScopedWSAEVENT &&other) { reset(other.release()); }
+
+  explicit operator bool() const { return event_ != WSA_INVALID_EVENT; }
+  WSAEVENT get() const { return event_; }
+
+  WSAEVENT release() {
+    WSAEVENT ret = event_;
+    event_ = WSA_INVALID_EVENT;
+    return ret;
+  }
+
+  void reset(WSAEVENT event = WSA_INVALID_EVENT) {
+    if (event_ != WSA_INVALID_EVENT) {
+      WSACloseEvent(event_);
+    }
+    event_ = event;
+  }
+
+ private:
+  WSAEVENT event_ = WSA_INVALID_EVENT;
+};
+
+// SocketWaiter, on Windows, is more complicated. While |WaitForMultipleObjects|
+// works for both sockets and stdin, the latter is often a line-buffered
+// console. The |HANDLE| is considered readable if there are any console events
+// available, but reading blocks until a full line is available.
+//
+// So that |Wait| reflects final stdin read, we spawn a stdin reader thread that
+// writes to an in-memory buffer and signals a |WSAEVENT| to coordinate with the
+// socket.
+class SocketWaiter {
+ public:
+  explicit SocketWaiter(int sock) : sock_(sock) {}
+  SocketWaiter(const SocketWaiter &) = delete;
+  SocketWaiter &operator=(const SocketWaiter &) = delete;
+
+  bool Init() {
+    stdin_ = std::make_shared<StdinState>();
+    stdin_->event.reset(WSACreateEvent());
+    if (!stdin_->event) {
+      PrintSocketError("Error in WSACreateEvent");
+      return false;
+    }
+
+    // Spawn a thread to block on stdin.
+    std::shared_ptr<StdinState> state = stdin_;
+    std::thread thread([state]() {
+      for (;;) {
+        uint8_t buf[512];
+        int ret = _read(0 /* stdin */, buf, sizeof(buf));
+        if (ret <= 0) {
+          if (ret < 0) {
+            perror("read from stdin");
+          }
+          // Report the error or EOF to the caller.
+          std::lock_guard<std::mutex> lock(state->lock);
+          state->error = ret < 0;
+          state->open = false;
+          WSASetEvent(state->event.get());
+          return;
+        }
+
+        size_t len = static_cast<size_t>(ret);
+        size_t written = 0;
+        while (written < len) {
+          std::unique_lock<std::mutex> lock(state->lock);
+          // Wait for there to be room in the buffer.
+          state->cond.wait(lock, [&] { return !state->buffer_full(); });
+
+          // Copy what we can and signal to the caller.
+          size_t todo = std::min(len - written, state->buffer_remaining());
+          state->buffer.insert(state->buffer.end(), buf + written,
+                               buf + written + todo);
+          written += todo;
+          WSASetEvent(state->event.get());
+        }
+      }
+    });
+    thread.detach();
+    return true;
+  }
+
+  bool Wait(bool *socket_ready, bool *stdin_ready) {
+    *socket_ready = true;
+    *stdin_ready = false;
+
+    ScopedWSAEVENT sock_event(WSACreateEvent());
+    if (!sock_event ||
+        WSAEventSelect(sock_, sock_event.get(), FD_READ | FD_CLOSE) != 0) {
+      PrintSocketError("Error waiting for socket");
+      return false;
+    }
+
+    DWORD count = 1;
+    WSAEVENT events[2] = {sock_event.get(), WSA_INVALID_EVENT};
+    if (listen_stdin_) {
+      events[1] = stdin_->event.get();
+      count++;
+    }
+
+    switch (WSAWaitForMultipleEvents(count, events, FALSE /* wait all */,
+                                     WSA_INFINITE, FALSE /* alertable */)) {
+      case WSA_WAIT_EVENT_0 + 0:
+        *socket_ready = true;
+        return true;
+      case WSA_WAIT_EVENT_0 + 1:
+        *stdin_ready = true;
+        return true;
+      case WSA_WAIT_TIMEOUT:
+        return true;
+      default:
+        PrintSocketError("Error waiting for events");
+        return false;
+    }
+  }
+
+  bool ReadStdin(void *out, size_t *out_len, size_t max_out) {
+    std::lock_guard<std::mutex> locked(stdin_->lock);
+
+    if (stdin_->buffer.empty()) {
+      // |ReadStdin| may only be called when |Wait| signals it is ready, so
+      // stdin must have reached EOF or error.
+      assert(!stdin_->open);
+      listen_stdin_ = false;
+      if (stdin_->error) {
+        return false;
+      }
+      *out_len = 0;
+      return true;
+    }
+
+    bool was_full = stdin_->buffer_full();
+    // Copy as many bytes as well fit.
+    *out_len = std::min(max_out, stdin_->buffer.size());
+    auto begin = stdin_->buffer.begin();
+    auto end = stdin_->buffer.begin() + *out_len;
+    std::copy(begin, end, static_cast<uint8_t *>(out));
+    stdin_->buffer.erase(begin, end);
+    // Notify the stdin thread if there is more space.
+    if (was_full && !stdin_->buffer_full()) {
+      stdin_->cond.notify_one();
+    }
+    // If stdin is now waiting for input, clear the event.
+    if (stdin_->buffer.empty() && stdin_->open) {
+      WSAResetEvent(stdin_->event.get());
+    }
+    return true;
+  }
+
+ private:
+  struct StdinState {
+    static constexpr size_t kMaxBuffer = 1024;
+
+    StdinState() = default;
+    StdinState(const StdinState &) = delete;
+    StdinState &operator=(const StdinState &) = delete;
+
+    size_t buffer_remaining() const { return kMaxBuffer - buffer.size(); }
+    bool buffer_full() const { return buffer_remaining() == 0; }
+
+    ScopedWSAEVENT event;
+    // lock protects the following fields.
+    std::mutex lock;
+    // cond notifies the stdin thread that |buffer| is no longer full.
+    std::condition_variable cond;
+    std::deque<uint8_t> buffer;
+    bool open = true;
+    bool error = false;
+  };
+
+  int sock_;
+  std::shared_ptr<StdinState> stdin_;
+  // listen_stdin_ is set to false when we have consumed an EOF or error from
+  // |stdin_|. This is separate from |stdin_->open| because the signal may not
+  // have been consumed yet.
+  bool listen_stdin_ = true;
+};
+
+#endif  // OPENSSL_WINDOWS
 
 void PrintSSLError(FILE *file, const char *msg, int ssl_err, int ret) {
   switch (ssl_err) {
@@ -433,47 +647,46 @@
     return false;
   }
 
-  bool stdin_open = true;
+  SocketWaiter waiter(sock);
+  if (!waiter.Init()) {
+    return false;
+  }
   for (;;) {
     bool socket_ready = false;
     bool stdin_ready = false;
-    if (!SocketSelect(sock, stdin_open, &socket_ready, &stdin_ready)) {
+    if (!waiter.Wait(&socket_ready, &stdin_ready)) {
       return false;
     }
 
     if (stdin_ready) {
       uint8_t buffer[512];
-      ssize_t n;
-
-      do {
-        n = BORINGSSL_READ(0, buffer, sizeof(buffer));
-      } while (n == -1 && errno == EINTR);
-
+      size_t n;
+      if (!waiter.ReadStdin(buffer, &n, sizeof(buffer))) {
+        return false;
+      }
       if (n == 0) {
-        stdin_open = false;
 #if !defined(OPENSSL_WINDOWS)
         shutdown(sock, SHUT_WR);
 #else
         shutdown(sock, SD_SEND);
 #endif
         continue;
-      } else if (n < 0) {
-        perror("read from stdin");
-        return false;
       }
 
-      // On Windows, SocketSelect ends up setting sock to non-blocking.
+      // TODO(davidben): On Windows, |WSAEventSocket| sets |sock| to non-
+      // blocking and forbids setting it back to blocking. Rather than toggle
+      // the blocking state, loop waiting for the socket to be writable.
 #if !defined(OPENSSL_WINDOWS)
       if (!SocketSetNonBlocking(sock, false)) {
         return false;
       }
 #endif
-      int ssl_ret = SSL_write(ssl, buffer, n);
+      int ssl_ret = SSL_write(ssl, buffer, static_cast<int>(n));
       if (ssl_ret <= 0) {
         int ssl_err = SSL_get_error(ssl, ssl_ret);
         PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret);
         return false;
-      } else if (ssl_ret != n) {
+      } else if (ssl_ret != static_cast<int>(n)) {
         fprintf(stderr, "Short write from SSL_write.\n");
         return false;
       }