integration: Add test case for trailer part
This commit is contained in:
parent
b9d6fff962
commit
5c31c130bd
|
@ -211,6 +211,41 @@ func TestH1H1HTTP10NoHostRewrite(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestH1H1RequestTrailer tests request trailer part is forwarded to
|
||||
// backend.
|
||||
func TestH1H1RequestTrailer(t *testing.T) {
|
||||
st := newServerTester(nil, t, func(w http.ResponseWriter, r *http.Request) {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
_, err := r.Body.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("r.Body.Read() = %v", err)
|
||||
}
|
||||
}
|
||||
if got, want := r.Trailer.Get("foo"), "bar"; got != want {
|
||||
t.Errorf("r.Trailer.Get(foo): %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
defer st.Close()
|
||||
|
||||
res, err := st.http1(requestParam{
|
||||
name: "TestH1H1RequestTrailer",
|
||||
body: []byte("1"),
|
||||
trailer: []hpack.HeaderField{
|
||||
pair("foo", "bar"),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Error st.http1() = %v", err)
|
||||
}
|
||||
if got, want := res.status, 200; got != want {
|
||||
t.Errorf("res.status: %v; want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestH1H2ConnectFailure tests that server handles the situation that
|
||||
// connection attempt to HTTP/2 backend failed.
|
||||
func TestH1H2ConnectFailure(t *testing.T) {
|
||||
|
|
|
@ -520,6 +520,41 @@ func TestH2H1ServerPush(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestH2H1RequestTrailer tests request trailer part is forwarded to
|
||||
// backend.
|
||||
func TestH2H1RequestTrailer(t *testing.T) {
|
||||
st := newServerTester(nil, t, func(w http.ResponseWriter, r *http.Request) {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
_, err := r.Body.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("r.Body.Read() = %v", err)
|
||||
}
|
||||
}
|
||||
if got, want := r.Trailer.Get("foo"), "bar"; got != want {
|
||||
t.Errorf("r.Trailer.Get(foo): %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
defer st.Close()
|
||||
|
||||
res, err := st.http2(requestParam{
|
||||
name: "TestH2H1RequestTrailer",
|
||||
body: []byte("1"),
|
||||
trailer: []hpack.HeaderField{
|
||||
pair("foo", "bar"),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Error st.http2() = %v", err)
|
||||
}
|
||||
if got, want := res.status, 200; got != want {
|
||||
t.Errorf("res.status: %v; want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestH2H1GracefulShutdown tests graceful shutdown.
|
||||
func TestH2H1GracefulShutdown(t *testing.T) {
|
||||
st := newServerTester(nil, t, noopHandler)
|
||||
|
|
|
@ -255,6 +255,27 @@ type requestParam struct {
|
|||
path string // path, defaults to /
|
||||
header []hpack.HeaderField // additional request header fields
|
||||
body []byte // request body
|
||||
trailer []hpack.HeaderField // trailer part
|
||||
}
|
||||
|
||||
// wrapper for request body to set trailer part
|
||||
type chunkedBodyReader struct {
|
||||
trailer []hpack.HeaderField
|
||||
trailerWritten bool
|
||||
body io.Reader
|
||||
req *http.Request
|
||||
}
|
||||
|
||||
func (cbr *chunkedBodyReader) Read(p []byte) (n int, err error) {
|
||||
// document says that we have to set http.Request.Trailer
|
||||
// after request was sent and before body returns EOF.
|
||||
if !cbr.trailerWritten {
|
||||
cbr.trailerWritten = true
|
||||
for _, h := range cbr.trailer {
|
||||
cbr.req.Trailer.Set(h.Name, h.Value)
|
||||
}
|
||||
}
|
||||
return cbr.body.Read(p)
|
||||
}
|
||||
|
||||
func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
|
||||
|
@ -264,8 +285,16 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
|
|||
}
|
||||
|
||||
var body io.Reader
|
||||
var cbr *chunkedBodyReader
|
||||
if rp.body != nil {
|
||||
body = bytes.NewBuffer(rp.body)
|
||||
if len(rp.trailer) != 0 {
|
||||
cbr = &chunkedBodyReader{
|
||||
trailer: rp.trailer,
|
||||
body: body,
|
||||
}
|
||||
body = cbr
|
||||
}
|
||||
}
|
||||
req, err := http.NewRequest(method, st.url, body)
|
||||
if err != nil {
|
||||
|
@ -275,7 +304,15 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
|
|||
req.Header.Add(h.Name, h.Value)
|
||||
}
|
||||
req.Header.Add("Test-Case", rp.name)
|
||||
|
||||
if cbr != nil {
|
||||
cbr.req = req
|
||||
// this makes request use chunked encoding
|
||||
req.ContentLength = -1
|
||||
req.Trailer = make(http.Header)
|
||||
for _, h := range cbr.trailer {
|
||||
req.Trailer.Set(h.Name, "")
|
||||
}
|
||||
}
|
||||
if err := req.Write(st.conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -473,7 +510,7 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) {
|
|||
|
||||
err := st.fr.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: id,
|
||||
EndStream: len(rp.body) == 0,
|
||||
EndStream: len(rp.body) == 0 && len(rp.trailer) == 0,
|
||||
EndHeaders: true,
|
||||
BlockFragment: st.headerBlkBuf.Bytes(),
|
||||
})
|
||||
|
@ -483,7 +520,23 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) {
|
|||
|
||||
if len(rp.body) != 0 {
|
||||
// TODO we assume rp.body fits in 1 frame
|
||||
if err := st.fr.WriteData(id, true, rp.body); err != nil {
|
||||
if err := st.fr.WriteData(id, len(rp.trailer) == 0, rp.body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(rp.trailer) != 0 {
|
||||
st.headerBlkBuf.Reset()
|
||||
for _, h := range rp.trailer {
|
||||
_ = st.enc.WriteField(h)
|
||||
}
|
||||
err := st.fr.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: id,
|
||||
EndStream: true,
|
||||
EndHeaders: true,
|
||||
BlockFragment: st.headerBlkBuf.Bytes(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue