Skip to content

Commit 17ba204

Browse files
rustyconoverclaude
andcommitted
Fix function matching for scalar vs table functions
Scalar functions have two variants with different argument passing: 1. PolarsScalarFunction: Column bindings declared in class, only ConstParams passed as positional args 2. Regular ScalarFunction: Column names passed as positional args, with varargs support for variable argument counts Table functions retain existing matching logic with min/max positional counts, defaults, and named argument support. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 6dfbbc6 commit 17ba204

1 file changed

Lines changed: 47 additions & 33 deletions

File tree

vgi/worker.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -371,49 +371,63 @@ def _match_function(
371371

372372
# Check positional arguments
373373
if is_scalar:
374-
# Scalar functions: only ConstParams are matched against
375-
# invocation.arguments. Column params come from input batches.
376-
required_positional = [
377-
p for p in positional_params if p.required and p.is_const
378-
]
379-
# For scalar functions, column params can optionally be passed
380-
# as arguments (for literals) or come from batches
381-
max_positional = len(positional_params)
374+
# Scalar functions have two variants:
375+
# 1. PolarsScalarFunction (has _polars_params): Column bindings are
376+
# declared in the class, so only ConstParams are passed as args.
377+
# 2. Regular ScalarFunction (has _compute_params only): Column NAMES
378+
# are passed as positional args to specify which columns to bind.
379+
#
380+
# All scalar params are always required (no defaults).
381+
# Scalar functions don't support named arguments.
382+
is_polars_scalar = getattr(func_cls, "_polars_params", None) is not None
383+
384+
if is_polars_scalar:
385+
# PolarsScalarFunction: only count ConstParams for matching
386+
const_positional = [p for p in positional_params if p.is_const]
387+
expected_positional = len(const_positional)
388+
if num_positional != expected_positional:
389+
continue # Must match exactly
390+
else:
391+
# Regular ScalarFunction: count ALL params
392+
# (column names + ConstParams)
393+
has_varargs = any(p.is_varargs for p in positional_params)
394+
expected_positional = len(positional_params)
395+
if has_varargs:
396+
# With varargs, need at least expected params
397+
if num_positional < expected_positional:
398+
continue
399+
else:
400+
if num_positional != expected_positional:
401+
continue # Must match exactly
402+
403+
# Scalar functions don't support named arguments
404+
if named_keys:
405+
continue
382406
else:
383407
# Table functions: all params come from invocation.arguments
384408
required_positional = [p for p in positional_params if p.required]
409+
min_positional = len(required_positional)
385410
max_positional = len(positional_params)
411+
has_varargs = any(p.is_varargs for p in positional_params)
386412

387-
has_varargs = any(p.is_varargs for p in positional_params)
388-
min_positional = len(required_positional)
389-
390-
if has_varargs:
391-
# Varargs: allow any number >= min_positional
392-
if num_positional < min_positional:
393-
continue # Too few positional arguments
394-
else:
395-
# Fixed positional: must be within [min, max]
396-
if not (min_positional <= num_positional <= max_positional):
397-
continue # Wrong number of positional arguments
413+
if has_varargs:
414+
if num_positional < min_positional:
415+
continue # Too few positional arguments
416+
else:
417+
if not (min_positional <= num_positional <= max_positional):
418+
continue # Wrong number of positional arguments
398419

399-
# Check named arguments
400-
if is_scalar:
401-
# Scalar: only match ConstParams for named arguments
402-
valid_named_keys = {p.position for p in named_params if p.is_const}
403-
required_named_keys = {
404-
p.position for p in named_params if p.required and p.is_const
405-
}
406-
else:
420+
# Check named arguments
407421
valid_named_keys = {p.position for p in named_params}
408422
required_named_keys = {p.position for p in named_params if p.required}
409423

410-
# All provided named args must be valid
411-
if not named_keys.issubset(valid_named_keys):
412-
continue # Unknown named argument
424+
# All provided named args must be valid
425+
if not named_keys.issubset(valid_named_keys):
426+
continue # Unknown named argument
413427

414-
# All required named args must be provided
415-
if not required_named_keys.issubset(named_keys):
416-
continue # Missing required named argument
428+
# All required named args must be provided
429+
if not required_named_keys.issubset(named_keys):
430+
continue # Missing required named argument
417431

418432
matches.append(func_cls)
419433

0 commit comments

Comments
 (0)