Fix bssl client/server's error-handling.

Rather than printing the SSL_ERROR_* constants, print the actual error.
This should be a bit more understandable. Debugging this also uncovered
some other issues on Windows:

- We were mixing up C runtime and Winsock errors, which are separate in
  Windows.

- The thread local implementation interferes with WSAGetLastError due to
  a quirk of TlsGetValue. This could affect other Windows consumers.
  (Chromium uses a custom BIO, so it isn't affected.)

- SocketSetNonBlocking also interferes with WSAGetLastError.

- Listen for FD_CLOSE along with FD_READ. Connection close does not
  signal FD_READ. (The select loop only barely works on Windows anyway
  due to issues with stdin and line buffering, but if we take stdin out
  of the equation, FD_CLOSE can be tested.)

Change-Id: If991259915acc96606a314fbe795fe6ea1e295e8
Reviewed-on: https://boringssl-review.googlesource.com/28125
Commit-Queue: Steven Valdez <svaldez@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/crypto/err/err_test.cc b/crypto/err/err_test.cc
index 489d248..d975721 100644
--- a/crypto/err/err_test.cc
+++ b/crypto/err/err_test.cc
@@ -23,6 +23,14 @@
 
 #include "./internal.h"
 
+#if defined(OPENSSL_WINDOWS)
+OPENSSL_MSVC_PRAGMA(warning(push, 3))
+#include <windows.h>
+OPENSSL_MSVC_PRAGMA(warning(pop))
+#else
+#include <errno.h>
+#endif
+
 
 TEST(ErrTest, Overflow) {
   for (unsigned i = 0; i < ERR_NUM_ERRORS*2; i++) {
@@ -212,3 +220,18 @@
     EXPECT_EQ(0u, ERR_get_error());
   }
 }
+
+// Querying the error queue should not affect the OS error.
+#if defined(OPENSSL_WINDOWS)
+TEST(ErrTest, PreservesLastError) {
+  SetLastError(ERROR_INVALID_FUNCTION);
+  ERR_get_error();
+  EXPECT_EQ(ERROR_INVALID_FUNCTION, GetLastError());
+}
+#else
+TEST(ErrTest, PreservesErrno) {
+  errno = EINVAL;
+  ERR_get_error();
+  EXPECT_EQ(EINVAL, errno);
+}
+#endif
diff --git a/crypto/thread_win.c b/crypto/thread_win.c
index d6fa548..248870a 100644
--- a/crypto/thread_win.c
+++ b/crypto/thread_win.c
@@ -190,13 +190,31 @@
 
 #endif  // _WIN64
 
+static void **get_thread_locals(void) {
+  // |TlsGetValue| clears the last error even on success, so that callers may
+  // distinguish it successfully returning NULL or failing. It is documented to
+  // never fail if the argument is a valid index from |TlsAlloc|, so we do not
+  // need to handle this.
+  //
+  // However, this error-mangling behavior interferes with the caller's use of
+  // |GetLastError|. In particular |SSL_get_error| queries the error queue to
+  // determine whether the caller should look at the OS's errors. To avoid
+  // destroying state, save and restore the Windows error.
+  //
+  // https://msdn.microsoft.com/en-us/library/windows/desktop/ms686812(v=vs.85).aspx
+  DWORD last_error = GetLastError();
+  void **ret = TlsGetValue(g_thread_local_key);
+  SetLastError(last_error);
+  return ret;
+}
+
 void *CRYPTO_get_thread_local(thread_local_data_t index) {
   CRYPTO_once(&g_thread_local_init_once, thread_local_init);
   if (g_thread_local_failed) {
     return NULL;
   }
 
-  void **pointers = TlsGetValue(g_thread_local_key);
+  void **pointers = get_thread_locals();
   if (pointers == NULL) {
     return NULL;
   }
@@ -211,7 +229,7 @@
     return 0;
   }
 
-  void **pointers = TlsGetValue(g_thread_local_key);
+  void **pointers = get_thread_locals();
   if (pointers == NULL) {
     pointers = OPENSSL_malloc(sizeof(void *) * NUM_OPENSSL_THREAD_LOCALS);
     if (pointers == NULL) {
diff --git a/tool/client.cc b/tool/client.cc
index bdb5de7..037e10c 100644
--- a/tool/client.cc
+++ b/tool/client.cc
@@ -181,7 +181,7 @@
     if (!PEM_write_bio_SSL_SESSION(session_out.get(), session) ||
         BIO_flush(session_out.get()) <= 0) {
       fprintf(stderr, "Error while saving session:\n");
-      ERR_print_errors_cb(PrintErrorCallback, stderr);
+      ERR_print_errors_fp(stderr);
       return 0;
     }
   }
@@ -221,8 +221,7 @@
       if (ssl_err == SSL_ERROR_WANT_READ) {
         continue;
       }
-      fprintf(stderr, "Error while reading: %d\n", ssl_err);
-      ERR_print_errors_cb(PrintErrorCallback, stderr);
+      PrintSSLError(stderr, "Error while reading", ssl_err, ssl_ret);
       return false;
     }
   }
