diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..8e4aa4f --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "gopls": { + "buildFlags": [ + "-tags=race" + ] + } +} \ No newline at end of file diff --git a/README.md b/README.md index c28ff71..40abf8c 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ SRouter is a high-performance HTTP router for Go that wraps [julienschmidt/httpr - **Advanced Features** - [IP Configuration](./docs/ip-configuration.md) - [Rate Limiting](./docs/rate-limiting.md) + - [WebSocket Support](#websocket-support) - [Authentication](./docs/authentication.md) - [Context Management](./docs/context-management.md) - [Custom Error Handling](./docs/error-handling.md) @@ -321,6 +322,26 @@ func GetUserHandler(w http.ResponseWriter, r *http.Request) { } ``` +### WebSocket Support + +SRouter supports WebSocket connections by allowing you to disable the automatic request timeout for specific routes. This is crucial for long-lived connections. + +To enable WebSocket support for a route, set the `DisableTimeout` flag to `true` in your `RouteConfigBase`. This will prevent the global or sub-router timeout from terminating the connection. This is also useful for other long-lived connections such as Server-Sent Events (SSE). + +```go +// Register a WebSocket route +r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + DisableTimeout: true, // Disables timeout for this route + Handler: func(w http.ResponseWriter, r *http.Request) { + // Upgrade the connection to a WebSocket + // conn, err := upgrader.Upgrade(w, r, nil) + // ... handle connection ... + }, +}) +``` + ### Trace ID Logging SRouter provides built-in support for trace ID logging, which allows you to correlate log entries across different parts of your application for a single request. Each request is assigned a unique trace ID (UUID) that is automatically included in all log entries when `EnableTraceLogging` is true. @@ -1224,6 +1245,7 @@ type RouteConfigBase struct { Overrides common.RouteOverrides // Optional per-route overrides Handler http.HandlerFunc // Standard HTTP handler function Middlewares []common.Middleware // Middlewares applied to this specific route + DisableTimeout bool // Indicates if the timeout should be disabled for this route (e.g., for WebSockets or long-lived connections). } ``` diff --git a/examples/websocket/main.go b/examples/websocket/main.go new file mode 100644 index 0000000..a1c3d06 --- /dev/null +++ b/examples/websocket/main.go @@ -0,0 +1,146 @@ +package main + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "net/url" + "time" + + "github.com/Suhaibinator/SRouter/pkg/router" + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + // Allow all origins for this example + CheckOrigin: func(r *http.Request) bool { return true }, +} + +func main() { + // 1. Setup Server + logger, _ := zap.NewProduction() + defer logger.Sync() + + routerConfig := router.RouterConfig{ + ServiceName: "websocket-example", + Logger: logger, + GlobalTimeout: 5 * time.Second, // Global timeout to test DisableTimeout bypass + } + + // Simple auth - accept everything + authFunc := func(ctx context.Context, token string) (*string, bool) { + user := "generic-user" + return &user, true + } + userIdFunc := func(user *string) string { return *user } + + r := router.NewRouter(routerConfig, authFunc, userIdFunc) + + // REST Endpoint + r.RegisterRoute(router.RouteConfigBase{ + Path: "/hello", + Methods: []router.HttpMethod{router.MethodGet}, + Handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Hello, World!")) + }, + }) + + // WebSocket Endpoint + r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + DisableTimeout: true, // Crucial: disables global timeout + Handler: func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Error("upgrade failed", zap.Error(err)) + return + } + defer conn.Close() + + for { + messageType, p, err := conn.ReadMessage() + if err != nil { + return + } + // Echo message back + if err := conn.WriteMessage(messageType, p); err != nil { + return + } + } + }, + }) + + // Start server in goroutine + port := "8089" + server := &http.Server{Addr: ":" + port, Handler: r} + + go func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("ListenAndServe(): %v", err) + } + }() + fmt.Printf("Server started on port %s\n", port) + + // Give server a moment to start + time.Sleep(100 * time.Millisecond) + + // 2. Test Client Logic + testREST(port) + testWebSocket(port) + + // Shutdown + server.Shutdown(context.Background()) + fmt.Println("Done.") +} + +func testREST(port string) { + fmt.Println("--- Testing REST Endpoint ---") + resp, err := http.Get(fmt.Sprintf("http://localhost:%s/hello", port)) + if err != nil { + log.Fatalf("REST request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Fatalf("REST expected status 200, got %d", resp.StatusCode) + } + + body, _ := io.ReadAll(resp.Body) + fmt.Printf("REST Response: %s\n", string(body)) + fmt.Println("REST Test Passed!") +} + +func testWebSocket(port string) { + fmt.Println("--- Testing WebSocket Endpoint ---") + u := url.URL{Scheme: "ws", Host: "localhost:" + port, Path: "/ws"} + + c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Fatalf("WebSocket dial failed: %v", err) + } + defer c.Close() + + msg := "hello websocket" + err = c.WriteMessage(websocket.TextMessage, []byte(msg)) + if err != nil { + log.Fatalf("WebSocket write failed: %v", err) + } + + _, message, err := c.ReadMessage() + if err != nil { + log.Fatalf("WebSocket read failed: %v", err) + } + + fmt.Printf("WebSocket Response: %s\n", string(message)) + if string(message) != msg { + log.Fatalf("WebSocket expected echo '%s', got '%s'", msg, string(message)) + } + fmt.Println("WebSocket Test Passed!") +} diff --git a/go.mod b/go.mod index ea335bc..224ee55 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,14 @@ go 1.24.0 require ( github.com/julienschmidt/httprouter v1.3.0 - go.uber.org/zap v1.27.0 + go.uber.org/zap v1.27.1 ) require ( github.com/google/uuid v1.6.0 - github.com/stretchr/testify v1.10.0 - gorm.io/gorm v1.30.1 + github.com/gorilla/websocket v1.5.3 + github.com/stretchr/testify v1.11.1 + gorm.io/gorm v1.31.1 ) require ( @@ -18,7 +19,8 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/text v0.28.0 // indirect + go.yaml.in/yaml/v2 v2.4.3 // indirect + golang.org/x/text v0.32.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) @@ -28,14 +30,14 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/prometheus/client_model v0.6.2 - github.com/prometheus/common v0.65.0 // indirect - github.com/prometheus/procfs v0.17.0 // indirect + github.com/prometheus/common v0.67.4 // indirect + github.com/prometheus/procfs v0.19.2 // indirect go.uber.org/ratelimit v0.3.1 - golang.org/x/sys v0.35.0 // indirect - google.golang.org/protobuf v1.36.7 + golang.org/x/sys v0.39.0 // indirect + google.golang.org/protobuf v1.36.11 ) require ( - github.com/prometheus/client_golang v1.23.0 + github.com/prometheus/client_golang v1.23.2 go.uber.org/multierr v1.11.0 // indirect ) diff --git a/go.sum b/go.sum index 9ab7e1b..ca0d78e 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= @@ -28,24 +30,18 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= -github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= -github.com/prometheus/client_golang v1.23.0 h1:ust4zpdl9r4trLY/gSjlm07PuiBq2ynaXXlptpfy8Uc= -github.com/prometheus/client_golang v1.23.0/go.mod h1:i/o0R9ByOnHX0McrTMTyhYvKE4haaf2mW08I+jGAjEE= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= -github.com/prometheus/common v0.64.0 h1:pdZeA+g617P7oGv1CzdTzyeShxAGrTBsolKNOLQPGO4= -github.com/prometheus/common v0.64.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= -github.com/prometheus/common v0.65.0 h1:QDwzd+G1twt//Kwj/Ww6E9FQq1iVMmODnILtW1t2VzE= -github.com/prometheus/common v0.65.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8= -github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= -github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= -github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= -github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= +github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc= +github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI= +github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= +github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -54,30 +50,20 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/ratelimit v0.3.1 h1:K4qVE+byfv/B3tC+4nYWP7v/6SimcO7HzHekoMNBma0= go.uber.org/ratelimit v0.3.1/go.mod h1:6euWsTB6U/Nb3X++xEUXA8ciPJvr19Q/0h1+oDcJhRk= -go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= -go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= -golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= -google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= -google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= +go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs= -gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= -gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= -gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/pkg/router/config.go b/pkg/router/config.go index 4c8f701..5ae1f0d 100644 --- a/pkg/router/config.go +++ b/pkg/router/config.go @@ -189,12 +189,13 @@ type SubRouterConfig struct { // - Sub-router settings override global settings // - Middlewares are additive (not replaced) type RouteConfigBase struct { - Path string // Route path (will be prefixed with sub-router path prefix if applicable) - Methods []HttpMethod // HTTP methods this route handles (use constants like MethodGet) - AuthLevel *AuthLevel // Authentication level for this route. If nil, inherits from sub-router or defaults to NoAuth - Overrides common.RouteOverrides // Configuration overrides for this specific route - Handler http.HandlerFunc // Standard HTTP handler function - Middlewares []common.Middleware // Middlewares applied to this specific route (combined with sub-router and global middlewares) + Path string // Route path (will be prefixed with sub-router path prefix if applicable) + Methods []HttpMethod // HTTP methods this route handles (use constants like MethodGet) + AuthLevel *AuthLevel // Authentication level for this route. If nil, inherits from sub-router or defaults to NoAuth + Overrides common.RouteOverrides // Configuration overrides for this specific route + Handler http.HandlerFunc // Standard HTTP handler function + Middlewares []common.Middleware // Middlewares applied to this specific route (combined with sub-router and global middlewares) + DisableTimeout bool // Indicates if the timeout should be disabled for this route (e.g., for WebSockets or long-lived connections). } // Implement RouteDefinition for RouteConfigBase diff --git a/pkg/router/handler_error_test.go b/pkg/router/handler_error_test.go index 25eccd8..afc5120 100644 --- a/pkg/router/handler_error_test.go +++ b/pkg/router/handler_error_test.go @@ -23,7 +23,7 @@ func TestGenericRouteHandlerError(t *testing.T) { return 0 } - router := NewRouter[int, interface{}](RouterConfig{ + router := NewRouter(RouterConfig{ Logger: zap.NewNop(), }, getUserByID, getUserID) @@ -177,11 +177,11 @@ func TestHandlerErrorWithMultipleMiddleware(t *testing.T) { getUserByID := func(ctx context.Context, userID string) (*interface{}, bool) { return nil, false } - getUserID := func(user *interface{}) int { + getUserID := func(user *any) int { return 0 } - router := NewRouter[int, interface{}](RouterConfig{ + router := NewRouter(RouterConfig{ Logger: zap.NewNop(), }, getUserByID, getUserID) diff --git a/pkg/router/route.go b/pkg/router/route.go index a82f7a2..cf826b1 100644 --- a/pkg/router/route.go +++ b/pkg/router/route.go @@ -25,6 +25,12 @@ import ( func (r *Router[T, U]) RegisterRoute(route RouteConfigBase) { // Get effective timeout, max body size, and rate limit for this route timeout := r.getEffectiveTimeout(route.Overrides.Timeout, 0) + + // If route has timeout disabled, set timeout to 0 + if route.DisableTimeout { + timeout = 0 + } + maxBodySize := r.getEffectiveMaxBodySize(route.Overrides.MaxBodySize, 0) // Pass the specific route config (which is *common.RateLimitConfig[any, any]) // to getEffectiveRateLimit. The conversion happens inside getEffectiveRateLimit. diff --git a/pkg/router/route_query_decode_error_test.go b/pkg/router/route_query_decode_error_test.go new file mode 100644 index 0000000..74e95f7 --- /dev/null +++ b/pkg/router/route_query_decode_error_test.go @@ -0,0 +1,94 @@ +package router + +import ( + "encoding/base64" + "encoding/json" + "math/big" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Suhaibinator/SRouter/pkg/codec" + "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func encodeBase62(b []byte) string { + const alphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + + n := new(big.Int).SetBytes(b) + if n.Sign() == 0 { + return "0" + } + + base := big.NewInt(62) + mod := new(big.Int) + + var out []byte + for n.Sign() > 0 { + n.DivMod(n, base, mod) + out = append(out, alphabet[mod.Int64()]) + } + + for i, j := 0, len(out)-1; i < j; i, j = i+1, j-1 { + out[i], out[j] = out[j], out[i] + } + return string(out) +} + +func TestRegisterGenericRoute_Base64QueryParameter_DecodeBytesError(t *testing.T) { + logger := zap.NewNop() + r := NewRouter(RouterConfig{Logger: logger}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + RegisterGenericRoute(r, RouteConfig[RequestType, ResponseType]{ + Path: "/test", + Methods: []HttpMethod{MethodGet}, + Codec: codec.NewJSONCodec[RequestType, ResponseType](), + SourceType: Base64QueryParameter, + SourceKey: "qdata", + Handler: func(r *http.Request, req RequestType) (ResponseType, error) { + t.Fatalf("handler should not be called on decode error") + return ResponseType{}, nil + }, + }, 0, 0, nil) + + invalidJSONBase64 := base64.StdEncoding.EncodeToString([]byte("{invalid json")) + req := httptest.NewRequest(http.MethodGet, "/test?qdata="+invalidJSONBase64, nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code) + + var body map[string]map[string]string + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) + require.Equal(t, "Failed to decode query parameter data", body["error"]["message"]) +} + +func TestRegisterGenericRoute_Base62QueryParameter_DecodeBytesError(t *testing.T) { + logger := zap.NewNop() + r := NewRouter(RouterConfig{Logger: logger}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + RegisterGenericRoute(r, RouteConfig[RequestType, ResponseType]{ + Path: "/test", + Methods: []HttpMethod{MethodGet}, + Codec: codec.NewJSONCodec[RequestType, ResponseType](), + SourceType: Base62QueryParameter, + SourceKey: "qdata", + Handler: func(r *http.Request, req RequestType) (ResponseType, error) { + t.Fatalf("handler should not be called on decode error") + return ResponseType{}, nil + }, + }, 0, 0, nil) + + invalidJSONBase62 := encodeBase62([]byte("{invalid json")) + req := httptest.NewRequest(http.MethodGet, "/test?qdata="+invalidJSONBase62, nil) + rr := httptest.NewRecorder() + r.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code) + + var body map[string]map[string]string + require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &body)) + require.Equal(t, "Failed to decode query parameter data", body["error"]["message"]) +} diff --git a/pkg/router/router.go b/pkg/router/router.go index 38b71da..3b605d8 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -3,10 +3,12 @@ package router import ( + "bufio" "context" "encoding/json" // Added for JSON marshalling "errors" "fmt" + "net" "net/http" "slices" // Added for CORS "strconv" // Added for CORS @@ -210,6 +212,12 @@ func (r *Router[T, U]) registerSubRouter(sr SubRouterConfig) { // Get effective settings considering overrides timeout := r.getEffectiveTimeout(route.Overrides.Timeout, sr.Overrides.Timeout) + + // If route has timeout disabled, set timeout to 0 + if route.DisableTimeout { + timeout = 0 + } + maxBodySize := r.getEffectiveMaxBodySize(route.Overrides.MaxBodySize, sr.Overrides.MaxBodySize) rateLimit := r.getEffectiveRateLimit(route.Overrides.RateLimit, sr.Overrides.RateLimit) authLevel := route.AuthLevel // Use route-specific first @@ -415,19 +423,48 @@ func (r *Router[T, U]) timeoutMiddleware(timeout time.Duration) common.Middlewar fields = r.addTrace(fields, req) r.logger.Error("Request timed out", fields...) - // Acquire lock to safely check and potentially write timeout response. - wrappedW.mu.Lock() - // Check if handler already started writing. Use Swap for atomic check-and-set. - if !wrappedW.wroteHeader.Swap(true) { - // Handler hasn't written yet, we can write the timeout error. - // Hold the lock while writing headers and body for timeout. - // Use the new JSON error writer, passing the request - traceID := scontext.GetTraceIDFromRequest[T, U](req) - r.writeJSONError(wrappedW.ResponseWriter, req, http.StatusRequestTimeout, "Request Timeout", traceID) + // If the handler already started writing, don't attempt to take over the response. + // Wait for the handler to finish to avoid returning while another goroutine is writing. + if wrappedW.wroteHeader.Load() { + <-done + select { + case p := <-panicChan: + panic(p) + default: + } + return + } + + // Mark timed out so any in-flight handler writes fail fast and don't touch the underlying writer. + wrappedW.timedOut.Store(true) + + // Reserve the response so the handler can't race to write its own error response. + if !wrappedW.wroteHeader.CompareAndSwap(false, true) { + <-done + select { + case p := <-panicChan: + panic(p) + default: + } + return } - // If wroteHeader was already true, handler won the race, do nothing here. - // Unlock should happen regardless of whether we wrote the error or not. + + // Serialize the timeout response write with any handler goroutine currently inside rw methods. + wrappedW.mu.Lock() + traceID := scontext.GetTraceIDFromRequest[T, U](req) + r.writeJSONError(wrappedW.ResponseWriter, req, http.StatusRequestTimeout, "Request Timeout", traceID) wrappedW.mu.Unlock() + + // Give the handler a chance to observe cancellation and exit promptly. + select { + case <-done: + select { + case p := <-panicChan: + panic(p) + default: + } + case <-time.After(50 * time.Millisecond): + } return } }) @@ -786,6 +823,13 @@ type baseResponseWriter struct { http.ResponseWriter } +// Unwrap returns the underlying ResponseWriter. +// This enables Go 1.20+'s http.ResponseController to reach optional interfaces (e.g. Flusher, Hijacker) +// implemented by the original writer when this writer is wrapped. +func (bw *baseResponseWriter) Unwrap() http.ResponseWriter { + return bw.ResponseWriter +} + // WriteHeader calls the underlying ResponseWriter's WriteHeader. func (bw *baseResponseWriter) WriteHeader(statusCode int) { bw.ResponseWriter.WriteHeader(statusCode) @@ -803,6 +847,16 @@ func (bw *baseResponseWriter) Flush() { } } +// Hijack delegates to the underlying ResponseWriter when it supports http.Hijacker. +// This is required for WebSocket upgrades to work through ResponseWriter wrappers. +func (bw *baseResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := bw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("underlying ResponseWriter (%T) does not support hijacking: %w", bw.ResponseWriter, http.ErrNotSupported) + } + return h.Hijack() +} + // metricsResponseWriter is a wrapper around http.ResponseWriter that captures metrics. // It tracks the status code, bytes written, and timing information for each response. type metricsResponseWriter[T comparable, U any] struct { @@ -988,6 +1042,56 @@ func (r *Router[T, U]) handleError(w http.ResponseWriter, req *http.Request, err // It includes the trace ID in the JSON payload if available and enabled. // It also adds CORS headers based on information stored in the context by the CORS middleware. func (r *Router[T, U]) writeJSONError(w http.ResponseWriter, req *http.Request, statusCode int, message string, traceID string) { // Add req parameter + if mrw, ok := w.(*mutexResponseWriter); ok { + if mrw.timedOut.Load() { + return + } + if !mrw.wroteHeader.CompareAndSwap(false, true) { + return + } + + mrw.mu.Lock() + defer mrw.mu.Unlock() + + allowedOrigin, credentialsAllowed, corsOK := scontext.GetCORSInfoFromRequest[T, U](req) + header := mrw.ResponseWriter.Header() + + if corsOK { + if allowedOrigin != "" { + header.Set("Access-Control-Allow-Origin", allowedOrigin) + } + if credentialsAllowed { + header.Set("Access-Control-Allow-Credentials", "true") + } + if allowedOrigin != "" && allowedOrigin != "*" { + header.Add("Vary", "Origin") + } + } + + header.Set("Content-Type", "application/json; charset=utf-8") + mrw.ResponseWriter.WriteHeader(statusCode) + + errorPayload := map[string]any{ + "error": map[string]string{ + "message": message, + }, + } + if r.config.TraceIDBufferSize > 0 && traceID != "" { + errorMap := errorPayload["error"].(map[string]string) + errorMap["trace_id"] = traceID + } + + if err := json.NewEncoder(mrw.ResponseWriter).Encode(errorPayload); err != nil { + r.logger.Error("Failed to write JSON error response", + zap.Error(err), + zap.Int("original_status", statusCode), + zap.String("original_message", message), + zap.String("trace_id", traceID), + ) + } + return + } + // Retrieve CORS info from context using the passed-in request allowedOrigin, credentialsAllowed, corsOK := scontext.GetCORSInfoFromRequest[T, U](req) @@ -1195,10 +1299,14 @@ type mutexResponseWriter struct { http.ResponseWriter mu *sync.Mutex wroteHeader atomic.Bool // Tracks if WriteHeader or Write has been called + timedOut atomic.Bool // When true, reject all writes to the underlying writer } // Header acquires the mutex and returns the underlying Header map. func (rw *mutexResponseWriter) Header() http.Header { + if rw.timedOut.Load() { + return make(http.Header) + } rw.mu.Lock() defer rw.mu.Unlock() return rw.ResponseWriter.Header() @@ -1206,6 +1314,9 @@ func (rw *mutexResponseWriter) Header() http.Header { // WriteHeader acquires the mutex, marks headers as written, and calls the underlying ResponseWriter.WriteHeader. func (rw *mutexResponseWriter) WriteHeader(statusCode int) { + if rw.timedOut.Load() { + return + } rw.mu.Lock() defer rw.mu.Unlock() if !rw.wroteHeader.Swap(true) { // Atomically set flag and check previous value @@ -1216,6 +1327,9 @@ func (rw *mutexResponseWriter) WriteHeader(statusCode int) { // Write acquires the mutex, marks headers/body as written, and calls the underlying ResponseWriter.Write. func (rw *mutexResponseWriter) Write(b []byte) (int, error) { + if rw.timedOut.Load() { + return 0, http.ErrHandlerTimeout + } rw.mu.Lock() defer rw.mu.Unlock() rw.wroteHeader.Store(true) // Mark as written (headers might be implicitly written here) @@ -1224,6 +1338,9 @@ func (rw *mutexResponseWriter) Write(b []byte) (int, error) { // Flush acquires the mutex and calls the underlying ResponseWriter.Flush if it implements http.Flusher. func (rw *mutexResponseWriter) Flush() { + if rw.timedOut.Load() { + return + } rw.mu.Lock() defer rw.mu.Unlock() if f, ok := rw.ResponseWriter.(http.Flusher); ok { diff --git a/pkg/router/timeout_middleware_race_test.go b/pkg/router/timeout_middleware_race_test.go new file mode 100644 index 0000000..b8ffa20 --- /dev/null +++ b/pkg/router/timeout_middleware_race_test.go @@ -0,0 +1,113 @@ +//go:build race + +package router + +import ( + "context" + "net/http" + "net/http/httptest" + "runtime" + "testing" + "time" + + "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" + "go.uber.org/zap" +) + +func TestTimeoutMiddleware_WhenHandlerWritesBetweenHeaderCheckAndTimeoutStore_TakeoverCASFails(t *testing.T) { + oldProcs := runtime.GOMAXPROCS(2) + defer runtime.GOMAXPROCS(oldProcs) + + r := NewRouter(RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + timeout := 50 * time.Microsecond + + deadline := time.Now().Add(5 * time.Second) + attempts := 0 + + for time.Now().Before(deadline) { + attempts++ + + mrwCh := make(chan *mutexResponseWriter, 1) + ctxErrCh := make(chan error, 1) + + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if mrw, ok := w.(*mutexResponseWriter); ok { + select { + case mrwCh <- mrw: + default: + } + } + + <-req.Context().Done() + ctxErrCh <- req.Context().Err() + + w.WriteHeader(http.StatusAccepted) + }) + + h := r.timeoutMiddleware(timeout)(handler) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + mrw := <-mrwCh + if rr.Code == http.StatusAccepted && mrw.timedOut.Load() { + select { + case err := <-ctxErrCh: + if err != context.DeadlineExceeded { + t.Fatalf("expected context deadline exceeded, got %v", err) + } + default: + t.Fatalf("expected handler to observe context cancellation") + } + return + } + } + + t.Fatalf("did not observe timeout takeover CAS failure within deadline (attempts=%d)", attempts) +} + +func TestTimeoutMiddleware_WhenHandlerPanicsInCASFailurePath_RethrowsToRecovery(t *testing.T) { + oldProcs := runtime.GOMAXPROCS(2) + defer runtime.GOMAXPROCS(oldProcs) + + r := NewRouter(RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + timeout := 50 * time.Microsecond + + deadline := time.Now().Add(5 * time.Second) + attempts := 0 + + for time.Now().Before(deadline) { + attempts++ + + mrwCh := make(chan *mutexResponseWriter, 1) + + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if mrw, ok := w.(*mutexResponseWriter); ok { + select { + case mrwCh <- mrw: + default: + } + } + <-req.Context().Done() + w.WriteHeader(http.StatusAccepted) + panic("boom") + }) + + h := r.recoveryMiddleware(r.timeoutMiddleware(timeout)(handler)) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + + mrw := <-mrwCh + if rr.Code == http.StatusAccepted && mrw.timedOut.Load() { + if msg := parseJSONErrorMessage(t, rr.Body.Bytes()); msg != "Internal Server Error" { + t.Fatalf("expected internal server error payload, got %q", msg) + } + return + } + } + + t.Fatalf("did not observe panic rethrow in CAS-failure path within deadline (attempts=%d)", attempts) +} diff --git a/pkg/router/timeout_middleware_test.go b/pkg/router/timeout_middleware_test.go new file mode 100644 index 0000000..cfca017 --- /dev/null +++ b/pkg/router/timeout_middleware_test.go @@ -0,0 +1,101 @@ +package router + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" + "go.uber.org/zap" +) + +func parseJSONErrorMessage(t *testing.T, body []byte) string { + t.Helper() + + var payload struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("expected JSON error payload, got %q: %v", string(body), err) + } + return payload.Error.Message +} + +func TestTimeoutMiddleware_WhenHandlerStartedWriting_DoesNotOverrideResponse(t *testing.T) { + r := NewRouter(RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + timeout := 25 * time.Millisecond + wroteHeader := make(chan struct{}) + ctxErrCh := make(chan error, 1) + + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusCreated) + close(wroteHeader) + + <-req.Context().Done() + ctxErrCh <- req.Context().Err() + time.Sleep(10 * time.Millisecond) + + _, _ = w.Write([]byte("handler-finished")) + }) + + h := r.recoveryMiddleware(r.timeoutMiddleware(timeout)(handler)) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + rr := httptest.NewRecorder() + + select { + case <-wroteHeader: + t.Fatalf("handler should not have executed before ServeHTTP") + default: + } + + h.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("expected status %d, got %d", http.StatusCreated, rr.Code) + } + if rr.Body.String() != "handler-finished" { + t.Fatalf("expected body %q, got %q", "handler-finished", rr.Body.String()) + } + + select { + case err := <-ctxErrCh: + if err != context.DeadlineExceeded { + t.Fatalf("expected context deadline exceeded, got %v", err) + } + default: + t.Fatalf("expected handler to observe context cancellation") + } +} + +func TestTimeoutMiddleware_WhenHandlerPanicsAfterTimeoutAndStartedWrite_RethrowsToRecovery(t *testing.T) { + r := NewRouter(RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + timeout := 15 * time.Millisecond + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusTeapot) + <-req.Context().Done() + time.Sleep(10 * time.Millisecond) + panic("boom") + }) + + h := r.recoveryMiddleware(r.timeoutMiddleware(timeout)(handler)) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + rr := httptest.NewRecorder() + + h.ServeHTTP(rr, req) + + if rr.Code != http.StatusTeapot { + t.Fatalf("expected status %d, got %d", http.StatusTeapot, rr.Code) + } + if msg := parseJSONErrorMessage(t, rr.Body.Bytes()); msg != "Internal Server Error" { + t.Fatalf("expected internal server error payload, got %q", msg) + } +} diff --git a/pkg/router/websocket_test.go b/pkg/router/websocket_test.go new file mode 100644 index 0000000..1c5b707 --- /dev/null +++ b/pkg/router/websocket_test.go @@ -0,0 +1,364 @@ +package router_test + +import ( + "bufio" + "errors" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Suhaibinator/SRouter/pkg/common" + "github.com/Suhaibinator/SRouter/pkg/router" + "go.uber.org/zap" +) + +type hijackableRecorder struct { + *httptest.ResponseRecorder + hijacked bool + serverConn net.Conn + clientConn net.Conn + + readDeadline time.Time + writeDeadline time.Time + fullDuplexEnabled bool +} + +func newHijackableRecorder() *hijackableRecorder { + return &hijackableRecorder{ResponseRecorder: httptest.NewRecorder()} +} + +func (rw *hijackableRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if rw.serverConn != nil || rw.clientConn != nil { + return nil, nil, errors.New("connection already hijacked") + } + + rw.hijacked = true + rw.clientConn, rw.serverConn = net.Pipe() + return rw.serverConn, bufio.NewReadWriter(bufio.NewReader(rw.serverConn), bufio.NewWriter(rw.serverConn)), nil +} + +func (rw *hijackableRecorder) SetReadDeadline(deadline time.Time) error { + rw.readDeadline = deadline + return nil +} + +func (rw *hijackableRecorder) SetWriteDeadline(deadline time.Time) error { + rw.writeDeadline = deadline + return nil +} + +func (rw *hijackableRecorder) EnableFullDuplex() error { + rw.fullDuplexEnabled = true + return nil +} + +func (rw *hijackableRecorder) Close() { + if rw.serverConn != nil { + _ = rw.serverConn.Close() + } + if rw.clientConn != nil { + _ = rw.clientConn.Close() + } +} + +func TestWebSocketRoute(t *testing.T) { + logger := zap.NewNop() + config := router.RouterConfig{ + Logger: logger, + GlobalTimeout: 100 * time.Millisecond, + } + + r := router.NewRouter[string, string](config, nil, nil) + + // Register a "WebSocket" route that sleeps longer than the global timeout + // Since DisableTimeout is true, it should NOT timeout. + r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + DisableTimeout: true, + Handler: func(w http.ResponseWriter, r *http.Request) { + // Simulate long-lived connection + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }, + }) + + // Register a normal route that SHOULD timeout + r.RegisterRoute(router.RouteConfigBase{ + Path: "/normal", + Methods: []router.HttpMethod{router.MethodGet}, + Handler: func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }, + }) + + server := httptest.NewServer(r) + defer server.Close() + + client := server.Client() + + // Test WebSocket Route + t.Run("WebSocket Route should not timeout", func(t *testing.T) { + start := time.Now() + resp, err := client.Get(server.URL + "/ws") + duration := time.Since(start) + + if err != nil { + t.Fatalf("/ws request failed: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("/ws: expected 200 OK, got %d", resp.StatusCode) + } + if duration < 200*time.Millisecond { + t.Errorf("/ws: completed too fast (%v), sleep didn't happen?", duration) + } + }) + + // Test Normal Route (Control Case) + t.Run("Normal Route should timeout", func(t *testing.T) { + resp, err := client.Get(server.URL + "/normal") + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusRequestTimeout { + t.Errorf("/normal: expected 408 Timeout, got %d", resp.StatusCode) + } + }) +} + +func TestWebSocketRoutePreservesHijackerWithTracingEnabled(t *testing.T) { + logger := zap.NewNop() + config := router.RouterConfig{ + Logger: logger, + GlobalTimeout: 100 * time.Millisecond, + TraceIDBufferSize: 1, + } + + r := router.NewRouter[string, string](config, nil, nil) + + var sawHijacker bool + var hijackErr error + + r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + DisableTimeout: true, + Handler: func(w http.ResponseWriter, _ *http.Request) { + h, ok := w.(http.Hijacker) + if !ok { + hijackErr = errors.New("response writer does not implement http.Hijacker") + return + } + sawHijacker = true + + conn, _, err := h.Hijack() + if err != nil { + hijackErr = err + return + } + _ = conn.Close() + }, + }) + + req := httptest.NewRequest(http.MethodGet, "/ws", nil) + rr := newHijackableRecorder() + defer rr.Close() + + r.ServeHTTP(rr, req) + + if !sawHijacker { + t.Fatalf("expected handler to receive an http.Hijacker when tracing is enabled") + } + if hijackErr != nil { + t.Fatalf("expected Hijack to succeed, got %v", hijackErr) + } + if !rr.hijacked { + t.Fatalf("expected Hijack to be delegated to the underlying response writer") + } +} + +func TestWebSocketRouteHijackNotSupportedIsWrapped(t *testing.T) { + logger := zap.NewNop() + config := router.RouterConfig{ + Logger: logger, + GlobalTimeout: 100 * time.Millisecond, + TraceIDBufferSize: 1, + } + + r := router.NewRouter[string, string](config, nil, nil) + + var hijackErr error + + r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + DisableTimeout: true, + Handler: func(w http.ResponseWriter, _ *http.Request) { + h, ok := w.(http.Hijacker) + if !ok { + hijackErr = errors.New("response writer does not implement http.Hijacker") + return + } + + _, _, hijackErr = h.Hijack() + }, + }) + + req := httptest.NewRequest(http.MethodGet, "/ws", nil) + rr := httptest.NewRecorder() + + r.ServeHTTP(rr, req) + + if hijackErr == nil { + t.Fatalf("expected Hijack to fail") + } + if !errors.Is(hijackErr, http.ErrNotSupported) { + t.Fatalf("expected errors.Is(hijackErr, http.ErrNotSupported) to be true, got %v", hijackErr) + } + if !strings.Contains(hijackErr.Error(), "does not support hijacking") { + t.Fatalf("expected Hijack error to include context, got %q", hijackErr.Error()) + } +} + +func TestWebSocketRouteResponseControllerCanReachOptionalInterfaces(t *testing.T) { + logger := zap.NewNop() + config := router.RouterConfig{ + Logger: logger, + GlobalTimeout: 100 * time.Millisecond, + TraceIDBufferSize: 1, // ensures the router wraps the ResponseWriter + } + + r := router.NewRouter[string, string](config, nil, nil) + + var controllerErr error + var sawDeadlines bool + var sawFullDuplex bool + + r.RegisterRoute(router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + DisableTimeout: true, + Handler: func(w http.ResponseWriter, _ *http.Request) { + rc := http.NewResponseController(w) + deadline := time.Now().Add(5 * time.Second) + + if err := rc.SetReadDeadline(deadline); err != nil { + controllerErr = err + return + } + if err := rc.SetWriteDeadline(deadline); err != nil { + controllerErr = err + return + } + if err := rc.EnableFullDuplex(); err != nil { + controllerErr = err + return + } + + // Also exercise Hijack through ResponseController, which is commonly used by WebSocket implementations. + conn, _, err := rc.Hijack() + if err != nil { + controllerErr = err + return + } + _ = conn.Close() + + sawDeadlines = true + sawFullDuplex = true + }, + }) + + req := httptest.NewRequest(http.MethodGet, "/ws", nil) + rr := newHijackableRecorder() + defer rr.Close() + + r.ServeHTTP(rr, req) + + if controllerErr != nil { + t.Fatalf("expected ResponseController methods to succeed, got %v", controllerErr) + } + if !sawDeadlines || rr.readDeadline.IsZero() || rr.writeDeadline.IsZero() { + t.Fatalf("expected ResponseController to reach SetReadDeadline/SetWriteDeadline on the underlying writer") + } + if !sawFullDuplex || !rr.fullDuplexEnabled { + t.Fatalf("expected ResponseController to reach EnableFullDuplex on the underlying writer") + } + if !rr.hijacked { + t.Fatalf("expected Hijack to be delegated to the underlying response writer") + } +} + +func TestSubRouterWebSocketRoute(t *testing.T) { + logger := zap.NewNop() + config := router.RouterConfig{ + Logger: logger, + GlobalTimeout: 100 * time.Millisecond, + SubRouters: []router.SubRouterConfig{ + { + PathPrefix: "/sub", + Overrides: common.RouteOverrides{ + Timeout: 50 * time.Millisecond, + }, + Routes: []router.RouteDefinition{ + router.RouteConfigBase{ + Path: "/ws", + Methods: []router.HttpMethod{router.MethodGet}, + DisableTimeout: true, + Handler: func(w http.ResponseWriter, r *http.Request) { + // Simulate long-lived connection + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }, + }, + router.RouteConfigBase{ + Path: "/normal", + Methods: []router.HttpMethod{router.MethodGet}, + Handler: func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }, + }, + }, + }, + }, + } + + r := router.NewRouter[string, string](config, nil, nil) + + server := httptest.NewServer(r) + defer server.Close() + + client := server.Client() + + // Test SubRouter WebSocket Route + t.Run("SubRouter WebSocket Route should not timeout", func(t *testing.T) { + start := time.Now() + resp, err := client.Get(server.URL + "/sub/ws") + duration := time.Since(start) + + if err != nil { + t.Fatalf("/sub/ws request failed: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("/sub/ws: expected 200 OK, got %d", resp.StatusCode) + } + if duration < 200*time.Millisecond { + t.Errorf("/sub/ws: completed too fast (%v), sleep didn't happen?", duration) + } + }) + + // Test SubRouter Normal Route (Control Case) + t.Run("SubRouter Normal Route should timeout", func(t *testing.T) { + resp, err := client.Get(server.URL + "/sub/normal") + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusRequestTimeout { + t.Errorf("/sub/normal: expected 408 Timeout, got %d", resp.StatusCode) + } + }) +} diff --git a/pkg/router/write_json_error_test.go b/pkg/router/write_json_error_test.go new file mode 100644 index 0000000..aa28ce1 --- /dev/null +++ b/pkg/router/write_json_error_test.go @@ -0,0 +1,156 @@ +package router + +import ( + "errors" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" + "github.com/Suhaibinator/SRouter/pkg/scontext" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +func TestWriteJSONError_MutexResponseWriter_SetsCORSHeaders(t *testing.T) { + tests := []struct { + name string + allowedOrigin string + credentialsAllowed bool + wantVaryOrigin bool + }{ + { + name: "specific_origin_with_credentials_sets_vary", + allowedOrigin: "https://example.com", + credentialsAllowed: true, + wantVaryOrigin: true, + }, + { + name: "wildcard_origin_no_vary", + allowedOrigin: "*", + credentialsAllowed: false, + wantVaryOrigin: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := NewRouter[string, string](RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + req = req.WithContext(scontext.WithCORSInfo[string, string](req.Context(), tc.allowedOrigin, tc.credentialsAllowed)) + + rr := httptest.NewRecorder() + var mu sync.Mutex + mrw := &mutexResponseWriter{ResponseWriter: rr, mu: &mu} + + r.writeJSONError(mrw, req, http.StatusBadRequest, "Bad Request", "") + + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != tc.allowedOrigin { + t.Fatalf("expected Access-Control-Allow-Origin %q, got %q", tc.allowedOrigin, got) + } + + if tc.credentialsAllowed { + if got := rr.Header().Get("Access-Control-Allow-Credentials"); got != "true" { + t.Fatalf("expected Access-Control-Allow-Credentials %q, got %q", "true", got) + } + } else if got := rr.Header().Get("Access-Control-Allow-Credentials"); got != "" { + t.Fatalf("expected no Access-Control-Allow-Credentials header, got %q", got) + } + + if tc.wantVaryOrigin { + if got := rr.Header().Get("Vary"); got != "Origin" { + t.Fatalf("expected Vary %q, got %q", "Origin", got) + } + } else if got := rr.Header().Get("Vary"); got != "" { + t.Fatalf("expected no Vary header, got %q", got) + } + }) + } +} + +type errResponseWriter struct { + header http.Header +} + +func (w *errResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *errResponseWriter) WriteHeader(statusCode int) {} + +func (w *errResponseWriter) Write([]byte) (int, error) { + return 0, errors.New("write failed") +} + +func TestWriteJSONError_MutexResponseWriter_LogsOnEncodeFailure(t *testing.T) { + core, logs := observer.New(zap.ErrorLevel) + logger := zap.New(core) + r := NewRouter[string, string](RouterConfig{Logger: logger, TraceIDBufferSize: 1}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + + var mu sync.Mutex + mrw := &mutexResponseWriter{ResponseWriter: &errResponseWriter{}, mu: &mu} + + r.writeJSONError(mrw, req, http.StatusInternalServerError, "Internal Server Error", "trace-123") + + entries := logs.All() + if len(entries) != 1 { + t.Fatalf("expected 1 log entry, got %d", len(entries)) + } + if entries[0].Message != "Failed to write JSON error response" { + t.Fatalf("expected log message %q, got %q", "Failed to write JSON error response", entries[0].Message) + } + + var foundStatus, foundMessage, foundTrace bool + for _, f := range entries[0].Context { + switch f.Key { + case "original_status": + foundStatus = f.Integer == int64(http.StatusInternalServerError) + case "original_message": + foundMessage = f.String == "Internal Server Error" + case "trace_id": + foundTrace = f.String == "trace-123" + } + } + + if !foundStatus { + t.Fatalf("expected original_status field to be present") + } + if !foundMessage { + t.Fatalf("expected original_message field to be present") + } + if !foundTrace { + t.Fatalf("expected trace_id field to be present") + } +} + +func TestWriteJSONError_MutexResponseWriter_NoOpWhenHeaderAlreadyWritten(t *testing.T) { + r := NewRouter[string, string](RouterConfig{Logger: zap.NewNop()}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + req := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + rr := httptest.NewRecorder() + rr.Header().Set("X-Existing", "1") + + var mu sync.Mutex + mrw := &mutexResponseWriter{ResponseWriter: rr, mu: &mu} + + // Simulate the handler having already started the response. + mrw.WriteHeader(http.StatusCreated) + + r.writeJSONError(mrw, req, http.StatusBadRequest, "Bad Request", "trace-ignored") + + require.Equal(t, http.StatusCreated, rr.Code) + require.Equal(t, "1", rr.Header().Get("X-Existing")) + require.Equal(t, "", rr.Header().Get("Content-Type")) + require.Equal(t, "", rr.Header().Get("Access-Control-Allow-Origin")) + require.Equal(t, "", rr.Header().Get("Access-Control-Allow-Credentials")) + require.Equal(t, "", rr.Header().Get("Vary")) + require.Equal(t, "", rr.Body.String()) +} diff --git a/pkg/scontext/copy_test.go b/pkg/scontext/copy_test.go index 71646d1..fc8092b 100644 --- a/pkg/scontext/copy_test.go +++ b/pkg/scontext/copy_test.go @@ -44,7 +44,7 @@ func createFullSRouterContext() context.Context { // Set all values in context ctx = WithUserID[int, testUser](ctx, userID) - ctx = WithUser[int, testUser](ctx, user) + ctx = WithUser[int](ctx, user) ctx = WithTraceID[int, testUser](ctx, traceID) ctx = WithClientIP[int, testUser](ctx, clientIP) ctx = WithUserAgent[int, testUser](ctx, userAgent)