From d0e89a60bb34dd958f69714e5cfc5fa42e58fec7 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Sat, 5 Jul 2025 12:35:55 +0200 Subject: [PATCH] ENH,MAINT: Rename "relevant" to "dispatch" args and expose in context This is useful for generic wrappers. I still wan to hold back to do all the work for them, but the information is rather readily available, so I think we can pass it on. NumPy uses "relevant args" naming for the arguments we dispatch on, I decided to rename it to "dispatch" instead. (Also changed to pass a tuple of types, because I think that is more future proof. We expect few types, a set isn't useful e.g. if done in C.) --- src/spatch/backend_system.py | 101 ++++++++++++++++++++--------------- src/spatch/testing.py | 39 ++++++++++++++ tests/test_context.py | 51 ++++++++++++++++++ tests/test_priority.py | 39 +------------- 4 files changed, 148 insertions(+), 82 deletions(-) create mode 100644 src/spatch/testing.py create mode 100644 tests/test_context.py diff --git a/src/spatch/backend_system.py b/src/spatch/backend_system.py index a729f26..84882c0 100644 --- a/src/spatch/backend_system.py +++ b/src/spatch/backend_system.py @@ -8,7 +8,7 @@ import warnings import textwrap from types import MethodType -from typing import Callable +from typing import Any, Callable from spatch import from_identifier, get_identifier from spatch.utils import TypeIdentifier, valid_backend_name @@ -50,16 +50,16 @@ def from_namespace(cls, info): requires_opt_in=info.requires_opt_in, ) - def known_type(self, relevant_type): - if relevant_type in self.primary_types: + def known_type(self, dispatch_type): + if dispatch_type in self.primary_types: return "primary" # TODO: maybe make it an enum? - elif relevant_type in self.secondary_types: + elif dispatch_type in self.secondary_types: return "secondary" else: return False - def matches(self, relevant_types): - matches = frozenset(self.known_type(t) for t in relevant_types) + def matches(self, dispatch_types): + matches = frozenset(self.known_type(t) for t in dispatch_types) if "primary" in matches and False not in matches: return True return False @@ -594,19 +594,19 @@ def _get_entry_points(group, blocked): return sorted(backends, key=lambda x: x.name) @functools.lru_cache(maxsize=128) - def known_type(self, relevant_type, primary=False): + def known_type(self, dispatch_type, primary=False): for backend in self.backends.values(): - if backend.known_type(relevant_type): + if backend.known_type(dispatch_type): return True return False - def get_known_unique_types(self, relevant_types): - # From a list of args, return only the set of relevant types - return frozenset(val for val in relevant_types if self.known_type(val)) + def get_known_unique_types(self, dispatch_types): + # From a list of args, return only the set of dispatch types + return frozenset(val for val in dispatch_types if self.known_type(val)) @functools.lru_cache(maxsize=128) - def get_types_and_backends(self, relevant_types, ordered_backends): - """Fetch relevant types and matching backends. + def get_types_and_backends(self, dispatch_types, ordered_backends): + """Fetch dispatch types and matching backends. The main purpose of this function is to cache the results for a set of unique input types to functions. @@ -617,18 +617,18 @@ def get_types_and_backends(self, relevant_types, ordered_backends): Returns ------- - relevant_types : frozenset - The set of relevant types that are known to the backend system. + dispatch_types : frozenset + The set of dispatch types that are known to the backend system. matching_backends : tuple A tuple of backend names sorted by priority. """ # Filter out unknown types: - relevant_types = self.get_known_unique_types(relevant_types) + dispatch_types = self.get_known_unique_types(dispatch_types) matching_backends = tuple( - n for n in ordered_backends if self.backends[n].matches(relevant_types) + n for n in ordered_backends if self.backends[n].matches(dispatch_types) ) - return relevant_types, matching_backends + return dispatch_types, matching_backends def backend_from_namespace(self, info_namespace): new_backend = Backend.from_namespace(info_namespace) @@ -640,19 +640,20 @@ def backend_from_namespace(self, info_namespace): return self.backends[new_backend.name] = new_backend - def dispatchable(self, relevant_args=None, *, module=None, qualname=None): + def dispatchable(self, dispatch_args=None, *, module=None, qualname=None): """ Decorator to mark functions as dispatchable. Decorate a Python function with information on how to extract - the "relevant" arguments, i.e. arguments we wish to dispatch for. + the "dispatch" arguments, i.e. arguments we wish to dispatch for. Parameters ---------- - relevant_args : str, list, tuple, or None + dispatch_args : str, list, tuple, or None The names of parameters to extract (we use inspect to map these correctly). - If ``None`` all parameters will be considered relevant. + If ``None`` all parameters will be considered relevant for + dispatching. module : str Override the module of the function (actually modifies it) to ensure a well defined and stable public API. @@ -675,7 +676,7 @@ def wrap_callable(func): if qualname is not None: func.__qualname__ = qualname - disp = Dispatchable(self, func, relevant_args) + disp = Dispatchable(self, func, dispatch_args) return disp @@ -703,7 +704,7 @@ class DispatchContext: Attributes ---------- - types : Sequence[type] + types : tuple[type, ...] The (unique) types we are dispatching for. It is possible that not all types are passed as arguments if the user is requesting a specific type. @@ -718,8 +719,17 @@ class DispatchContext: Backends that strictly match a single primary type can safely ignore this (they always return the same type). + dispatch_args : tuple[Any, ...] + The arguments for which we dispatched. This can be useful information + for some generic wrappers who still need to inspect all dispatch arguments. + .. note:: - This is a frozenset currently, but please consider it a sequence. + ``dispatch_args`` can be empty if a function takes no arguments. + Yet, a backend version may be called explicitly e.g. in a + ``with backend_opts(type=): ...`` context. + + name : str + The name of the backend that was selected. prioritized : bool Whether the backend is prioritized. You may use this for example when @@ -730,7 +740,8 @@ class DispatchContext: # The idea is for the context to be very light-weight so that specific # information should be properties (because most likely we will never need it). # This object can grow to provide more information to backends. - types: tuple[type] + types: tuple[type, ...] + dispatch_args: tuple[Any, ...] name: str _state: tuple @@ -797,7 +808,7 @@ class Dispatchable: # # TODO: We may want to return a function just to be nice (not having a func was # OK in NumPy for example, but has a few little stumbling blocks) - def __init__(self, backend_system, func, relevant_args, ident=None): + def __init__(self, backend_system, func, dispatch_args, ident=None): functools.update_wrapper(self, func) self._backend_system = backend_system @@ -807,11 +818,11 @@ def __init__(self, backend_system, func, relevant_args, ident=None): self._ident = ident - if isinstance(relevant_args, str): - relevant_args = {relevant_args: 0} - elif isinstance(relevant_args, list | tuple): - relevant_args = {val: i for i, val in enumerate(relevant_args)} - self._relevant_args = relevant_args + if isinstance(dispatch_args, str): + dispatch_args = {dispatch_args: 0} + elif isinstance(dispatch_args, list | tuple): + dispatch_args = {val: i for i, val in enumerate(dispatch_args)} + self._dispatch_args = dispatch_args new_doc = [] impl_infos = {} @@ -851,27 +862,29 @@ def __get__(self, obj, objtype=None): return self return MethodType(self, obj) - def _get_relevant_types(self, *args, **kwargs): - # Return all relevant types, these are not filtered by the known_types - if self._relevant_args is None: - return set(type(val) for val in args) | set(type(k) for k in kwargs.values()) + def _get_dispatch_args(self, *args, **kwargs): + # Return all dispatch args + if self._dispatch_args is None: + return args + tuple(kwargs.values()) else: - return set( - type(val) for name, pos in self._relevant_args.items() + return tuple( + val for name, pos in self._dispatch_args.items() if (val := args[pos] if pos < len(args) else kwargs.get(name)) is not None ) def __call__(self, *args, **kwargs): - relevant_types = self._get_relevant_types(*args, **kwargs) + dispatch_args = self._get_dispatch_args(*args, **kwargs) + # At this point dispatch_types is not filtered for known types. + dispatch_types = set(type(val) for val in dispatch_args) state = self._backend_system._dispatch_state.get() ordered_backends, type_, prioritized, trace = state if type_ is not None: - relevant_types.add(type_) - relevant_types = frozenset(relevant_types) + dispatch_types.add(type_) + dispatch_types = frozenset(dispatch_types) - relevant_types, matching_backends = self._backend_system.get_types_and_backends( - relevant_types, ordered_backends) + dispatch_types, matching_backends = self._backend_system.get_types_and_backends( + dispatch_types, ordered_backends) if trace is not None: call_trace = [] @@ -886,7 +899,7 @@ def __call__(self, *args, **kwargs): # may want to optimize this (in case many backends have few functions). continue - context = DispatchContext(relevant_types, state, name) + context = DispatchContext(tuple(dispatch_types), dispatch_args, name, state) should_run = impl.should_run if should_run is None or (should_run := should_run(context, *args, **kwargs)) is True: diff --git a/src/spatch/testing.py b/src/spatch/testing.py new file mode 100644 index 0000000..de75a75 --- /dev/null +++ b/src/spatch/testing.py @@ -0,0 +1,39 @@ +from spatch.utils import get_identifier + +class _FuncGetter: + def __init__(self, get): + self.get = get + + +class BackendDummy: + """Helper to construct a minimal "backend" for testing. + + Forwards any lookup to the class. Documentation are used from + the function which must match in the name. + """ + def __init__(self): + self.functions = _FuncGetter(self.get_function) + + @classmethod + def get_function(cls, name, default=None): + # Simply ignore the module for testing purposes. + _, name = name.split(":") + + # Not get_identifier because it would find the super-class name. + res = {"function": f"{cls.__module__}:{cls.__name__}.{name}" } + if hasattr(cls, "uses_context"): + res["uses_context"] = cls.uses_context + if hasattr(cls, "should_run"): + res["should_run"] = get_identifier(cls.should_run) + + func = getattr(cls, name) + if func.__doc__ is not None: + res["additional_docs"] = func.__doc__ + + return res + + @classmethod + def dummy_func(cls, *args, **kwargs): + # Always define a small function that mainly forwards. + return cls.name, args, kwargs + diff --git a/tests/test_context.py b/tests/test_context.py new file mode 100644 index 0000000..e425e0a --- /dev/null +++ b/tests/test_context.py @@ -0,0 +1,51 @@ +import pytest + +from spatch.backend_system import BackendSystem +from spatch.testing import BackendDummy + + +class FloatWithContext(BackendDummy): + name = "FloatWithContext" + primary_types = ("~builtins:float",) + secondary_types = ("builtins:int",) + uses_context = True + requires_opt_in = False + + +def test_context_basic(): + bs = BackendSystem( + None, + environ_prefix="SPATCH_TEST", + default_primary_types=("builtin:int",), + backends=[FloatWithContext()] + ) + + # Add a dummy dispatchable function that dispatches on all arguments. + @bs.dispatchable(None, module="", qualname="dummy_func") + def dummy_func(*args, **kwargs): + return "fallback", args, kwargs + + _, (ctx, *args), kwargs = dummy_func(1, 1.) + assert ctx.name == "FloatWithContext" + assert set(ctx.types) == {int, float} + assert ctx.dispatch_args == (1, 1.) + assert not ctx.prioritized + + class float_subclass(float): + pass + + with bs.backend_opts(prioritize=("FloatWithContext",)): + _, (ctx, *args), kwargs = dummy_func(float_subclass(1.)) + assert ctx.name == "FloatWithContext" + assert set(ctx.types) == {float_subclass} + assert ctx.dispatch_args == (float_subclass(1.),) + assert ctx.prioritized + + with bs.backend_opts(type=float): + # No argument, works if explicitly prioritized... + _, (ctx, *args), kwargs = dummy_func() + assert ctx.name == "FloatWithContext" + assert set(ctx.types) == {float} + assert ctx.dispatch_args == () + assert not ctx.prioritized # not prioritized "just" type enforced + diff --git a/tests/test_priority.py b/tests/test_priority.py index fdef874..9ef36da 100644 --- a/tests/test_priority.py +++ b/tests/test_priority.py @@ -1,44 +1,7 @@ import pytest from spatch.backend_system import BackendSystem -from spatch.utils import get_identifier - -class FuncGetter: - def __init__(self, get): - self.get = get - - -class BackendDummy: - """Helper to construct a minimal "backend" for testing. - - Forwards any lookup to the class. Documentation are used from - the function which must match in the name. - """ - def __init__(self): - self.functions = FuncGetter(self.get_function) - - @classmethod - def get_function(cls, name, default=None): - # Simply ignore the module for testing purposes. - _, name = name.split(":") - - # Not get_identifier because it would find the super-class name. - res = {"function": f"{cls.__module__}:{cls.__name__}.{name}" } - if hasattr(cls, "uses_context"): - res["uses_context"] = cls.uses_context - if hasattr(cls, "should_run"): - res["should_run"] = get_identifier(cls.should_run) - - func = getattr(cls, name) - if func.__doc__ is not None: - res["additional_docs"] = func.__doc__ - - return res - - @classmethod - def dummy_func(cls, *args, **kwargs): - # Always define a small function that mainly forwards. - return cls.name, args, kwargs +from spatch.testing import BackendDummy class IntB(BackendDummy):