diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index b191f98e0..b219ed865 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -17,9 +17,12 @@ This module provides helper functions for common proto type operations. """ +import logging + from typing import TYPE_CHECKING, Any, TypedDict from google.api.field_behavior_pb2 import FieldBehavior, field_behavior +from google.protobuf import __version__ as protobuf_version from google.protobuf.descriptor import FieldDescriptor from google.protobuf.json_format import ParseDict from google.protobuf.message import Message as ProtobufMessage @@ -28,6 +31,35 @@ from a2a.utils.errors import InvalidParamsError +logger = logging.getLogger(__name__) + +# FieldDescriptor.is_repeated was introduced in protobuf 4.0; field.label was +# removed in protobuf 7.0. Check once at import time so _is_field_repeated() +# avoids a per-call hasattr probe on a hot path. + +_PROTOBUF_HAS_IS_REPEATED: bool = hasattr(FieldDescriptor, 'is_repeated') + +logger.debug( + 'Protobuf %s: using %s API', + protobuf_version, + 'field.is_repeated' + if _PROTOBUF_HAS_IS_REPEATED + else 'deprecated field.label', +) + + +def _is_field_repeated(field: FieldDescriptor) -> bool: + """Return True if *field* is a repeated field. + + Uses ``field.is_repeated`` (protobuf ≥ 4.0) when available, and falls back + to ``field.label == FieldDescriptor.LABEL_REPEATED`` for older versions. + See https://github.com/a2aproject/a2a-python/issues/1011. + """ + if _PROTOBUF_HAS_IS_REPEATED: + return field.is_repeated # type: ignore[attr-defined] + return field.label == FieldDescriptor.LABEL_REPEATED # type: ignore[attr-defined] + + if TYPE_CHECKING: from starlette.datastructures import QueryParams else: @@ -36,7 +68,7 @@ except ImportError: QueryParams = Any -from a2a.types.a2a_pb2 import ( +from a2a.types.a2a_pb2 import ( # noqa: E402 Message, StreamResponse, Task, @@ -174,10 +206,7 @@ def parse_params(params: QueryParams, message: ProtobufMessage) -> None: field = fields[k] v_list = params.getlist(k) - # TODO(https://github.com/a2aproject/a2a-python/issues/1011): Replace - # deprecated `field.label` with `field.is_repeated` once the minimum - # protobuf version requirement is bumped. - if field.label == FieldDescriptor.LABEL_REPEATED: + if _is_field_repeated(field): accumulated: list[Any] = [] for v in v_list: if not v: @@ -211,10 +240,7 @@ def _check_required_field_violation( ) -> ValidationDetail | None: """Check if a required field is missing or invalid.""" val = getattr(msg, field.name) - # TODO(https://github.com/a2aproject/a2a-python/issues/1011): Replace - # deprecated `field.label` with `field.is_repeated` once the minimum - # protobuf version requirement is bumped. - if field.label == FieldDescriptor.LABEL_REPEATED: + if _is_field_repeated(field): if not val: return ValidationDetail( field=field.name, @@ -255,10 +281,7 @@ def _recurse_validation( return errors val = getattr(msg, field.name) - # TODO(https://github.com/a2aproject/a2a-python/issues/1011): Replace - # deprecated `field.label` with `field.is_repeated` once the minimum - # protobuf version requirement is bumped. - if field.label != FieldDescriptor.LABEL_REPEATED: + if not _is_field_repeated(field): if msg.HasField(field.name): sub_errs = _validate_proto_required_fields_internal(val) _append_nested_errors(errors, field.name, sub_errs) diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index db49dbf05..698a8c73b 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -3,15 +3,19 @@ This module tests the proto utilities including to_stream_response and dictionary normalization. """ +from unittest.mock import patch + import httpx import pytest from a2a.types.a2a_pb2 import ( + AgentCard, AgentSkill, ListTasksRequest, Message, Part, Role, + SecurityScheme, StreamResponse, Task, TaskArtifactUpdateEvent, @@ -24,6 +28,7 @@ from google.protobuf.json_format import MessageToDict, Parse from google.protobuf.message import Message as ProtobufMessage from google.protobuf.timestamp_pb2 import Timestamp +from google.rpc import error_details_pb2 from starlette.datastructures import QueryParams @@ -236,9 +241,10 @@ def test_repeated_fields_parsing(self, query_string: str): def _message_to_rest_params(self, message: ProtobufMessage) -> QueryParams: """Converts a message to REST query parameters.""" rest_dict = MessageToDict(message) - return httpx.Request( + httpx_params = httpx.Request( 'GET', 'http://api.example.com', params=rest_dict ).url.params + return QueryParams(str(httpx_params)) class TestValidateProtoRequiredFields: @@ -276,3 +282,124 @@ def test_nested_required_fields(self): fields = [e['field'] for e in errors] assert 'status.state' in fields + + +class TestIsFieldRepeated: + """Tests for the _is_field_repeated helper, including the legacy fallback.""" + + def test_repeated_field_fallback_path(self): + """Uses the legacy field.label path when is_repeated is unavailable.""" + tags_field = AgentSkill.DESCRIPTOR.fields_by_name['tags'] + with patch('a2a.utils.proto_utils._PROTOBUF_HAS_IS_REPEATED', False): + assert proto_utils._is_field_repeated(tags_field) is True + + def test_non_repeated_field_fallback_path(self): + """Legacy field.label path returns False for a non-repeated field.""" + id_field = AgentSkill.DESCRIPTOR.fields_by_name['id'] + with patch('a2a.utils.proto_utils._PROTOBUF_HAS_IS_REPEATED', False): + assert proto_utils._is_field_repeated(id_field) is False + + +class TestParseParamsEdgeCases: + """Edge-case tests for parse_params to cover missing branches.""" + + def test_unknown_key_is_ignored(self): + """Unknown query param keys are silently ignored; known keys are still parsed.""" + msg = ListTasksRequest() + proto_utils.parse_params(QueryParams('unknownKey=value&tenant=t1'), msg) + assert msg.tenant == 't1' + + def test_repeated_field_skips_empty_string(self): + """Empty string values in a repeated field are skipped rather than accumulated.""" + msg = AgentSkill() + proto_utils.parse_params(QueryParams('id=s1&tags=&tags=tag1'), msg) + assert list(msg.tags) == ['tag1'] + + def test_repeated_field_non_string_value(self): + """Non-string values in a repeated field are appended directly without splitting.""" + + class _MockParams: + def keys(self): + return ['tags'] + + def getlist(self, _key): + return ['tag1', 42] # 42 is a non-string + + msg = AgentSkill() + with patch('a2a.utils.proto_utils.ParseDict') as mock_parse: + proto_utils.parse_params(_MockParams(), msg) # type: ignore[arg-type] + # 42 should be appended directly (not split as a string) + mock_parse.assert_called_once_with( + {'tags': ['tag1', 42]}, msg, ignore_unknown_fields=True + ) + + +class TestValidationEdgeCases: + """Additional validation tests to cover missing branches.""" + + def test_required_message_field_not_set(self): + """A REQUIRED message field with presence that is not set produces a validation error.""" + # Task.status is REQUIRED + has_presence; omitting it hits the branch. + task = Task(id='task-1', context_id='ctx-1') + with pytest.raises(InvalidParamsError) as exc_info: + proto_utils.validate_proto_required_fields(task) + + errors = ( + exc_info.value.data.get('errors', []) if exc_info.value.data else [] + ) + fields = [e['field'] for e in errors] + assert 'status' in fields + + def test_map_field_recurse_validation(self): + """Map entry fields are recursively validated when populated.""" + # AgentCard.security_schemes is a map. + # Populating it causes _recurse_validation to enter the map_entry branch. + card = AgentCard() + card.security_schemes['myScheme'].CopyFrom(SecurityScheme()) + # We only need the code path to execute; errors from other required + # fields on AgentCard are expected. + errors = proto_utils._validate_proto_required_fields_internal(card) + # The map branch ran; verify no crash and we got some errors. + assert isinstance(errors, list) + + +class TestBadRequestConversions: + """Tests for validation_errors_to_bad_request and bad_request_to_validation_errors.""" + + def test_validation_errors_to_bad_request(self): + """Lines 334-339: convert ValidationDetail list to BadRequest proto.""" + errors: list[proto_utils.ValidationDetail] = [ + proto_utils.ValidationDetail(field='foo', message='required'), + proto_utils.ValidationDetail(field='bar', message='invalid'), + ] + bad_request = proto_utils.validation_errors_to_bad_request(errors) + + assert isinstance(bad_request, error_details_pb2.BadRequest) + assert len(bad_request.field_violations) == 2 + assert bad_request.field_violations[0].field == 'foo' + assert bad_request.field_violations[0].description == 'required' + assert bad_request.field_violations[1].field == 'bar' + assert bad_request.field_violations[1].description == 'invalid' + + def test_bad_request_to_validation_errors(self): + """Converts a BadRequest proto back to a ValidationDetail list.""" + bad_request = error_details_pb2.BadRequest() + v = bad_request.field_violations.add() + v.field = 'baz' + v.description = 'must be set' + + errors = proto_utils.bad_request_to_validation_errors(bad_request) + + assert len(errors) == 1 + assert errors[0]['field'] == 'baz' + assert errors[0]['message'] == 'must be set' + + def test_bad_request_roundtrip(self): + """Roundtrip: ValidationDetail -> BadRequest -> ValidationDetail.""" + original: list[proto_utils.ValidationDetail] = [ + proto_utils.ValidationDetail(field='x', message='err'), + ] + restored = proto_utils.bad_request_to_validation_errors( + proto_utils.validation_errors_to_bad_request(original) + ) + assert restored == original