From 6e070abaf055976f7501c0bec570700bb6f1de2a Mon Sep 17 00:00:00 2001 From: cornfeedhobo Date: Wed, 30 Sep 2020 10:17:15 -0500 Subject: [PATCH] revise exposure of unknowns - add VisitUnknowns and remove previous implementation - add Get methods to match Visit methods - change ResetForTesting to use standard constructor --- export_test.go | 7 +- flag.go | 236 +++++++++++++++++++++++-------------------------- flag_test.go | 39 ++++---- 3 files changed, 133 insertions(+), 149 deletions(-) diff --git a/export_test.go b/export_test.go index 404b0995..3869c9f8 100644 --- a/export_test.go +++ b/export_test.go @@ -14,11 +14,8 @@ import ( // After calling ResetForTesting, parse errors in flag handling will not // exit the program. func ResetForTesting(usage func()) { - CommandLine = &FlagSet{ - name: os.Args[0], - errorHandling: ContinueOnError, - output: ioutil.Discard, - } + CommandLine = NewFlagSet(os.Args[0], ContinueOnError) + CommandLine.output = ioutil.Discard Usage = usage } diff --git a/flag.go b/flag.go index b6cab784..d783c7d4 100644 --- a/flag.go +++ b/flag.go @@ -72,9 +72,15 @@ type FlagSet struct { output io.Writer // nil means stderr; use Output() accessor interspersed bool // allow interspersed option/non-option args normalizeNameFunc func(f *FlagSet, name string) NormalizedName - unknownFlags *[]string addedGoFlagSets []*goflag.FlagSet + unknownFlags []*UnknownFlag +} + +// A UnknownFlag represents the state of a flag that is not expected. +type UnknownFlag struct { + Name string // name, as it appears on command line + Value string // argument, if provided } // A Flag represents the state of a flag. @@ -193,20 +199,18 @@ func (f *FlagSet) VisitAll(fn func(*Flag)) { if len(f.formal) == 0 { return } - - var flags []*Flag - if f.SortFlags { - if len(f.formal) != len(f.sortedFormal) { - f.sortedFormal = sortFlags(f.formal) - } - flags = f.sortedFormal - } else { - flags = f.orderedFormal + for _, flag := range f.GetAllFlags() { + fn(flag) } +} - for _, flag := range flags { - fn(flag) +// GetAllFlags return the flags in lexicographical order or +// in primordial order if f.SortFlags is false. +func (f *FlagSet) GetAllFlags() []*Flag { + if f.SortFlags && len(f.formal) != len(f.sortedFormal) { + f.sortedFormal = sortFlags(f.formal) } + return f.sortedFormal } // HasFlags returns a bool to indicate if the FlagSet has any flags defined. @@ -232,6 +236,12 @@ func VisitAll(fn func(*Flag)) { CommandLine.VisitAll(fn) } +// GetAllFlags return the flags in lexicographical order or +// in primordial order if f.SortFlags is false. +func GetAllFlags() []*Flag { + return CommandLine.GetAllFlags() +} + // Visit visits the flags in lexicographical order or // in primordial order if f.SortFlags is false, calling fn for each. // It visits only those flags that have been set. @@ -239,20 +249,18 @@ func (f *FlagSet) Visit(fn func(*Flag)) { if len(f.actual) == 0 { return } - - var flags []*Flag - if f.SortFlags { - if len(f.actual) != len(f.sortedActual) { - f.sortedActual = sortFlags(f.actual) - } - flags = f.sortedActual - } else { - flags = f.orderedActual + for _, flag := range f.GetFlags() { + fn(flag) } +} - for _, flag := range flags { - fn(flag) +// GetFlags return the flags in lexicographical order or +// in primordial order if f.SortFlags is false. +func (f *FlagSet) GetFlags() []*Flag { + if f.SortFlags && len(f.actual) != len(f.sortedActual) { + f.sortedActual = sortFlags(f.actual) } + return f.sortedActual } // Visit visits the command-line flags in lexicographical order or @@ -262,6 +270,45 @@ func Visit(fn func(*Flag)) { CommandLine.Visit(fn) } +// GetFlags return the flags in lexicographical order or +// in primordial order if f.SortFlags is false. +func GetFlags() []*Flag { + return CommandLine.GetFlags() +} + +// VisitUnknowns visits all the flags that have not been registered. +func (f *FlagSet) VisitUnknowns(fn func(*UnknownFlag)) { + if len(f.unknownFlags) == 0 { + return + } + for _, flag := range f.unknownFlags { + fn(flag) + } +} + +// GetUnknownFlags returns unknown flags found during Parse. +// This requires ParseErrorsWhitelist.UnknownFlags to be set so that +// parsing does not abort on the first unknown flag. +func (f *FlagSet) GetUnknownFlags() []*UnknownFlag { + return f.unknownFlags +} + +func (f *FlagSet) addUnknownFlag(name, value string) { + f.unknownFlags = append(f.unknownFlags, &UnknownFlag{name, value}) +} + +// VisitUnknowns visits all the flags that have not been registered. +func VisitUnknowns(fn func(*UnknownFlag)) { + CommandLine.VisitUnknowns(fn) +} + +// GetUnknownFlags returns unknown flags found during Parse. +// This requires ParseErrorsWhitelist.UnknownFlags to be set so that +// parsing does not abort on the first unknown flag. +func GetUnknownFlags() []*UnknownFlag { + return CommandLine.GetUnknownFlags() +} + // Lookup returns the Flag structure of the named flag, returning nil if none exists. func (f *FlagSet) Lookup(name string) *Flag { return f.lookup(f.normalizeFlagName(name)) @@ -883,36 +930,6 @@ func (f *FlagSet) usage() { } } -func (f *FlagSet) addUnknownFlag(s string) { - if f.unknownFlags == nil { - f.unknownFlags = new([]string) - } - *f.unknownFlags = append(*f.unknownFlags, s) -} - -//--unknown (args will be empty) -//--unknown --next-flag ... (args will be --next-flag ...) -//--unknown arg ... (args will be arg ...) -func (f *FlagSet) stripUnknownFlagValue(args []string) []string { - if len(args) == 0 { - //--unknown - return args - } - - first := args[0] - if len(first) > 0 && first[0] == '-' { - //--unknown --next-flag ... - return args - } - - //--unknown arg ... (args will be arg ...) - if len(args) > 1 { - f.addUnknownFlag(args[0]) - return args[1:] - } - return nil -} - func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []string, err error) { a = args name := s[2:] @@ -926,21 +943,12 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin flag, exists := f.formal[f.normalizeFlagName(name)] if !exists || flag.ShorthandOnly { - switch { - case !f.DisableBuiltinHelp && name == "help": + if !f.DisableBuiltinHelp && name == "help" { f.usage() err = ErrHelp return - case f.ParseErrorsWhitelist.UnknownFlags: - f.addUnknownFlag(s) - // --unknown=unknownval arg ... - // we do not want to lose arg in this case - if len(split) >= 2 { - return a, nil - } - - return f.stripUnknownFlagValue(a), nil - default: + } + if !f.ParseErrorsWhitelist.UnknownFlags { err = f.failf("unknown flag: --%s", name) return } @@ -950,16 +958,23 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin if len(split) == 2 { // '--flag=arg' value = split[1] - } else if flag.NoOptDefVal != "" { + } else if exists && flag.NoOptDefVal != "" { // '--flag' (arg was optional) value = flag.NoOptDefVal } else if len(a) > 0 { // '--flag arg' - value = a[0] - a = a[1:] - } else { - // '--flag' (arg was required) - err = f.failf("flag needs an argument: %s", s) + if !exists && strings.HasPrefix(a[0], "-") { + value = "" + } else { + value = a[0] + a = a[1:] + } + } else if f.ParseErrorsWhitelist.UnknownFlags { + value = "" + } + + if !exists && f.ParseErrorsWhitelist.UnknownFlags { + f.addUnknownFlag(name, value) return } @@ -978,26 +993,12 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse flag, exists := f.shorthands[c] if !exists { - switch { - case !f.DisableBuiltinHelp && c == 'h': + if !f.DisableBuiltinHelp && c == 'h' { f.usage() err = ErrHelp return - case f.ParseErrorsWhitelist.UnknownFlags: - // '-f=arg arg ...' - // we do not want to lose arg in this case - if len(shorthands) > 2 && shorthands[1] == '=' { - f.addUnknownFlag("-" + shorthands) - outShorts = "" - return - } - - f.addUnknownFlag("-" + string(c)) - if len(outShorts) == 0 { - outArgs = f.stripUnknownFlagValue(outArgs) - } - return - default: + } + if !f.ParseErrorsWhitelist.UnknownFlags { err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands) return } @@ -1008,18 +1009,33 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse // '-f=arg' value = shorthands[2:] outShorts = "" - } else if flag.NoOptDefVal != "" { + } else if exists && flag.NoOptDefVal != "" { // '-f' (arg was optional) value = flag.NoOptDefVal } else if len(shorthands) > 1 { // '-farg' - value = shorthands[1:] - outShorts = "" + if next := f.ShorthandLookup(string(shorthands[1])); next == nil { + // preserve arg if it's a known flag + value = shorthands[1:] + outShorts = "" + } } else if len(args) > 0 { - // '-f arg' - value = args[0] - outArgs = args[1:] - } else { + if !exists && strings.HasPrefix(args[0], "-") { + value = "" + } else { + value = args[0] + outArgs = args[1:] + } + } else if f.ParseErrorsWhitelist.UnknownFlags { + value = "" + } + + if !exists && f.ParseErrorsWhitelist.UnknownFlags { + f.addUnknownFlag(string(c), value) + return + } + + if flag.NoOptDefVal == "" && value == "" { // '-f' (arg was required) err = f.failf("flag needs an argument: %q in -%s", c, shorthands) return @@ -1150,21 +1166,6 @@ func (f *FlagSet) Parsed() bool { return f.parsed } -// SetUnknownFlags sets the store for unknown flags found during Parse. -// The argument s points to a slice variable in which to store the values. -// This requires ParseErrorsWhitelist.UnknownFlags to be set so that -// parsing does not abort on the first unknown flag. -func (f *FlagSet) SetUnknownFlags(s *[]string) { - f.unknownFlags = s -} - -// GetUnknownFlags returns unknown flags found during Parse. -// This requires ParseErrorsWhitelist.UnknownFlags to be set so that -// parsing does not abort on the first unknown flag. -func (f *FlagSet) GetUnknownFlags() *[]string { - return f.unknownFlags -} - // Parse parses the command-line flags from os.Args[1:]. Must be called // after all flags are defined and before flags are accessed by the program. func Parse() { @@ -1190,21 +1191,6 @@ func Parsed() bool { return CommandLine.Parsed() } -// SetUnknownFlags sets the store for unknown flags found during Parse. -// The argument s points to a slice variable in which to store the values. -// This requires ParseErrorsWhitelist.UnknownFlags to be set so that -// parsing does not abort on the first unknown flag. -func SetUnknownFlags(s *[]string) { - CommandLine.SetUnknownFlags(s) -} - -// GetUnknownFlags returns unknown flags found during Parse. -// This requires ParseErrorsWhitelist.UnknownFlags to be set so that -// parsing does not abort on the first unknown flag. -func GetUnknownFlags() *[]string { - return CommandLine.GetUnknownFlags() -} - // CommandLine is the default set of command-line flags, parsed from os.Args. var CommandLine = NewFlagSet(os.Args[0], ExitOnError) diff --git a/flag_test.go b/flag_test.go index 1d5b5679..fc69b443 100644 --- a/flag_test.go +++ b/flag_test.go @@ -409,14 +409,13 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { t.Error("f.Parse() = true before Parse") } f.ParseErrorsWhitelist.UnknownFlags = true - var unknownFlags []string - f.SetUnknownFlags(&unknownFlags) f.BoolP("boola", "a", false, "bool value") f.BoolP("boolb", "b", false, "bool2 value") f.BoolP("boolc", "c", false, "bool3 value") f.BoolP("boold", "d", false, "bool4 value") - f.BoolP("boole", "e", false, "bool4 value") + f.BoolP("boole", "e", false, "bool5 value") + f.BoolP("boolf", "f", false, "bool6 value") f.StringP("stringa", "s", "0", "string value") f.StringP("stringz", "z", "0", "string value") f.StringP("stringx", "x", "0", "string value") @@ -444,7 +443,7 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { "--boole", "--unknown6", "", - "-uuuuu", + "-ufuuuu", "", "--unknown10", "--unknown11", @@ -460,19 +459,21 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { "stringy", "ee", "stringo", "ovalue", "boole", "true", - } - wantUnknowns := []string{ - "--unknown1", "unknown1Value", - "--unknown2=unknown2Value", - "-u=unknown3Value", - "-p", "unknown4Value", - "-q", - "--unknown7=unknown7value", - "--unknown8=unknown8value", - "--unknown6", "", - "-u", "-u", "-u", "-u", "-u", "", - "--unknown10", - "--unknown11", + "boolf", "true", + } + wantUnknowns := []*UnknownFlag{ + {"unknown1", "unknown1Value"}, + {"unknown2", "unknown2Value"}, + {"u", "unknown3Value"}, + {"p", "unknown4Value"}, + {"q", ""}, + {"unknown7", "unknown7value"}, + {"unknown8", "unknown8value"}, + {"unknown6", ""}, + {"u", ""}, + {"u", "uuu"}, + {"unknown10", ""}, + {"unknown11", ""}, } got := []string{} store := func(flag *Flag, value string) error { @@ -493,9 +494,9 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { t.Errorf("Got: %v", got) t.Errorf("Want: %v", want) } - if !reflect.DeepEqual(unknownFlags, wantUnknowns) { + if !reflect.DeepEqual(f.GetUnknownFlags(), wantUnknowns) { t.Errorf("f.Parse() failed to enumerate the unknown flags") - t.Errorf("Got: %v", unknownFlags) + t.Errorf("Got: %v", f.GetUnknownFlags()) t.Errorf("Want: %v", wantUnknowns) } }