diff --git a/src/panel_reactflow/base.py b/src/panel_reactflow/base.py index a4ff643..9baf0d6 100644 --- a/src/panel_reactflow/base.py +++ b/src/panel_reactflow/base.py @@ -597,6 +597,7 @@ class Node(param.Parameterized): deletable = param.Boolean(default=True, doc="Whether node is deletable.") style = param.Dict(default=None, allow_None=True, doc="Optional node style.") className = param.String(default=None, allow_None=True, doc="Optional CSS class.") + flow = param.Parameter(default=None, allow_None=True, precedence=-1, doc="Parent ReactFlow instance.") @classmethod def _data_param_names(cls) -> list[str]: @@ -838,6 +839,7 @@ class Edge(param.Parameterized): markerEnd = param.Dict(default=None, allow_None=True, doc="Optional edge end marker.") sourceHandle = param.String(default=None, allow_None=True, doc="Optional source handle id.") targetHandle = param.String(default=None, allow_None=True, doc="Optional target handle id.") + flow = param.Parameter(default=None, allow_None=True, precedence=-1, doc="Parent ReactFlow instance.") @classmethod def _data_param_names(cls) -> list[str]: @@ -1417,6 +1419,8 @@ def __init__(self, **params: Any): params["color_mode"] = "dark" if pn.config.theme == "dark" else "light" self._node_ids: list[str] = [] self._edge_ids: list[str] = [] + self._attached_node_instances: dict[int, Node] = {} + self._attached_edge_instances: dict[int, Edge] = {} self._node_data_param_watchers: dict[str, tuple[Node, list[Any]]] = {} self._edge_data_param_watchers: dict[str, tuple[Edge, list[Any]]] = {} # Normalize type specs before parent init so the frontend receives @@ -1434,6 +1438,7 @@ def __init__(self, **params: Any): self._event_handlers: dict[str, list[Callable]] = {"*": []} self.param.watch(self._normalize_nodes, ["nodes"]) self.param.watch(self._normalize_edges, ["edges"]) + self.param.watch(self._sync_instance_flow_refs, ["nodes", "edges"]) self.param.watch(self._update_instance_data_param_watchers, ["nodes", "edges"]) self.param.watch(self._update_selection_from_graph, ["nodes", "edges"]) self.param.watch(self._normalize_specs, ["node_types", "edge_types"]) @@ -1447,6 +1452,7 @@ def __init__(self, **params: Any): ) self._update_node_editors() self._update_edge_editors() + self._sync_instance_flow_refs() self._update_instance_data_param_watchers() @classmethod @@ -1663,6 +1669,25 @@ def _update_instance_data_param_watchers(self, *_: param.parameterized.Event) -> self._update_node_data_param_watchers() self._update_edge_data_param_watchers() + def _sync_instance_flow_refs(self, *_: param.parameterized.Event) -> None: + current_nodes = {id(node): node for node in self.nodes if isinstance(node, Node)} + for node_ref, old_node in list(self._attached_node_instances.items()): + if node_ref not in current_nodes: + old_node.flow = None + for node in current_nodes.values(): + if node.flow is not self: + node.flow = self + self._attached_node_instances = current_nodes + + current_edges = {id(edge): edge for edge in self.edges if isinstance(edge, Edge)} + for edge_ref, old_edge in list(self._attached_edge_instances.items()): + if edge_ref not in current_edges: + old_edge.flow = None + for edge in current_edges.values(): + if edge.flow is not self: + edge.flow = self + self._attached_edge_instances = current_edges + def _update_node_data_param_watchers(self) -> None: current_nodes = {node.id: node for node in self.nodes if isinstance(node, Node) and node.id} for node_id, (watched_node, _) in list(self._node_data_param_watchers.items()): @@ -2050,6 +2075,7 @@ def add_node(self, node: dict[str, Any] | NodeSpec | Node) -> None: raw_node.draggable = payload["draggable"] raw_node.connectable = payload["connectable"] raw_node.deletable = payload["deletable"] + raw_node.flow = self self._sync_node_data_params_from_data(raw_node) self._validate_graph_payload(payload, kind="node") if self.validate_on_add: @@ -2207,6 +2233,7 @@ def remove_node(self, node_id: str) -> None: }, ) if isinstance(removed_node, Node): + removed_node.flow = None payload = { "type": "node_deleted", "node_id": node_id, @@ -2293,6 +2320,7 @@ def add_edge(self, edge: dict[str, Any] | EdgeSpec | Edge) -> None: if isinstance(raw_edge, Edge): raw_edge.id = payload["id"] raw_edge.data = dict(payload["data"]) + raw_edge.flow = self self._sync_edge_data_params_from_data(raw_edge) self._validate_graph_payload(payload, kind="edge") if self.validate_on_add: @@ -2329,6 +2357,7 @@ def remove_edge(self, edge_id: str) -> None: if removed: self._emit("edge_deleted", {"type": "edge_deleted", "edge_id": edge_id}) if isinstance(removed_edge, Edge): + removed_edge.flow = None payload = {"type": "edge_deleted", "edge_id": edge_id} self._invoke_edge_hook(removed_edge, "on_delete", payload) self._invoke_edge_hook(removed_edge, "on_event", payload) diff --git a/tests/test_api.py b/tests/test_api.py index 4bc8d9e..99ca30e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -99,6 +99,7 @@ def test_reactflow_accepts_node_instance() -> None: node = Node(id="n1", position={"x": 0, "y": 0}, label="Node object", data={"status": "idle"}) flow.add_node(node) assert flow.nodes[0] is node + assert node.flow is flow assert flow.nodes[0].data["status"] == "idle" @@ -148,6 +149,7 @@ def test_node_hooks_receive_events() -> None: assert ("event", "node_moved") in node.events assert ("delete", "n1") in node.events assert ("event", "node_deleted") in node.events + assert node.flow is None def test_node_can_provide_custom_editor() -> None: @@ -200,12 +202,22 @@ def test_parameterized_node_watchers_clean_up_on_delete() -> None: assert "n1" in flow._node_data_param_watchers flow.remove_node("n1") assert "n1" not in flow._node_data_param_watchers + assert node.flow is None events = [] flow.on("node_data_changed", events.append) node.threshold = 0.31 assert events == [] +def test_node_flow_ref_updates_on_nodes_assignment() -> None: + flow = ReactFlow() + node = Node(id="n1", position={"x": 0, "y": 0}, data={}) + flow.nodes = [node] + assert node.flow is flow + flow.nodes = [] + assert node.flow is None + + def test_edge_spec_roundtrip() -> None: edge = EdgeSpec(id="e1", source="n1", target="n2", data={"weight": 0.5}) payload = edge.to_dict() @@ -238,6 +250,7 @@ def test_reactflow_accepts_edge_instance() -> None: edge = Edge(id="e1", source="n1", target="n2", data={"weight": 1}) flow.add_edge(edge) assert flow.edges[0] is edge + assert edge.flow is flow assert flow.edges[0].data["weight"] == 1 @@ -305,6 +318,7 @@ def test_edge_hooks_receive_events() -> None: assert ("event", "edge_data_changed") in edge.events assert ("delete", "e1") in edge.events assert ("event", "edge_deleted") in edge.events + assert edge.flow is None def test_edge_can_provide_custom_editor() -> None: @@ -387,12 +401,27 @@ def test_parameterized_edge_watchers_clean_up_on_delete() -> None: assert "e1" in flow._edge_data_param_watchers flow.remove_edge("e1") assert "e1" not in flow._edge_data_param_watchers + assert edge.flow is None events = [] flow.on("edge_data_changed", events.append) edge.confidence = 0.2 assert events == [] +def test_edge_flow_ref_updates_on_edges_assignment() -> None: + edge = Edge(id="e1", source="n1", target="n2", data={}) + flow = ReactFlow( + nodes=[ + {"id": "n1", "position": {"x": 0, "y": 0}, "data": {}}, + {"id": "n2", "position": {"x": 1, "y": 1}, "data": {}}, + ] + ) + flow.edges = [edge] + assert edge.flow is flow + flow.edges = [] + assert edge.flow is None + + def test_edge_spec_with_handles() -> None: """Test that EdgeSpec correctly handles sourceHandle and targetHandle.""" edge = EdgeSpec(id="e1", source="producer", target="consumer", sourceHandle="result", targetHandle="mode")