Skip to content

Commit 82de601

Browse files
authored
Add flag name suggestions for misspelled flags (#4723)
## Why When users misspell a flag name (e.g., `--outpu` instead of `--output`), they get a generic "unknown flag" error with no help. Cobra already suggests corrections for misspelled command names, but not flags. ## Changes Before: Misspelled flags produce a generic "unknown flag" error with no guidance. Now: The `flagErrorFunc` suggests the closest matching flag using Levenshtein distance (threshold of 2, matching Cobra's own suggestion threshold for commands). Both long flags (`--flagname`) and shorthand flags (`-x`) are handled. Hidden and deprecated flags are excluded from suggestions. A small Levenshtein distance function is included inline (no new dependencies). ## Test plan - [x] Unit tests for suggestion matching (close match, no match, shorthand) - [x] Unit tests for hidden flag exclusion - [x] Unit tests for the Levenshtein distance function - [x] Regression test for Cobra's error format parsing - [x] `make checks` passes
1 parent 1f4a349 commit 82de601

File tree

3 files changed

+366
-1
lines changed

3 files changed

+366
-1
lines changed

cmd/root/flag_suggestions.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package root
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"strings"
7+
8+
"github.com/spf13/cobra"
9+
"github.com/spf13/pflag"
10+
)
11+
12+
const maxSuggestionDistance = 2
13+
14+
// levenshteinDistance computes the edit distance between two strings.
15+
func levenshteinDistance(a, b string) int {
16+
if len(a) == 0 {
17+
return len(b)
18+
}
19+
if len(b) == 0 {
20+
return len(a)
21+
}
22+
23+
// Use a single row for the DP table.
24+
prev := make([]int, len(b)+1)
25+
for j := range len(b) + 1 {
26+
prev[j] = j
27+
}
28+
29+
for i := range len(a) {
30+
curr := make([]int, len(b)+1)
31+
curr[0] = i + 1
32+
for j := range len(b) {
33+
cost := 1
34+
if a[i] == b[j] {
35+
cost = 0
36+
}
37+
curr[j+1] = min(
38+
curr[j]+1, // insertion
39+
prev[j+1]+1, // deletion
40+
prev[j]+cost, // substitution
41+
)
42+
}
43+
prev = curr
44+
}
45+
46+
return prev[len(b)]
47+
}
48+
49+
// suggestFlagFromError inspects the error from Cobra for unknown-flag errors.
50+
// If a close match is found among the command's flags, it returns an enhanced error
51+
// with a "Did you mean" suggestion appended. Otherwise it returns the original error.
52+
func suggestFlagFromError(cmd *cobra.Command, err error) error {
53+
var notExist *pflag.NotExistError
54+
if !errors.As(err, &notExist) {
55+
return err
56+
}
57+
58+
flagName := notExist.GetSpecifiedName()
59+
isShorthand := notExist.GetSpecifiedShortnames() != ""
60+
61+
if isShorthand {
62+
return suggestShorthandFlag(cmd, err, flagName)
63+
}
64+
65+
return suggestLongFlag(cmd, err, flagName)
66+
}
67+
68+
// suggestLongFlag suggests a matching long flag name for an unknown long flag error.
69+
func suggestLongFlag(cmd *cobra.Command, original error, flagName string) error {
70+
if flagName == "" {
71+
return original
72+
}
73+
74+
best, bestDist := findClosestFlag(cmd, flagName)
75+
if best == "" || bestDist > maxSuggestionDistance {
76+
return original
77+
}
78+
79+
return fmt.Errorf("%w\n\nDid you mean \"--%s\"?", original, best)
80+
}
81+
82+
// suggestShorthandFlag suggests a matching shorthand for an unknown shorthand flag error.
83+
func suggestShorthandFlag(cmd *cobra.Command, original error, flagName string) error {
84+
if flagName == "" {
85+
return original
86+
}
87+
ch := string(flagName[0])
88+
89+
best := findClosestShorthand(cmd, ch)
90+
if best == "" {
91+
return original
92+
}
93+
94+
return fmt.Errorf("%w\n\nDid you mean \"-%s\"?", original, best)
95+
}
96+
97+
// findClosestFlag returns the closest non-hidden, non-deprecated long flag name
98+
// and its edit distance from the given misspelled name.
99+
func findClosestFlag(cmd *cobra.Command, name string) (string, int) {
100+
best := ""
101+
bestDist := maxSuggestionDistance + 1
102+
103+
seen := map[string]bool{}
104+
check := func(f *pflag.Flag) {
105+
if f.Hidden || f.Deprecated != "" {
106+
return
107+
}
108+
if seen[f.Name] {
109+
return
110+
}
111+
seen[f.Name] = true
112+
113+
d := levenshteinDistance(name, f.Name)
114+
if d < bestDist {
115+
bestDist = d
116+
best = f.Name
117+
}
118+
}
119+
120+
cmd.Flags().VisitAll(check)
121+
cmd.InheritedFlags().VisitAll(check)
122+
123+
return best, bestDist
124+
}
125+
126+
// findClosestShorthand returns a case-insensitive exact match for the given
127+
// shorthand character. Levenshtein is not useful for single characters because
128+
// any two distinct characters always have distance 1.
129+
func findClosestShorthand(cmd *cobra.Command, ch string) string {
130+
best := ""
131+
seen := map[string]bool{}
132+
check := func(f *pflag.Flag) {
133+
if f.Hidden || f.Deprecated != "" || f.ShorthandDeprecated != "" || f.Shorthand == "" {
134+
return
135+
}
136+
if seen[f.Shorthand] {
137+
return
138+
}
139+
seen[f.Shorthand] = true
140+
if strings.EqualFold(ch, f.Shorthand) {
141+
best = f.Shorthand
142+
}
143+
}
144+
cmd.Flags().VisitAll(check)
145+
cmd.InheritedFlags().VisitAll(check)
146+
return best
147+
}

