Test non-blocking socket BIOs

As part of this, factor out some of the socket bits. I tried to write
the sockaddr mess in a way that's strict-aliasing-clean, at least as far
as code we own goes. But the API is really not designed for it, and who
knows what effective type the underlying libc functions expect.
(Fortunately it's mostly syscalls, which definitely escape the
abstract machine.)

Change-Id: I12621f6c40f074ff7423dd46ddceca120ba63db9
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/61728
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/crypto/bio/bio_test.cc b/crypto/bio/bio_test.cc
index 1350392..a169b56 100644
--- a/crypto/bio/bio_test.cc
+++ b/crypto/bio/bio_test.cc
@@ -14,6 +14,7 @@
 
 #include <algorithm>
 #include <string>
+#include <utility>
 
 #include <gtest/gtest.h>
 
@@ -30,6 +31,7 @@
 #include <errno.h>
 #include <fcntl.h>
 #include <netinet/in.h>
+#include <poll.h>
 #include <string.h>
 #include <sys/socket.h>
 #include <unistd.h>
@@ -41,11 +43,13 @@
 OPENSSL_MSVC_PRAGMA(warning(pop))
 #endif
 
-
 #if !defined(OPENSSL_WINDOWS)
+using Socket = int;
+#define INVALID_SOCKET (-1)
 static int closesocket(int sock) { return close(sock); }
 static std::string LastSocketError() { return strerror(errno); }
 #else
+using Socket = SOCKET;
 static std::string LastSocketError() {
   char buf[DECIMAL_SIZE(int) + 1];
   snprintf(buf, sizeof(buf), "%d", WSAGetLastError());
@@ -53,78 +57,277 @@
 }
 #endif
 
