@@ -371,49 +371,63 @@ def _match_function(
371371
372372 # Check positional arguments
373373 if is_scalar :
374- # Scalar functions: only ConstParams are matched against
375- # invocation.arguments. Column params come from input batches.
376- required_positional = [
377- p for p in positional_params if p .required and p .is_const
378- ]
379- # For scalar functions, column params can optionally be passed
380- # as arguments (for literals) or come from batches
381- max_positional = len (positional_params )
374+ # Scalar functions have two variants:
375+ # 1. PolarsScalarFunction (has _polars_params): Column bindings are
376+ # declared in the class, so only ConstParams are passed as args.
377+ # 2. Regular ScalarFunction (has _compute_params only): Column NAMES
378+ # are passed as positional args to specify which columns to bind.
379+ #
380+ # All scalar params are always required (no defaults).
381+ # Scalar functions don't support named arguments.
382+ is_polars_scalar = getattr (func_cls , "_polars_params" , None ) is not None
383+
384+ if is_polars_scalar :
385+ # PolarsScalarFunction: only count ConstParams for matching
386+ const_positional = [p for p in positional_params if p .is_const ]
387+ expected_positional = len (const_positional )
388+ if num_positional != expected_positional :
389+ continue # Must match exactly
390+ else :
391+ # Regular ScalarFunction: count ALL params
392+ # (column names + ConstParams)
393+ has_varargs = any (p .is_varargs for p in positional_params )
394+ expected_positional = len (positional_params )
395+ if has_varargs :
396+ # With varargs, need at least expected params
397+ if num_positional < expected_positional :
398+ continue
399+ else :
400+ if num_positional != expected_positional :
401+ continue # Must match exactly
402+
403+ # Scalar functions don't support named arguments
404+ if named_keys :
405+ continue
382406 else :
383407 # Table functions: all params come from invocation.arguments
384408 required_positional = [p for p in positional_params if p .required ]
409+ min_positional = len (required_positional )
385410 max_positional = len (positional_params )
411+ has_varargs = any (p .is_varargs for p in positional_params )
386412
387- has_varargs = any (p .is_varargs for p in positional_params )
388- min_positional = len (required_positional )
389-
390- if has_varargs :
391- # Varargs: allow any number >= min_positional
392- if num_positional < min_positional :
393- continue # Too few positional arguments
394- else :
395- # Fixed positional: must be within [min, max]
396- if not (min_positional <= num_positional <= max_positional ):
397- continue # Wrong number of positional arguments
413+ if has_varargs :
414+ if num_positional < min_positional :
415+ continue # Too few positional arguments
416+ else :
417+ if not (min_positional <= num_positional <= max_positional ):
418+ continue # Wrong number of positional arguments
398419
399- # Check named arguments
400- if is_scalar :
401- # Scalar: only match ConstParams for named arguments
402- valid_named_keys = {p .position for p in named_params if p .is_const }
403- required_named_keys = {
404- p .position for p in named_params if p .required and p .is_const
405- }
406- else :
420+ # Check named arguments
407421 valid_named_keys = {p .position for p in named_params }
408422 required_named_keys = {p .position for p in named_params if p .required }
409423
410- # All provided named args must be valid
411- if not named_keys .issubset (valid_named_keys ):
412- continue # Unknown named argument
424+ # All provided named args must be valid
425+ if not named_keys .issubset (valid_named_keys ):
426+ continue # Unknown named argument
413427
414- # All required named args must be provided
415- if not required_named_keys .issubset (named_keys ):
416- continue # Missing required named argument
428+ # All required named args must be provided
429+ if not required_named_keys .issubset (named_keys ):
430+ continue # Missing required named argument
417431
418432 matches .append (func_cls )
419433
0 commit comments