From 1a42aed772ae85bc868fc867238d06eccf1c6e69 Mon Sep 17 00:00:00 2001 From: Feny Mehta Date: Fri, 13 Mar 2026 12:50:35 +0530 Subject: [PATCH] Unit Test Cases addition Signed-off-by: Feny Mehta --- go.mod | 6 ++ go.sum | 16 ++- pkg/metrics/metrics_test.go | 61 +++++++++++ pkg/middleware/logging_test.go | 173 +++++++++++++++++++++++++++++++ pkg/middleware/metrics_test.go | 183 +++++++++++++++++++++++++++++++++ pkg/version/version_test.go | 45 ++++++++ 6 files changed, 482 insertions(+), 2 deletions(-) create mode 100644 pkg/metrics/metrics_test.go create mode 100644 pkg/middleware/logging_test.go create mode 100644 pkg/middleware/metrics_test.go create mode 100644 pkg/version/version_test.go diff --git a/go.mod b/go.mod index 1cddaf2..f1a34e3 100644 --- a/go.mod +++ b/go.mod @@ -7,13 +7,18 @@ toolchain go1.24.13 require ( github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/prometheus/client_golang v1.22.0 + github.com/stretchr/testify v1.11.1 ) require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/jsonschema-go v0.3.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect @@ -21,4 +26,5 @@ require ( golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sys v0.30.0 // indirect google.golang.org/protobuf v1.36.5 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f107951..33fa335 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= @@ -10,6 +11,12 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/modelcontextprotocol/go-sdk v1.2.0 h1:Y23co09300CEk8iZ/tMxIX1dVmKZkzoSBZOpJwUnc/s= github.com/modelcontextprotocol/go-sdk v1.2.0/go.mod h1:6fM3LCm3yV7pAs8isnKLn07oKtB0MP9LHd3DfAcKw10= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= @@ -24,8 +31,10 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= @@ -36,5 +45,8 @@ golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/metrics/metrics_test.go b/pkg/metrics/metrics_test.go new file mode 100644 index 0000000..ff22ad5 --- /dev/null +++ b/pkg/metrics/metrics_test.go @@ -0,0 +1,61 @@ +package metrics_test + +import ( + "testing" + + "github.com/codeready-toolchain/mcp-common/pkg/metrics" + + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMCPCallsTotal(t *testing.T) { + t.Run("counter is registered and can be incremented", func(t *testing.T) { + // given + require.NotNil(t, metrics.MCPCallsTotal) + counter := metrics.MCPCallsTotal.WithLabelValues("test-server", "tools/call", "my-tool", "true") + before := testutil.ToFloat64(counter) + + // when + counter.Inc() + + // then + after := testutil.ToFloat64(counter) + assert.InDelta(t, float64(1), after-before, 0) + }) + + t.Run("counter tracks different label combinations independently", func(t *testing.T) { + // given + require.NotNil(t, metrics.MCPCallsTotal) + successCounter := metrics.MCPCallsTotal.WithLabelValues("test-server", "tools/call", "tool-a", "true") + failureCounter := metrics.MCPCallsTotal.WithLabelValues("test-server", "tools/call", "tool-a", "false") + + beforeSuccess := testutil.ToFloat64(successCounter) + beforeFailure := testutil.ToFloat64(failureCounter) + + // when + successCounter.Inc() + successCounter.Inc() + failureCounter.Inc() + + // then + assert.InDelta(t, float64(2), testutil.ToFloat64(successCounter)-beforeSuccess, 0) + assert.InDelta(t, float64(1), testutil.ToFloat64(failureCounter)-beforeFailure, 0) + }) +} + +func TestMCPCallDuration(t *testing.T) { + t.Run("histogram is registered and can observe values", func(t *testing.T) { + // given + require.NotNil(t, metrics.MCPCallDuration) + histogram := metrics.MCPCallDuration.WithLabelValues("test-server", "tools/call", "my-tool", "true") + + // when + histogram.Observe(0.5) + histogram.Observe(1.0) + + // then (no panic means histogram is working) + assert.NotNil(t, histogram) + }) +} diff --git a/pkg/middleware/logging_test.go b/pkg/middleware/logging_test.go new file mode 100644 index 0000000..957de27 --- /dev/null +++ b/pkg/middleware/logging_test.go @@ -0,0 +1,173 @@ +package middleware_test + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "log/slog" + "testing" + + "github.com/codeready-toolchain/mcp-common/pkg/middleware" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewLoggingMiddleware(t *testing.T) { + t.Run("logs CallToolRequest with tool name on success", func(t *testing.T) { + // given + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + + nextCalled := false + expectedResult := &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "result"}}, + } + next := func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + nextCalled = true + return expectedResult, nil + } + + req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Session: &mcp.ServerSession{}, + Params: &mcp.CallToolParamsRaw{ + Name: "my-tool", + Arguments: json.RawMessage(`{"key":"value"}`), + }, + } + + mw := middleware.NewLoggingMiddleware(logger) + handler := mw(next) + + // when + result, err := handler(context.Background(), "tools/call", req) + + // then + require.NoError(t, err) + assert.True(t, nextCalled) + assert.Equal(t, expectedResult, result) + + logs := buf.String() + assert.Contains(t, logs, "MCP method started") + assert.Contains(t, logs, "my-tool") + assert.Contains(t, logs, "MCP call completed") + assert.Contains(t, logs, "tools/call") + assert.NotContains(t, logs, "MCP call failed") + }) + + t.Run("logs CallToolRequest without arguments", func(t *testing.T) { + // given + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + + next := func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + return &mcp.CallToolResult{}, nil + } + + req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Session: &mcp.ServerSession{}, + Params: &mcp.CallToolParamsRaw{Name: "no-args-tool"}, + } + + handler := middleware.NewLoggingMiddleware(logger)(next) + + // when + _, err := handler(context.Background(), "tools/call", req) + + // then + require.NoError(t, err) + + logs := buf.String() + assert.Contains(t, logs, "no-args-tool") + assert.Contains(t, logs, `"has_args":false`) + }) + + t.Run("logs error when next handler returns error", func(t *testing.T) { + // given + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + + expectedErr := errors.New("something went wrong") + next := func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + return nil, expectedErr + } + + req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Session: &mcp.ServerSession{}, + Params: &mcp.CallToolParamsRaw{Name: "failing-tool"}, + } + + handler := middleware.NewLoggingMiddleware(logger)(next) + + // when + result, err := handler(context.Background(), "tools/call", req) + + // then + assert.Nil(t, result) + require.ErrorIs(t, err, expectedErr) + + logs := buf.String() + assert.Contains(t, logs, "MCP call failed") + assert.Contains(t, logs, "something went wrong") + assert.NotContains(t, logs, "MCP call completed") + }) + + t.Run("logs non-CallToolRequest with has_params", func(t *testing.T) { + // given + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug})) + + next := func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + return nil, nil + } + + req := &mcp.ServerRequest[*mcp.ListToolsParams]{ + Session: &mcp.ServerSession{}, + Params: &mcp.ListToolsParams{}, + } + + handler := middleware.NewLoggingMiddleware(logger)(next) + + // when + _, err := handler(context.Background(), "tools/list", req) + + // then + require.NoError(t, err) + + logs := buf.String() + assert.Contains(t, logs, "MCP method started") + assert.Contains(t, logs, "tools/list") + assert.Contains(t, logs, "has_params") + assert.NotContains(t, logs, "MCP call failed") + }) + + t.Run("next handler result is passed through", func(t *testing.T) { + // given + var buf bytes.Buffer + logger := slog.New(slog.NewJSONHandler(&buf, nil)) + + expectedResult := &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "hello"}}, + IsError: false, + } + next := func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + return expectedResult, nil + } + + req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Session: &mcp.ServerSession{}, + Params: &mcp.CallToolParamsRaw{Name: "pass-through-tool"}, + } + + handler := middleware.NewLoggingMiddleware(logger)(next) + + // when + result, err := handler(context.Background(), "tools/call", req) + + // then + require.NoError(t, err) + assert.Equal(t, expectedResult, result) + }) +} diff --git a/pkg/middleware/metrics_test.go b/pkg/middleware/metrics_test.go new file mode 100644 index 0000000..1b07326 --- /dev/null +++ b/pkg/middleware/metrics_test.go @@ -0,0 +1,183 @@ +package middleware_test + +import ( + "context" + "encoding/json" + "errors" + "log/slog" + "testing" + + "github.com/codeready-toolchain/mcp-common/pkg/metrics" + "github.com/codeready-toolchain/mcp-common/pkg/middleware" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewMetricsMiddleware(t *testing.T) { + logger := slog.Default() + + t.Run("increments counter on successful tool call", func(t *testing.T) { + // given + counter := metrics.MCPCallsTotal.WithLabelValues("test-server", "tools/call", "my-tool", "true") + before := testutil.ToFloat64(counter) + + next := func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "ok"}}, + }, nil + } + + req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Session: &mcp.ServerSession{}, + Params: &mcp.CallToolParamsRaw{ + Name: "my-tool", + Arguments: json.RawMessage(`{}`), + }, + } + + handler := middleware.NewMetricsMiddleware("test-server", logger)(next) + + // when + result, err := handler(context.Background(), "tools/call", req) + + // then + require.NoError(t, err) + assert.NotNil(t, result) + assert.InDelta(t, float64(1), testutil.ToFloat64(counter)-before, 0) + }) + + t.Run("records failure when next handler returns error", func(t *testing.T) { + // given + failCounter := metrics.MCPCallsTotal.WithLabelValues("test-server", "tools/call", "error-tool", "false") + before := testutil.ToFloat64(failCounter) + + next := func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + return nil, errors.New("handler error") + } + + req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Session: &mcp.ServerSession{}, + Params: &mcp.CallToolParamsRaw{ + Name: "error-tool", + }, + } + + handler := middleware.NewMetricsMiddleware("test-server", logger)(next) + + // when + _, err := handler(context.Background(), "tools/call", req) + + // then + require.Error(t, err) + assert.InDelta(t, float64(1), testutil.ToFloat64(failCounter)-before, 0) + }) + + t.Run("records failure when CallToolResult has IsError true", func(t *testing.T) { + // given + failCounter := metrics.MCPCallsTotal.WithLabelValues("test-server", "tools/call", "is-error-tool", "false") + before := testutil.ToFloat64(failCounter) + + next := func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "error output"}}, + IsError: true, + }, nil + } + + req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Session: &mcp.ServerSession{}, + Params: &mcp.CallToolParamsRaw{ + Name: "is-error-tool", + }, + } + + handler := middleware.NewMetricsMiddleware("test-server", logger)(next) + + // when + result, err := handler(context.Background(), "tools/call", req) + + // then + require.NoError(t, err) + assert.NotNil(t, result) + assert.InDelta(t, float64(1), testutil.ToFloat64(failCounter)-before, 0) + }) + + t.Run("records duration histogram", func(t *testing.T) { + // given + next := func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "ok"}}, + }, nil + } + + req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Session: &mcp.ServerSession{}, + Params: &mcp.CallToolParamsRaw{ + Name: "duration-tool", + }, + } + + handler := middleware.NewMetricsMiddleware("test-server", logger)(next) + + // when + _, err := handler(context.Background(), "tools/call", req) + + // then + require.NoError(t, err) + histogram := metrics.MCPCallDuration.WithLabelValues("test-server", "tools/call", "duration-tool", "true") + assert.NotNil(t, histogram) + }) + + t.Run("handles non-tool-call method with empty tool name", func(t *testing.T) { + // given + counter := metrics.MCPCallsTotal.WithLabelValues("test-server", "tools/list", "", "true") + before := testutil.ToFloat64(counter) + + next := func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + return nil, nil + } + + req := &mcp.ServerRequest[*mcp.ListToolsParams]{ + Session: &mcp.ServerSession{}, + Params: &mcp.ListToolsParams{}, + } + + handler := middleware.NewMetricsMiddleware("test-server", logger)(next) + + // when + _, err := handler(context.Background(), "tools/list", req) + + // then + require.NoError(t, err) + assert.InDelta(t, float64(1), testutil.ToFloat64(counter)-before, 0) + }) + + t.Run("uses correct server name in metrics labels", func(t *testing.T) { + // given + counter := metrics.MCPCallsTotal.WithLabelValues("custom-server", "tools/call", "label-tool", "true") + before := testutil.ToFloat64(counter) + + next := func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "ok"}}, + }, nil + } + + req := &mcp.ServerRequest[*mcp.CallToolParamsRaw]{ + Session: &mcp.ServerSession{}, + Params: &mcp.CallToolParamsRaw{Name: "label-tool"}, + } + + handler := middleware.NewMetricsMiddleware("custom-server", logger)(next) + + // when + _, err := handler(context.Background(), "tools/call", req) + + // then + require.NoError(t, err) + assert.InDelta(t, float64(1), testutil.ToFloat64(counter)-before, 0) + }) +} diff --git a/pkg/version/version_test.go b/pkg/version/version_test.go new file mode 100644 index 0000000..e9d807e --- /dev/null +++ b/pkg/version/version_test.go @@ -0,0 +1,45 @@ +package version_test + +import ( + "testing" + + "github.com/codeready-toolchain/mcp-common/pkg/version" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultVersionValues(t *testing.T) { + t.Run("Commit has default value", func(t *testing.T) { + // given/when + commit := version.Commit + + // then + assert.Equal(t, "unknown", commit) + }) + + t.Run("BuildTime has default value", func(t *testing.T) { + // given/when + buildTime := version.BuildTime + + // then + assert.Equal(t, "unknown", buildTime) + }) +} + +func TestVersionValuesCanBeOverridden(t *testing.T) { + // given + originalCommit := version.Commit + originalBuildTime := version.BuildTime + defer func() { + version.Commit = originalCommit + version.BuildTime = originalBuildTime + }() + + // when + version.Commit = "abc123" + version.BuildTime = "2025-01-01T00:00:00Z" + + // then + assert.Equal(t, "abc123", version.Commit) + assert.Equal(t, "2025-01-01T00:00:00Z", version.BuildTime) +}