diff --git a/golangflag.go b/golangflag.go index f563907e..e62eab53 100644 --- a/golangflag.go +++ b/golangflag.go @@ -8,6 +8,7 @@ import ( goflag "flag" "reflect" "strings" + "time" ) // go test flags prefixes @@ -113,6 +114,38 @@ func (f *FlagSet) AddGoFlagSet(newSet *goflag.FlagSet) { f.addedGoFlagSets = append(f.addedGoFlagSets, newSet) } +// CopyToGoFlagSet will add all current flags to the given Go flag set. +// Deprecation remarks get copied into the usage description. +// Whenever possible, a flag gets added for which Go flags shows +// a proper type in the help message. +func (f *FlagSet) CopyToGoFlagSet(newSet *goflag.FlagSet) { + f.VisitAll(func(flag *Flag) { + usage := flag.Usage + if flag.Deprecated != "" { + usage += " (DEPRECATED: " + flag.Deprecated + ")" + } + + switch value := flag.Value.(type) { + case *stringValue: + newSet.StringVar((*string)(value), flag.Name, flag.DefValue, usage) + case *intValue: + newSet.IntVar((*int)(value), flag.Name, *(*int)(value), usage) + case *int64Value: + newSet.Int64Var((*int64)(value), flag.Name, *(*int64)(value), usage) + case *uintValue: + newSet.UintVar((*uint)(value), flag.Name, *(*uint)(value), usage) + case *uint64Value: + newSet.Uint64Var((*uint64)(value), flag.Name, *(*uint64)(value), usage) + case *durationValue: + newSet.DurationVar((*time.Duration)(value), flag.Name, *(*time.Duration)(value), usage) + case *float64Value: + newSet.Float64Var((*float64)(value), flag.Name, *(*float64)(value), usage) + default: + newSet.Var(flag.Value, flag.Name, usage) + } + }) +} + // ParseSkippedFlags explicitly Parses go test flags (i.e. the one starting with '-test.') with goflag.Parse(), // since by default those are skipped by pflag.Parse(). // Typical usage example: `ParseGoTestFlags(os.Args[1:], goflag.CommandLine)` @@ -125,3 +158,4 @@ func ParseSkippedFlags(osArgs []string, goFlagSet *goflag.FlagSet) error { } return goFlagSet.Parse(skippedFlags) } + diff --git a/golangflag_test.go b/golangflag_test.go index 2ecefefa..7309808d 100644 --- a/golangflag_test.go +++ b/golangflag_test.go @@ -7,6 +7,7 @@ package pflag import ( goflag "flag" "testing" + "time" ) func TestGoflags(t *testing.T) { @@ -59,3 +60,76 @@ func TestGoflags(t *testing.T) { t.Fatal("goflag.CommandLine.Parsed() return false after f.Parse() called") } } + +func TestToGoflags(t *testing.T) { + pfs := FlagSet{} + gfs := goflag.FlagSet{} + pfs.String("StringFlag", "String value", "String flag usage") + pfs.Int("IntFlag", 1, "Int flag usage") + pfs.Uint("UintFlag", 2, "Uint flag usage") + pfs.Int64("Int64Flag", 3, "Int64 flag usage") + pfs.Uint64("Uint64Flag", 4, "Uint64 flag usage") + pfs.Int8("Int8Flag", 5, "Int8 flag usage") + pfs.Float64("Float64Flag", 6.0, "Float64 flag usage") + pfs.Duration("DurationFlag", time.Second, "Duration flag usage") + pfs.Bool("BoolFlag", true, "Bool flag usage") + pfs.String("deprecated", "Deprecated value", "Deprecated flag usage") + pfs.MarkDeprecated("deprecated", "obsolete") + + pfs.CopyToGoFlagSet(&gfs) + + // Modify via pfs. Should be visible via gfs because both share the + // same values. + for name, value := range map[string]string{ + "StringFlag": "Modified String value", + "IntFlag": "11", + "UintFlag": "12", + "Int64Flag": "13", + "Uint64Flag": "14", + "Int8Flag": "15", + "Float64Flag": "16.0", + "BoolFlag": "false", + } { + pf := pfs.Lookup(name) + if pf == nil { + t.Errorf("%s: not found in pflag flag set", name) + continue + } + if err := pf.Value.Set(value); err != nil { + t.Errorf("error setting %s = %s: %v", name, value, err) + } + } + + // Check that all flags were added and share the same value. + pfs.VisitAll(func(pf *Flag) { + gf := gfs.Lookup(pf.Name) + if gf == nil { + t.Errorf("%s: not found in Go flag set", pf.Name) + return + } + if gf.Value.String() != pf.Value.String() { + t.Errorf("%s: expected value %v from Go flag set, got %v", + pf.Name, pf.Value, gf.Value) + return + } + }) + + // Check for unexpected additional flags. + gfs.VisitAll(func(gf *goflag.Flag) { + pf := gfs.Lookup(gf.Name) + if pf == nil { + t.Errorf("%s: not found in pflag flag set", gf.Name) + return + } + }) + + deprecated := gfs.Lookup("deprecated") + if deprecated == nil { + t.Error("deprecated: not found in Go flag set") + } else { + expectedUsage := "Deprecated flag usage (DEPRECATED: obsolete)" + if deprecated.Usage != expectedUsage { + t.Errorf("deprecation remark not added, expected usage %q, got %q", expectedUsage, deprecated.Usage) + } + } +}