From 142a51a3f0947ec954ad611b8e8876815d9c46e8 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Mon, 26 Apr 2021 10:43:36 -0700 Subject: [PATCH] Chained cl PiperOrigin-RevId: 370492077 --- gin/__init__.py | 1 + gin/config.py | 95 ++++++++++++++++++++++++++++++++++++-------- tests/config_test.py | 68 +++++++++++++++++++++++++++++++ 3 files changed, 147 insertions(+), 17 deletions(-) diff --git a/gin/__init__.py b/gin/__init__.py index cd796c4..3b428fe 100644 --- a/gin/__init__.py +++ b/gin/__init__.py @@ -29,6 +29,7 @@ from gin.config import exit_interactive_mode from gin.config import external_configurable from gin.config import finalize +from gin.config import get_bindings from gin.config import operative_config_str from gin.config import parse_config from gin.config import parse_config_file diff --git a/gin/config.py b/gin/config.py index 8d45bc6..9112e2e 100644 --- a/gin/config.py +++ b/gin/config.py @@ -93,7 +93,7 @@ def drink(cocktail): import sys import threading import traceback -from typing import Optional, Sequence, Union +from typing import Any, Callable, Dict, Optional, Sequence, Type, Union from gin import config_parser from gin import selector_map @@ -139,6 +139,8 @@ def exit_scope(self): # Maintains the registry of configurable functions and classes. _REGISTRY = selector_map.SelectorMap() +# Inverse registery to recover a binding from a function or class +_FN_OR_CLS_TO_SELECTOR = {} # Maps tuples of `(scope, selector)` to associated parameter values. This # specifies the current global "configuration" set through `bind_parameter` or @@ -983,6 +985,51 @@ def load_eval_data(): _SCOPE_MANAGER.exit_scope() +def get_bindings( + fn_or_cls: Union[str, Callable[..., Any], Type[Any]], +) -> Dict[str, Any]: + """Returns the bindings associated with the given configurable. + + Example: + + ```python + config.parse_config('MyParams.kwarg0 = 123') + + gin.get_bindings('MyParams') == {'kwarg0': 123} + ``` + + Note: The scope in which `get_bindings` is called will be used. + + Args: + fn_or_cls: Configurable function, class or selector `str` too. + + Returns: + The bindings kwargs injected by gin. + """ + if isinstance(fn_or_cls, str): + # Resolve partial selector -> full selector + selector = _REGISTRY.get_match(fn_or_cls) + if selector: + selector = selector.selector + else: + selector = _FN_OR_CLS_TO_SELECTOR.get(fn_or_cls) + + if selector is None: + raise ValueError(f'Could not find {fn_or_cls} in the gin register.') + + return _get_bindings(selector) + + +def _get_bindings(selector: str) -> Dict[str, Any]: + """Returns the bindings for the current full selector.""" + scope_components = current_scope() + new_kwargs = {} + for i in range(len(scope_components) + 1): + partial_scope_str = '/'.join(scope_components[:i]) + new_kwargs.update(_CONFIG.get((partial_scope_str, selector), {})) + return new_kwargs + + def _make_gin_wrapper(fn, fn_or_cls, name, selector, allowlist, denylist): """Creates the final Gin wrapper for the given function. @@ -1015,13 +1062,9 @@ def _make_gin_wrapper(fn, fn_or_cls, name, selector, allowlist, denylist): @functools.wraps(fn) def gin_wrapper(*args, **kwargs): """Supplies fn with parameter values from the configuration.""" - scope_components = current_scope() - new_kwargs = {} - for i in range(len(scope_components) + 1): - partial_scope_str = '/'.join(scope_components[:i]) - new_kwargs.update(_CONFIG.get((partial_scope_str, selector), {})) + new_kwargs = _get_bindings(selector) gin_bound_args = list(new_kwargs.keys()) - scope_str = partial_scope_str + scope_str = '/'.join(current_scope()) arg_names = _get_supplied_positional_parameter_names(signature_fn, args) @@ -1147,6 +1190,27 @@ def gin_wrapper(*args, **kwargs): return gin_wrapper +def _make_selector( + fn_or_cls, + *, + name: Optional[str], + module: Optional[str], +) -> str: + """Returns the gin name selector.""" + name = fn_or_cls.__name__ if name is None else name + if config_parser.IDENTIFIER_RE.match(name): + default_module = getattr(fn_or_cls, '__module__', None) + module = default_module if module is None else module + elif not config_parser.MODULE_RE.match(name): + raise ValueError("Configurable name '{}' is invalid.".format(name)) + + if module is not None and not config_parser.MODULE_RE.match(module): + raise ValueError("Module '{}' is invalid.".format(module)) + + selector = module + '.' + name if module else name + return selector + + def _make_configurable(fn_or_cls, name=None, module=None, @@ -1188,17 +1252,12 @@ def _make_configurable(fn_or_cls, err_str = 'Attempted to add a new configurable after the config was locked.' raise RuntimeError(err_str) - name = fn_or_cls.__name__ if name is None else name - if config_parser.IDENTIFIER_RE.match(name): - default_module = getattr(fn_or_cls, '__module__', None) - module = default_module if module is None else module - elif not config_parser.MODULE_RE.match(name): - raise ValueError("Configurable name '{}' is invalid.".format(name)) - - if module is not None and not config_parser.MODULE_RE.match(module): - raise ValueError("Module '{}' is invalid.".format(module)) + selector = _make_selector( + fn_or_cls, + name=name, + module=module, + ) - selector = module + '.' + name if module else name if not _INTERACTIVE_MODE and selector in _REGISTRY: err_str = ("A configurable matching '{}' already exists.\n\n" 'To allow re-registration of configurables in an interactive ' @@ -1234,6 +1293,8 @@ def decorator(fn): allowlist=allowlist, denylist=denylist, selector=selector) + # Inverse registery + _FN_OR_CLS_TO_SELECTOR[decorated_fn_or_cls] = selector return decorated_fn_or_cls diff --git a/tests/config_test.py b/tests/config_test.py index 2aeb505..c49f44a 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -1985,6 +1985,74 @@ def testEmptyNestedIncludesAndImports(self): [], ['TEST=1'], print_includes_and_imports=True) self.assertListEqual(result, []) + def testGetBindings(self): + # Bindings can be accessed through name or object + # Default are empty + self.assertDictEqual(config.get_bindings('configurable1'), {}) + self.assertDictEqual(config.get_bindings(fn1), {}) + + self.assertDictEqual(config.get_bindings('ConfigurableClass'), {}) + self.assertDictEqual(config.get_bindings(ConfigurableClass), {}) + + config_str = """ + configurable1.non_kwarg = 'kwarg1' + configurable1.kwarg2 = 123 + ConfigurableClass.kwarg1 = 'okie dokie' + """ + config.parse_config(config_str) + + self.assertDictEqual(config.get_bindings('configurable1'), { + 'non_kwarg': 'kwarg1', + 'kwarg2': 123, + }) + self.assertDictEqual(config.get_bindings(fn1), { + 'non_kwarg': 'kwarg1', + 'kwarg2': 123, + }) + + self.assertDictEqual(config.get_bindings('ConfigurableClass'), { + 'kwarg1': 'okie dokie', + }) + self.assertDictEqual(config.get_bindings(ConfigurableClass), { + 'kwarg1': 'okie dokie', + }) + + def testGetBindingsScope(self): + config_str = """ + configurable1.non_kwarg = 'kwarg1' + configurable1.kwarg2 = 123 + scope/configurable1.kwarg2 = 456 + """ + config.parse_config(config_str) + + self.assertDictEqual(config.get_bindings('configurable1'), { + 'non_kwarg': 'kwarg1', + 'kwarg2': 123, + }) + self.assertDictEqual(config.get_bindings(fn1), { + 'non_kwarg': 'kwarg1', + 'kwarg2': 123, + }) + + with config.config_scope('scope'): + self.assertDictEqual(config.get_bindings('configurable1'), { + 'non_kwarg': 'kwarg1', + 'kwarg2': 456, + }) + self.assertDictEqual(config.get_bindings(fn1), { + 'non_kwarg': 'kwarg1', + 'kwarg2': 456, + }) + + def testGetBindingsUnknown(self): + + expected_msg = 'Could not find .* in the gin register' + with self.assertRaisesRegex(ValueError, expected_msg): + config.get_bindings('UnknownParam') + + with self.assertRaisesRegex(ValueError, expected_msg): + config.get_bindings(lambda x: None) + if __name__ == '__main__': absltest.main()