From 4ad8093479a09f0606fec1f6f6788a4fec956fad Mon Sep 17 00:00:00 2001 From: thomasloux Date: Sun, 1 Mar 2026 11:48:04 +0000 Subject: [PATCH 1/3] remove count_dof --- tests/test_constraints.py | 34 ---------------------------------- torch_sim/constraints.py | 29 ----------------------------- 2 files changed, 63 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index a26c93bc..bcc3368c 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -9,7 +9,6 @@ Constraint, FixAtoms, FixCom, - count_degrees_of_freedom, merge_constraints, validate_constraints, ) @@ -669,39 +668,6 @@ def test_multiple_constraints_and_dof( ([FixCom([0]), FixAtoms(atom_idx=[0])], 6), ], ) -def test_count_degrees_of_freedom_single_system( - cu_sim_state: ts.SimState, constraint_list: list[Constraint], removed_dof: int -) -> None: - """count_degrees_of_freedom returns expected scalar for one system.""" - total_dof = 3 * cu_sim_state.n_atoms - computed_dof = count_degrees_of_freedom(cu_sim_state, constraint_list) - assert computed_dof == total_dof - removed_dof - - -def test_count_degrees_of_freedom_multi_system_sum( - mixed_double_sim_state: ts.SimState, -) -> None: - """count_degrees_of_freedom correctly sums removed dof across systems.""" - n_atoms_in_first_system = int(mixed_double_sim_state.n_atoms_per_system[0].item()) - constraint_list: list[Constraint] = [ - FixCom([0, 1]), - FixAtoms(atom_idx=[0, n_atoms_in_first_system]), - ] - total_dof = 3 * mixed_double_sim_state.n_atoms - computed_dof = count_degrees_of_freedom(mixed_double_sim_state, constraint_list) - assert computed_dof == total_dof - 12 - - -def test_count_degrees_of_freedom_clamped_to_zero( - cu_sim_state: ts.SimState, -) -> None: - """count_degrees_of_freedom never returns a negative value.""" - all_atom_indices = torch.arange(cu_sim_state.n_atoms, device=cu_sim_state.device) - constraint_list: list[Constraint] = [FixAtoms(atom_idx=all_atom_indices), FixCom([0])] - computed_dof = count_degrees_of_freedom(cu_sim_state, constraint_list) - assert computed_dof == 0 - - @pytest.mark.parametrize( ("cell_filter", "fire_flavor"), [ diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 239f6c8d..b5f5ea05 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -610,35 +610,6 @@ def __repr__(self) -> str: return f"FixCom(system_idx={self.system_idx})" -def count_degrees_of_freedom( - state: SimState, constraints: list[Constraint] | None = None -) -> int: - """Count the total degrees of freedom in a system with constraints. - - This function calculates the total number of degrees of freedom by starting - with the unconstrained count (n_atoms * 3) and subtracting the degrees of - freedom removed by each constraint. - - Args: - state: Simulation state - constraints: List of active constraints (optional) - - Returns: - Total number of degrees of freedom - """ - # Start with unconstrained DOF - total_dof: int | torch.Tensor = state.n_atoms * 3 - - # Subtract DOF removed by constraints (get_removed_dof returns per-system tensor) - if constraints is not None: - for constraint in constraints: - removed = constraint.get_removed_dof(state) - total_dof = total_dof - removed.sum() - - result = max(0, total_dof) - return int(result.item()) if isinstance(result, torch.Tensor) else result - - def check_no_index_out_of_bounds( indices: torch.Tensor, max_state_indices: int, constraint_name: str ) -> None: From c2aa80c14ec35e2f627faeb9f729ffe99e65627f Mon Sep 17 00:00:00 2001 From: thomasloux Date: Sun, 1 Mar 2026 11:55:01 +0000 Subject: [PATCH 2/3] remove parametrization that was used with a erased test --- tests/test_constraints.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index bcc3368c..a315afd7 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -658,16 +658,6 @@ def test_multiple_constraints_and_dof( ) assert torch.allclose(final_com, initial_com, atol=1e-5) - -@pytest.mark.parametrize( - ("constraint_list", "removed_dof"), - [ - ([], 0), - ([FixAtoms(atom_idx=[0, 1])], 6), - ([FixCom([0])], 3), - ([FixCom([0]), FixAtoms(atom_idx=[0])], 6), - ], -) @pytest.mark.parametrize( ("cell_filter", "fire_flavor"), [ From bee056b48768e9d9845bfc88e1f8b7907e7db363 Mon Sep 17 00:00:00 2001 From: thomasloux Date: Sun, 1 Mar 2026 12:05:56 +0000 Subject: [PATCH 3/3] lint --- tests/test_constraints.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_constraints.py b/tests/test_constraints.py index a315afd7..4e67209b 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -658,6 +658,7 @@ def test_multiple_constraints_and_dof( ) assert torch.allclose(final_com, initial_com, atol=1e-5) + @pytest.mark.parametrize( ("cell_filter", "fire_flavor"), [