From ec7291bc228cd0d2fd52926ebbbd1f375f3928fd Mon Sep 17 00:00:00 2001 From: shuofengzhang Date: Thu, 19 Mar 2026 16:43:15 +0000 Subject: [PATCH] Preserve CLI error_during_execution text for initialize failures --- src/claude_agent_sdk/_internal/query.py | 26 ++++++++++- tests/test_query.py | 62 +++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index b21405fc..5e460e0f 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -15,6 +15,7 @@ ListToolsRequest, ) +from .._errors import ProcessError from ..types import ( PermissionMode, PermissionResultAllow, @@ -116,6 +117,10 @@ def __init__( # Track first result for proper stream closure with SDK MCP servers self._first_result_event = anyio.Event() + # Preserve CLI execution error text (from result subtype=error_during_execution) + # so initialize/control callers receive actionable errors instead of generic + # process-exit placeholders. + self._last_execution_error: str | None = None async def initialize(self) -> dict[str, Any] | None: """Initialize control protocol if in streaming mode. @@ -231,6 +236,13 @@ async def _read_messages(self) -> None: # Track results for proper stream closure if msg_type == "result": self._first_result_event.set() + if ( + message.get("subtype") == "error_during_execution" + and message.get("is_error") is True + ): + result_text = message.get("result") + if isinstance(result_text, str) and result_text.strip(): + self._last_execution_error = result_text.strip() # Regular SDK messages go to the stream await self._message_send.send(message) @@ -241,13 +253,23 @@ async def _read_messages(self) -> None: raise # Re-raise to properly handle cancellation except Exception as e: logger.error(f"Fatal error in message reader: {e}") + + # If the CLI emitted an explicit execution error result before exiting, + # prefer that actionable message for control waiters (e.g. initialize) + # over generic process-exit placeholders. + pending_error: Exception = e + if isinstance(e, ProcessError) and self._last_execution_error: + pending_error = Exception(self._last_execution_error) + # Signal all pending control requests so they fail fast instead of timing out for request_id, event in list(self.pending_control_responses.items()): if request_id not in self.pending_control_results: - self.pending_control_results[request_id] = e + self.pending_control_results[request_id] = pending_error event.set() # Put error in stream so iterators can handle it - await self._message_send.send({"type": "error", "error": str(e)}) + await self._message_send.send( + {"type": "error", "error": str(pending_error)} + ) finally: # Unblock any waiters (e.g. string-prompt path waiting for first # result) so they don't stall for the full timeout on early exit. diff --git a/tests/test_query.py b/tests/test_query.py index d0cece38..490c4020 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -16,6 +16,7 @@ from claude_agent_sdk import ( AssistantMessage, ClaudeAgentOptions, + ProcessError, ResultMessage, create_sdk_mcp_server, query, @@ -676,3 +677,64 @@ async def _test(): assert "fast_1" not in q._inflight_requests asyncio.run(_test()) + + +class TestInitializeErrorPropagation: + """Test initialize error propagation for process exit cases.""" + + def test_initialize_uses_error_during_execution_result_text(self): + """When CLI exits after error_during_execution, propagate real error text.""" + + async def _test(): + control_request_received = anyio.Event() + + class FailingInitializeTransport: + async def connect(self): + return None + + async def close(self): + return None + + async def end_input(self): + return None + + def is_ready(self) -> bool: + return True + + async def write(self, data: str): + payload = json.loads(data) + if payload.get("type") == "control_request": + control_request_received.set() + + async def read_messages(self): + await control_request_received.wait() + yield { + "type": "result", + "subtype": "error_during_execution", + "duration_ms": 1, + "duration_api_ms": 0, + "is_error": True, + "num_turns": 0, + "session_id": "session_123", + "result": "No conversation found with session ID ab2c985b", + } + raise ProcessError( + "Command failed with exit code 1", + exit_code=1, + stderr="Check stderr output for details", + ) + + transport = FailingInitializeTransport() + + caught: Exception | None = None + try: + async for _msg in query(prompt="Hello", transport=transport): + pass + except Exception as e: + caught = e + + assert caught is not None + assert "No conversation found with session ID ab2c985b" in str(caught) + assert "Check stderr output for details" not in str(caught) + + anyio.run(_test)