diff --git a/README.md b/README.md index 2df9146..5719c4c 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,8 @@ result2, err := interpreter.Run(map[string]any{ Pretty errors use the passed-in input along with the error's offset to display an arrow of where within the expression the error occurs. +Offsets and lengths are rune-based, so caret placement stays correct for Unicode input. + ```go inputStr := "2 * foo" _, err := mexpr.Eval(inputStr, nil) @@ -173,6 +175,8 @@ Current limitations: - functions must have exactly one return value - zero-argument scalar functions can also be used as lazy values, e.g. `id + 1` +Numeric arguments are coerced to the target Go type using standard Go conversions. For example, calling a function that takes an `int` with `1.9` will pass `1` to the function. If you need fractional behavior, make the function accept a float type. + ### String operators - Indexing, e.g. `foo[0]` @@ -208,6 +212,8 @@ String dates & times can be compared if they follow RFC 3339 / ISO 8601 with or - `in` (has item), e.g. `1 in foo` - `contains` e.g. `foo contains 1` +Common Go slice inputs like `[]any`, `[]int`, `[]float64`, and `[]string` are supported directly without normalizing them into `[]any` first. Other slice and array types fall back to reflection. + Indexes are zero-based. Slice indexes are optional and are _inclusive_. `foo[1:2]` returns `[2, 3]` if the `foo` is `[1, 2, 3, 4]`. Indexes can be negative, e.g. `foo[-1]` selects the last item in the array. #### Array/slice filtering diff --git a/conversions.go b/conversions.go index 31810c2..744aa5c 100644 --- a/conversions.go +++ b/conversions.go @@ -140,6 +140,282 @@ func stringSlice(v string, start, end int) string { return v[from:to] } +func isByteArrayOrSlice(t reflect.Type) bool { + if t == nil { + return false + } + if t.Kind() != reflect.Array && t.Kind() != reflect.Slice { + return false + } + return t.Elem().Kind() == reflect.Uint8 +} + +func sliceLen(v any) (int, bool) { + switch s := v.(type) { + case []any: + return len(s), true + case []int: + return len(s), true + case []float64: + return len(s), true + case []string: + return len(s), true + } + + rv := reflect.ValueOf(v) + if !rv.IsValid() || isByteArrayOrSlice(rv.Type()) { + return 0, false + } + if rv.Kind() != reflect.Array && rv.Kind() != reflect.Slice { + return 0, false + } + return rv.Len(), true +} + +func isSlice(v any) bool { + _, ok := sliceLen(v) + return ok +} + +func sliceItem(v any, idx int) (any, bool) { + switch s := v.(type) { + case []any: + return s[idx], true + case []int: + return s[idx], true + case []float64: + return s[idx], true + case []string: + return s[idx], true + } + + rv := reflect.ValueOf(v) + if !rv.IsValid() || isByteArrayOrSlice(rv.Type()) { + return nil, false + } + if rv.Kind() != reflect.Array && rv.Kind() != reflect.Slice { + return nil, false + } + return rv.Index(idx).Interface(), true +} + +func sliceRange(v any, start, end int) (any, bool) { + switch s := v.(type) { + case []any: + return s[start : end+1], true + case []int: + return s[start : end+1], true + case []float64: + return s[start : end+1], true + case []string: + return s[start : end+1], true + } + + rv := reflect.ValueOf(v) + if !rv.IsValid() || isByteArrayOrSlice(rv.Type()) { + return nil, false + } + if rv.Kind() == reflect.Array { + copyValue := reflect.New(rv.Type()).Elem() + copyValue.Set(rv) + return copyValue.Slice(start, end+1).Interface(), true + } + if rv.Kind() == reflect.Slice { + return rv.Slice(start, end+1).Interface(), true + } + return nil, false +} + +func appendSliceItems(dst []any, v any) ([]any, bool) { + switch s := v.(type) { + case []any: + return append(dst, s...), true + case []int: + for _, item := range s { + dst = append(dst, item) + } + return dst, true + case []float64: + for _, item := range s { + dst = append(dst, item) + } + return dst, true + case []string: + for _, item := range s { + dst = append(dst, item) + } + return dst, true + } + + rv := reflect.ValueOf(v) + if !rv.IsValid() || isByteArrayOrSlice(rv.Type()) { + return nil, false + } + if rv.Kind() != reflect.Array && rv.Kind() != reflect.Slice { + return nil, false + } + for idx := 0; idx < rv.Len(); idx++ { + dst = append(dst, rv.Index(idx).Interface()) + } + return dst, true +} + +func concatSlices(left, right any) (any, bool) { + switch l := left.(type) { + case []any: + if r, ok := right.([]any); ok { + out := make([]any, 0, len(l)+len(r)) + out = append(out, l...) + out = append(out, r...) + return out, true + } + case []int: + if r, ok := right.([]int); ok { + out := make([]int, 0, len(l)+len(r)) + out = append(out, l...) + out = append(out, r...) + return out, true + } + case []float64: + if r, ok := right.([]float64); ok { + out := make([]float64, 0, len(l)+len(r)) + out = append(out, l...) + out = append(out, r...) + return out, true + } + case []string: + if r, ok := right.([]string); ok { + out := make([]string, 0, len(l)+len(r)) + out = append(out, l...) + out = append(out, r...) + return out, true + } + } + + leftLen, ok := sliceLen(left) + if !ok { + return nil, false + } + rightLen, ok := sliceLen(right) + if !ok { + return nil, false + } + out := make([]any, 0, leftLen+rightLen) + out, ok = appendSliceItems(out, left) + if !ok { + return nil, false + } + return appendSliceItems(out, right) +} + +func iterateSlice(v any, yield func(any) bool) bool { + switch s := v.(type) { + case []any: + for _, item := range s { + if !yield(item) { + return true + } + } + return true + case []int: + for _, item := range s { + if !yield(item) { + return true + } + } + return true + case []float64: + for _, item := range s { + if !yield(item) { + return true + } + } + return true + case []string: + for _, item := range s { + if !yield(item) { + return true + } + } + return true + } + + rv := reflect.ValueOf(v) + if !rv.IsValid() || isByteArrayOrSlice(rv.Type()) { + return false + } + if rv.Kind() != reflect.Array && rv.Kind() != reflect.Slice { + return false + } + for idx := 0; idx < rv.Len(); idx++ { + if !yield(rv.Index(idx).Interface()) { + return true + } + } + return true +} + +func recursiveDeepEqual(left, right any) bool { + l := normalize(left) + r := normalize(right) + + switch lv := l.(type) { + case float64: + rv, ok := r.(float64) + return ok && lv == rv + case string: + rv, ok := r.(string) + return ok && lv == rv + case bool: + rv, ok := r.(bool) + return ok && lv == rv + } + + if lLen, ok := sliceLen(l); ok { + rLen, ok := sliceLen(r) + if !ok || lLen != rLen { + return false + } + for idx := 0; idx < lLen; idx++ { + leftItem, _ := sliceItem(l, idx) + rightItem, _ := sliceItem(r, idx) + if !recursiveDeepEqual(leftItem, rightItem) { + return false + } + } + return true + } + + switch lv := l.(type) { + case map[string]any: + rv, ok := r.(map[string]any) + if !ok || len(lv) != len(rv) { + return false + } + for key, leftValue := range lv { + rightValue, ok := rv[key] + if !ok || !recursiveDeepEqual(leftValue, rightValue) { + return false + } + } + return true + case map[any]any: + rv, ok := r.(map[any]any) + if !ok || len(lv) != len(rv) { + return false + } + for key, leftValue := range lv { + rightValue, ok := rv[key] + if !ok || !recursiveDeepEqual(leftValue, rightValue) { + return false + } + } + return true + } + + return reflect.DeepEqual(l, r) +} + // toTime converts a string value into a time.Time if possible, otherwise // returns a zero time. func toTime(v any) time.Time { @@ -156,13 +432,6 @@ func toTime(v any) time.Time { return time.Time{} } -func isSlice(v any) bool { - if _, ok := v.([]any); ok { - return true - } - return false -} - func toBool(v any) bool { switch n := v.(type) { case bool: @@ -195,13 +464,14 @@ func toBool(v any) bool { return len(n) > 0 case []byte: return len(n) > 0 - case []any: - return len(n) > 0 case map[string]any: return len(n) > 0 case map[any]any: return len(n) > 0 } + if l, ok := sliceLen(v); ok { + return l > 0 + } return false } @@ -269,21 +539,5 @@ func normalize(v any) any { // deepEqual returns whether two values are deeply equal. func deepEqual(left, right any) bool { - l := normalize(left) - r := normalize(right) - - // Optimization for simple types to prevent allocations - switch l.(type) { - case float64: - if f, ok := r.(float64); ok { - return l == f - } - case string: - if s, ok := r.(string); ok { - return l == s - } - } - - // Otherwise, just use the built-in deep equality check. - return reflect.DeepEqual(left, right) + return recursiveDeepEqual(left, right) } diff --git a/conversions_test.go b/conversions_test.go new file mode 100644 index 0000000..419f408 --- /dev/null +++ b/conversions_test.go @@ -0,0 +1,77 @@ +package mexpr + +import ( + "reflect" + "testing" +) + +func TestToNumber(t *testing.T) { + ast := &Node{} + cases := []struct { + name string + value any + want float64 + }{ + {name: "int32", value: int32(3), want: 3}, + {name: "uint64", value: uint64(4), want: 4}, + {name: "float32", value: float32(1.5), want: 1.5}, + {name: "func int16", value: func() int16 { return 7 }, want: 7}, + {name: "func uint32", value: func() uint32 { return 8 }, want: 8}, + {name: "func float32", value: func() float32 { return 2.5 }, want: 2.5}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := toNumber(ast, tc.value) + if err != nil { + t.Fatal(err) + } + if got != tc.want { + t.Fatalf("expected %v but found %v", tc.want, got) + } + }) + } +} + +func TestAppendAndConcatGenericSlices(t *testing.T) { + appended, ok := appendSliceItems(nil, [3]uint{1, 2, 3}) + if !ok { + t.Fatal("expected array append to succeed") + } + if !reflect.DeepEqual([]any{uint(1), uint(2), uint(3)}, appended) { + t.Fatalf("unexpected appended items: %v", appended) + } + + concatenated, ok := concatSlices([2]uint{1, 2}, [2]uint{3, 4}) + if !ok { + t.Fatal("expected array concat to succeed") + } + if !reflect.DeepEqual([]any{uint(1), uint(2), uint(3), uint(4)}, concatenated) { + t.Fatalf("unexpected concatenated items: %v", concatenated) + } +} + +func TestToBool(t *testing.T) { + cases := []struct { + name string + value any + want bool + }{ + {name: "int16 positive", value: int16(1), want: true}, + {name: "uint32 zero", value: uint32(0), want: false}, + {name: "float32 positive", value: float32(1.25), want: true}, + {name: "bytes empty", value: []byte{}, want: false}, + {name: "bytes non-empty", value: []byte("x"), want: true}, + {name: "map any empty", value: map[any]any{}, want: false}, + {name: "map any non-empty", value: map[any]any{"k": 1}, want: true}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := toBool(tc.value) + if got != tc.want { + t.Fatalf("expected %v but found %v", tc.want, got) + } + }) + } +} \ No newline at end of file diff --git a/error.go b/error.go index 873d4f7..4690f52 100644 --- a/error.go +++ b/error.go @@ -1,15 +1,19 @@ package mexpr -import "fmt" +import ( + "fmt" + "strings" + "unicode/utf8" +) // Error represents an error at a specific location. type Error interface { Error() string - // Offset returns the character offset of the error within the experssion. + // Offset returns the rune offset of the error within the expression. Offset() uint16 - // Length returns the length in bytes after the offset where the error ends. + // Length returns the rune length after the offset where the error ends. Length() uint8 // Pretty prints out a message with a pointer to the source location of the @@ -36,14 +40,22 @@ func (e *exprErr) Length() uint8 { } func (e *exprErr) Pretty(source string) string { - msg := e.Error() + "\n" + source + "\n" + var msg strings.Builder + msg.WriteString(e.Error()) + msg.WriteByte('\n') + msg.WriteString(source) + msg.WriteByte('\n') for i := uint16(0); i < e.offset; i++ { - msg += "." + msg.WriteByte('.') } - for i := uint8(0); i < e.length; i++ { - msg += "^" + length := e.length + if length == 0 && utf8.RuneCountInString(source) > int(e.offset) { + length = 1 } - return msg + for i := uint8(0); i < length; i++ { + msg.WriteByte('^') + } + return msg.String() } // NewError creates a new error at a specific location. diff --git a/interpreter.go b/interpreter.go index cbdea15..e26259c 100644 --- a/interpreter.go +++ b/interpreter.go @@ -32,9 +32,9 @@ func mapValues[M ~map[K]V, K comparable, V any](m M) []V { // checkBounds returns an error if the index is out of bounds. func checkBounds(ast *Node, input any, idx int) Error { - if v, ok := input.([]any); ok { - if idx < 0 || idx >= len(v) { - return NewError(ast.Offset, ast.Length, "invalid index %d for slice of length %d", int(idx), len(v)) + if l, ok := sliceLen(input); ok { + if idx < 0 || idx >= l { + return NewError(ast.Offset, ast.Length, "invalid index %d for slice of length %d", int(idx), l) } } if v, ok := input.(string); ok { @@ -111,8 +111,8 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { if s, ok := value.(string); ok { return stringLength(s), nil } - if a, ok := value.([]any); ok { - return len(a), nil + if l, ok := sliceLen(value); ok { + return l, nil } case "lower": if s, ok := value.(func() string); ok { @@ -187,23 +187,27 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { if err != nil { return nil, err } - if left, ok := resultLeft.([]any); ok { + if leftLen, ok := sliceLen(resultLeft); ok { if start < 0 { - start += float64(len(left)) + start += float64(leftLen) } if end < 0 { - end += float64(len(left)) + end += float64(leftLen) } - if err := checkBounds(ast, left, int(start)); err != nil { + if err := checkBounds(ast, resultLeft, int(start)); err != nil { return nil, err } - if err := checkBounds(ast, left, int(end)); err != nil { + if err := checkBounds(ast, resultLeft, int(end)); err != nil { return nil, err } if int(start) > int(end) { return nil, NewError(ast.Offset, ast.Length, "slice start cannot be greater than end") } - return left[int(start) : int(end)+1], nil + result, ok := sliceRange(resultLeft, int(start), int(end)) + if !ok { + return nil, NewError(ast.Offset, ast.Length, "can only index strings or arrays but got %v", resultLeft) + } + return result, nil } left := toString(resultLeft) leftLen := stringLength(left) @@ -228,32 +232,38 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { if err != nil { return nil, err } - if isSlice(resultRight) && len(resultRight.([]any)) == 2 { - start, err := toNumber(ast, resultRight.([]any)[0]) + if rightLen, ok := sliceLen(resultRight); ok && rightLen == 2 { + startValue, _ := sliceItem(resultRight, 0) + start, err := toNumber(ast, startValue) if err != nil { return nil, err } - end, err := toNumber(ast, resultRight.([]any)[1]) + endValue, _ := sliceItem(resultRight, 1) + end, err := toNumber(ast, endValue) if err != nil { return nil, err } - if left, ok := resultLeft.([]any); ok { + if leftLen, ok := sliceLen(resultLeft); ok { if start < 0 { - start += float64(len(left)) + start += float64(leftLen) } if end < 0 { - end += float64(len(left)) + end += float64(leftLen) } - if err := checkBounds(ast, left, int(start)); err != nil { + if err := checkBounds(ast, resultLeft, int(start)); err != nil { return nil, err } - if err := checkBounds(ast, left, int(end)); err != nil { + if err := checkBounds(ast, resultLeft, int(end)); err != nil { return nil, err } if int(start) > int(end) { return nil, NewError(ast.Offset, ast.Length, "slice start cannot be greater than end") } - return left[int(start) : int(end)+1], nil + result, ok := sliceRange(resultLeft, int(start), int(end)) + if !ok { + return nil, NewError(ast.Offset, ast.Length, "can only index strings or arrays but got %v", resultLeft) + } + return result, nil } left := toString(resultLeft) leftLen := stringLength(left) @@ -279,14 +289,18 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { if err != nil { return nil, err } - if left, ok := resultLeft.([]any); ok { + if leftLen, ok := sliceLen(resultLeft); ok { if idx < 0 { - idx += float64(len(left)) + idx += float64(leftLen) } - if err := checkBounds(ast, left, int(idx)); err != nil { + if err := checkBounds(ast, resultLeft, int(idx)); err != nil { return nil, err } - return left[int(idx)], nil + result, ok := sliceItem(resultLeft, int(idx)) + if !ok { + return nil, NewError(ast.Offset, ast.Length, "can only index strings or arrays but got %v", resultLeft) + } + return result, nil } left := toString(resultLeft) leftLen := stringLength(left) @@ -338,8 +352,9 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { return toString(resultLeft) + toString(resultRight), nil } if isSlice(resultLeft) && isSlice(resultRight) { - tmp := append([]any{}, resultLeft.([]any)...) - return append(tmp, resultRight.([]any)...), nil + if out, ok := concatSlices(resultLeft, resultRight); ok { + return out, nil + } } } if isNumber(resultLeft) && isNumber(resultRight) { @@ -367,7 +382,7 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { if int(right) == 0 { return nil, NewError(ast.Offset, ast.Length, "cannot divide by zero") } - return int(left) % int(right), nil + return float64(int(left) % int(right)), nil case NodePower: return math.Pow(left, right), nil } @@ -467,13 +482,16 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { } switch ast.Type { case NodeIn: - if a, ok := resultRight.([]any); ok { - for _, item := range a { + if isSlice(resultRight) { + matched := false + iterateSlice(resultRight, func(item any) bool { if deepEqual(item, resultLeft) { - return true, nil + matched = true + return false } - } - return false, nil + return true + }) + return matched, nil } if m, ok := resultRight.(map[string]any); ok { _, ok := m[toString(resultLeft)] @@ -485,13 +503,16 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { } return strings.Contains(toString(resultRight), toString(resultLeft)), nil case NodeContains: - if a, ok := resultLeft.([]any); ok { - for _, item := range a { + if isSlice(resultLeft) { + matched := false + iterateSlice(resultLeft, func(item any) bool { if deepEqual(item, resultRight) { - return true, nil + matched = true + return false } - } - return false, nil + return true + }) + return matched, nil } if m, ok := resultLeft.(map[string]any); ok { _, ok := m[toString(resultRight)] @@ -533,19 +554,24 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { } resultLeft = values } - if leftSlice, ok := resultLeft.([]any); ok { - for _, item := range leftSlice { + if isSlice(resultLeft) { + iterateSlice(resultLeft, func(item any) bool { // In an unquoted string scenario it makes no sense for the first/only // token after a `where` clause to be treated as a string. Instead we // treat a `where` the same as a field select `.` in this scenario. i.prevFieldSelect = true - resultRight, err := i.run(ast.Right, item) - if i.strict && err != nil { - return nil, err + resultRight, runErr := i.run(ast.Right, item) + if i.strict && runErr != nil { + err = runErr + return false } if toBool(resultRight) { results = append(results, item) } + return true + }) + if err != nil { + return nil, err } } return results, nil diff --git a/interpreter_test.go b/interpreter_test.go index d260086..2253c39 100644 --- a/interpreter_test.go +++ b/interpreter_test.go @@ -28,7 +28,7 @@ func TestInterpreter(t *testing.T) { {expr: `1_000_000 + 1`, output: 1000001.0}, // Mul/div {expr: "4 * 5 / 10", output: 2.0}, - {expr: `19 % x`, input: `{"x": 5}`, output: 4}, + {expr: `19 % x`, input: `{"x": 5}`, output: 4.0}, // Power {expr: "2^3", output: 8.0}, {expr: "2^3^2", output: 512.0}, @@ -156,7 +156,7 @@ func TestInterpreter(t *testing.T) { {expr: `foo where method == "GET"`, inputParsed: map[any]any{"foo": map[any]any{"op1": map[any]any{"method": "GET", "path": "/op1"}, "op2": map[any]any{"method": "PUT", "path": "/op2"}, "op3": map[any]any{"method": "DELETE", "path": "/op3"}}}, output: []any{map[any]any{"method": "GET", "path": "/op1"}}}, {expr: `items where id > 0`, input: `{"items": [{"id": 1}, "x", {"id": 2}]}`, output: []any{map[string]any{"id": 1.0}, map[string]any{"id": 2.0}}}, {expr: `foo where id > 0`, input: `{"foo": {"a": "x", "b": {"id": 1}, "c": {"id": 2}}}`, unordered: true, output: []any{map[string]any{"id": 1.0}, map[string]any{"id": 2.0}}}, - {expr: `items where id > 3`, input: `{"items": []}`, err: "where clause requires a non-empty array or object"}, + {expr: `items where id > 3`, input: `{"items": []}`, output: []any{}}, {expr: `items where id > 3`, input: `{"items": 1}`, skipTC: true, output: []any{}}, // Order of operations {expr: "1 + 2 + 3", output: 6.0}, @@ -168,7 +168,7 @@ func TestInterpreter(t *testing.T) { {expr: "6 -", err: "incomplete expression"}, {expr: `foo.bar + "baz"`, input: `{"foo": 1}`, err: "no property bar"}, {expr: `foo + 1`, input: `{"foo": [1, 2]}`, err: "cannot operate on incompatible types"}, - {expr: `foo > 1`, input: `{"foo": []}`, err: "cannot compare array[] with number"}, + {expr: `foo > 1`, input: `{"foo": []}`, err: "cannot compare array[unknown] with number"}, {expr: `foo[1-]`, input: `{"foo": "hello"}`, err: "unexpected right-bracket"}, {expr: `not (1- <= 5)`, err: "missing right operand"}, {expr: `(1 >=)`, err: "unexpected right-paren"}, @@ -311,6 +311,7 @@ func TestTypedFunctions(t *testing.T) { {expr: "name.lower == \"mexpr\"", output: true}, {expr: "name.length == 5", output: true}, {expr: "enabled and a > 1", output: true}, + {expr: "add(1.9, 2.1)", output: 3}, {expr: "add(a)", err: "expects 2 parameter"}, {expr: "isAdmin(a)", err: "expects string but found number"}, {expr: "toggle(a)", err: "expects boolean but found number"}, @@ -344,6 +345,229 @@ func TestTypedFunctions(t *testing.T) { } } +func TestTypedSlices(t *testing.T) { + input := map[string]any{ + "ints": []int{1, 2, 3}, + "floats": []float64{1.5, 2.5, 3.5}, + "strings": []string{"alpha", "beta", "gamma"}, + "bounds": []int{0, 1}, + } + + cases := []struct { + expr string + output any + }{ + {expr: "ints[1]", output: 2}, + {expr: "ints[1:]", output: []int{2, 3}}, + {expr: "ints[bounds]", output: []int{1, 2}}, + {expr: "ints + ints", output: []int{1, 2, 3, 1, 2, 3}}, + {expr: `2 in ints`, output: true}, + {expr: `strings contains "beta"`, output: true}, + {expr: `strings[1]`, output: "beta"}, + {expr: `floats.length`, output: 3}, + } + + for _, tc := range cases { + t.Run(tc.expr, func(t *testing.T) { + ast, err := Parse(tc.expr, input) + if err != nil { + t.Fatal(err.Pretty(tc.expr)) + } + result, err := Run(ast, input) + if err != nil { + t.Fatal(err.Pretty(tc.expr)) + } + if !reflect.DeepEqual(tc.output, result) { + t.Fatalf("expected %v but found %v", tc.output, result) + } + }) + } +} + +func TestReflectionBackedSlices(t *testing.T) { + input := map[string]any{ + "uints": []uint{1, 2, 3}, + "emptyUints": []uint{}, + } + + cases := []struct { + expr string + output any + }{ + {expr: "uints[1]", output: uint(2)}, + {expr: "uints[:1]", output: []uint{1, 2}}, + {expr: "uints[1:2]", output: []uint{2, 3}}, + {expr: "uints + uints", output: []any{uint(1), uint(2), uint(3), uint(1), uint(2), uint(3)}}, + {expr: `2 in uints`, output: true}, + {expr: `uints.length`, output: 3}, + {expr: `uints and 1`, output: true}, + {expr: `emptyUints or 1`, output: true}, + } + + for _, tc := range cases { + t.Run(tc.expr, func(t *testing.T) { + ast, err := Parse(tc.expr, input) + if err != nil { + t.Fatal(err.Pretty(tc.expr)) + } + result, err := Run(ast, input) + if err != nil { + t.Fatal(err.Pretty(tc.expr)) + } + if !reflect.DeepEqual(tc.output, result) { + t.Fatalf("expected %v but found %v", tc.output, result) + } + }) + } +} + +func TestTruthinessConversions(t *testing.T) { + cases := []struct { + name string + expr string + input map[string]any + output any + }{ + {name: "int64 true", expr: "value and 1", input: map[string]any{"value": int64(2)}, output: true}, + {name: "int64 false", expr: "value or 0", input: map[string]any{"value": int64(-1)}, output: false}, + {name: "uint8 true", expr: "value and 1", input: map[string]any{"value": uint8(1)}, output: true}, + {name: "float32 false", expr: "value or 0", input: map[string]any{"value": float32(0)}, output: false}, + {name: "bytes true", expr: "value and 1", input: map[string]any{"value": []byte("x")}, output: true}, + {name: "bytes false", expr: "value or 0", input: map[string]any{"value": []byte{}}, output: false}, + {name: "map any true", expr: "value and 1", input: map[string]any{"value": map[any]any{"k": 1}}, output: true}, + {name: "map any false", expr: "value or 0", input: map[string]any{"value": map[any]any{}}, output: false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ast, err := Parse(tc.expr, tc.input) + if err != nil { + t.Fatal(err.Pretty(tc.expr)) + } + result, err := Run(ast, tc.input) + if err != nil { + t.Fatal(err.Pretty(tc.expr)) + } + if !reflect.DeepEqual(tc.output, result) { + t.Fatalf("expected %v but found %v", tc.output, result) + } + }) + } +} + +func TestAdditionalNumericConversions(t *testing.T) { + cases := []struct { + name string + expr string + input map[string]any + output any + }{ + {name: "int8 and uint16", expr: "a + b", input: map[string]any{"a": int8(2), "b": uint16(3)}, output: 5.0}, + {name: "float32 and literal", expr: "a + 2", input: map[string]any{"a": float32(1.5)}, output: 3.5}, + {name: "numeric equality", expr: "a == b", input: map[string]any{"a": int16(1), "b": uint8(1)}, output: true}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ast, err := Parse(tc.expr, tc.input) + if err != nil { + t.Fatal(err.Pretty(tc.expr)) + } + result, err := Run(ast, tc.input) + if err != nil { + t.Fatal(err.Pretty(tc.expr)) + } + if !reflect.DeepEqual(tc.output, result) { + t.Fatalf("expected %v but found %v", tc.output, result) + } + }) + } +} + +func TestReflectionBackedArrays(t *testing.T) { + input := map[string]any{ + "arr": [3]uint{1, 2, 3}, + "tail": [2]uint{4, 5}, + } + + cases := []struct { + expr string + output any + }{ + {expr: "arr[1]", output: uint(2)}, + {expr: "arr[:1]", output: []uint{1, 2}}, + {expr: "arr + tail", output: []any{uint(1), uint(2), uint(3), uint(4), uint(5)}}, + {expr: "2 in arr", output: true}, + } + + for _, tc := range cases { + t.Run(tc.expr, func(t *testing.T) { + ast, err := Parse(tc.expr, input) + if err != nil { + t.Fatal(err.Pretty(tc.expr)) + } + result, err := Run(ast, input) + if err != nil { + t.Fatal(err.Pretty(tc.expr)) + } + if !reflect.DeepEqual(tc.output, result) { + t.Fatalf("expected %v but found %v", tc.output, result) + } + }) + } +} + +func TestEmptyArrayTypeCheck(t *testing.T) { + ast, err := Parse(`items where id > 3`, map[string]any{"items": []any{}}) + if err != nil { + t.Fatal(err.Pretty(`items where id > 3`)) + } + result, err := Run(ast, map[string]any{"items": []any{}}) + if err != nil { + t.Fatal(err.Pretty(`items where id > 3`)) + } + if !reflect.DeepEqual([]any{}, result) { + t.Fatalf("expected empty result, got %v", result) + } +} + +func TestEmptyObjectTypeCheck(t *testing.T) { + ast, err := Parse(`items where id > 3`, map[string]any{"items": map[string]any{}}) + if err != nil { + t.Fatal(err.Pretty(`items where id > 3`)) + } + result, err := Run(ast, map[string]any{"items": map[string]any{}}) + if err != nil { + t.Fatal(err.Pretty(`items where id > 3`)) + } + if !reflect.DeepEqual([]any{}, result) { + t.Fatalf("expected empty result, got %v", result) + } +} + +func TestWhereRequiresArrayOrObject(t *testing.T) { + _, err := Parse(`items where id > 3`, map[string]any{"items": 1}) + if err == nil { + t.Fatal("expected error") + } + if err.Error() != "where clause requires an array or object, but found number" { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestNestedNumericEquality(t *testing.T) { + result, err := Eval("a == b", map[string]any{ + "a": []int{1}, + "b": []float64{1}, + }) + if err != nil { + t.Fatal(err) + } + if result != true { + t.Fatalf("expected true but found %v", result) + } +} + func TestTypedFunctionsMapAny(t *testing.T) { input := map[any]any{ "upper": func(s string) string { return strings.ToUpper(s) }, diff --git a/lexer.go b/lexer.go index 7a72944..1ec071b 100644 --- a/lexer.go +++ b/lexer.go @@ -131,6 +131,7 @@ func NewLexer(expression string) Lexer { return &lexer{ expression: expression, pos: 0, + runePos: 0, lastWidth: 0, token: &Token{}, } @@ -139,6 +140,7 @@ func NewLexer(expression string) Lexer { type lexer struct { expression string pos uint16 + runePos uint16 lastWidth uint16 // token is a cached token to prevent new tokens from being allocated. @@ -154,6 +156,7 @@ func (l *lexer) next() rune { } r, w := utf8.DecodeRuneInString(l.expression[l.pos:]) l.pos += uint16(w) + l.runePos++ l.lastWidth = uint16(w) return r } @@ -161,6 +164,9 @@ func (l *lexer) next() rune { // back moves back one rune. func (l *lexer) back() { l.pos -= l.lastWidth + if l.lastWidth > 0 { + l.runePos-- + } } // peek returns the next rune without moving the position forward. @@ -173,8 +179,8 @@ func (l *lexer) peek() rune { func (l *lexer) newToken(typ TokenType, value string) *Token { l.token.Type = typ l.token.Value = value - l.token.Offset = l.pos - uint16(len(value)) - l.token.Length = uint8(len(value)) + l.token.Offset = l.runePos - uint16(utf8.RuneCountInString(value)) + l.token.Length = uint8(utf8.RuneCountInString(value)) if l.token.Length == 0 { l.token.Length = 1 } @@ -236,7 +242,7 @@ func (l *lexer) consumeIdentifier() *Token { // quote is encountered. Only double-quoted strings are supported. func (l *lexer) consumeString() (*Token, Error) { buf := bytes.NewBuffer(make([]byte, 0, 8)) - offset := l.pos - l.lastWidth + offset := l.runePos - 1 for { r := l.next() if r == '\\' && l.peek() == '"' { @@ -254,7 +260,7 @@ func (l *lexer) consumeString() (*Token, Error) { } tok := l.newToken(TokenString, buf.String()) tok.Offset = offset - tok.Length = uint8(l.pos - offset) + tok.Length = uint8(l.runePos - offset) return tok, nil } @@ -299,7 +305,7 @@ func (l *lexer) Next() (*Token, Error) { l.next() return l.newToken(TokenComparison, "=="), nil } - return nil, NewError(l.pos, 1, "= should be ==") + return nil, NewError(l.runePos-1, 1, "= should be ==") } if r == '"' { diff --git a/lexer_test.go b/lexer_test.go index 9b613a0..a8c869b 100644 --- a/lexer_test.go +++ b/lexer_test.go @@ -1,6 +1,10 @@ package mexpr -import "testing" +import ( + "strings" + "testing" + "unicode/utf8" +) func TestEscapedStringTokenOffsets(t *testing.T) { expr := `"a\"b"` @@ -20,8 +24,8 @@ func TestEscapedStringTokenOffsets(t *testing.T) { if tok.Offset != 0 { t.Fatalf("expected token offset 0, got %d", tok.Offset) } - if tok.Length != uint8(len(expr)) { - t.Fatalf("expected token length %d, got %d", len(expr), tok.Length) + if tok.Length != uint8(utf8.RuneCountInString(expr)) { + t.Fatalf("expected token length %d, got %d", utf8.RuneCountInString(expr), tok.Length) } ast, err := Parse(expr, nil) @@ -34,7 +38,88 @@ func TestEscapedStringTokenOffsets(t *testing.T) { if ast.Offset != 0 { t.Fatalf("expected node offset 0, got %d", ast.Offset) } - if ast.Length != uint8(len(expr)) { - t.Fatalf("expected node length %d, got %d", len(expr), ast.Length) + if ast.Length != uint8(utf8.RuneCountInString(expr)) { + t.Fatalf("expected node length %d, got %d", utf8.RuneCountInString(expr), ast.Length) + } +} + +func TestPrettyErrorUsesRuneOffsets(t *testing.T) { + _, err := Parse("é +", nil) + if err == nil { + t.Fatal("expected error") + } + pretty := err.Pretty("é +") + expected := "incomplete expression, EOF found\né +\n...^" + if pretty != expected { + t.Fatalf("expected %q but found %q", expected, pretty) + } +} + +func TestErrorAccessors(t *testing.T) { + _, err := Parse("1 ]", nil) + if err == nil { + t.Fatal("expected error") + } + if err.Offset() != 2 { + t.Fatalf("expected offset 2, got %d", err.Offset()) + } + if err.Length() != 1 { + t.Fatalf("expected length 1, got %d", err.Length()) + } +} + +func TestSingleEqualsUsesRuneOffsets(t *testing.T) { + _, err := Parse("é = 1", nil) + if err == nil { + t.Fatal("expected error") + } + pretty := err.Pretty("é = 1") + expected := "= should be ==\né = 1\n..^" + if pretty != expected { + t.Fatalf("expected %q but found %q", expected, pretty) + } +} + +func TestTokenFormatting(t *testing.T) { + l := NewLexer("foo") + tok, err := l.Next() + if err != nil { + t.Fatal(err) + } + if got := tok.String(); got != "0 (identifier) foo" { + t.Fatalf("expected formatted token, got %q", got) + } + + if TokenIdentifier.String() != "identifier" { + t.Fatalf("expected identifier string, got %q", TokenIdentifier.String()) + } + if TokenWhere.String() != "where" { + t.Fatalf("expected where string, got %q", TokenWhere.String()) + } + if TokenUnknown.String() != "unknown" { + t.Fatalf("expected unknown string, got %q", TokenUnknown.String()) + } +} + +func TestNodeFormatting(t *testing.T) { + ast, err := Parse("a + 1", nil) + if err != nil { + t.Fatal(err) + } + if ast == nil { + t.Fatal("expected ast") + } + if got := ast.String(); got != "+" { + t.Fatalf("expected root string +, got %q", got) + } + dot := ast.Dot("") + if !strings.Contains(dot, `"+" [label="+"]`) { + t.Fatalf("expected dot to contain root label, got %q", dot) + } + if !strings.Contains(dot, `"la" [label="a"]`) { + t.Fatalf("expected dot to contain identifier child, got %q", dot) + } + if !strings.Contains(dot, `"r1" [label="1"]`) { + t.Fatalf("expected dot to contain literal child, got %q", dot) } } diff --git a/typecheck.go b/typecheck.go index ecf45aa..11fe9e9 100644 --- a/typecheck.go +++ b/typecheck.go @@ -136,6 +136,9 @@ func getSchema(v any) *schema { for _, item := range i { s.items = mergeSchema(s.items, getSchema(item)) } + if s.items == nil { + s.items = newSchema(typeUnknown) + } return s case map[string]any: m := newSchema(typeObject) @@ -158,6 +161,17 @@ func getSchema(v any) *schema { } return fn } + if isSlice(v) { + s := newSchema(typeArray) + iterateSlice(v, func(item any) bool { + s.items = mergeSchema(s.items, getSchema(item)) + return true + }) + if s.items == nil { + s.items = newSchema(typeUnknown) + } + return s + } return newSchema(typeUnknown) } @@ -373,15 +387,21 @@ func (i *typeChecker) run(ast *Node, value any) (*schema, Error) { objectType := leftType keys := mapKeys(objectType.properties) sort.Strings(keys) + leftType = newSchema(typeArray) if len(keys) > 0 { - leftType = newSchema(typeArray) for _, key := range keys { leftType.items = mergeSchema(leftType.items, objectType.properties[key]) } } + if leftType.items == nil { + leftType.items = newSchema(typeUnknown) + } + } + if leftType.isArray() && leftType.items == nil { + leftType.items = newSchema(typeUnknown) } - if !leftType.isArray() || leftType.items == nil { - return nil, NewError(ast.Offset, ast.Length, "where clause requires a non-empty array or object, but found %s", leftType) + if !leftType.isArray() { + return nil, NewError(ast.Offset, ast.Length, "where clause requires an array or object, but found %s", leftType) } // In an unquoted string scenario it makes no sense for the first/only // token after a `where` clause to be treated as a string. Instead we