diff --git a/configure.ac b/configure.ac index 4152b8d9..9237b4ad 100644 --- a/configure.ac +++ b/configure.ac @@ -651,6 +651,7 @@ AC_CONFIG_FILES([ examples/Makefile python/Makefile python/setup.py + integration-tests/config.go doc/Makefile doc/conf.py doc/index.rst diff --git a/integration-tests/config.go.in b/integration-tests/config.go.in new file mode 100644 index 00000000..0a6fd6b7 --- /dev/null +++ b/integration-tests/config.go.in @@ -0,0 +1,5 @@ +package nghttp2 + +const ( + buildDir = "@top_builddir@" +) diff --git a/integration-tests/nghttpx_test.go b/integration-tests/nghttpx_test.go new file mode 100644 index 00000000..ddeee27c --- /dev/null +++ b/integration-tests/nghttpx_test.go @@ -0,0 +1,174 @@ +package nghttp2 + +import ( + "fmt" + "github.com/bradfitz/http2" + "net/http" + "testing" +) + +func TestPlainGET(t *testing.T) { + st := newServerTester(nil, t, noopHandler) + defer st.Close() + + res, err := st.http2(requestParam{ + name: "TestPlainGet", + }) + if err != nil { + t.Errorf("Error st.http2() = %v", err) + } + + want := 200 + if res.status != want { + t.Errorf("status = %v; want %v", res.status, want) + } +} + +func TestAddXff(t *testing.T) { + st := newServerTester([]string{"--add-x-forwarded-for"}, t, func(w http.ResponseWriter, r *http.Request) { + xff := r.Header.Get("X-Forwarded-For") + want := "127.0.0.1" + if xff != want { + t.Errorf("X-Forwarded-For = %v; want %v", xff, want) + } + }) + defer st.Close() + + _, err := st.http2(requestParam{ + name: "TestAddXff", + }) + if err != nil { + t.Errorf("Error st.http2() = %v", err) + } +} + +func TestAddXff2(t *testing.T) { + st := newServerTester([]string{"--add-x-forwarded-for"}, t, func(w http.ResponseWriter, r *http.Request) { + xff := r.Header.Get("X-Forwarded-For") + want := "host, 127.0.0.1" + if xff != want { + t.Errorf("X-Forwarded-For = %v; want %v", xff, want) + } + }) + defer st.Close() + + _, err := st.http2(requestParam{ + name: "TestAddXff2", + header: http.Header{ + "x-forwarded-for": []string{"host"}, + }, + }) + if err != nil { + t.Errorf("Error st.http2() = %v", err) + } +} + +func TestStripXff(t *testing.T) { + st := newServerTester([]string{"--strip-incoming-x-forwarded-for"}, t, func(w http.ResponseWriter, r *http.Request) { + if xff, found := r.Header["X-Forwarded-For"]; found { + t.Errorf("X-Forwarded-For = %v; want nothing", xff) + } + }) + defer st.Close() + + _, err := st.http2(requestParam{ + name: "TestStripXff1", + header: http.Header{ + "x-forwarded-for": []string{"host"}, + }, + }) + if err != nil { + t.Errorf("Error st.http2() = %v", err) + } +} + +func TestStripAddXff(t *testing.T) { + args := []string{ + "--strip-incoming-x-forwarded-for", + "--add-x-forwarded-for", + } + st := newServerTester(args, t, func(w http.ResponseWriter, r *http.Request) { + xff := r.Header.Get("X-Forwarded-For") + want := "127.0.0.1" + if xff != want { + t.Errorf("X-Forwarded-For = %v; want %v", xff, want) + } + }) + defer st.Close() + + _, err := st.http2(requestParam{ + name: "TestStripAddXff", + header: http.Header{ + "x-forwarded-for": []string{"host"}, + }, + }) + if err != nil { + t.Errorf("Error st.http2() = %v", err) + } +} + +func TestHTTP2BadRequestCL(t *testing.T) { + st := newServerTester(nil, t, noopHandler) + defer st.Close() + + // we set content-length: 1024, but the actual request body is + // 3 bytes. + res, err := st.http2(requestParam{ + name: "TestHTTP2BadRequestCL", + method: "POST", + header: http.Header{ + "content-length": []string{"1024"}, + }, + body: []byte("foo"), + }) + if err != nil { + t.Errorf("Error st.http2() = %v", err) + } + + want := http2.ErrCodeProtocol + if res.errCode != want { + t.Errorf("res.errCode = %v; want %v", res.errCode, want) + } +} + +func TestHTTP2BadResponseCL(t *testing.T) { + st := newServerTester(nil, t, func(w http.ResponseWriter, r *http.Request) { + // we set content-length: 1024, but only send 3 bytes. + w.Header().Add("Content-Length", "1024") + w.Write([]byte("foo")) + }) + defer st.Close() + + res, err := st.http2(requestParam{ + name: "TestHTTP2BadResponseCL", + }) + if err != nil { + t.Errorf("Error st.http2() = %v", err) + } + + want := http2.ErrCodeProtocol + if res.errCode != want { + t.Errorf("res.errCode = %v; want %v", res.errCode, want) + } +} + +func TestHTTP2LocationRewrite(t *testing.T) { + st := newServerTester(nil, t, func(w http.ResponseWriter, r *http.Request) { + // TODO we cannot get st.ts's port number here.. 8443 + // is just a place holder. We ignore it on rewrite. + w.Header().Add("Location", "http://127.0.0.1:8443/p/q?a=b#fragment") + }) + defer st.Close() + + res, err := st.http2(requestParam{ + name: "TestHTTP2LocationRewrite", + }) + if err != nil { + t.Errorf("Error st.http2() = %v", err) + } + + want := fmt.Sprintf("http://127.0.0.1:%v/p/q?a=b#fragment", serverPort) + if got := res.header.Get("Location"); got != want { + t.Errorf("Location: %v; want %v", got, want) + } +} diff --git a/integration-tests/server_tester.go b/integration-tests/server_tester.go new file mode 100644 index 00000000..c7ecb53d --- /dev/null +++ b/integration-tests/server_tester.go @@ -0,0 +1,301 @@ +package nghttp2 + +import ( + "bytes" + "errors" + "fmt" + "github.com/bradfitz/http2" + "github.com/bradfitz/http2/hpack" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os/exec" + "strconv" + "strings" + "testing" + "time" +) + +const ( + serverBin = buildDir + "/src/nghttpx" + serverPort = 3009 +) + +func pair(name, value string) hpack.HeaderField { + return hpack.HeaderField{ + Name: name, + Value: value, + } +} + +type serverTester struct { + args []string // command-line arguments + cmd *exec.Cmd // test frontend server process, which is test subject + t *testing.T + ts *httptest.Server // backend server + conn net.Conn // connection to frontend server + h2PrefaceSent bool // HTTP/2 preface was sent in conn + nextStreamID uint32 // next stream ID + fr *http2.Framer + headerBlkBuf bytes.Buffer // buffer to store encoded header block + enc *hpack.Encoder + header http.Header // received header fields + dec *hpack.Decoder + authority string // server's host:port + frCh chan http2.Frame + errCh chan error +} + +func newServerTester(args []string, t *testing.T, handler http.HandlerFunc) *serverTester { + ts := httptest.NewServer(handler) + + u, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("Error parsing URL from httptest.Server: %v", err) + } + + // URL.Host looks like "127.0.0.1:8080", but we want + // "127.0.0.1,8080" + b := "-b" + strings.Replace(u.Host, ":", ",", -1) + args = append(args, fmt.Sprintf("-f127.0.0.1,%v", serverPort), b, + "--errorlog-file="+buildDir+"/integration-tests/log.txt", + "-LINFO", "--frontend-no-tls") + + st := &serverTester{ + cmd: exec.Command(serverBin, args...), + t: t, + ts: ts, + nextStreamID: 1, + authority: u.Host, + frCh: make(chan http2.Frame), + errCh: make(chan error), + } + + if err := st.cmd.Start(); err != nil { + st.t.Fatalf("Error starting %v: %v", serverBin, err) + } + + retry := 0 + for { + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%v", serverPort)) + if err != nil { + retry += 1 + if retry >= 10 { + st.t.Fatalf("Error server is not responding too long; server command-line arguments may be invalid") + } + time.Sleep(500 * time.Millisecond) + continue + } + st.conn = conn + break + } + + st.fr = http2.NewFramer(st.conn, st.conn) + st.enc = hpack.NewEncoder(&st.headerBlkBuf) + st.dec = hpack.NewDecoder(4096, func(f hpack.HeaderField) { + st.header.Add(f.Name, f.Value) + }) + + return st +} + +func (st *serverTester) Close() { + if st.conn != nil { + st.conn.Close() + } + if st.cmd != nil { + st.cmd.Process.Kill() + st.cmd.Wait() + } + if st.ts != nil { + st.ts.Close() + } + close(st.frCh) + close(st.errCh) +} + +func (st *serverTester) readFrame() (http2.Frame, error) { + go func() { + f, err := st.fr.ReadFrame() + if err != nil { + st.errCh <- err + return + } + st.frCh <- f + }() + + t := time.NewTimer(2 * time.Second) + defer t.Stop() + select { + case f := <-st.frCh: + return f, nil + case err := <-st.errCh: + return nil, err + case <-t.C: + return nil, errors.New("timeout waiting for frame") + } +} + +type requestParam struct { + name string // name for this request to identify the request + // in log easily + streamID uint32 // stream ID, automatically assigned if 0 + method string // method, defaults to GET + scheme string // scheme, defaults to http + authority string // authority, defaults to backend server address + path string // path, defaults to / + header http.Header // additional request header fields + body []byte // request body +} + +func (st *serverTester) http2(rp requestParam) (*serverResponse, error) { + res := &serverResponse{} + st.headerBlkBuf.Reset() + st.header = make(http.Header) + + var id uint32 + if rp.streamID != 0 { + id = rp.streamID + if id >= st.nextStreamID && id%2 == 1 { + st.nextStreamID = id + 2 + } + } else { + id = st.nextStreamID + st.nextStreamID += 2 + } + + if !st.h2PrefaceSent { + st.h2PrefaceSent = true + fmt.Fprint(st.conn, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + if err := st.fr.WriteSettings(); err != nil { + return nil, err + } + } + + method := "GET" + if rp.method != "" { + method = rp.method + } + _ = st.enc.WriteField(pair(":method", method)) + + scheme := "http" + if rp.scheme != "" { + scheme = rp.scheme + } + _ = st.enc.WriteField(pair(":scheme", scheme)) + + authority := st.authority + if rp.authority != "" { + authority = rp.authority + } + _ = st.enc.WriteField(pair(":authority", authority)) + + path := "/" + if rp.path != "" { + path = rp.path + } + _ = st.enc.WriteField(pair(":path", path)) + + _ = st.enc.WriteField(pair("test-case", rp.name)) + + for k, v := range rp.header { + for _, h := range v { + _ = st.enc.WriteField(pair(strings.ToLower(k), h)) + } + } + + err := st.fr.WriteHeaders(http2.HeadersFrameParam{ + StreamID: id, + EndStream: len(rp.body) == 0, + EndHeaders: true, + BlockFragment: st.headerBlkBuf.Bytes(), + }) + if err != nil { + return nil, err + } + + if len(rp.body) != 0 { + // TODO we assume rp.body fits in 1 frame + if err := st.fr.WriteData(id, true, rp.body); err != nil { + return nil, err + } + } + +loop: + for { + fr, err := st.readFrame() + if err != nil { + return res, err + } + switch f := fr.(type) { + case *http2.HeadersFrame: + _, err := st.dec.Write(f.HeaderBlockFragment()) + if err != nil { + return res, err + } + if f.FrameHeader.StreamID != id { + st.header = make(http.Header) + break + } + res.header = cloneHeader(st.header) + res.status, err = strconv.Atoi(res.header.Get(":status")) + if err != nil { + return res, fmt.Errorf("Error parsing status code: %v", err) + } + + if f.StreamEnded() { + break loop + } + case *http2.DataFrame: + if f.FrameHeader.StreamID != id { + break + } + res.body = append(res.body, f.Data()...) + if f.StreamEnded() { + break loop + } + case *http2.RSTStreamFrame: + if f.FrameHeader.StreamID != id { + break + } + res.errCode = f.ErrCode + break loop + case *http2.GoAwayFrame: + if f.FrameHeader.StreamID != id || f.ErrCode == http2.ErrCodeNo { + break + } + res.errCode = f.ErrCode + res.connErr = true + break loop + case *http2.SettingsFrame: + if f.IsAck() { + break + } + if err := st.fr.WriteSettingsAck(); err != nil { + return res, err + } + } + } + return res, nil +} + +type serverResponse struct { + status int // HTTP status code + header http.Header // response header fields + body []byte // response body + errCode http2.ErrCode // error code received in RST_STREAM or GOAWAY + connErr bool // true if connection error +} + +func cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +func noopHandler(w http.ResponseWriter, r *http.Request) {}