integration: Add test case for trailer part

This commit is contained in:
Tatsuhiro Tsujikawa 2015-03-08 19:31:43 +09:00
parent b9d6fff962
commit 5c31c130bd
3 changed files with 126 additions and 3 deletions

View File

@ -211,6 +211,41 @@ func TestH1H1HTTP10NoHostRewrite(t *testing.T) {
}
}
// TestH1H1RequestTrailer tests request trailer part is forwarded to
// backend.
func TestH1H1RequestTrailer(t *testing.T) {
st := newServerTester(nil, t, func(w http.ResponseWriter, r *http.Request) {
buf := make([]byte, 4096)
for {
_, err := r.Body.Read(buf)
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("r.Body.Read() = %v", err)
}
}
if got, want := r.Trailer.Get("foo"), "bar"; got != want {
t.Errorf("r.Trailer.Get(foo): %v; want %v", got, want)
}
})
defer st.Close()
res, err := st.http1(requestParam{
name: "TestH1H1RequestTrailer",
body: []byte("1"),
trailer: []hpack.HeaderField{
pair("foo", "bar"),
},
})
if err != nil {
t.Fatalf("Error st.http1() = %v", err)
}
if got, want := res.status, 200; got != want {
t.Errorf("res.status: %v; want %v", got, want)
}
}
// TestH1H2ConnectFailure tests that server handles the situation that
// connection attempt to HTTP/2 backend failed.
func TestH1H2ConnectFailure(t *testing.T) {

View File

@ -520,6 +520,41 @@ func TestH2H1ServerPush(t *testing.T) {
}
}
// TestH2H1RequestTrailer tests request trailer part is forwarded to
// backend.
func TestH2H1RequestTrailer(t *testing.T) {
st := newServerTester(nil, t, func(w http.ResponseWriter, r *http.Request) {
buf := make([]byte, 4096)
for {
_, err := r.Body.Read(buf)
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("r.Body.Read() = %v", err)
}
}
if got, want := r.Trailer.Get("foo"), "bar"; got != want {
t.Errorf("r.Trailer.Get(foo): %v; want %v", got, want)
}
})
defer st.Close()
res, err := st.http2(requestParam{
name: "TestH2H1RequestTrailer",
body: []byte("1"),
trailer: []hpack.HeaderField{
pair("foo", "bar"),
},
})
if err != nil {
t.Fatalf("Error st.http2() = %v", err)
}
if got, want := res.status, 200; got != want {
t.Errorf("res.status: %v; want %v", got, want)
}
}
// TestH2H1GracefulShutdown tests graceful shutdown.
func TestH2H1GracefulShutdown(t *testing.T) {
st := newServerTester(nil, t, noopHandler)

View File

@ -255,6 +255,27 @@ type requestParam struct {
path string // path, defaults to /
header []hpack.HeaderField // additional request header fields
body []byte // request body
trailer []hpack.HeaderField // trailer part
}
// wrapper for request body to set trailer part
type chunkedBodyReader struct {
trailer []hpack.HeaderField
trailerWritten bool
body io.Reader
req *http.Request
}
func (cbr *chunkedBodyReader) Read(p []byte) (n int, err error) {
// document says that we have to set http.Request.Trailer
// after request was sent and before body returns EOF.
if !cbr.trailerWritten {
cbr.trailerWritten = true
for _, h := range cbr.trailer {
cbr.req.Trailer.Set(h.Name, h.Value)
}
}
return cbr.body.Read(p)
}
func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
@ -264,8 +285,16 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
}
var body io.Reader
var cbr *chunkedBodyReader
if rp.body != nil {
body = bytes.NewBuffer(rp.body)
if len(rp.trailer) != 0 {
cbr = &chunkedBodyReader{
trailer: rp.trailer,
body: body,
}
body = cbr
}
}
req, err := http.NewRequest(method, st.url, body)
if err != nil {
@ -275,7 +304,15 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
req.Header.Add(h.Name, h.Value)
}
req.Header.Add("Test-Case", rp.name)
if cbr != nil {
cbr.req = req
// this makes request use chunked encoding
req.ContentLength = -1
req.Trailer = make(http.Header)
for _, h := range cbr.trailer {
req.Trailer.Set(h.Name, "")
}
}
if err := req.Write(st.conn); err != nil {
return nil, err
}
@ -473,7 +510,7 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) {
err := st.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: id,
EndStream: len(rp.body) == 0,
EndStream: len(rp.body) == 0 && len(rp.trailer) == 0,
EndHeaders: true,
BlockFragment: st.headerBlkBuf.Bytes(),
})
@ -483,7 +520,23 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) {
if len(rp.body) != 0 {
// TODO we assume rp.body fits in 1 frame
if err := st.fr.WriteData(id, true, rp.body); err != nil {
if err := st.fr.WriteData(id, len(rp.trailer) == 0, rp.body); err != nil {
return nil, err
}
}
if len(rp.trailer) != 0 {
st.headerBlkBuf.Reset()
for _, h := range rp.trailer {
_ = st.enc.WriteField(h)
}
err := st.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: id,
EndStream: true,
EndHeaders: true,
BlockFragment: st.headerBlkBuf.Bytes(),
})
if err != nil {
return nil, err
}
}