diff --git a/.gitignore b/.gitignore index 589cb21..ec2786b 100644 --- a/.gitignore +++ b/.gitignore @@ -166,6 +166,5 @@ cython_debug/ **.h5 **settings.json - -# Test artifacts test_graph +test_graph_dict diff --git a/docs/source/notebooks/AdvancedGuide.ipynb b/docs/source/notebooks/AdvancedGuide.ipynb index bd46ab3..9e4fe73 100644 --- a/docs/source/notebooks/AdvancedGuide.ipynb +++ b/docs/source/notebooks/AdvancedGuide.ipynb @@ -702,6 +702,90 @@ "Note that the groups will automatically order themselves by sorting the group integers, so you can easily pick where each group goes in the params. You can even place groups ahead of group 0 by using negative integers." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Node Collections: `NodeList`, `NodeTuple`, and `NodeDict`\n", + "\n", + "Sometimes you want to group a subset of nodes together without wrapping them inside a new `Module`. `caskade` provides three lightweight collection types for this purpose:\n", + "\n", + "- **`NodeList`** – a mutable, ordered list of nodes (supports `append`, `insert`, `pop`, etc.)\n", + "- **`NodeTuple`** – an immutable, ordered tuple of nodes\n", + "- **`NodeDict`** – a mutable mapping of `str → node`, with attribute-style access\n", + "\n", + "All three behave like their built-in Python counterparts while also being `Node` objects themselves, so they participate fully in the caskade graph.\n", + "\n", + "When you assign a plain Python `list`, `tuple`, or `dict` of nodes as an attribute of a `Module`, `caskade` automatically converts it to the appropriate collection type:\n", + "\n", + "```python\n", + "self.my_params = [p1, p2, p3] # becomes NodeList\n", + "self.my_params = (p1, p2, p3) # becomes NodeTuple\n", + "self.my_params = {'a': p1, 'b': p2} # becomes NodeDict\n", + "```\n", + "\n", + "This means you never have to construct `NodeList` / `NodeTuple` / `NodeDict` explicitly inside a `Module` — just assign the plain collection and caskade handles the rest.\n", + "\n", + "### Use cases\n", + "\n", + "- **Gibbs-style sampling** – keep one `NodeList` of the parameters you currently want to sample (dynamic) and another for the parameters you want to hold fixed (static). Calling `.to_dynamic()` / `.to_static()` on the collection flips the whole subset at once.\n", + "- **Selective saving / updating** – build a `NodeDict` of the parameters you care about and call `get_values()` / `set_values()` on just that subset, without touching the rest of the graph.\n", + "- **Quick inspection** – group modules or params of interest into a collection and iterate over them for printing, plotting, or diagnostics without restructuring the model.\n", + "\n", + "Note that `set_values` and `get_values` also work on Node collections." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Collections can be constructed directly\n", + "G = Gaussian(\"G\", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", + "\n", + "# NodeList: mutable, ordered\n", + "position_params = ck.NodeList([G.x0, G.y0])\n", + "print(\"NodeList:\", position_params)\n", + "\n", + "# NodeTuple: immutable, ordered\n", + "shape_params = ck.NodeTuple((G.q, G.phi, G.sigma))\n", + "print(\"NodeTuple:\", shape_params)\n", + "\n", + "# NodeDict: mutable, keyed\n", + "named_params = ck.NodeDict({\"x0\": G.x0, \"y0\": G.y0, \"sigma\": G.sigma})\n", + "print(\"NodeDict:\", named_params)\n", + "print(\"Attribute access:\", named_params[\"x0\"]) # same as named_params.x0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Auto-conversion when assigned to a Module attribute\n", + "class GaussianWithCollections(ck.Module):\n", + " def __init__(self, name, submod):\n", + " super().__init__(name)\n", + " self.g = submod\n", + " # plain list/dict – caskade converts them automatically\n", + " self.position = [submod.x0, submod.y0] # -> NodeList\n", + " self.named = {\"q\": submod.q, \"phi\": submod.phi} # -> NodeDict\n", + "\n", + " @ck.forward\n", + " def __call__(self, x, y):\n", + " return self.g(x, y)\n", + "\n", + "\n", + "G2 = Gaussian(\"G2\", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", + "gc = GaussianWithCollections(\"gc\", G2)\n", + "print(type(gc.position)) # NodeList\n", + "print(type(gc.named)) # NodeDict\n", + "gc.position.to_dynamic() # can call collection methods\n", + "display(gc.graphviz())" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/caskade/__init__.py b/src/caskade/__init__.py index 58ed274..cb22394 100644 --- a/src/caskade/__init__.py +++ b/src/caskade/__init__.py @@ -25,7 +25,7 @@ from .decorators import forward, active_cache from .module import Module from .param import Param -from .collection import NodeCollection, NodeList, NodeTuple +from .collection import NodeCollection, NodeList, NodeTuple, NodeDict from .tests import test from .errors import ( CaskadeException, @@ -58,6 +58,7 @@ "NodeCollection", "NodeList", "NodeTuple", + "NodeDict", "ActiveContext", "ValidContext", "OverrideParam", diff --git a/src/caskade/collection.py b/src/caskade/collection.py index 782ceb9..c8bdddb 100644 --- a/src/caskade/collection.py +++ b/src/caskade/collection.py @@ -208,39 +208,51 @@ def _link_nodes(self): def append(self, node): """Append a node to the list and update graph links.""" self._unlink_nodes() - super().append(node) - self._link_nodes() + try: + super().append(node) + finally: + self._link_nodes() def insert(self, index, node): """Insert a node at the given index and update graph links.""" self._unlink_nodes() - super().insert(index, node) - self._link_nodes() + try: + super().insert(index, node) + finally: + self._link_nodes() def extend(self, iterable): """Extend the list with nodes from an iterable and update graph links.""" self._unlink_nodes() - super().extend(iterable) - self._link_nodes() + try: + super().extend(iterable) + finally: + self._link_nodes() def clear(self): """Remove all nodes from the list and update graph links.""" self._unlink_nodes() - super().clear() - self._link_nodes() + try: + super().clear() + finally: + self._link_nodes() def pop(self, index=-1): """Remove and return a node at the given index, updating graph links.""" self._unlink_nodes() - node = super().pop(index) - self._link_nodes() + try: + node = super().pop(index) + finally: + self._link_nodes() return node def remove(self, value): """Remove the first occurrence of a node and update graph links.""" self._unlink_nodes() - super().remove(value) - self._link_nodes() + try: + super().remove(value) + finally: + self._link_nodes() def __getitem__(self, key): if isinstance(key, str): @@ -251,13 +263,17 @@ def __getitem__(self, key): def __setitem__(self, key, value): self._unlink_nodes() - super().__setitem__(key, value) - self._link_nodes() + try: + super().__setitem__(key, value) + finally: + self._link_nodes() def __delitem__(self, key): self._unlink_nodes() - super().__delitem__(key) - self._link_nodes() + try: + super().__delitem__(key) + finally: + self._link_nodes() def __add__(self, other): res = super().__add__(other) @@ -265,9 +281,114 @@ def __add__(self, other): def __iadd__(self, other): self._unlink_nodes() - ret = super().__iadd__(other) - self._link_nodes() + try: + ret = super().__iadd__(other) + finally: + self._link_nodes() return ret def __imul__(self, other): raise NotImplementedError + + +class NodeDict(NodeCollection, dict): + """Mutable, keyed collection of nodes. + + Behaves like a standard ``dict`` but also participates in the caskade + node graph. All elements must be ``Node`` instances. Graph links are + automatically updated whenever the dict is modified. + + Parameters + ---------- + mapping : mapping of str to Node, optional + Nodes to include in the dict. Defaults to an empty dict. + name : str, optional + Human-readable name for this collection of nodes. + """ + + def __init__(self, mapping=None, name=None): + if mapping is None: + mapping = {} + dict.__init__(self, mapping) + Node.__init__(self, name=name) + self.node_type = "ndict" + self._link_nodes() + + @property + def graphviz_style(self): + return {"style": "solid", "color": "black", "shape": "component"} + + @property + def dynamic(self): + return any(node.dynamic for node in dict.values(self)) + + def _unlink_nodes(self): + for node in dict.values(self): + self.unlink(node) + + def _link_nodes(self): + for key, node in dict.items(self): + if not isinstance(node, Node): + raise TypeError(f"NodeDict values must be Node objects, not {type(node)}") + self.link(key, node) + + def __getitem__(self, key): + return dict.__getitem__(self, key) + + def __setitem__(self, key, node): + self._unlink_nodes() + try: + dict.__setitem__(self, key, node) + finally: + self._link_nodes() + + def __delitem__(self, key): + self._unlink_nodes() + try: + dict.__delitem__(self, key) + finally: + self._link_nodes() + + def update(self, mapping=None, **kwargs): + """Update the dict with another mapping (i.e. dict) and update graph links.""" + self._unlink_nodes() + try: + if mapping is not None: + dict.update(self, mapping) + if kwargs: + dict.update(self, kwargs) + finally: + self._link_nodes() + + def pop(self, key, *args): + """Remove and return a node from the dict and update graph links.""" + self._unlink_nodes() + try: + node = dict.pop(self, key, *args) + finally: + self._link_nodes() + return node + + def popitem(self): + """Remove and return an arbitrary (key, node) pair from the dict (the last one inserted) and update graph links.""" + self._unlink_nodes() + try: + key, node = dict.popitem(self) + finally: + self._link_nodes() + return key, node + + def clear(self): + """Remove all nodes from the dict and update graph links.""" + self._unlink_nodes() + dict.clear(self) + + def setdefault(self, key, default): + """If key is in the dictionary, return its value. If not, insert key with a value of default and return default. Update graph links.""" + # Preserve dict.setdefault API shape but enforce NodeDict invariants + if key in self: + return self[key] + if not isinstance(default, Node): + raise TypeError(f"NodeDict values must be Node objects, not {type(default)}") + self[key] = default + return default diff --git a/src/caskade/module.py b/src/caskade/module.py index 64d4fe0..146f5fd 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -3,7 +3,7 @@ from .backend import ArrayLike from .base import Node from .param import Param -from .collection import NodeTuple, NodeList +from .collection import NodeTuple, NodeList, NodeDict from .mixins import GetSetValues from .errors import ActiveStateError, FillParamsError @@ -281,6 +281,9 @@ def __setattr__(self, key: str, value: Any): ): if len(value) > 0 and all(isinstance(v, Node) for v in value): value = NodeTuple(value, name=key) + elif isinstance(value, dict) and not isinstance(value, NodeDict): + if len(value) > 0 and all(isinstance(v, Node) for v in value.values()): + value = NodeDict(value, name=key) except AttributeError: pass super().__setattr__(key, value) diff --git a/tests/test_collection.py b/tests/test_collection.py index c02e259..f41a2b8 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -4,6 +4,7 @@ from caskade import ( NodeList, NodeTuple, + NodeDict, Param, Module, backend, @@ -172,15 +173,19 @@ def test_collection_in_module(): l1 = [Param("ptest1"), Param("ptest2"), Module("mtest1"), Module("mtest2")] t1 = (Param("ptest3"), Param("ptest4"), Module("mtest3"), Module("mtest4")) + d1 = {"ptest5": Param("ptest5"), "ptest6": Param("ptest6"), "mtest5": Module("mtest5")} m1 = Module("test") m1.l = l1 m1.t = t1 + m1.d = d1 assert m1["l"][2] == l1[2] assert m1["t"][2] == t1[2] assert m1.l[3] == l1[3] assert m1.t[3] == t1[3] + assert m1["d"]["ptest5"] == d1["ptest5"] + assert m1.d["mtest5"] == d1["mtest5"] @pytest.mark.parametrize("node_type", [NodeTuple, NodeList]) @@ -277,3 +282,140 @@ def test_valid_tuple(node_tuple, params_type, group): for i in range(len(node_tuple.dynamic_param_groups)): assert backend.module.allclose(init_params[i], round_trip_params[i]) assert backend.module.allclose(init_params[i], final_params[i]) + + +def test_node_dict_creation(): + + # Minimal creation + n1 = NodeDict() + assert n1.name.startswith("NodeDict") + assert len(n1) == 0 + + # Creation with dict of param nodes + params = {"p1": Param("p1"), "p2": Param("p2")} + n2 = NodeDict(params) + assert len(n2) == 2 + assert n2["p1"] is params["p1"] + assert n2.p1 is params["p1"] + assert n2["p2"] is params["p2"] + + # Creation with dict of module nodes + modules = {"m1": Module("m1"), "m2": Module("m2"), "m3": Module("m3")} + n3 = NodeDict(modules) + assert len(n3) == 3 + assert n3["m1"] is modules["m1"] + assert n3.m1 is modules["m1"] + assert n3["m2"] is modules["m2"] + + # Check repr + assert isinstance(repr(n3), str) + assert "[3]" in repr(n3) + + # Check to static/dynamic + n2.to_dynamic(False) + assert len(n2.static_params) == 0 + n2.to_static(False) + assert len(n2.static_params) == 2 + assert len(n2.pointer_params) == 0 + + # Graphviz + graph = n3.graphviz(saveto="test_graph_dict.pdf") + assert graph is not None, "should return a graphviz object" + assert os.path.exists("test_graph_dict.pdf") + os.remove("test_graph_dict.pdf") + + # Check copy + with pytest.raises(NotImplementedError): + n3.copy() + with pytest.raises(NotImplementedError): + n3.deepcopy() + + # Check bad init + with pytest.raises(TypeError): + NodeDict({"bad": 1}) + + +def test_node_dict_manipulation(): + + params = {"p1": Param("p1", 1), "p2": Param("p2", 2)} + modules = {"m1": Module("m1"), "m2": Module("m2"), "m3": Module("m3")} + nd = NodeDict(params) + + # Set item + p3 = Param("p3", 3) + nd["p3"] = p3 + assert len(nd) == 3 + assert nd["p3"] is p3 + assert nd.p3 is p3 + + # Update + nd.update(modules) + assert len(nd) == 6 + assert nd["m1"] is modules["m1"] + assert nd.m1 is modules["m1"] + + # Pop + popped = nd.pop("m3") + assert popped is modules["m3"] + assert len(nd) == 5 + assert "m3" not in nd + + # Del item + del nd["m2"] + assert len(nd) == 4 + assert "m2" not in nd + + # Popitem + key, _ = nd.popitem() + assert key not in nd + + # Clear + nd.clear() + assert len(nd) == 0 + + # Setdefault + nd2 = NodeDict({"p1": Param("p1")}) + p_new = Param("p_new") + nd2.setdefault("new_key", p_new) + assert nd2["new_key"] is p_new + assert len(nd2) == 2 + # setdefault should not overwrite existing + existing = nd2["p1"] + nd2.setdefault("p1", Param("p1_other")) + assert nd2["p1"] is existing + with pytest.raises(TypeError): + nd2.setdefault("bad_key", "not a node") + + # Check to static/dynamic + nd3 = NodeDict({"p1": Param("p1", 1), "p2": Param("p2", 2)}) + nd3.to_dynamic() + nd3.to_static() + + # dynamic property + assert nd3.static + assert not nd3.dynamic + nd3.to_dynamic() + assert nd3.dynamic + assert not nd3.static + + # Update with kwargs + nd4 = NodeDict({"p1": Param("p1")}) + m_kw = Module("mkw") + nd4.update(mkw=m_kw) + assert len(nd4) == 2 + assert nd4["mkw"] is m_kw + assert nd4.mkw is m_kw + + # mul raises NotImplementedError + with pytest.raises(NotImplementedError): + nd3 * 2 + + +def test_node_dict_param_values(): + nd = NodeDict({"p1": Param("p1"), "p2": Param("p2"), "p3": Param("p3")}) + + nd.set_values([1, 2, 3]) + + assert nd["p1"].value.item() == 1.0 + assert nd["p2"].value.item() == 2.0 + assert nd["p3"].value.item() == 3.0