From 2c2274243f9dec6757de832ad0771ac7a3b64798 Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Thu, 23 Apr 2026 10:27:24 -0700 Subject: [PATCH 1/2] Improve hot-path performance checks --- benchmark_test.go | 52 +++++++++++++++++++++++ conversions.go | 88 +++++++++++++++++++++++++++++++++++++++ expr.go | 27 +++++++++++- functions.go | 45 ++++++++++++++++++++ interpreter.go | 12 +----- lexer.go | 102 ++++++++++++++++++++++++++++++---------------- typecheck.go | 9 +--- 7 files changed, 279 insertions(+), 56 deletions(-) create mode 100644 benchmark_test.go diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 0000000..2cdb744 --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,52 @@ +package mexpr + +import "testing" + +func BenchmarkInternals(b *testing.B) { + b.Run("lexer-complex", func(b *testing.B) { + b.ReportAllocs() + expression := `foo.bar / (1 * 1024 * 1024) >= 1.0 and "v" in baz and baz.length > 3 and arr[2:].length == 1` + for n := 0; n < b.N; n++ { + l := lexer{expression: expression} + for { + tok, err := l.Next() + if err != nil { + b.Fatal(err) + } + if tok.Type == TokenEOF { + break + } + } + } + }) + + b.Run("resolve-lazy-value-non-function", func(b *testing.B) { + b.ReportAllocs() + input := map[string]any{"foo": "bar"} + for n := 0; n < b.N; n++ { + if _, ok := resolveLazyValue(input); ok { + b.Fatal("unexpected lazy value") + } + } + }) + + b.Run("resolve-lazy-value-number-func", func(b *testing.B) { + b.ReportAllocs() + input := func() int { return 42 } + for n := 0; n < b.N; n++ { + out, ok := resolveLazyValue(input) + if !ok || out.(int) != 42 { + b.Fatalf("unexpected lazy value result: %v %v", out, ok) + } + } + }) + + b.Run("deep-equal-number", func(b *testing.B) { + b.ReportAllocs() + for n := 0; n < b.N; n++ { + if !deepEqual(1, 1.0) { + b.Fatal("expected equal numbers") + } + } + }) +} diff --git a/conversions.go b/conversions.go index 744aa5c..18432e1 100644 --- a/conversions.go +++ b/conversions.go @@ -537,7 +537,95 @@ func normalize(v any) any { return v } +func normalizedNumber(v any) (float64, bool) { + switch n := v.(type) { + case int: + return float64(n), true + case int8: + return float64(n), true + case int16: + return float64(n), true + case int32: + return float64(n), true + case int64: + return float64(n), true + case uint: + return float64(n), true + case uint8: + return float64(n), true + case uint16: + return float64(n), true + case uint32: + return float64(n), true + case uint64: + return float64(n), true + case float32: + return float64(n), true + case float64: + return n, true + case func() int: + return float64(n()), true + case func() int8: + return float64(n()), true + case func() int16: + return float64(n()), true + case func() int32: + return float64(n()), true + case func() int64: + return float64(n()), true + case func() uint: + return float64(n()), true + case func() uint8: + return float64(n()), true + case func() uint16: + return float64(n()), true + case func() uint32: + return float64(n()), true + case func() uint64: + return float64(n()), true + case func() float32: + return float64(n()), true + case func() float64: + return n(), true + } + return 0, false +} + +func normalizedString(v any) (string, bool) { + switch s := v.(type) { + case string: + return s, true + case []byte: + return string(s), true + case func() string: + return s(), true + } + return "", false +} + +func normalizedBool(v any) (bool, bool) { + switch b := v.(type) { + case bool: + return b, true + case func() bool: + return b(), true + } + return false, false +} + // deepEqual returns whether two values are deeply equal. func deepEqual(left, right any) bool { + if leftNum, ok := normalizedNumber(left); ok { + rightNum, ok := normalizedNumber(right) + return ok && leftNum == rightNum + } + if leftStr, ok := normalizedString(left); ok { + rightStr, ok := normalizedString(right) + return ok && leftStr == rightStr + } + if leftBool, ok := normalizedBool(left); ok { + rightBool, ok := normalizedBool(right) + return ok && leftBool == rightBool + } return recursiveDeepEqual(left, right) } diff --git a/expr.go b/expr.go index 5a7266f..d4ba2b3 100644 --- a/expr.go +++ b/expr.go @@ -1,6 +1,20 @@ // Package mexpr provides a simple expression parser. package mexpr +func parseInterpreterOptions(options []InterpreterOption) (bool, bool) { + strict := false + unquoted := false + for _, opt := range options { + switch opt { + case StrictMode: + strict = true + case UnquotedStrings: + unquoted = true + } + } + return strict, unquoted +} + // Parse an expression and return the abstract syntax tree. If `types` is // passed, it should be a set of representative example values for the input // which will be used to type check the expression against. @@ -22,13 +36,22 @@ func Parse(expression string, types any, options ...InterpreterOption) (*Node, E // TypeCheck will take a parsed AST and type check against the given input // structure with representative example values. func TypeCheck(ast *Node, types any, options ...InterpreterOption) Error { - i := NewTypeChecker(ast, options...) + _, unquoted := parseInterpreterOptions(options) + i := typeChecker{ + ast: ast, + unquoted: unquoted, + } return i.Run(types) } // Run executes an AST with the given input and returns the output. func Run(ast *Node, input any, options ...InterpreterOption) (any, Error) { - i := NewInterpreter(ast, options...) + strict, unquoted := parseInterpreterOptions(options) + i := interpreter{ + ast: ast, + strict: strict, + unquoted: unquoted, + } return i.Run(input) } diff --git a/functions.go b/functions.go index c8e2583..aa2c297 100644 --- a/functions.go +++ b/functions.go @@ -44,6 +44,51 @@ func getFunctionSchema(v any) (*schema, bool) { } func resolveLazyValue(v any) (any, bool) { + switch fn := v.(type) { + case nil: + return nil, false + case bool, string, []byte: + return nil, false + case int, int8, int16, int32, int64: + return nil, false + case uint, uint8, uint16, uint32, uint64: + return nil, false + case float32, float64: + return nil, false + case []any, []int, []float64, []string: + return nil, false + case map[string]any, map[any]any: + return nil, false + case func() bool: + return fn(), true + case func() int: + return fn(), true + case func() int8: + return fn(), true + case func() int16: + return fn(), true + case func() int32: + return fn(), true + case func() int64: + return fn(), true + case func() uint: + return fn(), true + case func() uint8: + return fn(), true + case func() uint16: + return fn(), true + case func() uint32: + return fn(), true + case func() uint64: + return fn(), true + case func() float32: + return fn(), true + case func() float64: + return fn(), true + case func() string: + return fn(), true + } + s, ok := getFunctionSchema(v) if !ok || len(s.parameters) != 0 { return nil, false diff --git a/interpreter.go b/interpreter.go index e26259c..62fc7d2 100644 --- a/interpreter.go +++ b/interpreter.go @@ -57,17 +57,7 @@ type Interpreter interface { // NewInterpreter returns an interpreter for the given AST. func NewInterpreter(ast *Node, options ...InterpreterOption) Interpreter { - strict := false - unquoted := false - - for _, opt := range options { - switch opt { - case StrictMode: - strict = true - case UnquotedStrings: - unquoted = true - } - } + strict, unquoted := parseInterpreterOptions(options) return &interpreter{ ast: ast, diff --git a/lexer.go b/lexer.go index 1ec071b..44638e4 100644 --- a/lexer.go +++ b/lexer.go @@ -133,7 +133,6 @@ func NewLexer(expression string) Lexer { pos: 0, runePos: 0, lastWidth: 0, - token: &Token{}, } } @@ -145,7 +144,7 @@ type lexer struct { // token is a cached token to prevent new tokens from being allocated. // It is re-used on each call to `Next()`. - token *Token + token Token } // next returns the next rune in the expression at the current position. @@ -176,25 +175,22 @@ func (l *lexer) peek() rune { return r } -func (l *lexer) newToken(typ TokenType, value string) *Token { +func (l *lexer) newToken(typ TokenType, value string, offset, length uint16) *Token { l.token.Type = typ l.token.Value = value - l.token.Offset = l.runePos - uint16(utf8.RuneCountInString(value)) - l.token.Length = uint8(utf8.RuneCountInString(value)) + l.token.Offset = offset + l.token.Length = uint8(length) if l.token.Length == 0 { l.token.Length = 1 } - if typ == TokenString { - // Account for quotes - l.token.Offset-- - } - return l.token + return &l.token } // consumeNumber reads runes from the expression until a non-number or // non-decimal is encountered. func (l *lexer) consumeNumber() *Token { start := l.pos - l.lastWidth + offset := l.runePos - 1 for { r := l.next() if r != '.' && r != '_' && (r < '0' || r > '9') { @@ -202,7 +198,7 @@ func (l *lexer) consumeNumber() *Token { break } } - return l.newToken(TokenNumber, l.expression[start:l.pos]) + return l.newToken(TokenNumber, l.expression[start:l.pos], offset, l.runePos-offset) } // consumeIdentifier reads runes from the expression until a non-identifier @@ -210,6 +206,7 @@ func (l *lexer) consumeNumber() *Token { // then that corresponding token is returned, otherwise a normal identifier. func (l *lexer) consumeIdentifier() *Token { start := l.pos - l.lastWidth + offset := l.runePos - 1 for { r := l.next() if r == -1 || basic(r) != TokenUnknown || r == ' ' || r == '\t' || r == '\r' || r == '\n' || r == '<' || r == '>' || r == '=' || r == '!' || r == '.' || r == '[' || r == '(' { @@ -224,44 +221,63 @@ func (l *lexer) consumeIdentifier() *Token { // keywords to be used as properties without issue. switch string(value) { case "and": - return l.newToken(TokenAnd, value) + return l.newToken(TokenAnd, value, offset, l.runePos-offset) case "or": - return l.newToken(TokenOr, value) + return l.newToken(TokenOr, value, offset, l.runePos-offset) case "not": - return l.newToken(TokenNot, value) + return l.newToken(TokenNot, value, offset, l.runePos-offset) case "in", "contains", "startsWith", "endsWith", "before", "after": - return l.newToken(TokenStringCompare, value) + return l.newToken(TokenStringCompare, value, offset, l.runePos-offset) case "where": - return l.newToken(TokenWhere, value) + return l.newToken(TokenWhere, value, offset, l.runePos-offset) } } - return l.newToken(TokenIdentifier, value) + return l.newToken(TokenIdentifier, value, offset, l.runePos-offset) } // consumeString reads runes from the expression until a non-escaped double // quote is encountered. Only double-quoted strings are supported. func (l *lexer) consumeString() (*Token, Error) { - buf := bytes.NewBuffer(make([]byte, 0, 8)) offset := l.runePos - 1 + start := l.pos + for { r := l.next() - if r == '\\' && l.peek() == '"' { - l.next() - buf.WriteRune('"') - continue - } if r == -1 { return nil, NewError(offset, 1, "unterminated string") } if r == '"' { - break + return l.newToken(TokenString, l.expression[start:l.pos-l.lastWidth], offset, l.runePos-offset), nil + } + if r != '\\' { + continue + } + + buf := bytes.NewBuffer(make([]byte, 0, int(l.pos-start)+8)) + buf.WriteString(l.expression[start : l.pos-l.lastWidth]) + if l.peek() == '"' { + l.next() + buf.WriteRune('"') + } else { + buf.WriteRune('\\') + } + + for { + r = l.next() + if r == '\\' && l.peek() == '"' { + l.next() + buf.WriteRune('"') + continue + } + if r == -1 { + return nil, NewError(offset, 1, "unterminated string") + } + if r == '"' { + return l.newToken(TokenString, buf.String(), offset, l.runePos-offset), nil + } + buf.WriteRune(r) } - buf.WriteRune(r) } - tok := l.newToken(TokenString, buf.String()) - tok.Offset = offset - tok.Length = uint8(l.runePos - offset) - return tok, nil } func (l *lexer) Next() (*Token, Error) { @@ -270,7 +286,7 @@ func (l *lexer) Next() (*Token, Error) { r = l.next() } if r == -1 { - return l.newToken(TokenEOF, ""), nil + return l.newToken(TokenEOF, "", l.runePos, 0), nil } b := basic(r) @@ -282,9 +298,10 @@ func (l *lexer) Next() (*Token, Error) { } } if l.pos-l.lastWidth > uint16(len(l.expression)-1) { - return l.newToken(TokenEOF, ""), nil + return l.newToken(TokenEOF, "", l.runePos, 0), nil } - return l.newToken(b, l.expression[l.pos-l.lastWidth:l.pos]), nil + offset := l.runePos - 1 + return l.newToken(b, l.expression[l.pos-l.lastWidth:l.pos], offset, 1), nil } if r >= '0' && r <= '9' { @@ -292,18 +309,33 @@ func (l *lexer) Next() (*Token, Error) { } if r == '<' || r == '>' || r == '!' { + offset := l.runePos - 1 eq := l.next() if eq == '=' { - return l.newToken(TokenComparison, string([]rune{r, eq})), nil + switch r { + case '<': + return l.newToken(TokenComparison, "<=", offset, 2), nil + case '>': + return l.newToken(TokenComparison, ">=", offset, 2), nil + default: + return l.newToken(TokenComparison, "!=", offset, 2), nil + } } l.back() - return l.newToken(TokenComparison, string(r)), nil + switch r { + case '<': + return l.newToken(TokenComparison, "<", offset, 1), nil + case '>': + return l.newToken(TokenComparison, ">", offset, 1), nil + default: + return l.newToken(TokenComparison, "!", offset, 1), nil + } } if r == '=' { if l.peek() == '=' { l.next() - return l.newToken(TokenComparison, "=="), nil + return l.newToken(TokenComparison, "==", l.runePos-2, 2), nil } return nil, NewError(l.runePos-1, 1, "= should be ==") } diff --git a/typecheck.go b/typecheck.go index 11fe9e9..834cd42 100644 --- a/typecheck.go +++ b/typecheck.go @@ -182,14 +182,7 @@ type TypeChecker interface { // NewTypeChecker returns a type checker for the given AST. func NewTypeChecker(ast *Node, options ...InterpreterOption) TypeChecker { - unquoted := false - - for _, opt := range options { - switch opt { - case UnquotedStrings: - unquoted = true - } - } + _, unquoted := parseInterpreterOptions(options) return &typeChecker{ ast: ast, From d5461d4803304f1956433c467e4f1fc942c6bd6c Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Thu, 23 Apr 2026 10:32:06 -0700 Subject: [PATCH 2/2] Optimize parser and slice length fast paths --- expr.go | 4 +- interpreter.go | 161 ++++++++++++++++++++++++++++++++----------------- parser.go | 24 ++++++-- 3 files changed, 126 insertions(+), 63 deletions(-) diff --git a/expr.go b/expr.go index d4ba2b3..f735f7b 100644 --- a/expr.go +++ b/expr.go @@ -19,8 +19,8 @@ func parseInterpreterOptions(options []InterpreterOption) (bool, bool) { // passed, it should be a set of representative example values for the input // which will be used to type check the expression against. func Parse(expression string, types any, options ...InterpreterOption) (*Node, Error) { - l := NewLexer(expression) - p := NewParser(l) + l := lexer{expression: expression} + p := parser{lexer: &l} ast, err := p.Parse() if err != nil { return nil, err diff --git a/interpreter.go b/interpreter.go index 62fc7d2..6c49adf 100644 --- a/interpreter.go +++ b/interpreter.go @@ -50,6 +50,48 @@ func checkStringBounds(ast *Node, length, idx int) Error { return nil } +func normalizeSliceBounds(ast *Node, length int, start, end float64) (int, int, Error) { + if start < 0 { + start += float64(length) + } + if end < 0 { + end += float64(length) + } + startIdx := int(start) + endIdx := int(end) + if startIdx < 0 || startIdx >= length { + return 0, 0, NewError(ast.Offset, ast.Length, "invalid index %d for slice of length %d", startIdx, length) + } + if endIdx < 0 || endIdx >= length { + return 0, 0, NewError(ast.Offset, ast.Length, "invalid index %d for slice of length %d", endIdx, length) + } + if startIdx > endIdx { + return 0, 0, NewError(ast.Offset, ast.Length, "slice start cannot be greater than end") + } + return startIdx, endIdx, nil +} + +func normalizeStringSliceBounds(ast *Node, length int, start, end float64) (int, int, Error) { + if start < 0 { + start += float64(length) + } + if end < 0 { + end += float64(length) + } + startIdx := int(start) + endIdx := int(end) + if err := checkStringBounds(ast, length, startIdx); err != nil { + return 0, 0, err + } + if startIdx > endIdx { + return 0, 0, NewError(ast.Offset, ast.Length, "string slice start cannot be greater than end") + } + if err := checkStringBounds(ast, length, endIdx); err != nil { + return 0, 0, err + } + return startIdx, endIdx, nil +} + // Interpreter executes expression AST programs. type Interpreter interface { Run(value any) (any, Error) @@ -77,6 +119,52 @@ func (i *interpreter) Run(value any) (any, Error) { return i.run(i.ast, value) } +func (i *interpreter) fastLength(ast *Node, value any) (any, bool, Error) { + if ast == nil || ast.Type != NodeArrayIndex || ast.Right == nil || ast.Right.Type != NodeSlice { + return nil, false, nil + } + + resultLeft, err := i.run(ast.Left, value) + if err != nil { + return nil, true, err + } + startValue, err := i.run(ast.Right.Left, value) + if err != nil { + return nil, true, err + } + endValue, err := i.run(ast.Right.Right, value) + if err != nil { + return nil, true, err + } + start, err := toNumber(ast.Right.Left, startValue) + if err != nil { + return nil, true, err + } + end, err := toNumber(ast.Right.Right, endValue) + if err != nil { + return nil, true, err + } + + if leftLen, ok := sliceLen(resultLeft); ok { + startIdx, endIdx, err := normalizeSliceBounds(ast, leftLen, start, end) + if err != nil { + return nil, true, err + } + return endIdx - startIdx + 1, true, nil + } + if !isString(resultLeft) { + return nil, true, NewError(ast.Offset, ast.Length, "can only index strings or arrays but got %v", resultLeft) + } + + left := toString(resultLeft) + leftLen := stringLength(left) + startIdx, endIdx, err := normalizeStringSliceBounds(ast, leftLen, start, end) + if err != nil { + return nil, true, err + } + return endIdx - startIdx + 1, true, nil +} + func (i *interpreter) run(ast *Node, value any) (any, Error) { if ast == nil { return nil, nil @@ -145,6 +233,11 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { } return nil, NewError(ast.Offset, ast.Length, "cannot get %v from %v", ast.Value, value) case NodeFieldSelect: + if ast.Right != nil && ast.Right.Type == NodeIdentifier && ast.Right.Value == "length" { + if result, ok, err := i.fastLength(ast.Left, value); ok { + return result, err + } + } i.prevFieldSelect = true leftValue, err := i.run(ast.Left, value) if err != nil { @@ -178,22 +271,11 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { return nil, err } if leftLen, ok := sliceLen(resultLeft); ok { - if start < 0 { - start += float64(leftLen) - } - if end < 0 { - end += float64(leftLen) - } - if err := checkBounds(ast, resultLeft, int(start)); err != nil { + startIdx, endIdx, err := normalizeSliceBounds(ast, leftLen, start, end) + if err != nil { return nil, err } - if err := checkBounds(ast, resultLeft, int(end)); err != nil { - return nil, err - } - if int(start) > int(end) { - return nil, NewError(ast.Offset, ast.Length, "slice start cannot be greater than end") - } - result, ok := sliceRange(resultLeft, int(start), int(end)) + result, ok := sliceRange(resultLeft, startIdx, endIdx) if !ok { return nil, NewError(ast.Offset, ast.Length, "can only index strings or arrays but got %v", resultLeft) } @@ -201,22 +283,11 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { } left := toString(resultLeft) leftLen := stringLength(left) - if start < 0 { - start += float64(leftLen) - } - if end < 0 { - end += float64(leftLen) - } - if err := checkStringBounds(ast, leftLen, int(start)); err != nil { - return nil, err - } - if int(start) > int(end) { - return nil, NewError(ast.Offset, ast.Length, "string slice start cannot be greater than end") - } - if err := checkStringBounds(ast, leftLen, int(end)); err != nil { + startIdx, endIdx, err := normalizeStringSliceBounds(ast, leftLen, start, end) + if err != nil { return nil, err } - return stringSlice(left, int(start), int(end)), nil + return stringSlice(left, startIdx, endIdx), nil } resultRight, err := i.run(ast.Right, value) if err != nil { @@ -234,22 +305,11 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { return nil, err } if leftLen, ok := sliceLen(resultLeft); ok { - if start < 0 { - start += float64(leftLen) - } - if end < 0 { - end += float64(leftLen) - } - if err := checkBounds(ast, resultLeft, int(start)); err != nil { + startIdx, endIdx, err := normalizeSliceBounds(ast, leftLen, start, end) + if err != nil { return nil, err } - if err := checkBounds(ast, resultLeft, int(end)); err != nil { - return nil, err - } - if int(start) > int(end) { - return nil, NewError(ast.Offset, ast.Length, "slice start cannot be greater than end") - } - result, ok := sliceRange(resultLeft, int(start), int(end)) + result, ok := sliceRange(resultLeft, startIdx, endIdx) if !ok { return nil, NewError(ast.Offset, ast.Length, "can only index strings or arrays but got %v", resultLeft) } @@ -257,22 +317,11 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { } left := toString(resultLeft) leftLen := stringLength(left) - if start < 0 { - start += float64(leftLen) - } - if end < 0 { - end += float64(leftLen) - } - if err := checkStringBounds(ast, leftLen, int(start)); err != nil { - return nil, err - } - if int(start) > int(end) { - return nil, NewError(ast.Offset, ast.Length, "string slice start cannot be greater than end") - } - if err := checkStringBounds(ast, leftLen, int(end)); err != nil { + startIdx, endIdx, err := normalizeStringSliceBounds(ast, leftLen, start, end) + if err != nil { return nil, err } - return stringSlice(left, int(start), int(end)), nil + return stringSlice(left, startIdx, endIdx), nil } if isNumber(resultRight) { idx, err := toNumber(ast, resultRight) diff --git a/parser.go b/parser.go index 7d7c0b7..b2bca6b 100644 --- a/parser.go +++ b/parser.go @@ -194,20 +194,34 @@ type Parser interface { // NewParser creates a new parser that uses the given lexer to get and process // tokens into an abstract syntax tree. -func NewParser(lexer Lexer) Parser { +func NewParser(lx Lexer) Parser { + if concrete, ok := lx.(*lexer); ok { + return &parser{ + lexer: concrete, + } + } return &parser{ - lexer: lexer, + genericLexer: lx, } } // parser is an implementation of a Pratt or top-down operator precedence parser type parser struct { - lexer Lexer - token *Token + lexer *lexer + genericLexer Lexer + token *Token } func (p *parser) advance() Error { - t, err := p.lexer.Next() + var ( + t *Token + err Error + ) + if p.lexer != nil { + t, err = p.lexer.Next() + } else { + t, err = p.genericLexer.Next() + } if err != nil { return err }