From 2d8740ad41727b4069d76d1eb92838a5f7409dab Mon Sep 17 00:00:00 2001 From: Josh Park <50765702+JoshParkSJ@users.noreply.github.com> Date: Thu, 5 Feb 2026 15:10:22 -0500 Subject: [PATCH 1/4] add interrupt support for convo coded agent --- src/uipath_langchain/runtime/messages.py | 83 +++++++++------ src/uipath_langchain/runtime/runtime.py | 127 ++++++++++++++++++++++- 2 files changed, 177 insertions(+), 33 deletions(-) diff --git a/src/uipath_langchain/runtime/messages.py b/src/uipath_langchain/runtime/messages.py index 1ab4e014..fa0ec11c 100644 --- a/src/uipath_langchain/runtime/messages.py +++ b/src/uipath_langchain/runtime/messages.py @@ -216,22 +216,27 @@ async def map_ai_message_chunk_to_events( ) ) - # Check if this is the last chunk by examining chunk_position, send end message event only if there are no pending tool calls + # Check if this is the last chunk by examining chunk_position if message.chunk_position == "last": if ( self.current_message.tool_calls is not None and len(self.current_message.tool_calls) > 0 ): - events.extend( - await self.map_current_message_to_start_tool_call_events() - ) + # Store tool call mappings but DON'T emit startToolCall yet + # Tool calls will be emitted when they actually execute (in map_tool_message_to_events) + await self._store_tool_call_mappings() + # Don't emit endMessage yet - will be emitted after all tool calls complete else: events.append(self.map_to_message_end_event(message.id)) return events - async def map_current_message_to_start_tool_call_events(self): - events: list[UiPathConversationMessageEvent] = [] + async def _store_tool_call_mappings(self): + """Store tool call ID to (message ID, tool call data) mappings without emitting events. + + This allows us to correlate ToolMessages with their originating AI message later, + and emit startToolCall events with full tool data when the tool actually executes. + """ if ( self.current_message and self.current_message.id is not None @@ -240,7 +245,7 @@ async def map_current_message_to_start_tool_call_events(self): async with self._storage_lock: if self.storage is not None: tool_call_id_to_message_id_map: dict[ - str, str + str, dict[str, Any] ] = await self.storage.get_value( self.runtime_id, STORAGE_NAMESPACE_EVENT_MAPPER, @@ -252,17 +257,14 @@ async def map_current_message_to_start_tool_call_events(self): else: tool_call_id_to_message_id_map = {} + # Store full tool call data for each tool call for tool_call in self.current_message.tool_calls: tool_call_id = tool_call["id"] if tool_call_id is not None: - tool_call_id_to_message_id_map[tool_call_id] = ( - self.current_message.id - ) - events.append( - self.map_tool_call_to_tool_call_start_event( - self.current_message.id, tool_call - ) - ) + tool_call_id_to_message_id_map[tool_call_id] = { + "message_id": self.current_message.id, + "tool_call": tool_call, + } if self.storage is not None: await self.storage.set_value( @@ -272,18 +274,21 @@ async def map_current_message_to_start_tool_call_events(self): tool_call_id_to_message_id_map, ) - return events - async def map_tool_message_to_events( self, message: ToolMessage ) -> list[UiPathConversationMessageEvent]: - # Look up the AI message ID using the tool_call_id - message_id, is_last_tool_call = await self.get_message_id_for_tool_call( + """Map a ToolMessage to conversation events. + + Emits both startToolCall and endToolCall events together, as the tool + has now actually executed (unlike when the AI message was generated). + """ + # Look up the AI message ID and tool call data using the tool_call_id + message_id, tool_call, is_last_tool_call = await self.get_message_id_for_tool_call( message.tool_call_id ) - if message_id is None: + if message_id is None or tool_call is None: logger.warning( - f"Tool message {message.tool_call_id} has no associated AI message ID. Skipping." + f"Tool message {message.tool_call_id} has no associated AI message ID or tool call data. Skipping." ) return [] @@ -295,7 +300,12 @@ async def map_tool_message_to_events( # Keep as string if not valid JSON pass + # Emit BOTH startToolCall and endToolCall together + # This represents the tool actually executing (not just the LLM's intention) events = [ + # First: tool call starts + self.map_tool_call_to_tool_call_start_event(message_id, tool_call), + # Then: tool call ends with result UiPathConversationMessageEvent( message_id=message_id, tool_call=UiPathConversationToolCallEvent( @@ -308,6 +318,7 @@ async def map_tool_message_to_events( ) ] + # End the AI message after all tool calls complete if is_last_tool_call: events.append(self.map_to_message_end_event(message_id)) @@ -315,16 +326,21 @@ async def map_tool_message_to_events( async def get_message_id_for_tool_call( self, tool_call_id: str - ) -> tuple[str | None, bool]: + ) -> tuple[str | None, ToolCall | None, bool]: + """Get message ID and tool call data for a given tool call ID. + + Returns: + Tuple of (message_id, tool_call, is_last_tool_call) + """ if self.storage is None: logger.error( f"attempt to lookup tool call id {tool_call_id} when no storage provided" ) - return None, False + return None, None, False async with self._storage_lock: tool_call_id_to_message_id_map: dict[ - str, str + str, dict[str, Any] ] = await self.storage.get_value( self.runtime_id, STORAGE_NAMESPACE_EVENT_MAPPER, @@ -335,14 +351,17 @@ async def get_message_id_for_tool_call( logger.error( f"attempt to lookup tool call id {tool_call_id} when no map present in storage" ) - return None, False + return None, None, False - message_id = tool_call_id_to_message_id_map.get(tool_call_id) - if message_id is None: + tool_call_info = tool_call_id_to_message_id_map.get(tool_call_id) + if tool_call_info is None: logger.error( f"tool call to message map does not contain tool call id {tool_call_id}" ) - return None, False + return None, None, False + + message_id = tool_call_info["message_id"] + tool_call = tool_call_info["tool_call"] del tool_call_id_to_message_id_map[tool_call_id] @@ -353,9 +372,13 @@ async def get_message_id_for_tool_call( tool_call_id_to_message_id_map, ) - is_last = message_id not in tool_call_id_to_message_id_map.values() + # Check if this is the last tool call by seeing if message_id appears in remaining values + is_last = not any( + info["message_id"] == message_id + for info in tool_call_id_to_message_id_map.values() + ) - return message_id, is_last + return message_id, tool_call, is_last def map_tool_call_to_tool_call_start_event( self, message_id: str, tool_call: ToolCall diff --git a/src/uipath_langchain/runtime/runtime.py b/src/uipath_langchain/runtime/runtime.py index 5e9ac903..95eb41bc 100644 --- a/src/uipath_langchain/runtime/runtime.py +++ b/src/uipath_langchain/runtime/runtime.py @@ -5,12 +5,16 @@ from langchain_core.callbacks import BaseCallbackHandler from langchain_core.runnables.config import RunnableConfig +from langchain_core.tools import BaseTool from langgraph.errors import EmptyInputError, GraphRecursionError, InvalidUpdateError from langgraph.graph.state import CompiledStateGraph +from langgraph.prebuilt import ToolNode from langgraph.types import Command, Interrupt, StateSnapshot from uipath.runtime import ( UiPathBreakpointResult, UiPathExecuteOptions, + UiPathResumeTrigger, + UiPathResumeTriggerType, UiPathRuntimeResult, UiPathRuntimeStatus, UiPathRuntimeStorageProtocol, @@ -23,10 +27,20 @@ UiPathRuntimeStateEvent, ) from uipath.runtime.schema import UiPathRuntimeSchema +from uipath.core.chat import ( + InterruptTypeEnum, + UiPathConversationGenericInterruptStart, + UiPathConversationToolCallConfirmationInterruptStart, + UiPathConversationToolCallConfirmationValue, +) from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError from uipath_langchain.runtime.messages import UiPathChatMessagesMapper -from uipath_langchain.runtime.schema import get_entrypoints_schema, get_graph_schema +from uipath_langchain.runtime.schema import ( + _unwrap_runnable_callable, + get_entrypoints_schema, + get_graph_schema, +) from ._serialize import serialize_output @@ -60,6 +74,7 @@ def __init__( self.callbacks: list[BaseCallbackHandler] = callbacks or [] self.chat = UiPathChatMessagesMapper(self.runtime_id, storage) self._middleware_node_names: set[str] = self._detect_middleware_nodes() + self._tools_by_name: dict[str, BaseTool] = self._build_tools_map() async def execute( self, @@ -229,9 +244,43 @@ async def _get_graph_input( if messages and isinstance(messages, list): graph_input["messages"] = self.chat.map_messages(messages) if options and options.resume: - return Command(resume=graph_input) + # Transform CAS generic format to LangGraph-specific format + resume_data = self._transform_interrupt_resume(graph_input) + return Command(resume=resume_data) return graph_input + def _transform_interrupt_resume(self, cas_input: dict[str, Any]) -> Any: + """Transform CAS resume input to LangGraph Command.resume format. + + The bridge includes interrupt metadata (interrupt_type, lg_interrupt_id) + alongside the widget response, allowing type-specific transformation: + - HITL tool call confirmation: transforms to middleware decisions format + - Generic interrupts: passes through the widget response directly + """ + # Bridge format: { interrupt_type, lg_interrupt_id, response: {approved, input?} } + if "response" not in cas_input: + return cas_input + + interrupt_type = cas_input.get("interrupt_type", "") + response = cas_input["response"] + + if not isinstance(response, dict): + return response + + # HITL tool call confirmation: transform to middleware decisions format + if interrupt_type == InterruptTypeEnum.TOOL_CALL_CONFIRMATION: + approved = response.get("approved", True) + modified_input = response.get("input") + if not approved: + return {"decisions": [{"type": "reject", "message": "User rejected"}]} + decision: dict[str, Any] = {"type": "approve"} + if modified_input: + decision["modified_args"] = modified_input + return {"decisions": [decision]} + + # Generic interrupt: pass through widget response directly + return response + async def _get_graph_state( self, graph_config: RunnableConfig, @@ -323,6 +372,75 @@ async def _create_runtime_result( # Normal completion return self._create_success_result(graph_output) + def _format_interrupt_value(self, value: Any) -> dict[str, Any]: + """Format interrupt value for CAS consumption.""" + # HITL middleware format: { action_requests: [...] } + if isinstance(value, dict) and "action_requests" in value: + actions = value.get("action_requests", []) + if actions: + action = actions[0] # First action for now + tool_name = action.get("name", "") + input_schema = self._get_tool_input_schema(tool_name) + return UiPathConversationToolCallConfirmationInterruptStart( + type=InterruptTypeEnum.TOOL_CALL_CONFIRMATION, + value=UiPathConversationToolCallConfirmationValue( + tool_call_id=action.get("tool_call_id", str(uuid4())), + tool_name=tool_name, + input_schema=input_schema, + input_value=action.get("args"), + ), + ).model_dump(by_alias=True) + # Already CAS-compatible or generic - pass through + return UiPathConversationGenericInterruptStart( + type="generic", + value=value, + ).model_dump(by_alias=True) + + def _get_tool_input_schema(self, tool_name: str) -> dict[str, Any]: + """Get a tool's input JSON schema by name from the graph's tools.""" + tool = self._tools_by_name.get(tool_name) + if not tool: + return {} + try: + schema = tool.args_schema.model_json_schema() + return self._resolve_schema_refs(schema) + except Exception: + return {} + + @staticmethod + def _resolve_schema_refs(schema: dict[str, Any]) -> dict[str, Any]: + """Inline $defs/$ref references so the frontend gets a flat schema.""" + defs = schema.pop("$defs", {}) + if not defs: + return schema + + def resolve(obj: Any) -> Any: + if isinstance(obj, dict): + if "$ref" in obj: + ref_name = obj["$ref"].rsplit("/", 1)[-1] + return resolve(defs.get(ref_name, obj)) + return {k: resolve(v) for k, v in obj.items()} + if isinstance(obj, list): + return [resolve(item) for item in obj] + return obj + + return resolve(schema) + + def _build_tools_map(self) -> dict[str, BaseTool]: + """Build a map of tool name -> BaseTool from the graph's ToolNode instances.""" + tools_map: dict[str, BaseTool] = {} + try: + graph = self.graph.get_graph(xray=0) + for _, node in graph.nodes.items(): + if node.data is None: + continue + tool_node = _unwrap_runnable_callable(node.data, ToolNode) + if tool_node and hasattr(tool_node, "_tools_by_name"): + tools_map.update(tool_node._tools_by_name) + except Exception: + pass + return tools_map + async def _create_suspended_result( self, graph_state: StateSnapshot, @@ -338,7 +456,9 @@ async def _create_suspended_result( if task.interrupts and interrupt in task.interrupts: # Only include if this task is still waiting for interrupt resolution if task.interrupts and not task.result: - interrupt_map[interrupt.id] = interrupt.value + interrupt_map[interrupt.id] = self._format_interrupt_value( + interrupt.value + ) break # If we have dynamic interrupts, return suspended with interrupt map @@ -347,6 +467,7 @@ async def _create_suspended_result( return UiPathRuntimeResult( output=interrupt_map, status=UiPathRuntimeStatus.SUSPENDED, + trigger=UiPathResumeTrigger(trigger_type=UiPathResumeTriggerType.API), ) else: # Static interrupt (breakpoint) From 0f5a977dbff27a18b9e9ffa122f48cde2b4acd8a Mon Sep 17 00:00:00 2001 From: Josh Park <50765702+JoshParkSJ@users.noreply.github.com> Date: Thu, 5 Feb 2026 15:46:11 -0500 Subject: [PATCH 2/4] clean up --- src/uipath_langchain/runtime/messages.py | 35 ++---------------------- src/uipath_langchain/runtime/runtime.py | 17 ++---------- 2 files changed, 5 insertions(+), 47 deletions(-) diff --git a/src/uipath_langchain/runtime/messages.py b/src/uipath_langchain/runtime/messages.py index fa0ec11c..625e158c 100644 --- a/src/uipath_langchain/runtime/messages.py +++ b/src/uipath_langchain/runtime/messages.py @@ -216,27 +216,19 @@ async def map_ai_message_chunk_to_events( ) ) - # Check if this is the last chunk by examining chunk_position + # Check if this is the last chunk by examining chunk_position, send end message event only if there are no pending tool calls if message.chunk_position == "last": if ( self.current_message.tool_calls is not None and len(self.current_message.tool_calls) > 0 ): - # Store tool call mappings but DON'T emit startToolCall yet - # Tool calls will be emitted when they actually execute (in map_tool_message_to_events) - await self._store_tool_call_mappings() - # Don't emit endMessage yet - will be emitted after all tool calls complete + await self.map_current_message_to_start_tool_call_events() else: events.append(self.map_to_message_end_event(message.id)) return events - async def _store_tool_call_mappings(self): - """Store tool call ID to (message ID, tool call data) mappings without emitting events. - - This allows us to correlate ToolMessages with their originating AI message later, - and emit startToolCall events with full tool data when the tool actually executes. - """ + async def map_current_message_to_start_tool_call_events(self): if ( self.current_message and self.current_message.id is not None @@ -252,12 +244,9 @@ async def _store_tool_call_mappings(self): STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP, ) - if tool_call_id_to_message_id_map is None: - tool_call_id_to_message_id_map = {} else: tool_call_id_to_message_id_map = {} - # Store full tool call data for each tool call for tool_call in self.current_message.tool_calls: tool_call_id = tool_call["id"] if tool_call_id is not None: @@ -277,12 +266,6 @@ async def _store_tool_call_mappings(self): async def map_tool_message_to_events( self, message: ToolMessage ) -> list[UiPathConversationMessageEvent]: - """Map a ToolMessage to conversation events. - - Emits both startToolCall and endToolCall events together, as the tool - has now actually executed (unlike when the AI message was generated). - """ - # Look up the AI message ID and tool call data using the tool_call_id message_id, tool_call, is_last_tool_call = await self.get_message_id_for_tool_call( message.tool_call_id ) @@ -297,15 +280,10 @@ async def map_tool_message_to_events( try: content_value = json.loads(content_value) except (json.JSONDecodeError, TypeError): - # Keep as string if not valid JSON pass - # Emit BOTH startToolCall and endToolCall together - # This represents the tool actually executing (not just the LLM's intention) events = [ - # First: tool call starts self.map_tool_call_to_tool_call_start_event(message_id, tool_call), - # Then: tool call ends with result UiPathConversationMessageEvent( message_id=message_id, tool_call=UiPathConversationToolCallEvent( @@ -318,7 +296,6 @@ async def map_tool_message_to_events( ) ] - # End the AI message after all tool calls complete if is_last_tool_call: events.append(self.map_to_message_end_event(message_id)) @@ -327,11 +304,6 @@ async def map_tool_message_to_events( async def get_message_id_for_tool_call( self, tool_call_id: str ) -> tuple[str | None, ToolCall | None, bool]: - """Get message ID and tool call data for a given tool call ID. - - Returns: - Tuple of (message_id, tool_call, is_last_tool_call) - """ if self.storage is None: logger.error( f"attempt to lookup tool call id {tool_call_id} when no storage provided" @@ -372,7 +344,6 @@ async def get_message_id_for_tool_call( tool_call_id_to_message_id_map, ) - # Check if this is the last tool call by seeing if message_id appears in remaining values is_last = not any( info["message_id"] == message_id for info in tool_call_id_to_message_id_map.values() diff --git a/src/uipath_langchain/runtime/runtime.py b/src/uipath_langchain/runtime/runtime.py index 95eb41bc..dbab9032 100644 --- a/src/uipath_langchain/runtime/runtime.py +++ b/src/uipath_langchain/runtime/runtime.py @@ -244,20 +244,12 @@ async def _get_graph_input( if messages and isinstance(messages, list): graph_input["messages"] = self.chat.map_messages(messages) if options and options.resume: - # Transform CAS generic format to LangGraph-specific format resume_data = self._transform_interrupt_resume(graph_input) return Command(resume=resume_data) return graph_input def _transform_interrupt_resume(self, cas_input: dict[str, Any]) -> Any: - """Transform CAS resume input to LangGraph Command.resume format. - - The bridge includes interrupt metadata (interrupt_type, lg_interrupt_id) - alongside the widget response, allowing type-specific transformation: - - HITL tool call confirmation: transforms to middleware decisions format - - Generic interrupts: passes through the widget response directly - """ - # Bridge format: { interrupt_type, lg_interrupt_id, response: {approved, input?} } + """Transform CAS resume input to LangGraph Command.resume format.""" if "response" not in cas_input: return cas_input @@ -267,7 +259,6 @@ def _transform_interrupt_resume(self, cas_input: dict[str, Any]) -> Any: if not isinstance(response, dict): return response - # HITL tool call confirmation: transform to middleware decisions format if interrupt_type == InterruptTypeEnum.TOOL_CALL_CONFIRMATION: approved = response.get("approved", True) modified_input = response.get("input") @@ -278,7 +269,6 @@ def _transform_interrupt_resume(self, cas_input: dict[str, Any]) -> Any: decision["modified_args"] = modified_input return {"decisions": [decision]} - # Generic interrupt: pass through widget response directly return response async def _get_graph_state( @@ -374,7 +364,6 @@ async def _create_runtime_result( def _format_interrupt_value(self, value: Any) -> dict[str, Any]: """Format interrupt value for CAS consumption.""" - # HITL middleware format: { action_requests: [...] } if isinstance(value, dict) and "action_requests" in value: actions = value.get("action_requests", []) if actions: @@ -390,14 +379,12 @@ def _format_interrupt_value(self, value: Any) -> dict[str, Any]: input_value=action.get("args"), ), ).model_dump(by_alias=True) - # Already CAS-compatible or generic - pass through return UiPathConversationGenericInterruptStart( type="generic", value=value, ).model_dump(by_alias=True) def _get_tool_input_schema(self, tool_name: str) -> dict[str, Any]: - """Get a tool's input JSON schema by name from the graph's tools.""" tool = self._tools_by_name.get(tool_name) if not tool: return {} @@ -431,7 +418,7 @@ def _build_tools_map(self) -> dict[str, BaseTool]: tools_map: dict[str, BaseTool] = {} try: graph = self.graph.get_graph(xray=0) - for _, node in graph.nodes.items(): + for node in graph.nodes.values(): if node.data is None: continue tool_node = _unwrap_runnable_callable(node.data, ToolNode) From 4ebe0ceb637fc87d146f41cb420b56e1ee1f3829 Mon Sep 17 00:00:00 2001 From: Josh Park <50765702+JoshParkSJ@users.noreply.github.com> Date: Thu, 5 Feb 2026 16:14:11 -0500 Subject: [PATCH 3/4] add better types --- src/uipath_langchain/runtime/messages.py | 30 ++++++++++++++---------- src/uipath_langchain/runtime/runtime.py | 30 +++++++++++++++--------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/src/uipath_langchain/runtime/messages.py b/src/uipath_langchain/runtime/messages.py index 625e158c..9c546c76 100644 --- a/src/uipath_langchain/runtime/messages.py +++ b/src/uipath_langchain/runtime/messages.py @@ -236,7 +236,7 @@ async def map_current_message_to_start_tool_call_events(self): ): async with self._storage_lock: if self.storage is not None: - tool_call_id_to_message_id_map: dict[ + tool_call_map: dict[ str, dict[str, Any] ] = await self.storage.get_value( self.runtime_id, @@ -244,13 +244,15 @@ async def map_current_message_to_start_tool_call_events(self): STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP, ) + if tool_call_map is None: + tool_call_map = {} else: - tool_call_id_to_message_id_map = {} + tool_call_map = {} for tool_call in self.current_message.tool_calls: tool_call_id = tool_call["id"] if tool_call_id is not None: - tool_call_id_to_message_id_map[tool_call_id] = { + tool_call_map[tool_call_id] = { "message_id": self.current_message.id, "tool_call": tool_call, } @@ -260,18 +262,19 @@ async def map_current_message_to_start_tool_call_events(self): self.runtime_id, STORAGE_NAMESPACE_EVENT_MAPPER, STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP, - tool_call_id_to_message_id_map, + tool_call_map, ) async def map_tool_message_to_events( self, message: ToolMessage ) -> list[UiPathConversationMessageEvent]: + # Look up the AI message ID using the tool_call_id message_id, tool_call, is_last_tool_call = await self.get_message_id_for_tool_call( message.tool_call_id ) - if message_id is None or tool_call is None: + if message_id is None: logger.warning( - f"Tool message {message.tool_call_id} has no associated AI message ID or tool call data. Skipping." + f"Tool message {message.tool_call_id} has no associated AI message ID. Skipping." ) return [] @@ -280,6 +283,7 @@ async def map_tool_message_to_events( try: content_value = json.loads(content_value) except (json.JSONDecodeError, TypeError): + # Keep as string if not valid JSON pass events = [ @@ -311,21 +315,21 @@ async def get_message_id_for_tool_call( return None, None, False async with self._storage_lock: - tool_call_id_to_message_id_map: dict[ - str, dict[str, Any] + tool_call_map: dict[ + str, dict[str, Any] # tool_call_id -> {message_id, tool_call} ] = await self.storage.get_value( self.runtime_id, STORAGE_NAMESPACE_EVENT_MAPPER, STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP, ) - if tool_call_id_to_message_id_map is None: + if tool_call_map is None: logger.error( f"attempt to lookup tool call id {tool_call_id} when no map present in storage" ) return None, None, False - tool_call_info = tool_call_id_to_message_id_map.get(tool_call_id) + tool_call_info = tool_call_map.get(tool_call_id) if tool_call_info is None: logger.error( f"tool call to message map does not contain tool call id {tool_call_id}" @@ -335,18 +339,18 @@ async def get_message_id_for_tool_call( message_id = tool_call_info["message_id"] tool_call = tool_call_info["tool_call"] - del tool_call_id_to_message_id_map[tool_call_id] + del tool_call_map[tool_call_id] await self.storage.set_value( self.runtime_id, STORAGE_NAMESPACE_EVENT_MAPPER, STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP, - tool_call_id_to_message_id_map, + tool_call_map, ) is_last = not any( info["message_id"] == message_id - for info in tool_call_id_to_message_id_map.values() + for info in tool_call_map.values() ) return message_id, tool_call, is_last diff --git a/src/uipath_langchain/runtime/runtime.py b/src/uipath_langchain/runtime/runtime.py index dbab9032..c44b8799 100644 --- a/src/uipath_langchain/runtime/runtime.py +++ b/src/uipath_langchain/runtime/runtime.py @@ -3,6 +3,14 @@ from typing import Any, AsyncGenerator from uuid import uuid4 +from langchain.agents.middleware.human_in_the_loop import ( + ActionRequest, + ApproveDecision, + EditDecision, + HITLRequest, + HITLResponse, + RejectDecision, +) from langchain_core.callbacks import BaseCallbackHandler from langchain_core.runnables.config import RunnableConfig from langchain_core.tools import BaseTool @@ -261,13 +269,12 @@ def _transform_interrupt_resume(self, cas_input: dict[str, Any]) -> Any: if interrupt_type == InterruptTypeEnum.TOOL_CALL_CONFIRMATION: approved = response.get("approved", True) - modified_input = response.get("input") if not approved: - return {"decisions": [{"type": "reject", "message": "User rejected"}]} - decision: dict[str, Any] = {"type": "approve"} - if modified_input: - decision["modified_args"] = modified_input - return {"decisions": [decision]} + return HITLResponse(decisions=[RejectDecision(type="reject", message="User rejected")]) + edited_args = response.get("input") + if edited_args: + return HITLResponse(decisions=[EditDecision(type="edit", edited_action=edited_args)]) + return HITLResponse(decisions=[ApproveDecision(type="approve")]) return response @@ -365,18 +372,19 @@ async def _create_runtime_result( def _format_interrupt_value(self, value: Any) -> dict[str, Any]: """Format interrupt value for CAS consumption.""" if isinstance(value, dict) and "action_requests" in value: - actions = value.get("action_requests", []) + request: HITLRequest = value + actions = request["action_requests"] if actions: - action = actions[0] # First action for now - tool_name = action.get("name", "") + action: ActionRequest = actions[0] + tool_name = action["name"] input_schema = self._get_tool_input_schema(tool_name) return UiPathConversationToolCallConfirmationInterruptStart( type=InterruptTypeEnum.TOOL_CALL_CONFIRMATION, value=UiPathConversationToolCallConfirmationValue( - tool_call_id=action.get("tool_call_id", str(uuid4())), + tool_call_id=str(uuid4()), tool_name=tool_name, input_schema=input_schema, - input_value=action.get("args"), + input_value=action["args"], ), ).model_dump(by_alias=True) return UiPathConversationGenericInterruptStart( From 16a5ae9be6a496c751d8beda47b6eddf3affa666 Mon Sep 17 00:00:00 2001 From: Josh Park <50765702+JoshParkSJ@users.noreply.github.com> Date: Thu, 5 Feb 2026 16:26:49 -0500 Subject: [PATCH 4/4] handle edited arguments path --- src/uipath_langchain/runtime/runtime.py | 30 ++++++------------------- 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/src/uipath_langchain/runtime/runtime.py b/src/uipath_langchain/runtime/runtime.py index c44b8799..ffdfa3cd 100644 --- a/src/uipath_langchain/runtime/runtime.py +++ b/src/uipath_langchain/runtime/runtime.py @@ -4,6 +4,7 @@ from uuid import uuid4 from langchain.agents.middleware.human_in_the_loop import ( + Action, ActionRequest, ApproveDecision, EditDecision, @@ -83,6 +84,7 @@ def __init__( self.chat = UiPathChatMessagesMapper(self.runtime_id, storage) self._middleware_node_names: set[str] = self._detect_middleware_nodes() self._tools_by_name: dict[str, BaseTool] = self._build_tools_map() + self._last_hitl_tool_name: str | None = None async def execute( self, @@ -272,8 +274,9 @@ def _transform_interrupt_resume(self, cas_input: dict[str, Any]) -> Any: if not approved: return HITLResponse(decisions=[RejectDecision(type="reject", message="User rejected")]) edited_args = response.get("input") - if edited_args: - return HITLResponse(decisions=[EditDecision(type="edit", edited_action=edited_args)]) + if edited_args and self._last_hitl_tool_name: + edited_action = Action(name=self._last_hitl_tool_name, args=edited_args) + return HITLResponse(decisions=[EditDecision(type="edit", edited_action=edited_action)]) return HITLResponse(decisions=[ApproveDecision(type="approve")]) return response @@ -377,6 +380,7 @@ def _format_interrupt_value(self, value: Any) -> dict[str, Any]: if actions: action: ActionRequest = actions[0] tool_name = action["name"] + self._last_hitl_tool_name = tool_name input_schema = self._get_tool_input_schema(tool_name) return UiPathConversationToolCallConfirmationInterruptStart( type=InterruptTypeEnum.TOOL_CALL_CONFIRMATION, @@ -397,30 +401,10 @@ def _get_tool_input_schema(self, tool_name: str) -> dict[str, Any]: if not tool: return {} try: - schema = tool.args_schema.model_json_schema() - return self._resolve_schema_refs(schema) + return tool.args_schema.model_json_schema() except Exception: return {} - @staticmethod - def _resolve_schema_refs(schema: dict[str, Any]) -> dict[str, Any]: - """Inline $defs/$ref references so the frontend gets a flat schema.""" - defs = schema.pop("$defs", {}) - if not defs: - return schema - - def resolve(obj: Any) -> Any: - if isinstance(obj, dict): - if "$ref" in obj: - ref_name = obj["$ref"].rsplit("/", 1)[-1] - return resolve(defs.get(ref_name, obj)) - return {k: resolve(v) for k, v in obj.items()} - if isinstance(obj, list): - return [resolve(item) for item in obj] - return obj - - return resolve(schema) - def _build_tools_map(self) -> dict[str, BaseTool]: """Build a map of tool name -> BaseTool from the graph's ToolNode instances.""" tools_map: dict[str, BaseTool] = {}