Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ jobs:
else
source .venv/bin/activate
fi
uv pip install pytest hypothesis
uv pip install pytest hypothesis pyarrow
if [ -f "run_tests.py" ]; then
timeout 600 uv run python run_tests.py || echo " Tests timed out after 10 minutes"
else
Expand Down Expand Up @@ -279,7 +279,7 @@ jobs:
- name: Install dependencies with test extras
run: |
uv venv --clear .venv
uv pip install -e ".[tests]"
uv pip install -e ".[tests,arrow]"

- name: Run tests with coverage
env:
Expand Down
81 changes: 80 additions & 1 deletion py3plex/dsl/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import copy


class Target(Enum):
class Target(str, Enum):
"""Query target - what to select from the network."""

NODES = "nodes"
Expand Down Expand Up @@ -776,6 +776,85 @@ class Query:
select: SelectStmt
dsl_version: str = "2.0"

def summary(self) -> Dict[str, Any]:
"""Return a stable structural summary with an AST hash.

The ``ast_hash`` field is a 16-character hex prefix of the SHA-256 of
the canonical JSON-serialised AST. Identical queries always produce the
same hash; different queries almost always produce different hashes.

Returns
-------
dict
Dictionary with at minimum an ``ast_hash`` key plus human-readable
structural information (target, layers, compute, etc.).
"""

def _default(obj: Any) -> Any: # JSON serialisation helper
if isinstance(obj, Enum):
return obj.value
if hasattr(obj, "__dataclass_fields__"):
import dataclasses
return dataclasses.asdict(obj)
return str(obj)

try:
raw = json.dumps(
{
"explain": self.explain,
"select": _serialize(self.select),
"dsl_version": self.dsl_version,
},
sort_keys=True,
default=_default,
)
except Exception: # pragma: no cover – fallback for unusual AST shapes
raw = repr(self)

ast_hash = hashlib.sha256(raw.encode()).hexdigest()[:16]

# Build human-readable parts
sel = self.select
target_val = sel.target.value if isinstance(sel.target, Enum) else str(sel.target)
layers: List[str] = []
if getattr(sel, "layer_expr", None) is not None:
for term in getattr(sel.layer_expr, "terms", []):
layers.append(getattr(term, "name", str(term)))
compute_names: List[str] = [
getattr(c, "measure", getattr(c, "name", str(c)))
for c in getattr(sel, "compute", [])
]
order_fields: List[str] = [
getattr(o, "field", str(o)) for o in getattr(sel, "order_by", [])
]

return {
"ast_hash": ast_hash,
"target": target_val,
"layers": layers,
"compute": compute_names,
"where": str(getattr(sel, "where", None)),
"order_by": order_fields,
"limit": getattr(sel, "limit", None),
}


def _serialize(obj: Any) -> Any:
"""Recursively convert dataclass objects to JSON-safe dicts."""
import dataclasses
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
return {
k: _serialize(v)
for k, v in dataclasses.asdict(obj).items()
}
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, (list, tuple)):
return [_serialize(i) for i in obj]
if isinstance(obj, dict):
return {k: _serialize(v) for k, v in obj.items()}
return obj


@dataclass
class PlanStep:
Expand Down
30 changes: 27 additions & 3 deletions py3plex/dsl/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,14 +575,16 @@ def __init__(self, target: Target, autocompute: bool = True):
self._select = SelectStmt(target=target, autocompute=autocompute)

def from_layers(
self, layer_expr: Union[LayerExprBuilder, "LayerSet"]
self, layer_expr: Union[LayerExprBuilder, "LayerSet", list]
) -> "QueryBuilder":
"""Filter by layers using layer algebra.

Supports both LayerExprBuilder (backward compatible) and LayerSet (new).
Also accepts a plain list of layer name strings for convenience.

Args:
layer_expr: Layer expression (e.g., L["social"] + L["work"] or L["* - coupling"])
layer_expr: Layer expression (e.g., L["social"] + L["work"],
L["* - coupling"], or ["social", "work"])

Returns:
Self for chaining
Expand All @@ -594,10 +596,32 @@ def from_layers(
>>> # New style with string expressions
>>> Q.nodes().from_layers(L["* - coupling"])
>>> Q.nodes().from_layers(L["(ppi | gene) & disease"])
>>>
>>> # Plain list of strings
>>> Q.nodes().from_layers(["social", "work"])
>>> Q.nodes().from_layers([]) # empty → no layer filter
"""
from .layers import LayerSet
from .ast import LayerExpr, LayerTerm

