From 0e5522b9619637c92b66aae9a259d83d869ad0f7 Mon Sep 17 00:00:00 2001 From: Sergey Bykov Date: Fri, 13 Feb 2026 22:20:34 +0300 Subject: [PATCH 1/2] Improve performance, migrate tests to testify, expand test coverage - Server hot-path optimization: fast path for single requests without wrapping into an array, replaced unicode.IsSpace with direct byte comparison in IsArray, replaced strings.SplitN with strings.Cut, optimized ConvertToObject with pre-allocated buffer - PackageInfo.String() optimization: replaced string concatenation with strings.Builder - Test framework migration from goconvey to testify: replaced github.com/smartystreets/goconvey with github.com/stretchr/testify (assert/require) across all test files - Expanded test coverage: added unit tests for JSON-RPC 2.0 spec compliance (ID formats, version validation, batch/params edge cases), response structure tests, concurrency tests, context function tests, Error/Response API tests - Added benchmarks: IsArray, ConvertToObject, Do() (single/batch/notification/error), ServeHTTP (sequential/parallel), middleware overhead - Makefile: added -race flag to test target, added bench target - README: fixed code examples (pointer receivers, removed dependency on testdata package) --- Makefile | 5 +- README.md | 14 +- bench_test.go | 167 +++++++++++++++++ go.mod | 24 +-- go.sum | 24 ++- handlers_test.go | 107 ++++------- jsonrpc2_test.go | 181 +++++++++++++++++++ parser/helpers_test.go | 8 +- parser/parser.go | 35 ++-- server.go | 69 +++---- server_test.go | 399 +++++++++++++++++++++++++++++++++++++++-- 11 files changed, 858 insertions(+), 175 deletions(-) create mode 100644 bench_test.go create mode 100644 jsonrpc2_test.go diff --git a/Makefile b/Makefile index 098675c..36bf367 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,10 @@ lint: @golangci-lint run test: - @go test -v ./... + @go test -race -v ./... + +bench: + @go test -bench=. -benchmem -run=^$$ ./... mod: @go mod tidy diff --git a/README.md b/README.md index 0a3eb91..2e2cc43 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,6 @@ import ( "os" "github.com/vmkteam/zenrpc/v2" - "github.com/vmkteam/zenrpc/v2/testdata" ) type ArithService struct{ zenrpc.Service } @@ -74,7 +73,10 @@ type Quotient struct { Quo, Rem int } -func (as ArithService) Divide(a, b int) (quo *Quotient, err error) { +// Divide divides two numbers. +// +//zenrpc:401 we do not serve 1 +func (as *ArithService) Divide(a, b int) (quo *Quotient, err error) { if b == 0 { return nil, errors.New("divide by zero") } else if b == 1 { @@ -90,8 +92,8 @@ func (as ArithService) Divide(a, b int) (quo *Quotient, err error) { // Pow returns x**y, the base-x exponential of y. If Exp is not set then default value is 2. // //zenrpc:exp=2 -func (as ArithService) Pow(base float64, exp float64) float64 { - return math.Pow(base, exp) +func (as *ArithService) Pow(base float64, exp *float64) float64 { + return math.Pow(base, *exp) } //go:generate zenrpc @@ -101,8 +103,8 @@ func main() { flag.Parse() rpc := zenrpc.NewServer(zenrpc.Options{ExposeSMD: true}) - rpc.Register("arith", testdata.ArithService{}) - rpc.Register("", testdata.ArithService{}) // public + rpc.Register("arith", ArithService{}) + rpc.Register("", ArithService{}) // public rpc.Use(zenrpc.Logger(log.New(os.Stderr, "", log.LstdFlags))) http.Handle("/", rpc) diff --git a/bench_test.go b/bench_test.go new file mode 100644 index 0000000..83926a6 --- /dev/null +++ b/bench_test.go @@ -0,0 +1,167 @@ +package zenrpc_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "testing" + + "github.com/vmkteam/zenrpc/v2" + "github.com/vmkteam/zenrpc/v2/testdata" +) + +// --- Low-level utilities --- + +func BenchmarkIsArray(b *testing.B) { + object := json.RawMessage(`{"jsonrpc":"2.0","method":"arith.pi","id":1}`) + array := json.RawMessage(`[{"jsonrpc":"2.0","method":"arith.pi","id":1}]`) + + b.Run("Object", func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(object))) + for b.Loop() { + zenrpc.IsArray(object) + } + }) + b.Run("Array", func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(array))) + for b.Loop() { + zenrpc.IsArray(array) + } + }) +} + +func BenchmarkConvertToObject(b *testing.B) { + keys := []string{"a", "b"} + params := json.RawMessage(`[3,2]`) + + b.ReportAllocs() + b.SetBytes(int64(len(params))) + for b.Loop() { + _, _ = zenrpc.ConvertToObject(keys, params) + } +} + +// --- Core processing via Do() --- + +// benchDo is a helper that runs a Do() benchmark with throughput and response size metrics. +func benchDo(b *testing.B, srv *zenrpc.Server, req []byte) { + b.Helper() + b.ReportAllocs() + b.SetBytes(int64(len(req))) + ctx := context.Background() + + for b.Loop() { + _, _ = srv.Do(ctx, req) + } +} + +func BenchmarkDo_SimpleMethod(b *testing.B) { + benchDo(b, testRPC, []byte(`{"jsonrpc":"2.0","method":"arith.pi","id":1}`)) +} + +func BenchmarkDo_MethodWithObjectParams(b *testing.B) { + benchDo(b, testRPC, []byte(`{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":3,"b":2},"id":1}`)) +} + +func BenchmarkDo_MethodWithArrayParams(b *testing.B) { + benchDo(b, testRPC, []byte(`{"jsonrpc":"2.0","method":"arith.multiply","params":[3,2],"id":1}`)) +} + +func BenchmarkDo_MethodWithDefaultParam(b *testing.B) { + benchDo(b, testRPC, []byte(`{"jsonrpc":"2.0","method":"arith.pow","params":{"base":3},"id":1}`)) +} + +func BenchmarkDo_MethodNotFound(b *testing.B) { + benchDo(b, testRPC, []byte(`{"jsonrpc":"2.0","method":"arith.nonexistent","id":1}`)) +} + +func BenchmarkDo_Notification(b *testing.B) { + benchDo(b, testRPC, []byte(`{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":3,"b":2}}`)) +} + +func BenchmarkDo_InvalidJSON(b *testing.B) { + benchDo(b, testRPC, []byte(`{"jsonrpc": "2.0", "method": "foobar, "params": "bar", "baz]`)) +} + +// --- Batch processing via Do() --- + +func BenchmarkDo_Batch(b *testing.B) { + for _, size := range []int{1, 2, 5} { + b.Run(fmt.Sprintf("Size%d", size), func(b *testing.B) { + req := buildBatchRequest(size) + b.ReportAllocs() + b.SetBytes(int64(len(req))) + ctx := context.Background() + + for b.Loop() { + _, _ = testRPC.Do(ctx, req) + } + }) + } +} + +// --- HTTP transport via ServeHTTP --- + +func BenchmarkServeHTTP(b *testing.B) { + ts := httptest.NewServer(http.HandlerFunc(testRPC.ServeHTTP)) + defer ts.Close() + + body := []byte(`{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":3,"b":2},"id":1}`) + b.ReportAllocs() + b.SetBytes(int64(len(body))) + + for b.Loop() { + resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body)) + if err != nil { + b.Fatal(err) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } +} + +func BenchmarkServeHTTP_Parallel(b *testing.B) { + ts := httptest.NewServer(http.HandlerFunc(testRPC.ServeHTTP)) + defer ts.Close() + + body := []byte(`{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":3,"b":2},"id":1}`) + b.ReportAllocs() + b.SetBytes(int64(len(body))) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + resp, err := http.Post(ts.URL, "application/json", bytes.NewReader(body)) + if err != nil { + b.Fatal(err) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + }) +} + +// --- Middleware overhead --- + +func BenchmarkDo_WithMiddleware(b *testing.B) { + req := []byte(`{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":3,"b":2},"id":1}`) + + b.Run("NoMiddleware", func(b *testing.B) { + srv := zenrpc.NewServer(zenrpc.Options{}) + srv.Register("arith", &testdata.ArithService{}) + benchDo(b, srv, req) + }) + + b.Run("WithLogger", func(b *testing.B) { + srv := zenrpc.NewServer(zenrpc.Options{}) + srv.Register("arith", &testdata.ArithService{}) + srv.Use(zenrpc.Logger(log.New(io.Discard, "", 0))) + benchDo(b, srv, req) + }) +} diff --git a/go.mod b/go.mod index 117d11f..c71404c 100644 --- a/go.mod +++ b/go.mod @@ -5,27 +5,27 @@ go 1.24.0 require ( github.com/gorilla/websocket v1.5.3 github.com/prometheus/client_golang v1.23.2 - github.com/smartystreets/goconvey v1.8.1 - golang.org/x/text v0.29.0 - golang.org/x/tools v0.37.0 + github.com/stretchr/testify v1.11.1 + golang.org/x/text v0.34.0 + golang.org/x/tools v0.42.0 ) require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/gopherjs/gopherjs v1.17.2 // indirect - github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect - github.com/prometheus/common v0.66.1 // indirect - github.com/prometheus/procfs v0.17.0 // indirect - github.com/smarty/assertions v1.16.0 // indirect + github.com/prometheus/common v0.67.5 // indirect + github.com/prometheus/procfs v0.19.2 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect - golang.org/x/mod v0.28.0 // indirect - golang.org/x/sync v0.17.0 // indirect - golang.org/x/sys v0.36.0 // indirect - google.golang.org/protobuf v1.36.9 // indirect + golang.org/x/mod v0.33.0 // indirect + golang.org/x/sync v0.19.0 // indirect + golang.org/x/sys v0.41.0 // indirect + google.golang.org/protobuf v1.36.11 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) retract ( diff --git a/go.sum b/go.sum index 9d8f4aa..16e6ad4 100644 --- a/go.sum +++ b/go.sum @@ -7,12 +7,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= -github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -27,14 +23,14 @@ github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNw github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= +github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= +github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= +github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= +github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= -github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY= -github.com/smarty/assertions v1.16.0/go.mod h1:duaaFdCS0K9dnoM50iyek/eYINOZ64gbh1Xlf6LG7AI= -github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= -github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -43,16 +39,28 @@ go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U= golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE= golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/handlers_test.go b/handlers_test.go index 2f072df..ee20645 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -15,6 +15,8 @@ import ( "github.com/vmkteam/zenrpc/v2/testdata" "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestServer_ServeHTTPWithHeaders(t *testing.T) { @@ -41,13 +43,8 @@ func TestServer_ServeHTTPWithHeaders(t *testing.T) { for _, c := range tc { res, err := http.Post(ts.URL, c.h, bytes.NewBufferString(`{"jsonrpc": "2.0", "method": "arith.pi", "id": 2 }`)) - if err != nil { - t.Fatal(err) - } - - if res.StatusCode != c.s { - t.Errorf("Input: %s\n got %d expected %d", c.h, res.StatusCode, c.s) - } + require.NoError(t, err) + assert.Equal(t, c.s, res.StatusCode, "Input: %s", c.h) res.Body.Close() } } @@ -111,19 +108,13 @@ func TestServer_ServeHTTP(t *testing.T) { for _, c := range tc { res, err := http.Post(ts.URL, "application/json", bytes.NewBufferString(c.in)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resp, err := io.ReadAll(res.Body) res.Body.Close() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - if string(resp) != c.out { - t.Errorf("Input: %s\n got %s expected %s", c.in, resp, c.out) - } + assert.Equal(t, c.out, string(resp), "Input: %s", c.in) } } @@ -164,19 +155,13 @@ func TestServer_ServeHTTPNotifications(t *testing.T) { for _, c := range tc { res, err := http.Post(ts.URL, "application/json", bytes.NewBufferString(c.in)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resp, err := io.ReadAll(res.Body) res.Body.Close() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - if string(resp) != c.out { - t.Errorf("Input: %s\n got %s expected %s", c.in, resp, c.out) - } + assert.Equal(t, c.out, string(resp), "Input: %s", c.in) } } @@ -222,26 +207,19 @@ func TestServer_ServeHTTPBatch(t *testing.T) { for _, c := range tc { res, err := http.Post(ts.URL, "application/json", bytes.NewBufferString(c.in)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resp, err := io.ReadAll(res.Body) res.Body.Close() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // checking if count of responses is correct - if cnt := strings.Count(string(resp), `"jsonrpc":"2.0"`); len(c.out) != cnt { - t.Errorf("Input: %s\n got %d in batch expected %d", c.in, cnt, len(c.out)) - } + cnt := strings.Count(string(resp), `"jsonrpc":"2.0"`) + assert.Equal(t, len(c.out), cnt, "Input: %s\n batch count mismatch", c.in) // checking every response variant to be in response for _, check := range c.out { - if !strings.Contains(string(resp), check) { - t.Errorf("Input: %s\n not found %s in batch %s", c.in, check, resp) - } + assert.Contains(t, string(resp), check, "Input: %s", c.in) } } } @@ -296,19 +274,13 @@ func TestServer_ServeHTTPWithErrors(t *testing.T) { for _, c := range tc { res, err := http.Post(c.url, "application/json", bytes.NewBufferString(c.in)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resp, err := io.ReadAll(res.Body) res.Body.Close() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - if string(resp) != c.out { - t.Errorf("Input: %s\n got %s expected %s", c.in, resp, c.out) - } + assert.Equal(t, c.out, string(resp), "Input: %s", c.in) } } @@ -361,19 +333,13 @@ func TestServer_Extensions(t *testing.T) { for _, c := range tc { res, err := http.Post(c.url, "application/json", bytes.NewBufferString(c.in)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) resp, err := io.ReadAll(res.Body) res.Body.Close() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - if string(resp) != c.out { - t.Errorf("Input: %s\n got %s expected %s", c.in, resp, c.out) - } + assert.Equal(t, c.out, string(resp), "Input: %s", c.in) } } @@ -381,14 +347,13 @@ func TestServer_ServeWS(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(testRPC.ServeWS)) defer ts.Close() - u, _ := url.Parse(ts.URL) + u, err := url.Parse(ts.URL) + require.NoError(t, err) u.Scheme = "ws" //nolint:bodyclose ws, _, err := websocket.DefaultDialer.Dial(u.String(), nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer ws.Close() var tc = []struct { @@ -421,24 +386,16 @@ func TestServer_ServeWS(t *testing.T) { } for _, c := range tc { - if err := ws.WriteMessage(websocket.TextMessage, []byte(c.in)); err != nil { - t.Fatal(err) - return - } + err = ws.WriteMessage(websocket.TextMessage, []byte(c.in)) + require.NoError(t, err) - _, resp, err := ws.ReadMessage() - if err != nil { - t.Fatal(err) - return - } + var resp []byte + _, resp, err = ws.ReadMessage() + require.NoError(t, err) - if string(resp) != c.out { - t.Errorf("Input: %s\n got %s expected %s", c.in, resp, c.out) - } + assert.Equal(t, c.out, string(resp), "Input: %s", c.in) } - if err := ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil { - t.Fatal(err) - return - } + err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + require.NoError(t, err) } diff --git a/jsonrpc2_test.go b/jsonrpc2_test.go new file mode 100644 index 0000000..872b4a1 --- /dev/null +++ b/jsonrpc2_test.go @@ -0,0 +1,181 @@ +package zenrpc_test + +import ( + "encoding/json" + "errors" + "fmt" + "testing" + + "github.com/vmkteam/zenrpc/v2" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- ErrorMsg --- + +func TestErrorMsg(t *testing.T) { + tests := []struct { + code int + want string + }{ + {zenrpc.ParseError, "Parse error"}, + {zenrpc.InvalidRequest, "Invalid Request"}, + {zenrpc.MethodNotFound, "Method not found"}, + {zenrpc.InvalidParams, "Invalid params"}, + {zenrpc.InternalError, "Internal error"}, + {zenrpc.ServerError, "Server error"}, + {-99999, ""}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("code_%d", tt.code), func(t *testing.T) { + assert.Equal(t, tt.want, zenrpc.ErrorMsg(tt.code)) + }) + } +} + +// --- NewResponseError --- + +func TestNewResponseError(t *testing.T) { + t.Run("StandardMessage", func(t *testing.T) { + r := zenrpc.NewResponseError(nil, zenrpc.ParseError, "", nil) + require.NotNil(t, r.Error) + assert.Equal(t, zenrpc.ParseError, r.Error.Code) + assert.Equal(t, "Parse error", r.Error.Message) + assert.Equal(t, zenrpc.Version, r.Version) + assert.Nil(t, r.ID) + }) + + t.Run("CustomMessage", func(t *testing.T) { + r := zenrpc.NewResponseError(nil, zenrpc.InternalError, "custom msg", nil) + require.NotNil(t, r.Error) + assert.Equal(t, "custom msg", r.Error.Message) + }) + + t.Run("WithData", func(t *testing.T) { + r := zenrpc.NewResponseError(nil, zenrpc.InternalError, "", "extra") + require.NotNil(t, r.Error) + assert.Equal(t, "extra", r.Error.Data) + }) + + t.Run("WithID", func(t *testing.T) { + id := json.RawMessage(`42`) + r := zenrpc.NewResponseError(&id, zenrpc.InternalError, "", nil) + require.NotNil(t, r.ID) + assert.Equal(t, "42", string(*r.ID)) + }) +} + +// --- NewStringError --- + +func TestNewStringError(t *testing.T) { + e := zenrpc.NewStringError(zenrpc.InvalidParams, "bad param") + assert.Equal(t, zenrpc.InvalidParams, e.Code) + assert.Equal(t, "bad param", e.Message) + assert.NoError(t, e.Err) +} + +// --- NewError --- + +func TestNewError(t *testing.T) { + inner := errors.New("something went wrong") + e := zenrpc.NewError(zenrpc.InternalError, inner) + assert.Equal(t, zenrpc.InternalError, e.Code) + assert.Equal(t, "something went wrong", e.Message) + require.ErrorIs(t, e, inner) +} + +// --- Error.Error --- + +func TestError_Error(t *testing.T) { + t.Run("FromErr", func(t *testing.T) { + e := &zenrpc.Error{Code: zenrpc.InternalError, Err: errors.New("inner")} + assert.Equal(t, "inner", e.Error()) + }) + + t.Run("FromMessage", func(t *testing.T) { + e := &zenrpc.Error{Code: zenrpc.InternalError, Message: "custom"} + assert.Equal(t, "custom", e.Error()) + }) + + t.Run("FromCode", func(t *testing.T) { + e := &zenrpc.Error{Code: zenrpc.InternalError} + assert.Equal(t, "Internal error", e.Error()) + }) +} + +// --- Error.Unwrap --- + +func TestError_Unwrap(t *testing.T) { + inner := errors.New("root cause") + e := zenrpc.NewError(zenrpc.InternalError, inner) + + require.ErrorIs(t, e, inner) + assert.Equal(t, inner, errors.Unwrap(e)) +} + +// --- Response.Set --- + +func TestResponse_Set(t *testing.T) { + t.Run("Success", func(t *testing.T) { + var r zenrpc.Response + r.Set(42) + require.NotNil(t, r.Result) + assert.Nil(t, r.Error) + assert.Equal(t, "42", string(*r.Result)) + }) + + t.Run("StructValue", func(t *testing.T) { + var r zenrpc.Response + r.Set(struct{ X int }{1}) + require.NotNil(t, r.Result) + assert.Nil(t, r.Error) + assert.JSONEq(t, `{"X":1}`, string(*r.Result)) + }) + + t.Run("GoError", func(t *testing.T) { + var r zenrpc.Response + r.Set(nil, errors.New("boom")) + assert.Nil(t, r.Result) + require.NotNil(t, r.Error) + assert.Equal(t, zenrpc.InternalError, r.Error.Code) + }) + + t.Run("ZenrpcError", func(t *testing.T) { + var r zenrpc.Response + r.Set(nil, zenrpc.NewStringError(zenrpc.InvalidParams, "bad")) + assert.Nil(t, r.Result) + require.NotNil(t, r.Error) + assert.Equal(t, zenrpc.InvalidParams, r.Error.Code) + assert.Equal(t, "bad", r.Error.Message) + }) + + t.Run("NilZenrpcError", func(t *testing.T) { + var r zenrpc.Response + r.Set(nil, (*zenrpc.Error)(nil)) + require.NotNil(t, r.Result) + assert.Nil(t, r.Error) + assert.Equal(t, "null", string(*r.Result)) + }) + + t.Run("ErrorAsValue", func(t *testing.T) { + var r zenrpc.Response + r.Set(errors.New("boom")) + assert.Nil(t, r.Result) + require.NotNil(t, r.Error) + assert.Equal(t, zenrpc.InternalError, r.Error.Code) + }) +} + +// --- Response.JSON --- + +func TestResponse_JSON(t *testing.T) { + r := zenrpc.NewResponseError(nil, zenrpc.MethodNotFound, "", nil) + b := r.JSON() + + var m map[string]json.RawMessage + require.NoError(t, json.Unmarshal(b, &m)) + assert.Contains(t, m, "error") + assert.Equal(t, `"2.0"`, string(m["jsonrpc"])) +} diff --git a/parser/helpers_test.go b/parser/helpers_test.go index 8a44e64..cf46b53 100644 --- a/parser/helpers_test.go +++ b/parser/helpers_test.go @@ -3,12 +3,10 @@ package parser import ( "testing" - . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" ) func TestLoadPackage(t *testing.T) { - Convey("Should load package with syntax and imports", t, func() { - _, err := loadPackage("../testdata/subservice/subarithservice.go") - So(err, ShouldBeNil) - }) + _, err := loadPackage("../testdata/subservice/subarithservice.go") + assert.NoError(t, err) } diff --git a/parser/parser.go b/parser/parser.go index 83d6e62..5e45c96 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -297,53 +297,52 @@ func (pi *PackageInfo) parseMethods(f *ast.File, packagePath string) error { } func (pi *PackageInfo) String() string { - result := fmt.Sprintf("Generated services for package %s:\n", pi.PackageName) + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Generated services for package %s:\n", pi.PackageName)) for _, s := range pi.Services { - result += fmt.Sprintf("- %s\n", s.Name) + sb.WriteString(fmt.Sprintf("- %s\n", s.Name)) for _, m := range s.Methods { - result += fmt.Sprintf(" • %s", m.Name) + sb.WriteString(fmt.Sprintf(" • %s", m.Name)) // args - result += "(" + sb.WriteString("(") for i, a := range m.Args { if i != 0 { - result += ", " + sb.WriteString(", ") } - - result += fmt.Sprintf("%s %s", a.Name, a.Type) + sb.WriteString(fmt.Sprintf("%s %s", a.Name, a.Type)) } - result += ") " + sb.WriteString(") ") // no return args if len(m.Returns) == 0 { - result += "\n" + sb.WriteString("\n") continue } // only one return arg without name if len(m.Returns) == 1 && m.Returns[0].Name == "" { - result += m.Returns[0].Type + "\n" + sb.WriteString(m.Returns[0].Type + "\n") continue } // return - result += "(" + sb.WriteString("(") for i, a := range m.Returns { if i != 0 { - result += ", " + sb.WriteString(", ") } - if a.Name == "" { - result += a.Type + sb.WriteString(a.Type) } else { - result += fmt.Sprintf("%s %s", a.Name, a.Type) + sb.WriteString(fmt.Sprintf("%s %s", a.Name, a.Type)) } } - result += ")\n" + sb.WriteString(")\n") } } - return result + return sb.String() } func (pi *PackageInfo) OutputFilename() string { @@ -474,7 +473,7 @@ func (m *Method) parseReturns(pi *PackageInfo, fdecl *ast.FuncDecl, serviceNames // get Service.Method list methods := func() string { - methods := []string{} + methods := make([]string, 0, len(serviceNames)) for _, s := range serviceNames { methods = append(methods, s+"."+m.Name) } diff --git a/server.go b/server.go index 50a6e7a..137adc2 100644 --- a/server.go +++ b/server.go @@ -8,7 +8,6 @@ import ( "net/http" "strings" "sync" - "unicode" "github.com/vmkteam/zenrpc/v2/smd" @@ -134,16 +133,27 @@ func (s *Server) SetLogger(printer Printer) { // process processes JSON-RPC 2.0 message, invokes correct method for namespace and returns JSON-RPC 2.0 Response. func (s *Server) process(ctx context.Context, message json.RawMessage) interface{} { - var requests []Request // parsing batch requests batch := IsArray(message) - // making not batch request looks like batch to simplify further code + // fast path: single request — parse directly without wrapping in array if !batch { - message = append(append([]byte{'['}, message...), ']') + var req Request + if err := json.Unmarshal(message, &req); err != nil { + return NewResponseError(nil, ParseError, "", nil) + } + + if req.ID == nil { + // notification — fire and forget + s.processRequest(ctx, req) + return nil + } + + return s.processRequest(ctx, req) } - // unmarshal request(s) + // batch path + var requests []Request if err := json.Unmarshal(message, &requests); err != nil { return NewResponseError(nil, ParseError, "", nil) } @@ -156,14 +166,7 @@ func (s *Server) process(ctx context.Context, message json.RawMessage) interface } // set batch methods in request - if batch { - ctx = newBatchMethodsContext(ctx, methodsFromRequests(requests)) - } - - // process single request: if request single and not notification - just run it and return result - if !batch && requests[0].ID != nil { - return s.processRequest(ctx, requests[0]) - } + ctx = newBatchMethodsContext(ctx, methodsFromRequests(requests)) // process batch requests if res := s.processBatch(ctx, requests); len(res) > 0 { @@ -223,10 +226,9 @@ func (s *Server) processRequest(ctx context.Context, req Request) Response { // convert method to lower and find namespace lowerM := strings.ToLower(req.Method) - sp := strings.SplitN(lowerM, ".", 2) - namespace, method := "", lowerM - if len(sp) == 2 { - namespace, method = sp[0], sp[1] + namespace, method, found := strings.Cut(lowerM, ".") + if !found { + namespace, method = "", namespace } if _, ok := s.services[namespace]; !ok { @@ -306,7 +308,7 @@ func (s *Server) SMD() smd.Schema { // IsArray checks json message if it arrays or object. func IsArray(message json.RawMessage) bool { for _, b := range message { - if unicode.IsSpace(rune(b)) { + if b == ' ' || b == '\t' || b == '\n' || b == '\r' { continue } @@ -333,33 +335,20 @@ func ConvertToObject(keys []string, params json.RawMessage) (json.RawMessage, er return nil, fmt.Errorf("invalid params number, expected %d, got %d", paramCount, len(rawParams)) } - buf := bytes.Buffer{} - if _, err := buf.WriteString(`{`); err != nil { - return nil, err - } + buf := bytes.NewBuffer(make([]byte, 0, len(params)+64)) + buf.WriteByte('{') for i, p := range rawParams { - // Writing key - if _, err := buf.WriteString(`"` + keys[i] + `":`); err != nil { - return nil, err - } - - // Writing value - if _, err := buf.Write(p); err != nil { - return nil, err - } - - // Writing trailing comma if not last argument - if i != rawParamCount-1 { - if _, err := buf.WriteString(`,`); err != nil { - return nil, err - } + if i > 0 { + buf.WriteByte(',') } + buf.WriteByte('"') + buf.WriteString(keys[i]) + buf.WriteString(`":`) + buf.Write(p) } - if _, err := buf.WriteString(`}`); err != nil { - return nil, err - } + buf.WriteByte('}') return buf.Bytes(), nil } diff --git a/server_test.go b/server_test.go index 14c3478..44de8f6 100644 --- a/server_test.go +++ b/server_test.go @@ -2,15 +2,24 @@ package zenrpc_test import ( "bytes" + "context" "encoding/json" + "fmt" "log" + "net/http" + "net/http/httptest" "os" "testing" "github.com/vmkteam/zenrpc/v2" "github.com/vmkteam/zenrpc/v2/testdata" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +// --- Setup & helpers --- + var ( testRPC = zenrpc.NewServer(zenrpc.Options{BatchMaxLen: 5, AllowCORS: true}) logRequests = false @@ -28,13 +37,54 @@ func TestMain(m *testing.M) { os.Exit(result) } +// doRPC calls srv.Do and unmarshals the response into zenrpc.Response. +func doRPC(t *testing.T, srv *zenrpc.Server, req string) zenrpc.Response { + t.Helper() + + b, err := srv.Do(context.Background(), []byte(req)) + require.NoError(t, err, "Do() error") + + var r zenrpc.Response + require.NoError(t, json.Unmarshal(b, &r), "Unmarshal response, raw: %s", b) + + return r +} + +// doRPCRaw calls srv.Do and unmarshals the response into a raw JSON map, +// useful for testing presence/absence of specific JSON keys. +func doRPCRaw(t *testing.T, srv *zenrpc.Server, req string) map[string]json.RawMessage { + t.Helper() + + b, err := srv.Do(context.Background(), []byte(req)) + require.NoError(t, err, "Do() error") + + var m map[string]json.RawMessage + require.NoError(t, json.Unmarshal(b, &m), "Unmarshal raw response, raw: %s", b) + + return m +} + +// buildBatchRequest builds a JSON-RPC 2.0 batch request with n multiply calls. +func buildBatchRequest(n int) []byte { + var buf bytes.Buffer + buf.WriteByte('[') + for i := range n { + if i > 0 { + buf.WriteByte(',') + } + fmt.Fprintf(&buf, `{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":%d,"b":2},"id":%d}`, i+1, i) + } + buf.WriteByte(']') + return buf.Bytes() +} + +// --- SMD --- + func TestServer_SMD(t *testing.T) { r := testRPC.SMD() - if b, err := json.Marshal(r); err != nil { - t.Fatal(err) - } else if !bytes.Contains(b, []byte("default")) { - t.Error(string(b)) - } + b, err := json.Marshal(r) + require.NoError(t, err) + assert.Contains(t, string(b), "default") } func TestServer_SmdGenerate(t *testing.T) { @@ -44,14 +94,343 @@ func TestServer_SmdGenerate(t *testing.T) { rpc.Register("printer", testdata.PrintService{}) rpc.Register("", testdata.ArithService{}) - b, _ := json.MarshalIndent(rpc.SMD(), "", " ") + b, err := json.MarshalIndent(rpc.SMD(), "", " ") + require.NoError(t, err) testData, err := os.ReadFile("./testdata/testdata/arithsrv-smd.json") - if err != nil { - t.Fatalf("open test data file") + require.NoError(t, err, "open test data file") + + assert.Equal(t, string(testData), string(b), "bad zenrpc output") +} + +// --- Server unit tests --- + +func TestNewServer_Defaults(t *testing.T) { + srv := zenrpc.NewServer(zenrpc.Options{}) + srv.Register("arith", &testdata.ArithService{}) + + t.Run("BatchMaxLen10OK", func(t *testing.T) { + resp, err := srv.Do(context.Background(), buildBatchRequest(10)) + require.NoError(t, err) + + var results []zenrpc.Response + require.NoError(t, json.Unmarshal(resp, &results)) + assert.Len(t, results, 10) + }) + + t.Run("BatchMaxLen11Fail", func(t *testing.T) { + r := doRPC(t, srv, string(buildBatchRequest(11))) + require.NotNil(t, r.Error) + assert.Equal(t, zenrpc.InvalidRequest, r.Error.Code) + }) + + t.Run("DefaultTargetURL", func(t *testing.T) { + srv := zenrpc.NewServer(zenrpc.Options{}) + assert.Equal(t, "/", srv.SMD().Target) + }) + + t.Run("CustomTargetURL", func(t *testing.T) { + srv := zenrpc.NewServer(zenrpc.Options{TargetURL: "/custom"}) + assert.Equal(t, "/custom", srv.SMD().Target) + }) +} + +func TestServer_RegisterAll(t *testing.T) { + srv := zenrpc.NewServer(zenrpc.Options{}) + srv.RegisterAll(map[string]zenrpc.Invoker{ + "arith": &testdata.ArithService{}, + }) + + r := doRPC(t, srv, `{"jsonrpc":"2.0","method":"arith.pi","id":1}`) + assert.Nil(t, r.Error) + assert.NotNil(t, r.Result) +} + +func TestServer_RegisterOverwrite(t *testing.T) { + srv := zenrpc.NewServer(zenrpc.Options{}) + srv.Register("svc", &testdata.ArithService{}) + + // Verify ArithService method works before overwrite. + r := doRPC(t, srv, `{"jsonrpc":"2.0","method":"svc.pi","id":1}`) + assert.Nil(t, r.Error) + assert.NotNil(t, r.Result) + + // Overwrite with PrintService — ArithService methods should no longer be available. + srv.Register("svc", &testdata.PrintService{}) + + r = doRPC(t, srv, `{"jsonrpc":"2.0","method":"svc.pi","id":2}`) + require.NotNil(t, r.Error, "old service method should not be available after overwrite") + assert.Equal(t, zenrpc.MethodNotFound, r.Error.Code) +} + +func TestIsArray_Whitespace(t *testing.T) { + tests := []struct { + name string + msg string + want bool + }{ + {"array", `[1]`, true}, + {"object", `{"a":1}`, false}, + {"space_before_array", " [1]", true}, + {"tab_before_array", "\t[1]", true}, + {"newline_before_array", "\n[1]", true}, + {"crlf_before_array", "\r\n[1]", true}, + {"space_before_object", " {\"a\":1}", false}, + {"tab_before_object", "\t{\"a\":1}", false}, + {"newline_before_object", "\n{\"a\":1}", false}, + {"empty", "", false}, } - if !bytes.Equal(b, testData) { - t.Fatalf("bad zenrpc output") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, zenrpc.IsArray(json.RawMessage(tt.msg))) + }) } } + +func TestConvertToObject_EdgeCases(t *testing.T) { + t.Run("TooManyParams", func(t *testing.T) { + keys := []string{"a"} + params := json.RawMessage(`[1,2]`) + _, err := zenrpc.ConvertToObject(keys, params) + assert.Error(t, err) + }) + + t.Run("FewerParams", func(t *testing.T) { + keys := []string{"a", "b", "c"} + params := json.RawMessage(`[1]`) + result, err := zenrpc.ConvertToObject(keys, params) + require.NoError(t, err) + assert.JSONEq(t, `{"a":1}`, string(result)) + }) + + t.Run("InvalidJSON", func(t *testing.T) { + keys := []string{"a"} + params := json.RawMessage(`not json`) + _, err := zenrpc.ConvertToObject(keys, params) + assert.Error(t, err) + }) + + t.Run("EmptyArray", func(t *testing.T) { + keys := []string{"a", "b"} + params := json.RawMessage(`[]`) + result, err := zenrpc.ConvertToObject(keys, params) + require.NoError(t, err) + assert.JSONEq(t, `{}`, string(result)) + }) +} + +func TestContextFunctions(t *testing.T) { + t.Run("EmptyContext", func(t *testing.T) { + ctx := context.Background() + + _, ok := zenrpc.RequestFromContext(ctx) + assert.False(t, ok) + + assert.Empty(t, zenrpc.NamespaceFromContext(ctx)) + assert.Nil(t, zenrpc.IDFromContext(ctx)) + assert.Nil(t, zenrpc.BatchMethodsFromContext(ctx)) + }) + + t.Run("RequestRoundTrip", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/", nil) + ctx := zenrpc.NewRequestContext(context.Background(), req) + + got, ok := zenrpc.RequestFromContext(ctx) + require.True(t, ok) + assert.Equal(t, req, got) + }) +} + +// --- JSON-RPC 2.0 spec compliance --- + +func TestDo_IDFormats(t *testing.T) { + tests := []struct { + name string + id string + }{ + {"string", `"42"`}, + {"negative", "-1"}, + {"fractional", "1.5"}, + {"large", "999999"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := fmt.Sprintf(`{"jsonrpc":"2.0","method":"arith.pi","id":%s}`, tt.id) + r := doRPC(t, testRPC, req) + assert.JSONEq(t, tt.id, string(*r.ID)) + assert.NotNil(t, r.Result, "result should be present") + }) + } +} + +// TestDo_NotificationNullID verifies that "id":null is treated as a notification +// (no response body). Go json.Unmarshal sets *json.RawMessage to nil for JSON null. +func TestDo_NotificationNullID(t *testing.T) { + resp, err := testRPC.Do(context.Background(), []byte(`{"jsonrpc":"2.0","method":"arith.pi","id":null}`)) + require.NoError(t, err) + assert.Equal(t, "null", string(resp)) +} + +func TestDo_Version(t *testing.T) { + tests := []struct { + name string + req string + }{ + {"version_1.0", `{"jsonrpc":"1.0","method":"arith.pi","id":1}`}, + {"version_absent", `{"method":"arith.pi","id":1}`}, + {"version_empty", `{"jsonrpc":"","method":"arith.pi","id":1}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := doRPC(t, testRPC, tt.req) + require.NotNil(t, r.Error, "expected error response") + assert.Equal(t, zenrpc.InvalidRequest, r.Error.Code) + }) + } +} + +func TestDo_RpcDotReserved(t *testing.T) { + r := doRPC(t, testRPC, `{"jsonrpc":"2.0","method":"rpc.discover","id":1}`) + require.NotNil(t, r.Error) + assert.Equal(t, zenrpc.MethodNotFound, r.Error.Code) +} + +func TestDo_ParamsEdgeCases(t *testing.T) { + tests := []struct { + name string + req string + }{ + {"params_null", `{"jsonrpc":"2.0","method":"arith.pi","params":null,"id":1}`}, + {"params_empty_object", `{"jsonrpc":"2.0","method":"arith.pi","params":{},"id":1}`}, + {"params_empty_array", `{"jsonrpc":"2.0","method":"arith.pi","params":[],"id":1}`}, + {"params_omitted", `{"jsonrpc":"2.0","method":"arith.pi","id":1}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := doRPC(t, testRPC, tt.req) + assert.Nil(t, r.Error, "expected success") + assert.NotNil(t, r.Result) + }) + } +} + +func TestDo_BatchEdgeCases(t *testing.T) { + t.Run("empty_array", func(t *testing.T) { + // [] → single InvalidRequest (spec: "not an Array with at least one value") + r := doRPC(t, testRPC, `[]`) + require.NotNil(t, r.Error) + assert.Equal(t, zenrpc.InvalidRequest, r.Error.Code) + }) + + t.Run("non_object_element", func(t *testing.T) { + // [1] → ParseError; Go cannot unmarshal the integer 1 into a Request struct, + // so the entire batch fails to parse (deviation from spec which expects per-element InvalidRequest). + r := doRPC(t, testRPC, `[1]`) + require.NotNil(t, r.Error) + assert.Equal(t, zenrpc.ParseError, r.Error.Code) + }) + + t.Run("single_element_batch", func(t *testing.T) { + // Batch with one element must return an array with one response, not a bare object. + resp, err := testRPC.Do(context.Background(), []byte(`[{"jsonrpc":"2.0","method":"arith.pi","id":1}]`)) + require.NoError(t, err) + + var results []zenrpc.Response + require.NoError(t, json.Unmarshal(resp, &results), "response must be a JSON array") + require.Len(t, results, 1) + assert.Nil(t, results[0].Error) + assert.NotNil(t, results[0].Result) + }) +} + +func TestDo_ResponseStructure(t *testing.T) { + t.Run("SuccessHasResultNoError", func(t *testing.T) { + m := doRPCRaw(t, testRPC, `{"jsonrpc":"2.0","method":"arith.pi","id":1}`) + assert.Contains(t, m, "result") + assert.NotContains(t, m, "error") + assert.Equal(t, `"2.0"`, string(m["jsonrpc"])) + }) + + t.Run("ErrorHasErrorNoResult", func(t *testing.T) { + m := doRPCRaw(t, testRPC, `{"jsonrpc":"2.0","method":"arith.checkerror","params":{"isErr":true},"id":1}`) + assert.Contains(t, m, "error") + assert.NotContains(t, m, "result") + assert.Equal(t, `"2.0"`, string(m["jsonrpc"])) + }) + + t.Run("NullResultPresent", func(t *testing.T) { + // checkerror with isErr=false returns nil error → result:null, no error field. + m := doRPCRaw(t, testRPC, `{"jsonrpc":"2.0","method":"arith.checkerror","params":{"isErr":false},"id":1}`) + assert.Contains(t, m, "result", `"result" must be present even when null`) + assert.NotContains(t, m, "error") + assert.Equal(t, "null", string(m["result"])) + }) + + t.Run("AlwaysHasVersion", func(t *testing.T) { + // Even error responses must have jsonrpc:"2.0". + m := doRPCRaw(t, testRPC, `{"jsonrpc":"2.0","method":"nonexistent.method","id":1}`) + assert.Equal(t, `"2.0"`, string(m["jsonrpc"])) + }) + + t.Run("ParseErrorHasNullID", func(t *testing.T) { + m := doRPCRaw(t, testRPC, `{invalid json`) + assert.Contains(t, m, "id") + assert.Equal(t, "null", string(m["id"])) + }) +} + +// --- Concurrency --- + +func TestDo_Parallel(t *testing.T) { + t.Parallel() + const n = 100 + + for i := range n { + a := i + 1 + t.Run(fmt.Sprintf("g%d", i), func(t *testing.T) { + t.Parallel() + + req := fmt.Sprintf(`{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":%d,"b":2},"id":%d}`, a, a) + r := doRPC(t, testRPC, req) + + var result int + require.NoError(t, json.Unmarshal(*r.Result, &result)) + assert.Equal(t, a*2, result) + }) + } +} + +func TestDo_BatchParallel(t *testing.T) { + t.Parallel() + const n = 50 + + for i := range n { + t.Run(fmt.Sprintf("g%d", i), func(t *testing.T) { + t.Parallel() + + resp, err := testRPC.Do(context.Background(), buildBatchRequest(2)) + require.NoError(t, err) + + var results []zenrpc.Response + require.NoError(t, json.Unmarshal(resp, &results)) + assert.Len(t, results, 2) + }) + } +} + +func TestDo_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + // Server must handle a cancelled context without panic and return valid JSON-RPC. + resp, err := testRPC.Do(ctx, []byte(`{"jsonrpc":"2.0","method":"arith.pi","id":1}`)) + require.NoError(t, err) + + var r zenrpc.Response + require.NoError(t, json.Unmarshal(resp, &r)) + assert.Equal(t, zenrpc.Version, r.Version) +} From 7497f31158b4c821edba8f189e6ea5a5c790c6a5 Mon Sep 17 00:00:00 2001 From: Sergey Bykov Date: Sat, 14 Feb 2026 11:40:20 +0300 Subject: [PATCH 2/2] Add Fuzzing test --- .github/workflows/go.yml | 4 +- Makefile | 6 ++ fuzz_test.go | 217 +++++++++++++++++++++++++++++++++++++++ server.go | 4 +- 4 files changed, 228 insertions(+), 3 deletions(-) create mode 100644 fuzz_test.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index f2be515..f9ac09a 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -28,4 +28,6 @@ jobs: run: go test -race -coverprofile=coverage.txt -covermode=atomic - name: Upload coverage to Codecov - run: bash <(curl -s https://codecov.io/bash) + uses: codecov/codecov-action@v5 + with: + files: ./coverage.txt diff --git a/Makefile b/Makefile index 36bf367..ab95166 100644 --- a/Makefile +++ b/Makefile @@ -14,6 +14,12 @@ test: bench: @go test -bench=. -benchmem -run=^$$ ./... +fuzz: + @go test -fuzz FuzzServerDo -fuzztime 30s + @go test -fuzz FuzzIsArray -fuzztime 30s + @go test -fuzz FuzzConvertToObject -fuzztime 30s + @go test -fuzz FuzzServeHTTP -fuzztime 30s + mod: @go mod tidy diff --git a/fuzz_test.go b/fuzz_test.go new file mode 100644 index 0000000..dcc7b3d --- /dev/null +++ b/fuzz_test.go @@ -0,0 +1,217 @@ +package zenrpc_test + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/vmkteam/zenrpc/v2" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// FuzzServerDo feeds arbitrary bytes into Server.Do(). +// Invariants: +// - must never panic +// - must always return valid JSON (or nil for notifications) +// - response must always contain "jsonrpc":"2.0" +func FuzzServerDo(f *testing.F) { + // seed corpus: valid requests + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.pi","id":1}`)) + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":3,"b":2},"id":0}`)) + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.multiply","params":[3,2],"id":0}`)) + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.divide","params":{"a":1,"b":0},"id":1}`)) + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.pow","params":{"base":3},"id":0}`)) + + // seed corpus: notifications (no id) + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":3,"b":2}}`)) + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.pi"}`)) + + // seed corpus: batch requests + f.Add([]byte(`[{"jsonrpc":"2.0","method":"arith.pi","id":1}]`)) + f.Add([]byte(`[{"jsonrpc":"2.0","method":"arith.pi","id":1},{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":1,"b":2},"id":2}]`)) + f.Add([]byte(`[]`)) + + // seed corpus: malformed JSON + f.Add([]byte(`{"jsonrpc": "2.0", "method": "foobar, "params": "bar", "baz]`)) + f.Add([]byte(`{invalid json`)) + f.Add([]byte(`not json at all`)) + f.Add([]byte(``)) + f.Add([]byte(`null`)) + f.Add([]byte(`123`)) + f.Add([]byte(`"string"`)) + f.Add([]byte(`true`)) + + // seed corpus: invalid requests + f.Add([]byte(`{"jsonrpc":"1.0","method":"arith.pi","id":1}`)) + f.Add([]byte(`{"jsonrpc":"2.0","params":{"a":1},"id":1}`)) + f.Add([]byte(`{"jsonrpc":"2.0","method":"nonexistent","id":1}`)) + f.Add([]byte(`{"jsonrpc":"2.0","method":"rpc.discover","id":1}`)) + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.pow","params":{"base":"3"},"id":0}`)) + + // seed corpus: edge-case IDs + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.pi","id":"string-id"}`)) + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.pi","id":-1}`)) + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.pi","id":1.5}`)) + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.pi","id":null}`)) + + // seed corpus: batch with non-object elements + f.Add([]byte(`[1]`)) + f.Add([]byte(`[1,2,3]`)) + f.Add([]byte(`[null]`)) + + // seed corpus: deeply nested / large + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":999999999999999999,"b":999999999999999999},"id":1}`)) + + f.Fuzz(func(t *testing.T, data []byte) { + resp, err := testRPC.Do(context.Background(), data) + require.NoError(t, err, "Do() returned error") + + // nil or "null" response is valid (notification or all-notification batch) + if resp == nil || string(resp) == "null" { + return + } + + // response must be valid JSON + require.True(t, json.Valid(resp), "Do() returned invalid JSON: %s", resp) + + // response must contain jsonrpc version + // could be a single response or a batch array + if resp[0] == '[' { + var batch []json.RawMessage + require.NoError(t, json.Unmarshal(resp, &batch), "batch response is not a JSON array") + for i, item := range batch { + assertValidResponse(t, i, item) + } + } else { + assertValidResponse(t, 0, resp) + } + }) +} + +// assertValidResponse checks that a single JSON-RPC response has required fields. +func assertValidResponse(t *testing.T, idx int, data []byte) { + t.Helper() + + var m map[string]json.RawMessage + require.NoError(t, json.Unmarshal(data, &m), "response[%d] is not a JSON object", idx) + require.Contains(t, m, "jsonrpc", "response[%d] missing 'jsonrpc' field", idx) + assert.Equal(t, `"2.0"`, string(m["jsonrpc"]), "response[%d] jsonrpc version", idx) + + _, hasResult := m["result"] + _, hasError := m["error"] + + // must have exactly one of result or error + assert.False(t, hasResult && hasError, "response[%d] has both 'result' and 'error'", idx) + assert.True(t, hasResult || hasError, "response[%d] has neither 'result' nor 'error'", idx) +} + +// FuzzIsArray checks that IsArray never panics on arbitrary input. +func FuzzIsArray(f *testing.F) { + f.Add([]byte(`[1]`)) + f.Add([]byte(`{"a":1}`)) + f.Add([]byte(` [1]`)) + f.Add([]byte("\t[1]")) + f.Add([]byte("\n[1]")) + f.Add([]byte("\r\n[1]")) + f.Add([]byte(` {"a":1}`)) + f.Add([]byte(``)) + f.Add([]byte(`null`)) + f.Add([]byte(`123`)) + f.Add([]byte{0x00}) + f.Add([]byte{0xff, 0xfe}) + + f.Fuzz(func(t *testing.T, data []byte) { + // must not panic; result is either true or false + result := zenrpc.IsArray(json.RawMessage(data)) + assert.IsType(t, true, result) + }) +} + +// FuzzConvertToObject checks that ConvertToObject never panics +// and returns valid JSON on success. +func FuzzConvertToObject(f *testing.F) { + f.Add([]byte(`[1,2]`)) + f.Add([]byte(`[3]`)) + f.Add([]byte(`[]`)) + f.Add([]byte(`[1,2,3,4,5]`)) + f.Add([]byte(`["a","b"]`)) + f.Add([]byte(`[null,null]`)) + f.Add([]byte(`[true,false]`)) + f.Add([]byte(`[1.5,2.5]`)) + f.Add([]byte(`[{"nested":1},[1,2]]`)) + f.Add([]byte(`not json`)) + f.Add([]byte(``)) + f.Add([]byte(`null`)) + f.Add([]byte(`{}`)) + + keys := []string{"a", "b", "c"} + + f.Fuzz(func(t *testing.T, data []byte) { + result, err := zenrpc.ConvertToObject(keys, json.RawMessage(data)) + if err != nil { + return // errors are fine, panics are not + } + + // on success, result must be valid JSON + require.True(t, json.Valid(result), "ConvertToObject returned invalid JSON: %s", result) + + // result must be a JSON object + var m map[string]json.RawMessage + require.NoError(t, json.Unmarshal(result, &m), "ConvertToObject result is not an object") + }) +} + +// FuzzServeHTTP tests the full HTTP handler with arbitrary request bodies. +// Invariants: +// - must never panic +// - must return valid HTTP status code +// - non-empty response body must be valid JSON +func FuzzServeHTTP(f *testing.F) { + // seed corpus: valid requests + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.pi","id":1}`)) + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":3,"b":2},"id":0}`)) + f.Add([]byte(`[{"jsonrpc":"2.0","method":"arith.pi","id":1}]`)) + + // seed corpus: malformed + f.Add([]byte(`{invalid}`)) + f.Add([]byte(``)) + f.Add([]byte(`null`)) + f.Add([]byte(`[]`)) + + // seed corpus: edge cases + f.Add([]byte(`{"jsonrpc":"2.0","method":"arith.multiply","params":{"a":3,"b":2}}`)) // notification + f.Add([]byte(`{"jsonrpc":"1.0","method":"arith.pi","id":1}`)) // wrong version + f.Add([]byte(`{"jsonrpc":"2.0","method":"","id":1}`)) // empty method + + f.Fuzz(func(t *testing.T, body []byte) { + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + testRPC.ServeHTTP(rec, req) + + resp := rec.Result() + defer resp.Body.Close() + + // status must be valid + assert.True(t, resp.StatusCode >= 100 && resp.StatusCode < 600, + "invalid HTTP status: %d", resp.StatusCode) + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err, "failed to read response body") + + // empty body is valid (notifications) + if len(respBody) == 0 { + return + } + + // non-empty body must be valid JSON + assert.True(t, json.Valid(respBody), "ServeHTTP returned invalid JSON: %s", respBody) + }) +} diff --git a/server.go b/server.go index 137adc2..15034ae 100644 --- a/server.go +++ b/server.go @@ -132,7 +132,7 @@ func (s *Server) SetLogger(printer Printer) { } // process processes JSON-RPC 2.0 message, invokes correct method for namespace and returns JSON-RPC 2.0 Response. -func (s *Server) process(ctx context.Context, message json.RawMessage) interface{} { +func (s *Server) process(ctx context.Context, message json.RawMessage) any { // parsing batch requests batch := IsArray(message) @@ -145,7 +145,7 @@ func (s *Server) process(ctx context.Context, message json.RawMessage) interface if req.ID == nil { // notification — fire and forget - s.processRequest(ctx, req) + go s.processRequest(ctx, req) return nil }