diff --git a/integration-tests/nghttpx_http1_test.go b/integration-tests/nghttpx_http1_test.go index b84e3059..80d3b1ca 100644 --- a/integration-tests/nghttpx_http1_test.go +++ b/integration-tests/nghttpx_http1_test.go @@ -211,6 +211,41 @@ func TestH1H1HTTP10NoHostRewrite(t *testing.T) { } } +// TestH1H1RequestTrailer tests request trailer part is forwarded to +// backend. +func TestH1H1RequestTrailer(t *testing.T) { + st := newServerTester(nil, t, func(w http.ResponseWriter, r *http.Request) { + buf := make([]byte, 4096) + for { + _, err := r.Body.Read(buf) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("r.Body.Read() = %v", err) + } + } + if got, want := r.Trailer.Get("foo"), "bar"; got != want { + t.Errorf("r.Trailer.Get(foo): %v; want %v", got, want) + } + }) + defer st.Close() + + res, err := st.http1(requestParam{ + name: "TestH1H1RequestTrailer", + body: []byte("1"), + trailer: []hpack.HeaderField{ + pair("foo", "bar"), + }, + }) + if err != nil { + t.Fatalf("Error st.http1() = %v", err) + } + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} + // TestH1H2ConnectFailure tests that server handles the situation that // connection attempt to HTTP/2 backend failed. func TestH1H2ConnectFailure(t *testing.T) { diff --git a/integration-tests/nghttpx_http2_test.go b/integration-tests/nghttpx_http2_test.go index 6d9edf7f..fc1daf54 100644 --- a/integration-tests/nghttpx_http2_test.go +++ b/integration-tests/nghttpx_http2_test.go @@ -520,6 +520,41 @@ func TestH2H1ServerPush(t *testing.T) { } } +// TestH2H1RequestTrailer tests request trailer part is forwarded to +// backend. +func TestH2H1RequestTrailer(t *testing.T) { + st := newServerTester(nil, t, func(w http.ResponseWriter, r *http.Request) { + buf := make([]byte, 4096) + for { + _, err := r.Body.Read(buf) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("r.Body.Read() = %v", err) + } + } + if got, want := r.Trailer.Get("foo"), "bar"; got != want { + t.Errorf("r.Trailer.Get(foo): %v; want %v", got, want) + } + }) + defer st.Close() + + res, err := st.http2(requestParam{ + name: "TestH2H1RequestTrailer", + body: []byte("1"), + trailer: []hpack.HeaderField{ + pair("foo", "bar"), + }, + }) + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} + // TestH2H1GracefulShutdown tests graceful shutdown. func TestH2H1GracefulShutdown(t *testing.T) { st := newServerTester(nil, t, noopHandler) diff --git a/integration-tests/server_tester.go b/integration-tests/server_tester.go index b8a50300..58a3051a 100644 --- a/integration-tests/server_tester.go +++ b/integration-tests/server_tester.go @@ -255,6 +255,27 @@ type requestParam struct { path string // path, defaults to / header []hpack.HeaderField // additional request header fields body []byte // request body + trailer []hpack.HeaderField // trailer part +} + +// wrapper for request body to set trailer part +type chunkedBodyReader struct { + trailer []hpack.HeaderField + trailerWritten bool + body io.Reader + req *http.Request +} + +func (cbr *chunkedBodyReader) Read(p []byte) (n int, err error) { + // document says that we have to set http.Request.Trailer + // after request was sent and before body returns EOF. + if !cbr.trailerWritten { + cbr.trailerWritten = true + for _, h := range cbr.trailer { + cbr.req.Trailer.Set(h.Name, h.Value) + } + } + return cbr.body.Read(p) } func (st *serverTester) http1(rp requestParam) (*serverResponse, error) { @@ -264,8 +285,16 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) { } var body io.Reader + var cbr *chunkedBodyReader if rp.body != nil { body = bytes.NewBuffer(rp.body) + if len(rp.trailer) != 0 { + cbr = &chunkedBodyReader{ + trailer: rp.trailer, + body: body, + } + body = cbr + } } req, err := http.NewRequest(method, st.url, body) if err != nil { @@ -275,7 +304,15 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) { req.Header.Add(h.Name, h.Value) } req.Header.Add("Test-Case", rp.name) - + if cbr != nil { + cbr.req = req + // this makes request use chunked encoding + req.ContentLength = -1 + req.Trailer = make(http.Header) + for _, h := range cbr.trailer { + req.Trailer.Set(h.Name, "") + } + } if err := req.Write(st.conn); err != nil { return nil, err } @@ -473,7 +510,7 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) { err := st.fr.WriteHeaders(http2.HeadersFrameParam{ StreamID: id, - EndStream: len(rp.body) == 0, + EndStream: len(rp.body) == 0 && len(rp.trailer) == 0, EndHeaders: true, BlockFragment: st.headerBlkBuf.Bytes(), }) @@ -483,7 +520,23 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) { if len(rp.body) != 0 { // TODO we assume rp.body fits in 1 frame - if err := st.fr.WriteData(id, true, rp.body); err != nil { + if err := st.fr.WriteData(id, len(rp.trailer) == 0, rp.body); err != nil { + return nil, err + } + } + + if len(rp.trailer) != 0 { + st.headerBlkBuf.Reset() + for _, h := range rp.trailer { + _ = st.enc.WriteField(h) + } + err := st.fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: id, + EndStream: true, + EndHeaders: true, + BlockFragment: st.headerBlkBuf.Bytes(), + }) + if err != nil { return nil, err } }