diff --git a/integration-tests/nghttpx_http2_test.go b/integration-tests/nghttpx_http2_test.go index ee7fbf69..29e5cc82 100644 --- a/integration-tests/nghttpx_http2_test.go +++ b/integration-tests/nghttpx_http2_test.go @@ -491,6 +491,37 @@ func TestH2H1SNI(t *testing.T) { } } +// TestH2H1ServerPush tests server push using Link header field from +// backend server. +func TestH2H1ServerPush(t *testing.T) { + st := newServerTester(nil, t, func(w http.ResponseWriter, r *http.Request) { + // only resources marked as rel=preload are pushed + w.Header().Add("Link", "; rel=preload, , ; rel=preload") + }) + defer st.Close() + + res, err := st.http2(requestParam{ + name: "TestH2H1ServerPush", + }) + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } + if got, want := len(res.pushResponse), 2; got != want { + t.Fatalf("len(res.pushResponse): %v; want %v", got, want) + } + mainCSS := res.pushResponse[0] + if got, want := mainCSS.status, 200; got != want { + t.Errorf("mainCSS.status: %v; want %v", got, want) + } + themeCSS := res.pushResponse[1] + if got, want := themeCSS.status, 200; got != want { + t.Errorf("themeCSS.status: %v; want %v", got, want) + } +} + // TestH2H1GracefulShutdown tests graceful shutdown. func TestH2H1GracefulShutdown(t *testing.T) { st := newServerTester(nil, t, noopHandler) diff --git a/integration-tests/server_tester.go b/integration-tests/server_tester.go index 93f421eb..b8a50300 100644 --- a/integration-tests/server_tester.go +++ b/integration-tests/server_tester.go @@ -17,6 +17,7 @@ import ( "net/http/httptest" "net/url" "os/exec" + "sort" "strconv" "strings" "testing" @@ -411,7 +412,6 @@ loop: } func (st *serverTester) http2(rp requestParam) (*serverResponse, error) { - res := &serverResponse{} st.headerBlkBuf.Reset() st.header = make(http.Header) @@ -434,6 +434,13 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) { } } + res := &serverResponse{ + streamID: id, + } + + streams := make(map[uint32]*serverResponse) + streams[id] = res + method := "GET" if rp.method != "" { method = rp.method @@ -493,34 +500,53 @@ loop: if err != nil { return res, err } - if f.FrameHeader.StreamID != id { + sr, ok := streams[f.FrameHeader.StreamID] + if !ok { st.header = make(http.Header) break } - res.header = cloneHeader(st.header) + sr.header = cloneHeader(st.header) var status int - status, err = strconv.Atoi(res.header.Get(":status")) + status, err = strconv.Atoi(sr.header.Get(":status")) if err != nil { return res, fmt.Errorf("Error parsing status code: %v", err) } - res.status = status + sr.status = status if f.StreamEnded() { - break loop + if streamEnded(res, streams, sr) { + break loop + } } + case *http2.PushPromiseFrame: + _, err := st.dec.Write(f.HeaderBlockFragment()) + if err != nil { + return res, err + } + sr := &serverResponse{ + streamID: f.PromiseID, + reqHeader: cloneHeader(st.header), + } + streams[sr.streamID] = sr case *http2.DataFrame: - if f.FrameHeader.StreamID != id { + sr, ok := streams[f.FrameHeader.StreamID] + if !ok { break } - res.body = append(res.body, f.Data()...) + sr.body = append(sr.body, f.Data()...) if f.StreamEnded() { - break loop + if streamEnded(res, streams, sr) { + break loop + } } case *http2.RSTStreamFrame: - if f.FrameHeader.StreamID != id { + sr, ok := streams[f.FrameHeader.StreamID] + if !ok { break } - res.errCode = f.ErrCode - break loop + sr.errCode = f.ErrCode + if streamEnded(res, streams, sr) { + break loop + } case *http2.GoAwayFrame: if f.ErrCode == http2.ErrCodeNo { break @@ -535,21 +561,46 @@ loop: if err := st.fr.WriteSettingsAck(); err != nil { return res, err } - // TODO handle PUSH_PROMISE as well, since it alters HPACK context } } + sort.Sort(ByStreamID(res.pushResponse)) return res, nil } +func streamEnded(mainSr *serverResponse, streams map[uint32]*serverResponse, sr *serverResponse) bool { + delete(streams, sr.streamID) + if mainSr.streamID != sr.streamID { + mainSr.pushResponse = append(mainSr.pushResponse, sr) + } + return len(streams) == 0 +} + type serverResponse struct { status int // HTTP status code header http.Header // response header fields body []byte // response body + streamID uint32 // stream ID in HTTP/2 errCode http2.ErrCode // error code received in HTTP/2 RST_STREAM or GOAWAY connErr bool // true if HTTP/2 connection error spdyGoAwayErrCode spdy.GoAwayStatus // status code received in SPDY RST_STREAM spdyRstErrCode spdy.RstStreamStatus // status code received in SPDY GOAWAY connClose bool // Conection: close is included in response header in HTTP/1 test + reqHeader http.Header // http request header, currently only sotres pushed request header + pushResponse []*serverResponse // pushed response +} + +type ByStreamID []*serverResponse + +func (b ByStreamID) Len() int { + return len(b) +} + +func (b ByStreamID) Swap(i, j int) { + b[i], b[j] = b[j], b[i] +} + +func (b ByStreamID) Less(i, j int) bool { + return b[i].streamID < b[j].streamID } func cloneHeader(h http.Header) http.Header {