blob: 0f48703c17f43052cc000781939b3b0019783725 [file] [log] [blame]
Dave Tapuskab8a824d2014-12-10 19:09:52 -05001/* Copyright (c) 2014, Google Inc.
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15#include <openssl/base.h>
16
Dave Tapuskab8a824d2014-12-10 19:09:52 -050017#include <string>
18#include <vector>
19
20#include <errno.h>
21#include <stdlib.h>
22#include <sys/types.h>
Dave Tapuskab8a824d2014-12-10 19:09:52 -050023
24#if !defined(OPENSSL_WINDOWS)
25#include <arpa/inet.h>
26#include <fcntl.h>
27#include <netdb.h>
28#include <netinet/in.h>
29#include <sys/select.h>
Brian Smith33970e62015-01-27 22:32:08 -080030#include <sys/socket.h>
Dave Tapuskab8a824d2014-12-10 19:09:52 -050031#include <unistd.h>
32#else
Brian Smith33970e62015-01-27 22:32:08 -080033#define NOMINMAX
34#include <io.h>
Dave Tapuskab8a824d2014-12-10 19:09:52 -050035#include <WinSock2.h>
36#include <WS2tcpip.h>
Brian Smith33970e62015-01-27 22:32:08 -080037typedef int ssize_t;
38#define read _read
39#define write _write
40#pragma comment(lib, "Ws2_32.lib")
Dave Tapuskab8a824d2014-12-10 19:09:52 -050041#endif
42
43#include <openssl/err.h>
44#include <openssl/ssl.h>
45
46#include "internal.h"
47
48
Brian Smith33970e62015-01-27 22:32:08 -080049#if !defined(OPENSSL_WINDOWS)
50static int closesocket(int sock) {
51 return close(sock);
52}
53#endif
54
55bool InitSocketLibrary() {
56#if defined(OPENSSL_WINDOWS)
57 WSADATA wsaData;
58 int err = WSAStartup(MAKEWORD(2, 2), &wsaData);
59 if (err != 0) {
60 fprintf(stderr, "WSAStartup failed with error %d\n", err);
61 return false;
62 }
63#endif
64 return true;
65}
66
Dave Tapuskab8a824d2014-12-10 19:09:52 -050067// Connect sets |*out_sock| to be a socket connected to the destination given
68// in |hostname_and_port|, which should be of the form "www.example.com:123".
69// It returns true on success and false otherwise.
70bool Connect(int *out_sock, const std::string &hostname_and_port) {
71 const size_t colon_offset = hostname_and_port.find_last_of(':');
72 std::string hostname, port;
73
74 if (colon_offset == std::string::npos) {
75 hostname = hostname_and_port;
76 port = "443";
77 } else {
78 hostname = hostname_and_port.substr(0, colon_offset);
79 port = hostname_and_port.substr(colon_offset + 1);
80 }
81
82 struct addrinfo hint, *result;
83 memset(&hint, 0, sizeof(hint));
84 hint.ai_family = AF_UNSPEC;
85 hint.ai_socktype = SOCK_STREAM;
86
87 int ret = getaddrinfo(hostname.c_str(), port.c_str(), &hint, &result);
88 if (ret != 0) {
89 fprintf(stderr, "getaddrinfo returned: %s\n", gai_strerror(ret));
90 return false;
91 }
92
93 bool ok = false;
94 char buf[256];
95
96 *out_sock =
97 socket(result->ai_family, result->ai_socktype, result->ai_protocol);
98 if (*out_sock < 0) {
99 perror("socket");
100 goto out;
101 }
102
103 switch (result->ai_family) {
104 case AF_INET: {
105 struct sockaddr_in *sin =
106 reinterpret_cast<struct sockaddr_in *>(result->ai_addr);
107 fprintf(stderr, "Connecting to %s:%d\n",
108 inet_ntop(result->ai_family, &sin->sin_addr, buf, sizeof(buf)),
109 ntohs(sin->sin_port));
110 break;
111 }
112 case AF_INET6: {
113 struct sockaddr_in6 *sin6 =
114 reinterpret_cast<struct sockaddr_in6 *>(result->ai_addr);
115 fprintf(stderr, "Connecting to [%s]:%d\n",
116 inet_ntop(result->ai_family, &sin6->sin6_addr, buf, sizeof(buf)),
117 ntohs(sin6->sin6_port));
118 break;
119 }
120 }
121
122 if (connect(*out_sock, result->ai_addr, result->ai_addrlen) != 0) {
123 perror("connect");
124 goto out;
125 }
126 ok = true;
127
128out:
129 freeaddrinfo(result);
130 return ok;
131}
132
133bool Accept(int *out_sock, const std::string &port) {
134 struct sockaddr_in addr, cli_addr;
135 socklen_t cli_addr_len = sizeof(cli_addr);
136 memset(&addr, 0, sizeof(addr));
137
138 addr.sin_family = AF_INET;
139 addr.sin_addr.s_addr = INADDR_ANY;
140 addr.sin_port = htons(atoi(port.c_str()));
141
142 bool ok = false;
143 int server_sock = -1;
144
145 server_sock =
146 socket(addr.sin_family, SOCK_STREAM, 0);
147 if (server_sock < 0) {
148 perror("socket");
149 goto out;
150 }
151
152 if (bind(server_sock, (struct sockaddr*)&addr, sizeof(addr)) != 0) {
153 perror("connect");
154 goto out;
155 }
156 listen(server_sock, 1);
157 *out_sock = accept(server_sock, (struct sockaddr*)&cli_addr, &cli_addr_len);
158
159 ok = true;
160
161out:
Brian Smith33970e62015-01-27 22:32:08 -0800162 closesocket(server_sock);
Dave Tapuskab8a824d2014-12-10 19:09:52 -0500163 return ok;
164}
165
166void PrintConnectionInfo(const SSL *ssl) {
167 const SSL_CIPHER *cipher = SSL_get_current_cipher(ssl);
168
169 fprintf(stderr, " Version: %s\n", SSL_get_version(ssl));
170 fprintf(stderr, " Cipher: %s\n", SSL_CIPHER_get_name(cipher));
171 fprintf(stderr, " Secure renegotiation: %s\n",
172 SSL_get_secure_renegotiation_support(ssl) ? "yes" : "no");
173}
174
175bool SocketSetNonBlocking(int sock, bool is_non_blocking) {
176 bool ok;
177
178#if defined(OPENSSL_WINDOWS)
179 u_long arg = is_non_blocking;
Brian Smith33970e62015-01-27 22:32:08 -0800180 ok = 0 == ioctlsocket(sock, FIONBIO, &arg);
Dave Tapuskab8a824d2014-12-10 19:09:52 -0500181#else
182 int flags = fcntl(sock, F_GETFL, 0);
183 if (flags < 0) {
184 return false;
185 }
186 if (is_non_blocking) {
187 flags |= O_NONBLOCK;
188 } else {
189 flags &= ~O_NONBLOCK;
190 }
191 ok = 0 == fcntl(sock, F_SETFL, flags);
192#endif
193 if (!ok) {
194 fprintf(stderr, "Failed to set socket non-blocking.\n");
195 }
196 return ok;
197}
198
199// PrintErrorCallback is a callback function from OpenSSL's
200// |ERR_print_errors_cb| that writes errors to a given |FILE*|.
201int PrintErrorCallback(const char *str, size_t len, void *ctx) {
202 fwrite(str, len, 1, reinterpret_cast<FILE*>(ctx));
203 return 1;
204}
205
206bool TransferData(SSL *ssl, int sock) {
207 bool stdin_open = true;
208
209 fd_set read_fds;
210 FD_ZERO(&read_fds);
211
212 if (!SocketSetNonBlocking(sock, true)) {
213 return false;
214 }
215
216 for (;;) {
217 if (stdin_open) {
218 FD_SET(0, &read_fds);
219 }
220 FD_SET(sock, &read_fds);
221
222 int ret = select(sock + 1, &read_fds, NULL, NULL, NULL);
223 if (ret <= 0) {
224 perror("select");
225 return false;
226 }
227
228 if (FD_ISSET(0, &read_fds)) {
229 uint8_t buffer[512];
230 ssize_t n;
231
232 do {
233 n = read(0, buffer, sizeof(buffer));
234 } while (n == -1 && errno == EINTR);
235
236 if (n == 0) {
237 FD_CLR(0, &read_fds);
238 stdin_open = false;
Brian Smith33970e62015-01-27 22:32:08 -0800239#if !defined(OPENSSL_WINDOWS)
Dave Tapuskab8a824d2014-12-10 19:09:52 -0500240 shutdown(sock, SHUT_WR);
Brian Smith33970e62015-01-27 22:32:08 -0800241#else
242 shutdown(sock, SD_SEND);
243#endif
Dave Tapuskab8a824d2014-12-10 19:09:52 -0500244 continue;
245 } else if (n < 0) {
246 perror("read from stdin");
247 return false;
248 }
249
250 if (!SocketSetNonBlocking(sock, false)) {
251 return false;
252 }
253 int ssl_ret = SSL_write(ssl, buffer, n);
254 if (!SocketSetNonBlocking(sock, true)) {
255 return false;
256 }
257
258 if (ssl_ret <= 0) {
259 int ssl_err = SSL_get_error(ssl, ssl_ret);
260 fprintf(stderr, "Error while writing: %d\n", ssl_err);
261 ERR_print_errors_cb(PrintErrorCallback, stderr);
262 return false;
263 } else if (ssl_ret != n) {
264 fprintf(stderr, "Short write from SSL_write.\n");
265 return false;
266 }
267 }
268
269 if (FD_ISSET(sock, &read_fds)) {
270 uint8_t buffer[512];
271 int ssl_ret = SSL_read(ssl, buffer, sizeof(buffer));
272
273 if (ssl_ret < 0) {
274 int ssl_err = SSL_get_error(ssl, ssl_ret);
275 if (ssl_err == SSL_ERROR_WANT_READ) {
276 continue;
277 }
278 fprintf(stderr, "Error while reading: %d\n", ssl_err);
279 ERR_print_errors_cb(PrintErrorCallback, stderr);
280 return false;
281 } else if (ssl_ret == 0) {
282 return true;
283 }
284
285 ssize_t n;
286 do {
287 n = write(1, buffer, ssl_ret);
288 } while (n == -1 && errno == EINTR);
289
290 if (n != ssl_ret) {
291 fprintf(stderr, "Short write to stderr.\n");
292 return false;
293 }
294 }
295 }
296}