spdycat: Handle timeout in connect and SSL/TLS handshake

This commit is contained in:
Tatsuhiro Tsujikawa 2012-10-14 23:39:41 +09:00
parent a28e1c6e7d
commit 22840dbfaf
3 changed files with 168 additions and 19 deletions

View File

@ -68,6 +68,7 @@ struct Config {
bool get_assets;
bool stat;
int spdy_version;
// milliseconds
int timeout;
std::string certfile;
std::string keyfile;
@ -396,13 +397,6 @@ void on_stream_close_callback
}
}
int64_t time_delta(const timeval& a, const timeval& b)
{
int64_t res = (a.tv_sec - b.tv_sec) * 1000;
res += (a.tv_usec - b.tv_usec) / 1000;
return res;
}
void print_stats(const SpdySession& spdySession)
{
std::cout << "***** Statistics *****" << std::endl;
@ -440,11 +434,19 @@ int communicate(const std::string& host, uint16_t port,
SpdySession& spdySession,
const spdylay_session_callbacks *callbacks)
{
int fd = connect_to(host, port);
int rv;
int timeout = config.timeout;
int fd = nonblock_connect_to(host, port, timeout);
if(fd == -1) {
std::cerr << "Could not connect to the host" << std::endl;
return -1;
} else if(fd == -2) {
std::cerr << "Request to " << spdySession.hostport << " timed out "
<< "during establishing connection."
<< std::endl;
return -1;
}
set_tcp_nodelay(fd);
SSL_CTX *ssl_ctx;
ssl_ctx = SSL_CTX_new(TLSv1_client_method());
if(!ssl_ctx) {
@ -484,12 +486,17 @@ int communicate(const std::string& host, uint16_t port,
std::cerr << ERR_error_string(ERR_get_error(), 0) << std::endl;
return -1;
}
if(ssl_handshake(ssl, fd) == -1) {
rv = ssl_nonblock_handshake(ssl, fd, timeout);
if(rv == -1) {
return -1;
} else if(rv == -2) {
std::cerr << "Request to " << spdySession.hostport
<< " timed out in SSL/TLS handshake."
<< std::endl;
return -1;
}
spdySession.record_handshake_time();
make_non_block(fd);
set_tcp_nodelay(fd);
int spdy_version = spdylay_npn_get_version(
reinterpret_cast<const unsigned char*>(next_proto.c_str()),
next_proto.size());
@ -516,11 +523,13 @@ int communicate(const std::string& host, uint16_t port,
}
pollfds[0].fd = fd;
ctl_poll(pollfds, &sc);
int end_time = time(NULL) + config.timeout;
int timeout = config.timeout;
bool ok = true;
timeval tv1, tv2;
while(!sc.finish()) {
if(config.timeout != -1) {
gettimeofday(&tv1, 0);
}
int nfds = poll(pollfds, npollfds, timeout);
if(nfds == -1) {
perror("poll");
@ -543,12 +552,16 @@ int communicate(const std::string& host, uint16_t port,
ok = false;
break;
}
timeout = timeout == -1 ? timeout : end_time - time(NULL);
if (config.timeout != -1 && timeout <= 0) {
std::cout << "Requests to " << spdySession.hostport << "timed out.";
if(config.timeout != -1) {
gettimeofday(&tv2, 0);
timeout -= time_delta(tv2, tv1);
if (timeout <= 0) {
std::cerr << "Requests to " << spdySession.hostport << " timed out."
<< std::endl;
ok = false;
break;
}
}
assert(ok);
ctl_poll(pollfds, &sc);
}
@ -700,7 +713,7 @@ int main(int argc, char **argv)
config.spdy_version = SPDYLAY_PROTO_SPDY3;
break;
case 't':
config.timeout = atoi(optarg);
config.timeout = atoi(optarg) * 1000;
break;
case 'w': {
errno = 0;

View File

@ -29,6 +29,7 @@
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <poll.h>
#include <cassert>
#include <cstdio>
@ -234,6 +235,74 @@ int connect_to(const std::string& host, uint16_t port)
return fd;
}
int nonblock_connect_to(const std::string& host, uint16_t port, int timeout)
{
struct addrinfo hints;
int fd = -1;
int r;
char service[10];
snprintf(service, sizeof(service), "%u", port);
memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
struct addrinfo *res;
r = getaddrinfo(host.c_str(), service, &hints, &res);
if(r != 0) {
std::cerr << "getaddrinfo: " << gai_strerror(r) << std::endl;
return -1;
}
for(struct addrinfo *rp = res; rp; rp = rp->ai_next) {
fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if(fd == -1) {
continue;
}
if(make_non_block(fd) == -1) {
close(fd);
fd = -1;
continue;
}
while((r = connect(fd, rp->ai_addr, rp->ai_addrlen)) == -1 &&
errno == EINTR);
if(r == 0) {
break;
} else if(errno == EINPROGRESS) {
struct timeval tv1, tv2;
struct pollfd pfd = {fd, POLLOUT, 0};
if(timeout != -1) {
gettimeofday(&tv1, 0);
}
r = poll(&pfd, 1, timeout);
if(r == 0) {
return -2;
} else if(r == -1) {
return -1;
} else {
if(timeout != -1) {
gettimeofday(&tv2, 0);
timeout -= time_delta(tv2, tv1);
if(timeout <= 0) {
return -2;
}
}
int socket_error;
socklen_t optlen = sizeof(socket_error);
r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &socket_error, &optlen);
if(r == 0 && socket_error == 0) {
break;
} else {
close(fd);
fd = -1;
}
}
} else {
close(fd);
fd = -1;
}
}
freeaddrinfo(res);
return fd;
}
int make_listen_socket(const std::string& host, uint16_t port, int family)
{
addrinfo hints;
@ -664,6 +733,65 @@ int ssl_handshake(SSL *ssl, int fd)
return 0;
}
int ssl_nonblock_handshake(SSL *ssl, int fd, int& timeout)
{
if(SSL_set_fd(ssl, fd) == 0) {
std::cerr << ERR_error_string(ERR_get_error(), 0) << std::endl;
return -1;
}
ERR_clear_error();
pollfd pfd;
pfd.fd = fd;
pfd.events = POLLOUT;
timeval tv1, tv2;
while(1) {
if(timeout != -1) {
gettimeofday(&tv1, 0);
}
int rv = poll(&pfd, 1, timeout);
if(rv == 0) {
return -2;
} else if(rv == -1) {
return -1;
}
ERR_clear_error();
rv = SSL_connect(ssl);
if(rv == 0) {
std::cerr << ERR_error_string(ERR_get_error(), 0) << std::endl;
return -1;
} else if(rv < 0) {
if(timeout != -1) {
gettimeofday(&tv2, 0);
timeout -= time_delta(tv2, tv1);
if(timeout <= 0) {
return -2;
}
}
switch(SSL_get_error(ssl, rv)) {
case SSL_ERROR_WANT_READ:
pfd.events = POLLIN;
break;
case SSL_ERROR_WANT_WRITE:
pfd.events = POLLOUT;
break;
default:
std::cerr << ERR_error_string(ERR_get_error(), 0) << std::endl;
return -1;
}
} else {
break;
}
}
return 0;
}
int64_t time_delta(const timeval& a, const timeval& b)
{
int64_t res = (a.tv_sec - b.tv_sec) * 1000;
res += (a.tv_usec - b.tv_usec) / 1000;
return res;
}
namespace {
timeval base_tv;
} // namespace

View File

@ -73,6 +73,8 @@ private:
int connect_to(const std::string& host, uint16_t port);
int nonblock_connect_to(const std::string& host, uint16_t port, int timeout);
int make_listen_socket(const std::string& host, uint16_t port, int family);
int make_non_block(int fd);
@ -134,6 +136,12 @@ void setup_ssl_ctx(SSL_CTX *ssl_ctx, void *next_proto_select_cb_arg);
int ssl_handshake(SSL *ssl, int fd);
int ssl_nonblock_handshake(SSL *ssl, int fd, int& timeout);
// Returns difference between |a| and |b| in milliseconds, assuming
// |a| is more recent than |b|.
int64_t time_delta(const timeval& a, const timeval& b);
void reset_timer();
void get_timer(timeval *tv);