diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 73b8c68..670c758 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -39,7 +39,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v2 with: - go-version: 1.14 + go-version: 1.16 - name: Run GoReleaser uses: goreleaser/goreleaser-action@v2 env: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2df87e5..76a06b2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v2 with: - go-version: 1.14 + go-version: 1.16 - name: Restore cache uses: actions/cache@v1 with: diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..f5b6099 --- /dev/null +++ b/cache.go @@ -0,0 +1,99 @@ +package goone + +import ( + "go/token" + "golang.org/x/tools/go/analysis" + "strconv" + "sync" +) + +type ReportCache struct { + sync.Mutex + reportMemo map[string]bool +} + +func NewReportCache() *ReportCache { + return &ReportCache{ + reportMemo: make(map[string]bool), + } +} +func (m *ReportCache) toKey(pass *analysis.Pass, pos token.Pos) (key string) { + posn := pass.Fset.Position(pos) + fileName, lineNum := posn.Filename, posn.Line + key = fileName + strconv.Itoa(lineNum) + return +} +func (m *ReportCache) Set(pass *analysis.Pass, pos token.Pos, value bool) { + key := m.toKey(pass, pos) + m.reportMemo[key] = value +} + +func (m *ReportCache) Get(pass *analysis.Pass, pos token.Pos) bool { + key := m.toKey(pass, pos) + value := m.reportMemo[key] + return value +} + +type FuncCache struct { + sync.Mutex + funcMemo map[token.Pos]bool +} + +func NewFuncCache() *FuncCache { + return &FuncCache{ + funcMemo: make(map[token.Pos]bool), + } +} + +func (m *FuncCache) Set(key token.Pos, value bool) { + m.Lock() + m.funcMemo[key] = value + m.Unlock() +} + + +func (m *FuncCache) Get(key token.Pos) (bool,bool) { + m.Lock() + value, ok := m.funcMemo[key] + m.Unlock() + if !ok { + return false, false + } + return value, true +} + +type PkgCache struct { + sync.Mutex + pkgMemo map[string]bool +} + +func NewPkgCache() *PkgCache { + return &PkgCache{ + pkgMemo: make(map[string]bool), + } +} + +func (m *PkgCache) Set(name string, value bool) { + m.pkgMemo[name] = value +} + +func (m *PkgCache) Get(name string) bool { + value := m.pkgMemo[name] + return value +} + +func (m *PkgCache) Exists(key string) bool { + m.Lock() + _, exist := m.pkgMemo[key] + m.Unlock() + return exist +} + +// pkgCache contains pkg AST +var pkgCache *PkgCache + +// reportCache manages whether if this line already reported +var reportCache *ReportCache + +// funcCache manages whether if this function contains queries +var funcCache *FuncCache \ No newline at end of file diff --git a/go.mod b/go.mod index adeb3ae..cf8afb0 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,11 @@ module github.com/masibw/goone -go 1.14 +go 1.16 require ( github.com/gostaticanalysis/analysisutil v0.6.1 + // idk why, but we need this indirect dependency to pass the test github.com/masibw/goone_test v0.0.0-20210112093021-7d2e0b363db0 + github.com/masibw/goone_test v0.0.0-20210112093021-7d2e0b363db0 // indirect golang.org/x/tools v0.0.0-20210111221946-d33bae441459 gopkg.in/yaml.v2 v2.4.0 ) diff --git a/go.sum b/go.sum index f37a473..f982b15 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/gostaticanalysis/analysisutil v0.6.1 h1:/1JkoHe4DVxur+0wPvi26FoQfe1E3 github.com/gostaticanalysis/analysisutil v0.6.1/go.mod h1:18U/DLpRgIUd459wGxVHE0fRgmo1UgHDcbw7F5idXu0= github.com/gostaticanalysis/comment v1.4.1 h1:xHopR5L2lRz6OsjH4R2HG5wRhW9ySl3FsHIvi5pcXwc= github.com/gostaticanalysis/comment v1.4.1/go.mod h1:ih6ZxzTHLdadaiSnF5WY3dxUoXfXAlTaRzuaNDlSado= +github.com/masibw/goone_test v0.0.0-20210112093021-7d2e0b363db0 h1:PEBlkDmuNQMxcBaznRvs2sP0UqVPoXsNx5oWWK3MmX0= +github.com/masibw/goone_test v0.0.0-20210112093021-7d2e0b363db0/go.mod h1:yBWoicU1E30NC++4C6bor5y7dCFrobTb0jGVEPpH98Q= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1 h1:ruQGxdhGHe7FWOJPT0mKs5+pD2Xs1Bm/kdGlHO04FmM= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/goone.go b/goone.go index 28af593..0e15e27 100644 --- a/goone.go +++ b/goone.go @@ -2,25 +2,15 @@ package goone import ( "fmt" + "github.com/gostaticanalysis/analysisutil" "go/ast" - "go/token" "go/types" - "io/ioutil" - "log" - "os" - "path/filepath" - "strconv" - "strings" - "sync" - - "golang.org/x/tools/go/packages" - - "gopkg.in/yaml.v2" - - "github.com/gostaticanalysis/analysisutil" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/inspect" "golang.org/x/tools/go/ast/inspector" + "golang.org/x/tools/go/packages" + "log" + "strings" ) type Types struct { @@ -34,96 +24,6 @@ type Types struct { const doc = "goone finds N+1 query " -type ReportCache struct { - sync.Mutex - reportMemo map[string]bool -} - -func NewReportCache() *ReportCache { - return &ReportCache{ - reportMemo: make(map[string]bool), - } -} -func (m *ReportCache) toKey(pass *analysis.Pass, pos token.Pos) (key string) { - posn := pass.Fset.Position(pos) - fileName, lineNum := posn.Filename, posn.Line - key = fileName + strconv.Itoa(lineNum) - return -} -func (m *ReportCache) Set(pass *analysis.Pass, pos token.Pos, value bool) { - key := m.toKey(pass, pos) - m.reportMemo[key] = value -} - -func (m *ReportCache) Get(pass *analysis.Pass, pos token.Pos) bool { - key := m.toKey(pass, pos) - value := m.reportMemo[key] - return value -} - -type SearchCache struct { - sync.Mutex - searchMemo map[token.Pos]bool -} - -func NewSearchCache() *SearchCache { - return &SearchCache{ - searchMemo: make(map[token.Pos]bool), - } -} - -func (m *SearchCache) Set(key token.Pos, value bool) { - m.Lock() - m.searchMemo[key] = value - m.Unlock() -} - -func (m *SearchCache) Get(key token.Pos) bool { - m.Lock() - value := m.searchMemo[key] - m.Unlock() - return value -} - -type FuncCache struct { - sync.Mutex - funcMemo map[token.Pos]bool -} - -func NewFuncCache() *FuncCache { - return &FuncCache{ - funcMemo: make(map[token.Pos]bool), - } -} - -func (m *FuncCache) Set(key token.Pos, value bool) { - m.Lock() - m.funcMemo[key] = value - m.Unlock() -} - -func (m *FuncCache) Exists(key token.Pos) bool { - m.Lock() - _, exist := m.funcMemo[key] - m.Unlock() - return exist -} - -func (m *FuncCache) Get(key token.Pos) bool { - m.Lock() - value := m.funcMemo[key] - m.Unlock() - return value -} - -// searchCache manages whether if already searched this node -var searchCache *SearchCache - -// reportCache manages whether if this line already reported -var reportCache *ReportCache - -// funcCache manages whether if this function contains queries -var funcCache *FuncCache var configPath string // Analyzer is analysis files @@ -138,95 +38,10 @@ var Analyzer = &analysis.Analyzer{ } var sqlTypes []string -func init() { - Analyzer.Flags.StringVar(&configPath, "configPath", "", "config file path(abs)") -} - -func appendTypes(pass *analysis.Pass, pkg, name string) { - - //TODO Use types instead of the name of types - //if typ := analysisutil.TypeOf(pass, pkg, name); typ != nil { - // sqlTypes = append(sqlTypes, typ) - //} else { - // if name[0] == '*' { - // name = name[1:] - // pkg = "*" + pkg - // } - // for _, v := range pass.TypesInfo.Types { - // if v.Type.String() == pkg+"."+name { - // sqlTypes = append(sqlTypes, v.Type) - // return - // } - // } - //} - - if name[0] == '*' { - name = name[1:] - pkg = "*" + pkg - } - sqlTypes = append(sqlTypes, pkg+"."+name) -} - -func fileExists(filename string) bool { - _, err := os.Stat(filename) - return err == nil -} - -func readTypeConfig() *Types { - var cp string - // if configPath flag is not set - if configPath ==""{ - - curDir, _ := os.Getwd() - for !fileExists(curDir+"/goone.yml"){ - // Search up to the root - if curDir == filepath.Dir(curDir) || curDir == ""{ - // If goone.yml is not found - return nil - } - curDir = filepath.Dir(curDir) - } - cp = curDir+"/goone.yml" - }else{ - cp = configPath - } - - buf, err := ioutil.ReadFile(cp) - if err != nil { - return nil - } - typesFromConfig := Types{} - err = yaml.Unmarshal([]byte(buf), &typesFromConfig) - if err != nil { - log.Fatalf("yml parse error:%v", err) - } - - return &typesFromConfig -} - -func prepareTypes(pass *analysis.Pass){ - typesFromConfig := readTypeConfig() - if typesFromConfig != nil { - for _, pkg := range typesFromConfig.Package { - pkgName := pkg.PkgName - for _, typeName := range pkg.TypeNames { - appendTypes(pass, pkgName, typeName.TypeName) - } - } - } - - appendTypes(pass, "database/sql", "*DB") - appendTypes(pass, "gorm.io/gorm", "*DB") - appendTypes(pass, "gopkg.in/gorp.v1", "*DbMap") - appendTypes(pass, "gopkg.in/gorp.v2", "*DbMap") - appendTypes(pass, "github.com/go-gorp/gorp/v3", "*DbMap") - appendTypes(pass, "github.com/jmoiron/sqlx", "*DB") -} - func run(pass *analysis.Pass) (interface{}, error) { - searchCache = NewSearchCache() funcCache = NewFuncCache() reportCache = NewReportCache() + pkgCache = NewPkgCache() inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) prepareTypes(pass) forFilter := []ast.Node{ @@ -237,7 +52,7 @@ func run(pass *analysis.Pass) (interface{}, error) { inspect.Preorder(forFilter, func(n ast.Node) { switch n := n.(type) { case *ast.ForStmt, *ast.RangeStmt: - findQuery(pass, n, nil, nil) + findQuery(pass, n, nil, nil) } }) return nil, nil @@ -246,12 +61,11 @@ func run(pass *analysis.Pass) (interface{}, error) { func inspectFile(pass *analysis.Pass, parentNode ast.Node, file *ast.File, funcName string, pkgTypes *types.Info) { inspect := inspector.New([]*ast.File{file}) types := []ast.Node{new(ast.FuncDecl)} - inspect.Preorder(types, func(n ast.Node) { switch n := n.(type) { case *ast.FuncDecl: if n.Name.Name == funcName { - findQuery(pass, n, parentNode, pkgTypes) + findQuery(pass, n, parentNode, pkgTypes) } } }) @@ -264,6 +78,7 @@ func loadImportPackages(pkgName string) (pkgs []*packages.Package) { if err != nil { log.Fatalf(err.Error()) } + pkgCache.Set(pkgName,true) return } @@ -273,11 +88,21 @@ func anotherFileSearch(pass *analysis.Pass, funcExpr *ast.Ident, parentNode ast. if file == nil { if anotherFileNode.Pkg() != nil { importPath := convertToImportPath(pass, anotherFileNode.Pkg().Name()) - pkgs := loadImportPackages(importPath) - for _, pkg := range pkgs { - for _, file := range pkg.Syntax { - inspectFile(pass, parentNode, file, funcExpr.Name, pkg.TypesInfo) + if pkgCache.Exists(importPath){ + // Check whether if already checked function + if containsQuery, ok := funcCache.Get(funcExpr.Pos()); ok{ + if containsQuery { + pass.Reportf(parentNode.Pos(), "this query is called in a loop") + } + return false + } + } + // If the package has not been loaded yet, load it. + pkgs := loadImportPackages(importPath) + for i := range pkgs { + for j := range pkgs[i].Syntax { + inspectFile(pass, parentNode, pkgs[i].Syntax[j], funcExpr.Name, pkgs[i].TypesInfo) } } } @@ -291,8 +116,7 @@ func anotherFileSearch(pass *analysis.Pass, funcExpr *ast.Ident, parentNode ast. func findQuery(pass *analysis.Pass, rootNode, parentNode ast.Node, pkgTypes *types.Info) { //nolint:gocognit - if funcCache.Exists(rootNode.Pos()) { - if funcCache.Get(rootNode.Pos()) { + if containsQuery, ok := funcCache.Get(rootNode.Pos()); ok && containsQuery{ reportCache.Lock() if reportCache.Get(pass, parentNode.Pos()) { @@ -304,7 +128,6 @@ func findQuery(pass *analysis.Pass, rootNode, parentNode ast.Node, pkgTypes *typ pass.Reportf(parentNode.Pos(), "this query is called in a loop") - } return } @@ -327,9 +150,9 @@ func findQuery(pass *analysis.Pass, rootNode, parentNode ast.Node, pkgTypes *typ if reportNode == nil { reportNode = node } - for _, typ := range sqlTypes { + for i := range sqlTypes { //TODO Comparing by string is bad, but I don't have any ideas to compare another package's type - if strings.TrimPrefix(tv.Type.String(), "vendor/") == typ || strings.TrimPrefix(tv.Type.String(), "*vendor/") == strings.TrimPrefix(typ, "*") { + if strings.TrimPrefix(tv.Type.String(), "vendor/") == sqlTypes[i] || strings.TrimPrefix(tv.Type.String(), "*vendor/") == strings.TrimPrefix(sqlTypes[i], "*") { //if types.Identical(tv.Type,typ){ reportCache.Lock() if reportCache.Get(pass, reportNode.Pos()) { @@ -357,16 +180,14 @@ func findQuery(pass *analysis.Pass, rootNode, parentNode ast.Node, pkgTypes *typ } switch decl := obj.Decl.(type) { case *ast.FuncDecl: - if !searchCache.Get(decl.Pos()) { - searchCache.Set(decl.Pos(), true) + if _, ok := funcCache.Get(decl.Pos()); !ok { newParentNode := parentNode if parentNode == nil { newParentNode = node } - - findQuery(pass, decl, newParentNode, nil) + findQuery(pass, decl, newParentNode, nil) } else { - if funcCache.Get(decl.Pos()) { + if containsQuery,_ := funcCache.Get(decl.Pos()); containsQuery{ reportCache.Lock() if reportCache.Get(pass, node.Pos()) { @@ -388,20 +209,28 @@ func findQuery(pass *analysis.Pass, rootNode, parentNode ast.Node, pkgTypes *typ //TODO get packageName without fmt.Sprinf importPath := convertToImportPath(pass, fmt.Sprintf("%s", funcExpr.X)) + if pkgCache.Exists(importPath){ + if containsQuery, ok := funcCache.Get(funcExpr.Pos()); ok{ + if containsQuery { + pass.Reportf(parentNode.Pos(), "this query is called in a loop") + } + return false + } + } + // If the package has not been loaded yet, load it pkgs := loadImportPackages(importPath) - for _, pkg := range pkgs { + for i := range pkgs { //Do not scan standard packages. Is it...ok? - if !strings.Contains(pkg.PkgPath, ".") { + if !strings.Contains(pkgs[i].PkgPath, ".") { continue } - for _, file := range pkg.Syntax { + for j := range pkgs[i].Syntax { newParentNode := parentNode if parentNode == nil { newParentNode = node } - - inspectFile(pass, newParentNode, file, funcExpr.Sel.Name, pkg.TypesInfo) + inspectFile(pass, newParentNode, pkgs[i].Syntax[j], funcExpr.Sel.Name, pkgs[i].TypesInfo) } } } @@ -415,9 +244,9 @@ func findQuery(pass *analysis.Pass, rootNode, parentNode ast.Node, pkgTypes *typ } func convertToImportPath(pass *analysis.Pass, pkgName string) (importPath string) { - for _, v := range pass.Pkg.Imports() { - if strings.HasSuffix("/"+v.Path(), pkgName) { - importPath = v.Path() + for i := range pass.Pkg.Imports() { + if strings.HasSuffix("/"+pass.Pkg.Imports()[i].Path(), pkgName) { + importPath = pass.Pkg.Imports()[i].Path() if strings.HasPrefix(importPath, "vendor/") { importPath = strings.TrimPrefix(importPath, "vendor/") } diff --git a/initialize.go b/initialize.go new file mode 100644 index 0000000..297dfae --- /dev/null +++ b/initialize.go @@ -0,0 +1,95 @@ +package goone + +import ( + "golang.org/x/tools/go/analysis" + "gopkg.in/yaml.v2" + "io/ioutil" + "log" + "os" + "path/filepath" +) + +func init() { + Analyzer.Flags.StringVar(&configPath, "configPath", "", "config file path(abs)") +} + +func appendTypes(pass *analysis.Pass, pkg, name string) { + + //TODO Use types instead of the name of types + //if typ := analysisutil.TypeOf(pass, pkg, name); typ != nil { + // sqlTypes = append(sqlTypes, typ) + //} else { + // if name[0] == '*' { + // name = name[1:] + // pkg = "*" + pkg + // } + // for _, v := range pass.TypesInfo.Types { + // if v.Type.String() == pkg+"."+name { + // sqlTypes = append(sqlTypes, v.Type) + // return + // } + // } + //} + + if name[0] == '*' { + name = name[1:] + pkg = "*" + pkg + } + sqlTypes = append(sqlTypes, pkg+"."+name) +} + +func fileExists(filename string) bool { + _, err := os.Stat(filename) + return err == nil +} + +func readTypeConfig() *Types { + var cp string + // if configPath flag is not set + if configPath ==""{ + + curDir, _ := os.Getwd() + for !fileExists(curDir+"/goone.yml"){ + // Search up to the root + if curDir == filepath.Dir(curDir) || curDir == ""{ + // If goone.yml is not found + return nil + } + curDir = filepath.Dir(curDir) + } + cp = curDir+"/goone.yml" + }else{ + cp = configPath + } + + buf, err := ioutil.ReadFile(cp) + if err != nil { + return nil + } + typesFromConfig := Types{} + err = yaml.Unmarshal([]byte(buf), &typesFromConfig) + if err != nil { + log.Fatalf("yml parse error:%v", err) + } + + return &typesFromConfig +} + +func prepareTypes(pass *analysis.Pass){ + typesFromConfig := readTypeConfig() + if typesFromConfig != nil { + for _, pkg := range typesFromConfig.Package { + pkgName := pkg.PkgName + for _, typeName := range pkg.TypeNames { + appendTypes(pass, pkgName, typeName.TypeName) + } + } + } + + appendTypes(pass, "database/sql", "*DB") + appendTypes(pass, "gorm.io/gorm", "*DB") + appendTypes(pass, "gopkg.in/gorp.v1", "*DbMap") + appendTypes(pass, "gopkg.in/gorp.v2", "*DbMap") + appendTypes(pass, "github.com/go-gorp/gorp/v3", "*DbMap") + appendTypes(pass, "github.com/jmoiron/sqlx", "*DB") +} \ No newline at end of file diff --git a/testdata/src/go.mod b/testdata/src/go.mod index 63226c4..504a16d 100644 --- a/testdata/src/go.mod +++ b/testdata/src/go.mod @@ -1,11 +1,11 @@ module github.com/masibw/goone/testdata/src -go 1.14 +go 1.16 require ( github.com/go-sql-driver/mysql v1.5.0 github.com/jmoiron/sqlx v1.2.0 - github.com/masibw/goone_test v0.0.0-20210109133101-80916951cb8c + github.com/masibw/goone_test v0.0.0-20210112093021-7d2e0b363db0 github.com/ziutek/mymysql v1.5.4 // indirect gopkg.in/gorp.v1 v1.7.2 gorm.io/driver/mysql v1.0.3 diff --git a/testdata/src/go.sum b/testdata/src/go.sum index 23eba8c..7ad4e45 100644 --- a/testdata/src/go.sum +++ b/testdata/src/go.sum @@ -11,6 +11,8 @@ github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/masibw/goone_test v0.0.0-20210109133101-80916951cb8c h1:p/c3QqAkPJAYH0dIAPn4T1sDPpnxbf69w71tmTWN6Mw= github.com/masibw/goone_test v0.0.0-20210109133101-80916951cb8c/go.mod h1:yBWoicU1E30NC++4C6bor5y7dCFrobTb0jGVEPpH98Q= +github.com/masibw/goone_test v0.0.0-20210112093021-7d2e0b363db0 h1:PEBlkDmuNQMxcBaznRvs2sP0UqVPoXsNx5oWWK3MmX0= +github.com/masibw/goone_test v0.0.0-20210112093021-7d2e0b363db0/go.mod h1:yBWoicU1E30NC++4C6bor5y7dCFrobTb0jGVEPpH98Q= github.com/mattn/go-sqlite3 v1.9.0 h1:pDRiWfl+++eC2FEFRy6jXmQlvp4Yh3z1MJKg4UeYM/4= github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs=