From 16e91746d99b3eb78519ec64a74b63b5d421582b Mon Sep 17 00:00:00 2001 From: Tatsuhiro Tsujikawa Date: Wed, 21 Jan 2015 01:03:56 +0900 Subject: [PATCH] nghttpx: Return 400 error if multiple CLs are received in SPDY upstream This change adds SPDY upstream tests. --- integration-tests/nghttpx_test.go | 85 +++++++++++ integration-tests/server_tester.go | 236 ++++++++++++++++++++++++++--- src/shrpx_spdy_upstream.cc | 23 +-- 3 files changed, 310 insertions(+), 34 deletions(-) diff --git a/integration-tests/nghttpx_test.go b/integration-tests/nghttpx_test.go index 7d75d6f8..76ead2e8 100644 --- a/integration-tests/nghttpx_test.go +++ b/integration-tests/nghttpx_test.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/bradfitz/http2" "github.com/bradfitz/http2/hpack" + "golang.org/x/net/spdy" "io" "io/ioutil" "net/http" @@ -385,3 +386,87 @@ func TestH2H2InvalidResponseCL(t *testing.T) { t.Errorf("status: %v; want %v", got, want) } } + +func TestS3H1PlainGET(t *testing.T) { + st := newServerTesterTLS([]string{"--npn-list=spdy/3.1"}, t, noopHandler) + defer st.Close() + + res, err := st.spdy(requestParam{ + name: "TestS3H1PlainGET", + }) + if err != nil { + t.Fatalf("Error st.spdy() = %v", err) + } + + want := 200 + if got := res.status; got != want { + t.Errorf("status = %v; want %v", got, want) + } +} + +func TestS3H1BadRequestCL(t *testing.T) { + st := newServerTesterTLS([]string{"--npn-list=spdy/3.1"}, t, noopHandler) + defer st.Close() + + // we set content-length: 1024, but the actual request body is + // 3 bytes. + res, err := st.spdy(requestParam{ + name: "TestS3H1BadRequestCL", + method: "POST", + header: []hpack.HeaderField{ + pair("content-length", "1024"), + }, + body: []byte("foo"), + }) + if err != nil { + t.Fatalf("Error st.spdy() = %v", err) + } + + want := spdy.ProtocolError + if got := res.spdyRstErrCode; got != want { + t.Errorf("res.spdyRstErrCode = %v; want %v", got, want) + } +} + +func TestS3H1MultipleRequestCL(t *testing.T) { + st := newServerTesterTLS([]string{"--npn-list=spdy/3.1"}, t, func(w http.ResponseWriter, r *http.Request) { + t.Errorf("server should not forward bad request") + }) + defer st.Close() + + res, err := st.spdy(requestParam{ + name: "TestS3H1MultipleRequestCL", + header: []hpack.HeaderField{ + pair("content-length", "1"), + pair("content-length", "2"), + }, + }) + if err != nil { + t.Fatalf("Error st.spdy() = %v", err) + } + want := 400 + if got := res.status; got != want { + t.Errorf("status: %v; want %v", got, want) + } +} + +func TestS3H1InvalidRequestCL(t *testing.T) { + st := newServerTesterTLS([]string{"--npn-list=spdy/3.1"}, t, func(w http.ResponseWriter, r *http.Request) { + t.Errorf("server should not forward bad request") + }) + defer st.Close() + + res, err := st.spdy(requestParam{ + name: "TestS3H1InvalidRequestCL", + header: []hpack.HeaderField{ + pair("content-length", ""), + }, + }) + if err != nil { + t.Fatalf("Error st.spdy() = %v", err) + } + want := 400 + if got := res.status; got != want { + t.Errorf("status: %v; want %v", got, want) + } +} diff --git a/integration-tests/server_tester.go b/integration-tests/server_tester.go index 6b859222..5d18d54e 100644 --- a/integration-tests/server_tester.go +++ b/integration-tests/server_tester.go @@ -9,6 +9,7 @@ import ( "github.com/bradfitz/http2" "github.com/bradfitz/http2/hpack" "github.com/tatsuhiro-t/go-nghttp2" + "golang.org/x/net/spdy" "io" "io/ioutil" "net" @@ -25,6 +26,7 @@ import ( const ( serverBin = buildDir + "/src/nghttpx" serverPort = 3009 + testDir = buildDir + "/integration-tests" ) func pair(name, value string) hpack.HeaderField { @@ -43,24 +45,39 @@ type serverTester struct { conn net.Conn // connection to frontend server h2PrefaceSent bool // HTTP/2 preface was sent in conn nextStreamID uint32 // next stream ID - fr *http2.Framer - headerBlkBuf bytes.Buffer // buffer to store encoded header block - enc *hpack.Encoder - header http.Header // received header fields - dec *hpack.Decoder - authority string // server's host:port - frCh chan http2.Frame + fr *http2.Framer // HTTP/2 framer + spdyFr *spdy.Framer // SPDY/3.1 framer + headerBlkBuf bytes.Buffer // buffer to store encoded header block + enc *hpack.Encoder // HTTP/2 HPACK encoder + header http.Header // received header fields + dec *hpack.Decoder // HTTP/2 HPACK decoder + authority string // server's host:port + frCh chan http2.Frame // used for incoming HTTP/2 frame + spdyFrCh chan spdy.Frame // used for incoming SPDY frame errCh chan error } +// newServerTester creates test context for plain TCP frontend +// connection. func newServerTester(args []string, t *testing.T, handler http.HandlerFunc) *serverTester { + return newServerTesterInternal(args, t, handler, false) +} + +// newServerTester creates test context for TLS frontend connection. +func newServerTesterTLS(args []string, t *testing.T, handler http.HandlerFunc) *serverTester { + return newServerTesterInternal(args, t, handler, true) +} + +// newServerTesterInternal creates test context. If frontendTLS is +// true, set up TLS frontend connection. +func newServerTesterInternal(args []string, t *testing.T, handler http.HandlerFunc, frontendTLS bool) *serverTester { ts := httptest.NewUnstartedServer(handler) backendTLS := false for _, k := range args { - if k == "--http2-bridge" { + switch k { + case "--http2-bridge": backendTLS = true - break } } if backendTLS { @@ -75,26 +92,36 @@ func newServerTester(args []string, t *testing.T, handler http.HandlerFunc) *ser } else { ts.Start() } - u, err := url.Parse(ts.URL) + scheme := "http" + if frontendTLS { + scheme = "https" + args = append(args, testDir+"/server.key", testDir+"/server.crt") + } else { + args = append(args, "--frontend-no-tls") + } + + backendURL, err := url.Parse(ts.URL) if err != nil { t.Fatalf("Error parsing URL from httptest.Server: %v", err) } // URL.Host looks like "127.0.0.1:8080", but we want // "127.0.0.1,8080" - b := "-b" + strings.Replace(u.Host, ":", ",", -1) + b := "-b" + strings.Replace(backendURL.Host, ":", ",", -1) args = append(args, fmt.Sprintf("-f127.0.0.1,%v", serverPort), b, - "--errorlog-file="+buildDir+"/integration-tests/log.txt", - "-LINFO", "--frontend-no-tls") + "--errorlog-file="+testDir+"/log.txt", "-LINFO") + + authority := fmt.Sprintf("127.0.0.1:%v", serverPort) st := &serverTester{ cmd: exec.Command(serverBin, args...), t: t, ts: ts, - url: fmt.Sprintf("http://127.0.0.1:%v", serverPort), + url: fmt.Sprintf("%v://%v", scheme, authority), nextStreamID: 1, - authority: u.Host, + authority: authority, frCh: make(chan http2.Frame), + spdyFrCh: make(chan spdy.Frame), errCh: make(chan error), } @@ -104,20 +131,45 @@ func newServerTester(args []string, t *testing.T, handler http.HandlerFunc) *ser retry := 0 for { - conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%v", serverPort)) + var conn net.Conn + var err error + if frontendTLS { + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2-14", "spdy/3.1"}, + } + conn, err = tls.Dial("tcp", authority, tlsConfig) + } else { + conn, err = net.Dial("tcp", authority) + } if err != nil { retry += 1 if retry >= 100 { + st.Close() st.t.Fatalf("Error server is not responding too long; server command-line arguments may be invalid") } time.Sleep(150 * time.Millisecond) continue } + if frontendTLS { + tlsConn := conn.(*tls.Conn) + cs := tlsConn.ConnectionState() + if !cs.NegotiatedProtocolIsMutual { + st.Close() + st.t.Fatalf("Error negotiated next protocol is not mutual") + } + } st.conn = conn break } st.fr = http2.NewFramer(st.conn, st.conn) + spdyFr, err := spdy.NewFramer(st.conn, st.conn) + if err != nil { + st.Close() + st.t.Fatalf("Error spdy.NewFramer: %v", err) + } + st.spdyFr = spdyFr st.enc = hpack.NewEncoder(&st.headerBlkBuf) st.dec = hpack.NewDecoder(4096, func(f hpack.HeaderField) { st.header.Add(f.Name, f.Value) @@ -159,6 +211,26 @@ func (st *serverTester) readFrame() (http2.Frame, error) { } } +func (st *serverTester) readSpdyFrame() (spdy.Frame, error) { + go func() { + f, err := st.spdyFr.ReadFrame() + if err != nil { + st.errCh <- err + return + } + st.spdyFrCh <- f + }() + + select { + case f := <-st.spdyFrCh: + return f, nil + case err := <-st.errCh: + return nil, err + case <-time.After(2 * time.Second): + return nil, errors.New("timeout waiting for frame") + } +} + type requestParam struct { name string // name for this request to identify the request in log easily streamID uint32 // stream ID, automatically assigned if 0 @@ -211,6 +283,118 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) { return res, nil } +func (st *serverTester) spdy(rp requestParam) (*serverResponse, error) { + res := &serverResponse{} + + var id spdy.StreamId + if rp.streamID != 0 { + id = spdy.StreamId(rp.streamID) + if id >= spdy.StreamId(st.nextStreamID) && id%2 == 1 { + st.nextStreamID = uint32(id) + 2 + } + } else { + id = spdy.StreamId(st.nextStreamID) + st.nextStreamID += 2 + } + + method := "GET" + if rp.method != "" { + method = rp.method + } + + scheme := "http" + if rp.scheme != "" { + scheme = rp.scheme + } + + host := st.authority + if rp.authority != "" { + host = rp.authority + } + + path := "/" + if rp.path != "" { + path = rp.path + } + + header := make(http.Header) + header.Add(":method", method) + header.Add(":scheme", scheme) + header.Add(":host", host) + header.Add(":path", path) + header.Add(":version", "HTTP/1.1") + header.Add("test-case", rp.name) + for _, h := range rp.header { + header.Add(h.Name, h.Value) + } + + var synStreamFlags spdy.ControlFlags + if len(rp.body) == 0 { + synStreamFlags = spdy.ControlFlagFin + } + if err := st.spdyFr.WriteFrame(&spdy.SynStreamFrame{ + CFHeader: spdy.ControlFrameHeader{ + Flags: synStreamFlags, + }, + StreamId: id, + Headers: header, + }); err != nil { + return nil, err + } + + if len(rp.body) != 0 { + if err := st.spdyFr.WriteFrame(&spdy.DataFrame{ + StreamId: id, + Flags: spdy.DataFlagFin, + Data: rp.body, + }); err != nil { + return nil, err + } + } + +loop: + for { + fr, err := st.readSpdyFrame() + if err != nil { + return res, err + } + switch f := fr.(type) { + case *spdy.SynReplyFrame: + if f.StreamId != id { + break + } + res.header = cloneHeader(f.Headers) + if _, err := fmt.Sscan(res.header.Get(":status"), &res.status); err != nil { + return res, fmt.Errorf("Error parsing status code: %v", err) + } + if f.CFHeader.Flags&spdy.ControlFlagFin != 0 { + break loop + } + case *spdy.DataFrame: + if f.StreamId != id { + break + } + res.body = append(res.body, f.Data...) + if f.Flags&spdy.DataFlagFin != 0 { + break loop + } + case *spdy.RstStreamFrame: + if f.StreamId != id { + break + } + res.spdyRstErrCode = f.Status + break loop + case *spdy.GoAwayFrame: + if f.Status == spdy.GoAwayOK { + break + } + res.spdyGoAwayErrCode = f.Status + break loop + } + } + return res, nil +} + func (st *serverTester) http2(rp requestParam) (*serverResponse, error) { res := &serverResponse{} st.headerBlkBuf.Reset() @@ -299,11 +483,12 @@ loop: break } res.header = cloneHeader(st.header) - res.status, err = strconv.Atoi(res.header.Get(":status")) + var status int + status, err = strconv.Atoi(res.header.Get(":status")) if err != nil { return res, fmt.Errorf("Error parsing status code: %v", err) } - + res.status = status if f.StreamEnded() { break loop } @@ -322,7 +507,7 @@ loop: res.errCode = f.ErrCode break loop case *http2.GoAwayFrame: - if f.FrameHeader.StreamID != id || f.ErrCode == http2.ErrCodeNo { + if f.ErrCode == http2.ErrCodeNo { break } res.errCode = f.ErrCode @@ -335,17 +520,20 @@ loop: if err := st.fr.WriteSettingsAck(); err != nil { return res, err } + // TODO handle PUSH_PROMISE as well, since it alters HPACK context } } return res, nil } type serverResponse struct { - status int // HTTP status code - header http.Header // response header fields - body []byte // response body - errCode http2.ErrCode // error code received in RST_STREAM or GOAWAY - connErr bool // true if connection error + status int // HTTP status code + header http.Header // response header fields + body []byte // response body + errCode http2.ErrCode // error code received in HTTP/2 RST_STREAM or GOAWAY + connErr bool // true if HTTP/2 connection error + spdyGoAwayErrCode spdy.GoAwayStatus // status code received in SPDY RST_STREAM + spdyRstErrCode spdy.RstStreamStatus // status code received in SPDY GOAWAY } func cloneHeader(h http.Header) http.Header { diff --git a/src/shrpx_spdy_upstream.cc b/src/shrpx_spdy_upstream.cc index f57e38ed..4a03485f 100644 --- a/src/shrpx_spdy_upstream.cc +++ b/src/shrpx_spdy_upstream.cc @@ -156,11 +156,23 @@ void on_ctrl_recv_callback(spdylay_session *session, spdylay_frame_type type, auto nv = frame->syn_stream.nv; + if (LOG_ENABLED(INFO)) { + std::stringstream ss; + for (size_t i = 0; nv[i]; i += 2) { + ss << TTY_HTTP_HD << nv[i] << TTY_RST << ": " << nv[i + 1] << "\n"; + } + ULOG(INFO, upstream) << "HTTP request headers. stream_id=" + << downstream->get_stream_id() << "\n" << ss.str(); + } + for (size_t i = 0; nv[i]; i += 2) { downstream->add_request_header(nv[i], nv[i + 1]); } - downstream->index_request_headers(); + if (downstream->index_request_headers() != 0) { + upstream->error_reply(downstream, 400); + return; + } auto path = downstream->get_request_header(http2::HD__PATH); auto scheme = downstream->get_request_header(http2::HD__SCHEME); @@ -193,15 +205,6 @@ void on_ctrl_recv_callback(spdylay_session *session, spdylay_frame_type type, downstream->inspect_http2_request(); - if (LOG_ENABLED(INFO)) { - std::stringstream ss; - for (size_t i = 0; nv[i]; i += 2) { - ss << TTY_HTTP_HD << nv[i] << TTY_RST << ": " << nv[i + 1] << "\n"; - } - ULOG(INFO, upstream) << "HTTP request headers. stream_id=" - << downstream->get_stream_id() << "\n" << ss.str(); - } - downstream->set_request_state(Downstream::HEADER_COMPLETE); if (frame->syn_stream.hd.flags & SPDYLAY_CTRL_FLAG_FIN) { if (!downstream->validate_request_bodylen()) {