diff --git a/integration-tests/nghttpx_http2_test.go b/integration-tests/nghttpx_http2_test.go index 0698dd24..84646fa0 100644 --- a/integration-tests/nghttpx_http2_test.go +++ b/integration-tests/nghttpx_http2_test.go @@ -9,6 +9,7 @@ import ( "golang.org/x/net/http2/hpack" "io" "io/ioutil" + "net" "net/http" "regexp" "strings" @@ -1506,6 +1507,235 @@ func TestH2H1ProxyProtocolV1InvalidID(t *testing.T) { } } +// TestH2H1ProxyProtocolV2TCP4 tests PROXY protocol version 2 +// containing AF_INET family is accepted and X-Forwarded-For contains +// advertised src address. +func TestH2H1ProxyProtocolV2TCP4(t *testing.T) { + st := newServerTester([]string{"--accept-proxy-protocol", "--add-x-forwarded-for", "--add-forwarded=for", "--forwarded-for=ip"}, t, func(w http.ResponseWriter, r *http.Request) { + if got, want := r.Header.Get("X-Forwarded-For"), "192.168.0.2"; got != want { + t.Errorf("X-Forwarded-For: %v; want %v", got, want) + } + if got, want := r.Header.Get("Forwarded"), "for=192.168.0.2"; got != want { + t.Errorf("Forwarded: %v; want %v", got, want) + } + }) + defer st.Close() + + var b bytes.Buffer + writeProxyProtocolV2(&b, proxyProtocolV2{ + command: proxyProtocolV2CommandProxy, + sourceAddress: &net.TCPAddr{ + IP: net.ParseIP("192.168.0.2").To4(), + Port: 12345, + }, + destinationAddress: &net.TCPAddr{ + IP: net.ParseIP("192.168.0.100").To4(), + Port: 8080, + }, + additionalData: []byte("foobar"), + }) + st.conn.Write(b.Bytes()) + + res, err := st.http2(requestParam{ + name: "TestH2H1ProxyProtocolV2TCP4", + }) + + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} + +// TestH2H1ProxyProtocolV2TCP6 tests PROXY protocol version 2 +// containing AF_INET6 family is accepted and X-Forwarded-For contains +// advertised src address. +func TestH2H1ProxyProtocolV2TCP6(t *testing.T) { + st := newServerTester([]string{"--accept-proxy-protocol", "--add-x-forwarded-for", "--add-forwarded=for", "--forwarded-for=ip"}, t, func(w http.ResponseWriter, r *http.Request) { + if got, want := r.Header.Get("X-Forwarded-For"), "2001:db8:85a3::8a2e:370:7334"; got != want { + t.Errorf("X-Forwarded-For: %v; want %v", got, want) + } + if got, want := r.Header.Get("Forwarded"), `for="[2001:db8:85a3::8a2e:370:7334]"`; got != want { + t.Errorf("Forwarded: %v; want %v", got, want) + } + }) + defer st.Close() + + var b bytes.Buffer + writeProxyProtocolV2(&b, proxyProtocolV2{ + command: proxyProtocolV2CommandProxy, + sourceAddress: &net.TCPAddr{ + IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334"), + Port: 12345, + }, + destinationAddress: &net.TCPAddr{ + IP: net.ParseIP("::1"), + Port: 8080, + }, + additionalData: []byte("foobar"), + }) + st.conn.Write(b.Bytes()) + + res, err := st.http2(requestParam{ + name: "TestH2H1ProxyProtocolV2TCP6", + }) + + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} + +// TestH2H1ProxyProtocolV2Local tests PROXY protocol version 2 +// containing cmd == Local is ignored. +func TestH2H1ProxyProtocolV2Local(t *testing.T) { + st := newServerTester([]string{"--accept-proxy-protocol", "--add-x-forwarded-for", "--add-forwarded=for", "--forwarded-for=ip"}, t, func(w http.ResponseWriter, r *http.Request) { + if got, want := r.Header.Get("X-Forwarded-For"), "127.0.0.1"; got != want { + t.Errorf("X-Forwarded-For: %v; want %v", got, want) + } + if got, want := r.Header.Get("Forwarded"), "for=127.0.0.1"; got != want { + t.Errorf("Forwarded: %v; want %v", got, want) + } + }) + defer st.Close() + + var b bytes.Buffer + writeProxyProtocolV2(&b, proxyProtocolV2{ + command: proxyProtocolV2CommandLocal, + sourceAddress: &net.TCPAddr{ + IP: net.ParseIP("192.168.0.2").To4(), + Port: 12345, + }, + destinationAddress: &net.TCPAddr{ + IP: net.ParseIP("192.168.0.100").To4(), + Port: 8080, + }, + additionalData: []byte("foobar"), + }) + st.conn.Write(b.Bytes()) + + res, err := st.http2(requestParam{ + name: "TestH2H1ProxyProtocolV2Local", + }) + + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} + +// TestH2H1ProxyProtocolV2UnknownCmd tests PROXY protocol version 2 +// containing unknown cmd should be rejected. +func TestH2H1ProxyProtocolV2UnknownCmd(t *testing.T) { + st := newServerTester([]string{"--accept-proxy-protocol"}, t, noopHandler) + defer st.Close() + + var b bytes.Buffer + writeProxyProtocolV2(&b, proxyProtocolV2{ + command: 0xf, + sourceAddress: &net.TCPAddr{ + IP: net.ParseIP("192.168.0.2").To4(), + Port: 12345, + }, + destinationAddress: &net.TCPAddr{ + IP: net.ParseIP("192.168.0.100").To4(), + Port: 8080, + }, + additionalData: []byte("foobar"), + }) + st.conn.Write(b.Bytes()) + + _, err := st.http2(requestParam{ + name: "TestH2H1ProxyProtocolV2UnknownCmd", + }) + + if err == nil { + t.Fatalf("connection was not terminated") + } +} + +// TestH2H1ProxyProtocolV2Unix tests PROXY protocol version 2 +// containing AF_UNIX family is ignored. +func TestH2H1ProxyProtocolV2Unix(t *testing.T) { + st := newServerTester([]string{"--accept-proxy-protocol", "--add-x-forwarded-for", "--add-forwarded=for", "--forwarded-for=ip"}, t, func(w http.ResponseWriter, r *http.Request) { + if got, want := r.Header.Get("X-Forwarded-For"), "127.0.0.1"; got != want { + t.Errorf("X-Forwarded-For: %v; want %v", got, want) + } + if got, want := r.Header.Get("Forwarded"), "for=127.0.0.1"; got != want { + t.Errorf("Forwarded: %v; want %v", got, want) + } + }) + defer st.Close() + + var b bytes.Buffer + writeProxyProtocolV2(&b, proxyProtocolV2{ + command: proxyProtocolV2CommandProxy, + sourceAddress: &net.UnixAddr{ + Name: "/foo", + Net: "unix", + }, + destinationAddress: &net.UnixAddr{ + Name: "/bar", + Net: "unix", + }, + additionalData: []byte("foobar"), + }) + st.conn.Write(b.Bytes()) + + res, err := st.http2(requestParam{ + name: "TestH2H1ProxyProtocolV2Unix", + }) + + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} + +// TestH2H1ProxyProtocolV2Unspec tests PROXY protocol version 2 +// containing AF_UNSPEC family is ignored. +func TestH2H1ProxyProtocolV2Unspec(t *testing.T) { + st := newServerTester([]string{"--accept-proxy-protocol", "--add-x-forwarded-for", "--add-forwarded=for", "--forwarded-for=ip"}, t, func(w http.ResponseWriter, r *http.Request) { + if got, want := r.Header.Get("X-Forwarded-For"), "127.0.0.1"; got != want { + t.Errorf("X-Forwarded-For: %v; want %v", got, want) + } + if got, want := r.Header.Get("Forwarded"), "for=127.0.0.1"; got != want { + t.Errorf("Forwarded: %v; want %v", got, want) + } + }) + defer st.Close() + + var b bytes.Buffer + writeProxyProtocolV2(&b, proxyProtocolV2{ + command: proxyProtocolV2CommandProxy, + additionalData: []byte("foobar"), + }) + st.conn.Write(b.Bytes()) + + res, err := st.http2(requestParam{ + name: "TestH2H1ProxyProtocolV2Unspec", + }) + + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} + // TestH2H1ExternalDNS tests that DNS resolution using external DNS // with HTTP/1 backend works. func TestH2H1ExternalDNS(t *testing.T) { diff --git a/integration-tests/server_tester.go b/integration-tests/server_tester.go index 89e7f29d..113e27d1 100644 --- a/integration-tests/server_tester.go +++ b/integration-tests/server_tester.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "crypto/tls" + "encoding/binary" "errors" "fmt" "github.com/tatsuhiro-t/go-nghttp2" @@ -671,3 +672,93 @@ type APIResponse struct { Code int `json:"code,omitempty"` Data map[string]interface{} `json:"data,omitempty"` } + +type proxyProtocolV2 struct { + command proxyProtocolV2Command + sourceAddress net.Addr + destinationAddress net.Addr + additionalData []byte +} + +type proxyProtocolV2Command int + +const ( + proxyProtocolV2CommandLocal proxyProtocolV2Command = 0x0 + proxyProtocolV2CommandProxy proxyProtocolV2Command = 0x1 +) + +type proxyProtocolV2Family int + +const ( + proxyProtocolV2FamilyUnspec proxyProtocolV2Family = 0x0 + proxyProtocolV2FamilyInet proxyProtocolV2Family = 0x1 + proxyProtocolV2FamilyInet6 proxyProtocolV2Family = 0x2 + proxyProtocolV2FamilyUnix proxyProtocolV2Family = 0x3 +) + +type proxyProtocolV2Protocol int + +const ( + proxyProtocolV2ProtocolUnspec proxyProtocolV2Protocol = 0x0 + proxyProtocolV2ProtocolStream proxyProtocolV2Protocol = 0x1 + proxyProtocolV2ProtocolDgram proxyProtocolV2Protocol = 0x2 +) + +func writeProxyProtocolV2(w io.Writer, hdr proxyProtocolV2) { + w.Write([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}) + w.Write([]byte{byte(0x20 | hdr.command)}) + + switch srcAddr := hdr.sourceAddress.(type) { + case *net.TCPAddr: + dstAddr := hdr.destinationAddress.(*net.TCPAddr) + if len(srcAddr.IP) != len(dstAddr.IP) { + panic("len(srcAddr.IP) != len(dstAddr.IP)") + } + var fam byte + if len(srcAddr.IP) == 4 { + fam = byte(proxyProtocolV2FamilyInet << 4) + } else { + fam = byte(proxyProtocolV2FamilyInet6 << 4) + } + fam |= byte(proxyProtocolV2ProtocolStream) + w.Write([]byte{fam}) + length := uint16(len(srcAddr.IP)*2 + 4 + len(hdr.additionalData)) + binary.Write(w, binary.BigEndian, length) + w.Write(srcAddr.IP) + w.Write(dstAddr.IP) + binary.Write(w, binary.BigEndian, uint16(srcAddr.Port)) + binary.Write(w, binary.BigEndian, uint16(dstAddr.Port)) + case *net.UnixAddr: + dstAddr := hdr.destinationAddress.(*net.UnixAddr) + if len(srcAddr.Name) > 108 { + panic("too long Unix source address") + } + if len(dstAddr.Name) > 108 { + panic("too long Unix destination address") + } + fam := byte(proxyProtocolV2FamilyUnix << 4) + switch srcAddr.Net { + case "unix": + fam |= byte(proxyProtocolV2ProtocolStream) + case "unixdgram": + fam |= byte(proxyProtocolV2ProtocolDgram) + default: + fam |= byte(proxyProtocolV2ProtocolUnspec) + } + w.Write([]byte{fam}) + length := uint16(216 + len(hdr.additionalData)) + binary.Write(w, binary.BigEndian, length) + zeros := make([]byte, 108) + w.Write([]byte(srcAddr.Name)) + w.Write(zeros[:108-len(srcAddr.Name)]) + w.Write([]byte(dstAddr.Name)) + w.Write(zeros[:108-len(dstAddr.Name)]) + default: + fam := byte(proxyProtocolV2FamilyUnspec<<4) | byte(proxyProtocolV2ProtocolUnspec) + w.Write([]byte{fam}) + length := uint16(len(hdr.additionalData)) + binary.Write(w, binary.BigEndian, length) + } + + w.Write(hdr.additionalData) +}