Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion docs/source/notebooks/BeginnersGuide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,42 @@
"As you can see, a `pointer` parameter is represented in the graph as a shaded arrow. It will now return the same value as the `x0` parameter in `secondsim`."
]
},
{
"cell_type": "markdown",
"id": "link_unlink_md",
"metadata": {},
"source": [
"### Linking and unlinking params\n",
"\n",
"Pointer parameters can be linked to and unlinked from other nodes. ",
"Use `link(node)` to connect a child node, and `unlink(node)` (or a key string) to disconnect a specific child. ",
"Calling `unlink()` with no arguments removes **all** children at once, acting as a convenient clear."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "link_unlink_code",
"metadata": {},
"outputs": [],
"source": [
"time_param = ck.Param(\"mytime\") # a standalone param to link against\n",
"shared_x0 = ck.Param(\"shared_x0\", shape=(2,))\n",
"\n",
"# Link a child node using a key or by passing the node directly\n",
"shared_x0.link(\"mytime\", time_param)\n",
"print(\"Children after link:\", list(shared_x0.children))\n",
"\n",
"# Unlink a specific child by key or node reference\n",
"shared_x0.unlink(\"mytime\")\n",
"print(\"Children after unlink(key):\", list(shared_x0.children))\n",
"\n",
"# Re-link and then clear all children at once\n",
"shared_x0.link(\"mytime\", time_param)\n",
"shared_x0.unlink() # removes all children\n",
"print(\"Children after unlink():\", list(shared_x0.children))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -568,4 +604,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ keywords = [
"pytorch"
]
classifiers=[
"Development Status :: 1 - Planning",
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
Expand Down
24 changes: 16 additions & 8 deletions src/caskade/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,31 +272,39 @@ def _unlink(self, key: str):
del self.children[key]
self.update_graph()

def unlink(self, key: Union[str, "Node", list, tuple]):
"""
Unlink one or more child nodes from this node.
def unlink(self, key: Union[str, "Node", list, tuple, None] = None):
"""Unlink one or more ``Node`` objects from this ``Node``.

Parameters
----------
key : str, Node, list, or tuple
Identifier of the child(ren) to remove. May be a link key
string, the child ``Node`` object itself, or a list/tuple of
keys or nodes to unlink in bulk.

key: (str, Node, list, tuple, or None, optional)
The key, ``Node`` object, or collection of keys/nodes to unlink.
If a string, the child with that key is unlinked. If a ``Node``
object, the matching child is located and unlinked. If a list or
tuple, each element is unlinked in turn. If ``None`` (the
default), all children are unlinked.

Raises
------
GraphError
If the graph is currently active.
"""
if key is None:
self.unlink(list(self.children))
return
if isinstance(key, Node):
for node in self.children:
if self.children[node] is key:
key = node
break
else:
raise KeyError(f"Node {key.name} not found in parent {self.name}")
elif isinstance(key, (tuple, list)):
for k in key:
self.unlink(k)
return
if key not in self.children:
raise KeyError(f"Child key '{key}' not found in parent {self.name}")
self.__delattr__(key)

def topological_ordering(self) -> tuple["Node"]:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,18 @@ def test_linking(node_graph):
assert e in a.topological_ordering()
with pytest.raises(AttributeError):
a.e
with pytest.raises(KeyError):
a.unlink(e)
with pytest.raises(KeyError):
a.unlink("e")
a.unlink((b, c))

# Check unlink with no arguments clears all children
a.link(e)
assert len(a.children) > 0
a.unlink()
assert len(a.children) == 0


def test_graphviz(node_graph):
a, *_ = node_graph
Expand Down
Loading