diff --git a/go.mod b/go.mod index 52950c0e..05c60529 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/yarpc/yab -go 1.23 +go 1.23.0 toolchain go1.24.0 diff --git a/options.go b/options.go index 921e3fc7..9a5be9af 100644 --- a/options.go +++ b/options.go @@ -95,6 +95,10 @@ type TransportOptions struct { HTTPMethod string `long:"http-method" description:"The HTTP method to use"` GRPCMaxResponseSize int `long:"grpc-max-response-size" description:"Maximum response size for gRPC requests. Default value is 4MB"` ForceJaegerSample bool `long:"force-jaeger-sample" description:"Force all requests to be sampled for Jaeger tracing (use with --jaeger)"` + + // Enables HTTP2 transport + UseHTTP2 bool `long:"http2" description:"Enable HTTP/2 for HTTP transport"` + // This is a hack to work around go-flags not allowing disabling flags: // https://github.com/jessevdk/go-flags/issues/191 // Do not specify this value in a defaults.ini file as it is not possible diff --git a/transport.go b/transport.go index cef4f177..7dfff2a3 100644 --- a/transport.go +++ b/transport.go @@ -185,6 +185,7 @@ func getTransport(opts TransportOptions, resolved resolvedProtocolEncoding, trac Encoding: resolved.enc.String(), URLs: opts.Peers, Tracer: tracer, + UseHTTP2: opts.UseHTTP2, } return transport.NewHTTP(hopts) } diff --git a/transport/http.go b/transport/http.go index e4b9f19d..91b638f3 100644 --- a/transport/http.go +++ b/transport/http.go @@ -22,14 +22,18 @@ package transport import ( "bytes" + "crypto/tls" "errors" "fmt" "io/ioutil" "math/rand" + "net" "net/http" "strconv" "time" + "golang.org/x/net/http2" + "github.com/opentracing/opentracing-go" "golang.org/x/net/context" ) @@ -51,6 +55,9 @@ type HTTPOptions struct { ShardKey string Encoding string Tracer opentracing.Tracer + + // HTTP/2 specific options + UseHTTP2 bool } var ( @@ -70,11 +77,25 @@ func NewHTTP(opts HTTPOptions) (Transport, error) { opts.Method = "POST" } + var transport http.RoundTripper + + if opts.UseHTTP2 { + transport = &http2.Transport{ + AllowHTTP: true, + DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, addr) + }, + } + } else { + transport = &http.Transport{} + } + return &httpTransport{ opts: opts, // Use independent HTTP clients for each transport. client: &http.Client{ - Transport: &http.Transport{}, + Transport: transport, }, tracer: opts.Tracer, }, nil diff --git a/transport/http_test.go b/transport/http_test.go index bdac7d3e..6aba71a8 100644 --- a/transport/http_test.go +++ b/transport/http_test.go @@ -22,7 +22,6 @@ package transport import ( "io" - "io/ioutil" "net/http" "net/http/httptest" "net/url" @@ -30,12 +29,21 @@ import ( "testing" "time" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + "golang.org/x/net/context" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +var ( + defaultIdleConnTimeout = 15 * time.Minute + serverTimeout = 20 * time.Millisecond + clientTimeout = 10 * time.Millisecond +) + func TestHTTPConstructor(t *testing.T) { tests := []struct { opts HTTPOptions @@ -71,73 +79,19 @@ func TestHTTPConstructor(t *testing.T) { } func TestHTTPCall(t *testing.T) { - timeoutCtx, _ := context.WithTimeout(context.Background(), 3*time.Second) - immediateTimeout, _ := context.WithTimeout(context.Background(), time.Nanosecond) - - tests := []struct { - msg string - ctxOverride context.Context - hook string - method string - errMsg string - ttlMin time.Duration - ttlMax time.Duration - wantCode int - wantBody []byte // If nil, uses the request body - }{ - { - msg: "ok", - ttlMin: time.Second, - ttlMax: time.Second, - wantCode: http.StatusOK, - }, - { - msg: "3 second timeout", - ctxOverride: timeoutCtx, - ttlMin: 3*time.Second - 100*time.Millisecond, - ttlMax: 3 * time.Second, - wantCode: http.StatusOK, - }, - { - msg: "timed out", - ctxOverride: immediateTimeout, - errMsg: context.DeadlineExceeded.Error(), - }, + tests := getCommonHttpTestCases() + tests = append(tests, []TestBody{ { msg: "connection closed before data", hook: "kill_conn", errMsg: "EOF", }, - { - msg: "bad request response", - hook: "bad_req", - errMsg: "non-success response code: 400, body: bad request", - }, { msg: "connection closed after data", hook: "flush_and_kill", errMsg: "unexpected EOF", }, - { - msg: "no content", - hook: "no_content", - wantCode: http.StatusNoContent, - wantBody: []byte{}, - }, - { - msg: "default method to POST", - hook: "method", - wantCode: http.StatusOK, - wantBody: []byte("POST"), - }, - { - msg: "override method to GET", - method: "GET", - hook: "method", - wantCode: http.StatusOK, - wantBody: []byte("GET"), - }, - } + }...) lastReq := struct { url *url.URL @@ -149,7 +103,7 @@ func TestHTTPCall(t *testing.T) { var err error lastReq.url = r.URL lastReq.headers = r.Header - lastReq.body, err = ioutil.ReadAll(r.Body) + lastReq.body, err = io.ReadAll(r.Body) require.NoError(t, err, "Failed to read body from request") // Test hooks to change the request behaviour. @@ -169,6 +123,9 @@ func TestHTTPCall(t *testing.T) { case "server_err": w.WriteHeader(http.StatusInternalServerError) return + case "timeout_hook": + time.Sleep(serverTimeout) + return case "flush_and_kill": io.WriteString(w, "some data") flusher := w.(http.Flusher) @@ -254,3 +211,185 @@ func TestHTTPCall(t *testing.T) { }) } } + +func TestHTTP2Call(t *testing.T) { + tests := getCommonHttpTestCases() + + lastReq := struct { + url *url.URL + headers http.Header + body []byte + }{} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + lastReq.url = r.URL + lastReq.headers = r.Header + lastReq.body, err = io.ReadAll(r.Body) + require.NoError(t, err, "Failed to read body from request") + + // Test hooks to change the request behaviour. + switch f := r.Header.Get("hook"); f { + case "no_content": + w.Header().Set("Custom-Header", "ok") + w.WriteHeader(http.StatusNoContent) + return + case "method": + w.Header().Set("Custom-Header", "ok") + io.WriteString(w, r.Method) + return + case "bad_req": + w.WriteHeader(http.StatusBadRequest) + io.WriteString(w, "bad request") + return + case "server_err": + w.WriteHeader(http.StatusInternalServerError) + return + case "timeout_hook": + time.Sleep(serverTimeout) + return + } + + w.Header().Set("Custom-Header", "ok") + io.WriteString(w, "ok") + }) + + h2s := &http2.Server{ + IdleTimeout: defaultIdleConnTimeout, + } + + svr := httptest.NewServer(h2c.NewHandler(handler, h2s)) + defer svr.Close() + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + transport, err := NewHTTP(HTTPOptions{ + Method: tt.method, + URLs: []string{svr.URL + "/rpc"}, + SourceService: "source", + TargetService: "target", + ShardKey: "sk", + RoutingKey: "rk", + RoutingDelegate: "rd", + Encoding: "raw", + UseHTTP2: true, + }) + require.NoError(t, err, "Failed to create HTTP transport") + + ctx := context.Background() + if tt.ctxOverride != nil { + ctx = tt.ctxOverride + } + + r := &Request{Method: "method", Body: []byte{1, 2, 3}} + + r.TransportHeaders = map[string]string{"hook": tt.hook} + r.Headers = map[string]string{"headerkey": "headervalue"} + got, err := transport.Call(ctx, r) + if tt.errMsg != "" { + if assert.Error(t, err, "Call should fail") { + assert.Contains(t, err.Error(), tt.errMsg, "Unexpected error") + } + return + } + + if !assert.NoError(t, err, "Call shouldn't fail") { + return + } + + wantBody := tt.wantBody + if wantBody == nil { + wantBody = []byte("ok") + } + if !assert.Equal(t, wantBody, got.Body, "Response body mismatch") { + return + } + + assert.Equal(t, "/rpc", lastReq.url.Path, "Path mismatch") + assert.Equal(t, "target", lastReq.headers.Get("Rpc-Service"), "Service header mismatch") + assert.Equal(t, "source", lastReq.headers.Get("Rpc-Caller"), "Caller header mismatch") + assert.Equal(t, "sk", lastReq.headers.Get("Rpc-Shard-Key"), "Shard key header mismatch") + assert.Equal(t, "rk", lastReq.headers.Get("Rpc-Routing-Key"), "Routing key header mismatch") + assert.Equal(t, "rd", lastReq.headers.Get("Rpc-Routing-Delegate"), "Routing delegate header mismatch") + assert.Equal(t, r.Method, lastReq.headers.Get("Rpc-Procedure"), "Method header mismatch") + assert.Equal(t, "raw", lastReq.headers.Get("Rpc-Encoding"), "Encoding header mismatch") + assert.Equal(t, "headervalue", lastReq.headers.Get("Rpc-Header-Headerkey"), "Application header is sent with prefix") + + if tt.ttlMin != 0 && tt.ttlMax != 0 { + ttlMS, err := strconv.Atoi(lastReq.headers.Get("Context-TTL-MS")) + if assert.NoError(t, err, "Failed to parse TTLms header: %v", lastReq.headers.Get("YARPC-TTLms")) { + gotTTL := time.Duration(ttlMS) * time.Millisecond + assert.True(t, gotTTL >= tt.ttlMin && gotTTL <= tt.ttlMax, + "Got TTL %v out of range [%v,%v]", gotTTL, tt.ttlMin, tt.ttlMax) + } + } + + assert.Equal(t, "ok", got.Headers["Custom-Header"], "Header mismatch") + assert.Equal(t, tt.wantCode, got.TransportFields["statusCode"], "Status code mismatch") + assert.Equal(t, lastReq.body, r.Body, "Body mismatch") + }) + } +} + +type TestBody struct { + msg string + ctxOverride context.Context + hook string + method string + errMsg string + ttlMin time.Duration + ttlMax time.Duration + wantCode int + wantBody []byte // If nil, uses the request body +} + +func getCommonHttpTestCases() []TestBody { + timeoutCtx, _ := context.WithTimeout(context.Background(), 3*time.Second) + immediateTimeout, _ := context.WithTimeout(context.Background(), clientTimeout) + + return []TestBody{ + { + msg: "ok", + ttlMin: time.Second, + ttlMax: time.Second, + wantCode: http.StatusOK, + }, + { + msg: "3 second timeout", + ctxOverride: timeoutCtx, + ttlMin: 3*time.Second - 100*time.Millisecond, + ttlMax: 3 * time.Second, + wantCode: http.StatusOK, + }, + { + msg: "timed out", + ctxOverride: immediateTimeout, + hook: "timeout_hook", + errMsg: context.DeadlineExceeded.Error(), + }, + { + msg: "bad request response", + hook: "bad_req", + errMsg: "non-success response code: 400, body: bad request", + }, + { + msg: "no content", + hook: "no_content", + wantCode: http.StatusNoContent, + wantBody: []byte{}, + }, + { + msg: "default method to POST", + hook: "method", + wantCode: http.StatusOK, + wantBody: []byte("POST"), + }, + { + msg: "override method to GET", + method: "GET", + hook: "method", + wantCode: http.StatusOK, + wantBody: []byte("GET"), + }, + } +}