From 8cf0cbfed4bac9a4f97ef1487e5be34c58d12e30 Mon Sep 17 00:00:00 2001 From: Mark Date: Mon, 18 Aug 2025 16:42:41 +0200 Subject: [PATCH] Added http2 support to yab Remove old comments Fix comment stating default http2 behavior in options Fixed test TimedOutUsingHTTP2Transport Fix comments Check if http2 flag is enabled when sending Thrift encoding Fix cli http2 option comment; Move test cases for http1/2 to method Reverted check for http2 flag in cli; Fix creation of http1 server --- go.mod | 2 +- options.go | 4 + transport.go | 1 + transport/http.go | 23 +++- transport/http_test.go | 257 +++++++++++++++++++++++++++++++---------- 5 files changed, 226 insertions(+), 61 deletions(-) diff --git a/go.mod b/go.mod index 52950c0e..05c60529 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/yarpc/yab -go 1.23 +go 1.23.0 toolchain go1.24.0 diff --git a/options.go b/options.go index 921e3fc7..9a5be9af 100644 --- a/options.go +++ b/options.go @@ -95,6 +95,10 @@ type TransportOptions struct { HTTPMethod string `long:"http-method" description:"The HTTP method to use"` GRPCMaxResponseSize int `long:"grpc-max-response-size" description:"Maximum response size for gRPC requests. Default value is 4MB"` ForceJaegerSample bool `long:"force-jaeger-sample" description:"Force all requests to be sampled for Jaeger tracing (use with --jaeger)"` + + // Enables HTTP2 transport + UseHTTP2 bool `long:"http2" description:"Enable HTTP/2 for HTTP transport"` + // This is a hack to work around go-flags not allowing disabling flags: // https://github.com/jessevdk/go-flags/issues/191 // Do not specify this value in a defaults.ini file as it is not possible diff --git a/transport.go b/transport.go index cef4f177..7dfff2a3 100644 --- a/transport.go +++ b/transport.go @@ -185,6 +185,7 @@ func getTransport(opts TransportOptions, resolved resolvedProtocolEncoding, trac Encoding: resolved.enc.String(), URLs: opts.Peers, Tracer: tracer, + UseHTTP2: opts.UseHTTP2, } return transport.NewHTTP(hopts) } diff --git a/transport/http.go b/transport/http.go index e4b9f19d..91b638f3 100644 --- a/transport/http.go +++ b/transport/http.go @@ -22,14 +22,18 @@ package transport import ( "bytes" + "crypto/tls" "errors" "fmt" "io/ioutil" "math/rand" + "net" "net/http" "strconv" "time" + "golang.org/x/net/http2" + "github.com/opentracing/opentracing-go" "golang.org/x/net/context" ) @@ -51,6 +55,9 @@ type HTTPOptions struct { ShardKey string Encoding string Tracer opentracing.Tracer + + // HTTP/2 specific options + UseHTTP2 bool } var ( @@ -70,11 +77,25 @@ func NewHTTP(opts HTTPOptions) (Transport, error) { opts.Method = "POST" } + var transport http.RoundTripper + + if opts.UseHTTP2 { + transport = &http2.Transport{ + AllowHTTP: true, + DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, addr) + }, + } + } else { + transport = &http.Transport{} + } + return &httpTransport{ opts: opts, // Use independent HTTP clients for each transport. client: &http.Client{ - Transport: &http.Transport{}, + Transport: transport, }, tracer: opts.Tracer, }, nil diff --git a/transport/http_test.go b/transport/http_test.go index bdac7d3e..6aba71a8 100644 --- a/transport/http_test.go +++ b/transport/http_test.go @@ -22,7 +22,6 @@ package transport import ( "io" - "io/ioutil" "net/http" "net/http/httptest" "net/url" @@ -30,12 +29,21 @@ import ( "testing" "time" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + "golang.org/x/net/context" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +var ( + defaultIdleConnTimeout = 15 * time.Minute + serverTimeout = 20 * time.Millisecond + clientTimeout = 10 * time.Millisecond +) + func TestHTTPConstructor(t *testing.T) { tests := []struct { opts HTTPOptions @@ -71,73 +79,19 @@ func TestHTTPConstructor(t *testing.T) { } func TestHTTPCall(t *testing.T) { - timeoutCtx, _ := context.WithTimeout(context.Background(), 3*time.Second) - immediateTimeout, _ := context.WithTimeout(context.Background(), time.Nanosecond) - - tests := []struct { - msg string - ctxOverride context.Context - hook string - method string - errMsg string - ttlMin time.Duration - ttlMax time.Duration - wantCode int - wantBody []byte // If nil, uses the request body - }{ - { - msg: "ok", - ttlMin: time.Second, - ttlMax: time.Second, - wantCode: http.StatusOK, - }, - { - msg: "3 second timeout", - ctxOverride: timeoutCtx, - ttlMin: 3*time.Second - 100*time.Millisecond, - ttlMax: 3 * time.Second, - wantCode: http.StatusOK, - }, - { - msg: "timed out", - ctxOverride: immediateTimeout, - errMsg: context.DeadlineExceeded.Error(), - }, + tests := getCommonHttpTestCases() + tests = append(tests, []TestBody{ { msg: "connection closed before data", hook: "kill_conn", errMsg: "EOF", }, - { - msg: "bad request response", - hook: "bad_req", - errMsg: "non-success response code: 400, body: bad request", - }, { msg: "connection closed after data", hook: "flush_and_kill", errMsg: "unexpected EOF", }, - { - msg: "no content", - hook: "no_content", - wantCode: http.StatusNoContent, - wantBody: []byte{}, - }, - { - msg: "default method to POST", - hook: "method", - wantCode: http.StatusOK, - wantBody: []byte("POST"), - }, - { - msg: "override method to GET", - method: "GET", - hook: "method", - wantCode: http.StatusOK, - wantBody: []byte("GET"), - }, - } + }...) lastReq := struct { url *url.URL @@ -149,7 +103,7 @@ func TestHTTPCall(t *testing.T) { var err error lastReq.url = r.URL lastReq.headers = r.Header - lastReq.body, err = ioutil.ReadAll(r.Body) + lastReq.body, err = io.ReadAll(r.Body) require.NoError(t, err, "Failed to read body from request") // Test hooks to change the request behaviour. @@ -169,6 +123,9 @@ func TestHTTPCall(t *testing.T) { case "server_err": w.WriteHeader(http.StatusInternalServerError) return + case "timeout_hook": + time.Sleep(serverTimeout) + return case "flush_and_kill": io.WriteString(w, "some data") flusher := w.(http.Flusher) @@ -254,3 +211,185 @@ func TestHTTPCall(t *testing.T) { }) } } + +func TestHTTP2Call(t *testing.T) { + tests := getCommonHttpTestCases() + + lastReq := struct { + url *url.URL + headers http.Header + body []byte + }{} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var err error + lastReq.url = r.URL + lastReq.headers = r.Header + lastReq.body, err = io.ReadAll(r.Body) + require.NoError(t, err, "Failed to read body from request") + + // Test hooks to change the request behaviour. + switch f := r.Header.Get("hook"); f { + case "no_content": + w.Header().Set("Custom-Header", "ok") + w.WriteHeader(http.StatusNoContent) + return + case "method": + w.Header().Set("Custom-Header", "ok") + io.WriteString(w, r.Method) + return + case "bad_req": + w.WriteHeader(http.StatusBadRequest) + io.WriteString(w, "bad request") + return + case "server_err": + w.WriteHeader(http.StatusInternalServerError) + return + case "timeout_hook": + time.Sleep(serverTimeout) + return + } + + w.Header().Set("Custom-Header", "ok") + io.WriteString(w, "ok") + }) + + h2s := &http2.Server{ + IdleTimeout: defaultIdleConnTimeout, + } + + svr := httptest.NewServer(h2c.NewHandler(handler, h2s)) + defer svr.Close() + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + transport, err := NewHTTP(HTTPOptions{ + Method: tt.method, + URLs: []string{svr.URL + "/rpc"}, + SourceService: "source", + TargetService: "target", + ShardKey: "sk", + RoutingKey: "rk", + RoutingDelegate: "rd", + Encoding: "raw", + UseHTTP2: true, + }) + require.NoError(t, err, "Failed to create HTTP transport") + + ctx := context.Background() + if tt.ctxOverride != nil { + ctx = tt.ctxOverride + } + + r := &Request{Method: "method", Body: []byte{1, 2, 3}} + + r.TransportHeaders = map[string]string{"hook": tt.hook} + r.Headers = map[string]string{"headerkey": "headervalue"} + got, err := transport.Call(ctx, r) + if tt.errMsg != "" { + if assert.Error(t, err, "Call should fail") { + assert.Contains(t, err.Error(), tt.errMsg, "Unexpected error") + } + return + } + + if !assert.NoError(t, err, "Call shouldn't fail") { + return + } + + wantBody := tt.wantBody + if wantBody == nil { + wantBody = []byte("ok") + } + if !assert.Equal(t, wantBody, got.Body, "Response body mismatch") { + return + } + + assert.Equal(t, "/rpc", lastReq.url.Path, "Path mismatch") + assert.Equal(t, "target", lastReq.headers.Get("Rpc-Service"), "Service header mismatch") + assert.Equal(t, "source", lastReq.headers.Get("Rpc-Caller"), "Caller header mismatch") + assert.Equal(t, "sk", lastReq.headers.Get("Rpc-Shard-Key"), "Shard key header mismatch") + assert.Equal(t, "rk", lastReq.headers.Get("Rpc-Routing-Key"), "Routing key header mismatch") + assert.Equal(t, "rd", lastReq.headers.Get("Rpc-Routing-Delegate"), "Routing delegate header mismatch") + assert.Equal(t, r.Method, lastReq.headers.Get("Rpc-Procedure"), "Method header mismatch") + assert.Equal(t, "raw", lastReq.headers.Get("Rpc-Encoding"), "Encoding header mismatch") + assert.Equal(t, "headervalue", lastReq.headers.Get("Rpc-Header-Headerkey"), "Application header is sent with prefix") + + if tt.ttlMin != 0 && tt.ttlMax != 0 { + ttlMS, err := strconv.Atoi(lastReq.headers.Get("Context-TTL-MS")) + if assert.NoError(t, err, "Failed to parse TTLms header: %v", lastReq.headers.Get("YARPC-TTLms")) { + gotTTL := time.Duration(ttlMS) * time.Millisecond + assert.True(t, gotTTL >= tt.ttlMin && gotTTL <= tt.ttlMax, + "Got TTL %v out of range [%v,%v]", gotTTL, tt.ttlMin, tt.ttlMax) + } + } + + assert.Equal(t, "ok", got.Headers["Custom-Header"], "Header mismatch") + assert.Equal(t, tt.wantCode, got.TransportFields["statusCode"], "Status code mismatch") + assert.Equal(t, lastReq.body, r.Body, "Body mismatch") + }) + } +} + +type TestBody struct { + msg string + ctxOverride context.Context + hook string + method string + errMsg string + ttlMin time.Duration + ttlMax time.Duration + wantCode int + wantBody []byte // If nil, uses the request body +} + +func getCommonHttpTestCases() []TestBody { + timeoutCtx, _ := context.WithTimeout(context.Background(), 3*time.Second) + immediateTimeout, _ := context.WithTimeout(context.Background(), clientTimeout) + + return []TestBody{ + { + msg: "ok", + ttlMin: time.Second, + ttlMax: time.Second, + wantCode: http.StatusOK, + }, + { + msg: "3 second timeout", + ctxOverride: timeoutCtx, + ttlMin: 3*time.Second - 100*time.Millisecond, + ttlMax: 3 * time.Second, + wantCode: http.StatusOK, + }, + { + msg: "timed out", + ctxOverride: immediateTimeout, + hook: "timeout_hook", + errMsg: context.DeadlineExceeded.Error(), + }, + { + msg: "bad request response", + hook: "bad_req", + errMsg: "non-success response code: 400, body: bad request", + }, + { + msg: "no content", + hook: "no_content", + wantCode: http.StatusNoContent, + wantBody: []byte{}, + }, + { + msg: "default method to POST", + hook: "method", + wantCode: http.StatusOK, + wantBody: []byte("POST"), + }, + { + msg: "override method to GET", + method: "GET", + hook: "method", + wantCode: http.StatusOK, + wantBody: []byte("GET"), + }, + } +}