Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/yarpc/yab

go 1.23
go 1.23.0

toolchain go1.24.0

Expand Down
4 changes: 4 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ type TransportOptions struct {
HTTPMethod string `long:"http-method" description:"The HTTP method to use"`
GRPCMaxResponseSize int `long:"grpc-max-response-size" description:"Maximum response size for gRPC requests. Default value is 4MB"`
ForceJaegerSample bool `long:"force-jaeger-sample" description:"Force all requests to be sampled for Jaeger tracing (use with --jaeger)"`

// Enables HTTP2 transport
UseHTTP2 bool `long:"http2" description:"Enable HTTP/2 for HTTP transport"`

// This is a hack to work around go-flags not allowing disabling flags:
// https://github.com/jessevdk/go-flags/issues/191
// Do not specify this value in a defaults.ini file as it is not possible
Expand Down
1 change: 1 addition & 0 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ func getTransport(opts TransportOptions, resolved resolvedProtocolEncoding, trac
Encoding: resolved.enc.String(),
URLs: opts.Peers,
Tracer: tracer,
UseHTTP2: opts.UseHTTP2,
}
return transport.NewHTTP(hopts)
}
23 changes: 22 additions & 1 deletion transport/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@ package transport

import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/http"
"strconv"
"time"

"golang.org/x/net/http2"

"github.com/opentracing/opentracing-go"
"golang.org/x/net/context"
)
Expand All @@ -51,6 +55,9 @@ type HTTPOptions struct {
ShardKey string
Encoding string
Tracer opentracing.Tracer

// HTTP/2 specific options
UseHTTP2 bool
}

var (
Expand All @@ -70,11 +77,25 @@ func NewHTTP(opts HTTPOptions) (Transport, error) {
opts.Method = "POST"
}

var transport http.RoundTripper

if opts.UseHTTP2 {
transport = &http2.Transport{
AllowHTTP: true,
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, addr)
},
}
} else {
transport = &http.Transport{}
}

return &httpTransport{
opts: opts,
// Use independent HTTP clients for each transport.
client: &http.Client{
Transport: &http.Transport{},
Transport: transport,
},
tracer: opts.Tracer,
}, nil
Expand Down
257 changes: 198 additions & 59 deletions transport/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,28 @@ package transport

import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"time"

"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"

"golang.org/x/net/context"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var (
defaultIdleConnTimeout = 15 * time.Minute
serverTimeout = 20 * time.Millisecond
clientTimeout = 10 * time.Millisecond
)

