diff --git a/README.org b/README.org index 18524d5..de11404 100644 --- a/README.org +++ b/README.org @@ -76,6 +76,21 @@ def _(): def _(): return "MacOS" #+end_src +*** Dispatching based on environment variables +Environment variables can be checked by prefixing them with a ``$``. Only equality tests are supported for that. +#+begin_src python +@versiondispatch +def func(): + return "default" + +@func.register("$LANG==en_US.UTF-8") +def _(): + return "English" + +@func.register("$LANG==de_DE.UTF-8") +def _(): + return "German" +#+end_src *** Optional warnings It is possible to register warnings for specific versions. These warnings are shown to the user in case their version matches with the registerd version. @@ -138,8 +153,6 @@ ruff . * TODOs Under consideration to be implemented: ** Special keys -*** Environment variables -It might be nice to be able to check env vars, even if only for exact equality, like ~@foo.register("$LANG==en_US.UTF-8")~ ** Version inequality Add support for the ~!=~ operator ([[https://peps.python.org/pep-0440/#version-exclusion][PEP440]]) ** More checks on indicated versions diff --git a/src.py b/src.py index ba02748..21592b7 100644 --- a/src.py +++ b/src.py @@ -9,6 +9,7 @@ import collections import itertools import operator +import os import re import sys import warnings @@ -58,6 +59,8 @@ def _is_valid_package(package: str) -> bool: return True if package.lower() == "os": return True + if package.startswith("$"): + return True valid = True try: @@ -72,6 +75,8 @@ def get_version(package: str) -> Union[str, "Version"]: return Version(".".join(map(str, sys.version_info[:3]))) if package.lower() == "os": return sys.platform + if package.startswith("$"): + return os.environ.get(package[1:], "") return Version(_get_version(package)) @@ -120,9 +125,12 @@ def get_version(package: str) -> Union[str, "Version"]: # we can't do version_dict.get(package, _get_version(package)) since the # 2nd argument may raise an error, e.g. for 'Python' if version is None: - version = _get_version(package) + if package.startswith("$"): + version = os.environ.get(package[1:], "") + else: + version = _get_version(package) - if package.lower() == "os": + if package.lower() == "os" or package.startswith("$"): return version return Version(version) @@ -220,7 +228,11 @@ def _register( if not ( _is_valid_package(package) - and ((package.lower() == "os") or _is_valid_version(version)) + and ( + package.lower() == "os" + or package.startswith("$") + or _is_valid_version(version) + ) ): raise ValueError( f"{self._funcname} uses incorrect version spec or package is not " diff --git a/test.py b/test.py index 8b03c43..ab3dc61 100644 --- a/test.py +++ b/test.py @@ -655,6 +655,52 @@ def _old(): return "Windows" +class TestCheckEnvVar: + # dispatching on environment variables (only equality) + + def get_func(self): + @versiondispatch + def func(): + return "none" + + @func.register("$FOO==bar") + def _bar(): + return "bar" + + @func.register("$FOO==baz") + def _baz(): + return "baz" + + return func + + def test_default(self): + func = self.get_func() + assert func() == "none" + + def test_bar(self): + with pretend_version({"$FOO": "bar"}): + func = self.get_func() + assert func() == "bar" + + def test_baz(self): + with pretend_version({"$FOO": "baz"}): + func = self.get_func() + assert func() == "baz" + + @pytest.mark.parametrize("op", ["<", "<=", ">", ">="]) + def test_operator_not_eq_raises(self, op): + @versiondispatch + def func(): + return "none" + + match = "string comparison only possible with ==" + with pytest.raises(ValueError, match=match): + + @func.register(f"$FOO{op}bar") + def _old(): + return "bar" + + class TestWarnings: # test that warnings are shown if the version matches