diff --git a/src/dependency_groups/__init__.py b/src/dependency_groups/__init__.py index 9fec202..2a8f71b 100644 --- a/src/dependency_groups/__init__.py +++ b/src/dependency_groups/__init__.py @@ -3,6 +3,7 @@ DependencyGroupInclude, DependencyGroupResolver, resolve, + resolve_all, ) __all__ = ( @@ -10,4 +11,5 @@ "DependencyGroupInclude", "DependencyGroupResolver", "resolve", + "resolve_all", ) diff --git a/src/dependency_groups/_implementation.py b/src/dependency_groups/_implementation.py index c89edaf..4609fbe 100644 --- a/src/dependency_groups/_implementation.py +++ b/src/dependency_groups/_implementation.py @@ -13,14 +13,21 @@ def _normalize_name(name: str) -> str: def _normalize_group_names( dependency_groups: Mapping[str, str | Mapping[str, str]], -) -> Mapping[str, str | Mapping[str, str]]: +) -> tuple[Mapping[str, str | Mapping[str, str]], Mapping[str, str]]: + """ + Normalize group names and return both normalized groups and reverse mapping. + + Returns a tuple of (normalized_groups, normalized_to_original). + """ original_names: dict[str, list[str]] = {} normalized_groups = {} + normalized_to_original: dict[str, str] = {} for group_name, value in dependency_groups.items(): normed_group_name = _normalize_name(group_name) original_names.setdefault(normed_group_name, []).append(group_name) normalized_groups[normed_group_name] = value + normalized_to_original[normed_group_name] = group_name errors = [] for normed_name, names in original_names.items(): @@ -29,7 +36,7 @@ def _normalize_group_names( if errors: raise ValueError(f"Duplicate dependency group names: {', '.join(errors)}") - return normalized_groups + return normalized_groups, normalized_to_original @dataclasses.dataclass @@ -75,7 +82,9 @@ def __init__( ) -> None: if not isinstance(dependency_groups, Mapping): raise TypeError("Dependency Groups table is not a mapping") - self.dependency_groups = _normalize_group_names(dependency_groups) + self.dependency_groups, self._normalized_to_original = _normalize_group_names( + dependency_groups + ) # a map of group names to parsed data self._parsed_groups: dict[ str, tuple[Requirement | DependencyGroupInclude, ...] @@ -189,6 +198,25 @@ def _resolve(self, group: str, requested_group: str) -> tuple[Requirement, ...]: self._resolve_cache[group] = tuple(resolved_group) return self._resolve_cache[group] + def resolve_all(self) -> Mapping[str, tuple[Requirement, ...]]: + """ + Resolve all dependency groups, returning a mapping of normalized group + names to resolved requirements. + + This is more efficient than calling resolve() on each group individually + because it avoids repeated work when groups share common includes. + + :raises TypeError: if the data appears to be the wrong types + :raises ValueError: if the data does not appear to be valid dependency group + data + :raises packaging.requirements.InvalidRequirement: if a specifier is not valid + """ + # Resolve all groups that haven't been resolved yet + for group in self.dependency_groups: + self._resolve(group, group) + + return dict(self._resolve_cache) + def resolve( dependency_groups: Mapping[str, str | Mapping[str, str]], /, *groups: str @@ -207,3 +235,44 @@ def resolve( """ resolver = DependencyGroupResolver(dependency_groups) return tuple(str(r) for group in groups for r in resolver.resolve(group)) + + +def resolve_all( + dependency_groups: Mapping[str, str | Mapping[str, str]], + /, + *, + normalize: bool = False, +) -> Mapping[str, tuple[str, ...]]: + """ + Resolve all dependency groups, returning a mapping of group names to + resolved requirements. + + :param dependency_groups: the parsed contents of the ``[dependency-groups]`` table + from ``pyproject.toml`` + :param normalize: if True normalize names, otherwise use original names + when returning keys, but still normalize for lookup. Defaults to False. + + :raises TypeError: if the inputs appear to be the wrong types + :raises ValueError: if the data does not appear to be valid dependency group data + :raises packaging.requirements.InvalidRequirement: if a specifier is not valid + + Example usage:: + + resolved = dependency_groups.resolve_all(dep_groups) + # {'test': ('pytest', 'sqlalchemy'), 'runtime': ('sqlalchemy',)} + """ + resolver = DependencyGroupResolver(dependency_groups) + resolved = resolver.resolve_all() + if normalize: + return { + group: tuple(str(r) for r in requirements) + for group, requirements in resolved.items() + } + else: + # Map back to original names + return { + resolver._normalized_to_original.get(group, group): tuple( + str(r) for r in requirements + ) + for group, requirements in resolved.items() + } diff --git a/tests/test_resolve_func.py b/tests/test_resolve_func.py index 076a7f7..5a04d36 100644 --- a/tests/test_resolve_func.py +++ b/tests/test_resolve_func.py @@ -1,6 +1,6 @@ import pytest -from dependency_groups import resolve +from dependency_groups import resolve, resolve_all def test_empty_group(): @@ -95,7 +95,7 @@ def test_cyclic_include(): def test_cyclic_include_many_steps(): groups = {} for i in range(100): - groups[f"group{i}"] = [{"include-group": f"group{i+1}"}] + groups[f"group{i}"] = [{"include-group": f"group{i + 1}"}] groups["group100"] = [{"include-group": "group0"}] with pytest.raises( ValueError, @@ -161,3 +161,149 @@ def test_unknown_object_shape(item): groups = {"test": [item]} with pytest.raises(ValueError, match="Invalid dependency group item:"): resolve(groups, "test") + + +def test_resolve_all_empty(): + groups = {} + assert resolve_all(groups) == {} + + +def test_resolve_all_single_group(): + groups = {"test": ["pytest"]} + assert resolve_all(groups) == {"test": ("pytest",)} + + +def test_resolve_all_multiple_groups(): + groups = { + "test": ["pytest", {"include-group": "runtime"}], + "runtime": ["sqlalchemy"], + } + result = resolve_all(groups) + assert "test" in result + assert "runtime" in result + assert set(result["test"]) == {"pytest", "sqlalchemy"} + assert result["runtime"] == ("sqlalchemy",) + + +def test_resolve_all_with_normalize_false(): + groups = { + "TEST": ["pytest"], + } + result = resolve_all(groups, normalize=False) + assert "TEST" in result + assert result["TEST"] == ("pytest",) + + +def test_resolve_all_with_normalize_true(): + groups = { + "TEST": ["pytest"], + } + result = resolve_all(groups, normalize=True) + assert "test" in result + assert result["test"] == ("pytest",) + + +def test_resolve_all_default_preserves_names(): + """Test that normalize=False is the default.""" + groups = { + "TEST": ["pytest"], + } + result = resolve_all(groups) + assert "TEST" in result + assert "test" not in result + assert result["TEST"] == ("pytest",) + + +def test_resolve_all_shared_includes(): + """Test that resolve_all correctly handles groups with shared includes. + + Structure: + dev = [{include-group = "test"}, {include-group = "lint"}] + test = [{include-group = "pytest"}, {include-group = "coverage"}] + lint = ["prek", {include-group = "typing"}] + typing = ["mypy"] + pytest = ["pytest>=7"] + coverage = ["coverage[toml]"] + """ + groups = { + "dev": [{"include-group": "test"}, {"include-group": "lint"}], + "test": [{"include-group": "pytest"}, {"include-group": "coverage"}], + "lint": ["prek", {"include-group": "typing"}], + "typing": ["mypy"], + "pytest": ["pytest>=7"], + "coverage": ["coverage[toml]"], + } + + result = resolve_all(groups) + + # Verify all groups are present + assert set(result.keys()) == {"dev", "test", "lint", "typing", "pytest", "coverage"} + + # Verify each group has correct contents + assert set(result["typing"]) == {"mypy"} + assert set(result["pytest"]) == {"pytest>=7"} + assert set(result["coverage"]) == {"coverage[toml]"} + assert set(result["lint"]) == {"prek", "mypy"} + assert set(result["test"]) == {"pytest>=7", "coverage[toml]"} + assert set(result["dev"]) == {"pytest>=7", "coverage[toml]", "prek", "mypy"} + + +def test_resolve_all_uses_cache(): + """Test that resolve_all uses the resolver's internal cache properly. + + This tests that calling resolve_all uses the optimized resolve_all method + which avoids duplicate work when groups share includes. + """ + from dependency_groups._implementation import DependencyGroupResolver + + groups = { + "dev": [{"include-group": "test"}, {"include-group": "lint"}], + "test": [{"include-group": "pytest"}, {"include-group": "coverage"}], + "lint": ["prek", {"include-group": "typing"}], + "typing": ["mypy"], + "pytest": ["pytest>=7"], + "coverage": ["coverage[toml]"], + } + + # Create a resolver and check that resolve_all populates the cache + resolver = DependencyGroupResolver(groups) + assert len(resolver._resolve_cache) == 0 + + resolved = resolver.resolve_all() + + # All groups should now be in the cache + assert len(resolver._resolve_cache) == len(groups) + for group in resolver.dependency_groups: + assert group in resolver._resolve_cache + + # Verify the results are correct (resolve_all returns Requirement objects) + assert {str(r) for r in resolved["dev"]} == { + "pytest>=7", + "coverage[toml]", + "prek", + "mypy", + } + assert {str(r) for r in resolved["test"]} == {"pytest>=7", "coverage[toml]"} + assert {str(r) for r in resolved["lint"]} == {"prek", "mypy"} + + +def test_resolve_all_vs_individual_resolve(): + """Test that resolve_all produces the same results as individual resolve calls.""" + groups = { + "dev": [{"include-group": "test"}, {"include-group": "lint"}], + "test": [{"include-group": "pytest"}, {"include-group": "coverage"}], + "lint": ["prek", {"include-group": "typing"}], + "typing": ["mypy"], + "pytest": ["pytest>=7"], + "coverage": ["coverage[toml]"], + } + + # Using resolve_all + result_all = resolve_all(groups) + + # Using individual resolve calls + result_individual = {group: resolve(groups, group) for group in groups} + + assert set(result_all.keys()) == set(result_individual.keys()) + for group in groups: + assert set(result_all[group]) == set(result_individual[group])