Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/panel_reactflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ class Node(param.Parameterized):
deletable = param.Boolean(default=True, doc="Whether node is deletable.")
style = param.Dict(default=None, allow_None=True, doc="Optional node style.")
className = param.String(default=None, allow_None=True, doc="Optional CSS class.")
flow = param.Parameter(default=None, allow_None=True, precedence=-1, doc="Parent ReactFlow instance.")

@classmethod
def _data_param_names(cls) -> list[str]:
Expand Down Expand Up @@ -838,6 +839,7 @@ class Edge(param.Parameterized):
markerEnd = param.Dict(default=None, allow_None=True, doc="Optional edge end marker.")
sourceHandle = param.String(default=None, allow_None=True, doc="Optional source handle id.")
targetHandle = param.String(default=None, allow_None=True, doc="Optional target handle id.")
flow = param.Parameter(default=None, allow_None=True, precedence=-1, doc="Parent ReactFlow instance.")

@classmethod
def _data_param_names(cls) -> list[str]:
Expand Down Expand Up @@ -1417,6 +1419,8 @@ def __init__(self, **params: Any):
params["color_mode"] = "dark" if pn.config.theme == "dark" else "light"
self._node_ids: list[str] = []
self._edge_ids: list[str] = []
self._attached_node_instances: dict[int, Node] = {}
self._attached_edge_instances: dict[int, Edge] = {}
self._node_data_param_watchers: dict[str, tuple[Node, list[Any]]] = {}
self._edge_data_param_watchers: dict[str, tuple[Edge, list[Any]]] = {}
# Normalize type specs before parent init so the frontend receives
Expand All @@ -1434,6 +1438,7 @@ def __init__(self, **params: Any):
self._event_handlers: dict[str, list[Callable]] = {"*": []}
self.param.watch(self._normalize_nodes, ["nodes"])
self.param.watch(self._normalize_edges, ["edges"])
self.param.watch(self._sync_instance_flow_refs, ["nodes", "edges"])
self.param.watch(self._update_instance_data_param_watchers, ["nodes", "edges"])
self.param.watch(self._update_selection_from_graph, ["nodes", "edges"])
self.param.watch(self._normalize_specs, ["node_types", "edge_types"])
Expand All @@ -1447,6 +1452,7 @@ def __init__(self, **params: Any):
)
self._update_node_editors()
self._update_edge_editors()
self._sync_instance_flow_refs()
self._update_instance_data_param_watchers()

@classmethod
Expand Down Expand Up @@ -1663,6 +1669,25 @@ def _update_instance_data_param_watchers(self, *_: param.parameterized.Event) ->
self._update_node_data_param_watchers()
self._update_edge_data_param_watchers()

def _sync_instance_flow_refs(self, *_: param.parameterized.Event) -> None:
current_nodes = {id(node): node for node in self.nodes if isinstance(node, Node)}
for node_ref, old_node in list(self._attached_node_instances.items()):
if node_ref not in current_nodes:
old_node.flow = None
for node in current_nodes.values():
if node.flow is not self:
node.flow = self
self._attached_node_instances = current_nodes

current_edges = {id(edge): edge for edge in self.edges if isinstance(edge, Edge)}
for edge_ref, old_edge in list(self._attached_edge_instances.items()):
if edge_ref not in current_edges:
old_edge.flow = None
for edge in current_edges.values():
if edge.flow is not self:
edge.flow = self
self._attached_edge_instances = current_edges

