diff --git a/docs/source/notebooks/BeginnersGuide.ipynb b/docs/source/notebooks/BeginnersGuide.ipynb index bf39e9c..e1e05a4 100644 --- a/docs/source/notebooks/BeginnersGuide.ipynb +++ b/docs/source/notebooks/BeginnersGuide.ipynb @@ -312,6 +312,42 @@ "As you can see, a `pointer` parameter is represented in the graph as a shaded arrow. It will now return the same value as the `x0` parameter in `secondsim`." ] }, + { + "cell_type": "markdown", + "id": "link_unlink_md", + "metadata": {}, + "source": [ + "### Linking and unlinking params\n", + "\n", + "Pointer parameters can be linked to and unlinked from other nodes. ", + "Use `link(node)` to connect a child node, and `unlink(node)` (or a key string) to disconnect a specific child. ", + "Calling `unlink()` with no arguments removes **all** children at once, acting as a convenient clear." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "link_unlink_code", + "metadata": {}, + "outputs": [], + "source": [ + "time_param = ck.Param(\"mytime\") # a standalone param to link against\n", + "shared_x0 = ck.Param(\"shared_x0\", shape=(2,))\n", + "\n", + "# Link a child node using a key or by passing the node directly\n", + "shared_x0.link(\"mytime\", time_param)\n", + "print(\"Children after link:\", list(shared_x0.children))\n", + "\n", + "# Unlink a specific child by key or node reference\n", + "shared_x0.unlink(\"mytime\")\n", + "print(\"Children after unlink(key):\", list(shared_x0.children))\n", + "\n", + "# Re-link and then clear all children at once\n", + "shared_x0.link(\"mytime\", time_param)\n", + "shared_x0.unlink() # removes all children\n", + "print(\"Children after unlink():\", list(shared_x0.children))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -568,4 +604,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4ac2230..0f62a30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ keywords = [ "pytorch" ] classifiers=[ - "Development Status :: 1 - Planning", + "Development Status :: 5 - Production/Stable", "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", diff --git a/src/caskade/base.py b/src/caskade/base.py index c46550f..7a80a95 100644 --- a/src/caskade/base.py +++ b/src/caskade/base.py @@ -272,31 +272,39 @@ def _unlink(self, key: str): del self.children[key] self.update_graph() - def unlink(self, key: Union[str, "Node", list, tuple]): - """ - Unlink one or more child nodes from this node. + def unlink(self, key: Union[str, "Node", list, tuple, None] = None): + """Unlink one or more ``Node`` objects from this ``Node``. Parameters ---------- - key : str, Node, list, or tuple - Identifier of the child(ren) to remove. May be a link key - string, the child ``Node`` object itself, or a list/tuple of - keys or nodes to unlink in bulk. - + key: (str, Node, list, tuple, or None, optional) + The key, ``Node`` object, or collection of keys/nodes to unlink. + If a string, the child with that key is unlinked. If a ``Node`` + object, the matching child is located and unlinked. If a list or + tuple, each element is unlinked in turn. If ``None`` (the + default), all children are unlinked. + Raises ------ GraphError If the graph is currently active. """ + if key is None: + self.unlink(list(self.children)) + return if isinstance(key, Node): for node in self.children: if self.children[node] is key: key = node break + else: + raise KeyError(f"Node {key.name} not found in parent {self.name}") elif isinstance(key, (tuple, list)): for k in key: self.unlink(k) return + if key not in self.children: + raise KeyError(f"Child key '{key}' not found in parent {self.name}") self.__delattr__(key) def topological_ordering(self) -> tuple["Node"]: diff --git a/tests/test_base.py b/tests/test_base.py index e566889..0929710 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -85,8 +85,18 @@ def test_linking(node_graph): assert e in a.topological_ordering() with pytest.raises(AttributeError): a.e + with pytest.raises(KeyError): + a.unlink(e) + with pytest.raises(KeyError): + a.unlink("e") a.unlink((b, c)) + # Check unlink with no arguments clears all children + a.link(e) + assert len(a.children) > 0 + a.unlink() + assert len(a.children) == 0 + def test_graphviz(node_graph): a, *_ = node_graph