diff --git a/configurator.go b/configurator.go index d0b5ecf..fcd3a16 100644 --- a/configurator.go +++ b/configurator.go @@ -4,6 +4,8 @@ import ( "context" "os" "reflect" + "sort" + "strconv" "github.com/upfluence/errors" @@ -112,6 +114,14 @@ func (c *configurator) walkFunc(ctx context.Context, f *walker.Field) error { s := c.factory.Build(f.Field.Type) if s == nil { + if reflectutil.SubKeyMapElem(f.Field.Type) != nil { + return c.populateMapField(ctx, f) + } + + if reflectutil.SubKeySliceElem(f.Field.Type) != nil { + return c.populateSliceField(ctx, f) + } + return nil } @@ -180,6 +190,155 @@ func (c *configurator) walkFunc(ctx context.Context, f *walker.Field) error { return nil } +func (c *configurator) collectSubKeys(ctx context.Context, f *walker.Field) ([]string, error) { + seen := make(map[string]struct{}) + + var keys []string + + for _, p := range c.providers { + fqp := provider.WrapFullyQualifiedProvider(p) + + for _, prefix := range walker.BuildFieldKeys(fqp, f, c.ignoreMissingTag) { + sks, err := fqp.SubKeys(ctx, prefix) + + if err != nil { + return nil, errors.WithStack( + &ProvidingError{ + Err: err, + Key: prefix, + Field: f.Field, + Provider: p, + }, + ) + } + + for _, k := range sks { + if _, dup := seen[k]; dup { + continue + } + + seen[k] = struct{}{} + keys = append(keys, k) + } + } + } + + return keys, nil +} + +func (c *configurator) populateMapField(ctx context.Context, f *walker.Field) error { + ft := reflectutil.IndirectedType(f.Field.Type) + elemIsPtr := ft.Elem().Kind() == reflect.Ptr + structType := reflectutil.SubKeyMapElem(f.Field.Type) + + keys, err := c.collectSubKeys(ctx, f) + + if err != nil { + return err + } + + if len(keys) == 0 { + return walker.SkipStruct + } + + fv := reflectutil.IndirectedValue(f.Value).FieldByName(f.Field.Name) + mapVal := reflect.MakeMap(ft) + + for _, subKey := range keys { + elem := reflect.New(structType) + + prefixed := &walker.SubKeyPrefixed{ + Ancestor: f, + SubKey: subKey, + Value: elem.Interface(), + } + + if err := c.Populate(ctx, prefixed); err != nil { + return err + } + + if reflectutil.IsZero(elem.Elem()) { + continue + } + + if elemIsPtr { + mapVal.SetMapIndex(reflect.ValueOf(subKey), elem) + } else { + mapVal.SetMapIndex(reflect.ValueOf(subKey), elem.Elem()) + } + } + + if mapVal.Len() > 0 { + fv.Set(mapVal) + } + + return walker.SkipStruct +} + +func (c *configurator) populateSliceField(ctx context.Context, f *walker.Field) error { + ft := reflectutil.IndirectedType(f.Field.Type) + elemIsPtr := ft.Elem().Kind() == reflect.Ptr + structType := reflectutil.SubKeySliceElem(f.Field.Type) + + keys, err := c.collectSubKeys(ctx, f) + + if err != nil { + return err + } + + type indexedKey struct { + index int + key string + } + + var indices []indexedKey + + for _, k := range keys { + idx, err := strconv.Atoi(k) + + if err != nil { + continue + } + + indices = append(indices, indexedKey{index: idx, key: k}) + } + + if len(indices) == 0 { + return walker.SkipStruct + } + + sort.Slice(indices, func(i, j int) bool { + return indices[i].index < indices[j].index + }) + + fv := reflectutil.IndirectedValue(f.Value).FieldByName(f.Field.Name) + sliceVal := reflect.MakeSlice(ft, len(indices), len(indices)) + + for i, ik := range indices { + elem := reflect.New(structType) + + prefixed := &walker.SubKeyPrefixed{ + Ancestor: f, + SubKey: ik.key, + Value: elem.Interface(), + } + + if err := c.Populate(ctx, prefixed); err != nil { + return err + } + + if elemIsPtr { + sliceVal.Index(i).Set(elem) + } else { + sliceVal.Index(i).Set(elem.Elem()) + } + } + + fv.Set(sliceVal) + + return walker.SkipStruct +} + func isRequired(f reflect.StructField) bool { v, ok := f.Tag.Lookup("required") diff --git a/configurator_test.go b/configurator_test.go index df743de..8dc2330 100644 --- a/configurator_test.go +++ b/configurator_test.go @@ -6,6 +6,8 @@ import ( "fmt" "net" "os" + "reflect" + "strings" "testing" "time" @@ -42,6 +44,38 @@ func (p *mockProvider) Provide(_ context.Context, k string) (string, bool, error return v, ok, nil } +func (p *mockProvider) DefaultFieldValue(fieldName string) string { + return fieldName +} + +func (p *mockProvider) JoinFieldKeys(prefix, key string) string { + return prefix + "." + key +} + +func (p *mockProvider) SubKeys(_ context.Context, prefix string) ([]string, error) { + pfx := prefix + "." + seen := make(map[string]struct{}) + + var keys []string + + for k := range p.st { + if !strings.HasPrefix(k, pfx) { + continue + } + + seg, _, _ := strings.Cut(k[len(pfx):], ".") + + if _, dup := seen[seg]; dup { + continue + } + + seen[seg] = struct{}{} + keys = append(keys, seg) + } + + return keys, nil +} + type testCase struct { caseName string input interface{} @@ -604,6 +638,142 @@ func TestConfigurator(t *testing.T) { errAssertion: hasStaticError(walker.ErrShouldBeAStructPtr), dataAssertion: func(t *testing.T, y interface{}) {}, }, + + // SubKeys: map[string]Struct + { + caseName: "subkeys-map-of-structs", + input: &mapStructConfig{}, + provider: &mockProvider{st: map[string]string{ + "Databases.PRIMARY.Host": "h1", + "Databases.PRIMARY.Port": "5432", + "Databases.REPLICA.Host": "h2", + "Databases.REPLICA.Port": "5433", + }}, + dataAssertion: deepEqual(&mapStructConfig{ + Databases: map[string]dbConfig{ + "PRIMARY": {Host: "h1", Port: 5432}, + "REPLICA": {Host: "h2", Port: 5433}, + }, + }), + errAssertion: noError, + }, + { + caseName: "subkeys-map-of-ptr-structs", + input: &mapPtrStructConfig{}, + provider: &mockProvider{st: map[string]string{ + "Databases.PRIMARY.Host": "h1", + "Databases.PRIMARY.Port": "5432", + "Databases.REPLICA.Host": "h2", + "Databases.REPLICA.Port": "5433", + }}, + dataAssertion: deepEqual(&mapPtrStructConfig{ + Databases: map[string]*dbConfig{ + "PRIMARY": {Host: "h1", Port: 5432}, + "REPLICA": {Host: "h2", Port: 5433}, + }, + }), + errAssertion: noError, + }, + { + caseName: "subkeys-map-no-keys", + input: &mapStructConfig{}, + provider: &mockProvider{st: map[string]string{}}, + dataAssertion: deepEqual(&mapStructConfig{}), + errAssertion: noError, + }, + { + caseName: "subkeys-map-union", + input: &mapStructConfig{}, + provider: &mockProvider{st: map[string]string{ + "Databases.B.Host": "h2", + "Databases.B.Port": "2", + }}, + options: []Option{WithProviders(&mockProvider{st: map[string]string{ + "Databases.A.Host": "h1", + "Databases.A.Port": "1", + }})}, + dataAssertion: deepEqual(&mapStructConfig{ + Databases: map[string]dbConfig{ + "A": {Host: "h1", Port: 1}, + "B": {Host: "h2", Port: 2}, + }, + }), + errAssertion: noError, + }, + + // SubKeys: []Struct + { + caseName: "subkeys-slice-of-structs", + input: &sliceStructConfig{}, + provider: &mockProvider{st: map[string]string{ + "Workers.0.Host": "h0", + "Workers.0.Port": "1000", + "Workers.1.Host": "h1", + "Workers.1.Port": "1001", + }}, + dataAssertion: deepEqual(&sliceStructConfig{ + Workers: []dbConfig{ + {Host: "h0", Port: 1000}, + {Host: "h1", Port: 1001}, + }, + }), + errAssertion: noError, + }, + { + caseName: "subkeys-slice-of-ptr-structs", + input: &slicePtrStructConfig{}, + provider: &mockProvider{st: map[string]string{ + "Workers.0.Host": "h0", + "Workers.0.Port": "1000", + "Workers.1.Host": "h1", + "Workers.1.Port": "1001", + }}, + dataAssertion: deepEqual(&slicePtrStructConfig{ + Workers: []*dbConfig{ + {Host: "h0", Port: 1000}, + {Host: "h1", Port: 1001}, + }, + }), + errAssertion: noError, + }, + { + caseName: "subkeys-slice-filters-non-numeric", + input: &sliceStructConfig{}, + provider: &mockProvider{st: map[string]string{ + "Workers.0.Host": "h0", + "Workers.0.Port": "80", + "Workers.abc.Host": "bad", + "Workers.1x.Host": "bad", + }}, + dataAssertion: deepEqual(&sliceStructConfig{ + Workers: []dbConfig{{Host: "h0", Port: 80}}, + }), + errAssertion: noError, + }, + { + caseName: "subkeys-slice-sorts-by-index", + input: &sliceStructConfig{}, + provider: &mockProvider{st: map[string]string{ + "Workers.2.Host": "h2", + "Workers.2.Port": "2", + "Workers.0.Host": "h0", + "Workers.0.Port": "0", + }}, + dataAssertion: deepEqual(&sliceStructConfig{ + Workers: []dbConfig{ + {Host: "h0", Port: 0}, + {Host: "h2", Port: 2}, + }, + }), + errAssertion: noError, + }, + { + caseName: "subkeys-slice-no-keys", + input: &sliceStructConfig{}, + provider: &mockProvider{st: map[string]string{}}, + dataAssertion: deepEqual(&sliceStructConfig{}), + errAssertion: noError, + }, } { t.Run( tCase.caseName, @@ -747,12 +917,24 @@ type prefixedConfig struct { value any } -func (p *prefixedConfig) WalkPrefix() []string { return p.prefix } -func (p *prefixedConfig) WalkValue() any { return p.value } +func (p *prefixedConfig) WalkAncestor() *walker.Field { + var ancestor *walker.Field + + for _, seg := range p.prefix { + ancestor = &walker.Field{ + Field: reflect.StructField{Name: seg}, + Ancestor: ancestor, + } + } + + return ancestor +} + +func (p *prefixedConfig) WalkValue() any { return p.value } type outerPrefixedConfig struct { Direct string `mock:"direct"` - Nested *prefixedConfig + Nested *walker.SubKeyPrefixed } func TestPrefixedPopulate(t *testing.T) { @@ -826,49 +1008,25 @@ func TestPrefixedPopulate(t *testing.T) { } } -func TestNestedPrefixedPopulate(t *testing.T) { - for _, tc := range []struct { - name string - have *outerPrefixedConfig - provider provider.Provider - wantOuter string - wantInner *basicStruct1 - }{ - { - name: "nested prefixed field", - have: &outerPrefixedConfig{ - Nested: &prefixedConfig{ - prefix: []string{"ns"}, - value: &basicStruct1{}, - }, - }, - provider: &mockProvider{st: map[string]string{"direct": "top", "Nested.ns.Fiz": "deep"}}, - wantOuter: "top", - wantInner: &basicStruct1{Fiz: "deep"}, - }, - { - name: "nested prefixed with multi-segment prefix", - have: &outerPrefixedConfig{ - Nested: &prefixedConfig{ - prefix: []string{"a", "b"}, - value: &basicStruct1{}, - }, - }, - provider: &mockProvider{st: map[string]string{"Nested.a.b.Fiz": "val"}}, - wantOuter: "", - wantInner: &basicStruct1{Fiz: "val"}, - }, - } { - t.Run(tc.name, func(t *testing.T) { - c := NewConfiguratorWithOptions(WithProviders(tc.provider)) +type dbConfig struct { + Host string + Port int32 +} - err := c.Populate(context.Background(), tc.have) +type mapStructConfig struct { + Databases map[string]dbConfig +} - require.NoError(t, err) - assert.Equal(t, tc.wantOuter, tc.have.Direct) - assert.Equal(t, tc.wantInner, tc.have.Nested.value) - }) - } +type mapPtrStructConfig struct { + Databases map[string]*dbConfig +} + +type sliceStructConfig struct { + Workers []dbConfig +} + +type slicePtrStructConfig struct { + Workers []*dbConfig } func ExampleNewDefaultConfigurator() { diff --git a/internal/help/writer.go b/internal/help/writer.go index 1c07215..a4e001e 100644 --- a/internal/help/writer.go +++ b/internal/help/writer.go @@ -7,6 +7,8 @@ import ( "reflect" "strings" + "github.com/upfluence/errors" + "github.com/upfluence/cfg/internal/reflectutil" "github.com/upfluence/cfg/internal/setter" "github.com/upfluence/cfg/internal/walker" @@ -64,73 +66,91 @@ type Writer struct { func (w *Writer) writeConfig(out io.Writer, in interface{}) (int, error) { var n int - return n, walker.Walk( - in, - func(f *walker.Field) error { - s := w.Factory.Build(f.Field.Type) + return n, errors.Wrap(walker.Walk(in, w.buildWalkFn(out, &n, true)), "walk") +} - if s == nil { - return nil - } +func (w *Writer) buildWalkFn(out io.Writer, n *int, includeDefaults bool) walker.WalkFunc { + return func(f *walker.Field) error { + s := w.Factory.Build(f.Field.Type) - fks := walker.BuildFieldKeys( - provider.WrapFullyQualifiedProvider( - provider.NewStaticProvider("", nil, nil), - ), - f, - w.IgnoreMissingTag, - ) + if s == nil { + return w.writeSubKeyField(out, n, f) + } - if len(fks) == 0 { - return nil - } + fks := walker.BuildFieldKeys( + provider.WrapFullyQualifiedProvider( + provider.NewStaticProvider("", nil, nil), + ), + f, + w.IgnoreMissingTag, + ) - if setter.IsUnmarshaler(f.Value.Type()) { - return walker.SkipStruct - } + if len(fks) == 0 { + return nil + } + + if setter.IsUnmarshaler(f.Value.Type()) { + return walker.SkipStruct + } - var b bytes.Buffer + var b bytes.Buffer - b.WriteString("\t- ") - b.WriteString(fks[0]) - b.WriteString(": ") - b.WriteString(s.String()) + b.WriteString("\t- ") + b.WriteString(fks[0]) + b.WriteString(": ") + b.WriteString(s.String()) + if includeDefaults { if h := fieldHelp(f); h != "" { b.WriteString(" ") b.WriteString(h) } + } - defaultValue := fieldDefault(f) - providedKeys, tagDefault := w.providerKeys(f) + var defaultValue string + + providedKeys, tagDefault := w.providerKeys(f) + + if includeDefaults { + defaultValue = fieldDefault(f) if tagDefault != "" { defaultValue = tagDefault } + } - if len(providedKeys) == 0 { - return nil - } - - if defaultValue != "" { - b.WriteString(" (default: ") - b.WriteString(defaultValue) - b.WriteString(")") - } + if len(providedKeys) == 0 { + return nil + } - b.WriteString(" (") - b.WriteString(strings.Join(providedKeys, ", ")) + if defaultValue != "" { + b.WriteString(" (default: ") + b.WriteString(defaultValue) b.WriteString(")") + } - b.WriteRune('\n') + b.WriteString(" (") + b.WriteString(strings.Join(providedKeys, ", ")) + b.WriteString(")") - nn, err := b.WriteTo(out) + b.WriteRune('\n') - n += int(nn) + nn, err := b.WriteTo(out) - return err - }, - ) + *n += int(nn) + + return errors.Wrap(err, "write") + } +} + +func (w *Writer) writeSubKeyField(out io.Writer, n *int, f *walker.Field) error { + prefixed := walker.BuildSubKeyField(f) + + if prefixed == nil { + return nil + } + + return errors.Wrap(walker.Walk(prefixed, w.buildWalkFn(out, n, false)), "walk") } func fieldDefault(f *walker.Field) string { diff --git a/internal/help/writer_test.go b/internal/help/writer_test.go index c2f95a9..e6a4820 100644 --- a/internal/help/writer_test.go +++ b/internal/help/writer_test.go @@ -56,6 +56,18 @@ type helperEmptyFallsBackConfig struct { Dynamic helpString `env:"-" flag:"dyn" help:"from tag"` } +type mapStructConfig struct { + Databases map[string]dbConfig `env:"DATABASES" flag:"databases"` +} + +type sliceStructConfig struct { + Workers []dbConfig `env:"WORKERS" flag:"workers"` +} + +type mapPtrStructConfig struct { + Databases map[string]*dbConfig `env:"DATABASES" flag:"databases"` +} + func TestPrintDefaults(t *testing.T) { for _, tt := range []struct { name string @@ -119,6 +131,27 @@ func TestPrintDefaults(t *testing.T) { out: "Arguments:\n" + "\t- Dynamic: string from tag (flag: --dyn)\n", }, + { + name: "map of structs shows inner fields with placeholder", + in: &mapStructConfig{}, + out: "Arguments:\n" + + "\t- Databases..Host: string (env: DATABASES__HOST, flag: --databases..host)\n" + + "\t- Databases..Port: integer (env: DATABASES__PORT, flag: --databases..port)\n", + }, + { + name: "slice of structs shows inner fields with placeholder", + in: &sliceStructConfig{}, + out: "Arguments:\n" + + "\t- Workers..Host: string (env: WORKERS__HOST, flag: --workers..host)\n" + + "\t- Workers..Port: integer (env: WORKERS__PORT, flag: --workers..port)\n", + }, + { + name: "map of ptr structs shows inner fields with placeholder", + in: &mapPtrStructConfig{}, + out: "Arguments:\n" + + "\t- Databases..Host: string (env: DATABASES__HOST, flag: --databases..host)\n" + + "\t- Databases..Port: integer (env: DATABASES__PORT, flag: --databases..port)\n", + }, } { t.Run(tt.name, func(t *testing.T) { var b bytes.Buffer diff --git a/internal/reflectutil/value.go b/internal/reflectutil/value.go index 1cde4f0..e107bf1 100644 --- a/internal/reflectutil/value.go +++ b/internal/reflectutil/value.go @@ -50,3 +50,35 @@ func IndirectedType(t reflect.Type) reflect.Type { return t } + +// SubKeyMapElem returns the struct type of the map's element if t is +// map[string]S where S (or *S) is a struct. Otherwise it returns nil. +func SubKeyMapElem(t reflect.Type) reflect.Type { + t = IndirectedType(t) + + if t.Kind() != reflect.Map || t.Key().Kind() != reflect.String { + return nil + } + + if et := IndirectedType(t.Elem()); et.Kind() == reflect.Struct { + return et + } + + return nil +} + +// SubKeySliceElem returns the struct type of the slice's element if t +// is []S where S (or *S) is a struct. Otherwise it returns nil. +func SubKeySliceElem(t reflect.Type) reflect.Type { + t = IndirectedType(t) + + if t.Kind() != reflect.Slice { + return nil + } + + if et := IndirectedType(t.Elem()); et.Kind() == reflect.Struct { + return et + } + + return nil +} diff --git a/internal/synopsis/writer.go b/internal/synopsis/writer.go index e91d0b4..08f1bfd 100644 --- a/internal/synopsis/writer.go +++ b/internal/synopsis/writer.go @@ -4,6 +4,8 @@ import ( "bytes" "io" + "github.com/upfluence/errors" + "github.com/upfluence/cfg/internal/setter" "github.com/upfluence/cfg/internal/walker" "github.com/upfluence/cfg/provider" @@ -24,50 +26,69 @@ type Writer struct { func (w *Writer) Write(out io.Writer, in interface{}) (int, error) { var b bytes.Buffer - if err := walker.Walk( - in, - func(f *walker.Field) error { - if s := w.Factory.Build(f.Field.Type); s == nil { - return nil - } + writeFn := w.buildWriteFn(&b) + + if err := walker.Walk(in, writeFn); err != nil { + return 0, errors.Wrap(err, "walk") + } - fks := walker.BuildFieldKeys( - provider.WrapFullyQualifiedProvider(w.Provider), - f, - w.IgnoreMissingTag, - ) + n, err := out.Write(b.Bytes()) - if len(fks) == 0 { - return nil - } + return n, errors.Wrap(err, "write") +} - if setter.IsUnmarshaler(f.Value.Type()) { - return walker.SkipStruct - } +func (w *Writer) buildWriteFn(b *bytes.Buffer) walker.WalkFunc { + return func(f *walker.Field) error { + if s := w.Factory.Build(f.Field.Type); s == nil { + return w.writeSubKeyField(b, f) + } - b.WriteRune('[') + fks := walker.BuildFieldKeys( + provider.WrapFullyQualifiedProvider(w.Provider), + f, + w.IgnoreMissingTag, + ) - kf, hasFormatter := w.Provider.(provider.KeyFormatter) + if len(fks) == 0 { + return nil + } - for i, fk := range fks { - if hasFormatter { - fk = kf.FormatKey(fk) - } + if setter.IsUnmarshaler(f.Value.Type()) { + return walker.SkipStruct + } - b.WriteString(fk) + w.writeKeys(b, fks) - if i < len(fks)-1 { - b.WriteString(", ") - } - } + return nil + } +} - b.WriteString("] ") +func (w *Writer) writeKeys(b *bytes.Buffer, fks []string) { + b.WriteRune('[') - return nil - }, - ); err != nil { - return 0, err + kf, hasFormatter := w.Provider.(provider.KeyFormatter) + + for i, fk := range fks { + if hasFormatter { + fk = kf.FormatKey(fk) + } + + b.WriteString(fk) + + if i < len(fks)-1 { + b.WriteString(", ") + } + } + + b.WriteString("] ") +} + +func (w *Writer) writeSubKeyField(b *bytes.Buffer, f *walker.Field) error { + prefixed := walker.BuildSubKeyField(f) + + if prefixed == nil { + return nil } - return out.Write(b.Bytes()) + return errors.Wrap(walker.Walk(prefixed, w.buildWriteFn(b)), "walk") } diff --git a/internal/walker/prefixed.go b/internal/walker/prefixed.go new file mode 100644 index 0000000..aaaa04f --- /dev/null +++ b/internal/walker/prefixed.go @@ -0,0 +1,61 @@ +package walker + +import ( + "reflect" + + "github.com/upfluence/cfg/internal/reflectutil" +) + +// Prefixed is an optional interface that a value passed to Walk can +// implement to inject a dynamic ancestor chain. When Walk receives a +// Prefixed value it uses the ancestor returned by WalkAncestor and +// walks the inner value returned by WalkValue. +type Prefixed interface { + WalkAncestor() *Field + WalkValue() any +} + +// SubKeyPrefixed is a Prefixed implementation used by the configurator +// and help/synopsis writers to handle dynamic map[string]Struct and +// []Struct fields. It preserves the real ancestor Field (with struct +// tags) and appends one synthetic segment for the sub-key. +type SubKeyPrefixed struct { + Ancestor *Field + SubKey string + Value any +} + +func (p *SubKeyPrefixed) WalkAncestor() *Field { + return &Field{ + Field: reflect.StructField{Name: p.SubKey}, + Ancestor: p.Ancestor, + } +} + +func (p *SubKeyPrefixed) WalkValue() any { return p.Value } + +// BuildSubKeyField returns a SubKeyPrefixed for a map[string]Struct or +// []Struct field, using "" or "" as the placeholder sub-key. +// It returns nil if the field type is neither. +func BuildSubKeyField(f *Field) *SubKeyPrefixed { + var ( + placeholder string + structType reflect.Type + ) + + if st := reflectutil.SubKeyMapElem(f.Field.Type); st != nil { + placeholder = "" + structType = st + } else if st := reflectutil.SubKeySliceElem(f.Field.Type); st != nil { + placeholder = "" + structType = st + } else { + return nil + } + + return &SubKeyPrefixed{ + Ancestor: f, + SubKey: placeholder, + Value: reflect.New(structType).Interface(), + } +} diff --git a/internal/walker/walker.go b/internal/walker/walker.go index b583088..3111175 100644 --- a/internal/walker/walker.go +++ b/internal/walker/walker.go @@ -21,15 +21,6 @@ type Field struct { type WalkFunc func(*Field) error -// Prefixed is an optional interface that a value passed to Walk can -// implement to inject dynamic key prefix segments. When Walk receives a -// Prefixed value it builds a synthetic ancestor chain from the prefix -// segments and walks the inner value returned by WalkValue. -type Prefixed interface { - WalkPrefix() []string - WalkValue() any -} - func Walk(in any, fn WalkFunc) error { return walkValue(in, fn, nil) } @@ -43,14 +34,25 @@ func walkValue(in any, fn WalkFunc, ancestor *Field) error { } func walkPrefixed(p Prefixed, fn WalkFunc, ancestor *Field) error { - for _, seg := range p.WalkPrefix() { - ancestor = &Field{ - Field: reflect.StructField{Name: seg}, - Ancestor: ancestor, - } + extra := p.WalkAncestor() + + if extra == nil { + return walkValue(p.WalkValue(), fn, ancestor) } - return walkValue(p.WalkValue(), fn, ancestor) + // Clone the extra chain and graft it onto the incoming ancestor + // so that the original chain is not mutated. + clone := &Field{Field: extra.Field} + tip := clone + + for cur := extra.Ancestor; cur != nil; cur = cur.Ancestor { + tip.Ancestor = &Field{Field: cur.Field} + tip = tip.Ancestor + } + + tip.Ancestor = ancestor + + return walkValue(p.WalkValue(), fn, clone) } func walkStruct(in any, fn WalkFunc, ancestor *Field) error { diff --git a/internal/walker/walker_test.go b/internal/walker/walker_test.go index 26e0ba4..45df914 100644 --- a/internal/walker/walker_test.go +++ b/internal/walker/walker_test.go @@ -1,6 +1,7 @@ package walker import ( + "reflect" "strings" "testing" @@ -38,8 +39,20 @@ type prefixed struct { value any } -func (p *prefixed) WalkPrefix() []string { return p.prefix } -func (p *prefixed) WalkValue() any { return p.value } +func (p *prefixed) WalkAncestor() *Field { + var ancestor *Field + + for _, seg := range p.prefix { + ancestor = &Field{ + Field: reflect.StructField{Name: seg}, + Ancestor: ancestor, + } + } + + return ancestor +} + +func (p *prefixed) WalkValue() any { return p.value } type outerWithPrefixed struct { Nested *prefixed diff --git a/provider/default/provider.go b/provider/default/provider.go index c3f0414..5ebe5fc 100644 --- a/provider/default/provider.go +++ b/provider/default/provider.go @@ -12,6 +12,10 @@ func (Provider) JoinFieldKeys(_, key string) string { return key } +func (Provider) SubKeys(context.Context, string) ([]string, error) { + return nil, nil +} + func (Provider) Provide(_ context.Context, k string) (string, bool, error) { if k == "" { return "", false, nil diff --git a/provider/env/provider.go b/provider/env/provider.go index 7629f17..74fd76a 100644 --- a/provider/env/provider.go +++ b/provider/env/provider.go @@ -45,3 +45,39 @@ func (p *Provider) Provide(_ context.Context, v string) (string, bool, error) { return res, ok, nil } + +func (p *Provider) SubKeys(_ context.Context, prefix string) ([]string, error) { + fullPrefix := p.buildPrefix() + prefix + "_" + + seen := make(map[string]struct{}) + + for _, entry := range os.Environ() { + if !strings.HasPrefix(entry, fullPrefix) { + continue + } + + rest := entry[len(fullPrefix):] + + if idx := strings.IndexByte(rest, '='); idx >= 0 { + rest = rest[:idx] + } + + if idx := strings.IndexByte(rest, '_'); idx >= 0 { + rest = rest[:idx] + } + + if rest == "" { + continue + } + + seen[rest] = struct{}{} + } + + keys := make([]string, 0, len(seen)) + + for k := range seen { + keys = append(keys, k) + } + + return keys, nil +} diff --git a/provider/env/provider_test.go b/provider/env/provider_test.go index 9016044..503131b 100644 --- a/provider/env/provider_test.go +++ b/provider/env/provider_test.go @@ -3,7 +3,11 @@ package env import ( "context" "os" + "sort" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestProvider_Provide(t *testing.T) { @@ -78,3 +82,71 @@ func TestProvider_Provide(t *testing.T) { }) } } + +func TestProvider_SubKeys(t *testing.T) { + for _, tc := range []struct { + name string + haveEnv map[string]string + haveKey string + prefix string + want []string + }{ + { + name: "no matching vars", + haveKey: "DB", + want: []string{}, + }, + { + name: "single sub-key", + haveEnv: map[string]string{"DB_PRIMARY_HOST": "localhost"}, + haveKey: "DB", + want: []string{"PRIMARY"}, + }, + { + name: "multiple sub-keys", + haveEnv: map[string]string{ + "DB_PRIMARY_HOST": "h1", + "DB_PRIMARY_PORT": "5432", + "DB_REPLICA_HOST": "h2", + "DB_SECONDARY_HOST": "h3", + }, + haveKey: "DB", + want: []string{"PRIMARY", "REPLICA", "SECONDARY"}, + }, + { + name: "with global prefix", + haveEnv: map[string]string{ + "APP_CACHE_REDIS_HOST": "r1", + "APP_CACHE_MEMCACHE_HOST": "m1", + "CACHE_OTHER_HOST": "x", + }, + haveKey: "CACHE", + prefix: "app", + want: []string{"MEMCACHE", "REDIS"}, + }, + { + name: "ignores vars with empty segment after prefix", + haveEnv: map[string]string{ + "DB_HOST": "localhost", + }, + haveKey: "DB_HOST", + want: []string{}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + for k, v := range tc.haveEnv { + t.Setenv(k, v) + } + + p := &Provider{prefix: tc.prefix} + + got, err := p.SubKeys(context.Background(), tc.haveKey) + + require.NoError(t, err) + + sort.Strings(got) + sort.Strings(tc.want) + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/provider/flags/provider.go b/provider/flags/provider.go index edf322c..9eb077e 100644 --- a/provider/flags/provider.go +++ b/provider/flags/provider.go @@ -78,17 +78,21 @@ func NewDefaultProvider() *Provider { } func NewProvider(args []string) *Provider { + fs := parseFlags(args) + return &Provider{ + flags: fs, sp: provider.NewStaticProvider( StructTag, - parseFlags(args), + fs, strings.ToLower, ), } } type Provider struct { - sp provider.Provider + flags map[string]string + sp provider.Provider } func kebabCase(s string) string { @@ -121,6 +125,38 @@ func (*Provider) JoinFieldKeys(prefix, key string) string { return prefix + "." + key } +func (p *Provider) SubKeys(_ context.Context, prefix string) ([]string, error) { + fullPrefix := prefix + "." + + seen := make(map[string]struct{}) + + for k := range p.flags { + if !strings.HasPrefix(k, fullPrefix) { + continue + } + + rest := k[len(fullPrefix):] + + if idx := strings.IndexByte(rest, '.'); idx >= 0 { + rest = rest[:idx] + } + + if rest == "" { + continue + } + + seen[rest] = struct{}{} + } + + keys := make([]string, 0, len(seen)) + + for k := range seen { + keys = append(keys, k) + } + + return keys, nil +} + func (*Provider) FormatKey(n string) string { n = strings.ToLower(n) diff --git a/provider/flags/provider_test.go b/provider/flags/provider_test.go index 695c753..881dcb6 100644 --- a/provider/flags/provider_test.go +++ b/provider/flags/provider_test.go @@ -1,9 +1,12 @@ package flags import ( + "context" + "sort" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestParseFlags(t *testing.T) { @@ -124,3 +127,60 @@ func TestKebabCase(t *testing.T) { }) } } + +func TestProvider_SubKeys(t *testing.T) { + for _, tc := range []struct { + name string + haveArgs []string + haveKey string + want []string + }{ + { + name: "no matching flags", + haveKey: "workers", + want: []string{}, + }, + { + name: "single sub-key", + haveArgs: []string{"--workers.0.host", "h0"}, + haveKey: "workers", + want: []string{"0"}, + }, + { + name: "multiple sub-keys", + haveArgs: []string{ + "--workers.0.host", "h0", + "--workers.0.port", "80", + "--workers.1.host", "h1", + }, + haveKey: "workers", + want: []string{"0", "1"}, + }, + { + name: "preserves original key casing", + haveArgs: []string{ + "--Workers.PRIMARY.Host", "h1", + }, + haveKey: "Workers", + want: []string{"PRIMARY"}, + }, + { + name: "ignores exact prefix without sub-segment", + haveArgs: []string{"--workers=true"}, + haveKey: "workers", + want: []string{}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + p := NewProvider(tc.haveArgs) + + got, err := p.SubKeys(context.Background(), tc.haveKey) + + require.NoError(t, err) + + sort.Strings(got) + sort.Strings(tc.want) + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/provider/json/provider.go b/provider/json/provider.go index dfc6b29..fa60bd8 100644 --- a/provider/json/provider.go +++ b/provider/json/provider.go @@ -74,6 +74,44 @@ func (p *Provider) Provide(_ context.Context, v string) (string, bool, error) { return stringifyValue(res), true, nil } +func (p *Provider) SubKeys(_ context.Context, prefix string) ([]string, error) { + cur := p.navigateTo(prefix) + + if cur == nil { + return nil, nil + } + + keys := make([]string, 0, len(cur)) + + for k := range cur { + keys = append(keys, k) + } + + return keys, nil +} + +func (p *Provider) navigateTo(prefix string) map[string]any { + cur := p.store + + for k := range strings.SplitSeq(prefix, ".") { + t := cur[k] + + if t == nil { + return nil + } + + next, ok := t.(map[string]any) + + if !ok { + return nil + } + + cur = next + } + + return cur +} + func stringifyValue(v interface{}) string { vv := reflect.ValueOf(v) diff --git a/provider/json/provider_test.go b/provider/json/provider_test.go index 48dce87..878cf0e 100644 --- a/provider/json/provider_test.go +++ b/provider/json/provider_test.go @@ -2,87 +2,133 @@ package json import ( "context" + "sort" + "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestProvider_Provide(t *testing.T) { - tests := []struct { - name string - p *Provider - in string - wantValue string - wantValueFn func(string) bool - wantExist bool - wantErr bool + for _, tc := range []struct { + name string + haveJSON string + haveKey string + wantValue string + assertVal func(*testing.T, string) + wantExist bool + wantErr error }{ { - name: "empty store", - p: &Provider{}, + name: "empty store", + haveJSON: `{}`, + haveKey: "foo", }, { name: "top level value", - p: &Provider{map[string]interface{}{"foo": "bar"}}, - in: "foo", + haveJSON: `{"foo":"bar"}`, + haveKey: "foo", wantValue: "bar", wantExist: true, }, { name: "slice value", - p: &Provider{map[string]interface{}{"foo": []int64{1, 2, 3}}}, - in: "foo", + haveJSON: `{"foo":[1,2,3]}`, + haveKey: "foo", wantValue: "1,2,3", wantExist: true, }, { - name: "map value", - p: &Provider{ - map[string]interface{}{ - "foo": map[string]int64{"foo": 1, "bar": 2}, - }, - }, - in: "foo", - wantValueFn: func(got string) bool { - switch got { - case "foo=1,bar=2", "bar=2,foo=1": - return true - } - - return false + name: "map value", + haveJSON: `{"foo":{"foo":1,"bar":2}}`, + haveKey: "foo", + assertVal: func(t *testing.T, got string) { + t.Helper() + assert.Contains(t, []string{"foo=1,bar=2", "bar=2,foo=1"}, got) }, wantExist: true, }, { name: "second level value", - p: &Provider{map[string]interface{}{"foo": map[string]interface{}{"fiz": "bar"}}}, - in: "foo.fiz", + haveJSON: `{"foo":{"fiz":"bar"}}`, + haveKey: "foo.fiz", wantValue: "bar", wantExist: true, }, { - name: "wrong format", - p: &Provider{map[string]interface{}{"foo": map[string]interface{}{"fiz": "bar"}}}, - in: "foo.fiz.buz", - wantErr: true, + name: "wrong format", + haveJSON: `{"foo":{"fiz":"bar"}}`, + haveKey: "foo.fiz.buz", + wantErr: ErrJSONMalformated, }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, got1, err := tt.p.Provide(context.Background(), tt.in) - if (err != nil) != tt.wantErr { - t.Errorf("Provider.Provide() error = %v, wantErr %v", err, tt.wantErr) - return - } + } { + t.Run(tc.name, func(t *testing.T) { + p := NewProviderFromReader(strings.NewReader(tc.haveJSON)) - if tt.wantValueFn != nil { - if !tt.wantValueFn(got) { - t.Errorf("Provider.Provide() got = %v", got) - } - } else if got != tt.wantValue { - t.Errorf("Provider.Provide() got = %v, want %v", got, tt.wantValue) - } - if got1 != tt.wantExist { - t.Errorf("Provider.Provide() got1 = %v, want %v", got1, tt.wantExist) + got, gotExist, err := p.Provide(context.Background(), tc.haveKey) + + require.ErrorIs(t, err, tc.wantErr) + + if tc.assertVal != nil { + tc.assertVal(t, got) + } else { + assert.Equal(t, tc.wantValue, got) } + + assert.Equal(t, tc.wantExist, gotExist) + }) + } +} + +func TestProvider_SubKeys(t *testing.T) { + for _, tc := range []struct { + name string + haveJSON string + haveKey string + want []string + }{ + { + name: "empty store", + haveJSON: `{}`, + haveKey: "workers", + want: nil, + }, + { + name: "top level map keys", + haveJSON: `{"workers":{"0":{"host":"h0"},"1":{"host":"h1"}}}`, + haveKey: "workers", + want: []string{"0", "1"}, + }, + { + name: "nested prefix", + haveJSON: `{"db":{"shards":{"primary":{"host":"h1"},"replica":{"host":"h2"}}}}`, + haveKey: "db.shards", + want: []string{"primary", "replica"}, + }, + { + name: "prefix points to non-map value", + haveJSON: `{"workers":"not-a-map"}`, + haveKey: "workers", + want: nil, + }, + { + name: "prefix not found", + haveJSON: `{"foo":{"bar":"baz"}}`, + haveKey: "missing", + want: nil, + }, + } { + t.Run(tc.name, func(t *testing.T) { + p := NewProviderFromReader(strings.NewReader(tc.haveJSON)) + + got, err := p.(*Provider).SubKeys(context.Background(), tc.haveKey) + + require.NoError(t, err) + + sort.Strings(got) + sort.Strings(tc.want) + assert.Equal(t, tc.want, got) }) } } diff --git a/provider/provider.go b/provider/provider.go index a0212fd..476cb37 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -57,11 +57,19 @@ func (sp *staticProvider) Provide(_ context.Context, k string) (string, bool, er // The standard provider joins them with "."; other providers may need a // different strategy (e.g. the default-value provider filters out empty // parts). +// +// SubKeys enumerates dynamic sub-keys under a given prefix. This is +// used by the configurator to populate map[string]Struct fields: for +// each discovered sub-key a new struct value is allocated and +// recursively walked with the sub-key injected as an ancestor prefix. +// Providers that do not support key enumeration should return a nil +// slice. type FullyQualifiedProvider interface { Provider DefaultFieldValue(fieldName string) string JoinFieldKeys(prefix, key string) string + SubKeys(ctx context.Context, prefix string) ([]string, error) } // WrapFullyQualifiedProvider returns p as a FullyQualifiedProvider. If @@ -88,6 +96,10 @@ func (d *defaultFQProvider) JoinFieldKeys(prefix, key string) string { return prefix + "." + key } +func (*defaultFQProvider) SubKeys(context.Context, string) ([]string, error) { + return nil, nil +} + // KeyFormatter is an optional interface that providers can implement to // control how keys are displayed in help and synopsis output. type KeyFormatter interface { diff --git a/x/cli/output/command.go b/x/cli/output/command.go index 00d9e36..5d7bf2c 100644 --- a/x/cli/output/command.go +++ b/x/cli/output/command.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "reflect" "slices" "sort" "strings" @@ -11,6 +12,7 @@ import ( "github.com/upfluence/errors" "github.com/upfluence/cfg" + "github.com/upfluence/cfg/internal/walker" "github.com/upfluence/cfg/x/cli" "github.com/upfluence/cfg/x/cli/output/printer" "github.com/upfluence/cfg/x/cli/output/printer/json" @@ -86,21 +88,21 @@ func WrapDefaultCommand[T any](cmd Command[T], additionalPrinters ...printer.Pri ) } -type prefixedConfig struct { - prefix string - value any +var outputAncestor = &walker.Field{ + Field: reflect.StructField{Name: "output"}, } -func (p *prefixedConfig) WalkPrefix() []string { return []string{"output", p.prefix} } -func (p *prefixedConfig) WalkValue() any { return p.value } - type prefixedConfigurator struct { inner cfg.Configurator prefix string } func (pc *prefixedConfigurator) Populate(ctx context.Context, in any) error { - return pc.inner.Populate(ctx, &prefixedConfig{prefix: pc.prefix, value: in}) //nolint:wrapcheck + return pc.inner.Populate(ctx, &walker.SubKeyPrefixed{ //nolint:wrapcheck + Ancestor: outputAncestor, + SubKey: pc.prefix, + Value: in, + }) } func (pc *prefixedConfigurator) WithOptions(opts ...cfg.Option) cfg.Configurator { @@ -130,7 +132,11 @@ func (wc *wrappedCommand[T]) wrapIntrospectionOptions(opts cli.IntrospectionOpti key := p.Key() for i, c := range def.Configs { - def.Configs[i] = &prefixedConfig{prefix: key, value: c} + def.Configs[i] = &walker.SubKeyPrefixed{ + Ancestor: outputAncestor, + SubKey: key, + Value: c, + } } opts.Definitions = append(opts.Definitions, def)