diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index e7dd4868..ff5387ef 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -3,7 +3,7 @@ import logging from collections.abc import Callable -from typing import Any +from typing import Any, cast import httpx @@ -20,6 +20,8 @@ AgentInterface, ) from a2a.utils.constants import ( + PROTOCOL_VERSION_CURRENT, + VERSION_HEADER, TransportProtocol, ) @@ -65,6 +67,11 @@ def __init__( ): if consumers is None: consumers = [] + + client = config.httpx_client or httpx.AsyncClient() + client.headers.setdefault(VERSION_HEADER, PROTOCOL_VERSION_CURRENT) + config.httpx_client = client + self._config = config self._consumers = consumers self._registry: dict[str, TransportProducer] = {} @@ -72,11 +79,12 @@ def __init__( def _register_defaults(self, supported: list[str]) -> None: # Empty support list implies JSON-RPC only. + if TransportProtocol.JSONRPC in supported or not supported: self.register( TransportProtocol.JSONRPC, lambda card, url, config, interceptors: JsonRpcTransport( - config.httpx_client or httpx.AsyncClient(), + cast('httpx.AsyncClient', config.httpx_client), card, url, interceptors, @@ -87,7 +95,7 @@ def _register_defaults(self, supported: list[str]) -> None: self.register( TransportProtocol.HTTP_JSON, lambda card, url, config, interceptors: RestTransport( - config.httpx_client or httpx.AsyncClient(), + cast('httpx.AsyncClient', config.httpx_client), card, url, interceptors, diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 3815d722..ffae90d8 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -43,6 +43,7 @@ Task, TaskPushNotificationConfig, ) +from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER from a2a.utils.telemetry import SpanKind, trace_class @@ -303,11 +304,14 @@ async def close(self) -> None: def _get_grpc_metadata( self, extensions: list[str] | None = None, - ) -> list[tuple[str, str]] | None: + ) -> list[tuple[str, str]]: """Creates gRPC metadata for extensions.""" + metadata = [(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT)] + extensions_to_use = extensions or self.extensions if extensions_to_use: - return [ + metadata.append( (HTTP_EXTENSION_HEADER.lower(), ','.join(extensions_to_use)) - ] - return None + ) + + return metadata diff --git a/src/a2a/utils/constants.py b/src/a2a/utils/constants.py index b90b390d..65d6598f 100644 --- a/src/a2a/utils/constants.py +++ b/src/a2a/utils/constants.py @@ -22,3 +22,7 @@ class TransportProtocol(str, Enum): DEFAULT_MAX_CONTENT_LENGTH = 10 * 1024 * 1024 # 10MB JSONRPC_PARSE_ERROR_CODE = -32700 +VERSION_HEADER = 'A2A-Version' + +PROTOCOL_VERSION_1_0 = '1.0' +PROTOCOL_VERSION_CURRENT = PROTOCOL_VERSION_1_0 diff --git a/tests/client/transports/test_grpc_client.py b/tests/client/transports/test_grpc_client.py index f6615d17..3fd45b6f 100644 --- a/tests/client/transports/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -5,6 +5,7 @@ from a2a.client.transports.grpc import GrpcTransport from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.utils.constants import VERSION_HEADER, PROTOCOL_VERSION_CURRENT from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import ( AgentCapabilities, @@ -217,10 +218,11 @@ async def test_send_message_task_response( mock_grpc_stub.SendMessage.assert_awaited_once() _, kwargs = mock_grpc_stub.SendMessage.call_args assert kwargs['metadata'] == [ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), ( HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3', - ) + ), ] assert response.HasField('task') assert response.task.id == sample_task.id @@ -266,10 +268,11 @@ async def test_send_message_message_response( mock_grpc_stub.SendMessage.assert_awaited_once() _, kwargs = mock_grpc_stub.SendMessage.call_args assert kwargs['metadata'] == [ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), ( HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', - ) + ), ] assert response.HasField('message') assert response.message.message_id == sample_message.message_id @@ -315,10 +318,11 @@ async def test_send_message_streaming( # noqa: PLR0913 mock_grpc_stub.SendStreamingMessage.assert_called_once() _, kwargs = mock_grpc_stub.SendStreamingMessage.call_args assert kwargs['metadata'] == [ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), ( HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', - ) + ), ] # Responses are StreamResponse proto objects assert responses[0].HasField('message') @@ -350,10 +354,11 @@ async def test_get_task( mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest(id=f'{sample_task.id}', history_length=None), metadata=[ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), ( HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', - ) + ), ], ) assert response.id == sample_task.id @@ -378,10 +383,11 @@ async def test_list_tasks( mock_grpc_stub.ListTasks.assert_awaited_once_with( params, metadata=[ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), ( HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', - ) + ), ], ) assert result.total_size == 2 @@ -405,10 +411,11 @@ async def test_get_task_with_history( id=f'{sample_task.id}', history_length=history_len ), metadata=[ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), ( HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', - ) + ), ], ) @@ -433,7 +440,8 @@ async def test_cancel_task( mock_grpc_stub.CancelTask.assert_awaited_once_with( a2a_pb2.CancelTaskRequest(id=f'{sample_task.id}'), metadata=[ - (HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3') + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), + (HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3'), ], ) assert response.status.state == TaskState.TASK_STATE_CANCELED @@ -462,10 +470,11 @@ async def test_create_task_push_notification_config_with_valid_task( mock_grpc_stub.CreateTaskPushNotificationConfig.assert_awaited_once_with( request, metadata=[ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), ( HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', - ) + ), ], ) assert response.task_id == sample_task_push_notification_config.task_id @@ -524,10 +533,11 @@ async def test_get_task_push_notification_config_with_valid_task( id=config_id, ), metadata=[ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), ( HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', - ) + ), ], ) assert response.task_id == sample_task_push_notification_config.task_id @@ -577,10 +587,11 @@ async def test_list_task_push_notification_configs( mock_grpc_stub.ListTaskPushNotificationConfigs.assert_awaited_once_with( a2a_pb2.ListTaskPushNotificationConfigsRequest(task_id='task-1'), metadata=[ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), ( HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', - ) + ), ], ) assert len(response.configs) == 1 @@ -609,10 +620,11 @@ async def test_delete_task_push_notification_config( id='config-1', ), metadata=[ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), ( HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', - ) + ), ], ) @@ -623,32 +635,47 @@ async def test_delete_task_push_notification_config( ( None, None, - None, + [(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT)], ), # Case 1: No initial, No input ( ['ext1'], None, - [(HTTP_EXTENSION_HEADER.lower(), 'ext1')], + [ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), + (HTTP_EXTENSION_HEADER.lower(), 'ext1'), + ], ), # Case 2: Initial, No input ( None, ['ext2'], - [(HTTP_EXTENSION_HEADER.lower(), 'ext2')], + [ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), + (HTTP_EXTENSION_HEADER.lower(), 'ext2'), + ], ), # Case 3: No initial, Input ( ['ext1'], ['ext2'], - [(HTTP_EXTENSION_HEADER.lower(), 'ext2')], + [ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), + (HTTP_EXTENSION_HEADER.lower(), 'ext2'), + ], ), # Case 4: Initial, Input (override) ( ['ext1'], ['ext2', 'ext3'], - [(HTTP_EXTENSION_HEADER.lower(), 'ext2,ext3')], + [ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), + (HTTP_EXTENSION_HEADER.lower(), 'ext2,ext3'), + ], ), # Case 5: Initial, Multiple inputs (override) ( ['ext1', 'ext2'], ['ext3'], - [(HTTP_EXTENSION_HEADER.lower(), 'ext3')], + [ + (VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT), + (HTTP_EXTENSION_HEADER.lower(), 'ext3'), + ], ), # Case 6: Multiple initial, Single input (override) ], ) diff --git a/tests/utils/test_constants.py b/tests/utils/test_constants.py index 4208268d..1c427b3f 100644 --- a/tests/utils/test_constants.py +++ b/tests/utils/test_constants.py @@ -13,3 +13,14 @@ def test_agent_card_constants(): def test_default_rpc_url(): """Test default RPC URL constant.""" assert constants.DEFAULT_RPC_URL == '/' + + +def test_version_header(): + """Test version header constant.""" + assert constants.VERSION_HEADER == 'A2A-Version' + + +def test_protocol_versions(): + """Test protocol version constants.""" + assert constants.PROTOCOL_VERSION_1_0 == '1.0' + assert constants.PROTOCOL_VERSION_CURRENT == '1.0'