diff --git a/integration-tests/nghttpx_http2_test.go b/integration-tests/nghttpx_http2_test.go index c0a19a3f..a33d4c60 100644 --- a/integration-tests/nghttpx_http2_test.go +++ b/integration-tests/nghttpx_http2_test.go @@ -140,7 +140,7 @@ func TestH2H1StripAddXff(t *testing.T) { // Forwarded header field with obfuscated "by" and "for" parameters. func TestH2H1AddForwardedObfuscated(t *testing.T) { st := newServerTester([]string{"--add-forwarded=by,for,host,proto"}, t, func(w http.ResponseWriter, r *http.Request) { - pattern := fmt.Sprintf(`by="_[^"]+";for="_[^"]+";host="127.0.0.1:%v";proto="http"`, serverPort) + pattern := fmt.Sprintf(`by=_[^;]+;for=_[^;]+;host="127.0.0.1:%v";proto=http`, serverPort) validFwd := regexp.MustCompile(pattern) got := r.Header.Get("Forwarded") @@ -165,7 +165,7 @@ func TestH2H1AddForwardedObfuscated(t *testing.T) { // field with IP address in "by" parameter. func TestH2H1AddForwardedByIP(t *testing.T) { st := newServerTester([]string{"--add-forwarded=by,for,host,proto", "--forwarded-by=ip", "--forwarded-for=_bravo"}, t, func(w http.ResponseWriter, r *http.Request) { - want := fmt.Sprintf(`by="127.0.0.1:%v";for="_bravo";host="127.0.0.1:%v";proto="http"`, serverPort, serverPort) + want := fmt.Sprintf(`by="127.0.0.1:%v";for=_bravo;host="127.0.0.1:%v";proto=http`, serverPort, serverPort) if got := r.Header.Get("Forwarded"); got != want { t.Errorf("Forwarded = %v; want %v", got, want) } @@ -187,7 +187,7 @@ func TestH2H1AddForwardedByIP(t *testing.T) { // field with IP address in "for" parameters. func TestH2H1AddForwardedForIP(t *testing.T) { st := newServerTester([]string{"--add-forwarded=by,for,host,proto", "--forwarded-by=_alpha", "--forwarded-for=ip"}, t, func(w http.ResponseWriter, r *http.Request) { - want := fmt.Sprintf(`by="_alpha";for="127.0.0.1";host="127.0.0.1:%v";proto="http"`, serverPort) + want := fmt.Sprintf(`by=_alpha;for=127.0.0.1;host="127.0.0.1:%v";proto=http`, serverPort) if got := r.Header.Get("Forwarded"); got != want { t.Errorf("Forwarded = %v; want %v", got, want) } @@ -210,7 +210,7 @@ func TestH2H1AddForwardedForIP(t *testing.T) { // generated values must be appended to the existing value. func TestH2H1AddForwardedMerge(t *testing.T) { st := newServerTester([]string{"--add-forwarded=proto"}, t, func(w http.ResponseWriter, r *http.Request) { - if got, want := r.Header.Get("Forwarded"), `host=foo, proto="http"`; got != want { + if got, want := r.Header.Get("Forwarded"), `host=foo, proto=http`; got != want { t.Errorf("Forwarded = %v; want %v", got, want) } }) @@ -235,7 +235,7 @@ func TestH2H1AddForwardedMerge(t *testing.T) { // generated values must not include the existing value. func TestH2H1AddForwardedStrip(t *testing.T) { st := newServerTester([]string{"--strip-incoming-forwarded", "--add-forwarded=proto"}, t, func(w http.ResponseWriter, r *http.Request) { - if got, want := r.Header.Get("Forwarded"), `proto="http"`; got != want { + if got, want := r.Header.Get("Forwarded"), `proto=http`; got != want { t.Errorf("Forwarded = %v; want %v", got, want) } }) @@ -284,7 +284,7 @@ func TestH2H1StripForwarded(t *testing.T) { // "for" parameters. func TestH2H1AddForwardedStatic(t *testing.T) { st := newServerTester([]string{"--add-forwarded=by,for", "--forwarded-by=_alpha", "--forwarded-for=_bravo"}, t, func(w http.ResponseWriter, r *http.Request) { - if got, want := r.Header.Get("Forwarded"), `by="_alpha";for="_bravo"`; got != want { + if got, want := r.Header.Get("Forwarded"), `by=_alpha;for=_bravo`; got != want { t.Errorf("Forwarded = %v; want %v", got, want) } }) @@ -615,7 +615,7 @@ func TestH2H1BadAuthority(t *testing.T) { defer st.Close() res, err := st.http2(requestParam{ - name: "TestH2H1BadAuthority", + name: "TestH2H1BadAuthority", authority: `foo\bar`, }) if err != nil { @@ -635,7 +635,7 @@ func TestH2H1BadScheme(t *testing.T) { defer st.Close() res, err := st.http2(requestParam{ - name: "TestH2H1BadScheme", + name: "TestH2H1BadScheme", scheme: "http*", }) if err != nil { @@ -1672,7 +1672,7 @@ func TestH2H2StripAddXff(t *testing.T) { // field using static obfuscated "by" and "for" parameter. func TestH2H2AddForwarded(t *testing.T) { st := newServerTesterTLS([]string{"--http2-bridge", "--add-forwarded=by,for,host,proto", "--forwarded-by=_alpha", "--forwarded-for=_bravo"}, t, func(w http.ResponseWriter, r *http.Request) { - want := fmt.Sprintf(`by="_alpha";for="_bravo";host="127.0.0.1:%v";proto="https"`, serverPort) + want := fmt.Sprintf(`by=_alpha;for=_bravo;host="127.0.0.1:%v";proto=https`, serverPort) if got := r.Header.Get("Forwarded"); got != want { t.Errorf("Forwarded = %v; want %v", got, want) } @@ -1696,7 +1696,7 @@ func TestH2H2AddForwarded(t *testing.T) { // existing Forwarded header field. func TestH2H2AddForwardedMerge(t *testing.T) { st := newServerTesterTLS([]string{"--http2-bridge", "--add-forwarded=by,for,host,proto", "--forwarded-by=_alpha", "--forwarded-for=_bravo"}, t, func(w http.ResponseWriter, r *http.Request) { - want := fmt.Sprintf(`host=foo, by="_alpha";for="_bravo";host="127.0.0.1:%v";proto="https"`, serverPort) + want := fmt.Sprintf(`host=foo, by=_alpha;for=_bravo;host="127.0.0.1:%v";proto=https`, serverPort) if got := r.Header.Get("Forwarded"); got != want { t.Errorf("Forwarded = %v; want %v", got, want) } @@ -1723,7 +1723,7 @@ func TestH2H2AddForwardedMerge(t *testing.T) { // existing Forwarded header field stripped. func TestH2H2AddForwardedStrip(t *testing.T) { st := newServerTesterTLS([]string{"--http2-bridge", "--strip-incoming-forwarded", "--add-forwarded=by,for,host,proto", "--forwarded-by=_alpha", "--forwarded-for=_bravo"}, t, func(w http.ResponseWriter, r *http.Request) { - want := fmt.Sprintf(`by="_alpha";for="_bravo";host="127.0.0.1:%v";proto="https"`, serverPort) + want := fmt.Sprintf(`by=_alpha;for=_bravo;host="127.0.0.1:%v";proto=https`, serverPort) if got := r.Header.Get("Forwarded"); got != want { t.Errorf("Forwarded = %v; want %v", got, want) } diff --git a/src/Makefile.am b/src/Makefile.am index 99408894..90fc5d21 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -169,6 +169,7 @@ nghttpx_unittest_SOURCES = shrpx-unittest.cc \ shrpx_ssl_test.cc shrpx_ssl_test.h \ shrpx_downstream_test.cc shrpx_downstream_test.h \ shrpx_config_test.cc shrpx_config_test.h \ + shrpx_http_test.cc shrpx_http_test.h \ http2_test.cc http2_test.h \ util_test.cc util_test.h \ nghttp2_gzip_test.c nghttp2_gzip_test.h \ diff --git a/src/shrpx-unittest.cc b/src/shrpx-unittest.cc index 3e6d4a12..227aa356 100644 --- a/src/shrpx-unittest.cc +++ b/src/shrpx-unittest.cc @@ -39,6 +39,7 @@ #include "buffer_test.h" #include "memchunk_test.h" #include "template_test.h" +#include "shrpx_http_test.h" #include "shrpx_config.h" #include "ssl.h" @@ -124,6 +125,8 @@ int main(int argc, char *argv[]) { shrpx::test_shrpx_config_read_tls_ticket_key_file_aes_256) || !CU_add_test(pSuite, "config_match_downstream_addr_group", shrpx::test_shrpx_config_match_downstream_addr_group) || + !CU_add_test(pSuite, "http_create_forwarded", + shrpx::test_shrpx_http_create_forwarded) || !CU_add_test(pSuite, "util_streq", shrpx::test_util_streq) || !CU_add_test(pSuite, "util_strieq", shrpx::test_util_strieq) || !CU_add_test(pSuite, "util_inp_strlower", diff --git a/src/shrpx_http.cc b/src/shrpx_http.cc index faf83afb..f764f21a 100644 --- a/src/shrpx_http.cc +++ b/src/shrpx_http.cc @@ -69,24 +69,43 @@ std::string create_forwarded(int params, const std::string &node_by, const std::string &proto) { std::string res; if ((params & FORWARDED_BY) && !node_by.empty()) { - res += "by=\""; - res += node_by; - res += "\";"; + // This must be quoted-string unless it is obfuscated version + // (which starts with "_"), since ':' is not allowed in token. + // ':' is used to separate host and port. + if (node_by[0] == '_') { + res += "by="; + res += node_by; + res += ";"; + } else { + res += "by=\""; + res += node_by; + res += "\";"; + } } if ((params & FORWARDED_FOR) && !node_for.empty()) { - res += "for=\""; - res += node_for; - res += "\";"; + // We only quote IPv6 literal address only, which starts with '['. + if (node_for[0] == '[') { + res += "for=\""; + res += node_for; + res += "\";"; + } else { + res += "for="; + res += node_for; + res += ";"; + } } if ((params & FORWARDED_HOST) && !host.empty()) { + // Just be quoted to skip checking characters. res += "host=\""; res += host; res += "\";"; } if ((params & FORWARDED_PROTO) && !proto.empty()) { - res += "proto=\""; + // Scheme production rule only allow characters which are all in + // token. + res += "proto="; res += proto; - res += "\";"; + res += ";"; } if (res.empty()) {