Skip to content

Commit d0e89a6

Browse files
committed
ENH,MAINT: Rename "relevant" to "dispatch" args and expose in context
This is useful for generic wrappers. I still wan to hold back to do all the work for them, but the information is rather readily available, so I think we can pass it on. NumPy uses "relevant args" naming for the arguments we dispatch on, I decided to rename it to "dispatch" instead. (Also changed to pass a tuple of types, because I think that is more future proof. We expect few types, a set isn't useful e.g. if done in C.)
1 parent fe991c2 commit d0e89a6

4 files changed

Lines changed: 148 additions & 82 deletions

File tree

src/spatch/backend_system.py

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import warnings
99
import textwrap
1010
from types import MethodType
11-
from typing import Callable
11+
from typing import Any, Callable
1212

1313
from spatch import from_identifier, get_identifier
1414
from spatch.utils import TypeIdentifier, valid_backend_name
@@ -50,16 +50,16 @@ def from_namespace(cls, info):
5050
requires_opt_in=info.requires_opt_in,
5151
)
5252

53-
def known_type(self, relevant_type):
54-
if relevant_type in self.primary_types:
53+
def known_type(self, dispatch_type):
54+
if dispatch_type in self.primary_types:
5555
return "primary" # TODO: maybe make it an enum?
56-
elif relevant_type in self.secondary_types:
56+
elif dispatch_type in self.secondary_types:
5757
return "secondary"
5858
else:
5959
return False
6060

61-
def matches(self, relevant_types):
62-
matches = frozenset(self.known_type(t) for t in relevant_types)
61+
def matches(self, dispatch_types):
62+
matches = frozenset(self.known_type(t) for t in dispatch_types)
6363
if "primary" in matches and False not in matches:
6464
return True
6565
return False
@@ -594,19 +594,19 @@ def _get_entry_points(group, blocked):
594594
return sorted(backends, key=lambda x: x.name)
595595

596596
@functools.lru_cache(maxsize=128)
597-
def known_type(self, relevant_type, primary=False):
597+
def known_type(self, dispatch_type, primary=False):
598598
for backend in self.backends.values():
599-
if backend.known_type(relevant_type):
599+
if backend.known_type(dispatch_type):
600600
return True
601601
return False
602602

603-
def get_known_unique_types(self, relevant_types):
604-
# From a list of args, return only the set of relevant types
605-
return frozenset(val for val in relevant_types if self.known_type(val))
603+
def get_known_unique_types(self, dispatch_types):
604+
# From a list of args, return only the set of dispatch types
605+
return frozenset(val for val in dispatch_types if self.known_type(val))
606606

607607
@functools.lru_cache(maxsize=128)
608-
def get_types_and_backends(self, relevant_types, ordered_backends):
609-
"""Fetch relevant types and matching backends.
608+
def get_types_and_backends(self, dispatch_types, ordered_backends):
609+
"""Fetch dispatch types and matching backends.
610610
611611
The main purpose of this function is to cache the results for a set
612612
of unique input types to functions.
@@ -617,18 +617,18 @@ def get_types_and_backends(self, relevant_types, ordered_backends):
617617
618618
Returns
619619
-------
620-
relevant_types : frozenset
621-
The set of relevant types that are known to the backend system.
620+
dispatch_types : frozenset
621+
The set of dispatch types that are known to the backend system.
622622
matching_backends : tuple
623623
A tuple of backend names sorted by priority.
624624
"""
625625
# Filter out unknown types:
626-
relevant_types = self.get_known_unique_types(relevant_types)
626+
dispatch_types = self.get_known_unique_types(dispatch_types)
627627

628628
matching_backends = tuple(
629-
n for n in ordered_backends if self.backends[n].matches(relevant_types)
629+
n for n in ordered_backends if self.backends[n].matches(dispatch_types)
630630
)
631-
return relevant_types, matching_backends
631+
return dispatch_types, matching_backends
632632

633633
def backend_from_namespace(self, info_namespace):
634634
new_backend = Backend.from_namespace(info_namespace)
@@ -640,19 +640,20 @@ def backend_from_namespace(self, info_namespace):
640640
return
641641
self.backends[new_backend.name] = new_backend
642642

643-
def dispatchable(self, relevant_args=None, *, module=None, qualname=None):
643+
def dispatchable(self, dispatch_args=None, *, module=None, qualname=None):
644644
"""
645645
Decorator to mark functions as dispatchable.
646646
647647
Decorate a Python function with information on how to extract
648-
the "relevant" arguments, i.e. arguments we wish to dispatch for.
648+
the "dispatch" arguments, i.e. arguments we wish to dispatch for.
649649
650650
Parameters
651651
----------
652-
relevant_args : str, list, tuple, or None
652+
dispatch_args : str, list, tuple, or None
653653
The names of parameters to extract (we use inspect to
654654
map these correctly).
655-
If ``None`` all parameters will be considered relevant.
655+
If ``None`` all parameters will be considered relevant for
656+
dispatching.
656657
module : str
657658
Override the module of the function (actually modifies it)
658659
to ensure a well defined and stable public API.
@@ -675,7 +676,7 @@ def wrap_callable(func):
675676
if qualname is not None:
676677
func.__qualname__ = qualname
677678

