diff --git a/download/download_test.go b/download/download_test.go index 93bec23..1b054f9 100644 --- a/download/download_test.go +++ b/download/download_test.go @@ -3,14 +3,18 @@ package download_test import ( "testing" - "github.com/stretchr/testify/require" - "github.com/murfffi/getaduck/download" + "github.com/stretchr/testify/require" ) func TestDo(t *testing.T) { - res, err := download.Do(download.DefaultSpec()) - require.NoError(t, err) - require.FileExists(t, res.OutputFile) - + if !testing.Short() { + t.Skip("skipping test that downloads from Github in short mode.") + } + t.Run("default lib", func(t *testing.T) { + res, err := download.Do(download.DefaultSpec()) + require.NoError(t, err) + require.FileExists(t, res.OutputFile) + }) + // cli is tested e2e in shell/run_test.go . Avoid multiple downloads. } diff --git a/internal/enumflag/enumflag.go b/internal/enumflag/enumflag.go new file mode 100644 index 0000000..f515123 --- /dev/null +++ b/internal/enumflag/enumflag.go @@ -0,0 +1,68 @@ +package enumflag + +// Adapted from https://github.com/creachadair/goflags/blob/main/enumflag/flag.go +// under BSD 3-Clause license + +import ( + "fmt" + "strings" +) + +// A Value represents an enumeration of string values. A pointer to a Value +// satisfies the flag.Value interface. Use the Key method to recover the +// currently-selected value of the enumeration. +type Value struct { + keys []string + index int // The selected index in the enumeration +} + +// Help concatenates a human-readable string summarizing the legal values of v +// to h, for use in generating a documentation string. +func (v *Value) Help(h string) string { + return fmt.Sprintf("%s (%s)", h, strings.Join(v.keys, "|")) +} + +// New returns a *Value for the specified enumerators, where defaultKey is the +// default value and otherKeys are additional options. The index of a selected +// key reflects its position in the order given to this function, so that if: +// +// v := enumflag.New("a", "b", "c", "d") +// +// then the index of "a" is 0, "b" is 1, "c" is 2, "d" is 3. The default key is +// always stored at index 0. +func New(defaultKey string, otherKeys ...string) *Value { + return &Value{keys: append([]string{defaultKey}, otherKeys...)} +} + +// Key returns the currently-selected key in the enumeration. The original +// spelling of the selected value is returned, as given to the constructor, not +// the value as parsed. +func (v *Value) Key() string { + if len(v.keys) == 0 { + return "" // BUG: https://github.com/golang/go/issues/16694 + } + return v.keys[v.index] +} + +// Get satisfies the flag.Getter interface. +// The concrete value is the the string of the current key. +func (v *Value) Get() any { return v.Key() } + +// Index returns the currently-selected index in the enumeration. +// The order of keys reflects the original order in which they were passed to +// the constructor, so index 0 is the default value. +func (v *Value) Index() int { return v.index } + +// String satisfies part of the flag.Value interface. +func (v *Value) String() string { return fmt.Sprintf("%q", v.Key()) } + +// Set satisfies part of the flag.Value interface. +func (v *Value) Set(s string) error { + for i, key := range v.keys { + if strings.EqualFold(s, key) { + v.index = i + return nil + } + } + return fmt.Errorf("expected one of (%s)", strings.Join(v.keys, "|")) +} diff --git a/internal/enumflag/enumflag_test.go b/internal/enumflag/enumflag_test.go new file mode 100644 index 0000000..910a751 --- /dev/null +++ b/internal/enumflag/enumflag_test.go @@ -0,0 +1,59 @@ +package enumflag + +// Adapted from https://github.com/creachadair/goflags/blob/main/enumflag/flag.go +// under BSD 3-Clause license + +import ( + "bytes" + "flag" + "io" + "testing" +) + +func newFlagSet(name string, w io.Writer) *flag.FlagSet { + fs := flag.NewFlagSet(name, flag.ContinueOnError) + fs.SetOutput(w) + return fs +} + +func TestFlagBits(t *testing.T) { + color := New("red", "orange", "yellow", "green", "blue") + + const initial = "red" + const flagged = "green" + const flaggedIndex = 3 + + var buf bytes.Buffer + fs := newFlagSet("color", &buf) + fs.Var(color, "color", color.Help("The color to paint the bike shed")) + fs.PrintDefaults() + t.Logf("Color flag set:\n%s", buf.String()) + buf.Reset() + + if key := color.Key(); key != initial { + t.Errorf("Initial value for -color: got %q, want %q", key, initial) + } + + if err := fs.Parse([]string{"-color", "GREEN"}); err != nil { + t.Fatalf("Argument parsing failed: %v", err) + } + + if key := color.Key(); key != flagged { + t.Errorf("Value for -color: got %q, want %q", key, flagged) + } + if idx := color.Index(); idx != flaggedIndex { + t.Errorf("Index for -color: got %d, want %d", idx, flaggedIndex) + } + + taste := New("", "sweet", "sour") + fs = newFlagSet("taste", &buf) + fs.Var(taste, "taste", taste.Help("The flavour of the ice cream")) + fs.PrintDefaults() + t.Logf("Taste flag set:\n%s", buf.String()) + + if err := fs.Parse([]string{"-taste", "crud"}); err == nil { + t.Error("Expected error from bogus flag, but got none") + } else { + t.Logf("Got expected error from bogus -taste: %v", err) + } +} diff --git a/main.go b/main.go index e940b1c..3cf5044 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,16 @@ package main import ( + "flag" + "log" + "os" + "github.com/murfffi/getaduck/shell" ) func main() { - shell.Run() + err := shell.RunArgs(os.Args, flag.ExitOnError) + if err != nil { + log.Fatal(err) + } } diff --git a/shell/run.go b/shell/run.go index 7172405..bf736e6 100644 --- a/shell/run.go +++ b/shell/run.go @@ -2,18 +2,24 @@ package shell import ( + "flag" "log" "path/filepath" "github.com/murfffi/getaduck/download" + "github.com/murfffi/getaduck/internal/enumflag" ) -// Run executes getaduck -func Run() { - spec := download.DefaultSpec() +// RunArgs executes getaduck CLI +func RunArgs(args []string, onError flag.ErrorHandling) error { + spec, err := parseSpec(args, onError) + if err != nil { + return err + } + res, err := download.Do(spec) if err != nil { - log.Fatalf("download failed: %v", err) + return err } outFileName := res.OutputFile absPath, err := filepath.Abs(outFileName) @@ -21,4 +27,26 @@ func Run() { absPath = outFileName } log.Print("downloaded: ", absPath) + return nil +} + +func parseSpec(args []string, onError flag.ErrorHandling) (download.Spec, error) { + fs := flag.NewFlagSet(args[0], onError) + spec := download.DefaultSpec() + + // order of args must match download.BinType const order + binType := enumflag.New("lib", "cli") + fs.Var(binType, "type", binType.Help("type of binary to download")) + version := fs.String("version", spec.Version, "DuckDB version") + binOS := fs.String("os", spec.OS, "target OS") + binArch := fs.String("arch", spec.Arch, "target architecture") + if err := fs.Parse(args[1:]); err != nil { + return download.Spec{}, err + } + + spec.Type = download.BinType(binType.Index()) + spec.Version = *version + spec.OS = *binOS + spec.Arch = *binArch + return spec, nil } diff --git a/shell/run_test.go b/shell/run_test.go new file mode 100644 index 0000000..868dbb8 --- /dev/null +++ b/shell/run_test.go @@ -0,0 +1,26 @@ +package shell + +import ( + "flag" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRunArgs(t *testing.T) { + if !testing.Short() { + t.Skip("skipping test that downloads from Github in short mode.") + } + t.Run("cli", func(t *testing.T) { + err := RunArgs([]string{"test", "-type", "cli"}, flag.ContinueOnError) + require.NoError(t, err) + }) +} + +func TestParseSpec(t *testing.T) { + t.Run("version", func(t *testing.T) { + spec, err := parseSpec([]string{"test", "--version", "1.1.0"}, flag.ContinueOnError) + require.NoError(t, err) + require.Equal(t, "1.1.0", spec.Version) + }) +}