Turn SocketCloser in bssl_shim into a proper owning type

It's a bit more verbose to set up, but makes the error paths in
Connect() tidier. While I'm here, stick to Windows' actual SOCKET
type until we have to cross into BIO. It doesn't really matter
(Windows cannot use the upper half of that type without badly
breaking backwards compatibility), but it silences some 64/32
truncation warnings.

Change-Id: I7be7c2b543373a7a9fc50711131e5345d84ebb8b
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/60886
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index c14d7c0..6afbcc0 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -68,14 +68,14 @@
 
 
 #if !defined(OPENSSL_WINDOWS)
-static int closesocket(int sock) {
-  return close(sock);
-}
+using Socket = int;
+#define INVALID_SOCKET (-1)
 
-static void PrintSocketError(const char *func) {
-  perror(func);
-}
+static int closesocket(int sock) { return close(sock); }
+static void PrintSocketError(const char *func) { perror(func); }
 #else
+using Socket = SOCKET;
+
 static void PrintSocketError(const char *func) {
   int error = WSAGetLastError();
   char *buffer;
@@ -94,6 +94,57 @@
 }
 #endif
 
+class OwnedSocket {
+ public:
+  OwnedSocket() = default;
+  explicit OwnedSocket(Socket sock) : sock_(sock) {}
+  OwnedSocket(OwnedSocket &&other) { *this = std::move(other); }
+  ~OwnedSocket() { reset(); }
+  OwnedSocket &operator=(OwnedSocket &&other) {
+    drain_on_close_ = other.drain_on_close_;
+    reset(other.release());
+    return *this;
+  }
+
+  bool is_valid() const { return sock_ != INVALID_SOCKET; }
+  void set_drain_on_close(bool drain) { drain_on_close_ = drain; }
+
+  void reset(Socket sock = INVALID_SOCKET) {
+    if (is_valid()) {
+      if (drain_on_close_) {
+#if defined(OPENSSL_WINDOWS)
+        shutdown(sock_, SD_SEND);
+#else
+        shutdown(sock_, SHUT_WR);
+#endif
+        while (true) {
+          char buf[1024];
+          if (recv(sock_, buf, sizeof(buf), 0) <= 0) {
+            break;
+          }
+        }
+        closesocket(sock_);
+      }
+    }
+
+    drain_on_close_ = false;
+    sock_ = sock;
+  }
+
+  Socket get() const { return sock_; }
+
+  Socket release() {
+    Socket sock = sock_;
+    sock_ = INVALID_SOCKET;
+    drain_on_close_ = false;
+    return sock;
+  }
+
+ private:
+  Socket sock_ = INVALID_SOCKET;
+  bool drain_on_close_ = false;
+};
+
 static int Usage(const char *program) {
   fprintf(stderr, "Usage: %s [flags...]\n", program);
   return 1;
@@ -107,7 +158,7 @@
 };
 
 // Connect returns a new socket connected to the runner, or -1 on error.
-static int Connect(const TestConfig *config) {
+static OwnedSocket Connect(const TestConfig *config) {
   sockaddr_storage addr;
   socklen_t addr_len = 0;
   if (config->ipv6) {
@@ -117,7 +168,7 @@
     sin6.sin6_port = htons(config->port);
     if (!inet_pton(AF_INET6, "::1", &sin6.sin6_addr)) {
       PrintSocketError("inet_pton");
-      return -1;
+      return OwnedSocket();
     }
     addr_len = sizeof(sin6);
     memcpy(&addr, &sin6, addr_len);
@@ -128,60 +179,34 @@
     sin.sin_port = htons(config->port);
     if (!inet_pton(AF_INET, "127.0.0.1", &sin.sin_addr)) {
       PrintSocketError("inet_pton");
-      return -1;
+      return OwnedSocket();
     }
     addr_len = sizeof(sin);
     memcpy(&addr, &sin, addr_len);
   }
 
-  int sock = socket(addr.ss_family, SOCK_STREAM, 0);
-  if (sock == -1) {
+  OwnedSocket sock(socket(addr.ss_family, SOCK_STREAM, 0));
+  if (!sock.is_valid()) {
     PrintSocketError("socket");
-    return -1;
+    return OwnedSocket();
   }
   int nodelay = 1;
-  if (setsockopt(sock, IPPROTO_TCP, TCP_NODELAY,
+  if (setsockopt(sock.get(), IPPROTO_TCP, TCP_NODELAY,
                  reinterpret_cast<const char *>(&nodelay),
                  sizeof(nodelay)) != 0) {
     PrintSocketError("setsockopt");
-    closesocket(sock);
-    return -1;
+    return OwnedSocket();
   }
 
-  if (connect(sock, reinterpret_cast<const sockaddr *>(&addr), addr_len) != 0) {
+  if (connect(sock.get(), reinterpret_cast<const sockaddr *>(&addr),
+              addr_len) != 0) {
     PrintSocketError("connect");
-    closesocket(sock);
-    return -1;
+    return OwnedSocket();
   }
 
   return sock;
 }
 
-class SocketCloser {
- public:
-  explicit SocketCloser(int sock) : sock_(sock) {}
-  ~SocketCloser() {
-    // Half-close and drain the socket before releasing it. This seems to be
-    // necessary for graceful shutdown on Windows. It will also avoid write
-    // failures in the test runner.
-#if defined(OPENSSL_WINDOWS)
-    shutdown(sock_, SD_SEND);
-#else
-    shutdown(sock_, SHUT_WR);
-#endif
-    while (true) {
-      char buf[1024];
-      if (recv(sock_, buf, sizeof(buf), 0) <= 0) {
-        break;
-      }
-    }
-    closesocket(sock_);
-  }
-
- private:
-  const int sock_;
-};
-
 // DoRead reads from |ssl|, resolving any asynchronous operations. It returns
 // the result value of the final |SSL_read| call.
 static int DoRead(SSL *ssl, uint8_t *out, size_t max_out) {
@@ -787,13 +812,20 @@
 #endif
   }
 
-  int sock = Connect(config);
-  if (sock == -1) {
+  OwnedSocket sock = Connect(config);
+  if (!sock.is_valid()) {
     return false;
   }
-  SocketCloser closer(sock);
 
-  bssl::UniquePtr<BIO> bio(BIO_new_socket(sock, BIO_NOCLOSE));
+  // Half-close and drain the socket before releasing it. This seems to be
+  // necessary for graceful shutdown on Windows. It will also avoid write
+  // failures in the test runner.
+  sock.set_drain_on_close(true);
+
+  // Windows uses |SOCKET| for socket types, but OpenSSL's API requires casting
+  // them to |int|.
+  bssl::UniquePtr<BIO> bio(
+      BIO_new_socket(static_cast<int>(sock.get()), BIO_NOCLOSE));
   if (!bio) {
     return false;
   }