diff --git a/integration-tests/nghttpx_http2_test.go b/integration-tests/nghttpx_http2_test.go index 909816c3..89cd353a 100644 --- a/integration-tests/nghttpx_http2_test.go +++ b/integration-tests/nghttpx_http2_test.go @@ -558,6 +558,39 @@ func TestH2H1RequestTrailer(t *testing.T) { } } +// TestH2H1Upgrade tests HTTP Upgrade to HTTP/2 +func TestH2H1Upgrade(t *testing.T) { + st := newServerTester(nil, t, func(w http.ResponseWriter, r *http.Request) {}) + defer st.Close() + + res, err := st.http1(requestParam{ + name: "TestH2H1Upgrade", + header: []hpack.HeaderField{ + pair("Connection", "Upgrade, HTTP2-Settings"), + pair("Upgrade", "h2c-14"), + pair("HTTP2-Settings", "AAMAAABkAAQAAP__"), + }, + }) + + if err != nil { + t.Fatalf("Error st.http1() = %v", err) + } + + if got, want := res.status, 101; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } + + res, err = st.http2(requestParam{ + httpUpgrade: true, + }) + if err != nil { + t.Fatalf("Error st.http2() = %v", err) + } + if got, want := res.status, 200; got != want { + t.Errorf("res.status: %v; want %v", got, want) + } +} + // TestH2H1GracefulShutdown tests graceful shutdown. func TestH2H1GracefulShutdown(t *testing.T) { st := newServerTester(nil, t, noopHandler) diff --git a/integration-tests/server_tester.go b/integration-tests/server_tester.go index 58a3051a..3a4b5568 100644 --- a/integration-tests/server_tester.go +++ b/integration-tests/server_tester.go @@ -247,15 +247,16 @@ func (st *serverTester) readSpdyFrame() (spdy.Frame, error) { } type requestParam struct { - name string // name for this request to identify the request in log easily - streamID uint32 // stream ID, automatically assigned if 0 - method string // method, defaults to GET - scheme string // scheme, defaults to http - authority string // authority, defaults to backend server address - path string // path, defaults to / - header []hpack.HeaderField // additional request header fields - body []byte // request body - trailer []hpack.HeaderField // trailer part + name string // name for this request to identify the request in log easily + streamID uint32 // stream ID, automatically assigned if 0 + method string // method, defaults to GET + scheme string // scheme, defaults to http + authority string // authority, defaults to backend server address + path string // path, defaults to / + header []hpack.HeaderField // additional request header fields + body []byte // request body + trailer []hpack.HeaderField // trailer part + httpUpgrade bool // true if upgraded to HTTP/2 through HTTP Upgrade } // wrapper for request body to set trailer part @@ -478,69 +479,70 @@ func (st *serverTester) http2(rp requestParam) (*serverResponse, error) { streams := make(map[uint32]*serverResponse) streams[id] = res - method := "GET" - if rp.method != "" { - method = rp.method - } - _ = st.enc.WriteField(pair(":method", method)) - - scheme := "http" - if rp.scheme != "" { - scheme = rp.scheme - } - _ = st.enc.WriteField(pair(":scheme", scheme)) - - authority := st.authority - if rp.authority != "" { - authority = rp.authority - } - _ = st.enc.WriteField(pair(":authority", authority)) - - path := "/" - if rp.path != "" { - path = rp.path - } - _ = st.enc.WriteField(pair(":path", path)) - - _ = st.enc.WriteField(pair("test-case", rp.name)) - - for _, h := range rp.header { - _ = st.enc.WriteField(h) - } - - err := st.fr.WriteHeaders(http2.HeadersFrameParam{ - StreamID: id, - EndStream: len(rp.body) == 0 && len(rp.trailer) == 0, - EndHeaders: true, - BlockFragment: st.headerBlkBuf.Bytes(), - }) - if err != nil { - return nil, err - } - - if len(rp.body) != 0 { - // TODO we assume rp.body fits in 1 frame - if err := st.fr.WriteData(id, len(rp.trailer) == 0, rp.body); err != nil { - return nil, err + if !rp.httpUpgrade { + method := "GET" + if rp.method != "" { + method = rp.method } - } + _ = st.enc.WriteField(pair(":method", method)) - if len(rp.trailer) != 0 { - st.headerBlkBuf.Reset() - for _, h := range rp.trailer { + scheme := "http" + if rp.scheme != "" { + scheme = rp.scheme + } + _ = st.enc.WriteField(pair(":scheme", scheme)) + + authority := st.authority + if rp.authority != "" { + authority = rp.authority + } + _ = st.enc.WriteField(pair(":authority", authority)) + + path := "/" + if rp.path != "" { + path = rp.path + } + _ = st.enc.WriteField(pair(":path", path)) + + _ = st.enc.WriteField(pair("test-case", rp.name)) + + for _, h := range rp.header { _ = st.enc.WriteField(h) } + err := st.fr.WriteHeaders(http2.HeadersFrameParam{ StreamID: id, - EndStream: true, + EndStream: len(rp.body) == 0 && len(rp.trailer) == 0, EndHeaders: true, BlockFragment: st.headerBlkBuf.Bytes(), }) if err != nil { return nil, err } - } + if len(rp.body) != 0 { + // TODO we assume rp.body fits in 1 frame + if err := st.fr.WriteData(id, len(rp.trailer) == 0, rp.body); err != nil { + return nil, err + } + } + + if len(rp.trailer) != 0 { + st.headerBlkBuf.Reset() + for _, h := range rp.trailer { + _ = st.enc.WriteField(h) + } + err := st.fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: id, + EndStream: true, + EndHeaders: true, + BlockFragment: st.headerBlkBuf.Bytes(), + }) + if err != nil { + return nil, err + } + } + } loop: for { fr, err := st.readFrame()