From 44aa4aa86e19d69d62479bfa1a244d4858a62ee5 Mon Sep 17 00:00:00 2001 From: Patrick Ohly Date: Wed, 15 Sep 2021 12:49:31 +0200 Subject: [PATCH] add CopyToGoFlagSet This is useful for programs which want to define some flags with pflag (for example, in external packages) but still need to use Go flag command line parsing to preserve backward compatibility with previous releases, in particular support for single-dash flags. Without this in pflag, such tools have to resort to copying via the public API, which leads to less useful help messages (type of basic values will be unknown). --- golangflag.go | 34 +++++++++++++++++++++ golangflag_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+) diff --git a/golangflag.go b/golangflag.go index f563907e..e62eab53 100644 --- a/golangflag.go +++ b/golangflag.go @@ -8,6 +8,7 @@ import ( goflag "flag" "reflect" "strings" + "time" ) // go test flags prefixes @@ -113,6 +114,38 @@ func (f *FlagSet) AddGoFlagSet(newSet *goflag.FlagSet) { f.addedGoFlagSets = append(f.addedGoFlagSets, newSet) } +// CopyToGoFlagSet will add all current flags to the given Go flag set. +// Deprecation remarks get copied into the usage description. +// Whenever possible, a flag gets added for which Go flags shows +// a proper type in the help message. +func (f *FlagSet) CopyToGoFlagSet(newSet *goflag.FlagSet) { + f.VisitAll(func(flag *Flag) { + usage := flag.Usage + if flag.Deprecated != "" { + usage += " (DEPRECATED: " + flag.Deprecated + ")" + } + + switch value := flag.Value.(type) { + case *stringValue: + newSet.StringVar((*string)(value), flag.Name, flag.DefValue, usage) + case *intValue: + newSet.IntVar((*int)(value), flag.Name, *(*int)(value), usage) + case *int64Value: + newSet.Int64Var((*int64)(value), flag.Name, *(*int64)(value), usage) + case *uintValue: + newSet.UintVar((*uint)(value), flag.Name, *(*uint)(value), usage) + case *uint64Value: + newSet.Uint64Var((*uint64)(value), flag.Name, *(*uint64)(value), usage) + case *durationValue: + newSet.DurationVar((*time.Duration)(value), flag.Name, *(*time.Duration)(value), usage) + case *float64Value: + newSet.Float64Var((*float64)(value), flag.Name, *(*float64)(value), usage) + default: + newSet.Var(flag.Value, flag.Name, usage) + } + }) +} + // ParseSkippedFlags explicitly Parses go test flags (i.e. the one starting with '-test.') with goflag.Parse(), // since by default those are skipped by pflag.Parse(). // Typical usage example: `ParseGoTestFlags(os.Args[1:], goflag.CommandLine)` @@ -125,3 +158,4 @@ func ParseSkippedFlags(osArgs []string, goFlagSet *goflag.FlagSet) error { } return goFlagSet.Parse(skippedFlags) } + diff --git a/golangflag_test.go b/golangflag_test.go index 2ecefefa..7309808d 100644 --- a/golangflag_test.go +++ b/golangflag_test.go @@ -7,6 +7,7 @@ package pflag import ( goflag "flag" "testing" + "time" ) func TestGoflags(t *testing.T) { @@ -59,3 +60,76 @@ func TestGoflags(t *testing.T) { t.Fatal("goflag.CommandLine.Parsed() return false after f.Parse() called") } } + +func TestToGoflags(t *testing.T) { + pfs := FlagSet{} + gfs := goflag.FlagSet{} + pfs.String("StringFlag", "String value", "String flag usage") + pfs.Int("IntFlag", 1, "Int flag usage") + pfs.Uint("UintFlag", 2, "Uint flag usage") + pfs.Int64("Int64Flag", 3, "Int64 flag usage") + pfs.Uint64("Uint64Flag", 4, "Uint64 flag usage") + pfs.Int8("Int8Flag", 5, "Int8 flag usage") + pfs.Float64("Float64Flag", 6.0, "Float64 flag usage") + pfs.Duration("DurationFlag", time.Second, "Duration flag usage") + pfs.Bool("BoolFlag", true, "Bool flag usage") + pfs.String("deprecated", "Deprecated value", "Deprecated flag usage") + pfs.MarkDeprecated("deprecated", "obsolete") + + pfs.CopyToGoFlagSet(&gfs) + + // Modify via pfs. Should be visible via gfs because both share the + // same values. + for name, value := range map[string]string{ + "StringFlag": "Modified String value", + "IntFlag": "11", + "UintFlag": "12", + "Int64Flag": "13", + "Uint64Flag": "14", + "Int8Flag": "15", + "Float64Flag": "16.0", + "BoolFlag": "false", + } { + pf := pfs.Lookup(name) + if pf == nil { + t.Errorf("%s: not found in pflag flag set", name) + continue + } + if err := pf.Value.Set(value); err != nil { + t.Errorf("error setting %s = %s: %v", name, value, err) + } + } + + // Check that all flags were added and share the same value. + pfs.VisitAll(func(pf *Flag) { + gf := gfs.Lookup(pf.Name) + if gf == nil { + t.Errorf("%s: not found in Go flag set", pf.Name) + return + } + if gf.Value.String() != pf.Value.String() { + t.Errorf("%s: expected value %v from Go flag set, got %v", + pf.Name, pf.Value, gf.Value) + return + } + }) + + // Check for unexpected additional flags. + gfs.VisitAll(func(gf *goflag.Flag) { + pf := gfs.Lookup(gf.Name) + if pf == nil { + t.Errorf("%s: not found in pflag flag set", gf.Name) + return + } + }) + + deprecated := gfs.Lookup("deprecated") + if deprecated == nil { + t.Error("deprecated: not found in Go flag set") + } else { + expectedUsage := "Deprecated flag usage (DEPRECATED: obsolete)" + if deprecated.Usage != expectedUsage { + t.Errorf("deprecation remark not added, expected usage %q, got %q", expectedUsage, deprecated.Usage) + } + } +}