diff --git a/CHANGELOG.md b/CHANGELOG.md index 3cfa482..504d602 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + +### Added + +- Support compile-time automatic instrumentation of packages + that don't import `errtrace` by specifying the flag `unsafe-packages`. + This still requires at least one import of `errtrace` in the binary. + ## 0.4.0 - 2025-07-21 This release supports compile-time rewriting of source files via `toolexec`. diff --git a/cmd/errtrace/main.go b/cmd/errtrace/main.go index b463f1a..718ffcc 100644 --- a/cmd/errtrace/main.go +++ b/cmd/errtrace/main.go @@ -26,9 +26,6 @@ // -l list files that would be modified without making any changes. package main -// TODO -// - -toolexec: run as a tool executor, fit for use with 'go build -toolexec' - import ( "bytes" "encoding/json" @@ -52,6 +49,8 @@ import ( "braces.dev/errtrace" ) +const errtracePkgImport = "braces.dev/errtrace" + func main() { cmd := &mainCmd{ Stdin: os.Stdin, @@ -413,6 +412,17 @@ type parsedFile struct { importsErrtrace bool // includes blank imports inserts []insert unusedOptouts []int // list of line numbers + + // Used for toolexec unsafe mode, when rewriting packages that don't import + // errtrace, so we use go:linkname wrappers. + errtraceUnsafePrefix string +} + +func (f *parsedFile) errtracePkgPrefix() string { + if f.errtraceUnsafePrefix == "" { + return fmt.Sprintf("%s.", f.errtracePkg) + } + return f.errtraceUnsafePrefix } func (cmd *mainCmd) parseFile(filename string, src []byte, opts rewriteOpts) (parsedFile, error) { @@ -426,7 +436,7 @@ func (cmd *mainCmd) parseFile(filename string, src []byte, opts rewriteOpts) (pa var importsErrtrace bool // whether there's any errtrace import, including blank imports needErrtraceImport := true // whether to add a new import. for _, imp := range f.Imports { - if imp.Path.Value == `"braces.dev/errtrace"` { + if imp.Path.Value == `"`+errtracePkgImport+`"` { importsErrtrace = true if imp.Name != nil { if imp.Name.Name == "_" { @@ -574,15 +584,20 @@ func (cmd *mainCmd) rewriteFile(f parsedFile, out *bytes.Buffer) error { _, _ = io.WriteString(out, "import ") } - if f.errtracePkg == "errtrace" { + if f.errtraceUnsafePrefix != "" { + // unsafe mode uses go:link, which requires importing "unsafe" instead + // of errtrace. Since duplicate blank imports are allowed by Go, we can + // always add it, without checking if it has been imported. + fmt.Fprint(out, `_ "unsafe"`) + } else if f.errtracePkg == "errtrace" { // Don't use named imports if we're using the default name. - fmt.Fprintf(out, "%q", "braces.dev/errtrace") + fmt.Fprintf(out, "%q", errtracePkgImport) } else { - fmt.Fprintf(out, "%s %q", f.errtracePkg, "braces.dev/errtrace") + fmt.Fprintf(out, "%s %q", f.errtracePkg, errtracePkgImport) } case *insertWrapOpen: - fmt.Fprintf(out, "%s.Wrap", f.errtracePkg) + fmt.Fprintf(out, "%sWrap", f.errtracePkgPrefix()) if it.N > 1 { fmt.Fprintf(out, "%d", it.N) } @@ -605,7 +620,7 @@ func (cmd *mainCmd) rewriteFile(f parsedFile, out *bytes.Buffer) error { // Last return is an error, wrap it. last := &vars[len(vars)-1] - *last = fmt.Sprintf("%s.Wrap(%v)", f.errtracePkg, *last) + *last = fmt.Sprintf("%sWrap(%v)", f.errtracePkgPrefix(), *last) fmt.Fprintf(out, "; return %s }", strings.Join(vars, ", ")) @@ -625,7 +640,7 @@ func (cmd *mainCmd) rewriteFile(f parsedFile, out *bytes.Buffer) error { if i > 0 { _, _ = out.WriteString(", ") } - fmt.Fprintf(out, "%s.Wrap(%s)", f.errtracePkg, name) + fmt.Fprintf(out, "%sWrap(%s)", f.errtracePkgPrefix(), name) } _, _ = out.WriteString("; ") diff --git a/cmd/errtrace/testdata/toolexec-test/main.go b/cmd/errtrace/testdata/toolexec-test/main.go index fa384ee..533d4fa 100644 --- a/cmd/errtrace/testdata/toolexec-test/main.go +++ b/cmd/errtrace/testdata/toolexec-test/main.go @@ -14,5 +14,5 @@ func main() { } func callP1() error { - return p1.WrapP2() // @trace + return p1.WrapP2OnlyErr() // @trace } diff --git a/cmd/errtrace/testdata/toolexec-test/p1/p1.go b/cmd/errtrace/testdata/toolexec-test/p1/p1.go index 217045c..174ac9b 100644 --- a/cmd/errtrace/testdata/toolexec-test/p1/p1.go +++ b/cmd/errtrace/testdata/toolexec-test/p1/p1.go @@ -6,7 +6,15 @@ import ( "braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/p2" ) -// WrapP2 wraps an error return from p2. -func WrapP2() error { - return fmt.Errorf("test2: %w", p2.CallP3()) +// WrapP2OnlyErr only returns the error from WrapP2. +func WrapP2OnlyErr() error { + if _, err := WrapP2(); err != nil { + return fmt.Errorf("test2: %w", err) // @unsafe-trace + } + return nil +} + +// WrapRet2 calls WrapP2, but has a multi-return. +func WrapP2() (string, error) { + return p2.CallP3() // @unsafe-trace } diff --git a/cmd/errtrace/testdata/toolexec-test/p2/p2.go b/cmd/errtrace/testdata/toolexec-test/p2/p2.go index 32e99ee..5569473 100644 --- a/cmd/errtrace/testdata/toolexec-test/p2/p2.go +++ b/cmd/errtrace/testdata/toolexec-test/p2/p2.go @@ -7,6 +7,6 @@ import ( ) // CallP3 calls p3, and wraps the error. -func CallP3() error { - return errtrace.Wrap(p3.ReturnErr()) // @trace +func CallP3() (string, error) { + return errtrace.Wrap2(p3.ReturnStrErr()) // @trace } diff --git a/cmd/errtrace/testdata/toolexec-test/p3/p3.go b/cmd/errtrace/testdata/toolexec-test/p3/p3.go index b60a12f..9a06a84 100644 --- a/cmd/errtrace/testdata/toolexec-test/p3/p3.go +++ b/cmd/errtrace/testdata/toolexec-test/p3/p3.go @@ -4,7 +4,7 @@ import ( "errors" ) -// ReturnErr returns an error. -func ReturnErr() error { - return errors.New("test") // @trace +// ReturnStrErr returns an error. +func ReturnStrErr() (string, error) { + return "", errors.New("test") // @trace } diff --git a/cmd/errtrace/testdata/toolexec-unsafe-no-import/main.go b/cmd/errtrace/testdata/toolexec-unsafe-no-import/main.go new file mode 100644 index 0000000..81be4a6 --- /dev/null +++ b/cmd/errtrace/testdata/toolexec-unsafe-no-import/main.go @@ -0,0 +1,13 @@ +package main + +import "fmt" + +func main() { + if err := getErr(); err != nil { + fmt.Printf("%+v\n", err) + } +} + +func getErr() error { + return fmt.Errorf("err") +} diff --git a/cmd/errtrace/toolexec.go b/cmd/errtrace/toolexec.go index 76be1c3..7ed95a3 100644 --- a/cmd/errtrace/toolexec.go +++ b/cmd/errtrace/toolexec.go @@ -18,6 +18,9 @@ import ( "braces.dev/errtrace" ) +// Note: Choose a prefix that is not likely to clash with user symbols. +const errtraceUnsafePrefix = "__errtrace_" + func (cmd *mainCmd) handleToolExec(args []string) (exitCode int, handled bool) { // In toolexec mode, we're passed the original command + arguments. if len(args) == 0 { @@ -51,6 +54,7 @@ func (cmd *mainCmd) handleToolExec(args []string) (exitCode int, handled bool) { type toolExecParams struct { RequiredPkgSelectors []string + UnsafePkgSelectors []string Tool string ToolArgs []string @@ -64,9 +68,11 @@ func (p *toolExecParams) Parse(w io.Writer, args []string) error { logln(w, `usage with go build/run/test: -toolexec="errtrace [options]"`) flag.PrintDefaults() } - var requiredPkgs string + var requiredPkgs, unsafePkgs string p.flags.StringVar(&requiredPkgs, "required-packages", "", "comma-separated list of package selectors "+ "that are expected to be import errtrace if they return errors.") + p.flags.StringVar(&unsafePkgs, "unsafe-packages", "", "comma-separated list of package selectors "+ + "to rewrite using unsafe go:link, regardless of whether they import errtrace.") // Flag parsing stops at the first non-flag argument (no "-"). if err := p.flags.Parse(args); err != nil { @@ -81,6 +87,7 @@ func (p *toolExecParams) Parse(w io.Writer, args []string) error { p.Tool = remArgs[0] p.ToolArgs = remArgs[1:] p.RequiredPkgSelectors = strings.Split(requiredPkgs, ",") + p.UnsafePkgSelectors = strings.Split(unsafePkgs, ",") return nil } @@ -106,6 +113,27 @@ func (p *toolExecParams) requiredPackage(pkg string) bool { return false } +func (p *toolExecParams) unsafeRewriteStd() bool { + // stdlib requires an explicit opt-in. + // Since there's known issues with error checks in the stdlib + // which can break with error wrapping, we call it std-unsafe. + return slices.Contains(p.UnsafePkgSelectors, "std-unsafe") +} + +func (p *toolExecParams) unsafeRewrite(pkg string) bool { + if pkg == errtracePkgImport { + // Never rewrite the errtrace package, which leads to circular deps. + return false + } + + for _, selector := range p.UnsafePkgSelectors { + if packageSelectorMatch(selector, pkg) { + return true + } + } + return false +} + func (cmd *mainCmd) toolExecVersion(p toolExecParams) int { version, err := binaryVersion() if err != nil { @@ -145,8 +173,8 @@ func (cmd *mainCmd) toolExecRewrite(pkg string, p toolExecParams) (exitCode int) return cmd.runOriginal(p) } - // We only modify files that import errtrace, so stdlib is never eliglble. - if isStdLib(p.ToolArgs) { + // We only modify files that import errtrace, so stdlib is only eligible in unsafe mode. + if isStdLib(p.ToolArgs) && !p.unsafeRewriteStd() { return cmd.runOriginal(p) } @@ -160,43 +188,25 @@ func (cmd *mainCmd) toolExecRewrite(pkg string, p toolExecParams) (exitCode int) } func (cmd *mainCmd) rewriteCompile(pkg string, p toolExecParams) (exitCode int, _ error) { - var canRewrite, needRewrite bool - parsed := make(map[string]parsedFile) - for _, arg := range p.ToolArgs { - if !isGoFile(arg) { - continue - } - - contents, err := os.ReadFile(arg) - if err != nil { - return -1, errtrace.Wrap(err) - } - - f, err := cmd.parseFile(arg, contents, rewriteOpts{}) - if err != nil { - return -1, errtrace.Wrap(err) - } - parsed[arg] = f - - // TODO: Support an "unsafe" mode to rewrite packages without errtrace imports. - if f.importsErrtrace { - canRewrite = true - } - if len(f.inserts) > 0 { - needRewrite = true - } + parsed, err := cmd.parsePkg(pkg, p.ToolArgs) + if err != nil { + return -1, errtrace.Wrap(err) } - if !needRewrite { + if !parsed.needsRewrite { return cmd.runOriginal(p), nil } - if !canRewrite { - if p.requiredPackage(pkg) { + var unsafeForceImport bool + if !parsed.importsErrtrace { + unsafeForceImport = p.unsafeRewrite(pkg) + if !unsafeForceImport && p.requiredPackage(pkg) { logf(cmd.Stderr, "errtrace required package %v missing errtrace import, needs rewrite", pkg) return 1, nil } - return cmd.runOriginal(p), nil + if !unsafeForceImport { + return cmd.runOriginal(p), nil + } } // Use a temporary directory per-package that is rewritten. @@ -206,22 +216,36 @@ func (cmd *mainCmd) rewriteCompile(pkg string, p toolExecParams) (exitCode int, } defer os.RemoveAll(tempDir) //nolint:errcheck // best-effort removal of temp files. + // If a package doesn't already import errtrace, add `go:linkname` to the + // package to link to errtrace symbols. Only required once per-package. + addLinkName := unsafeForceImport + newArgs := make([]string, 0, len(p.ToolArgs)) for _, arg := range p.ToolArgs { - f, ok := parsed[arg] + f, ok := parsed.files[arg] if !ok || len(f.inserts) == 0 { newArgs = append(newArgs, arg) continue } + if unsafeForceImport { + f.errtraceUnsafePrefix = errtraceUnsafePrefix + } + // Add a //line directive so the original filepath is used in errors and panics. - var out bytes.Buffer - _, _ = fmt.Fprintf(&out, "//line %v:1\n", arg) + out := &bytes.Buffer{} + _, _ = fmt.Fprintf(out, "//line %v:1\n", arg) - if err := cmd.rewriteFile(f, &out); err != nil { + if err := cmd.rewriteFile(f, out); err != nil { return -1, errtrace.Wrap(err) } + if addLinkName { + _, _ = fmt.Fprintf(out, "\n\n//go:linkname %vWrap %v.Wrap\n", errtraceUnsafePrefix, errtracePkgImport) + _, _ = fmt.Fprintf(out, "func %vWrap(err error) error\n", errtraceUnsafePrefix) + addLinkName = false + } + // TODO: Handle clashes with the same base name in different directories (E.g., with bazel). newFile := filepath.Join(tempDir, filepath.Base(arg)) if err := os.WriteFile(newFile, out.Bytes(), 0o666); err != nil { @@ -235,6 +259,52 @@ func (cmd *mainCmd) rewriteCompile(pkg string, p toolExecParams) (exitCode int, return cmd.runOriginal(p), nil } +type parsePkgState struct { + pkg string + files map[string]parsedFile + importsErrtrace bool + needsRewrite bool +} + +func (cmd *mainCmd) parsePkg(pkg string, toolArgs []string) (*parsePkgState, error) { + s := &parsePkgState{ + pkg: pkg, + files: make(map[string]parsedFile), + } + + for _, arg := range toolArgs { + if !isGoFile(arg) { + continue + } + + contents, err := os.ReadFile(arg) + if err != nil { + return nil, errtrace.Wrap(err) + } + + f, err := cmd.parseFile(arg, contents, rewriteOpts{ + // WrapN is not compatible with unsafe rewrites, as `go:linkname` + // can't be used for generic functions like WrapN. + // We don't need WrapN, as it's is meant for direct source file changes, + // while toolexec writes ephemeral temp files. + NoWrapN: true, + }) + if err != nil { + return nil, errtrace.Wrap(err) + } + s.files[arg] = f + + if f.importsErrtrace { + s.importsErrtrace = true + } + if len(f.inserts) > 0 { + s.needsRewrite = true + } + } + + return s, nil +} + func isCompile(arg string) bool { if runtime.GOOS == "windows" { arg = strings.TrimSuffix(arg, ".exe") diff --git a/cmd/errtrace/toolexec_test.go b/cmd/errtrace/toolexec_test.go index 9cf64b9..d87e86b 100644 --- a/cmd/errtrace/toolexec_test.go +++ b/cmd/errtrace/toolexec_test.go @@ -90,6 +90,27 @@ func TestToolExec(t *testing.T) { }, wantTraces: wantTraces, }, + { + name: "toolexec with unsafe-packages ...", + goArgs: func(t testing.TB) []string { + return []string{"-toolexec", errTraceCmd + " -unsafe-packages=...", "."} + }, + wantTraces: append(wantTraces, tracePaths(t, testProgDir, "@unsafe-trace")...), + }, + { + name: "toolexec with unsafe-packages test/...", + goArgs: func(t testing.TB) []string { + return []string{"-toolexec", errTraceCmd + " -unsafe-packages=braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/...", "."} + }, + wantTraces: append(wantTraces, tracePaths(t, testProgDir, "@unsafe-trace")...), + }, + { + name: "toolexec with unsafe-packages test/p1", + goArgs: func(t testing.TB) []string { + return []string{"-toolexec", errTraceCmd + " -unsafe-packages=braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/p1", "."} + }, + wantTraces: append(wantTraces, tracePaths(t, testProgDir, "@unsafe-trace")...), + }, } for _, tt := range tests { @@ -116,6 +137,7 @@ func TestToolExec(t *testing.T) { verifyTraces := func(t testing.TB, stdout string) { gotLines := fileLines(stdout) sort.Strings(gotLines) + sort.Strings(tt.wantTraces) if d := diff.Diff(tt.wantTraces, gotLines); d != "" { t.Errorf("diff in traces:\n%s", d) @@ -152,6 +174,20 @@ func TestToolExec(t *testing.T) { }) }) } + + // When using -unsafe-packages, packages that don't import errtrace can be + // rewritten to use Wrap, but at least one package in the binary still needs + // to import errtrace, otherwise, the final link will fail. + t.Run("unsafe no errtrace import", func(t *testing.T) { + args := []string{"run", "-toolexec", errTraceCmd + " -unsafe-packages=...", "."} + _, stderr, err := runGo(t, "./testdata/toolexec-unsafe-no-import", args...) + if err == nil { + t.Fatal("run should fail") + } + if want := "relocation target braces.dev/errtrace.Wrap not defined"; !strings.Contains(stderr, want) { + t.Fatalf("stderr missing expected error: %v, got:\n%s", want, stderr) + } + }) } func tracePaths(t testing.TB, path string, traceMarker string) []string {