diff --git a/flag.go b/flag.go index 107fa190..e6726966 100644 --- a/flag.go +++ b/flag.go @@ -137,10 +137,42 @@ const ( PanicOnError ) +// UnknownFlagsHandling decides how to handle unknown flags +type UnknownFlagsHandling int + +const ( + // UnknownFlagsHandlingErrorOnUnknown will return an error if an unknown flag is found + UnknownFlagsHandlingErrorOnUnknown UnknownFlagsHandling = iota + // UnknownFlagsHandlingIgnoreUnknown will ignore unknown flags and continue parsing rest of the flags + UnknownFlagsHandlingIgnoreUnknown + // UnknownFlagsHandlingPassUnknownToArgs will treat unknown flags as non-flag arguments. + // Combined shorthand flags mixed with known ones and unknown ones results + // combined flags only with unknown ones. + // E.g. -fghi results -gh if only `f` and `i` are known. + UnknownFlagsHandlingPassUnknownToArgs +) + // ParseErrorsAllowlist defines the parsing errors that can be ignored type ParseErrorsAllowlist struct { // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags + // Deprecated: Use UnknownFlagsHandling instead UnknownFlags bool + + // UnknownFlagsHandling decides how to handle unknown flags. Defaults to UnknownFlagsHandlingErrorOnUnknown. + UnknownFlagsHandling UnknownFlagsHandling +} + +// getUnknownFlagsHandling returns the UnknownFlagsHandling value, considering deprecated UnknownFlags field +func (a *ParseErrorsAllowlist) getUnknownFlagsHandling() UnknownFlagsHandling { + // if UnknownFlagsHandling is set, use it + if a.UnknownFlagsHandling != UnknownFlagsHandlingErrorOnUnknown { + return a.UnknownFlagsHandling + } + + if a.UnknownFlags { + return UnknownFlagsHandlingIgnoreUnknown + } + return UnknownFlagsHandlingErrorOnUnknown } // NormalizedName is a flag name that has been normalized according to rules @@ -967,6 +999,17 @@ func stripUnknownFlagValue(args []string) []string { return nil } +// errUnknownFlag is used for internal unknown flag handling. +type unknownFlagError struct { + // UnknownFlags is flags that are unknown and unprocessed. + // It depends on the context whether this has a prefix like '-' or '--'. + UnknownFlags string +} + +func (e *unknownFlagError) Error() string { + return fmt.Sprintf("unknown flag: %v", e.UnknownFlags) +} + func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []string, err error) { a = args name := s[2:] @@ -978,13 +1021,14 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin split := strings.SplitN(name, "=", 2) name = split[0] flag, exists := f.formal[f.normalizeFlagName(name)] + unknownFlagsHandling := f.ParseErrorsAllowlist.getUnknownFlagsHandling() if !exists { switch { case name == "help": f.usage() return a, ErrHelp - case f.ParseErrorsAllowlist.UnknownFlags: + case unknownFlagsHandling == UnknownFlagsHandlingIgnoreUnknown: // --unknown=unknownval arg ... // we do not want to lose arg in this case if len(split) >= 2 { @@ -992,6 +1036,10 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin } return stripUnknownFlagValue(a), nil + case unknownFlagsHandling == UnknownFlagsHandlingPassUnknownToArgs: + return a, &unknownFlagError{ + UnknownFlags: s, + } default: err = f.fail(&NotExistError{name: name, messageType: flagUnknownFlagMessage}) return @@ -1037,12 +1085,14 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse flag, exists := f.shorthands[c] if !exists { + unknownFlagsHandling := f.ParseErrorsAllowlist.getUnknownFlagsHandling() + switch { case c == 'h': f.usage() err = ErrHelp return - case f.ParseErrorsAllowlist.UnknownFlags: + case unknownFlagsHandling == UnknownFlagsHandlingIgnoreUnknown: // '-f=arg arg ...' // we do not want to lose arg in this case if len(shorthands) > 2 && shorthands[1] == '=' { @@ -1052,6 +1102,20 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse outArgs = stripUnknownFlagValue(outArgs) return + case unknownFlagsHandling == UnknownFlagsHandlingPassUnknownToArgs: + // '-f=arg': pass all the argument + if len(shorthands) > 2 && shorthands[1] == '=' { + outShorts = "" + err = &unknownFlagError{ + UnknownFlags: shorthands, + } + return + } + // '-fgh': pass only the first switch + err = &unknownFlagError{ + UnknownFlags: shorthands[0:1], + } + return default: err = f.fail(&NotExistError{ name: string(c), @@ -1102,14 +1166,31 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse func (f *FlagSet) parseShortArg(s string, args []string, fn parseFunc) (a []string, err error) { a = args shorthands := s[1:] + var errUnknownFlagAll *unknownFlagError // "shorthands" can be a series of shorthand letters of flags (e.g. "-vvv"). for len(shorthands) > 0 { shorthands, a, err = f.parseSingleShortArg(shorthands, args, fn) if err != nil { - return + if errUnknownFlag, ok := err.(*unknownFlagError); ok { + // this means f.ParseErrorsAllowlist.UnknownFlagsHandling is set to UnknownFlagsHandlingPassUnknownToArgs + if errUnknownFlagAll == nil { + errUnknownFlagAll = &unknownFlagError{ + UnknownFlags: "-", + } + } + + errUnknownFlagAll.UnknownFlags = errUnknownFlagAll.UnknownFlags + + errUnknownFlag.UnknownFlags + err = nil + } else { + return + } } } + if errUnknownFlagAll != nil { + err = errUnknownFlagAll + } return } @@ -1139,7 +1220,13 @@ func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) { args, err = f.parseShortArg(s, args, fn) } if err != nil { - return + if errUnknownFlag, ok := err.(*unknownFlagError); ok { + // this means f.ParseErrorsAllowlist.UnknownFlagsHandling is set to UnknownFlagsHandlingPassUnknownToArgs + f.args = append(f.args, errUnknownFlag.UnknownFlags) + err = nil + } else { + return + } } } return diff --git a/flag_test.go b/flag_test.go index 2df3ea20..ccad6cfa 100644 --- a/flag_test.go +++ b/flag_test.go @@ -523,6 +523,111 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { } } +func testParseWithUnknownFlagsAndPassToArgs(f *FlagSet, t *testing.T) { + if f.Parsed() { + t.Fatal("f.Parse() = true before Parse") + } + f.ParseErrorsAllowlist.UnknownFlagsHandling = UnknownFlagsHandlingPassUnknownToArgs + f.SetInterspersed(true) + + f.BoolP("boola", "a", false, "bool value") + f.BoolP("boolb", "b", false, "bool2 value") + f.BoolP("boolc", "c", false, "bool3 value") + f.BoolP("boold", "d", false, "bool4 value") + f.BoolP("boole", "e", false, "bool4 value") + f.StringP("stringa", "s", "0", "string value") + f.StringP("stringz", "z", "0", "string value") + f.StringP("stringx", "x", "0", "string value") + f.StringP("stringy", "y", "0", "string value") + f.StringP("stringo", "o", "0", "string value") + f.Lookup("stringx").NoOptDefVal = "1" + args := []string{ + "-ab", + // -f and -g is unknown + "-fcgs=xx", + "--stringz=something", + "--unknown1", + "unknown1Value", + "-d=true", + "-x", + "--unknown2=unknown2Value", + "-u=unknown3Value", + "-p", + "unknown4Value", + "-q", //another unknown with bool value + "-y", + "ee", + "--unknown7=unknown7value", + "--stringo=ovalue", + "--unknown8=unknown8value", + "--boole", + "--unknown6", + "", + "-uuuuu", + "", + "--unknown10", + "--unknown11", + "arg0", + "arg1", + } + want := []string{ + "boola", "true", + "boolb", "true", + "boolc", "true", + "stringa", "xx", + "stringz", "something", + "boold", "true", + "stringx", "1", + "stringy", "ee", + "stringo", "ovalue", + "boole", "true", + } + wantArgs := []string{ + "-fg", + "--unknown1", + "unknown1Value", + "--unknown2=unknown2Value", + "-u=unknown3Value", + "-p", + "unknown4Value", + "-q", //another unknown with bool value + "--unknown7=unknown7value", + "--unknown8=unknown8value", + "--unknown6", + "", + "-uuuuu", + "", + "--unknown10", + "--unknown11", + "arg0", + "arg1", + } + got := []string{} + store := func(flag *Flag, value string) error { + got = append(got, flag.Name) + if len(value) > 0 { + got = append(got, value) + } + return nil + } + if err := f.ParseAll(args, store); err != nil { + t.Errorf("expected no error, got %s", err) + } + if !f.Parsed() { + t.Errorf("f.Parse() = false after Parse") + } + if !reflect.DeepEqual(got, want) { + t.Errorf("f.ParseAll() fail to restore the args") + t.Errorf("Got: %v", got) + t.Errorf("Want: %v", want) + } + if !reflect.DeepEqual(f.Args(), wantArgs) { + t.Errorf("f.ParseAll() fail to restore the args") + t.Errorf("Got: %v", f.Args()) + t.Errorf("Want: %v", wantArgs) + } +} + func TestShorthand(t *testing.T) { f := NewFlagSet("shorthand", ContinueOnError) if f.Parsed() { @@ -652,6 +757,10 @@ func TestIgnoreUnknownFlags(t *testing.T) { testParseWithUnknownFlags(GetCommandLine(), t) } +func TestIgnoreUnknownFlagsAndPassToArgs(t *testing.T) { + ResetForTesting(func() { t.Error("bad parse") }) + testParseWithUnknownFlagsAndPassToArgs(GetCommandLine(), t) +} func TestFlagSetParse(t *testing.T) { testParse(NewFlagSet("test", ContinueOnError), t) }