Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 57 additions & 44 deletions src/spatch/backend_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
import textwrap
from types import MethodType
from typing import Callable
from typing import Any, Callable

from spatch import from_identifier, get_identifier
from spatch.utils import TypeIdentifier, valid_backend_name
Expand Down Expand Up @@ -50,16 +50,16 @@ def from_namespace(cls, info):
requires_opt_in=info.requires_opt_in,
)

def known_type(self, relevant_type):
if relevant_type in self.primary_types:
def known_type(self, dispatch_type):
if dispatch_type in self.primary_types:
return "primary" # TODO: maybe make it an enum?
elif relevant_type in self.secondary_types:
elif dispatch_type in self.secondary_types:
return "secondary"
else:
return False

def matches(self, relevant_types):
matches = frozenset(self.known_type(t) for t in relevant_types)
def matches(self, dispatch_types):
matches = frozenset(self.known_type(t) for t in dispatch_types)
if "primary" in matches and False not in matches:
return True
return False
Expand Down Expand Up @@ -594,19 +594,19 @@ def _get_entry_points(group, blocked):
return sorted(backends, key=lambda x: x.name)

@functools.lru_cache(maxsize=128)
def known_type(self, relevant_type, primary=False):
def known_type(self, dispatch_type, primary=False):
for backend in self.backends.values():
if backend.known_type(relevant_type):
if backend.known_type(dispatch_type):
return True
return False

def get_known_unique_types(self, relevant_types):
# From a list of args, return only the set of relevant types
return frozenset(val for val in relevant_types if self.known_type(val))
def get_known_unique_types(self, dispatch_types):
# From a list of args, return only the set of dispatch types
return frozenset(val for val in dispatch_types if self.known_type(val))

@functools.lru_cache(maxsize=128)
def get_types_and_backends(self, relevant_types, ordered_backends):
"""Fetch relevant types and matching backends.
def get_types_and_backends(self, dispatch_types, ordered_backends):
"""Fetch dispatch types and matching backends.

The main purpose of this function is to cache the results for a set
of unique input types to functions.
Expand All @@ -617,18 +617,18 @@ def get_types_and_backends(self, relevant_types, ordered_backends):

Returns
-------
relevant_types : frozenset
The set of relevant types that are known to the backend system.
dispatch_types : frozenset
The set of dispatch types that are known to the backend system.
matching_backends : tuple
A tuple of backend names sorted by priority.
"""
# Filter out unknown types:
relevant_types = self.get_known_unique_types(relevant_types)
dispatch_types = self.get_known_unique_types(dispatch_types)

matching_backends = tuple(
n for n in ordered_backends if self.backends[n].matches(relevant_types)
n for n in ordered_backends if self.backends[n].matches(dispatch_types)
)
return relevant_types, matching_backends
return dispatch_types, matching_backends

def backend_from_namespace(self, info_namespace):
new_backend = Backend.from_namespace(info_namespace)
Expand All @@ -640,19 +640,20 @@ def backend_from_namespace(self, info_namespace):
return
self.backends[new_backend.name] = new_backend

def dispatchable(self, relevant_args=None, *, module=None, qualname=None):
def dispatchable(self, dispatch_args=None, *, module=None, qualname=None):
"""
Decorator to mark functions as dispatchable.

Decorate a Python function with information on how to extract
the "relevant" arguments, i.e. arguments we wish to dispatch for.
the "dispatch" arguments, i.e. arguments we wish to dispatch for.

Parameters
----------
relevant_args : str, list, tuple, or None
dispatch_args : str, list, tuple, or None
The names of parameters to extract (we use inspect to
map these correctly).
If ``None`` all parameters will be considered relevant.
If ``None`` all parameters will be considered relevant for
dispatching.
module : str
Override the module of the function (actually modifies it)
to ensure a well defined and stable public API.
Expand All @@ -675,7 +676,7 @@ def wrap_callable(func):
if qualname is not None:
func.__qualname__ = qualname

disp = Dispatchable(self, func, relevant_args)
disp = Dispatchable(self, func, dispatch_args)

return disp

