diff --git a/astTraversal/callExpression.go b/astTraversal/callExpression.go index fb7dc06..7b5d1b8 100644 --- a/astTraversal/callExpression.go +++ b/astTraversal/callExpression.go @@ -1,6 +1,7 @@ package astTraversal import ( + "errors" "go/ast" "go/types" ) @@ -14,7 +15,7 @@ type CallExpressionTraverser struct { func (t *BaseTraverser) CallExpression(node ast.Node) (*CallExpressionTraverser, error) { callExpr, ok := node.(*ast.CallExpr) if !ok { - return nil, ErrInvalidNodeType + return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *ast.CallExpr")) } return &CallExpressionTraverser{ @@ -59,7 +60,7 @@ func (c *CallExpressionTraverser) Args() []ast.Expr { func (c *CallExpressionTraverser) Type() (*types.Func, error) { if c.Node.Fun == nil { - return nil, ErrInvalidNodeType + return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *ast.CallExpr.Fun to not be nil")) } var obj types.Object @@ -70,7 +71,7 @@ func (c *CallExpressionTraverser) Type() (*types.Func, error) { case *ast.SelectorExpr: obj, err = c.File.Package.FindObjectForIdent(nodeFun.Sel) default: - err = ErrInvalidNodeType + err = errors.Join(ErrInvalidNodeType, errors.New("expected *ast.CallExpr.Fun to be *ast.Ident or *ast.SelectorExpr")) } if err != nil { return nil, err @@ -83,7 +84,7 @@ func (c *CallExpressionTraverser) Type() (*types.Func, error) { return nil, ErrBuiltInFunction } - return nil, ErrInvalidNodeType + return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *types.Func")) } func (c *CallExpressionTraverser) ReturnType(returnNum int) (types.Type, error) { @@ -94,11 +95,11 @@ func (c *CallExpressionTraverser) ReturnType(returnNum int) (types.Type, error) signature, ok := funcType.Type().(*types.Signature) if !ok { - return nil, ErrInvalidNodeType + return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *types.Signature")) } if signature.Results().Len() <= returnNum { - return nil, ErrInvalidIndex + return nil, errors.Join(ErrInvalidIndex, errors.New("expected returnNum to be less than signature.Results().Len()")) } return signature.Results().At(returnNum).Type(), nil @@ -112,11 +113,11 @@ func (c *CallExpressionTraverser) ArgType(argNum int) (types.Object, error) { signature, ok := funcType.Type().(*types.Signature) if !ok { - return nil, ErrInvalidNodeType + return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *types.Signature")) } if signature.Params().Len() <= argNum { - return nil, ErrInvalidIndex + return nil, errors.Join(ErrInvalidIndex, errors.New("expected argNum to be less than signature.Params().Len()")) } return signature.Params().At(argNum), nil diff --git a/astTraversal/declaration.go b/astTraversal/declaration.go index 17aefa3..cced283 100644 --- a/astTraversal/declaration.go +++ b/astTraversal/declaration.go @@ -1,6 +1,7 @@ package astTraversal import ( + "errors" "go/ast" ) @@ -43,7 +44,7 @@ func (d *DeclarationTraverser) Value() (ast.Node, error) { return n.Rhs[index], nil default: - return nil, ErrInvalidNodeType + return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *ast.ValueSpec or *ast.AssignStmt")) } } diff --git a/astTraversal/extractStatusCode.go b/astTraversal/extractStatusCode.go index 3e0d7ac..18011a7 100644 --- a/astTraversal/extractStatusCode.go +++ b/astTraversal/extractStatusCode.go @@ -1,6 +1,7 @@ package astTraversal import ( + "errors" "go/ast" "strconv" ) @@ -9,7 +10,7 @@ import ( func (t *BaseTraverser) ExtractStatusCode(status ast.Node) (int, error) { exprNode, ok := status.(ast.Expr) if !ok { - return 0, ErrInvalidNodeType + return 0, errors.Join(ErrInvalidNodeType, errors.New("expected ast.Expr")) } expr := t.Expression(exprNode) diff --git a/astTraversal/function.go b/astTraversal/function.go index 8bee5cc..35d05df 100644 --- a/astTraversal/function.go +++ b/astTraversal/function.go @@ -1,6 +1,7 @@ package astTraversal import ( + "errors" "go/ast" "go/types" ) @@ -25,7 +26,7 @@ func (t *BaseTraverser) Function(node ast.Node) (*FunctionTraverser, error) { } funcDecl = n default: - return nil, ErrInvalidNodeType + return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *ast.FuncLit or *ast.FuncDecl")) } return &FunctionTraverser{ diff --git a/astTraversal/literal.go b/astTraversal/literal.go index c234bc5..1c85e34 100644 --- a/astTraversal/literal.go +++ b/astTraversal/literal.go @@ -1,6 +1,7 @@ package astTraversal import ( + "errors" "go/ast" "go/types" ) @@ -22,7 +23,7 @@ func (t *BaseTraverser) Literal(node ast.Node, returnNum int) (*LiteralTraverser func (lt *LiteralTraverser) Type() (types.Type, error) { exprNode, ok := lt.Node.(ast.Expr) if !ok { - return nil, ErrInvalidNodeType + return nil, errors.Join(ErrInvalidNodeType, errors.New("expected ast.Expr")) } return lt.Traverser.Expression(exprNode).SetReturnNum(lt.ReturnNum).Type() diff --git a/astTraversal/type.go b/astTraversal/type.go index aca6adf..50bbe2b 100644 --- a/astTraversal/type.go +++ b/astTraversal/type.go @@ -1,6 +1,7 @@ package astTraversal import ( + "errors" "go/ast" "go/token" "go/types" @@ -265,7 +266,7 @@ func (t *TypeTraverser) Result() (Result, error) { if result.Type != "" { return result, nil } else { - return Result{}, ErrInvalidNodeType + return Result{}, errors.Join(ErrInvalidNodeType, errors.New("expected *types.Basic, *types.Named, *types.Pointer, *types.Slice, *types.Array, *types.Map, *types.Struct, or *types.Interface")) } } diff --git a/inputs/gin/createRoute.go b/inputs/gin/createRoute.go index 5d553f4..c9904b8 100644 --- a/inputs/gin/createRoute.go +++ b/inputs/gin/createRoute.go @@ -3,16 +3,22 @@ package gin import ( "os" "path/filepath" + "runtime" "github.com/ls6-events/astra" "github.com/gin-gonic/gin" ) +var deniedHandlers = []string{ + "github.com/gin-gonic/gin.LoggerWithConfig.func1", + "github.com/gin-gonic/gin.CustomRecoveryWithWriter.func1", +} + // createRoute creates a route from a gin RouteInfo. // It will only create the route and refer to the handler function by name, file and line number. -// The route will be populated later by parseRoute. -func createRoute(s *astra.Service, file string, line int, info gin.RouteInfo) error { +// The route will be populated later by parseHandler. +func createRoute(s *astra.Service, handlers []uintptr, info gin.RouteInfo) error { log := s.Log.With().Str("path", info.Path).Str("method", info.Method).Str("handler", info.Handler).Logger() cwd, err := os.Getwd() @@ -21,16 +27,43 @@ func createRoute(s *astra.Service, file string, line int, info gin.RouteInfo) er return err } - relativePath, err := filepath.Rel(cwd, file) - if err != nil { - log.Error().Err(err).Msg("Failed to get relative path") - return err + log.Debug().Msg("Found route handlers") + resultHandlers := make([]astra.Handler, 0) + for _, handler := range handlers { + funcForPC := runtime.FuncForPC(handler) + name := funcForPC.Name() + file, line := funcForPC.FileLine(handler) + + var found bool + for _, deniedHandler := range deniedHandlers { + if deniedHandler == name { + found = true + break + } + } + + if found { + log.Debug().Str("name", name).Str("file", file).Int("line", line).Msg("Denied handler") + continue + } + + log.Debug().Str("name", name).Str("file", file).Int("line", line).Msg("Found handler") + + relativePath, err := filepath.Rel(cwd, file) + if err != nil { + log.Error().Err(err).Msg("Failed to get relative path") + return err + } + + resultHandlers = append(resultHandlers, astra.Handler{ + Name: name, + File: relativePath, + LineNo: line, + }) } baseRoute := astra.Route{ - Handler: info.Handler, - File: relativePath, - LineNo: line, + Handlers: resultHandlers, Path: info.Path, Method: info.Method, PathParams: make([]astra.Param, 0), diff --git a/inputs/gin/createRoutes.go b/inputs/gin/createRoutes.go index c7344af..5cdd3e8 100644 --- a/inputs/gin/createRoutes.go +++ b/inputs/gin/createRoutes.go @@ -1,12 +1,16 @@ package gin import ( + "errors" "reflect" - "runtime" + + "github.com/gin-gonic/gin" "github.com/ls6-events/astra" +) - "github.com/gin-gonic/gin" +var ( + ErrRouteNotFound = errors.New("route not found") ) // CreateRoutes creates routes from a gin routes. @@ -16,6 +20,13 @@ import ( func CreateRoutes(router *gin.Engine) astra.ServiceFunction { return func(s *astra.Service) error { s.Log.Debug().Msg("Populating service with gin routes") + + // To prevent performance issues, only find the gin.Engine tree if middleware is enabled + var trees reflect.Value + if s.UnstableEnableMiddleware { + trees = reflect.ValueOf(router).Elem().FieldByName("trees") + } + for _, route := range router.Routes() { s.Log.Debug().Str("path", route.Path).Str("method", route.Method).Msg("Populating route") @@ -31,15 +42,24 @@ func CreateRoutes(router *gin.Engine) astra.ServiceFunction { continue } - pc := reflect.ValueOf(route.HandlerFunc).Pointer() - file, line := runtime.FuncForPC(pc).FileLine(pc) + var handlers []uintptr - s.Log.Debug().Str("path", route.Path).Str("method", route.Method).Str("file", file).Int("line", line).Msg("Found route handler") + // If middleware is enabled, find the handlers for the route using the gin.Engine tree + if s.UnstableEnableMiddleware { + var found bool + handlers, found = findHandlersForRoute(trees, route) + if !found { + s.Log.Error().Str("path", route.Path).Str("method", route.Method).Msg("Route not found") + return ErrRouteNotFound + } + } else { + // If middleware is disabled, use the gin.Engine handlers + handlers = []uintptr{reflect.ValueOf(route.HandlerFunc).Pointer()} + } - s.Log.Debug().Str("path", route.Path).Str("method", route.Method).Str("file", file).Int("line", line).Msg("Parsing route") - err := createRoute(s, file, line, route) + err := createRoute(s, handlers, route) if err != nil { - s.Log.Error().Str("path", route.Path).Str("method", route.Method).Str("file", file).Int("line", line).Err(err).Msg("Failed to parse route") + s.Log.Error().Str("path", route.Path).Err(err).Msg("Failed to parse route") return err } } diff --git a/inputs/gin/findHandlersForRoute.go b/inputs/gin/findHandlersForRoute.go new file mode 100644 index 0000000..0716e0c --- /dev/null +++ b/inputs/gin/findHandlersForRoute.go @@ -0,0 +1,54 @@ +package gin + +import ( + "reflect" + + "github.com/gin-gonic/gin" +) + +// findHandlersForRoute finds all handlers for a given route. +// It uses reflection to access the private properties of the gin.Engine,specifically it's method tree. +// It returns the handler pointers and a boolean indicating if the route was found. +func findHandlersForRoute(tree reflect.Value, route gin.RouteInfo) ([]uintptr, bool) { + for i := 0; i < tree.Len(); i++ { + method := tree.Index(i) + methodString := method.FieldByName("method").String() + if methodString != route.Method { + continue + } + + node := method.FieldByName("root") + + foundRoute, found := searchNode(node.Elem(), route.Path) + if found { + handlersField := foundRoute.FieldByName("handlers") + handlers := make([]uintptr, handlersField.Len()) + + for j := 0; j < handlersField.Len(); j++ { + handlers[j] = handlersField.Index(j).Pointer() + } + + return handlers, true + } + } + + return nil, false +} + +// searchNode searches a gin.Engine node for a given path. +// It uses recursion to search all children nodes. +// It returns the node and a boolean indicating if the path was found. +func searchNode(node reflect.Value, path string) (reflect.Value, bool) { + if node.FieldByName("fullPath").String() == path && node.FieldByName("handlers").Len() > 0 { + return node, true + } + + for i := 0; i < node.FieldByName("children").Len(); i++ { + child, found := searchNode(node.FieldByName("children").Index(i).Elem(), path) + if found { + return child, true + } + } + + return reflect.Value{}, false +} diff --git a/inputs/gin/parseFunction.go b/inputs/gin/parseFunction.go index bba2ef9..c9e5249 100644 --- a/inputs/gin/parseFunction.go +++ b/inputs/gin/parseFunction.go @@ -25,7 +25,7 @@ const ( // And the package name and path are used to determine the package of the currently analysed function. // The currRoute reference is used to manipulate the current route being analysed. // The imports are used to determine the package of the context variable. -func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTraverser, currRoute *astra.Route, activeFile *astTraversal.FileNode, level int) error { +func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTraverser, currRoute *astra.Route, activeFile *astTraversal.FileNode, isLastHandler bool, level int) error { traverser := funcTraverser.Traverser traverser.SetActiveFile(activeFile) @@ -95,7 +95,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers return false } - err = parseFunction(s, function, currRoute, function.Traverser.ActiveFile(), level+1) + err = parseFunction(s, function, currRoute, function.Traverser.ActiveFile(), isLastHandler, level+1) if err != nil { traverser.Log.Error().Err(err).Msg("error parsing function") return false @@ -696,11 +696,11 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers return true }) - if err != nil { + if err != nil && !errors.Is(err, astTraversal.ErrInvalidNodeType) { return err } - if len(currRoute.ReturnTypes) == 0 && level == 0 { + if len(currRoute.ReturnTypes) == 0 && level == 0 && isLastHandler { return errors.New("return type not found") } diff --git a/inputs/gin/parseRoute.go b/inputs/gin/parseHandler.go similarity index 80% rename from inputs/gin/parseRoute.go rename to inputs/gin/parseHandler.go index fbd814b..328b73d 100644 --- a/inputs/gin/parseRoute.go +++ b/inputs/gin/parseHandler.go @@ -12,14 +12,17 @@ import ( "github.com/iancoleman/strcase" ) -// parseRoute parses a route from a gin routes. +// parseHandler parses a route from a gin routes. // It will populate the route with the handler function. // createRoute must be called before this. // It will open the file as an AST and find the handler function using the line number and function name. // It can also find the path parameters from the handler function. // It calls the parseFunction function to parse the handler function. -func parseRoute(s *astra.Service, baseRoute *astra.Route) error { - log := s.Log.With().Str("path", baseRoute.Path).Str("method", baseRoute.Method).Str("file", baseRoute.File).Logger() +func parseHandler(s *astra.Service, baseRoute *astra.Route, handlerIndex int) error { + handler := baseRoute.Handlers[handlerIndex] + isLastHandler := handlerIndex == len(baseRoute.Handlers)-1 + + log := s.Log.With().Str("path", baseRoute.Path).Str("method", baseRoute.Method).Str("file", handler.File).Logger() traverser := astTraversal.New(s.WorkDir).SetLog(&log) @@ -30,18 +33,18 @@ func parseRoute(s *astra.Service, baseRoute *astra.Route) error { return path, nil }) - handler := utils.SplitHandlerPath(baseRoute.Handler) + splitHandler := utils.SplitHandlerPath(handler.Name) - pkgPath := handler.PackagePath() - pkgName := handler.PackageName() + pkgPath := splitHandler.PackagePath() + pkgName := splitHandler.PackageName() - if len(handler.HandlerParts) < 1 { - err := fmt.Errorf("invalid handler name for file: %s", baseRoute.Handler) + if len(splitHandler.HandlerParts) < 1 { + err := fmt.Errorf("invalid handler name for file: %s", handler.Name) log.Error().Err(err).Msg("Failed to parse handler name") return err } - funcName := handler.FuncName() + funcName := splitHandler.FuncName() pkgNode := traverser.Packages.AddPackage(pkgPath) @@ -56,7 +59,7 @@ func parseRoute(s *astra.Service, baseRoute *astra.Route) error { } for _, file := range pkgNode.Files { - if path.Base(file.FileName) == path.Base(baseRoute.File) { + if path.Base(file.FileName) == path.Base(handler.File) { log.Debug().Str("fileName", file.FileName).Msg("Found file") traverser.SetActiveFile(file) break @@ -64,7 +67,7 @@ func parseRoute(s *astra.Service, baseRoute *astra.Route) error { } if traverser.ActiveFile() == nil { - err := fmt.Errorf("could not find file: %s", baseRoute.File) + err := fmt.Errorf("could not find file: %s", handler.File) log.Error().Err(err).Msg("Failed to find file") return err } @@ -88,11 +91,11 @@ func parseRoute(s *astra.Service, baseRoute *astra.Route) error { startPos := traverser.ActiveFile().Package.Package.Fset.Position(funcDecl.Pos()) - if baseRoute.LineNo != startPos.Line { + if handler.LineNo != startPos.Line { // This means that the function is set inline in the route definition log.Debug().Str("funcName", funcName).Msg("Function is inline") - ast.Inspect(funcDecl, func(n ast.Node) bool { + ast.Inspect(funcDecl.Body, func(n ast.Node) bool { if n == nil { return true } @@ -102,7 +105,7 @@ func parseRoute(s *astra.Service, baseRoute *astra.Route) error { if ok { inlineStartPos := traverser.ActiveFile().Package.Package.Fset.Position(funcLit.Pos()) - if baseRoute.LineNo == inlineStartPos.Line { + if handler.LineNo == inlineStartPos.Line { log.Debug().Str("funcName", funcName).Msg("Found inline handler function") function, err := traverser.Function(funcLit) @@ -111,7 +114,7 @@ func parseRoute(s *astra.Service, baseRoute *astra.Route) error { return false } - err = parseFunction(s, function, baseRoute, traverser.ActiveFile(), 0) + err = parseFunction(s, function, baseRoute, traverser.ActiveFile(), isLastHandler, 0) if err != nil { log.Error().Err(err).Msg("Failed to parse inline function") return false @@ -139,7 +142,7 @@ func parseRoute(s *astra.Service, baseRoute *astra.Route) error { // And define the function name as the operation ID baseRoute.OperationID = strcase.ToLowerCamel(funcName) - err = parseFunction(s, function, baseRoute, traverser.ActiveFile(), 0) + err = parseFunction(s, function, baseRoute, traverser.ActiveFile(), isLastHandler, 0) if err != nil { log.Error().Err(err).Msg("Failed to parse function") return false diff --git a/inputs/gin/parseRoutes.go b/inputs/gin/parseRoutes.go index 7c028d2..a611d92 100644 --- a/inputs/gin/parseRoutes.go +++ b/inputs/gin/parseRoutes.go @@ -6,19 +6,21 @@ import ( // ParseRoutes parses routes from a gin routes. // It will populate the routes with the handler function. -// It will individually call parseRoute for each route. +// It will individually call parseHandler for each route. // createRoutes must be called before this. func ParseRoutes() astra.ServiceFunction { return func(s *astra.Service) error { s.Log.Debug().Msg("Populating routes from gin routes") for _, route := range s.Routes { - s.Log.Debug().Str("path", route.Path).Str("method", route.Method).Msg("Populating route") + s.Log.Debug().Str("path", route.Path).Str("method", route.Method).Msg("Parsing route") - s.Log.Debug().Str("path", route.Path).Str("method", route.Method).Str("file", route.File).Int("line", route.LineNo).Msg("Parsing route") - err := parseRoute(s, &route) - if err != nil { - s.Log.Error().Str("path", route.Path).Str("method", route.Method).Str("file", route.File).Int("line", route.LineNo).Err(err).Msg("Failed to parse route") - return err + for i := range route.Handlers { + s.Log.Debug().Str("path", route.Path).Str("method", route.Method).Str("file", route.Handlers[i].File).Int("line", route.Handlers[i].LineNo).Msg("Parsing handler") + err := parseHandler(s, &route, i) + if err != nil { + s.Log.Error().Str("path", route.Path).Str("method", route.Method).Str("file", route.Handlers[i].File).Int("line", route.Handlers[i].LineNo).Err(err).Msg("Failed to parse handler") + return err + } } s.ReplaceRoute(route) diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..e7da0e2 --- /dev/null +++ b/middleware.go @@ -0,0 +1,7 @@ +package astra + +func UnstableWithMiddleware() Option { + return func(service *Service) { + service.UnstableEnableMiddleware = true + } +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..89d52bc --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,16 @@ +package astra + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func Test_UnstableWithMiddleware(t *testing.T) { + service := &Service{ + UnstableEnableMiddleware: false, + } + + UnstableWithMiddleware()(service) + + require.True(t, service.UnstableEnableMiddleware) +} diff --git a/outputs/azureFunctions/generate.go b/outputs/azureFunctions/generate.go index 024eb1c..6e8a17f 100644 --- a/outputs/azureFunctions/generate.go +++ b/outputs/azureFunctions/generate.go @@ -35,7 +35,7 @@ func Generate(directoryPath string) astra.ServiceFunction { } for _, route := range s.Routes { - splitHandler := strings.Split(route.Handler, ".") + splitHandler := strings.Split(route.Handlers[len(route.Handlers)-1].Name, ".") funcName := splitHandler[len(splitHandler)-1] functionDirectoryPath := path.Join(tempOutputDirectoryPath, funcName) diff --git a/service.go b/service.go index e6431cc..5e8b7ac 100644 --- a/service.go +++ b/service.go @@ -44,4 +44,7 @@ type Service struct { CustomTypeMapping map[string]TypeFormat `json:"custom_type_mapping" yaml:"custom_type_mapping"` // fullTypeMapping is a full map of types to their OpenAPI type and format (to save merging the custom type mapping with the predefined type mapping every time) fullTypeMapping map[string]TypeFormat + + // UnstableEnableMiddleware is a flag to enable the middleware feature (this is unstable and may change in the future) + UnstableEnableMiddleware bool `json:"unstable_enable_middleware" yaml:"unstable_enable_middleware"` } diff --git a/tests/integration/13-middleware/.gitignore b/tests/integration/13-middleware/.gitignore new file mode 100644 index 0000000..325a564 --- /dev/null +++ b/tests/integration/13-middleware/.gitignore @@ -0,0 +1 @@ +output.json \ No newline at end of file diff --git a/tests/integration/13-middleware/README.md b/tests/integration/13-middleware/README.md new file mode 100644 index 0000000..d6be159 --- /dev/null +++ b/tests/integration/13-middleware/README.md @@ -0,0 +1,4 @@ +# Middleware +This example shows how to use middleware in Astra. We have 3 types of middleware available, all are in the `setupRouter` function in `router.go`. + +**Note: this is currently an unstable (alpha) feature and may change in the future as not all test cases currently work.** \ No newline at end of file diff --git a/tests/integration/13-middleware/handlers.go b/tests/integration/13-middleware/handlers.go new file mode 100644 index 0000000..32b28b8 --- /dev/null +++ b/tests/integration/13-middleware/handlers.go @@ -0,0 +1,62 @@ +package petstore + +import ( + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + + "github.com/ls6-events/astra/tests/petstore" +) + +func getAllPets(c *gin.Context) { + allPets := petstore.Pets + + c.JSON(http.StatusOK, allPets) +} + +func getPetByID(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + pet, err := petstore.PetByID(int64(id)) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, pet) +} + +func createPet(c *gin.Context) { + var pet petstore.PetDTO + err := c.BindJSON(&pet) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + petstore.AddPet(petstore.Pet{ + Name: pet.Name, + PhotoURLs: pet.PhotoURLs, + Status: pet.Status, + Tags: pet.Tags, + }) + + c.JSON(http.StatusOK, pet) +} + +func deletePet(c *gin.Context) { + id, err := strconv.Atoi(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + petstore.RemovePet(int64(id)) + + c.Status(http.StatusOK) +} diff --git a/tests/integration/13-middleware/middleware.go b/tests/integration/13-middleware/middleware.go new file mode 100644 index 0000000..8390376 --- /dev/null +++ b/tests/integration/13-middleware/middleware.go @@ -0,0 +1,62 @@ +package petstore + +import "github.com/gin-gonic/gin" + +func petsMiddleware(c *gin.Context) { + apiKey, ok := c.GetQuery("api_key") + if !ok { + c.AbortWithStatusJSON(400, gin.H{ + "message": "api_key is missing", + }) + return + } + + if apiKey != "1234567890" { + c.AbortWithStatusJSON(401, gin.H{ + "message": "invalid api_key", + }) + return + } + + c.Next() +} + +func handlerFunc() gin.HandlerFunc { + return func(c *gin.Context) { + apiKey, ok := c.GetQuery("inline_api_key") + if !ok { + c.AbortWithStatusJSON(400, gin.H{ + "message": "api_key is missing", + }) + return + } + + if apiKey != "1234567890" { + c.AbortWithStatusJSON(401, gin.H{ + "message": "invalid api_key", + }) + return + } + + c.Next() + } +} + +func headerMiddleware(c *gin.Context) { + authorization := c.GetHeader("Authorization") + if authorization == "" { + c.AbortWithStatusJSON(400, gin.H{ + "message": "Authorization header is missing", + }) + return + } + + if authorization != "Bearer 1234567890" { + c.AbortWithStatusJSON(401, gin.H{ + "message": "invalid Authorization header", + }) + return + } + + c.Next() +} diff --git a/tests/integration/13-middleware/output.json b/tests/integration/13-middleware/output.json deleted file mode 100644 index b03d13c..0000000 --- a/tests/integration/13-middleware/output.json +++ /dev/null @@ -1,390 +0,0 @@ -{ - "openapi": "3.0.0", - "info": { - "title": "", - "description": "Generated by astra", - "contact": {}, - "license": { - "name": "" - }, - "version": "" - }, - "servers": [ - { - "url": "http://localhost:8000" - } - ], - "paths": { - "/middleware": { - "get": { - "operationId": "headerMiddleware", - "parameters": [ - { - "name": "Authorization", - "in": "header", - "schema": { - "type": "string" - } - } - ], - "responses": { - "200": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - }, - "400": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - }, - "401": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - } - } - } - }, - "/no-middleware": { - "get": { - "responses": { - "200": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - } - } - } - }, - "/pets": { - "get": { - "operationId": "getAllPets", - "parameters": [ - { - "name": "api_key", - "in": "query", - "style": "form", - "explode": true, - "schema": { - "type": "string" - } - } - ], - "responses": { - "200": { - "description": "", - "content": { - "application/json": { - "schema": { - "type": "array", - "items": { - "$ref": "#/components/schemas/petstore.Pet" - } - } - } - } - }, - "400": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - }, - "401": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - } - } - }, - "post": { - "operationId": "createPet", - "parameters": [ - { - "name": "api_key", - "in": "query", - "style": "form", - "explode": true, - "schema": { - "type": "string" - } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/petstore.PetDTO" - } - } - } - }, - "responses": { - "200": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/petstore.PetDTO" - } - } - } - }, - "400": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - }, - "401": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - } - } - } - }, - "/pets/{id}": { - "get": { - "operationId": "getPetById", - "parameters": [ - { - "name": "id", - "in": "path", - "required": true, - "schema": { - "type": "string" - } - }, - { - "name": "api_key", - "in": "query", - "style": "form", - "explode": true, - "schema": { - "type": "string" - } - } - ], - "responses": { - "200": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/petstore.Pet" - } - } - } - }, - "400": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - }, - "401": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - }, - "404": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - } - } - }, - "delete": { - "operationId": "deletePet", - "parameters": [ - { - "name": "id", - "in": "path", - "required": true, - "schema": { - "type": "string" - } - }, - { - "name": "api_key", - "in": "query", - "style": "form", - "explode": true, - "schema": { - "type": "string" - } - } - ], - "responses": { - "200": { - "description": "" - }, - "400": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - }, - "401": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - } - } - } - }, - "/wrapper-func-middleware": { - "get": { - "responses": { - "200": { - "description": "", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/gin.H" - } - } - } - } - } - } - } - }, - "components": { - "schemas": { - "gin.H": { - "type": "object", - "additionalProperties": {}, - "description": "H is a shortcut for map[string]any" - }, - "petstore.Pet": { - "type": "object", - "properties": { - "id": { - "type": "integer", - "format": "int64" - }, - "name": { - "type": "string" - }, - "photoUrls": { - "type": "array", - "items": { - "type": "string" - } - }, - "status": { - "type": "string" - }, - "tags": { - "type": "array", - "items": { - "$ref": "#/components/schemas/petstore.Tag" - } - } - }, - "description": "Pet the pet model." - }, - "petstore.PetDTO": { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "photoUrls": { - "type": "array", - "items": { - "type": "string" - } - }, - "status": { - "type": "string" - }, - "tags": { - "type": "array", - "items": { - "$ref": "#/components/schemas/petstore.Tag" - } - } - }, - "description": "PetDTO the pet dto." - }, - "petstore.Tag": { - "type": "object", - "properties": { - "id": { - "type": "integer", - "format": "int64" - }, - "name": { - "type": "string" - } - }, - "description": "Tag the tag model." - } - } - } -} \ No newline at end of file diff --git a/tests/integration/13-middleware/paths_test.go b/tests/integration/13-middleware/paths_test.go new file mode 100644 index 0000000..ae98c1a --- /dev/null +++ b/tests/integration/13-middleware/paths_test.go @@ -0,0 +1,57 @@ +package petstore + +import ( + "github.com/ls6-events/astra" + "github.com/ls6-events/astra/tests/integration/helpers" + "github.com/stretchr/testify/require" + "testing" +) + +func TestMiddleware(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + r := setupRouter() + + testAstra, err := helpers.SetupTestAstraWithDefaultConfig(t, r, astra.UnstableWithMiddleware()) + require.NoError(t, err) + + paths := testAstra.Path("paths") + + // /pets middleware + // GET /pets + require.Equal(t, "string", paths.Search("/pets", "get", "parameters", "0", "schema", "type").Data().(string)) + require.Equal(t, "api_key", paths.Search("/pets", "get", "parameters", "0", "name").Data().(string)) + require.Equal(t, "query", paths.Search("/pets", "get", "parameters", "0", "in").Data().(string)) + + // GET /pets/{id} + require.Equal(t, "string", paths.Search("/pets/{id}", "get", "parameters", "0", "schema", "type").Data().(string)) + require.Equal(t, "api_key", paths.Search("/pets/{id}", "get", "parameters", "0", "name").Data().(string)) + require.Equal(t, "query", paths.Search("/pets/{id}", "get", "parameters", "0", "in").Data().(string)) + + // POST /pets + require.Equal(t, "string", paths.Search("/pets", "post", "parameters", "0", "schema", "type").Data().(string)) + require.Equal(t, "api_key", paths.Search("/pets", "post", "parameters", "0", "name").Data().(string)) + require.Equal(t, "query", paths.Search("/pets", "post", "parameters", "0", "in").Data().(string)) + + // DELETE /pets/{id} + require.Equal(t, "string", paths.Search("/pets/{id}", "delete", "parameters", "0", "schema", "type").Data().(string)) + require.Equal(t, "api_key", paths.Search("/pets/{id}", "delete", "parameters", "0", "name").Data().(string)) + require.Equal(t, "query", paths.Search("/pets/{id}", "delete", "parameters", "0", "in").Data().(string)) + + // /no-middleware + require.True(t, paths.Exists("/no-middleware", "get")) + require.False(t, paths.Exists("/no-middleware", "parameters", "0")) + + // /middleware + require.Equal(t, "string", paths.Search("/middleware", "get", "parameters", "0", "schema", "type").Data().(string)) + require.Equal(t, "Authorization", paths.Search("/middleware", "get", "parameters", "0", "name").Data().(string)) + require.Equal(t, "header", paths.Search("/middleware", "get", "parameters", "0", "in").Data().(string)) + + // /wrapper-func-middleware + // Higher order functions don't seem to play nicely with runtime.FuncForPC when not in JetBrains debugger (unknown reason, will require further investigation) + //require.Equal(t, "string", paths.Search("/wrapper-func-middleware", "get", "parameters", "0", "schema", "type").Data().(string)) + //require.Equal(t, "inline_api_key", paths.Search("/wrapper-func-middleware", "get", "parameters", "0", "name").Data().(string)) + //require.Equal(t, "query", paths.Search("/wrapper-func-middleware", "get", "parameters", "0", "in").Data().(string)) +} diff --git a/tests/integration/13-middleware/router.go b/tests/integration/13-middleware/router.go new file mode 100644 index 0000000..aa3ee11 --- /dev/null +++ b/tests/integration/13-middleware/router.go @@ -0,0 +1,36 @@ +package petstore + +import "github.com/gin-gonic/gin" + +func setupRouter() *gin.Engine { + r := gin.Default() + + pets := r.Group("/pets") + + pets.Use(petsMiddleware) + + pets.GET("", getAllPets) + pets.GET("/:id", getPetByID) + pets.POST("", createPet) + pets.DELETE("/:id", deletePet) + + r.GET("/no-middleware", func(c *gin.Context) { + c.JSON(200, gin.H{ + "message": "no middleware", + }) + }) + + r.GET("/middleware", headerMiddleware, func(c *gin.Context) { + c.JSON(200, gin.H{ + "message": "middleware", + }) + }) + + r.GET("/wrapper-func-middleware", handlerFunc(), func(c *gin.Context) { + c.JSON(200, gin.H{ + "message": "inline middleware", + }) + }) + + return r +} diff --git a/types.go b/types.go index 48deddd..4f87cb1 100644 --- a/types.go +++ b/types.go @@ -4,11 +4,15 @@ import "github.com/ls6-events/astra/astTraversal" // These are types that are used throughout the astra package. +type Handler struct { + Name string `json:"name" yaml:"name"` + File string `json:"file" yaml:"file"` + LineNo int `json:"lineNo" yaml:"lineNo"` +} + // Route is a route in the service and all of its potential options. type Route struct { - Handler string `json:"handler" yaml:"handler"` - File string `json:"file" yaml:"file"` - LineNo int `json:"lineNo" yaml:"lineNo"` + Handlers []Handler `json:"handlers,omitempty" yaml:"handlers,omitempty"` Method string `json:"method" yaml:"method"` Path string `json:"path" yaml:"path"` PathParams []Param `json:"params,omitempty" yaml:"params,omitempty"` // for now, we use :param in the path to denote a required path param, and *param to denote an optional path param.