diff --git a/encode.go b/encode.go index 302d993..d23ef5e 100644 --- a/encode.go +++ b/encode.go @@ -144,24 +144,178 @@ func encodeValue(v reflect.Value, z bool, o bool) interface{} { } } +type encoderField struct { + index []int + name string + omitempty bool +} + func encodeStruct(v reflect.Value, z bool, o bool) interface{} { - t := v.Type() + fields := collectFields(v.Type()) n := node{} - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - k, oe := fieldInfo(f) - - if k == "-" { + for _, f := range fields { + fv := fieldByIndex(v, f.index) + if !fv.IsValid() { continue - } else if fv := v.Field(i); (o || oe) && isEmptyValue(fv) { - delete(n, k) - } else { - n[k] = encodeValue(fv, z, o) } + if (o || f.omitempty) && isEmptyValue(fv) { + continue + } + n[f.name] = encodeValue(fv, z, o) } return n } +func hasExplicitTag(f reflect.StructField) bool { + tag := f.Tag.Get("form") + if tag == "" { + tag = f.Tag.Get("json") + } + if tag == "" { + return false + } + return strings.SplitN(tag, ",", 2)[0] != "" +} + +func shouldPromote(f reflect.StructField) bool { + return f.Anonymous && !hasExplicitTag(f) +} + +func collectFields(t reflect.Type) []encoderField { + type queueItem struct { + typ reflect.Type + index []int + depth int + } + type fieldCandidate struct { + field encoderField + depth int + tagged bool + } + + current := []queueItem{{typ: t}} + visited := map[reflect.Type]bool{} + candidatesByName := map[string][]fieldCandidate{} + nameOrder := []string{} + + for len(current) > 0 { + var next []queueItem + for _, item := range current { + if visited[item.typ] { + continue + } + visited[item.typ] = true + + for i := 0; i < item.typ.NumField(); i++ { + f := item.typ.Field(i) + k, oe := fieldInfo(f) + if k == omittedKey { + continue + } + + idx := make([]int, len(item.index)+1) + copy(idx, item.index) + idx[len(item.index)] = i + + if shouldPromote(f) { + ft := f.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if ft.Kind() == reflect.Struct && !isLeafStruct(ft) { + next = append(next, queueItem{typ: ft, index: idx, depth: item.depth + 1}) + continue + } + } + + tagged := hasExplicitTag(f) + fc := fieldCandidate{ + field: encoderField{ + index: idx, + name: k, + omitempty: oe, + }, + depth: item.depth, + tagged: tagged, + } + + if _, exists := candidatesByName[k]; !exists { + nameOrder = append(nameOrder, k) + } + candidatesByName[k] = append(candidatesByName[k], fc) + } + } + + current = next + } + + // Resolve conflicts + var result []encoderField + for _, name := range nameOrder { + cands := candidatesByName[name] + if len(cands) == 1 { + result = append(result, cands[0].field) + continue + } + + // Multiple candidates: keep only those at minimum depth + minDepth := cands[0].depth + for _, c := range cands[1:] { + if c.depth < minDepth { + minDepth = c.depth + } + } + var filtered []fieldCandidate + for _, c := range cands { + if c.depth == minDepth { + filtered = append(filtered, c) + } + } + + if len(filtered) == 1 { + result = append(result, filtered[0].field) + continue + } + + // Still multiple at same depth: keep only tagged ones + var tagged []fieldCandidate + for _, c := range filtered { + if c.tagged { + tagged = append(tagged, c) + } + } + + if len(tagged) == 1 { + result = append(result, tagged[0].field) + continue + } + + // Still multiple or none tagged: ambiguous, omit entirely + } + + return result +} + +func fieldByIndex(v reflect.Value, index []int) reflect.Value { + for _, i := range index { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return reflect.Value{} + } + v = v.Elem() + } + v = v.Field(i) + } + return v +} + +func isLeafStruct(ft reflect.Type) bool { + if ft.ConvertibleTo(timeType) || ft.ConvertibleTo(urlType) { + return true + } + return ft.Implements(textMarshalerType) || reflect.PtrTo(ft).Implements(textMarshalerType) +} + func encodeMap(v reflect.Value, z bool, o bool) interface{} { n := node{} for _, i := range v.MapKeys() { @@ -358,11 +512,12 @@ func findField(v reflect.Value, n string, ignoreCase bool) (reflect.Value, bool) } var ( - stringType = reflect.TypeOf(string("")) - stringMapType = reflect.TypeOf(map[string]interface{}{}) - timeType = reflect.TypeOf(time.Time{}) - timePtrType = reflect.TypeOf(&time.Time{}) - urlType = reflect.TypeOf(url.URL{}) + stringType = reflect.TypeOf(string("")) + stringMapType = reflect.TypeOf(map[string]interface{}{}) + timeType = reflect.TypeOf(time.Time{}) + timePtrType = reflect.TypeOf(&time.Time{}) + urlType = reflect.TypeOf(url.URL{}) + textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() ) func skipTextMarshalling(t reflect.Type) bool { diff --git a/encode_test.go b/encode_test.go index 10326d5..3246f33 100644 --- a/encode_test.go +++ b/encode_test.go @@ -136,3 +136,35 @@ func TestEncode_OmitEmpty(t *testing.T) { } } } + +func TestEncode_ConflictResolution(t *testing.T) { + for _, c := range []struct { + name string + b interface{} + s string + }{ + { + "depth shadow: parent field wins", + &DepthShadow{X: "parent", DepthInner: DepthInner{X: "child"}}, + "X=parent", + }, + { + "ambiguous: same-depth fields omitted", + &Ambiguous{AmbigA: AmbigA{X: "a"}, AmbigB: AmbigB{X: "b"}}, + "", + }, + { + "tagged wins: tagged field beats untagged at same depth", + &TaggedWins{TaggedInner: TaggedInner{X: "tagged"}, UntaggedInner: UntaggedInner{X: "untagged"}}, + "X=tagged", + }, + } { + t.Run(c.name, func(t *testing.T) { + if s, err := EncodeToString(c.b); err != nil { + t.Errorf("EncodeToString(%#v): %s", c.b, err) + } else if s != c.s { + t.Errorf("EncodeToString(%#v)\n want %q\n have %q", c.b, c.s, s) + } + }) + } +} diff --git a/form_test.go b/form_test.go index 36408ce..44d8754 100644 --- a/form_test.go +++ b/form_test.go @@ -32,8 +32,8 @@ type Struct struct { type SXs map[string]interface{} type E struct { - Bytes1 []byte // For testing explicit (qualified by embedder) name, e.g. "E.Bytes1". - Bytes2 []byte // For testing implicit (unqualified) name, e.g. just "Bytes2" + Bytes1 []byte // Promoted to parent during struct encoding, e.g. just "Bytes1". + Bytes2 []byte // Promoted to parent during struct encoding, e.g. just "Bytes2". } type Z time.Time // Defined as such to test conversions. @@ -126,6 +126,7 @@ func testCases(dir direction) (cs []testCase) { var T time.Time var U url.URL const canonical = `A.0=x&A.1=y&A.2=z&B=true&C=42%2B6.6i&E.Bytes1=%00%01%02&E.Bytes2=%03%04%05&F=6.6&M.Bar=8&M.Foo=7&M.Qux=9&P%5C.D%5C%5CQ%5C.B.A=P%2FD&P%5C.D%5C%5CQ%5C.B.B=Q-B&R=8734&S=Hello%2C+there.&T=2013-10-01T07%3A05%3A34.000000088Z&U=http%3A%2F%2Fexample.org%2Ffoo%23bar&Zs.0.Q=11_22&Zs.0.Qp=33_44&Zs.0.Z=2006-12-01&life=42` + const structCanonical = `A.0=x&A.1=y&A.2=z&B=true&Bytes1=%00%01%02&Bytes2=%03%04%05&C=42%2B6.6i&F=6.6&M.Bar=8&M.Foo=7&M.Qux=9&P%5C.D%5C%5CQ%5C.B.A=P%2FD&P%5C.D%5C%5CQ%5C.B.B=Q-B&R=8734&S=Hello%2C+there.&T=2013-10-01T07%3A05%3A34.000000088Z&U=http%3A%2F%2Fexample.org%2Ffoo%23bar&Zs.0.Q=11_22&Zs.0.Qp=33_44&Zs.0.Z=2006-12-01&life=42` const variation = `C=42%2B6.6i&A.0=x&M.Bar=8&F=6.6&A.1=y&R=8734&A.2=z&Zs.0.Qp=33_44&B=true&M.Foo=7&T=2013-10-01T07:05:34.000000088Z&E.Bytes1=%00%01%02&Bytes2=%03%04%05&Zs.0.Q=11_22&Zs.0.Z=2006-12-01&M.Qux=9&life=42&S=Hello,+there.&P\.D\\Q\.B.A=P/D&P\.D\\Q\.B.B=Q-B&U=http%3A%2F%2Fexample.org%2Ffoo%23bar` for _, c := range []testCase{ @@ -176,7 +177,7 @@ func testCases(dir direction) (cs []testCase) { {rndTrip, &U, "=git%3A%2F%2Fgithub.com%2Fajg%2Fform.git", u(url.URL{Scheme: "git", Host: "github.com", Path: "/ajg/form.git"})}, // Structs - {rndTrip, &Struct{Y: 786}, canonical, + {rndTrip, &Struct{Y: 786}, structCanonical, &Struct{ true, 42, @@ -340,3 +341,27 @@ func mustParseQuery(s string) url.Values { } return vs } + +// Conflict resolution test types + +type DepthInner struct{ X string } +type DepthShadow struct { + X string + DepthInner +} + +type AmbigA struct{ X string } +type AmbigB struct{ X string } +type Ambiguous struct { + AmbigA + AmbigB +} + +type TaggedInner struct { + X string `form:"X"` +} +type UntaggedInner struct{ X string } +type TaggedWins struct { + TaggedInner + UntaggedInner +}