Expand Down Expand Up @@ -703,7 +704,7 @@ class DispatchContext:

Attributes
----------
types : Sequence[type]
types : tuple[type, ...]
The (unique) types we are dispatching for. It is possible that
not all types are passed as arguments if the user is requesting
a specific type.
Expand All @@ -718,8 +719,17 @@ class DispatchContext:
Backends that strictly match a single primary type can safely ignore this
(they always return the same type).

dispatch_args : tuple[Any, ...]
The arguments for which we dispatched. This can be useful information
for some generic wrappers who still need to inspect all dispatch arguments.

.. note::
This is a frozenset currently, but please consider it a sequence.
``dispatch_args`` can be empty if a function takes no arguments.
Yet, a backend version may be called explicitly e.g. in a
``with backend_opts(type=): ...`` context.

name : str
The name of the backend that was selected.

prioritized : bool
Whether the backend is prioritized. You may use this for example when
Expand All @@ -730,7 +740,8 @@ class DispatchContext:
# The idea is for the context to be very light-weight so that specific
# information should be properties (because most likely we will never need it).
# This object can grow to provide more information to backends.
types: tuple[type]
types: tuple[type, ...]
dispatch_args: tuple[Any, ...]
name: str
_state: tuple

Expand Down Expand Up @@ -797,7 +808,7 @@ class Dispatchable:
#
# TODO: We may want to return a function just to be nice (not having a func was
# OK in NumPy for example, but has a few little stumbling blocks)
def __init__(self, backend_system, func, relevant_args, ident=None):
def __init__(self, backend_system, func, dispatch_args, ident=None):
functools.update_wrapper(self, func)

self._backend_system = backend_system
Expand All @@ -807,11 +818,11 @@ def __init__(self, backend_system, func, relevant_args, ident=None):

self._ident = ident

if isinstance(relevant_args, str):
relevant_args = {relevant_args: 0}
elif isinstance(relevant_args, list | tuple):
relevant_args = {val: i for i, val in enumerate(relevant_args)}
self._relevant_args = relevant_args
if isinstance(dispatch_args, str):
dispatch_args = {dispatch_args: 0}
elif isinstance(dispatch_args, list | tuple):
dispatch_args = {val: i for i, val in enumerate(dispatch_args)}
self._dispatch_args = dispatch_args

new_doc = []
impl_infos = {}
Expand Down Expand Up @@ -851,27 +862,29 @@ def __get__(self, obj, objtype=None):
return self
return MethodType(self, obj)

