diff --git a/integration-tests/nghttpx_http1_test.go b/integration-tests/nghttpx_http1_test.go index d93877a1..a57ae90c 100644 --- a/integration-tests/nghttpx_http1_test.go +++ b/integration-tests/nghttpx_http1_test.go @@ -961,3 +961,30 @@ func TestH1Healthmon(t *testing.T) { t.Errorf("res.status: %v; want %v", got, want) } } + +// TestH1ResponseBeforeRequestEnd tests the situation where response +// ends before request body finishes. +func TestH1ResponseBeforeRequestEnd(t *testing.T) { + st := newServerTester([]string{"--mruby-file=" + testDir + "/req-return.rb"}, t, func(w http.ResponseWriter, r *http.Request) { + t.Fatal("request should not be forwarded") + }) + defer st.Close() + + if _, err := io.WriteString(st.conn, fmt.Sprintf(`POST / HTTP/1.1 +Host: %v +Test-Case: TestH1ResponseBeforeRequestEnd +Content-Length: 1000000 + +`, st.authority)); err != nil { + t.Fatalf("Error io.WriteString() = %v", err) + } + + resp, err := http.ReadResponse(bufio.NewReader(st.conn), nil) + if err != nil { + t.Fatalf("Error http.ReadResponse() = %v", err) + } + + if got, want := resp.StatusCode, 404; got != want { + t.Errorf("status: %v; want %v", got, want) + } +} diff --git a/integration-tests/nghttpx_http2_test.go b/integration-tests/nghttpx_http2_test.go index 86036865..2c5e58bd 100644 --- a/integration-tests/nghttpx_http2_test.go +++ b/integration-tests/nghttpx_http2_test.go @@ -2011,3 +2011,23 @@ func TestH2Healthmon(t *testing.T) { t.Errorf("res.status: %v; want %v", got, want) } } + +// TestH2ResponseBeforeRequestEnd tests the situation where response +// ends before request body finishes. +func TestH2ResponseBeforeRequestEnd(t *testing.T) { + st := newServerTester([]string{"--mruby-file=" + testDir + "/req-return.rb"}, t, func(w http.ResponseWriter, r *http.Request) { + t.Fatal("request should not be forwarded") + }) + defer st.Close() + + res, err := st.http2(requestParam{ + name: "TestH2ResponseBeforeRequestEnd", + noEndStream: true, + }) + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + if got, want := res.status, 404; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} diff --git a/integration-tests/nghttpx_spdy_test.go b/integration-tests/nghttpx_spdy_test.go index 4bb3b0f1..b54d020b 100644 --- a/integration-tests/nghttpx_spdy_test.go +++ b/integration-tests/nghttpx_spdy_test.go @@ -642,3 +642,23 @@ func TestS3Healthmon(t *testing.T) { t.Errorf("res.status: %v; want %v", got, want) } } + +// TestS3ResponseBeforeRequestEnd tests the situation where response +// ends before request body finishes. +func TestS3ResponseBeforeRequestEnd(t *testing.T) { + st := newServerTesterTLS([]string{"--npn-list=spdy/3.1", "--mruby-file=" + testDir + "/req-return.rb"}, t, func(w http.ResponseWriter, r *http.Request) { + t.Fatal("request should not be forwarded") + }) + defer st.Close() + + res, err := st.spdy(requestParam{ + name: "TestS3ResponseBeforeRequestEnd", + noEndStream: true, + }) + if err != nil { + t.Fatalf("Error st.spdy() = %v", err) + } + if got, want := res.status, 404; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} diff --git a/integration-tests/server_tester.go b/integration-tests/server_tester.go index ae80861e..bccf314b 100644 --- a/integration-tests/server_tester.go +++ b/integration-tests/server_tester.go @@ -169,6 +169,8 @@ func newServerTesterInternal(src_args []string, t *testing.T, handler http.Handl retry := 0 for { + time.Sleep(50 * time.Millisecond) + var conn net.Conn var err error if frontendTLS { @@ -190,7 +192,6 @@ func newServerTesterInternal(src_args []string, t *testing.T, handler http.Handl st.Close() st.t.Fatalf("Error server is not responding too long; server command-line arguments may be invalid") } - time.Sleep(150 * time.Millisecond) continue } if frontendTLS { @@ -296,6 +297,7 @@ type requestParam struct { body []byte // request body trailer []hpack.HeaderField // trailer part httpUpgrade bool // true if upgraded to HTTP/2 through HTTP Upgrade + noEndStream bool // true if END_STREAM should not be sent } // wrapper for request body to set trailer part @@ -470,7 +472,7 @@ func (st *serverTester) spdy(rp requestParam) (*serverResponse, error) { } var synStreamFlags spdy.ControlFlags - if len(rp.body) == 0 { + if len(rp.body) == 0 && !rp.noEndStream { synStreamFlags = spdy.ControlFlagFin } if err := st.spdyFr.WriteFrame(&spdy.SynStreamFrame{ @@ -484,9 +486,13 @@ func (st *serverTester) spdy(rp requestParam) (*serverResponse, error) { } if len(rp.body) != 0 { + var dataFlags spdy.DataFlags + if !rp.noEndStream { + dataFlags = spdy.DataFlagFin + } if err := st.spdyFr.WriteFrame(&spdy.DataFrame{ StreamId: id, - Flags: spdy.DataFlagFin, + Flags: dataFlags, Data: rp.body, }); err != nil { return nil, err @@ -599,7 +605,7 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) { err := st.fr.WriteHeaders(http2.HeadersFrameParam{ StreamID: id, - EndStream: len(rp.body) == 0 && len(rp.trailer) == 0, + EndStream: len(rp.body) == 0 && len(rp.trailer) == 0 && !rp.noEndStream, EndHeaders: true, BlockFragment: st.headerBlkBuf.Bytes(), }) @@ -609,7 +615,7 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) { if len(rp.body) != 0 { // TODO we assume rp.body fits in 1 frame - if err := st.fr.WriteData(id, len(rp.trailer) == 0, rp.body); err != nil { + if err := st.fr.WriteData(id, len(rp.trailer) == 0 && !rp.noEndStream, rp.body); err != nil { return nil, err } }