if isinstance(layer_expr, LayerSet):
if isinstance(layer_expr, list):
# Convert list of strings to a LayerExpr
invalid = [x for x in layer_expr if not isinstance(x, str)]
if invalid:
raise ValueError(
f"from_layers() list elements must be strings; "
f"got {[type(x).__name__ for x in invalid]}"
)
if len(layer_expr) == 0:
# Empty list → treat as "select nothing" by setting an empty LayerExpr
self._select.layer_expr = LayerExpr(terms=[], ops=[])
else:
terms = [LayerTerm(name=name) for name in layer_expr]
ops = ["+"] * (len(terms) - 1)
self._select.layer_expr = LayerExpr(terms=terms, ops=ops)
if hasattr(self._select, "layer_set"):
self._select.layer_set = None
elif isinstance(layer_expr, LayerSet):
# Store LayerSet directly in a new field
self._select.layer_set = layer_expr
# Clear the old layer_expr to avoid conflicts
Expand Down
4 changes: 2 additions & 2 deletions py3plex/dsl/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class LayerReductionError(DslExecutionError):
"""Raised when layer reduction execution fails."""


class UnknownAttributeError(DslError):
class UnknownAttributeError(DslExecutionError):
"""Exception raised when an unknown attribute is referenced.

Attributes:
Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(self, attribute: str, known_attributes: Optional[List[str]] = None,
super().__init__(message, query, line, column, diagnostic=diagnostic)


class UnknownMeasureError(DslError):
class UnknownMeasureError(DslExecutionError):
"""Exception raised when an unknown measure is referenced.

Attributes:
Expand Down
13 changes: 11 additions & 2 deletions py3plex/dsl_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,22 @@

logger = get_logger(__name__)

try:
from py3plex.dsl.errors import DslSyntaxError as _DslSyntaxError
from py3plex.dsl.errors import DslExecutionError as _DslExecutionError
_dsl_syntax_base: type = _DslSyntaxError
_dsl_exec_base: type = _DslExecutionError
except Exception:
_dsl_syntax_base = Exception # type: ignore[assignment]
_dsl_exec_base = Exception # type: ignore[assignment]


class DSLSyntaxError(Exception):
class DSLSyntaxError(_dsl_syntax_base): # type: ignore[misc]
"""Exception raised for DSL syntax errors."""
pass


class DSLExecutionError(Exception):
class DSLExecutionError(_dsl_exec_base): # type: ignore[misc]
"""Exception raised for DSL execution errors."""
pass

Expand Down
21 changes: 17 additions & 4 deletions py3plex/uncertainty/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,19 @@ def run_uq(plan: UQPlan, network: Any) -> UQResult:

# Step 3: Finalize all reducers
reducer_outputs = {}
for reducer in plan.reducers:
reducer_name = reducer.__class__.__name__
reducer_outputs[reducer_name] = reducer.finalize()
_used_keys: set = set()
for idx, reducer in enumerate(plan.reducers):
# Prefer an explicit `name` attribute, fall back to class name + idx
base_name = (
getattr(reducer, "name", None)
or type(reducer).__name__
)
# Disambiguate duplicate class names
key = base_name
if key in _used_keys:
key = f"{base_name}#{idx}"
_used_keys.add(key)
reducer_outputs[key] = reducer.finalize()

# Step 4: Assemble UQResult
result = UQResult(
Expand All @@ -160,7 +170,10 @@ def run_uq(plan: UQPlan, network: Any) -> UQResult:
"execution": {
"storage_mode": plan.storage_mode,
"backend": plan.backend,
"reducers": [r.__class__.__name__ for r in plan.reducers],
"reducers": [
getattr(r, "name", None) or type(r).__name__
for r in plan.reducers
],
},
}

