diff --git a/rpc/client_test.go b/rpc/client_test.go index ff65c26eb..4500bc076 100644 --- a/rpc/client_test.go +++ b/rpc/client_test.go @@ -24,6 +24,9 @@ import ( stdjson "encoding/json" "fmt" "math/big" + "net/http" + "net/http/httptest" + "sync/atomic" "testing" "github.com/AlekSi/pointer" @@ -3123,3 +3126,76 @@ func TestClient_GetRecentPrioritizationFees(t *testing.T) { assert.Equal(t, expected, got, "both deserialized values must be equal") } + +// mockServer will check if the request header contains the specified key/value +func mockServer(t *testing.T, wantKey, wantValue string, called *atomic.Bool) *httptest.Server { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get(wantKey); got != wantValue { + t.Errorf("header %q: got %q, want %q", wantKey, got, wantValue) + } + called.Store(true) + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","result":{},"id":1}`)) + }) + return httptest.NewServer(handler) +} + +func TestClient_SetHeader_AddAndModify(t *testing.T) { + var called atomic.Bool + server := mockServer(t, "Authorization", "Bearer testtoken", &called) + defer server.Close() + + cli := New(server.URL) + + cli.SetHeader("Authorization", "Bearer testtoken") + // send a request to trigger the server + _ = cli.RPCCallForInto(context.Background(), &struct{}{}, "getVersion", nil) + if !called.Load() { + t.Fatalf("server was not called") + } + + // modify header + called.Store(false) + cli.SetHeader("Authorization", "Bearer newtoken") + s2 := mockServer(t, "Authorization", "Bearer newtoken", &called) + defer s2.Close() + cli = New(s2.URL) + cli.SetHeader("Authorization", "Bearer newtoken") + _ = cli.RPCCallForInto(context.Background(), &struct{}{}, "getVersion", nil) + if !called.Load() { + t.Fatalf("server was not called after modifying header") + } +} + +func TestClient_SetHeader_RemoveHeader(t *testing.T) { + var called atomic.Bool + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("X-Remove-Me"); got != "" { + t.Errorf("header X-Remove-Me should be empty after removal, got %q", got) + } + called.Store(true) + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","result":{},"id":1}`)) + })) + defer server.Close() + + cli := New(server.URL) + cli.SetHeader("X-Remove-Me", "to-be-removed") + cli.RemoveHeader("X-Remove-Me") + _ = cli.RPCCallForInto(context.Background(), &struct{}{}, "getVersion", nil) + if !called.Load() { + t.Fatalf("server was not called for RemoveHeader test") + } +} + +func TestClient_GetHeaders(t *testing.T) { + cli := New("http://localhost") + cli.SetHeader("A", "1") + cli.SetHeader("B", "2") + cli.RemoveHeader("A") + headers := cli.GetHeaders() + if _, ok := headers["A"]; ok { + t.Errorf("header A should be removed") + } + if v, ok := headers["B"]; !ok || v != "2" { + t.Errorf("header B missing or incorrect") + } +} diff --git a/rpc/jsonrpc/context_headers.go b/rpc/jsonrpc/context_headers.go new file mode 100644 index 000000000..810eeb473 --- /dev/null +++ b/rpc/jsonrpc/context_headers.go @@ -0,0 +1,56 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package jsonrpc + +import ( + "context" + "net/http" +) + +type mdHeaderKey struct{} + +// NewContextWithHeaders wraps the given context, adding HTTP headers. These headers will +// be applied by Client when making a request using the returned context. +func NewContextWithHeaders(ctx context.Context, h http.Header) context.Context { + if len(h) == 0 { + // This check ensures the header map set in context will never be nil. + return ctx + } + + var ctxh http.Header + prev, ok := ctx.Value(mdHeaderKey{}).(http.Header) + if ok { + ctxh = setHeaders(prev.Clone(), h) + } else { + ctxh = h.Clone() + } + return context.WithValue(ctx, mdHeaderKey{}, ctxh) +} + +// headersFromContext is used to extract http.Header from context. +func headersFromContext(ctx context.Context) http.Header { + source, _ := ctx.Value(mdHeaderKey{}).(http.Header) + return source +} + +// setHeaders sets all headers from src in dst. +func setHeaders(dst http.Header, src http.Header) http.Header { + for key, values := range src { + dst[http.CanonicalHeaderKey(key)] = values + } + return dst +} diff --git a/rpc/jsonrpc/jsonrpc.go b/rpc/jsonrpc/jsonrpc.go index 6a48ec95f..6fe7be67d 100644 --- a/rpc/jsonrpc/jsonrpc.go +++ b/rpc/jsonrpc/jsonrpc.go @@ -7,6 +7,7 @@ import ( stdjson "encoding/json" "errors" "fmt" + "io" "net/http" "reflect" "sync/atomic" @@ -235,9 +236,22 @@ func (e *RPCError) Error() string { // and the body could not be parsed to a valid RPCResponse object that holds a RPCError. // // Otherwise a RPCResponse object is returned with a RPCError field that is not nil. +//type HTTPError struct { +// Code int +// err error +//} + type HTTPError struct { - Code int - err error + StatusCode int + Status string + Body []byte +} + +func (err HTTPError) Error() string { + if len(err.Body) == 0 { + return err.Status + } + return fmt.Sprintf("%v: %s", err.Status, err.Body) } // HTTPClient is an abstraction for a HTTP client @@ -246,17 +260,18 @@ type HTTPClient interface { CloseIdleConnections() } -func NewHTTPError(code int, err error) *HTTPError { +func NewHTTPError(code int, message string, body []byte) *HTTPError { return &HTTPError{ - Code: code, - err: err, + StatusCode: code, + Status: message, + Body: body, } } // Error function is provided to be used as error object. -func (e *HTTPError) Error() string { - return e.err.Error() -} +//func (e *HTTPError) Error() string { +// return e.err.Error() +//} type rpcClient struct { endpoint string @@ -335,6 +350,11 @@ func NewClient(endpoint string) RPCClient { // // opts: RPCClientOpts provide custom configuration func NewClientWithOpts(endpoint string, opts *RPCClientOpts) RPCClient { + if endpoint == "" { + panic("endpoint must not be empty") + } + + // create the rpcClient and set default values rpcClient := &rpcClient{ endpoint: endpoint, httpClient: &http.Client{}, @@ -480,6 +500,8 @@ func (client *rpcClient) newRequest(ctx context.Context, req interface{}) (*http for k, v := range client.customHeaders { request.Header.Set(k, v) } + // set headers from context + setHeaders(request.Header, headersFromContext(ctx)) return request, nil } @@ -499,26 +521,28 @@ func (client *rpcClient) doCall( err := decoder.Decode(&rpcResponse) // parsing error if err != nil { + return err // if we have some http error, return it - if httpResponse.StatusCode >= 400 { - return &HTTPError{ - Code: httpResponse.StatusCode, - err: fmt.Errorf("rpc call %v() on %v status code: %v. could not decode body to rpc response: %w", RPCRequest.Method, httpRequest.URL.String(), httpResponse.StatusCode, err), - } - } - return fmt.Errorf("rpc call %v() on %v status code: %v. could not decode body to rpc response: %w", RPCRequest.Method, httpRequest.URL.String(), httpResponse.StatusCode, err) + //if httpResponse.StatusCode >= 400 { + // return &HTTPError{ + // Code: httpResponse.StatusCode, + // err: fmt.Errorf("rpc call %v() on %v status code: %v. could not decode body to rpc response: %w", RPCRequest.Method, httpRequest.URL.String(), httpResponse.StatusCode, err), + // } + //} + //return fmt.Errorf("rpc call %v() on %v status code: %v. could not decode body to rpc response: %w", RPCRequest.Method, httpRequest.URL.String(), httpResponse.StatusCode, err) } // response body empty if rpcResponse == nil { + return fmt.Errorf("rpc response is empty") // if we have some http error, return it - if httpResponse.StatusCode >= 400 { - return &HTTPError{ - Code: httpResponse.StatusCode, - err: fmt.Errorf("rpc call %v() on %v status code: %v. rpc response missing", RPCRequest.Method, httpRequest.URL.String(), httpResponse.StatusCode), - } - } - return fmt.Errorf("rpc call %v() on %v status code: %v. rpc response missing", RPCRequest.Method, httpRequest.URL.String(), httpResponse.StatusCode) + //if httpResponse.StatusCode >= 400 { + // return &HTTPError{ + // Code: httpResponse.StatusCode, + // err: fmt.Errorf("rpc call %v() on %v status code: %v. rpc response missing", RPCRequest.Method, httpRequest.URL.String(), httpResponse.StatusCode), + // } + //} + //return fmt.Errorf("rpc call %v() on %v status code: %v. rpc response missing", RPCRequest.Method, httpRequest.URL.String(), httpResponse.StatusCode) } return nil }, @@ -558,18 +582,30 @@ func (client *rpcClient) doCallWithCallbackOnHTTPResponse( } httpRequest, err := client.newRequest(ctx, RPCRequest) if err != nil { - if httpRequest != nil { - return fmt.Errorf("rpc call %v() on %v: %w", RPCRequest.Method, httpRequest.URL.String(), err) - } - return fmt.Errorf("rpc call %v(): %w", RPCRequest.Method, err) + return err } httpResponse, err := client.httpClient.Do(httpRequest) if err != nil { - return fmt.Errorf("rpc call %v() on %v: %w", RPCRequest.Method, httpRequest.URL.String(), err) + return err } defer httpResponse.Body.Close() - return callback(httpRequest, httpResponse) + // allow callback to process first (regardless of status code) + if err := callback(httpRequest, httpResponse); err != nil { + return err + } + + // if HTTP status code is not 2xx, return HTTPError additionally + if httpResponse.StatusCode < 200 || httpResponse.StatusCode >= 300 { + body, _ := io.ReadAll(io.LimitReader(httpResponse.Body, 4<<10)) // 最多4KB + return HTTPError{ + Status: httpResponse.Status, + StatusCode: httpResponse.StatusCode, + Body: body, + } + } + + return nil } func (client *rpcClient) doBatchCall(ctx context.Context, rpcRequest []*RPCRequest) ([]*RPCResponse, error) { @@ -581,10 +617,23 @@ func (client *rpcClient) doBatchCall(ctx context.Context, rpcRequest []*RPCReque return nil, fmt.Errorf("rpc batch call: %w", err) } httpResponse, err := client.httpClient.Do(httpRequest) - if err != nil { - return nil, fmt.Errorf("rpc batch call on %v: %w", httpRequest.URL.String(), err) + if httpResponse.StatusCode < 200 || httpResponse.StatusCode >= 300 { + var buf bytes.Buffer + var body []byte + if _, err := buf.ReadFrom(httpResponse.Body); err == nil { + body = buf.Bytes() + } + httpResponse.Body.Close() + return nil, HTTPError{ + Status: httpResponse.Status, + StatusCode: httpResponse.StatusCode, + Body: body, + } } - defer httpResponse.Body.Close() + //if err != nil { + // return nil, fmt.Errorf("rpc batch call on %v: %w", httpRequest.URL.String(), err) + //} + //defer httpResponse.Body.Close() var rpcResponse RPCResponses decoder := json.NewDecoder(httpResponse.Body) @@ -593,26 +642,30 @@ func (client *rpcClient) doBatchCall(ctx context.Context, rpcRequest []*RPCReque err = decoder.Decode(&rpcResponse) // parsing error if err != nil { + return nil, err // if we have some http error, return it - if httpResponse.StatusCode >= 400 { - return nil, &HTTPError{ - Code: httpResponse.StatusCode, - err: fmt.Errorf("rpc batch call on %v status code: %v. could not decode body to rpc response: %w", httpRequest.URL.String(), httpResponse.StatusCode, err), - } - } - return nil, fmt.Errorf("rpc batch call on %v status code: %v. could not decode body to rpc response: %w", httpRequest.URL.String(), httpResponse.StatusCode, err) + //if httpResponse.StatusCode >= 400 { + // return nil, &HTTPError{ + // StatusCode: httpResponse.StatusCode, + // Status: httpResponse.Status, + // Body: + // //err: fmt.Errorf("rpc batch call on %v status code: %v. could not decode body to rpc response: %w", httpRequest.URL.String(), httpResponse.StatusCode, err), + // } + //} + //return nil, fmt.Errorf("rpc batch call on %v status code: %v. could not decode body to rpc response: %w", httpRequest.URL.String(), httpResponse.StatusCode, err) } // response body empty if rpcResponse == nil || len(rpcResponse) == 0 { + return nil, fmt.Errorf("JSON-RPC response has no result") // if we have some http error, return it - if httpResponse.StatusCode >= 400 { - return nil, &HTTPError{ - Code: httpResponse.StatusCode, - err: fmt.Errorf("rpc batch call on %v status code: %v. rpc response missing", httpRequest.URL.String(), httpResponse.StatusCode), - } - } - return nil, fmt.Errorf("rpc batch call on %v status code: %v. rpc response missing", httpRequest.URL.String(), httpResponse.StatusCode) + //if httpResponse.StatusCode >= 400 { + // return nil, &HTTPError{ + // Code: httpResponse.StatusCode, + // err: fmt.Errorf("rpc batch call on %v status code: %v. rpc response missing", httpRequest.URL.String(), httpResponse.StatusCode), + // } + //} + //return nil, fmt.Errorf("rpc batch call on %v status code: %v. rpc response missing", httpRequest.URL.String(), httpResponse.StatusCode) } return rpcResponse, nil