Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions src/a2a/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
This module provides helper functions for common proto type operations.
"""

import logging

from typing import TYPE_CHECKING, Any, TypedDict

from google.api.field_behavior_pb2 import FieldBehavior, field_behavior
from google.protobuf import __version__ as protobuf_version
from google.protobuf.descriptor import FieldDescriptor
from google.protobuf.json_format import ParseDict
from google.protobuf.message import Message as ProtobufMessage
Expand All @@ -28,6 +31,35 @@
from a2a.utils.errors import InvalidParamsError


logger = logging.getLogger(__name__)

# FieldDescriptor.is_repeated was introduced in protobuf 4.0; field.label was
# removed in protobuf 7.0. Check once at import time so _is_field_repeated()
# avoids a per-call hasattr probe on a hot path.

_PROTOBUF_HAS_IS_REPEATED: bool = hasattr(FieldDescriptor, 'is_repeated')

logger.debug(
'Protobuf %s: using %s API',
protobuf_version,
'field.is_repeated'
if _PROTOBUF_HAS_IS_REPEATED
else 'deprecated field.label',
)


def _is_field_repeated(field: FieldDescriptor) -> bool:
"""Return True if *field* is a repeated field.

Uses ``field.is_repeated`` (protobuf ≥ 4.0) when available, and falls back
to ``field.label == FieldDescriptor.LABEL_REPEATED`` for older versions.
See https://github.com/a2aproject/a2a-python/issues/1011.
"""
if _PROTOBUF_HAS_IS_REPEATED:
return field.is_repeated # type: ignore[attr-defined]
return field.label == FieldDescriptor.LABEL_REPEATED # type: ignore[attr-defined]


if TYPE_CHECKING:
from starlette.datastructures import QueryParams
else:
Expand All @@ -36,7 +68,7 @@
except ImportError:
QueryParams = Any

from a2a.types.a2a_pb2 import (
from a2a.types.a2a_pb2 import ( # noqa: E402
Message,
StreamResponse,
Task,
Expand Down Expand Up @@ -174,10 +206,7 @@ def parse_params(params: QueryParams, message: ProtobufMessage) -> None:
field = fields[k]
v_list = params.getlist(k)

# TODO(https://github.com/a2aproject/a2a-python/issues/1011): Replace
# deprecated `field.label` with `field.is_repeated` once the minimum
# protobuf version requirement is bumped.
if field.label == FieldDescriptor.LABEL_REPEATED:
if _is_field_repeated(field):
accumulated: list[Any] = []
for v in v_list:
if not v:
Expand Down Expand Up @@ -211,10 +240,7 @@ def _check_required_field_violation(
) -> ValidationDetail | None:
"""Check if a required field is missing or invalid."""
val = getattr(msg, field.name)
# TODO(https://github.com/a2aproject/a2a-python/issues/1011): Replace
# deprecated `field.label` with `field.is_repeated` once the minimum
# protobuf version requirement is bumped.
if field.label == FieldDescriptor.LABEL_REPEATED:
if _is_field_repeated(field):
if not val:
return ValidationDetail(
field=field.name,
Expand Down Expand Up @@ -255,10 +281,7 @@ def _recurse_validation(
return errors

val = getattr(msg, field.name)
# TODO(https://github.com/a2aproject/a2a-python/issues/1011): Replace
# deprecated `field.label` with `field.is_repeated` once the minimum
# protobuf version requirement is bumped.
if field.label != FieldDescriptor.LABEL_REPEATED:
if not _is_field_repeated(field):
if msg.HasField(field.name):
sub_errs = _validate_proto_required_fields_internal(val)
_append_nested_errors(errors, field.name, sub_errs)
Expand Down
129 changes: 128 additions & 1 deletion tests/utils/test_proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
This module tests the proto utilities including to_stream_response and dictionary normalization.
"""

from unittest.mock import patch

import httpx
import pytest

