From 66727661191f733ce4fbac304201dda50b16f8be Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Fri, 13 Mar 2026 13:27:21 +0100 Subject: [PATCH 1/2] Allow creating graphs with parameterized Node and Edge instances --- docs/how-to/define-nodes-edges.md | 113 +++- docs/how-to/react-to-events.md | 56 +- examples/node_edge_instances.py | 138 +++++ src/panel_reactflow/__init__.py | 4 + src/panel_reactflow/base.py | 821 ++++++++++++++++++++++++++---- tests/test_api.py | 313 +++++++++++- tests/test_core.py | 41 +- 7 files changed, 1385 insertions(+), 101 deletions(-) create mode 100644 examples/node_edge_instances.py diff --git a/docs/how-to/define-nodes-edges.md b/docs/how-to/define-nodes-edges.md index a062a56..6ab4882 100644 --- a/docs/how-to/define-nodes-edges.md +++ b/docs/how-to/define-nodes-edges.md @@ -2,9 +2,9 @@ Every graph in Panel-ReactFlow is built from two lists: **nodes** and **edges**. Nodes represent entities on the canvas; edges represent -connections between them. Both are plain Python dictionaries, so you can -construct them from any data source — a database, a config file, or user -input at runtime. +connections between them. Nodes can be plain dictionaries, `NodeSpec` +objects, or `Node` instances, so you can choose between lightweight payloads +and object-oriented node classes. This guide covers how to create nodes and edges, use the helper dataclasses, and update data after the graph is live. @@ -104,6 +104,40 @@ nodes = [ --- +## Define nodes as classes + +Use `Node` when you want per-node Python state, event hooks, and optional +custom view/editor methods. + +```python +import panel as pn +from panel_reactflow import Node, ReactFlow + + +class JobNode(Node): + def __init__(self, **params): + super().__init__(type="job", data={"status": "idle"}, **params) + + def __panel__(self): + return pn.pane.Markdown(f"**{self.label}**: {self.data.get('status')}") + + def on_move(self, payload, flow): + print(f"{self.id} moved to {payload['position']}") + + +nodes = [ + JobNode(id="j1", label="Fetch", position={"x": 0, "y": 0}), + JobNode(id="j2", label="Process", position={"x": 260, "y": 60}), +] + +flow = ReactFlow(nodes=nodes) +``` + +`Node` instances stay as Python objects in `flow.nodes`; they are serialized +to dicts only when syncing to the frontend. + +--- + ## Define edges Edges link two nodes by their `id`. Use the top-level `label` for the @@ -128,6 +162,79 @@ edges = [ --- +## Define edges as classes + +Use `Edge` when you want object-oriented edge state and edge-specific hooks or +editor logic. + +```python +from panel_reactflow import Edge, ReactFlow + + +class FlowEdge(Edge): + def __init__(self, **params): + super().__init__(type="flow", data={"weight": 1.0}, **params) + + def on_data_change(self, payload, flow): + print(f"{self.id} updated:", payload["patch"]) + + +flow = ReactFlow( + nodes=[ + {"id": "n1", "position": {"x": 0, "y": 0}, "data": {}}, + {"id": "n2", "position": {"x": 260, "y": 60}, "data": {}}, + ], + edges=[FlowEdge(id="e1", source="n1", target="n2")], +) +``` + +`Edge` instances stay as Python objects in `flow.edges`; they are serialized +to dicts only when syncing to the frontend. + +--- + +## Data <-> parameter sync on `Node` and `Edge` + +For class-based nodes/edges, Panel-ReactFlow supports two-way synchronization +between `data` and declared parameters. + +### Which parameters are included? + +Only subclass parameters with **explicit non-negative precedence** +(`precedence >= 0`) are treated as data fields. + +```python +import param +from panel_reactflow import Node + + +class TaskNode(Node): + status = param.Selector(default="idle", objects=["idle", "running", "done"], precedence=0) + retries = param.Integer(default=0, precedence=0) + _internal_state = param.String(default="x", precedence=-1) +``` + +In this example: + +- `status` and `retries` are included in `data` +- `_internal_state` is not included + +### Sync behavior + +- **Parameter -> data**: updating `node.status` or `edge.weight` triggers an + automatic data patch to the graph and frontend. +- **Data -> parameter**: incoming graph patches/sync updates write values back + onto matching parameters. +- **Schema generation**: if no explicit type schema is provided, these + included parameters are used to generate a JSON schema for editors. + +### Editor implication + +If your editor widgets are bound with `from_param(...)`, you usually do not +need manual `on_patch` watchers for those data parameters. + +--- + ## Use the NodeSpec / EdgeSpec helpers If you prefer a typed API, use the dataclass helpers. They validate fields diff --git a/docs/how-to/react-to-events.md b/docs/how-to/react-to-events.md index 4d29581..f28cea6 100644 --- a/docs/how-to/react-to-events.md +++ b/docs/how-to/react-to-events.md @@ -23,10 +23,10 @@ the `ReactFlow` instance as a second argument. You can also listen for | `node_deleted` | A node is removed. | `node_id` | | `node_moved` | A node is dragged to a new position. | `node_id`, `position` | | `node_clicked` | A node is clicked (single click). | `node_id` | -| `node_data_changed` | `patch_node_data()` is called. | `node_id`, `patch` | +| `node_data_changed` | Node data is patched (via API, editor patch, or parameter-driven sync). | `node_id`, `patch` | | `edge_added` | An edge is created (UI connect or API). | `edge` | | `edge_deleted` | An edge is removed. | `edge_id` | -| `edge_data_changed` | `patch_edge_data()` is called. | `edge_id`, `patch` | +| `edge_data_changed` | Edge data is patched (via API, editor patch, or parameter-driven sync). | `edge_id`, `patch` | | `selection_changed` | The active selection changes. | `nodes`, `edges` | | `sync` | A batch sync from the frontend. | *(varies)* | @@ -57,6 +57,58 @@ pn.Column(log, flow).servable() --- +## Handle events on `Node` classes + +If you define nodes as `Node` subclasses, you can implement hooks directly on +the node instance: + +```python +from panel_reactflow import Node, ReactFlow + + +class TaskNode(Node): + def on_event(self, payload, flow): + print("any node event:", payload["type"]) + + def on_delete(self, payload, flow): + print("deleted:", self.id) + + +flow = ReactFlow(nodes=[TaskNode(id="t1", position={"x": 0, "y": 0}, data={})]) +``` + +Common hooks include `on_event` (wildcard), `on_add`, `on_move`, `on_click`, +`on_data_change`, and `on_delete`. + +When a `Node` subclass parameter with `precedence >= 0` changes, it +automatically patches node data and will trigger `on_data_change`. + +--- + +## Handle events on `Edge` classes + +`Edge` subclasses can handle edge lifecycle and patch events directly: + +```python +from panel_reactflow import Edge, ReactFlow + + +class WeightedEdge(Edge): + def on_data_change(self, payload, flow): + print("edge patch:", payload["patch"]) + + def on_delete(self, payload, flow): + print("edge deleted:", self.id) +``` + +Common edge hooks include `on_event`, `on_add`, `on_data_change`, +`on_selection_changed`, and `on_delete`. + +Likewise, changing an `Edge` subclass data parameter (`precedence >= 0`) +triggers `on_data_change` through the same data patch pipeline. + +--- + ## Listen for all events Use the wildcard `"*"` to receive every event. This is useful for diff --git a/examples/node_edge_instances.py b/examples/node_edge_instances.py new file mode 100644 index 0000000..fee3dc8 --- /dev/null +++ b/examples/node_edge_instances.py @@ -0,0 +1,138 @@ +"""Complex example using Node and Edge class instances. + +Demonstrates: +- ``Node`` / ``Edge`` subclass instances in ``ReactFlow`` +- Per-instance ``__panel__`` node views +- Per-instance custom editors via ``editor(...)`` +- Node/edge event hooks (``on_data_change``, ``on_selection_changed``) +- Programmatic updates with ``patch_node_data`` / ``patch_edge_data`` +""" + +import random + +import panel as pn +import panel_material_ui as pmui +import param + +from panel_reactflow import Edge, Node, ReactFlow + +pn.extension() + + +class PipelineNode(Node): + status = param.Selector(default="idle", objects=["idle", "running", "done", "failed"], precedence=0) + retries = param.Integer(default=0, bounds=(0, None), precedence=0) + owner = param.String(default="ops", precedence=0) + notes = param.String(default="", precedence=0) + + def __init__(self, **params): + params.setdefault("type", "pipeline") + super().__init__(**params) + self._summary = pn.pane.Markdown(margin=(0, 0, 6, 0)) + self._activity = pn.pane.Markdown("", styles={"font-size": "12px", "opacity": "0.8"}) + self.param.watch(self._refresh_view, ["status", "owner", "retries", "label"]) + self._refresh_view() + + def _refresh_view(self, *_): + self._summary.object = ( + f"**{self.label}** \n" + f"Status: `{self.status}` \n" + f"Owner: `{self.owner}` \n" + f"Retries: `{self.retries}`" + ) + + def __panel__(self): + return pn.Column(self._summary, self._activity, margin=0, sizing_mode="stretch_width") + + def editor(self, data, schema, *, id, type, on_patch): + status = pmui.Select.from_param(self.param.status, name="Status") + retries = pmui.IntInput.from_param(self.param.retries, name="Retries") + owner = pmui.TextInput.from_param(self.param.owner, name="Owner") + notes = pmui.TextAreaInput.from_param(self.param.notes, name="Notes", height=80) + return pn.Column(status, retries, owner, notes, sizing_mode="stretch_width") + + def on_data_change(self, payload, flow): + if payload.get("node_id") == self.id: + self._activity.object = f"Last patch: `{payload.get('patch', {})}`" + + def on_selection_changed(self, payload, flow): + selected = self.id in (payload.get("nodes") or []) + if selected: + self._activity.object = "Selected in canvas" + + +class WeightedEdge(Edge): + weight = param.Number(default=0.5, bounds=(0, 1), precedence=0) + channel = param.Selector(default="main", objects=["main", "backup", "shadow"], precedence=0) + enabled = param.Boolean(default=True, precedence=0) + + def __init__(self, **params): + params.setdefault("type", "weighted") + super().__init__(**params) + + def editor(self, data, schema, *, id, type, on_patch): + weight = pmui.FloatSlider.from_param(self.param.weight, name="Weight", step=0.01) + channel = pmui.Select.from_param(self.param.channel, name="Channel") + enabled = pmui.Checkbox.from_param(self.param.enabled, name="Enabled") + return pn.Column(weight, channel, enabled, sizing_mode="stretch_width") + + +nodes = [ + PipelineNode(id="extract", label="Extract", position={"x": 0, "y": 40}), + PipelineNode(id="transform", label="Transform", position={"x": 300, "y": 160}, status="running", retries=1, owner="ml", notes="Batch window"), + PipelineNode(id="load", label="Load", position={"x": 600, "y": 40}, owner="platform"), +] + +edges = [ + WeightedEdge(id="e1", source="extract", target="transform", weight=0.72), + WeightedEdge(id="e2", source="transform", target="load", weight=0.63, channel="backup"), +] + +event_log = pmui.TextAreaInput(name="Events", value="", disabled=True, height=180, sizing_mode="stretch_width") +last_event = pn.pane.Markdown("**Last event:** _none_") + +flow = ReactFlow( + nodes=nodes, + edges=edges, + editor_mode="side", + sizing_mode="stretch_both", +) + +def _log_event(payload): + event_type = payload.get("type", "unknown") + last_event.object = f"**Last event:** `{event_type}`" + snippet = str(payload) + event_log.value = f"{event_log.value}\n{event_type}: {snippet}"[-6000:] + + +flow.on("*", _log_event) + + +def _advance_nodes(_): + order = {"idle": "running", "running": "done", "done": "done", "failed": "idle"} + for node in nodes: + current = node.status + flow.patch_node_data(node.id, {"status": order.get(current, "idle")}) + + +def _randomize_weights(_): + for edge in edges: + flow.patch_edge_data(edge.id, {"weight": round(random.uniform(0.05, 0.95), 2)}) + + +advance_btn = pmui.Button(name="Advance pipeline") +advance_btn.on_click(_advance_nodes) + +weights_btn = pmui.Button(name="Randomize edge weights") +weights_btn.on_click(_randomize_weights) + +controls = pn.Row(advance_btn, weights_btn, sizing_mode="stretch_width") + +pn.Column( + pn.pane.Markdown("## Node/Edge Instance Workflow"), + controls, + last_event, + flow, + event_log, + sizing_mode="stretch_both", +).servable() diff --git a/src/panel_reactflow/__init__.py b/src/panel_reactflow/__init__.py index c3dd158..df168d1 100644 --- a/src/panel_reactflow/__init__.py +++ b/src/panel_reactflow/__init__.py @@ -2,10 +2,12 @@ from .__version import __version__ # noqa from .base import ( + Edge, EdgeSpec, EdgeType, Editor, JsonEditor, + Node, NodeSpec, NodeType, ReactFlow, @@ -14,10 +16,12 @@ ) __all__ = [ + "Edge", "EdgeSpec", "EdgeType", "Editor", "JsonEditor", + "Node", "NodeSpec", "NodeType", "ReactFlow", diff --git a/src/panel_reactflow/base.py b/src/panel_reactflow/base.py index faed470..a4ff643 100644 --- a/src/panel_reactflow/base.py +++ b/src/panel_reactflow/base.py @@ -87,6 +87,30 @@ def _param_to_jsonschema(parameterized_cls: type) -> dict[str, Any]: return {"type": "object", "properties": properties} +def _parameterized_data_param_names(parameterized_cls: type[param.Parameterized], base_cls: type[param.Parameterized]) -> list[str]: + """Return subclass-defined parameter names included in node/edge data. + + Only parameters with explicitly non-negative precedence are included. + """ + base_params = set(base_cls.param) + names: list[str] = [] + for name in parameterized_cls.param: + if name in base_params or name.startswith("_"): + continue + precedence = parameterized_cls.param[name].precedence + if precedence is not None and precedence >= 0: + names.append(name) + return names + + +def _parameterized_data_schema(parameterized_cls: type[param.Parameterized], base_cls: type[param.Parameterized]) -> dict[str, Any]: + """Build a JSON Schema for subclass-defined data parameters.""" + names = _parameterized_data_param_names(parameterized_cls, base_cls) + schema = _param_to_jsonschema(parameterized_cls) + properties = schema.get("properties", {}) + return {"type": "object", "properties": {name: properties[name] for name in names if name in properties}} + + def _pydantic_to_jsonschema(model_cls: type) -> dict[str, Any]: """Convert a Pydantic ``BaseModel`` class to a JSON Schema dict.""" return model_cls.model_json_schema() @@ -548,6 +572,101 @@ def from_dict(cls, payload: dict[str, Any]) -> "NodeSpec": return cls(**payload) +class Node(param.Parameterized): + """Base class for object-oriented nodes. + + Subclass this class when you want node instances to keep Python-side state + and react to graph events directly. Node instances can be passed anywhere + a node dict/``NodeSpec`` is accepted. + + Subclasses can customize: + + - ``__panel__`` to render node content. + - ``editor`` to provide a node-specific editor. + - ``on_event`` (wildcard) and event-specific ``on_*`` hooks. + """ + + id = param.String(default="", doc="Unique node identifier.") + position = param.Dict(default={"x": 0.0, "y": 0.0}, doc="Node position.") + type = param.String(default="panel", doc="Node type.") + label = param.String(default=None, allow_None=True, doc="Display label.") + data = param.Dict(default={}, doc="Custom node data.") + selected = param.Boolean(default=False, doc="Selection state.") + draggable = param.Boolean(default=True, doc="Whether node is draggable.") + connectable = param.Boolean(default=True, doc="Whether node is connectable.") + deletable = param.Boolean(default=True, doc="Whether node is deletable.") + style = param.Dict(default=None, allow_None=True, doc="Optional node style.") + className = param.String(default=None, allow_None=True, doc="Optional CSS class.") + + @classmethod + def _data_param_names(cls) -> list[str]: + return _parameterized_data_param_names(cls, Node) + + @classmethod + def _data_schema(cls) -> dict[str, Any]: + return _parameterized_data_schema(cls, Node) + + def to_dict(self) -> dict[str, Any]: + """Convert this node instance to a ReactFlow-compatible dictionary.""" + data = dict(self.data or {}) + for name in self._data_param_names(): + data[name] = getattr(self, name) + payload = { + "id": self.id, + "position": dict(self.position or {"x": 0.0, "y": 0.0}), + "type": self.type or "panel", + "label": self.label, + "data": data, + "selected": self.selected, + "draggable": self.draggable, + "connectable": self.connectable, + "deletable": self.deletable, + } + if self.style is not None: + payload["style"] = dict(self.style) + if self.className is not None: + payload["className"] = self.className + view = self.__panel__() + if view is not None: + payload["view"] = view + return payload + + def __panel__(self) -> Any | None: + """Optional view rendered inside the node.""" + return None + + def editor(self, data, schema, *, id, type, on_patch): + """Optional per-node editor factory. + + Return ``None`` to fall back to type/default editors. + """ + return None + + def on_event(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Wildcard event hook for node-related events.""" + + def on_add(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Hook called when this node is added.""" + + def on_delete(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Hook called when this node is deleted.""" + + def on_move(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Hook called when this node moves.""" + + def on_click(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Hook called when this node is clicked.""" + + def on_data_change(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Hook called when this node's data changes.""" + + def on_selection_changed(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Hook called when this node participates in a selection update.""" + + def on_sync(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Hook called when the graph receives a sync payload.""" + + @dataclass class EdgeSpec: """Builder for edge dictionaries with validation and type safety. @@ -705,6 +824,79 @@ def from_dict(cls, payload: dict[str, Any]) -> "EdgeSpec": return cls(**payload) +class Edge(param.Parameterized): + """Base class for object-oriented edges.""" + + id = param.String(default="", doc="Unique edge identifier.") + source = param.String(default="", doc="Source node id.") + target = param.String(default="", doc="Target node id.") + label = param.String(default=None, allow_None=True, doc="Display label.") + type = param.String(default=None, allow_None=True, doc="Edge type.") + selected = param.Boolean(default=False, doc="Selection state.") + data = param.Dict(default={}, doc="Custom edge data.") + style = param.Dict(default=None, allow_None=True, doc="Optional edge style.") + markerEnd = param.Dict(default=None, allow_None=True, doc="Optional edge end marker.") + sourceHandle = param.String(default=None, allow_None=True, doc="Optional source handle id.") + targetHandle = param.String(default=None, allow_None=True, doc="Optional target handle id.") + + @classmethod + def _data_param_names(cls) -> list[str]: + return _parameterized_data_param_names(cls, Edge) + + @classmethod + def _data_schema(cls) -> dict[str, Any]: + return _parameterized_data_schema(cls, Edge) + + def to_dict(self) -> dict[str, Any]: + """Convert this edge instance to a ReactFlow-compatible dictionary.""" + data = dict(self.data or {}) + for name in self._data_param_names(): + data[name] = getattr(self, name) + payload = { + "id": self.id, + "source": self.source, + "target": self.target, + "label": self.label, + "type": self.type, + "selected": self.selected, + "data": data, + } + if self.style is not None: + payload["style"] = dict(self.style) + if self.markerEnd is not None: + payload["markerEnd"] = dict(self.markerEnd) + if self.sourceHandle is not None: + payload["sourceHandle"] = self.sourceHandle + if self.targetHandle is not None: + payload["targetHandle"] = self.targetHandle + return payload + + def editor(self, data, schema, *, id, type, on_patch): + """Optional per-edge editor factory. + + Return ``None`` to fall back to type/default editors. + """ + return None + + def on_event(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Wildcard event hook for edge-related events.""" + + def on_add(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Hook called when this edge is added.""" + + def on_delete(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Hook called when this edge is deleted.""" + + def on_data_change(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Hook called when this edge's data changes.""" + + def on_selection_changed(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Hook called when this edge participates in a selection update.""" + + def on_sync(self, payload: dict[str, Any], flow: "ReactFlow") -> None: + """Hook called when the graph receives a sync payload.""" + + class Editor(Viewer): """Base class for custom node and edge editors. @@ -975,14 +1167,16 @@ class ReactFlow(ReactComponent): Parameters ---------- - nodes : list of dict, default [] - List of node dictionaries defining the graph nodes. Each node should - have at minimum ``id``, ``position``, and ``type`` fields. Use - :class:`NodeSpec` for type-safe node creation. - edges : list of dict, default [] - List of edge dictionaries defining connections between nodes. Each edge - should have ``id``, ``source``, and ``target`` fields. Use - :class:`EdgeSpec` for type-safe edge creation. + nodes : list of dict or Node, default [] + List of node dictionaries or :class:`Node` instances defining the + graph nodes. Each node should have at minimum ``id``, ``position``, + and ``type`` fields. Use :class:`NodeSpec` or :class:`Node` for + structured node creation. + edges : list of dict or Edge, default [] + List of edge dictionaries or :class:`Edge` instances defining + connections between nodes. Each edge should have ``id``, ``source``, + and ``target`` fields. Use :class:`EdgeSpec` or :class:`Edge` for + structured edge creation. node_types : dict, default {} Dictionary mapping type names to :class:`NodeType` definitions or dicts. Define custom node types with schemas, ports, and validation. @@ -1134,7 +1328,9 @@ class ReactFlow(ReactComponent): See Also -------- + Node : Base class for object-oriented nodes NodeSpec : Builder for node dictionaries + Edge : Base class for object-oriented edges EdgeSpec : Builder for edge dictionaries NodeType : Define custom node types with schemas EdgeType : Define custom edge types with schemas @@ -1153,8 +1349,8 @@ class ReactFlow(ReactComponent): - Disabling ``show_minimap`` if not needed """ - nodes = param.List(default=[], doc="Canonical list of node dictionaries.") - edges = param.List(default=[], doc="Canonical list of edge dictionaries.") + nodes = param.List(default=[], doc="Canonical list of node dictionaries or Node instances.") + edges = param.List(default=[], doc="Canonical list of edge dictionaries or Edge instances.") node_types = param.Dict(default={}, doc="Node type descriptors keyed by type name.") edge_types = param.Dict(default={}, doc="Edge type descriptors keyed by type name.") @@ -1221,13 +1417,15 @@ def __init__(self, **params: Any): params["color_mode"] = "dark" if pn.config.theme == "dark" else "light" self._node_ids: list[str] = [] self._edge_ids: list[str] = [] + self._node_data_param_watchers: dict[str, tuple[Node, list[Any]]] = {} + self._edge_data_param_watchers: dict[str, tuple[Edge, list[Any]]] = {} # Normalize type specs before parent init so the frontend receives # JSON-serializable descriptors from the start. if "node_types" in params: params["node_types"] = _coerce_spec_map(params["node_types"]) if "edge_types" in params: params["edge_types"] = _coerce_spec_map(params["edge_types"], edge=True) - # Normalize nodes and edges to ensure NodeSpec/EdgeSpec are converted to dicts + # Normalize nodes and edges to ensure NodeSpec/EdgeSpec are converted. if "nodes" in params: params["nodes"] = [ReactFlow._coerce_node(node) for node in params["nodes"]] if "edges" in params: @@ -1236,6 +1434,7 @@ def __init__(self, **params: Any): self._event_handlers: dict[str, list[Callable]] = {"*": []} self.param.watch(self._normalize_nodes, ["nodes"]) self.param.watch(self._normalize_edges, ["edges"]) + self.param.watch(self._update_instance_data_param_watchers, ["nodes", "edges"]) self.param.watch(self._update_selection_from_graph, ["nodes", "edges"]) self.param.watch(self._normalize_specs, ["node_types", "edge_types"]) self.param.watch( @@ -1248,6 +1447,7 @@ def __init__(self, **params: Any): ) self._update_node_editors() self._update_edge_editors() + self._update_instance_data_param_watchers() @classmethod def _esm_path(cls, compiled: bool | Literal["compiling"] = True) -> os.PathLike | None: @@ -1273,16 +1473,28 @@ def _bundle_path(cls) -> os.PathLike | None: def _get_node_schema(self, node_type: str) -> dict[str, Any] | None: """Return the normalized JSON Schema for *node_type*, or ``None``.""" type_spec = self.node_types.get(node_type) - if type_spec is None: - return None - return type_spec.get("schema") + if type_spec is not None and type_spec.get("schema") is not None: + return type_spec.get("schema") + for node in self.nodes: + if isinstance(node, Node) and node.type == node_type: + schema = node._data_schema() + if schema.get("properties"): + return schema + break + return None def _get_edge_schema(self, edge_type: str) -> dict[str, Any] | None: """Return the normalized JSON Schema for *edge_type*, or ``None``.""" type_spec = self.edge_types.get(edge_type) - if type_spec is None: - return None - return type_spec.get("schema") + if type_spec is not None and type_spec.get("schema") is not None: + return type_spec.get("schema") + for edge in self.edges: + if isinstance(edge, Edge) and (edge.type or "") == edge_type: + schema = edge._data_schema() + if schema.get("properties"): + return schema + break + return None def _create_editor( self, @@ -1308,8 +1520,297 @@ def on_patch(patch: dict) -> None: return factory(data, schema, id=item_id, type=item_type, on_patch=on_patch) + @staticmethod + def _node_is_instance(node: Any) -> bool: + return isinstance(node, Node) + + @staticmethod + def _node_id(node: dict[str, Any] | Node) -> str | None: + return node.id if isinstance(node, Node) else node.get("id") + + @staticmethod + def _node_type(node: dict[str, Any] | Node) -> str: + if isinstance(node, Node): + return node.type or "panel" + return node.get("type", "panel") + + @staticmethod + def _node_data(node: dict[str, Any] | Node) -> dict[str, Any]: + if isinstance(node, Node): + return dict(node.data or {}) + return dict(node.get("data", {})) + + @staticmethod + def _node_view(node: dict[str, Any] | Node) -> Any | None: + if isinstance(node, Node): + return node.__panel__() + return node.get("view", None) + + @staticmethod + def _node_set_position(node: dict[str, Any] | Node, position: dict[str, Any]) -> None: + if isinstance(node, Node): + node.position = dict(position) + else: + node["position"] = position + + @staticmethod + def _node_set_selected(node: dict[str, Any] | Node, selected: bool) -> None: + if isinstance(node, Node): + node.selected = selected + else: + node["selected"] = selected + + @staticmethod + def _node_set_data(node: dict[str, Any] | Node, data: dict[str, Any]) -> None: + if isinstance(node, Node): + node.data = dict(data) + else: + node["data"] = data + + @staticmethod + def _node_payload(node: dict[str, Any] | NodeSpec | Node) -> dict[str, Any]: + if isinstance(node, Node): + return node.to_dict() + if isinstance(node, NodeSpec): + return node.to_dict() + return dict(node) + + @staticmethod + def _edge_id(edge: dict[str, Any] | Edge) -> str | None: + return edge.id if isinstance(edge, Edge) else edge.get("id") + + @staticmethod + def _edge_type(edge: dict[str, Any] | Edge) -> str: + if isinstance(edge, Edge): + return edge.type or "" + return edge.get("type", "") + + @staticmethod + def _edge_data(edge: dict[str, Any] | Edge) -> dict[str, Any]: + if isinstance(edge, Edge): + return dict(edge.data or {}) + return dict(edge.get("data", {})) + + @staticmethod + def _edge_set_selected(edge: dict[str, Any] | Edge, selected: bool) -> None: + if isinstance(edge, Edge): + edge.selected = selected + else: + edge["selected"] = selected + + @staticmethod + def _edge_set_data(edge: dict[str, Any] | Edge, data: dict[str, Any]) -> None: + if isinstance(edge, Edge): + edge.data = dict(data) + else: + edge["data"] = data + + @staticmethod + def _edge_payload(edge: dict[str, Any] | EdgeSpec | Edge) -> dict[str, Any]: + if isinstance(edge, Edge): + return edge.to_dict() + if isinstance(edge, EdgeSpec): + return edge.to_dict() + return dict(edge) + + def _get_node_instance(self, node_id: str) -> Node | None: + for node in self.nodes: + if isinstance(node, Node) and node.id == node_id: + return node + return None + + def _get_edge_instance(self, edge_id: str) -> Edge | None: + for edge in self.edges: + if isinstance(edge, Edge) and edge.id == edge_id: + return edge + return None + + @staticmethod + def _sync_node_data_params_from_data(node: Node) -> None: + for name in node._data_param_names(): + if name in (node.data or {}): + setattr(node, name, node.data[name]) + + @staticmethod + def _sync_edge_data_params_from_data(edge: Edge) -> None: + for name in edge._data_param_names(): + if name in (edge.data or {}): + setattr(edge, name, edge.data[name]) + + def _teardown_node_data_param_watcher(self, node_id: str) -> None: + record = self._node_data_param_watchers.pop(node_id, None) + if record is None: + return + node, watchers = record + for watcher in watchers: + try: + node.param.unwatch(watcher) + except Exception: + pass + + def _teardown_edge_data_param_watcher(self, edge_id: str) -> None: + record = self._edge_data_param_watchers.pop(edge_id, None) + if record is None: + return + edge, watchers = record + for watcher in watchers: + try: + edge.param.unwatch(watcher) + except Exception: + pass + + def _update_instance_data_param_watchers(self, *_: param.parameterized.Event) -> None: + self._update_node_data_param_watchers() + self._update_edge_data_param_watchers() + + def _update_node_data_param_watchers(self) -> None: + current_nodes = {node.id: node for node in self.nodes if isinstance(node, Node) and node.id} + for node_id, (watched_node, _) in list(self._node_data_param_watchers.items()): + current = current_nodes.get(node_id) + if current is None or current is not watched_node: + self._teardown_node_data_param_watcher(node_id) + for node_id, node in current_nodes.items(): + if node_id in self._node_data_param_watchers: + continue + watchers = [] + for name in node._data_param_names(): + watchers.append( + node.param.watch( + lambda event, _id=node_id, _name=name, _node=node: self._on_node_data_param_change(_id, _name, _node, event), + name, + ) + ) + self._node_data_param_watchers[node_id] = (node, watchers) + + def _update_edge_data_param_watchers(self) -> None: + current_edges = {edge.id: edge for edge in self.edges if isinstance(edge, Edge) and edge.id} + for edge_id, (watched_edge, _) in list(self._edge_data_param_watchers.items()): + current = current_edges.get(edge_id) + if current is None or current is not watched_edge: + self._teardown_edge_data_param_watcher(edge_id) + for edge_id, edge in current_edges.items(): + if edge_id in self._edge_data_param_watchers: + continue + watchers = [] + for name in edge._data_param_names(): + watchers.append( + edge.param.watch( + lambda event, _id=edge_id, _name=name, _edge=edge: self._on_edge_data_param_change(_id, _name, _edge, event), + name, + ) + ) + self._edge_data_param_watchers[edge_id] = (edge, watchers) + + def _on_node_data_param_change( + self, + node_id: str, + param_name: str, + node: Node, + event: param.parameterized.Event, + ) -> None: + if self._get_node_instance(node_id) is not node: + return + if (node.data or {}).get(param_name) == event.new: + return + self.patch_node_data(node_id, {param_name: event.new}) + + def _on_edge_data_param_change( + self, + edge_id: str, + param_name: str, + edge: Edge, + event: param.parameterized.Event, + ) -> None: + if self._get_edge_instance(edge_id) is not edge: + return + if (edge.data or {}).get(param_name) == event.new: + return + self.patch_edge_data(edge_id, {param_name: event.new}) + + @staticmethod + def _invoke_node_callback(callback: Callable, payload: dict[str, Any], flow: "ReactFlow") -> None: + if len(inspect.signature(callback).parameters) == 2: + cb = partial(callback, payload, flow) + else: + cb = partial(callback, payload) + pn.state.execute(cb) + + def _invoke_node_hook(self, node: Node, hook_name: str, payload: dict[str, Any]) -> None: + hook = getattr(node, hook_name, None) + if callable(hook): + self._invoke_node_callback(hook, payload, self) + + def _invoke_edge_hook(self, edge: Edge, hook_name: str, payload: dict[str, Any]) -> None: + hook = getattr(edge, hook_name, None) + if callable(hook): + self._invoke_node_callback(hook, payload, self) + + def _dispatch_node_hooks(self, event_type: str, payload: dict[str, Any]) -> None: + node_ids: list[str] = [] + if event_type == "node_added": + node_payload = payload.get("node", {}) + if isinstance(node_payload, dict) and node_payload.get("id"): + node_ids = [node_payload["id"]] + elif event_type in ("node_moved", "node_clicked", "node_data_changed"): + node_id = payload.get("node_id") + if node_id: + node_ids = [node_id] + elif event_type == "selection_changed": + node_ids = list(payload.get("nodes") or []) + elif event_type == "sync": + node_ids = [node.id for node in self.nodes if isinstance(node, Node)] + + hook_map = { + "node_added": "on_add", + "node_moved": "on_move", + "node_clicked": "on_click", + "node_data_changed": "on_data_change", + "selection_changed": "on_selection_changed", + "sync": "on_sync", + } + hook_name = hook_map.get(event_type) + for node_id in node_ids: + node = self._get_node_instance(node_id) + if node is None: + continue + if hook_name is not None: + self._invoke_node_hook(node, hook_name, payload) + self._invoke_node_hook(node, "on_event", payload) + + def _dispatch_edge_hooks(self, event_type: str, payload: dict[str, Any]) -> None: + edge_ids: list[str] = [] + if event_type == "edge_added": + edge_payload = payload.get("edge", {}) + if isinstance(edge_payload, dict) and edge_payload.get("id"): + edge_ids = [edge_payload["id"]] + elif event_type in ("edge_deleted", "edge_data_changed"): + edge_id = payload.get("edge_id") + if edge_id: + edge_ids = [edge_id] + elif event_type == "selection_changed": + edge_ids = list(payload.get("edges") or []) + elif event_type == "sync": + edge_ids = [edge.id for edge in self.edges if isinstance(edge, Edge)] + + hook_map = { + "edge_added": "on_add", + "edge_deleted": "on_delete", + "edge_data_changed": "on_data_change", + "selection_changed": "on_selection_changed", + "sync": "on_sync", + } + hook_name = hook_map.get(event_type) + for edge_id in edge_ids: + edge = self._get_edge_instance(edge_id) + if edge is None: + continue + if hook_name is not None: + self._invoke_edge_hook(edge, hook_name, payload) + self._invoke_edge_hook(edge, "on_event", payload) + def _update_node_editors(self, *events: tuple[param.parameterized.Event]) -> None: - node_ids = [node["id"] for node in self.nodes] + node_ids = [self._node_id(node) for node in self.nodes] + node_ids = [node_id for node_id in node_ids if node_id is not None] config_changed = any(event.name in ("editor_mode", "node_editors", "default_node_editor") for event in events) if node_ids == self._node_ids and not config_changed: return @@ -1317,28 +1818,38 @@ def _update_node_editors(self, *events: tuple[param.parameterized.Event]) -> Non editors = {} for node in self.nodes: - node_id = node.get("id") + node_id = self._node_id(node) + if node_id is None: + continue if node_id in self._node_editors and not config_changed: editors[node_id] = self._node_editors[node_id] continue - node_type = node.get("type", "panel") - editor_factory = self.node_editors.get(node_type) or self.default_node_editor or SchemaEditor + node_type = self._node_type(node) + editor_factory = None + if isinstance(node, Node) and type(node).editor is not Node.editor: + editor_factory = node.editor + if editor_factory is None: + editor_factory = self.node_editors.get(node_type) or self.default_node_editor or SchemaEditor schema = self._get_node_schema(node_type) - data = node.get("data", {}) - editor = self._create_editor( - editor_factory, - node_id, - data, - schema, - node_type, - patch_fn=self.patch_node_data, - ) + data = self._node_data(node) + if callable(editor_factory): + editor = self._create_editor( + editor_factory, + node_id, + data, + schema, + node_type, + patch_fn=self.patch_node_data, + ) + else: + editor = editor_factory editors[node_id] = editor self._node_editors = editors self.param.trigger("_node_editor_views") def _update_edge_editors(self, *events: tuple[param.parameterized.Event]) -> None: - edge_ids = [edge["id"] for edge in self.edges] + edge_ids = [self._edge_id(edge) for edge in self.edges] + edge_ids = [edge_id for edge_id in edge_ids if edge_id is not None] config_changed = any(event.name in ("edge_editors", "default_edge_editor") for event in events) if edge_ids == self._edge_ids and not config_changed: return @@ -1346,22 +1857,31 @@ def _update_edge_editors(self, *events: tuple[param.parameterized.Event]) -> Non editors = {} for edge in self.edges: - edge_id = edge.get("id") + edge_id = self._edge_id(edge) + if edge_id is None: + continue if edge_id in self._edge_editors and not config_changed: editors[edge_id] = self._edge_editors[edge_id] continue - edge_type = edge.get("type", "") - editor_factory = self.edge_editors.get(edge_type) or self.default_edge_editor or SchemaEditor + edge_type = self._edge_type(edge) + editor_factory = None + if isinstance(edge, Edge) and type(edge).editor is not Edge.editor: + editor_factory = edge.editor + if editor_factory is None: + editor_factory = self.edge_editors.get(edge_type) or self.default_edge_editor or SchemaEditor schema = self._get_edge_schema(edge_type) if edge_type else None - data = edge.get("data", {}) - editor = self._create_editor( - editor_factory, - edge_id, - data, - schema, - edge_type, - patch_fn=self.patch_edge_data, - ) + data = self._edge_data(edge) + if callable(editor_factory): + editor = self._create_editor( + editor_factory, + edge_id, + data, + schema, + edge_type, + patch_fn=self.patch_edge_data, + ) + else: + editor = editor_factory editors[edge_id] = editor self._edge_editors = editors self.param.trigger("_edge_editor_views") @@ -1379,11 +1899,11 @@ def _get_children(self, data_model, doc, root, parent, comm) -> tuple[dict[str, views = [] node_editors = [] for node in self.nodes: - view = node.get("view", None) + view = self._node_view(node) if view is not None: views.append(self._resolve_editor_view(view)) - node_editors.append(self._resolve_editor_view(self._node_editors.get(node.get("id")))) - edge_editors = [self._resolve_editor_view(self._edge_editors.get(edge.get("id"))) for edge in self.edges] + node_editors.append(self._resolve_editor_view(self._node_editors.get(self._node_id(node)))) + edge_editors = [self._resolve_editor_view(self._edge_editors.get(self._edge_id(edge))) for edge in self.edges] children: dict[str, list[UIElement] | UIElement | None] = {} old_models: list[UIElement] = [] @@ -1419,7 +1939,7 @@ def _process_param_change(self, params): nodes = [] view_idx = 0 for node in params["nodes"]: - node = dict(node) + node = self._node_payload(node) view = node.pop("view", None) data = dict(node.get("data", {})) if view is not None: @@ -1428,6 +1948,8 @@ def _process_param_change(self, params): node["data"] = data nodes.append(node) params["nodes"] = nodes + if "edges" in params: + params["edges"] = [self._edge_payload(edge) for edge in params["edges"]] # node_types / edge_types are now JSON-serializable descriptors # and intentionally synced to the frontend. # Pop Python-only editor registries and internal state. @@ -1441,7 +1963,7 @@ def _process_param_change(self, params): params.pop("_edge_editors", None) return params - def add_node(self, node: dict[str, Any] | NodeSpec) -> None: + def add_node(self, node: dict[str, Any] | NodeSpec | Node) -> None: """Add a node to the graph. Adds a new node to the graph with optional validation. If a ``view`` @@ -1450,8 +1972,8 @@ def add_node(self, node: dict[str, Any] | NodeSpec) -> None: Parameters ---------- - node : dict or NodeSpec - Node dictionary or :class:`NodeSpec` instance to add. The only + node : dict or NodeSpec or Node + Node dictionary, :class:`NodeSpec`, or :class:`Node` instance to add. The only required field is ``id``. Other fields have defaults: - ``id``: Unique node identifier (required) @@ -1511,15 +2033,29 @@ def add_node(self, node: dict[str, Any] | NodeSpec) -> None: remove_node : Remove a node from the graph NodeSpec : Helper for constructing node dictionaries """ - payload = self._coerce_node(node) + raw_node = self._coerce_node(node) + payload = self._node_payload(raw_node) payload.setdefault("type", "panel") payload.setdefault("data", {}) payload.setdefault("position", {"x": 0.0, "y": 0.0}) + payload.setdefault("selected", False) + payload.setdefault("draggable", True) + payload.setdefault("connectable", True) + payload.setdefault("deletable", True) + if isinstance(raw_node, Node): + raw_node.type = payload["type"] + raw_node.data = dict(payload["data"]) + raw_node.position = dict(payload["position"]) + raw_node.selected = payload["selected"] + raw_node.draggable = payload["draggable"] + raw_node.connectable = payload["connectable"] + raw_node.deletable = payload["deletable"] + self._sync_node_data_params_from_data(raw_node) self._validate_graph_payload(payload, kind="node") if self.validate_on_add: schema = self._get_node_schema(payload.get("type", "panel")) _validate_data(payload.get("data", {}), schema) - self.nodes = self.nodes + [payload] + self.nodes = self.nodes + [raw_node if isinstance(raw_node, Node) else payload] self._emit("node_added", {"type": "node_added", "node": payload}) def _handle_msg(self, msg: dict[str, Any]) -> None: @@ -1531,9 +2067,55 @@ def _handle_msg(self, msg: dict[str, Any]) -> None: nodes = msg.get("nodes") edges = msg.get("edges") if nodes is not None: - self.nodes = nodes + current_instances = {node.id: node for node in self.nodes if isinstance(node, Node)} + synced_nodes: list[dict[str, Any] | Node] = [] + for payload in nodes: + node_id = payload.get("id") + if node_id in current_instances: + node = current_instances[node_id] + node.position = dict(payload.get("position", node.position)) + node.type = payload.get("type", node.type) + node.label = payload.get("label", node.label) + node.data = dict(payload.get("data", node.data)) + self._sync_node_data_params_from_data(node) + node.selected = payload.get("selected", node.selected) + node.draggable = payload.get("draggable", node.draggable) + node.connectable = payload.get("connectable", node.connectable) + node.deletable = payload.get("deletable", node.deletable) + if "style" in payload: + node.style = payload.get("style") + if "className" in payload: + node.className = payload.get("className") + synced_nodes.append(node) + else: + synced_nodes.append(payload) + self.nodes = synced_nodes if edges is not None: - self.edges = edges + current_instances = {edge.id: edge for edge in self.edges if isinstance(edge, Edge)} + synced_edges: list[dict[str, Any] | Edge] = [] + for payload in edges: + edge_id = payload.get("id") + if edge_id in current_instances: + edge = current_instances[edge_id] + edge.source = payload.get("source", edge.source) + edge.target = payload.get("target", edge.target) + edge.label = payload.get("label", edge.label) + edge.type = payload.get("type", edge.type) + edge.selected = payload.get("selected", edge.selected) + edge.data = dict(payload.get("data", edge.data)) + self._sync_edge_data_params_from_data(edge) + if "style" in payload: + edge.style = payload.get("style") + if "markerEnd" in payload: + edge.markerEnd = payload.get("markerEnd") + if "sourceHandle" in payload: + edge.sourceHandle = payload.get("sourceHandle") + if "targetHandle" in payload: + edge.targetHandle = payload.get("targetHandle") + synced_edges.append(edge) + else: + synced_edges.append(payload) + self.edges = synced_edges self._emit("sync", msg) case "node_moved": node_id = msg.get("node_id") @@ -1541,16 +2123,16 @@ def _handle_msg(self, msg: dict[str, Any]) -> None: if node_id is None or position is None: return for node in self.nodes: - if node.get("id") == node_id: - node["position"] = position + if self._node_id(node) == node_id: + self._node_set_position(node, position) self._emit("node_moved", msg) case "selection_changed": node_ids = msg.get("nodes") or [] edge_ids = msg.get("edges") or [] for node in self.nodes: - node["selected"] = node.get("id") in node_ids + self._node_set_selected(node, self._node_id(node) in node_ids) for edge in self.edges: - edge["selected"] = edge.get("id") in edge_ids + self._edge_set_selected(edge, self._edge_id(edge) in edge_ids) self.selection = {"nodes": list(node_ids), "edges": list(edge_ids)} self._emit("selection_changed", msg) case "edge_added": @@ -1609,7 +2191,8 @@ def remove_node(self, node_id: str) -> None: add_node : Add a node to the graph remove_edge : Remove an edge from the graph """ - nodes = [node for node in self.nodes if node.get("id") != node_id] + removed_node = next((node for node in self.nodes if self._node_id(node) == node_id), None) + nodes = [node for node in self.nodes if self._node_id(node) != node_id] removed_edges = [edge for edge in self.edges if edge.get("source") == node_id or edge.get("target") == node_id] self.nodes = nodes if removed_edges: @@ -1623,8 +2206,16 @@ def remove_node(self, node_id: str) -> None: "deleted_edges": [edge.get("id") for edge in removed_edges], }, ) + if isinstance(removed_node, Node): + payload = { + "type": "node_deleted", + "node_id": node_id, + "deleted_edges": [edge.get("id") for edge in removed_edges], + } + self._invoke_node_hook(removed_node, "on_delete", payload) + self._invoke_node_hook(removed_node, "on_event", payload) - def add_edge(self, edge: dict[str, Any] | EdgeSpec) -> None: + def add_edge(self, edge: dict[str, Any] | EdgeSpec | Edge) -> None: """Add an edge to the graph. Adds a new edge connecting two nodes with optional validation. If no @@ -1633,8 +2224,8 @@ def add_edge(self, edge: dict[str, Any] | EdgeSpec) -> None: Parameters ---------- - edge : dict or EdgeSpec - Edge dictionary or :class:`EdgeSpec` instance to add. Required + edge : dict or EdgeSpec or Edge + Edge dictionary, :class:`EdgeSpec`, or :class:`Edge` instance to add. Required fields are ``source`` and ``target``. Other fields have defaults: - ``source``: ID of the source node (required) @@ -1694,16 +2285,21 @@ def add_edge(self, edge: dict[str, Any] | EdgeSpec) -> None: remove_edge : Remove an edge from the graph EdgeSpec : Helper for constructing edge dictionaries """ - payload = self._coerce_edge(edge) + raw_edge = self._coerce_edge(edge) + payload = self._edge_payload(raw_edge) payload.setdefault("data", {}) if not payload.get("id"): payload["id"] = self._generate_edge_id(payload["source"], payload["target"]) + if isinstance(raw_edge, Edge): + raw_edge.id = payload["id"] + raw_edge.data = dict(payload["data"]) + self._sync_edge_data_params_from_data(raw_edge) self._validate_graph_payload(payload, kind="edge") if self.validate_on_add: edge_type = payload.get("type") schema = self._get_edge_schema(edge_type) if edge_type else None _validate_data(payload.get("data", {}), schema) - self.edges = self.edges + [payload] + self.edges = self.edges + [raw_edge if isinstance(raw_edge, Edge) else payload] self._emit("edge_added", {"type": "edge_added", "edge": payload}) def remove_edge(self, edge_id: str) -> None: @@ -1727,10 +2323,15 @@ def remove_edge(self, edge_id: str) -> None: add_edge : Add an edge to the graph remove_node : Remove a node from the graph """ - removed = [edge for edge in self.edges if edge.get("id") == edge_id] - self.edges = [edge for edge in self.edges if edge.get("id") != edge_id] + removed_edge = next((edge for edge in self.edges if self._edge_id(edge) == edge_id), None) + removed = [edge for edge in self.edges if self._edge_id(edge) == edge_id] + self.edges = [edge for edge in self.edges if self._edge_id(edge) != edge_id] if removed: self._emit("edge_deleted", {"type": "edge_deleted", "edge_id": edge_id}) + if isinstance(removed_edge, Edge): + payload = {"type": "edge_deleted", "edge_id": edge_id} + self._invoke_edge_hook(removed_edge, "on_delete", payload) + self._invoke_edge_hook(removed_edge, "on_event", payload) def patch_node_data(self, node_id: str, patch: dict[str, Any]) -> None: """Update specific properties in a node's data dictionary. @@ -1786,13 +2387,15 @@ def patch_node_data(self, node_id: str, patch: dict[str, Any]) -> None: add_node : Add a new node to the graph """ for node in self.nodes: - if node.get("id") == node_id: - data = dict(node.get("data", {})) + if self._node_id(node) == node_id: + data = self._node_data(node) data.update(patch) if self.validate_on_patch: - schema = self._get_node_schema(node.get("type", "panel")) + schema = self._get_node_schema(self._node_type(node)) _validate_data(data, schema) - node["data"] = data + self._node_set_data(node, data) + if isinstance(node, Node): + self._sync_node_data_params_from_data(node) break self._send_msg({"type": "patch_node_data", "node_id": node_id, "patch": patch}) self._emit("node_data_changed", {"type": "node_data_changed", "node_id": node_id, "patch": patch}) @@ -1845,14 +2448,16 @@ def patch_edge_data(self, edge_id: str, patch: dict[str, Any]) -> None: add_edge : Add a new edge to the graph """ for edge in self.edges: - if edge.get("id") == edge_id: - data = dict(edge.get("data", {})) + if self._edge_id(edge) == edge_id: + data = self._edge_data(edge) data.update(patch) if self.validate_on_patch: - edge_type = edge.get("type") + edge_type = self._edge_type(edge) schema = self._get_edge_schema(edge_type) if edge_type else None _validate_data(data, schema) - edge["data"] = data + self._edge_set_data(edge, data) + if isinstance(edge, Edge): + self._sync_edge_data_params_from_data(edge) break self._send_msg({"type": "patch_edge_data", "edge_id": edge_id, "patch": patch}) self._emit("edge_data_changed", {"type": "edge_data_changed", "edge_id": edge_id, "patch": patch}) @@ -1944,25 +2549,27 @@ def to_networkx(self, *, multigraph: bool = False): graph = nx.MultiDiGraph() if multigraph else nx.DiGraph() for node in self.nodes: - data = dict(node.get("data", {})) - data.update({"position": node.get("position"), "type": node.get("type")}) - if node.get("label") is not None: - data["label"] = node.get("label") - graph.add_node(node["id"], **data) + payload = self._node_payload(node) + data = dict(payload.get("data", {})) + data.update({"position": payload.get("position"), "type": payload.get("type")}) + if payload.get("label") is not None: + data["label"] = payload.get("label") + graph.add_node(payload["id"], **data) for edge in self.edges: - data = dict(edge.get("data", {})) - if edge.get("label") is not None: - data["label"] = edge["label"] - if edge.get("type") is not None: - data["type"] = edge["type"] - if edge.get("sourceHandle") is not None: - data["sourceHandle"] = edge["sourceHandle"] - if edge.get("targetHandle") is not None: - data["targetHandle"] = edge["targetHandle"] + payload = self._edge_payload(edge) + data = dict(payload.get("data", {})) + if payload.get("label") is not None: + data["label"] = payload["label"] + if payload.get("type") is not None: + data["type"] = payload["type"] + if payload.get("sourceHandle") is not None: + data["sourceHandle"] = payload["sourceHandle"] + if payload.get("targetHandle") is not None: + data["targetHandle"] = payload["targetHandle"] if multigraph: - graph.add_edge(edge["source"], edge["target"], key=edge.get("id"), **data) + graph.add_edge(payload["source"], payload["target"], key=payload.get("id"), **data) else: - graph.add_edge(edge["source"], edge["target"], **data) + graph.add_edge(payload["source"], payload["target"], **data) return graph @classmethod @@ -2215,11 +2822,29 @@ def _emit(self, event_type: str, payload: dict[str, Any]) -> None: else: cb = partial(callback, payload) pn.state.execute(cb) + self._dispatch_node_hooks(event_type, payload) + self._dispatch_edge_hooks(event_type, payload) def _update_selection_from_graph(self, *_: param.parameterized.Event) -> None: + selected_node_ids: list[str] = [] + for node in self.nodes: + node_id = self._node_id(node) + if node_id is None: + continue + is_selected = node.selected if isinstance(node, Node) else node.get("selected") + if is_selected: + selected_node_ids.append(node_id) + selected_edge_ids: list[str] = [] + for edge in self.edges: + edge_id = self._edge_id(edge) + if edge_id is None: + continue + is_selected = edge.selected if isinstance(edge, Edge) else edge.get("selected") + if is_selected: + selected_edge_ids.append(edge_id) selection = { - "nodes": [node["id"] for node in self.nodes if node.get("selected")], - "edges": [edge["id"] for edge in self.edges if edge.get("selected")], + "nodes": selected_node_ids, + "edges": selected_edge_ids, } if selection != self.selection: self.selection = selection @@ -2254,11 +2879,19 @@ def _generate_edge_id(source: str, target: str) -> str: return f"{existing}-{uuid4().hex[:8]}" @staticmethod - def _coerce_node(node: dict[str, Any] | NodeSpec) -> dict[str, Any]: + def _coerce_node(node: dict[str, Any] | NodeSpec | Node): + if isinstance(node, Node): + return node + if isinstance(node, NodeSpec): + return node.to_dict() return node.to_dict() if hasattr(node, "to_dict") else node @staticmethod - def _coerce_edge(edge: dict[str, Any] | EdgeSpec) -> dict[str, Any]: + def _coerce_edge(edge: dict[str, Any] | EdgeSpec | Edge): + if isinstance(edge, Edge): + return edge + if isinstance(edge, EdgeSpec): + return edge.to_dict() return edge.to_dict() if hasattr(edge, "to_dict") else edge def _validate_graph_payload(self, payload: dict[str, Any], *, kind: str) -> None: diff --git a/tests/test_api.py b/tests/test_api.py index c38fb42..4bc8d9e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,7 +8,7 @@ import param import pytest -from panel_reactflow import EdgeSpec, EdgeType, NodeSpec, NodeType, ReactFlow, SchemaSource +from panel_reactflow import Edge, EdgeSpec, EdgeType, Node, NodeSpec, NodeType, ReactFlow, SchemaSource from panel_reactflow.base import ( Editor, JsonEditor, @@ -74,6 +74,138 @@ def test_reactflow_add_node_with_nodespec_view() -> None: assert flow.nodes[0]["view"] is view +class _CountingNode(Node): + def __init__(self, **params): + super().__init__(**params) + self.events = [] + + def on_event(self, payload, flow): + self.events.append(("event", payload["type"])) + + def on_move(self, payload, flow): + self.events.append(("move", payload["position"])) + + def on_delete(self, payload, flow): + self.events.append(("delete", payload["node_id"])) + + +class _ParameterizedNode(Node): + threshold = param.Number(default=0.5, precedence=0) + hidden = param.String(default="secret", precedence=-1) + + +def test_reactflow_accepts_node_instance() -> None: + flow = ReactFlow() + node = Node(id="n1", position={"x": 0, "y": 0}, label="Node object", data={"status": "idle"}) + flow.add_node(node) + assert flow.nodes[0] is node + assert flow.nodes[0].data["status"] == "idle" + + +def test_patch_node_data_updates_node_instance() -> None: + node = Node(id="n1", position={"x": 0, "y": 0}, data={"value": 1}) + flow = ReactFlow(nodes=[node]) + flow.patch_node_data("n1", {"value": 42, "name": "patched"}) + assert node.data["value"] == 42 + assert node.data["name"] == "patched" + + +def test_sync_updates_node_instance_fields() -> None: + node = Node(id="n1", position={"x": 0, "y": 0}, data={"value": 1}, selected=False) + flow = ReactFlow(nodes=[node]) + flow._handle_msg( + { + "type": "sync", + "nodes": [ + { + "id": "n1", + "type": "panel", + "position": {"x": 10, "y": 20}, + "data": {"value": 3}, + "selected": True, + "draggable": False, + "connectable": False, + "deletable": False, + } + ], + } + ) + assert flow.nodes[0] is node + assert node.position == {"x": 10, "y": 20} + assert node.data == {"value": 3} + assert node.selected is True + assert node.draggable is False + assert node.connectable is False + assert node.deletable is False + + +def test_node_hooks_receive_events() -> None: + node = _CountingNode(id="n1", position={"x": 0, "y": 0}, data={}) + flow = ReactFlow(nodes=[node]) + flow._handle_msg({"type": "node_moved", "node_id": "n1", "position": {"x": 5, "y": 9}}) + flow.remove_node("n1") + assert ("move", {"x": 5, "y": 9}) in node.events + assert ("event", "node_moved") in node.events + assert ("delete", "n1") in node.events + assert ("event", "node_deleted") in node.events + + +def test_node_can_provide_custom_editor() -> None: + class _NodeWithEditor(Node): + def editor(self, data, schema, *, id, type, on_patch): + return pn.pane.Markdown(f"Editor for {id}") + + node = _NodeWithEditor(id="n1", position={"x": 0, "y": 0}, data={}) + flow = ReactFlow(nodes=[node]) + editor = flow._node_editors["n1"] + assert hasattr(editor, "object") + assert "n1" in editor.object + + +def test_node_subclass_params_with_non_negative_precedence_in_data_and_schema() -> None: + node = _ParameterizedNode(id="n1", type="custom", position={"x": 0, "y": 0}, data={}) + flow = ReactFlow(nodes=[node]) + payload = node.to_dict() + assert payload["data"]["threshold"] == 0.5 + assert "hidden" not in payload["data"] + schema = flow._get_node_schema("custom") + assert schema is not None + assert "threshold" in schema["properties"] + assert "hidden" not in schema["properties"] + + +def test_patch_node_data_updates_parameterized_node_params() -> None: + node = _ParameterizedNode(id="n1", type="custom", position={"x": 0, "y": 0}, data={}) + flow = ReactFlow(nodes=[node]) + flow.patch_node_data("n1", {"threshold": 0.9, "hidden": "still-hidden"}) + assert node.threshold == 0.9 + assert node.hidden == "secret" + assert node.data["threshold"] == 0.9 + assert node.data["hidden"] == "still-hidden" + + +def test_parameterized_node_param_change_auto_patches_data() -> None: + node = _ParameterizedNode(id="n1", type="custom", position={"x": 0, "y": 0}, data={}) + flow = ReactFlow(nodes=[node]) + events = [] + flow.on("node_data_changed", events.append) + node.threshold = 0.77 + assert node.data["threshold"] == 0.77 + assert events[-1]["patch"] == {"threshold": 0.77} + + +def test_parameterized_node_watchers_clean_up_on_delete() -> None: + node = _ParameterizedNode(id="n1", type="custom", position={"x": 0, "y": 0}, data={}) + flow = ReactFlow(nodes=[node]) + assert "n1" in flow._node_data_param_watchers + flow.remove_node("n1") + assert "n1" not in flow._node_data_param_watchers + events = [] + flow.on("node_data_changed", events.append) + node.threshold = 0.31 + assert events == [] + + def test_edge_spec_roundtrip() -> None: edge = EdgeSpec(id="e1", source="n1", target="n2", data={"weight": 0.5}) payload = edge.to_dict() @@ -82,6 +214,185 @@ def test_edge_spec_roundtrip() -> None: assert EdgeSpec.from_dict(payload).to_dict() == payload +class _CountingEdge(Edge): + def __init__(self, **params): + super().__init__(**params) + self.events = [] + + def on_event(self, payload, flow): + self.events.append(("event", payload["type"])) + + def on_delete(self, payload, flow): + self.events.append(("delete", payload["edge_id"])) + + +class _ParameterizedEdge(Edge): + confidence = param.Number(default=0.8, precedence=0) + internal = param.String(default="ignore", precedence=-1) + + +def test_reactflow_accepts_edge_instance() -> None: + flow = ReactFlow() + flow.add_node({"id": "n1", "position": {"x": 0, "y": 0}, "data": {}}) + flow.add_node({"id": "n2", "position": {"x": 1, "y": 1}, "data": {}}) + edge = Edge(id="e1", source="n1", target="n2", data={"weight": 1}) + flow.add_edge(edge) + assert flow.edges[0] is edge + assert flow.edges[0].data["weight"] == 1 + + +def test_patch_edge_data_updates_edge_instance() -> None: + edge = Edge(id="e1", source="n1", target="n2", data={"weight": 1}) + flow = ReactFlow( + nodes=[ + {"id": "n1", "position": {"x": 0, "y": 0}, "data": {}}, + {"id": "n2", "position": {"x": 1, "y": 1}, "data": {}}, + ], + edges=[edge], + ) + flow.patch_edge_data("e1", {"weight": 3, "label": "hi"}) + assert edge.data["weight"] == 3 + assert edge.data["label"] == "hi" + + +def test_sync_updates_edge_instance_fields() -> None: + edge = Edge(id="e1", source="n1", target="n2", data={"weight": 1}, selected=False) + flow = ReactFlow( + nodes=[ + {"id": "n1", "position": {"x": 0, "y": 0}, "data": {}}, + {"id": "n2", "position": {"x": 1, "y": 1}, "data": {}}, + ], + edges=[edge], + ) + flow._handle_msg( + { + "type": "sync", + "edges": [ + { + "id": "e1", + "source": "n1", + "target": "n2", + "label": "patched", + "type": "flow", + "selected": True, + "data": {"weight": 7}, + "sourceHandle": "out", + "targetHandle": "in", + } + ], + } + ) + assert flow.edges[0] is edge + assert edge.label == "patched" + assert edge.type == "flow" + assert edge.selected is True + assert edge.data == {"weight": 7} + assert edge.sourceHandle == "out" + assert edge.targetHandle == "in" + + +def test_edge_hooks_receive_events() -> None: + edge = _CountingEdge(id="e1", source="n1", target="n2", data={}) + flow = ReactFlow( + nodes=[ + {"id": "n1", "position": {"x": 0, "y": 0}, "data": {}}, + {"id": "n2", "position": {"x": 1, "y": 1}, "data": {}}, + ], + edges=[edge], + ) + flow.patch_edge_data("e1", {"weight": 2}) + flow.remove_edge("e1") + assert ("event", "edge_data_changed") in edge.events + assert ("delete", "e1") in edge.events + assert ("event", "edge_deleted") in edge.events + + +def test_edge_can_provide_custom_editor() -> None: + class _EdgeWithEditor(Edge): + def editor(self, data, schema, *, id, type, on_patch): + return pn.pane.Markdown(f"Edge editor for {id}") + + edge = _EdgeWithEditor(id="e1", source="n1", target="n2", data={}) + flow = ReactFlow( + nodes=[ + {"id": "n1", "position": {"x": 0, "y": 0}, "data": {}}, + {"id": "n2", "position": {"x": 1, "y": 1}, "data": {}}, + ], + edges=[edge], + ) + editor = flow._edge_editors["e1"] + assert hasattr(editor, "object") + assert "e1" in editor.object + + +def test_edge_subclass_params_with_non_negative_precedence_in_data_and_schema() -> None: + edge = _ParameterizedEdge(id="e1", source="n1", target="n2", type="weighted", data={}) + flow = ReactFlow( + nodes=[ + {"id": "n1", "position": {"x": 0, "y": 0}, "data": {}}, + {"id": "n2", "position": {"x": 1, "y": 1}, "data": {}}, + ], + edges=[edge], + ) + payload = edge.to_dict() + assert payload["data"]["confidence"] == 0.8 + assert "internal" not in payload["data"] + schema = flow._get_edge_schema("weighted") + assert schema is not None + assert "confidence" in schema["properties"] + assert "internal" not in schema["properties"] + + +def test_patch_edge_data_updates_parameterized_edge_params() -> None: + edge = _ParameterizedEdge(id="e1", source="n1", target="n2", type="weighted", data={}) + flow = ReactFlow( + nodes=[ + {"id": "n1", "position": {"x": 0, "y": 0}, "data": {}}, + {"id": "n2", "position": {"x": 1, "y": 1}, "data": {}}, + ], + edges=[edge], + ) + flow.patch_edge_data("e1", {"confidence": 0.25, "internal": "keep-data-only"}) + assert edge.confidence == 0.25 + assert edge.internal == "ignore" + assert edge.data["confidence"] == 0.25 + assert edge.data["internal"] == "keep-data-only" + + +def test_parameterized_edge_param_change_auto_patches_data() -> None: + edge = _ParameterizedEdge(id="e1", source="n1", target="n2", type="weighted", data={}) + flow = ReactFlow( + nodes=[ + {"id": "n1", "position": {"x": 0, "y": 0}, "data": {}}, + {"id": "n2", "position": {"x": 1, "y": 1}, "data": {}}, + ], + edges=[edge], + ) + events = [] + flow.on("edge_data_changed", events.append) + edge.confidence = 0.41 + assert edge.data["confidence"] == 0.41 + assert events[-1]["patch"] == {"confidence": 0.41} + + +def test_parameterized_edge_watchers_clean_up_on_delete() -> None: + edge = _ParameterizedEdge(id="e1", source="n1", target="n2", type="weighted", data={}) + flow = ReactFlow( + nodes=[ + {"id": "n1", "position": {"x": 0, "y": 0}, "data": {}}, + {"id": "n2", "position": {"x": 1, "y": 1}, "data": {}}, + ], + edges=[edge], + ) + assert "e1" in flow._edge_data_param_watchers + flow.remove_edge("e1") + assert "e1" not in flow._edge_data_param_watchers + events = [] + flow.on("edge_data_changed", events.append) + edge.confidence = 0.2 + assert events == [] + + def test_edge_spec_with_handles() -> None: """Test that EdgeSpec correctly handles sourceHandle and targetHandle.""" edge = EdgeSpec(id="e1", source="producer", target="consumer", sourceHandle="result", targetHandle="mode") diff --git a/tests/test_core.py b/tests/test_core.py index ba4eb9e..d20b810 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3,7 +3,7 @@ from panel.pane import Markdown from panel.viewable import Viewer -from panel_reactflow import ReactFlow +from panel_reactflow import Edge, Node, ReactFlow def test_reactflow_add_node_with_arbitrary_object(document, comm) -> None: @@ -79,3 +79,42 @@ def test_reactflow_add_node_dynamically_creates_views(document, comm): assert len(model.data._views) == 1 assert len(model.data._node_editor_views) == 1 + + +def test_bokeh_children_initialize_for_object_views_and_editors(document, comm) -> None: + class ViewNode(Node): + def __panel__(self): + return Markdown("Node view content") + + def editor(self, data, schema, *, id, type, on_patch): + return Markdown(f"Node editor {id}") + + class EditorEdge(Edge): + def editor(self, data, schema, *, id, type, on_patch): + return Markdown(f"Edge editor {id}") + + flow = ReactFlow( + nodes=[ + ViewNode(id="n1", position={"x": 0, "y": 0}, data={}), + {"id": "n2", "position": {"x": 150, "y": 0}, "data": {}}, + ], + edges=[EditorEdge(id="e1", source="n1", target="n2", data={})], + ) + + model = flow.get_root(document, comm=comm) + assert model.children == [ + "_views", + "_node_editor_views", + "_edge_editor_views", + "top_panel", + "bottom_panel", + "left_panel", + "right_panel", + ] + assert len(model.data._views) == 1 + assert len(model.data._node_editor_views) == 2 + assert len(model.data._edge_editor_views) == 1 + + by_id = {node["id"]: node for node in model.data.nodes} + assert by_id["n1"]["data"]["view_idx"] == 0 + assert by_id["n2"]["data"].get("view_idx") is None From c9c708ef559129f4aa01f74f538627a1b0ca2670 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Fri, 13 Mar 2026 13:33:35 +0100 Subject: [PATCH 2/2] Add instance example of threejs --- examples/threejs_viewer_instances.py | 335 +++++++++++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 examples/threejs_viewer_instances.py diff --git a/examples/threejs_viewer_instances.py b/examples/threejs_viewer_instances.py new file mode 100644 index 0000000..45b01fa --- /dev/null +++ b/examples/threejs_viewer_instances.py @@ -0,0 +1,335 @@ +"""3D Cube Viewer using Node/Edge instances. + +This mirrors ``threejs_viewer.py`` but models graph elements as ``Node`` and +``Edge`` subclasses instead of plain dictionaries. +""" + +import panel as pn +import panel_material_ui as pmui +import param + +from panel.custom import JSComponent +from panel_reactflow import Edge, Node, NodeType, ReactFlow + +pn.extension("jsoneditor") + + +class CubeViewer(JSComponent): + color = param.Color(default="#9c5afd") + num_cubes = param.Integer(default=8, bounds=(1, 64)) + cube_size = param.Number(default=0.5, bounds=(0.1, 2.0)) + rotation_speed = param.Number(default=0.01, bounds=(0.0, 0.05)) + spacing = param.Number(default=1.8, bounds=(0.5, 4.0)) + background = param.Color(default="#0f172a") + + _importmap = {"imports": {"three": "https://esm.sh/three@0.160.0"}} + + _esm = """ + import * as THREE from "three" + export function render({ model, el }) { + const W = 420, H = 300; + const scene = new THREE.Scene(); + scene.background = new THREE.Color(model.background); + const camera = new THREE.PerspectiveCamera(45, W / H, 0.1, 100); + camera.position.set(6, 4.5, 8); camera.lookAt(0, 0, 0); + const renderer = new THREE.WebGLRenderer({ antialias: true }); + renderer.setSize(W, H); renderer.setPixelRatio(Math.min(window.devicePixelRatio, 2)); + el.appendChild(renderer.domElement); + scene.add(new THREE.AmbientLight(0xffffff, 0.5)); + const key = new THREE.DirectionalLight(0xffffff, 1.0); key.position.set(5, 10, 7); scene.add(key); + const rim = new THREE.DirectionalLight(0x8b5cf6, 0.3); rim.position.set(-5, 3, -5); scene.add(rim); + const grid = new THREE.GridHelper(14, 14, 0x334155, 0x1e293b); grid.position.y = -1.2; scene.add(grid); + const group = new THREE.Group(); scene.add(group); + const material = new THREE.MeshStandardMaterial({ color: model.color, roughness: 0.3, metalness: 0.65 }); + function rebuild() { + if (group.children.length > 0) group.children[0].geometry.dispose(); + group.clear(); + const n = model.num_cubes, s = model.cube_size, sp = model.spacing; + const cols = Math.max(1, Math.ceil(Math.sqrt(n))), rows = Math.ceil(n / cols); + const geo = new THREE.BoxGeometry(s, s, s); + for (let i = 0; i < n; i++) { + const c = i % cols, r = Math.floor(i / cols); + const mesh = new THREE.Mesh(geo, material); + mesh.position.set((c - (cols - 1) / 2) * sp, 0, (r - (rows - 1) / 2) * sp); + mesh.rotation.set(Math.random() * Math.PI, Math.random() * Math.PI, 0); + group.add(mesh); + } + } + rebuild(); + let raf; + (function loop() { + raf = requestAnimationFrame(loop); + group.rotation.y += model.rotation_speed; + group.children.forEach((m, i) => { m.rotation.x += 0.003 + i * 0.0002; m.rotation.z += 0.002 + i * 0.0001; }); + renderer.render(scene, camera); + })(); + model.on("change:color", () => material.color.set(model.color)); + model.on("change:background", () => { scene.background = new THREE.Color(model.background); }); + model.on("change:num_cubes", rebuild); + model.on("change:cube_size", rebuild); + model.on("change:spacing", rebuild); + model.on("remove", () => { cancelAnimationFrame(raf); renderer.dispose(); }); + } + """ + + +STYLES = """ +.react-flow__node-viewer { padding: 0; border-radius: 12px; border: 2px solid #7c3aed; background: #0f172a; box-shadow: 0 4px 24px rgba(124, 58, 237, .15); overflow: hidden; } +.react-flow__node-viewer.selected { box-shadow: 0 0 0 2.5px rgba(124, 58, 237, .35), 0 4px 24px rgba(124, 58, 237, .2); } +.rf-node-content { padding: 0; } +.react-flow__node-color,.react-flow__node-count,.react-flow__node-size,.react-flow__node-speed,.react-flow__node-spacing,.react-flow__node-background { border-radius: 8px; border: 1.5px solid #e2e8f0; border-left: 4px solid #94a3b8; background: #fff; box-shadow: 0 1px 4px rgba(0, 0, 0, .05); min-width: 180px; } +.react-flow__node-color { border-left-color: #ec4899; } .react-flow__node-count { border-left-color: #3b82f6; } .react-flow__node-size { border-left-color: #10b981; } +.react-flow__node-speed { border-left-color: #f59e0b; } .react-flow__node-spacing { border-left-color: #06b6d4; } .react-flow__node-background { border-left-color: #64748b; } +.react-flow__edge-path { stroke: #7c3aed; stroke-width: 2px; } +""" + + +class ViewerNode(Node): + def __init__(self, viewer: CubeViewer, **params): + params.setdefault("id", "viewer") + params.setdefault("type", "viewer") + params.setdefault("label", "") + params.setdefault("position", {"x": 500, "y": 100}) + super().__init__(**params) + self._viewer = viewer + + def __panel__(self): + return self._viewer + + +class LinkEdge(Edge): + pass + + +class ControllerNode(Node): + viewer_param = "" + ctrl_type = "" + + def __init__(self, viewer: CubeViewer, **params): + params.setdefault("type", self.ctrl_type) + super().__init__(**params) + self._viewer = viewer + self._widget = self._make_widget() + + def _make_widget(self): + raise NotImplementedError + + def _param_value(self): + return getattr(self, self.viewer_param) + + def editor(self, data, schema, *, id, type, on_patch): + return self._widget + + def _push_if_connected(self, flow: ReactFlow) -> None: + for edge in flow.edges: + source = edge.source if isinstance(edge, Edge) else edge.get("source") + target = edge.target if isinstance(edge, Edge) else edge.get("target") + if source == self.id and target == "viewer": + setattr(self._viewer, self.viewer_param, self._param_value()) + return + + def on_add(self, payload, flow): + self._push_if_connected(flow) + + def on_data_change(self, payload, flow): + if payload.get("node_id") == self.id: + self._push_if_connected(flow) + + +class ColorController(ControllerNode): + ctrl_type = "color" + viewer_param = "color" + color = param.Color(default=CubeViewer.color, precedence=0) + + def _make_widget(self): + return pmui.ColorPicker.from_param(self.param.color, name="") + + +class CountController(ControllerNode): + ctrl_type = "count" + viewer_param = "num_cubes" + num_cubes = param.Integer(default=CubeViewer.num_cubes, bounds=(1, 64), precedence=0) + + def _make_widget(self): + return pmui.IntSlider.from_param(self.param.num_cubes, name="") + + +class SizeController(ControllerNode): + ctrl_type = "size" + viewer_param = "cube_size" + cube_size = param.Number(default=CubeViewer.cube_size, bounds=(0.1, 2.0), precedence=0) + + def _make_widget(self): + return pmui.FloatSlider.from_param(self.param.cube_size, step=0.05, name="") + + +class SpeedController(ControllerNode): + ctrl_type = "speed" + viewer_param = "rotation_speed" + rotation_speed = param.Number(default=CubeViewer.rotation_speed, bounds=(0.0, 0.05), precedence=0) + + def _make_widget(self): + return pmui.FloatSlider.from_param(self.param.rotation_speed, step=0.001, name="") + + +class SpacingController(ControllerNode): + ctrl_type = "spacing" + viewer_param = "spacing" + spacing = param.Number(default=CubeViewer.spacing, bounds=(0.5, 4.0), precedence=0) + + def _make_widget(self): + return pmui.FloatSlider.from_param(self.param.spacing, step=0.1, name="") + + +class BackgroundController(ControllerNode): + ctrl_type = "background" + viewer_param = "background" + background = param.Color(default=CubeViewer.background, precedence=0) + + def _make_widget(self): + return pmui.ColorPicker.from_param(self.param.background, name="") + + +CTRL_CLASSES = { + "color": ColorController, + "count": CountController, + "size": SizeController, + "speed": SpeedController, + "spacing": SpacingController, + "background": BackgroundController, +} + +LABELS = { + "color": "Color", + "count": "Cube Count", + "size": "Cube Size", + "speed": "Rotation Speed", + "spacing": "Spacing", + "background": "Background", +} + +node_types = { + "viewer": NodeType(type="viewer", label="3D Viewer", inputs=["param"]), + **{t: NodeType(type=t, label=LABELS[t], outputs=["out"]) for t in CTRL_CLASSES}, +} + +viewer_component = CubeViewer(margin=0, width=420, height=300) +viewer_node = ViewerNode(viewer=viewer_component) + +flow = ReactFlow( + nodes=[viewer_node], + edges=[], + node_types=node_types, + editor_mode="node", + stylesheets=[STYLES], + sizing_mode="stretch_both", +) + +_counter = [0] +_active_controllers: dict[int, str] = {} +_syncing = [False] + + +def _controller_by_id(node_id: str): + for node in flow.nodes: + if isinstance(node, ControllerNode) and node.id == node_id: + return node + return None + + +def _on_edge_added(payload): + edge_payload = payload.get("edge", {}) + source = edge_payload.get("source") + target = edge_payload.get("target") + if source and target == "viewer": + controller = _controller_by_id(source) + if controller is not None: + controller._push_if_connected(flow) + + +def _on_node_deleted(payload): + node_ids = payload.get("node_ids") or [payload.get("node_id")] + changed = False + for node_id in node_ids: + if node_id is None: + continue + for idx, nid in list(_active_controllers.items()): + if nid == node_id: + del _active_controllers[idx] + changed = True + break + if changed: + _syncing[0] = True + menu_tree.active = [(idx,) for idx in sorted(_active_controllers)] + _syncing[0] = False + + +flow.on("edge_added", _on_edge_added) +flow.on("node_deleted", _on_node_deleted) + + +def add_controller(ctrl_type: str) -> str: + _counter[0] += 1 + node_id = f"{ctrl_type}_{_counter[0]}" + y_pos = 30 + ((_counter[0] - 1) % 6) * 120 + controller = CTRL_CLASSES[ctrl_type]( + viewer=viewer_component, + id=node_id, + label=LABELS[ctrl_type], + position={"x": 50, "y": y_pos}, + ) + flow.add_node(controller) + return node_id + + +CTRL_ORDER = list(CTRL_CLASSES) + +menu_tree = pmui.Tree( + items=[ + {"label": "Color", "icon": "palette"}, + {"label": "Cube Count", "icon": "grid_view"}, + {"label": "Cube Size", "icon": "open_with"}, + {"label": "Rotation Speed", "icon": "speed"}, + {"label": "Spacing", "icon": "space_bar"}, + {"label": "Background", "icon": "dark_mode"}, + ], + checkboxes=True, + active=[(0,), (1,)], + width=200, + margin=5, +) + + +def _on_tree_change(event): + if _syncing[0]: + return + new_indices = {idx for (idx,) in event.new} + old_indices = set(_active_controllers) + for idx in sorted(new_indices - old_indices): + ctrl_type = CTRL_ORDER[idx] + node_id = add_controller(ctrl_type) + flow.add_edge(LinkEdge(source=node_id, target="viewer")) + _active_controllers[idx] = node_id + for idx in old_indices - new_indices: + node_id = _active_controllers.pop(idx) + flow.remove_node(node_id) + + +menu_tree.param.watch(_on_tree_change, "active") + +menu = pn.Column( + pn.pane.Markdown("#### Controllers"), + menu_tree, + width=210, + margin=(10, 5), +) +flow.left_panel = [menu] + +_color_id = add_controller("color") +_count_id = add_controller("count") +flow.add_edge(LinkEdge(source=_color_id, target="viewer")) +flow.add_edge(LinkEdge(source=_count_id, target="viewer")) +_active_controllers[0] = _color_id +_active_controllers[1] = _count_id + +pn.Column(flow, sizing_mode="stretch_both").servable()