From 8efb11b65f3cd7900f38a6a8d688e0d441ef269f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 19:54:02 +0000 Subject: [PATCH 01/11] Initial plan From bee4b5e59fc601a097b775c5b64f12c50b1c607f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 3 Mar 2026 20:08:40 +0000 Subject: [PATCH 02/11] feat: add NodeDict collection class for dictionary of nodes Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- .gitignore | 4 +- src/caskade/__init__.py | 3 +- src/caskade/collection.py | 65 ++++++++++++++++++++ src/caskade/module.py | 5 +- tests/test_collection.py | 121 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 195 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index a56f3b3..ec2786b 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,6 @@ cython_debug/ # HDF5 files **.h5 -**settings.json \ No newline at end of file +**settings.json +test_graph +test_graph_dict diff --git a/src/caskade/__init__.py b/src/caskade/__init__.py index 1021295..5f65efc 100644 --- a/src/caskade/__init__.py +++ b/src/caskade/__init__.py @@ -6,7 +6,7 @@ from .decorators import forward, active_cache from .module import Module from .param import Param -from .collection import NodeCollection, NodeList, NodeTuple +from .collection import NodeCollection, NodeList, NodeTuple, NodeDict from .tests import test from .errors import ( CaskadeException, @@ -39,6 +39,7 @@ "NodeCollection", "NodeList", "NodeTuple", + "NodeDict", "ActiveContext", "ValidContext", "OverrideParam", diff --git a/src/caskade/collection.py b/src/caskade/collection.py index 5e4eab3..4e6b52c 100644 --- a/src/caskade/collection.py +++ b/src/caskade/collection.py @@ -187,3 +187,68 @@ def __iadd__(self, other): def __imul__(self, other): raise NotImplementedError + + +class NodeDict(NodeCollection, dict): + + def __init__(self, mapping=None, name=None): + if mapping is None: + mapping = {} + dict.__init__(self, mapping) + Node.__init__(self, name=name) + self.node_type = "ndict" + self._link_nodes() + + @property + def graphviz_style(self): + return {"style": "solid", "color": "black", "shape": "component"} + + @property + def dynamic(self): + return any(node.dynamic for node in dict.values(self)) + + def _unlink_nodes(self): + for node in dict.values(self): + self.unlink(node) + + def _link_nodes(self): + for key, node in dict.items(self): + if not isinstance(node, Node): + raise TypeError(f"NodeDict values must be Node objects, not {type(node)}") + self.link(key, node) + + def __getitem__(self, key): + return dict.__getitem__(self, key) + + def __setitem__(self, key, node): + self._unlink_nodes() + dict.__setitem__(self, key, node) + self._link_nodes() + + def __delitem__(self, key): + self._unlink_nodes() + dict.__delitem__(self, key) + self._link_nodes() + + def update(self, mapping=None, **kwargs): + self._unlink_nodes() + if mapping is not None: + dict.update(self, mapping) + if kwargs: + dict.update(self, kwargs) + self._link_nodes() + + def pop(self, key, *args): + self._unlink_nodes() + node = dict.pop(self, key, *args) + self._link_nodes() + return node + + def clear(self): + self._unlink_nodes() + dict.clear(self) + + def setdefault(self, key, default=None): + if key not in self: + self[key] = default + return self[key] diff --git a/src/caskade/module.py b/src/caskade/module.py index 774faa8..8e47384 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -3,7 +3,7 @@ from .backend import ArrayLike from .base import Node from .param import Param -from .collection import NodeTuple, NodeList +from .collection import NodeTuple, NodeList, NodeDict from .mixins import GetSetValues from .errors import ActiveStateError, FillParamsError @@ -236,6 +236,9 @@ def __setattr__(self, key: str, value: Any): ): if len(value) > 0 and all(isinstance(v, Node) for v in value): value = NodeTuple(value, name=key) + elif isinstance(value, dict) and not isinstance(value, NodeDict): + if len(value) > 0 and all(isinstance(v, Node) for v in value.values()): + value = NodeDict(value, name=key) except AttributeError: pass super().__setattr__(key, value) diff --git a/tests/test_collection.py b/tests/test_collection.py index c02e259..ad2a103 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -4,6 +4,7 @@ from caskade import ( NodeList, NodeTuple, + NodeDict, Param, Module, backend, @@ -172,15 +173,19 @@ def test_collection_in_module(): l1 = [Param("ptest1"), Param("ptest2"), Module("mtest1"), Module("mtest2")] t1 = (Param("ptest3"), Param("ptest4"), Module("mtest3"), Module("mtest4")) + d1 = {"ptest5": Param("ptest5"), "ptest6": Param("ptest6"), "mtest5": Module("mtest5")} m1 = Module("test") m1.l = l1 m1.t = t1 + m1.d = d1 assert m1["l"][2] == l1[2] assert m1["t"][2] == t1[2] assert m1.l[3] == l1[3] assert m1.t[3] == t1[3] + assert m1["d"]["ptest5"] == d1["ptest5"] + assert m1.d["mtest5"] == d1["mtest5"] @pytest.mark.parametrize("node_type", [NodeTuple, NodeList]) @@ -277,3 +282,119 @@ def test_valid_tuple(node_tuple, params_type, group): for i in range(len(node_tuple.dynamic_param_groups)): assert backend.module.allclose(init_params[i], round_trip_params[i]) assert backend.module.allclose(init_params[i], final_params[i]) + + +def test_node_dict_creation(): + + # Minimal creation + n1 = NodeDict() + assert n1.name.startswith("NodeDict") + assert len(n1) == 0 + + # Creation with dict of param nodes + params = {"p1": Param("p1"), "p2": Param("p2")} + n2 = NodeDict(params) + assert len(n2) == 2 + assert n2["p1"] is params["p1"] + assert n2.p1 is params["p1"] + assert n2["p2"] is params["p2"] + + # Creation with dict of module nodes + modules = {"m1": Module("m1"), "m2": Module("m2"), "m3": Module("m3")} + n3 = NodeDict(modules) + assert len(n3) == 3 + assert n3["m1"] is modules["m1"] + assert n3.m1 is modules["m1"] + assert n3["m2"] is modules["m2"] + + # Check repr + assert isinstance(repr(n3), str) + assert "[3]" in repr(n3) + + # Check to static/dynamic + n2.to_dynamic(False) + assert len(n2.static_params) == 0 + n2.to_static(False) + assert len(n2.static_params) == 2 + assert len(n2.pointer_params) == 0 + + # Graphviz + graph = n3.graphviz(saveto="test_graph_dict.pdf") + assert graph is not None, "should return a graphviz object" + assert os.path.exists("test_graph_dict.pdf") + os.remove("test_graph_dict.pdf") + + # Check copy + with pytest.raises(NotImplementedError): + n3.copy() + with pytest.raises(NotImplementedError): + n3.deepcopy() + + # Check bad init + with pytest.raises(TypeError): + NodeDict({"bad": 1}) + + +def test_node_dict_manipulation(): + + params = {"p1": Param("p1", 1), "p2": Param("p2", 2)} + modules = {"m1": Module("m1"), "m2": Module("m2"), "m3": Module("m3")} + nd = NodeDict(params) + + # Set item + p3 = Param("p3", 3) + nd["p3"] = p3 + assert len(nd) == 3 + assert nd["p3"] is p3 + assert nd.p3 is p3 + + # Update + nd.update(modules) + assert len(nd) == 6 + assert nd["m1"] is modules["m1"] + assert nd.m1 is modules["m1"] + + # Pop + popped = nd.pop("m3") + assert popped is modules["m3"] + assert len(nd) == 5 + assert "m3" not in nd + + # Del item + del nd["m2"] + assert len(nd) == 4 + assert "m2" not in nd + + # Clear + nd.clear() + assert len(nd) == 0 + + # Setdefault + nd2 = NodeDict({"p1": Param("p1")}) + p_new = Param("p_new") + nd2.setdefault("new_key", p_new) + assert nd2["new_key"] is p_new + assert len(nd2) == 2 + # setdefault should not overwrite existing + existing = nd2["p1"] + nd2.setdefault("p1", Param("p1_other")) + assert nd2["p1"] is existing + + # Check to static/dynamic + nd3 = NodeDict({"p1": Param("p1", 1), "p2": Param("p2", 2)}) + nd3.to_dynamic() + nd3.to_static() + + # mul raises NotImplementedError + with pytest.raises(NotImplementedError): + nd3 * 2 + + +def test_node_dict_param_values(): + nd = NodeDict({"p1": Param("p1"), "p2": Param("p2"), "p3": Param("p3")}) + + nd.set_values([1, 2, 3]) + + assert nd["p1"].value.item() == 1.0 + assert nd["p2"].value.item() == 2.0 + assert nd["p3"].value.item() == 3.0 From 9b605d4cd265f07cd52ab3afe93e9d0912ceefdb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Mar 2026 02:14:32 +0000 Subject: [PATCH 03/11] test: add dynamic property and update-with-kwargs tests for NodeDict Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- tests/test_collection.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_collection.py b/tests/test_collection.py index ad2a103..8f2ce69 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -385,6 +385,21 @@ def test_node_dict_manipulation(): nd3.to_dynamic() nd3.to_static() + # dynamic property + assert nd3.static + assert not nd3.dynamic + nd3.to_dynamic() + assert nd3.dynamic + assert not nd3.static + + # Update with kwargs + nd4 = NodeDict({"p1": Param("p1")}) + m_kw = Module("mkw") + nd4.update(mkw=m_kw) + assert len(nd4) == 2 + assert nd4["mkw"] is m_kw + assert nd4.mkw is m_kw + # mul raises NotImplementedError with pytest.raises(NotImplementedError): nd3 * 2 From b6f81d9ffa263132ab07d26f40a8d38101ab0a56 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Mar 2026 03:57:44 +0000 Subject: [PATCH 04/11] docs: add Node Collections section to AdvancedGuide tutorial Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- docs/source/notebooks/AdvancedGuide.ipynb | 124 +++++++++++++++++++++- 1 file changed, 123 insertions(+), 1 deletion(-) diff --git a/docs/source/notebooks/AdvancedGuide.ipynb b/docs/source/notebooks/AdvancedGuide.ipynb index bd46ab3..b7530ac 100644 --- a/docs/source/notebooks/AdvancedGuide.ipynb +++ b/docs/source/notebooks/AdvancedGuide.ipynb @@ -702,6 +702,128 @@ "Note that the groups will automatically order themselves by sorting the group integers, so you can easily pick where each group goes in the params. You can even place groups ahead of group 0 by using negative integers." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Node Collections: `NodeList`, `NodeTuple`, and `NodeDict`\n", + "\n", + "Sometimes you want to group a subset of nodes together without wrapping them inside a new `Module`. `caskade` provides three lightweight collection types for this purpose:\n", + "\n", + "- **`NodeList`** \u2013 a mutable, ordered list of nodes (supports `append`, `insert`, `pop`, etc.)\n", + "- **`NodeTuple`** \u2013 an immutable, ordered tuple of nodes\n", + "- **`NodeDict`** \u2013 a mutable mapping of `str \u2192 node`, with attribute-style access\n", + "\n", + "All three behave like their built-in Python counterparts while also being `Node` objects themselves, so they participate fully in the caskade graph.\n", + "\n", + "### Auto-conversion inside a `Module`\n", + "\n", + "When you assign a plain Python `list`, `tuple`, or `dict` of nodes as an attribute of a `Module`, `caskade` automatically converts it to the appropriate collection type::\n", + "\n", + " self.my_params = [p1, p2, p3] # becomes NodeList\n", + " self.my_params = (p1, p2, p3) # becomes NodeTuple\n", + " self.my_params = {'a': p1, 'b': p2} # becomes NodeDict\n", + "\n", + "This means you never have to construct `NodeList` / `NodeTuple` / `NodeDict` explicitly inside a `Module` \u2014 just assign the plain collection and caskade handles the rest.\n", + "\n", + "### Use cases\n", + "\n", + "- **Gibbs-style sampling** \u2013 keep one `NodeList` of the parameters you currently want to sample (dynamic) and another for the parameters you want to hold fixed (static). Calling `.to_dynamic()` / `.to_static()` on the collection flips the whole subset at once.\n", + "- **Selective saving / updating** \u2013 build a `NodeDict` of the parameters you care about and call `get_values()` / `set_values()` on just that subset, without touching the rest of the graph.\n", + "- **Quick inspection** \u2013 group modules or params of interest into a collection and iterate over them for printing, plotting, or diagnostics without restructuring the model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Collections can be constructed directly\n", + "G = Gaussian('G', x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", + "\n", + "# NodeList: mutable, ordered\n", + "position_params = ck.NodeList([G.x0, G.y0])\n", + "print('NodeList:', position_params)\n", + "\n", + "# NodeTuple: immutable, ordered\n", + "shape_params = ck.NodeTuple((G.q, G.phi, G.sigma))\n", + "print('NodeTuple:', shape_params)\n", + "\n", + "# NodeDict: mutable, keyed\n", + "named_params = ck.NodeDict({'x0': G.x0, 'y0': G.y0, 'sigma': G.sigma})\n", + "print('NodeDict:', named_params)\n", + "print('Attribute access:', named_params.x0) # same as named_params['x0']\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Auto-conversion when assigned to a Module attribute\n", + "class GaussianWithCollections(ck.Module):\n", + " def __init__(self, name, submod):\n", + " super().__init__()\n", + " self.g = submod\n", + " # plain list/dict \u2013 caskade converts them automatically\n", + " self.position = [submod.x0, submod.y0] # -> NodeList\n", + " self.named = {'q': submod.q, 'phi': submod.phi} # -> NodeDict\n", + "\n", + " @ck.forward\n", + " def __call__(self, x, y):\n", + " return self.g(x, y)\n", + "\n", + "G2 = Gaussian('G2', x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", + "gc = GaussianWithCollections('gc', G2)\n", + "print(type(gc.position)) # NodeList\n", + "print(type(gc.named)) # NodeDict\n", + "display(gc.graphviz())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# get_values and set_values work on collections the same way as on Modules\n", + "G3 = Gaussian('G3', x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", + "G3.to_dynamic()\n", + "\n", + "subset = ck.NodeDict({'x0': G3.x0, 'y0': G3.y0})\n", + "print('Values from subset:', subset.get_values())\n", + "\n", + "# Update only the position params; everything else stays unchanged\n", + "subset.set_values([10.0, -3.0])\n", + "print('After set_values \u2013 x0:', G3.x0.value, ' y0:', G3.y0.value)\n", + "print('sigma unchanged:', G3.sigma.value)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Gibbs-style sampling: flip subsets between dynamic and static\n", + "G4 = Gaussian('G4', x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", + "G4.to_dynamic()\n", + "\n", + "group_A = ck.NodeList([G4.x0, G4.y0]) # position\n", + "group_B = ck.NodeList([G4.q, G4.phi, G4.sigma]) # shape\n", + "\n", + "# Sample group A while holding group B fixed\n", + "group_B.to_static()\n", + "print('Dynamic params (group A active):', [p.name for p in G4.dynamic_params])\n", + "\n", + "# Now swap: fix group A, sample group B\n", + "group_A.to_static()\n", + "group_B.to_dynamic()\n", + "print('Dynamic params (group B active):', [p.name for p in G4.dynamic_params])\n" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -929,4 +1051,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file From 64254a428182557955c3ca722965bac2fd416a2d Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 3 Mar 2026 23:33:14 -0500 Subject: [PATCH 05/11] clean up node collections in advanced notebook --- docs/source/notebooks/AdvancedGuide.ipynb | 90 +++++++---------------- 1 file changed, 25 insertions(+), 65 deletions(-) diff --git a/docs/source/notebooks/AdvancedGuide.ipynb b/docs/source/notebooks/AdvancedGuide.ipynb index b7530ac..ae016e6 100644 --- a/docs/source/notebooks/AdvancedGuide.ipynb +++ b/docs/source/notebooks/AdvancedGuide.ipynb @@ -710,27 +710,27 @@ "\n", "Sometimes you want to group a subset of nodes together without wrapping them inside a new `Module`. `caskade` provides three lightweight collection types for this purpose:\n", "\n", - "- **`NodeList`** \u2013 a mutable, ordered list of nodes (supports `append`, `insert`, `pop`, etc.)\n", - "- **`NodeTuple`** \u2013 an immutable, ordered tuple of nodes\n", - "- **`NodeDict`** \u2013 a mutable mapping of `str \u2192 node`, with attribute-style access\n", + "- **`NodeList`** – a mutable, ordered list of nodes (supports `append`, `insert`, `pop`, etc.)\n", + "- **`NodeTuple`** – an immutable, ordered tuple of nodes\n", + "- **`NodeDict`** – a mutable mapping of `str → node`, with attribute-style access\n", "\n", "All three behave like their built-in Python counterparts while also being `Node` objects themselves, so they participate fully in the caskade graph.\n", "\n", - "### Auto-conversion inside a `Module`\n", - "\n", "When you assign a plain Python `list`, `tuple`, or `dict` of nodes as an attribute of a `Module`, `caskade` automatically converts it to the appropriate collection type::\n", "\n", " self.my_params = [p1, p2, p3] # becomes NodeList\n", " self.my_params = (p1, p2, p3) # becomes NodeTuple\n", " self.my_params = {'a': p1, 'b': p2} # becomes NodeDict\n", "\n", - "This means you never have to construct `NodeList` / `NodeTuple` / `NodeDict` explicitly inside a `Module` \u2014 just assign the plain collection and caskade handles the rest.\n", + "This means you never have to construct `NodeList` / `NodeTuple` / `NodeDict` explicitly inside a `Module` — just assign the plain collection and caskade handles the rest.\n", "\n", "### Use cases\n", "\n", - "- **Gibbs-style sampling** \u2013 keep one `NodeList` of the parameters you currently want to sample (dynamic) and another for the parameters you want to hold fixed (static). Calling `.to_dynamic()` / `.to_static()` on the collection flips the whole subset at once.\n", - "- **Selective saving / updating** \u2013 build a `NodeDict` of the parameters you care about and call `get_values()` / `set_values()` on just that subset, without touching the rest of the graph.\n", - "- **Quick inspection** \u2013 group modules or params of interest into a collection and iterate over them for printing, plotting, or diagnostics without restructuring the model.\n" + "- **Gibbs-style sampling** – keep one `NodeList` of the parameters you currently want to sample (dynamic) and another for the parameters you want to hold fixed (static). Calling `.to_dynamic()` / `.to_static()` on the collection flips the whole subset at once.\n", + "- **Selective saving / updating** – build a `NodeDict` of the parameters you care about and call `get_values()` / `set_values()` on just that subset, without touching the rest of the graph.\n", + "- **Quick inspection** – group modules or params of interest into a collection and iterate over them for printing, plotting, or diagnostics without restructuring the model.\n", + "\n", + "Note that `set_values` and `get_values` also work on Node collections." ] }, { @@ -740,20 +740,20 @@ "outputs": [], "source": [ "# Collections can be constructed directly\n", - "G = Gaussian('G', x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", + "G = Gaussian(\"G\", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", "\n", "# NodeList: mutable, ordered\n", "position_params = ck.NodeList([G.x0, G.y0])\n", - "print('NodeList:', position_params)\n", + "print(\"NodeList:\", position_params)\n", "\n", "# NodeTuple: immutable, ordered\n", "shape_params = ck.NodeTuple((G.q, G.phi, G.sigma))\n", - "print('NodeTuple:', shape_params)\n", + "print(\"NodeTuple:\", shape_params)\n", "\n", "# NodeDict: mutable, keyed\n", - "named_params = ck.NodeDict({'x0': G.x0, 'y0': G.y0, 'sigma': G.sigma})\n", - "print('NodeDict:', named_params)\n", - "print('Attribute access:', named_params.x0) # same as named_params['x0']\n" + "named_params = ck.NodeDict({\"x0\": G.x0, \"y0\": G.y0, \"sigma\": G.sigma})\n", + "print(\"NodeDict:\", named_params)\n", + "print(\"Attribute access:\", named_params[\"x0\"]) # same as named_params.x0" ] }, { @@ -767,61 +767,21 @@ " def __init__(self, name, submod):\n", " super().__init__()\n", " self.g = submod\n", - " # plain list/dict \u2013 caskade converts them automatically\n", - " self.position = [submod.x0, submod.y0] # -> NodeList\n", - " self.named = {'q': submod.q, 'phi': submod.phi} # -> NodeDict\n", + " # plain list/dict – caskade converts them automatically\n", + " self.position = [submod.x0, submod.y0] # -> NodeList\n", + " self.named = {\"q\": submod.q, \"phi\": submod.phi} # -> NodeDict\n", "\n", " @ck.forward\n", " def __call__(self, x, y):\n", " return self.g(x, y)\n", "\n", - "G2 = Gaussian('G2', x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", - "gc = GaussianWithCollections('gc', G2)\n", - "print(type(gc.position)) # NodeList\n", - "print(type(gc.named)) # NodeDict\n", - "display(gc.graphviz())\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# get_values and set_values work on collections the same way as on Modules\n", - "G3 = Gaussian('G3', x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", - "G3.to_dynamic()\n", - "\n", - "subset = ck.NodeDict({'x0': G3.x0, 'y0': G3.y0})\n", - "print('Values from subset:', subset.get_values())\n", "\n", - "# Update only the position params; everything else stays unchanged\n", - "subset.set_values([10.0, -3.0])\n", - "print('After set_values \u2013 x0:', G3.x0.value, ' y0:', G3.y0.value)\n", - "print('sigma unchanged:', G3.sigma.value)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Gibbs-style sampling: flip subsets between dynamic and static\n", - "G4 = Gaussian('G4', x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", - "G4.to_dynamic()\n", - "\n", - "group_A = ck.NodeList([G4.x0, G4.y0]) # position\n", - "group_B = ck.NodeList([G4.q, G4.phi, G4.sigma]) # shape\n", - "\n", - "# Sample group A while holding group B fixed\n", - "group_B.to_static()\n", - "print('Dynamic params (group A active):', [p.name for p in G4.dynamic_params])\n", - "\n", - "# Now swap: fix group A, sample group B\n", - "group_A.to_static()\n", - "group_B.to_dynamic()\n", - "print('Dynamic params (group B active):', [p.name for p in G4.dynamic_params])\n" + "G2 = Gaussian(\"G2\", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", + "gc = GaussianWithCollections(\"gc\", G2)\n", + "print(type(gc.position)) # NodeList\n", + "print(type(gc.named)) # NodeDict\n", + "gc.position.to_dynamic() # can call collection methods\n", + "display(gc.graphviz())" ] }, { @@ -1051,4 +1011,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +} From fe75ede7561d2d0f02d927cbc72882eddba5534a Mon Sep 17 00:00:00 2001 From: "Connor Stone, PhD" Date: Wed, 4 Mar 2026 16:25:41 -0500 Subject: [PATCH 06/11] make setdefault behaviour safer to improper inputs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/caskade/collection.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/caskade/collection.py b/src/caskade/collection.py index f2a62c2..10323af 100644 --- a/src/caskade/collection.py +++ b/src/caskade/collection.py @@ -333,6 +333,15 @@ def clear(self): dict.clear(self) def setdefault(self, key, default=None): - if key not in self: - self[key] = default - return self[key] + # Preserve dict.setdefault API shape but enforce NodeDict invariants + if key in self: + return self[key] + if default is None: + raise TypeError( + "NodeDict.setdefault() requires a default Node when key is absent; " + "None is not a valid NodeDict value" + ) + if not isinstance(default, Node): + raise TypeError(f"NodeDict values must be Node objects, not {type(default)}") + self[key] = default + return default From f5b70cb8bc4529c19bd346c267bed22ca6bad656 Mon Sep 17 00:00:00 2001 From: "Connor Stone, PhD" Date: Wed, 4 Mar 2026 16:26:55 -0500 Subject: [PATCH 07/11] improve code rendering in tutorial notebook Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/source/notebooks/AdvancedGuide.ipynb | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source/notebooks/AdvancedGuide.ipynb b/docs/source/notebooks/AdvancedGuide.ipynb index ae016e6..e99c35c 100644 --- a/docs/source/notebooks/AdvancedGuide.ipynb +++ b/docs/source/notebooks/AdvancedGuide.ipynb @@ -716,11 +716,13 @@ "\n", "All three behave like their built-in Python counterparts while also being `Node` objects themselves, so they participate fully in the caskade graph.\n", "\n", - "When you assign a plain Python `list`, `tuple`, or `dict` of nodes as an attribute of a `Module`, `caskade` automatically converts it to the appropriate collection type::\n", + "When you assign a plain Python `list`, `tuple`, or `dict` of nodes as an attribute of a `Module`, `caskade` automatically converts it to the appropriate collection type:\n", "\n", - " self.my_params = [p1, p2, p3] # becomes NodeList\n", - " self.my_params = (p1, p2, p3) # becomes NodeTuple\n", - " self.my_params = {'a': p1, 'b': p2} # becomes NodeDict\n", + "```python\n", + "self.my_params = [p1, p2, p3] # becomes NodeList\n", + "self.my_params = (p1, p2, p3) # becomes NodeTuple\n", + "self.my_params = {'a': p1, 'b': p2} # becomes NodeDict\n", + "```\n", "\n", "This means you never have to construct `NodeList` / `NodeTuple` / `NodeDict` explicitly inside a `Module` — just assign the plain collection and caskade handles the rest.\n", "\n", From 387df50281af3dbb83d97c3ebed84cbf3590a4ce Mon Sep 17 00:00:00 2001 From: "Connor Stone, PhD" Date: Wed, 4 Mar 2026 16:27:45 -0500 Subject: [PATCH 08/11] include name passed to super in example Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/source/notebooks/AdvancedGuide.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/notebooks/AdvancedGuide.ipynb b/docs/source/notebooks/AdvancedGuide.ipynb index e99c35c..9e4fe73 100644 --- a/docs/source/notebooks/AdvancedGuide.ipynb +++ b/docs/source/notebooks/AdvancedGuide.ipynb @@ -767,7 +767,7 @@ "# Auto-conversion when assigned to a Module attribute\n", "class GaussianWithCollections(ck.Module):\n", " def __init__(self, name, submod):\n", - " super().__init__()\n", + " super().__init__(name)\n", " self.g = submod\n", " # plain list/dict – caskade converts them automatically\n", " self.position = [submod.x0, submod.y0] # -> NodeList\n", From a87467e1b2916be2fe37cf7e5bea313aea269feb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Mar 2026 22:05:54 +0000 Subject: [PATCH 09/11] fix: wrap all NodeList/NodeDict mutation methods in try/finally to ensure re-linking Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --- src/caskade/collection.py | 84 +++++++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 29 deletions(-) diff --git a/src/caskade/collection.py b/src/caskade/collection.py index 10323af..6d3057f 100644 --- a/src/caskade/collection.py +++ b/src/caskade/collection.py @@ -208,39 +208,51 @@ def _link_nodes(self): def append(self, node): """Append a node to the list and update graph links.""" self._unlink_nodes() - super().append(node) - self._link_nodes() + try: + super().append(node) + finally: + self._link_nodes() def insert(self, index, node): """Insert a node at the given index and update graph links.""" self._unlink_nodes() - super().insert(index, node) - self._link_nodes() + try: + super().insert(index, node) + finally: + self._link_nodes() def extend(self, iterable): """Extend the list with nodes from an iterable and update graph links.""" self._unlink_nodes() - super().extend(iterable) - self._link_nodes() + try: + super().extend(iterable) + finally: + self._link_nodes() def clear(self): """Remove all nodes from the list and update graph links.""" self._unlink_nodes() - super().clear() - self._link_nodes() + try: + super().clear() + finally: + self._link_nodes() def pop(self, index=-1): """Remove and return a node at the given index, updating graph links.""" self._unlink_nodes() - node = super().pop(index) - self._link_nodes() + try: + node = super().pop(index) + finally: + self._link_nodes() return node def remove(self, value): """Remove the first occurrence of a node and update graph links.""" self._unlink_nodes() - super().remove(value) - self._link_nodes() + try: + super().remove(value) + finally: + self._link_nodes() def __getitem__(self, key): if isinstance(key, str): @@ -251,13 +263,17 @@ def __getitem__(self, key): def __setitem__(self, key, value): self._unlink_nodes() - super().__setitem__(key, value) - self._link_nodes() + try: + super().__setitem__(key, value) + finally: + self._link_nodes() def __delitem__(self, key): self._unlink_nodes() - super().__delitem__(key) - self._link_nodes() + try: + super().__delitem__(key) + finally: + self._link_nodes() def __add__(self, other): res = super().__add__(other) @@ -265,8 +281,10 @@ def __add__(self, other): def __iadd__(self, other): self._unlink_nodes() - ret = super().__iadd__(other) - self._link_nodes() + try: + ret = super().__iadd__(other) + finally: + self._link_nodes() return ret def __imul__(self, other): @@ -306,26 +324,34 @@ def __getitem__(self, key): def __setitem__(self, key, node): self._unlink_nodes() - dict.__setitem__(self, key, node) - self._link_nodes() + try: + dict.__setitem__(self, key, node) + finally: + self._link_nodes() def __delitem__(self, key): self._unlink_nodes() - dict.__delitem__(self, key) - self._link_nodes() + try: + dict.__delitem__(self, key) + finally: + self._link_nodes() def update(self, mapping=None, **kwargs): self._unlink_nodes() - if mapping is not None: - dict.update(self, mapping) - if kwargs: - dict.update(self, kwargs) - self._link_nodes() + try: + if mapping is not None: + dict.update(self, mapping) + if kwargs: + dict.update(self, kwargs) + finally: + self._link_nodes() def pop(self, key, *args): self._unlink_nodes() - node = dict.pop(self, key, *args) - self._link_nodes() + try: + node = dict.pop(self, key, *args) + finally: + self._link_nodes() return node def clear(self): From 28eac31dbc0991fa117aac52d65500e59f00040d Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 5 Mar 2026 09:13:07 -0500 Subject: [PATCH 10/11] add docstrings, handle setdefault error cases --- src/caskade/collection.py | 33 +++++++++++++++++++++++++++------ tests/test_collection.py | 2 ++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/caskade/collection.py b/src/caskade/collection.py index 6d3057f..c8bdddb 100644 --- a/src/caskade/collection.py +++ b/src/caskade/collection.py @@ -292,6 +292,19 @@ def __imul__(self, other): class NodeDict(NodeCollection, dict): + """Mutable, keyed collection of nodes. + + Behaves like a standard ``dict`` but also participates in the caskade + node graph. All elements must be ``Node`` instances. Graph links are + automatically updated whenever the dict is modified. + + Parameters + ---------- + mapping : mapping of str to Node, optional + Nodes to include in the dict. Defaults to an empty dict. + name : str, optional + Human-readable name for this collection of nodes. + """ def __init__(self, mapping=None, name=None): if mapping is None: @@ -337,6 +350,7 @@ def __delitem__(self, key): self._link_nodes() def update(self, mapping=None, **kwargs): + """Update the dict with another mapping (i.e. dict) and update graph links.""" self._unlink_nodes() try: if mapping is not None: @@ -347,6 +361,7 @@ def update(self, mapping=None, **kwargs): self._link_nodes() def pop(self, key, *args): + """Remove and return a node from the dict and update graph links.""" self._unlink_nodes() try: node = dict.pop(self, key, *args) @@ -354,19 +369,25 @@ def pop(self, key, *args): self._link_nodes() return node + def popitem(self): + """Remove and return an arbitrary (key, node) pair from the dict (the last one inserted) and update graph links.""" + self._unlink_nodes() + try: + key, node = dict.popitem(self) + finally: + self._link_nodes() + return key, node + def clear(self): + """Remove all nodes from the dict and update graph links.""" self._unlink_nodes() dict.clear(self) - def setdefault(self, key, default=None): + def setdefault(self, key, default): + """If key is in the dictionary, return its value. If not, insert key with a value of default and return default. Update graph links.""" # Preserve dict.setdefault API shape but enforce NodeDict invariants if key in self: return self[key] - if default is None: - raise TypeError( - "NodeDict.setdefault() requires a default Node when key is absent; " - "None is not a valid NodeDict value" - ) if not isinstance(default, Node): raise TypeError(f"NodeDict values must be Node objects, not {type(default)}") self[key] = default diff --git a/tests/test_collection.py b/tests/test_collection.py index 8f2ce69..a737e5b 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -379,6 +379,8 @@ def test_node_dict_manipulation(): existing = nd2["p1"] nd2.setdefault("p1", Param("p1_other")) assert nd2["p1"] is existing + with pytest.raises(TypeError): + nd2.setdefault("bad_key", "not a node") # Check to static/dynamic nd3 = NodeDict({"p1": Param("p1", 1), "p2": Param("p2", 2)}) From 6ee3af7d92f7c70ac897e568fab3a195514ca0b3 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 5 Mar 2026 09:28:21 -0500 Subject: [PATCH 11/11] add popitem test --- tests/test_collection.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_collection.py b/tests/test_collection.py index a737e5b..f41a2b8 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -365,6 +365,10 @@ def test_node_dict_manipulation(): assert len(nd) == 4 assert "m2" not in nd + # Popitem + key, _ = nd.popitem() + assert key not in nd + # Clear nd.clear() assert len(nd) == 0