nghttpx: Fix broken PROXY-protocol

Fix PROXY-protocol that is enabled for TLS connection.
This commit is contained in:
Tatsuhiro Tsujikawa 2022-07-04 21:21:02 +09:00
parent e065cbccb6
commit d9acf873ed
6 changed files with 288 additions and 22 deletions

View File

@ -1383,6 +1383,82 @@ func TestH2H1ProxyProtocolV1TCP6(t *testing.T) {
}
}
// TestH2H1ProxyProtocolV1TCP4TLS tests PROXY protocol version 1 over
// TLS containing TCP4 entry is accepted and X-Forwarded-For contains
// advertised src address.
func TestH2H1ProxyProtocolV1TCP4TLS(t *testing.T) {
opts := options{
args: []string{
"--accept-proxy-protocol",
"--add-x-forwarded-for",
"--add-forwarded=for",
"--forwarded-for=ip",
},
handler: func(w http.ResponseWriter, r *http.Request) {
if got, want := r.Header.Get("X-Forwarded-For"), "192.168.0.2"; got != want {
t.Errorf("X-Forwarded-For: %v; want %v", got, want)
}
if got, want := r.Header.Get("Forwarded"), "for=192.168.0.2"; got != want {
t.Errorf("Forwarded: %v; want %v", got, want)
}
},
tls: true,
tcpData: []byte("PROXY TCP4 192.168.0.2 192.168.0.100 12345 8080\r\n"),
}
st := newServerTester(t, opts)
defer st.Close()
res, err := st.http2(requestParam{
name: "TestH2H1ProxyProtocolV1TCP4TLS",
})
if err != nil {
t.Fatalf("Error st.http2() = %v", err)
}
if got, want := res.status, 200; got != want {
t.Errorf("res.status: %v; want %v", got, want)
}
}
// TestH2H1ProxyProtocolV1TCP6TLS tests PROXY protocol version 1 over
// TLS containing TCP6 entry is accepted and X-Forwarded-For contains
// advertised src address.
func TestH2H1ProxyProtocolV1TCP6TLS(t *testing.T) {
opts := options{
args: []string{
"--accept-proxy-protocol",
"--add-x-forwarded-for",
"--add-forwarded=for",
"--forwarded-for=ip",
},
handler: func(w http.ResponseWriter, r *http.Request) {
if got, want := r.Header.Get("X-Forwarded-For"), "2001:0db8:85a3:0000:0000:8a2e:0370:7334"; got != want {
t.Errorf("X-Forwarded-For: %v; want %v", got, want)
}
if got, want := r.Header.Get("Forwarded"), `for="[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"`; got != want {
t.Errorf("Forwarded: %v; want %v", got, want)
}
},
tls: true,
tcpData: []byte("PROXY TCP6 2001:0db8:85a3:0000:0000:8a2e:0370:7334 ::1 12345 8080\r\n"),
}
st := newServerTester(t, opts)
defer st.Close()
res, err := st.http2(requestParam{
name: "TestH2H1ProxyProtocolV1TCP6TLS",
})
if err != nil {
t.Fatalf("Error st.http2() = %v", err)
}
if got, want := res.status, 200; got != want {
t.Errorf("res.status: %v; want %v", got, want)
}
}
// TestH2H1ProxyProtocolV1Unknown tests PROXY protocol version 1
// containing UNKNOWN entry is accepted.
func TestH2H1ProxyProtocolV1Unknown(t *testing.T) {
@ -1855,6 +1931,110 @@ func TestH2H1ProxyProtocolV2TCP6(t *testing.T) {
}
}
// TestH2H1ProxyProtocolV2TCP4TLS tests PROXY protocol version 2 over
// TLS containing AF_INET family is accepted and X-Forwarded-For
// contains advertised src address.
func TestH2H1ProxyProtocolV2TCP4TLS(t *testing.T) {
var v2Hdr bytes.Buffer
writeProxyProtocolV2(&v2Hdr, proxyProtocolV2{
command: proxyProtocolV2CommandProxy,
sourceAddress: &net.TCPAddr{
IP: net.ParseIP("192.168.0.2").To4(),
Port: 12345,
},
destinationAddress: &net.TCPAddr{
IP: net.ParseIP("192.168.0.100").To4(),
Port: 8080,
},
additionalData: []byte("foobar"),
})
opts := options{
args: []string{
"--accept-proxy-protocol",
"--add-x-forwarded-for",
"--add-forwarded=for",
"--forwarded-for=ip",
},
handler: func(w http.ResponseWriter, r *http.Request) {
if got, want := r.Header.Get("X-Forwarded-For"), "192.168.0.2"; got != want {
t.Errorf("X-Forwarded-For: %v; want %v", got, want)
}
if got, want := r.Header.Get("Forwarded"), "for=192.168.0.2"; got != want {
t.Errorf("Forwarded: %v; want %v", got, want)
}
},
tls: true,
tcpData: v2Hdr.Bytes(),
}
st := newServerTester(t, opts)
defer st.Close()
res, err := st.http2(requestParam{
name: "TestH2H1ProxyProtocolV2TCP4TLS",
})
if err != nil {
t.Fatalf("Error st.http2() = %v", err)
}
if got, want := res.status, 200; got != want {
t.Errorf("res.status: %v; want %v", got, want)
}
}
// TestH2H1ProxyProtocolV2TCP6TLS tests PROXY protocol version 2 over
// TLS containing AF_INET6 family is accepted and X-Forwarded-For
// contains advertised src address.
func TestH2H1ProxyProtocolV2TCP6TLS(t *testing.T) {
var v2Hdr bytes.Buffer
writeProxyProtocolV2(&v2Hdr, proxyProtocolV2{
command: proxyProtocolV2CommandProxy,
sourceAddress: &net.TCPAddr{
IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
Port: 12345,
},
destinationAddress: &net.TCPAddr{
IP: net.ParseIP("::1"),
Port: 8080,
},
additionalData: []byte("foobar"),
})
opts := options{
args: []string{
"--accept-proxy-protocol",
"--add-x-forwarded-for",
"--add-forwarded=for",
"--forwarded-for=ip",
},
handler: func(w http.ResponseWriter, r *http.Request) {
if got, want := r.Header.Get("X-Forwarded-For"), "2001:db8:85a3::8a2e:370:7334"; got != want {
t.Errorf("X-Forwarded-For: %v; want %v", got, want)
}
if got, want := r.Header.Get("Forwarded"), `for="[2001:db8:85a3::8a2e:370:7334]"`; got != want {
t.Errorf("Forwarded: %v; want %v", got, want)
}
},
tls: true,
tcpData: v2Hdr.Bytes(),
}
st := newServerTester(t, opts)
defer st.Close()
res, err := st.http2(requestParam{
name: "TestH2H1ProxyProtocolV2TCP6TLS",
})
if err != nil {
t.Fatalf("Error st.http2() = %v", err)
}
if got, want := res.status, 200; got != want {
t.Errorf("res.status: %v; want %v", got, want)
}
}
// TestH2H1ProxyProtocolV2Local tests PROXY protocol version 2
// containing cmd == Local is ignored.
func TestH2H1ProxyProtocolV2Local(t *testing.T) {

View File

@ -76,6 +76,10 @@ type options struct {
// tlsConfig is the client side TLS configuration that is used
// when tls is true.
tlsConfig *tls.Config
// tcpData is additional data that are written to connection
// before TLS handshake starts. This field is ignored if tls
// is false.
tcpData []byte
}
// newServerTester creates test context.
@ -204,9 +208,15 @@ func newServerTester(t *testing.T, opts options) *serverTester {
for {
time.Sleep(50 * time.Millisecond)
var conn net.Conn
var err error
if opts.tls {
conn, err := net.Dial("tcp", authority)
if err == nil && opts.tls {
if len(opts.tcpData) > 0 {
if _, err := conn.Write(opts.tcpData); err != nil {
st.Close()
st.t.Fatal("Error writing TCP data")
}
}
var tlsConfig *tls.Config
if opts.tlsConfig == nil {
tlsConfig = new(tls.Config)
@ -219,9 +229,16 @@ func newServerTester(t *testing.T, opts options) *serverTester {
} else {
tlsConfig.NextProtos = []string{"h2"}
}
conn, err = tls.Dial("tcp", authority, tlsConfig)
} else {
conn, err = net.Dial("tcp", authority)
tlsConn := tls.Client(conn, tlsConfig)
err = tlsConn.Handshake()
if err == nil {
cs := tlsConn.ConnectionState()
if !cs.NegotiatedProtocolIsMutual {
st.Close()
st.t.Fatalf("Error negotiated next protocol is not mutual")
}
conn = tlsConn
}
}
if err != nil {
retry += 1
@ -231,14 +248,6 @@ func newServerTester(t *testing.T, opts options) *serverTester {
}
continue
}
if opts.tls {
tlsConn := conn.(*tls.Conn)
cs := tlsConn.ConnectionState()
if !cs.NegotiatedProtocolIsMutual {
st.Close()
st.t.Fatalf("Error negotiated next protocol is not mutual")
}
}
st.conn = conn
break
}

View File

@ -181,6 +181,35 @@ int ClientHandler::write_clear() {
return 0;
}
int ClientHandler::proxy_protocol_peek_clear() {
rb_.ensure_chunk();
assert(rb_.rleft() == 0);
auto nread = conn_.peek_clear(rb_.last(), rb_.wleft());
if (nread < 0) {
return -1;
}
if (nread == 0) {
return 0;
}
if (LOG_ENABLED(INFO)) {
CLOG(INFO, this) << "PROXY-protocol: Peek " << nread
<< " bytes from socket";
}
rb_.write(nread);
if (on_read() != 0) {
return -1;
}
rb_.reset();
return 0;
}
int ClientHandler::tls_handshake() {
ev_timer_again(conn_.loop, &conn_.rt);
@ -446,7 +475,7 @@ ClientHandler::ClientHandler(Worker *worker, int fd, SSL *ssl,
if (!faddr->quic) {
if (faddr_->accept_proxy_protocol ||
config->conn.upstream.accept_proxy_protocol) {
read_ = &ClientHandler::read_clear;
read_ = &ClientHandler::proxy_protocol_peek_clear;
write_ = &ClientHandler::noop;
on_read_ = &ClientHandler::proxy_protocol_read;
on_write_ = &ClientHandler::upstream_noop;
@ -1257,19 +1286,25 @@ ssize_t parse_proxy_line_port(const uint8_t *first, const uint8_t *last) {
} // namespace
int ClientHandler::on_proxy_protocol_finish() {
if (conn_.tls.ssl) {
conn_.tls.rbuf.append(rb_.pos(), rb_.rleft());
rb_.reset();
auto len = rb_.pos() - rb_.begin();
assert(len);
if (LOG_ENABLED(INFO)) {
CLOG(INFO, this) << "PROXY-protocol: Draining " << len
<< " bytes from socket";
}
setup_upstream_io_callback();
rb_.reset();
// Run on_read to process data left in buffer since they are not
// notified further
if (on_read() != 0) {
if (conn_.read_nolim_clear(rb_.pos(), len) < 0) {
return -1;
}
rb_.reset();
setup_upstream_io_callback();
return 0;
}

View File

@ -68,6 +68,8 @@ public:
// Performs clear text I/O
int read_clear();
int write_clear();
// Specialized for PROXY-protocol use; peek data from socket.
int proxy_protocol_peek_clear();
// Performs TLS handshake
int tls_handshake();
// Performs TLS I/O

View File

@ -1156,6 +1156,42 @@ ssize_t Connection::read_clear(void *data, size_t len) {
return nread;
}
ssize_t Connection::read_nolim_clear(void *data, size_t len) {
ssize_t nread;
while ((nread = read(fd, data, len)) == -1 && errno == EINTR)
;
if (nread == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
return 0;
}
return SHRPX_ERR_NETWORK;
}
if (nread == 0) {
return SHRPX_ERR_EOF;
}
return nread;
}
ssize_t Connection::peek_clear(void *data, size_t len) {
ssize_t nread;
while ((nread = recv(fd, data, len, MSG_PEEK)) == -1 && errno == EINTR)
;
if (nread == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
return 0;
}
return SHRPX_ERR_NETWORK;
}
if (nread == 0) {
return SHRPX_ERR_EOF;
}
return nread;
}
void Connection::handle_tls_pending_read() {
if (!ev_is_active(&rev)) {
return;

View File

@ -141,6 +141,10 @@ struct Connection {
ssize_t write_clear(const void *data, size_t len);
ssize_t writev_clear(struct iovec *iov, int iovcnt);
ssize_t read_clear(void *data, size_t len);
// Read at most |len| bytes of data from socket without rate limit.
ssize_t read_nolim_clear(void *data, size_t len);
// Peek at most |len| bytes of data from socket without rate limit.
ssize_t peek_clear(void *data, size_t len);
void handle_tls_pending_read();