diff --git a/integration-tests/nghttpx_http2_test.go b/integration-tests/nghttpx_http2_test.go index 30ebe3c9..b33e4efe 100644 --- a/integration-tests/nghttpx_http2_test.go +++ b/integration-tests/nghttpx_http2_test.go @@ -1383,6 +1383,82 @@ func TestH2H1ProxyProtocolV1TCP6(t *testing.T) { } } +// TestH2H1ProxyProtocolV1TCP4TLS tests PROXY protocol version 1 over +// TLS containing TCP4 entry is accepted and X-Forwarded-For contains +// advertised src address. +func TestH2H1ProxyProtocolV1TCP4TLS(t *testing.T) { + opts := options{ + args: []string{ + "--accept-proxy-protocol", + "--add-x-forwarded-for", + "--add-forwarded=for", + "--forwarded-for=ip", + }, + handler: func(w http.ResponseWriter, r *http.Request) { + if got, want := r.Header.Get("X-Forwarded-For"), "192.168.0.2"; got != want { + t.Errorf("X-Forwarded-For: %v; want %v", got, want) + } + if got, want := r.Header.Get("Forwarded"), "for=192.168.0.2"; got != want { + t.Errorf("Forwarded: %v; want %v", got, want) + } + }, + tls: true, + tcpData: []byte("PROXY TCP4 192.168.0.2 192.168.0.100 12345 8080\r\n"), + } + st := newServerTester(t, opts) + defer st.Close() + + res, err := st.http2(requestParam{ + name: "TestH2H1ProxyProtocolV1TCP4TLS", + }) + + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} + +// TestH2H1ProxyProtocolV1TCP6TLS tests PROXY protocol version 1 over +// TLS containing TCP6 entry is accepted and X-Forwarded-For contains +// advertised src address. +func TestH2H1ProxyProtocolV1TCP6TLS(t *testing.T) { + opts := options{ + args: []string{ + "--accept-proxy-protocol", + "--add-x-forwarded-for", + "--add-forwarded=for", + "--forwarded-for=ip", + }, + handler: func(w http.ResponseWriter, r *http.Request) { + if got, want := r.Header.Get("X-Forwarded-For"), "2001:0db8:85a3:0000:0000:8a2e:0370:7334"; got != want { + t.Errorf("X-Forwarded-For: %v; want %v", got, want) + } + if got, want := r.Header.Get("Forwarded"), `for="[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"`; got != want { + t.Errorf("Forwarded: %v; want %v", got, want) + } + }, + tls: true, + tcpData: []byte("PROXY TCP6 2001:0db8:85a3:0000:0000:8a2e:0370:7334 ::1 12345 8080\r\n"), + } + st := newServerTester(t, opts) + defer st.Close() + + res, err := st.http2(requestParam{ + name: "TestH2H1ProxyProtocolV1TCP6TLS", + }) + + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} + // TestH2H1ProxyProtocolV1Unknown tests PROXY protocol version 1 // containing UNKNOWN entry is accepted. func TestH2H1ProxyProtocolV1Unknown(t *testing.T) { @@ -1855,6 +1931,110 @@ func TestH2H1ProxyProtocolV2TCP6(t *testing.T) { } } +// TestH2H1ProxyProtocolV2TCP4TLS tests PROXY protocol version 2 over +// TLS containing AF_INET family is accepted and X-Forwarded-For +// contains advertised src address. +func TestH2H1ProxyProtocolV2TCP4TLS(t *testing.T) { + var v2Hdr bytes.Buffer + writeProxyProtocolV2(&v2Hdr, proxyProtocolV2{ + command: proxyProtocolV2CommandProxy, + sourceAddress: &net.TCPAddr{ + IP: net.ParseIP("192.168.0.2").To4(), + Port: 12345, + }, + destinationAddress: &net.TCPAddr{ + IP: net.ParseIP("192.168.0.100").To4(), + Port: 8080, + }, + additionalData: []byte("foobar"), + }) + + opts := options{ + args: []string{ + "--accept-proxy-protocol", + "--add-x-forwarded-for", + "--add-forwarded=for", + "--forwarded-for=ip", + }, + handler: func(w http.ResponseWriter, r *http.Request) { + if got, want := r.Header.Get("X-Forwarded-For"), "192.168.0.2"; got != want { + t.Errorf("X-Forwarded-For: %v; want %v", got, want) + } + if got, want := r.Header.Get("Forwarded"), "for=192.168.0.2"; got != want { + t.Errorf("Forwarded: %v; want %v", got, want) + } + }, + tls: true, + tcpData: v2Hdr.Bytes(), + } + st := newServerTester(t, opts) + defer st.Close() + + res, err := st.http2(requestParam{ + name: "TestH2H1ProxyProtocolV2TCP4TLS", + }) + + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} + +// TestH2H1ProxyProtocolV2TCP6TLS tests PROXY protocol version 2 over +// TLS containing AF_INET6 family is accepted and X-Forwarded-For +// contains advertised src address. +func TestH2H1ProxyProtocolV2TCP6TLS(t *testing.T) { + var v2Hdr bytes.Buffer + writeProxyProtocolV2(&v2Hdr, proxyProtocolV2{ + command: proxyProtocolV2CommandProxy, + sourceAddress: &net.TCPAddr{ + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Port: 12345, + }, + destinationAddress: &net.TCPAddr{ + IP: net.ParseIP("::1"), + Port: 8080, + }, + additionalData: []byte("foobar"), + }) + + opts := options{ + args: []string{ + "--accept-proxy-protocol", + "--add-x-forwarded-for", + "--add-forwarded=for", + "--forwarded-for=ip", + }, + handler: func(w http.ResponseWriter, r *http.Request) { + if got, want := r.Header.Get("X-Forwarded-For"), "2001:db8:85a3::8a2e:370:7334"; got != want { + t.Errorf("X-Forwarded-For: %v; want %v", got, want) + } + if got, want := r.Header.Get("Forwarded"), `for="[2001:db8:85a3::8a2e:370:7334]"`; got != want { + t.Errorf("Forwarded: %v; want %v", got, want) + } + }, + tls: true, + tcpData: v2Hdr.Bytes(), + } + st := newServerTester(t, opts) + defer st.Close() + + res, err := st.http2(requestParam{ + name: "TestH2H1ProxyProtocolV2TCP6TLS", + }) + + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} + // TestH2H1ProxyProtocolV2Local tests PROXY protocol version 2 // containing cmd == Local is ignored. func TestH2H1ProxyProtocolV2Local(t *testing.T) { diff --git a/integration-tests/server_tester.go b/integration-tests/server_tester.go index 1eed639c..9c5c323a 100644 --- a/integration-tests/server_tester.go +++ b/integration-tests/server_tester.go @@ -76,6 +76,10 @@ type options struct { // tlsConfig is the client side TLS configuration that is used // when tls is true. tlsConfig *tls.Config + // tcpData is additional data that are written to connection + // before TLS handshake starts. This field is ignored if tls + // is false. + tcpData []byte } // newServerTester creates test context. @@ -204,9 +208,15 @@ func newServerTester(t *testing.T, opts options) *serverTester { for { time.Sleep(50 * time.Millisecond) - var conn net.Conn - var err error - if opts.tls { + conn, err := net.Dial("tcp", authority) + if err == nil && opts.tls { + if len(opts.tcpData) > 0 { + if _, err := conn.Write(opts.tcpData); err != nil { + st.Close() + st.t.Fatal("Error writing TCP data") + } + } + var tlsConfig *tls.Config if opts.tlsConfig == nil { tlsConfig = new(tls.Config) @@ -219,9 +229,16 @@ func newServerTester(t *testing.T, opts options) *serverTester { } else { tlsConfig.NextProtos = []string{"h2"} } - conn, err = tls.Dial("tcp", authority, tlsConfig) - } else { - conn, err = net.Dial("tcp", authority) + tlsConn := tls.Client(conn, tlsConfig) + err = tlsConn.Handshake() + if err == nil { + cs := tlsConn.ConnectionState() + if !cs.NegotiatedProtocolIsMutual { + st.Close() + st.t.Fatalf("Error negotiated next protocol is not mutual") + } + conn = tlsConn + } } if err != nil { retry += 1 @@ -231,14 +248,6 @@ func newServerTester(t *testing.T, opts options) *serverTester { } continue } - if opts.tls { - tlsConn := conn.(*tls.Conn) - cs := tlsConn.ConnectionState() - if !cs.NegotiatedProtocolIsMutual { - st.Close() - st.t.Fatalf("Error negotiated next protocol is not mutual") - } - } st.conn = conn break } diff --git a/src/shrpx_client_handler.cc b/src/shrpx_client_handler.cc index 77177de8..081ab153 100644 --- a/src/shrpx_client_handler.cc +++ b/src/shrpx_client_handler.cc @@ -181,6 +181,35 @@ int ClientHandler::write_clear() { return 0; } +int ClientHandler::proxy_protocol_peek_clear() { + rb_.ensure_chunk(); + + assert(rb_.rleft() == 0); + + auto nread = conn_.peek_clear(rb_.last(), rb_.wleft()); + if (nread < 0) { + return -1; + } + if (nread == 0) { + return 0; + } + + if (LOG_ENABLED(INFO)) { + CLOG(INFO, this) << "PROXY-protocol: Peek " << nread + << " bytes from socket"; + } + + rb_.write(nread); + + if (on_read() != 0) { + return -1; + } + + rb_.reset(); + + return 0; +} + int ClientHandler::tls_handshake() { ev_timer_again(conn_.loop, &conn_.rt); @@ -446,7 +475,7 @@ ClientHandler::ClientHandler(Worker *worker, int fd, SSL *ssl, if (!faddr->quic) { if (faddr_->accept_proxy_protocol || config->conn.upstream.accept_proxy_protocol) { - read_ = &ClientHandler::read_clear; + read_ = &ClientHandler::proxy_protocol_peek_clear; write_ = &ClientHandler::noop; on_read_ = &ClientHandler::proxy_protocol_read; on_write_ = &ClientHandler::upstream_noop; @@ -1257,19 +1286,25 @@ ssize_t parse_proxy_line_port(const uint8_t *first, const uint8_t *last) { } // namespace int ClientHandler::on_proxy_protocol_finish() { - if (conn_.tls.ssl) { - conn_.tls.rbuf.append(rb_.pos(), rb_.rleft()); - rb_.reset(); + auto len = rb_.pos() - rb_.begin(); + + assert(len); + + if (LOG_ENABLED(INFO)) { + CLOG(INFO, this) << "PROXY-protocol: Draining " << len + << " bytes from socket"; } - setup_upstream_io_callback(); + rb_.reset(); - // Run on_read to process data left in buffer since they are not - // notified further - if (on_read() != 0) { + if (conn_.read_nolim_clear(rb_.pos(), len) < 0) { return -1; } + rb_.reset(); + + setup_upstream_io_callback(); + return 0; } diff --git a/src/shrpx_client_handler.h b/src/shrpx_client_handler.h index 85063cca..511dd910 100644 --- a/src/shrpx_client_handler.h +++ b/src/shrpx_client_handler.h @@ -68,6 +68,8 @@ public: // Performs clear text I/O int read_clear(); int write_clear(); + // Specialized for PROXY-protocol use; peek data from socket. + int proxy_protocol_peek_clear(); // Performs TLS handshake int tls_handshake(); // Performs TLS I/O diff --git a/src/shrpx_connection.cc b/src/shrpx_connection.cc index 2bb23459..8fcf54cb 100644 --- a/src/shrpx_connection.cc +++ b/src/shrpx_connection.cc @@ -1156,6 +1156,42 @@ ssize_t Connection::read_clear(void *data, size_t len) { return nread; } +ssize_t Connection::read_nolim_clear(void *data, size_t len) { + ssize_t nread; + while ((nread = read(fd, data, len)) == -1 && errno == EINTR) + ; + if (nread == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return 0; + } + return SHRPX_ERR_NETWORK; + } + + if (nread == 0) { + return SHRPX_ERR_EOF; + } + + return nread; +} + +ssize_t Connection::peek_clear(void *data, size_t len) { + ssize_t nread; + while ((nread = recv(fd, data, len, MSG_PEEK)) == -1 && errno == EINTR) + ; + if (nread == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return 0; + } + return SHRPX_ERR_NETWORK; + } + + if (nread == 0) { + return SHRPX_ERR_EOF; + } + + return nread; +} + void Connection::handle_tls_pending_read() { if (!ev_is_active(&rev)) { return; diff --git a/src/shrpx_connection.h b/src/shrpx_connection.h index 52245689..22934dfb 100644 --- a/src/shrpx_connection.h +++ b/src/shrpx_connection.h @@ -141,6 +141,10 @@ struct Connection { ssize_t write_clear(const void *data, size_t len); ssize_t writev_clear(struct iovec *iov, int iovcnt); ssize_t read_clear(void *data, size_t len); + // Read at most |len| bytes of data from socket without rate limit. + ssize_t read_nolim_clear(void *data, size_t len); + // Peek at most |len| bytes of data from socket without rate limit. + ssize_t peek_clear(void *data, size_t len); void handle_tls_pending_read();