diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index bfdbf99..986cb57 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -7,27 +7,36 @@ jobs: name: Test and Build strategy: matrix: - go-version: [1.16.x] + go-version: [1.24.x] os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: - name: Install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} + - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 + + - name: Run make + run: make update && make correct + - name: Test - run: go test -v ./... -race + run: go test -v . -race + coverage: name: Code Coverage runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 + + - name: Run make + run: make update && make correct + - name: Generage coverage - run: go test -v ./... -coverprofile=coverage.txt -covermode=atomic + run: go test -v . -coverprofile=coverage.txt -covermode=atomic + - name: Publish coverage - uses: codecov/codecov-action@v1 - with: - file: ./coverage.txt + uses: codecov/codecov-action@v5 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/Makefile b/Makefile index cd90d17..db09e59 100644 --- a/Makefile +++ b/Makefile @@ -6,8 +6,17 @@ include go.mk all: update update: - curl -LO https://raw.githubusercontent.com/golang/go/master/src/flag/example_test.go - curl -LO https://raw.githubusercontent.com/golang/go/master/src/flag/example_value_test.go - curl -LO https://raw.githubusercontent.com/golang/go/master/src/flag/export_test.go - curl -LO https://raw.githubusercontent.com/golang/go/master/src/flag/flag.go - curl -LO https://raw.githubusercontent.com/golang/go/master/src/flag/flag_test.go + git clone --depth=1 --no-checkout https://github.com/golang/go third_party/go + git submodule add https://github.com/golang/go/ third_party/go + git submodule absorbgitdirs + git -C third_party/go config core.sparseCheckout true + cp test/sparse-checkout .git/modules/third_party/go/info/sparse-checkout + git submodule update --force --checkout third_party/go + +correct: + @cp -R third_party/go/src/internal/* internal +ifeq ($(OS), Darwin) + find internal/* -type f -name "*.go" | xargs -I {} sed -i '' -e "s/\"internal\//\"github.com\/jnovack\/flag\/internal\//g" {} +else + find internal/* -type f -name "*.go" | xargs -I {} sed -i'' -e "s/\"internal\//\"github.com\/jnovack\/flag\/internal\//g" {} +endif diff --git a/README.md b/README.md index d49996d..1c83646 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ flag.String(flag.DefaultConfigFlagname, "", "path to config file") Run the command: ```go -$ go run ./gopher.go -config ./gopher.conf +go run ./gopher.go -config ./gopher.conf ``` The default flag name for the configuration file is "config" and can be diff --git a/example_test.go b/example_test.go index 04a0d20..088447d 100644 --- a/example_test.go +++ b/example_test.go @@ -78,6 +78,8 @@ func Example() { // to enable the flag package to see the flags defined there, one must // execute, typically at the start of main (not init!): // flag.Parse() - // We don't run it here because this is not a main function and - // the testing suite has already parsed the flags. + // We don't call it here because this code is a function called "Example" + // that is part of the testing suite for the package, which has already + // parsed the flags. When viewed at pkg.go.dev, however, the function is + // renamed to "main" and it could be run as a standalone example. } diff --git a/export_test.go b/export_test.go index 838cfaf..9ef93ed 100644 --- a/export_test.go +++ b/export_test.go @@ -4,7 +4,10 @@ package flag -import "os" +import ( + "io" + "os" +) // Additional routines compiled into the package only during testing. @@ -15,6 +18,7 @@ var DefaultUsage = Usage // exit the program. func ResetForTesting(usage func()) { CommandLine = NewFlagSet(os.Args[0], ContinueOnError) + CommandLine.SetOutput(io.Discard) CommandLine.Usage = commandLineUsage Usage = usage } diff --git a/extras.go b/extras.go index aedb075..a1f1389 100644 --- a/extras.go +++ b/extras.go @@ -139,7 +139,7 @@ func (f *FlagSet) ParseFile(path string) error { } } - if hasValue == false { + if !hasValue { name = line } diff --git a/flag.go b/flag.go index 539ff63..e3998b1 100644 --- a/flag.go +++ b/flag.go @@ -3,77 +3,94 @@ // license that can be found in the LICENSE file. /* - Package flag implements command-line flag parsing. +Package flag implements command-line flag parsing. - Usage +# Usage - Define flags using flag.String(), Bool(), Int(), etc. +Define flags using [flag.String], [Bool], [Int], etc. - This declares an integer flag, -n, stored in the pointer nFlag, with type *int: - import "flag" - var nFlag = flag.Int("n", 1234, "help message for flag n") - If you like, you can bind the flag to a variable using the Var() functions. - var flagvar int - func init() { - flag.IntVar(&flagvar, "flagname", 1234, "help message for flagname") - } - Or you can create custom flags that satisfy the Value interface (with - pointer receivers) and couple them to flag parsing by - flag.Var(&flagVal, "name", "help message for flagname") - For such flags, the default value is just the initial value of the variable. - - After all flags are defined, call - flag.Parse() - to parse the command line into the defined flags. - - Flags may then be used directly. If you're using the flags themselves, - they are all pointers; if you bind to variables, they're values. - fmt.Println("ip has value ", *ip) - fmt.Println("flagvar has value ", flagvar) - - After parsing, the arguments following the flags are available as the - slice flag.Args() or individually as flag.Arg(i). - The arguments are indexed from 0 through flag.NArg()-1. - - Command line flag syntax - - The following forms are permitted: - - -flag - -flag=x - -flag x // non-boolean flags only - One or two minus signs may be used; they are equivalent. - The last form is not permitted for boolean flags because the - meaning of the command - cmd -x * - where * is a Unix shell wildcard, will change if there is a file - called 0, false, etc. You must use the -flag=false form to turn - off a boolean flag. - - Flag parsing stops just before the first non-flag argument - ("-" is a non-flag argument) or after the terminator "--". - - Integer flags accept 1234, 0664, 0x1234 and may be negative. - Boolean flags may be: - 1, 0, t, f, T, F, true, false, TRUE, FALSE, True, False - Duration flags accept any input valid for time.ParseDuration. - - The default set of command-line flags is controlled by - top-level functions. The FlagSet type allows one to define - independent sets of flags, such as to implement subcommands - in a command-line interface. The methods of FlagSet are - analogous to the top-level functions for the command-line - flag set. +This declares an integer flag, -n, stored in the pointer nFlag, with type *int: + + import "flag" + var nFlag = flag.Int("n", 1234, "help message for flag n") + +If you like, you can bind the flag to a variable using the Var() functions. + + var flagvar int + func init() { + flag.IntVar(&flagvar, "flagname", 1234, "help message for flagname") + } + +Or you can create custom flags that satisfy the Value interface (with +pointer receivers) and couple them to flag parsing by + + flag.Var(&flagVal, "name", "help message for flagname") + +For such flags, the default value is just the initial value of the variable. + +After all flags are defined, call + + flag.Parse() + +to parse the command line into the defined flags. + +Flags may then be used directly. If you're using the flags themselves, +they are all pointers; if you bind to variables, they're values. + + fmt.Println("ip has value ", *ip) + fmt.Println("flagvar has value ", flagvar) + +After parsing, the arguments following the flags are available as the +slice [flag.Args] or individually as [flag.Arg](i). +The arguments are indexed from 0 through [flag.NArg]-1. + +# Command line flag syntax + +The following forms are permitted: + + -flag + --flag // double dashes are also permitted + -flag=x + -flag x // non-boolean flags only + +One or two dashes may be used; they are equivalent. +The last form is not permitted for boolean flags because the +meaning of the command + + cmd -x * + +where * is a Unix shell wildcard, will change if there is a file +called 0, false, etc. You must use the -flag=false form to turn +off a boolean flag. + +Flag parsing stops just before the first non-flag argument +("-" is a non-flag argument) or after the terminator "--". + +Integer flags accept 1234, 0664, 0x1234 and may be negative. +Boolean flags may be: + + 1, 0, t, f, T, F, true, false, TRUE, FALSE, True, False + +Duration flags accept any input valid for time.ParseDuration. + +The default set of command-line flags is controlled by +top-level functions. The [FlagSet] type allows one to define +independent sets of flags, such as to implement subcommands +in a command-line interface. The methods of [FlagSet] are +analogous to the top-level functions for the command-line +flag set. */ package flag import ( + "encoding" "errors" "fmt" "io" "os" "reflect" - "sort" + "runtime" + "slices" "strconv" "strings" "time" @@ -122,7 +139,7 @@ func (b *boolValue) Set(s string) error { return err } -func (b *boolValue) Get() interface{} { return bool(*b) } +func (b *boolValue) Get() any { return bool(*b) } func (b *boolValue) String() string { return strconv.FormatBool(bool(*b)) } @@ -152,7 +169,7 @@ func (i *intValue) Set(s string) error { return err } -func (i *intValue) Get() interface{} { return int(*i) } +func (i *intValue) Get() any { return int(*i) } func (i *intValue) String() string { return strconv.Itoa(int(*i)) } @@ -173,7 +190,7 @@ func (i *int64Value) Set(s string) error { return err } -func (i *int64Value) Get() interface{} { return int64(*i) } +func (i *int64Value) Get() any { return int64(*i) } func (i *int64Value) String() string { return strconv.FormatInt(int64(*i), 10) } @@ -194,7 +211,7 @@ func (i *uintValue) Set(s string) error { return err } -func (i *uintValue) Get() interface{} { return uint(*i) } +func (i *uintValue) Get() any { return uint(*i) } func (i *uintValue) String() string { return strconv.FormatUint(uint64(*i), 10) } @@ -215,7 +232,7 @@ func (i *uint64Value) Set(s string) error { return err } -func (i *uint64Value) Get() interface{} { return uint64(*i) } +func (i *uint64Value) Get() any { return uint64(*i) } func (i *uint64Value) String() string { return strconv.FormatUint(uint64(*i), 10) } @@ -232,7 +249,7 @@ func (s *stringValue) Set(val string) error { return nil } -func (s *stringValue) Get() interface{} { return string(*s) } +func (s *stringValue) Get() any { return string(*s) } func (s *stringValue) String() string { return string(*s) } @@ -253,7 +270,7 @@ func (f *float64Value) Set(s string) error { return err } -func (f *float64Value) Get() interface{} { return float64(*f) } +func (f *float64Value) Get() any { return float64(*f) } func (f *float64Value) String() string { return strconv.FormatFloat(float64(*f), 'g', -1, 64) } @@ -274,16 +291,62 @@ func (d *durationValue) Set(s string) error { return err } -func (d *durationValue) Get() interface{} { return time.Duration(*d) } +func (d *durationValue) Get() any { return time.Duration(*d) } func (d *durationValue) String() string { return (*time.Duration)(d).String() } +// -- encoding.TextUnmarshaler Value +type textValue struct{ p encoding.TextUnmarshaler } + +func newTextValue(val encoding.TextMarshaler, p encoding.TextUnmarshaler) textValue { + ptrVal := reflect.ValueOf(p) + if ptrVal.Kind() != reflect.Ptr { + panic("variable value type must be a pointer") + } + defVal := reflect.ValueOf(val) + if defVal.Kind() == reflect.Ptr { + defVal = defVal.Elem() + } + if defVal.Type() != ptrVal.Type().Elem() { + panic(fmt.Sprintf("default type does not match variable type: %v != %v", defVal.Type(), ptrVal.Type().Elem())) + } + ptrVal.Elem().Set(defVal) + return textValue{p} +} + +func (v textValue) Set(s string) error { + return v.p.UnmarshalText([]byte(s)) +} + +func (v textValue) Get() any { + return v.p +} + +func (v textValue) String() string { + if m, ok := v.p.(encoding.TextMarshaler); ok { + if b, err := m.MarshalText(); err == nil { + return string(b) + } + } + return "" +} + +// -- func Value type funcValue func(string) error func (f funcValue) Set(s string) error { return f(s) } func (f funcValue) String() string { return "" } +// -- boolFunc Value +type boolFuncValue func(string) error + +func (f boolFuncValue) Set(s string) error { return f(s) } + +func (f boolFuncValue) String() string { return "" } + +func (f boolFuncValue) IsBoolFlag() bool { return true } + // Value is the interface to the dynamic value stored in a flag. // (The default value is represented as a string.) // @@ -292,26 +355,26 @@ func (f funcValue) String() string { return "" } // rather than using the next command-line argument. // // Set is called once, in command line order, for each flag present. -// The flag package may call the String method with a zero-valued receiver, +// The flag package may call the [String] method with a zero-valued receiver, // such as a nil pointer. type Value interface { String() string Set(string) error } -// Getter is an interface that allows the contents of a Value to be retrieved. -// It wraps the Value interface, rather than being part of it, because it -// appeared after Go 1 and its compatibility rules. All Value types provided -// by this package satisfy the Getter interface, except the type used by Func. +// Getter is an interface that allows the contents of a [Value] to be retrieved. +// It wraps the [Value] interface, rather than being part of it, because it +// appeared after Go 1 and its compatibility rules. All [Value] types provided +// by this package satisfy the [Getter] interface, except the type used by [Func]. type Getter interface { Value - Get() interface{} + Get() any } -// ErrorHandling defines how FlagSet.Parse behaves if the parse fails. +// ErrorHandling defines how [FlagSet.Parse] behaves if the parse fails. type ErrorHandling int -// These constants cause FlagSet.Parse to behave as described if the parse fails. +// These constants cause [FlagSet.Parse] to behave as described if the parse fails. const ( ContinueOnError ErrorHandling = iota // Return a descriptive error. ExitOnError // Call os.Exit(2) or for -h/-help Exit(0). @@ -319,9 +382,9 @@ const ( ) // A FlagSet represents a set of defined flags. The zero value of a FlagSet -// has no name and has ContinueOnError error handling. +// has no name and has [ContinueOnError] error handling. // -// Flag names must be unique within a FlagSet. An attempt to define a flag whose +// [Flag] names must be unique within a FlagSet. An attempt to define a flag whose // name is already in use will cause a panic. type FlagSet struct { // Usage is the function called when an error occurs while parsing flags. @@ -338,7 +401,8 @@ type FlagSet struct { envPrefix string // prefix to all env variable names /* jnovack/flag */ args []string // arguments after flags errorHandling ErrorHandling - output io.Writer // nil means stderr; use Output() accessor + output io.Writer // nil means stderr; use Output() accessor + undef map[string]string // flags which didn't exist at the time of Set } // A Flag represents the state of a flag. @@ -357,13 +421,13 @@ func sortFlags(flags map[string]*Flag) []*Flag { result[i] = f i++ } - sort.Slice(result, func(i, j int) bool { - return result[i].Name < result[j].Name + slices.SortFunc(result, func(a, b *Flag) int { + return strings.Compare(a.Name, b.Name) }) return result } -// Output returns the destination for usage and error messages. os.Stderr is returned if +// Output returns the destination for usage and error messages. [os.Stderr] is returned if // output was not set or was set to nil. func (f *FlagSet) Output() io.Writer { if f.output == nil { @@ -383,7 +447,7 @@ func (f *FlagSet) ErrorHandling() ErrorHandling { } // SetOutput sets the destination for usage and error messages. -// If output is nil, os.Stderr is used. +// If output is nil, [os.Stderr] is used. func (f *FlagSet) SetOutput(output io.Writer) { f.output = output } @@ -416,12 +480,12 @@ func Visit(fn func(*Flag)) { CommandLine.Visit(fn) } -// Lookup returns the Flag structure of the named flag, returning nil if none exists. +// Lookup returns the [Flag] structure of the named flag, returning nil if none exists. func (f *FlagSet) Lookup(name string) *Flag { return f.formal[name] } -// Lookup returns the Flag structure of the named command-line flag, +// Lookup returns the [Flag] structure of the named command-line flag, // returning nil if none exists. func Lookup(name string) *Flag { return CommandLine.formal[name] @@ -429,8 +493,29 @@ func Lookup(name string) *Flag { // Set sets the value of the named flag. func (f *FlagSet) Set(name, value string) error { + return f.set(name, value) +} +func (f *FlagSet) set(name, value string) error { flag, ok := f.formal[name] if !ok { + // Remember that a flag that isn't defined is being set. + // We return an error in this case, but in addition if + // subsequently that flag is defined, we want to panic + // at the definition point. + // This is a problem which occurs if both the definition + // and the Set call are in init code and for whatever + // reason the init code changes evaluation order. + // See issue 57411. + _, file, line, ok := runtime.Caller(2) + if !ok { + file = "?" + line = 0 + } + if f.undef == nil { + f.undef = map[string]string{} + } + f.undef[name] = fmt.Sprintf("%s:%d", file, line) + return fmt.Errorf("no such flag -%v", name) } err := flag.Value.Set(value) @@ -446,23 +531,34 @@ func (f *FlagSet) Set(name, value string) error { // Set sets the value of the named command-line flag. func Set(name, value string) error { - return CommandLine.Set(name, value) + return CommandLine.set(name, value) } // isZeroValue determines whether the string represents the zero // value for a flag. -func isZeroValue(flag *Flag, value string) bool { +func isZeroValue(flag *Flag, value string) (ok bool, err error) { // Build a zero value of the flag's Value type, and see if the // result of calling its String method equals the value passed in. // This works unless the Value type is itself an interface type. typ := reflect.TypeOf(flag.Value) var z reflect.Value - if typ.Kind() == reflect.Ptr { + if typ.Kind() == reflect.Pointer { z = reflect.New(typ.Elem()) } else { z = reflect.Zero(typ) } - return value == z.Interface().(Value).String() + // Catch panics calling the String method, which shouldn't prevent the + // usage message from being printed, but that we should report to the + // user so that they know to fix their code. + defer func() { + if e := recover(); e != nil { + if typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + err = fmt.Errorf("panic calling String method on zero %v for flag %s: %v", typ, flag.Name, e) + } + }() + return value == z.Interface().(Value).String(), nil } // UnquoteUsage extracts a back-quoted name from the usage @@ -487,9 +583,11 @@ func UnquoteUsage(flag *Flag) (name string, usage string) { } // No explicit name, so use type if we can find one. name = "value" - switch flag.Value.(type) { + switch fv := flag.Value.(type) { case boolFlag: - name = "" + if fv.IsBoolFlag() { + name = "" + } case *durationValue: name = "duration" case *float64Value: @@ -508,41 +606,59 @@ func UnquoteUsage(flag *Flag) (name string, usage string) { // default values of all defined command-line flags in the set. See the // documentation for the global function PrintDefaults for more information. func (f *FlagSet) PrintDefaults() { + var isZeroValueErrs []error f.VisitAll(func(flag *Flag) { - s := fmt.Sprintf(" -%s", flag.Name) // Two spaces before -; see next two comments. + var b strings.Builder + fmt.Fprintf(&b, " -%s", flag.Name) // Two spaces before -; see next two comments. name, usage := UnquoteUsage(flag) if len(name) > 0 { - s += " " + name + b.WriteString(" ") + b.WriteString(name) } // Boolean flags of one ASCII letter are so common we // treat them specially, putting their usage on the same line. - if len(s) <= 4 { // space, space, '-', 'x'. - s += "\t" + if b.Len() <= 4 { // space, space, '-', 'x'. + b.WriteString("\t") } else { // Four spaces before the tab triggers good alignment // for both 4- and 8-space tab stops. - s += "\n \t" + b.WriteString("\n \t") } - s += strings.ReplaceAll(usage, "\n", "\n \t") + b.WriteString(strings.ReplaceAll(usage, "\n", "\n \t")) - if !isZeroValue(flag, flag.DefValue) { + // Print the default value only if it differs to the zero value + // for this flag type. + if isZero, err := isZeroValue(flag, flag.DefValue); err != nil { + isZeroValueErrs = append(isZeroValueErrs, err) + } else if !isZero { if _, ok := flag.Value.(*stringValue); ok { // put quotes on the value - s += fmt.Sprintf(" (default %q)", flag.DefValue) + fmt.Fprintf(&b, " (default %q)", flag.DefValue) } else { - s += fmt.Sprintf(" (default %v)", flag.DefValue) + fmt.Fprintf(&b, " (default %v)", flag.DefValue) } } - fmt.Fprint(f.Output(), s, "\n") + fmt.Fprint(f.Output(), b.String(), "\n") }) + // If calling String on any zero flag.Values triggered a panic, print + // the messages after the full set of defaults so that the programmer + // knows to fix the panic. + if errs := isZeroValueErrs; len(errs) > 0 { + fmt.Fprintln(f.Output()) + for _, err := range errs { + fmt.Fprintln(f.Output(), err) + } + } } // PrintDefaults prints, to standard error unless configured otherwise, // a usage message showing the default settings of all defined // command-line flags. // For an integer valued flag x, the default output has the form +// // -x int // usage-message-for-x (default 7) +// // The usage message will appear on a separate line for anything but // a bool flag with a one-byte name. For bool flags, the type is // omitted and if the flag name is one byte the usage message appears @@ -552,12 +668,15 @@ func (f *FlagSet) PrintDefaults() { // string; the first such item in the message is taken to be a parameter // name to show in the message and the back quotes are stripped from // the message when displayed. For instance, given +// // flag.String("I", "", "search `directory` for include files") +// // the output will be +// // -I directory // search directory for include files. // -// To change the destination for flag messages, call CommandLine.SetOutput. +// To change the destination for flag messages, call [CommandLine].SetOutput. func PrintDefaults() { CommandLine.PrintDefaults() } @@ -577,14 +696,14 @@ func (f *FlagSet) defaultUsage() { // for how to write your own usage function. // Usage prints a usage message documenting all defined command-line flags -// to CommandLine's output, which by default is os.Stderr. +// to [CommandLine]'s output, which by default is [os.Stderr]. // It is called when an error occurs while parsing flags. // The function is a variable that may be changed to point to a custom function. -// By default it prints a simple header and calls PrintDefaults; for details about the -// format of the output and how to control it, see the documentation for PrintDefaults. +// By default it prints a simple header and calls [PrintDefaults]; for details about the +// format of the output and how to control it, see the documentation for [PrintDefaults]. // Custom usage functions may choose to exit the program; by default exiting // happens anyway as the command line's error handling strategy is set to -// ExitOnError. +// [ExitOnError]. var Usage = func() { fmt.Fprintf(CommandLine.Output(), "Usage of %s:\n", os.Args[0]) PrintDefaults() @@ -837,6 +956,24 @@ func Duration(name string, value time.Duration, usage string) *time.Duration { return CommandLine.Duration(name, value, usage) } +// TextVar defines a flag with a specified name, default value, and usage string. +// The argument p must be a pointer to a variable that will hold the value +// of the flag, and p must implement encoding.TextUnmarshaler. +// If the flag is used, the flag value will be passed to p's UnmarshalText method. +// The type of the default value must be the same as the type of p. +func (f *FlagSet) TextVar(p encoding.TextUnmarshaler, name string, value encoding.TextMarshaler, usage string) { + f.Var(newTextValue(value, p), name, usage) +} + +// TextVar defines a flag with a specified name, default value, and usage string. +// The argument p must be a pointer to a variable that will hold the value +// of the flag, and p must implement encoding.TextUnmarshaler. +// If the flag is used, the flag value will be passed to p's UnmarshalText method. +// The type of the default value must be the same as the type of p. +func TextVar(p encoding.TextUnmarshaler, name string, value encoding.TextMarshaler, usage string) { + CommandLine.Var(newTextValue(value, p), name, usage) +} + // Func defines a flag with the specified name and usage string. // Each time the flag is seen, fn is called with the value of the flag. // If fn returns a non-nil error, it will be treated as a flag value parsing error. @@ -851,11 +988,25 @@ func Func(name, usage string, fn func(string) error) { CommandLine.Func(name, usage, fn) } +// BoolFunc defines a flag with the specified name and usage string without requiring values. +// Each time the flag is seen, fn is called with the value of the flag. +// If fn returns a non-nil error, it will be treated as a flag value parsing error. +func (f *FlagSet) BoolFunc(name, usage string, fn func(string) error) { + f.Var(boolFuncValue(fn), name, usage) +} + +// BoolFunc defines a flag with the specified name and usage string without requiring values. +// Each time the flag is seen, fn is called with the value of the flag. +// If fn returns a non-nil error, it will be treated as a flag value parsing error. +func BoolFunc(name, usage string, fn func(string) error) { + CommandLine.BoolFunc(name, usage, fn) +} + // Var defines a flag with the specified name and usage string. The type and -// value of the flag are represented by the first argument, of type Value, which -// typically holds a user-defined implementation of Value. For instance, the +// value of the flag are represented by the first argument, of type [Value], which +// typically holds a user-defined implementation of [Value]. For instance, the // caller could create a flag that turns a comma-separated string into a slice -// of strings by giving the slice the methods of Value; in particular, Set would +// of strings by giving the slice the methods of [Value]; in particular, [Set] would // decompose the comma-separated string into the slice. func (f *FlagSet) Var(value Value, name string, usage string) { // Flag must not begin "-" or contain "=". @@ -877,6 +1028,9 @@ func (f *FlagSet) Var(value Value, name string, usage string) { } panic(msg) // Happens only if flags are declared with identical names } + if pos := f.undef[name]; pos != "" { + panic(fmt.Sprintf("flag %s set at %s before being defined", name, pos)) + } if f.formal == nil { f.formal = make(map[string]*Flag) } @@ -884,17 +1038,17 @@ func (f *FlagSet) Var(value Value, name string, usage string) { } // Var defines a flag with the specified name and usage string. The type and -// value of the flag are represented by the first argument, of type Value, which -// typically holds a user-defined implementation of Value. For instance, the +// value of the flag are represented by the first argument, of type [Value], which +// typically holds a user-defined implementation of [Value]. For instance, the // caller could create a flag that turns a comma-separated string into a slice -// of strings by giving the slice the methods of Value; in particular, Set would +// of strings by giving the slice the methods of [Value]; in particular, [Set] would // decompose the comma-separated string into the slice. func Var(value Value, name string, usage string) { CommandLine.Var(value, name, usage) } // sprintf formats the message, prints it to output, and returns it. -func (f *FlagSet) sprintf(format string, a ...interface{}) string { +func (f *FlagSet) sprintf(format string, a ...any) string { msg := fmt.Sprintf(format, a...) fmt.Fprintln(f.Output(), msg) return msg @@ -902,7 +1056,7 @@ func (f *FlagSet) sprintf(format string, a ...interface{}) string { // failf prints to standard error a formatted error and usage message and // returns the error. -func (f *FlagSet) failf(format string, a ...interface{}) error { +func (f *FlagSet) failf(format string, a ...any) error { msg := f.sprintf(format, a...) f.usage() return errors.New(msg) @@ -952,9 +1106,9 @@ func (f *FlagSet) parseOne() (bool, error) { break } } - m := f.formal - flag, alreadythere := m[name] // BUG - if !alreadythere { + + flag, ok := f.formal[name] + if !ok { if name == "help" || name == "h" { // special case for nice help message. f.usage() return false, ErrHelp @@ -994,9 +1148,9 @@ func (f *FlagSet) parseOne() (bool, error) { } // Parse parses flag definitions from the argument list, which should not -// include the command name. Must be called after all flags in the FlagSet +// include the command name. Must be called after all flags in the [FlagSet] // are defined and before flags are accessed by the program. -// The return value will be ErrHelp if -help or -h were set but not defined. +// The return value will be [ErrHelp] if -help or -h were set but not defined. func (f *FlagSet) Parse(arguments []string) error { f.parsed = true f.args = arguments @@ -1066,7 +1220,7 @@ func (f *FlagSet) Parsed() bool { return f.parsed } -// Parse parses the command-line flags from os.Args[1:]. Must be called +// Parse parses the command-line flags from [os.Args][1:]. Must be called // after all flags are defined and before flags are accessed by the program. func Parse() { // Ignore errors; CommandLine is set for ExitOnError. @@ -1078,12 +1232,19 @@ func Parsed() bool { return CommandLine.Parsed() } -// CommandLine is the default set of command-line flags, parsed from os.Args. -// The top-level functions such as BoolVar, Arg, and so on are wrappers for the +// CommandLine is the default set of command-line flags, parsed from [os.Args]. +// The top-level functions such as [BoolVar], [Arg], and so on are wrappers for the // methods of CommandLine. -var CommandLine = NewFlagSet(os.Args[0], ExitOnError) +var CommandLine *FlagSet func init() { + // It's possible for execl to hand us an empty os.Args. + if len(os.Args) == 0 { + CommandLine = NewFlagSet("", ExitOnError) + } else { + CommandLine = NewFlagSet(os.Args[0], ExitOnError) + } + // Override generic FlagSet default Usage with call to global Usage. // Note: This is not CommandLine.Usage = Usage, // because we want any eventual call to use any updated value of Usage, @@ -1108,9 +1269,8 @@ func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet { } // Init sets the name and error handling property for a flag set. -// By default, the zero FlagSet uses an empty name and the -// ContinueOnError error handling policy. -// /* jnovack/flag */ Adds Environment Prefix +// By default, the zero [FlagSet] uses an empty name and the +// [ContinueOnError] error handling policy. func (f *FlagSet) Init(name string, errorHandling ErrorHandling) { f.name = name f.envPrefix = EnvironmentPrefix /* jnovack/flag */ diff --git a/flag_test.go b/flag_test.go index 3d7206a..2e9a7d3 100644 --- a/flag_test.go +++ b/flag_test.go @@ -9,13 +9,17 @@ import ( "fmt" "io" "os" - "sort" + "os/exec" + "regexp" + "runtime" + "slices" "strconv" "strings" "testing" "time" . "github.com/jnovack/flag" + "github.com/jnovack/flag/internal/testenv" ) func boolString(s string) string { @@ -36,6 +40,7 @@ func TestEverything(t *testing.T) { Float64("test_float64", 0, "float64 value") Duration("test_duration", 0, "time.Duration value") Func("test_func", "func value", func(string) error { return nil }) + BoolFunc("test_boolfunc", "func", func(string) error { return nil }) m := make(map[string]*Flag) desired := "0" @@ -52,6 +57,8 @@ func TestEverything(t *testing.T) { ok = true case f.Name == "test_func" && f.Value.String() == "": ok = true + case f.Name == "test_boolfunc" && f.Value.String() == "": + ok = true } if !ok { t.Error("Visit: bad value", f.Value.String(), "for", f.Name) @@ -59,7 +66,7 @@ func TestEverything(t *testing.T) { } } VisitAll(visitor) - if len(m) != 9 { + if len(m) != 10 { t.Error("VisitAll misses some flags") for k, v := range m { t.Log(k, *v) @@ -83,9 +90,10 @@ func TestEverything(t *testing.T) { Set("test_float64", "1") Set("test_duration", "1s") Set("test_func", "1") + Set("test_boolfunc", "") desired = "1" Visit(visitor) - if len(m) != 9 { + if len(m) != 10 { t.Error("Visit fails after set") for k, v := range m { t.Log(k, *v) @@ -94,7 +102,7 @@ func TestEverything(t *testing.T) { // Now test they're visited in sort order. var flagNames []string Visit(func(f *Flag) { flagNames = append(flagNames, f.Name) }) - if !sort.StringsAreSorted(flagNames) { + if !slices.IsSorted(flagNames) { t.Errorf("flag names not sorted: %v", flagNames) } } @@ -244,6 +252,7 @@ func (f *flagVar) Set(value string) error { func TestUserDefined(t *testing.T) { var flags FlagSet flags.Init("test", ContinueOnError) + flags.SetOutput(io.Discard) var v flagVar flags.Var(&v, "v", "usage") if err := flags.Parse([]string{"-v", "1", "-v", "2", "-v=3"}); err != nil { @@ -259,8 +268,8 @@ func TestUserDefined(t *testing.T) { } func TestUserDefinedFunc(t *testing.T) { - var flags FlagSet - flags.Init("test", ContinueOnError) + flags := NewFlagSet("test", ContinueOnError) + flags.SetOutput(io.Discard) var ss []string flags.Func("v", "usage", func(s string) error { ss = append(ss, s) @@ -284,7 +293,8 @@ func TestUserDefinedFunc(t *testing.T) { t.Errorf("usage string not included: %q", usage) } // test Func error - flags = *NewFlagSet("test", ContinueOnError) + flags = NewFlagSet("test", ContinueOnError) + flags.SetOutput(io.Discard) flags.Func("v", "usage", func(s string) error { return fmt.Errorf("test error") }) @@ -333,6 +343,7 @@ func (b *boolFlagVar) IsBoolFlag() bool { func TestUserDefinedBool(t *testing.T) { var flags FlagSet flags.Init("test", ContinueOnError) + flags.SetOutput(io.Discard) var b boolFlagVar var err error flags.Var(&b, "b", "usage") @@ -351,10 +362,35 @@ func TestUserDefinedBool(t *testing.T) { } } -func TestSetOutput(t *testing.T) { +func TestUserDefinedBoolUsage(t *testing.T) { var flags FlagSet + flags.Init("test", ContinueOnError) var buf bytes.Buffer flags.SetOutput(&buf) + var b boolFlagVar + flags.Var(&b, "b", "X") + b.count = 0 + // b.IsBoolFlag() will return true and usage will look boolean. + flags.PrintDefaults() + got := buf.String() + want := " -b\tX\n" + if got != want { + t.Errorf("false: want %q; got %q", want, got) + } + b.count = 4 + // b.IsBoolFlag() will return false and usage will look non-boolean. + flags.PrintDefaults() + got = buf.String() + want = " -b\tX\n -b value\n \tX\n" + if got != want { + t.Errorf("false: want %q; got %q", want, got) + } +} + +func TestSetOutput(t *testing.T) { + var flags FlagSet + var buf strings.Builder + flags.SetOutput(&buf) flags.Init("test", ContinueOnError) flags.Parse([]string{"-unknown"}) if out := buf.String(); !strings.Contains(out, "-unknown") { @@ -427,6 +463,25 @@ func TestHelp(t *testing.T) { } } +// zeroPanicker is a flag.Value whose String method panics if its dontPanic +// field is false. +type zeroPanicker struct { + dontPanic bool + v string +} + +func (f *zeroPanicker) Set(s string) error { + f.v = s + return nil +} + +func (f *zeroPanicker) String() string { + if !f.dontPanic { + panic("panic!") + } + return f.v +} + const defaultOutput = ` -A for bootstrapping, allow 'any' type -Alongflagname disable bounds checking @@ -447,15 +502,24 @@ const defaultOutput = ` -A for bootstrapping, allow 'any' type a non-zero int (default 27) -O a flag multiline help string (default true) + -V list + a list of strings (default [a b]) -Z int an int that defaults to zero + -ZP0 value + a flag whose String method panics when it is zero + -ZP1 value + a flag whose String method panics when it is zero -maxT timeout set timeout for dial + +panic calling String method on zero flag_test.zeroPanicker for flag ZP0: panic! +panic calling String method on zero flag_test.zeroPanicker for flag ZP1: panic! ` func TestPrintDefaults(t *testing.T) { fs := NewFlagSet("print defaults test", ContinueOnError) - var buf bytes.Buffer + var buf strings.Builder fs.SetOutput(&buf) fs.Bool("A", false, "for bootstrapping, allow 'any' type") fs.Bool("Alongflagname", false, "disable bounds checking") @@ -467,12 +531,15 @@ func TestPrintDefaults(t *testing.T) { fs.String("M", "", "a multiline\nhelp\nstring") fs.Int("N", 27, "a non-zero int") fs.Bool("O", true, "a flag\nmultiline help string") + fs.Var(&flagVar{"a", "b"}, "V", "a `list` of strings") fs.Int("Z", 0, "an int that defaults to zero") + fs.Var(&zeroPanicker{true, ""}, "ZP0", "a flag whose String method panics when it is zero") + fs.Var(&zeroPanicker{true, "something"}, "ZP1", "a flag whose String method panics when it is zero") fs.Duration("maxT", 0, "set `timeout` for dial") fs.PrintDefaults() got := buf.String() if got != defaultOutput { - t.Errorf("got %q want %q\n", got, defaultOutput) + t.Errorf("got:\n%q\nwant:\n%q", got, defaultOutput) } } @@ -495,7 +562,7 @@ func TestIntFlagOverflow(t *testing.T) { // Issue 20998: Usage should respect CommandLine.output. func TestUsageOutput(t *testing.T) { ResetForTesting(DefaultUsage) - var buf bytes.Buffer + var buf strings.Builder CommandLine.SetOutput(&buf) defer func(old []string) { os.Args = old }(os.Args) os.Args = []string{"app", "-i=1", "-unknown"} @@ -591,8 +658,6 @@ func TestRangeError(t *testing.T) { } } -// jnovack/flag cannot import TextExitCode because it relies on internal/testenv -/* func TestExitCode(t *testing.T) { testenv.MustHaveExec(t) @@ -637,7 +702,7 @@ func TestExitCode(t *testing.T) { } for _, test := range tests { - cmd := exec.Command(os.Args[0], "-test.run=TestExitCode") + cmd := exec.Command(testenv.Executable(t), "-test.run=^TestExitCode$") cmd.Env = append( os.Environ(), "GO_CHILD_FLAG="+test.flag, @@ -663,7 +728,7 @@ func mustPanic(t *testing.T, testName string, expected string, f func()) { case nil: t.Errorf("%s\n: expected panic(%q), but did not panic", testName, expected) case string: - if msg != expected { + if ok, _ := regexp.MatchString(expected, msg); !ok { t.Errorf("%s\n: expected panic(%q), but got panic(%q)", testName, expected, msg) } default: @@ -692,7 +757,7 @@ func TestInvalidFlags(t *testing.T) { testName := fmt.Sprintf("FlagSet.Var(&v, %q, \"\")", test.flag) fs := NewFlagSet("", ContinueOnError) - buf := bytes.NewBuffer(nil) + buf := &strings.Builder{} fs.SetOutput(buf) mustPanic(t, testName, test.errorMsg, func() { @@ -724,7 +789,7 @@ func TestRedefinedFlags(t *testing.T) { testName := fmt.Sprintf("flag redefined in FlagSet(%q)", test.flagSetName) fs := NewFlagSet(test.flagSetName, ContinueOnError) - buf := bytes.NewBuffer(nil) + buf := &strings.Builder{} fs.SetOutput(buf) var v flagVar @@ -738,4 +803,57 @@ func TestRedefinedFlags(t *testing.T) { } } } -*/ + +func TestUserDefinedBoolFunc(t *testing.T) { + flags := NewFlagSet("test", ContinueOnError) + flags.SetOutput(io.Discard) + var ss []string + flags.BoolFunc("v", "usage", func(s string) error { + ss = append(ss, s) + return nil + }) + if err := flags.Parse([]string{"-v", "", "-v", "1", "-v=2"}); err != nil { + t.Error(err) + } + if len(ss) != 1 { + t.Fatalf("got %d args; want 1 arg", len(ss)) + } + want := "[true]" + if got := fmt.Sprint(ss); got != want { + t.Errorf("got %q; want %q", got, want) + } + // test usage + var buf strings.Builder + flags.SetOutput(&buf) + flags.Parse([]string{"-h"}) + if usage := buf.String(); !strings.Contains(usage, "usage") { + t.Errorf("usage string not included: %q", usage) + } + // test BoolFunc error + flags = NewFlagSet("test", ContinueOnError) + flags.SetOutput(io.Discard) + flags.BoolFunc("v", "usage", func(s string) error { + return fmt.Errorf("test error") + }) + // flag not set, so no error + if err := flags.Parse(nil); err != nil { + t.Error(err) + } + // flag set, expect error + if err := flags.Parse([]string{"-v", ""}); err == nil { + t.Error("got err == nil; want err != nil") + } else if errMsg := err.Error(); !strings.Contains(errMsg, "test error") { + t.Errorf(`got %q; error should contain "test error"`, errMsg) + } +} + +func TestDefineAfterSet(t *testing.T) { + flags := NewFlagSet("test", ContinueOnError) + // Set by itself doesn't panic. + flags.Set("myFlag", "value") + + // Define-after-set panics. + mustPanic(t, "DefineAfterSet", "flag myFlag set at .*/flag_test.go:.* before being defined", func() { + _ = flags.String("myFlag", "default", "usage") + }) +} diff --git a/go.mod b/go.mod index 14cff16..bc64bc1 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/jnovack/flag -go 1.16 +go 1.24.2 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/internal/README.md b/internal/README.md new file mode 100644 index 0000000..e14d850 --- /dev/null +++ b/internal/README.md @@ -0,0 +1,7 @@ +# internal + +The following files are copied from go's src and modified because calling `internal/testenv` +is prohibited. + +I am only including `internal/testenv` because this module is a fork/fix of go's original +`flag` functions and their original source uses `internal/testenv`. \ No newline at end of file diff --git a/test/sparse-checkout b/test/sparse-checkout new file mode 100644 index 0000000..1b3edf9 --- /dev/null +++ b/test/sparse-checkout @@ -0,0 +1,11 @@ +!/*/ +!/src/*/ +/src/internal/ +!/src/internal/*/ +/src/flag/ +/src/internal/cfg/ +/src/internal/diff/ +/src/internal/goarch/ +/src/internal/platform/ +/src/internal/testenv/ +/src/internal/txtar/