From 22840dbfaf9178d93c946d62af8a761d400ea2b5 Mon Sep 17 00:00:00 2001 From: Tatsuhiro Tsujikawa Date: Sun, 14 Oct 2012 23:39:41 +0900 Subject: [PATCH] spdycat: Handle timeout in connect and SSL/TLS handshake --- src/spdycat.cc | 51 +++++++++++------- src/spdylay_ssl.cc | 128 +++++++++++++++++++++++++++++++++++++++++++++ src/spdylay_ssl.h | 8 +++ 3 files changed, 168 insertions(+), 19 deletions(-) diff --git a/src/spdycat.cc b/src/spdycat.cc index 0e60ae24..4020b4c0 100644 --- a/src/spdycat.cc +++ b/src/spdycat.cc @@ -68,6 +68,7 @@ struct Config { bool get_assets; bool stat; int spdy_version; + // milliseconds int timeout; std::string certfile; std::string keyfile; @@ -396,13 +397,6 @@ void on_stream_close_callback } } -int64_t time_delta(const timeval& a, const timeval& b) -{ - int64_t res = (a.tv_sec - b.tv_sec) * 1000; - res += (a.tv_usec - b.tv_usec) / 1000; - return res; -} - void print_stats(const SpdySession& spdySession) { std::cout << "***** Statistics *****" << std::endl; @@ -440,11 +434,19 @@ int communicate(const std::string& host, uint16_t port, SpdySession& spdySession, const spdylay_session_callbacks *callbacks) { - int fd = connect_to(host, port); + int rv; + int timeout = config.timeout; + int fd = nonblock_connect_to(host, port, timeout); if(fd == -1) { std::cerr << "Could not connect to the host" << std::endl; return -1; + } else if(fd == -2) { + std::cerr << "Request to " << spdySession.hostport << " timed out " + << "during establishing connection." + << std::endl; + return -1; } + set_tcp_nodelay(fd); SSL_CTX *ssl_ctx; ssl_ctx = SSL_CTX_new(TLSv1_client_method()); if(!ssl_ctx) { @@ -484,12 +486,17 @@ int communicate(const std::string& host, uint16_t port, std::cerr << ERR_error_string(ERR_get_error(), 0) << std::endl; return -1; } - if(ssl_handshake(ssl, fd) == -1) { + rv = ssl_nonblock_handshake(ssl, fd, timeout); + if(rv == -1) { + return -1; + } else if(rv == -2) { + std::cerr << "Request to " << spdySession.hostport + << " timed out in SSL/TLS handshake." + << std::endl; return -1; } + spdySession.record_handshake_time(); - make_non_block(fd); - set_tcp_nodelay(fd); int spdy_version = spdylay_npn_get_version( reinterpret_cast(next_proto.c_str()), next_proto.size()); @@ -516,11 +523,13 @@ int communicate(const std::string& host, uint16_t port, } pollfds[0].fd = fd; ctl_poll(pollfds, &sc); - int end_time = time(NULL) + config.timeout; - int timeout = config.timeout; bool ok = true; + timeval tv1, tv2; while(!sc.finish()) { + if(config.timeout != -1) { + gettimeofday(&tv1, 0); + } int nfds = poll(pollfds, npollfds, timeout); if(nfds == -1) { perror("poll"); @@ -543,11 +552,15 @@ int communicate(const std::string& host, uint16_t port, ok = false; break; } - timeout = timeout == -1 ? timeout : end_time - time(NULL); - if (config.timeout != -1 && timeout <= 0) { - std::cout << "Requests to " << spdySession.hostport << "timed out."; - ok = false; - break; + if(config.timeout != -1) { + gettimeofday(&tv2, 0); + timeout -= time_delta(tv2, tv1); + if (timeout <= 0) { + std::cerr << "Requests to " << spdySession.hostport << " timed out." + << std::endl; + ok = false; + break; + } } assert(ok); ctl_poll(pollfds, &sc); @@ -700,7 +713,7 @@ int main(int argc, char **argv) config.spdy_version = SPDYLAY_PROTO_SPDY3; break; case 't': - config.timeout = atoi(optarg); + config.timeout = atoi(optarg) * 1000; break; case 'w': { errno = 0; diff --git a/src/spdylay_ssl.cc b/src/spdylay_ssl.cc index 8cbb4da3..d3131bb7 100644 --- a/src/spdylay_ssl.cc +++ b/src/spdylay_ssl.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -234,6 +235,74 @@ int connect_to(const std::string& host, uint16_t port) return fd; } +int nonblock_connect_to(const std::string& host, uint16_t port, int timeout) +{ + struct addrinfo hints; + int fd = -1; + int r; + char service[10]; + snprintf(service, sizeof(service), "%u", port); + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + struct addrinfo *res; + r = getaddrinfo(host.c_str(), service, &hints, &res); + if(r != 0) { + std::cerr << "getaddrinfo: " << gai_strerror(r) << std::endl; + return -1; + } + for(struct addrinfo *rp = res; rp; rp = rp->ai_next) { + fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + if(fd == -1) { + continue; + } + if(make_non_block(fd) == -1) { + close(fd); + fd = -1; + continue; + } + while((r = connect(fd, rp->ai_addr, rp->ai_addrlen)) == -1 && + errno == EINTR); + if(r == 0) { + break; + } else if(errno == EINPROGRESS) { + struct timeval tv1, tv2; + struct pollfd pfd = {fd, POLLOUT, 0}; + if(timeout != -1) { + gettimeofday(&tv1, 0); + } + r = poll(&pfd, 1, timeout); + if(r == 0) { + return -2; + } else if(r == -1) { + return -1; + } else { + if(timeout != -1) { + gettimeofday(&tv2, 0); + timeout -= time_delta(tv2, tv1); + if(timeout <= 0) { + return -2; + } + } + int socket_error; + socklen_t optlen = sizeof(socket_error); + r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &socket_error, &optlen); + if(r == 0 && socket_error == 0) { + break; + } else { + close(fd); + fd = -1; + } + } + } else { + close(fd); + fd = -1; + } + } + freeaddrinfo(res); + return fd; +} + int make_listen_socket(const std::string& host, uint16_t port, int family) { addrinfo hints; @@ -664,6 +733,65 @@ int ssl_handshake(SSL *ssl, int fd) return 0; } +int ssl_nonblock_handshake(SSL *ssl, int fd, int& timeout) +{ + if(SSL_set_fd(ssl, fd) == 0) { + std::cerr << ERR_error_string(ERR_get_error(), 0) << std::endl; + return -1; + } + ERR_clear_error(); + pollfd pfd; + pfd.fd = fd; + pfd.events = POLLOUT; + timeval tv1, tv2; + while(1) { + if(timeout != -1) { + gettimeofday(&tv1, 0); + } + int rv = poll(&pfd, 1, timeout); + if(rv == 0) { + return -2; + } else if(rv == -1) { + return -1; + } + ERR_clear_error(); + rv = SSL_connect(ssl); + if(rv == 0) { + std::cerr << ERR_error_string(ERR_get_error(), 0) << std::endl; + return -1; + } else if(rv < 0) { + if(timeout != -1) { + gettimeofday(&tv2, 0); + timeout -= time_delta(tv2, tv1); + if(timeout <= 0) { + return -2; + } + } + switch(SSL_get_error(ssl, rv)) { + case SSL_ERROR_WANT_READ: + pfd.events = POLLIN; + break; + case SSL_ERROR_WANT_WRITE: + pfd.events = POLLOUT; + break; + default: + std::cerr << ERR_error_string(ERR_get_error(), 0) << std::endl; + return -1; + } + } else { + break; + } + } + return 0; +} + +int64_t time_delta(const timeval& a, const timeval& b) +{ + int64_t res = (a.tv_sec - b.tv_sec) * 1000; + res += (a.tv_usec - b.tv_usec) / 1000; + return res; +} + namespace { timeval base_tv; } // namespace diff --git a/src/spdylay_ssl.h b/src/spdylay_ssl.h index b2364667..f0909518 100644 --- a/src/spdylay_ssl.h +++ b/src/spdylay_ssl.h @@ -73,6 +73,8 @@ private: int connect_to(const std::string& host, uint16_t port); +int nonblock_connect_to(const std::string& host, uint16_t port, int timeout); + int make_listen_socket(const std::string& host, uint16_t port, int family); int make_non_block(int fd); @@ -134,6 +136,12 @@ void setup_ssl_ctx(SSL_CTX *ssl_ctx, void *next_proto_select_cb_arg); int ssl_handshake(SSL *ssl, int fd); +int ssl_nonblock_handshake(SSL *ssl, int fd, int& timeout); + +// Returns difference between |a| and |b| in milliseconds, assuming +// |a| is more recent than |b|. +int64_t time_delta(const timeval& a, const timeval& b); + void reset_timer(); void get_timer(timeval *tv);