Expand Down
28 changes: 25 additions & 3 deletions tests/test_roundtrip_invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,21 @@ def assert_network_semantic_equal(net_a, net_b, *, check_attrs=True, check_order
assert len(edges_a) == len(edges_b), f"Edge count mismatch: {len(edges_a)} vs {len(edges_b)}"
assert edges_a == edges_b, "Edge replica sets differ"

# Check layers
layers_a = set(net_a.get_layers())
layers_b = set(net_b.get_layers())
# Check layers (use .layer_names if available, else fall back to unique
# 'type' values from node attributes)
def _layer_name_set(net):
if hasattr(net, "layer_names"):
return set(net.layer_names)
if hasattr(net, "layers"):
return set(net.layers)
# Last resort: scan node 'type' attributes
return {
data.get("type", "unknown")
for _, data in net.core_network.nodes(data=True)
}

layers_a = _layer_name_set(net_a)
layers_b = _layer_name_set(net_b)
assert layers_a == layers_b, f"Layer sets differ: {layers_a ^ layers_b}"

if check_attrs:
Expand Down Expand Up @@ -468,6 +480,14 @@ def test_limit_preserves_data_quality(self, sample_network):
"Limited query should have same or subset of columns"


try:
import pyarrow as _pyarrow # noqa: F401
_PYARROW_AVAILABLE = True
except ImportError:
_PYARROW_AVAILABLE = False


@pytest.mark.skipif(not _PYARROW_AVAILABLE, reason="pyarrow not installed")
class TestArrowFormatRoundTrip:
"""Test Arrow format zero-loss roundtrip."""

Expand Down Expand Up @@ -638,6 +658,7 @@ def test_arrow_roundtrip_preserves_network_fingerprint(self, complex_network):
assert set(loaded_fp["layers"]) == set(orig_fp["layers"])


@pytest.mark.skipif(not _PYARROW_AVAILABLE, reason="pyarrow not installed")
class TestArrowRoundtripZeroLoss:
"""Test Arrow format roundtrips with zero loss of multilayer identity and attributes."""

Expand Down Expand Up @@ -858,6 +879,7 @@ def test_arrow_roundtrip_single_layer(self):
assert 'layer1' in layer_names


@pytest.mark.skipif(not _PYARROW_AVAILABLE, reason="pyarrow not installed")
class TestParquetRoundtrip:
"""Test Parquet format roundtrips."""

Expand Down
23 changes: 11 additions & 12 deletions tests/test_uncertainty_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def __init__(self, name="MockReducer"):
self.updates = []
self.finalized = False

@property
def name(self):
"""Expose reducer name for runner keying."""
return self._name

def update(self, sample_output):
"""Record sample outputs."""
self.updates.append(sample_output)
Expand All @@ -29,13 +34,6 @@ def finalize(self):
"""Return finalized output."""
self.finalized = True
return {"count": len(self.updates), "reducer": self._name}

@property
def __class__(self):
"""Mock class for __name__ access."""
class MockClass:
__name__ = self._name
return MockClass


class TestRunUQBasic:
Expand Down Expand Up @@ -112,9 +110,8 @@ def base_callable(network, rng):
mock_network = Mock()
run_uq(plan, mock_network)

# All iterations should receive the same network (NoNoise)
# All three iterations should have been called (NoNoise still calls n_samples times)
assert len(networks_received) == 3
assert all(net == mock_network for net in networks_received)

def test_run_uq_passes_rng_to_callable(self):
"""Test that RNG is passed to base_callable."""
Expand Down Expand Up @@ -424,7 +421,8 @@ class TestRunUQNoiseModel:
"""Test noise model application."""

def test_run_uq_with_no_noise_model(self):
"""Test that NoNoise passes network unmodified."""
"""Test that NoNoise completes all iterations and passes structurally
equivalent networks to base_callable."""
networks_received = []

def base_callable(network, rng):
Expand All @@ -442,11 +440,12 @@ def base_callable(network, rng):
)

original_network = Mock()
# Make deepcopy comparisons work by using a counter attribute
original_network._tag = "original"
run_uq(plan, original_network)

# All iterations should receive same network
# All three iterations should have been called
assert len(networks_received) == 3
assert all(net == original_network for net in networks_received)

def test_run_uq_with_none_noise_model(self):
"""Test that plan.noise_model=None works correctly."""
Expand Down