678-
disp = Dispatchable(self, func, relevant_args)
679+
disp = Dispatchable(self, func, dispatch_args)
679680

680681
return disp
681682

@@ -703,7 +704,7 @@ class DispatchContext:
703704
704705
Attributes
705706
----------
706-
types : Sequence[type]
707+
types : tuple[type, ...]
707708
The (unique) types we are dispatching for. It is possible that
708709
not all types are passed as arguments if the user is requesting
709710
a specific type.
@@ -718,8 +719,17 @@ class DispatchContext:
718719
Backends that strictly match a single primary type can safely ignore this
719720
(they always return the same type).
720721
722+
dispatch_args : tuple[Any, ...]
723+
The arguments for which we dispatched. This can be useful information
724+
for some generic wrappers who still need to inspect all dispatch arguments.
725+
721726
.. note::
722-
This is a frozenset currently, but please consider it a sequence.
727+
``dispatch_args`` can be empty if a function takes no arguments.
728+
Yet, a backend version may be called explicitly e.g. in a
729+
``with backend_opts(type=): ...`` context.
730+
731+
name : str
732+
The name of the backend that was selected.
723733
724734
prioritized : bool
725735
Whether the backend is prioritized. You may use this for example when
@@ -730,7 +740,8 @@ class DispatchContext:
730740
# The idea is for the context to be very light-weight so that specific
731741
# information should be properties (because most likely we will never need it).
732742
# This object can grow to provide more information to backends.
733-
types: tuple[type]
743+
types: tuple[type, ...]
744+
dispatch_args: tuple[Any, ...]
734745
name: str
735746
_state: tuple
736747

@@ -797,7 +808,7 @@ class Dispatchable:
797808
#
798809
# TODO: We may want to return a function just to be nice (not having a func was
799810
# OK in NumPy for example, but has a few little stumbling blocks)
800-
def __init__(self, backend_system, func, relevant_args, ident=None):
811+
def __init__(self, backend_system, func, dispatch_args, ident=None):
801812
functools.update_wrapper(self, func)
802813

803814
self._backend_system = backend_system
@@ -807,11 +818,11 @@ def __init__(self, backend_system, func, relevant_args, ident=None):
807818

808819
self._ident = ident
809820

810-
if isinstance(relevant_args, str):
811-
relevant_args = {relevant_args: 0}
812-
elif isinstance(relevant_args, list | tuple):
813-
relevant_args = {val: i for i, val in enumerate(relevant_args)}
814-
self._relevant_args = relevant_args
821+
if isinstance(dispatch_args, str):
822+
dispatch_args = {dispatch_args: 0}
823+
elif isinstance(dispatch_args, list | tuple):
824+
dispatch_args = {val: i for i, val in enumerate(dispatch_args)}
825+
self._dispatch_args = dispatch_args
815826

816827
new_doc = []
817828
impl_infos = {}
@@ -851,27 +862,29 @@ def __get__(self, obj, objtype=None):
851862
return self
852863
return MethodType(self, obj)
853864