def _get_relevant_types(self, *args, **kwargs):
# Return all relevant types, these are not filtered by the known_types
if self._relevant_args is None:
return set(type(val) for val in args) | set(type(k) for k in kwargs.values())
def _get_dispatch_args(self, *args, **kwargs):
# Return all dispatch args
if self._dispatch_args is None:
return args + tuple(kwargs.values())
else:
return set(
type(val) for name, pos in self._relevant_args.items()
return tuple(
val for name, pos in self._dispatch_args.items()
if (val := args[pos] if pos < len(args) else kwargs.get(name)) is not None
)

def __call__(self, *args, **kwargs):
relevant_types = self._get_relevant_types(*args, **kwargs)
dispatch_args = self._get_dispatch_args(*args, **kwargs)
# At this point dispatch_types is not filtered for known types.
dispatch_types = set(type(val) for val in dispatch_args)
state = self._backend_system._dispatch_state.get()
ordered_backends, type_, prioritized, trace = state

if type_ is not None:
relevant_types.add(type_)
relevant_types = frozenset(relevant_types)
dispatch_types.add(type_)
dispatch_types = frozenset(dispatch_types)

relevant_types, matching_backends = self._backend_system.get_types_and_backends(
relevant_types, ordered_backends)
dispatch_types, matching_backends = self._backend_system.get_types_and_backends(
dispatch_types, ordered_backends)

if trace is not None:
call_trace = []
Expand All @@ -886,7 +899,7 @@ def __call__(self, *args, **kwargs):
# may want to optimize this (in case many backends have few functions).
continue

context = DispatchContext(relevant_types, state, name)
context = DispatchContext(tuple(dispatch_types), dispatch_args, name, state)

should_run = impl.should_run
if should_run is None or (should_run := should_run(context, *args, **kwargs)) is True:
Expand Down
39 changes: 39 additions & 0 deletions src/spatch/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from spatch.utils import get_identifier

class _FuncGetter:
def __init__(self, get):
self.get = get


class BackendDummy:
"""Helper to construct a minimal "backend" for testing.

Forwards any lookup to the class. Documentation are used from
the function which must match in the name.
"""
def __init__(self):
self.functions = _FuncGetter(self.get_function)

@classmethod
def get_function(cls, name, default=None):
# Simply ignore the module for testing purposes.
_, name = name.split(":")

# Not get_identifier because it would find the super-class name.
res = {"function": f"{cls.__module__}:{cls.__name__}.{name}" }
if hasattr(cls, "uses_context"):
res["uses_context"] = cls.uses_context
if hasattr(cls, "should_run"):
res["should_run"] = get_identifier(cls.should_run)

func = getattr(cls, name)
if func.__doc__ is not None:
res["additional_docs"] = func.__doc__

return res

@classmethod
def dummy_func(cls, *args, **kwargs):
# Always define a small function that mainly forwards.
return cls.name, args, kwargs

51 changes: 51 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest

from spatch.backend_system import BackendSystem
from spatch.testing import BackendDummy


class FloatWithContext(BackendDummy):
name = "FloatWithContext"
primary_types = ("~builtins:float",)
secondary_types = ("builtins:int",)
uses_context = True
requires_opt_in = False


def test_context_basic():
bs = BackendSystem(
None,
environ_prefix="SPATCH_TEST",
default_primary_types=("builtin:int",),
backends=[FloatWithContext()]
)

# Add a dummy dispatchable function that dispatches on all arguments.
@bs.dispatchable(None, module="<test>", qualname="dummy_func")
def dummy_func(*args, **kwargs):
return "fallback", args, kwargs

_, (ctx, *args), kwargs = dummy_func(1, 1.)
assert ctx.name == "FloatWithContext"
assert set(ctx.types) == {int, float}
assert ctx.dispatch_args == (1, 1.)
assert not ctx.prioritized

class float_subclass(float):
pass

with bs.backend_opts(prioritize=("FloatWithContext",)):
_, (ctx, *args), kwargs = dummy_func(float_subclass(1.))
assert ctx.name == "FloatWithContext"
assert set(ctx.types) == {float_subclass}
assert ctx.dispatch_args == (float_subclass(1.),)
assert ctx.prioritized

with bs.backend_opts(type=float):
# No argument, works if explicitly prioritized...
_, (ctx, *args), kwargs = dummy_func()
assert ctx.name == "FloatWithContext"
assert set(ctx.types) == {float}
assert ctx.dispatch_args == ()
assert not ctx.prioritized # not prioritized "just" type enforced

39 changes: 1 addition & 38 deletions tests/test_priority.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,7 @@
import pytest

from spatch.backend_system import BackendSystem
from spatch.utils import get_identifier

class FuncGetter:
def __init__(self, get):
self.get = get


class BackendDummy:
"""Helper to construct a minimal "backend" for testing.

Forwards any lookup to the class. Documentation are used from
the function which must match in the name.
"""
def __init__(self):
self.functions = FuncGetter(self.get_function)

@classmethod
def get_function(cls, name, default=None):
# Simply ignore the module for testing purposes.
_, name = name.split(":")

# Not get_identifier because it would find the super-class name.
res = {"function": f"{cls.__module__}:{cls.__name__}.{name}" }
if hasattr(cls, "uses_context"):
res["uses_context"] = cls.uses_context
if hasattr(cls, "should_run"):
res["should_run"] = get_identifier(cls.should_run)

func = getattr(cls, name)
if func.__doc__ is not None:
res["additional_docs"] = func.__doc__

return res

@classmethod
def dummy_func(cls, *args, **kwargs):
# Always define a small function that mainly forwards.
return cls.name, args, kwargs
from spatch.testing import BackendDummy


class IntB(BackendDummy):
Expand Down