diff --git a/flag.go b/flag.go index 7c058de3..529c84bb 100644 --- a/flag.go +++ b/flag.go @@ -165,6 +165,7 @@ type FlagSet struct { normalizeNameFunc func(f *FlagSet, name string) NormalizedName addedGoFlagSets []*goflag.FlagSet + unknownFlags []*Flag } // A Flag represents the state of a flag. @@ -182,6 +183,12 @@ type Flag struct { Annotations map[string][]string // used by cobra.Command bash autocomple code } +// A UnknownFlag represents the state of a flag that is not expected. +type UnknownFlag struct { + Name string // name as it appears on command line + Value Value // value as set +} + // Value is the interface to the dynamic value stored in a flag. // (The default value is represented as a string.) type Value interface { @@ -275,6 +282,17 @@ func (f *FlagSet) SetOutput(output io.Writer) { f.output = output } +// VisitUnknowns visits all the flags that have not been registered. +func (f *FlagSet) VisitUnknowns(fn func(*Flag)) { + if len(f.unknownFlags) == 0 { + return + } + + for _, flag := range f.unknownFlags { + fn(flag) + } +} + // VisitAll visits the flags in lexicographical order or // in primordial order if f.SortFlags is false, calling fn for each. // It visits all flags, even those not set. @@ -956,6 +974,18 @@ func stripUnknownFlagValue(args []string) []string { return nil } +func createUnknownFlag(name string, value string) *Flag { + flag := new(Flag) + flag.Name = name + flag.Value = newStringValue(value, &value) + return flag +} + +func (f *FlagSet) addUnknownFlag(name string, value string) { + flag := createUnknownFlag(name, value) + f.unknownFlags = append(f.unknownFlags, flag) +} + func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []string, err error) { a = args name := s[2:] @@ -969,19 +999,11 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin flag, exists := f.formal[f.normalizeFlagName(name)] if !exists { - switch { - case name == "help": + if name == "help" { f.usage() return a, ErrHelp - case f.ParseErrorsWhitelist.UnknownFlags: - // --unknown=unknownval arg ... - // we do not want to lose arg in this case - if len(split) >= 2 { - return a, nil - } - - return stripUnknownFlagValue(a), nil - default: + } + if !f.ParseErrorsWhitelist.UnknownFlags { err = f.failf("unknown flag: --%s", name) return } @@ -991,16 +1013,23 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin if len(split) == 2 { // '--flag=arg' value = split[1] - } else if flag.NoOptDefVal != "" { + } else if exists && flag.NoOptDefVal != "" { // '--flag' (arg was optional) value = flag.NoOptDefVal } else if len(a) > 0 { // '--flag arg' - value = a[0] - a = a[1:] - } else { - // '--flag' (arg was required) - err = f.failf("flag needs an argument: %s", s) + if !exists && strings.HasPrefix(a[0], "-") { + value = "" + } else { + value = a[0] + a = a[1:] + } + } else if f.ParseErrorsWhitelist.UnknownFlags { + value = "" + } + + if !exists && f.ParseErrorsWhitelist.UnknownFlags { + f.addUnknownFlag(name, value) return } @@ -1023,22 +1052,12 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse flag, exists := f.shorthands[c] if !exists { - switch { - case c == 'h': + if c == 'h' { f.usage() err = ErrHelp return - case f.ParseErrorsWhitelist.UnknownFlags: - // '-f=arg arg ...' - // we do not want to lose arg in this case - if len(shorthands) > 2 && shorthands[1] == '=' { - outShorts = "" - return - } - - outArgs = stripUnknownFlagValue(outArgs) - return - default: + } + if !f.ParseErrorsWhitelist.UnknownFlags { err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands) return } @@ -1049,8 +1068,8 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse // '-f=arg' value = shorthands[2:] outShorts = "" - } else if flag.NoOptDefVal != "" { - // '-f' (arg was optional) + } else if exists && flag.NoOptDefVal != "" { + // '--flag' (arg was optional) value = flag.NoOptDefVal } else if len(shorthands) > 1 { // '-farg' @@ -1058,9 +1077,23 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse outShorts = "" } else if len(args) > 0 { // '-f arg' - value = args[0] - outArgs = args[1:] - } else { + if !exists && strings.HasPrefix(args[0], "-") { + value = "" + } else { + value = args[0] + outArgs = args[1:] + } + + } else if f.ParseErrorsWhitelist.UnknownFlags { + value = "" + } + + if !exists && f.ParseErrorsWhitelist.UnknownFlags { + f.addUnknownFlag(string(c), value) + return + } + + if flag.NoOptDefVal == "" && value == "" { // '-f' (arg was required) err = f.failf("flag needs an argument: %q in -%s", c, shorthands) return diff --git a/flag_test.go b/flag_test.go index 58a5d25a..13e9e242 100644 --- a/flag_test.go +++ b/flag_test.go @@ -480,6 +480,68 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { } } +func testRetrieveUknowsWhenUnknownFlagsParsed(t *testing.T) { + f := NewFlagSet("unknwonFlags", ContinueOnError) + if f.Parsed() { + t.Error("f.Parse() = true before Parse") + } + boolaFlag := f.BoolP("boola", "a", false, "bool value") + stringaFlag := f.StringP("stringa", "s", "0", "string value") + + args := []string{ + "-a", + "--stringa", + "hello", + "--unknownFlag1", + "unknownValue1", + "--unknownFlag2", + "--unknownFlag3=unknownValue3", + "-e", + "unknownValue4", + "-f=unknownValue5", + "-g", + } + + f.ParseErrorsWhitelist.UnknownFlags = true + + want := map[string]string{ + "unknownFlag1": "unknownValue1", + "unknownFlag2": "", + "unknownFlag3": "unknownValue3", + "e": "unknownValue4", + "f": "unknownValue5", + "g": "", + } + + f.SetOutput(ioutil.Discard) + if err := f.Parse(args); err != nil { + t.Error("expected no error, got ", err) + } + if !f.Parsed() { + t.Error("f.Parse() = false after Parse") + } + if *boolaFlag != true { + t.Error("boola flag should be true, is ", *boolaFlag) + } + if *stringaFlag != "hello" { + t.Error("stringa flag should be `hello`, is ", *stringaFlag) + } + if len(f.unknownFlags) != len(want) { + t.Errorf("f.ParseAll() failed to parse unknown flags") + } + for _, flag := range f.unknownFlags { + wantedValue, ok := want[flag.Name] + if !ok { + t.Errorf("f.unknownFlags contains a flag \"%s\" and shouldn't", flag.Name) + break + } + if wantedValue != flag.Value.String() { + t.Errorf("value for the unknown flag \"%s\" should be \"%s\", got \"%s\"", flag.Name, wantedValue, flag.Value.String()) + } + + } +} + func TestShorthand(t *testing.T) { f := NewFlagSet("shorthand", ContinueOnError) if f.Parsed() { @@ -588,6 +650,7 @@ func TestParseAll(t *testing.T) { func TestIgnoreUnknownFlags(t *testing.T) { ResetForTesting(func() { t.Error("bad parse") }) + testRetrieveUknowsWhenUnknownFlagsParsed(t) testParseWithUnknownFlags(GetCommandLine(), t) }