Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions astTraversal/callExpression.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package astTraversal

import (
"errors"
"go/ast"
"go/types"
)
Expand All @@ -14,7 +15,7 @@ type CallExpressionTraverser struct {
func (t *BaseTraverser) CallExpression(node ast.Node) (*CallExpressionTraverser, error) {
callExpr, ok := node.(*ast.CallExpr)
if !ok {
return nil, ErrInvalidNodeType
return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *ast.CallExpr"))
}

return &CallExpressionTraverser{
Expand Down Expand Up @@ -59,7 +60,7 @@ func (c *CallExpressionTraverser) Args() []ast.Expr {

func (c *CallExpressionTraverser) Type() (*types.Func, error) {
if c.Node.Fun == nil {
return nil, ErrInvalidNodeType
return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *ast.CallExpr.Fun to not be nil"))
}

var obj types.Object
Expand All @@ -70,7 +71,7 @@ func (c *CallExpressionTraverser) Type() (*types.Func, error) {
case *ast.SelectorExpr:
obj, err = c.File.Package.FindObjectForIdent(nodeFun.Sel)
default:
err = ErrInvalidNodeType
err = errors.Join(ErrInvalidNodeType, errors.New("expected *ast.CallExpr.Fun to be *ast.Ident or *ast.SelectorExpr"))
}
if err != nil {
return nil, err
Expand All @@ -83,7 +84,7 @@ func (c *CallExpressionTraverser) Type() (*types.Func, error) {
return nil, ErrBuiltInFunction
}

return nil, ErrInvalidNodeType
return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *types.Func"))
}

func (c *CallExpressionTraverser) ReturnType(returnNum int) (types.Type, error) {
Expand All @@ -94,11 +95,11 @@ func (c *CallExpressionTraverser) ReturnType(returnNum int) (types.Type, error)

signature, ok := funcType.Type().(*types.Signature)
if !ok {
return nil, ErrInvalidNodeType
return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *types.Signature"))
}

if signature.Results().Len() <= returnNum {
return nil, ErrInvalidIndex
return nil, errors.Join(ErrInvalidIndex, errors.New("expected returnNum to be less than signature.Results().Len()"))
}

return signature.Results().At(returnNum).Type(), nil
Expand All @@ -112,11 +113,11 @@ func (c *CallExpressionTraverser) ArgType(argNum int) (types.Object, error) {

signature, ok := funcType.Type().(*types.Signature)
if !ok {
return nil, ErrInvalidNodeType
return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *types.Signature"))
}

if signature.Params().Len() <= argNum {
return nil, ErrInvalidIndex
return nil, errors.Join(ErrInvalidIndex, errors.New("expected argNum to be less than signature.Params().Len()"))
}

return signature.Params().At(argNum), nil
Expand Down
3 changes: 2 additions & 1 deletion astTraversal/declaration.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package astTraversal

import (
"errors"
"go/ast"
)

Expand Down Expand Up @@ -43,7 +44,7 @@ func (d *DeclarationTraverser) Value() (ast.Node, error) {

return n.Rhs[index], nil
default:
return nil, ErrInvalidNodeType
return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *ast.ValueSpec or *ast.AssignStmt"))
}
}

Expand Down
3 changes: 2 additions & 1 deletion astTraversal/extractStatusCode.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package astTraversal

import (
"errors"
"go/ast"
"strconv"
)
Expand All @@ -9,7 +10,7 @@ import (
func (t *BaseTraverser) ExtractStatusCode(status ast.Node) (int, error) {
exprNode, ok := status.(ast.Expr)
if !ok {
return 0, ErrInvalidNodeType
return 0, errors.Join(ErrInvalidNodeType, errors.New("expected ast.Expr"))
}

expr := t.Expression(exprNode)
Expand Down
3 changes: 2 additions & 1 deletion astTraversal/function.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package astTraversal

import (
"errors"
"go/ast"
"go/types"
)
Expand All @@ -25,7 +26,7 @@ func (t *BaseTraverser) Function(node ast.Node) (*FunctionTraverser, error) {
}
funcDecl = n
default:
return nil, ErrInvalidNodeType
return nil, errors.Join(ErrInvalidNodeType, errors.New("expected *ast.FuncLit or *ast.FuncDecl"))
}

return &FunctionTraverser{
Expand Down
3 changes: 2 additions & 1 deletion astTraversal/literal.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package astTraversal

import (
"errors"
"go/ast"
"go/types"
)
Expand All @@ -22,7 +23,7 @@ func (t *BaseTraverser) Literal(node ast.Node, returnNum int) (*LiteralTraverser
func (lt *LiteralTraverser) Type() (types.Type, error) {
exprNode, ok := lt.Node.(ast.Expr)
if !ok {
return nil, ErrInvalidNodeType
return nil, errors.Join(ErrInvalidNodeType, errors.New("expected ast.Expr"))
}

return lt.Traverser.Expression(exprNode).SetReturnNum(lt.ReturnNum).Type()
Expand Down
3 changes: 2 additions & 1 deletion astTraversal/type.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package astTraversal

import (
"errors"
"go/ast"
"go/token"
"go/types"
Expand Down Expand Up @@ -265,7 +266,7 @@ func (t *TypeTraverser) Result() (Result, error) {
if result.Type != "" {
return result, nil
} else {
return Result{}, ErrInvalidNodeType
return Result{}, errors.Join(ErrInvalidNodeType, errors.New("expected *types.Basic, *types.Named, *types.Pointer, *types.Slice, *types.Array, *types.Map, *types.Struct, or *types.Interface"))
}
}

Expand Down
51 changes: 42 additions & 9 deletions inputs/gin/createRoute.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,22 @@ package gin
import (
"os"
"path/filepath"
"runtime"

"github.com/ls6-events/astra"

"github.com/gin-gonic/gin"
)

var deniedHandlers = []string{
"github.com/gin-gonic/gin.LoggerWithConfig.func1",
"github.com/gin-gonic/gin.CustomRecoveryWithWriter.func1",
}

// createRoute creates a route from a gin RouteInfo.
// It will only create the route and refer to the handler function by name, file and line number.
// The route will be populated later by parseRoute.
func createRoute(s *astra.Service, file string, line int, info gin.RouteInfo) error {
// The route will be populated later by parseHandler.
func createRoute(s *astra.Service, handlers []uintptr, info gin.RouteInfo) error {
log := s.Log.With().Str("path", info.Path).Str("method", info.Method).Str("handler", info.Handler).Logger()

cwd, err := os.Getwd()
Expand All @@ -21,16 +27,43 @@ func createRoute(s *astra.Service, file string, line int, info gin.RouteInfo) er
return err
}

relativePath, err := filepath.Rel(cwd, file)
if err != nil {
log.Error().Err(err).Msg("Failed to get relative path")
return err
log.Debug().Msg("Found route handlers")
resultHandlers := make([]astra.Handler, 0)
for _, handler := range handlers {
funcForPC := runtime.FuncForPC(handler)
name := funcForPC.Name()
file, line := funcForPC.FileLine(handler)

var found bool
for _, deniedHandler := range deniedHandlers {
if deniedHandler == name {
found = true
break
}
}

if found {
log.Debug().Str("name", name).Str("file", file).Int("line", line).Msg("Denied handler")
continue
}

log.Debug().Str("name", name).Str("file", file).Int("line", line).Msg("Found handler")

relativePath, err := filepath.Rel(cwd, file)
if err != nil {
log.Error().Err(err).Msg("Failed to get relative path")
return err
}

resultHandlers = append(resultHandlers, astra.Handler{
Name: name,
File: relativePath,
LineNo: line,
})
}

baseRoute := astra.Route{
Handler: info.Handler,
File: relativePath,
LineNo: line,
Handlers: resultHandlers,
Path: info.Path,
Method: info.Method,
PathParams: make([]astra.Param, 0),
Expand Down
36 changes: 28 additions & 8 deletions inputs/gin/createRoutes.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package gin

import (
"errors"
"reflect"
"runtime"

"github.com/gin-gonic/gin"

"github.com/ls6-events/astra"
)

"github.com/gin-gonic/gin"
var (
ErrRouteNotFound = errors.New("route not found")
)

// CreateRoutes creates routes from a gin routes.
Expand All @@ -16,6 +20,13 @@ import (
func CreateRoutes(router *gin.Engine) astra.ServiceFunction {
return func(s *astra.Service) error {
s.Log.Debug().Msg("Populating service with gin routes")

// To prevent performance issues, only find the gin.Engine tree if middleware is enabled
var trees reflect.Value
if s.UnstableEnableMiddleware {
trees = reflect.ValueOf(router).Elem().FieldByName("trees")
}

for _, route := range router.Routes() {
s.Log.Debug().Str("path", route.Path).Str("method", route.Method).Msg("Populating route")

Expand All @@ -31,15 +42,24 @@ func CreateRoutes(router *gin.Engine) astra.ServiceFunction {
continue
}

pc := reflect.ValueOf(route.HandlerFunc).Pointer()
file, line := runtime.FuncForPC(pc).FileLine(pc)
var handlers []uintptr

s.Log.Debug().Str("path", route.Path).Str("method", route.Method).Str("file", file).Int("line", line).Msg("Found route handler")
// If middleware is enabled, find the handlers for the route using the gin.Engine tree
if s.UnstableEnableMiddleware {
var found bool
handlers, found = findHandlersForRoute(trees, route)
if !found {
s.Log.Error().Str("path", route.Path).Str("method", route.Method).Msg("Route not found")
return ErrRouteNotFound
}
} else {
// If middleware is disabled, use the gin.Engine handlers
handlers = []uintptr{reflect.ValueOf(route.HandlerFunc).Pointer()}
}

s.Log.Debug().Str("path", route.Path).Str("method", route.Method).Str("file", file).Int("line", line).Msg("Parsing route")
err := createRoute(s, file, line, route)
err := createRoute(s, handlers, route)
if err != nil {
s.Log.Error().Str("path", route.Path).Str("method", route.Method).Str("file", file).Int("line", line).Err(err).Msg("Failed to parse route")
s.Log.Error().Str("path", route.Path).Err(err).Msg("Failed to parse route")
return err
}
}
Expand Down
54 changes: 54 additions & 0 deletions inputs/gin/findHandlersForRoute.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package gin

import (
"reflect"

"github.com/gin-gonic/gin"
)

// findHandlersForRoute finds all handlers for a given route.
// It uses reflection to access the private properties of the gin.Engine,specifically it's method tree.
// It returns the handler pointers and a boolean indicating if the route was found.
func findHandlersForRoute(tree reflect.Value, route gin.RouteInfo) ([]uintptr, bool) {
for i := 0; i < tree.Len(); i++ {
method := tree.Index(i)
methodString := method.FieldByName("method").String()
if methodString != route.Method {
continue
}

node := method.FieldByName("root")

foundRoute, found := searchNode(node.Elem(), route.Path)
if found {
handlersField := foundRoute.FieldByName("handlers")
handlers := make([]uintptr, handlersField.Len())

for j := 0; j < handlersField.Len(); j++ {
handlers[j] = handlersField.Index(j).Pointer()
}

return handlers, true
}
}

return nil, false
}

// searchNode searches a gin.Engine node for a given path.
// It uses recursion to search all children nodes.
// It returns the node and a boolean indicating if the path was found.
func searchNode(node reflect.Value, path string) (reflect.Value, bool) {
if node.FieldByName("fullPath").String() == path && node.FieldByName("handlers").Len() > 0 {
return node, true
}

for i := 0; i < node.FieldByName("children").Len(); i++ {
child, found := searchNode(node.FieldByName("children").Index(i).Elem(), path)
if found {
return child, true
}
}

return reflect.Value{}, false
}
8 changes: 4 additions & 4 deletions inputs/gin/parseFunction.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ const (
// And the package name and path are used to determine the package of the currently analysed function.
// The currRoute reference is used to manipulate the current route being analysed.
// The imports are used to determine the package of the context variable.
func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTraverser, currRoute *astra.Route, activeFile *astTraversal.FileNode, level int) error {
func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTraverser, currRoute *astra.Route, activeFile *astTraversal.FileNode, isLastHandler bool, level int) error {
traverser := funcTraverser.Traverser

traverser.SetActiveFile(activeFile)
Expand Down Expand Up @@ -95,7 +95,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
return false
}

err = parseFunction(s, function, currRoute, function.Traverser.ActiveFile(), level+1)
err = parseFunction(s, function, currRoute, function.Traverser.ActiveFile(), isLastHandler, level+1)
if err != nil {
traverser.Log.Error().Err(err).Msg("error parsing function")
return false
Expand Down Expand Up @@ -696,11 +696,11 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
return true
})

if err != nil {
if err != nil && !errors.Is(err, astTraversal.ErrInvalidNodeType) {
return err
}

if len(currRoute.ReturnTypes) == 0 && level == 0 {
if len(currRoute.ReturnTypes) == 0 && level == 0 && isLastHandler {
return errors.New("return type not found")
}

Expand Down
Loading