88import warnings
99import textwrap
1010from types import MethodType
11- from typing import Callable
11+ from typing import Any , Callable
1212
1313from spatch import from_identifier , get_identifier
1414from spatch .utils import TypeIdentifier , valid_backend_name
@@ -50,16 +50,16 @@ def from_namespace(cls, info):
5050 requires_opt_in = info .requires_opt_in ,
5151 )
5252
53- def known_type (self , relevant_type ):
54- if relevant_type in self .primary_types :
53+ def known_type (self , dispatch_type ):
54+ if dispatch_type in self .primary_types :
5555 return "primary" # TODO: maybe make it an enum?
56- elif relevant_type in self .secondary_types :
56+ elif dispatch_type in self .secondary_types :
5757 return "secondary"
5858 else :
5959 return False
6060
61- def matches (self , relevant_types ):
62- matches = frozenset (self .known_type (t ) for t in relevant_types )
61+ def matches (self , dispatch_types ):
62+ matches = frozenset (self .known_type (t ) for t in dispatch_types )
6363 if "primary" in matches and False not in matches :
6464 return True
6565 return False
@@ -594,19 +594,19 @@ def _get_entry_points(group, blocked):
594594 return sorted (backends , key = lambda x : x .name )
595595
596596 @functools .lru_cache (maxsize = 128 )
597- def known_type (self , relevant_type , primary = False ):
597+ def known_type (self , dispatch_type , primary = False ):
598598 for backend in self .backends .values ():
599- if backend .known_type (relevant_type ):
599+ if backend .known_type (dispatch_type ):
600600 return True
601601 return False
602602
603- def get_known_unique_types (self , relevant_types ):
604- # From a list of args, return only the set of relevant types
605- return frozenset (val for val in relevant_types if self .known_type (val ))
603+ def get_known_unique_types (self , dispatch_types ):
604+ # From a list of args, return only the set of dispatch types
605+ return frozenset (val for val in dispatch_types if self .known_type (val ))
606606
607607 @functools .lru_cache (maxsize = 128 )
608- def get_types_and_backends (self , relevant_types , ordered_backends ):
609- """Fetch relevant types and matching backends.
608+ def get_types_and_backends (self , dispatch_types , ordered_backends ):
609+ """Fetch dispatch types and matching backends.
610610
611611 The main purpose of this function is to cache the results for a set
612612 of unique input types to functions.
@@ -617,18 +617,18 @@ def get_types_and_backends(self, relevant_types, ordered_backends):
617617
618618 Returns
619619 -------
620- relevant_types : frozenset
621- The set of relevant types that are known to the backend system.
620+ dispatch_types : frozenset
621+ The set of dispatch types that are known to the backend system.
622622 matching_backends : tuple
623623 A tuple of backend names sorted by priority.
624624 """
625625 # Filter out unknown types:
626- relevant_types = self .get_known_unique_types (relevant_types )
626+ dispatch_types = self .get_known_unique_types (dispatch_types )
627627
628628 matching_backends = tuple (
629- n for n in ordered_backends if self .backends [n ].matches (relevant_types )
629+ n for n in ordered_backends if self .backends [n ].matches (dispatch_types )
630630 )
631- return relevant_types , matching_backends
631+ return dispatch_types , matching_backends
632632
633633 def backend_from_namespace (self , info_namespace ):
634634 new_backend = Backend .from_namespace (info_namespace )
@@ -640,19 +640,20 @@ def backend_from_namespace(self, info_namespace):
640640 return
641641 self .backends [new_backend .name ] = new_backend
642642
643- def dispatchable (self , relevant_args = None , * , module = None , qualname = None ):
643+ def dispatchable (self , dispatch_args = None , * , module = None , qualname = None ):
644644 """
645645 Decorator to mark functions as dispatchable.
646646
647647 Decorate a Python function with information on how to extract
648- the "relevant " arguments, i.e. arguments we wish to dispatch for.
648+ the "dispatch " arguments, i.e. arguments we wish to dispatch for.
649649
650650 Parameters
651651 ----------
652- relevant_args : str, list, tuple, or None
652+ dispatch_args : str, list, tuple, or None
653653 The names of parameters to extract (we use inspect to
654654 map these correctly).
655- If ``None`` all parameters will be considered relevant.
655+ If ``None`` all parameters will be considered relevant for
656+ dispatching.
656657 module : str
657658 Override the module of the function (actually modifies it)
658659 to ensure a well defined and stable public API.
@@ -675,7 +676,7 @@ def wrap_callable(func):
675676 if qualname is not None :
676677 func .__qualname__ = qualname
677678
678- disp = Dispatchable (self , func , relevant_args )
679+ disp = Dispatchable (self , func , dispatch_args )
679680
680681 return disp
681682
@@ -703,7 +704,7 @@ class DispatchContext:
703704
704705 Attributes
705706 ----------
706- types : Sequence [type]
707+ types : tuple [type, ... ]
707708 The (unique) types we are dispatching for. It is possible that
708709 not all types are passed as arguments if the user is requesting
709710 a specific type.
@@ -718,8 +719,17 @@ class DispatchContext:
718719 Backends that strictly match a single primary type can safely ignore this
719720 (they always return the same type).
720721
722+ dispatch_args : tuple[Any, ...]
723+ The arguments for which we dispatched. This can be useful information
724+ for some generic wrappers who still need to inspect all dispatch arguments.
725+
721726 .. note::
722- This is a frozenset currently, but please consider it a sequence.
727+ ``dispatch_args`` can be empty if a function takes no arguments.
728+ Yet, a backend version may be called explicitly e.g. in a
729+ ``with backend_opts(type=): ...`` context.
730+
731+ name : str
732+ The name of the backend that was selected.
723733
724734 prioritized : bool
725735 Whether the backend is prioritized. You may use this for example when
@@ -730,7 +740,8 @@ class DispatchContext:
730740 # The idea is for the context to be very light-weight so that specific
731741 # information should be properties (because most likely we will never need it).
732742 # This object can grow to provide more information to backends.
733- types : tuple [type ]
743+ types : tuple [type , ...]
744+ dispatch_args : tuple [Any , ...]
734745 name : str
735746 _state : tuple
736747
@@ -797,7 +808,7 @@ class Dispatchable:
797808 #
798809 # TODO: We may want to return a function just to be nice (not having a func was
799810 # OK in NumPy for example, but has a few little stumbling blocks)
800- def __init__ (self , backend_system , func , relevant_args , ident = None ):
811+ def __init__ (self , backend_system , func , dispatch_args , ident = None ):
801812 functools .update_wrapper (self , func )
802813
803814 self ._backend_system = backend_system
@@ -807,11 +818,11 @@ def __init__(self, backend_system, func, relevant_args, ident=None):
807818
808819 self ._ident = ident
809820
810- if isinstance (relevant_args , str ):
811- relevant_args = {relevant_args : 0 }
812- elif isinstance (relevant_args , list | tuple ):
813- relevant_args = {val : i for i , val in enumerate (relevant_args )}
814- self ._relevant_args = relevant_args
821+ if isinstance (dispatch_args , str ):
822+ dispatch_args = {dispatch_args : 0 }
823+ elif isinstance (dispatch_args , list | tuple ):
824+ dispatch_args = {val : i for i , val in enumerate (dispatch_args )}
825+ self ._dispatch_args = dispatch_args
815826
816827 new_doc = []
817828 impl_infos = {}
@@ -851,27 +862,29 @@ def __get__(self, obj, objtype=None):
851862 return self
852863 return MethodType (self , obj )
853864
854- def _get_relevant_types (self , * args , ** kwargs ):
855- # Return all relevant types, these are not filtered by the known_types
856- if self ._relevant_args is None :
857- return set ( type ( val ) for val in args ) | set ( type ( k ) for k in kwargs .values ())
865+ def _get_dispatch_args (self , * args , ** kwargs ):
866+ # Return all dispatch args
867+ if self ._dispatch_args is None :
868+ return args + tuple ( kwargs .values ())
858869 else :
859- return set (
860- type ( val ) for name , pos in self ._relevant_args .items ()
870+ return tuple (
871+ val for name , pos in self ._dispatch_args .items ()
861872 if (val := args [pos ] if pos < len (args ) else kwargs .get (name )) is not None
862873 )
863874
864875 def __call__ (self , * args , ** kwargs ):
865- relevant_types = self ._get_relevant_types (* args , ** kwargs )
876+ dispatch_args = self ._get_dispatch_args (* args , ** kwargs )
877+ # At this point dispatch_types is not filtered for known types.
878+ dispatch_types = set (type (val ) for val in dispatch_args )
866879 state = self ._backend_system ._dispatch_state .get ()
867880 ordered_backends , type_ , prioritized , trace = state
868881
869882 if type_ is not None :
870- relevant_types .add (type_ )
871- relevant_types = frozenset (relevant_types )
883+ dispatch_types .add (type_ )
884+ dispatch_types = frozenset (dispatch_types )
872885
873- relevant_types , matching_backends = self ._backend_system .get_types_and_backends (
874- relevant_types , ordered_backends )
886+ dispatch_types , matching_backends = self ._backend_system .get_types_and_backends (
887+ dispatch_types , ordered_backends )
875888
876889 if trace is not None :
877890 call_trace = []
@@ -886,7 +899,7 @@ def __call__(self, *args, **kwargs):
886899 # may want to optimize this (in case many backends have few functions).
887900 continue
888901
889- context = DispatchContext (relevant_types , state , name )
902+ context = DispatchContext (tuple ( dispatch_types ), dispatch_args , name , state )
890903
891904 should_run = impl .should_run
892905 if should_run is None or (should_run := should_run (context , * args , ** kwargs )) is True :
0 commit comments