@@ -267,14 +266,14 @@
                                          "rb"));
     if (!in) {
       fprintf(stderr, "Error reading session\n");
-      ERR_print_errors_cb(PrintErrorCallback, stderr);
+      ERR_print_errors_fp(stderr);
       return false;
     }
     bssl::UniquePtr<SSL_SESSION> session(PEM_read_bio_SSL_SESSION(in.get(),
                                          nullptr, nullptr, nullptr));
     if (!session) {
       fprintf(stderr, "Error reading session\n");
-      ERR_print_errors_cb(PrintErrorCallback, stderr);
+      ERR_print_errors_fp(stderr);
       return false;
     }
     SSL_set_session(ssl.get(), session.get());
@@ -294,8 +293,7 @@
   int ret = SSL_connect(ssl.get());
   if (ret != 1) {
     int ssl_err = SSL_get_error(ssl.get(), ret);
-    fprintf(stderr, "Error while connecting: %d\n", ssl_err);
-    ERR_print_errors_cb(PrintErrorCallback, stderr);
+    PrintSSLError(stderr, "Error while connecting", ssl_err, ret);
     return false;
   }
 
@@ -315,8 +313,7 @@
     int ssl_ret = SSL_write(ssl.get(), early_data.data(), ed_size);
     if (ssl_ret <= 0) {
       int ssl_err = SSL_get_error(ssl.get(), ssl_ret);
-      fprintf(stderr, "Error while writing: %d\n", ssl_err);
-      ERR_print_errors_cb(PrintErrorCallback, stderr);
+      PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret);
       return false;
     } else if (ssl_ret != ed_size) {
       fprintf(stderr, "Short write from SSL_write.\n");
@@ -500,7 +497,7 @@
     if (!session_out) {
       fprintf(stderr, "Error while opening %s:\n",
               args_map["-session-out"].c_str());
-      ERR_print_errors_cb(PrintErrorCallback, stderr);
+      ERR_print_errors_fp(stderr);
       return false;
     }
   }
@@ -513,7 +510,7 @@
     if (!SSL_CTX_load_verify_locations(
             ctx.get(), args_map["-root-certs"].c_str(), nullptr)) {
       fprintf(stderr, "Failed to load root certificates.\n");
-      ERR_print_errors_cb(PrintErrorCallback, stderr);
+      ERR_print_errors_fp(stderr);
       return false;
     }
     SSL_CTX_set_verify(ctx.get(), SSL_VERIFY_PEER, nullptr);
diff --git a/tool/server.cc b/tool/server.cc
index 23a47e9..7a4e53b 100644
--- a/tool/server.cc
+++ b/tool/server.cc
@@ -185,8 +185,7 @@
         SSL_read(ssl, request + request_len, sizeof(request) - request_len);
     if (ssl_ret <= 0) {
       int ssl_err = SSL_get_error(ssl, ssl_ret);
-      fprintf(stderr, "Error while reading: %d\n", ssl_err);
-      ERR_print_errors_cb(PrintErrorCallback, stderr);
+      PrintSSLError(stderr, "Error while reading", ssl_err, ssl_ret);
       return false;
     }
     request_len += static_cast<size_t>(ssl_ret);
@@ -342,8 +341,7 @@
     int ret = SSL_accept(ssl.get());
     if (ret != 1) {
       int ssl_err = SSL_get_error(ssl.get(), ret);
-      fprintf(stderr, "Error while connecting: %d\n", ssl_err);
-      ERR_print_errors_cb(PrintErrorCallback, stderr);
+      PrintSSLError(stderr, "Error while connecting", ssl_err, ret);
       result = false;
       continue;
     }
diff --git a/tool/transport_common.cc b/tool/transport_common.cc
index 55f2059..dcb8e0d 100644
--- a/tool/transport_common.cc
+++ b/tool/transport_common.cc
@@ -91,6 +91,33 @@
   }
 }
 
+static std::string GetLastSocketErrorString() {
+#if defined(OPENSSL_WINDOWS)
+  int error = WSAGetLastError();
+  char *buffer;
+  DWORD len = FormatMessageA(
+      FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_ALLOCATE_BUFFER, 0, error, 0,
+      reinterpret_cast<char *>(&buffer), 0, nullptr);
+  if (len == 0) {
+    char buf[256];
+    snprintf(buf, sizeof(buf), "unknown error (0x%x)", error);
+    return buf;
+  }
+  std::string ret(buffer, len);
+  LocalFree(buffer);
+  return ret;
+#else
+  return strerror(errno);
+#endif
+}
+
+static void PrintSocketError(const char *function) {
+  // On Windows, |perror| and |errno| are part of the C runtime, while sockets
+  // are separate, so we must print errors manually.
+  std::string error = GetLastSocketErrorString();
+  fprintf(stderr, "%s: %s\n", function, error.c_str());
+}
+
 // Connect sets |*out_sock| to be a socket connected to the destination given
 // in |hostname_and_port|, which should be of the form "www.example.com:123".
 // It returns true on success and false otherwise.
