nghttp2/integration-tests/server_tester.go

851 lines
20 KiB
Go

package nghttp2
import (
"bufio"
"bytes"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"os/exec"
"sort"
"strconv"
"strings"
"syscall"
"testing"
"time"
"github.com/lucas-clemente/quic-go/http3"
"github.com/tatsuhiro-t/go-nghttp2"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/websocket"
)
const (
serverBin = buildDir + "/src/nghttpx"
serverPort = 3009
testDir = sourceDir + "/integration-tests"
logDir = buildDir + "/integration-tests"
)
func pair(name, value string) hpack.HeaderField {
return hpack.HeaderField{
Name: name,
Value: value,
}
}
type serverTester struct {
args []string // command-line arguments
cmd *exec.Cmd // test frontend server process, which is test subject
url string // test frontend server URL
t *testing.T
ts *httptest.Server // backend server
frontendHost string // frontend server host
backendHost string // backend server host
conn net.Conn // connection to frontend server
h2PrefaceSent bool // HTTP/2 preface was sent in conn
nextStreamID uint32 // next stream ID
fr *http2.Framer // HTTP/2 framer
headerBlkBuf bytes.Buffer // buffer to store encoded header block
enc *hpack.Encoder // HTTP/2 HPACK encoder
header http.Header // received header fields
dec *hpack.Decoder // HTTP/2 HPACK decoder
authority string // server's host:port
frCh chan http2.Frame // used for incoming HTTP/2 frame
errCh chan error
}
type options struct {
// args is the additional arguments to nghttpx.
args []string
// handler is the handler to handle the request. It defaults
// to noopHandler.
handler http.HandlerFunc
// connectPort is the server side port where client connection
// is made. It defaults to serverPort.
connectPort int
// tls, if set to true, sets up TLS frontend connection.
tls bool
// tlsConfig is the client side TLS configuration that is used
// when tls is true.
tlsConfig *tls.Config
// tcpData is additional data that are written to connection
// before TLS handshake starts. This field is ignored if tls
// is false.
tcpData []byte
// quic, if set to true, sets up QUIC frontend connection.
// quic implies tls = true.
quic bool
}
// newServerTester creates test context.
func newServerTester(t *testing.T, opts options) *serverTester {
if opts.quic {
opts.tls = true
}
if opts.handler == nil {
opts.handler = noopHandler
}
if opts.connectPort == 0 {
opts.connectPort = serverPort
}
ts := httptest.NewUnstartedServer(opts.handler)
var args []string
var backendTLS, dns, externalDNS, acceptProxyProtocol, redirectIfNotTLS, affinityCookie, alpnH1 bool
for _, k := range opts.args {
switch k {
case "--http2-bridge":
backendTLS = true
case "--dns":
dns = true
case "--external-dns":
dns = true
externalDNS = true
case "--accept-proxy-protocol":
acceptProxyProtocol = true
case "--redirect-if-not-tls":
redirectIfNotTLS = true
case "--affinity-cookie":
affinityCookie = true
case "--alpn-h1":
alpnH1 = true
default:
args = append(args, k)
}
}
if backendTLS {
nghttp2.ConfigureServer(ts.Config, &nghttp2.Server{})
// According to httptest/server.go, we have to set
// NextProtos separately for ts.TLS. NextProtos set
// in nghttp2.ConfigureServer is effectively ignored.
ts.TLS = new(tls.Config)
ts.TLS.NextProtos = append(ts.TLS.NextProtos, "h2")
ts.StartTLS()
args = append(args, "-k")
} else {
ts.Start()
}
scheme := "http"
if opts.tls {
scheme = "https"
args = append(args, testDir+"/server.key", testDir+"/server.crt")
}
backendURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("Error parsing URL from httptest.Server: %v", err)
}
// URL.Host looks like "127.0.0.1:8080", but we want
// "127.0.0.1,8080"
b := "-b"
if !externalDNS {
b += fmt.Sprintf("%v;", strings.Replace(backendURL.Host, ":", ",", -1))
} else {
sep := strings.LastIndex(backendURL.Host, ":")
if sep == -1 {
t.Fatalf("backendURL.Host %v does not contain separator ':'", backendURL.Host)
}
// We use awesome service nip.io.
b += fmt.Sprintf("%v.nip.io,%v;", backendURL.Host[:sep], backendURL.Host[sep+1:])
}
if backendTLS {
b += ";proto=h2;tls"
}
if dns {
b += ";dns"
}
if redirectIfNotTLS {
b += ";redirect-if-not-tls"
}
if affinityCookie {
b += ";affinity=cookie;affinity-cookie-name=affinity;affinity-cookie-path=/foo/bar"
}
noTLS := ";no-tls"
if opts.tls {
noTLS = ""
}
var proxyProto string
if acceptProxyProtocol {
proxyProto = ";proxyproto"
}
args = append(args, fmt.Sprintf("-f127.0.0.1,%v%v%v", serverPort, noTLS, proxyProto), b,
"--errorlog-file="+logDir+"/log.txt", "-LINFO")
if opts.quic {
args = append(args,
fmt.Sprintf("-f127.0.0.1,%v;quic", serverPort),
"--no-quic-bpf")
}
authority := fmt.Sprintf("127.0.0.1:%v", opts.connectPort)
st := &serverTester{
cmd: exec.Command(serverBin, args...),
t: t,
ts: ts,
url: fmt.Sprintf("%v://%v", scheme, authority),
frontendHost: fmt.Sprintf("127.0.0.1:%v", serverPort),
backendHost: backendURL.Host,
nextStreamID: 1,
authority: authority,
frCh: make(chan http2.Frame),
errCh: make(chan error),
}
st.cmd.Stdout = os.Stdout
st.cmd.Stderr = os.Stderr
if err := st.cmd.Start(); err != nil {
st.t.Fatalf("Error starting %v: %v", serverBin, err)
}
retry := 0
for {
time.Sleep(50 * time.Millisecond)
conn, err := net.Dial("tcp", authority)
if err == nil && opts.tls {
if len(opts.tcpData) > 0 {
if _, err := conn.Write(opts.tcpData); err != nil {
st.Close()
st.t.Fatal("Error writing TCP data")
}
}
var tlsConfig *tls.Config
if opts.tlsConfig == nil {
tlsConfig = new(tls.Config)
} else {
tlsConfig = opts.tlsConfig.Clone()
}
tlsConfig.InsecureSkipVerify = true
if alpnH1 {
tlsConfig.NextProtos = []string{"http/1.1"}
} else {
tlsConfig.NextProtos = []string{"h2"}
}
tlsConn := tls.Client(conn, tlsConfig)
err = tlsConn.Handshake()
if err == nil {
cs := tlsConn.ConnectionState()
if !cs.NegotiatedProtocolIsMutual {
st.Close()
st.t.Fatalf("Error negotiated next protocol is not mutual")
}
conn = tlsConn
}
}
if err != nil {
retry += 1
if retry >= 100 {
st.Close()
st.t.Fatalf("Error server is not responding too long; server command-line arguments may be invalid")
}
continue
}
st.conn = conn
break
}
st.fr = http2.NewFramer(st.conn, st.conn)
st.enc = hpack.NewEncoder(&st.headerBlkBuf)
st.dec = hpack.NewDecoder(4096, func(f hpack.HeaderField) {
st.header.Add(f.Name, f.Value)
})
return st
}
func (st *serverTester) Close() {
if st.conn != nil {
st.conn.Close()
}
if st.cmd != nil {
done := make(chan struct{})
go func() {
st.cmd.Wait()
close(done)
}()
st.cmd.Process.Signal(syscall.SIGQUIT)
select {
case <-done:
case <-time.After(10 * time.Second):
st.cmd.Process.Kill()
<-done
}
}
if st.ts != nil {
st.ts.Close()
}
}
func (st *serverTester) readFrame() (http2.Frame, error) {
go func() {
f, err := st.fr.ReadFrame()
if err != nil {
st.errCh <- err
return
}
st.frCh <- f
}()
select {
case f := <-st.frCh:
return f, nil
case err := <-st.errCh:
return nil, err
case <-time.After(5 * time.Second):
return nil, errors.New("timeout waiting for frame")
}
}
type requestParam struct {
name string // name for this request to identify the request in log easily
streamID uint32 // stream ID, automatically assigned if 0
method string // method, defaults to GET
scheme string // scheme, defaults to http
authority string // authority, defaults to backend server address
path string // path, defaults to /
header []hpack.HeaderField // additional request header fields
body []byte // request body
trailer []hpack.HeaderField // trailer part
httpUpgrade bool // true if upgraded to HTTP/2 through HTTP Upgrade
noEndStream bool // true if END_STREAM should not be sent
}
// wrapper for request body to set trailer part
type chunkedBodyReader struct {
trailer []hpack.HeaderField
trailerWritten bool
body io.Reader
req *http.Request
}
func (cbr *chunkedBodyReader) Read(p []byte) (n int, err error) {
// document says that we have to set http.Request.Trailer
// after request was sent and before body returns EOF.
if !cbr.trailerWritten {
cbr.trailerWritten = true
for _, h := range cbr.trailer {
cbr.req.Trailer.Set(h.Name, h.Value)
}
}
return cbr.body.Read(p)
}
func (st *serverTester) websocket(rp requestParam) (*serverResponse, error) {
urlstring := st.url + "/echo"
config, err := websocket.NewConfig(urlstring, st.url)
if err != nil {
st.t.Fatalf("websocket.NewConfig(%q, %q) returned error: %v", urlstring, st.url, err)
}
config.Header.Add("Test-Case", rp.name)
for _, h := range rp.header {
config.Header.Add(h.Name, h.Value)
}
ws, err := websocket.NewClient(config, st.conn)
if err != nil {
st.t.Fatalf("Error creating websocket client: %v", err)
}
if _, err := ws.Write(rp.body); err != nil {
st.t.Fatalf("ws.Write() returned error: %v", err)
}
msg := make([]byte, 1024)
var n int
if n, err = ws.Read(msg); err != nil {
st.t.Fatalf("ws.Read() returned error: %v", err)
}
res := &serverResponse{
body: msg[:n],
}
return res, nil
}
func (st *serverTester) http3(rp requestParam) (*serverResponse, error) {
rt := &http3.RoundTripper{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
defer rt.Close()
c := &http.Client{
Transport: rt,
}
method := "GET"
if rp.method != "" {
method = rp.method
}
var body io.Reader
if rp.body != nil {
body = bytes.NewBuffer(rp.body)
}
reqURL := st.url
if rp.path != "" {
u, err := url.Parse(st.url)
if err != nil {
st.t.Fatalf("Error parsing URL from st.url %v: %v", st.url, err)
}
u.Path = ""
u.RawQuery = ""
reqURL = u.String() + rp.path
}
req, err := http.NewRequest(method, reqURL, body)
if err != nil {
return nil, err
}
for _, h := range rp.header {
req.Header.Add(h.Name, h.Value)
}
req.Header.Add("Test-Case", rp.name)
// TODO http3 package does not support trailer at the time of
// this writing.
resp, err := c.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
res := &serverResponse{
status: resp.StatusCode,
header: resp.Header,
body: respBody,
connClose: resp.Close,
}
return res, nil
}
func (st *serverTester) http1(rp requestParam) (*serverResponse, error) {
method := "GET"
if rp.method != "" {
method = rp.method
}
var body io.Reader
var cbr *chunkedBodyReader
if rp.body != nil {
body = bytes.NewBuffer(rp.body)
if len(rp.trailer) != 0 {
cbr = &chunkedBodyReader{
trailer: rp.trailer,
body: body,
}
body = cbr
}
}
reqURL := st.url
if rp.path != "" {
u, err := url.Parse(st.url)
if err != nil {
st.t.Fatalf("Error parsing URL from st.url %v: %v", st.url, err)
}
u.Path = ""
u.RawQuery = ""
reqURL = u.String() + rp.path
}
req, err := http.NewRequest(method, reqURL, body)
if err != nil {
return nil, err
}
for _, h := range rp.header {
req.Header.Add(h.Name, h.Value)
}
req.Header.Add("Test-Case", rp.name)
if cbr != nil {
cbr.req = req
// this makes request use chunked encoding
req.ContentLength = -1
req.Trailer = make(http.Header)
for _, h := range cbr.trailer {
req.Trailer.Set(h.Name, "")
}
}
if err := req.Write(st.conn); err != nil {
return nil, err
}
resp, err := http.ReadResponse(bufio.NewReader(st.conn), req)
if err != nil {
return nil, err
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
resp.Body.Close()
res := &serverResponse{
status: resp.StatusCode,
header: resp.Header,
body: respBody,
connClose: resp.Close,
}
return res, nil
}
func (st *serverTester) http2(rp requestParam) (*serverResponse, error) {
st.headerBlkBuf.Reset()
st.header = make(http.Header)
var id uint32
if rp.streamID != 0 {
id = rp.streamID
if id >= st.nextStreamID && id%2 == 1 {
st.nextStreamID = id + 2
}
} else {
id = st.nextStreamID
st.nextStreamID += 2
}
if !st.h2PrefaceSent {
st.h2PrefaceSent = true
fmt.Fprint(st.conn, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
if err := st.fr.WriteSettings(); err != nil {
return nil, err
}
}
res := &serverResponse{
streamID: id,
}
streams := make(map[uint32]*serverResponse)
streams[id] = res
if !rp.httpUpgrade {
method := "GET"
if rp.method != "" {
method = rp.method
}
_ = st.enc.WriteField(pair(":method", method))
scheme := "http"
if rp.scheme != "" {
scheme = rp.scheme
}
_ = st.enc.WriteField(pair(":scheme", scheme))
authority := st.authority
if rp.authority != "" {
authority = rp.authority
}
_ = st.enc.WriteField(pair(":authority", authority))
path := "/"
if rp.path != "" {
path = rp.path
}
_ = st.enc.WriteField(pair(":path", path))
_ = st.enc.WriteField(pair("test-case", rp.name))
for _, h := range rp.header {
_ = st.enc.WriteField(h)
}
err := st.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: id,
EndStream: len(rp.body) == 0 && len(rp.trailer) == 0 && !rp.noEndStream,
EndHeaders: true,
BlockFragment: st.headerBlkBuf.Bytes(),
})
if err != nil {
return nil, err
}
if len(rp.body) != 0 {
// TODO we assume rp.body fits in 1 frame
if err := st.fr.WriteData(id, len(rp.trailer) == 0 && !rp.noEndStream, rp.body); err != nil {
return nil, err
}
}
if len(rp.trailer) != 0 {
st.headerBlkBuf.Reset()
for _, h := range rp.trailer {
_ = st.enc.WriteField(h)
}
err := st.fr.WriteHeaders(http2.HeadersFrameParam{
StreamID: id,
EndStream: true,
EndHeaders: true,
BlockFragment: st.headerBlkBuf.Bytes(),
})
if err != nil {
return nil, err
}
}
}
loop:
for {
fr, err := st.readFrame()
if err != nil {
return res, err
}
switch f := fr.(type) {
case *http2.HeadersFrame:
_, err := st.dec.Write(f.HeaderBlockFragment())
if err != nil {
return res, err
}
sr, ok := streams[f.FrameHeader.StreamID]
if !ok {
st.header = make(http.Header)
break
}
sr.header = cloneHeader(st.header)
var status int
status, err = strconv.Atoi(sr.header.Get(":status"))
if err != nil {
return res, fmt.Errorf("Error parsing status code: %w", err)
}
sr.status = status
if f.StreamEnded() {
if streamEnded(res, streams, sr) {
break loop
}
}
case *http2.PushPromiseFrame:
_, err := st.dec.Write(f.HeaderBlockFragment())
if err != nil {
return res, err
}
sr := &serverResponse{
streamID: f.PromiseID,
reqHeader: cloneHeader(st.header),
}
streams[sr.streamID] = sr
case *http2.DataFrame:
sr, ok := streams[f.FrameHeader.StreamID]
if !ok {
break
}
sr.body = append(sr.body, f.Data()...)
if f.StreamEnded() {
if streamEnded(res, streams, sr) {
break loop
}
}
case *http2.RSTStreamFrame:
sr, ok := streams[f.FrameHeader.StreamID]
if !ok {
break
}
sr.errCode = f.ErrCode
if streamEnded(res, streams, sr) {
break loop
}
case *http2.GoAwayFrame:
if f.ErrCode == http2.ErrCodeNo {
break
}
res.errCode = f.ErrCode
res.connErr = true
break loop
case *http2.SettingsFrame:
if f.IsAck() {
break
}
if err := st.fr.WriteSettingsAck(); err != nil {
return res, err
}
}
}
sort.Sort(ByStreamID(res.pushResponse))
return res, nil
}
func streamEnded(mainSr *serverResponse, streams map[uint32]*serverResponse, sr *serverResponse) bool {
delete(streams, sr.streamID)
if mainSr.streamID != sr.streamID {
mainSr.pushResponse = append(mainSr.pushResponse, sr)
}
return len(streams) == 0
}
type serverResponse struct {
status int // HTTP status code
header http.Header // response header fields
body []byte // response body
streamID uint32 // stream ID in HTTP/2
errCode http2.ErrCode // error code received in HTTP/2 RST_STREAM or GOAWAY
connErr bool // true if HTTP/2 connection error
connClose bool // Connection: close is included in response header in HTTP/1 test
reqHeader http.Header // http request header, currently only sotres pushed request header
pushResponse []*serverResponse // pushed response
}
type ByStreamID []*serverResponse
func (b ByStreamID) Len() int {
return len(b)
}
func (b ByStreamID) Swap(i, j int) {
b[i], b[j] = b[j], b[i]
}
func (b ByStreamID) Less(i, j int) bool {
return b[i].streamID < b[j].streamID
}
func cloneHeader(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
return h2
}
func noopHandler(w http.ResponseWriter, r *http.Request) {
io.ReadAll(r.Body)
}
type APIResponse struct {
Status string `json:"status,omitempty"`
Code int `json:"code,omitempty"`
Data map[string]interface{} `json:"data,omitempty"`
}
type proxyProtocolV2 struct {
command proxyProtocolV2Command
sourceAddress net.Addr
destinationAddress net.Addr
additionalData []byte
}
type proxyProtocolV2Command int
const (
proxyProtocolV2CommandLocal proxyProtocolV2Command = 0x0
proxyProtocolV2CommandProxy proxyProtocolV2Command = 0x1
)
type proxyProtocolV2Family int
const (
proxyProtocolV2FamilyUnspec proxyProtocolV2Family = 0x0
proxyProtocolV2FamilyInet proxyProtocolV2Family = 0x1
proxyProtocolV2FamilyInet6 proxyProtocolV2Family = 0x2
proxyProtocolV2FamilyUnix proxyProtocolV2Family = 0x3
)
type proxyProtocolV2Protocol int
const (
proxyProtocolV2ProtocolUnspec proxyProtocolV2Protocol = 0x0
proxyProtocolV2ProtocolStream proxyProtocolV2Protocol = 0x1
proxyProtocolV2ProtocolDgram proxyProtocolV2Protocol = 0x2
)
func writeProxyProtocolV2(w io.Writer, hdr proxyProtocolV2) {
w.Write([]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A})
w.Write([]byte{byte(0x20 | hdr.command)})
switch srcAddr := hdr.sourceAddress.(type) {
case *net.TCPAddr:
dstAddr := hdr.destinationAddress.(*net.TCPAddr)
if len(srcAddr.IP) != len(dstAddr.IP) {
panic("len(srcAddr.IP) != len(dstAddr.IP)")
}
var fam byte
if len(srcAddr.IP) == 4 {
fam = byte(proxyProtocolV2FamilyInet << 4)
} else {
fam = byte(proxyProtocolV2FamilyInet6 << 4)
}
fam |= byte(proxyProtocolV2ProtocolStream)
w.Write([]byte{fam})
length := uint16(len(srcAddr.IP)*2 + 4 + len(hdr.additionalData))
binary.Write(w, binary.BigEndian, length)
w.Write(srcAddr.IP)
w.Write(dstAddr.IP)
binary.Write(w, binary.BigEndian, uint16(srcAddr.Port))
binary.Write(w, binary.BigEndian, uint16(dstAddr.Port))
case *net.UnixAddr:
dstAddr := hdr.destinationAddress.(*net.UnixAddr)
if len(srcAddr.Name) > 108 {
panic("too long Unix source address")
}
if len(dstAddr.Name) > 108 {
panic("too long Unix destination address")
}
fam := byte(proxyProtocolV2FamilyUnix << 4)
switch srcAddr.Net {
case "unix":
fam |= byte(proxyProtocolV2ProtocolStream)
case "unixdgram":
fam |= byte(proxyProtocolV2ProtocolDgram)
default:
fam |= byte(proxyProtocolV2ProtocolUnspec)
}
w.Write([]byte{fam})
length := uint16(216 + len(hdr.additionalData))
binary.Write(w, binary.BigEndian, length)
zeros := make([]byte, 108)
w.Write([]byte(srcAddr.Name))
w.Write(zeros[:108-len(srcAddr.Name)])
w.Write([]byte(dstAddr.Name))
w.Write(zeros[:108-len(dstAddr.Name)])
default:
fam := byte(proxyProtocolV2FamilyUnspec<<4) | byte(proxyProtocolV2ProtocolUnspec)
w.Write([]byte{fam})
length := uint16(len(hdr.additionalData))
binary.Write(w, binary.BigEndian, length)
}
w.Write(hdr.additionalData)
}