def _update_node_data_param_watchers(self) -> None:
current_nodes = {node.id: node for node in self.nodes if isinstance(node, Node) and node.id}
for node_id, (watched_node, _) in list(self._node_data_param_watchers.items()):
Expand Down Expand Up @@ -2050,6 +2075,7 @@ def add_node(self, node: dict[str, Any] | NodeSpec | Node) -> None:
raw_node.draggable = payload["draggable"]
raw_node.connectable = payload["connectable"]
raw_node.deletable = payload["deletable"]
raw_node.flow = self
self._sync_node_data_params_from_data(raw_node)
self._validate_graph_payload(payload, kind="node")
if self.validate_on_add:
Expand Down Expand Up @@ -2207,6 +2233,7 @@ def remove_node(self, node_id: str) -> None:
},
)
if isinstance(removed_node, Node):
removed_node.flow = None
payload = {
"type": "node_deleted",
"node_id": node_id,
Expand Down Expand Up @@ -2293,6 +2320,7 @@ def add_edge(self, edge: dict[str, Any] | EdgeSpec | Edge) -> None:
if isinstance(raw_edge, Edge):
raw_edge.id = payload["id"]
raw_edge.data = dict(payload["data"])
raw_edge.flow = self
self._sync_edge_data_params_from_data(raw_edge)
self._validate_graph_payload(payload, kind="edge")
if self.validate_on_add:
Expand Down Expand Up @@ -2329,6 +2357,7 @@ def remove_edge(self, edge_id: str) -> None:
if removed:
self._emit("edge_deleted", {"type": "edge_deleted", "edge_id": edge_id})
if isinstance(removed_edge, Edge):
removed_edge.flow = None
payload = {"type": "edge_deleted", "edge_id": edge_id}
self._invoke_edge_hook(removed_edge, "on_delete", payload)
self._invoke_edge_hook(removed_edge, "on_event", payload)
Expand Down
29 changes: 29 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_reactflow_accepts_node_instance() -> None:
node = Node(id="n1", position={"x": 0, "y": 0}, label="Node object", data={"status": "idle"})
flow.add_node(node)
assert flow.nodes[0] is node
assert node.flow is flow
assert flow.nodes[0].data["status"] == "idle"


Expand Down Expand Up @@ -148,6 +149,7 @@ def test_node_hooks_receive_events() -> None:
assert ("event", "node_moved") in node.events
assert ("delete", "n1") in node.events
assert ("event", "node_deleted") in node.events
assert node.flow is None


def test_node_can_provide_custom_editor() -> None:
Expand Down Expand Up @@ -200,12 +202,22 @@ def test_parameterized_node_watchers_clean_up_on_delete() -> None:
assert "n1" in flow._node_data_param_watchers
flow.remove_node("n1")
assert "n1" not in flow._node_data_param_watchers
assert node.flow is None
events = []
flow.on("node_data_changed", events.append)
node.threshold = 0.31
assert events == []


def test_node_flow_ref_updates_on_nodes_assignment() -> None:
flow = ReactFlow()
node = Node(id="n1", position={"x": 0, "y": 0}, data={})
flow.nodes = [node]
assert node.flow is flow
flow.nodes = []
assert node.flow is None


def test_edge_spec_roundtrip() -> None:
edge = EdgeSpec(id="e1", source="n1", target="n2", data={"weight": 0.5})
payload = edge.to_dict()
Expand Down Expand Up @@ -238,6 +250,7 @@ def test_reactflow_accepts_edge_instance() -> None:
edge = Edge(id="e1", source="n1", target="n2", data={"weight": 1})
flow.add_edge(edge)
assert flow.edges[0] is edge
assert edge.flow is flow
assert flow.edges[0].data["weight"] == 1


Expand Down Expand Up @@ -305,6 +318,7 @@ def test_edge_hooks_receive_events() -> None:
assert ("event", "edge_data_changed") in edge.events
assert ("delete", "e1") in edge.events
assert ("event", "edge_deleted") in edge.events
assert edge.flow is None


def test_edge_can_provide_custom_editor() -> None:
Expand Down Expand Up @@ -387,12 +401,27 @@ def test_parameterized_edge_watchers_clean_up_on_delete() -> None:
assert "e1" in flow._edge_data_param_watchers
flow.remove_edge("e1")
assert "e1" not in flow._edge_data_param_watchers
assert edge.flow is None
events = []
flow.on("edge_data_changed", events.append)
edge.confidence = 0.2
assert events == []


def test_edge_flow_ref_updates_on_edges_assignment() -> None:
edge = Edge(id="e1", source="n1", target="n2", data={})
flow = ReactFlow(
nodes=[
{"id": "n1", "position": {"x": 0, "y": 0}, "data": {}},
{"id": "n2", "position": {"x": 1, "y": 1}, "data": {}},
]
)
flow.edges = [edge]
assert edge.flow is flow
flow.edges = []
assert edge.flow is None


def test_edge_spec_with_handles() -> None:
"""Test that EdgeSpec correctly handles sourceHandle and targetHandle."""
edge = EdgeSpec(id="e1", source="producer", target="consumer", sourceHandle="result", targetHandle="mode")
Expand Down
Loading