Fix bssl client/server's error-handling. Rather than printing the SSL_ERROR_* constants, print the actual error. This should be a bit more understandable. Debugging this also uncovered some other issues on Windows: - We were mixing up C runtime and Winsock errors, which are separate in Windows. - The thread local implementation interferes with WSAGetLastError due to a quirk of TlsGetValue. This could affect other Windows consumers. (Chromium uses a custom BIO, so it isn't affected.) - SocketSetNonBlocking also interferes with WSAGetLastError. - Listen for FD_CLOSE along with FD_READ. Connection close does not signal FD_READ. (The select loop only barely works on Windows anyway due to issues with stdin and line buffering, but if we take stdin out of the equation, FD_CLOSE can be tested.) Change-Id: If991259915acc96606a314fbe795fe6ea1e295e8 Reviewed-on: https://boringssl-review.googlesource.com/28125 Commit-Queue: Steven Valdez <svaldez@google.com> Reviewed-by: Steven Valdez <svaldez@google.com> CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/crypto/err/err_test.cc b/crypto/err/err_test.cc index 489d248..d975721 100644 --- a/crypto/err/err_test.cc +++ b/crypto/err/err_test.cc
@@ -23,6 +23,14 @@ #include "./internal.h" +#if defined(OPENSSL_WINDOWS) +OPENSSL_MSVC_PRAGMA(warning(push, 3)) +#include <windows.h> +OPENSSL_MSVC_PRAGMA(warning(pop)) +#else +#include <errno.h> +#endif + TEST(ErrTest, Overflow) { for (unsigned i = 0; i < ERR_NUM_ERRORS*2; i++) { @@ -212,3 +220,18 @@ EXPECT_EQ(0u, ERR_get_error()); } } + +// Querying the error queue should not affect the OS error. +#if defined(OPENSSL_WINDOWS) +TEST(ErrTest, PreservesLastError) { + SetLastError(ERROR_INVALID_FUNCTION); + ERR_get_error(); + EXPECT_EQ(ERROR_INVALID_FUNCTION, GetLastError()); +} +#else +TEST(ErrTest, PreservesErrno) { + errno = EINVAL; + ERR_get_error(); + EXPECT_EQ(EINVAL, errno); +} +#endif
diff --git a/crypto/thread_win.c b/crypto/thread_win.c index d6fa548..248870a 100644 --- a/crypto/thread_win.c +++ b/crypto/thread_win.c
@@ -190,13 +190,31 @@ #endif // _WIN64 +static void **get_thread_locals(void) { + // |TlsGetValue| clears the last error even on success, so that callers may + // distinguish it successfully returning NULL or failing. It is documented to + // never fail if the argument is a valid index from |TlsAlloc|, so we do not + // need to handle this. + // + // However, this error-mangling behavior interferes with the caller's use of + // |GetLastError|. In particular |SSL_get_error| queries the error queue to + // determine whether the caller should look at the OS's errors. To avoid + // destroying state, save and restore the Windows error. + // + // https://msdn.microsoft.com/en-us/library/windows/desktop/ms686812(v=vs.85).aspx + DWORD last_error = GetLastError(); + void **ret = TlsGetValue(g_thread_local_key); + SetLastError(last_error); + return ret; +} + void *CRYPTO_get_thread_local(thread_local_data_t index) { CRYPTO_once(&g_thread_local_init_once, thread_local_init); if (g_thread_local_failed) { return NULL; } - void **pointers = TlsGetValue(g_thread_local_key); + void **pointers = get_thread_locals(); if (pointers == NULL) { return NULL; } @@ -211,7 +229,7 @@ return 0; } - void **pointers = TlsGetValue(g_thread_local_key); + void **pointers = get_thread_locals(); if (pointers == NULL) { pointers = OPENSSL_malloc(sizeof(void *) * NUM_OPENSSL_THREAD_LOCALS); if (pointers == NULL) {
diff --git a/tool/client.cc b/tool/client.cc index bdb5de7..037e10c 100644 --- a/tool/client.cc +++ b/tool/client.cc
@@ -181,7 +181,7 @@ if (!PEM_write_bio_SSL_SESSION(session_out.get(), session) || BIO_flush(session_out.get()) <= 0) { fprintf(stderr, "Error while saving session:\n"); - ERR_print_errors_cb(PrintErrorCallback, stderr); + ERR_print_errors_fp(stderr); return 0; } } @@ -221,8 +221,7 @@ if (ssl_err == SSL_ERROR_WANT_READ) { continue; } - fprintf(stderr, "Error while reading: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while reading", ssl_err, ssl_ret); return false; } } @@ -267,14 +266,14 @@ "rb")); if (!in) { fprintf(stderr, "Error reading session\n"); - ERR_print_errors_cb(PrintErrorCallback, stderr); + ERR_print_errors_fp(stderr); return false; } bssl::UniquePtr<SSL_SESSION> session(PEM_read_bio_SSL_SESSION(in.get(), nullptr, nullptr, nullptr)); if (!session) { fprintf(stderr, "Error reading session\n"); - ERR_print_errors_cb(PrintErrorCallback, stderr); + ERR_print_errors_fp(stderr); return false; } SSL_set_session(ssl.get(), session.get()); @@ -294,8 +293,7 @@ int ret = SSL_connect(ssl.get()); if (ret != 1) { int ssl_err = SSL_get_error(ssl.get(), ret); - fprintf(stderr, "Error while connecting: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while connecting", ssl_err, ret); return false; } @@ -315,8 +313,7 @@ int ssl_ret = SSL_write(ssl.get(), early_data.data(), ed_size); if (ssl_ret <= 0) { int ssl_err = SSL_get_error(ssl.get(), ssl_ret); - fprintf(stderr, "Error while writing: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret); return false; } else if (ssl_ret != ed_size) { fprintf(stderr, "Short write from SSL_write.\n"); @@ -500,7 +497,7 @@ if (!session_out) { fprintf(stderr, "Error while opening %s:\n", args_map["-session-out"].c_str()); - ERR_print_errors_cb(PrintErrorCallback, stderr); + ERR_print_errors_fp(stderr); return false; } } @@ -513,7 +510,7 @@ if (!SSL_CTX_load_verify_locations( ctx.get(), args_map["-root-certs"].c_str(), nullptr)) { fprintf(stderr, "Failed to load root certificates.\n"); - ERR_print_errors_cb(PrintErrorCallback, stderr); + ERR_print_errors_fp(stderr); return false; } SSL_CTX_set_verify(ctx.get(), SSL_VERIFY_PEER, nullptr);
diff --git a/tool/server.cc b/tool/server.cc index 23a47e9..7a4e53b 100644 --- a/tool/server.cc +++ b/tool/server.cc
@@ -185,8 +185,7 @@ SSL_read(ssl, request + request_len, sizeof(request) - request_len); if (ssl_ret <= 0) { int ssl_err = SSL_get_error(ssl, ssl_ret); - fprintf(stderr, "Error while reading: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while reading", ssl_err, ssl_ret); return false; } request_len += static_cast<size_t>(ssl_ret); @@ -342,8 +341,7 @@ int ret = SSL_accept(ssl.get()); if (ret != 1) { int ssl_err = SSL_get_error(ssl.get(), ret); - fprintf(stderr, "Error while connecting: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while connecting", ssl_err, ret); result = false; continue; }
diff --git a/tool/transport_common.cc b/tool/transport_common.cc index 55f2059..dcb8e0d 100644 --- a/tool/transport_common.cc +++ b/tool/transport_common.cc
@@ -91,6 +91,33 @@ } } +static std::string GetLastSocketErrorString() { +#if defined(OPENSSL_WINDOWS) + int error = WSAGetLastError(); + char *buffer; + DWORD len = FormatMessageA( + FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_ALLOCATE_BUFFER, 0, error, 0, + reinterpret_cast<char *>(&buffer), 0, nullptr); + if (len == 0) { + char buf[256]; + snprintf(buf, sizeof(buf), "unknown error (0x%x)", error); + return buf; + } + std::string ret(buffer, len); + LocalFree(buffer); + return ret; +#else + return strerror(errno); +#endif +} + +static void PrintSocketError(const char *function) { + // On Windows, |perror| and |errno| are part of the C runtime, while sockets + // are separate, so we must print errors manually. + std::string error = GetLastSocketErrorString(); + fprintf(stderr, "%s: %s\n", function, error.c_str()); +} + // Connect sets |*out_sock| to be a socket connected to the destination given // in |hostname_and_port|, which should be of the form "www.example.com:123". // It returns true on success and false otherwise. @@ -121,7 +148,7 @@ *out_sock = socket(result->ai_family, result->ai_socktype, result->ai_protocol); if (*out_sock < 0) { - perror("socket"); + PrintSocketError("socket"); goto out; } @@ -145,7 +172,7 @@ } if (connect(*out_sock, result->ai_addr, result->ai_addrlen) != 0) { - perror("connect"); + PrintSocketError("connect"); goto out; } ok = true; @@ -188,18 +215,18 @@ server_sock_ = socket(addr.sin6_family, SOCK_STREAM, 0); if (server_sock_ < 0) { - perror("socket"); + PrintSocketError("socket"); return false; } if (setsockopt(server_sock_, SOL_SOCKET, SO_REUSEADDR, (const char *)&enable, sizeof(enable)) < 0) { - perror("setsockopt"); + PrintSocketError("setsockopt"); return false; } if (bind(server_sock_, (struct sockaddr *)&addr, sizeof(addr)) != 0) { - perror("connect"); + PrintSocketError("connect"); return false; } @@ -350,7 +377,7 @@ #else WSAEVENT socket_handle = WSACreateEvent(); if (socket_handle == WSA_INVALID_EVENT || - WSAEventSelect(sock, socket_handle, FD_READ) != 0) { + WSAEventSelect(sock, socket_handle, FD_READ | FD_CLOSE) != 0) { WSACloseEvent(socket_handle); return false; } @@ -379,11 +406,26 @@ #endif } -// PrintErrorCallback is a callback function from OpenSSL's -// |ERR_print_errors_cb| that writes errors to a given |FILE*|. -int PrintErrorCallback(const char *str, size_t len, void *ctx) { - fwrite(str, len, 1, reinterpret_cast<FILE*>(ctx)); - return 1; +void PrintSSLError(FILE *file, const char *msg, int ssl_err, int ret) { + switch (ssl_err) { + case SSL_ERROR_SSL: + fprintf(file, "%s: %s\n", msg, ERR_reason_error_string(ERR_peek_error())); + break; + case SSL_ERROR_SYSCALL: + if (ret == 0) { + fprintf(file, "%s: peer closed connection\n", msg); + } else { + std::string error = GetLastSocketErrorString(); + fprintf(file, "%s: %s\n", msg, error.c_str()); + } + break; + case SSL_ERROR_ZERO_RETURN: + fprintf(file, "%s: received close_notify\n", msg); + break; + default: + fprintf(file, "%s: unknown error type (%d)\n", msg, ssl_err); + } + ERR_print_errors_fp(file); } bool TransferData(SSL *ssl, int sock) { @@ -427,19 +469,20 @@ } #endif int ssl_ret = SSL_write(ssl, buffer, n); - if (!SocketSetNonBlocking(sock, true)) { - return false; - } - if (ssl_ret <= 0) { int ssl_err = SSL_get_error(ssl, ssl_ret); - fprintf(stderr, "Error while writing: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret); return false; } else if (ssl_ret != n) { fprintf(stderr, "Short write from SSL_write.\n"); return false; } + + // Note we handle errors before restoring the non-blocking state. On + // Windows, |SocketSetNonBlocking| internally clears the last error. + if (!SocketSetNonBlocking(sock, true)) { + return false; + } } if (socket_ready) { @@ -451,8 +494,7 @@ if (ssl_err == SSL_ERROR_WANT_READ) { continue; } - fprintf(stderr, "Error while reading: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while reading", ssl_err, ssl_ret); return false; } else if (ssl_ret == 0) { return true;
diff --git a/tool/transport_common.h b/tool/transport_common.h index 492416a..7d45d1c 100644 --- a/tool/transport_common.h +++ b/tool/transport_common.h
@@ -53,7 +53,10 @@ bool SocketSetNonBlocking(int sock, bool is_non_blocking); -int PrintErrorCallback(const char *str, size_t len, void *ctx); +// PrintSSLError prints information about the most recent SSL error to stderr. +// |ssl_err| must be the output of |SSL_get_error| and the |SSL| object must be +// connected to socket from |Connect|. +void PrintSSLError(FILE *file, const char *msg, int ssl_err, int ret); bool TransferData(SSL *ssl, int sock);