@@ -121,7 +148,7 @@
   *out_sock =
       socket(result->ai_family, result->ai_socktype, result->ai_protocol);
   if (*out_sock < 0) {
-    perror("socket");
+    PrintSocketError("socket");
     goto out;
   }
 
@@ -145,7 +172,7 @@
   }
 
   if (connect(*out_sock, result->ai_addr, result->ai_addrlen) != 0) {
-    perror("connect");
+    PrintSocketError("connect");
     goto out;
   }
   ok = true;
@@ -188,18 +215,18 @@
 
   server_sock_ = socket(addr.sin6_family, SOCK_STREAM, 0);
   if (server_sock_ < 0) {
-    perror("socket");
+    PrintSocketError("socket");
     return false;
   }
 
   if (setsockopt(server_sock_, SOL_SOCKET, SO_REUSEADDR, (const char *)&enable,
                  sizeof(enable)) < 0) {
-    perror("setsockopt");
+    PrintSocketError("setsockopt");
     return false;
   }
 
   if (bind(server_sock_, (struct sockaddr *)&addr, sizeof(addr)) != 0) {
-    perror("connect");
+    PrintSocketError("connect");
     return false;
   }
 
@@ -350,7 +377,7 @@
 #else
   WSAEVENT socket_handle = WSACreateEvent();
   if (socket_handle == WSA_INVALID_EVENT ||
-      WSAEventSelect(sock, socket_handle, FD_READ) != 0) {
+      WSAEventSelect(sock, socket_handle, FD_READ | FD_CLOSE) != 0) {
     WSACloseEvent(socket_handle);
     return false;
   }
@@ -379,11 +406,26 @@
 #endif
 }
 
-// PrintErrorCallback is a callback function from OpenSSL's
-// |ERR_print_errors_cb| that writes errors to a given |FILE*|.
-int PrintErrorCallback(const char *str, size_t len, void *ctx) {
-  fwrite(str, len, 1, reinterpret_cast<FILE*>(ctx));
-  return 1;
+void PrintSSLError(FILE *file, const char *msg, int ssl_err, int ret) {
+  switch (ssl_err) {
+    case SSL_ERROR_SSL:
+      fprintf(file, "%s: %s\n", msg, ERR_reason_error_string(ERR_peek_error()));
+      break;
+    case SSL_ERROR_SYSCALL:
+      if (ret == 0) {
+        fprintf(file, "%s: peer closed connection\n", msg);
+      } else {
+        std::string error = GetLastSocketErrorString();
+        fprintf(file, "%s: %s\n", msg, error.c_str());
+      }
+      break;
+    case SSL_ERROR_ZERO_RETURN:
+      fprintf(file, "%s: received close_notify\n", msg);
+      break;
+    default:
+      fprintf(file, "%s: unknown error type (%d)\n", msg, ssl_err);
+  }
+  ERR_print_errors_fp(file);
 }
 
 bool TransferData(SSL *ssl, int sock) {
@@ -427,19 +469,20 @@
       }
 #endif
       int ssl_ret = SSL_write(ssl, buffer, n);
-      if (!SocketSetNonBlocking(sock, true)) {
-        return false;
-      }
-
       if (ssl_ret <= 0) {
         int ssl_err = SSL_get_error(ssl, ssl_ret);
-        fprintf(stderr, "Error while writing: %d\n", ssl_err);
-        ERR_print_errors_cb(PrintErrorCallback, stderr);
+        PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret);
         return false;
       } else if (ssl_ret != n) {
         fprintf(stderr, "Short write from SSL_write.\n");
         return false;
       }
+
+      // Note we handle errors before restoring the non-blocking state. On
+      // Windows, |SocketSetNonBlocking| internally clears the last error.
+      if (!SocketSetNonBlocking(sock, true)) {
+        return false;
+      }
     }
 
     if (socket_ready) {
@@ -451,8 +494,7 @@
         if (ssl_err == SSL_ERROR_WANT_READ) {
           continue;
         }
-        fprintf(stderr, "Error while reading: %d\n", ssl_err);
-        ERR_print_errors_cb(PrintErrorCallback, stderr);
+        PrintSSLError(stderr, "Error while reading", ssl_err, ssl_ret);
         return false;
       } else if (ssl_ret == 0) {
         return true;
diff --git a/tool/transport_common.h b/tool/transport_common.h
index 492416a..7d45d1c 100644
--- a/tool/transport_common.h
+++ b/tool/transport_common.h
@@ -53,7 +53,10 @@
 
 bool SocketSetNonBlocking(int sock, bool is_non_blocking);
 
-int PrintErrorCallback(const char *str, size_t len, void *ctx);
+// PrintSSLError prints information about the most recent SSL error to stderr.
+// |ssl_err| must be the output of |SSL_get_error| and the |SSL| object must be
+// connected to socket from |Connect|.
+void PrintSSLError(FILE *file, const char *msg, int ssl_err, int ret);
 
 bool TransferData(SSL *ssl, int sock);