From d211f2d6d7e6dd8cd54907d60f78929ae5502c74 Mon Sep 17 00:00:00 2001 From: suhaib Date: Fri, 19 Dec 2025 21:22:50 -0800 Subject: [PATCH 1/6] feat: implement authentication token source configuration and middleware support --- pkg/common/config.go | 32 +++++++ pkg/router/auth_middleware_test.go | 42 +++++++++ pkg/router/route.go | 8 +- pkg/router/router.go | 147 +++++++++++++++++++++-------- 4 files changed, 189 insertions(+), 40 deletions(-) diff --git a/pkg/common/config.go b/pkg/common/config.go index ae3f89e..ffb149f 100644 --- a/pkg/common/config.go +++ b/pkg/common/config.go @@ -2,6 +2,29 @@ package common import "time" +// AuthTokenSource defines where to extract authentication tokens from. +type AuthTokenSource int + +const ( + // AuthTokenSourceHeader reads the token from a request header. + AuthTokenSourceHeader AuthTokenSource = iota + // AuthTokenSourceCookie reads the token from a request cookie. + AuthTokenSourceCookie +) + +// AuthTokenConfig defines how to extract authentication tokens from requests. +type AuthTokenConfig struct { + // Source determines where to look for the token. + Source AuthTokenSource + + // HeaderName is used when Source is AuthTokenSourceHeader. + // If empty, defaults to "Authorization". + HeaderName string + + // CookieName is used when Source is AuthTokenSourceCookie. + CookieName string +} + // RouteOverrides contains settings that can be overridden at different levels (global, sub-router, route). // These overrides follow a hierarchy where the most specific setting takes precedence. type RouteOverrides struct { @@ -16,6 +39,10 @@ type RouteOverrides struct { // RateLimit overrides the rate limiting configuration. // A nil value means no override is set. RateLimit *RateLimitConfig[any, any] + + // AuthToken overrides the authentication token source. + // A nil value means no override is set. + AuthToken *AuthTokenConfig } // HasTimeout returns true if a timeout override is set (non-zero). @@ -32,3 +59,8 @@ func (ro *RouteOverrides) HasMaxBodySize() bool { func (ro *RouteOverrides) HasRateLimit() bool { return ro.RateLimit != nil } + +// HasAuthToken returns true if an auth token override is set (non-nil). +func (ro *RouteOverrides) HasAuthToken() bool { + return ro.AuthToken != nil +} diff --git a/pkg/router/auth_middleware_test.go b/pkg/router/auth_middleware_test.go index 8a7eade..17b65d3 100644 --- a/pkg/router/auth_middleware_test.go +++ b/pkg/router/auth_middleware_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "testing" + "github.com/Suhaibinator/SRouter/pkg/common" // Keep middleware alias if needed for other types "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" // Use centralized mocks "github.com/Suhaibinator/SRouter/pkg/scontext" // Added scontext import @@ -313,6 +314,47 @@ func TestAuthRequiredMiddlewareWithUserObject(t *testing.T) { } } +func TestAuthRequiredMiddlewareWithCookieSource(t *testing.T) { + logger := zap.NewNop() + r := NewRouter(RouterConfig{Logger: logger}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userID, ok := scontext.GetUserIDFromRequest[string, string](r) + if !ok { + t.Error("Expected user ID to be in context") + } + if userID != "user123" { + t.Errorf("Expected user ID %q, got %q", "user123", userID) + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK")) + }) + + authTokenConfig := common.AuthTokenConfig{ + Source: common.AuthTokenSourceCookie, + CookieName: "auth_token", + } + wrappedHandler := r.authRequiredMiddlewareWithConfig(authTokenConfig)(handler) + + // Test with valid auth cookie + req, _ := http.NewRequest("GET", "/test", nil) + req.AddCookie(&http.Cookie{Name: "auth_token", Value: "valid-token"}) + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) + } + + // Test with valid Authorization header but no cookie (should not fallback) + req, _ = http.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer valid-token") + rr = httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, rr.Code) + } +} + // TestAuthRequiredMiddlewareWithTraceID tests the authRequiredMiddleware function with trace ID // (from auth_required_middleware_test.go) func TestAuthRequiredMiddlewareWithTraceID(t *testing.T) { diff --git a/pkg/router/route.go b/pkg/router/route.go index cf826b1..e9b4891 100644 --- a/pkg/router/route.go +++ b/pkg/router/route.go @@ -35,9 +35,10 @@ func (r *Router[T, U]) RegisterRoute(route RouteConfigBase) { // Pass the specific route config (which is *common.RateLimitConfig[any, any]) // to getEffectiveRateLimit. The conversion happens inside getEffectiveRateLimit. rateLimit := r.getEffectiveRateLimit(route.Overrides.RateLimit, nil) + authTokenConfig := r.getEffectiveAuthTokenConfig(route.Overrides.AuthToken, nil) // Create a handler with all middlewares applied - handler := r.wrapHandler(route.Handler, route.AuthLevel, timeout, maxBodySize, rateLimit, route.Middlewares) + handler := r.wrapHandler(route.Handler, route.AuthLevel, authTokenConfig, timeout, maxBodySize, rateLimit, route.Middlewares) // Register the route with httprouter for _, method := range route.Methods { @@ -271,7 +272,8 @@ func RegisterGenericRoute[Req any, Resp any, UserID comparable, User any]( }) // Create a handler with all middlewares applied, using the effective settings passed in - wrappedHandler := r.wrapHandler(handler, route.AuthLevel, effectiveTimeout, effectiveMaxBodySize, effectiveRateLimit, route.Middlewares) + authTokenConfig := r.getEffectiveAuthTokenConfig(route.Overrides.AuthToken, nil) + wrappedHandler := r.wrapHandler(handler, route.AuthLevel, authTokenConfig, effectiveTimeout, effectiveMaxBodySize, effectiveRateLimit, route.Middlewares) // Register the route with httprouter for _, method := range route.Methods { @@ -321,6 +323,8 @@ func NewGenericRouteDefinition[Req any, Resp any, UserID comparable, User any]( // Pass the specific route config (which is *common.RateLimitConfig[any, any]) // to getEffectiveRateLimit. The conversion happens inside getEffectiveRateLimit. effectiveRateLimit := r.getEffectiveRateLimit(route.Overrides.RateLimit, sr.Overrides.RateLimit) + effectiveAuthTokenConfig := r.getEffectiveAuthTokenConfig(route.Overrides.AuthToken, sr.Overrides.AuthToken) + finalRouteConfig.Overrides.AuthToken = &effectiveAuthTokenConfig // Call the underlying generic registration function with the modified config and effective settings RegisterGenericRoute(r, finalRouteConfig, effectiveTimeout, effectiveMaxBodySize, effectiveRateLimit) diff --git a/pkg/router/router.go b/pkg/router/router.go index 3b605d8..b2e298a 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -49,6 +49,8 @@ type Router[T comparable, U any] struct { corsMaxAge string } +const defaultAuthHeaderName = "Authorization" + // RegisterSubRouterWithSubRouter registers a nested SubRouter with a parent SubRouter. // This helper function enables hierarchical route organization by adding a child // SubRouter to the parent's SubRouters slice. @@ -220,6 +222,7 @@ func (r *Router[T, U]) registerSubRouter(sr SubRouterConfig) { maxBodySize := r.getEffectiveMaxBodySize(route.Overrides.MaxBodySize, sr.Overrides.MaxBodySize) rateLimit := r.getEffectiveRateLimit(route.Overrides.RateLimit, sr.Overrides.RateLimit) + authTokenConfig := r.getEffectiveAuthTokenConfig(route.Overrides.AuthToken, sr.Overrides.AuthToken) authLevel := route.AuthLevel // Use route-specific first if authLevel == nil { authLevel = sr.AuthLevel // Fallback to sub-router default @@ -231,7 +234,7 @@ func (r *Router[T, U]) registerSubRouter(sr SubRouterConfig) { allMiddlewares = append(allMiddlewares, route.Middlewares...) // Create a handler with all middlewares applied (global middlewares are added inside wrapHandler) - handler := r.wrapHandler(route.Handler, authLevel, timeout, maxBodySize, rateLimit, allMiddlewares) + handler := r.wrapHandler(route.Handler, authLevel, authTokenConfig, timeout, maxBodySize, rateLimit, allMiddlewares) // Register the route with httprouter for _, method := range route.Methods { @@ -292,7 +295,7 @@ func (r *Router[T, U]) convertToHTTPRouterHandle(handler http.Handler, routeTemp // 7. Shutdown check and body size limit (in the base handler) // // Middlewares are combined additively, not replaced. -func (r *Router[T, U]) wrapHandler(handler http.HandlerFunc, authLevel *AuthLevel, timeout time.Duration, maxBodySize int64, rateLimit *common.RateLimitConfig[T, U], middlewares []common.Middleware) http.Handler { // Use common.RateLimitConfig +func (r *Router[T, U]) wrapHandler(handler http.HandlerFunc, authLevel *AuthLevel, authTokenConfig common.AuthTokenConfig, timeout time.Duration, maxBodySize int64, rateLimit *common.RateLimitConfig[T, U], middlewares []common.Middleware) http.Handler { // Use common.RateLimitConfig // Create a base handler that only handles shutdown check and body size limit directly // Timeout is now handled by timeoutMiddleware setting the context. h := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { @@ -334,9 +337,9 @@ func (r *Router[T, U]) wrapHandler(handler http.HandlerFunc, authLevel *AuthLeve if authLevel != nil { switch *authLevel { case AuthRequired: - chain = chain.Append(r.authRequiredMiddleware) + chain = chain.Append(r.authRequiredMiddlewareWithConfig(authTokenConfig)) case AuthOptional: - chain = chain.Append(r.authOptionalMiddleware) + chain = chain.Append(r.authOptionalMiddlewareWithConfig(authTokenConfig)) } } @@ -549,6 +552,8 @@ func RegisterGenericRouteOnSubRouter[Req any, Resp any, UserID comparable, User effectiveTimeout := r.getEffectiveTimeout(route.Overrides.Timeout, subRouterTimeout) effectiveMaxBodySize := r.getEffectiveMaxBodySize(route.Overrides.MaxBodySize, subRouterMaxBodySize) effectiveRateLimit := r.getEffectiveRateLimit(route.Overrides.RateLimit, subRouterRateLimit) // This returns *common.RateLimitConfig[UserID, User] + effectiveAuthTokenConfig := r.getEffectiveAuthTokenConfig(route.Overrides.AuthToken, sr.Overrides.AuthToken) + finalRouteConfig.Overrides.AuthToken = &effectiveAuthTokenConfig // Call the underlying generic registration function with the modified config RegisterGenericRoute(r, finalRouteConfig, effectiveTimeout, effectiveMaxBodySize, effectiveRateLimit) @@ -934,6 +939,35 @@ func (r *Router[T, U]) getEffectiveTimeout(routeTimeout, subRouterTimeout time.D return r.config.GlobalTimeout } +func defaultAuthTokenConfig() common.AuthTokenConfig { + return common.AuthTokenConfig{ + Source: common.AuthTokenSourceHeader, + HeaderName: defaultAuthHeaderName, + } +} + +func normalizeAuthTokenConfig(config common.AuthTokenConfig) common.AuthTokenConfig { + if config.Source == common.AuthTokenSourceHeader && config.HeaderName == "" { + config.HeaderName = defaultAuthHeaderName + } + return config +} + +// getEffectiveAuthTokenConfig returns the effective auth token config for a route. +// Precedence order (first non-nil value wins): +// 1. Route-specific auth token config +// 2. Sub-router auth token override (NOT inherited by nested sub-routers) +// 3. Default header-based auth token config +func (r *Router[T, U]) getEffectiveAuthTokenConfig(routeAuth, subRouterAuth *common.AuthTokenConfig) common.AuthTokenConfig { + if routeAuth != nil { + return normalizeAuthTokenConfig(*routeAuth) + } + if subRouterAuth != nil { + return normalizeAuthTokenConfig(*subRouterAuth) + } + return defaultAuthTokenConfig() +} + // getEffectiveMaxBodySize returns the effective max body size for a route. // Precedence order (first non-zero value wins): // 1. Route-specific max body size @@ -1204,13 +1238,36 @@ func (r *Router[T, U]) recoveryMiddleware(next http.Handler) http.Handler { // authenticateRequest attempts to authenticate the request and, if successful, // returns a new request with user information stored in the context. // It does not perform any logging; callers handle logging based on the result. -func (r *Router[T, U]) authenticateRequest(req *http.Request) (*http.Request, bool, string) { - authHeader := req.Header.Get("Authorization") - if authHeader == "" { - return req, false, "no authorization header" +func (r *Router[T, U]) authenticateRequest(req *http.Request, authTokenConfig common.AuthTokenConfig) (*http.Request, bool, string) { + var token string + + switch authTokenConfig.Source { + case common.AuthTokenSourceHeader: + headerName := authTokenConfig.HeaderName + if headerName == "" { + headerName = defaultAuthHeaderName + } + authHeader := req.Header.Get(headerName) + if authHeader == "" { + if headerName == defaultAuthHeaderName { + return req, false, "no authorization header" + } + return req, false, "no auth header" + } + token = strings.TrimPrefix(authHeader, "Bearer ") + case common.AuthTokenSourceCookie: + if authTokenConfig.CookieName == "" { + return req, false, "auth cookie name not configured" + } + cookie, err := req.Cookie(authTokenConfig.CookieName) + if err != nil { + return req, false, "no auth cookie" + } + token = cookie.Value + default: + return req, false, "unsupported auth token source" } - token := strings.TrimPrefix(authHeader, "Bearer ") if user, valid := r.authFunction(req.Context(), token); valid { id := r.getUserIdFromUser(user) ctx := scontext.WithUserID[T, U](req.Context(), id) @@ -1230,26 +1287,33 @@ func (r *Router[T, U]) authenticateRequest(req *http.Request) (*http.Request, bo // If authentication fails, it returns a 401 Unauthorized response. // It uses the middleware.AuthenticationWithUser function with a configurable authentication function. func (r *Router[T, U]) authRequiredMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - var ok bool - var reason string - req, ok, reason = r.authenticateRequest(req) - if !ok { - fields := append(r.baseFields(req), - zap.String("remote_addr", req.RemoteAddr), - zap.String("error", reason), - ) - fields = r.addTrace(fields, req) - r.logger.Warn("Authentication failed", fields...) - traceID := scontext.GetTraceIDFromRequest[T, U](req) - r.writeJSONError(w, req, http.StatusUnauthorized, "Unauthorized", traceID) - return - } + return r.authRequiredMiddlewareWithConfig(defaultAuthTokenConfig())(next) +} - fields := r.addTrace(r.baseFields(req), req) - r.logger.Debug("Authentication successful", fields...) - next.ServeHTTP(w, req) - }) +func (r *Router[T, U]) authRequiredMiddlewareWithConfig(authTokenConfig common.AuthTokenConfig) common.Middleware { + authTokenConfig = normalizeAuthTokenConfig(authTokenConfig) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var ok bool + var reason string + req, ok, reason = r.authenticateRequest(req, authTokenConfig) + if !ok { + fields := append(r.baseFields(req), + zap.String("remote_addr", req.RemoteAddr), + zap.String("error", reason), + ) + fields = r.addTrace(fields, req) + r.logger.Warn("Authentication failed", fields...) + traceID := scontext.GetTraceIDFromRequest[T, U](req) + r.writeJSONError(w, req, http.StatusUnauthorized, "Unauthorized", traceID) + return + } + + fields := r.addTrace(r.baseFields(req), req) + r.logger.Debug("Authentication successful", fields...) + next.ServeHTTP(w, req) + }) + } } // authOptionalMiddleware is a middleware that attempts authentication for a request, @@ -1257,17 +1321,24 @@ func (r *Router[T, U]) authRequiredMiddleware(next http.Handler) http.Handler { // It tries to authenticate the request and adds the user ID to the context if successful, // but allows the request to proceed even if authentication fails. func (r *Router[T, U]) authOptionalMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - var ok bool - req, ok, _ = r.authenticateRequest(req) - if ok { - fields := r.addTrace(r.baseFields(req), req) - r.logger.Debug("Authentication successful", fields...) - } + return r.authOptionalMiddlewareWithConfig(defaultAuthTokenConfig())(next) +} - // Call the next handler regardless of authentication result - next.ServeHTTP(w, req) - }) +func (r *Router[T, U]) authOptionalMiddlewareWithConfig(authTokenConfig common.AuthTokenConfig) common.Middleware { + authTokenConfig = normalizeAuthTokenConfig(authTokenConfig) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + var ok bool + req, ok, _ = r.authenticateRequest(req, authTokenConfig) + if ok { + fields := r.addTrace(r.baseFields(req), req) + r.logger.Debug("Authentication successful", fields...) + } + + // Call the next handler regardless of authentication result + next.ServeHTTP(w, req) + }) + } } // responseWriter is a wrapper around http.ResponseWriter that captures the status code. From 4fb5145991855aef217ec06b070ee57783b8b541 Mon Sep 17 00:00:00 2001 From: suhaib Date: Fri, 19 Dec 2025 21:24:37 -0800 Subject: [PATCH 2/6] feat: implement auth token extraction logic for header and cookie sources --- pkg/router/router.go | 80 +++++++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 30 deletions(-) diff --git a/pkg/router/router.go b/pkg/router/router.go index b2e298a..e615e8f 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -51,6 +51,8 @@ type Router[T comparable, U any] struct { const defaultAuthHeaderName = "Authorization" +type authTokenExtractor func(*http.Request) (string, bool, string) + // RegisterSubRouterWithSubRouter registers a nested SubRouter with a parent SubRouter. // This helper function enables hierarchical route organization by adding a child // SubRouter to the parent's SubRouters slice. @@ -953,6 +955,46 @@ func normalizeAuthTokenConfig(config common.AuthTokenConfig) common.AuthTokenCon return config } +func buildAuthTokenExtractor(config common.AuthTokenConfig) authTokenExtractor { + switch config.Source { + case common.AuthTokenSourceHeader: + headerName := config.HeaderName + if headerName == "" { + headerName = defaultAuthHeaderName + } + missingReason := "no auth header" + if headerName == defaultAuthHeaderName { + missingReason = "no authorization header" + } + return func(req *http.Request) (string, bool, string) { + authHeader := req.Header.Get(headerName) + if authHeader == "" { + return "", false, missingReason + } + token := strings.TrimPrefix(authHeader, "Bearer ") + return token, true, "" + } + case common.AuthTokenSourceCookie: + cookieName := config.CookieName + if cookieName == "" { + return func(*http.Request) (string, bool, string) { + return "", false, "auth cookie name not configured" + } + } + return func(req *http.Request) (string, bool, string) { + cookie, err := req.Cookie(cookieName) + if err != nil { + return "", false, "no auth cookie" + } + return cookie.Value, true, "" + } + default: + return func(*http.Request) (string, bool, string) { + return "", false, "unsupported auth token source" + } + } +} + // getEffectiveAuthTokenConfig returns the effective auth token config for a route. // Precedence order (first non-nil value wins): // 1. Route-specific auth token config @@ -1238,34 +1280,10 @@ func (r *Router[T, U]) recoveryMiddleware(next http.Handler) http.Handler { // authenticateRequest attempts to authenticate the request and, if successful, // returns a new request with user information stored in the context. // It does not perform any logging; callers handle logging based on the result. -func (r *Router[T, U]) authenticateRequest(req *http.Request, authTokenConfig common.AuthTokenConfig) (*http.Request, bool, string) { - var token string - - switch authTokenConfig.Source { - case common.AuthTokenSourceHeader: - headerName := authTokenConfig.HeaderName - if headerName == "" { - headerName = defaultAuthHeaderName - } - authHeader := req.Header.Get(headerName) - if authHeader == "" { - if headerName == defaultAuthHeaderName { - return req, false, "no authorization header" - } - return req, false, "no auth header" - } - token = strings.TrimPrefix(authHeader, "Bearer ") - case common.AuthTokenSourceCookie: - if authTokenConfig.CookieName == "" { - return req, false, "auth cookie name not configured" - } - cookie, err := req.Cookie(authTokenConfig.CookieName) - if err != nil { - return req, false, "no auth cookie" - } - token = cookie.Value - default: - return req, false, "unsupported auth token source" +func (r *Router[T, U]) authenticateRequest(req *http.Request, extractToken authTokenExtractor) (*http.Request, bool, string) { + token, ok, reason := extractToken(req) + if !ok { + return req, false, reason } if user, valid := r.authFunction(req.Context(), token); valid { @@ -1292,11 +1310,12 @@ func (r *Router[T, U]) authRequiredMiddleware(next http.Handler) http.Handler { func (r *Router[T, U]) authRequiredMiddlewareWithConfig(authTokenConfig common.AuthTokenConfig) common.Middleware { authTokenConfig = normalizeAuthTokenConfig(authTokenConfig) + extractToken := buildAuthTokenExtractor(authTokenConfig) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { var ok bool var reason string - req, ok, reason = r.authenticateRequest(req, authTokenConfig) + req, ok, reason = r.authenticateRequest(req, extractToken) if !ok { fields := append(r.baseFields(req), zap.String("remote_addr", req.RemoteAddr), @@ -1326,10 +1345,11 @@ func (r *Router[T, U]) authOptionalMiddleware(next http.Handler) http.Handler { func (r *Router[T, U]) authOptionalMiddlewareWithConfig(authTokenConfig common.AuthTokenConfig) common.Middleware { authTokenConfig = normalizeAuthTokenConfig(authTokenConfig) + extractToken := buildAuthTokenExtractor(authTokenConfig) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { var ok bool - req, ok, _ = r.authenticateRequest(req, authTokenConfig) + req, ok, _ = r.authenticateRequest(req, extractToken) if ok { fields := r.addTrace(r.baseFields(req), req) r.logger.Debug("Authentication successful", fields...) From 20dedab635a8de150a4b49702c0c0972f030be3b Mon Sep 17 00:00:00 2001 From: suhaib Date: Fri, 19 Dec 2025 21:24:46 -0800 Subject: [PATCH 3/6] feat: add warning for unconfigured auth token cookie name --- pkg/router/router.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pkg/router/router.go b/pkg/router/router.go index e615e8f..ceeed43 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -955,6 +955,12 @@ func normalizeAuthTokenConfig(config common.AuthTokenConfig) common.AuthTokenCon return config } +func (r *Router[T, U]) warnOnInvalidAuthTokenConfig(config common.AuthTokenConfig) { + if config.Source == common.AuthTokenSourceCookie && config.CookieName == "" { + r.logger.Warn("Auth token cookie name not configured") + } +} + func buildAuthTokenExtractor(config common.AuthTokenConfig) authTokenExtractor { switch config.Source { case common.AuthTokenSourceHeader: @@ -1310,6 +1316,7 @@ func (r *Router[T, U]) authRequiredMiddleware(next http.Handler) http.Handler { func (r *Router[T, U]) authRequiredMiddlewareWithConfig(authTokenConfig common.AuthTokenConfig) common.Middleware { authTokenConfig = normalizeAuthTokenConfig(authTokenConfig) + r.warnOnInvalidAuthTokenConfig(authTokenConfig) extractToken := buildAuthTokenExtractor(authTokenConfig) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { @@ -1345,6 +1352,7 @@ func (r *Router[T, U]) authOptionalMiddleware(next http.Handler) http.Handler { func (r *Router[T, U]) authOptionalMiddlewareWithConfig(authTokenConfig common.AuthTokenConfig) common.Middleware { authTokenConfig = normalizeAuthTokenConfig(authTokenConfig) + r.warnOnInvalidAuthTokenConfig(authTokenConfig) extractToken := buildAuthTokenExtractor(authTokenConfig) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { From 4fed79a747387f98ceb88cac869d77548d558a23 Mon Sep 17 00:00:00 2001 From: suhaib Date: Fri, 19 Dec 2025 21:28:36 -0800 Subject: [PATCH 4/6] test: add unit tests for auth token configuration and extraction logic --- pkg/router/auth_token_config_test.go | 91 ++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 pkg/router/auth_token_config_test.go diff --git a/pkg/router/auth_token_config_test.go b/pkg/router/auth_token_config_test.go new file mode 100644 index 0000000..b8946a6 --- /dev/null +++ b/pkg/router/auth_token_config_test.go @@ -0,0 +1,91 @@ +package router + +import ( + "net/http/httptest" + "testing" + + "github.com/Suhaibinator/SRouter/pkg/common" + "github.com/Suhaibinator/SRouter/pkg/router/internal/mocks" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +func TestNormalizeAuthTokenConfigDefaultsHeaderName(t *testing.T) { + config := common.AuthTokenConfig{Source: common.AuthTokenSourceHeader} + normalized := normalizeAuthTokenConfig(config) + if normalized.HeaderName != defaultAuthHeaderName { + t.Fatalf("expected header name %q, got %q", defaultAuthHeaderName, normalized.HeaderName) + } +} + +func TestBuildAuthTokenExtractorDefaultsHeaderName(t *testing.T) { + extractor := buildAuthTokenExtractor(common.AuthTokenConfig{Source: common.AuthTokenSourceHeader}) + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set(defaultAuthHeaderName, "Bearer valid-token") + + token, ok, reason := extractor(req) + if !ok { + t.Fatalf("expected token extraction to succeed, got reason %q", reason) + } + if token != "valid-token" { + t.Fatalf("expected token %q, got %q", "valid-token", token) + } +} + +func TestBuildAuthTokenExtractorCookieMissingName(t *testing.T) { + extractor := buildAuthTokenExtractor(common.AuthTokenConfig{Source: common.AuthTokenSourceCookie}) + req := httptest.NewRequest("GET", "/test", nil) + + _, ok, reason := extractor(req) + if ok { + t.Fatal("expected token extraction to fail") + } + if reason != "auth cookie name not configured" { + t.Fatalf("expected reason %q, got %q", "auth cookie name not configured", reason) + } +} + +func TestBuildAuthTokenExtractorUnsupportedSource(t *testing.T) { + extractor := buildAuthTokenExtractor(common.AuthTokenConfig{Source: common.AuthTokenSource(99)}) + req := httptest.NewRequest("GET", "/test", nil) + + _, ok, reason := extractor(req) + if ok { + t.Fatal("expected token extraction to fail") + } + if reason != "unsupported auth token source" { + t.Fatalf("expected reason %q, got %q", "unsupported auth token source", reason) + } +} + +func TestWarnOnInvalidAuthTokenConfigLogs(t *testing.T) { + core, logs := observer.New(zap.WarnLevel) + logger := zap.New(core) + r := NewRouter(RouterConfig{Logger: logger}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + r.authRequiredMiddlewareWithConfig(common.AuthTokenConfig{ + Source: common.AuthTokenSourceCookie, + }) + + logEntries := logs.All() + if len(logEntries) != 1 { + t.Fatalf("expected 1 warning log, got %d", len(logEntries)) + } + if logEntries[0].Message != "Auth token cookie name not configured" { + t.Fatalf("expected warning message %q, got %q", "Auth token cookie name not configured", logEntries[0].Message) + } +} + +func TestGetEffectiveAuthTokenConfigUsesSubRouter(t *testing.T) { + logger := zap.NewNop() + r := NewRouter(RouterConfig{Logger: logger}, mocks.MockAuthFunction, mocks.MockUserIDFromUser) + + subRouterAuth := common.AuthTokenConfig{ + Source: common.AuthTokenSourceHeader, + } + + config := r.getEffectiveAuthTokenConfig(nil, &subRouterAuth) + if config.HeaderName != defaultAuthHeaderName { + t.Fatalf("expected header name %q, got %q", defaultAuthHeaderName, config.HeaderName) + } +} From b688ad986459c6c917777e538003aeb9f1b6adc3 Mon Sep 17 00:00:00 2001 From: suhaib Date: Fri, 19 Dec 2025 21:32:11 -0800 Subject: [PATCH 5/6] test: add unit tests for RouteOverrides rate limit and auth token checks --- pkg/common/config_test.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 pkg/common/config_test.go diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go new file mode 100644 index 0000000..641c707 --- /dev/null +++ b/pkg/common/config_test.go @@ -0,0 +1,27 @@ +package common + +import "testing" + +func TestRouteOverridesHasRateLimit(t *testing.T) { + overrides := RouteOverrides{} + if overrides.HasRateLimit() { + t.Fatal("expected HasRateLimit to be false when rate limit is nil") + } + + overrides.RateLimit = &RateLimitConfig[any, any]{} + if !overrides.HasRateLimit() { + t.Fatal("expected HasRateLimit to be true when rate limit is set") + } +} + +func TestRouteOverridesHasAuthToken(t *testing.T) { + overrides := RouteOverrides{} + if overrides.HasAuthToken() { + t.Fatal("expected HasAuthToken to be false when auth token is nil") + } + + overrides.AuthToken = &AuthTokenConfig{} + if !overrides.HasAuthToken() { + t.Fatal("expected HasAuthToken to be true when auth token is set") + } +} From 456f5d1e701288008bc9ef3db610ccb061f7b656 Mon Sep 17 00:00:00 2001 From: suhaib Date: Fri, 19 Dec 2025 21:41:24 -0800 Subject: [PATCH 6/6] docs: update authentication and configuration documentation to include auth token source details --- README.md | 2 ++ docs/authentication.md | 26 +++++++++++++++++---- docs/configuration.md | 51 +++++++++++++++++++++++++++++++++++++---- docs/getting-started.md | 4 ++-- docs/routing.md | 9 ++++---- 5 files changed, 77 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 40abf8c..8d42a1c 100644 --- a/README.md +++ b/README.md @@ -696,6 +696,8 @@ SRouter supports three authentication levels, specified in `RouteConfig` or `Rou 2. **AuthOptional**: Authentication is attempted (e.g., by middleware). If successful, user info is added to the context. The request proceeds regardless. 3. **AuthRequired**: Authentication is required (e.g., by middleware). If authentication fails, the middleware should reject the request (e.g., with 401 Unauthorized). If successful, user info is added to the context. +When using the built-in `AuthOptional`/`AuthRequired` middleware, the token is extracted from the configured auth token source (`common.RouteOverrides.AuthToken`). The default source is the `Authorization` header. Cookie-based auth is supported by setting `AuthToken` to a cookie source on a sub-router or route. + ```go // Example route configurations routePublic := router.RouteConfigBase{ AuthLevel: router.Ptr(router.NoAuth), ... } diff --git a/docs/authentication.md b/docs/authentication.md index 1bdea0b..a560b14 100644 --- a/docs/authentication.md +++ b/docs/authentication.md @@ -9,9 +9,9 @@ SRouter defines three authentication levels using the `router.AuthLevel` type. Y Setting `AuthLevel` to `AuthOptional` or `AuthRequired` activates **built-in middleware** within the router. This middleware performs the following based on the level: 1. **`router.NoAuth`**: No authentication is required or attempted by the built-in middleware. The request proceeds directly to the next middleware or handler. This is the default if `AuthLevel` is not set. -2. **`router.AuthOptional`**: The built-in authentication middleware is activated. It attempts to validate credentials (currently expects a Bearer token in the `Authorization` header) using the `authFunction` provided to `NewRouter`. +2. **`router.AuthOptional`**: The built-in authentication middleware is activated. It attempts to validate credentials using the `authFunction` provided to `NewRouter`. The token is extracted from the configured auth token source (see "Auth Token Source" below); the default is the `Authorization` header. * If authentication succeeds, the middleware populates the user ID (using the `userIdFromUserFunction` from `NewRouter`) and optionally the user object into the request context using `scontext.WithUserID` and `scontext.WithUser`. Storing the user object requires `RouterConfig.AddUserObjectToCtx` to be `true`. The request then proceeds to the next middleware or handler. - * If authentication fails (or no `Authorization` header is provided), the request *still proceeds* to the next middleware or handler, but without user information in the context. The handler must check for the presence of user information using `scontext.GetUserIDFromRequest` or `scontext.GetUserFromRequest`. + * If authentication fails (or no token is provided from the configured source), the request *still proceeds* to the next middleware or handler, but without user information in the context. The handler must check for the presence of user information using `scontext.GetUserIDFromRequest` or `scontext.GetUserFromRequest`. 3. **`router.AuthRequired`**: The built-in authentication middleware is activated and authentication is mandatory. It attempts validation as described for `AuthOptional`. * If authentication succeeds, the middleware populates the context (as above) and proceeds to the next middleware or handler. * If authentication fails, the built-in middleware **rejects** the request by sending an HTTP `401 Unauthorized` response and stops the middleware chain. The handler is not called. @@ -51,7 +51,7 @@ The core of the built-in authentication mechanism relies on two functions you ** 1. **`authFunction func(ctx context.Context, token string) (*UserObjectType, bool)`**: * This function is called by the built-in middleware when `AuthLevel` is `AuthOptional` or `AuthRequired`. - * It receives the request context and the token string extracted from the `Authorization: Bearer ` header. + * It receives the request context and the token string extracted from the configured auth token source (header or cookie). * It should validate the token (e.g., check a database, validate a JWT signature). * It must return the corresponding `UserObjectType` (your application's user struct/type) and `true` if the token is valid, or a zero-value `UserObjectType` and `false` if invalid. @@ -83,9 +83,27 @@ r := router.NewRouter[string, MyUserType](routerConfig, myAuthValidator, myGetID **If you do not intend to use the built-in `AuthLevel` mechanism** (e.g., you rely solely on custom authentication middleware), you must still provide non-nil functions to `NewRouter`. These can be simple dummy functions that always return `false` or zero values. +## Auth Token Source + +By default, the built-in middleware reads the token from the `Authorization` header and trims a `Bearer ` prefix if present. You can override the source per sub-router or per route via `common.RouteOverrides.AuthToken`: + +```go +Overrides: common.RouteOverrides{ + AuthToken: &common.AuthTokenConfig{ + Source: common.AuthTokenSourceCookie, + CookieName: "auth_token", + }, +}, +``` + +Notes: +- Only the configured source is honored (no fallback to other sources). +- If `Source` is `AuthTokenSourceHeader` and `HeaderName` is empty, it defaults to `Authorization`. +- If `Source` is `AuthTokenSourceCookie` and `CookieName` is empty, the built-in middleware logs a warning at registration time. + ## Custom Authentication Middleware -While the `AuthLevel` setting provides convenient Bearer token authentication via the built-in mechanism, you can implement **custom authentication middleware** for other schemes (Cookies, API Keys, Basic Auth, etc.) or more complex logic. +While the `AuthLevel` setting provides convenient token authentication via the built-in mechanism, you can implement **custom authentication middleware** for other schemes (Cookies, API Keys, Basic Auth, etc.) or more complex logic. Your custom middleware is responsible for: diff --git a/docs/configuration.md b/docs/configuration.md index 03fff3e..4aa8b39 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -130,7 +130,8 @@ type SubRouterConfig struct { // within this group (e.g., "/api/v1"). PathPrefix string - // Overrides allows this sub-router to specify timeout, body size, or rate limit + // Overrides allows this sub-router to specify timeout, body size, rate limit, + // or auth token source settings that override the global configuration. // settings that override the global configuration. Zero values mean no override. Overrides common.RouteOverrides @@ -174,8 +175,9 @@ type RouteConfigBase struct { // Nil inherits from parent sub-router or defaults to NoAuth. AuthLevel *AuthLevel - // Overrides allows this route to specify timeout, body size, or rate limit - // settings. Zero values mean inherit from the sub-router or global configuration. + // Overrides allows this route to specify timeout, body size, rate limit, + // or auth token source settings. Zero values mean inherit from the sub-router + // or global configuration. Overrides common.RouteOverrides // Handler is the standard Go HTTP handler function. Required. @@ -187,6 +189,44 @@ type RouteConfigBase struct { } ``` +## `common.RouteOverrides` and Auth Token Source + +Route overrides control per-route and per-sub-router settings, including the auth token source used by the built-in authentication middleware. + +```go +package common + +import "time" + +type AuthTokenSource int + +const ( + // AuthTokenSourceHeader reads the token from a request header. + AuthTokenSourceHeader AuthTokenSource = iota + // AuthTokenSourceCookie reads the token from a request cookie. + AuthTokenSourceCookie +) + +type AuthTokenConfig struct { + // Source determines where to look for the token. + Source AuthTokenSource + + // HeaderName is used when Source is AuthTokenSourceHeader. + // If empty, defaults to "Authorization". + HeaderName string + + // CookieName is used when Source is AuthTokenSourceCookie. + CookieName string +} + +type RouteOverrides struct { + Timeout time.Duration + MaxBodySize int64 + RateLimit *RateLimitConfig[any, any] + AuthToken *AuthTokenConfig +} +``` + ## `RouteConfig[T, U]` Used for defining generic routes with type-safe request (`T`) and response (`U`) handling. @@ -214,7 +254,8 @@ type RouteConfig[T any, U any] struct { // Nil inherits. AuthLevel *AuthLevel - // Overrides allows this route to specify timeout, body size, or rate limit + // Overrides allows this route to specify timeout, body size, rate limit, + // or auth token source settings. // settings. Zero values mean inherit from the sub-router or global configuration. Overrides common.RouteOverrides @@ -308,4 +349,4 @@ type CORSConfig struct { AllowCredentials bool // Whether to allow credentials (cookies, authorization headers). MaxAge time.Duration // How long the results of a preflight request can be cached. } -``` \ No newline at end of file +``` diff --git a/docs/getting-started.md b/docs/getting-started.md index 0932b69..8b2439a 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -113,7 +113,7 @@ func main() { ### Key Components - **`RouterConfig`**: Holds global settings like logger, timeouts, body size limits, and global middleware. -- **`authFunction`**: A function `func(ctx context.Context, token string) (UserObjectType, bool)` that validates an authentication token (currently expects Bearer token) and returns the user object and a boolean indicating success. Used by the built-in middleware when `AuthLevel` is set. +- **`authFunction`**: A function `func(ctx context.Context, token string) (UserObjectType, bool)` that validates an authentication token and returns the user object and a boolean indicating success. The token is extracted from the configured auth token source (default is the `Authorization` header). Used by the built-in middleware when `AuthLevel` is set. - **`userIdFromUserFunction`**: A function `func(user UserObjectType) UserIDType` that extracts the comparable User ID from the user object returned by `authFunction`. Used by the built-in middleware. - **`NewRouter[UserIDType, UserObjectType]`**: The constructor for the router. The type parameters define the type used for user IDs (`UserIDType`, must be comparable) and the type used for the user object (`UserObjectType`, can be any type) potentially stored in the context. - **`RouterConfig.SubRouters`**: A slice of `SubRouterConfig` where routes are defined. Each `SubRouterConfig` has a `PathPrefix` and a `Routes` slice. @@ -124,4 +124,4 @@ func main() { - Learn about [Authentication](./authentication.md) to secure your routes - Explore the [Configuration Reference](./configuration.md) for all available options - Check out [Routing](./routing.md) for advanced routing features -- See [Examples](./examples.md) for more complex use cases \ No newline at end of file +- See [Examples](./examples.md) for more complex use cases diff --git a/docs/routing.md b/docs/routing.md index ca5c623..d84fd13 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -4,7 +4,7 @@ This guide covers the routing system in SRouter, including sub-routers for organ ## Sub-Routers -Sub-routers allow you to group related routes under a common path prefix and apply shared configuration like middleware, timeouts, body size limits, and rate limits. +Sub-routers allow you to group related routes under a common path prefix and apply shared configuration like middleware, timeouts, body size limits, rate limits, and auth token source settings. ### Defining Sub-Routers @@ -18,6 +18,7 @@ apiV1SubRouter := router.SubRouterConfig{ Timeout: 3 * time.Second, // Overrides GlobalTimeout MaxBodySize: 2 << 20, // 2 MB, overrides GlobalMaxBodySize // RateLimit: &common.RateLimitConfig[any, any]{...}, + // AuthToken: &common.AuthTokenConfig{Source: common.AuthTokenSourceCookie, CookieName: "auth_token"}, }, // Middlewares specific to /api/v1 routes can be added here // Middlewares: []common.Middleware{ myV1Middleware }, @@ -80,7 +81,7 @@ r := router.NewRouter[string, string](routerConfig, authFunction, userIdFromUser Key points: - `PathPrefix`: Defines the base path for all routes within the sub-router. -- `Overrides`: `common.RouteOverrides` allowing timeout, body size, or rate limit overrides specific to this sub-router. +- `Overrides`: `common.RouteOverrides` allowing timeout, body size, rate limit, or auth token source overrides specific to this sub-router. - `Routes`: A slice of `router.RouteDefinition` that can contain `RouteConfigBase` or `GenericRouteDefinition`. Paths within these routes are relative to the `PathPrefix`. - `Middlewares`: Middleware applied to routes within this sub-router. These are **added to** (not replacing) any global middlewares defined in RouterConfig. - `AuthLevel`: Default authentication level for all routes in this sub-router (can be overridden at the route level). @@ -140,7 +141,7 @@ r := router.NewRouter[string, string](routerConfig, authFunction, userIdFromUser ``` **Configuration precedence:** -- **Overrides** (timeouts, body size, rate limits, auth level): The most specific setting wins (Route > Sub-Router > Global). Each level must explicitly set overrides; they are not inherited. +- **Overrides** (timeouts, body size, rate limits, auth token source): The most specific setting wins (Route > Sub-Router > Global). Each level must explicitly set overrides; they are not inherited. - **Middlewares**: These are combined additively in order: Global → Outer Sub-Router → Inner Sub-Router → Route-specific. All applicable middlewares run in this sequence. ### Imperative Route Registration @@ -360,4 +361,4 @@ Generally, `router.GetParam` is more convenient when you know the specific param - **`router.GetParam(r *http.Request, name string) string`**: Retrieves a specific parameter by name from the request context. Returns an empty string if the parameter is not found. - **`router.GetParams(r *http.Request) httprouter.Params`**: Retrieves all parameters from the request context as an `httprouter.Params` slice. - **`scontext.GetPathParamsFromRequest[T, U](r *http.Request) (httprouter.Params, bool)`**: Returns all parameters from the generic `scontext` wrapper along with a boolean indicating presence. -- **`scontext.GetRouteTemplateFromRequest[T, U](r *http.Request) (string, bool)`**: Retrieves the original route pattern from the context for metrics or logging. \ No newline at end of file +- **`scontext.GetRouteTemplateFromRequest[T, U](r *http.Request) (string, bool)`**: Retrieves the original route pattern from the context for metrics or logging.