diff --git a/cmd/version.go b/cmd/version.go index ec22c99..6ec9184 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -6,7 +6,7 @@ import ( "github.com/spf13/cobra" ) -const Version = "1.0.4" +const Version = "1.0.6" var versionCmd = &cobra.Command{ Use: "version", diff --git a/go.mod b/go.mod index fa3fb6f..bcc1017 100644 --- a/go.mod +++ b/go.mod @@ -6,13 +6,13 @@ require ( github.com/bmatcuk/doublestar/v4 v4.8.1 github.com/pkg/errors v0.9.1 github.com/spf13/cobra v1.9.1 - golang.org/x/tools v0.21.0 + golang.org/x/tools v0.24.0 gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/spf13/pflag v1.0.6 // indirect - golang.org/x/mod v0.17.0 // indirect - golang.org/x/sync v0.7.0 // indirect + golang.org/x/mod v0.20.0 // indirect + golang.org/x/sync v0.8.0 // indirect ) diff --git a/go.sum b/go.sum index 6411f59..d89572c 100644 --- a/go.sum +++ b/go.sum @@ -10,12 +10,12 @@ github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= -golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= +golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/parser/dependency.go b/parser/dependency.go index f88840b..252f8e4 100644 --- a/parser/dependency.go +++ b/parser/dependency.go @@ -27,9 +27,9 @@ type interfaceInfo struct { // interfaceImplementations 存储接口和实现类型的映射关系 type interfaceImplementations struct { - // 接口ID -> 实现该接口的类型ID列表 + // InterfaceID -> 实现该接口的类型ID列表 ImplementersMap map[string][]string - // 接口ID -> 接口方法名称 -> 实现该方法的类型ID列表 + // InterfaceID -> 接口方法名称 -> 实现该方法的类型ID列表 MethodImplementersMap map[string]map[string][]string } @@ -54,7 +54,7 @@ func GetObjectID(pkg string, fileName string, obj string) string { type DependencyInfo struct { // 项目内所有的顶级声明, key: NodeID, value: Node nodes map[string]*node - // 依赖图的反向图, key: 依赖的节点ID, value: 集合(set)内存放被依赖的节点ID + // 依赖图的反向图, key: NodeID, value: 依赖key的NodeID列表 revGraph Graph } @@ -64,11 +64,10 @@ func (d *DependencyInfo) GetDependency(targetID string) ([]string, error) { return nil, fmt.Errorf("target %s is not defined in project", targetID) } - // 利用深度优先搜索(DFS)查找所有直接或间接依赖 targetID 的节点 visited := make(map[string]struct{}) var dfs func(string) dfs = func(node string) { - for dep, _ := range d.revGraph[node] { + for dep := range d.revGraph[node] { if _, ok := visited[dep]; !ok { visited[dep] = struct{}{} dfs(dep) @@ -77,10 +76,6 @@ func (d *DependencyInfo) GetDependency(targetID string) ([]string, error) { } dfs(targetID) - if len(visited) == 0 { - return []string{}, nil - } - deps := make([]string, 0, len(visited)) for id := range visited { deps = append(deps, id) @@ -98,7 +93,7 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { nodesInfo := make(map[string]*node) // 所有接口信息 - // key: 接口的fullName表示, value: 接口信息 + // key: InterfaceID, value: 接口信息 interfacesInfo := make(map[string]*interfaceInfo) // 接口与实现类型的映射关系 @@ -107,15 +102,15 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { MethodImplementersMap: make(map[string]map[string][]string), } - // 依赖图: key->当前节点ID, value->集合(set)内存放依赖的节点ID + // 依赖图: key->NodeID, value->依赖Key的NodeID列表 graph := make(Graph) // 遍历所有包和文件,提取顶级声明,构建接口表 for _, pkg := range pkgs { fset := pkg.Fset for _, file := range pkg.Syntax { - fullFilename := fset.File(file.Pos()).Name() // 文件的绝对路径 - baseFilename := filepath.Base(fullFilename) // 文件名 + fullFilename := fset.File(file.Pos()).Name() + baseFilename := filepath.Base(fullFilename) // 遍历文件中的所有顶级声明 for _, decl := range file.Decls { switch d := decl.(type) { @@ -153,6 +148,7 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { continue } id := GetObjectID(pkg.ID, baseFilename, ident.Name) + obj := pkg.TypesInfo.Defs[ident] if obj == nil { continue @@ -189,6 +185,45 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { graph[id] = make(map[string]struct{}) } + // 处理泛型类型参数 + typeParams := make(map[string]*types.TypeParam) + if s.TypeParams != nil && len(s.TypeParams.List) > 0 { + for _, field := range s.TypeParams.List { + for _, name := range field.Names { + // 获取类型参数对象 + typeParamObj := pkg.TypesInfo.Defs[name] + if typeParamObj == nil { + continue + } + + // 转换为TypeParam类型 + typeParam, ok := typeParamObj.Type().(*types.TypeParam) + if !ok { + continue + } + + // 存储类型参数信息 + typeParams[name.Name] = typeParam + + // 处理类型参数之间的依赖关系 + if field.Type != nil { + // 解析类型参数的约束,添加对其他类型的依赖 + ast.Inspect(field.Type, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok && ident.Name != name.Name { + // 获取约束中引用的类型 + if refObj := pkg.TypesInfo.Uses[ident]; refObj != nil { + if depID, exists := nodesMap[refObj]; exists && depID != id { + graph[id][depID] = struct{}{} + } + } + } + return true + }) + } + } + } + } + // 检查是否为接口定义 if t, ok := s.Type.(*ast.InterfaceType); ok { // 记录接口信息 @@ -212,6 +247,25 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { } iface.Methods[name.Name] = methodFunc + + // 处理方法中使用的泛型参数 + ast.Inspect(method.Type, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + // 检查是否引用了类型参数 + if _, exists := typeParams[ident.Name]; exists { + // 方法使用了泛型参数,这里可以做额外处理 + // 例如记录方法与泛型参数的关联 + } + + // 检查是否引用了其他类型 + if refObj := pkg.TypesInfo.Uses[ident]; refObj != nil { + if depID, exists := nodesMap[refObj]; exists && depID != id { + graph[id][depID] = struct{}{} + } + } + } + return true + }) } } } @@ -228,6 +282,19 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { for methodName := range iface.Methods { interfaceImpls.MethodImplementersMap[id][methodName] = []string{} } + } else { + // 非接口类型定义,处理泛型类型参数在类型定义中的使用 + ast.Inspect(s.Type, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + // 检查是否引用了其他类型 + if refObj := pkg.TypesInfo.Uses[ident]; refObj != nil { + if depID, exists := nodesMap[refObj]; exists && depID != id { + graph[id][depID] = struct{}{} + } + } + } + return true + }) } } } @@ -235,10 +302,8 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { } } } - // key: 类型ID, value: 方法名 -> 节点ID typeMethodsMap := make(map[string]map[string]string) - // 提取所有类型的方法 for nodeID, node := range nodesInfo { // 检查是否为方法声明(形如 (Type).Method 的名称) @@ -320,7 +385,6 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { if methodNode == nil || methodNode.Obj == nil { continue } - typeMethod, ok := methodNode.Obj.(*types.Func) if !ok { continue @@ -361,6 +425,93 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { // 辅助函数:处理一个AST节点(函数体或变量初始化表达式)来查找依赖的顶级对象 collectDependencies := func(n ast.Node, curNodeID string, pkg *packages.Package) { ast.Inspect(n, func(n ast.Node) bool { + // 处理类型转换和类型断言 + if typeAssert, ok := n.(*ast.TypeAssertExpr); ok { + if typeAssert.Type != nil { + ast.Inspect(typeAssert.Type, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + if obj := pkg.TypesInfo.Uses[ident]; obj != nil { + if depID, ok := nodesMap[obj]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + return true + }) + } + } + + // 处理结构体字面量 + if compLit, ok := n.(*ast.CompositeLit); ok { + if compLit.Type != nil { + ast.Inspect(compLit.Type, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + if obj := pkg.TypesInfo.Uses[ident]; obj != nil { + if depID, ok := nodesMap[obj]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + return true + }) + } + // 处理结构体字段 + for _, elt := range compLit.Elts { + if kv, ok := elt.(*ast.KeyValueExpr); ok { + if ident, ok := kv.Key.(*ast.Ident); ok { + if obj := pkg.TypesInfo.Uses[ident]; obj != nil { + if depID, ok := nodesMap[obj]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + } + } + } + + // 处理函数类型(包括参数和返回值) + if funcType, ok := n.(*ast.FuncType); ok { + // 处理参数 + if funcType.Params != nil { + for _, field := range funcType.Params.List { + ast.Inspect(field.Type, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + if obj := pkg.TypesInfo.Uses[ident]; obj != nil { + if depID, ok := nodesMap[obj]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + return true + }) + } + } + // 处理返回值 + if funcType.Results != nil { + for _, field := range funcType.Results.List { + ast.Inspect(field.Type, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + if obj := pkg.TypesInfo.Uses[ident]; obj != nil { + if depID, ok := nodesMap[obj]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + return true + }) + } + } + } + + // 处理标签 + if label, ok := n.(*ast.LabeledStmt); ok { + if obj := pkg.TypesInfo.Defs[label.Label]; obj != nil { + if depID, ok := nodesMap[obj]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + // 处理选择器表达式(如 a.b 形式的调用) if sel, ok := n.(*ast.SelectorExpr); ok { var obj types.Object @@ -398,6 +549,7 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { // 查找右侧方法调用 methodName := sel.Sel.Name + // 检查对象类型是否是接口类型 _, isIface := objType.Underlying().(*types.Interface) @@ -408,6 +560,135 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { } } + // 检查是否是泛型方法调用 + if methodType, ok := pkg.TypesInfo.Types[sel].Type.(*types.Signature); ok { + // 处理泛型方法的类型参数 + if methodType.TypeParams() != nil && methodType.TypeParams().Len() > 0 { + // 处理泛型方法的类型参数 + for i := 0; i < methodType.TypeParams().Len(); i++ { + typeParam := methodType.TypeParams().At(i) + if depID, ok := nodesMap[typeParam.Obj()]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + + // 处理泛型方法的约束 + if methodType.TypeParams() != nil && methodType.TypeParams().Len() > 0 { + for i := 0; i < methodType.TypeParams().Len(); i++ { + typeParam := methodType.TypeParams().At(i) + // 获取类型参数的约束 + if constraint := typeParam.Constraint(); constraint != nil { + // 处理约束中的类型 + ast.Inspect(sel, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + if obj := pkg.TypesInfo.Uses[ident]; obj != nil { + if depID, ok := nodesMap[obj]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + return true + }) + } + } + } + + // 处理泛型方法的接收器 + if recv := methodType.Recv(); recv != nil { + recvType := recv.Type() + // 处理指针接收器 + if ptr, ok := recvType.(*types.Pointer); ok { + recvType = ptr.Elem() + } + // 处理命名类型 + if named, ok := recvType.(*types.Named); ok { + // 处理接收器类型的类型参数 + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + for i := 0; i < named.TypeParams().Len(); i++ { + typeParam := named.TypeParams().At(i) + if depID, ok := nodesMap[typeParam.Obj()]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + // 处理接收器类型本身 + if depID, ok := nodesMap[named.Obj()]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + } + + // 处理泛型方法的调用参数 + if call, ok := sel.X.(*ast.CallExpr); ok { + // 处理调用参数中的类型 + for _, arg := range call.Args { + ast.Inspect(arg, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + if obj := pkg.TypesInfo.Uses[ident]; obj != nil { + if depID, ok := nodesMap[obj]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + return true + }) + } + } + + // 处理泛型方法的类型参数使用 + if methodType, ok := pkg.TypesInfo.Types[sel].Type.(*types.Signature); ok { + // 处理参数中的类型参数使用 + if methodType.Params() != nil { + for i := 0; i < methodType.Params().Len(); i++ { + param := methodType.Params().At(i) + paramType := param.Type() + // 检查参数类型是否是类型参数 + if typeParam, ok := paramType.(*types.TypeParam); ok { + if depID, ok := nodesMap[typeParam.Obj()]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + // 检查参数类型是否包含类型参数 + if named, ok := paramType.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + for j := 0; j < named.TypeParams().Len(); j++ { + typeParam := named.TypeParams().At(j) + if depID, ok := nodesMap[typeParam.Obj()]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + } + } + } + // 处理返回值中的类型参数使用 + if methodType.Results() != nil { + for i := 0; i < methodType.Results().Len(); i++ { + result := methodType.Results().At(i) + resultType := result.Type() + // 检查返回值类型是否是类型参数 + if typeParam, ok := resultType.(*types.TypeParam); ok { + if depID, ok := nodesMap[typeParam.Obj()]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + // 检查返回值类型是否包含类型参数 + if named, ok := resultType.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + for j := 0; j < named.TypeParams().Len(); j++ { + typeParam := named.TypeParams().At(j) + if depID, ok := nodesMap[typeParam.Obj()]; ok && depID != curNodeID { + graph[curNodeID][depID] = struct{}{} + } + } + } + } + } + } + } + if isIface { // 它是接口类型,需要查找所有实现了这个接口的类型 // 首先尝试查找精确匹配的接口(如果已经记录在 interfacesInfo 中) @@ -471,6 +752,12 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { if obj == nil { return true } + // 如果obj是函数,则获取其原始函数 + if f, ok := obj.(*types.Func); ok { + if f.Origin() != nil { + obj = f.Origin() + } + } // 判断这个对象是否在我们的顶级声明中(只考虑同一项目内部) if depID, ok := nodesMap[obj]; ok { // 避免自引用 @@ -495,6 +782,55 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { } funcName := GetFuncOrMethodName(d) curID := GetObjectID(pkg.ID, baseFilename, funcName) + + // 处理函数参数 + if d.Type.Params != nil { + for _, field := range d.Type.Params.List { + ast.Inspect(field.Type, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + if obj := pkg.TypesInfo.Uses[ident]; obj != nil { + if depID, ok := nodesMap[obj]; ok && depID != curID { + graph[curID][depID] = struct{}{} + } + } + } + return true + }) + } + } + + // 处理函数返回值 + if d.Type.Results != nil { + for _, field := range d.Type.Results.List { + ast.Inspect(field.Type, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + if obj := pkg.TypesInfo.Uses[ident]; obj != nil { + if depID, ok := nodesMap[obj]; ok && depID != curID { + graph[curID][depID] = struct{}{} + } + } + } + return true + }) + } + } + + // 处理函数接收器 + if d.Recv != nil && len(d.Recv.List) > 0 { + for _, field := range d.Recv.List { + ast.Inspect(field.Type, func(n ast.Node) bool { + if ident, ok := n.(*ast.Ident); ok { + if obj := pkg.TypesInfo.Uses[ident]; obj != nil { + if depID, ok := nodesMap[obj]; ok && depID != curID { + graph[curID][depID] = struct{}{} + } + } + } + return true + }) + } + } + // 处理函数体 collectDependencies(d.Body, curID, pkg) case *ast.GenDecl: for _, spec := range d.Specs { @@ -514,17 +850,6 @@ func BuildDependency(pkgs []*packages.Package) (*DependencyInfo, error) { } } - // 输出接口实现的信息(用于调试) - // fmt.Println("\n接口实现关系:") - // for ifaceID, impls := range interfaceImpls.ImplementersMap { - // if len(impls) > 0 { - // fmt.Printf("接口 %s 的实现类型:\n", ifaceID) - // for _, impl := range impls { - // fmt.Printf(" - %s\n", impl) - // } - // } - // } - // 构建反向图 revGraph := make(Graph) for nodeID, deps := range graph { @@ -572,6 +897,27 @@ func GetFuncOrMethodName(fn *ast.FuncDecl) string { // signaturesCompatible 检查类型方法的签名是否与接口方法的签名兼容 func signaturesCompatible(ifaceMethodSig, typeMethodSig *types.Signature) bool { + // 检查接口方法是否有类型参数 + ifaceHasTypeParams := ifaceMethodSig.TypeParams() != nil && ifaceMethodSig.TypeParams().Len() > 0 + + // 检查接收器是否有类型参数 + ifaceHasRecvTypeParams := false + if recv := ifaceMethodSig.Recv(); recv != nil { + // 获取接收器类型 + recvType := recv.Type() + + // 检查是否是命名类型 + if named, ok := recvType.(*types.Named); ok { + // 检查命名类型是否有类型参数 + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + ifaceHasRecvTypeParams = true + } + } + } + + // 如果接口方法有类型参数或接收器有类型参数,标记为泛型方法 + hasGenericParams := ifaceHasTypeParams || ifaceHasRecvTypeParams + // 检查参数数量是否相同 if ifaceMethodSig.Params().Len() != typeMethodSig.Params().Len() { return false @@ -587,7 +933,140 @@ func signaturesCompatible(ifaceMethodSig, typeMethodSig *types.Signature) bool { ifaceParam := ifaceMethodSig.Params().At(i) typeParam := typeMethodSig.Params().At(i) - if !types.AssignableTo(typeParam.Type(), ifaceParam.Type()) { + ifaceParamType := ifaceParam.Type() + typeParamType := typeParam.Type() + + // 如果接口有类型参数 + if hasGenericParams { + // 检查参数本身是否是类型参数 + _, isTypeParam := ifaceParamType.(*types.TypeParam) + if isTypeParam { + // 泛型类型参数可以匹配任何类型,跳过详细检查 + continue + } + + // 检查参数是否是包含泛型的复合类型 + + // 处理切片类型 []T + ifaceSlice, ifaceIsSlice := ifaceParamType.(*types.Slice) + _, typeIsSlice := typeParamType.(*types.Slice) + if ifaceIsSlice && typeIsSlice { + // 获取切片元素类型 + ifaceElem := ifaceSlice.Elem() + + // 检查元素是否是类型参数 + _, elemIsTypeParam := ifaceElem.(*types.TypeParam) + if elemIsTypeParam { + // 如果接口参数是[]T,T是类型参数,则认为兼容任何切片类型 + continue + } + + // 检查是否是命名类型且包含类型参数 + if named, ok := ifaceElem.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + // 如果元素是泛型类型,也认为兼容 + continue + } + } + } + + // 处理map类型 map[K]V + ifaceMap, ifaceIsMap := ifaceParamType.(*types.Map) + _, typeIsMap := typeParamType.(*types.Map) + if ifaceIsMap && typeIsMap { + // 获取map的键和值类型 + ifaceKey := ifaceMap.Key() + ifaceVal := ifaceMap.Elem() + + // 检查键或值是否是类型参数 + _, keyIsTypeParam := ifaceKey.(*types.TypeParam) + _, valIsTypeParam := ifaceVal.(*types.TypeParam) + + // 检查键和值是否是命名类型且包含类型参数 + keyHasGenericParams := false + valHasGenericParams := false + + if named, ok := ifaceKey.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + keyHasGenericParams = true + } + } + + if named, ok := ifaceVal.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + valHasGenericParams = true + } + } + + // 如果键或值是类型参数或泛型类型,则认为兼容 + if keyIsTypeParam || valIsTypeParam || keyHasGenericParams || valHasGenericParams { + continue + } + } + + // 处理通道类型 chan T + ifaceChan, ifaceIsChan := ifaceParamType.(*types.Chan) + typeChan, typeIsChan := typeParamType.(*types.Chan) + if ifaceIsChan && typeIsChan { + // 获取通道元素类型 + ifaceElem := ifaceChan.Elem() + + // 检查元素是否是类型参数 + _, elemIsTypeParam := ifaceElem.(*types.TypeParam) + + // 检查是否是命名类型且包含类型参数 + elemHasGenericParams := false + if named, ok := ifaceElem.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + elemHasGenericParams = true + } + } + + // 如果元素是类型参数或泛型类型,并且通道方向相同,则认为兼容 + if (elemIsTypeParam || elemHasGenericParams) && (ifaceChan.Dir() == typeChan.Dir()) { + continue + } + } + + // 处理指针类型 *T + ifacePtr, ifaceIsPtr := ifaceParamType.(*types.Pointer) + _, typeIsPtr := typeParamType.(*types.Pointer) + if ifaceIsPtr && typeIsPtr { + // 获取指针的基础类型 + ifaceElem := ifacePtr.Elem() + + // 检查基础类型是否是类型参数 + _, elemIsTypeParam := ifaceElem.(*types.TypeParam) + + // 检查是否是命名类型且包含类型参数 + elemHasGenericParams := false + if named, ok := ifaceElem.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + elemHasGenericParams = true + } + } + + // 如果基础类型是类型参数或泛型类型,则认为兼容 + if elemIsTypeParam || elemHasGenericParams { + continue + } + } + + // 检查命名类型是否包含类型参数 + if named, ok := ifaceParamType.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + // 对于带有类型参数的命名类型,我们简化处理,认为兼容 + continue + } + } + } + + // 对于非泛型参数或未能识别的泛型场景,使用标准的类型赋值检查 + if !types.AssignableTo(typeParamType, ifaceParamType) { + if hasGenericParams { + //panic(fmt.Sprintf("类型不兼容: %v 不能赋值给 %v", typeParamType, ifaceParamType)) + return false + } return false } } @@ -597,21 +1076,146 @@ func signaturesCompatible(ifaceMethodSig, typeMethodSig *types.Signature) bool { ifaceResult := ifaceMethodSig.Results().At(i) typeResult := typeMethodSig.Results().At(i) - if !types.AssignableTo(typeResult.Type(), ifaceResult.Type()) { + ifaceResultType := ifaceResult.Type() + typeResultType := typeResult.Type() + + // 如果接口有类型参数 + if hasGenericParams { + // 检查返回值是否是类型参数 + _, isTypeParam := ifaceResultType.(*types.TypeParam) + if isTypeParam { + // 泛型类型参数可以匹配任何类型 + continue + } + + // 检查返回值是否是包含泛型的复合类型 + + // 处理切片类型 []T + ifaceSlice, ifaceIsSlice := ifaceResultType.(*types.Slice) + _, typeIsSlice := typeResultType.(*types.Slice) + if ifaceIsSlice && typeIsSlice { + // 获取切片元素类型 + ifaceElem := ifaceSlice.Elem() + + // 检查元素是否是类型参数 + _, elemIsTypeParam := ifaceElem.(*types.TypeParam) + + // 检查是否是命名类型且包含类型参数 + elemHasGenericParams := false + if named, ok := ifaceElem.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + elemHasGenericParams = true + } + } + + // 如果元素是类型参数或泛型类型,则认为兼容 + if elemIsTypeParam || elemHasGenericParams { + continue + } + } + + // 处理map类型 map[K]V + ifaceMap, ifaceIsMap := ifaceResultType.(*types.Map) + _, typeIsMap := typeResultType.(*types.Map) + if ifaceIsMap && typeIsMap { + // 获取map的键和值类型 + ifaceKey := ifaceMap.Key() + ifaceVal := ifaceMap.Elem() + + // 检查键或值是否是类型参数 + _, keyIsTypeParam := ifaceKey.(*types.TypeParam) + _, valIsTypeParam := ifaceVal.(*types.TypeParam) + + // 检查键和值是否是命名类型且包含类型参数 + keyHasGenericParams := false + valHasGenericParams := false + + if named, ok := ifaceKey.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + keyHasGenericParams = true + } + } + + if named, ok := ifaceVal.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + valHasGenericParams = true + } + } + + // 如果键或值是类型参数或泛型类型,则认为兼容 + if keyIsTypeParam || valIsTypeParam || keyHasGenericParams || valHasGenericParams { + continue + } + } + + // 处理通道类型 chan T + ifaceChan, ifaceIsChan := ifaceResultType.(*types.Chan) + typeChan, typeIsChan := typeResultType.(*types.Chan) + if ifaceIsChan && typeIsChan { + // 获取通道元素类型 + ifaceElem := ifaceChan.Elem() + + // 检查元素是否是类型参数 + _, elemIsTypeParam := ifaceElem.(*types.TypeParam) + + // 检查是否是命名类型且包含类型参数 + elemHasGenericParams := false + if named, ok := ifaceElem.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + elemHasGenericParams = true + } + } + + // 如果元素是类型参数或泛型类型,并且通道方向相同,则认为兼容 + if (elemIsTypeParam || elemHasGenericParams) && (ifaceChan.Dir() == typeChan.Dir()) { + continue + } + } + + // 处理指针类型 *T + ifacePtr, ifaceIsPtr := ifaceResultType.(*types.Pointer) + _, typeIsPtr := typeResultType.(*types.Pointer) + if ifaceIsPtr && typeIsPtr { + // 获取指针的基础类型 + ifaceElem := ifacePtr.Elem() + + // 检查基础类型是否是类型参数 + _, elemIsTypeParam := ifaceElem.(*types.TypeParam) + + // 检查是否是命名类型且包含类型参数 + elemHasGenericParams := false + if named, ok := ifaceElem.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + elemHasGenericParams = true + } + } + + // 如果基础类型是类型参数或泛型类型,则认为兼容 + if elemIsTypeParam || elemHasGenericParams { + continue + } + } + + // 检查命名类型是否包含类型参数 + if named, ok := ifaceResultType.(*types.Named); ok { + if named.TypeParams() != nil && named.TypeParams().Len() > 0 { + // 对于带有类型参数的命名类型,我们简化处理,认为兼容 + continue + } + } + } + + // 对于非泛型返回值或未能识别的泛型场景,使用标准的类型赋值检查 + if !types.AssignableTo(typeResultType, ifaceResultType) { return false } } // 检查可变参数特性是否一致 - if ifaceMethodSig.Variadic() != typeMethodSig.Variadic() { - return false - } - - return true + return ifaceMethodSig.Variadic() == typeMethodSig.Variadic() } func LoadPackages(repo string) ([]*packages.Package, error) { - // 加载包信息 cfg := &packages.Config{ Mode: packages.NeedName | packages.NeedFiles | packages.NeedSyntax | packages.NeedTypes | packages.NeedTypesInfo, Dir: repo, diff --git a/parser/dependency_test.go b/parser/dependency_test.go new file mode 100644 index 0000000..692b620 --- /dev/null +++ b/parser/dependency_test.go @@ -0,0 +1,77 @@ +package parser + +import ( + "go/ast" + "go/parser" + "go/token" + "testing" +) + +func TestBuildDependency(t *testing.T) { + pkgs, err := LoadPackages("./material") + if err != nil { + t.Fatalf("failed to load packages: %v", err) + } + + depInfo, err := BuildDependency(pkgs) + if err != nil { + t.Fatalf("failed to build dependency: %v", err) + } + + t.Logf("depInfo: %v", depInfo) +} + +func TestGetFuncOrMethodName(t *testing.T) { + type args struct { + fn string + } + tests := []struct { + name string + args args + want string + }{ + // 函数 + {name: "func1", args: args{fn: "func Func1() {}"}, want: "Func1"}, + {name: "func2", args: args{fn: "func Func2() int { return 1 }"}, want: "Func2"}, + {name: "func3", args: args{fn: "func Func3() (int, error) { return 1, nil }"}, want: "Func3"}, + {name: "func4", args: args{fn: "func Func4(a int, b int) (int, error) { return a + b, nil }"}, want: "Func4"}, + {name: "func5", args: args{fn: "func Func5(a int, b int) (int, error) { return a + b, nil }"}, want: "Func5"}, + // 方法 + {name: "method1", args: args{fn: "func (m *MyType) Method1() { return 1 }"}, want: "(*MyType).Method1"}, + {name: "method2", args: args{fn: "func (m *MyType) Method2() (int, error) { return 1, nil }"}, want: "(*MyType).Method2"}, + // 范型 + {name: "genericFunc1", args: args{fn: "func GenericFunc[T any](a T) T { return a }"}, want: "GenericFunc"}, + {name: "genericFunc2", args: args{fn: "func GenericFunc2[T any, U any](a T, b U) (T, U) { return a, b }"}, want: "GenericFunc2"}, + {name: "genericFunc3", args: args{fn: "func GenericFunc3[T any](a T, b int) (T, int) { return a, b }"}, want: "GenericFunc3"}, + // 泛型方法 + {name: "genericMethod1", args: args{fn: "func (a *AutoFlushBuffer[T]) WriteMessage(msg T) { return a }"}, want: "(*AutoFlushBuffer[T]).WriteMessage"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + src := "package test\n\n" + tt.args.fn + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", src, parser.ParseComments) + if err != nil { + t.Fatalf("failed to parse file: %v", err) + } + + // 获取第一个函数声明 + var funcDecl *ast.FuncDecl + for _, decl := range f.Decls { + if fd, ok := decl.(*ast.FuncDecl); ok { + funcDecl = fd + break + } + } + + if funcDecl == nil { + t.Fatal("no function declaration found") + } + + if got := GetFuncOrMethodName(funcDecl); got != tt.want { + t.Errorf("GetFuncOrMethodName() = %s, want %v", got, tt.want) + } + }) + } +} diff --git a/parser/material/bufutil/bufutil.go b/parser/material/bufutil/bufutil.go new file mode 100644 index 0000000..6989fd3 --- /dev/null +++ b/parser/material/bufutil/bufutil.go @@ -0,0 +1,93 @@ +package bufutil + +import ( + "context" + "log" + "sync" + "time" +) + +type BatchWriter[T any] interface { + // BatchWrite 由写入者提供,将内容批量写入,并返回成功写入的数量和错误 + // 如果返回了错误,且数量不等于缓冲区队列长度,缓冲区会将剩下的消息重新放进队列下次重试。 + // 因此,只有当你确定这条消息写入失败后需要下次重试时,返回的数量才 != len(msgs)。 + // 如果某条消息写入失败,但业务上可以接受,建议在BatchWrite里面打条日志,继续处理 + // 下一条消息,最终依旧返回len(msgs), nil。 如果你遇到写入失败但不想丢弃消息,可以 + // return 成功的数量, err + BatchWrite(msgs []T) (int, error) +} + +type AutoFlushBuffer[T any] struct { + buf []T // 缓冲区 + flushDuration time.Duration // 定时刷入的时间间隔 + flushSize int // 批量刷入的阈值 + lock sync.RWMutex // 锁,修改缓冲区时进行保护 + writer BatchWriter[T] // 写入的实现 +} + +// NewAutoFlushBuffer 创建一个缓冲区, 该缓冲区支持自动定时刷入, 支持手动刷入, 支持达到阈值批量刷入 +// 该缓冲区是线程安全的, 多个协程可以并发写入, 消息会按顺序调用用户提供的BatchWrite方法进行写入 +// 如果你需要并发或异步写入,只需要在BatchWrite方法中自行实现进行控制即可 +func NewAutoFlushBuffer[T any](flushDuration time.Duration, flushSize int, writer BatchWriter[T]) *AutoFlushBuffer[T] { + return &AutoFlushBuffer[T]{ + buf: make([]T, 0, flushSize), + flushDuration: flushDuration, + flushSize: flushSize, + lock: sync.RWMutex{}, + writer: writer, + } +} + +// StartTimedFlusher 会根据设置的时间间隔,定期将缓冲区中的消息进行刷入 +func (a *AutoFlushBuffer[T]) StartTimedFlusher(ctx context.Context) { + ticker := time.NewTicker(a.flushDuration) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := a.Flush(); err != nil { + log.Printf("定时刷入时发生了错误: %v", err) + } + case <-ctx.Done(): + // 退出前尝试把缓冲区中的消息刷入 + _ = a.Flush() + return + } + } +} + +// WriteMessage 往缓冲区队列中写入一条消息 +func (a *AutoFlushBuffer[T]) WriteMessage(msg T) { + bufLen := a.writeBuf(msg) + // 如果写入后的长度大于等于设置的阈值,则尝试刷入 + if bufLen >= a.flushSize { + if err := a.Flush(); err != nil { + log.Printf("写入时发生了错误: %v", err) + return + } + } +} + +// writeBuf 写入缓冲区并返回缓冲区的当前长度 +func (a *AutoFlushBuffer[T]) writeBuf(msg T) int { + a.lock.Lock() + defer a.lock.Unlock() + a.buf = append(a.buf, msg) + return len(a.buf) +} + +// Flush 调用用户提供的处理函数,将buffer中的内容刷入 +func (a *AutoFlushBuffer[T]) Flush() error { + a.lock.Lock() + defer a.lock.Unlock() + if len(a.buf) > 0 { + n, err := a.writer.BatchWrite(a.buf) + if err != nil { + // 原地将队列里未写完的数据拷贝到缓冲区前面 + a.buf = append(a.buf[:0], a.buf[n:]...) + return err + } + a.buf = a.buf[:0] + } + return nil +}