diff --git a/Makefile b/Makefile index 9c4251f..00d6629 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ get: .PHONY: get_lint get_lint: @if [ ! -f ./bin/golangci-lint ]; then \ - curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.57.2; \ + curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.63.0; \ fi; .PHONY: lint diff --git a/README.md b/README.md index ac4545f..f38c62c 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ make analyser #### Analyze Command ```sh -./bin/analyzer analyze [command options] +./bin/analyzer analyze [command options] arg[source path] ``` #### Analyze Options @@ -111,7 +111,7 @@ make analyser #### Trace Command ```sh -./bin/analyzer trace [command options] +./bin/analyzer trace [command options] arg[source path] ``` #### Trace Options @@ -120,7 +120,8 @@ make analyser |-----------------------|----------------------------------------------------------------------------------------|---------| | `--vm-profile value` | Path to the VM profile config file (required). | None | | `--function value` | Name of the function to trace. Include package name (e.g., `syscall.read`). (required) | None | -| `--help, -h` | Show help. | None | +| `--source-type value` | Assembly or go source code. | None | +| `--help, -h` | Show help. | None | ## Example Usage @@ -134,7 +135,7 @@ make analyser ### Running a Trace ```sh -./bin/analyzer trace --vm-profile=./profile/cannon/cannon-64.yaml --function=syscall.read +./bin/analyzer trace --vm-profile=./profile/cannon/cannon-64.yaml --function=syscall.read sample.asm ```` diff --git a/analyzer/opcode/opcode.go b/analyzer/opcode/opcode.go index 061c7d0..c120659 100644 --- a/analyzer/opcode/opcode.go +++ b/analyzer/opcode/opcode.go @@ -23,18 +23,11 @@ func NewAnalyser(profile *profile.VMProfile) analyzer.Analyzer { } func (op *opcode) Analyze(path string, withTrace bool) ([]*analyzer.Issue, error) { - var err error - var callGraph asmparser.CallGraph - - switch op.profile.GOARCH { - case "mips32", "mips64": - callGraph, err = mips.NewParser().Parse(path) - default: - return nil, fmt.Errorf("unsupported GOARCH %s", op.profile.GOARCH) - } + callGraph, err := op.buildCallGraph(path) if err != nil { return nil, err } + absPath, err := filepath.Abs(path) if err != nil { return nil, err @@ -43,7 +36,7 @@ func (op *opcode) Analyze(path string, withTrace bool) ([]*analyzer.Issue, error for _, segment := range callGraph.Segments() { for _, instruction := range segment.Instructions() { if !op.isAllowedOpcode(instruction.OpcodeHex(), instruction.Funct()) { - source, err := common.TraceAsmCaller(absPath, callGraph, segment.Label()) + source, err := common.TraceAsmCaller(absPath, callGraph, segment.Label(), endCondition) if err != nil { // non-reachable portion ignored continue } @@ -62,11 +55,37 @@ func (op *opcode) Analyze(path string, withTrace bool) ([]*analyzer.Issue, error return issues, nil } +func (op *opcode) buildCallGraph(path string) (asmparser.CallGraph, error) { + var ( + err error + callGraph asmparser.CallGraph + ) + + // Select the correct parser based on architecture. + switch op.profile.GOARCH { + case "mips32", "mips64": + callGraph, err = mips.NewParser().Parse(path) + default: + return nil, fmt.Errorf("unsupported GOARCH: %s", op.profile.GOARCH) + } + if err != nil { + return nil, fmt.Errorf("error parsing assembly file: %w", err) + } + return callGraph, nil +} + // TraceStack generates callstack for a function to debug func (op *opcode) TraceStack(path string, function string) (*analyzer.IssueSource, error) { - return nil, fmt.Errorf("stack trace is not supported for assembly code") + graph, err := op.buildCallGraph(path) + if err != nil { + return nil, err + } + absPath, err := filepath.Abs(path) + if err != nil { + return nil, err + } + return common.TraceAsmCaller(absPath, graph, function, endCondition) } - func (op *opcode) isAllowedOpcode(opcode, funct string) bool { return slices.ContainsFunc(op.profile.AllowedOpcodes, func(instr profile.OpcodeInstruction) bool { if !strings.EqualFold(instr.Opcode, opcode) { @@ -80,3 +99,10 @@ func (op *opcode) isAllowedOpcode(opcode, funct string) bool { }) }) } + +func endCondition(function string) bool { + return function == "runtime.rt0_go" || // start point of a go program + function == "main.main" || // main + strings.Contains(function, ".init.") || // all init functions + strings.HasSuffix(function, ".init") // vars +} diff --git a/analyzer/syscall/asm_syscall.go b/analyzer/syscall/asm_syscall.go index 53e2be7..c4a1d4a 100644 --- a/analyzer/syscall/asm_syscall.go +++ b/analyzer/syscall/asm_syscall.go @@ -5,6 +5,7 @@ import ( "fmt" "path/filepath" "slices" + "strings" "github.com/ChainSafe/vm-compat/analyzer" "github.com/ChainSafe/vm-compat/asmparser" @@ -13,8 +14,6 @@ import ( "github.com/ChainSafe/vm-compat/profile" ) -var syscallAPISForAsm = append(syscallAPIs, "runtime/internal/syscall.Syscall6") - // asmSyscallAnalyser analyzes system calls in assembly files. type asmSyscallAnalyser struct { profile *profile.VMProfile @@ -29,76 +28,93 @@ func NewAssemblySyscallAnalyser(profile *profile.VMProfile) analyzer.Analyzer { // //nolint:cyclop func (a *asmSyscallAnalyser) Analyze(path string, withTrace bool) ([]*analyzer.Issue, error) { - var ( - err error - callGraph asmparser.CallGraph - ) - - // Select the correct parser based on architecture. - switch a.profile.GOARCH { - case "mips32", "mips64": - callGraph, err = mips.NewParser().Parse(path) - default: - return nil, fmt.Errorf("unsupported GOARCH: %s", a.profile.GOARCH) - } + callGraph, err := a.buildCallGraph(path) if err != nil { - return nil, fmt.Errorf("error parsing assembly file: %w", err) + return nil, err } - - issues := make([]*analyzer.Issue, 0) - absPath, err := filepath.Abs(path) if err != nil { return nil, err } + issues := make([]*analyzer.Issue, 0) // Iterate through segments and check for syscall. for _, segment := range callGraph.Segments() { - segmentLabel := segment.Label() for _, instruction := range segment.Instructions() { if !instruction.IsSyscall() { continue } - // Ignore indirect syscall calling from syscall apis - if slices.Contains(syscallAPISForAsm, segmentLabel) { - continue - } - syscallNum, err := segment.RetrieveSyscallNum(instruction) + syscalls, err := callGraph.RetrieveSyscallNum(segment, instruction) if err != nil { return nil, fmt.Errorf("failed to retrieve syscall number: %w", err) } + for _, syscall := range syscalls { + // Categorize syscall + if slices.Contains(a.profile.AllowedSycalls, syscall.Number) { + continue + } + source, err := common.TraceAsmCaller(absPath, callGraph, syscall.Segment.Label(), endCondition) + if err != nil { // non-reachable portion ignored + continue + } + if !withTrace { + source.CallStack = nil + } - // Categorize syscall - if slices.Contains(a.profile.AllowedSycalls, syscallNum) { - continue - } - // Better to develop a new algo to check all segments at once like go_syscall - source, err := common.TraceAsmCaller(absPath, callGraph, segment.Label()) - if err != nil { // non-reachable portion ignored - continue - } - if !withTrace { - source.CallStack = nil - } + severity := analyzer.IssueSeverityCritical + message := fmt.Sprintf("Potential Incompatible Syscall Detected: %d", syscall.Number) + if slices.Contains(a.profile.NOOPSyscalls, syscall.Number) { + message = fmt.Sprintf("Potential NOOP Syscall Detected: %d", syscall.Number) + severity = analyzer.IssueSeverityWarning + } - severity := analyzer.IssueSeverityCritical - message := fmt.Sprintf("Potential Incompatible Syscall Detected: %d", syscallNum) - if slices.Contains(a.profile.NOOPSyscalls, syscallNum) { - message = fmt.Sprintf("Potential NOOP Syscall Detected: %d", syscallNum) - severity = analyzer.IssueSeverityWarning + issues = append(issues, &analyzer.Issue{ + Severity: severity, + Message: message, + Sources: source, + }) } - - issues = append(issues, &analyzer.Issue{ - Severity: severity, - Message: message, - Sources: source, - }) } } return issues, nil } +func (a *asmSyscallAnalyser) buildCallGraph(path string) (asmparser.CallGraph, error) { + var ( + err error + callGraph asmparser.CallGraph + ) + + // Select the correct parser based on architecture. + switch a.profile.GOARCH { + case "mips32", "mips64": + callGraph, err = mips.NewParser().Parse(path) + default: + return nil, fmt.Errorf("unsupported GOARCH: %s", a.profile.GOARCH) + } + if err != nil { + return nil, fmt.Errorf("error parsing assembly file: %w", err) + } + return callGraph, nil +} + // TraceStack generates callstack for a function to debug func (a *asmSyscallAnalyser) TraceStack(path string, function string) (*analyzer.IssueSource, error) { - return nil, fmt.Errorf("stack trace is not supported for assembly code") + graph, err := a.buildCallGraph(path) + if err != nil { + return nil, err + } + + absPath, err := filepath.Abs(path) + if err != nil { + return nil, err + } + return common.TraceAsmCaller(absPath, graph, function, endCondition) +} + +func endCondition(function string) bool { + return function == "runtime.rt0_go" || // start point of a go program + function == "main.main" || // main + strings.Contains(function, ".init.") || // all init functions + strings.HasSuffix(function, ".init") // vars } diff --git a/analyzer/syscall/go_syscall.go b/analyzer/syscall/go_syscall.go index 1eeab8c..83718f0 100644 --- a/analyzer/syscall/go_syscall.go +++ b/analyzer/syscall/go_syscall.go @@ -11,6 +11,7 @@ import ( "github.com/ChainSafe/vm-compat/analyzer" "github.com/ChainSafe/vm-compat/common" + "github.com/ChainSafe/vm-compat/common/lifo" "github.com/ChainSafe/vm-compat/profile" "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/callgraph/rta" @@ -44,42 +45,16 @@ func (a *goSyscallAnalyser) Analyze(path string, withTrace bool) ([]*analyzer.Is if err != nil { return nil, err } - syscalls := make([]syscallSource, 0) - err = callgraph.GraphVisitEdges(cg, func(edge *callgraph.Edge) error { - callee := edge.Callee.Func - if callee != nil && callee.Pkg != nil && callee.Pkg.Pkg != nil { - if slices.Contains(syscallAPIs, callee.String()) { - calls := traceSyscalls(edge.Site.Common().Args[0], edge, fset) - syscalls = append(syscalls, calls...) - } - } - return nil - }) - if err != nil { - return nil, err - } + syscalls := a.extractSyscalls(cg) // Check against allowed syscalls. issues := make([]*analyzer.Issue, 0) - functions := make([]string, 0) - for _, syscall := range syscalls { - functions = append(functions, syscall.source.Function) - } - tracker := a.reachableFunctions(cg, functions) - stackTraceMap := make(map[string]*analyzer.IssueSource) for i := range syscalls { syscll := syscalls[i] - if slices.Contains(a.profile.AllowedSycalls, syscll.num) || !tracker[syscll.source.Function] { + if slices.Contains(a.profile.AllowedSycalls, syscll.num) { continue } - stackTrace := syscll.source - if withTrace { - stackTrace = stackTraceMap[syscll.source.Function] - if stackTrace == nil { - stackTrace, _ = a.trackStack(cg, fset, syscll.source.Function) - stackTraceMap[syscll.source.Function] = stackTrace - } - } + stackTrace := a.edgeToCallStack(syscll.edgeStack, fset, withTrace) severity := analyzer.IssueSeverityCritical message := fmt.Sprintf("Potential Incompatible Syscall Detected: %d", syscll.num) @@ -103,13 +78,92 @@ func (a *goSyscallAnalyser) TraceStack(path string, function string) (*analyzer. if err != nil { return nil, err } - return a.trackStack(cg, fset, function) + sources := a.buildCallStack(cg, fset, []string{function}) + if sources[function] == nil { + return nil, fmt.Errorf("no trace found to main for function %s not found", function) + } + return sources[function], nil +} + +func (a *goSyscallAnalyser) extractSyscalls(cg *callgraph.Graph) []*syscallSource { + sources := make([]*lifo.Stack[*callgraph.Edge], 0) + currentStack := lifo.Stack[*callgraph.Edge]{} + seen := make(map[*callgraph.Edge]bool) + + var visit func(n *callgraph.Node, edge *callgraph.Edge) + + visit = func(n *callgraph.Node, edge *callgraph.Edge) { + if edge != nil { + currentStack.Push(edge) + } + + if edge != nil && edge.Callee != nil && slices.Contains(syscallAPIs, edge.Callee.Func.String()) { + sources = append(sources, currentStack.Copy()) + } else { + seen[edge] = true + for _, e := range n.Out { + if !seen[e] { + visit(e.Callee, e) + } + } + } + if edge != nil { + currentStack.Pop() + } + } + + for _, n := range cg.Nodes { + if isRoot(n.Func.String()) { + visit(n, nil) + } + } + + syscalls := make([]*syscallSource, 0) + for _, stack := range sources { + edge, _ := stack.Peek() // It must be a syscall API + calls := resolveSyscallValue(edge.Site.Common().Args[0], stack) + for _, call := range calls { + call.edgeStack = stack + } + syscalls = append(syscalls, calls...) + } + + return syscalls } -func (a *goSyscallAnalyser) trackStack(cg *callgraph.Graph, fset *token.FileSet, function string) (*analyzer.IssueSource, error) { +func (a *goSyscallAnalyser) edgeToCallStack(stack *lifo.Stack[*callgraph.Edge], fset *token.FileSet, fullStack bool) *analyzer.IssueSource { + var issueSource *analyzer.IssueSource + for !stack.IsEmpty() { + edge, _ := stack.Pop() + if edge.Site == nil { + continue + } + position := fset.Position(edge.Site.Pos()) + src := &analyzer.IssueSource{ + File: position.Filename, + Line: position.Line, + Function: edge.Caller.Func.String(), + AbsPath: filepath.Clean(position.Filename), + } + if issueSource != nil { + src.CallStack = issueSource + } + issueSource = src + if !fullStack { + return issueSource + } + } + + return issueSource +} + +func (a *goSyscallAnalyser) buildCallStack(cg *callgraph.Graph, fset *token.FileSet, functions []string) map[string]*analyzer.IssueSource { + sources := make(map[string]*lifo.Stack[*analyzer.IssueSource]) + currentStack := lifo.Stack[*analyzer.IssueSource]{} seen := make(map[*callgraph.Node]bool) - var visit func(n *callgraph.Node, edge *callgraph.Edge) *analyzer.IssueSource - visit = func(n *callgraph.Node, edge *callgraph.Edge) *analyzer.IssueSource { + var visit func(n *callgraph.Node, edge *callgraph.Edge) + + visit = func(n *callgraph.Node, edge *callgraph.Edge) { var src *analyzer.IssueSource if edge != nil && edge.Caller != nil && edge.Site != nil { position := fset.Position(edge.Site.Pos()) @@ -120,36 +174,44 @@ func (a *goSyscallAnalyser) trackStack(cg *callgraph.Graph, fset *token.FileSet, Function: fn, AbsPath: filepath.Clean(position.Filename), } - if fn == function { - return src + currentStack.Push(src) + + if slices.Contains(functions, fn) { + sources[fn] = currentStack.Copy() + if len(sources) == len(functions) { + return + } } } // as we are checking edge.Caller we need to get 1 step deeper everytime, that requires to re-visit the node if seen[n] { - return nil + return } seen[n] = true for _, e := range n.Out { - ch := visit(e.Callee, e) - if ch != nil { - if src != nil { - ch.AddCallStack(src) - } - return ch - } + visit(e.Callee, e) + currentStack.Pop() } - return nil } for _, n := range cg.Nodes { - if n.Func.String() == "command-line-arguments.main" || n.Func.String() == "command-line-arguments.init" { - if source := visit(n, nil); source != nil { - return source, nil - } + if isRoot(n.Func.String()) { + visit(n, nil) + } + } + issuesSources := make(map[string]*analyzer.IssueSource) + for fn, stack := range sources { + source, _ := stack.Pop() + for !stack.IsEmpty() { + parent, _ := stack.Pop() + parent.CallStack = source + source = parent } + issuesSources[fn] = source } - return nil, fmt.Errorf("no trace found to root for the given function") + + return issuesSources } func (a *goSyscallAnalyser) buildCallGraph(path string) (*callgraph.Graph, *token.FileSet, error) { @@ -199,65 +261,43 @@ func (a *goSyscallAnalyser) buildCallGraph(path string) (*callgraph.Graph, *toke return cg, initial[0].Fset, nil } -func (a *goSyscallAnalyser) reachableFunctions(cg *callgraph.Graph, functions []string) map[string]bool { - seen := make(map[*callgraph.Node]bool) - tracker := make(map[string]bool) - - var visit func(n *callgraph.Node) - visit = func(n *callgraph.Node) { - if seen[n] { - return - } - seen[n] = true - - if slices.Contains(functions, n.Func.String()) { - tracker[n.Func.String()] = true - } - - for _, e := range n.Out { - visit(e.Callee) - } - } - - for _, n := range cg.Nodes { - if n.Func.String() == "command-line-arguments.main" || n.Func.String() == "command-line-arguments.init" { - visit(n) - } - } - return tracker -} - type syscallSource struct { - num int - source *analyzer.IssueSource + num int + edgeStack *lifo.Stack[*callgraph.Edge] } -func traceSyscalls(value ssa.Value, edge *callgraph.Edge, fset *token.FileSet) []syscallSource { - result := make([]syscallSource, 0) +func resolveSyscallValue(value ssa.Value, edgeStack *lifo.Stack[*callgraph.Edge]) []*syscallSource { + result := make([]*syscallSource, 0) switch v := value.(type) { case *ssa.Const: valInt, err := strconv.Atoi(v.Value.String()) if err == nil { - position := fset.Position(edge.Site.Pos()) - return []syscallSource{{num: valInt, - source: &analyzer.IssueSource{ - File: position.Filename, - Line: position.Line, - Function: edge.Caller.Func.String(), - AbsPath: filepath.Clean(position.Filename), - }, - }} + return []*syscallSource{{num: valInt, edgeStack: edgeStack.Copy()}} } case *ssa.Global: - result = append(result, traceInit(v, v.Pkg.Members, edge, fset)...) - case *ssa.Parameter: - prev := edge.Caller.In - for _, p := range prev { - result = append(result, traceSyscalls(p.Site.Common().Args[0], p, fset)...) + // Iterate through instructions in the Init function + // Iterate through all functions in the package to find the initialization + for _, member := range v.Pkg.Members { + if fn, ok := member.(*ssa.Function); ok { + for _, block := range fn.Blocks { + for _, instr := range block.Instrs { + // Look for Store instructions + if store, ok := instr.(*ssa.Store); ok { + if store.Addr == v { + result = append(result, resolveSyscallValue(store.Val, edgeStack)...) + } + } + } + } + } } + case *ssa.Parameter: + cpStack := edgeStack.Copy() + prev, _ := cpStack.Pop() + result = append(result, resolveSyscallValue(prev.Site.Common().Args[0], cpStack)...) case *ssa.Phi: for _, val := range v.Edges { - result = append(result, traceSyscalls(val, edge, fset)...) + result = append(result, resolveSyscallValue(val, edgeStack)...) // TODO: debug } case *ssa.Call: // Trace nested calls @@ -267,15 +307,15 @@ func traceSyscalls(value ssa.Value, edge *callgraph.Edge, fset *token.FileSet) [ // Look for return instructions if ret, ok := instr.(*ssa.Return); ok { for _, val := range ret.Results { - result = append(result, traceSyscalls(val, edge, fset)...) + result = append(result, resolveSyscallValue(val, edgeStack)...) } } } } case *ssa.UnOp: - result = append(result, traceSyscalls(v.X, edge, fset)...) + result = append(result, resolveSyscallValue(v.X, edgeStack)...) case *ssa.Convert: - result = append(result, traceSyscalls(v.X, edge, fset)...) + result = append(result, resolveSyscallValue(v.X, edgeStack)...) case *ssa.FieldAddr: // check all instructions to get the latest value store for this field address var val ssa.Value @@ -288,7 +328,7 @@ func traceSyscalls(value ssa.Value, edge *callgraph.Edge, fset *token.FileSet) [ } } } - result = append(result, traceSyscalls(val, edge, fset)...) + result = append(result, resolveSyscallValue(val, edgeStack)...) default: fmt.Printf("Unhandled value type: %T\n", v) panic("not handled") @@ -296,26 +336,6 @@ func traceSyscalls(value ssa.Value, edge *callgraph.Edge, fset *token.FileSet) [ return result } -func traceInit(v *ssa.Global, members map[string]ssa.Member, edge *callgraph.Edge, fset *token.FileSet) (result []syscallSource) { - // Iterate through instructions in the Init function - // Iterate through all functions in the package to find the initialization - for _, member := range members { - if fn, ok := member.(*ssa.Function); ok { - for _, block := range fn.Blocks { - for _, instr := range block.Instrs { - // Look for Store instructions - if store, ok := instr.(*ssa.Store); ok { - if store.Addr == v { - result = append(result, traceSyscalls(store.Val, edge, fset)...) - } - } - } - } - } - } - return result -} - // mainPackages returns the main packages to analyze. // Each resulting package is named "main" and has a main function. func mainPackages(pkgs []*ssa.Package) ([]*ssa.Package, error) { @@ -350,3 +370,7 @@ func initFuncs(pkgs []*ssa.Package) []*ssa.Function { } return inits } + +func isRoot(function string) bool { + return function == "command-line-arguments.main" || function == "command-line-arguments.init" +} diff --git a/asmparser/mips/mips_parser.go b/asmparser/mips/mips_parser.go index cb304e6..1b4a473 100644 --- a/asmparser/mips/mips_parser.go +++ b/asmparser/mips/mips_parser.go @@ -4,7 +4,6 @@ package mips import ( "bufio" "fmt" - "math" "os" "path/filepath" "regexp" @@ -16,8 +15,9 @@ import ( // Constants defining MIPS register indexes. const ( - registerZero = 0 // $zero register index in MIPS - registerV0 = 2 // $v0 register index in MIPS + registerZero = 0 // $zero register index in MIPS + registerV0 = 2 // $v0 register index in MIPS + registerSP = 29 // $sp (Stack Pointer) ) var ( @@ -270,41 +270,6 @@ func (s *segment) Instructions() []asmparser.Instruction { return instrs } -// RetrieveSyscallNum extracts the syscall number by analyzing the preceding instructions. -// Limitations: -// - Only supports `daddui` and `addui` instructions for loading syscall numbers. -// - Assumes that `v0` is set by an immediate operation and does not track register dependencies. -// - Does not handle indirect loading methods or data-dependent values. -func (s *segment) RetrieveSyscallNum(instr asmparser.Instruction) (int, error) { - ins, ok := instr.(*instruction) - if !ok { - return 0, fmt.Errorf("invalid instruction type: expected MIPS instruction, got %T", instr) - } - offset := ins.address - s.address - indexOfInstr := offset / uint64(4) - - // every value of i is a uint64 which is always >= 0, hence check against max uint64 - // TODO: if some instruction is skipped this may fail to target the correct one - for i := indexOfInstr - 1; i < math.MaxUint64; i-- { - currInstr := s.instructions[i] - if currInstr.instType == asmparser.RType && len(currInstr.operands) > 2 && currInstr.operands[2] == registerV0 { - return 0, fmt.Errorf("unsupported operation: register v0 modified before syscall assignment at %s", - currInstr.Address()) - } - if currInstr.instType == asmparser.IType && len(currInstr.operands) > 2 && currInstr.operands[1] == registerV0 { - if currInstr.opcode == 0x19 || currInstr.opcode == 0x09 { // daddui or addui - if currInstr.operands[0] != registerZero { - return 0, fmt.Errorf("unsupported operation: syscall number must be loaded from $zero at address %s", - currInstr.Address()) - } - return int(currInstr.operands[2]), nil - } - } - } - - return 0, fmt.Errorf("failed to retrieve syscall number: no valid assignment to register $v0 found in segment") -} - // callGraph represents a graph structure implementing asmparser.CallGraph. type callGraph struct { segments map[uint64]*segment @@ -349,3 +314,139 @@ func (g *callGraph) addSegment(seg *segment) { } g.segments[seg.address] = seg } + +// RetrieveSyscallNum extracts the syscall number by analyzing the preceding instructions. +// Limitation: If syscall number is dynamically generated, it cannot trace that +func (g *callGraph) RetrieveSyscallNum(seg asmparser.Segment, instr asmparser.Instruction) ([]*asmparser.Syscall, error) { + ins, ok := instr.(*instruction) + if !ok { + return nil, fmt.Errorf("invalid instruction type: expected MIPS instruction, got %T", instr) + } + s, ok := seg.(*segment) + if !ok { + return nil, fmt.Errorf("invalid segment type: expected MIPS segment, got %T", seg) + } + var indexOfInstr int + for i, _instr := range s.instructions { + if _instr.Address() == ins.Address() { + indexOfInstr = i + } + } + var resolveRegisterValue func(register, offset int64, instrIdx int, seg, childSeg *segment) ([]*asmparser.Syscall, error) + seen := make(map[*segment]bool) + resolveRegisterValue = func(register, offset int64, instrIdx int, seg, childSeg *segment) ([]*asmparser.Syscall, error) { + result := make([]*asmparser.Syscall, 0) + // Special case, where we don't know from where to start + // Need to find out instruction index + if instrIdx == -2 { + if seen[seg] { + return result, nil + } + // multiple jump possible + for i, inst := range seg.instructions { + if inst.isJump() && uint64(inst.jumpTarget()) == childSeg.address { //nolint:gosec + res, err := resolveRegisterValue(register, offset, i, seg, childSeg) + if err != nil { + return nil, err + } + result = append(result, res...) + } + } + return result, nil + } + // When all the instruction has finished while processing, + // Need to track back to it's caller + if instrIdx == -1 { + seen[seg] = true + parents := g.ParentsOf(seg) + if len(parents) == 0 { + // Here, we cannot resolve any value for syscall, reasons can be it's being assigned in runtime. + // Fine to ignore those syscall + return result, nil + } + for _, sg := range parents { + res, err := resolveRegisterValue(register, offset, -2, sg.(*segment), seg) + if err != nil { + return nil, err + } + result = append(result, res...) + } + return result, nil + } + + currInstr := seg.instructions[instrIdx] + switch currInstr.instType { + case asmparser.RType: + if len(currInstr.operands) > 2 { + rd := currInstr.operands[2] // destination register + // If the destination register is our target register, + // we need to resolve the value for it + if rd == register { + return nil, fmt.Errorf("not handled modification of register in r-type instruction, instruction:%s", currInstr.Address()) + } + } + case asmparser.IType: + if len(currInstr.operands) > 1 { + rs := currInstr.operands[0] + rt := currInstr.operands[1] + if rs == register || rt == register { + switch currInstr.opcode { + case 0x23, 0x24, 0x27, 0x37: // load from rs to rt - rt matters + if register == rt { + // Load to sp - need to match offset + if rt == registerSP && offset == currInstr.operands[2] { + register = rs + } + // Load from SP - update offset + if rs == registerSP { + offset = currInstr.operands[2] + register = rs + } + } + case 0x2B, 0x3F, 0x28: // store to rs from rt - rs matters + if register == rs { + // Store to SP - need to match the offset + if rs == registerSP && offset == currInstr.operands[2] { + register = rt + } + // Store from SP - need to update the offset + if rt == registerSP { + offset = currInstr.operands[2] + register = rt + } + } + return resolveRegisterValue(register, offset, instrIdx-1, seg, childSeg) + case 0x08, 0x09, 0x18, 0x19: // add operations + if register == rt { + // need to check rs carefully + // case 1- memory shift of sp(daddi sp, sp, -88) + if rs == registerSP { + offset += currInstr.operands[2] + return resolveRegisterValue(register, offset, instrIdx-1, seg, childSeg) + } + // case 2- direct assigment to register where rs=registerZero + if rs == registerZero { + return []*asmparser.Syscall{{ + Number: int(currInstr.operands[2]), + Segment: seg, + Instruction: currInstr, + }}, nil + } + return nil, fmt.Errorf("not handled modification of register in i-type instruction, instruction:%s", currInstr.Address()) + } + default: + return nil, fmt.Errorf("not handled opcode, instruction:%s", currInstr.Address()) + } + } + } + default: + } + return resolveRegisterValue(register, offset, instrIdx-1, seg, childSeg) + } + + result, err := resolveRegisterValue(registerV0, 0, indexOfInstr-1, s, nil) + if err != nil { + return nil, err + } + return result, nil +} diff --git a/asmparser/mips/mips_parser_test.go b/asmparser/mips/mips_parser_test.go index 3722f91..c8a708e 100644 --- a/asmparser/mips/mips_parser_test.go +++ b/asmparser/mips/mips_parser_test.go @@ -89,9 +89,6 @@ Disassembly of section .text: assert.Equal(t, asmparser.RType, instrs[2].Type()) assert.Equal(t, "syscall", instrs[2].Mnemonic()) - _, err = segment1.RetrieveSyscallNum(instrs[2]) - require.Error(t, err) - assert.Equal(t, "0x1100c", instrs[3].Address()) assert.Equal(t, "0x3", instrs[3].OpcodeHex()) assert.False(t, instrs[3].IsSyscall()) @@ -117,9 +114,9 @@ Disassembly of section .text: assert.Equal(t, asmparser.IType, instrs[3].Type()) assert.Equal(t, "daddiu", instrs[3].Mnemonic()) - syscallNum, err := segment2.RetrieveSyscallNum(instrs[4]) + syscallNums, err := graph.RetrieveSyscallNum(segment2, instrs[4]) require.NoError(t, err) - assert.Equal(t, 5000, syscallNum) + assert.Equal(t, 5000, syscallNums[0].Number) assert.Equal(t, "0x8d9ec", instrs[5].Address()) assert.Equal(t, "0x4", instrs[5].OpcodeHex()) @@ -135,3 +132,80 @@ Disassembly of section .text: assert.Equal(t, asmparser.RType, instrs[6].Type()) assert.Equal(t, "sync", instrs[6].Mnemonic()) } + +func TestIndirectSyscall(t *testing.T) { + tempFile, err := os.CreateTemp("", "sample.asm") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tempFile.Name()) + + content := `/sample: file format elf64-tradbigmips + +Disassembly of section .text: + +0000000000011000
: + 937e8: 64 01 00 02 daddiu at,zero,2 + 937ec: ff a1 00 08 sd at,8(sp) + 937f0: 64 01 00 01 daddiu at,zero,1 + 937f4: ff a1 00 10 sd at,16(sp) + 937f8: ff a1 00 18 sd at,24(sp) + 937fc: ff a1 00 20 sd at,32(sp) + 93800: ff a1 00 28 sd at,40(sp) + 93804: ff a1 00 30 sd at,48(sp) + 93808: ff a1 00 38 sd at,56(sp) + 9380c: 0c 00 48 e6 jal 12398 +0000000000012398 : + 12398: ff bf ff a8 sd ra,-88(sp) + 1239c: 63 bd ff a8 daddi sp,sp,-88 + 123a0: ff bf 00 00 sd ra,0(sp) + 123a4: df a1 00 60 ld at,96(sp) + 123a8: ff a1 00 08 sd at,8(sp) + 123ac: df a1 00 68 ld at,104(sp) + 123b0: ff a1 00 10 sd at,16(sp) + 123b4: df a1 00 70 ld at,112(sp) + 123b8: ff a1 00 18 sd at,24(sp) + 123bc: df a1 00 78 ld at,120(sp) + 123c0: ff a1 00 20 sd at,32(sp) + 123c4: df a1 00 80 ld at,128(sp) + 123c8: ff a1 00 28 sd at,40(sp) + 123cc: df a1 00 88 ld at,136(sp) + 123d0: ff a1 00 30 sd at,48(sp) + 123d4: df a1 00 90 ld at,144(sp) + 123d8: ff a1 00 38 sd at,56(sp) + 123dc: 0c 00 49 04 jal 12410 +0000000000012410 : + 12410: df a2 00 08 ld v0,8(sp) + 12414: df a4 00 10 ld a0,16(sp) + 12418: df a5 00 18 ld a1,24(sp) + 1241c: df a6 00 20 ld a2,32(sp) + 12420: df a7 00 28 ld a3,40(sp) + 12424: df a8 00 30 ld a4,48(sp) + 12428: df a9 00 38 ld a5,56(sp) + 1242c: 00 00 18 25 move v1,zero + 12430: 00 00 00 0c syscall +` + if _, err = tempFile.WriteString(content); err != nil { + t.Fatal(err) + } + defer func() { + _ = tempFile.Close() + }() + + parser := NewParser() + graph, err := parser.Parse(tempFile.Name()) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + var syscalls []*asmparser.Syscall + for _, seg := range graph.Segments() { + for _, instr := range seg.Instructions() { + if instr.IsSyscall() { + res, err := graph.RetrieveSyscallNum(seg, instr) + assert.NoError(t, err) + syscalls = append(syscalls, res...) + } + } + } + assert.Equal(t, 2, syscalls[0].Number) +} diff --git a/asmparser/parser.go b/asmparser/parser.go index 625ba07..9e06f14 100644 --- a/asmparser/parser.go +++ b/asmparser/parser.go @@ -34,12 +34,21 @@ type Segment interface { Label() string // Instructions return the list of instructions in the segment. Instructions() []Instruction - // RetrieveSyscallNum returns the number of the syscall from the instr - RetrieveSyscallNum(instr Instruction) (int, error) } // CallGraph defines an interface representing a call graph of segments. type CallGraph interface { - Segments() []Segment // Segments returns all segments in the call graph. - ParentsOf(segment Segment) []Segment // ParentsOf returns the parent segments of a given segment. + // Segments returns all segments in the call graph. + Segments() []Segment + // ParentsOf returns the parent segments of a given segment. + ParentsOf(segment Segment) []Segment + // RetrieveSyscallNum returns the number of the syscall from the instr + RetrieveSyscallNum(segment Segment, instr Instruction) ([]*Syscall, error) +} + +// Syscall holds syscall origin related details +type Syscall struct { + Number int + Segment Segment + Instruction Instruction } diff --git a/cmd/analyze.go b/cmd/analyze.go index 3c71980..fedf4db 100644 --- a/cmd/analyze.go +++ b/cmd/analyze.go @@ -90,7 +90,7 @@ func AnalyzeCompatibility(ctx *cli.Context) error { return fmt.Errorf("error disassembling the file: %w", err) } - issues, err := analyze(prof, source, disassemblyPath, analysisType, withTrace) + issues, err := analyze(prof, disassemblyPath, analysisType, withTrace) if err != nil { return fmt.Errorf("analysis failed: %w", err) } @@ -117,19 +117,19 @@ func disassemble(prof *profile.VMProfile, path, outputPath string) (string, erro } // analyze runs the selected analyzer(s). -func analyze(prof *profile.VMProfile, path, disassemblyPath, mode string, withTrace bool) ([]*analyzer.Issue, error) { +func analyze(prof *profile.VMProfile, disassemblyPath, mode string, withTrace bool) ([]*analyzer.Issue, error) { if mode == "opcode" { return opcode.NewAnalyser(prof).Analyze(disassemblyPath, withTrace) } if mode == "syscall" { - return analyzeSyscalls(prof, path, disassemblyPath, withTrace) + return syscall.NewAssemblySyscallAnalyser(prof).Analyze(disassemblyPath, withTrace) } // by default analyze both opIssues, err := opcode.NewAnalyser(prof).Analyze(disassemblyPath, withTrace) if err != nil { return nil, err } - sysIssues, err := analyzeSyscalls(prof, path, disassemblyPath, withTrace) + sysIssues, err := syscall.NewAssemblySyscallAnalyser(prof).Analyze(disassemblyPath, withTrace) if err != nil { return nil, err } @@ -168,15 +168,3 @@ func writeReport(issues []*analyzer.Issue, format, outputPath string, prof *prof return rendererInstance.Render(issues, output) } - -func analyzeSyscalls(profile *profile.VMProfile, source string, disassemblyPath string, withTrace bool) ([]*analyzer.Issue, error) { - issues, err := syscall.NewGOSyscallAnalyser(profile).Analyze(source, withTrace) - if err != nil { - return nil, err - } - issues2, err := syscall.NewAssemblySyscallAnalyser(profile).Analyze(disassemblyPath, withTrace) - if err != nil { - return nil, err - } - return append(issues, issues2...), nil -} diff --git a/cmd/trace.go b/cmd/trace.go index d2ddc51..2562336 100644 --- a/cmd/trace.go +++ b/cmd/trace.go @@ -19,6 +19,12 @@ var ( Usage: "Name of the function to trace. Name should include with package name. Ex: syscall.read", Required: true, } + SourceTypeFlag = &cli.StringFlag{ + Name: "source-type", + Usage: "Tracing on 'go' source code or 'assembly' code. Default assembly", + Required: false, + Value: "assembly", + } ) func CreateTraceCommand(action cli.ActionFunc) *cli.Command { @@ -30,6 +36,7 @@ func CreateTraceCommand(action cli.ActionFunc) *cli.Command { Flags: []cli.Flag{ VMProfileFlag, FunctionNameFlag, + SourceTypeFlag, }, } } @@ -43,11 +50,18 @@ func TraceCaller(ctx *cli.Context) error { return fmt.Errorf("error loading profile: %w", err) } - function := ctx.Path(FunctionNameFlag.Name) - source := ctx.Args().First() + function := ctx.String(FunctionNameFlag.Name) + sourceType := ctx.String(SourceTypeFlag.Name) + path := ctx.Args().First() + + var analyzer analyzer.Analyzer + if sourceType == "go" { + analyzer = syscall.NewGOSyscallAnalyser(prof) + } else { + analyzer = syscall.NewAssemblySyscallAnalyser(prof) + } - analyzer := syscall.NewGOSyscallAnalyser(prof) - callStack, err := analyzer.TraceStack(source, function) + callStack, err := analyzer.TraceStack(path, function) if err != nil { return err } diff --git a/common/lifo/lifo.go b/common/lifo/lifo.go new file mode 100644 index 0000000..211286b --- /dev/null +++ b/common/lifo/lifo.go @@ -0,0 +1,48 @@ +// Package lifo implements lifo stack +package lifo + +type Stack[T any] struct { + items []T +} + +// Push adds an item to the stack +func (s *Stack[T]) Push(value T) { + s.items = append(s.items, value) +} + +// Pop removes and returns the last item from the stack +func (s *Stack[T]) Pop() (T, bool) { + if len(s.items) == 0 { + var zero T + return zero, false + } + val := s.items[len(s.items)-1] + s.items = s.items[:len(s.items)-1] + return val, true +} + +// Peek returns the last item without removing it +func (s *Stack[T]) Peek() (T, bool) { + if len(s.items) == 0 { + var zero T + return zero, false + } + return s.items[len(s.items)-1], true +} + +// Len returns the number of items in the stack +func (s *Stack[T]) Len() int { + return len(s.items) +} + +// IsEmpty checks if the stack is empty +func (s *Stack[T]) IsEmpty() bool { + return len(s.items) == 0 +} + +// Copy creates a new stack with the same elements +func (s *Stack[T]) Copy() *Stack[T] { + newStack := &Stack[T]{} + newStack.items = append([]T{}, s.items...) // Efficient slice copy + return newStack +} diff --git a/common/lifo/lifo_test.go b/common/lifo/lifo_test.go new file mode 100644 index 0000000..15b7997 --- /dev/null +++ b/common/lifo/lifo_test.go @@ -0,0 +1,158 @@ +package lifo + +import ( + "testing" +) + +// TestPushAndPop tests basic push and pop operations +func TestPushAndPop(t *testing.T) { + stack := Stack[int]{} + + // Push items onto the stack + stack.Push(1) + stack.Push(2) + stack.Push(3) + + // Pop and check order (LIFO) + val, ok := stack.Pop() + if !ok || val != 3 { + t.Errorf("Expected 3, got %v", val) + } + + val, ok = stack.Pop() + if !ok || val != 2 { + t.Errorf("Expected 2, got %v", val) + } + + val, ok = stack.Pop() + if !ok || val != 1 { + t.Errorf("Expected 1, got %v", val) + } + + // Stack should now be empty + _, ok = stack.Pop() + if ok { + t.Errorf("Expected empty stack, but Pop returned a value") + } +} + +// TestPeek tests the Peek operation +func TestPeek(t *testing.T) { + stack := Stack[string]{} + + stack.Push("A") + stack.Push("B") + + // Peek should return the last pushed item without removing it + val, ok := stack.Peek() + if !ok || val != "B" { + t.Errorf("Expected B, got %v", val) + } + + // Peek again to ensure it's still there + val, ok = stack.Peek() + if !ok || val != "B" { + t.Errorf("Expected B again, got %v", val) + } +} + +// TestLen tests the Len() function +func TestLen(t *testing.T) { + stack := Stack[float64]{} + + if stack.Len() != 0 { + t.Errorf("Expected length 0, got %d", stack.Len()) + } + + stack.Push(10.5) + stack.Push(20.3) + stack.Push(30.7) + + if stack.Len() != 3 { + t.Errorf("Expected length 3, got %d", stack.Len()) + } + + stack.Pop() + if stack.Len() != 2 { + t.Errorf("Expected length 2 after pop, got %d", stack.Len()) + } +} + +// TestIsEmpty tests IsEmpty() function +func TestIsEmpty(t *testing.T) { + stack := Stack[int]{} + + if !stack.IsEmpty() { + t.Errorf("Expected empty stack, got non-empty") + } + + stack.Push(42) + if stack.IsEmpty() { + t.Errorf("Expected non-empty stack, got empty") + } + + stack.Pop() + if !stack.IsEmpty() { + t.Errorf("Expected empty stack after pop, got non-empty") + } +} + +// TestPopEmpty tests popping from an empty stack +func TestPopEmpty(t *testing.T) { + stack := Stack[bool]{} + + val, ok := stack.Pop() + if ok { + t.Errorf("Expected false for Pop from empty stack, got %v", val) + } +} + +// TestPeekEmpty tests peeking from an empty stack +func TestPeekEmpty(t *testing.T) { + stack := Stack[rune]{} + + val, ok := stack.Peek() + if ok { + t.Errorf("Expected false for Peek from empty stack, got %v", val) + } +} + +// TestCopy tests the Copy() function +func TestCopy(t *testing.T) { + original := Stack[int]{} + original.Push(1) + original.Push(2) + original.Push(3) + + copied := original.Copy() + + // Ensure the copied stack has the same length + if copied.Len() != original.Len() { + t.Errorf("Expected copied stack length %d, got %d", original.Len(), copied.Len()) + } + + // Ensure elements are in the same LIFO order + for !original.IsEmpty() { + origVal, _ := original.Pop() + copyVal, _ := copied.Pop() + if origVal != copyVal { + t.Errorf("Copy failed: expected %d, got %d", origVal, copyVal) + } + } + + // Ensure copied stack is now empty after popping all elements + if !copied.IsEmpty() { + t.Errorf("Expected copied stack to be empty after popping all elements") + } +} + +// TestCopyEmpty tests Copy() on an empty stack +func TestCopyEmpty(t *testing.T) { + original := Stack[string]{} + + copied := original.Copy() + + if !copied.IsEmpty() { + t.Errorf("Expected copied stack to be empty, but it's not") + } +} diff --git a/common/stack_tracer.go b/common/stack_tracer.go index 6d905e2..1e93f3d 100644 --- a/common/stack_tracer.go +++ b/common/stack_tracer.go @@ -3,7 +3,6 @@ package common import ( "fmt" "path/filepath" - "strings" "github.com/ChainSafe/vm-compat/analyzer" "github.com/ChainSafe/vm-compat/asmparser" @@ -14,6 +13,7 @@ func TraceAsmCaller( filePath string, graph asmparser.CallGraph, function string, + endCond func(string) bool, ) (*analyzer.IssueSource, error) { var segment asmparser.Segment for _, seg := range graph.Segments() { @@ -40,7 +40,7 @@ func TraceAsmCaller( AbsPath: filePath, Function: segment.Label(), } - if strings.Contains(source.Function, ".init") { // where to end + if endCond(source.Function) { return source } for _, seg := range graph.ParentsOf(segment) {