diff --git a/catalogue/__init__.py b/catalogue/__init__.py index 32490bd..8e2b8c7 100644 --- a/catalogue/__init__.py +++ b/catalogue/__init__.py @@ -1,4 +1,5 @@ -from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union +from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union, Generic, Type +from types import ModuleType, MethodType, FunctionType, TracebackType, FrameType, CodeType from typing import List import inspect @@ -15,9 +16,9 @@ InFunc = TypeVar("InFunc") +S = TypeVar('S') - -def create(*namespace: str, entry_points: bool = False) -> "Registry": +def create(*namespace: str, entry_points: bool = False, generic_type: Optional[Type[S]] = None) -> "Registry[S]": """Create a new registry. *namespace (str): The namespace, e.g. "spacy" or "spacy", "architectures". @@ -26,10 +27,14 @@ def create(*namespace: str, entry_points: bool = False) -> "Registry": """ if check_exists(*namespace): raise RegistryError(f"Namespace already exists: {namespace}") - return Registry(namespace, entry_points=entry_points) + + if generic_type is None: + return Registry[Any](namespace, entry_points=entry_points) + else: + return Registry[S](namespace, entry_points=entry_points) -class Registry(object): +class Registry(Generic[InFunc]): def __init__(self, namespace: Sequence[str], entry_points: bool = False) -> None: """Initialize a new registry. @@ -47,27 +52,27 @@ def __contains__(self, name: str) -> bool: RETURNS (bool): Whether the name is in the registry. """ namespace = tuple(list(self.namespace) + [name]) - has_entry_point = self.entry_points and self.get_entry_point(name) + has_entry_point = self.entry_points and self.get_entry_point(name) is not None return has_entry_point or namespace in REGISTRY def __call__( - self, name: str, func: Optional[Any] = None - ) -> Callable[[InFunc], InFunc]: + self, name: str, func: Optional[InFunc] = None + ) -> Union[Callable[[InFunc], InFunc], InFunc]: """Register a function for a given namespace. Same as Registry.register. name (str): The name to register under the namespace. - func (Any): Optional function to register (if not used as decorator). + func (InFunc): Optional function to register (if not used as decorator). RETURNS (Callable): The decorator. """ return self.register(name, func=func) def register( - self, name: str, *, func: Optional[Any] = None - ) -> Callable[[InFunc], InFunc]: + self, name: str, *, func: Optional[InFunc] = None + ) -> Union[Callable[[InFunc], InFunc], InFunc]: """Register a function for a given namespace. name (str): The name to register under the namespace. - func (Any): Optional function to register (if not used as decorator). + func (InFunc): Optional function to register (if not used as decorator). RETURNS (Callable): The decorator. """ @@ -79,11 +84,11 @@ def do_registration(func): return do_registration(func) return do_registration - def get(self, name: str) -> Any: + def get(self, name: str) -> InFunc: """Get the registered function for a given name. name (str): The name. - RETURNS (Any): The registered function. + RETURNS (InFunc): The registered function. """ if self.entry_points: from_entry_point = self.get_entry_point(name) @@ -98,11 +103,11 @@ def get(self, name: str) -> Any: ) return _get(namespace) - def get_all(self) -> Dict[str, Any]: + def get_all(self) -> Dict[str, InFunc]: """Get a all functions for a given namespace. namespace (Tuple[str]): The namespace to get. - RETURNS (Dict[str, Any]): The functions, keyed by name. + RETURNS (Dict[str, InFunc]): The functions, keyed by name. """ global REGISTRY result = {} @@ -115,22 +120,22 @@ def get_all(self) -> Dict[str, Any]: result[keys[-1]] = value return result - def get_entry_points(self) -> Dict[str, Any]: + def get_entry_points(self) -> Dict[str, InFunc]: """Get registered entry points from other packages for this namespace. - RETURNS (Dict[str, Any]): Entry points, keyed by name. + RETURNS (Dict[str, InFunc]): Entry points, keyed by name. """ result = {} for entry_point in self._get_entry_points(): result[entry_point.name] = entry_point.load() return result - def get_entry_point(self, name: str, default: Optional[Any] = None) -> Any: + def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> Optional[InFunc]: """Check if registered entry point is available for a given name in the namespace and load it. Otherwise, return the default value. name (str): Name of entry point to load. - default (Any): The default value to return. + default (InFunc): The default value to return. RETURNS (Any): The loaded entry point or the default value. """ for entry_point in self._get_entry_points(): @@ -138,10 +143,11 @@ def get_entry_point(self, name: str, default: Optional[Any] = None) -> Any: return entry_point.load() return default - def _get_entry_points(self) -> List[importlib_metadata.EntryPoint]: + def _get_entry_points(self) -> Union[List[importlib_metadata.EntryPoint], importlib_metadata.EntryPoints]: if hasattr(AVAILABLE_ENTRY_POINTS, "select"): return AVAILABLE_ENTRY_POINTS.select(group=self.entry_point_namespace) else: # dict + assert isinstance(AVAILABLE_ENTRY_POINTS, dict) return AVAILABLE_ENTRY_POINTS.get(self.entry_point_namespace, []) def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]: @@ -158,6 +164,9 @@ def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]: line_no: Optional[int] = None file_name: Optional[str] = None try: + if not isinstance(func, (ModuleType, MethodType, FunctionType, TracebackType, FrameType, CodeType, type)): + raise TypeError(f"func type {type(func)} is not a valid type for inspect.getsourcelines()") + _, line_no = inspect.getsourcelines(func) file_name = inspect.getfile(func) except (TypeError, ValueError): @@ -170,7 +179,6 @@ def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]: "docstring": inspect.cleandoc(docstring) if docstring else None, } - def check_exists(*namespace: str) -> bool: """Check if a namespace exists. diff --git a/catalogue/tests/test_catalogue.py b/catalogue/tests/test_catalogue.py index e53ebf2..7cff281 100644 --- a/catalogue/tests/test_catalogue.py +++ b/catalogue/tests/test_catalogue.py @@ -159,3 +159,31 @@ def a(): assert info["file"] == str(Path(__file__)) assert info["docstring"] == "This is a registered function." assert info["line_no"] + +def test_registry_find_module(): + import json + + test_registry = catalogue.create("test_registry_find_module") + + test_registry.register("json", func=json) + + info = test_registry.find("json") + assert info["module"] == "json" + assert info["file"] == json.__file__ + assert info["docstring"] == json.__doc__.strip('\n') + assert info["line_no"] == 0 + +def test_registry_find_class(): + test_registry = catalogue.create("test_registry_find_class") + + class TestClass: + """This is a registered class.""" + pass + + test_registry.register("test_class", func=TestClass) + + info = test_registry.find("test_class") + assert info["module"] == "catalogue.tests.test_catalogue" + assert info["file"] == str(Path(__file__)) + assert info["docstring"] == TestClass.__doc__ + assert info["line_no"] \ No newline at end of file