diff --git a/src/shrpx_connection.cc b/src/shrpx_connection.cc index 2269c63f..0b3c0302 100644 --- a/src/shrpx_connection.cc +++ b/src/shrpx_connection.cc @@ -60,7 +60,8 @@ Connection::Connection(struct ev_loop *loop, int fd, SSL *ssl, IOCb readcb, TimerCb timeoutcb, void *data, size_t tls_dyn_rec_warmup_threshold, ev_tstamp tls_dyn_rec_idle_timeout, shrpx_proto proto) - : tls{DefaultMemchunks(mcpool), DefaultPeekMemchunks(mcpool)}, + : tls{DefaultMemchunks(mcpool), DefaultPeekMemchunks(mcpool), + DefaultMemchunks(mcpool)}, wlimit(loop, &wev, write_limit.rate, write_limit.burst), rlimit(loop, &rev, read_limit.rate, read_limit.burst, this), loop(loop), @@ -120,10 +121,11 @@ void Connection::disconnect() { tls.warmup_writelen = 0; tls.last_writelen = 0; tls.last_readlen = 0; - tls.handshake_state = 0; + tls.handshake_state = TLS_CONN_NORMAL; tls.initial_handshake_done = false; tls.reneg_started = false; tls.sct_requested = false; + tls.early_data_finish = false; } if (fd != -1) { @@ -141,7 +143,11 @@ void Connection::disconnect() { wlimit.stopw(); } -void Connection::prepare_client_handshake() { SSL_set_connect_state(tls.ssl); } +void Connection::prepare_client_handshake() { + SSL_set_connect_state(tls.ssl); + // This prevents SSL_read_early_data from being called. + tls.early_data_finish = true; +} void Connection::prepare_server_handshake() { SSL_set_accept_state(tls.ssl); @@ -327,8 +333,9 @@ int Connection::tls_handshake() { wlimit.stopw(); ev_timer_stop(loop, &wt); + std::array buf; + if (ev_is_active(&rev)) { - std::array buf; auto nread = read_clear(buf.data(), buf.size()); if (nread < 0) { if (LOG_ENABLED(INFO)) { @@ -381,9 +388,59 @@ int Connection::tls_handshake() { break; } + int rv; + ERR_clear_error(); - auto rv = SSL_do_handshake(tls.ssl); +#if OPENSSL_1_1_1_API + if (!tls.server_handshake || tls.early_data_finish) { + rv = SSL_do_handshake(tls.ssl); + } else { + for (;;) { + size_t nread; + + rv = SSL_read_early_data(tls.ssl, buf.data(), buf.size(), &nread); + if (rv == SSL_READ_EARLY_DATA_ERROR) { + // If we have early data, and server sends ServerHello, assume + // that handshake is completed in server side, and start + // processing request. If we don't exit handshake code here, + // server waits for EndOfEarlyData and Finished message from + // client, which voids the purpose of 0-RTT data. The left + // over of handshake is done through write_tls or read_tls. + rv = (tls.handshake_state == TLS_CONN_WRITE_STARTED || + tls.wbuf.rleft()) && + tls.earlybuf.rleft(); + break; + } + + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "tls: read early data " << nread << " bytes"; + } + + tls.earlybuf.append(buf.data(), nread); + + if (rv == SSL_READ_EARLY_DATA_FINISH) { + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "tls: read all early data; total " + << tls.earlybuf.rleft() << " bytes"; + } + tls.early_data_finish = true; + // The same reason stated above. + if ((tls.handshake_state == TLS_CONN_WRITE_STARTED || + tls.wbuf.rleft()) && + tls.earlybuf.rleft()) { + rv = 1; + } else { + ERR_clear_error(); + rv = SSL_do_handshake(tls.ssl); + } + break; + } + } + } +#else // !OPENSSL_1_1_1_API + rv = SSL_do_handshake(tls.ssl); +#endif // !OPENSSL_1_1_1_API if (rv <= 0) { auto err = SSL_get_error(tls.ssl, rv); @@ -621,7 +678,21 @@ ssize_t Connection::write_tls(const void *data, size_t len) { ERR_clear_error(); +#if OPENSSL_1_1_1_API + int rv; + if (SSL_is_init_finished(tls.ssl)) { + rv = SSL_write(tls.ssl, data, len); + } else { + size_t nwrite; + rv = SSL_write_early_data(tls.ssl, data, len, &nwrite); + // Use the same semantics with SSL_write. + if (rv == 1) { + rv = nwrite; + } + } +#else // !OPENSSL_1_1_1_API auto rv = SSL_write(tls.ssl, data, len); +#endif // !OPENSSL_1_1_1_API if (rv <= 0) { auto err = SSL_get_error(tls.ssl, rv); @@ -656,6 +727,52 @@ ssize_t Connection::write_tls(const void *data, size_t len) { } ssize_t Connection::read_tls(void *data, size_t len) { + ERR_clear_error(); + +#if OPENSSL_1_1_1_API + if (tls.earlybuf.rleft()) { + return tls.earlybuf.remove(data, len); + } + if (!tls.early_data_finish) { + // TLSv1.3 handshake is still going on. + size_t nread; + auto rv = SSL_read_early_data(tls.ssl, data, len, &nread); + if (rv == SSL_READ_EARLY_DATA_ERROR) { + auto err = SSL_get_error(tls.ssl, rv); + switch (err) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: // TODO Probably not required. + return 0; + case SSL_ERROR_SSL: + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "SSL_read: " + << ERR_error_string(ERR_get_error(), nullptr); + } + return SHRPX_ERR_NETWORK; + default: + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "SSL_read: SSL_get_error returned " << err; + } + return SHRPX_ERR_NETWORK; + } + } + + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "tls: read early data " << nread << " bytes"; + } + + if (rv == SSL_READ_EARLY_DATA_FINISH) { + if (LOG_ENABLED(INFO)) { + LOG(INFO) << "tls: read all early data"; + } + tls.early_data_finish = true; + // We may have stopped write watcher in write_tls. + wlimit.startw(); + } + return nread; + } +#endif // OPENSSL_1_1_1_API + // SSL_read requires the same arguments (buf pointer and its // length) on SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE. // rlimit_.avail() or rlimit_.avail() may return different length @@ -673,8 +790,6 @@ ssize_t Connection::read_tls(void *data, size_t len) { tls.last_readlen = 0; } - ERR_clear_error(); - auto rv = SSL_read(tls.ssl, data, len); if (rv <= 0) { diff --git a/src/shrpx_connection.h b/src/shrpx_connection.h index 441ff513..d9ca5c0f 100644 --- a/src/shrpx_connection.h +++ b/src/shrpx_connection.h @@ -56,6 +56,8 @@ enum { struct TLSConnection { DefaultMemchunks wbuf; DefaultPeekMemchunks rbuf; + // Stores TLSv1.3 early data. + DefaultMemchunks earlybuf; SSL *ssl; SSL_SESSION *cached_session; MemcachedRequest *cached_session_lookup_req; @@ -74,6 +76,12 @@ struct TLSConnection { // true if ssl is initialized as server, and client requested // signed_certificate_timestamp extension. bool sct_requested; + // true if TLSv1.3 early data has been completely received. Since + // SSL_read_early_data acts like SSL_do_handshake, this field may be + // true even if the negotiated TLS version is TLSv1.2 or earlier. + // This value is also true if this is client side connection for + // convenience. + bool early_data_finish; }; struct TCPHint { diff --git a/src/shrpx_rate_limit.cc b/src/shrpx_rate_limit.cc index 77a1fe22..f05ae64c 100644 --- a/src/shrpx_rate_limit.cc +++ b/src/shrpx_rate_limit.cc @@ -108,8 +108,9 @@ void RateLimit::stopw() { } void RateLimit::handle_tls_pending_read() { - if (!conn_ || !conn_->tls.ssl || - (SSL_pending(conn_->tls.ssl) == 0 && conn_->tls.rbuf.rleft() == 0)) { + if (!conn_ || !conn_->tls.ssl || !conn_->tls.initial_handshake_done || + (SSL_pending(conn_->tls.ssl) == 0 && conn_->tls.rbuf.rleft() == 0 && + conn_->tls.earlybuf.rleft() == 0)) { return; } diff --git a/src/shrpx_tls.cc b/src/shrpx_tls.cc index 74ca9e16..bbfd6cfc 100644 --- a/src/shrpx_tls.cc +++ b/src/shrpx_tls.cc @@ -517,6 +517,13 @@ int ticket_key_cb(SSL *ssl, unsigned char *key_name, unsigned char *iv, namespace { void info_callback(const SSL *ssl, int where, int ret) { +#ifdef TLS1_3_VERSION + // TLSv1.3 has no renegotiation. + if (SSL_version(ssl) == TLS1_3_VERSION) { + return; + } +#endif // TLS1_3_VERSION + // To mitigate possible DOS attack using lots of renegotiations, we // disable renegotiation. Since OpenSSL does not provide an easy way // to disable it, we check that renegotiation is started in this