-class ScopedSocket {
+class OwnedSocket {
  public:
-  explicit ScopedSocket(int sock) : sock_(sock) {}
-  ~ScopedSocket() {
-    closesocket(sock_);
+  OwnedSocket() = default;
+  explicit OwnedSocket(Socket sock) : sock_(sock) {}
+  OwnedSocket(OwnedSocket &&other) { *this = std::move(other); }
+  ~OwnedSocket() { reset(); }
+  OwnedSocket &operator=(OwnedSocket &&other) {
+    reset(other.release());
+    return *this;
+  }
+
+  bool is_valid() const { return sock_ != INVALID_SOCKET; }
+  Socket get() const { return sock_; }
+  Socket release() { return std::exchange(sock_, INVALID_SOCKET); }
+
+  void reset(Socket sock = INVALID_SOCKET) {
+    if (is_valid()) {
+      closesocket(sock_);
+    }
+
+    sock_ = sock;
   }
 
  private:
-  const int sock_;
+  Socket sock_ = INVALID_SOCKET;
 };
 
+struct SockaddrStorage {
+  int family() const { return storage.ss_family; }
+
+  sockaddr *addr_mut() { return reinterpret_cast<sockaddr *>(&storage); }
+  const sockaddr *addr() const {
+    return reinterpret_cast<const sockaddr *>(&storage);
+  }
+
+  sockaddr_in ToIPv4() const {
+    if (family() != AF_INET || len != sizeof(sockaddr_in)) {
+      abort();
+    }
+    // These APIs were seemingly designed before C's strict aliasing rule, and
+    // C++'s strict union handling. Make a copy so the compiler does not read
+    // this as an aliasing violation.
+    sockaddr_in ret;
+    OPENSSL_memcpy(&ret, &storage, sizeof(ret));
+    return ret;
+  }
+
+  sockaddr_in6 ToIPv6() const {
+    if (family() != AF_INET6 || len != sizeof(sockaddr_in6)) {
+      abort();
+    }
+    // These APIs were seemingly designed before C's strict aliasing rule, and
+    // C++'s strict union handling. Make a copy so the compiler does not read
+    // this as an aliasing violation.
+    sockaddr_in6 ret;
+    OPENSSL_memcpy(&ret, &storage, sizeof(ret));
+    return ret;
+  }
+
+  sockaddr_storage storage = {};
+  socklen_t len = sizeof(storage);
+};
+
+static OwnedSocket Bind(int family, const sockaddr *addr, socklen_t addr_len) {
+  OwnedSocket sock(socket(family, SOCK_STREAM, 0));
+  if (!sock.is_valid()) {
+    return OwnedSocket();
+  }
+
+  if (bind(sock.get(), addr, addr_len) != 0) {
+    return OwnedSocket();
+  }
+
+  return sock;
+}
+
+static OwnedSocket ListenLoopback(int backlog) {
+  // Try binding to IPv6.
+  sockaddr_in6 sin6;
+  OPENSSL_memset(&sin6, 0, sizeof(sin6));
+  sin6.sin6_family = AF_INET6;
+  if (inet_pton(AF_INET6, "::1", &sin6.sin6_addr) != 1) {
+    return OwnedSocket();
+  }
+  OwnedSocket sock =
+      Bind(AF_INET6, reinterpret_cast<const sockaddr *>(&sin6), sizeof(sin6));
+  if (!sock.is_valid()) {
+    // Try binding to IPv4.
+    sockaddr_in sin;
+    OPENSSL_memset(&sin, 0, sizeof(sin));
+    sin.sin_family = AF_INET;
+    if (inet_pton(AF_INET, "127.0.0.1", &sin.sin_addr) != 1) {
+      return OwnedSocket();
+    }
+    sock = Bind(AF_INET, reinterpret_cast<const sockaddr *>(&sin), sizeof(sin));
+  }
+  if (!sock.is_valid()) {
+    return OwnedSocket();
+  }
+
+  if (listen(sock.get(), backlog) != 0) {
+    return OwnedSocket();
+  }
+
+  return sock;
+}
+
+static bool SocketSetNonBlocking(Socket sock) {
+#if defined(OPENSSL_WINDOWS)
+  u_long arg = 1;
+  return ioctlsocket(sock, FIONBIO, &arg) == 0;
+#else
+  int flags = fcntl(sock, F_GETFL, 0);
+  if (flags < 0) {
+    return false;
+  }
+  flags |= O_NONBLOCK;
+  return fcntl(sock, F_SETFL, flags) == 0;
+#endif
+}
+
+enum class WaitType { kRead, kWrite };
+
+static bool WaitForSocket(Socket sock, WaitType wait_type) {
+  // Use an arbitrary 5 second timeout, so the test doesn't hang indefinitely if
+  // there's an issue.
+  static const int kTimeoutSeconds = 5;
+#if defined(OPENSSL_WINDOWS)
+  fd_set read_set, write_set;
+  FD_ZERO(&read_set);
+  FD_ZERO(&write_set);
+  fd_set *wait_set = wait_type == WaitType::kRead ? &read_set : &write_set;
+  FD_SET(sock, wait_set);
+  timeval timeout;
+  timeout.tv_sec = kTimeoutSeconds;
+  timeout.tv_usec = 0;
+  if (select(0 /* unused on Windows */, &read_set, &write_set, nullptr,
+             &timeout) <= 0) {
+    return false;
+  }
+  return FD_ISSET(sock, wait_set);
+#else
+  short events = wait_type == WaitType::kRead ? POLLIN : POLLOUT;
+  pollfd fd = {/*fd=*/sock, events, /*revents=*/0};
+  return poll(&fd, 1, kTimeoutSeconds * 1000) == 1 && (fd.revents & events);
+#endif
+}
+
 TEST(BIOTest, SocketConnect) {
   static const char kTestMessage[] = "test";
-  int listening_sock = -1;
-  socklen_t len = 0;
-  sockaddr_storage ss;
-  struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *) &ss;
-  struct sockaddr_in *sin = (struct sockaddr_in *) &ss;
-  OPENSSL_memset(&ss, 0, sizeof(ss));
+  OwnedSocket listening_sock = ListenLoopback(/*backlog=*/1);
+  ASSERT_TRUE(listening_sock.is_valid()) << LastSocketError();
 
-  ss.ss_family = AF_INET6;
-  listening_sock = socket(AF_INET6, SOCK_STREAM, 0);
-  ASSERT_NE(-1, listening_sock) << LastSocketError();
-  len = sizeof(*sin6);
-  ASSERT_EQ(1, inet_pton(AF_INET6, "::1", &sin6->sin6_addr))
+  SockaddrStorage addr;
+  ASSERT_EQ(getsockname(listening_sock.get(), addr.addr_mut(), &addr.len), 0)
       << LastSocketError();
-  if (bind(listening_sock, (struct sockaddr *)sin6, sizeof(*sin6)) == -1) {
-    closesocket(listening_sock);
-
-    ss.ss_family = AF_INET;
-    listening_sock = socket(AF_INET, SOCK_STREAM, 0);
-    ASSERT_NE(-1, listening_sock) << LastSocketError();
-    len = sizeof(*sin);
-    ASSERT_EQ(1, inet_pton(AF_INET, "127.0.0.1", &sin->sin_addr))
-        << LastSocketError();
-    ASSERT_EQ(0, bind(listening_sock, (struct sockaddr *)sin, sizeof(*sin)))
-        << LastSocketError();
-  }
-
-  ScopedSocket listening_sock_closer(listening_sock);
-  ASSERT_EQ(0, listen(listening_sock, 1)) << LastSocketError();
-  ASSERT_EQ(0, getsockname(listening_sock, (struct sockaddr *)&ss, &len))
-        << LastSocketError();
 
   char hostname[80];
-  if (ss.ss_family == AF_INET6) {
-    snprintf(hostname, sizeof(hostname), "[::1]:%d", ntohs(sin6->sin6_port));
-  } else if (ss.ss_family == AF_INET) {
-    snprintf(hostname, sizeof(hostname), "127.0.0.1:%d", ntohs(sin->sin_port));
+  if (addr.family() == AF_INET6) {
+    snprintf(hostname, sizeof(hostname), "[::1]:%d",
+             ntohs(addr.ToIPv6().sin6_port));
+  } else {
+    snprintf(hostname, sizeof(hostname), "127.0.0.1:%d",
+             ntohs(addr.ToIPv4().sin_port));
   }
 
   // Connect to it with a connect BIO.
   bssl::UniquePtr<BIO> bio(BIO_new_connect(hostname));
   ASSERT_TRUE(bio);
 
-  // Write a test message to the BIO.
+  // Write a test message to the BIO. This is assumed to be smaller than the
+  // transport buffer.
   ASSERT_EQ(static_cast<int>(sizeof(kTestMessage)),
-            BIO_write(bio.get(), kTestMessage, sizeof(kTestMessage)));
+            BIO_write(bio.get(), kTestMessage, sizeof(kTestMessage)))
+      << LastSocketError();
 
   // Accept the socket.
-  int sock = accept(listening_sock, (struct sockaddr *) &ss, &len);
-  ASSERT_NE(-1, sock) << LastSocketError();
-  ScopedSocket sock_closer(sock);
+  OwnedSocket sock(accept(listening_sock.get(), addr.addr_mut(), &addr.len));
+  ASSERT_TRUE(sock.is_valid()) << LastSocketError();
 
   // Check the same message is read back out.
   char buf[sizeof(kTestMessage)];
   ASSERT_EQ(static_cast<int>(sizeof(kTestMessage)),
-            recv(sock, buf, sizeof(buf), 0))
+            recv(sock.get(), buf, sizeof(buf), 0))
       << LastSocketError();
   EXPECT_EQ(Bytes(kTestMessage, sizeof(kTestMessage)), Bytes(buf, sizeof(buf)));
 }
 
