Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,5 @@ cython_debug/
**.h5

**settings.json

# Test artifacts
test_graph
test_graph_dict
84 changes: 84 additions & 0 deletions docs/source/notebooks/AdvancedGuide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,90 @@
"Note that the groups will automatically order themselves by sorting the group integers, so you can easily pick where each group goes in the params. You can even place groups ahead of group 0 by using negative integers."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Node Collections: `NodeList`, `NodeTuple`, and `NodeDict`\n",
"\n",
"Sometimes you want to group a subset of nodes together without wrapping them inside a new `Module`. `caskade` provides three lightweight collection types for this purpose:\n",
"\n",
"- **`NodeList`** – a mutable, ordered list of nodes (supports `append`, `insert`, `pop`, etc.)\n",
"- **`NodeTuple`** – an immutable, ordered tuple of nodes\n",
"- **`NodeDict`** – a mutable mapping of `str → node`, with attribute-style access\n",
"\n",
"All three behave like their built-in Python counterparts while also being `Node` objects themselves, so they participate fully in the caskade graph.\n",
"\n",
"When you assign a plain Python `list`, `tuple`, or `dict` of nodes as an attribute of a `Module`, `caskade` automatically converts it to the appropriate collection type:\n",
"\n",
"```python\n",
"self.my_params = [p1, p2, p3] # becomes NodeList\n",
"self.my_params = (p1, p2, p3) # becomes NodeTuple\n",
"self.my_params = {'a': p1, 'b': p2} # becomes NodeDict\n",
"```\n",
"\n",
"This means you never have to construct `NodeList` / `NodeTuple` / `NodeDict` explicitly inside a `Module` — just assign the plain collection and caskade handles the rest.\n",
"\n",
"### Use cases\n",
"\n",
"- **Gibbs-style sampling** – keep one `NodeList` of the parameters you currently want to sample (dynamic) and another for the parameters you want to hold fixed (static). Calling `.to_dynamic()` / `.to_static()` on the collection flips the whole subset at once.\n",
"- **Selective saving / updating** – build a `NodeDict` of the parameters you care about and call `get_values()` / `set_values()` on just that subset, without touching the rest of the graph.\n",
"- **Quick inspection** – group modules or params of interest into a collection and iterate over them for printing, plotting, or diagnostics without restructuring the model.\n",
"\n",
"Note that `set_values` and `get_values` also work on Node collections."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Collections can be constructed directly\n",
"G = Gaussian(\"G\", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n",
"\n",
"# NodeList: mutable, ordered\n",
"position_params = ck.NodeList([G.x0, G.y0])\n",
"print(\"NodeList:\", position_params)\n",
"\n",
"# NodeTuple: immutable, ordered\n",
"shape_params = ck.NodeTuple((G.q, G.phi, G.sigma))\n",
"print(\"NodeTuple:\", shape_params)\n",
"\n",
"# NodeDict: mutable, keyed\n",
"named_params = ck.NodeDict({\"x0\": G.x0, \"y0\": G.y0, \"sigma\": G.sigma})\n",
"print(\"NodeDict:\", named_params)\n",
"print(\"Attribute access:\", named_params[\"x0\"]) # same as named_params.x0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Auto-conversion when assigned to a Module attribute\n",
"class GaussianWithCollections(ck.Module):\n",
" def __init__(self, name, submod):\n",
" super().__init__(name)\n",
" self.g = submod\n",
" # plain list/dict – caskade converts them automatically\n",
" self.position = [submod.x0, submod.y0] # -> NodeList\n",
" self.named = {\"q\": submod.q, \"phi\": submod.phi} # -> NodeDict\n",
"\n",
" @ck.forward\n",
" def __call__(self, x, y):\n",
" return self.g(x, y)\n",
"\n",
"\n",
"G2 = Gaussian(\"G2\", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n",
"gc = GaussianWithCollections(\"gc\", G2)\n",
"print(type(gc.position)) # NodeList\n",
"print(type(gc.named)) # NodeDict\n",
"gc.position.to_dynamic() # can call collection methods\n",
"display(gc.graphviz())"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
3 changes: 2 additions & 1 deletion src/caskade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .decorators import forward, active_cache
from .module import Module
from .param import Param
from .collection import NodeCollection, NodeList, NodeTuple
from .collection import NodeCollection, NodeList, NodeTuple, NodeDict
from .tests import test
from .errors import (
CaskadeException,
Expand Down Expand Up @@ -58,6 +58,7 @@
"NodeCollection",
"NodeList",
"NodeTuple",
"NodeDict",
"ActiveContext",
"ValidContext",
"OverrideParam",
Expand Down
157 changes: 139 additions & 18 deletions src/caskade/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,39 +208,51 @@ def _link_nodes(self):
def append(self, node):
"""Append a node to the list and update graph links."""
self._unlink_nodes()
super().append(node)
self._link_nodes()
try:
super().append(node)
finally:
self._link_nodes()

def insert(self, index, node):
"""Insert a node at the given index and update graph links."""
self._unlink_nodes()
super().insert(index, node)
self._link_nodes()
try:
super().insert(index, node)
finally:
self._link_nodes()

def extend(self, iterable):
"""Extend the list with nodes from an iterable and update graph links."""
self._unlink_nodes()
super().extend(iterable)
self._link_nodes()
try:
super().extend(iterable)
finally:
self._link_nodes()

def clear(self):
"""Remove all nodes from the list and update graph links."""
self._unlink_nodes()
super().clear()
self._link_nodes()
try:
super().clear()
finally:
self._link_nodes()

def pop(self, index=-1):
"""Remove and return a node at the given index, updating graph links."""
self._unlink_nodes()
node = super().pop(index)
self._link_nodes()
try:
node = super().pop(index)
finally:
self._link_nodes()
return node

def remove(self, value):
"""Remove the first occurrence of a node and update graph links."""
self._unlink_nodes()
super().remove(value)
self._link_nodes()
try:
super().remove(value)
finally:
self._link_nodes()

def __getitem__(self, key):
if isinstance(key, str):
Expand All @@ -251,23 +263,132 @@ def __getitem__(self, key):

def __setitem__(self, key, value):
self._unlink_nodes()
super().__setitem__(key, value)
self._link_nodes()
try:
super().__setitem__(key, value)
finally:
self._link_nodes()

def __delitem__(self, key):
self._unlink_nodes()
super().__delitem__(key)
self._link_nodes()
try:
super().__delitem__(key)
finally:
self._link_nodes()

def __add__(self, other):
res = super().__add__(other)
return NodeList(res, name=self.name)

def __iadd__(self, other):
self._unlink_nodes()
ret = super().__iadd__(other)
self._link_nodes()
try:
ret = super().__iadd__(other)
finally:
self._link_nodes()
return ret

def __imul__(self, other):
raise NotImplementedError


class NodeDict(NodeCollection, dict):
"""Mutable, keyed collection of nodes.

Behaves like a standard ``dict`` but also participates in the caskade
node graph. All elements must be ``Node`` instances. Graph links are
automatically updated whenever the dict is modified.

Parameters
----------
mapping : mapping of str to Node, optional
Nodes to include in the dict. Defaults to an empty dict.
name : str, optional
Human-readable name for this collection of nodes.
"""

def __init__(self, mapping=None, name=None):
if mapping is None:
mapping = {}
dict.__init__(self, mapping)
Node.__init__(self, name=name)
self.node_type = "ndict"
self._link_nodes()

@property
def graphviz_style(self):
return {"style": "solid", "color": "black", "shape": "component"}

@property
def dynamic(self):
return any(node.dynamic for node in dict.values(self))

def _unlink_nodes(self):
for node in dict.values(self):
self.unlink(node)

def _link_nodes(self):
for key, node in dict.items(self):
if not isinstance(node, Node):
raise TypeError(f"NodeDict values must be Node objects, not {type(node)}")
self.link(key, node)

def __getitem__(self, key):
return dict.__getitem__(self, key)

def __setitem__(self, key, node):
self._unlink_nodes()
try:
dict.__setitem__(self, key, node)
finally:
self._link_nodes()

def __delitem__(self, key):
self._unlink_nodes()
try:
dict.__delitem__(self, key)
finally:
self._link_nodes()

def update(self, mapping=None, **kwargs):
"""Update the dict with another mapping (i.e. dict) and update graph links."""
self._unlink_nodes()
try:
if mapping is not None:
dict.update(self, mapping)
if kwargs:
dict.update(self, kwargs)
finally:
self._link_nodes()

def pop(self, key, *args):
"""Remove and return a node from the dict and update graph links."""
self._unlink_nodes()
try:
node = dict.pop(self, key, *args)
finally:
self._link_nodes()
return node

def popitem(self):
"""Remove and return an arbitrary (key, node) pair from the dict (the last one inserted) and update graph links."""
self._unlink_nodes()
try:
key, node = dict.popitem(self)
finally:
self._link_nodes()
return key, node

def clear(self):
"""Remove all nodes from the dict and update graph links."""
self._unlink_nodes()
dict.clear(self)

def setdefault(self, key, default):
"""If key is in the dictionary, return its value. If not, insert key with a value of default and return default. Update graph links."""
# Preserve dict.setdefault API shape but enforce NodeDict invariants
if key in self:
return self[key]
if not isinstance(default, Node):
raise TypeError(f"NodeDict values must be Node objects, not {type(default)}")
self[key] = default
return default
5 changes: 4 additions & 1 deletion src/caskade/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .backend import ArrayLike
from .base import Node
from .param import Param
from .collection import NodeTuple, NodeList
from .collection import NodeTuple, NodeList, NodeDict
from .mixins import GetSetValues
from .errors import ActiveStateError, FillParamsError

Expand Down Expand Up @@ -281,6 +281,9 @@ def __setattr__(self, key: str, value: Any):
):
if len(value) > 0 and all(isinstance(v, Node) for v in value):
value = NodeTuple(value, name=key)
elif isinstance(value, dict) and not isinstance(value, NodeDict):
if len(value) > 0 and all(isinstance(v, Node) for v in value.values()):
value = NodeDict(value, name=key)
except AttributeError:
pass
super().__setattr__(key, value)
Loading
Loading