Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 32 additions & 34 deletions src/uipath_langchain/runtime/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,63 +222,54 @@ async def map_ai_message_chunk_to_events(
self.current_message.tool_calls is not None
and len(self.current_message.tool_calls) > 0
):
events.extend(
await self.map_current_message_to_start_tool_call_events()
)
await self.map_current_message_to_start_tool_call_events()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defer startToolCall and endToolCall to next message so we can guarantee message ends with the start/end tool calls

else:
events.append(self.map_to_message_end_event(message.id))

return events

async def map_current_message_to_start_tool_call_events(self):
events: list[UiPathConversationMessageEvent] = []
if (
self.current_message
and self.current_message.id is not None
and self.current_message.tool_calls
):
async with self._storage_lock:
if self.storage is not None:
tool_call_id_to_message_id_map: dict[
str, str
tool_call_map: dict[
str, dict[str, Any]
] = await self.storage.get_value(
self.runtime_id,
STORAGE_NAMESPACE_EVENT_MAPPER,
STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP,
)

if tool_call_id_to_message_id_map is None:
tool_call_id_to_message_id_map = {}
if tool_call_map is None:
tool_call_map = {}
else:
tool_call_id_to_message_id_map = {}
tool_call_map = {}

for tool_call in self.current_message.tool_calls:
tool_call_id = tool_call["id"]
if tool_call_id is not None:
tool_call_id_to_message_id_map[tool_call_id] = (
self.current_message.id
)
events.append(
self.map_tool_call_to_tool_call_start_event(
self.current_message.id, tool_call
)
)
tool_call_map[tool_call_id] = {
"message_id": self.current_message.id,
"tool_call": tool_call,
}

if self.storage is not None:
await self.storage.set_value(
self.runtime_id,
STORAGE_NAMESPACE_EVENT_MAPPER,
STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP,
tool_call_id_to_message_id_map,
tool_call_map,
)

return events

async def map_tool_message_to_events(
self, message: ToolMessage
) -> list[UiPathConversationMessageEvent]:
# Look up the AI message ID using the tool_call_id
message_id, is_last_tool_call = await self.get_message_id_for_tool_call(
message_id, tool_call, is_last_tool_call = await self.get_message_id_for_tool_call(
message.tool_call_id
)
if message_id is None:
Expand All @@ -296,6 +287,7 @@ async def map_tool_message_to_events(
pass

events = [
self.map_tool_call_to_tool_call_start_event(message_id, tool_call),
UiPathConversationMessageEvent(
message_id=message_id,
tool_call=UiPathConversationToolCallEvent(
Expand All @@ -315,47 +307,53 @@ async def map_tool_message_to_events(

async def get_message_id_for_tool_call(
self, tool_call_id: str
) -> tuple[str | None, bool]:
) -> tuple[str | None, ToolCall | None, bool]:
if self.storage is None:
logger.error(
f"attempt to lookup tool call id {tool_call_id} when no storage provided"
)
return None, False
return None, None, False

async with self._storage_lock:
tool_call_id_to_message_id_map: dict[
str, str
tool_call_map: dict[
str, dict[str, Any] # tool_call_id -> {message_id, tool_call}
] = await self.storage.get_value(
self.runtime_id,
STORAGE_NAMESPACE_EVENT_MAPPER,
STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP,
)

if tool_call_id_to_message_id_map is None:
if tool_call_map is None:
logger.error(
f"attempt to lookup tool call id {tool_call_id} when no map present in storage"
)
return None, False
return None, None, False

message_id = tool_call_id_to_message_id_map.get(tool_call_id)
if message_id is None:
tool_call_info = tool_call_map.get(tool_call_id)
if tool_call_info is None:
logger.error(
f"tool call to message map does not contain tool call id {tool_call_id}"
)
return None, False
return None, None, False

del tool_call_id_to_message_id_map[tool_call_id]
message_id = tool_call_info["message_id"]
tool_call = tool_call_info["tool_call"]

del tool_call_map[tool_call_id]

await self.storage.set_value(
self.runtime_id,
STORAGE_NAMESPACE_EVENT_MAPPER,
STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP,
tool_call_id_to_message_id_map,
tool_call_map,
)

is_last = message_id not in tool_call_id_to_message_id_map.values()
is_last = not any(
info["message_id"] == message_id
for info in tool_call_map.values()
)

return message_id, is_last
return message_id, tool_call, is_last

def map_tool_call_to_tool_call_start_event(
self, message_id: str, tool_call: ToolCall
Expand Down
106 changes: 103 additions & 3 deletions src/uipath_langchain/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,27 @@
from typing import Any, AsyncGenerator
from uuid import uuid4

from langchain.agents.middleware.human_in_the_loop import (
Action,
ActionRequest,
ApproveDecision,
EditDecision,
HITLRequest,
HITLResponse,
RejectDecision,
)
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.errors import EmptyInputError, GraphRecursionError, InvalidUpdateError
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode
from langgraph.types import Command, Interrupt, StateSnapshot
from uipath.runtime import (
UiPathBreakpointResult,
UiPathExecuteOptions,
UiPathResumeTrigger,
UiPathResumeTriggerType,
UiPathRuntimeResult,
UiPathRuntimeStatus,
UiPathRuntimeStorageProtocol,
Expand All @@ -23,10 +36,20 @@
UiPathRuntimeStateEvent,
)
from uipath.runtime.schema import UiPathRuntimeSchema
from uipath.core.chat import (
InterruptTypeEnum,
UiPathConversationGenericInterruptStart,
UiPathConversationToolCallConfirmationInterruptStart,
UiPathConversationToolCallConfirmationValue,
)

from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError
from uipath_langchain.runtime.messages import UiPathChatMessagesMapper
from uipath_langchain.runtime.schema import get_entrypoints_schema, get_graph_schema
from uipath_langchain.runtime.schema import (
_unwrap_runnable_callable,
get_entrypoints_schema,
get_graph_schema,
)

from ._serialize import serialize_output

Expand Down Expand Up @@ -60,6 +83,8 @@ def __init__(
self.callbacks: list[BaseCallbackHandler] = callbacks or []
self.chat = UiPathChatMessagesMapper(self.runtime_id, storage)
self._middleware_node_names: set[str] = self._detect_middleware_nodes()
self._tools_by_name: dict[str, BaseTool] = self._build_tools_map()
self._last_hitl_tool_name: str | None = None

async def execute(
self,
Expand Down Expand Up @@ -229,9 +254,33 @@ async def _get_graph_input(
if messages and isinstance(messages, list):
graph_input["messages"] = self.chat.map_messages(messages)
if options and options.resume:
return Command(resume=graph_input)
resume_data = self._transform_interrupt_resume(graph_input)
return Command(resume=resume_data)
return graph_input

def _transform_interrupt_resume(self, cas_input: dict[str, Any]) -> Any:
"""Transform CAS resume input to LangGraph Command.resume format."""
if "response" not in cas_input:
return cas_input

interrupt_type = cas_input.get("interrupt_type", "")
response = cas_input["response"]

if not isinstance(response, dict):
return response

if interrupt_type == InterruptTypeEnum.TOOL_CALL_CONFIRMATION:
approved = response.get("approved", True)
if not approved:
return HITLResponse(decisions=[RejectDecision(type="reject", message="User rejected")])
edited_args = response.get("input")
if edited_args and self._last_hitl_tool_name:
edited_action = Action(name=self._last_hitl_tool_name, args=edited_args)
return HITLResponse(decisions=[EditDecision(type="edit", edited_action=edited_action)])
return HITLResponse(decisions=[ApproveDecision(type="approve")])

return response

async def _get_graph_state(
self,
graph_config: RunnableConfig,
Expand Down Expand Up @@ -323,6 +372,54 @@ async def _create_runtime_result(
# Normal completion
return self._create_success_result(graph_output)

def _format_interrupt_value(self, value: Any) -> dict[str, Any]:
"""Format interrupt value for CAS consumption."""
if isinstance(value, dict) and "action_requests" in value:
request: HITLRequest = value
actions = request["action_requests"]
if actions:
action: ActionRequest = actions[0]
tool_name = action["name"]
self._last_hitl_tool_name = tool_name
input_schema = self._get_tool_input_schema(tool_name)
return UiPathConversationToolCallConfirmationInterruptStart(
type=InterruptTypeEnum.TOOL_CALL_CONFIRMATION,
value=UiPathConversationToolCallConfirmationValue(
tool_call_id=str(uuid4()),
tool_name=tool_name,
input_schema=input_schema,
input_value=action["args"],
),
).model_dump(by_alias=True)
return UiPathConversationGenericInterruptStart(
type="generic",
value=value,
).model_dump(by_alias=True)

def _get_tool_input_schema(self, tool_name: str) -> dict[str, Any]:
tool = self._tools_by_name.get(tool_name)
if not tool:
return {}
try:
return tool.args_schema.model_json_schema()
except Exception:
return {}

def _build_tools_map(self) -> dict[str, BaseTool]:
"""Build a map of tool name -> BaseTool from the graph's ToolNode instances."""
tools_map: dict[str, BaseTool] = {}
try:
graph = self.graph.get_graph(xray=0)
for node in graph.nodes.values():
if node.data is None:
continue
tool_node = _unwrap_runnable_callable(node.data, ToolNode)
if tool_node and hasattr(tool_node, "_tools_by_name"):
tools_map.update(tool_node._tools_by_name)
except Exception:
pass
return tools_map

async def _create_suspended_result(
self,
graph_state: StateSnapshot,
Expand All @@ -338,7 +435,9 @@ async def _create_suspended_result(
if task.interrupts and interrupt in task.interrupts:
# Only include if this task is still waiting for interrupt resolution
if task.interrupts and not task.result:
interrupt_map[interrupt.id] = interrupt.value
interrupt_map[interrupt.id] = self._format_interrupt_value(
interrupt.value
)
break

# If we have dynamic interrupts, return suspended with interrupt map
Expand All @@ -347,6 +446,7 @@ async def _create_suspended_result(
return UiPathRuntimeResult(
output=interrupt_map,
status=UiPathRuntimeStatus.SUSPENDED,
trigger=UiPathResumeTrigger(trigger_type=UiPathResumeTriggerType.API),
)
else:
# Static interrupt (breakpoint)
Expand Down
Loading