nghttpx: Return 400 error if multiple CLs are received in SPDY upstream
This change adds SPDY upstream tests.
This commit is contained in:
parent
b9a9a23b1e
commit
16e91746d9
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"github.com/bradfitz/http2"
|
||||
"github.com/bradfitz/http2/hpack"
|
||||
"golang.org/x/net/spdy"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
@ -385,3 +386,87 @@ func TestH2H2InvalidResponseCL(t *testing.T) {
|
|||
t.Errorf("status: %v; want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestS3H1PlainGET(t *testing.T) {
|
||||
st := newServerTesterTLS([]string{"--npn-list=spdy/3.1"}, t, noopHandler)
|
||||
defer st.Close()
|
||||
|
||||
res, err := st.spdy(requestParam{
|
||||
name: "TestS3H1PlainGET",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Error st.spdy() = %v", err)
|
||||
}
|
||||
|
||||
want := 200
|
||||
if got := res.status; got != want {
|
||||
t.Errorf("status = %v; want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestS3H1BadRequestCL(t *testing.T) {
|
||||
st := newServerTesterTLS([]string{"--npn-list=spdy/3.1"}, t, noopHandler)
|
||||
defer st.Close()
|
||||
|
||||
// we set content-length: 1024, but the actual request body is
|
||||
// 3 bytes.
|
||||
res, err := st.spdy(requestParam{
|
||||
name: "TestS3H1BadRequestCL",
|
||||
method: "POST",
|
||||
header: []hpack.HeaderField{
|
||||
pair("content-length", "1024"),
|
||||
},
|
||||
body: []byte("foo"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Error st.spdy() = %v", err)
|
||||
}
|
||||
|
||||
want := spdy.ProtocolError
|
||||
if got := res.spdyRstErrCode; got != want {
|
||||
t.Errorf("res.spdyRstErrCode = %v; want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestS3H1MultipleRequestCL(t *testing.T) {
|
||||
st := newServerTesterTLS([]string{"--npn-list=spdy/3.1"}, t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Errorf("server should not forward bad request")
|
||||
})
|
||||
defer st.Close()
|
||||
|
||||
res, err := st.spdy(requestParam{
|
||||
name: "TestS3H1MultipleRequestCL",
|
||||
header: []hpack.HeaderField{
|
||||
pair("content-length", "1"),
|
||||
pair("content-length", "2"),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Error st.spdy() = %v", err)
|
||||
}
|
||||
want := 400
|
||||
if got := res.status; got != want {
|
||||
t.Errorf("status: %v; want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestS3H1InvalidRequestCL(t *testing.T) {
|
||||
st := newServerTesterTLS([]string{"--npn-list=spdy/3.1"}, t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Errorf("server should not forward bad request")
|
||||
})
|
||||
defer st.Close()
|
||||
|
||||
res, err := st.spdy(requestParam{
|
||||
name: "TestS3H1InvalidRequestCL",
|
||||
header: []hpack.HeaderField{
|
||||
pair("content-length", ""),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Error st.spdy() = %v", err)
|
||||
}
|
||||
want := 400
|
||||
if got := res.status; got != want {
|
||||
t.Errorf("status: %v; want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/bradfitz/http2"
|
||||
"github.com/bradfitz/http2/hpack"
|
||||
"github.com/tatsuhiro-t/go-nghttp2"
|
||||
"golang.org/x/net/spdy"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
|
@ -25,6 +26,7 @@ import (
|
|||
const (
|
||||
serverBin = buildDir + "/src/nghttpx"
|
||||
serverPort = 3009
|
||||
testDir = buildDir + "/integration-tests"
|
||||
)
|
||||
|
||||
func pair(name, value string) hpack.HeaderField {
|
||||
|
@ -43,24 +45,39 @@ type serverTester struct {
|
|||
conn net.Conn // connection to frontend server
|
||||
h2PrefaceSent bool // HTTP/2 preface was sent in conn
|
||||
nextStreamID uint32 // next stream ID
|
||||
fr *http2.Framer
|
||||
fr *http2.Framer // HTTP/2 framer
|
||||
spdyFr *spdy.Framer // SPDY/3.1 framer
|
||||
headerBlkBuf bytes.Buffer // buffer to store encoded header block
|
||||
enc *hpack.Encoder
|
||||
enc *hpack.Encoder // HTTP/2 HPACK encoder
|
||||
header http.Header // received header fields
|
||||
dec *hpack.Decoder
|
||||
dec *hpack.Decoder // HTTP/2 HPACK decoder
|
||||
authority string // server's host:port
|
||||
frCh chan http2.Frame
|
||||
frCh chan http2.Frame // used for incoming HTTP/2 frame
|
||||
spdyFrCh chan spdy.Frame // used for incoming SPDY frame
|
||||
errCh chan error
|
||||
}
|
||||
|
||||
// newServerTester creates test context for plain TCP frontend
|
||||
// connection.
|
||||
func newServerTester(args []string, t *testing.T, handler http.HandlerFunc) *serverTester {
|
||||
return newServerTesterInternal(args, t, handler, false)
|
||||
}
|
||||
|
||||
// newServerTester creates test context for TLS frontend connection.
|
||||
func newServerTesterTLS(args []string, t *testing.T, handler http.HandlerFunc) *serverTester {
|
||||
return newServerTesterInternal(args, t, handler, true)
|
||||
}
|
||||
|
||||
// newServerTesterInternal creates test context. If frontendTLS is
|
||||
// true, set up TLS frontend connection.
|
||||
func newServerTesterInternal(args []string, t *testing.T, handler http.HandlerFunc, frontendTLS bool) *serverTester {
|
||||
ts := httptest.NewUnstartedServer(handler)
|
||||
|
||||
backendTLS := false
|
||||
for _, k := range args {
|
||||
if k == "--http2-bridge" {
|
||||
switch k {
|
||||
case "--http2-bridge":
|
||||
backendTLS = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if backendTLS {
|
||||
|
@ -75,26 +92,36 @@ func newServerTester(args []string, t *testing.T, handler http.HandlerFunc) *ser
|
|||
} else {
|
||||
ts.Start()
|
||||
}
|
||||
u, err := url.Parse(ts.URL)
|
||||
scheme := "http"
|
||||
if frontendTLS {
|
||||
scheme = "https"
|
||||
args = append(args, testDir+"/server.key", testDir+"/server.crt")
|
||||
} else {
|
||||
args = append(args, "--frontend-no-tls")
|
||||
}
|
||||
|
||||
backendURL, err := url.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Error parsing URL from httptest.Server: %v", err)
|
||||
}
|
||||
|
||||
// URL.Host looks like "127.0.0.1:8080", but we want
|
||||
// "127.0.0.1,8080"
|
||||
b := "-b" + strings.Replace(u.Host, ":", ",", -1)
|
||||
b := "-b" + strings.Replace(backendURL.Host, ":", ",", -1)
|
||||
args = append(args, fmt.Sprintf("-f127.0.0.1,%v", serverPort), b,
|
||||
"--errorlog-file="+buildDir+"/integration-tests/log.txt",
|
||||
"-LINFO", "--frontend-no-tls")
|
||||
"--errorlog-file="+testDir+"/log.txt", "-LINFO")
|
||||
|
||||
authority := fmt.Sprintf("127.0.0.1:%v", serverPort)
|
||||
|
||||
st := &serverTester{
|
||||
cmd: exec.Command(serverBin, args...),
|
||||
t: t,
|
||||
ts: ts,
|
||||
url: fmt.Sprintf("http://127.0.0.1:%v", serverPort),
|
||||
url: fmt.Sprintf("%v://%v", scheme, authority),
|
||||
nextStreamID: 1,
|
||||
authority: u.Host,
|
||||
authority: authority,
|
||||
frCh: make(chan http2.Frame),
|
||||
spdyFrCh: make(chan spdy.Frame),
|
||||
errCh: make(chan error),
|
||||
}
|
||||
|
||||
|
@ -104,20 +131,45 @@ func newServerTester(args []string, t *testing.T, handler http.HandlerFunc) *ser
|
|||
|
||||
retry := 0
|
||||
for {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%v", serverPort))
|
||||
var conn net.Conn
|
||||
var err error
|
||||
if frontendTLS {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
NextProtos: []string{"h2-14", "spdy/3.1"},
|
||||
}
|
||||
conn, err = tls.Dial("tcp", authority, tlsConfig)
|
||||
} else {
|
||||
conn, err = net.Dial("tcp", authority)
|
||||
}
|
||||
if err != nil {
|
||||
retry += 1
|
||||
if retry >= 100 {
|
||||
st.Close()
|
||||
st.t.Fatalf("Error server is not responding too long; server command-line arguments may be invalid")
|
||||
}
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if frontendTLS {
|
||||
tlsConn := conn.(*tls.Conn)
|
||||
cs := tlsConn.ConnectionState()
|
||||
if !cs.NegotiatedProtocolIsMutual {
|
||||
st.Close()
|
||||
st.t.Fatalf("Error negotiated next protocol is not mutual")
|
||||
}
|
||||
}
|
||||
st.conn = conn
|
||||
break
|
||||
}
|
||||
|
||||
st.fr = http2.NewFramer(st.conn, st.conn)
|
||||
spdyFr, err := spdy.NewFramer(st.conn, st.conn)
|
||||
if err != nil {
|
||||
st.Close()
|
||||
st.t.Fatalf("Error spdy.NewFramer: %v", err)
|
||||
}
|
||||
st.spdyFr = spdyFr
|
||||
st.enc = hpack.NewEncoder(&st.headerBlkBuf)
|
||||
st.dec = hpack.NewDecoder(4096, func(f hpack.HeaderField) {
|
||||
st.header.Add(f.Name, f.Value)
|
||||
|
@ -159,6 +211,26 @@ func (st *serverTester) readFrame() (http2.Frame, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (st *serverTester) readSpdyFrame() (spdy.Frame, error) {
|
||||
go func() {
|
||||
f, err := st.spdyFr.ReadFrame()
|
||||
if err != nil {
|
||||
st.errCh <- err
|
||||
return
|
||||
}
|
||||
st.spdyFrCh <- f
|
||||
}()
|
||||
|
||||
select {
|
||||
case f := <-st.spdyFrCh:
|
||||
return f, nil
|
||||
case err := <-st.errCh:
|
||||
return nil, err
|
||||
case <-time.After(2 * time.Second):
|
||||
return nil, errors.New("timeout waiting for frame")
|
||||
}
|
||||
}
|
||||
|
||||
type requestParam struct {
|
||||
name string // name for this request to identify the request in log easily
|
||||
streamID uint32 // stream ID, automatically assigned if 0
|
||||
|
@ -211,6 +283,118 @@ func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
|
|||
return res, nil
|
||||
}
|
||||
|
||||
func (st *serverTester) spdy(rp requestParam) (*serverResponse, error) {
|
||||
res := &serverResponse{}
|
||||
|
||||
var id spdy.StreamId
|
||||
if rp.streamID != 0 {
|
||||
id = spdy.StreamId(rp.streamID)
|
||||
if id >= spdy.StreamId(st.nextStreamID) && id%2 == 1 {
|
||||
st.nextStreamID = uint32(id) + 2
|
||||
}
|
||||
} else {
|
||||
id = spdy.StreamId(st.nextStreamID)
|
||||
st.nextStreamID += 2
|
||||
}
|
||||
|
||||
method := "GET"
|
||||
if rp.method != "" {
|
||||
method = rp.method
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if rp.scheme != "" {
|
||||
scheme = rp.scheme
|
||||
}
|
||||
|
||||
host := st.authority
|
||||
if rp.authority != "" {
|
||||
host = rp.authority
|
||||
}
|
||||
|
||||
path := "/"
|
||||
if rp.path != "" {
|
||||
path = rp.path
|
||||
}
|
||||
|
||||
header := make(http.Header)
|
||||
header.Add(":method", method)
|
||||
header.Add(":scheme", scheme)
|
||||
header.Add(":host", host)
|
||||
header.Add(":path", path)
|
||||
header.Add(":version", "HTTP/1.1")
|
||||
header.Add("test-case", rp.name)
|
||||
for _, h := range rp.header {
|
||||
header.Add(h.Name, h.Value)
|
||||
}
|
||||
|
||||
var synStreamFlags spdy.ControlFlags
|
||||
if len(rp.body) == 0 {
|
||||
synStreamFlags = spdy.ControlFlagFin
|
||||
}
|
||||
if err := st.spdyFr.WriteFrame(&spdy.SynStreamFrame{
|
||||
CFHeader: spdy.ControlFrameHeader{
|
||||
Flags: synStreamFlags,
|
||||
},
|
||||
StreamId: id,
|
||||
Headers: header,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(rp.body) != 0 {
|
||||
if err := st.spdyFr.WriteFrame(&spdy.DataFrame{
|
||||
StreamId: id,
|
||||
Flags: spdy.DataFlagFin,
|
||||
Data: rp.body,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
loop:
|
||||
for {
|
||||
fr, err := st.readSpdyFrame()
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
switch f := fr.(type) {
|
||||
case *spdy.SynReplyFrame:
|
||||
if f.StreamId != id {
|
||||
break
|
||||
}
|
||||
res.header = cloneHeader(f.Headers)
|
||||
if _, err := fmt.Sscan(res.header.Get(":status"), &res.status); err != nil {
|
||||
return res, fmt.Errorf("Error parsing status code: %v", err)
|
||||
}
|
||||
if f.CFHeader.Flags&spdy.ControlFlagFin != 0 {
|
||||
break loop
|
||||
}
|
||||
case *spdy.DataFrame:
|
||||
if f.StreamId != id {
|
||||
break
|
||||
}
|
||||
res.body = append(res.body, f.Data...)
|
||||
if f.Flags&spdy.DataFlagFin != 0 {
|
||||
break loop
|
||||
}
|
||||
case *spdy.RstStreamFrame:
|
||||
if f.StreamId != id {
|
||||
break
|
||||
}
|
||||
res.spdyRstErrCode = f.Status
|
||||
break loop
|
||||
case *spdy.GoAwayFrame:
|
||||
if f.Status == spdy.GoAwayOK {
|
||||
break
|
||||
}
|
||||
res.spdyGoAwayErrCode = f.Status
|
||||
break loop
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (st *serverTester) http2(rp requestParam) (*serverResponse, error) {
|
||||
res := &serverResponse{}
|
||||
st.headerBlkBuf.Reset()
|
||||
|
@ -299,11 +483,12 @@ loop:
|
|||
break
|
||||
}
|
||||
res.header = cloneHeader(st.header)
|
||||
res.status, err = strconv.Atoi(res.header.Get(":status"))
|
||||
var status int
|
||||
status, err = strconv.Atoi(res.header.Get(":status"))
|
||||
if err != nil {
|
||||
return res, fmt.Errorf("Error parsing status code: %v", err)
|
||||
}
|
||||
|
||||
res.status = status
|
||||
if f.StreamEnded() {
|
||||
break loop
|
||||
}
|
||||
|
@ -322,7 +507,7 @@ loop:
|
|||
res.errCode = f.ErrCode
|
||||
break loop
|
||||
case *http2.GoAwayFrame:
|
||||
if f.FrameHeader.StreamID != id || f.ErrCode == http2.ErrCodeNo {
|
||||
if f.ErrCode == http2.ErrCodeNo {
|
||||
break
|
||||
}
|
||||
res.errCode = f.ErrCode
|
||||
|
@ -335,6 +520,7 @@ loop:
|
|||
if err := st.fr.WriteSettingsAck(); err != nil {
|
||||
return res, err
|
||||
}
|
||||
// TODO handle PUSH_PROMISE as well, since it alters HPACK context
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
|
@ -344,8 +530,10 @@ type serverResponse struct {
|
|||
status int // HTTP status code
|
||||
header http.Header // response header fields
|
||||
body []byte // response body
|
||||
errCode http2.ErrCode // error code received in RST_STREAM or GOAWAY
|
||||
connErr bool // true if connection error
|
||||
errCode http2.ErrCode // error code received in HTTP/2 RST_STREAM or GOAWAY
|
||||
connErr bool // true if HTTP/2 connection error
|
||||
spdyGoAwayErrCode spdy.GoAwayStatus // status code received in SPDY RST_STREAM
|
||||
spdyRstErrCode spdy.RstStreamStatus // status code received in SPDY GOAWAY
|
||||
}
|
||||
|
||||
func cloneHeader(h http.Header) http.Header {
|
||||
|
|
|
@ -156,11 +156,23 @@ void on_ctrl_recv_callback(spdylay_session *session, spdylay_frame_type type,
|
|||
|
||||
auto nv = frame->syn_stream.nv;
|
||||
|
||||
if (LOG_ENABLED(INFO)) {
|
||||
std::stringstream ss;
|
||||
for (size_t i = 0; nv[i]; i += 2) {
|
||||
ss << TTY_HTTP_HD << nv[i] << TTY_RST << ": " << nv[i + 1] << "\n";
|
||||
}
|
||||
ULOG(INFO, upstream) << "HTTP request headers. stream_id="
|
||||
<< downstream->get_stream_id() << "\n" << ss.str();
|
||||
}
|
||||
|
||||
for (size_t i = 0; nv[i]; i += 2) {
|
||||
downstream->add_request_header(nv[i], nv[i + 1]);
|
||||
}
|
||||
|
||||
downstream->index_request_headers();
|
||||
if (downstream->index_request_headers() != 0) {
|
||||
upstream->error_reply(downstream, 400);
|
||||
return;
|
||||
}
|
||||
|
||||
auto path = downstream->get_request_header(http2::HD__PATH);
|
||||
auto scheme = downstream->get_request_header(http2::HD__SCHEME);
|
||||
|
@ -193,15 +205,6 @@ void on_ctrl_recv_callback(spdylay_session *session, spdylay_frame_type type,
|
|||
|
||||
downstream->inspect_http2_request();
|
||||
|
||||
if (LOG_ENABLED(INFO)) {
|
||||
std::stringstream ss;
|
||||
for (size_t i = 0; nv[i]; i += 2) {
|
||||
ss << TTY_HTTP_HD << nv[i] << TTY_RST << ": " << nv[i + 1] << "\n";
|
||||
}
|
||||
ULOG(INFO, upstream) << "HTTP request headers. stream_id="
|
||||
<< downstream->get_stream_id() << "\n" << ss.str();
|
||||
}
|
||||
|
||||
downstream->set_request_state(Downstream::HEADER_COMPLETE);
|
||||
if (frame->syn_stream.hd.flags & SPDYLAY_CTRL_FLAG_FIN) {
|
||||
if (!downstream->validate_request_bodylen()) {
|
||||
|
|
Loading…
Reference in New Issue