854-
def _get_relevant_types(self, *args, **kwargs):
855-
# Return all relevant types, these are not filtered by the known_types
856-
if self._relevant_args is None:
857-
return set(type(val) for val in args) | set(type(k) for k in kwargs.values())
865+
def _get_dispatch_args(self, *args, **kwargs):
866+
# Return all dispatch args
867+
if self._dispatch_args is None:
868+
return args + tuple(kwargs.values())
858869
else:
859-
return set(
860-
type(val) for name, pos in self._relevant_args.items()
870+
return tuple(
871+
val for name, pos in self._dispatch_args.items()
861872
if (val := args[pos] if pos < len(args) else kwargs.get(name)) is not None
862873
)
863874

864875
def __call__(self, *args, **kwargs):
865-
relevant_types = self._get_relevant_types(*args, **kwargs)
876+
dispatch_args = self._get_dispatch_args(*args, **kwargs)
877+
# At this point dispatch_types is not filtered for known types.
878+
dispatch_types = set(type(val) for val in dispatch_args)
866879
state = self._backend_system._dispatch_state.get()
867880
ordered_backends, type_, prioritized, trace = state
868881

869882
if type_ is not None:
870-
relevant_types.add(type_)
871-
relevant_types = frozenset(relevant_types)
883+
dispatch_types.add(type_)
884+
dispatch_types = frozenset(dispatch_types)
872885

873-
relevant_types, matching_backends = self._backend_system.get_types_and_backends(
874-
relevant_types, ordered_backends)
886+
dispatch_types, matching_backends = self._backend_system.get_types_and_backends(
887+
dispatch_types, ordered_backends)
875888

876889
if trace is not None:
877890
call_trace = []
@@ -886,7 +899,7 @@ def __call__(self, *args, **kwargs):
886899
# may want to optimize this (in case many backends have few functions).
887900
continue
888901

889-
context = DispatchContext(relevant_types, state, name)
902+
context = DispatchContext(tuple(dispatch_types), dispatch_args, name, state)
890903

891904
should_run = impl.should_run
892905
if should_run is None or (should_run := should_run(context, *args, **kwargs)) is True:

src/spatch/testing.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from spatch.utils import get_identifier
2+
3+
class _FuncGetter:
4+
def __init__(self, get):
5+
self.get = get
6+
7+
8+
class BackendDummy:
9+
"""Helper to construct a minimal "backend" for testing.
10+
11+
Forwards any lookup to the class. Documentation are used from
12+
the function which must match in the name.
13+
"""
14+
def __init__(self):
15+
self.functions = _FuncGetter(self.get_function)
16+
17+
@classmethod
18+
def get_function(cls, name, default=None):
19+
# Simply ignore the module for testing purposes.
20+
_, name = name.split(":")
21+
22+
# Not get_identifier because it would find the super-class name.
23+
res = {"function": f"{cls.__module__}:{cls.__name__}.{name}" }
24+
if hasattr(cls, "uses_context"):
25+
res["uses_context"] = cls.uses_context
26+
if hasattr(cls, "should_run"):
27+
res["should_run"] = get_identifier(cls.should_run)
28+
29+
func = getattr(cls, name)
30+
if func.__doc__ is not None:
31+
res["additional_docs"] = func.__doc__
32+
33+
return res
34+
35+
@classmethod
36+
def dummy_func(cls, *args, **kwargs):
37+
# Always define a small function that mainly forwards.
38+
return cls.name, args, kwargs
39+

tests/test_context.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pytest
2+
3+
from spatch.backend_system import BackendSystem
4+
from spatch.testing import BackendDummy
5+
6+
7+
class FloatWithContext(BackendDummy):
8+
name = "FloatWithContext"
9+
primary_types = ("~builtins:float",)
10+
secondary_types = ("builtins:int",)
11+
uses_context = True
12+
requires_opt_in = False
13+
14+
15+
def test_context_basic():
16+
bs = BackendSystem(
17+
None,
18+
environ_prefix="SPATCH_TEST",
19+
default_primary_types=("builtin:int",),
20+
backends=[FloatWithContext()]
21+
)
22+
23+
# Add a dummy dispatchable function that dispatches on all arguments.
24+
@bs.dispatchable(None, module="<test>", qualname="dummy_func")
25+
def dummy_func(*args, **kwargs):
26+
return "fallback", args, kwargs
27+
28+
_, (ctx, *args), kwargs = dummy_func(1, 1.)
29+
assert ctx.name == "FloatWithContext"
30+
assert set(ctx.types) == {int, float}
31+
assert ctx.dispatch_args == (1, 1.)
32+
assert not ctx.prioritized
33+
34+
class float_subclass(float):
35+
pass
36+
37+
with bs.backend_opts(prioritize=("FloatWithContext",)):
38+
_, (ctx, *args), kwargs = dummy_func(float_subclass(1.))
39+
assert ctx.name == "FloatWithContext"
40+
assert set(ctx.types) == {float_subclass}
41+
assert ctx.dispatch_args == (float_subclass(1.),)
42+
assert ctx.prioritized
43+
44+
with bs.backend_opts(type=float):
45+
# No argument, works if explicitly prioritized...
46+
_, (ctx, *args), kwargs = dummy_func()
47+
assert ctx.name == "FloatWithContext"
48+
assert set(ctx.types) == {float}
49+
assert ctx.dispatch_args == ()
50+
assert not ctx.prioritized # not prioritized "just" type enforced
51+

tests/test_priority.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,7 @@
11
import pytest
22

33
from spatch.backend_system import BackendSystem
4-
from spatch.utils import get_identifier
5-
6-
class FuncGetter:
7-
def __init__(self, get):
8-
self.get = get
9-
10-
11-
class BackendDummy:
12-
"""Helper to construct a minimal "backend" for testing.
13-
14-
Forwards any lookup to the class. Documentation are used from
15-
the function which must match in the name.
16-
"""
17-
def __init__(self):
18-
self.functions = FuncGetter(self.get_function)
19-
20-
@classmethod
21-
def get_function(cls, name, default=None):
22-
# Simply ignore the module for testing purposes.
23-
_, name = name.split(":")
24-
25-
# Not get_identifier because it would find the super-class name.
26-
res = {"function": f"{cls.__module__}:{cls.__name__}.{name}" }
27-
if hasattr(cls, "uses_context"):
28-
res["uses_context"] = cls.uses_context
29-
if hasattr(cls, "should_run"):
30-
res["should_run"] = get_identifier(cls.should_run)
31-
32-
func = getattr(cls, name)
33-
if func.__doc__ is not None:
34-
res["additional_docs"] = func.__doc__
35-
36-
return res
37-
38-
@classmethod
39-
def dummy_func(cls, *args, **kwargs):
40-
# Always define a small function that mainly forwards.
41-
return cls.name, args, kwargs
4+
from spatch.testing import BackendDummy
425

436

447
class IntB(BackendDummy):

0 commit comments

Comments
 (0)