diff --git a/integration-tests/nghttpx_http3_test.go b/integration-tests/nghttpx_http3_test.go index 9bac1652..28478365 100644 --- a/integration-tests/nghttpx_http3_test.go +++ b/integration-tests/nghttpx_http3_test.go @@ -246,3 +246,88 @@ func TestH3H1AffinityCookieTLS(t *testing.T) { t.Errorf("Set-Cookie: %v; want pattern %v", got, pattern) } } + +// TestH3H2ReqPhaseReturn tests mruby request phase hook returns +// custom response. +func TestH3H2ReqPhaseReturn(t *testing.T) { + opts := options{ + args: []string{ + "--http2-bridge", + "--mruby-file=" + testDir + "/req-return.rb", + }, + handler: func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("request should not be forwarded") + }, + quic: true, + } + st := newServerTester(t, opts) + defer st.Close() + + res, err := st.http3(requestParam{ + name: "TestH3H2ReqPhaseReturn", + }) + if err != nil { + t.Fatalf("Error st.http3() = %v", err) + } + + if got, want := res.status, 404; got != want { + t.Errorf("status = %v; want %v", got, want) + } + + hdtests := []struct { + k, v string + }{ + {"content-length", "20"}, + {"from", "mruby"}, + } + for _, tt := range hdtests { + if got, want := res.header.Get(tt.k), tt.v; got != want { + t.Errorf("%v = %v; want %v", tt.k, got, want) + } + } + + if got, want := string(res.body), "Hello World from req"; got != want { + t.Errorf("body = %v; want %v", got, want) + } +} + +// TestH3H2RespPhaseReturn tests mruby response phase hook returns +// custom response. +func TestH3H2RespPhaseReturn(t *testing.T) { + opts := options{ + args: []string{ + "--http2-bridge", + "--mruby-file=" + testDir + "/resp-return.rb", + }, + quic: true, + } + st := newServerTester(t, opts) + defer st.Close() + + res, err := st.http3(requestParam{ + name: "TestH3H2RespPhaseReturn", + }) + if err != nil { + t.Fatalf("Error st.http3() = %v", err) + } + + if got, want := res.status, 404; got != want { + t.Errorf("status = %v; want %v", got, want) + } + + hdtests := []struct { + k, v string + }{ + {"content-length", "21"}, + {"from", "mruby"}, + } + for _, tt := range hdtests { + if got, want := res.header.Get(tt.k), tt.v; got != want { + t.Errorf("%v = %v; want %v", tt.k, got, want) + } + } + + if got, want := string(res.body), "Hello World from resp"; got != want { + t.Errorf("body = %v; want %v", got, want) + } +}