Skip to content

Commit bbb454e

Browse files
rustyconoverclaude
andcommitted
Improve function matching error message to show argument values and types
Add _format_arguments_for_error() helper that formats Arguments instances showing actual values and their Arrow types. The error message now displays: positional=[3 (int64), 'hello' (string)], named={sep: ',' (string)} Instead of the previous: 1 positional, named=[] This makes it much easier to debug why a function call failed to match any overload. Also simplifies scalar function argument matching by removing the special case for PolarsScalarFunction (column bindings are now handled uniformly). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent da6c200 commit bbb454e

1 file changed

Lines changed: 73 additions & 19 deletions

File tree

vgi/worker.py

Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class MyWorker(Worker):
9696
from pyarrow import ipc
9797

9898
from vgi import tracing
99+
from vgi.arguments import Arguments
99100
from vgi.catalog import CatalogInterface
100101
from vgi.catalog.setting import SettingSpec, extract_setting_specs
101102
from vgi.exceptions import SchemaValidationError
@@ -124,6 +125,67 @@ class MyWorker(Worker):
124125
_BIND_ERROR_SCHEMA = pa.schema([("_error", pa.null())])
125126

126127

128+
def _format_arguments_for_error(args: Arguments) -> str:
129+
"""Format Arguments for error messages, showing values and types.
130+
131+
Produces output like:
132+
positional=[3 (int64), "hello" (string)], named={sep: "," (string)}
133+
134+
Args:
135+
args: The Arguments instance to format.
136+
137+
Returns:
138+
Human-readable string showing argument values and types.
139+
140+
"""
141+
parts = []
142+
143+
# Format positional arguments
144+
if args.positional:
145+
pos_strs = []
146+
for scalar in args.positional:
147+
if scalar is None:
148+
pos_strs.append("null")
149+
elif not scalar.is_valid:
150+
pos_strs.append(f"null ({scalar.type})")
151+
else:
152+
# Format value with type
153+
value = scalar.as_py()
154+
type_name = str(scalar.type)
155+
if isinstance(value, str):
156+
pos_strs.append(f"{value!r} ({type_name})")
157+
elif isinstance(value, bytes):
158+
# Truncate long bytes
159+
if len(value) > 20:
160+
pos_strs.append(f"<{len(value)} bytes> ({type_name})")
161+
else:
162+
pos_strs.append(f"{value!r} ({type_name})")
163+
else:
164+
pos_strs.append(f"{value} ({type_name})")
165+
parts.append(f"positional=[{', '.join(pos_strs)}]")
166+
else:
167+
parts.append("positional=[]")
168+
169+
# Format named arguments
170+
if args.named:
171+
named_strs = []
172+
for name, scalar in sorted(args.named.items()):
173+
if not scalar.is_valid:
174+
named_strs.append(f"{name}: null ({scalar.type})")
175+
else:
176+
value = scalar.as_py()
177+
type_name = str(scalar.type)
178+
if isinstance(value, str):
179+
named_strs.append(f"{name}: {value!r} ({type_name})")
180+
else:
181+
named_strs.append(f"{name}: {value} ({type_name})")
182+
parts.append(f"named={{{', '.join(named_strs)}}}")
183+
else:
184+
parts.append("named={}")
185+
186+
return ", ".join(parts)
187+
188+
127189
def _inject_trace_context(
128190
batch: pa.RecordBatch, traceparent: str, tracestate: str | None
129191
) -> pa.RecordBatch:
@@ -373,32 +435,24 @@ def _match_function(
373435
if is_scalar:
374436
# Scalar functions have two variants:
375437
# 1. PolarsScalarFunction (has _polars_params): Column bindings are
376-
# declared in the class, so only ConstParams are passed as args.
438+
# declared in the class.
377439
# 2. Regular ScalarFunction (has _compute_params only): Column NAMES
378440
# are passed as positional args to specify which columns to bind.
379441
#
380442
# All scalar params are always required (no defaults).
381443
# Scalar functions don't support named arguments.
382-
is_polars_scalar = getattr(func_cls, "_polars_params", None) is not None
383444

384-
if is_polars_scalar:
385-
# PolarsScalarFunction: only count ConstParams for matching
386-
const_positional = [p for p in positional_params if p.is_const]
387-
expected_positional = len(const_positional)
445+
# Regular ScalarFunction: count ALL params
446+
# (column names + ConstParams)
447+
has_varargs = any(p.is_varargs for p in positional_params)
448+
expected_positional = len(positional_params)
449+
if has_varargs:
450+
# With varargs, need at least expected params
451+
if num_positional < expected_positional:
452+
continue
453+
else:
388454
if num_positional != expected_positional:
389455
continue # Must match exactly
390-
else:
391-
# Regular ScalarFunction: count ALL params
392-
# (column names + ConstParams)
393-
has_varargs = any(p.is_varargs for p in positional_params)
394-
expected_positional = len(positional_params)
395-
if has_varargs:
396-
# With varargs, need at least expected params
397-
if num_positional < expected_positional:
398-
continue
399-
else:
400-
if num_positional != expected_positional:
401-
continue # Must match exactly
402456

403457
# Scalar functions don't support named arguments
404458
if named_keys:
@@ -446,7 +500,7 @@ def _match_function(
446500

447501
raise ValueError(
448502
f"No matching function '{invocation.function_name}' for arguments: "
449-
f"{num_positional} positional, named={sorted(named_keys)}. "
503+
f"{_format_arguments_for_error(args)}. "
450504
f"Available overloads:\n" + "\n".join(param_summaries)
451505
)
452506

0 commit comments

Comments
 (0)