from a2a.types.a2a_pb2 import (
AgentCard,
AgentSkill,
ListTasksRequest,
Message,
Part,
Role,
SecurityScheme,
StreamResponse,
Task,
TaskArtifactUpdateEvent,
Expand All @@ -24,6 +28,7 @@
from google.protobuf.json_format import MessageToDict, Parse
from google.protobuf.message import Message as ProtobufMessage
from google.protobuf.timestamp_pb2 import Timestamp
from google.rpc import error_details_pb2
from starlette.datastructures import QueryParams


Expand Down Expand Up @@ -236,9 +241,10 @@ def test_repeated_fields_parsing(self, query_string: str):
def _message_to_rest_params(self, message: ProtobufMessage) -> QueryParams:
"""Converts a message to REST query parameters."""
rest_dict = MessageToDict(message)
return httpx.Request(
httpx_params = httpx.Request(
'GET', 'http://api.example.com', params=rest_dict
).url.params
return QueryParams(str(httpx_params))


class TestValidateProtoRequiredFields:
Expand Down Expand Up @@ -276,3 +282,124 @@ def test_nested_required_fields(self):

fields = [e['field'] for e in errors]
assert 'status.state' in fields


class TestIsFieldRepeated:
"""Tests for the _is_field_repeated helper, including the legacy fallback."""

def test_repeated_field_fallback_path(self):
"""Uses the legacy field.label path when is_repeated is unavailable."""
tags_field = AgentSkill.DESCRIPTOR.fields_by_name['tags']
with patch('a2a.utils.proto_utils._PROTOBUF_HAS_IS_REPEATED', False):
assert proto_utils._is_field_repeated(tags_field) is True

def test_non_repeated_field_fallback_path(self):
"""Legacy field.label path returns False for a non-repeated field."""
id_field = AgentSkill.DESCRIPTOR.fields_by_name['id']
with patch('a2a.utils.proto_utils._PROTOBUF_HAS_IS_REPEATED', False):
assert proto_utils._is_field_repeated(id_field) is False


class TestParseParamsEdgeCases:
"""Edge-case tests for parse_params to cover missing branches."""

def test_unknown_key_is_ignored(self):
"""Unknown query param keys are silently ignored; known keys are still parsed."""
msg = ListTasksRequest()
proto_utils.parse_params(QueryParams('unknownKey=value&tenant=t1'), msg)
assert msg.tenant == 't1'

def test_repeated_field_skips_empty_string(self):
"""Empty string values in a repeated field are skipped rather than accumulated."""
msg = AgentSkill()
proto_utils.parse_params(QueryParams('id=s1&tags=&tags=tag1'), msg)
assert list(msg.tags) == ['tag1']

def test_repeated_field_non_string_value(self):
"""Non-string values in a repeated field are appended directly without splitting."""

class _MockParams:
def keys(self):
return ['tags']

def getlist(self, _key):
return ['tag1', 42] # 42 is a non-string

msg = AgentSkill()
with patch('a2a.utils.proto_utils.ParseDict') as mock_parse:
proto_utils.parse_params(_MockParams(), msg) # type: ignore[arg-type]
# 42 should be appended directly (not split as a string)
mock_parse.assert_called_once_with(
{'tags': ['tag1', 42]}, msg, ignore_unknown_fields=True
)


class TestValidationEdgeCases:
"""Additional validation tests to cover missing branches."""

def test_required_message_field_not_set(self):
"""A REQUIRED message field with presence that is not set produces a validation error."""
# Task.status is REQUIRED + has_presence; omitting it hits the branch.
task = Task(id='task-1', context_id='ctx-1')
with pytest.raises(InvalidParamsError) as exc_info:
proto_utils.validate_proto_required_fields(task)

errors = (
exc_info.value.data.get('errors', []) if exc_info.value.data else []
)
fields = [e['field'] for e in errors]
assert 'status' in fields

def test_map_field_recurse_validation(self):
"""Map entry fields are recursively validated when populated."""
# AgentCard.security_schemes is a map<string, SecurityScheme>.
# Populating it causes _recurse_validation to enter the map_entry branch.
card = AgentCard()
card.security_schemes['myScheme'].CopyFrom(SecurityScheme())
# We only need the code path to execute; errors from other required
# fields on AgentCard are expected.
errors = proto_utils._validate_proto_required_fields_internal(card)
# The map branch ran; verify no crash and we got some errors.
assert isinstance(errors, list)


class TestBadRequestConversions:
"""Tests for validation_errors_to_bad_request and bad_request_to_validation_errors."""

def test_validation_errors_to_bad_request(self):
"""Lines 334-339: convert ValidationDetail list to BadRequest proto."""
errors: list[proto_utils.ValidationDetail] = [
proto_utils.ValidationDetail(field='foo', message='required'),
proto_utils.ValidationDetail(field='bar', message='invalid'),
]
bad_request = proto_utils.validation_errors_to_bad_request(errors)

assert isinstance(bad_request, error_details_pb2.BadRequest)
assert len(bad_request.field_violations) == 2
assert bad_request.field_violations[0].field == 'foo'
assert bad_request.field_violations[0].description == 'required'
assert bad_request.field_violations[1].field == 'bar'
assert bad_request.field_violations[1].description == 'invalid'

def test_bad_request_to_validation_errors(self):
"""Converts a BadRequest proto back to a ValidationDetail list."""
bad_request = error_details_pb2.BadRequest()
v = bad_request.field_violations.add()
v.field = 'baz'
v.description = 'must be set'

errors = proto_utils.bad_request_to_validation_errors(bad_request)

assert len(errors) == 1
assert errors[0]['field'] == 'baz'
assert errors[0]['message'] == 'must be set'

def test_bad_request_roundtrip(self):
"""Roundtrip: ValidationDetail -> BadRequest -> ValidationDetail."""
original: list[proto_utils.ValidationDetail] = [
proto_utils.ValidationDetail(field='x', message='err'),
]
restored = proto_utils.bad_request_to_validation_errors(
proto_utils.validation_errors_to_bad_request(original)
)
assert restored == original
Loading