diff --git a/README.MD b/README.MD index b388b9d..c389c6a 100644 --- a/README.MD +++ b/README.MD @@ -16,6 +16,25 @@ GoWeb is a lightweight Go web framework that mimics the structure, design, and f --- +## ๐Ÿ“š Table of Contents + +- [โœจ Features](#-features) +- [๐Ÿš€ Usage](#-usage) + 1. [Define Request/Response DTOs](#1-define-requestresponse-dtos) + 2. [Create a Controller](#2-create-a-controller) + 3. [Register Controllers in main.go](#3-register-controllers-in-maingo) +- [๐Ÿงฑ Core Concepts](#-core-concepts) +- [๐Ÿ’ก Why These Matter](#-why-these-matter) +- [๐Ÿงช Response Builder](#-response-builder) +- [๐Ÿ“Œ Example JSON Response](#-example-json-response) +- [๐ŸŒ CORS Middleware](#-cors-middleware) + - [โœ… Registering the CORS Middleware](#-registering-the-cors-middleware) + - [โš™๏ธ Behavior](#-behavior) + - [๐Ÿ›ก Example: Block all but GET](#-example-block-all-but-get) +- [โค๏ธ Inspired By](#-inspired-by) + +--- + ## ๐Ÿš€ Usage ### 1. Define Request/Response DTOs @@ -39,11 +58,11 @@ type UserResponse struct { type UsersController struct{} func (c *UsersController) BasePath() string { - return "/users" + return "/api/v1/users" } -func (c *UsersController) Routes() []core.RouteEntry { - return []core.RouteEntry{ +func (c *UsersController) Routes() []types.Route { + return []types.Route{ {Method: "GET", Path: "/", Handler: "GetAll"}, {Method: "GET", Path: "/{userid}", Handler: "Get"}, {Method: "POST", Path: "/", Handler: "Post"}, @@ -85,18 +104,37 @@ func main() { --- -### ๐Ÿงฑ Core Concepts +## ๐Ÿงฑ Core Concepts +GoWeb is built on a clean and extendable foundation inspired by Spring Boot, but optimized for Go. Below are the key architectural components of the framework: + +| Concept | Description | +| --------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| **`Controller`** | A struct implementing `BasePath()` and `Routes()` to define a route group. It may embed `ControllerBase` to enable optional controller-specific middleware. | +| **`Route`** | Defines a single HTTP endpoint via `Method`, `Path`, and a string-based `Handler` name that maps to a function in the controller. | +| **DTOs** | Plain structs representing request or response data (e.g., `UserRequest`, `UserResponse`). GoWeb automatically binds path/query/body values to arguments. | +| **`ResponseEntity`** | A fluent builder for setting status, body, and headers. Example: `ResponseEntity.Status(200).Body(data).Send(w)` or return it directly from controllers. | +| **`HttpStatus`** | Enum-style constants for all HTTP status codes, e.g., `HttpStatus.OK`, `HttpStatus.CREATED`, etc., making your response code more readable. | +| **`HttpMethod`** | Enum-like constants for HTTP methods (`GET`, `POST`, etc.) and helpers like `IsValid(method)` to validate custom usage. | +| **`exception`** | Standardized error response utilities like `BadRequestException(...)` or `InternalServerException(...)` that send JSON error responses with status codes. | +| **Middleware** | Middleware objects implement the `Middleware` interface. They're registered globally or per controller using `app.Use(...)` or `controller.Use(...)`. | +| **Middleware Builder** | Use `NewMiddlewareBuilder(...)` to create strongly-typed, reusable middleware with config (`.Config`), init logic (`.WithInit()`), and error hooks (`.OnError()`). | +| **Request Context Helpers** | Access path params, query strings, and headers using `types.PathVar(ctx, "id")`, `QueryParam(ctx, "q")`, and `Header(ctx, "X-Token")`. Injected automatically by the router. | + +### ๐Ÿ’ก Why These Matter + +- โœ… **Minimal and clean:** Core concepts like `Controller` and `Route` are simple, composable structs. + +- โœ… **Extensible:** The middleware system uses generics and fluent chaining to support per-middleware config and lifecycle hooks. -| Concept | Description | -| ----------------- | --------------------------------------------------------------- | -| `Controller` | A struct implementing `BasePath()` and `Routes()` | -| `RouteEntry` | Defines each HTTP method, subpath, and method handler name | -| DTOs | Define `UserRequest`, `UserResponse`, etc. in `Models/` package | -| `BindArguments()` | Automatically binds path vars and JSON body to method arguments | -| `ResponseEntity` | Fluent builder for response body, status, and headers | +- โœ… **Type-safe binding:** Reflect-based argument resolution injects only what your handler expects โ€” nothing more. +- โœ… **Production-ready responses:** Use `ResponseEntity` and `exception` to consistently shape output without boilerplate. -### ๐Ÿงช Response Builder Examples +- โœ… **Testable architecture:** Middleware and controllers can be unit tested with standard Go tools (`httptest`). + +--- + +## ๐Ÿงช Response Builder Example of just sending a message: ```go ResponseEntity.Status(HttpStatus.OK). @@ -123,10 +161,56 @@ Content-Type: application/json X-Custom: example { - "message": "User created", - "id": 1 + "Id": 1 + "Name": "Test" + "Email": "Test@example.com" } ``` +--- +## ๐ŸŒ CORS Middleware + +GoWeb includes a built-in CORS middleware that allows you to control which origins, methods, and headers are allowed to access your server across different domains. This is especially useful when building frontend-backend systems or public APIs. + +### โœ… Registering the CORS Middleware + +To enable it globally: +```go +app.Use(middlewares.CORS) +``` + +Then configure it as needed: +```go +middlewares.CORS.Config.AllowedOrigins = []string{"https://example.com"} +middlewares.CORS.Config.AllowedMethods = []string{"GET", "POST"} +middlewares.CORS.Config.AllowedHeaders = []string{"Content-Type", "Authorization"} +middlewares.CORS.Config.AllowCredentials = true +``` +### โš™๏ธ Behavior +| Feature | Description | +| ------------------------- | ----------------------------------------------------------------------- | +| `AllowedOrigins` | List of allowed domains (use `"*"` for all) | +| `AllowedMethods` | List of allowed HTTP methods (`GET`, `POST`, etc.) | +| `AllowedHeaders` | List of allowed request headers | +| `AllowCredentials` | Enables `Access-Control-Allow-Credentials: true` | +| Auto-Handles `OPTIONS` | Returns `204 No Content` and skips route logic | +| Blocks Disallowed Methods | Returns `405 Method Not Allowed` if the request method is not permitted | + +### ๐Ÿ›ก Example: Block all but GET + +```go +app.Use(middlewares.CORS) +middlewares.CORS.Config.AllowedMethods = []string{"GET"} +``` + +If a client sends a `POST` request, the server will respond with: + +```http request +HTTP/1.1 405 Method Not Allowed +Access-Control-Allow-Methods: GET +Access-Control-Allow-Origin: https://example.com +``` + +--- ### โค๏ธ Inspired By diff --git a/app/internal/router.go b/app/internal/router.go index 6a6834e..0fff3d9 100644 --- a/app/internal/router.go +++ b/app/internal/router.go @@ -56,15 +56,18 @@ func ListenImpl(routes []CompiledRoute, addr string) error { } func Dispatch(routes []CompiledRoute, w http.ResponseWriter, req *http.Request) { + normalizedPath := normalizePath(req.URL.Path) + for _, route := range routes { - if req.Method != route.Method { - continue - } - normalizedPath := normalizePath(req.URL.Path) matches := route.Regex.FindStringSubmatch(normalizedPath) if matches == nil { continue } + + if req.Method != route.Method && req.Method != http.MethodOptions { + continue + } + pathVars := extractPathVars(route.ParamNames, matches[1:]) paramTypes := getParamTypes(route.Handler.Type()) argNames := buildArgNames(paramTypes, route.ParamNames) @@ -73,38 +76,42 @@ func Dispatch(routes []CompiledRoute, w http.ResponseWriter, req *http.Request) return } - // Controller-level pre/post middleware - var PreMiddlewares []types.MiddlewareFunc - if ctrl, ok := route.CtrlValue.Interface().(interface{ PreMiddleware() []types.MiddlewareFunc }); ok { - PreMiddlewares = ctrl.PreMiddleware() + // --- Controller-level middleware + var ctrlPre []types.Middleware + if ctrl, ok := route.CtrlValue.Interface().(interface{ PreMiddleware() []types.Middleware }); ok { + ctrlPre = ctrl.PreMiddleware() } - var PostMiddlewares []types.MiddlewareFunc - if ctrl, ok := route.CtrlValue.Interface().(interface{ PostMiddleware() []types.MiddlewareFunc }); ok { - PostMiddlewares = ctrl.PostMiddleware() + var ctrlPost []types.Middleware + if ctrl, ok := route.CtrlValue.Interface().(interface{ PostMiddleware() []types.Middleware }); ok { + ctrlPost = ctrl.PostMiddleware() } - // Build the middleware chain - chain := make([]types.MiddlewareFunc, 0, len(types.PreMiddlewares)+len(PreMiddlewares)+1+len(PostMiddlewares)+len(types.PostMiddlewares)) - chain = append(chain, types.PreMiddlewares...) - chain = append(chain, PreMiddlewares...) - chain = append(chain, func(ctx *types.MiddlewareContext) error { - result := route.Handler.Call(args) - if len(result) != 1 { - exception.InternalServerException("Expected 1 return value").Send(w) - return nil - } - resp, ok := result[0].Interface().(*types.ResponseEntity) - if ok { - ctx.ResponseEntity = resp - } - return ctx.Next() - }) + // --- Build the chain + chain := make([]types.MiddlewareFunc, 0, + len(types.PreMiddlewares)+len(ctrlPre)+1+len(ctrlPost)+len(types.PostMiddlewares), + ) + + chain = append(chain, types.ConvertMiddewaresToFuncs(types.PreMiddlewares)...) + chain = append(chain, types.ConvertMiddewaresToFuncs(ctrlPre)...) + + if req.Method != http.MethodOptions { + chain = append(chain, func(ctx *types.MiddlewareContext) error { + result := route.Handler.Call(args) + if len(result) != 1 { + exception.InternalServerException("Expected 1 return value").Send(w) + return nil + } + if resp, ok := result[0].Interface().(*types.ResponseEntity); ok { + ctx.ResponseEntity = resp + } + return ctx.Next() + }) + } - chain = append(chain, PostMiddlewares...) - chain = append(chain, types.PostMiddlewares...) + chain = append(chain, types.ConvertMiddewaresToFuncs(ctrlPost)...) + chain = append(chain, types.ConvertMiddewaresToFuncs(types.PostMiddlewares)...) - // Create the middleware context mwCtx := &types.MiddlewareContext{ Request: req, ResponseWriter: w, @@ -115,14 +122,25 @@ func Dispatch(routes []CompiledRoute, w http.ResponseWriter, req *http.Request) _ = mwCtx.Next() - // Serve the response if set if mwCtx.ResponseEntity != nil { mwCtx.ResponseEntity.Send(w) } return } - // Not found + if req.Method == http.MethodOptions { + mwCtx := &types.MiddlewareContext{ + Request: req, + ResponseWriter: w, + ResponseEntity: nil, + Index: -1, + Chain: types.ConvertMiddewaresToFuncs(types.PreMiddlewares), + } + _ = mwCtx.Next() + return + } + + // Fallback exception.NotFoundException("Route not found").Send(w) } diff --git a/app/middleware.go b/app/middleware.go index 53685aa..08996f2 100644 --- a/app/middleware.go +++ b/app/middleware.go @@ -2,15 +2,21 @@ package app import "github.com/isaacwallace123/GoWeb/app/types" -// Register pre-middleware -func Use(mw ...types.MiddlewareFunc) { +// Register pre-middleware (as Middleware interface, not just funcs) +func Use(mw ...types.Middleware) { types.PreMiddlewares = append(types.PreMiddlewares, mw...) } // Register post-middleware -func UseAfter(mw ...types.MiddlewareFunc) { +func UseAfter(mw ...types.Middleware) { types.PostMiddlewares = append(types.PostMiddlewares, mw...) } -func Pre() []types.MiddlewareFunc { return types.PreMiddlewares } -func Post() []types.MiddlewareFunc { return types.PostMiddlewares } +// Optional accessors +func Pre() []types.MiddlewareFunc { + return types.ConvertMiddewaresToFuncs(types.PreMiddlewares) +} + +func Post() []types.MiddlewareFunc { + return types.ConvertMiddewaresToFuncs(types.PostMiddlewares) +} diff --git a/app/middleware_test.go b/app/middleware_test.go index bb56e54..b8594f9 100644 --- a/app/middleware_test.go +++ b/app/middleware_test.go @@ -5,41 +5,60 @@ import ( "testing" ) -// Clears global middleware slices before each test for isolation. func clearMiddleware() { types.PreMiddlewares = nil types.PostMiddlewares = nil } -func dummyPre(ctx *types.MiddlewareContext) error { return nil } -func dummyPost(ctx *types.MiddlewareContext) error { return nil } +// Dummy middleware implementation +type dummyMiddleware struct { + name string +} + +func (d *dummyMiddleware) Func() types.MiddlewareFunc { + return func(ctx *types.MiddlewareContext) error { + return nil + } +} func TestUseRegistersPreMiddleware(t *testing.T) { clearMiddleware() - Use(dummyPre) + d := &dummyMiddleware{name: "pre"} + Use(d) + if len(types.PreMiddlewares) != 1 { t.Errorf("expected 1 pre-middleware, got %d", len(types.PreMiddlewares)) } - if Pre()[0] == nil { + + fn := Pre()[0] + if fn == nil { t.Error("Pre() did not return a valid middleware func") } } func TestUseAfterRegistersPostMiddleware(t *testing.T) { clearMiddleware() - UseAfter(dummyPost) + d := &dummyMiddleware{name: "post"} + UseAfter(d) + if len(types.PostMiddlewares) != 1 { t.Errorf("expected 1 post-middleware, got %d", len(types.PostMiddlewares)) } - if Post()[0] == nil { + + fn := Post()[0] + if fn == nil { t.Error("Post() did not return a valid middleware func") } } func TestMultipleMiddlewares(t *testing.T) { clearMiddleware() - Use(dummyPre, dummyPre) - UseAfter(dummyPost, dummyPost) + d1 := &dummyMiddleware{name: "a"} + d2 := &dummyMiddleware{name: "b"} + + Use(d1, d2) + UseAfter(d1, d2) + if len(types.PreMiddlewares) != 2 { t.Errorf("expected 2 pre-middlewares, got %d", len(types.PreMiddlewares)) } diff --git a/app/types/controller.go b/app/types/controller.go index 846cc57..327a926 100644 --- a/app/types/controller.go +++ b/app/types/controller.go @@ -8,36 +8,36 @@ type Controller interface { type ControllerBase struct { basePath string routes []Route - preMiddleware []MiddlewareFunc - postMiddleware []MiddlewareFunc + preMiddleware []Middleware + postMiddleware []Middleware } -// WithBasePath will set the URI of the controller (Like "/api/v1/users") +// WithBasePath sets the base path for the controller (e.g. "/api/users") func (c *ControllerBase) WithBasePath(path string) *ControllerBase { c.basePath = path return c } -// WithRoutes adds the routes that will be handled by the controller created +// WithRoutes adds the route list for this controller func (c *ControllerBase) WithRoutes(routes []Route) *ControllerBase { c.routes = routes return c } -// Use adds pre-middleware (runs before handler) -func (c *ControllerBase) Use(mw ...MiddlewareFunc) *ControllerBase { +// Use adds pre-handler middleware +func (c *ControllerBase) Use(mw ...Middleware) *ControllerBase { c.preMiddleware = append(c.preMiddleware, mw...) return c } -// UseAfter adds post-middleware (runs after handler, before global post-middleware) -func (c *ControllerBase) UseAfter(mw ...MiddlewareFunc) *ControllerBase { +// UseAfter adds post-handler middleware +func (c *ControllerBase) UseAfter(mw ...Middleware) *ControllerBase { c.postMiddleware = append(c.postMiddleware, mw...) return c } -// BasePath, Routes, PreMiddleware, & PostMiddleware These are the pre-determined methods that users are essentially FORCED to use because GoWeb uses these methods in the dispatch -func (c *ControllerBase) BasePath() string { return c.basePath } -func (c *ControllerBase) Routes() []Route { return c.routes } -func (c *ControllerBase) PreMiddleware() []MiddlewareFunc { return c.preMiddleware } -func (c *ControllerBase) PostMiddleware() []MiddlewareFunc { return c.postMiddleware } +// Required interface implementations +func (c *ControllerBase) BasePath() string { return c.basePath } +func (c *ControllerBase) Routes() []Route { return c.routes } +func (c *ControllerBase) PreMiddleware() []Middleware { return c.preMiddleware } +func (c *ControllerBase) PostMiddleware() []Middleware { return c.postMiddleware } diff --git a/app/types/middleware.go b/app/types/middleware.go index 167fa13..e05f275 100644 --- a/app/types/middleware.go +++ b/app/types/middleware.go @@ -1,6 +1,7 @@ package types import ( + "github.com/isaacwallace123/GoUtils/logger" "net/http" ) @@ -14,9 +15,14 @@ type MiddlewareContext struct { } // MiddlewareFunc represents a single middleware function. -// It receives a MiddlewareContext and returns an error if any. type MiddlewareFunc func(ctx *MiddlewareContext) error +// Middleware represents a middleware object with a Func method. +// This allows middleware structs to be registered and called. +type Middleware interface { + Func() MiddlewareFunc +} + // Next advances the middleware chain to the next function. // It returns any error produced by the next middleware. func (ctx *MiddlewareContext) Next() error { @@ -27,8 +33,69 @@ func (ctx *MiddlewareContext) Next() error { return nil // End of middleware chain } -// PreMiddlewares holds globally registered middleware that runs before the handler. -var PreMiddlewares []MiddlewareFunc +// PreMiddlewares holds globally registered middleware objects. +var PreMiddlewares []Middleware + +// PostMiddlewares holds globally registered middleware objects. +var PostMiddlewares []Middleware + +// --- Middleware Builder Pattern --- \\ + +// MiddlewareBuilder is a reusable, typed middleware object with attached config and logic. +type MiddlewareBuilder[T any] struct { + Config *T + Handler MiddlewareFunc + OnErrorHandler func(ctx *MiddlewareContext, err error) +} -// PostMiddlewares holds globally registered middleware that runs after the handler. -var PostMiddlewares []MiddlewareFunc +// Func allows the builder to be treated as a Middleware interface. +func (middleware *MiddlewareBuilder[T]) Func() MiddlewareFunc { + return func(ctx *MiddlewareContext) error { + err := middleware.Handler(ctx) + + if err != nil && middleware.OnErrorHandler != nil { + middleware.OnErrorHandler(ctx, err) + + return nil + } + + return err + } +} + +// WithInit is a script that the middleware's creator can tune to make it run once after the middleware is hooked +func (middleware *MiddlewareBuilder[T]) WithInit(initFn func(*T)) *MiddlewareBuilder[T] { + initFn(middleware.Config) + return middleware +} + +func (middleware *MiddlewareBuilder[T]) OnError(handler func(ctx *MiddlewareContext, err error)) *MiddlewareBuilder[T] { + middleware.OnErrorHandler = handler + return middleware +} + +// NewMiddlewareBuilder constructs a typed middleware with a config and handler function. +func NewMiddlewareBuilder[T any]( + name string, + defaultConfig *T, + handler func(ctx *MiddlewareContext, config *T) error, +) *MiddlewareBuilder[T] { + builder := &MiddlewareBuilder[T]{ + Config: defaultConfig, + Handler: func(ctx *MiddlewareContext) error { + return handler(ctx, defaultConfig) + }, + OnErrorHandler: func(ctx *MiddlewareContext, err error) { + logger.Error("Middleware '%s' error: %v\n", name, err) + }, + } + return builder +} + +func ConvertMiddewaresToFuncs(mw []Middleware) []MiddlewareFunc { + funcs := make([]MiddlewareFunc, len(mw)) + for i, m := range mw { + funcs[i] = m.Func() + } + return funcs +} diff --git a/pkg/HttpMethod/HttpMethod.go b/pkg/HttpMethod/HttpMethod.go new file mode 100644 index 0000000..6ce83fc --- /dev/null +++ b/pkg/HttpMethod/HttpMethod.go @@ -0,0 +1,26 @@ +package HttpMethod + +const ( + GET string = "GET" + POST string = "POST" + PUT string = "PUT" + DELETE string = "DELETE" + PATCH string = "PATCH" + HEAD string = "HEAD" + OPTIONS string = "OPTIONS" + CONNECT string = "CONNECT" + TRACE string = "TRACE" +) + +var allMethods = []string{ + GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS, CONNECT, TRACE, +} + +func IsValid(method string) bool { + for _, m := range allMethods { + if string(m) == method { + return true + } + } + return false +} diff --git a/pkg/middlewares/CORS.go b/pkg/middlewares/CORS.go new file mode 100644 index 0000000..84396e3 --- /dev/null +++ b/pkg/middlewares/CORS.go @@ -0,0 +1,91 @@ +package middlewares + +import ( + "fmt" + "github.com/isaacwallace123/GoUtils/color" + "github.com/isaacwallace123/GoUtils/logger" + + "net/http" + "strings" + + "github.com/isaacwallace123/GoWeb/app/types" +) + +type CORSConfig struct { + AllowedOrigins []string + AllowedMethods []string + AllowedHeaders []string + AllowCredentials bool +} + +var CORS_TAG = fmt.Sprintf("%sCORS%s", color.BrightCyan, color.Reset) + +var CORS = types.NewMiddlewareBuilder("cors", &CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Content-Type", "Authorization"}, + AllowCredentials: true, +}, func(ctx *types.MiddlewareContext, config *CORSConfig) error { + origin := ctx.Request.Header.Get("Origin") + + // Check and set all headers if needed + if origin != "" { + if isOriginAllowed(origin, config.AllowedOrigins) { + headers := ctx.ResponseWriter.Header() + headers.Set("Access-Control-Allow-Origin", origin) + headers.Set("Vary", "Origin") + headers.Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", ")) + headers.Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", ")) + + if config.AllowCredentials { + headers.Set("Access-Control-Allow-Credentials", "true") + } + + logger.Debug("%s Origin allowed: %s", CORS_TAG, origin) + } else { + logger.Warn("%s Disallowed origin: %s", CORS_TAG, origin) + } + } else { + logger.Debug("%s No Origin header present", CORS_TAG) + } + + // Always intercept OPTIONS + if ctx.Request.Method == http.MethodOptions { + logger.Debug("%s Preflight OPTIONS request intercepted", CORS_TAG) + + ctx.ResponseWriter.WriteHeader(http.StatusNoContent) + + return nil + } + + // Enforce method restriction on all other requests + if !isMethodAllowed(ctx.Request.Method, config.AllowedMethods) { + logger.Warn("%s Method not allowed: %s", CORS_TAG, color.HTTPMethodToColor[ctx.Request.Method]+ctx.Request.Method+color.Reset) + + ctx.ResponseWriter.WriteHeader(http.StatusMethodNotAllowed) + + return nil + } + + return ctx.Next() +}).WithInit(func(config *CORSConfig) { + logger.Info("%s Middleware successfully initialized", CORS_TAG) +}) + +func isOriginAllowed(origin string, allowed []string) bool { + for _, o := range allowed { + if o == "*" || o == origin { + return true + } + } + return false +} + +func isMethodAllowed(method string, allowed []string) bool { + for _, m := range allowed { + if strings.EqualFold(m, method) { + return true + } + } + return false +} diff --git a/pkg/middlewares/CORS_test.go b/pkg/middlewares/CORS_test.go new file mode 100644 index 0000000..7723bb9 --- /dev/null +++ b/pkg/middlewares/CORS_test.go @@ -0,0 +1,81 @@ +package middlewares + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/isaacwallace123/GoWeb/app/types" +) + +func setupRequest(method, origin string) *http.Request { + req, _ := http.NewRequest(method, "/test", nil) + if origin != "" { + req.Header.Set("Origin", origin) + } + return req +} + +func runMiddleware(req *http.Request) *httptest.ResponseRecorder { + rr := httptest.NewRecorder() + + ctx := &types.MiddlewareContext{ + Request: req, + ResponseWriter: rr, + ResponseEntity: nil, + Index: -1, + Chain: []types.MiddlewareFunc{ + CORS.Func(), // only CORS middleware + }, + } + + _ = ctx.Next() + return rr +} + +func TestCORS_AllowsOrigin(t *testing.T) { + CORS.Config.AllowedOrigins = []string{"http://example.com"} + CORS.Config.AllowedMethods = []string{"GET"} + + req := setupRequest("GET", "http://example.com") + rr := runMiddleware(req) + + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" { + t.Errorf("expected Access-Control-Allow-Origin to be set, got %q", got) + } +} + +func TestCORS_RejectsDisallowedMethod(t *testing.T) { + CORS.Config.AllowedOrigins = []string{"http://example.com"} + CORS.Config.AllowedMethods = []string{"GET"} + + req := setupRequest("POST", "http://example.com") + rr := runMiddleware(req) + + if rr.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status 405, got %d", rr.Code) + } +} + +func TestCORS_PreflightOPTIONS(t *testing.T) { + CORS.Config.AllowedOrigins = []string{"http://example.com"} + CORS.Config.AllowedMethods = []string{"GET", "POST"} + + req := setupRequest("OPTIONS", "http://example.com") + rr := runMiddleware(req) + + if rr.Code != http.StatusNoContent { + t.Errorf("expected status 204 for OPTIONS, got %d", rr.Code) + } +} + +func TestCORS_MissingOrigin(t *testing.T) { + CORS.Config.AllowedOrigins = []string{"http://example.com"} + + req := setupRequest("GET", "") // no Origin header + rr := runMiddleware(req) + + if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "" { + t.Errorf("expected no CORS header when Origin is missing, got %q", got) + } +} diff --git a/pkg/middlewares/logging.go b/pkg/middlewares/logging.go index db6581f..6a623d4 100644 --- a/pkg/middlewares/logging.go +++ b/pkg/middlewares/logging.go @@ -6,20 +6,28 @@ import ( "github.com/isaacwallace123/GoWeb/app/types" ) -// Pre-middleware: logs before the handler runs (no response yet) -func LoggingPre(ctx *types.MiddlewareContext) error { - methodColored := color.HTTPMethodToColor[ctx.Request.Method] + ctx.Request.Method + color.Reset - - logger.Info("%s %s", methodColored, ctx.Request.URL.Path) - - return ctx.Next() +type LoggingConfig struct { + Enabled bool } -// Post-middleware: logs after the handler, when ResponseEntity exists -func LoggingPost(ctx *types.MiddlewareContext) error { - methodColored := color.HTTPMethodToColor[ctx.Request.Method] + ctx.Request.Method + color.Reset - - logger.Info("%s %s %d", methodColored, ctx.Request.URL.Path, ctx.ResponseEntity.StatusCode) +// Logs before the handler runs (no response yet) +var LoggingPre = types.NewMiddlewareBuilder("logging_pre", &LoggingConfig{ + Enabled: true, +}, func(ctx *types.MiddlewareContext, cfg *LoggingConfig) error { + if cfg.Enabled { + methodColored := color.HTTPMethodToColor[ctx.Request.Method] + ctx.Request.Method + color.Reset + logger.Info("%s %s", methodColored, ctx.Request.URL.Path) + } + return ctx.Next() +}) +// Logs after the handler, when ResponseEntity exists +var LoggingPost = types.NewMiddlewareBuilder("logging_post", &LoggingConfig{ + Enabled: true, +}, func(ctx *types.MiddlewareContext, cfg *LoggingConfig) error { + if cfg.Enabled && ctx.ResponseEntity != nil { + methodColored := color.HTTPMethodToColor[ctx.Request.Method] + ctx.Request.Method + color.Reset + logger.Info("%s %s %d", methodColored, ctx.Request.URL.Path, ctx.ResponseEntity.StatusCode) + } return ctx.Next() -} +}) diff --git a/pkg/middlewares/logging_test.go b/pkg/middlewares/logging_test.go new file mode 100644 index 0000000..246f1bb --- /dev/null +++ b/pkg/middlewares/logging_test.go @@ -0,0 +1,113 @@ +package middlewares + +import ( + "bytes" + "github.com/isaacwallace123/GoWeb/app/types" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" +) + +func captureStdout(fn func()) string { + // Save original stdout + old := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + // Run the function while writing to pipe + fn() + + // Close and restore stdout + w.Close() + os.Stdout = old + + // Read captured output + var buf bytes.Buffer + _, _ = io.Copy(&buf, r) + return buf.String() +} + +func dummyHandler(ctx *types.MiddlewareContext) error { + ctx.ResponseEntity = &types.ResponseEntity{ + StatusCode: http.StatusOK, + BodyData: []byte(`OK`), + } + return ctx.Next() +} + +func TestLoggingPre_Enabled(t *testing.T) { + LoggingPre.Config.Enabled = true + + req := httptest.NewRequest("GET", "/test/pre", nil) + res := httptest.NewRecorder() + + logs := captureStdout(func() { + ctx := &types.MiddlewareContext{ + Request: req, + ResponseWriter: res, + Chain: []types.MiddlewareFunc{ + LoggingPre.Func(), + dummyHandler, + }, + Index: -1, + } + _ = ctx.Next() + }) + + if !strings.Contains(logs, "GET") && !strings.Contains(logs, "/test/pre") { + t.Errorf("expected log to contain 'GET /test/pre', got: %s", logs) + } +} + +func TestLoggingPost_Enabled(t *testing.T) { + LoggingPost.Config.Enabled = true + + req := httptest.NewRequest("POST", "/test/post", nil) + res := httptest.NewRecorder() + + logs := captureStdout(func() { + ctx := &types.MiddlewareContext{ + Request: req, + ResponseWriter: res, + Chain: []types.MiddlewareFunc{ + dummyHandler, + LoggingPost.Func(), + }, + Index: -1, + } + _ = ctx.Next() + }) + + if !strings.Contains(logs, "POST") && !strings.Contains(logs, "/test/post") { + t.Errorf("expected log to contain 'POST /test/post 200', got: %s", logs) + } +} + +func TestLogging_Disabled(t *testing.T) { + LoggingPre.Config.Enabled = false + LoggingPost.Config.Enabled = false + + req := httptest.NewRequest("GET", "/no/logs", nil) + res := httptest.NewRecorder() + + logs := captureStdout(func() { + ctx := &types.MiddlewareContext{ + Request: req, + ResponseWriter: res, + Chain: []types.MiddlewareFunc{ + LoggingPre.Func(), + dummyHandler, + LoggingPost.Func(), + }, + Index: -1, + } + _ = ctx.Next() + }) + + if logs != "" { + t.Errorf("expected no logs, got: %s", logs) + } +}