Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/dependency_groups/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
DependencyGroupInclude,
DependencyGroupResolver,
resolve,
resolve_all,
)

__all__ = (
"CyclicDependencyError",
"DependencyGroupInclude",
"DependencyGroupResolver",
"resolve",
"resolve_all",
)
75 changes: 72 additions & 3 deletions src/dependency_groups/_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@ def _normalize_name(name: str) -> str:

def _normalize_group_names(
dependency_groups: Mapping[str, str | Mapping[str, str]],
) -> Mapping[str, str | Mapping[str, str]]:
) -> tuple[Mapping[str, str | Mapping[str, str]], Mapping[str, str]]:
"""
Normalize group names and return both normalized groups and reverse mapping.

Returns a tuple of (normalized_groups, normalized_to_original).
"""
original_names: dict[str, list[str]] = {}
normalized_groups = {}
normalized_to_original: dict[str, str] = {}

for group_name, value in dependency_groups.items():
normed_group_name = _normalize_name(group_name)
original_names.setdefault(normed_group_name, []).append(group_name)
normalized_groups[normed_group_name] = value
normalized_to_original[normed_group_name] = group_name

errors = []
for normed_name, names in original_names.items():
Expand All @@ -29,7 +36,7 @@ def _normalize_group_names(
if errors:
raise ValueError(f"Duplicate dependency group names: {', '.join(errors)}")

return normalized_groups
return normalized_groups, normalized_to_original


@dataclasses.dataclass
Expand Down Expand Up @@ -75,7 +82,9 @@ def __init__(
) -> None:
if not isinstance(dependency_groups, Mapping):
raise TypeError("Dependency Groups table is not a mapping")
self.dependency_groups = _normalize_group_names(dependency_groups)
self.dependency_groups, self._normalized_to_original = _normalize_group_names(
dependency_groups
)
# a map of group names to parsed data
self._parsed_groups: dict[
str, tuple[Requirement | DependencyGroupInclude, ...]
Expand Down Expand Up @@ -189,6 +198,25 @@ def _resolve(self, group: str, requested_group: str) -> tuple[Requirement, ...]:
self._resolve_cache[group] = tuple(resolved_group)
return self._resolve_cache[group]

def resolve_all(self) -> Mapping[str, tuple[Requirement, ...]]:
"""
Resolve all dependency groups, returning a mapping of normalized group
names to resolved requirements.

This is more efficient than calling resolve() on each group individually
because it avoids repeated work when groups share common includes.

:raises TypeError: if the data appears to be the wrong types
:raises ValueError: if the data does not appear to be valid dependency group
data
:raises packaging.requirements.InvalidRequirement: if a specifier is not valid
"""
# Resolve all groups that haven't been resolved yet
for group in self.dependency_groups:
self._resolve(group, group)

return dict(self._resolve_cache)


def resolve(
dependency_groups: Mapping[str, str | Mapping[str, str]], /, *groups: str
Expand All @@ -207,3 +235,44 @@ def resolve(
"""
resolver = DependencyGroupResolver(dependency_groups)
return tuple(str(r) for group in groups for r in resolver.resolve(group))


def resolve_all(
dependency_groups: Mapping[str, str | Mapping[str, str]],
/,
*,
normalize: bool = False,
) -> Mapping[str, tuple[str, ...]]:
"""
Resolve all dependency groups, returning a mapping of group names to
resolved requirements.

:param dependency_groups: the parsed contents of the ``[dependency-groups]`` table
from ``pyproject.toml``
:param normalize: if True normalize names, otherwise use original names
when returning keys, but still normalize for lookup. Defaults to False.

:raises TypeError: if the inputs appear to be the wrong types
:raises ValueError: if the data does not appear to be valid dependency group data
:raises packaging.requirements.InvalidRequirement: if a specifier is not valid

Example usage::

resolved = dependency_groups.resolve_all(dep_groups)
# {'test': ('pytest', 'sqlalchemy'), 'runtime': ('sqlalchemy',)}
"""
resolver = DependencyGroupResolver(dependency_groups)
resolved = resolver.resolve_all()
if normalize:
return {
group: tuple(str(r) for r in requirements)
for group, requirements in resolved.items()
}
else:
# Map back to original names
return {
resolver._normalized_to_original.get(group, group): tuple(
str(r) for r in requirements
)
for group, requirements in resolved.items()
}
150 changes: 148 additions & 2 deletions tests/test_resolve_func.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from dependency_groups import resolve
from dependency_groups import resolve, resolve_all


def test_empty_group():
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_cyclic_include():
def test_cyclic_include_many_steps():
groups = {}
for i in range(100):
groups[f"group{i}"] = [{"include-group": f"group{i+1}"}]
groups[f"group{i}"] = [{"include-group": f"group{i + 1}"}]
groups["group100"] = [{"include-group": "group0"}]
with pytest.raises(
ValueError,
Expand Down Expand Up @@ -161,3 +161,149 @@ def test_unknown_object_shape(item):
groups = {"test": [item]}
with pytest.raises(ValueError, match="Invalid dependency group item:"):
resolve(groups, "test")


def test_resolve_all_empty():
groups = {}
assert resolve_all(groups) == {}


def test_resolve_all_single_group():
groups = {"test": ["pytest"]}
assert resolve_all(groups) == {"test": ("pytest",)}


def test_resolve_all_multiple_groups():
groups = {
"test": ["pytest", {"include-group": "runtime"}],
"runtime": ["sqlalchemy"],
}
result = resolve_all(groups)
assert "test" in result
assert "runtime" in result
assert set(result["test"]) == {"pytest", "sqlalchemy"}
assert result["runtime"] == ("sqlalchemy",)


def test_resolve_all_with_normalize_false():
groups = {
"TEST": ["pytest"],
}
result = resolve_all(groups, normalize=False)
assert "TEST" in result
assert result["TEST"] == ("pytest",)


def test_resolve_all_with_normalize_true():
groups = {
"TEST": ["pytest"],
}
result = resolve_all(groups, normalize=True)
assert "test" in result
assert result["test"] == ("pytest",)


def test_resolve_all_default_preserves_names():
"""Test that normalize=False is the default."""
groups = {
"TEST": ["pytest"],
}
result = resolve_all(groups)
assert "TEST" in result
assert "test" not in result
assert result["TEST"] == ("pytest",)


def test_resolve_all_shared_includes():
"""Test that resolve_all correctly handles groups with shared includes.

Structure:
dev = [{include-group = "test"}, {include-group = "lint"}]
test = [{include-group = "pytest"}, {include-group = "coverage"}]
lint = ["prek", {include-group = "typing"}]
typing = ["mypy"]
pytest = ["pytest>=7"]
coverage = ["coverage[toml]"]
"""
groups = {
"dev": [{"include-group": "test"}, {"include-group": "lint"}],
"test": [{"include-group": "pytest"}, {"include-group": "coverage"}],
"lint": ["prek", {"include-group": "typing"}],
"typing": ["mypy"],
"pytest": ["pytest>=7"],
"coverage": ["coverage[toml]"],
}

result = resolve_all(groups)

# Verify all groups are present
assert set(result.keys()) == {"dev", "test", "lint", "typing", "pytest", "coverage"}

# Verify each group has correct contents
assert set(result["typing"]) == {"mypy"}
assert set(result["pytest"]) == {"pytest>=7"}
assert set(result["coverage"]) == {"coverage[toml]"}
assert set(result["lint"]) == {"prek", "mypy"}
assert set(result["test"]) == {"pytest>=7", "coverage[toml]"}
assert set(result["dev"]) == {"pytest>=7", "coverage[toml]", "prek", "mypy"}


def test_resolve_all_uses_cache():
"""Test that resolve_all uses the resolver's internal cache properly.

This tests that calling resolve_all uses the optimized resolve_all method
which avoids duplicate work when groups share includes.
"""
from dependency_groups._implementation import DependencyGroupResolver

groups = {
"dev": [{"include-group": "test"}, {"include-group": "lint"}],
"test": [{"include-group": "pytest"}, {"include-group": "coverage"}],
"lint": ["prek", {"include-group": "typing"}],
"typing": ["mypy"],
"pytest": ["pytest>=7"],
"coverage": ["coverage[toml]"],
}

# Create a resolver and check that resolve_all populates the cache
resolver = DependencyGroupResolver(groups)
assert len(resolver._resolve_cache) == 0

resolved = resolver.resolve_all()

# All groups should now be in the cache
assert len(resolver._resolve_cache) == len(groups)
for group in resolver.dependency_groups:
assert group in resolver._resolve_cache

# Verify the results are correct (resolve_all returns Requirement objects)
assert {str(r) for r in resolved["dev"]} == {
"pytest>=7",
"coverage[toml]",
"prek",
"mypy",
}
assert {str(r) for r in resolved["test"]} == {"pytest>=7", "coverage[toml]"}
assert {str(r) for r in resolved["lint"]} == {"prek", "mypy"}


def test_resolve_all_vs_individual_resolve():
"""Test that resolve_all produces the same results as individual resolve calls."""
groups = {
"dev": [{"include-group": "test"}, {"include-group": "lint"}],
"test": [{"include-group": "pytest"}, {"include-group": "coverage"}],
"lint": ["prek", {"include-group": "typing"}],
"typing": ["mypy"],
"pytest": ["pytest>=7"],
"coverage": ["coverage[toml]"],
}

# Using resolve_all
result_all = resolve_all(groups)

# Using individual resolve calls
result_individual = {group: resolve(groups, group) for group in groups}

assert set(result_all.keys()) == set(result_individual.keys())
for group in groups:
assert set(result_all[group]) == set(result_individual[group])
Loading