diff --git a/fixtures/multi_vql_queries.golden b/fixtures/multi_vql_queries.golden index b0a53b6..7aaff23 100644 --- a/fixtures/multi_vql_queries.golden +++ b/fixtures/multi_vql_queries.golden @@ -1416,5 +1416,20 @@ } ] } + ], + "090/000 Function definition with defaults: LET Y \u003c= 7": null, + "090/001 Function definition with defaults: LET F(X=2, Y) = X + Y": null, + "090/002 Function definition with defaults: SELECT F(Y=1), F(), F(X=1, Y=4) FROM scope()": [ + { + "F(Y=1)": 3, + "F()": 9, + "F(X=1, Y=4)": 5 + } + ], + "091/000 Stored Query definition with defaults: LET F(X=2, Y) = SELECT X + Y AS Total FROM scope()": null, + "091/001 Stored Query definition with defaults: SELECT * FROM F(Y=1)": [ + { + "Total": 3 + } ] } \ No newline at end of file diff --git a/functions/generic.go b/functions/generic.go index 162a827..07dfcbb 100644 --- a/functions/generic.go +++ b/functions/generic.go @@ -16,17 +16,18 @@ type GenericFunctionInterface func(ctx context.Context, scope types.Scope, args // helper plugin allows callers to use these within VFilter // easily. Example: -// scope.AppendPlugins(GenericListPlugin{ -// PluginName: "my_plugin", -// Function: func(args types.Row) []types.Row { -// .... -// } -// }) +// scope.AppendPlugins(GenericListPlugin{ +// PluginName: "my_plugin", +// Function: func(args types.Row) []types.Row { +// .... +// } +// }) type GenericFunction struct { FunctionName string Doc string Function GenericFunctionInterface Metadata *ordereddict.Dict + Version int ArgType types.Any } @@ -47,6 +48,7 @@ func (self GenericFunction) Info(scope types.Scope, type_map *types.TypeMap) *ty Name: self.FunctionName, Doc: self.Doc, Metadata: self.Metadata, + Version: self.Version, } if self.ArgType != nil { diff --git a/lambda.go b/lambda.go index 012b843..77d2781 100644 --- a/lambda.go +++ b/lambda.go @@ -24,8 +24,16 @@ type Lambda struct { func (self *Lambda) GetParameters() []string { result := []string{} + var visitor func(parameters *_ParameterList) + visitor = func(parameters *_ParameterList) { + result = append(result, parameters.Left) + if parameters.Right != nil { + visitor(parameters.Right.Term) + } + } + if self.Parameters != nil { - visitor(self.Parameters, &result) + visitor(self.Parameters) } return result diff --git a/stored.go b/stored.go index e8cac20..8ac8c51 100644 --- a/stored.go +++ b/stored.go @@ -37,6 +37,7 @@ type _StoredQuery struct { query *_Select name string parameters []string + defaults map[string]*_Args } func NewStoredQuery(query *_Select, name string) *_StoredQuery { @@ -46,6 +47,15 @@ func NewStoredQuery(query *_Select, name string) *_StoredQuery { } } +func (self *_StoredQuery) ApplyDefaults( + ctx context.Context, scope types.Scope, args *ordereddict.Dict) { + if self.defaults == nil { + return + } + + applyDefaults(ctx, scope, args, self.defaults) +} + func (self *_StoredQuery) GoString() string { scope := NewScope() return fmt.Sprintf("StoredQuery{name: %v, query: {%v}, parameters: %v}", @@ -155,6 +165,7 @@ type StoredExpression struct { Expr *_AndExpression name string parameters []string + defaults map[string]*_Args } func (self *StoredExpression) Reduce( @@ -162,6 +173,40 @@ func (self *StoredExpression) Reduce( return self.Expr.Reduce(ctx, scope) } +func (self *StoredExpression) ApplyDefaults( + ctx context.Context, scope types.Scope, args *ordereddict.Dict) { + // Add in any missing args + if self.defaults == nil { + return + } + + applyDefaults(ctx, scope, args, self.defaults) +} + +func applyDefaults(ctx context.Context, scope types.Scope, + args *ordereddict.Dict, + defaults map[string]*_Args) { + + for name, arg := range defaults { + _, pres := args.Get(name) + + // If the user did not specify the arg we fill it in from + // the defaults. NOTE: The default expressions are + // evaluated at the calling scope. + if !pres { + if arg.Right != nil { + args.Set(name, arg.Right.Reduce(ctx, scope)) + } else if arg.SubSelect != nil { + args.Set(arg.Left, arg.SubSelect) + } else if arg.Array != nil { + args.Set(name, arg.Array.Reduce(ctx, scope)) + } else if arg.ArrayOpenBrace != "" { + args.Set(name, []Row{}) + } + } + } +} + // Act as a function func (self *StoredExpression) Call(ctx context.Context, scope types.Scope, args *ordereddict.Dict) types.Any { @@ -183,6 +228,8 @@ func (self *StoredExpression) Call(ctx context.Context, vars.Set(k, v) } + self.ApplyDefaults(ctx, scope, vars) + sub_scope.AppendVars(vars) return self.Reduce(ctx, sub_scope) diff --git a/types/defaults.go b/types/defaults.go new file mode 100644 index 0000000..cabf528 --- /dev/null +++ b/types/defaults.go @@ -0,0 +1,20 @@ +package types + +import ( + "context" + + "github.com/Velocidex/ordereddict" +) + +type DefaultArgInterface interface { + ApplyDefaults(ctx context.Context, scope Scope, vars *ordereddict.Dict) +} + +func MaybeApplyDefaultArgs(callable interface{}, + ctx context.Context, scope Scope, args *ordereddict.Dict) { + + defaults, ok := callable.(DefaultArgInterface) + if ok { + defaults.ApplyDefaults(ctx, scope, args) + } +} diff --git a/types/frozen.go b/types/frozen.go deleted file mode 100644 index 96660d1..0000000 --- a/types/frozen.go +++ /dev/null @@ -1,26 +0,0 @@ -package types - -import ( - "context" -) - -// A FrozenStoredQuery is a stored query which will be evaluated -// inside the defined scope instead of the calling scope. -type FrozenStoredQuery struct { - query StoredQuery - defined_scope Scope -} - -func (self FrozenStoredQuery) Query() StoredQuery { - return self.query -} - -func (self FrozenStoredQuery) Eval( - ctx context.Context, scope Scope) <-chan Row { - return self.query.Eval(ctx, self.defined_scope) -} - -func NewFrozenStoredQuery( - query StoredQuery, scope Scope) StoredQuery { - return &FrozenStoredQuery{query: query, defined_scope: scope} -} diff --git a/vfilter.go b/vfilter.go index 6829eb2..5007f2f 100644 --- a/vfilter.go +++ b/vfilter.go @@ -295,15 +295,25 @@ type _Comment struct { MultiLine *string `@MLineComment )` } +type LetParameter struct { + DefaultArg *_Args ` ( @@ | ` + Name *string ` @Ident )` +} + // An opaque object representing the VQL expression. type VQL struct { - Let string `LET @Ident ` - Parameters *_ParameterList `{ "(" @@ ")" }` - LetOperator string ` ( @"=" | @"<=" ) ` - StoredQuery *_Select ` ( @@ | ` - Expression *_AndExpression ` @@ ) |` - Query *_Select ` @@ ` - Comments []*_Comment + Let string `LET @Ident ` + LetParameters []*LetParameter ` [ "(" [ @@ { "," @@ } ] ")" ] ` + LetOperator string ` ( @"=" | @"<=" ) ` + StoredQuery *_Select ` ( @@ | ` + Expression *_AndExpression ` @@ ) |` + Query *_Select ` @@ ` + Comments []*_Comment + + // JIT Compile these for faster execution + mu sync.Mutex + argsCache map[string]*_Args + parametersCache []string } type _ParameterList struct { @@ -342,7 +352,9 @@ func (self *VQL) Eval(ctx context.Context, scope types.Scope) <-chan Row { // If this is a Let expression we need to create a stored // query and assign to the scope. if len(self.Let) > 0 { - if self.Parameters != nil && self.LetOperator == "<=" { + parameters, defaults := self.getParameters() + + if parameters != nil && self.LetOperator == "<=" { scope.Log("WARN:Expression %v takes parameters but is "+ "materialized! Did you mean to use '='? ", self.Let) } @@ -357,12 +369,10 @@ func (self *VQL) Eval(ctx context.Context, scope types.Scope) <-chan Row { // Let assigning an expression. if self.Expression != nil { expr := &StoredExpression{ - Expr: self.Expression, - name: name, - } - - if self.Parameters != nil { - expr.parameters = self.getParameters() + Expr: self.Expression, + name: name, + parameters: parameters, + defaults: defaults, } switch self.LetOperator { @@ -393,9 +403,7 @@ func (self *VQL) Eval(ctx context.Context, scope types.Scope) <-chan Row { switch self.LetOperator { case "=": stored_query := NewStoredQuery(self.StoredQuery, name) - if self.Parameters != nil { - stored_query.parameters = self.getParameters() - } + stored_query.parameters, stored_query.defaults = self.getParameters() scope.AppendVars(ordereddict.NewDict().Set(name, stored_query)) case "<=": @@ -427,7 +435,11 @@ func (self *VQL) Eval(ctx context.Context, scope types.Scope) <-chan Row { if !ok { return } - output_chan <- row + select { + case <-ctx.Done(): + return + case output_chan <- row: + } } } }() @@ -436,22 +448,36 @@ func (self *VQL) Eval(ctx context.Context, scope types.Scope) <-chan Row { } } -// Walk the parameters list and collect all the parameter names. -func visitor(parameters *_ParameterList, result *[]string) { - *result = append(*result, parameters.Left) - if parameters.Right != nil { - visitor(parameters.Right.Term, result) +func (self *VQL) getParameters() ([]string, map[string]*_Args) { + if self.Let == "" || len(self.LetParameters) == 0 { + return nil, nil } -} -func (self *VQL) getParameters() []string { - result := []string{} + self.mu.Lock() + defer self.mu.Unlock() - if self.Let != "" && self.Parameters != nil { - visitor(self.Parameters, &result) + if self.argsCache != nil { + return self.parametersCache, self.argsCache } - return result + self.argsCache = make(map[string]*_Args) + + for _, arg := range self.LetParameters { + + // Two possibilities - either the arg has a default or not. + if arg.Name != nil { + name := utils.Unquote_ident(*arg.Name) + + self.parametersCache = append(self.parametersCache, name) + } else if arg.DefaultArg != nil { + name := utils.Unquote_ident(arg.DefaultArg.Left) + self.parametersCache = append(self.parametersCache, name) + + self.argsCache[name] = arg.DefaultArg + } + } + + return self.parametersCache, self.argsCache } type _Select struct { @@ -1095,86 +1121,101 @@ func (self *Plugin) Eval(ctx context.Context, scope types.Scope) <-chan Row { } if self.Call { - return self.evalSymbol( - ctx, scope, - symbol, self.Name, buildArgsFromParameters(ctx, scope, self.Args)) + return self.evalSymbolWithArgs(ctx, scope, symbol, self.Name) } - return self.evalSymbol(ctx, scope, symbol, self.Name, nil) + return self.evalSymbol(ctx, scope, symbol, self.Name) } -func (self *Plugin) evalSymbol( +func (self *Plugin) evalSymbolWithArgs( ctx context.Context, scope types.Scope, - symbol types.Any, name string, args *ordereddict.Dict) <-chan Row { - - output_chan := make(chan Row) + symbol types.Any, name string) <-chan Row { if scope.CheckForOverflow() { + output_chan := make(chan Row) close(output_chan) return output_chan } // We need to call the symbol depending on what it is. - if args != nil { - switch t := symbol.(type) { + switch t := symbol.(type) { - // Stored Expression e.g. LET Foo(X) = X + 1 - case types.StoredExpression: - subscope := scope.Copy() - defer subscope.Close() + // Stored Expression e.g. LET Foo(X) = X + 1 + case types.StoredExpression: + subscope := scope.Copy() + defer subscope.Close() - subscope.AppendVars(args) - return self.evalSymbol( - ctx, scope, t.Reduce(ctx, subscope), name, nil) + args := buildArgsFromParameters(ctx, scope, self.Args) + types.MaybeApplyDefaultArgs(t, ctx, scope, args) - // A plugin like item - case PluginGeneratorInterface: - scope.GetStats().IncPluginsCalled() + subscope.AppendVars(args) + return self.evalSymbol(ctx, scope, t.Reduce(ctx, subscope), name) - return t.Call(ctx, scope, args) + // A plugin like item + case PluginGeneratorInterface: + scope.GetStats().IncPluginsCalled() - default: - scope.Log("ERROR:Symbol %v is not callable", name) - close(output_chan) - return output_chan - } + args := buildArgsFromParameters(ctx, scope, self.Args) + types.MaybeApplyDefaultArgs(t, ctx, scope, args) - // Symbol is not called - } else { + return t.Call(ctx, scope, args) - switch t := symbol.(type) { - case types.StoredExpression: - return self.evalSymbol(ctx, scope, t.Reduce(ctx, scope), name, nil) + default: + scope.Log("ERROR:Symbol %v is not callable as a plugin", name) + utils.DlvBreak() - case StoredQuery: - return t.Eval(ctx, scope) + output_chan := make(chan Row) + close(output_chan) + return output_chan + } +} - } +// Evaluate the symbol with the current scope. +func (self *Plugin) evalSymbol( + ctx context.Context, scope types.Scope, + symbol types.Any, name string) <-chan Row { + + if scope.CheckForOverflow() { + output_chan := make(chan Row) + close(output_chan) + return output_chan } - go func() { - defer close(output_chan) + switch t := symbol.(type) { + case types.StoredExpression: + return self.evalSymbol(ctx, scope, t.Reduce(ctx, scope), name) - if utils.IsArray(symbol) { - var_slice := reflect.ValueOf(symbol) - for i := 0; i < var_slice.Len(); i++ { - select { - case <-ctx.Done(): - return - case output_chan <- var_slice.Index(i).Interface(): + case StoredQuery: + return t.Eval(ctx, scope) + + default: + // Send the value to the caller. + output_chan := make(chan Row) + + go func() { + defer close(output_chan) + + if utils.IsArray(symbol) { + var_slice := reflect.ValueOf(symbol) + for i := 0; i < var_slice.Len(); i++ { + select { + case <-ctx.Done(): + return + case output_chan <- var_slice.Index(i).Interface(): + } } + return } - return - } - select { - case <-ctx.Done(): - return - case output_chan <- symbol: - } + select { + case <-ctx.Done(): + return + case output_chan <- symbol: + } - }() + }() - return output_chan + return output_chan + } } func (self *_MemberExpression) IsAggregate(scope types.Scope) bool { @@ -1644,8 +1685,10 @@ func (self *_SymbolRef) Reduce(ctx context.Context, scope types.Scope) Any { return &Null{} } - subscope.AppendVars(self.buildArgsFromParameters( - ctx, scope)) + args := self.buildArgsFromParameters(ctx, scope) + types.MaybeApplyDefaultArgs(t, ctx, scope, args) + + subscope.AppendVars(args) scope.GetStats().IncFunctionsCalled() @@ -1672,6 +1715,8 @@ func (self *_SymbolRef) Reduce(ctx context.Context, scope types.Scope) Any { } vars := self.buildArgsFromParameters(ctx, scope) + types.MaybeApplyDefaultArgs(t, ctx, scope, vars) + subscope.AppendVars(vars) scope.GetStats().IncFunctionsCalled() @@ -1716,30 +1761,32 @@ func (self *_SymbolRef) buildArgsFromParameters( func buildArgsFromParameters( ctx context.Context, - scope types.Scope, parameters []*_Args) *ordereddict.Dict { + scope types.Scope, + parameters []*_Args) *ordereddict.Dict { args := ordereddict.NewDict() // When calling into a VQL stored function, we materialize all // args. for _, arg := range parameters { + name := utils.Unquote_ident(arg.Left) + // e.g. X=func(foo=Bar) // This is evaluated at the point of definition. if arg.Right != nil { - name := utils.Unquote_ident(arg.Left) args.Set(name, arg.Right.Reduce(ctx, scope)) // e.g. X={ SELECT * FROM ... } } else if arg.SubSelect != nil { - args.Set(arg.Left, arg.SubSelect) + args.Set(name, arg.SubSelect) // e.g. X=[1,2,3,4] } else if arg.Array != nil { value := arg.Array.Reduce(ctx, scope) - args.Set(arg.Left, value) + args.Set(name, value) } else if arg.ArrayOpenBrace != "" { - args.Set(arg.Left, []Row{}) + args.Set(name, []Row{}) } } diff --git a/vfilter_test.go b/vfilter_test.go index 5b8af75..4ab7ab7 100644 --- a/vfilter_test.go +++ b/vfilter_test.go @@ -31,7 +31,8 @@ type execTest struct { } var compareOptions = cmpopts.IgnoreUnexported( - _Value{}, Plugin{}, _SymbolRef{}, _AliasedExpression{}) + _Value{}, Plugin{}, _SymbolRef{}, _AliasedExpression{}, VQL{}, +) var execTestsSerialization = []execTest{ {"1 or sleep(a=100)", true}, @@ -1344,6 +1345,24 @@ SELECT (dict(Foo=12), ) + X, X + (dict(Foo=12),) FROM scope() LET X <= SELECT value AS Foo FROM range(start=1, end=4) SELECT (dict(Foo=12), ) + X, X + (dict(Foo=12),) FROM scope() +`}, + + // The F() call is not valid and will trigger the verifier but it + // is allowed. The arg with no default will be evaluated in the + // scope of the caller. + {"Function definition with defaults", ` +LET Y <= 7 + +LET F(X = 2, Y ) = X+Y +SELECT F(Y=1), F(), F(X=1, Y=4) FROM scope() +`}, + + {"Stored Query definition with defaults", ` +LET F(X = 2, Y ) = + SELECT X + Y AS Total + FROM scope() + +SELECT * FROM F(Y=1) `}, } @@ -1544,7 +1563,7 @@ func TestMultiVQLQueries(t *testing.T) { // Store the result in ordered dict so we have a consistent golden file. result := ordereddict.NewDict() for i, testCase := range multiVQLTest { - if false && i != 88 { + if false && i != 66 { continue } scope := makeTestScope() diff --git a/visitor.go b/visitor.go index adeb15b..f7b154f 100644 --- a/visitor.go +++ b/visitor.go @@ -48,17 +48,28 @@ type FormatOptions struct { test bool } +// The CallSite describes the place where a callable is called +// from. For example, SELECT Foo(X=45) FROM scope() type CallSite struct { Type string Name string Args []string } +// The DefinitionSite describes where a function is declared: +// For example: LET Foo(X, Y) = .... +type DefinitionSite struct { + Type string + Name string + Args []string + Defaults []string +} + type Visitor struct { CallSites []CallSite // A list of LET definitions - Definitions []CallSite + Definitions []DefinitionSite // Tokens added to the visitor as we encounter each token during // parsing. Combining all the Fragments yields a reformatted @@ -227,9 +238,6 @@ func (self *Visitor) Visit(node interface{}) { self.line_break() } - case *types.FrozenStoredQuery: - self.Visit(t.Query()) - case *_StoredQuery: self.Visit(t.query) @@ -1049,24 +1057,39 @@ func (self *Visitor) visitVQL(node *VQL) { if node.Expression != nil || node.StoredQuery != nil { self.push("LET ", node.Let) - if node.Parameters != nil { + parameters, defaults := node.getParameters() + + if defaults != nil { self.push("(") - parameters := node.getParameters() if self.opts.CollectCallSites { - callsite := CallSite{ + defsite := DefinitionSite{ Type: "definition", Name: node.Let, } for _, p := range parameters { - callsite.Args = append(callsite.Args, + defsite.Args = append(defsite.Args, utils.Unquote_ident(p)) } - self.Definitions = append(self.Definitions, callsite) + + for k := range defaults { + defsite.Defaults = append(defsite.Defaults, + utils.Unquote_ident(k)) + } + self.Definitions = append(self.Definitions, defsite) } for idx, p := range parameters { - self.push(p) + // Is it an arg with default? + def_value, pres := defaults[p] + if pres { + self.Visit(def_value) + + } else { + // Otherwise just emit the plain name + self.push(p) + } + if idx < len(parameters)-1 { self.push(",", " ") } @@ -1074,7 +1097,7 @@ func (self *Visitor) visitVQL(node *VQL) { self.push(")") } else if self.opts.CollectCallSites { - self.Definitions = append(self.Definitions, CallSite{ + self.Definitions = append(self.Definitions, DefinitionSite{ Type: "definition", Name: node.Let, })