Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

### Added

- Support compile-time automatic instrumentation of packages
that don't import `errtrace` by specifying the flag `unsafe-packages`.
This still requires at least one import of `errtrace` in the binary.

## 0.4.0 - 2025-07-21

This release supports compile-time rewriting of source files via `toolexec`.
Expand Down
35 changes: 25 additions & 10 deletions cmd/errtrace/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
// -l list files that would be modified without making any changes.
package main

// TODO
// - -toolexec: run as a tool executor, fit for use with 'go build -toolexec'

import (
"bytes"
"encoding/json"
Expand All @@ -52,6 +49,8 @@ import (
"braces.dev/errtrace"
)

const errtracePkgImport = "braces.dev/errtrace"

func main() {
cmd := &mainCmd{
Stdin: os.Stdin,
Expand Down Expand Up @@ -413,6 +412,17 @@ type parsedFile struct {
importsErrtrace bool // includes blank imports
inserts []insert
unusedOptouts []int // list of line numbers

// Used for toolexec unsafe mode, when rewriting packages that don't import
// errtrace, so we use go:linkname wrappers.
errtraceUnsafePrefix string
}

func (f *parsedFile) errtracePkgPrefix() string {
if f.errtraceUnsafePrefix == "" {
return fmt.Sprintf("%s.", f.errtracePkg)
}
return f.errtraceUnsafePrefix
}

func (cmd *mainCmd) parseFile(filename string, src []byte, opts rewriteOpts) (parsedFile, error) {
Expand All @@ -426,7 +436,7 @@ func (cmd *mainCmd) parseFile(filename string, src []byte, opts rewriteOpts) (pa
var importsErrtrace bool // whether there's any errtrace import, including blank imports
needErrtraceImport := true // whether to add a new import.
for _, imp := range f.Imports {
if imp.Path.Value == `"braces.dev/errtrace"` {
if imp.Path.Value == `"`+errtracePkgImport+`"` {
importsErrtrace = true
if imp.Name != nil {
if imp.Name.Name == "_" {
Expand Down Expand Up @@ -574,15 +584,20 @@ func (cmd *mainCmd) rewriteFile(f parsedFile, out *bytes.Buffer) error {
_, _ = io.WriteString(out, "import ")
}

if f.errtracePkg == "errtrace" {
if f.errtraceUnsafePrefix != "" {
// unsafe mode uses go:link, which requires importing "unsafe" instead
// of errtrace. Since duplicate blank imports are allowed by Go, we can
// always add it, without checking if it has been imported.
fmt.Fprint(out, `_ "unsafe"`)
} else if f.errtracePkg == "errtrace" {
// Don't use named imports if we're using the default name.
fmt.Fprintf(out, "%q", "braces.dev/errtrace")
fmt.Fprintf(out, "%q", errtracePkgImport)
} else {
fmt.Fprintf(out, "%s %q", f.errtracePkg, "braces.dev/errtrace")
fmt.Fprintf(out, "%s %q", f.errtracePkg, errtracePkgImport)
}

case *insertWrapOpen:
fmt.Fprintf(out, "%s.Wrap", f.errtracePkg)
fmt.Fprintf(out, "%sWrap", f.errtracePkgPrefix())
if it.N > 1 {
fmt.Fprintf(out, "%d", it.N)
}
Expand All @@ -605,7 +620,7 @@ func (cmd *mainCmd) rewriteFile(f parsedFile, out *bytes.Buffer) error {

// Last return is an error, wrap it.
last := &vars[len(vars)-1]
*last = fmt.Sprintf("%s.Wrap(%v)", f.errtracePkg, *last)
*last = fmt.Sprintf("%sWrap(%v)", f.errtracePkgPrefix(), *last)

fmt.Fprintf(out, "; return %s }", strings.Join(vars, ", "))

Expand All @@ -625,7 +640,7 @@ func (cmd *mainCmd) rewriteFile(f parsedFile, out *bytes.Buffer) error {
if i > 0 {
_, _ = out.WriteString(", ")
}
fmt.Fprintf(out, "%s.Wrap(%s)", f.errtracePkg, name)
fmt.Fprintf(out, "%sWrap(%s)", f.errtracePkgPrefix(), name)
}
_, _ = out.WriteString("; ")

Expand Down
2 changes: 1 addition & 1 deletion cmd/errtrace/testdata/toolexec-test/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ func main() {
}

func callP1() error {
return p1.WrapP2() // @trace
return p1.WrapP2OnlyErr() // @trace
}
14 changes: 11 additions & 3 deletions cmd/errtrace/testdata/toolexec-test/p1/p1.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@ import (
"braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/p2"
)

// WrapP2 wraps an error return from p2.
func WrapP2() error {
return fmt.Errorf("test2: %w", p2.CallP3())
// WrapP2OnlyErr only returns the error from WrapP2.
func WrapP2OnlyErr() error {
if _, err := WrapP2(); err != nil {
return fmt.Errorf("test2: %w", err) // @unsafe-trace
}
return nil
}

// WrapRet2 calls WrapP2, but has a multi-return.
func WrapP2() (string, error) {
return p2.CallP3() // @unsafe-trace
}
4 changes: 2 additions & 2 deletions cmd/errtrace/testdata/toolexec-test/p2/p2.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ import (
)

// CallP3 calls p3, and wraps the error.
func CallP3() error {
return errtrace.Wrap(p3.ReturnErr()) // @trace
func CallP3() (string, error) {
return errtrace.Wrap2(p3.ReturnStrErr()) // @trace
}
6 changes: 3 additions & 3 deletions cmd/errtrace/testdata/toolexec-test/p3/p3.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"errors"
)

// ReturnErr returns an error.
func ReturnErr() error {
return errors.New("test") // @trace
// ReturnStrErr returns an error.
func ReturnStrErr() (string, error) {
return "", errors.New("test") // @trace
}
13 changes: 13 additions & 0 deletions cmd/errtrace/testdata/toolexec-unsafe-no-import/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package main

import "fmt"

func main() {
if err := getErr(); err != nil {
fmt.Printf("%+v\n", err)
}
}

func getErr() error {
return fmt.Errorf("err")
}
142 changes: 106 additions & 36 deletions cmd/errtrace/toolexec.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import (
"braces.dev/errtrace"
)

// Note: Choose a prefix that is not likely to clash with user symbols.
const errtraceUnsafePrefix = "__errtrace_"

func (cmd *mainCmd) handleToolExec(args []string) (exitCode int, handled bool) {
// In toolexec mode, we're passed the original command + arguments.
if len(args) == 0 {
Expand Down Expand Up @@ -51,6 +54,7 @@ func (cmd *mainCmd) handleToolExec(args []string) (exitCode int, handled bool) {

type toolExecParams struct {
RequiredPkgSelectors []string
UnsafePkgSelectors []string

Tool string
ToolArgs []string
Expand All @@ -64,9 +68,11 @@ func (p *toolExecParams) Parse(w io.Writer, args []string) error {
logln(w, `usage with go build/run/test: -toolexec="errtrace [options]"`)
flag.PrintDefaults()
}
var requiredPkgs string
var requiredPkgs, unsafePkgs string
p.flags.StringVar(&requiredPkgs, "required-packages", "", "comma-separated list of package selectors "+
"that are expected to be import errtrace if they return errors.")
p.flags.StringVar(&unsafePkgs, "unsafe-packages", "", "comma-separated list of package selectors "+
"to rewrite using unsafe go:link, regardless of whether they import errtrace.")

// Flag parsing stops at the first non-flag argument (no "-").
if err := p.flags.Parse(args); err != nil {
Expand All @@ -81,6 +87,7 @@ func (p *toolExecParams) Parse(w io.Writer, args []string) error {
p.Tool = remArgs[0]
p.ToolArgs = remArgs[1:]
p.RequiredPkgSelectors = strings.Split(requiredPkgs, ",")
p.UnsafePkgSelectors = strings.Split(unsafePkgs, ",")
return nil
}

Expand All @@ -106,6 +113,27 @@ func (p *toolExecParams) requiredPackage(pkg string) bool {
return false
}

func (p *toolExecParams) unsafeRewriteStd() bool {
// stdlib requires an explicit opt-in.
// Since there's known issues with error checks in the stdlib
// which can break with error wrapping, we call it std-unsafe.
return slices.Contains(p.UnsafePkgSelectors, "std-unsafe")
}

func (p *toolExecParams) unsafeRewrite(pkg string) bool {
if pkg == errtracePkgImport {
// Never rewrite the errtrace package, which leads to circular deps.
return false
}

for _, selector := range p.UnsafePkgSelectors {
if packageSelectorMatch(selector, pkg) {
return true
}
}
return false
}

func (cmd *mainCmd) toolExecVersion(p toolExecParams) int {
version, err := binaryVersion()
if err != nil {
Expand Down Expand Up @@ -145,8 +173,8 @@ func (cmd *mainCmd) toolExecRewrite(pkg string, p toolExecParams) (exitCode int)
return cmd.runOriginal(p)
}

// We only modify files that import errtrace, so stdlib is never eliglble.
if isStdLib(p.ToolArgs) {
// We only modify files that import errtrace, so stdlib is only eligible in unsafe mode.
if isStdLib(p.ToolArgs) && !p.unsafeRewriteStd() {
return cmd.runOriginal(p)
}

Expand All @@ -160,43 +188,25 @@ func (cmd *mainCmd) toolExecRewrite(pkg string, p toolExecParams) (exitCode int)
}

func (cmd *mainCmd) rewriteCompile(pkg string, p toolExecParams) (exitCode int, _ error) {
var canRewrite, needRewrite bool
parsed := make(map[string]parsedFile)
for _, arg := range p.ToolArgs {
if !isGoFile(arg) {
continue
}

contents, err := os.ReadFile(arg)
if err != nil {
return -1, errtrace.Wrap(err)
}

f, err := cmd.parseFile(arg, contents, rewriteOpts{})
if err != nil {
return -1, errtrace.Wrap(err)
}
parsed[arg] = f

// TODO: Support an "unsafe" mode to rewrite packages without errtrace imports.
if f.importsErrtrace {
canRewrite = true
}
if len(f.inserts) > 0 {
needRewrite = true
}
parsed, err := cmd.parsePkg(pkg, p.ToolArgs)
if err != nil {
return -1, errtrace.Wrap(err)
}

if !needRewrite {
if !parsed.needsRewrite {
return cmd.runOriginal(p), nil
}

if !canRewrite {
if p.requiredPackage(pkg) {
var unsafeForceImport bool
if !parsed.importsErrtrace {
unsafeForceImport = p.unsafeRewrite(pkg)
if !unsafeForceImport && p.requiredPackage(pkg) {
logf(cmd.Stderr, "errtrace required package %v missing errtrace import, needs rewrite", pkg)
return 1, nil
}
return cmd.runOriginal(p), nil
if !unsafeForceImport {
return cmd.runOriginal(p), nil
}
}

// Use a temporary directory per-package that is rewritten.
Expand All @@ -206,22 +216,36 @@ func (cmd *mainCmd) rewriteCompile(pkg string, p toolExecParams) (exitCode int,
}
defer os.RemoveAll(tempDir) //nolint:errcheck // best-effort removal of temp files.

// If a package doesn't already import errtrace, add `go:linkname` to the
// package to link to errtrace symbols. Only required once per-package.
addLinkName := unsafeForceImport

newArgs := make([]string, 0, len(p.ToolArgs))
for _, arg := range p.ToolArgs {
f, ok := parsed[arg]
f, ok := parsed.files[arg]
if !ok || len(f.inserts) == 0 {
newArgs = append(newArgs, arg)
continue
}

if unsafeForceImport {
f.errtraceUnsafePrefix = errtraceUnsafePrefix
}

// Add a //line directive so the original filepath is used in errors and panics.
var out bytes.Buffer
_, _ = fmt.Fprintf(&out, "//line %v:1\n", arg)
out := &bytes.Buffer{}
_, _ = fmt.Fprintf(out, "//line %v:1\n", arg)

if err := cmd.rewriteFile(f, &out); err != nil {
if err := cmd.rewriteFile(f, out); err != nil {
return -1, errtrace.Wrap(err)
}

if addLinkName {
_, _ = fmt.Fprintf(out, "\n\n//go:linkname %vWrap %v.Wrap\n", errtraceUnsafePrefix, errtracePkgImport)
_, _ = fmt.Fprintf(out, "func %vWrap(err error) error\n", errtraceUnsafePrefix)
addLinkName = false
}

// TODO: Handle clashes with the same base name in different directories (E.g., with bazel).
newFile := filepath.Join(tempDir, filepath.Base(arg))
if err := os.WriteFile(newFile, out.Bytes(), 0o666); err != nil {
Expand All @@ -235,6 +259,52 @@ func (cmd *mainCmd) rewriteCompile(pkg string, p toolExecParams) (exitCode int,
return cmd.runOriginal(p), nil
}

type parsePkgState struct {
pkg string
files map[string]parsedFile
importsErrtrace bool
needsRewrite bool
}

func (cmd *mainCmd) parsePkg(pkg string, toolArgs []string) (*parsePkgState, error) {
s := &parsePkgState{
pkg: pkg,
files: make(map[string]parsedFile),
}

for _, arg := range toolArgs {
if !isGoFile(arg) {
continue
}

contents, err := os.ReadFile(arg)
if err != nil {
return nil, errtrace.Wrap(err)
}

f, err := cmd.parseFile(arg, contents, rewriteOpts{
// WrapN is not compatible with unsafe rewrites, as `go:linkname`
// can't be used for generic functions like WrapN.
// We don't need WrapN, as it's is meant for direct source file changes,
// while toolexec writes ephemeral temp files.
NoWrapN: true,
})
if err != nil {
return nil, errtrace.Wrap(err)
}
s.files[arg] = f

if f.importsErrtrace {
s.importsErrtrace = true
}
if len(f.inserts) > 0 {
s.needsRewrite = true
}
}

return s, nil
}

func isCompile(arg string) bool {
if runtime.GOOS == "windows" {
arg = strings.TrimSuffix(arg, ".exe")
Expand Down
Loading
Loading