diff --git a/pyproject.toml b/pyproject.toml index 29814be..32820f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sap-cloud-sdk" -version = "0.18.1" +version = "0.18.2" description = "SAP Cloud SDK for Python" readme = "README.md" license = "Apache-2.0" diff --git a/src/sap_cloud_sdk/agentgateway/__init__.py b/src/sap_cloud_sdk/agentgateway/__init__.py index e85fda5..555cae0 100644 --- a/src/sap_cloud_sdk/agentgateway/__init__.py +++ b/src/sap_cloud_sdk/agentgateway/__init__.py @@ -52,7 +52,7 @@ ] """ -from sap_cloud_sdk.agentgateway._models import MCPTool +from sap_cloud_sdk.agentgateway._models import AuthResult, MCPTool from sap_cloud_sdk.agentgateway.agw_client import create_client, AgentGatewayClient from sap_cloud_sdk.agentgateway.exceptions import ( AgentGatewaySDKError, @@ -66,6 +66,7 @@ # Client class "AgentGatewayClient", # Data models + "AuthResult", "MCPTool", # Exceptions "AgentGatewaySDKError", diff --git a/src/sap_cloud_sdk/agentgateway/_customer.py b/src/sap_cloud_sdk/agentgateway/_customer.py index a14ed7e..13c425c 100644 --- a/src/sap_cloud_sdk/agentgateway/_customer.py +++ b/src/sap_cloud_sdk/agentgateway/_customer.py @@ -8,7 +8,6 @@ - Tool invocation: mTLS + jwt-bearer grant → user-scoped token (principal propagation) """ -import asyncio import json import logging import os @@ -427,16 +426,16 @@ async def _list_server_tools( async def get_mcp_tools_customer( credentials: CustomerCredentials, - app_tid: str | None = None, + system_token: str, ) -> list[MCPTool]: """List all MCP tools from servers defined in credentials. Iterates over all integrationDependencies in the credentials file and - discovers tools from each MCP server using mTLS client credentials. + discovers tools from each MCP server using a pre-fetched system token. Args: credentials: Customer credentials with integrationDependencies. - app_tid: BTP Application Tenant ID of subscriber (optional). + system_token: Pre-fetched raw system token for authentication. Returns: List of MCPTool objects from all servers. @@ -453,12 +452,6 @@ async def get_mcp_tools_customer( logger.info("Discovering tools from %d MCP server(s)", len(dependencies)) - # Get system token for discovery - loop = asyncio.get_running_loop() - system_token = await loop.run_in_executor( - None, get_system_token_mtls, credentials, app_tid - ) - tools: list[MCPTool] = [] for dep in dependencies: @@ -484,23 +477,18 @@ async def get_mcp_tools_customer( async def call_mcp_tool_customer( - credentials: CustomerCredentials, tool: MCPTool, - user_token: str | None, - app_tid: str | None = None, + auth_token: str, **kwargs, ) -> str: """Invoke an MCP tool using customer flow. - If user_token is provided, exchanges it for an AGW-scoped token to preserve - user identity for principal propagation. Otherwise, falls back to system token. + Uses a pre-fetched token (either user-scoped or system-scoped) for + authentication against the MCP server. Args: - credentials: Customer credentials. tool: MCPTool to invoke. - user_token: User's JWT token for principal propagation (optional). - If None, system token is used instead (no principal propagation). - app_tid: BTP Application Tenant ID of subscriber (optional). + auth_token: Pre-fetched raw access token for authentication. **kwargs: Tool input parameters. Returns: @@ -508,28 +496,9 @@ async def call_mcp_tool_customer( """ logger.info("Calling tool '%s' on server '%s'", tool.name, tool.server_name) - loop = asyncio.get_running_loop() - - if user_token: - # Exchange user token for AGW-scoped token (with principal propagation) - agw_token = await loop.run_in_executor( - None, exchange_user_token, credentials, user_token, app_tid - ) - else: - # TODO: IBD workaround - use system token when user_token is not available. - # This bypasses principal propagation. Remove this fallback once IBD - # supports proper user token flow. - logger.warning( - "No user_token provided - using system token for tool invocation. " - "Principal propagation will NOT work." - ) - agw_token = await loop.run_in_executor( - None, get_system_token_mtls, credentials, app_tid - ) - async with httpx.AsyncClient( headers={ - "Authorization": f"Bearer {agw_token}", + "Authorization": f"Bearer {auth_token}", "x-correlation-id": str(uuid.uuid4()), }, timeout=_HTTP_TIMEOUT, diff --git a/src/sap_cloud_sdk/agentgateway/_lob.py b/src/sap_cloud_sdk/agentgateway/_lob.py index 5533525..ed4526c 100644 --- a/src/sap_cloud_sdk/agentgateway/_lob.py +++ b/src/sap_cloud_sdk/agentgateway/_lob.py @@ -1,11 +1,12 @@ """LoB agent flow - BTP Destination Service based. LoB agents use BTP Destination Service for credential management: -- Phase 1 (discovery): Client credentials from destination -- Phase 2 (execution): Token exchange with user_token for principal propagation +- Phase 1 (discovery): Client credentials from destination (subscriber.ias fragment) +- Phase 2 (execution): Token exchange with user_token (subscriber.ias.user fragment) """ import asyncio +import base64 import logging import os import uuid @@ -33,6 +34,7 @@ # Label values for fragment discovery _MCP_LABEL_VALUE = "agw.mcp.server" _IAS_LABEL_VALUE = "subscriber.ias" +_IAS_USER_LABEL_VALUE = "subscriber.ias.user" _DESTINATION_INSTANCE = "default" @@ -61,8 +63,12 @@ def _fetch_auth_token( dest_name: str, tenant_subdomain: str, options: ConsumptionOptions | None = None, -) -> str: - """Fetch auth token from destination service. +) -> tuple[str, str]: + """Fetch raw access token and gateway URL from destination service. + + Extracts the raw JWT by base64-decoding the token value field + from the destination service response, and the gateway URL from + the destination's URL property. Args: dest_name: Destination name. @@ -70,7 +76,7 @@ def _fetch_auth_token( options: Consumption options (fragment_name, user_token). Returns: - Authorization header value. + Tuple of (raw_access_token, gateway_url). Raises: MCPServerNotFoundError: If no auth token is returned. @@ -88,13 +94,14 @@ def _fetch_auth_token( f"No auth token returned for destination '{dest_name}'" ) - auth = dest.auth_tokens[0].http_header.get("value", "") - if not auth: - raise MCPServerNotFoundError( - f"Empty Authorization header for destination '{dest_name}'" - ) + token_value = dest.auth_tokens[0].value + if not token_value: + raise MCPServerNotFoundError(f"Empty token value for destination '{dest_name}'") + + token = base64.b64decode(token_value).decode("utf-8") + gateway_url = (dest.url or "").rstrip("/") - return auth + return token, gateway_url def list_mcp_fragments(tenant_subdomain: str) -> list: @@ -146,10 +153,40 @@ def get_ias_fragment_name(tenant_subdomain: str) -> str: return fragments[0].name -async def get_system_auth( +def get_ias_user_fragment_name(tenant_subdomain: str) -> str: + """Get the IAS user fragment name for token exchange (principal propagation). + + Looks up the IAS user fragment created during subscription by the + sap-managed-runtime-type=subscriber.ias.user label. + + Args: + tenant_subdomain: Tenant subdomain for multi-tenant lookup. + + Returns: + IAS user fragment name. + + Raises: + MCPServerNotFoundError: If no IAS user fragment is found. + """ + client = create_fragment_client(instance=_DESTINATION_INSTANCE) + fragments = client.list_instance_fragments( + filter=ListOptions( + filter_labels=[Label(key=_LABEL_KEY, values=[_IAS_USER_LABEL_VALUE])] + ), + tenant=tenant_subdomain, + ) + if not fragments: + raise MCPServerNotFoundError( + f"No IAS user fragment found (label {_LABEL_KEY}={_IAS_USER_LABEL_VALUE}) " + f"for tenant '{tenant_subdomain}'" + ) + return fragments[0].name + + +async def fetch_system_auth( tenant_subdomain: str, -) -> str: - """Get system-scoped auth (Phase 1 - client credentials). +) -> tuple[str, str]: + """Fetch system-scoped auth (Phase 1 - client credentials). Looks up the IAS fragment (subscriber.ias label) and uses it to acquire a client-credentials token via BTP Destination Service. @@ -158,7 +195,7 @@ async def get_system_auth( tenant_subdomain: Tenant subdomain for multi-tenant lookup. Returns: - Authorization header value (e.g., "Bearer xxx"). + Tuple of (raw_access_token, gateway_url). Raises: MCPServerNotFoundError: If no IAS fragment or auth token is found. @@ -185,39 +222,42 @@ def _fetch_system_auth_sync(): return await loop.run_in_executor(None, _fetch_system_auth_sync) -async def get_user_auth( - mcp_fragment_name: str, +async def fetch_user_auth( user_token: str, tenant_subdomain: str, -) -> str: - """Get user-scoped auth (Phase 2 - token exchange). +) -> tuple[str, str]: + """Fetch user-scoped auth (Phase 2 - token exchange). + + Looks up the IAS user fragment (subscriber.ias.user label) and uses it + together with the user_token to perform a token exchange via BTP + Destination Service. Args: - mcp_fragment_name: MCP fragment name for token exchange. user_token: User's JWT for principal propagation. tenant_subdomain: Tenant subdomain for multi-tenant lookup. Returns: - Authorization header value with user identity embedded. + Tuple of (raw_access_token, gateway_url). Raises: - MCPServerNotFoundError: If no auth token is returned. + MCPServerNotFoundError: If no IAS user fragment or auth token is found. """ loop = asyncio.get_running_loop() def _fetch_user_auth_sync(): + ias_user_fragment_name = get_ias_user_fragment_name(tenant_subdomain) dest_name = _ias_dest_name() logger.info( "Exchanging user auth — destination: '%s', fragment: '%s', tenant: '%s'", dest_name, - mcp_fragment_name, + ias_user_fragment_name, tenant_subdomain, ) options = ConsumptionOptions( user_token=user_token, - fragment_name=mcp_fragment_name, + fragment_name=ias_user_fragment_name, fragment_level=ConsumptionLevel.INSTANCE, ) @@ -227,20 +267,23 @@ def _fetch_user_auth_sync(): async def list_server_tools( - dest_url: str, system_auth: str, fragment_name: str + dest_url: str, auth_token: str, fragment_name: str ) -> list[MCPTool]: """List tools from a single MCP server. Args: dest_url: MCP endpoint URL. - system_auth: Authorization header for the request. + auth_token: Raw access token for the request. fragment_name: Fragment name for reference. Returns: List of MCPTool objects from this server. """ async with httpx.AsyncClient( - headers={"Authorization": system_auth, "x-correlation-id": str(uuid.uuid4())}, + headers={ + "Authorization": f"Bearer {auth_token}", + "x-correlation-id": str(uuid.uuid4()), + }, timeout=_HTTP_TIMEOUT, ) as http_client: async with streamable_http_client(dest_url, http_client=http_client) as ( @@ -273,13 +316,15 @@ async def list_server_tools( async def get_mcp_tools_lob( tenant_subdomain: str, + system_token: str, ) -> list[MCPTool]: """List all MCP tools using LoB flow (destination-based). - Uses Phase 1 auth (client-scoped) via BTP Destination Service. + Uses a pre-fetched system token for authentication against MCP servers. Args: tenant_subdomain: Tenant subdomain for multi-tenant lookup. + system_token: Pre-fetched raw system token (from get_system_auth). Returns: List of MCPTool objects from all MCP servers. @@ -308,8 +353,7 @@ async def get_mcp_tools_lob( continue try: - system_auth = await get_system_auth(tenant_subdomain) - server_tools = await list_server_tools(mcp_url, system_auth, fragment_name) + server_tools = await list_server_tools(mcp_url, system_token, fragment_name) tools.extend(server_tools) logger.debug( "Loaded %d tool(s) from fragment '%s'", @@ -328,35 +372,31 @@ async def get_mcp_tools_lob( async def call_mcp_tool_lob( tool: MCPTool, - user_token: str, - tenant_subdomain: str, + user_auth_token: str, **kwargs, ) -> str: """Invoke an MCP tool using LoB flow (destination-based). - Uses Phase 2 auth (user-scoped) via token exchange. - Principal propagation ensures LoB systems see user identity. + Uses a pre-fetched user token for principal propagation. Args: tool: MCPTool object (from list_mcp_tools). - user_token: User's JWT for principal propagation. - tenant_subdomain: Tenant subdomain for token exchange. + user_auth_token: Pre-fetched raw user token (from get_user_auth). **kwargs: Tool input parameters. Returns: Tool execution result as string. - - Raises: - MCPServerNotFoundError: If destination/auth fails. """ if not tool.fragment_name: raise MCPServerNotFoundError( f"Tool '{tool.name}' missing fragment_name for LoB invocation" ) - user_auth = await get_user_auth(tool.fragment_name, user_token, tenant_subdomain) async with httpx.AsyncClient( - headers={"Authorization": user_auth, "x-correlation-id": str(uuid.uuid4())}, + headers={ + "Authorization": f"Bearer {user_auth_token}", + "x-correlation-id": str(uuid.uuid4()), + }, timeout=_HTTP_TIMEOUT, ) as http_client: async with streamable_http_client(tool.url, http_client=http_client) as ( diff --git a/src/sap_cloud_sdk/agentgateway/_models.py b/src/sap_cloud_sdk/agentgateway/_models.py index 138d278..6d15f02 100644 --- a/src/sap_cloud_sdk/agentgateway/_models.py +++ b/src/sap_cloud_sdk/agentgateway/_models.py @@ -6,6 +6,32 @@ from typing import Any +@dataclass +class AuthResult: + """Authentication result from Agent Gateway. + + Contains the access token and the Agent Gateway URL. + + Attributes: + access_token: Raw JWT access token string. + gateway_url: Agent Gateway base URL (no trailing slash). + + Example: + ```python + from sap_cloud_sdk.agentgateway import create_client + + agw_client = create_client(tenant_subdomain="my-tenant") + + auth = await agw_client.get_system_auth() + print(auth.access_token) # raw JWT + print(auth.gateway_url) # "https://agw.example.com" + ``` + """ + + access_token: str + gateway_url: str + + @dataclass class MCPTool: """MCP tool discovered from Agent Gateway. diff --git a/src/sap_cloud_sdk/agentgateway/agw_client.py b/src/sap_cloud_sdk/agentgateway/agw_client.py index f75d614..39af551 100644 --- a/src/sap_cloud_sdk/agentgateway/agw_client.py +++ b/src/sap_cloud_sdk/agentgateway/agw_client.py @@ -7,17 +7,25 @@ - Customer agents: Use file-based credentials mounted on pod with mTLS auth """ +import asyncio import logging from typing import Callable -from sap_cloud_sdk.agentgateway._models import MCPTool from sap_cloud_sdk.agentgateway._customer import ( + call_mcp_tool_customer, detect_customer_agent_credentials, - load_customer_credentials, + exchange_user_token, get_mcp_tools_customer, - call_mcp_tool_customer, + get_system_token_mtls, + load_customer_credentials, ) -from sap_cloud_sdk.agentgateway._lob import get_mcp_tools_lob, call_mcp_tool_lob +from sap_cloud_sdk.agentgateway._lob import ( + call_mcp_tool_lob, + fetch_system_auth, + fetch_user_auth, + get_mcp_tools_lob, +) +from sap_cloud_sdk.agentgateway._models import AuthResult, MCPTool from sap_cloud_sdk.agentgateway.exceptions import AgentGatewaySDKError from sap_cloud_sdk.core.telemetry import Module, Operation, record_metrics @@ -67,6 +75,23 @@ class AgentGatewayClient: cost_center="1000", ) ``` + + Example (auth for external use): + ```python + from sap_cloud_sdk.agentgateway import create_client + + agw_client = create_client(tenant_subdomain="my-tenant") + + # Get system-scoped auth (token + gateway URL) + auth = await agw_client.get_system_auth() + print(auth.access_token) # raw JWT + print(auth.gateway_url) # "https://agw.example.com" + + # Get user-scoped auth (token exchange + gateway URL) + auth = await agw_client.get_user_auth(user_token="user-jwt") + print(auth.access_token) # exchanged JWT with user identity + print(auth.gateway_url) # "https://agw.example.com" + ``` """ def __init__( @@ -113,6 +138,134 @@ def _resolve_tenant_subdomain(self) -> str: "tenant_subdomain is required for LoB agent flow.", ) + @record_metrics(Module.AGENTGATEWAY, Operation.AGENTGATEWAY_GET_SYSTEM_AUTH) + async def get_system_auth(self, app_tid: str | None = None) -> AuthResult: + """Get system-scoped authentication (client_credentials flow). + + Automatically detects agent type (LoB vs Customer) based on + credential file presence. + + Args: + app_tid: BTP Application Tenant ID of the subscriber. + Only used for customer agents. This is passed to the token + service for tenant-scoped token requests. + + Returns: + AuthResult with raw access token (JWT) and Agent Gateway URL. + + Raises: + AgentGatewaySDKError: If tenant_subdomain is required but not + provided (LoB), or if token acquisition fails. + + Example: + ```python + auth = await agw_client.get_system_auth() + headers = {"Authorization": f"Bearer {auth.access_token}"} + # auth.gateway_url is the Agent Gateway base URL + ``` + """ + try: + credentials_path = detect_customer_agent_credentials() + if credentials_path: + logger.info( + "Customer agent credentials detected at '%s'", credentials_path + ) + credentials = load_customer_credentials(credentials_path) + loop = asyncio.get_running_loop() + token = await loop.run_in_executor( + None, get_system_token_mtls, credentials, app_tid + ) + return AuthResult( + access_token=token, + gateway_url=credentials.gateway_url, + ) + + # LoB flow + if app_tid: + logger.warning("app_tid parameter ignored for LoB agent flow") + + tenant = self._resolve_tenant_subdomain() + token, gateway_url = await fetch_system_auth(tenant) + return AuthResult(access_token=token, gateway_url=gateway_url) + + except AgentGatewaySDKError: + raise + except Exception as e: + logger.exception("Unexpected error during system auth acquisition") + raise AgentGatewaySDKError(f"System auth acquisition failed: {e}") from e + + @record_metrics(Module.AGENTGATEWAY, Operation.AGENTGATEWAY_GET_USER_AUTH) + async def get_user_auth( + self, + user_token: str | Callable[[], str] | None, + app_tid: str | None = None, + ) -> AuthResult: + """Exchange a user token for AGW-scoped authentication (token exchange). + + Automatically detects agent type (LoB vs Customer) based on + credential file presence. + + Args: + user_token: User's JWT for principal propagation. + Can be a string or a callable returning a string. + app_tid: BTP Application Tenant ID of the subscriber. + Only used for customer agents. This is passed to the token + service for tenant-scoped token exchange. + + Returns: + AuthResult with raw access token (JWT, user identity embedded) + and Agent Gateway URL. + + Raises: + AgentGatewaySDKError: If user_token is empty, or tenant_subdomain + is required but not provided (LoB), or if token exchange fails. + + Example: + ```python + auth = await agw_client.get_user_auth(user_token="user-jwt") + headers = {"Authorization": f"Bearer {auth.access_token}"} + # auth.gateway_url is the Agent Gateway base URL + ``` + """ + try: + resolved_user_token = self._resolve_value( + user_token, + "user_token is required for token exchange.", + ) + + credentials_path = detect_customer_agent_credentials() + if credentials_path: + logger.info( + "Customer agent credentials detected at '%s'", credentials_path + ) + credentials = load_customer_credentials(credentials_path) + loop = asyncio.get_running_loop() + token = await loop.run_in_executor( + None, + exchange_user_token, + credentials, + resolved_user_token, + app_tid, + ) + return AuthResult( + access_token=token, + gateway_url=credentials.gateway_url, + ) + + # LoB flow + if app_tid: + logger.warning("app_tid parameter ignored for LoB agent flow") + + tenant = self._resolve_tenant_subdomain() + token, gateway_url = await fetch_user_auth(resolved_user_token, tenant) + return AuthResult(access_token=token, gateway_url=gateway_url) + + except AgentGatewaySDKError: + raise + except Exception as e: + logger.exception("Unexpected error during user auth exchange") + raise AgentGatewaySDKError(f"User auth exchange failed: {e}") from e + @record_metrics(Module.AGENTGATEWAY, Operation.AGENTGATEWAY_LIST_MCP_TOOLS) async def list_mcp_tools( self, @@ -153,14 +306,16 @@ async def list_mcp_tools( "Customer agent credentials detected at '%s'", credentials_path ) credentials = load_customer_credentials(credentials_path) - return await get_mcp_tools_customer(credentials, app_tid) + auth = await self.get_system_auth(app_tid=app_tid) + return await get_mcp_tools_customer(credentials, auth.access_token) # LoB flow - requires tenant_subdomain if app_tid: logger.warning("app_tid parameter ignored for LoB agent flow") tenant = self._resolve_tenant_subdomain() - return await get_mcp_tools_lob(tenant) + auth = await self.get_system_auth() + return await get_mcp_tools_lob(tenant, auth.access_token) except AgentGatewaySDKError: # Re-raise SDK errors as-is @@ -228,32 +383,26 @@ async def call_mcp_tool( ) # Resolve user_token if provided (optional for customer flow) - resolved_user_token = None if user_token: - resolved_user_token = ( - user_token() - if not isinstance(user_token, str) and callable(user_token) - else user_token + auth = await self.get_user_auth(user_token, app_tid) + else: + # TODO: IBD workaround - use system token when user_token + # is not available. This bypasses principal propagation. + # Remove this fallback once IBD supports proper user token flow. + logger.warning( + "No user_token provided - using system token for tool " + "invocation. Principal propagation will NOT work." ) - if resolved_user_token: - resolved_user_token = resolved_user_token.strip() or None + auth = await self.get_system_auth(app_tid) - credentials = load_customer_credentials(credentials_path) - return await call_mcp_tool_customer( - credentials, tool, resolved_user_token, app_tid, **kwargs - ) + return await call_mcp_tool_customer(tool, auth.access_token, **kwargs) # LoB flow - requires user_token and tenant_subdomain - resolved_user_token = self._resolve_value( - user_token, - "user_token is required for LoB agent tool invocation.", - ) - if app_tid: logger.warning("app_tid parameter ignored for LoB agent flow") - tenant = self._resolve_tenant_subdomain() - return await call_mcp_tool_lob(tool, resolved_user_token, tenant, **kwargs) + auth = await self.get_user_auth(user_token, app_tid) + return await call_mcp_tool_lob(tool, auth.access_token, **kwargs) except AgentGatewaySDKError: # Re-raise SDK errors as-is @@ -318,5 +467,16 @@ def create_client( cost_center="1000", # example tool-specific parameter ) ``` + + Example (auth fetching): + ```python + from sap_cloud_sdk.agentgateway import create_client + + agw_client = create_client(tenant_subdomain="my-tenant") + + # Get auth for external use + auth = await agw_client.get_system_auth() + user_auth = await agw_client.get_user_auth(user_token="user-jwt") + ``` """ return AgentGatewayClient(tenant_subdomain=tenant_subdomain) diff --git a/src/sap_cloud_sdk/core/telemetry/operation.py b/src/sap_cloud_sdk/core/telemetry/operation.py index 8619145..f56e702 100644 --- a/src/sap_cloud_sdk/core/telemetry/operation.py +++ b/src/sap_cloud_sdk/core/telemetry/operation.py @@ -107,6 +107,8 @@ class Operation(str, Enum): # Agent Gateway Operations AGENTGATEWAY_LIST_MCP_TOOLS = "list_mcp_tools" AGENTGATEWAY_CALL_MCP_TOOL = "call_mcp_tool" + AGENTGATEWAY_GET_SYSTEM_AUTH = "get_system_auth" + AGENTGATEWAY_GET_USER_AUTH = "get_user_auth" # Agent Memory Operations AGENT_MEMORY_ADD_MEMORY = "add_memory" diff --git a/src/sap_cloud_sdk/extensibility/_ums_transport.py b/src/sap_cloud_sdk/extensibility/_ums_transport.py index 437161a..f529bf9 100644 --- a/src/sap_cloud_sdk/extensibility/_ums_transport.py +++ b/src/sap_cloud_sdk/extensibility/_ums_transport.py @@ -436,7 +436,8 @@ class UmsTransport: 1. ``config.destination_name`` (explicit config override). 2. ``APPFND_UMS_DESTINATION_NAME`` environment variable. 3. ``sap-managed-runtime-ums-{APPFND_CONHOS_LANDSCAPE}`` (constructed). - 4. ``EXTENSIBILITY_SERVICE`` (fallback with warning). + + If none of the above are available, resolution fails with a warning. Args: agent_ord_id: ORD ID of the agent. diff --git a/src/sap_cloud_sdk/extensibility/config.py b/src/sap_cloud_sdk/extensibility/config.py index 649a700..c881492 100644 --- a/src/sap_cloud_sdk/extensibility/config.py +++ b/src/sap_cloud_sdk/extensibility/config.py @@ -23,8 +23,8 @@ class ExtensibilityConfig: When ``None`` (the default), the destination name is resolved automatically in order: (1) ``APPFND_UMS_DESTINATION_NAME`` environment variable, - (2) ``sap-managed-runtime-ums-{APPFND_CONHOS_LANDSCAPE}``, - (3) fallback to ``"EXTENSIBILITY_SERVICE"`` with a warning. + (2) ``sap-managed-runtime-ums-{APPFND_CONHOS_LANDSCAPE}``. + If neither is available, resolution fails with a warning. Set this only when the destination follows a non-standard naming convention that cannot be expressed via environment variables. diff --git a/src/sap_cloud_sdk/extensibility/user-guide.md b/src/sap_cloud_sdk/extensibility/user-guide.md index 043e539..ed7590d 100644 --- a/src/sap_cloud_sdk/extensibility/user-guide.md +++ b/src/sap_cloud_sdk/extensibility/user-guide.md @@ -714,7 +714,7 @@ Validation issues produce log warnings but never prevent output generation. The module resolves the extensibility service URL and credentials through the SAP BTP Destination Service. The destination is looked up at the subaccount level. -- **Default destination name resolution**: (1) `APPFND_UMS_DESTINATION_NAME` env var, (2) `sap-managed-runtime-ums-{APPFND_CONHOS_LANDSCAPE}`, (3) `EXTENSIBILITY_SERVICE` fallback. +- **Default destination name resolution**: (1) `APPFND_UMS_DESTINATION_NAME` env var, (2) `sap-managed-runtime-ums-{APPFND_CONHOS_LANDSCAPE}`. If neither is available, resolution fails with a warning. - **Default destination instance**: `default` - Override via `ExtensibilityConfig(destination_name=...)` when the destination uses a non-standard name. diff --git a/tests/agentgateway/unit/test_agw_client.py b/tests/agentgateway/unit/test_agw_client.py index 6a7bcc4..0f50620 100644 --- a/tests/agentgateway/unit/test_agw_client.py +++ b/tests/agentgateway/unit/test_agw_client.py @@ -1,12 +1,13 @@ """Unit tests for Agent Gateway client.""" -from unittest.mock import patch, AsyncMock +from unittest.mock import patch, AsyncMock, MagicMock import pytest from sap_cloud_sdk.agentgateway import ( create_client, AgentGatewayClient, + AuthResult, MCPTool, AgentGatewaySDKError, ) @@ -96,6 +97,221 @@ def test_raises_on_callable_returning_empty(self): AgentGatewayClient._resolve_value(get_empty, "test error") +# ============================================================ +# Test: get_system_auth +# ============================================================ + + +class TestGetSystemAuth: + """Tests for get_system_auth async method.""" + + @pytest.mark.asyncio + async def test_lob_flow_returns_auth_result(self): + """Return AuthResult from LoB flow.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_system_auth", + new_callable=AsyncMock, + return_value=("raw-system-jwt-token", "https://agw.example.com"), + ) as mock_auth: + agw_client = create_client(tenant_subdomain="my-tenant") + + result = await agw_client.get_system_auth() + + assert isinstance(result, AuthResult) + assert result.access_token == "raw-system-jwt-token" + assert result.gateway_url == "https://agw.example.com" + mock_auth.assert_called_once_with("my-tenant") + + @pytest.mark.asyncio + async def test_customer_flow_returns_auth_result(self): + """Return AuthResult from customer flow.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value="/path/to/credentials", + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", + ) as mock_load, patch( + "sap_cloud_sdk.agentgateway.agw_client.get_system_token_mtls", + return_value="customer-system-token", + ) as mock_mtls: + mock_creds = MagicMock() + mock_creds.gateway_url = "https://agw.customer.com" + mock_load.return_value = mock_creds + + agw_client = create_client() + + result = await agw_client.get_system_auth(app_tid="test-tid") + + assert isinstance(result, AuthResult) + assert result.access_token == "customer-system-token" + assert result.gateway_url == "https://agw.customer.com" + mock_load.assert_called_once_with("/path/to/credentials") + mock_mtls.assert_called_once_with(mock_creds, "test-tid") + + @pytest.mark.asyncio + async def test_missing_tenant_raises_for_lob(self): + """Raise AgentGatewaySDKError when tenant_subdomain is missing for LoB.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ): + agw_client = create_client() + + with pytest.raises(AgentGatewaySDKError, match="tenant_subdomain is required"): + await agw_client.get_system_auth() + + @pytest.mark.asyncio + async def test_callable_tenant_subdomain(self): + """Accept callable for tenant_subdomain in LoB flow.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_system_auth", + new_callable=AsyncMock, + return_value=("token", "https://agw.example.com"), + ) as mock_auth: + get_tenant = lambda: "dynamic-tenant" + agw_client = create_client(tenant_subdomain=get_tenant) + + await agw_client.get_system_auth() + + mock_auth.assert_called_once_with("dynamic-tenant") + + @pytest.mark.asyncio + async def test_wraps_unexpected_errors(self): + """Wrap unexpected errors in AgentGatewaySDKError.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_system_auth", + new_callable=AsyncMock, + side_effect=RuntimeError("unexpected"), + ): + agw_client = create_client(tenant_subdomain="my-tenant") + + with pytest.raises(AgentGatewaySDKError, match="System auth acquisition failed"): + await agw_client.get_system_auth() + + +# ============================================================ +# Test: get_user_auth +# ============================================================ + + +class TestGetUserAuth: + """Tests for get_user_auth async method.""" + + @pytest.mark.asyncio + async def test_lob_flow_returns_auth_result(self): + """Return AuthResult from LoB flow.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + return_value=("raw-user-jwt-token", "https://agw.example.com"), + ) as mock_auth: + agw_client = create_client(tenant_subdomain="my-tenant") + + result = await agw_client.get_user_auth(user_token="user-jwt") + + assert isinstance(result, AuthResult) + assert result.access_token == "raw-user-jwt-token" + assert result.gateway_url == "https://agw.example.com" + mock_auth.assert_called_once_with("user-jwt", "my-tenant") + + @pytest.mark.asyncio + async def test_customer_flow_exchanges_token(self): + """Exchange token via customer flow and return AuthResult.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value="/path/to/credentials", + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", + ) as mock_load, patch( + "sap_cloud_sdk.agentgateway.agw_client.exchange_user_token", + return_value="exchanged-token", + ) as mock_exchange: + mock_creds = MagicMock() + mock_creds.gateway_url = "https://agw.customer.com" + mock_load.return_value = mock_creds + + agw_client = create_client() + + result = await agw_client.get_user_auth( + user_token="user-jwt", app_tid="test-tid" + ) + + assert isinstance(result, AuthResult) + assert result.access_token == "exchanged-token" + assert result.gateway_url == "https://agw.customer.com" + mock_exchange.assert_called_once_with(mock_creds, "user-jwt", "test-tid") + + @pytest.mark.asyncio + async def test_missing_user_token_raises(self): + """Raise AgentGatewaySDKError when user_token is empty.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ): + agw_client = create_client(tenant_subdomain="my-tenant") + + with pytest.raises(AgentGatewaySDKError, match="user_token is required"): + await agw_client.get_user_auth(user_token="") + + @pytest.mark.asyncio + async def test_callable_user_token(self): + """Accept callable for user_token.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + return_value=("token", "https://agw.example.com"), + ) as mock_auth: + agw_client = create_client(tenant_subdomain="my-tenant") + get_token = lambda: "dynamic-user-jwt" + + await agw_client.get_user_auth(user_token=get_token) + + mock_auth.assert_called_once_with("dynamic-user-jwt", "my-tenant") + + @pytest.mark.asyncio + async def test_missing_tenant_raises_for_lob(self): + """Raise AgentGatewaySDKError when tenant_subdomain is missing for LoB.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ): + agw_client = create_client() + + with pytest.raises(AgentGatewaySDKError, match="tenant_subdomain is required"): + await agw_client.get_user_auth(user_token="user-jwt") + + @pytest.mark.asyncio + async def test_wraps_unexpected_errors(self): + """Wrap unexpected errors in AgentGatewaySDKError.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value=None, + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + side_effect=RuntimeError("unexpected"), + ): + agw_client = create_client(tenant_subdomain="my-tenant") + + with pytest.raises(AgentGatewaySDKError, match="User auth exchange failed"): + await agw_client.get_user_auth(user_token="user-jwt") + + # ============================================================ # Test: list_mcp_tools # ============================================================ @@ -154,6 +370,11 @@ async def test_with_callable_tenant(self): "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_system_auth", + new_callable=AsyncMock, + return_value=("system-token", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.get_mcp_tools_lob", new_callable=AsyncMock, @@ -165,16 +386,21 @@ async def test_with_callable_tenant(self): await agw_client.list_mcp_tools() - mock_lob.assert_called_once_with("my-tenant") + mock_lob.assert_called_once_with("my-tenant", "system-token") @pytest.mark.asyncio - async def test_calls_lob_flow(self): - """list_mcp_tools should call LoB flow with correct parameters.""" + async def test_calls_lob_flow_with_system_token(self): + """list_mcp_tools should call LoB flow with system token.""" with ( patch( "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_system_auth", + new_callable=AsyncMock, + return_value=("system-token-xyz", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.get_mcp_tools_lob", new_callable=AsyncMock, @@ -185,7 +411,7 @@ async def test_calls_lob_flow(self): await agw_client.list_mcp_tools() - mock_lob.assert_called_once_with("my-tenant") + mock_lob.assert_called_once_with("my-tenant", "system-token-xyz") @pytest.mark.asyncio async def test_returns_tools_from_lob_flow(self): @@ -206,6 +432,11 @@ async def test_returns_tools_from_lob_flow(self): "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_system_auth", + new_callable=AsyncMock, + return_value=("token", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.get_mcp_tools_lob", new_callable=AsyncMock, @@ -220,6 +451,34 @@ async def test_returns_tools_from_lob_flow(self): assert len(result) == 1 assert result[0].name == "tool1" + @pytest.mark.asyncio + async def test_customer_flow_passes_system_token(self): + """Customer flow passes pre-fetched system token to get_mcp_tools_customer.""" + with patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value="/path/to/credentials", + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", + ) as mock_load, patch( + "sap_cloud_sdk.agentgateway.agw_client.get_system_token_mtls", + return_value="customer-system-token", + ), patch( + "sap_cloud_sdk.agentgateway.agw_client.get_mcp_tools_customer", + new_callable=AsyncMock, + return_value=[], + ) as mock_customer: + mock_creds = MagicMock() + mock_creds.gateway_url = "https://agw.customer.com" + mock_load.return_value = mock_creds + + agw_client = create_client() + + await agw_client.list_mcp_tools(app_tid="tid") + + mock_customer.assert_called_once_with( + mock_creds, "customer-system-token" + ) + # ============================================================ # Test: call_mcp_tool @@ -289,6 +548,11 @@ async def test_with_callable_user_token(self, mock_tool): "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + return_value=("exchanged-token", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.call_mcp_tool_lob", new_callable=AsyncMock, @@ -306,7 +570,7 @@ async def test_with_callable_user_token(self, mock_tool): assert result == "result" mock_lob.assert_called_once_with( - mock_tool, "my-jwt", "my-tenant", param1="value1" + mock_tool, "exchanged-token", param1="value1" ) @pytest.mark.asyncio @@ -317,6 +581,11 @@ async def test_with_callable_tenant_subdomain(self, mock_tool): "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + return_value=("exchanged-token", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.call_mcp_tool_lob", new_callable=AsyncMock, @@ -332,7 +601,7 @@ async def test_with_callable_tenant_subdomain(self, mock_tool): ) assert result == "result" - mock_lob.assert_called_once_with(mock_tool, "my-jwt", "my-tenant") + mock_lob.assert_called_once_with(mock_tool, "exchanged-token") @pytest.mark.asyncio async def test_customer_credentials_calls_customer_flow(self, mock_tool): @@ -345,12 +614,20 @@ async def test_customer_credentials_calls_customer_flow(self, mock_tool): patch( "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", ) as mock_load, + patch( + "sap_cloud_sdk.agentgateway.agw_client.exchange_user_token", + return_value="exchanged-token", + ), patch( "sap_cloud_sdk.agentgateway.agw_client.call_mcp_tool_customer", new_callable=AsyncMock, return_value="customer result", ) as mock_customer, ): + mock_creds = MagicMock() + mock_creds.gateway_url = "https://agw.customer.com" + mock_load.return_value = mock_creds + agw_client = create_client(tenant_subdomain="my-tenant") result = await agw_client.call_mcp_tool( @@ -359,17 +636,62 @@ async def test_customer_credentials_calls_customer_flow(self, mock_tool): ) assert result == "customer result" + # load_customer_credentials is called once in get_user_auth() mock_load.assert_called_once_with("/path/to/credentials") - mock_customer.assert_called_once() + mock_customer.assert_called_once_with( + mock_tool, "exchanged-token" + ) + + @pytest.mark.asyncio + async def test_customer_flow_falls_back_to_system_token(self, mock_tool): + """Customer flow falls back to system token when user_token is None.""" + with ( + patch( + "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", + return_value="/path/to/credentials", + ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.load_customer_credentials", + ) as mock_load, + patch( + "sap_cloud_sdk.agentgateway.agw_client.get_system_token_mtls", + return_value="system-token", + ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.call_mcp_tool_customer", + new_callable=AsyncMock, + return_value="result with system token", + ) as mock_customer, + ): + mock_creds = MagicMock() + mock_creds.gateway_url = "https://agw.customer.com" + mock_load.return_value = mock_creds + + agw_client = create_client() + + result = await agw_client.call_mcp_tool( + tool=mock_tool, + user_token=None, + ) + + assert result == "result with system token" + mock_customer.assert_called_once_with( + mock_tool, "system-token" + ) @pytest.mark.asyncio - async def test_calls_lob_flow(self, mock_tool): - """call_mcp_tool should call LoB flow with correct parameters.""" + async def test_calls_lob_flow_with_exchanged_token(self, mock_tool): + """call_mcp_tool should exchange user token and pass to LoB flow.""" with ( patch( "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + return_value=("exchanged-user-token", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.call_mcp_tool_lob", new_callable=AsyncMock, @@ -386,7 +708,7 @@ async def test_calls_lob_flow(self, mock_tool): assert result == "tool result" mock_lob.assert_called_once_with( - mock_tool, "jwt-token", "my-tenant", order_id="12345" + mock_tool, "exchanged-user-token", order_id="12345" ) @pytest.mark.asyncio @@ -397,6 +719,11 @@ async def test_returns_result_from_lob_flow(self, mock_tool): "sap_cloud_sdk.agentgateway.agw_client.detect_customer_agent_credentials", return_value=None, ), + patch( + "sap_cloud_sdk.agentgateway.agw_client.fetch_user_auth", + new_callable=AsyncMock, + return_value=("token", "https://agw.example.com"), + ), patch( "sap_cloud_sdk.agentgateway.agw_client.call_mcp_tool_lob", new_callable=AsyncMock, diff --git a/tests/agentgateway/unit/test_customer.py b/tests/agentgateway/unit/test_customer.py index 9e1f7cf..c5f48b8 100644 --- a/tests/agentgateway/unit/test_customer.py +++ b/tests/agentgateway/unit/test_customer.py @@ -449,11 +449,11 @@ async def test_raises_when_empty_dependencies(self): with pytest.raises( AgentGatewaySDKError, match="integrationDependencies is empty" ): - await get_mcp_tools_customer(credentials) + await get_mcp_tools_customer(credentials, "system-token") @pytest.mark.asyncio async def test_discovers_tools_from_credentials(self, credentials): - """Discover tools from integrationDependencies in credentials.""" + """Discover tools from integrationDependencies using pre-fetched token.""" mock_tools = [ MCPTool( name="list_cost_centers", @@ -465,21 +465,22 @@ async def test_discovers_tools_from_credentials(self, credentials): ] with ( - patch( - "sap_cloud_sdk.agentgateway._customer.get_system_token_mtls", - return_value="system-token", - ), patch( "sap_cloud_sdk.agentgateway._customer._list_server_tools", new_callable=AsyncMock, return_value=mock_tools, ) as mock_list, ): - result = await get_mcp_tools_customer(credentials) + result = await get_mcp_tools_customer( + credentials, "pre-fetched-system-token" + ) assert len(result) == 1 assert result[0].name == "list_cost_centers" mock_list.assert_called_once() + # Verify the pre-fetched token was passed + call_args = mock_list.call_args[0] + assert call_args[1] == "pre-fetched-system-token" @pytest.mark.asyncio async def test_handles_server_error_gracefully(self): @@ -514,16 +515,14 @@ async def mock_list_tools(*args, **kwargs): return [mock_tool] with ( - patch( - "sap_cloud_sdk.agentgateway._customer.get_system_token_mtls", - return_value="system-token", - ), patch( "sap_cloud_sdk.agentgateway._customer._list_server_tools", side_effect=mock_list_tools, ), ): - result = await get_mcp_tools_customer(credentials) + result = await get_mcp_tools_customer( + credentials, "system-token" + ) # Should still return tools from server2 assert len(result) == 1 @@ -570,13 +569,9 @@ def mock_tool(self): ) @pytest.mark.asyncio - async def test_exchanges_user_token_before_call(self, credentials, mock_tool): - """Exchange user token before making tool call.""" + async def test_calls_tool_with_pre_fetched_token(self, credentials, mock_tool): + """Call tool using pre-fetched auth token.""" with ( - patch( - "sap_cloud_sdk.agentgateway._customer.exchange_user_token", - return_value="exchanged-token", - ) as mock_exchange, patch( "httpx.AsyncClient", ) as mock_client_class, @@ -613,25 +608,19 @@ async def test_exchanges_user_token_before_call(self, credentials, mock_tool): mock_session_class.return_value = mock_session_ctx result = await call_mcp_tool_customer( - credentials, mock_tool, "user-jwt", order_id="12345" + mock_tool, "pre-fetched-token", order_id="12345" ) assert result == "Order created successfully" - mock_exchange.assert_called_once_with(credentials, "user-jwt", None) + # Verify the token was used in the Authorization header + mock_client_class.assert_called_once() + call_kwargs = mock_client_class.call_args.kwargs + assert call_kwargs["headers"]["Authorization"] == "Bearer pre-fetched-token" @pytest.mark.asyncio - async def test_uses_system_token_when_user_token_not_provided( - self, credentials, mock_tool - ): - """Fall back to system token when user_token is None (IBD workaround).""" + async def test_returns_empty_string_when_no_content(self, credentials, mock_tool): + """Return empty string when tool returns no content.""" with ( - patch( - "sap_cloud_sdk.agentgateway._customer.get_system_token_mtls", - return_value="system-token", - ) as mock_system_token, - patch( - "sap_cloud_sdk.agentgateway._customer.exchange_user_token", - ) as mock_exchange, patch( "httpx.AsyncClient", ) as mock_client_class, @@ -642,7 +631,6 @@ async def test_uses_system_token_when_user_token_not_provided( "sap_cloud_sdk.agentgateway._customer.ClientSession", ) as mock_session_class, ): - # Set up mock chain mock_client = AsyncMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=None) @@ -658,21 +646,15 @@ async def test_uses_system_token_when_user_token_not_provided( mock_session = AsyncMock() mock_session.initialize = AsyncMock() mock_result = MagicMock() - mock_content = MagicMock() - mock_content.text = "Result with system token" - mock_result.content = [mock_content] + mock_result.content = [] mock_session.call_tool = AsyncMock(return_value=mock_result) mock_session_ctx = AsyncMock() mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session) mock_session_ctx.__aexit__ = AsyncMock(return_value=None) mock_session_class.return_value = mock_session_ctx - # Call without user_token (None) result = await call_mcp_tool_customer( - credentials, mock_tool, None, order_id="12345" + mock_tool, "auth-token" ) - assert result == "Result with system token" - # Should use system token, not exchange - mock_system_token.assert_called_once_with(credentials, None) - mock_exchange.assert_not_called() + assert result == "" diff --git a/tests/agentgateway/unit/test_lob.py b/tests/agentgateway/unit/test_lob.py index 4f9a6f3..b3071c9 100644 --- a/tests/agentgateway/unit/test_lob.py +++ b/tests/agentgateway/unit/test_lob.py @@ -1,5 +1,6 @@ """Unit tests for LoB agent flow.""" +import base64 import os from unittest.mock import patch, MagicMock, AsyncMock @@ -10,13 +11,15 @@ _fetch_auth_token, list_mcp_fragments, get_ias_fragment_name, - get_system_auth, - get_user_auth, + get_ias_user_fragment_name, + fetch_system_auth, + fetch_user_auth, get_mcp_tools_lob, call_mcp_tool_lob, _LABEL_KEY, _MCP_LABEL_VALUE, _IAS_LABEL_VALUE, + _IAS_USER_LABEL_VALUE, ) from sap_cloud_sdk.agentgateway._models import MCPTool from sap_cloud_sdk.agentgateway.exceptions import MCPServerNotFoundError @@ -61,11 +64,13 @@ def test_raises_when_env_not_set(self): class TestFetchAuthToken: """Tests for _fetch_auth_token function.""" - def test_fetches_token_successfully(self): - """Fetch auth token from destination service.""" + def test_fetches_and_decodes_token_and_url(self): + """Fetch token and base64-decode the value field, return tuple with gateway URL.""" + raw_token = "my-raw-jwt-token-123" mock_dest = MagicMock() mock_dest.auth_tokens = [MagicMock()] - mock_dest.auth_tokens[0].http_header = {"value": "Bearer test-token"} + mock_dest.auth_tokens[0].value = base64.b64encode(raw_token.encode()).decode() + mock_dest.url = "https://agw.example.com/" with patch( "sap_cloud_sdk.agentgateway._lob.create_destination_client" @@ -74,7 +79,7 @@ def test_fetches_token_successfully(self): result = _fetch_auth_token("dest-name", "tenant-sub") - assert result == "Bearer test-token" + assert result == (raw_token, "https://agw.example.com") mock_client.return_value.get_destination.assert_called_once_with( "dest-name", level=ConsumptionLevel.PROVIDER_SUBACCOUNT, @@ -82,6 +87,21 @@ def test_fetches_token_successfully(self): tenant="tenant-sub", ) + def test_strips_trailing_slashes_from_url(self): + """Strip trailing slashes from gateway URL.""" + raw_token = "token" + mock_dest = MagicMock() + mock_dest.auth_tokens = [MagicMock()] + mock_dest.auth_tokens[0].value = base64.b64encode(raw_token.encode()).decode() + mock_dest.url = "https://agw.example.com/v1/mcp///" + + with patch("sap_cloud_sdk.agentgateway._lob.create_destination_client") as mock_client: + mock_client.return_value.get_destination.return_value = mock_dest + + result = _fetch_auth_token("dest-name", "tenant-sub") + + assert result == (raw_token, "https://agw.example.com/v1/mcp") + def test_raises_when_no_destination(self): """Raise MCPServerNotFoundError when destination is None.""" with patch( @@ -105,25 +125,27 @@ def test_raises_when_no_auth_tokens(self): with pytest.raises(MCPServerNotFoundError, match="No auth token"): _fetch_auth_token("dest-name", "tenant-sub") - def test_raises_when_empty_auth_header(self): - """Raise MCPServerNotFoundError when auth header is empty.""" + def test_raises_when_empty_token_value(self): + """Raise MCPServerNotFoundError when token value is empty.""" mock_dest = MagicMock() mock_dest.auth_tokens = [MagicMock()] - mock_dest.auth_tokens[0].http_header = {"value": ""} + mock_dest.auth_tokens[0].value = "" with patch( "sap_cloud_sdk.agentgateway._lob.create_destination_client" ) as mock_client: mock_client.return_value.get_destination.return_value = mock_dest - with pytest.raises(MCPServerNotFoundError, match="Empty Authorization"): + with pytest.raises(MCPServerNotFoundError, match="Empty token value"): _fetch_auth_token("dest-name", "tenant-sub") def test_passes_options_to_destination(self): """Pass consumption options to get_destination.""" + raw_token = "token" mock_dest = MagicMock() mock_dest.auth_tokens = [MagicMock()] - mock_dest.auth_tokens[0].http_header = {"value": "Bearer token"} + mock_dest.auth_tokens[0].value = base64.b64encode(raw_token.encode()).decode() + mock_dest.url = "https://agw.example.com" mock_options = MagicMock() with patch( @@ -242,16 +264,65 @@ def test_raises_when_no_fragment_found(self): # ============================================================ -# Test: get_system_auth +# Test: get_ias_user_fragment_name # ============================================================ -class TestGetSystemAuth: - """Tests for get_system_auth async function.""" +class TestGetIasUserFragmentName: + """Tests for get_ias_user_fragment_name function.""" + + def test_returns_fragment_name(self): + """Return name of first IAS user fragment found.""" + fragment = MagicMock() + fragment.name = "sap-managed-runtime-agw-subscriber-ias-user-abc123" + + with patch("sap_cloud_sdk.agentgateway._lob.create_fragment_client") as mock_client: + mock_client.return_value.list_instance_fragments.return_value = [fragment] + + result = get_ias_user_fragment_name("tenant-sub") + + assert result == "sap-managed-runtime-agw-subscriber-ias-user-abc123" + + def test_uses_correct_filter_labels(self): + """Use correct label filter for IAS user fragments.""" + fragment = MagicMock() + fragment.name = "ias-user-fragment" + + with patch("sap_cloud_sdk.agentgateway._lob.create_fragment_client") as mock_client: + mock_client.return_value.list_instance_fragments.return_value = [fragment] + + get_ias_user_fragment_name("tenant-sub") + + call_args = mock_client.return_value.list_instance_fragments.call_args + filter_opt = call_args.kwargs.get("filter") + assert filter_opt is not None + assert len(filter_opt.filter_labels) == 1 + assert filter_opt.filter_labels[0].key == _LABEL_KEY + assert filter_opt.filter_labels[0].values == [_IAS_USER_LABEL_VALUE] + + def test_raises_when_no_fragment_found(self): + """Raise MCPServerNotFoundError when no IAS user fragment exists.""" + with patch("sap_cloud_sdk.agentgateway._lob.create_fragment_client") as mock_client: + mock_client.return_value.list_instance_fragments.return_value = [] + + with pytest.raises(MCPServerNotFoundError, match="No IAS user fragment found"): + get_ias_user_fragment_name("tenant-sub") + + +# ============================================================ +# Test: fetch_system_auth +# ============================================================ + + +class TestFetchSystemAuth: + """Tests for fetch_system_auth async function.""" @pytest.mark.asyncio async def test_fetches_system_auth(self): - """Fetch system auth using IAS fragment looked up by label.""" + """Fetch system auth using IAS fragment and return tuple (token, url).""" + raw_token = "system-jwt-token-xyz" + gateway_url = "https://agw.example.com" + with patch.dict(os.environ, {"APPFND_CONHOS_LANDSCAPE": "eu10"}): with ( patch( @@ -262,11 +333,11 @@ async def test_fetches_system_auth(self): ) as mock_fetch, ): mock_ias.return_value = "sap-managed-runtime-agw-subscriber-ias-abc" - mock_fetch.return_value = "Bearer system-token" + mock_fetch.return_value = (raw_token, gateway_url) - result = await get_system_auth("tenant-sub") + result = await fetch_system_auth("tenant-sub") - assert result == "Bearer system-token" + assert result == (raw_token, gateway_url) mock_ias.assert_called_once_with("tenant-sub") mock_fetch.assert_called_once() call_args = mock_fetch.call_args @@ -280,32 +351,38 @@ async def test_fetches_system_auth(self): # ============================================================ -# Test: get_user_auth +# Test: fetch_user_auth # ============================================================ -class TestGetUserAuth: - """Tests for get_user_auth async function.""" +class TestFetchUserAuth: + """Tests for fetch_user_auth async function.""" @pytest.mark.asyncio - async def test_fetches_user_auth_with_token_exchange(self): - """Fetch user auth with token exchange.""" + async def test_fetches_user_auth_with_ias_user_fragment(self): + """Fetch user auth using IAS user fragment and user_token, return tuple.""" + raw_token = "exchanged-user-jwt-token" + gateway_url = "https://agw.example.com" + with patch.dict(os.environ, {"APPFND_CONHOS_LANDSCAPE": "eu10"}): - with patch( - "sap_cloud_sdk.agentgateway._lob._fetch_auth_token" - ) as mock_fetch: - mock_fetch.return_value = "Bearer user-token" + with ( + patch("sap_cloud_sdk.agentgateway._lob.get_ias_user_fragment_name") as mock_ias_user, + patch("sap_cloud_sdk.agentgateway._lob._fetch_auth_token") as mock_fetch, + ): + mock_ias_user.return_value = "sap-managed-runtime-agw-subscriber-ias-user-abc" + mock_fetch.return_value = (raw_token, gateway_url) - result = await get_user_auth("mcp-fragment", "user-jwt", "tenant-sub") + result = await fetch_user_auth("user-jwt", "tenant-sub") - assert result == "Bearer user-token" + assert result == (raw_token, gateway_url) + mock_ias_user.assert_called_once_with("tenant-sub") mock_fetch.assert_called_once() call_args = mock_fetch.call_args assert call_args[0][0] == "sap-managed-runtime-ias-eu10" assert call_args[0][1] == "tenant-sub" options = call_args[0][2] assert options.user_token == "user-jwt" - assert options.fragment_name == "mcp-fragment" + assert options.fragment_name == "sap-managed-runtime-agw-subscriber-ias-user-abc" assert options.fragment_level == ConsumptionLevel.INSTANCE @@ -323,7 +400,7 @@ async def test_returns_empty_when_no_fragments(self): with patch("sap_cloud_sdk.agentgateway._lob.list_mcp_fragments") as mock_list: mock_list.return_value = [] - result = await get_mcp_tools_lob("tenant-sub") + result = await get_mcp_tools_lob("tenant-sub", "system-token") assert result == [] @@ -337,13 +414,13 @@ async def test_skips_fragments_without_url(self): with patch("sap_cloud_sdk.agentgateway._lob.list_mcp_fragments") as mock_list: mock_list.return_value = [fragment] - result = await get_mcp_tools_lob("tenant-sub") + result = await get_mcp_tools_lob("tenant-sub", "system-token") assert result == [] @pytest.mark.asyncio - async def test_uses_fragment_name_directly(self): - """Use fragment name as-is (no -technical stripping).""" + async def test_uses_pre_fetched_system_token(self): + """Use the pre-fetched system token for MCP server calls.""" fragment = MagicMock() fragment.name = "mcp-server-a" fragment.properties = {"URL": "https://example.com/mcp"} @@ -359,27 +436,20 @@ async def test_uses_fragment_name_directly(self): with ( patch("sap_cloud_sdk.agentgateway._lob.list_mcp_fragments") as mock_list, - patch( - "sap_cloud_sdk.agentgateway._lob.get_system_auth", - new_callable=AsyncMock, - ) as mock_auth, patch( "sap_cloud_sdk.agentgateway._lob.list_server_tools", new_callable=AsyncMock, ) as mock_tools, ): mock_list.return_value = [fragment] - mock_auth.return_value = "Bearer token" mock_tools.return_value = [mock_tool] - await get_mcp_tools_lob("tenant-sub") + await get_mcp_tools_lob("tenant-sub", "pre-fetched-token") - # Verify get_system_auth called with just tenant_subdomain - mock_auth.assert_called_once_with("tenant-sub") - # Verify list_server_tools called with the unchanged fragment name - mock_tools.assert_called_once() - call_args = mock_tools.call_args[0] - assert call_args[2] == "mcp-server-a" + # Verify list_server_tools called with the pre-fetched token + mock_tools.assert_called_once_with( + "https://example.com/mcp", "pre-fetched-token", "mcp-server-a" + ) @pytest.mark.asyncio async def test_handles_exception_for_single_fragment(self): @@ -401,24 +471,25 @@ async def test_handles_exception_for_single_fragment(self): fragment_name="mcp-server2", ) + call_count = 0 + + async def mock_list_tools_fn(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise Exception("Server connection failed") + return [mock_tool] + with ( patch("sap_cloud_sdk.agentgateway._lob.list_mcp_fragments") as mock_list, - patch( - "sap_cloud_sdk.agentgateway._lob.get_system_auth", - new_callable=AsyncMock, - ) as mock_auth, patch( "sap_cloud_sdk.agentgateway._lob.list_server_tools", - new_callable=AsyncMock, - ) as mock_tools, + side_effect=mock_list_tools_fn, + ), ): mock_list.return_value = [fragment1, fragment2] - # First fragment fails, second succeeds - mock_auth.side_effect = [Exception("Auth failed"), "Bearer token"] - mock_tools.return_value = [mock_tool] - - result = await get_mcp_tools_lob("tenant-sub") + result = await get_mcp_tools_lob("tenant-sub", "system-token") # Should still get tools from second fragment assert len(result) == 1 @@ -434,8 +505,8 @@ class TestCallMcpToolLob: """Tests for call_mcp_tool_lob async function.""" @pytest.mark.asyncio - async def test_calls_tool_with_user_auth(self): - """Call tool using user authentication.""" + async def test_calls_tool_with_pre_fetched_token(self): + """Call tool using pre-fetched user auth token.""" tool = MCPTool( name="test-tool", server_name="test-server", @@ -450,17 +521,12 @@ async def test_calls_tool_with_user_auth(self): mock_result.content[0].text = "Tool result" with ( - patch( - "sap_cloud_sdk.agentgateway._lob.get_user_auth", new_callable=AsyncMock - ) as mock_auth, patch("sap_cloud_sdk.agentgateway._lob.httpx.AsyncClient") as mock_http, patch( "sap_cloud_sdk.agentgateway._lob.streamable_http_client" ) as mock_stream, patch("sap_cloud_sdk.agentgateway._lob.ClientSession") as mock_session, ): - mock_auth.return_value = "Bearer user-token" - # Setup async context managers mock_http_instance = AsyncMock() mock_http.return_value.__aenter__.return_value = mock_http_instance @@ -477,15 +543,19 @@ async def test_calls_tool_with_user_auth(self): mock_session.return_value.__aenter__.return_value = mock_session_instance result = await call_mcp_tool_lob( - tool, "user-jwt", "tenant-sub", param1="value1" + tool, "user-auth-token", param1="value1" ) assert result == "Tool result" - mock_auth.assert_called_once_with("test-fragment", "user-jwt", "tenant-sub") mock_session_instance.call_tool.assert_called_once_with( "test-tool", {"param1": "value1"} ) + # Verify the Authorization header uses Bearer + raw token + mock_http.assert_called_once() + call_kwargs = mock_http.call_args.kwargs + assert call_kwargs["headers"]["Authorization"] == "Bearer user-auth-token" + @pytest.mark.asyncio async def test_returns_empty_string_when_no_content(self): """Return empty string when tool returns no content.""" @@ -502,17 +572,12 @@ async def test_returns_empty_string_when_no_content(self): mock_result.content = [] with ( - patch( - "sap_cloud_sdk.agentgateway._lob.get_user_auth", new_callable=AsyncMock - ) as mock_auth, patch("sap_cloud_sdk.agentgateway._lob.httpx.AsyncClient") as mock_http, patch( "sap_cloud_sdk.agentgateway._lob.streamable_http_client" ) as mock_stream, patch("sap_cloud_sdk.agentgateway._lob.ClientSession") as mock_session, ): - mock_auth.return_value = "Bearer user-token" - mock_http_instance = AsyncMock() mock_http.return_value.__aenter__.return_value = mock_http_instance @@ -527,6 +592,6 @@ async def test_returns_empty_string_when_no_content(self): mock_session_instance.call_tool = AsyncMock(return_value=mock_result) mock_session.return_value.__aenter__.return_value = mock_session_instance - result = await call_mcp_tool_lob(tool, "user-jwt", "tenant-sub") + result = await call_mcp_tool_lob(tool, "user-auth-token") assert result == "" diff --git a/tests/core/unit/telemetry/test_operation.py b/tests/core/unit/telemetry/test_operation.py index 1205626..651b3ec 100644 --- a/tests/core/unit/telemetry/test_operation.py +++ b/tests/core/unit/telemetry/test_operation.py @@ -180,5 +180,5 @@ def test_operation_count(self): """Test that we have the expected number of operations.""" all_operations = list(Operation) # 3 auditlog + 11 destination + 10 certificate + 10 fragment + 8 objectstore - # + 2 extensibility + 2 aicore + 23 dms + 2 agentgateway + 13 agent_memory = 84 - assert len(all_operations) == 84 + # + 2 extensibility + 2 aicore + 23 dms + 4 agentgateway + 13 agent_memory = 86 + assert len(all_operations) == 86