cmd/root/flag_suggestions_test.go

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
package root
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"testing"
7+
8+
"github.com/spf13/cobra"
9+
"github.com/spf13/pflag"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
// parseUnknownFlag triggers Cobra's flag parsing on args and returns the error.
15+
// The command is set up with DisableFlagParsing=false (default) and a
16+
// RunE that does nothing, so the only errors come from flag parsing.
17+
func parseUnknownFlag(cmd *cobra.Command, args []string) error {
18+
cmd.RunE = func(cmd *cobra.Command, args []string) error { return nil }
19+
cmd.SetArgs(args)
20+
return cmd.Execute()
21+
}
22+
23+
func TestLevenshteinDistance(t *testing.T) {
24+
tests := []struct {
25+
a, b string
26+
want int
27+
}{
28+
{"", "", 0},
29+
{"abc", "abc", 0},
30+
{"", "abc", 3},
31+
{"abc", "", 3},
32+
{"kitten", "sitting", 3},
33+
{"output", "outpu", 1}, // deletion
34+
{"output", "ouptut", 2}, // transposition = 2 edits
35+
{"output", "outpux", 1}, // substitution
36+
{"output", "outputx", 1}, // insertion
37+
}
38+
39+
for _, tt := range tests {
40+
t.Run(fmt.Sprintf("%s_%s", tt.a, tt.b), func(t *testing.T) {
41+
assert.Equal(t, tt.want, levenshteinDistance(tt.a, tt.b))
42+
})
43+
}
44+
}
45+
46+
func TestSuggestFlagFromError_LongFlagCloseMatch(t *testing.T) {
47+
cmd := &cobra.Command{Use: "test"}
48+
cmd.Flags().String("output", "", "output format")
49+
50+
err := &pflag.NotExistError{}
51+
// Parse "--outpu" to get a real error
52+
parseErr := parseUnknownFlag(cmd, []string{"--outpu"})
53+
require.Error(t, parseErr)
54+
55+
// Extract the pflag error from the cobra wrapping
56+
require.ErrorAs(t, parseErr, &err)
57+
58+
got := suggestFlagFromError(cmd, parseErr)
59+
assert.Contains(t, got.Error(), `Did you mean "--output"?`)
60+
assert.Contains(t, got.Error(), "unknown flag: --outpu")
61+
}
62+
63+
func TestSuggestFlagFromError_LongFlagNoMatch(t *testing.T) {
64+
cmd := &cobra.Command{Use: "test"}
65+
cmd.Flags().String("output", "", "output format")
66+
67+
parseErr := parseUnknownFlag(cmd, []string{"--zzzzzzz"})
68+
require.Error(t, parseErr)
69+
70+
got := suggestFlagFromError(cmd, parseErr)
71+
assert.NotContains(t, got.Error(), "Did you mean")
72+
}
73+
74+
func TestSuggestFlagFromError_ShorthandFlag(t *testing.T) {
75+
cmd := &cobra.Command{Use: "test"}
76+
cmd.Flags().StringP("output", "o", "", "output format")
77+
78+
parseErr := parseUnknownFlag(cmd, []string{"-O"})
79+
require.Error(t, parseErr)
80+
81+
got := suggestFlagFromError(cmd, parseErr)
82+
assert.Contains(t, got.Error(), `Did you mean "-o"?`)
83+
}
84+
85+
func TestSuggestFlagFromError_HiddenFlagsExcluded(t *testing.T) {
86+
cmd := &cobra.Command{Use: "test"}
87+
cmd.Flags().String("secret", "", "secret flag")
88+
require.NoError(t, cmd.Flags().MarkHidden("secret"))
89+
90+
parseErr := parseUnknownFlag(cmd, []string{"--secre"})
91+
require.Error(t, parseErr)
92+
93+
got := suggestFlagFromError(cmd, parseErr)
94+
assert.NotContains(t, got.Error(), "Did you mean")
95+
}
96+
97+
func TestSuggestFlagFromError_DeprecatedFlagsExcluded(t *testing.T) {
98+
cmd := &cobra.Command{Use: "test"}
99+
cmd.Flags().String("legacy", "", "old flag")
100+
require.NoError(t, cmd.Flags().MarkDeprecated("legacy", "use --new instead"))
101+
102+
parseErr := parseUnknownFlag(cmd, []string{"--legac"})
103+
require.Error(t, parseErr)
104+
105+
got := suggestFlagFromError(cmd, parseErr)
106+
assert.NotContains(t, got.Error(), "Did you mean")
107+
}
108+
109+
func TestSuggestFlagFromError_InheritedFlags(t *testing.T) {
110+
parent := &cobra.Command{Use: "parent"}
111+
parent.PersistentFlags().String("profile", "", "auth profile")
112+
113+
child := &cobra.Command{Use: "child"}
114+
child.RunE = func(cmd *cobra.Command, args []string) error { return nil }
115+
parent.AddCommand(child)
116+
117+
parent.SetArgs([]string{"child", "--profil"})
118+
parseErr := parent.Execute()
119+
require.Error(t, parseErr)
120+
121+
got := suggestFlagFromError(child, parseErr)
122+
assert.Contains(t, got.Error(), `Did you mean "--profile"?`)
123+
}
124+
125+
func TestSuggestFlagFromError_NonFlagError(t *testing.T) {
126+
cmd := &cobra.Command{Use: "test"}
127+
cmd.Flags().String("output", "", "output format")
128+
129+
err := errors.New("some other error")
130+
got := suggestFlagFromError(cmd, err)
131+
assert.Equal(t, err.Error(), got.Error())
132+
}
133+
134+
func TestSuggestFlagFromError_DeduplicatesLocalAndInherited(t *testing.T) {
135+
parent := &cobra.Command{Use: "parent"}
136+
parent.PersistentFlags().String("target", "", "deployment target")
137+
138+
child := &cobra.Command{Use: "child"}
139+
child.Flags().String("target", "", "deployment target")
140+
child.RunE = func(cmd *cobra.Command, args []string) error { return nil }
141+
parent.AddCommand(child)
142+
143+
parent.SetArgs([]string{"child", "--targe"})
144+
parseErr := parent.Execute()
145+
require.Error(t, parseErr)
146+
147+
got := suggestFlagFromError(child, parseErr)
148+
assert.Contains(t, got.Error(), `Did you mean "--target"?`)
149+
}
150+
151+
func TestSuggestFlagFromError_ShorthandUnrelatedNoSuggestion(t *testing.T) {
152+
cmd := &cobra.Command{Use: "test"}
153+
cmd.Flags().StringP("output", "o", "", "output format")
154+
155+
parseErr := parseUnknownFlag(cmd, []string{"-z"})
156+
require.Error(t, parseErr)
157+
158+
got := suggestFlagFromError(cmd, parseErr)
159+
assert.NotContains(t, got.Error(), "Did you mean")
160+
}
161+
162+
func TestSuggestFlagFromError_ShorthandDeprecatedStillSuggestsLongFlag(t *testing.T) {
163+
cmd := &cobra.Command{Use: "test"}
164+
cmd.Flags().StringP("output", "o", "", "output format")
165+
require.NoError(t, cmd.Flags().MarkShorthandDeprecated("output", "use --output instead"))
166+
167+
parseErr := parseUnknownFlag(cmd, []string{"--outpu"})
168+
require.Error(t, parseErr)
169+
170+
// The long flag should still be suggested even though the shorthand is deprecated.
171+
got := suggestFlagFromError(cmd, parseErr)
172+
assert.Contains(t, got.Error(), `Did you mean "--output"?`)
173+
}
174+
175+
func TestSuggestFlagFromError_ShorthandDeprecatedExcludedFromShorthandSuggestions(t *testing.T) {
176+
cmd := &cobra.Command{Use: "test"}
177+
cmd.Flags().StringP("output", "o", "", "output format")
178+
require.NoError(t, cmd.Flags().MarkShorthandDeprecated("output", "use --output instead"))
179+
180+
parseErr := parseUnknownFlag(cmd, []string{"-O"})
181+
require.Error(t, parseErr)
182+
183+
// The deprecated shorthand should NOT be suggested.
184+
got := suggestFlagFromError(cmd, parseErr)
185+
assert.NotContains(t, got.Error(), "Did you mean")
186+
}
187+
188+
func TestSuggestFlagFromError_TieBreakingEquidistantFlags(t *testing.T) {
189+
cmd := &cobra.Command{Use: "test"}
190+
// "ab" and "ac" are both distance 1 from "aa"
191+
cmd.Flags().String("ab", "", "")
192+
cmd.Flags().String("ac", "", "")
193+
194+
parseErr := parseUnknownFlag(cmd, []string{"--aa"})
195+
require.Error(t, parseErr)
196+
197+
got := suggestFlagFromError(cmd, parseErr)
198+
// Both are equidistant; we accept whichever is returned (order depends on
199+
// flag iteration) but a suggestion must be present.
200+
assert.Contains(t, got.Error(), "Did you mean")
201+
}
202+
203+
func TestSuggestFlagFromError_IntegrationThroughFlagErrorFunc(t *testing.T) {
204+
cmd := &cobra.Command{Use: "test"}
205+
cmd.Flags().String("output", "", "output format")
206+
cmd.SetFlagErrorFunc(flagErrorFunc)
207+
cmd.RunE = func(cmd *cobra.Command, args []string) error { return nil }
208+
cmd.SetArgs([]string{"--outpu"})
209+
210+
err := cmd.Execute()
211+
require.Error(t, err)
212+
213+
assert.Contains(t, err.Error(), `Did you mean "--output"?`)
214+
// flagErrorFunc also appends usage
215+
assert.Contains(t, err.Error(), "Usage:")
216+
}

cmd/root/root.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,10 @@ func New(ctx context.Context) *cobra.Command {
9797
return cmd
9898
}
9999

100-
// Wrap flag errors to include the usage string.
100+
// flagErrorFunc wraps flag errors to include the usage string and, for unknown
101+
// flags, a "Did you mean" suggestion based on Levenshtein distance.
101102
func flagErrorFunc(c *cobra.Command, err error) error {
103+
err = suggestFlagFromError(c, err)
102104
return fmt.Errorf("%w\n\n%s", err, c.UsageString())
103105
}
104106

0 commit comments

Comments
 (0)