func TestHTTPConstructor(t *testing.T) {
tests := []struct {
opts HTTPOptions
Expand Down Expand Up @@ -71,73 +79,19 @@ func TestHTTPConstructor(t *testing.T) {
}

func TestHTTPCall(t *testing.T) {
timeoutCtx, _ := context.WithTimeout(context.Background(), 3*time.Second)
immediateTimeout, _ := context.WithTimeout(context.Background(), time.Nanosecond)

tests := []struct {
msg string
ctxOverride context.Context
hook string
method string
errMsg string
ttlMin time.Duration
ttlMax time.Duration
wantCode int
wantBody []byte // If nil, uses the request body
}{
{
msg: "ok",
ttlMin: time.Second,
ttlMax: time.Second,
wantCode: http.StatusOK,
},
{
msg: "3 second timeout",
ctxOverride: timeoutCtx,
ttlMin: 3*time.Second - 100*time.Millisecond,
ttlMax: 3 * time.Second,
wantCode: http.StatusOK,
},
{
msg: "timed out",
ctxOverride: immediateTimeout,
errMsg: context.DeadlineExceeded.Error(),
},
tests := getCommonHttpTestCases()
tests = append(tests, []TestBody{
{
msg: "connection closed before data",
hook: "kill_conn",
errMsg: "EOF",
},
{
msg: "bad request response",
hook: "bad_req",
errMsg: "non-success response code: 400, body: bad request",
},
{
msg: "connection closed after data",
hook: "flush_and_kill",
errMsg: "unexpected EOF",
},
{
msg: "no content",
hook: "no_content",
wantCode: http.StatusNoContent,
wantBody: []byte{},
},
{
msg: "default method to POST",
hook: "method",
wantCode: http.StatusOK,
wantBody: []byte("POST"),
},
{
msg: "override method to GET",
method: "GET",
hook: "method",
wantCode: http.StatusOK,
wantBody: []byte("GET"),
},
}
}...)

lastReq := struct {
url *url.URL
Expand All @@ -149,7 +103,7 @@ func TestHTTPCall(t *testing.T) {
var err error
lastReq.url = r.URL
lastReq.headers = r.Header
lastReq.body, err = ioutil.ReadAll(r.Body)
lastReq.body, err = io.ReadAll(r.Body)
require.NoError(t, err, "Failed to read body from request")

// Test hooks to change the request behaviour.
Expand All @@ -169,6 +123,9 @@ func TestHTTPCall(t *testing.T) {
case "server_err":
w.WriteHeader(http.StatusInternalServerError)
return
case "timeout_hook":
time.Sleep(serverTimeout)
return
case "flush_and_kill":
io.WriteString(w, "some data")
flusher := w.(http.Flusher)
Expand Down Expand Up @@ -254,3 +211,185 @@ func TestHTTPCall(t *testing.T) {
})
}
}

func TestHTTP2Call(t *testing.T) {
tests := getCommonHttpTestCases()

lastReq := struct {
url *url.URL
headers http.Header
body []byte
}{}

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var err error
lastReq.url = r.URL
lastReq.headers = r.Header
lastReq.body, err = io.ReadAll(r.Body)
require.NoError(t, err, "Failed to read body from request")

// Test hooks to change the request behaviour.
switch f := r.Header.Get("hook"); f {
case "no_content":
w.Header().Set("Custom-Header", "ok")
w.WriteHeader(http.StatusNoContent)
return
case "method":
w.Header().Set("Custom-Header", "ok")
io.WriteString(w, r.Method)
return
case "bad_req":
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "bad request")
return
case "server_err":
w.WriteHeader(http.StatusInternalServerError)
return
case "timeout_hook":
time.Sleep(serverTimeout)
return
}

w.Header().Set("Custom-Header", "ok")
io.WriteString(w, "ok")
})

h2s := &http2.Server{
IdleTimeout: defaultIdleConnTimeout,
}

svr := httptest.NewServer(h2c.NewHandler(handler, h2s))
defer svr.Close()

for _, tt := range tests {
t.Run(tt.msg, func(t *testing.T) {
transport, err := NewHTTP(HTTPOptions{
Method: tt.method,
URLs: []string{svr.URL + "/rpc"},
SourceService: "source",
TargetService: "target",
ShardKey: "sk",
RoutingKey: "rk",
RoutingDelegate: "rd",
Encoding: "raw",
UseHTTP2: true,
})
require.NoError(t, err, "Failed to create HTTP transport")

ctx := context.Background()
if tt.ctxOverride != nil {
ctx = tt.ctxOverride
}

r := &Request{Method: "method", Body: []byte{1, 2, 3}}

r.TransportHeaders = map[string]string{"hook": tt.hook}
r.Headers = map[string]string{"headerkey": "headervalue"}
got, err := transport.Call(ctx, r)
if tt.errMsg != "" {
if assert.Error(t, err, "Call should fail") {
assert.Contains(t, err.Error(), tt.errMsg, "Unexpected error")
}
return
}

if !assert.NoError(t, err, "Call shouldn't fail") {
return
}

wantBody := tt.wantBody
if wantBody == nil {
wantBody = []byte("ok")
}
if !assert.Equal(t, wantBody, got.Body, "Response body mismatch") {
return
}

assert.Equal(t, "/rpc", lastReq.url.Path, "Path mismatch")
assert.Equal(t, "target", lastReq.headers.Get("Rpc-Service"), "Service header mismatch")
assert.Equal(t, "source", lastReq.headers.Get("Rpc-Caller"), "Caller header mismatch")
assert.Equal(t, "sk", lastReq.headers.Get("Rpc-Shard-Key"), "Shard key header mismatch")
assert.Equal(t, "rk", lastReq.headers.Get("Rpc-Routing-Key"), "Routing key header mismatch")
assert.Equal(t, "rd", lastReq.headers.Get("Rpc-Routing-Delegate"), "Routing delegate header mismatch")
assert.Equal(t, r.Method, lastReq.headers.Get("Rpc-Procedure"), "Method header mismatch")
assert.Equal(t, "raw", lastReq.headers.Get("Rpc-Encoding"), "Encoding header mismatch")
assert.Equal(t, "headervalue", lastReq.headers.Get("Rpc-Header-Headerkey"), "Application header is sent with prefix")

if tt.ttlMin != 0 && tt.ttlMax != 0 {
ttlMS, err := strconv.Atoi(lastReq.headers.Get("Context-TTL-MS"))
if assert.NoError(t, err, "Failed to parse TTLms header: %v", lastReq.headers.Get("YARPC-TTLms")) {
gotTTL := time.Duration(ttlMS) * time.Millisecond
assert.True(t, gotTTL >= tt.ttlMin && gotTTL <= tt.ttlMax,
"Got TTL %v out of range [%v,%v]", gotTTL, tt.ttlMin, tt.ttlMax)
}
}

assert.Equal(t, "ok", got.Headers["Custom-Header"], "Header mismatch")
assert.Equal(t, tt.wantCode, got.TransportFields["statusCode"], "Status code mismatch")
assert.Equal(t, lastReq.body, r.Body, "Body mismatch")
})
}
}

type TestBody struct {
msg string
ctxOverride context.Context
hook string
method string
errMsg string
ttlMin time.Duration
ttlMax time.Duration
wantCode int
wantBody []byte // If nil, uses the request body
}

func getCommonHttpTestCases() []TestBody {
timeoutCtx, _ := context.WithTimeout(context.Background(), 3*time.Second)
immediateTimeout, _ := context.WithTimeout(context.Background(), clientTimeout)

return []TestBody{
{
msg: "ok",
ttlMin: time.Second,
ttlMax: time.Second,
wantCode: http.StatusOK,
},
{
msg: "3 second timeout",
ctxOverride: timeoutCtx,
ttlMin: 3*time.Second - 100*time.Millisecond,
ttlMax: 3 * time.Second,
wantCode: http.StatusOK,
},
{
msg: "timed out",
ctxOverride: immediateTimeout,
hook: "timeout_hook",
errMsg: context.DeadlineExceeded.Error(),
},
{
msg: "bad request response",
hook: "bad_req",
errMsg: "non-success response code: 400, body: bad request",
},
{
msg: "no content",
hook: "no_content",
wantCode: http.StatusNoContent,
wantBody: []byte{},
},
{
msg: "default method to POST",
hook: "method",
wantCode: http.StatusOK,
wantBody: []byte("POST"),
},
{
msg: "override method to GET",
method: "GET",
hook: "method",
wantCode: http.StatusOK,
wantBody: []byte("GET"),
},
}
}