From 2321bd5c6ba929e196ed6735215f0ed0978e3072 Mon Sep 17 00:00:00 2001 From: Thareesha98 Date: Fri, 12 Jun 2026 17:09:49 +0530 Subject: [PATCH 1/8] feat(auth): add parse-time authentication schema validation (ballerina) Standardize AFM authentication field schemas (issue #35), Phase 1. - parser.bal: validate every authentication block at parse time against its type's field schema, across all three sites (model, MCP transport, webhook subscription). Reject unknown types, enforce required fields, and reject fields that do not belong to the type. bearer/basic/api-key are fully validated; jwt/oauth2 are recognized with field validation deferred to their phase. - agent.bal: add an explicit api-key case to mapToHttpClientAuth with a clear message. Ballerina's http client has no raw-header auth variant, so api-key is not yet supported for MCP/webhook transport (it remains supported for model providers). - tests: cover schema validation, the api-key mapping case, and parse-level rejection. --- ballerina-interpreter/agent.bal | 3 + ballerina-interpreter/parser.bal | 97 +++++++++++++++++++++++ ballerina-interpreter/tests/main_test.bal | 96 ++++++++++++++++++++++ 3 files changed, 196 insertions(+) diff --git a/ballerina-interpreter/agent.bal b/ballerina-interpreter/agent.bal index d60e55a..d24728c 100644 --- a/ballerina-interpreter/agent.bal +++ b/ballerina-interpreter/agent.bal @@ -350,6 +350,9 @@ function mapToHttpClientAuth(ClientAuthentication? auth) returns http:ClientAuth "bearer" => { return rest.cloneWithType(http:BearerTokenConfig); } + "api-key" => { + return error("API key authentication is not yet supported for MCP/webhook transport in the Ballerina interpreter"); + } "oauth2" => { // record {string grantType;}|error oauth2Config = check rest.cloneWithType(); // if oauth2Config is error { diff --git a/ballerina-interpreter/parser.bal b/ballerina-interpreter/parser.bal index 8758272..c4b4279 100644 --- a/ballerina-interpreter/parser.bal +++ b/ballerina-interpreter/parser.bal @@ -31,6 +31,8 @@ function parseAfm(string content) returns AFMRecord|error { body = resolvedContent; } + check validateMetadataAuthentication(metadata); + // Extract Role and Instructions sections string[] bodyLines = splitLines(body); string role = ""; @@ -305,6 +307,101 @@ function authenticationContainsHttpVariable(ClientAuthentication? authentication return false; } +final readonly & string[] RECOGNIZED_AUTH_TYPES = ["bearer", "basic", "api-key", "jwt", "oauth2"]; + +function validateMetadataAuthentication(AgentMetadata? metadata) returns error? { + if metadata is () { + return; + } + + Model? model = metadata.model; + if model is Model { + check validateAuthentication(model.authentication); + } + + Tools? tools = metadata.tools; + if tools is Tools { + MCPServer[]? mcp = tools.mcp; + if mcp is MCPServer[] { + foreach MCPServer server in mcp { + Transport transport = server.transport; + if transport is HttpTransport { + check validateAuthentication(transport.authentication); + } + } + } + } + + Interface[]? interfaces = metadata.interfaces; + if interfaces is Interface[] { + foreach Interface interface in interfaces { + if interface is WebhookInterface { + check validateAuthentication(interface.subscription.authentication); + } + } + } +} + +function validateAuthentication(ClientAuthentication? auth) returns error? { + if auth is () { + return; + } + + string authType = auth.'type.toLowerAscii(); + if RECOGNIZED_AUTH_TYPES.indexOf(authType) is () { + return error(string `unknown authentication type '${auth.'type}'. Supported types: bearer, basic, api-key, jwt, oauth2`); + } + + string[]? allowed = allowedAuthFields(authType); + if allowed is () { + return; + } + + string[] provided = from string key in auth.keys() where key != "type" select key; + + foreach string required in requiredAuthFields(authType) { + if provided.indexOf(required) is () { + return error(string `type '${authType}' requires '${required}' field`); + } + } + + foreach string fieldName in provided { + if allowed.indexOf(fieldName) is () { + return error(string `type '${authType}' does not support '${fieldName}' field`); + } + } +} + +function allowedAuthFields(string authType) returns string[]? { + match authType { + "bearer" => { + return ["token"]; + } + "basic" => { + return ["username", "password"]; + } + "api-key" => { + return ["api_key", "header_name"]; + } + } + return (); +} + +function requiredAuthFields(string authType) returns string[] { + match authType { + "bearer" => { + return ["token"]; + } + "basic" => { + return ["username", "password"]; + } + "api-key" => { + return ["api_key"]; + } + } + return []; +} + function signatureContainsHttpVariable(Signature signature) returns boolean => jsonSchemaContainsHttpVariable(signature.input) || jsonSchemaContainsHttpVariable(signature.output); diff --git a/ballerina-interpreter/tests/main_test.bal b/ballerina-interpreter/tests/main_test.bal index e9ad324..f6935ef 100644 --- a/ballerina-interpreter/tests/main_test.bal +++ b/ballerina-interpreter/tests/main_test.bal @@ -1199,3 +1199,99 @@ function testMapToHttpClientAuthUnsupportedType() { } test:assertEquals(result.message(), "Unsupported authentication type: custom-auth"); } + +@test:Config +function testMapToHttpClientAuthApiKeyNotSupported() { + ClientAuthentication auth = { + 'type: "api-key", + "api_key": "test-key" + }; + + http:ClientAuthConfig|error? result = mapToHttpClientAuth(auth); + if result is http:ClientAuthConfig? { + test:assertFail("Expected error for api-key on HTTP transport"); + } + test:assertEquals(result.message(), + "API key authentication is not yet supported for MCP/webhook transport in the Ballerina interpreter"); +} + + +@test:Config +function testValidateAuthenticationNull() returns error? { + error? result = validateAuthentication(()); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationBearerValid() returns error? { + ClientAuthentication auth = {'type: "bearer", "token": "t"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationBasicValid() returns error? { + ClientAuthentication auth = {'type: "basic", "username": "u", "password": "p"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationApiKeyValid() returns error? { + ClientAuthentication auth = {'type: "api-key", "api_key": "k", "header_name": "X-API-Key"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationCaseInsensitive() returns error? { + ClientAuthentication auth = {'type: "Bearer", "token": "t"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationJwtDeferred() returns error? { + ClientAuthentication auth = {'type: "jwt"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationUnknownType() { + ClientAuthentication auth = {'type: "token", "token": "t"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertTrue((result).message().includes("unknown authentication type 'token'")); +} + +@test:Config +function testValidateAuthenticationBearerMissingToken() { + ClientAuthentication auth = {'type: "bearer"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertEquals((result).message(), "type 'bearer' requires 'token' field"); +} + +@test:Config +function testValidateAuthenticationUnknownField() { + ClientAuthentication auth = {'type: "bearer", "token": "t", "username": "u"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertEquals((result).message(), "type 'bearer' does not support 'username' field"); +} + +@test:Config +function testParseAfmRejectsInvalidAuthentication() { + string content = "---\n" + + "model:\n" + + " provider: openai\n" + + " name: gpt-4\n" + + " authentication:\n" + + " type: bearer\n" + + "---\n\n" + + "# Role\nRole\n\n# Instructions\nInstructions\n"; + AFMRecord|error result = parseAfm(content); + test:assertTrue(result is error); + test:assertEquals((result).message(), "type 'bearer' requires 'token' field"); +} From 6eb3e146c818cd8476a8ce989f2277a8b1d6406b Mon Sep 17 00:00:00 2001 From: Thareesha98 Date: Fri, 12 Jun 2026 17:15:44 +0530 Subject: [PATCH 2/8] feat(auth): add parse-time authentication schema validation (python) Standardize AFM authentication field schemas (issue #35), Phase 1. - models.py: ClientAuthentication now validates at parse time - reject unknown types, enforce per-type required fields, and reject fields that do not belong to the type (catches typos and cross-type fields). Add an optional header_name for api-key. bearer/basic/api-key are fully validated; jwt/oauth2 are recognized with field validation deferred to their phase. - mcp.py: wire header_name into ApiKeyAuth so api-key can target custom headers (e.g. X-API-Key) instead of always Authorization. - tests: cover the validation rules and the custom header_name wiring. --- .../packages/afm-core/src/afm/models.py | 58 +++++++++++--- .../packages/afm-core/tests/test_parser.py | 80 +++++++++++++++++++ .../src/afm_langchain/tools/mcp.py | 2 +- .../packages/afm-langchain/tests/test_mcp.py | 15 ++++ 4 files changed, 142 insertions(+), 13 deletions(-) diff --git a/python-interpreter/packages/afm-core/src/afm/models.py b/python-interpreter/packages/afm-core/src/afm/models.py index 101f6d9..e112da0 100644 --- a/python-interpreter/packages/afm-core/src/afm/models.py +++ b/python-interpreter/packages/afm-core/src/afm/models.py @@ -30,6 +30,21 @@ class Provider(BaseModel): url: str | None = None +RECOGNIZED_AUTH_TYPES = ("bearer", "basic", "api-key", "jwt", "oauth2") + +_AUTH_REQUIRED_FIELDS: dict[str, set[str]] = { + "bearer": {"token"}, + "basic": {"username", "password"}, + "api-key": {"api_key"}, +} +_AUTH_ALLOWED_FIELDS: dict[str, set[str]] = { + "bearer": {"token"}, + "basic": {"username", "password"}, + "api-key": {"api_key", "header_name"}, +} +_AUTH_CREDENTIAL_FIELDS = ("token", "username", "password", "api_key", "header_name") + + class ClientAuthentication(BaseModel): model_config = ConfigDict(extra="allow") @@ -38,21 +53,40 @@ class ClientAuthentication(BaseModel): username: str | None = None password: str | None = None api_key: str | None = None + header_name: str | None = None @model_validator(mode="after") def validate_type_fields(self) -> Self: - match self.type.lower(): - case "bearer": - if self.token is None: - raise ValueError("type 'bearer' requires 'token' field") - case "basic": - if self.username is None or self.password is None: - raise ValueError( - "type 'basic' requires 'username' and 'password' fields" - ) - case "api-key": - if self.api_key is None: - raise ValueError("type 'api-key' requires 'api_key' field") + auth_type = self.type.lower() + + if auth_type not in RECOGNIZED_AUTH_TYPES: + supported = ", ".join(RECOGNIZED_AUTH_TYPES) + raise ValueError( + f"unknown authentication type '{self.type}'. " + f"Supported types: {supported}" + ) + + allowed = _AUTH_ALLOWED_FIELDS.get(auth_type) + if allowed is None: + return self + + provided = { + name for name in _AUTH_CREDENTIAL_FIELDS if getattr(self, name) is not None + } + provided |= set(self.model_extra or {}) + + missing = _AUTH_REQUIRED_FIELDS[auth_type] - provided + if missing: + fields = ", ".join(f"'{name}'" for name in sorted(missing)) + suffix = "field" if len(missing) == 1 else "fields" + raise ValueError(f"type '{auth_type}' requires {fields} {suffix}") + + unknown = provided - allowed + if unknown: + fields = ", ".join(f"'{name}'" for name in sorted(unknown)) + suffix = "field" if len(unknown) == 1 else "fields" + raise ValueError(f"type '{auth_type}' does not support {fields} {suffix}") + return self diff --git a/python-interpreter/packages/afm-core/tests/test_parser.py b/python-interpreter/packages/afm-core/tests/test_parser.py index 0b517ff..3a79d60 100644 --- a/python-interpreter/packages/afm-core/tests/test_parser.py +++ b/python-interpreter/packages/afm-core/tests/test_parser.py @@ -17,9 +17,11 @@ from pathlib import Path import pytest +from pydantic import ValidationError from afm.exceptions import AFMParseError, AFMValidationError, VariableResolutionError from afm.models import ( + ClientAuthentication, ConsoleChatInterface, HttpTransport, StdioTransport, @@ -493,3 +495,81 @@ def test_parse_afm_default_behavior_resolves_env(self, monkeypatch) -> None: assert result.metadata.model is not None assert result.metadata.model.authentication is not None assert result.metadata.model.authentication.token == "secret-token-123" + + +class TestClientAuthenticationValidation: + + def test_bearer_valid(self) -> None: + auth = ClientAuthentication(type="bearer", token="t") + assert auth.type == "bearer" + assert auth.token == "t" + + def test_basic_valid(self) -> None: + auth = ClientAuthentication(type="basic", username="u", password="p") + assert auth.username == "u" + assert auth.password == "p" + + def test_api_key_valid(self) -> None: + auth = ClientAuthentication(type="api-key", api_key="k") + assert auth.api_key == "k" + assert auth.header_name is None + + def test_api_key_with_header_name(self) -> None: + auth = ClientAuthentication(type="api-key", api_key="k", header_name="X-API-Key") + assert auth.header_name == "X-API-Key" + + def test_type_is_case_insensitive(self) -> None: + auth = ClientAuthentication(type="Bearer", token="t") + assert auth.token == "t" + + def test_jwt_recognized_without_field_validation(self) -> None: + auth = ClientAuthentication(type="jwt") + assert auth.type == "jwt" + + def test_oauth2_recognized_without_field_validation(self) -> None: + auth = ClientAuthentication(type="oauth2") + assert auth.type == "oauth2" + + def test_unknown_type_rejected(self) -> None: + with pytest.raises(ValidationError, match="unknown authentication type 'token'"): + ClientAuthentication(type="token", token="t") + + def test_bearer_missing_token_rejected(self) -> None: + with pytest.raises(ValidationError, match="type 'bearer' requires 'token'"): + ClientAuthentication(type="bearer") + + def test_basic_missing_password_rejected(self) -> None: + with pytest.raises(ValidationError, match="type 'basic' requires 'password'"): + ClientAuthentication(type="basic", username="u") + + def test_api_key_missing_key_rejected(self) -> None: + with pytest.raises(ValidationError, match="type 'api-key' requires 'api_key'"): + ClientAuthentication(type="api-key") + + def test_unknown_field_rejected(self) -> None: + with pytest.raises( + ValidationError, match="type 'bearer' does not support 'username'" + ): + ClientAuthentication(type="bearer", token="t", username="u") + + def test_typo_field_rejected(self) -> None: + with pytest.raises(ValidationError): + ClientAuthentication(type="api-key", api_key="k", headername="X") + + def test_invalid_auth_fails_at_parse_time(self) -> None: + content = """--- +model: + provider: openai + name: gpt-4 + authentication: + type: bearer +--- + +# Role +Role + +# Instructions +Instructions +""" + with pytest.raises(AFMValidationError): + parse_afm(content) diff --git a/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py b/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py index 709a065..c638230 100644 --- a/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py +++ b/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py @@ -79,7 +79,7 @@ def build_httpx_auth(auth: ClientAuthentication | None) -> httpx.Auth | None: elif auth_type == "api-key": if auth.api_key is None: raise MCPAuthenticationError("API key auth requires 'api_key' field") - return ApiKeyAuth(auth.api_key) + return ApiKeyAuth(auth.api_key, header_name=auth.header_name or "Authorization") elif auth_type in ("oauth2", "jwt"): raise MCPAuthenticationError( diff --git a/python-interpreter/packages/afm-langchain/tests/test_mcp.py b/python-interpreter/packages/afm-langchain/tests/test_mcp.py index fe6000f..fd998b5 100644 --- a/python-interpreter/packages/afm-langchain/tests/test_mcp.py +++ b/python-interpreter/packages/afm-langchain/tests/test_mcp.py @@ -133,6 +133,21 @@ def test_api_key_auth_returns_api_key_auth_instance(self): assert isinstance(result, ApiKeyAuth) assert result.api_key == "my-api-key" + def test_api_key_auth_defaults_to_authorization_header(self): + auth = ClientAuthentication(type="api-key", api_key="my-api-key") + result = build_httpx_auth(auth) + assert isinstance(result, ApiKeyAuth) + assert result.header_name == "Authorization" + + def test_api_key_auth_uses_custom_header_name(self): + auth = ClientAuthentication( + type="api-key", api_key="my-api-key", header_name="X-API-Key" + ) + result = build_httpx_auth(auth) + assert isinstance(result, ApiKeyAuth) + assert result.api_key == "my-api-key" + assert result.header_name == "X-API-Key" + class TestFilterTools: def test_no_filter_returns_all_tools(self): From 046be901d318b538391d6786563480fad678df4b Mon Sep 17 00:00:00 2001 From: Thareesha98 Date: Sat, 13 Jun 2026 18:57:33 +0530 Subject: [PATCH 3/8] feat(auth): add jwt runtime signing (ballerina) Standardize AFM authentication field schemas (issue #35), Phase 2. The runtime now mints and signs a JWT from a key plus claims and sends it as a bearer token, for MCP and webhook transports. - agent.bal: map a jwt authentication block to http:JwtIssuerConfig (issuer->iss, subject->username/sub, key_id->kid header, custom_claims, expiry_seconds->expTime). HMAC algorithms (HS256/384/512) use signing_key as a shared secret; asymmetric algorithms use it as a PEM key file path. - parser.bal: validate jwt fields at parse time (issuer, audience and signing_key required); jwt is no longer treated as deferred. - tests: HMAC and RS256 mapping, plus jwt schema validation. --- ballerina-interpreter/agent.bal | 53 ++++++++++++++++++++++- ballerina-interpreter/parser.bal | 6 +++ ballerina-interpreter/tests/main_test.bal | 48 ++++++++++++++++---- 3 files changed, 96 insertions(+), 11 deletions(-) diff --git a/ballerina-interpreter/agent.bal b/ballerina-interpreter/agent.bal index d24728c..2ad7388 100644 --- a/ballerina-interpreter/agent.bal +++ b/ballerina-interpreter/agent.bal @@ -379,11 +379,60 @@ function mapToHttpClientAuth(ClientAuthentication? auth) returns http:ClientAuth return error("OAuth2 authentication not yet supported"); } "jwt" => { - // return rest.cloneWithType(http:JwtIssuerConfig); - return error("JWT authentication not yet supported"); + JwtAuthConfig|error jwtConfig = rest.cloneWithType(); + if jwtConfig is error { + return error("Invalid JWT authentication configuration", jwtConfig); + } + return buildJwtIssuerConfig(jwtConfig); } _ => { return error(string `Unsupported authentication type: ${'type}`); } } } + +type JwtAuthConfig record {| + string issuer; + string|string[] audience; + string signing_key; + string algorithm = "RS256"; + string key_id?; + string subject?; + map custom_claims?; + decimal expiry_seconds = 300; +|}; + +function buildJwtIssuerConfig(JwtAuthConfig jwtConfig) returns http:JwtIssuerConfig|error { + string algorithm = jwtConfig.algorithm; + boolean isHmac = algorithm == "HS256" || algorithm == "HS384" || algorithm == "HS512"; + json signatureKeyConfig = isHmac ? jwtConfig.signing_key : {keyFile: jwtConfig.signing_key}; + + map issuerConfig = { + issuer: jwtConfig.issuer, + audience: jwtConfig.audience, + expTime: jwtConfig.expiry_seconds, + signatureConfig: { + algorithm, + config: signatureKeyConfig + } + }; + + string? keyId = jwtConfig?.key_id; + if keyId is string { + issuerConfig["keyId"] = keyId; + } + string? subject = jwtConfig?.subject; + if subject is string { + issuerConfig["username"] = subject; + } + map? customClaims = jwtConfig?.custom_claims; + if customClaims is map { + issuerConfig["customClaims"] = customClaims; + } + + http:JwtIssuerConfig|error result = issuerConfig.cloneWithType(); + if result is error { + return error("Invalid JWT authentication configuration", result); + } + return result; +} diff --git a/ballerina-interpreter/parser.bal b/ballerina-interpreter/parser.bal index c4b4279..1a62dd0 100644 --- a/ballerina-interpreter/parser.bal +++ b/ballerina-interpreter/parser.bal @@ -383,6 +383,9 @@ function allowedAuthFields(string authType) returns string[]? { "api-key" => { return ["api_key", "header_name"]; } + "jwt" => { + return ["issuer", "audience", "signing_key", "algorithm", "key_id", "subject", "custom_claims", "expiry_seconds"]; + } } return (); } @@ -398,6 +401,9 @@ function requiredAuthFields(string authType) returns string[] { "api-key" => { return ["api_key"]; } + "jwt" => { + return ["issuer", "audience", "signing_key"]; + } } return []; } diff --git a/ballerina-interpreter/tests/main_test.bal b/ballerina-interpreter/tests/main_test.bal index f6935ef..35bc35a 100644 --- a/ballerina-interpreter/tests/main_test.bal +++ b/ballerina-interpreter/tests/main_test.bal @@ -1175,16 +1175,33 @@ function testMapToHttpClientAuthOAuth2NotSupported() { } @test:Config -function testMapToHttpClientAuthJWTNotSupported() { +function testMapToHttpClientAuthJwtHmac() returns error? { ClientAuthentication auth = { - 'type: "jwt" + 'type: "jwt", + "issuer": "afm-agent", + "audience": "https://api.example.com", + "signing_key": "shared-secret", + "algorithm": "HS256" }; - http:ClientAuthConfig|error? result = mapToHttpClientAuth(auth); - if result is http:ClientAuthConfig? { - test:assertFail("Expected error for JWT authentication"); - } - test:assertEquals(result.message(), "JWT authentication not yet supported"); + http:ClientAuthConfig? result = check mapToHttpClientAuth(auth); + test:assertTrue(result is http:JwtIssuerConfig); + http:JwtIssuerConfig issuerConfig = result; + test:assertEquals(issuerConfig.issuer, "afm-agent"); + test:assertEquals(issuerConfig.audience, "https://api.example.com"); +} + +@test:Config +function testMapToHttpClientAuthJwtRs256() returns error? { + ClientAuthentication auth = { + 'type: "jwt", + "issuer": "afm-agent", + "audience": "https://api.example.com", + "signing_key": "/path/to/key.pem" + }; + + http:ClientAuthConfig? result = check mapToHttpClientAuth(auth); + test:assertTrue(result is http:JwtIssuerConfig); } @test:Config @@ -1251,12 +1268,25 @@ function testValidateAuthenticationCaseInsensitive() returns error? { } @test:Config -function testValidateAuthenticationJwtDeferred() returns error? { - ClientAuthentication auth = {'type: "jwt"}; +function testValidateAuthenticationJwtValid() returns error? { + ClientAuthentication auth = { + 'type: "jwt", + "issuer": "afm-agent", + "audience": "https://api.example.com", + "signing_key": "secret" + }; error? result = validateAuthentication(auth); test:assertTrue(result is ()); } +@test:Config +function testValidateAuthenticationJwtMissingField() { + ClientAuthentication auth = {'type: "jwt", "issuer": "afm-agent"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertEquals((result).message(), "type 'jwt' requires 'audience' field"); +} + @test:Config function testValidateAuthenticationUnknownType() { ClientAuthentication auth = {'type: "token", "token": "t"}; From 9d2c9f8863f06a9f89168590f679fabc502f03d0 Mon Sep 17 00:00:00 2001 From: Thareesha98 Date: Sat, 13 Jun 2026 18:58:58 +0530 Subject: [PATCH 4/8] feat(auth): add jwt field schema and parse-time validation (python) Standardize AFM authentication field schemas (issue #35), Phase 2. - models.py: add the jwt fields to ClientAuthentication (issuer, audience, signing_key, algorithm, key_id, subject, custom_claims, expiry_seconds) and fold jwt into the parse-time validator (issuer, audience and signing_key required; unknown jwt fields rejected). oauth2 stays deferred. - tests: jwt valid/invalid schema cases, audience list, unknown field. --- .../packages/afm-core/src/afm/models.py | 38 ++++++++++++++++++- .../packages/afm-core/tests/test_parser.py | 30 +++++++++++++-- 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/python-interpreter/packages/afm-core/src/afm/models.py b/python-interpreter/packages/afm-core/src/afm/models.py index e112da0..1d18f38 100644 --- a/python-interpreter/packages/afm-core/src/afm/models.py +++ b/python-interpreter/packages/afm-core/src/afm/models.py @@ -18,7 +18,7 @@ from enum import Enum from pathlib import Path -from typing import Annotated, Literal, Self +from typing import Annotated, Any, Literal, Self from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -32,17 +32,43 @@ class Provider(BaseModel): RECOGNIZED_AUTH_TYPES = ("bearer", "basic", "api-key", "jwt", "oauth2") +_JWT_ALLOWED_FIELDS = { + "issuer", + "audience", + "signing_key", + "algorithm", + "key_id", + "subject", + "custom_claims", + "expiry_seconds", +} _AUTH_REQUIRED_FIELDS: dict[str, set[str]] = { "bearer": {"token"}, "basic": {"username", "password"}, "api-key": {"api_key"}, + "jwt": {"issuer", "audience", "signing_key"}, } _AUTH_ALLOWED_FIELDS: dict[str, set[str]] = { "bearer": {"token"}, "basic": {"username", "password"}, "api-key": {"api_key", "header_name"}, + "jwt": _JWT_ALLOWED_FIELDS, } -_AUTH_CREDENTIAL_FIELDS = ("token", "username", "password", "api_key", "header_name") +_AUTH_CREDENTIAL_FIELDS = ( + "token", + "username", + "password", + "api_key", + "header_name", + "issuer", + "audience", + "signing_key", + "algorithm", + "key_id", + "subject", + "custom_claims", + "expiry_seconds", +) class ClientAuthentication(BaseModel): @@ -54,6 +80,14 @@ class ClientAuthentication(BaseModel): password: str | None = None api_key: str | None = None header_name: str | None = None + issuer: str | None = None + audience: str | list[str] | None = None + signing_key: str | None = None + algorithm: str | None = None + key_id: str | None = None + subject: str | None = None + custom_claims: dict[str, Any] | None = None + expiry_seconds: int | None = None @model_validator(mode="after") def validate_type_fields(self) -> Self: diff --git a/python-interpreter/packages/afm-core/tests/test_parser.py b/python-interpreter/packages/afm-core/tests/test_parser.py index 3a79d60..e5b453a 100644 --- a/python-interpreter/packages/afm-core/tests/test_parser.py +++ b/python-interpreter/packages/afm-core/tests/test_parser.py @@ -522,9 +522,33 @@ def test_type_is_case_insensitive(self) -> None: auth = ClientAuthentication(type="Bearer", token="t") assert auth.token == "t" - def test_jwt_recognized_without_field_validation(self) -> None: - auth = ClientAuthentication(type="jwt") - assert auth.type == "jwt" + def test_jwt_valid(self) -> None: + auth = ClientAuthentication( + type="jwt", + issuer="afm-agent", + audience="https://api.example.com", + signing_key="secret", + ) + assert auth.issuer == "afm-agent" + assert auth.audience == "https://api.example.com" + + def test_jwt_audience_list(self) -> None: + auth = ClientAuthentication( + type="jwt", issuer="i", audience=["a", "b"], signing_key="s" + ) + assert auth.audience == ["a", "b"] + + def test_jwt_missing_signing_key_rejected(self) -> None: + with pytest.raises( + ValidationError, match="type 'jwt' requires 'signing_key'" + ): + ClientAuthentication(type="jwt", issuer="i", audience="a") + + def test_jwt_unknown_field_rejected(self) -> None: + with pytest.raises(ValidationError, match="does not support"): + ClientAuthentication( + type="jwt", issuer="i", audience="a", signing_key="s", token="x" + ) def test_oauth2_recognized_without_field_validation(self) -> None: auth = ClientAuthentication(type="oauth2") From 78c499ef2f27245c5a5f0d6244eaa82cde58b712 Mon Sep 17 00:00:00 2001 From: Thareesha98 Date: Sat, 13 Jun 2026 19:01:53 +0530 Subject: [PATCH 5/8] feat(auth): add jwt runtime signing (python) Standardize AFM authentication field schemas (issue #35), Phase 2. - mcp.py: add JwtAuth (httpx.Auth) that mints and signs a JWT per request via PyJWT and sends it as a bearer token, and wire it into build_httpx_auth. HMAC algorithms use signing_key as a shared secret; asymmetric algorithms read it as a PEM key file. Claims are assembled to match the Ballerina runtime (sets nbf; custom_claims merged last). Only oauth2 remains not yet supported. - pyproject.toml / uv.lock: declare the pyjwt[crypto] dependency. - tests: HMAC sign+decode round-trip, RS256 with a generated key file, default-RS256, and custom header_name. --- .../packages/afm-langchain/pyproject.toml | 1 + .../src/afm_langchain/tools/mcp.py | 88 +++++++++++++++++- .../packages/afm-langchain/tests/test_mcp.py | 93 ++++++++++++++++++- python-interpreter/uv.lock | 2 + 4 files changed, 181 insertions(+), 3 deletions(-) diff --git a/python-interpreter/packages/afm-langchain/pyproject.toml b/python-interpreter/packages/afm-langchain/pyproject.toml index f6d7ffd..6e616ac 100644 --- a/python-interpreter/packages/afm-langchain/pyproject.toml +++ b/python-interpreter/packages/afm-langchain/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "langchain-anthropic>=1.3.1", "mcp>=1.26.0", "langchain-mcp-adapters>=0.2.1", + "pyjwt[crypto]>=2.10.0", ] [project.entry-points."afm.runner"] diff --git a/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py b/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py index c638230..8150540 100644 --- a/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py +++ b/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py @@ -17,8 +17,11 @@ from __future__ import annotations import logging +import time +from pathlib import Path import httpx +import jwt from afm.exceptions import ( MCPAuthenticationError, MCPConnectionError, @@ -58,6 +61,69 @@ def auth_flow(self, request: httpx.Request): yield request +_HMAC_JWT_ALGORITHMS = {"HS256", "HS384", "HS512"} + + +class JwtAuth(httpx.Auth): + + def __init__( + self, + *, + issuer: str, + audience: str | list[str], + signing_key: str, + algorithm: str = "RS256", + key_id: str | None = None, + subject: str | None = None, + custom_claims: dict | None = None, + expiry_seconds: int = 300, + ) -> None: + self.issuer = issuer + self.audience = audience + self.signing_key = signing_key + self.algorithm = algorithm + self.key_id = key_id + self.subject = subject + self.custom_claims = custom_claims or {} + self.expiry_seconds = expiry_seconds + + def _resolve_key(self) -> str: + if self.algorithm.upper() in _HMAC_JWT_ALGORITHMS: + return self.signing_key + try: + return Path(self.signing_key).read_text() + except OSError as e: + raise MCPAuthenticationError( + f"Could not read JWT signing key file '{self.signing_key}': {e}" + ) from e + + def sign(self) -> str: + now = int(time.time()) + claims: dict = { + "iss": self.issuer, + "aud": self.audience, + "iat": now, + "nbf": now, + "exp": now + self.expiry_seconds, + } + if self.subject is not None: + claims["sub"] = self.subject + claims.update(self.custom_claims) + headers = {"kid": self.key_id} if self.key_id else None + try: + return jwt.encode( + claims, self._resolve_key(), algorithm=self.algorithm, headers=headers + ) + except MCPAuthenticationError: + raise + except Exception as e: + raise MCPAuthenticationError(f"Failed to sign JWT: {e}") from e + + def auth_flow(self, request: httpx.Request): + request.headers["Authorization"] = f"Bearer {self.sign()}" + yield request + + def build_httpx_auth(auth: ClientAuthentication | None) -> httpx.Auth | None: if auth is None: return None @@ -81,9 +147,27 @@ def build_httpx_auth(auth: ClientAuthentication | None) -> httpx.Auth | None: raise MCPAuthenticationError("API key auth requires 'api_key' field") return ApiKeyAuth(auth.api_key, header_name=auth.header_name or "Authorization") - elif auth_type in ("oauth2", "jwt"): + elif auth_type == "jwt": + if auth.issuer is None or auth.audience is None or auth.signing_key is None: + raise MCPAuthenticationError( + "JWT auth requires 'issuer', 'audience', and 'signing_key' fields" + ) + return JwtAuth( + issuer=auth.issuer, + audience=auth.audience, + signing_key=auth.signing_key, + algorithm=auth.algorithm or "RS256", + key_id=auth.key_id, + subject=auth.subject, + custom_claims=auth.custom_claims, + expiry_seconds=( + auth.expiry_seconds if auth.expiry_seconds is not None else 300 + ), + ) + + elif auth_type == "oauth2": raise MCPAuthenticationError( - f"Authentication type '{auth_type}' not yet supported" + "Authentication type 'oauth2' not yet supported" ) else: diff --git a/python-interpreter/packages/afm-langchain/tests/test_mcp.py b/python-interpreter/packages/afm-langchain/tests/test_mcp.py index fd998b5..3cf0508 100644 --- a/python-interpreter/packages/afm-langchain/tests/test_mcp.py +++ b/python-interpreter/packages/afm-langchain/tests/test_mcp.py @@ -18,6 +18,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx +import jwt import pytest from langchain_core.tools import BaseTool from langchain_mcp_adapters.sessions import StdioConnection @@ -36,6 +37,7 @@ from afm_langchain.tools.mcp import ( ApiKeyAuth, BearerAuth, + JwtAuth, MCPClient, MCPManager, build_httpx_auth, @@ -59,7 +61,13 @@ def make_mcp_server( elif auth_type == "oauth2": auth = ClientAuthentication(type="oauth2") elif auth_type == "jwt": - auth = ClientAuthentication(type="jwt") + auth = ClientAuthentication( + type="jwt", + issuer="afm-agent", + audience="https://api.example.com", + signing_key="secret", + algorithm="HS256", + ) return MCPServer( name=name, @@ -148,6 +156,89 @@ def test_api_key_auth_uses_custom_header_name(self): assert result.api_key == "my-api-key" assert result.header_name == "X-API-Key" + def test_jwt_auth_returns_jwt_auth_instance(self): + auth = ClientAuthentication( + type="jwt", + issuer="afm-agent", + audience="https://api.example.com", + signing_key="secret", + algorithm="HS256", + ) + result = build_httpx_auth(auth) + assert isinstance(result, JwtAuth) + assert result.issuer == "afm-agent" + assert result.algorithm == "HS256" + + def test_jwt_auth_defaults_to_rs256(self): + auth = ClientAuthentication( + type="jwt", issuer="i", audience="a", signing_key="s" + ) + result = build_httpx_auth(auth) + assert isinstance(result, JwtAuth) + assert result.algorithm == "RS256" + + def test_jwt_auth_signs_valid_hmac_token(self): + auth = ClientAuthentication( + type="jwt", + issuer="afm-agent", + audience="https://api.example.com", + signing_key="topsecret-key-that-is-32-bytes-or-more", + algorithm="HS256", + subject="agent-1", + custom_claims={"scope": "read"}, + expiry_seconds=600, + ) + jwt_auth = build_httpx_auth(auth) + assert isinstance(jwt_auth, JwtAuth) + token = jwt_auth.sign() + decoded = jwt.decode( + token, + "topsecret-key-that-is-32-bytes-or-more", + algorithms=["HS256"], + audience="https://api.example.com", + ) + assert decoded["iss"] == "afm-agent" + assert decoded["sub"] == "agent-1" + assert decoded["scope"] == "read" + assert decoded["exp"] - decoded["iat"] == 600 + + def test_jwt_auth_signs_rs256_with_key_file(self, tmp_path): + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + key_file = tmp_path / "jwt_key.pem" + key_file.write_bytes( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + + auth = ClientAuthentication( + type="jwt", + issuer="afm-agent", + audience="https://api.example.com", + signing_key=str(key_file), # asymmetric: signing_key is a file path + algorithm="RS256", + ) + jwt_auth = build_httpx_auth(auth) + assert isinstance(jwt_auth, JwtAuth) + token = jwt_auth.sign() + + public_pem = private_key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + decoded = jwt.decode( + token, + public_pem, + algorithms=["RS256"], + audience="https://api.example.com", + ) + assert decoded["iss"] == "afm-agent" + class TestFilterTools: def test_no_filter_returns_all_tools(self): diff --git a/python-interpreter/uv.lock b/python-interpreter/uv.lock index 4b0a022..f805e1a 100644 --- a/python-interpreter/uv.lock +++ b/python-interpreter/uv.lock @@ -83,6 +83,7 @@ dependencies = [ { name = "langchain-mcp-adapters" }, { name = "langchain-openai" }, { name = "mcp" }, + { name = "pyjwt", extra = ["crypto"] }, ] [package.metadata] @@ -93,6 +94,7 @@ requires-dist = [ { name = "langchain-mcp-adapters", specifier = ">=0.2.1" }, { name = "langchain-openai", specifier = ">=1.1.7" }, { name = "mcp", specifier = ">=1.26.0" }, + { name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.0" }, ] [[package]] From 42861af6f06230b0d47e2fe5b3c1f78172743602 Mon Sep 17 00:00:00 2001 From: Thareesha98 Date: Sun, 14 Jun 2026 20:46:57 +0530 Subject: [PATCH 6/8] feat(auth): add oauth2 grant flows (ballerina) Standardize AFM authentication field schemas (issue #35), Phase 3. - agent.bal: add buildOAuth2GrantConfig, mapping an oauth2 authentication block to the matching http:OAuth2*GrantConfig so the HTTP client performs the token exchange (and refresh) natively. Supports the client_credentials, password, refresh_token and jwt_bearer grants; field names are mapped from the AFM snake_case names to the connector camelCase names per grant. - parser.bal: validate oauth2 at parse time. grant_type is a required discriminator that selects the required/optional fields; unknown grants and fields not allowed for the grant are rejected. All five auth types are now recognized and validated. --- ballerina-interpreter/agent.bal | 91 ++++++++++++++++++++++++-------- ballerina-interpreter/parser.bal | 51 ++++++++++++++++++ 2 files changed, 119 insertions(+), 23 deletions(-) diff --git a/ballerina-interpreter/agent.bal b/ballerina-interpreter/agent.bal index 2ad7388..7b68e95 100644 --- a/ballerina-interpreter/agent.bal +++ b/ballerina-interpreter/agent.bal @@ -354,29 +354,11 @@ function mapToHttpClientAuth(ClientAuthentication? auth) returns http:ClientAuth return error("API key authentication is not yet supported for MCP/webhook transport in the Ballerina interpreter"); } "oauth2" => { - // record {string grantType;}|error oauth2Config = check rest.cloneWithType(); - // if oauth2Config is error { - // return error("OAuth2 authentication requires 'grantType' field", oauth2Config); - // } - - // var {grantType, ...oauth2ConfigRest} = oauth2Config; - - // match grantType.toLowerAscii() { - // "client_credentials" => { - // return oauth2ConfigRest.cloneWithType(http:OAuth2ClientCredentialsGrantConfig); - // } - // "password" => { - // return oauth2ConfigRest.cloneWithType(http:OAuth2PasswordGrantConfig); - // } - // "refresh_token" => { - // return oauth2ConfigRest.cloneWithType(http:OAuth2RefreshTokenGrantConfig); - // } - // "jwt" => { - // return oauth2Config.cloneWithType(http:OAuth2JwtBearerGrantConfig); - // } - // } - // panic error(string `Unsupported OAuth2 grant type: ${grantType}`); - return error("OAuth2 authentication not yet supported"); + OAuth2Config|error oauth2Config = rest.cloneWithType(); + if oauth2Config is error { + return error("Invalid OAuth2 authentication configuration", oauth2Config); + } + return buildOAuth2GrantConfig(oauth2Config); } "jwt" => { JwtAuthConfig|error jwtConfig = rest.cloneWithType(); @@ -436,3 +418,66 @@ function buildJwtIssuerConfig(JwtAuthConfig jwtConfig) returns http:JwtIssuerCon } return result; } + +type OAuth2Config record {| + string grant_type; + string token_url?; + string refresh_url?; + string client_id?; + string client_secret?; + string username?; + string password?; + string refresh_token?; + string assertion?; + string[] scopes?; +|}; + +function buildOAuth2GrantConfig(OAuth2Config cfg) returns http:OAuth2GrantConfig|error { + string grant = cfg.grant_type.toLowerAscii(); + match grant { + "client_credentials" => { + map grantConfig = {tokenUrl: cfg?.token_url, clientId: cfg?.client_id, clientSecret: cfg?.client_secret}; + addScopes(grantConfig, cfg?.scopes); + return wrapOAuth2(grantConfig.cloneWithType(http:OAuth2ClientCredentialsGrantConfig)); + } + "password" => { + map grantConfig = {tokenUrl: cfg?.token_url, username: cfg?.username, password: cfg?.password}; + addOptional(grantConfig, "clientId", cfg?.client_id); + addOptional(grantConfig, "clientSecret", cfg?.client_secret); + addScopes(grantConfig, cfg?.scopes); + return wrapOAuth2(grantConfig.cloneWithType(http:OAuth2PasswordGrantConfig)); + } + "refresh_token" => { + map grantConfig = {refreshUrl: cfg?.refresh_url, refreshToken: cfg?.refresh_token, clientId: cfg?.client_id, clientSecret: cfg?.client_secret}; + addScopes(grantConfig, cfg?.scopes); + return wrapOAuth2(grantConfig.cloneWithType(http:OAuth2RefreshTokenGrantConfig)); + } + "jwt_bearer" => { + map grantConfig = {tokenUrl: cfg?.token_url, assertion: cfg?.assertion}; + addOptional(grantConfig, "clientId", cfg?.client_id); + addOptional(grantConfig, "clientSecret", cfg?.client_secret); + addScopes(grantConfig, cfg?.scopes); + return wrapOAuth2(grantConfig.cloneWithType(http:OAuth2JwtBearerGrantConfig)); + } + } + return error(string `Unsupported OAuth2 grant type: ${cfg.grant_type}`); +} + +function wrapOAuth2(http:OAuth2GrantConfig|error result) returns http:OAuth2GrantConfig|error { + if result is error { + return error("Invalid OAuth2 authentication configuration", result); + } + return result; +} + +function addScopes(map target, string[]? scopes) { + if scopes is string[] { + target["scopes"] = scopes; + } +} + +function addOptional(map target, string key, string? value) { + if value is string { + target[key] = value; + } +} diff --git a/ballerina-interpreter/parser.bal b/ballerina-interpreter/parser.bal index 1a62dd0..90ff7f9 100644 --- a/ballerina-interpreter/parser.bal +++ b/ballerina-interpreter/parser.bal @@ -352,6 +352,10 @@ function validateAuthentication(ClientAuthentication? auth) returns error? { return error(string `unknown authentication type '${auth.'type}'. Supported types: bearer, basic, api-key, jwt, oauth2`); } + if authType == "oauth2" { + return validateOAuth2(auth); + } + string[]? allowed = allowedAuthFields(authType); if allowed is () { return; @@ -408,6 +412,53 @@ function requiredAuthFields(string authType) returns string[] { return []; } +function validateOAuth2(ClientAuthentication auth) returns error? { + anydata grantTypeValue = auth["grant_type"]; + if grantTypeValue !is string { + return error("type 'oauth2' requires 'grant_type' field"); + } + + string grant = grantTypeValue.toLowerAscii(); + [string[], string[]]? grantFields = oauth2GrantFields(grant); + if grantFields is () { + return error("oauth2 grant_type '" + grantTypeValue + "' is not supported. Supported grant types: client_credentials, password, refresh_token, jwt_bearer"); + } + + [string[], string[]] [required, optional] = grantFields; + string[] allowed = ["grant_type", ...required, ...optional]; + string[] provided = from string key in auth.keys() where key != "type" select key; + + foreach string req in required { + if provided.indexOf(req) is () { + return error("oauth2 grant_type '" + grant + "' requires '" + req + "' field"); + } + } + + foreach string fieldName in provided { + if allowed.indexOf(fieldName) is () { + return error("oauth2 grant_type '" + grant + "' does not support '" + fieldName + "' field"); + } + } +} + +function oauth2GrantFields(string grant) returns [string[], string[]]? { + match grant { + "client_credentials" => { + return [["token_url", "client_id", "client_secret"], ["scopes"]]; + } + "password" => { + return [["token_url", "username", "password", "client_id", "client_secret"], ["scopes"]]; + } + "refresh_token" => { + return [["refresh_url", "refresh_token", "client_id", "client_secret"], ["scopes"]]; + } + "jwt_bearer" => { + return [["token_url", "assertion"], ["client_id", "client_secret", "scopes"]]; + } + } + return (); +} + function signatureContainsHttpVariable(Signature signature) returns boolean => jsonSchemaContainsHttpVariable(signature.input) || jsonSchemaContainsHttpVariable(signature.output); From 73d60f2ed37f7b8775efc9ba3bcb4fe9ebcb95f6 Mon Sep 17 00:00:00 2001 From: Thareesha98 Date: Sun, 14 Jun 2026 20:47:17 +0530 Subject: [PATCH 7/8] feat(auth): add oauth2 grant flows (python) Standardize AFM authentication field schemas (issue #35), Phase 3. - models.py: validate oauth2 at parse time. grant_type is a required discriminator selecting the required/optional fields per grant (client_credentials, password, refresh_token, jwt_bearer); unknown grants and fields not allowed for the grant are rejected. - mcp.py: add OAuth2Auth (httpx.Auth) that obtains an access token via a token exchange, caches it until expiry, and sends it as a bearer token. Implements both sync and async flows; client credentials are sent as HTTP Basic to match the Ballerina runtime, and jwt_bearer uses the RFC 7523 grant URN with a user-supplied assertion. No new dependency. All five auth types are now supported (none remain not yet supported). --- .../packages/afm-core/src/afm/models.py | 74 ++++++++- .../src/afm_langchain/tools/mcp.py | 150 +++++++++++++++++- 2 files changed, 217 insertions(+), 7 deletions(-) diff --git a/python-interpreter/packages/afm-core/src/afm/models.py b/python-interpreter/packages/afm-core/src/afm/models.py index 1d18f38..0245f3c 100644 --- a/python-interpreter/packages/afm-core/src/afm/models.py +++ b/python-interpreter/packages/afm-core/src/afm/models.py @@ -32,6 +32,7 @@ class Provider(BaseModel): RECOGNIZED_AUTH_TYPES = ("bearer", "basic", "api-key", "jwt", "oauth2") + _JWT_ALLOWED_FIELDS = { "issuer", "audience", @@ -54,6 +55,25 @@ class Provider(BaseModel): "api-key": {"api_key", "header_name"}, "jwt": _JWT_ALLOWED_FIELDS, } + +_OAUTH2_GRANTS: dict[str, dict[str, set[str]]] = { + "client_credentials": { + "required": {"token_url", "client_id", "client_secret"}, + "optional": {"scopes"}, + }, + "password": { + "required": {"token_url", "username", "password", "client_id", "client_secret"}, + "optional": {"scopes"}, + }, + "refresh_token": { + "required": {"refresh_url", "refresh_token", "client_id", "client_secret"}, + "optional": {"scopes"}, + }, + "jwt_bearer": { + "required": {"token_url", "assertion"}, + "optional": {"client_id", "client_secret", "scopes"}, + }, +} _AUTH_CREDENTIAL_FIELDS = ( "token", "username", @@ -68,6 +88,14 @@ class Provider(BaseModel): "subject", "custom_claims", "expiry_seconds", + "grant_type", + "token_url", + "refresh_url", + "client_id", + "client_secret", + "refresh_token", + "assertion", + "scopes", ) @@ -88,6 +116,14 @@ class ClientAuthentication(BaseModel): subject: str | None = None custom_claims: dict[str, Any] | None = None expiry_seconds: int | None = None + grant_type: str | None = None + token_url: str | None = None + refresh_url: str | None = None + client_id: str | None = None + client_secret: str | None = None + refresh_token: str | None = None + assertion: str | None = None + scopes: list[str] | None = None @model_validator(mode="after") def validate_type_fields(self) -> Self: @@ -100,22 +136,22 @@ def validate_type_fields(self) -> Self: f"Supported types: {supported}" ) - allowed = _AUTH_ALLOWED_FIELDS.get(auth_type) - if allowed is None: - return self - provided = { name for name in _AUTH_CREDENTIAL_FIELDS if getattr(self, name) is not None } provided |= set(self.model_extra or {}) + if auth_type == "oauth2": + self._validate_oauth2_fields(provided) + return self + missing = _AUTH_REQUIRED_FIELDS[auth_type] - provided if missing: fields = ", ".join(f"'{name}'" for name in sorted(missing)) suffix = "field" if len(missing) == 1 else "fields" raise ValueError(f"type '{auth_type}' requires {fields} {suffix}") - unknown = provided - allowed + unknown = provided - _AUTH_ALLOWED_FIELDS[auth_type] if unknown: fields = ", ".join(f"'{name}'" for name in sorted(unknown)) suffix = "field" if len(unknown) == 1 else "fields" @@ -123,6 +159,34 @@ def validate_type_fields(self) -> Self: return self + def _validate_oauth2_fields(self, provided: set[str]) -> None: + if self.grant_type is None: + raise ValueError("type 'oauth2' requires 'grant_type' field") + + grant = self.grant_type.lower() + spec = _OAUTH2_GRANTS.get(grant) + if spec is None: + supported = ", ".join(_OAUTH2_GRANTS) + raise ValueError( + f"oauth2 grant_type '{self.grant_type}' is not supported. " + f"Supported grant types: {supported}" + ) + + missing = spec["required"] - provided + if missing: + fields = ", ".join(f"'{name}'" for name in sorted(missing)) + suffix = "field" if len(missing) == 1 else "fields" + raise ValueError(f"oauth2 grant_type '{grant}' requires {fields} {suffix}") + + allowed = {"grant_type"} | spec["required"] | spec["optional"] + unknown = provided - allowed + if unknown: + fields = ", ".join(f"'{name}'" for name in sorted(unknown)) + suffix = "field" if len(unknown) == 1 else "fields" + raise ValueError( + f"oauth2 grant_type '{grant}' does not support {fields} {suffix}" + ) + class Model(BaseModel): model_config = ConfigDict(extra="forbid") diff --git a/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py b/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py index 8150540..65e97a5 100644 --- a/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py +++ b/python-interpreter/packages/afm-langchain/src/afm_langchain/tools/mcp.py @@ -124,6 +124,141 @@ def auth_flow(self, request: httpx.Request): yield request +_JWT_BEARER_GRANT_URN = "urn:ietf:params:oauth:grant-type:jwt-bearer" + + +class OAuth2Auth(httpx.Auth): + + def __init__( + self, + *, + grant_type: str, + token_url: str | None = None, + refresh_url: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, + username: str | None = None, + password: str | None = None, + refresh_token: str | None = None, + assertion: str | None = None, + scopes: list[str] | None = None, + ) -> None: + self.grant_type = grant_type.lower() + self.token_url = token_url + self.refresh_url = refresh_url + self.client_id = client_id + self.client_secret = client_secret + self.username = username + self.password = password + self.refresh_token = refresh_token + self.assertion = assertion + self.scopes = scopes + self._token: str | None = None + self._expires_at: float = 0.0 + + def _token_request(self) -> tuple[str, dict[str, str], tuple[str, str] | None]: + url: str | None + data: dict[str, str] + if self.grant_type == "client_credentials": + url = self.token_url + data = {"grant_type": "client_credentials"} + elif self.grant_type == "password": + url = self.token_url + data = { + "grant_type": "password", + "username": self.username or "", + "password": self.password or "", + } + elif self.grant_type == "refresh_token": + url = self.refresh_url + data = { + "grant_type": "refresh_token", + "refresh_token": self.refresh_token or "", + } + elif self.grant_type == "jwt_bearer": + url = self.token_url + data = { + "grant_type": _JWT_BEARER_GRANT_URN, + "assertion": self.assertion or "", + } + else: + raise MCPAuthenticationError( + f"Unsupported oauth2 grant_type: {self.grant_type}" + ) + + if url is None: + raise MCPAuthenticationError( + f"oauth2 grant_type '{self.grant_type}' is missing its token URL" + ) + if self.scopes: + data["scope"] = " ".join(self.scopes) + + basic = ( + (self.client_id, self.client_secret) + if self.client_id is not None and self.client_secret is not None + else None + ) + return url, data, basic + + def _store_token(self, payload: dict) -> str: + token = payload.get("access_token") + if not token: + raise MCPAuthenticationError( + "OAuth2 token response did not contain 'access_token'" + ) + expires_in = float(payload.get("expires_in", 3600)) + self._expires_at = time.time() + expires_in - 30 + self._token = token + return token + + def _cached_token(self) -> str | None: + if self._token is not None and time.time() < self._expires_at: + return self._token + return None + + def _fetch_token_sync(self) -> str: + url, data, basic = self._token_request() + try: + resp = httpx.post( + url, + data=data, + auth=basic, + headers={"Accept": "application/json"}, + timeout=30.0, + ) + resp.raise_for_status() + payload = resp.json() + except Exception as e: + raise MCPAuthenticationError(f"OAuth2 token request failed: {e}") from e + return self._store_token(payload) + + async def _fetch_token_async(self) -> str: + url, data, basic = self._token_request() + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + url, + data=data, + auth=basic, + headers={"Accept": "application/json"}, + ) + resp.raise_for_status() + payload = resp.json() + except Exception as e: + raise MCPAuthenticationError(f"OAuth2 token request failed: {e}") from e + return self._store_token(payload) + + def sync_auth_flow(self, request: httpx.Request): + token = self._cached_token() or self._fetch_token_sync() + request.headers["Authorization"] = f"Bearer {token}" + yield request + + async def async_auth_flow(self, request: httpx.Request): + token = self._cached_token() or await self._fetch_token_async() + request.headers["Authorization"] = f"Bearer {token}" + yield request + + def build_httpx_auth(auth: ClientAuthentication | None) -> httpx.Auth | None: if auth is None: return None @@ -166,8 +301,19 @@ def build_httpx_auth(auth: ClientAuthentication | None) -> httpx.Auth | None: ) elif auth_type == "oauth2": - raise MCPAuthenticationError( - "Authentication type 'oauth2' not yet supported" + if auth.grant_type is None: + raise MCPAuthenticationError("OAuth2 auth requires 'grant_type' field") + return OAuth2Auth( + grant_type=auth.grant_type, + token_url=auth.token_url, + refresh_url=auth.refresh_url, + client_id=auth.client_id, + client_secret=auth.client_secret, + username=auth.username, + password=auth.password, + refresh_token=auth.refresh_token, + assertion=auth.assertion, + scopes=auth.scopes, ) else: From fc855009b9d816ac2f157ee6311d6808c6d6e4f8 Mon Sep 17 00:00:00 2001 From: Thareesha98 Date: Sun, 14 Jun 2026 20:47:17 +0530 Subject: [PATCH 8/8] feat(auth): add oauth2 auth tests Standardize AFM authentication field schemas (issue #35), Phase 3. - cover parse-time validation for all four grants (valid configs, missing grant_type, unsupported grant, missing required fields, fields not allowed for the grant, case-insensitive grant_type) and the runtime token-exchange mapping, in both runtimes. --- ballerina-interpreter/tests/main_test.bal | 95 +++++++++++++++++-- .../packages/afm-core/tests/test_parser.py | 93 +++++++++++++++++- .../packages/afm-langchain/tests/test_mcp.py | 91 +++++++++++++++++- 3 files changed, 268 insertions(+), 11 deletions(-) diff --git a/ballerina-interpreter/tests/main_test.bal b/ballerina-interpreter/tests/main_test.bal index 35bc35a..fb37d56 100644 --- a/ballerina-interpreter/tests/main_test.bal +++ b/ballerina-interpreter/tests/main_test.bal @@ -1162,16 +1162,39 @@ function testMapToHttpClientAuthBearer() returns error? { } @test:Config -function testMapToHttpClientAuthOAuth2NotSupported() { +function testMapToHttpClientAuthOAuth2ClientCredentials() returns error? { ClientAuthentication auth = { - 'type: "oauth2" + 'type: "oauth2", + "grant_type": "client_credentials", + "token_url": "https://auth.example.com/token", + "client_id": "id", + "client_secret": "secret", + "scopes": ["read", "write"] }; - http:ClientAuthConfig|error? result = mapToHttpClientAuth(auth); - if result is http:ClientAuthConfig? { - test:assertFail("Expected error for OAuth2 authentication"); - } - test:assertEquals(result.message(), "OAuth2 authentication not yet supported"); + http:ClientAuthConfig? result = check mapToHttpClientAuth(auth); + test:assertTrue(result is http:OAuth2ClientCredentialsGrantConfig); + http:OAuth2ClientCredentialsGrantConfig grantConfig = result; + test:assertEquals(grantConfig.tokenUrl, "https://auth.example.com/token"); + test:assertEquals(grantConfig.clientId, "id"); +} + +@test:Config +function testMapToHttpClientAuthOAuth2RefreshToken() returns error? { + ClientAuthentication auth = { + 'type: "oauth2", + "grant_type": "refresh_token", + "refresh_url": "https://auth.example.com/token", + "refresh_token": "rt", + "client_id": "id", + "client_secret": "secret" + }; + + http:ClientAuthConfig? result = check mapToHttpClientAuth(auth); + test:assertTrue(result is http:OAuth2RefreshTokenGrantConfig); + http:OAuth2RefreshTokenGrantConfig grantConfig = result; + test:assertEquals(grantConfig.refreshUrl, "https://auth.example.com/token"); + test:assertEquals(grantConfig.refreshToken, "rt"); } @test:Config @@ -1287,6 +1310,64 @@ function testValidateAuthenticationJwtMissingField() { test:assertEquals((result).message(), "type 'jwt' requires 'audience' field"); } +@test:Config +function testValidateAuthenticationOAuth2Valid() returns error? { + ClientAuthentication auth = { + 'type: "oauth2", + "grant_type": "client_credentials", + "token_url": "https://auth.example.com/token", + "client_id": "id", + "client_secret": "secret", + "scopes": ["read"] + }; + error? result = validateAuthentication(auth); + test:assertTrue(result is ()); +} + +@test:Config +function testValidateAuthenticationOAuth2MissingGrantType() { + ClientAuthentication auth = {'type: "oauth2", "token_url": "u"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertEquals((result).message(), "type 'oauth2' requires 'grant_type' field"); +} + +@test:Config +function testValidateAuthenticationOAuth2UnknownGrantType() { + ClientAuthentication auth = {'type: "oauth2", "grant_type": "device_code", "token_url": "u"}; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertTrue((result).message().includes("grant_type 'device_code' is not supported")); +} + +@test:Config +function testValidateAuthenticationOAuth2MissingRequiredField() { + ClientAuthentication auth = { + 'type: "oauth2", + "grant_type": "client_credentials", + "token_url": "u", + "client_id": "id" + }; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertEquals((result).message(), "oauth2 grant_type 'client_credentials' requires 'client_secret' field"); +} + +@test:Config +function testValidateAuthenticationOAuth2FieldNotAllowed() { + ClientAuthentication auth = { + 'type: "oauth2", + "grant_type": "client_credentials", + "token_url": "u", + "client_id": "id", + "client_secret": "secret", + "refresh_token": "rt" + }; + error? result = validateAuthentication(auth); + test:assertTrue(result is error); + test:assertEquals((result).message(), "oauth2 grant_type 'client_credentials' does not support 'refresh_token' field"); +} + @test:Config function testValidateAuthenticationUnknownType() { ClientAuthentication auth = {'type: "token", "token": "t"}; diff --git a/python-interpreter/packages/afm-core/tests/test_parser.py b/python-interpreter/packages/afm-core/tests/test_parser.py index e5b453a..6f294dc 100644 --- a/python-interpreter/packages/afm-core/tests/test_parser.py +++ b/python-interpreter/packages/afm-core/tests/test_parser.py @@ -550,9 +550,96 @@ def test_jwt_unknown_field_rejected(self) -> None: type="jwt", issuer="i", audience="a", signing_key="s", token="x" ) - def test_oauth2_recognized_without_field_validation(self) -> None: - auth = ClientAuthentication(type="oauth2") - assert auth.type == "oauth2" + def test_oauth2_client_credentials_valid(self) -> None: + auth = ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="https://auth.example.com/token", + client_id="id", + client_secret="secret", + scopes=["read", "write"], + ) + assert auth.grant_type == "client_credentials" + assert auth.scopes == ["read", "write"] + + def test_oauth2_password_valid(self) -> None: + auth = ClientAuthentication( + type="oauth2", + grant_type="password", + token_url="https://auth.example.com/token", + username="u", + password="p", + client_id="id", + client_secret="secret", + ) + assert auth.grant_type == "password" + + def test_oauth2_refresh_token_valid(self) -> None: + auth = ClientAuthentication( + type="oauth2", + grant_type="refresh_token", + refresh_url="https://auth.example.com/token", + refresh_token="rt", + client_id="id", + client_secret="secret", + ) + assert auth.grant_type == "refresh_token" + + def test_oauth2_jwt_bearer_valid(self) -> None: + auth = ClientAuthentication( + type="oauth2", + grant_type="jwt_bearer", + token_url="https://auth.example.com/token", + assertion="signed.jwt.token", + ) + assert auth.grant_type == "jwt_bearer" + + def test_oauth2_grant_type_case_insensitive(self) -> None: + auth = ClientAuthentication( + type="oauth2", + grant_type="Client_Credentials", + token_url="u", + client_id="id", + client_secret="secret", + ) + assert auth.grant_type == "Client_Credentials" + + def test_oauth2_missing_grant_type_rejected(self) -> None: + with pytest.raises( + ValidationError, match="type 'oauth2' requires 'grant_type'" + ): + ClientAuthentication(type="oauth2", token_url="u") + + def test_oauth2_unknown_grant_type_rejected(self) -> None: + with pytest.raises( + ValidationError, match="grant_type 'device_code' is not supported" + ): + ClientAuthentication( + type="oauth2", grant_type="device_code", token_url="u" + ) + + def test_oauth2_missing_required_field_rejected(self) -> None: + with pytest.raises( + ValidationError, + match="grant_type 'client_credentials' requires 'client_secret'", + ): + ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="u", + client_id="id", + ) + + def test_oauth2_field_not_allowed_for_grant_rejected(self) -> None: + with pytest.raises(ValidationError, match="does not support 'refresh_token'"): + ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="u", + client_id="id", + client_secret="secret", + refresh_token="rt", + ) def test_unknown_type_rejected(self) -> None: with pytest.raises(ValidationError, match="unknown authentication type 'token'"): diff --git a/python-interpreter/packages/afm-langchain/tests/test_mcp.py b/python-interpreter/packages/afm-langchain/tests/test_mcp.py index 3cf0508..502a989 100644 --- a/python-interpreter/packages/afm-langchain/tests/test_mcp.py +++ b/python-interpreter/packages/afm-langchain/tests/test_mcp.py @@ -14,6 +14,7 @@ # specific language governing permissions and limitations # under the License. +import time from typing import cast from unittest.mock import AsyncMock, MagicMock, patch @@ -40,6 +41,7 @@ JwtAuth, MCPClient, MCPManager, + OAuth2Auth, build_httpx_auth, filter_tools, ) @@ -59,7 +61,13 @@ def make_mcp_server( elif auth_type == "api-key": auth = ClientAuthentication(type="api-key", api_key="test-api-key") elif auth_type == "oauth2": - auth = ClientAuthentication(type="oauth2") + auth = ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="https://auth.example.com/token", + client_id="id", + client_secret="secret", + ) elif auth_type == "jwt": auth = ClientAuthentication( type="jwt", @@ -239,6 +247,87 @@ def test_jwt_auth_signs_rs256_with_key_file(self, tmp_path): ) assert decoded["iss"] == "afm-agent" + def test_oauth2_returns_oauth2_auth_instance(self): + auth = ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="https://auth.example.com/token", + client_id="id", + client_secret="secret", + scopes=["read"], + ) + result = build_httpx_auth(auth) + assert isinstance(result, OAuth2Auth) + assert result.grant_type == "client_credentials" + + def test_oauth2_client_credentials_token_request(self): + result = build_httpx_auth( + ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="https://auth.example.com/token", + client_id="id", + client_secret="secret", + scopes=["read", "write"], + ) + ) + assert isinstance(result, OAuth2Auth) + url, data, basic = result._token_request() + assert url == "https://auth.example.com/token" + assert data["grant_type"] == "client_credentials" + assert data["scope"] == "read write" + assert basic == ("id", "secret") + + def test_oauth2_refresh_token_uses_refresh_url(self): + result = build_httpx_auth( + ClientAuthentication( + type="oauth2", + grant_type="refresh_token", + refresh_url="https://auth.example.com/refresh", + refresh_token="rt", + client_id="id", + client_secret="secret", + ) + ) + assert isinstance(result, OAuth2Auth) + url, data, _basic = result._token_request() + assert url == "https://auth.example.com/refresh" + assert data["grant_type"] == "refresh_token" + assert data["refresh_token"] == "rt" + + def test_oauth2_jwt_bearer_token_request(self): + result = build_httpx_auth( + ClientAuthentication( + type="oauth2", + grant_type="jwt_bearer", + token_url="https://auth.example.com/token", + assertion="signed.jwt", + ) + ) + assert isinstance(result, OAuth2Auth) + _url, data, basic = result._token_request() + assert data["grant_type"] == "urn:ietf:params:oauth:grant-type:jwt-bearer" + assert data["assertion"] == "signed.jwt" + assert basic is None # no client credentials provided + + def test_oauth2_uses_cached_token(self): + result = build_httpx_auth( + ClientAuthentication( + type="oauth2", + grant_type="client_credentials", + token_url="u", + client_id="id", + client_secret="secret", + ) + ) + assert isinstance(result, OAuth2Auth) + result._token = "cached-token" + result._expires_at = time.time() + 1000 + request = httpx.Request("GET", "https://api.example.com/resource") + flow = result.sync_auth_flow(request) + next(flow) + assert request.headers["Authorization"] == "Bearer cached-token" + class TestFilterTools: def test_no_filter_returns_all_tools(self):