diff --git a/ballerina-interpreter/agent.bal b/ballerina-interpreter/agent.bal index d60e55a..7b68e95 100644 --- a/ballerina-interpreter/agent.bal +++ b/ballerina-interpreter/agent.bal @@ -350,37 +350,134 @@ function mapToHttpClientAuth(ClientAuthentication? auth) returns http:ClientAuth "bearer" => { return rest.cloneWithType(http:BearerTokenConfig); } + "api-key" => { + return error("API key authentication is not yet supported for MCP/webhook transport in the Ballerina interpreter"); + } "oauth2" => { - // record {string grantType;}|error oauth2Config = check rest.cloneWithType(); - // if oauth2Config is error { - // return error("OAuth2 authentication requires 'grantType' field", oauth2Config); - // } - - // var {grantType, ...oauth2ConfigRest} = oauth2Config; - - // match grantType.toLowerAscii() { - // "client_credentials" => { - // return oauth2ConfigRest.cloneWithType(http:OAuth2ClientCredentialsGrantConfig); - // } - // "password" => { - // return oauth2ConfigRest.cloneWithType(http:OAuth2PasswordGrantConfig); - // } - // "refresh_token" => { - // return oauth2ConfigRest.cloneWithType(http:OAuth2RefreshTokenGrantConfig); - // } - // "jwt" => { - // return oauth2Config.cloneWithType(http:OAuth2JwtBearerGrantConfig); - // } - // } - // panic error(string `Unsupported OAuth2 grant type: ${grantType}`); - return error("OAuth2 authentication not yet supported"); + OAuth2Config|error oauth2Config = rest.cloneWithType(); + if oauth2Config is error { + return error("Invalid OAuth2 authentication configuration", oauth2Config); + } + return buildOAuth2GrantConfig(oauth2Config); } "jwt" => { - // return rest.cloneWithType(http:JwtIssuerConfig); - return error("JWT authentication not yet supported"); + JwtAuthConfig|error jwtConfig = rest.cloneWithType(); + if jwtConfig is error { + return error("Invalid JWT authentication configuration", jwtConfig); + } + return buildJwtIssuerConfig(jwtConfig); } _ => { return error(string `Unsupported authentication type: ${'type}`); } } } + +type JwtAuthConfig record {| + string issuer; + string|string[] audience; + string signing_key; + string algorithm = "RS256"; + string key_id?; + string subject?; + map custom_claims?; + decimal expiry_seconds = 300; +|}; + +function buildJwtIssuerConfig(JwtAuthConfig jwtConfig) returns http:JwtIssuerConfig|error { + string algorithm = jwtConfig.algorithm; + boolean isHmac = algorithm == "HS256" || algorithm == "HS384" || algorithm == "HS512"; + json signatureKeyConfig = isHmac ? jwtConfig.signing_key : {keyFile: jwtConfig.signing_key}; + + map issuerConfig = { + issuer: jwtConfig.issuer, + audience: jwtConfig.audience, + expTime: jwtConfig.expiry_seconds, + signatureConfig: { + algorithm, + config: signatureKeyConfig + } + }; + + string? keyId = jwtConfig?.key_id; + if keyId is string { + issuerConfig["keyId"] = keyId; + } + string? subject = jwtConfig?.subject; + if subject is string { + issuerConfig["username"] = subject; + } + map? customClaims = jwtConfig?.custom_claims; + if customClaims is map { + issuerConfig["customClaims"] = customClaims; + } + + http:JwtIssuerConfig|error result = issuerConfig.cloneWithType(); + if result is error { + return error("Invalid JWT authentication configuration", result); + } + return result; +} + +type OAuth2Config record {| + string grant_type; + string token_url?; + string refresh_url?; + string client_id?; + string client_secret?; + string username?; + string password?; + string refresh_token?; + string assertion?; + string[] scopes?; +|}; + +function buildOAuth2GrantConfig(OAuth2Config cfg) returns http:OAuth2GrantConfig|error { + string grant = cfg.grant_type.toLowerAscii(); + match grant { + "client_credentials" => { + map grantConfig = {tokenUrl: cfg?.token_url, clientId: cfg?.client_id, clientSecret: cfg?.client_secret}; + addScopes(grantConfig, cfg?.scopes); + return wrapOAuth2(grantConfig.cloneWithType(http:OAuth2ClientCredentialsGrantConfig)); + } + "password" => { + map grantConfig = {tokenUrl: cfg?.token_url, username: cfg?.username, password: cfg?.password}; + addOptional(grantConfig, "clientId", cfg?.client_id); + addOptional(grantConfig, "clientSecret", cfg?.client_secret); + addScopes(grantConfig, cfg?.scopes); + return wrapOAuth2(grantConfig.cloneWithType(http:OAuth2PasswordGrantConfig)); + } + "refresh_token" => { + map grantConfig = {refreshUrl: cfg?.refresh_url, refreshToken: cfg?.refresh_token, clientId: cfg?.client_id, clientSecret: cfg?.client_secret}; + addScopes(grantConfig, cfg?.scopes); + return wrapOAuth2(grantConfig.cloneWithType(http:OAuth2RefreshTokenGrantConfig)); + } + "jwt_bearer" => { + map grantConfig = {tokenUrl: cfg?.token_url, assertion: cfg?.assertion}; + addOptional(grantConfig, "clientId", cfg?.client_id); + addOptional(grantConfig, "clientSecret", cfg?.client_secret); + addScopes(grantConfig, cfg?.scopes); + return wrapOAuth2(grantConfig.cloneWithType(http:OAuth2JwtBearerGrantConfig)); + } + } + return error(string `Unsupported OAuth2 grant type: ${cfg.grant_type}`); +} + +function wrapOAuth2(http:OAuth2GrantConfig|error result) returns http:OAuth2GrantConfig|error { + if result is error { + return error("Invalid OAuth2 authentication configuration", result); + } + return result; +} + +function addScopes(map target, string[]? scopes) { + if scopes is string[] { + target["scopes"] = scopes; + } +} + +function addOptional(map target, string key, string? value) { + if value is string { + target[key] = value; + } +} diff --git a/ballerina-interpreter/parser.bal b/ballerina-interpreter/parser.bal index 8758272..90ff7f9 100644 --- a/ballerina-interpreter/parser.bal +++ b/ballerina-interpreter/parser.bal @@ -31,6 +31,8 @@ function parseAfm(string content) returns AFMRecord|error { body = resolvedContent; } + check validateMetadataAuthentication(metadata); + // Extract Role and Instructions sections string[] bodyLines = splitLines(body); string role = ""; @@ -305,6 +307,158 @@ function authenticationContainsHttpVariable(ClientAuthentication? authentication return false; } +final readonly & string[] RECOGNIZED_AUTH_TYPES = ["bearer", "basic", "api-key", "jwt", "oauth2"]; + +function validateMetadataAuthentication(AgentMetadata? metadata) returns error? { + if metadata is () { + return; + } + + Model? model = metadata.model; + if model is Model { + check validateAuthentication(model.authentication); + } + + Tools? tools = metadata.tools; + if tools is Tools { + MCPServer[]? mcp = tools.mcp; + if mcp is MCPServer[] { + foreach MCPServer server in mcp { + Transport transport = server.transport; + if transport is HttpTransport { + check validateAuthentication(transport.authentication); + } + } + } + } + + Interface[]? interfaces = metadata.interfaces; + if interfaces is Interface[] { + foreach Interface interface in interfaces { + if interface is WebhookInterface { + check validateAuthentication(interface.subscription.authentication); + } + } + } +} + +function validateAuthentication(ClientAuthentication? auth) returns error? { + if auth is () { + return; + } + + string authType = auth.'type.toLowerAscii(); + if RECOGNIZED_AUTH_TYPES.indexOf(authType) is () { + return error(string `unknown authentication type '${auth.'type}'. Supported types: bearer, basic, api-key, jwt, oauth2`); + } + + if authType == "oauth2" { + return validateOAuth2(auth); + } + + string[]? allowed = allowedAuthFields(authType); + if allowed is () { + return; + } + + string[] provided = from string key in auth.keys() where key != "type" select key; + + foreach string required in requiredAuthFields(authType) { + if provided.indexOf(required) is () { + return error(string `type '${authType}' requires '${required}' field`); + } + } + + foreach string fieldName in provided { + if allowed.indexOf(fieldName) is () { + return error(string `type '${authType}' does not support '${fieldName}' field`); + } + } +} + +function allowedAuthFields(string authType) returns string[]? { + match authType { + "bearer" => { + return ["token"]; + } + "basic" => { + return ["username", "password"]; + } + "api-key" => { + return ["api_key", "header_name"]; + } + "jwt" => { + return ["issuer", "audience", "signing_key", "algorithm", "key_id", "subject", "custom_claims", "expiry_seconds"]; + } + } + return (); +} + +function requiredAuthFields(string authType) returns string[] { + match authType { + "bearer" => { + return ["token"]; + } + "basic" => { + return ["username", "password"]; + } + "api-key" => { + return ["api_key"]; + } + "jwt" => { + return ["issuer", "audience", "signing_key"]; + } + } + return []; +} + +function validateOAuth2(ClientAuthentication auth) returns error? { + anydata grantTypeValue = auth["grant_type"]; + if grantTypeValue !is string { + return error("type 'oauth2' requires 'grant_type' field"); + } + + string grant = grantTypeValue.toLowerAscii(); + [string[], string[]]? grantFields = oauth2GrantFields(grant); + if grantFields is () { + return error("oauth2 grant_type '" + grantTypeValue + "' is not supported. Supported grant types: client_credentials, password, refresh_token, jwt_bearer"); + } + + [string[], string[]] [required, optional] = grantFields; + string[] allowed = ["grant_type", ...required, ...optional]; + string[] provided = from string key in auth.keys() where key != "type" select key; + + foreach string req in required { + if provided.indexOf(req) is () { + return error("oauth2 grant_type '" + grant + "' requires '" + req + "' field"); + } + } + + foreach string fieldName in provided { + if allowed.indexOf(fieldName) is () { + return error("oauth2 grant_type '" + grant + "' does not support '" + fieldName + "' field"); + } + } +} + +function oauth2GrantFields(string grant) returns [string[], string[]]? { + match grant { + "client_credentials" => { + return [["token_url", "client_id", "client_secret"], ["scopes"]]; + } + "password" => { + return [["token_url", "username", "password", "client_id", "client_secret"], ["scopes"]]; + } + "refresh_token" => { + return [["refresh_url", "refresh_token", "client_id", "client_secret"], ["scopes"]]; + } + "jwt_bearer" => { + return [["token_url", "assertion"], ["client_id", "client_secret", "scopes"]]; + } + } + return (); +} + function signatureContainsHttpVariable(Signature signature) returns boolean => jsonSchemaContainsHttpVariable(signature.input) || jsonSchemaContainsHttpVariable(signature.output); diff --git a/ballerina-interpreter/tests/main_test.bal b/ballerina-interpreter/tests/main_test.bal index e9ad324..fb37d56 100644 --- a/ballerina-interpreter/tests/main_test.bal +++ b/ballerina-interpreter/tests/main_test.bal @@ -1162,29 +1162,69 @@ function testMapToHttpClientAuthBearer() returns error? { } @test:Config -function testMapToHttpClientAuthOAuth2NotSupported() { +function testMapToHttpClientAuthOAuth2ClientCredentials() returns error? { ClientAuthentication auth = { - 'type: "oauth2" + 'type: "oauth2", + "grant_type": "client_credentials", + "token_url": "https://auth.example.com/token", + "client_id": "id", + "client_secret": "secret", + "scopes": ["read", "write"] }; - http:ClientAuthConfig|error? result = mapToHttpClientAuth(auth); - if result is http:ClientAuthConfig? { - test:assertFail("Expected error for OAuth2 authentication"); - } - test:assertEquals(result.message(), "OAuth2 authentication not yet supported"); + http:ClientAuthConfig? result = check mapToHttpClientAuth(auth); + test:assertTrue(result is http:OAuth2ClientCredentialsGrantConfig); + http:OAuth2ClientCredentialsGrantConfig grantConfig = result; + test:assertEquals(grantConfig.tokenUrl, "https://auth.example.com/token"); + test:assertEquals(grantConfig.clientId, "id"); } @test:Config -function testMapToHttpClientAuthJWTNotSupported() { +function testMapToHttpClientAuthOAuth2RefreshToken() returns error? { ClientAuthentication auth = { - 'type: "jwt" + 'type: "oauth2", + "grant_type": "refresh_token", + "refresh_url": "https://auth.example.com/token", + "refresh_token": "rt", + "client_id": "id", + "client_secret": "secret" }; - http:ClientAuthConfig|error? result = mapToHttpClientAuth(auth); - if result is http:ClientAuthConfig? { - test:assertFail("Expected error for JWT authentication"); - } - test:assertEquals(result.message(), "JWT authentication not yet supported"); + http:ClientAuthConfig? result = check mapToHttpClientAuth(auth); + test:assertTrue(result is http:OAuth2RefreshTokenGrantConfig); + http:OAuth2RefreshTokenGrantConfig grantConfig = result; + test:assertEquals(grantConfig.refreshUrl, "https://auth.example.com/token"); + test:assertEquals(grantConfig.refreshToken, "rt"); +} + +@test:Config +function testMapToHttpClientAuthJwtHmac() returns error? { + ClientAuthentication auth = { + 'type: "jwt", + "issuer": "afm-agent", + "audience": "https://api.example.com", + "signing_key": "shared-secret", + "algorithm": "HS256" + }; + + http:ClientAuthConfig? result = check mapToHttpClientAuth(auth); + test:assertTrue(result is http:JwtIssuerConfig); + http:JwtIssuerConfig issuerConfig = result; + test:assertEquals(issuerConfig.issuer, "afm-agent"); + test:assertEquals(issuerConfig.audience, "https://api.example.com"); +} + +@test:Config +function testMapToHttpClientAuthJwtRs256() returns error? { + ClientAuthentication auth = { + 'type: "jwt", + "issuer": "afm-agent", + "audience": "https://api.example.com", + "signing_key": "/path/to/key.pem" + }; + + http:ClientAuthConfig? result = check mapToHttpClientAuth(auth); + test:assertTrue(result is http:JwtIssuerConfig); } @test:Config @@ -1199,3 +1239,170 @@ function testMapToHttpClientAuthUnsupportedType() { } test:assertEquals(result.message(), "Unsupported authentication type: custom-auth"); } + +@test:Config +function testMapToHttpClientAuthApiKeyNotSupported() { + ClientAuthentication auth = { + 'type: "api-key", + "api_key": "test-key" + }; + + http:ClientAuthConfig|error? result = mapToHttpClientAuth(auth); + if result is http:ClientAuthConfig? { + test:assertFail("Expected error for api-key on HTTP transport"); + } + test:assertEquals(result.message(), + "API key authentication is not yet supported for MCP/webhook transport in the Ballerina interpreter"); +} + + +@test:Config +function testValidateAuthenticationNull() returns error? { + error? result = validateAuthentication(()); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationBearerValid() returns error? { + ClientAuthentication auth = {'type: "bearer", "token": "t"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationBasicValid() returns error? { + ClientAuthentication auth = {'type: "basic", "username": "u", "password": "p"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationApiKeyValid() returns error? { + ClientAuthentication auth = {'type: "api-key", "api_key": "k", "header_name": "X-API-Key"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationCaseInsensitive() returns error? { + ClientAuthentication auth = {'type: "Bearer", "token": "t"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationJwtValid() returns error? { + ClientAuthentication auth = { + 'type: "jwt", + "issuer": "afm-agent", + "audience": "https://api.example.com", + "signing_key": "secret" + }; + error? result = validateAuthentication(auth); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationJwtMissingField() { + ClientAuthentication auth = {'type: "jwt", "issuer": "afm-agent"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertEquals((result).message(), "type 'jwt' requires 'audience' field"); +} + +@test:Config +function testValidateAuthenticationOAuth2Valid() returns error? { + ClientAuthentication auth = { + 'type: "oauth2", + "grant_type": "client_credentials", + "token_url": "https://auth.example.com/token", + "client_id": "id", + "client_secret": "secret", + "scopes": ["read"] + }; + error? result = validateAuthentication(auth); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationOAuth2MissingGrantType() { + ClientAuthentication auth = {'type: "oauth2", "token_url": "u"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertEquals((result).message(), "type 'oauth2' requires 'grant_type' field"); +} + +@test:Config +function testValidateAuthenticationOAuth2UnknownGrantType() { + ClientAuthentication auth = {'type: "oauth2", "grant_type": "device_code", "token_url": "u"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertTrue((result).message().includes("grant_type 'device_code' is not supported")); +} + +@test:Config +function testValidateAuthenticationOAuth2MissingRequiredField() { + ClientAuthentication auth = { + 'type: "oauth2", + "grant_type": "client_credentials", + "token_url": "u", + "client_id": "id" + }; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertEquals((result).message(), "oauth2 grant_type 'client_credentials' requires 'client_secret' field"); +} + +@test:Config +function testValidateAuthenticationOAuth2FieldNotAllowed() { + ClientAuthentication auth = { + 'type: "oauth2", + "grant_type": "client_credentials", + "token_url": "u", + "client_id": "id", + "client_secret": "secret", + "refresh_token": "rt" + }; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertEquals((result).message(), "oauth2 grant_type 'client_credentials' does not support 'refresh_token' field"); +} + +@test:Config +function testValidateAuthenticationUnknownType() { + ClientAuthentication auth = {'type: "token", "token": "t"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertTrue((result).message().includes("unknown authentication type 'token'")); +} + +@test:Config +function testValidateAuthenticationBearerMissingToken() { + ClientAuthentication auth = {'type: "bearer"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertEquals((result).message(), "type 'bearer' requires 'token' field"); +} + +@test:Config +function testValidateAuthenticationUnknownField() { + ClientAuthentication auth = {'type: "bearer", "token": "t", "username": "u"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertEquals((result).message(), "type 'bearer' does not support 'username' field"); +} + +@test:Config +function testParseAfmRejectsInvalidAuthentication() { + string content = "---\n" + + "model:\n" + + " provider: openai\n" + + " name: gpt-4\n" + + " authentication:\n" + + " type: bearer\n" + + "---\n\n" + + "# Role\nRole\n\n# Instructions\nInstructions\n"; + AFMRecord|error result = parseAfm(content); + test:assertTrue(result is error); + test:assertEquals((result).message(), "type 'bearer' requires 'token' field"); +} diff --git a/python-interpreter/packages/afm-core/src/afm/models.py b/python-interpreter/packages/afm-core/src/afm/models.py index 101f6d9..0245f3c 100644 --- a/python-interpreter/packages/afm-core/src/afm/models.py +++ b/python-interpreter/packages/afm-core/src/afm/models.py @@ -18,7 +18,7 @@ from enum import Enum from pathlib import Path -from typing import Annotated, Literal, Self +from typing import Annotated, Any, Literal, Self from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -30,6 +30,75 @@ class Provider(BaseModel): url: str | None = None +RECOGNIZED_AUTH_TYPES = ("bearer", "basic", "api-key", "jwt", "oauth2") + + +_JWT_ALLOWED_FIELDS = { + "issuer", + "audience", + "signing_key", + "algorithm", + "key_id", + "subject", + "custom_claims", + "expiry_seconds", +} +_AUTH_REQUIRED_FIELDS: dict[str, set[str]] = { + "bearer": {"token"}, + "basic": {"username", "password"}, + "api-key": {"api_key"}, + "jwt": {"issuer", "audience", "signing_key"}, +} +_AUTH_ALLOWED_FIELDS: dict[str, set[str]] = { + "bearer": {"token"}, + "basic": {"username", "password"}, + "api-key": {"api_key", "header_name"}, + "jwt": _JWT_ALLOWED_FIELDS, +} + +_OAUTH2_GRANTS: dict[str, dict[str, set[str]]] = { + "client_credentials": { + "required": {"token_url", "client_id", "client_secret"}, + "optional": {"scopes"}, + }, + "password": { + "required": {"token_url", "username", "password", "client_id", "client_secret"}, + "optional": {"scopes"}, + }, + "refresh_token": { + "required": {"refresh_url", "refresh_token", "client_id", "client_secret"}, + "optional": {"scopes"}, + }, + "jwt_bearer": { + "required": {"token_url", "assertion"}, + "optional": {"client_id", "client_secret", "scopes"}, + }, +} +_AUTH_CREDENTIAL_FIELDS = ( + "token", + "username", + "password", + "api_key", + "header_name", + "issuer", + "audience", + "signing_key", + "algorithm", + "key_id", + "subject", + "custom_claims", + "expiry_seconds", + "grant_type", + "token_url", + "refresh_url", + "client_id", + "client_secret", + "refresh_token", + "assertion", + "scopes", +) + + class ClientAuthentication(BaseModel): model_config = ConfigDict(extra="allow") @@ -38,23 +107,86 @@ class ClientAuthentication(BaseModel): username: str | None = None password: str | None = None api_key: str | None = None + header_name: str | None = None + issuer: str | None = None + audience: str | list[str] | None = None + signing_key: str | None = None + algorithm: str | None = None + key_id: str | None = None + subject: str | None = None + custom_claims: dict[str, Any] | None = None + expiry_seconds: int | None = None + grant_type: str | None = None + token_url: str | None = None + refresh_url: str | None = None + client_id: str | None = None + client_secret: str | None = None + refresh_token: str | None = None + assertion: str | None = None + scopes: list[str] | None = None @model_validator(mode="after") def validate_type_fields(self) -> Self: - match self.type.lower(): - case "bearer": - if self.token is None: - raise ValueError("type 'bearer' requires 'token' field") - case "basic": - if self.username is None or self.password is None: - raise ValueError( - "type 'basic' requires 'username' and 'password' fields" - ) - case "api-key": - if self.api_key is None: - raise ValueError("type 'api-key' requires 'api_key' field") + auth_type = self.type.lower() + + if auth_type not in RECOGNIZED_AUTH_TYPES: + supported = ", ".join(RECOGNIZED_AUTH_TYPES) + raise ValueError( + f"unknown authentication type '{self.type}'. " + f"Supported types: {supported}" + ) + + provided = { + name for name in _AUTH_CREDENTIAL_FIELDS if getattr(self, name) is not None + } + provided |= set(self.model_extra or {}) + + if auth_type == "oauth2": + self._validate_oauth2_fields(provided) + return self + + missing = _AUTH_REQUIRED_FIELDS[auth_type] - provided + if missing: + fields = ", ".join(f"'{name}'" for name in sorted(missing)) + suffix = "field" if len(missing) == 1 else "fields" + raise ValueError(f"type '{auth_type}' requires {fields} {suffix}") + + unknown = provided - _AUTH_ALLOWED_FIELDS[auth_type] + if unknown: + fields = ", ".join(f"'{name}'" for name in sorted(unknown)) + suffix = "field" if len(unknown) == 1 else "fields" + raise ValueError(f"type '{auth_type}' does not support {fields} {suffix}") + return self + def _validate_oauth2_fields(self, provided: set[str]) -> None: + if self.grant_type is None: + raise ValueError("type 'oauth2' requires 'grant_type' field") + + grant = self.grant_type.lower() + spec = _OAUTH2_GRANTS.get(grant) + if spec is None: + supported = ", ".join(_OAUTH2_GRANTS) + raise ValueError( + f"oauth2 grant_type '{self.grant_type}' is not supported. " + f"Supported grant types: {supported}" + ) + + missing = spec["required"] - provided + if missing: + fields = ", ".join(f"'{name}'" for name in sorted(missing)) + suffix = "field" if len(missing) == 1 else "fields" + raise ValueError(f"oauth2 grant_type '{grant}' requires {fields} {suffix}") + + allowed = {"grant_type"} | spec["required"] | spec["optional"] + unknown = provided - allowed + if unknown: + fields = ", ".join(f"'{name}'" for name in sorted(unknown)) + suffix = "field" if len(unknown) == 1 else "fields" + raise ValueError( + f"oauth2 grant_type '{grant}' does not support {fields} {suffix}" + ) + class Model(BaseModel): model_config = ConfigDict(extra="forbid") diff --git a/python-interpreter/packages/afm-core/tests/test_parser.py b/python-interpreter/packages/afm-core/tests/test_parser.py index 0b517ff..6f294dc 100644 --- a/python-interpreter/packages/afm-core/tests/test_parser.py +++ b/python-interpreter/packages/afm-core/tests/test_parser.py @@ -17,9 +17,11 @@ from pathlib import Path import pytest +from pydantic import ValidationError from afm.exceptions import AFMParseError, AFMValidationError, VariableResolutionError from afm.models import ( + ClientAuthentication, ConsoleChatInterface, HttpTransport, StdioTransport, @@ -493,3 +495,192 @@ def test_parse_afm_default_behavior_resolves_env(self, monkeypatch) -> None: assert result.metadata.model is not None assert result.metadata.model.authentication is not None assert result.metadata.model.authentication.token == "secret-token-123" + + +class TestClientAuthenticationValidation: + + def test_bearer_valid(self) -> None: + auth = ClientAuthentication(type="bearer", token="t") + assert auth.type == "bearer" + assert auth.token == "t" + + def test_basic_valid(self) -> None: + auth = ClientAuthentication(type="basic", username="u", password="p") + assert auth.username == "u" + assert auth.password == "p" + + def test_api_key_valid(self) -> None: + auth = ClientAuthentication(type="api-key", api_key="k") + assert auth.api_key == "k" + assert auth.header_name is None + + def test_api_key_with_header_name(self) -> None: + auth = ClientAuthentication(type="api-key", api_key="k", header_name="X-API-Key") + assert auth.header_name == "X-API-Key" + + def test_type_is_case_insensitive(self) -> None: + auth = ClientAuthentication(type="Bearer", token="t") + assert auth.token == "t" + + def test_jwt_valid(self) -> None: + auth = ClientAuthentication( + type="jwt", + issuer="afm-agent", + audience="https://api.example.com", + signing_key="secret", + ) + assert auth.issuer == "afm-agent" + assert auth.audience == "https://api.example.com" + + def test_jwt_audience_list(self) -> None: + auth = ClientAuthentication( + type="jwt", issuer="i", audience=["a", "b"], signing_key="s" + ) + assert auth.audience == ["a", "b"] + + def test_jwt_missing_signing_key_rejected(self) -> None: + with pytest.raises( + ValidationError, match="type 'jwt' requires 'signing_key'" + ): + ClientAuthentication(type="jwt", issuer="i", audience="a") + + def test_jwt_unknown_field_rejected(self) -> None: + with pytest.raises(ValidationError, match="does not support"): + ClientAuthentication( + type="jwt", issuer="i", audience="a", signing_key="s", token="x" + ) + + def test_oauth2_client_credentials_valid(self) -> None: + auth = ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="https://auth.example.com/token", + client_id="id", + client_secret="secret", + scopes=["read", "write"], + ) + assert auth.grant_type == "client_credentials" + assert auth.scopes == ["read", "write"] + + def test_oauth2_password_valid(self) -> None: + auth = ClientAuthentication( + type="oauth2", + grant_type="password", + token_url="https://auth.example.com/token", + username="u", + password="p", + client_id="id", + client_secret="secret", + ) + assert auth.grant_type == "password" + + def test_oauth2_refresh_token_valid(self) -> None: + auth = ClientAuthentication( + type="oauth2", + grant_type="refresh_token", + refresh_url="https://auth.example.com/token", + refresh_token="rt", + client_id="id", + client_secret="secret", + ) + assert auth.grant_type == "refresh_token" + + def test_oauth2_jwt_bearer_valid(self) -> None: + auth = ClientAuthentication( + type="oauth2", + grant_type="jwt_bearer", + token_url="https://auth.example.com/token", + assertion="signed.jwt.token", + ) + assert auth.grant_type == "jwt_bearer" + + def test_oauth2_grant_type_case_insensitive(self) -> None: + auth = ClientAuthentication( + type="oauth2", + grant_type="Client_Credentials", + token_url="u", + client_id="id", + client_secret="secret", + ) + assert auth.grant_type == "Client_Credentials" + + def test_oauth2_missing_grant_type_rejected(self) -> None: + with pytest.raises( + ValidationError, match="type 'oauth2' requires 'grant_type'" + ): + ClientAuthentication(type="oauth2", token_url="u") + + def test_oauth2_unknown_grant_type_rejected(self) -> None: + with pytest.raises( + ValidationError, match="grant_type 'device_code' is not supported" + ): + ClientAuthentication( + type="oauth2", grant_type="device_code", token_url="u" + ) + + def test_oauth2_missing_required_field_rejected(self) -> None: + with pytest.raises( + ValidationError, + match="grant_type 'client_credentials' requires 'client_secret'", + ): + ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="u", + client_id="id", + ) + + def test_oauth2_field_not_allowed_for_grant_rejected(self) -> None: + with pytest.raises(ValidationError, match="does not support 'refresh_token'"): + ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="u", + client_id="id", + client_secret="secret", + refresh_token="rt", + ) + + def test_unknown_type_rejected(self) -> None: + with pytest.raises(ValidationError, match="unknown authentication type 'token'"): + ClientAuthentication(type="token", token="t") + + def test_bearer_missing_token_rejected(self) -> None: + with pytest.raises(ValidationError, match="type 'bearer' requires 'token'"): + ClientAuthentication(type="bearer") + + def test_basic_missing_password_rejected(self) -> None: + with pytest.raises(ValidationError, match="type 'basic' requires 'password'"): + ClientAuthentication(type="basic", username="u") + + def test_api_key_missing_key_rejected(self) -> None: + with pytest.raises(ValidationError, match="type 'api-key' requires 'api_key'"): + ClientAuthentication(type="api-key") + + def test_unknown_field_rejected(self) -> None: + with pytest.raises( + ValidationError, match="type 'bearer' does not support 'username'" + ): + ClientAuthentication(type="bearer", token="t", username="u") + + def test_typo_field_rejected(self) -> None: + with pytest.raises(ValidationError): + ClientAuthentication(type="api-key", api_key="k", headername="X") + + def test_invalid_auth_fails_at_parse_time(self) -> None: + content = """--- +model: + provider: openai + name: gpt-4 + authentication: + type: bearer +--- + +# Role +Role + +# Instructions +Instructions +""" + with pytest.raises(AFMValidationError): + parse_afm(content) diff --git a/python-interpreter/packages/afm-langchain/pyproject.toml b/python-interpreter/packages/afm-langchain/pyproject.toml index f6d7ffd..6e616ac 100644 --- a/python-interpreter/packages/afm-langchain/pyproject.toml +++ b/python-interpreter/packages/afm-langchain/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "langchain-anthropic>=1.3.1", "mcp>=1.26.0", "langchain-mcp-adapters>=0.2.1", + "pyjwt[crypto]>=2.10.0", ] [project.entry-points."afm.runner"] diff --git a/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py b/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py index 709a065..65e97a5 100644 --- a/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py +++ b/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py @@ -17,8 +17,11 @@ from __future__ import annotations import logging +import time +from pathlib import Path import httpx +import jwt from afm.exceptions import ( MCPAuthenticationError, MCPConnectionError, @@ -58,6 +61,204 @@ def auth_flow(self, request: httpx.Request): yield request +_HMAC_JWT_ALGORITHMS = {"HS256", "HS384", "HS512"} + + +class JwtAuth(httpx.Auth): + + def __init__( + self, + *, + issuer: str, + audience: str | list[str], + signing_key: str, + algorithm: str = "RS256", + key_id: str | None = None, + subject: str | None = None, + custom_claims: dict | None = None, + expiry_seconds: int = 300, + ) -> None: + self.issuer = issuer + self.audience = audience + self.signing_key = signing_key + self.algorithm = algorithm + self.key_id = key_id + self.subject = subject + self.custom_claims = custom_claims or {} + self.expiry_seconds = expiry_seconds + + def _resolve_key(self) -> str: + if self.algorithm.upper() in _HMAC_JWT_ALGORITHMS: + return self.signing_key + try: + return Path(self.signing_key).read_text() + except OSError as e: + raise MCPAuthenticationError( + f"Could not read JWT signing key file '{self.signing_key}': {e}" + ) from e + + def sign(self) -> str: + now = int(time.time()) + claims: dict = { + "iss": self.issuer, + "aud": self.audience, + "iat": now, + "nbf": now, + "exp": now + self.expiry_seconds, + } + if self.subject is not None: + claims["sub"] = self.subject + claims.update(self.custom_claims) + headers = {"kid": self.key_id} if self.key_id else None + try: + return jwt.encode( + claims, self._resolve_key(), algorithm=self.algorithm, headers=headers + ) + except MCPAuthenticationError: + raise + except Exception as e: + raise MCPAuthenticationError(f"Failed to sign JWT: {e}") from e + + def auth_flow(self, request: httpx.Request): + request.headers["Authorization"] = f"Bearer {self.sign()}" + yield request + + +_JWT_BEARER_GRANT_URN = "urn:ietf:params:oauth:grant-type:jwt-bearer" + + +class OAuth2Auth(httpx.Auth): + + def __init__( + self, + *, + grant_type: str, + token_url: str | None = None, + refresh_url: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, + username: str | None = None, + password: str | None = None, + refresh_token: str | None = None, + assertion: str | None = None, + scopes: list[str] | None = None, + ) -> None: + self.grant_type = grant_type.lower() + self.token_url = token_url + self.refresh_url = refresh_url + self.client_id = client_id + self.client_secret = client_secret + self.username = username + self.password = password + self.refresh_token = refresh_token + self.assertion = assertion + self.scopes = scopes + self._token: str | None = None + self._expires_at: float = 0.0 + + def _token_request(self) -> tuple[str, dict[str, str], tuple[str, str] | None]: + url: str | None + data: dict[str, str] + if self.grant_type == "client_credentials": + url = self.token_url + data = {"grant_type": "client_credentials"} + elif self.grant_type == "password": + url = self.token_url + data = { + "grant_type": "password", + "username": self.username or "", + "password": self.password or "", + } + elif self.grant_type == "refresh_token": + url = self.refresh_url + data = { + "grant_type": "refresh_token", + "refresh_token": self.refresh_token or "", + } + elif self.grant_type == "jwt_bearer": + url = self.token_url + data = { + "grant_type": _JWT_BEARER_GRANT_URN, + "assertion": self.assertion or "", + } + else: + raise MCPAuthenticationError( + f"Unsupported oauth2 grant_type: {self.grant_type}" + ) + + if url is None: + raise MCPAuthenticationError( + f"oauth2 grant_type '{self.grant_type}' is missing its token URL" + ) + if self.scopes: + data["scope"] = " ".join(self.scopes) + + basic = ( + (self.client_id, self.client_secret) + if self.client_id is not None and self.client_secret is not None + else None + ) + return url, data, basic + + def _store_token(self, payload: dict) -> str: + token = payload.get("access_token") + if not token: + raise MCPAuthenticationError( + "OAuth2 token response did not contain 'access_token'" + ) + expires_in = float(payload.get("expires_in", 3600)) + self._expires_at = time.time() + expires_in - 30 + self._token = token + return token + + def _cached_token(self) -> str | None: + if self._token is not None and time.time() < self._expires_at: + return self._token + return None + + def _fetch_token_sync(self) -> str: + url, data, basic = self._token_request() + try: + resp = httpx.post( + url, + data=data, + auth=basic, + headers={"Accept": "application/json"}, + timeout=30.0, + ) + resp.raise_for_status() + payload = resp.json() + except Exception as e: + raise MCPAuthenticationError(f"OAuth2 token request failed: {e}") from e + return self._store_token(payload) + + async def _fetch_token_async(self) -> str: + url, data, basic = self._token_request() + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + url, + data=data, + auth=basic, + headers={"Accept": "application/json"}, + ) + resp.raise_for_status() + payload = resp.json() + except Exception as e: + raise MCPAuthenticationError(f"OAuth2 token request failed: {e}") from e + return self._store_token(payload) + + def sync_auth_flow(self, request: httpx.Request): + token = self._cached_token() or self._fetch_token_sync() + request.headers["Authorization"] = f"Bearer {token}" + yield request + + async def async_auth_flow(self, request: httpx.Request): + token = self._cached_token() or await self._fetch_token_async() + request.headers["Authorization"] = f"Bearer {token}" + yield request + + def build_httpx_auth(auth: ClientAuthentication | None) -> httpx.Auth | None: if auth is None: return None @@ -79,11 +280,40 @@ def build_httpx_auth(auth: ClientAuthentication | None) -> httpx.Auth | None: elif auth_type == "api-key": if auth.api_key is None: raise MCPAuthenticationError("API key auth requires 'api_key' field") - return ApiKeyAuth(auth.api_key) + return ApiKeyAuth(auth.api_key, header_name=auth.header_name or "Authorization") + + elif auth_type == "jwt": + if auth.issuer is None or auth.audience is None or auth.signing_key is None: + raise MCPAuthenticationError( + "JWT auth requires 'issuer', 'audience', and 'signing_key' fields" + ) + return JwtAuth( + issuer=auth.issuer, + audience=auth.audience, + signing_key=auth.signing_key, + algorithm=auth.algorithm or "RS256", + key_id=auth.key_id, + subject=auth.subject, + custom_claims=auth.custom_claims, + expiry_seconds=( + auth.expiry_seconds if auth.expiry_seconds is not None else 300 + ), + ) - elif auth_type in ("oauth2", "jwt"): - raise MCPAuthenticationError( - f"Authentication type '{auth_type}' not yet supported" + elif auth_type == "oauth2": + if auth.grant_type is None: + raise MCPAuthenticationError("OAuth2 auth requires 'grant_type' field") + return OAuth2Auth( + grant_type=auth.grant_type, + token_url=auth.token_url, + refresh_url=auth.refresh_url, + client_id=auth.client_id, + client_secret=auth.client_secret, + username=auth.username, + password=auth.password, + refresh_token=auth.refresh_token, + assertion=auth.assertion, + scopes=auth.scopes, ) else: diff --git a/python-interpreter/packages/afm-langchain/tests/test_mcp.py b/python-interpreter/packages/afm-langchain/tests/test_mcp.py index fe6000f..502a989 100644 --- a/python-interpreter/packages/afm-langchain/tests/test_mcp.py +++ b/python-interpreter/packages/afm-langchain/tests/test_mcp.py @@ -14,10 +14,12 @@ # specific language governing permissions and limitations # under the License. +import time from typing import cast from unittest.mock import AsyncMock, MagicMock, patch import httpx +import jwt import pytest from langchain_core.tools import BaseTool from langchain_mcp_adapters.sessions import StdioConnection @@ -36,8 +38,10 @@ from afm_langchain.tools.mcp import ( ApiKeyAuth, BearerAuth, + JwtAuth, MCPClient, MCPManager, + OAuth2Auth, build_httpx_auth, filter_tools, ) @@ -57,9 +61,21 @@ def make_mcp_server( elif auth_type == "api-key": auth = ClientAuthentication(type="api-key", api_key="test-api-key") elif auth_type == "oauth2": - auth = ClientAuthentication(type="oauth2") + auth = ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="https://auth.example.com/token", + client_id="id", + client_secret="secret", + ) elif auth_type == "jwt": - auth = ClientAuthentication(type="jwt") + auth = ClientAuthentication( + type="jwt", + issuer="afm-agent", + audience="https://api.example.com", + signing_key="secret", + algorithm="HS256", + ) return MCPServer( name=name, @@ -133,6 +149,185 @@ def test_api_key_auth_returns_api_key_auth_instance(self): assert isinstance(result, ApiKeyAuth) assert result.api_key == "my-api-key" + def test_api_key_auth_defaults_to_authorization_header(self): + auth = ClientAuthentication(type="api-key", api_key="my-api-key") + result = build_httpx_auth(auth) + assert isinstance(result, ApiKeyAuth) + assert result.header_name == "Authorization" + + def test_api_key_auth_uses_custom_header_name(self): + auth = ClientAuthentication( + type="api-key", api_key="my-api-key", header_name="X-API-Key" + ) + result = build_httpx_auth(auth) + assert isinstance(result, ApiKeyAuth) + assert result.api_key == "my-api-key" + assert result.header_name == "X-API-Key" + + def test_jwt_auth_returns_jwt_auth_instance(self): + auth = ClientAuthentication( + type="jwt", + issuer="afm-agent", + audience="https://api.example.com", + signing_key="secret", + algorithm="HS256", + ) + result = build_httpx_auth(auth) + assert isinstance(result, JwtAuth) + assert result.issuer == "afm-agent" + assert result.algorithm == "HS256" + + def test_jwt_auth_defaults_to_rs256(self): + auth = ClientAuthentication( + type="jwt", issuer="i", audience="a", signing_key="s" + ) + result = build_httpx_auth(auth) + assert isinstance(result, JwtAuth) + assert result.algorithm == "RS256" + + def test_jwt_auth_signs_valid_hmac_token(self): + auth = ClientAuthentication( + type="jwt", + issuer="afm-agent", + audience="https://api.example.com", + signing_key="topsecret-key-that-is-32-bytes-or-more", + algorithm="HS256", + subject="agent-1", + custom_claims={"scope": "read"}, + expiry_seconds=600, + ) + jwt_auth = build_httpx_auth(auth) + assert isinstance(jwt_auth, JwtAuth) + token = jwt_auth.sign() + decoded = jwt.decode( + token, + "topsecret-key-that-is-32-bytes-or-more", + algorithms=["HS256"], + audience="https://api.example.com", + ) + assert decoded["iss"] == "afm-agent" + assert decoded["sub"] == "agent-1" + assert decoded["scope"] == "read" + assert decoded["exp"] - decoded["iat"] == 600 + + def test_jwt_auth_signs_rs256_with_key_file(self, tmp_path): + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + key_file = tmp_path / "jwt_key.pem" + key_file.write_bytes( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + + auth = ClientAuthentication( + type="jwt", + issuer="afm-agent", + audience="https://api.example.com", + signing_key=str(key_file), # asymmetric: signing_key is a file path + algorithm="RS256", + ) + jwt_auth = build_httpx_auth(auth) + assert isinstance(jwt_auth, JwtAuth) + token = jwt_auth.sign() + + public_pem = private_key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + decoded = jwt.decode( + token, + public_pem, + algorithms=["RS256"], + audience="https://api.example.com", + ) + assert decoded["iss"] == "afm-agent" + + def test_oauth2_returns_oauth2_auth_instance(self): + auth = ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="https://auth.example.com/token", + client_id="id", + client_secret="secret", + scopes=["read"], + ) + result = build_httpx_auth(auth) + assert isinstance(result, OAuth2Auth) + assert result.grant_type == "client_credentials" + + def test_oauth2_client_credentials_token_request(self): + result = build_httpx_auth( + ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="https://auth.example.com/token", + client_id="id", + client_secret="secret", + scopes=["read", "write"], + ) + ) + assert isinstance(result, OAuth2Auth) + url, data, basic = result._token_request() + assert url == "https://auth.example.com/token" + assert data["grant_type"] == "client_credentials" + assert data["scope"] == "read write" + assert basic == ("id", "secret") + + def test_oauth2_refresh_token_uses_refresh_url(self): + result = build_httpx_auth( + ClientAuthentication( + type="oauth2", + grant_type="refresh_token", + refresh_url="https://auth.example.com/refresh", + refresh_token="rt", + client_id="id", + client_secret="secret", + ) + ) + assert isinstance(result, OAuth2Auth) + url, data, _basic = result._token_request() + assert url == "https://auth.example.com/refresh" + assert data["grant_type"] == "refresh_token" + assert data["refresh_token"] == "rt" + + def test_oauth2_jwt_bearer_token_request(self): + result = build_httpx_auth( + ClientAuthentication( + type="oauth2", + grant_type="jwt_bearer", + token_url="https://auth.example.com/token", + assertion="signed.jwt", + ) + ) + assert isinstance(result, OAuth2Auth) + _url, data, basic = result._token_request() + assert data["grant_type"] == "urn:ietf:params:oauth:grant-type:jwt-bearer" + assert data["assertion"] == "signed.jwt" + assert basic is None # no client credentials provided + + def test_oauth2_uses_cached_token(self): + result = build_httpx_auth( + ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="u", + client_id="id", + client_secret="secret", + ) + ) + assert isinstance(result, OAuth2Auth) + result._token = "cached-token" + result._expires_at = time.time() + 1000 + request = httpx.Request("GET", "https://api.example.com/resource") + flow = result.sync_auth_flow(request) + next(flow) + assert request.headers["Authorization"] == "Bearer cached-token" + class TestFilterTools: def test_no_filter_returns_all_tools(self): diff --git a/python-interpreter/uv.lock b/python-interpreter/uv.lock index 4b0a022..f805e1a 100644 --- a/python-interpreter/uv.lock +++ b/python-interpreter/uv.lock @@ -83,6 +83,7 @@ dependencies = [ { name = "langchain-mcp-adapters" }, { name = "langchain-openai" }, { name = "mcp" }, + { name = "pyjwt", extra = ["crypto"] }, ] [package.metadata] @@ -93,6 +94,7 @@ requires-dist = [ { name = "langchain-mcp-adapters", specifier = ">=0.2.1" }, { name = "langchain-openai", specifier = ">=1.1.7" }, { name = "mcp", specifier = ">=1.26.0" }, + { name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.0" }, ] [[package]]