+TEST(BIOTest, SocketNonBlocking) {
+  OwnedSocket listening_sock = ListenLoopback(/*backlog=*/1);
+  ASSERT_TRUE(listening_sock.is_valid()) << LastSocketError();
+
+  // Connect to |listening_sock|.
+  SockaddrStorage addr;
+  ASSERT_EQ(getsockname(listening_sock.get(), addr.addr_mut(), &addr.len), 0)
+      << LastSocketError();
+  OwnedSocket connect_sock(socket(addr.family(), SOCK_STREAM, 0));
+  ASSERT_TRUE(connect_sock.is_valid()) << LastSocketError();
+  ASSERT_EQ(connect(connect_sock.get(), addr.addr(), addr.len), 0)
+      << LastSocketError();
+  ASSERT_TRUE(SocketSetNonBlocking(connect_sock.get())) << LastSocketError();
+  bssl::UniquePtr<BIO> connect_bio(
+      BIO_new_socket(connect_sock.get(), BIO_NOCLOSE));
+  ASSERT_TRUE(connect_bio);
+
+  // Make a corresponding accepting socket.
+  OwnedSocket accept_sock(
+      accept(listening_sock.get(), addr.addr_mut(), &addr.len));
+  ASSERT_TRUE(accept_sock.is_valid()) << LastSocketError();
+  ASSERT_TRUE(SocketSetNonBlocking(accept_sock.get())) << LastSocketError();
+  bssl::UniquePtr<BIO> accept_bio(
+      BIO_new_socket(accept_sock.get(), BIO_NOCLOSE));
+  ASSERT_TRUE(accept_bio);
+
+  // Exchange data through the socket.
+  static const char kTestMessage[] = "hello, world";
+
+  // Reading from |accept_bio| should not block.
+  char buf[sizeof(kTestMessage)];
+  int ret = BIO_read(accept_bio.get(), buf, sizeof(buf));
+  EXPECT_EQ(ret, -1);
+  EXPECT_TRUE(BIO_should_read(accept_bio.get())) << LastSocketError();
+
+  // Writing to |connect_bio| should eventually overflow the transport buffers
+  // and also give a retryable error.
+  int bytes_written = 0;
+  for (;;) {
+    ret = BIO_write(connect_bio.get(), kTestMessage, sizeof(kTestMessage));
+    if (ret <= 0) {
+      EXPECT_EQ(ret, -1);
+      EXPECT_TRUE(BIO_should_write(connect_bio.get())) << LastSocketError();
+      break;
+    }
+    bytes_written += ret;
+  }
+  EXPECT_GT(bytes_written, 0);
+
+  // |accept_bio| should readable. Drain it. Note data is not always available
+  // from loopback immediately, notably on macOS, so wait for the socket first.
+  int bytes_read = 0;
+  while (bytes_read < bytes_written) {
+    ASSERT_TRUE(WaitForSocket(accept_sock.get(), WaitType::kRead))
+        << LastSocketError();
+    ret = BIO_read(accept_bio.get(), buf, sizeof(buf));
+    ASSERT_GT(ret, 0);
+    bytes_read += ret;
+  }
+
+  // |connect_bio| should become writeable again.
+  ASSERT_TRUE(WaitForSocket(accept_sock.get(), WaitType::kWrite))
+      << LastSocketError();
+  ret = BIO_write(connect_bio.get(), kTestMessage, sizeof(kTestMessage));
+  EXPECT_EQ(static_cast<int>(sizeof(kTestMessage)), ret);
+
+  ASSERT_TRUE(WaitForSocket(accept_sock.get(), WaitType::kRead))
+      << LastSocketError();
+  ret = BIO_read(accept_bio.get(), buf, sizeof(buf));
+  EXPECT_EQ(static_cast<int>(sizeof(kTestMessage)), ret);
+  EXPECT_EQ(Bytes(buf), Bytes(kTestMessage));
+
+  // Close one socket. We should get an EOF out the other.
+  connect_bio.reset();
+  connect_sock.reset();
+
+  ASSERT_TRUE(WaitForSocket(accept_sock.get(), WaitType::kRead))
+      << LastSocketError();
+  ret = BIO_read(accept_bio.get(), buf, sizeof(buf));
+  EXPECT_EQ(ret, 0) << LastSocketError();
+  EXPECT_FALSE(BIO_should_read(accept_bio.get()));
+}
+
 TEST(BIOTest, Printf) {
   // Test a short output, a very long one, and various sizes around
   // 256 (the size of the buffer) to ensure edge cases are correct.