From be01104f503cc5b7ee0c3a66d73265d308b1fe31 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 3 Mar 2026 16:40:53 +0000 Subject: [PATCH 01/14] feat: handle tenant in Client --- src/a2a/client/client_factory.py | 20 ++- src/a2a/client/transports/__init__.py | 3 +- src/a2a/client/transports/base.py | 162 +++++++++++++++++++ src/a2a/client/transports/rest.py | 81 ++++++++-- tests/client/test_base_client.py | 23 ++- tests/client/test_client_factory.py | 22 ++- tests/client/transports/test_rest_client.py | 164 ++++++++++++++++++++ tests/integration/test_tenant.py | 160 +++++++++++++++++++ 8 files changed, 606 insertions(+), 29 deletions(-) create mode 100644 tests/integration/test_tenant.py diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index e7dd48689..0177cba9e 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -11,7 +11,7 @@ from a2a.client.card_resolver import A2ACardResolver from a2a.client.client import Client, ClientConfig, Consumer from a2a.client.middleware import ClientCallInterceptor -from a2a.client.transports.base import ClientTransport +from a2a.client.transports.base import ClientTransport, TenantTransportDecorator from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.client.transports.rest import RestTransport from a2a.types.a2a_pb2 import ( @@ -208,10 +208,10 @@ def create( TransportProtocol.JSONRPC ] transport_protocol = None - transport_url = None + selected_interface = None if self._config.use_client_preference: for protocol_binding in client_set: - supported_interface = next( + selected_interface = next( ( si for si in card.supported_interfaces @@ -219,17 +219,16 @@ def create( ), None, ) - if supported_interface: + if selected_interface: transport_protocol = protocol_binding - transport_url = supported_interface.url break else: for supported_interface in card.supported_interfaces: if supported_interface.protocol_binding in client_set: transport_protocol = supported_interface.protocol_binding - transport_url = supported_interface.url + selected_interface = supported_interface break - if not transport_protocol or not transport_url: + if not transport_protocol or not selected_interface: raise ValueError('no compatible transports found.') if transport_protocol not in self._registry: raise ValueError(f'no client available for {transport_protocol}') @@ -244,9 +243,14 @@ def create( self._config.extensions = all_extensions transport = self._registry[transport_protocol]( - card, transport_url, self._config, interceptors or [] + card, selected_interface.url, self._config, interceptors or [] ) + if selected_interface.tenant: + transport = TenantTransportDecorator( + transport, selected_interface.tenant + ) + return BaseClient( card, self._config, diff --git a/src/a2a/client/transports/__init__.py b/src/a2a/client/transports/__init__.py index af7c60f62..8c17ac04b 100644 --- a/src/a2a/client/transports/__init__.py +++ b/src/a2a/client/transports/__init__.py @@ -1,6 +1,6 @@ """A2A Client Transports.""" -from a2a.client.transports.base import ClientTransport +from a2a.client.transports.base import ClientTransport, TenantTransportDecorator from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.client.transports.rest import RestTransport @@ -16,4 +16,5 @@ 'GrpcTransport', 'JsonRpcTransport', 'RestTransport', + 'TenantTransportDecorator', ] diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py index f578ba3e3..28ce9257f 100644 --- a/src/a2a/client/transports/base.py +++ b/src/a2a/client/transports/base.py @@ -3,6 +3,7 @@ from types import TracebackType from typing_extensions import Self +from google.protobuf.message import Message from a2a.client.middleware import ClientCallContext from a2a.types.a2a_pb2 import ( @@ -158,3 +159,164 @@ async def get_extended_agent_card( @abstractmethod async def close(self) -> None: """Closes the transport.""" + + +class TenantTransportDecorator(ClientTransport): + """A transport decorator that attaches a tenant to all requests.""" + + def __init__(self, base: ClientTransport, tenant: str): + self._base = base + self._tenant = tenant + + def update_tenant(self, request: Message) -> str | None: + """Ensures the tenant is set on the request if provided and not already set. + + Returns: + The tenant used for the request. + """ + current_tenant = getattr(request, 'tenant', None) + if current_tenant: + return current_tenant + + if self._tenant and hasattr(request, 'tenant'): + request.tenant = self._tenant + return self._tenant + return None + + async def send_message( + self, + request: SendMessageRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> SendMessageResponse: + self.update_tenant(request) + return await self._base.send_message( + request, context=context, extensions=extensions + ) + + async def send_message_streaming( + self, + request: SendMessageRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> AsyncGenerator[StreamResponse]: + self.update_tenant(request) + async for event in self._base.send_message_streaming( + request, context=context, extensions=extensions + ): + yield event + + async def get_task( + self, + request: GetTaskRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> Task: + self.update_tenant(request) + return await self._base.get_task( + request, context=context, extensions=extensions + ) + + async def list_tasks( + self, + request: ListTasksRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> ListTasksResponse: + self.update_tenant(request) + return await self._base.list_tasks( + request, context=context, extensions=extensions + ) + + async def cancel_task( + self, + request: CancelTaskRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> Task: + self.update_tenant(request) + return await self._base.cancel_task( + request, context=context, extensions=extensions + ) + + async def set_task_callback( + self, + request: CreateTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> TaskPushNotificationConfig: + self.update_tenant(request) + return await self._base.set_task_callback( + request, context=context, extensions=extensions + ) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> TaskPushNotificationConfig: + self.update_tenant(request) + return await self._base.get_task_callback( + request, context=context, extensions=extensions + ) + + async def list_task_callback( + self, + request: ListTaskPushNotificationConfigsRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> ListTaskPushNotificationConfigsResponse: + self.update_tenant(request) + return await self._base.list_task_callback( + request, context=context, extensions=extensions + ) + + async def delete_task_callback( + self, + request: DeleteTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> None: + self.update_tenant(request) + await self._base.delete_task_callback( + request, context=context, extensions=extensions + ) + + async def subscribe( + self, + request: SubscribeToTaskRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> AsyncGenerator[StreamResponse]: + self.update_tenant(request) + async for event in self._base.subscribe( + request, context=context, extensions=extensions + ): + yield event + + async def get_extended_agent_card( + self, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, + ) -> AgentCard: + return await self._base.get_extended_agent_card( + context=context, + extensions=extensions, + signature_verifier=signature_verifier, + ) + + async def close(self) -> None: + await self._base.close() diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 3699f9feb..26c0ca8f1 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -96,6 +96,14 @@ async def _prepare_send_message( ) return payload, modified_kwargs + def _get_path(self, base_path: str, tenant: str | None) -> str: + """Returns the full path, prepending the tenant if provided.""" + return f'/{tenant}{base_path}' if tenant else base_path + + def _pop_tenant(self, data: dict[str, Any]) -> str | None: + """Pops and returns the tenant from the dictionary if it exists.""" + return data.pop('tenant', None) + async def send_message( self, request: SendMessageRequest, @@ -107,8 +115,10 @@ async def send_message( payload, modified_kwargs = await self._prepare_send_message( request, context, extensions ) + tenant = self._pop_tenant(payload) or request.tenant + path = self._get_path('/v1/message:send', tenant) response_data = await self._send_post_request( - '/v1/message:send', payload, modified_kwargs + path, payload, modified_kwargs ) response: SendMessageResponse = ParseDict( response_data, SendMessageResponse() @@ -126,13 +136,15 @@ async def send_message_streaming( payload, modified_kwargs = await self._prepare_send_message( request, context, extensions ) + tenant = self._pop_tenant(payload) or request.tenant + path = self._get_path('/v1/message:stream', tenant) modified_kwargs.setdefault('timeout', None) async with aconnect_sse( self.httpx_client, 'POST', - f'{self.url}/v1/message:stream', + f'{self.url}{path}', json=payload, **modified_kwargs, ) as event_source: @@ -239,8 +251,11 @@ async def get_task( if 'id' in params: del params['id'] # id is part of the URL path, not query params + tenant = self._pop_tenant(params) or request.tenant + path = self._get_path(f'/v1/tasks/{request.id}', tenant) + response_data = await self._send_get_request( - f'/v1/tasks/{request.id}', + path, params, modified_kwargs, ) @@ -255,7 +270,7 @@ async def list_tasks( extensions: list[str] | None = None, ) -> ListTasksResponse: """Retrieves tasks for an agent.""" - _, modified_kwargs = await self._apply_interceptors( + payload, modified_kwargs = await self._apply_interceptors( MessageToDict(request, preserving_proto_field_name=True), self._get_http_args(context), context, @@ -264,9 +279,13 @@ async def list_tasks( modified_kwargs, extensions if extensions is not None else self.extensions, ) + + tenant = self._pop_tenant(payload) or request.tenant + path = self._get_path('/v1/tasks', tenant) + response_data = await self._send_get_request( - '/v1/tasks', - _model_to_query_params(request), + path, + payload, modified_kwargs, ) response: ListTasksResponse = ParseDict( @@ -292,8 +311,12 @@ async def cancel_task( modified_kwargs, context, ) + + tenant = self._pop_tenant(payload) or request.tenant + path = self._get_path(f'/v1/tasks/{request.id}:cancel', tenant) + response_data = await self._send_post_request( - f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs + path, payload, modified_kwargs ) response: Task = ParseDict(response_data, Task()) return response @@ -314,8 +337,14 @@ async def set_task_callback( payload, modified_kwargs = await self._apply_interceptors( payload, modified_kwargs, context ) + + tenant = self._pop_tenant(payload) or request.tenant + path = self._get_path( + f'/v1/tasks/{request.task_id}/pushNotificationConfigs', tenant + ) + response_data = await self._send_post_request( - f'/v1/tasks/{request.task_id}/pushNotificationConfigs', + path, payload, modified_kwargs, ) @@ -346,8 +375,15 @@ async def get_task_callback( del params['id'] if 'task_id' in params: del params['task_id'] - response_data = await self._send_get_request( + + tenant = self._pop_tenant(params) or request.tenant + path = self._get_path( f'/v1/tasks/{request.task_id}/pushNotificationConfigs/{request.id}', + tenant, + ) + + response_data = await self._send_get_request( + path, params, modified_kwargs, ) @@ -376,8 +412,14 @@ async def list_task_callback( ) if 'task_id' in params: del params['task_id'] + + tenant = self._pop_tenant(params) or request.tenant + path = self._get_path( + f'/v1/tasks/{request.task_id}/pushNotificationConfigs', tenant + ) + response_data = await self._send_get_request( - f'/v1/tasks/{request.task_id}/pushNotificationConfigs', + path, params, modified_kwargs, ) @@ -408,8 +450,15 @@ async def delete_task_callback( del params['id'] if 'task_id' in params: del params['task_id'] - await self._send_delete_request( + + tenant = self._pop_tenant(params) or request.tenant + path = self._get_path( f'/v1/tasks/{request.task_id}/pushNotificationConfigs/{request.id}', + tenant, + ) + + await self._send_delete_request( + path, params, modified_kwargs, ) @@ -428,10 +477,13 @@ async def subscribe( ) modified_kwargs.setdefault('timeout', None) + tenant = request.tenant + path = self._get_path(f'/v1/tasks/{request.id}:subscribe', tenant) + async with aconnect_sse( self.httpx_client, 'GET', - f'{self.url}/v1/tasks/{request.id}:subscribe', + f'{self.url}{path}', **modified_kwargs, ) as event_source: try: @@ -493,11 +545,6 @@ async def close(self) -> None: await self.httpx_client.aclose() -def _model_to_query_params(instance: Message) -> dict[str, str]: - data = MessageToDict(instance, preserving_proto_field_name=True) - return _json_to_query_params(data) - - def _json_to_query_params(data: dict[str, Any]) -> dict[str, str]: query_dict = {} for key, value in data.items(): diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index ce47b7ac1..14d9ef532 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -1,18 +1,20 @@ +from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch import pytest from a2a.client.base_client import BaseClient from a2a.client.client import ClientConfig -from a2a.client.transports.base import ClientTransport +from a2a.client.transports.base import ClientTransport, TenantTransportDecorator from a2a.types.a2a_pb2 import ( AgentCapabilities, - AgentInterface, AgentCard, + AgentInterface, Message, Part, Role, SendMessageConfiguration, + SendMessageRequest, SendMessageResponse, StreamResponse, Task, @@ -276,3 +278,20 @@ async def create_stream(*args, **kwargs): assert params.configuration.history_length == 0 assert params.configuration.blocking is True assert params.configuration.accepted_output_modes == ['text/plain'] + + +@pytest.mark.asyncio +async def test_tenant_transport_decorator(mock_transport: AsyncMock) -> None: + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + + request = SendMessageRequest( + message=Message(parts=[Part(text='hello')]), + ) + mock_transport.send_message.return_value = SendMessageResponse() + + await decorator.send_message(request) + + mock_transport.send_message.assert_called_once() + called_request = mock_transport.send_message.call_args[0][0] + assert called_request.tenant == tenant_id diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index 246406f2b..8ccc4e0b8 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -1,5 +1,6 @@ """Tests for the ClientFactory.""" +from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch import typing @@ -8,7 +9,11 @@ from a2a.client import ClientConfig, ClientFactory from a2a.client.client_factory import TransportProducer -from a2a.client.transports import JsonRpcTransport, RestTransport +from a2a.client.transports import ( + JsonRpcTransport, + RestTransport, + TenantTransportDecorator, +) from a2a.types.a2a_pb2 import ( AgentCapabilities, AgentCard, @@ -284,3 +289,18 @@ async def test_client_factory_connect_with_consumers_and_interceptors( call_args = mock_base_client.call_args[0] assert call_args[3] == [consumer1] assert call_args[4] == [interceptor1] + + +def test_client_factory_applies_tenant_decorator(base_agent_card: AgentCard): + """Verify that the factory applies TenantTransportDecorator when tenant is present.""" + base_agent_card.supported_interfaces[0].tenant = 'my-tenant' + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_protocol_bindings=[TransportProtocol.JSONRPC], + ) + factory = ClientFactory(config) + client = factory.create(base_agent_card) + + assert isinstance(client._transport, TenantTransportDecorator) # type: ignore[attr-defined] + assert client._transport._tenant == 'my-tenant' # type: ignore[attr-defined] + assert isinstance(client._transport._base, JsonRpcTransport) # type: ignore[attr-defined] diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 663d13284..566a44a16 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -15,10 +15,17 @@ AgentCapabilities, AgentCard, AgentInterface, + CancelTaskRequest, + CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, ListTaskPushNotificationConfigsRequest, ListTaskPushNotificationConfigsResponse, + ListTasksRequest, + Message, SendMessageRequest, + SubscribeToTaskRequest, TaskPushNotificationConfig, ) from a2a.utils.constants import TransportProtocol @@ -362,3 +369,160 @@ async def test_delete_task_callback_success( f'/v1/tasks/{task_id}/pushNotificationConfigs/config-1' in call_args[0][1] ) + + +class TestRestTransportTenant: + """Tests for tenant path prepending in RestTransport.""" + + @pytest.mark.parametrize( + 'method_name, request_obj, expected_path', + [ + ( + 'send_message', + SendMessageRequest( + tenant='my-tenant', + message=create_text_message_object(content='hi'), + ), + '/my-tenant/v1/message:send', + ), + ( + 'list_tasks', + ListTasksRequest(tenant='my-tenant'), + '/my-tenant/v1/tasks', + ), + ( + 'get_task', + GetTaskRequest(tenant='my-tenant', id='task-123'), + '/my-tenant/v1/tasks/task-123', + ), + ( + 'cancel_task', + CancelTaskRequest(tenant='my-tenant', id='task-123'), + '/my-tenant/v1/tasks/task-123:cancel', + ), + ( + 'set_task_callback', + CreateTaskPushNotificationConfigRequest( + tenant='my-tenant', task_id='task-123' + ), + '/my-tenant/v1/tasks/task-123/pushNotificationConfigs', + ), + ( + 'get_task_callback', + GetTaskPushNotificationConfigRequest( + tenant='my-tenant', task_id='task-123', id='cfg-1' + ), + '/my-tenant/v1/tasks/task-123/pushNotificationConfigs/cfg-1', + ), + ( + 'list_task_callback', + ListTaskPushNotificationConfigsRequest( + tenant='my-tenant', task_id='task-123' + ), + '/my-tenant/v1/tasks/task-123/pushNotificationConfigs', + ), + ( + 'delete_task_callback', + DeleteTaskPushNotificationConfigRequest( + tenant='my-tenant', task_id='task-123', id='cfg-1' + ), + '/my-tenant/v1/tasks/task-123/pushNotificationConfigs/cfg-1', + ), + ], + ) + @pytest.mark.asyncio + async def test_rest_methods_prepend_tenant( + self, + method_name, + request_obj, + expected_path, + mock_httpx_client, + mock_agent_card, + ): + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + url='http://agent.example.com/api', + ) + + # 1. Get the method dynamically + method = getattr(client, method_name) + + # 2. Setup mocks + mock_httpx_client.build_request.return_value = MagicMock( + spec=httpx.Request + ) + mock_httpx_client.send.return_value = AsyncMock( + spec=httpx.Response, + status_code=200, + json=MagicMock(return_value={}), + ) + + # 3. Call the method + await method(request=request_obj) + + # 4. Verify the URL + args, _ = mock_httpx_client.build_request.call_args + assert args[1] == f'http://agent.example.com/api{expected_path}' + + @pytest.mark.parametrize( + 'method_name, request_obj, expected_path', + [ + ( + 'subscribe', + SubscribeToTaskRequest(tenant='my-tenant', id='task-123'), + '/my-tenant/v1/tasks/task-123:subscribe', + ), + ( + 'send_message_streaming', + SendMessageRequest( + tenant='my-tenant', + message=create_text_message_object(content='hi'), + ), + '/my-tenant/v1/message:stream', + ), + ], + ) + @pytest.mark.asyncio + @patch('a2a.client.transports.rest.aconnect_sse') + async def test_rest_streaming_methods_prepend_tenant( + self, + mock_aconnect_sse, + method_name, + request_obj, + expected_path, + mock_httpx_client, + mock_agent_card, + ): + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + url='http://agent.example.com/api', + ) + + # 1. Get the method dynamically + method = getattr(client, method_name) + + # 2. Setup mocks + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.response = MagicMock(spec=httpx.Response) + mock_event_source.response.raise_for_status.return_value = None + + async def empty_aiter(): + if False: + yield + + mock_event_source.aiter_sse.return_value = empty_aiter() + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + # 3. Call the method + async for _ in method(request=request_obj): + pass + + # 4. Verify the URL + mock_aconnect_sse.assert_called_once() + args, _ = mock_aconnect_sse.call_args + # url is 3rd positional argument in aconnect_sse(client, method, url, ...) + assert args[2] == f'http://agent.example.com/api{expected_path}' diff --git a/tests/integration/test_tenant.py b/tests/integration/test_tenant.py new file mode 100644 index 000000000..54c97ebec --- /dev/null +++ b/tests/integration/test_tenant.py @@ -0,0 +1,160 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +import httpx +from a2a.types.a2a_pb2 import ( + AgentCard, + AgentInterface, + SendMessageRequest, + Message, + GetTaskRequest, + AgentCapabilities, +) +from a2a.client.transports import RestTransport, JsonRpcTransport, GrpcTransport +from a2a.client.transports.base import TenantTransportDecorator +from a2a.client import ClientConfig, ClientFactory +from a2a.utils.constants import TransportProtocol + + +@pytest.fixture +def agent_card(): + return AgentCard( + supported_interfaces=[ + AgentInterface( + url='http://example.com/rest', + protocol_binding=TransportProtocol.HTTP_JSON, + tenant='tenant-1', + ), + AgentInterface( + url='http://example.com/jsonrpc', + protocol_binding=TransportProtocol.JSONRPC, + tenant='tenant-2', + ), + AgentInterface( + url='http://example.com/grpc', + protocol_binding=TransportProtocol.GRPC, + tenant='tenant-3', + ), + ], + capabilities=AgentCapabilities(streaming=True), + ) + + +@pytest.mark.asyncio +async def test_tenant_decorator_rest(agent_card): + mock_httpx = AsyncMock(spec=httpx.AsyncClient) + mock_httpx.build_request.return_value = MagicMock() + mock_httpx.send.return_value = MagicMock( + status_code=200, json=lambda: {'message': {}} + ) + + config = ClientConfig( + httpx_client=mock_httpx, + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + factory = ClientFactory(config) + client = factory.create(agent_card) + + assert isinstance(client._transport, TenantTransportDecorator) + assert client._transport._tenant == 'tenant-1' + + # Test SendMessage (POST) - Use transport directly to avoid streaming complexity in mock + request = SendMessageRequest(message=Message(parts=[{'text': 'hi'}])) + await client._transport.send_message(request) + + # Check that tenant was populated in request + assert request.tenant == 'tenant-1' + + # Check that path was prepended in the underlying transport + mock_httpx.build_request.assert_called() + send_call = next( + c + for c in mock_httpx.build_request.call_args_list + if 'v1/message:send' in c.args[1] + ) + args, kwargs = send_call + assert args[1] == 'http://example.com/rest/tenant-1/v1/message:send' + # tenant should NOT be in JSON body as it was popped + assert 'tenant' not in kwargs['json'] + + +@pytest.mark.asyncio +async def test_tenant_decorator_jsonrpc(agent_card): + mock_httpx = AsyncMock(spec=httpx.AsyncClient) + mock_httpx.post.return_value = MagicMock( + status_code=200, + json=lambda: {'result': {'message': {}}, 'id': '1', 'jsonrpc': '2.0'}, + ) + + config = ClientConfig( + httpx_client=mock_httpx, + supported_protocol_bindings=[TransportProtocol.JSONRPC], + ) + factory = ClientFactory(config) + client = factory.create(agent_card) + + assert isinstance(client._transport, TenantTransportDecorator) + assert client._transport._tenant == 'tenant-2' + + request = SendMessageRequest(message=Message(parts=[{'text': 'hi'}])) + await client._transport.send_message(request) + + mock_httpx.post.assert_called() + _, kwargs = mock_httpx.post.call_args + assert kwargs['json']['params']['tenant'] == 'tenant-2' + + +@pytest.mark.asyncio +async def test_tenant_decorator_grpc(agent_card): + mock_channel = MagicMock() + config = ClientConfig( + grpc_channel_factory=lambda url: mock_channel, + supported_protocol_bindings=[TransportProtocol.GRPC], + ) + + with patch('a2a.types.a2a_pb2_grpc.A2AServiceStub') as mock_stub_class: + mock_stub = mock_stub_class.return_value + mock_stub.SendMessage = AsyncMock(return_value={'message': {}}) + + factory = ClientFactory(config) + client = factory.create(agent_card) + + assert isinstance(client._transport, TenantTransportDecorator) + assert client._transport._tenant == 'tenant-3' + + await client._transport.send_message( + SendMessageRequest(message=Message(parts=[{'text': 'hi'}])) + ) + + call_args = mock_stub.SendMessage.call_args + assert call_args[0][0].tenant == 'tenant-3' + + +@pytest.mark.asyncio +async def test_tenant_decorator_explicit_override(agent_card): + mock_httpx = AsyncMock(spec=httpx.AsyncClient) + mock_httpx.build_request.return_value = MagicMock() + mock_httpx.send.return_value = MagicMock( + status_code=200, json=lambda: {'message': {}} + ) + + config = ClientConfig( + httpx_client=mock_httpx, + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + factory = ClientFactory(config) + client = factory.create(agent_card) + + request = SendMessageRequest( + message=Message(parts=[{'text': 'hi'}]), tenant='explicit-tenant' + ) + await client._transport.send_message(request) + + assert request.tenant == 'explicit-tenant' + + send_call = next( + c + for c in mock_httpx.build_request.call_args_list + if 'v1/message:send' in c.args[1] + ) + args, _ = send_call + assert args[1] == 'http://example.com/rest/explicit-tenant/v1/message:send' From 0599baf51bdc72bc3a4c678bc2f7c63721f9b52c Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 3 Mar 2026 17:04:29 +0000 Subject: [PATCH 02/14] docs: Add docstrings to base transport methods. --- src/a2a/client/transports/base.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py index 28ce9257f..4301bf5e7 100644 --- a/src/a2a/client/transports/base.py +++ b/src/a2a/client/transports/base.py @@ -190,6 +190,7 @@ async def send_message( context: ClientCallContext | None = None, extensions: list[str] | None = None, ) -> SendMessageResponse: + """Sends a streaming message request to the agent and yields responses as they arrive.""" self.update_tenant(request) return await self._base.send_message( request, context=context, extensions=extensions @@ -202,6 +203,7 @@ async def send_message_streaming( context: ClientCallContext | None = None, extensions: list[str] | None = None, ) -> AsyncGenerator[StreamResponse]: + """Sends a streaming message request to the agent and yields responses.""" self.update_tenant(request) async for event in self._base.send_message_streaming( request, context=context, extensions=extensions @@ -215,6 +217,7 @@ async def get_task( context: ClientCallContext | None = None, extensions: list[str] | None = None, ) -> Task: + """Retrieves the current state and history of a specific task.""" self.update_tenant(request) return await self._base.get_task( request, context=context, extensions=extensions @@ -227,6 +230,7 @@ async def list_tasks( context: ClientCallContext | None = None, extensions: list[str] | None = None, ) -> ListTasksResponse: + """Retrieves tasks for an agent.""" self.update_tenant(request) return await self._base.list_tasks( request, context=context, extensions=extensions @@ -239,6 +243,7 @@ async def cancel_task( context: ClientCallContext | None = None, extensions: list[str] | None = None, ) -> Task: + """Requests the agent to cancel a specific task.""" self.update_tenant(request) return await self._base.cancel_task( request, context=context, extensions=extensions @@ -251,6 +256,7 @@ async def set_task_callback( context: ClientCallContext | None = None, extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" self.update_tenant(request) return await self._base.set_task_callback( request, context=context, extensions=extensions @@ -263,6 +269,7 @@ async def get_task_callback( context: ClientCallContext | None = None, extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" self.update_tenant(request) return await self._base.get_task_callback( request, context=context, extensions=extensions @@ -275,6 +282,7 @@ async def list_task_callback( context: ClientCallContext | None = None, extensions: list[str] | None = None, ) -> ListTaskPushNotificationConfigsResponse: + """Lists push notification configurations for a specific task.""" self.update_tenant(request) return await self._base.list_task_callback( request, context=context, extensions=extensions @@ -287,6 +295,7 @@ async def delete_task_callback( context: ClientCallContext | None = None, extensions: list[str] | None = None, ) -> None: + """Deletes the push notification configuration for a specific task.""" self.update_tenant(request) await self._base.delete_task_callback( request, context=context, extensions=extensions @@ -299,6 +308,7 @@ async def subscribe( context: ClientCallContext | None = None, extensions: list[str] | None = None, ) -> AsyncGenerator[StreamResponse]: + """Reconnects to get task updates.""" self.update_tenant(request) async for event in self._base.subscribe( request, context=context, extensions=extensions @@ -312,6 +322,7 @@ async def get_extended_agent_card( extensions: list[str] | None = None, signature_verifier: Callable[[AgentCard], None] | None = None, ) -> AgentCard: + """Retrieves the Extended AgentCard.""" return await self._base.get_extended_agent_card( context=context, extensions=extensions, @@ -319,4 +330,5 @@ async def get_extended_agent_card( ) async def close(self) -> None: + """Closes the transport.""" await self._base.close() From a4f7d91b3be136788aaacf42914d038f22b67fb8 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 3 Mar 2026 17:14:14 +0000 Subject: [PATCH 03/14] fix: merging errors --- src/a2a/client/transports/base.py | 20 ++++++++++---------- src/a2a/client/transports/rest.py | 1 - tests/client/transports/test_rest_client.py | 8 ++++---- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py index b7f255383..3a0774ef5 100644 --- a/src/a2a/client/transports/base.py +++ b/src/a2a/client/transports/base.py @@ -2,8 +2,8 @@ from collections.abc import AsyncGenerator, Callable from types import TracebackType -from typing_extensions import Self from google.protobuf.message import Message +from typing_extensions import Self from a2a.client.middleware import ClientCallContext from a2a.types.a2a_pb2 import ( @@ -179,7 +179,7 @@ def update_tenant(self, request: Message) -> str | None: return current_tenant if self._tenant and hasattr(request, 'tenant'): - request.tenant = self._tenant + setattr(request, 'tenant', self._tenant) return self._tenant return None @@ -249,7 +249,7 @@ async def cancel_task( request, context=context, extensions=extensions ) - async def set_task_callback( + async def create_task_push_notification_config( self, request: CreateTaskPushNotificationConfigRequest, *, @@ -258,11 +258,11 @@ async def set_task_callback( ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" self.update_tenant(request) - return await self._base.set_task_callback( + return await self._base.create_task_push_notification_config( request, context=context, extensions=extensions ) - async def get_task_callback( + async def get_task_push_notification_config( self, request: GetTaskPushNotificationConfigRequest, *, @@ -271,11 +271,11 @@ async def get_task_callback( ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" self.update_tenant(request) - return await self._base.get_task_callback( + return await self._base.get_task_push_notification_config( request, context=context, extensions=extensions ) - async def list_task_callback( + async def list_task_push_notification_configs( self, request: ListTaskPushNotificationConfigsRequest, *, @@ -284,11 +284,11 @@ async def list_task_callback( ) -> ListTaskPushNotificationConfigsResponse: """Lists push notification configurations for a specific task.""" self.update_tenant(request) - return await self._base.list_task_callback( + return await self._base.list_task_push_notification_configs( request, context=context, extensions=extensions ) - async def delete_task_callback( + async def delete_task_push_notification_config( self, request: DeleteTaskPushNotificationConfigRequest, *, @@ -297,7 +297,7 @@ async def delete_task_callback( ) -> None: """Deletes the push notification configuration for a specific task.""" self.update_tenant(request) - await self._base.delete_task_callback( + await self._base.delete_task_push_notification_config( request, context=context, extensions=extensions ) diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 721ae9623..76086b1fd 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -7,7 +7,6 @@ import httpx from google.protobuf.json_format import MessageToDict, Parse, ParseDict -from google.protobuf.message import Message from httpx_sse import SSEError, aconnect_sse from a2a.client.errors import ( diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index bf9beddc1..362a17e17 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -401,28 +401,28 @@ class TestRestTransportTenant: '/my-tenant/v1/tasks/task-123:cancel', ), ( - 'set_task_callback', + 'create_task_push_notification_config', CreateTaskPushNotificationConfigRequest( tenant='my-tenant', task_id='task-123' ), '/my-tenant/v1/tasks/task-123/pushNotificationConfigs', ), ( - 'get_task_callback', + 'get_task_push_notification_config', GetTaskPushNotificationConfigRequest( tenant='my-tenant', task_id='task-123', id='cfg-1' ), '/my-tenant/v1/tasks/task-123/pushNotificationConfigs/cfg-1', ), ( - 'list_task_callback', + 'list_task_push_notification_configs', ListTaskPushNotificationConfigsRequest( tenant='my-tenant', task_id='task-123' ), '/my-tenant/v1/tasks/task-123/pushNotificationConfigs', ), ( - 'delete_task_callback', + 'delete_task_push_notification_config', DeleteTaskPushNotificationConfigRequest( tenant='my-tenant', task_id='task-123', id='cfg-1' ), From c6e22adbe5ee7e9df973c2a84d20bfad122e894b Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 3 Mar 2026 17:30:45 +0000 Subject: [PATCH 04/14] refactor: simplify tenant resolution logic in base --- src/a2a/client/transports/base.py | 36 +++++++++++++------------------ 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py index 3a0774ef5..4e6f8bde7 100644 --- a/src/a2a/client/transports/base.py +++ b/src/a2a/client/transports/base.py @@ -2,7 +2,6 @@ from collections.abc import AsyncGenerator, Callable from types import TracebackType -from google.protobuf.message import Message from typing_extensions import Self from a2a.client.middleware import ClientCallContext @@ -168,20 +167,15 @@ def __init__(self, base: ClientTransport, tenant: str): self._base = base self._tenant = tenant - def update_tenant(self, request: Message) -> str | None: - """Ensures the tenant is set on the request if provided and not already set. + def update_tenant(self, tenant: str) -> str: + """If tenant is not provided, use the default tenant. Returns: The tenant used for the request. """ - current_tenant = getattr(request, 'tenant', None) - if current_tenant: - return current_tenant - - if self._tenant and hasattr(request, 'tenant'): - setattr(request, 'tenant', self._tenant) - return self._tenant - return None + if tenant != '': + return tenant + return self._tenant or '' async def send_message( self, @@ -191,7 +185,7 @@ async def send_message( extensions: list[str] | None = None, ) -> SendMessageResponse: """Sends a streaming message request to the agent and yields responses as they arrive.""" - self.update_tenant(request) + request.tenant = self.update_tenant(request.tenant) return await self._base.send_message( request, context=context, extensions=extensions ) @@ -204,7 +198,7 @@ async def send_message_streaming( extensions: list[str] | None = None, ) -> AsyncGenerator[StreamResponse]: """Sends a streaming message request to the agent and yields responses.""" - self.update_tenant(request) + request.tenant = self.update_tenant(request.tenant) async for event in self._base.send_message_streaming( request, context=context, extensions=extensions ): @@ -218,7 +212,7 @@ async def get_task( extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" - self.update_tenant(request) + request.tenant = self.update_tenant(request.tenant) return await self._base.get_task( request, context=context, extensions=extensions ) @@ -231,7 +225,7 @@ async def list_tasks( extensions: list[str] | None = None, ) -> ListTasksResponse: """Retrieves tasks for an agent.""" - self.update_tenant(request) + request.tenant = self.update_tenant(request.tenant) return await self._base.list_tasks( request, context=context, extensions=extensions ) @@ -244,7 +238,7 @@ async def cancel_task( extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" - self.update_tenant(request) + request.tenant = self.update_tenant(request.tenant) return await self._base.cancel_task( request, context=context, extensions=extensions ) @@ -257,7 +251,7 @@ async def create_task_push_notification_config( extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" - self.update_tenant(request) + request.tenant = self.update_tenant(request.tenant) return await self._base.create_task_push_notification_config( request, context=context, extensions=extensions ) @@ -270,7 +264,7 @@ async def get_task_push_notification_config( extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" - self.update_tenant(request) + request.tenant = self.update_tenant(request.tenant) return await self._base.get_task_push_notification_config( request, context=context, extensions=extensions ) @@ -283,7 +277,7 @@ async def list_task_push_notification_configs( extensions: list[str] | None = None, ) -> ListTaskPushNotificationConfigsResponse: """Lists push notification configurations for a specific task.""" - self.update_tenant(request) + request.tenant = self.update_tenant(request.tenant) return await self._base.list_task_push_notification_configs( request, context=context, extensions=extensions ) @@ -296,7 +290,7 @@ async def delete_task_push_notification_config( extensions: list[str] | None = None, ) -> None: """Deletes the push notification configuration for a specific task.""" - self.update_tenant(request) + request.tenant = self.update_tenant(request.tenant) await self._base.delete_task_push_notification_config( request, context=context, extensions=extensions ) @@ -309,7 +303,7 @@ async def subscribe( extensions: list[str] | None = None, ) -> AsyncGenerator[StreamResponse]: """Reconnects to get task updates.""" - self.update_tenant(request) + request.tenant = self.update_tenant(request.tenant) async for event in self._base.subscribe( request, context=context, extensions=extensions ): From f1cc5a941e693130eea8ed58211d90b43d816baa Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 3 Mar 2026 18:12:43 +0000 Subject: [PATCH 05/14] refactor: group base client tests into a class and add new task and notification-related imports. --- tests/client/test_base_client.py | 559 +++++++++++++++++++------------ 1 file changed, 347 insertions(+), 212 deletions(-) diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index 14d9ef532..e4d5d834a 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -10,14 +10,26 @@ AgentCapabilities, AgentCard, AgentInterface, + CancelTaskRequest, + CreateTaskPushNotificationConfigRequest, + DeleteTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTaskPushNotificationConfigsResponse, + ListTasksRequest, + ListTasksResponse, Message, Part, + PushNotificationConfig, Role, SendMessageConfiguration, SendMessageRequest, SendMessageResponse, StreamResponse, + SubscribeToTaskRequest, Task, + TaskPushNotificationConfig, TaskState, TaskStatus, ) @@ -67,231 +79,354 @@ def base_client( ) -@pytest.mark.asyncio -async def test_transport_async_context_manager() -> None: - with ( - patch.object(ClientTransport, '__abstractmethods__', set()), - patch.object(ClientTransport, 'close', new_callable=AsyncMock), - ): - transport = ClientTransport() - async with transport as t: - assert t is transport - transport.close.assert_not_awaited() - transport.close.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_transport_async_context_manager_on_exception() -> None: - with ( - patch.object(ClientTransport, '__abstractmethods__', set()), - patch.object(ClientTransport, 'close', new_callable=AsyncMock), - ): - transport = ClientTransport() +class TestClientTransport: + @pytest.mark.asyncio + async def test_transport_async_context_manager(self) -> None: + with ( + patch.object(ClientTransport, '__abstractmethods__', set()), + patch.object(ClientTransport, 'close', new_callable=AsyncMock), + ): + transport = ClientTransport() + async with transport as t: + assert t is transport + transport.close.assert_not_awaited() + transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_transport_async_context_manager_on_exception(self) -> None: + with ( + patch.object(ClientTransport, '__abstractmethods__', set()), + patch.object(ClientTransport, 'close', new_callable=AsyncMock), + ): + transport = ClientTransport() + with pytest.raises(RuntimeError, match='boom'): + async with transport: + raise RuntimeError('boom') + transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_base_client_async_context_manager( + self, base_client: BaseClient, mock_transport: AsyncMock + ) -> None: + async with base_client as client: + assert client is base_client + mock_transport.close.assert_not_awaited() + mock_transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_base_client_async_context_manager_on_exception( + self, base_client: BaseClient, mock_transport: AsyncMock + ) -> None: with pytest.raises(RuntimeError, match='boom'): - async with transport: + async with base_client: raise RuntimeError('boom') - transport.close.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_base_client_async_context_manager( - base_client: BaseClient, mock_transport: AsyncMock -) -> None: - async with base_client as client: - assert client is base_client - mock_transport.close.assert_not_awaited() - mock_transport.close.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_base_client_async_context_manager_on_exception( - base_client: BaseClient, mock_transport: AsyncMock -) -> None: - with pytest.raises(RuntimeError, match='boom'): - async with base_client: - raise RuntimeError('boom') - mock_transport.close.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_send_message_streaming( - base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -) -> None: - async def create_stream(*args, **kwargs): + mock_transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_send_message_streaming( + self, + base_client: BaseClient, + mock_transport: MagicMock, + sample_message: Message, + ) -> None: + async def create_stream(*args, **kwargs): + task = Task( + id='task-123', + context_id='ctx-456', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + stream_response = StreamResponse() + stream_response.task.CopyFrom(task) + yield stream_response + + mock_transport.send_message_streaming.return_value = create_stream() + + meta = {'test': 1} + stream = base_client.send_message(sample_message, request_metadata=meta) + events = [event async for event in stream] + + mock_transport.send_message_streaming.assert_called_once() + assert ( + mock_transport.send_message_streaming.call_args[0][0].metadata + == meta + ) + assert not mock_transport.send_message.called + assert len(events) == 1 + # events[0] is (StreamResponse, Task) tuple + stream_response, tracked_task = events[0] + assert stream_response.task.id == 'task-123' + assert tracked_task is not None + assert tracked_task.id == 'task-123' + + @pytest.mark.asyncio + async def test_send_message_non_streaming( + self, + base_client: BaseClient, + mock_transport: MagicMock, + sample_message: Message, + ) -> None: + base_client._config.streaming = False task = Task( - id='task-123', - context_id='ctx-456', + id='task-456', + context_id='ctx-789', status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), ) - stream_response = StreamResponse() - stream_response.task.CopyFrom(task) - yield stream_response - - mock_transport.send_message_streaming.return_value = create_stream() - - meta = {'test': 1} - stream = base_client.send_message(sample_message, request_metadata=meta) - events = [event async for event in stream] - - mock_transport.send_message_streaming.assert_called_once() - assert ( - mock_transport.send_message_streaming.call_args[0][0].metadata == meta - ) - assert not mock_transport.send_message.called - assert len(events) == 1 - # events[0] is (StreamResponse, Task) tuple - stream_response, tracked_task = events[0] - assert stream_response.task.id == 'task-123' - assert tracked_task is not None - assert tracked_task.id == 'task-123' - - -@pytest.mark.asyncio -async def test_send_message_non_streaming( - base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -) -> None: - base_client._config.streaming = False - task = Task( - id='task-456', - context_id='ctx-789', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - response = SendMessageResponse() - response.task.CopyFrom(task) - mock_transport.send_message.return_value = response - - meta = {'test': 1} - stream = base_client.send_message(sample_message, request_metadata=meta) - events = [event async for event in stream] - - mock_transport.send_message.assert_called_once() - assert mock_transport.send_message.call_args[0][0].metadata == meta - assert not mock_transport.send_message_streaming.called - assert len(events) == 1 - stream_response, tracked_task = events[0] - assert stream_response.task.id == 'task-456' - assert tracked_task is not None - assert tracked_task.id == 'task-456' - - -@pytest.mark.asyncio -async def test_send_message_non_streaming_agent_capability_false( - base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -) -> None: - base_client._card.capabilities.streaming = False - task = Task( - id='task-789', - context_id='ctx-101', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - response = SendMessageResponse() - response.task.CopyFrom(task) - mock_transport.send_message.return_value = response - - events = [event async for event in base_client.send_message(sample_message)] - - mock_transport.send_message.assert_called_once() - assert not mock_transport.send_message_streaming.called - assert len(events) == 1 - stream_response, tracked_task = events[0] - assert stream_response is not None - assert tracked_task is not None - assert tracked_task.id == 'task-789' - - -@pytest.mark.asyncio -async def test_send_message_callsite_config_overrides_non_streaming( - base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -): - base_client._config.streaming = False - task = Task( - id='task-cfg-ns-1', - context_id='ctx-cfg-ns-1', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - response = SendMessageResponse() - response.task.CopyFrom(task) - mock_transport.send_message.return_value = response - - cfg = SendMessageConfiguration( - history_length=2, - blocking=False, - accepted_output_modes=['application/json'], - ) - events = [ - event - async for event in base_client.send_message( - sample_message, configuration=cfg + response = SendMessageResponse() + response.task.CopyFrom(task) + mock_transport.send_message.return_value = response + + meta = {'test': 1} + stream = base_client.send_message(sample_message, request_metadata=meta) + events = [event async for event in stream] + + mock_transport.send_message.assert_called_once() + assert mock_transport.send_message.call_args[0][0].metadata == meta + assert not mock_transport.send_message_streaming.called + assert len(events) == 1 + stream_response, tracked_task = events[0] + assert stream_response.task.id == 'task-456' + assert tracked_task is not None + assert tracked_task.id == 'task-456' + + @pytest.mark.asyncio + async def test_send_message_non_streaming_agent_capability_false( + self, + base_client: BaseClient, + mock_transport: MagicMock, + sample_message: Message, + ) -> None: + base_client._card.capabilities.streaming = False + task = Task( + id='task-789', + context_id='ctx-101', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), ) - ] - - mock_transport.send_message.assert_called_once() - assert not mock_transport.send_message_streaming.called - assert len(events) == 1 - stream_response, _ = events[0] - assert stream_response.task.id == 'task-cfg-ns-1' - - params = mock_transport.send_message.call_args[0][0] - assert params.configuration.history_length == 2 - assert params.configuration.blocking is False - assert params.configuration.accepted_output_modes == ['application/json'] - - -@pytest.mark.asyncio -async def test_send_message_callsite_config_overrides_streaming( - base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -): - base_client._config.streaming = True - base_client._card.capabilities.streaming = True - - async def create_stream(*args, **kwargs): + response = SendMessageResponse() + response.task.CopyFrom(task) + mock_transport.send_message.return_value = response + + events = [ + event async for event in base_client.send_message(sample_message) + ] + + mock_transport.send_message.assert_called_once() + assert not mock_transport.send_message_streaming.called + assert len(events) == 1 + stream_response, tracked_task = events[0] + assert stream_response is not None + assert tracked_task is not None + assert tracked_task.id == 'task-789' + + @pytest.mark.asyncio + async def test_send_message_callsite_config_overrides_non_streaming( + self, + base_client: BaseClient, + mock_transport: MagicMock, + sample_message: Message, + ): + base_client._config.streaming = False task = Task( - id='task-cfg-s-1', - context_id='ctx-cfg-s-1', + id='task-cfg-ns-1', + context_id='ctx-cfg-ns-1', status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), ) - stream_response = StreamResponse() - stream_response.task.CopyFrom(task) - yield stream_response - - mock_transport.send_message_streaming.return_value = create_stream() - - cfg = SendMessageConfiguration( - history_length=0, - blocking=True, - accepted_output_modes=['text/plain'], + response = SendMessageResponse() + response.task.CopyFrom(task) + mock_transport.send_message.return_value = response + + cfg = SendMessageConfiguration( + history_length=2, + blocking=False, + accepted_output_modes=['application/json'], + ) + events = [ + event + async for event in base_client.send_message( + sample_message, configuration=cfg + ) + ] + + mock_transport.send_message.assert_called_once() + assert not mock_transport.send_message_streaming.called + assert len(events) == 1 + stream_response, _ = events[0] + assert stream_response.task.id == 'task-cfg-ns-1' + + params = mock_transport.send_message.call_args[0][0] + assert params.configuration.history_length == 2 + assert params.configuration.blocking is False + assert params.configuration.accepted_output_modes == [ + 'application/json' + ] + + @pytest.mark.asyncio + async def test_send_message_callsite_config_overrides_streaming( + self, + base_client: BaseClient, + mock_transport: MagicMock, + sample_message: Message, + ): + base_client._config.streaming = True + base_client._card.capabilities.streaming = True + + async def create_stream(*args, **kwargs): + task = Task( + id='task-cfg-s-1', + context_id='ctx-cfg-s-1', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + stream_response = StreamResponse() + stream_response.task.CopyFrom(task) + yield stream_response + + mock_transport.send_message_streaming.return_value = create_stream() + + cfg = SendMessageConfiguration( + history_length=0, + blocking=True, + accepted_output_modes=['text/plain'], + ) + events = [ + event + async for event in base_client.send_message( + sample_message, configuration=cfg + ) + ] + + mock_transport.send_message_streaming.assert_called_once() + assert not mock_transport.send_message.called + assert len(events) == 1 + stream_response, _ = events[0] + assert stream_response.task.id == 'task-cfg-s-1' + + params = mock_transport.send_message_streaming.call_args[0][0] + assert params.configuration.history_length == 0 + assert params.configuration.blocking is True + assert params.configuration.accepted_output_modes == ['text/plain'] + + +class TestTenantTransportDecorator: + @pytest.mark.asyncio + async def test_update_tenant_logic(self, mock_transport: AsyncMock) -> None: + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + + # Case 1: Tenant already set on request + assert decorator.update_tenant('existing-tenant') == 'existing-tenant' + + # Case 2: Tenant not set (empty string) + assert decorator.update_tenant('') == tenant_id + + @pytest.mark.parametrize( + 'method_name, request_obj, response_obj', + [ + ( + 'get_task', + GetTaskRequest(id='t1'), + Task(id='t1'), + ), + ( + 'list_tasks', + ListTasksRequest(), + ListTasksResponse(), + ), + ( + 'cancel_task', + CancelTaskRequest(id='t1'), + Task(id='t1'), + ), + ( + 'create_task_push_notification_config', + CreateTaskPushNotificationConfigRequest(task_id='t1'), + TaskPushNotificationConfig(task_id='t1'), + ), + ( + 'get_task_push_notification_config', + GetTaskPushNotificationConfigRequest(task_id='t1', id='c1'), + TaskPushNotificationConfig(task_id='t1'), + ), + ( + 'list_task_push_notification_configs', + ListTaskPushNotificationConfigsRequest(task_id='t1'), + ListTaskPushNotificationConfigsResponse(), + ), + ], ) - events = [ - event - async for event in base_client.send_message( - sample_message, configuration=cfg + @pytest.mark.asyncio + async def test_methods( + self, mock_transport: AsyncMock, method_name, request_obj, response_obj + ) -> None: + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + mock_method = getattr(mock_transport, method_name) + mock_method.return_value = response_obj + + result = await getattr(decorator, method_name)(request_obj) + + assert result == response_obj + mock_method.assert_called_once() + assert request_obj.tenant == tenant_id + + @pytest.mark.asyncio + async def test_send_message(self, mock_transport: AsyncMock) -> None: + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + request = SendMessageRequest( + message=Message(parts=[Part(text='hello')]) ) - ] + mock_transport.send_message.return_value = SendMessageResponse() - mock_transport.send_message_streaming.assert_called_once() - assert not mock_transport.send_message.called - assert len(events) == 1 - stream_response, _ = events[0] - assert stream_response.task.id == 'task-cfg-s-1' + await decorator.send_message(request) - params = mock_transport.send_message_streaming.call_args[0][0] - assert params.configuration.history_length == 0 - assert params.configuration.blocking is True - assert params.configuration.accepted_output_modes == ['text/plain'] + mock_transport.send_message.assert_called_once() + assert request.tenant == tenant_id + @pytest.mark.asyncio + async def test_delete_config(self, mock_transport: AsyncMock) -> None: + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + request = DeleteTaskPushNotificationConfigRequest(task_id='t1', id='c1') -@pytest.mark.asyncio -async def test_tenant_transport_decorator(mock_transport: AsyncMock) -> None: - tenant_id = 'test-tenant' - decorator = TenantTransportDecorator(mock_transport, tenant_id) + await decorator.delete_task_push_notification_config(request) - request = SendMessageRequest( - message=Message(parts=[Part(text='hello')]), - ) - mock_transport.send_message.return_value = SendMessageResponse() - - await decorator.send_message(request) - - mock_transport.send_message.assert_called_once() - called_request = mock_transport.send_message.call_args[0][0] - assert called_request.tenant == tenant_id + mock_transport.delete_task_push_notification_config.assert_called_once_with( + request, context=None, extensions=None + ) + assert request.tenant == tenant_id + + @pytest.mark.asyncio + async def test_streaming_methods(self, mock_transport: AsyncMock) -> None: + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + + async def mock_stream(*args, **kwargs): + yield StreamResponse() + + # Test subscribe + mock_transport.subscribe.return_value = mock_stream() + request_sub = SubscribeToTaskRequest(id='t1') + async for _ in decorator.subscribe(request_sub): + pass + assert request_sub.tenant == tenant_id + + # Test send_message_streaming + mock_transport.send_message_streaming.return_value = mock_stream() + request_msg = SendMessageRequest() + async for _ in decorator.send_message_streaming(request_msg): + pass + assert request_msg.tenant == tenant_id + + @pytest.mark.asyncio + async def test_misc(self, mock_transport: AsyncMock) -> None: + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + + # Test get_extended_agent_card + card = MagicMock(spec=AgentCard) + mock_transport.get_extended_agent_card.return_value = card + res = await decorator.get_extended_agent_card() + assert res is card + + # Test close + await decorator.close() + mock_transport.close.assert_called_once() From 603ba8c5b71879e2a88f44be992cd5b9681c8fa6 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 4 Mar 2026 10:20:17 +0000 Subject: [PATCH 06/14] refactor: update pyproject.py to not include src/a2a/compat/*/*_pb2*.py files for coverage testing --- pyproject.toml | 1 + tests/client/test_base_client.py | 52 ++++++++------------------------ 2 files changed, 13 insertions(+), 40 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dffb43a71..0814a70e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,6 +176,7 @@ omit = [ "*/__init__.py", "src/a2a/types/a2a_pb2.py", "src/a2a/types/a2a_pb2_grpc.py", + "src/a2a/compat/*/*_pb2*.py", ] [tool.coverage.report] diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index e4d5d834a..9f5f7b5c1 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -1,4 +1,3 @@ -from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -21,7 +20,6 @@ ListTasksResponse, Message, Part, - PushNotificationConfig, Role, SendMessageConfiguration, SendMessageRequest, @@ -318,82 +316,56 @@ async def test_update_tenant_logic(self, mock_transport: AsyncMock) -> None: assert decorator.update_tenant('') == tenant_id @pytest.mark.parametrize( - 'method_name, request_obj, response_obj', + 'method_name, request_obj', [ + ( + 'send_message', + SendMessageRequest(message=Message(parts=[Part(text='hello')])), + ), ( 'get_task', GetTaskRequest(id='t1'), - Task(id='t1'), ), ( 'list_tasks', ListTasksRequest(), - ListTasksResponse(), ), ( 'cancel_task', CancelTaskRequest(id='t1'), - Task(id='t1'), ), ( 'create_task_push_notification_config', CreateTaskPushNotificationConfigRequest(task_id='t1'), - TaskPushNotificationConfig(task_id='t1'), ), ( 'get_task_push_notification_config', GetTaskPushNotificationConfigRequest(task_id='t1', id='c1'), - TaskPushNotificationConfig(task_id='t1'), ), ( 'list_task_push_notification_configs', ListTaskPushNotificationConfigsRequest(task_id='t1'), - ListTaskPushNotificationConfigsResponse(), + ), + ( + 'delete_task_push_notification_config', + DeleteTaskPushNotificationConfigRequest(task_id='t1', id='c1'), ), ], ) @pytest.mark.asyncio async def test_methods( - self, mock_transport: AsyncMock, method_name, request_obj, response_obj + self, mock_transport: AsyncMock, method_name, request_obj ) -> None: tenant_id = 'test-tenant' decorator = TenantTransportDecorator(mock_transport, tenant_id) mock_method = getattr(mock_transport, method_name) - mock_method.return_value = response_obj - result = await getattr(decorator, method_name)(request_obj) + await getattr(decorator, method_name)(request_obj) - assert result == response_obj mock_method.assert_called_once() + assert mock_transport.mock_calls[0][0] == method_name assert request_obj.tenant == tenant_id - @pytest.mark.asyncio - async def test_send_message(self, mock_transport: AsyncMock) -> None: - tenant_id = 'test-tenant' - decorator = TenantTransportDecorator(mock_transport, tenant_id) - request = SendMessageRequest( - message=Message(parts=[Part(text='hello')]) - ) - mock_transport.send_message.return_value = SendMessageResponse() - - await decorator.send_message(request) - - mock_transport.send_message.assert_called_once() - assert request.tenant == tenant_id - - @pytest.mark.asyncio - async def test_delete_config(self, mock_transport: AsyncMock) -> None: - tenant_id = 'test-tenant' - decorator = TenantTransportDecorator(mock_transport, tenant_id) - request = DeleteTaskPushNotificationConfigRequest(task_id='t1', id='c1') - - await decorator.delete_task_push_notification_config(request) - - mock_transport.delete_task_push_notification_config.assert_called_once_with( - request, context=None, extensions=None - ) - assert request.tenant == tenant_id - @pytest.mark.asyncio async def test_streaming_methods(self, mock_transport: AsyncMock) -> None: tenant_id = 'test-tenant' From 08befc376300c0bb62ac075f159eee0096e92e73 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 4 Mar 2026 12:14:47 +0000 Subject: [PATCH 07/14] refactor: put tenant back in requests in rest --- src/a2a/client/transports/rest.py | 43 ++++++++++++++----------------- tests/integration/test_tenant.py | 3 +-- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 76086b1fd..fef5a0144 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -7,6 +7,7 @@ import httpx from google.protobuf.json_format import MessageToDict, Parse, ParseDict +from google.protobuf.message import Message from httpx_sse import SSEError, aconnect_sse from a2a.client.errors import ( @@ -99,10 +100,6 @@ def _get_path(self, base_path: str, tenant: str | None) -> str: """Returns the full path, prepending the tenant if provided.""" return f'/{tenant}{base_path}' if tenant else base_path - def _pop_tenant(self, data: dict[str, Any]) -> str | None: - """Pops and returns the tenant from the dictionary if it exists.""" - return data.pop('tenant', None) - async def send_message( self, request: SendMessageRequest, @@ -114,8 +111,7 @@ async def send_message( payload, modified_kwargs = await self._prepare_send_message( request, context, extensions ) - tenant = self._pop_tenant(payload) or request.tenant - path = self._get_path('/v1/message:send', tenant) + path = self._get_path('/v1/message:send', request.tenant) response_data = await self._send_post_request( path, payload, modified_kwargs ) @@ -135,8 +131,7 @@ async def send_message_streaming( payload, modified_kwargs = await self._prepare_send_message( request, context, extensions ) - tenant = self._pop_tenant(payload) or request.tenant - path = self._get_path('/v1/message:stream', tenant) + path = self._get_path('/v1/message:stream', request.tenant) modified_kwargs.setdefault('timeout', None) @@ -250,8 +245,7 @@ async def get_task( if 'id' in params: del params['id'] # id is part of the URL path, not query params - tenant = self._pop_tenant(params) or request.tenant - path = self._get_path(f'/v1/tasks/{request.id}', tenant) + path = self._get_path(f'/v1/tasks/{request.id}', request.tenant) response_data = await self._send_get_request( path, @@ -269,7 +263,7 @@ async def list_tasks( extensions: list[str] | None = None, ) -> ListTasksResponse: """Retrieves tasks for an agent.""" - payload, modified_kwargs = await self._apply_interceptors( + _, modified_kwargs = await self._apply_interceptors( MessageToDict(request, preserving_proto_field_name=True), self._get_http_args(context), context, @@ -279,12 +273,11 @@ async def list_tasks( extensions if extensions is not None else self.extensions, ) - tenant = self._pop_tenant(payload) or request.tenant - path = self._get_path('/v1/tasks', tenant) + path = self._get_path('/v1/tasks', request.tenant) response_data = await self._send_get_request( path, - payload, + _model_to_query_params(request), modified_kwargs, ) response: ListTasksResponse = ParseDict( @@ -311,8 +304,7 @@ async def cancel_task( context, ) - tenant = self._pop_tenant(payload) or request.tenant - path = self._get_path(f'/v1/tasks/{request.id}:cancel', tenant) + path = self._get_path(f'/v1/tasks/{request.id}:cancel', request.tenant) response_data = await self._send_post_request( path, payload, modified_kwargs @@ -337,9 +329,9 @@ async def create_task_push_notification_config( payload, modified_kwargs, context ) - tenant = self._pop_tenant(payload) or request.tenant path = self._get_path( - f'/v1/tasks/{request.task_id}/pushNotificationConfigs', tenant + f'/v1/tasks/{request.task_id}/pushNotificationConfigs', + request.tenant, ) response_data = await self._send_post_request( @@ -375,10 +367,9 @@ async def get_task_push_notification_config( if 'task_id' in params: del params['task_id'] - tenant = self._pop_tenant(params) or request.tenant path = self._get_path( f'/v1/tasks/{request.task_id}/pushNotificationConfigs/{request.id}', - tenant, + request.tenant, ) response_data = await self._send_get_request( @@ -412,9 +403,9 @@ async def list_task_push_notification_configs( if 'task_id' in params: del params['task_id'] - tenant = self._pop_tenant(params) or request.tenant path = self._get_path( - f'/v1/tasks/{request.task_id}/pushNotificationConfigs', tenant + f'/v1/tasks/{request.task_id}/pushNotificationConfigs', + request.tenant, ) response_data = await self._send_get_request( @@ -450,10 +441,9 @@ async def delete_task_push_notification_config( if 'task_id' in params: del params['task_id'] - tenant = self._pop_tenant(params) or request.tenant path = self._get_path( f'/v1/tasks/{request.task_id}/pushNotificationConfigs/{request.id}', - tenant, + request.tenant, ) await self._send_delete_request( @@ -544,6 +534,11 @@ async def close(self) -> None: await self.httpx_client.aclose() +def _model_to_query_params(instance: Message) -> dict[str, str]: + data = MessageToDict(instance, preserving_proto_field_name=True) + return _json_to_query_params(data) + + def _json_to_query_params(data: dict[str, Any]) -> dict[str, str]: query_dict = {} for key, value in data.items(): diff --git a/tests/integration/test_tenant.py b/tests/integration/test_tenant.py index 54c97ebec..134110bef 100644 --- a/tests/integration/test_tenant.py +++ b/tests/integration/test_tenant.py @@ -73,8 +73,7 @@ async def test_tenant_decorator_rest(agent_card): ) args, kwargs = send_call assert args[1] == 'http://example.com/rest/tenant-1/v1/message:send' - # tenant should NOT be in JSON body as it was popped - assert 'tenant' not in kwargs['json'] + assert 'tenant' in kwargs['json'] @pytest.mark.asyncio From 64dd6dbd6cee18141331faa014d201e4182e7b83 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 4 Mar 2026 12:16:26 +0000 Subject: [PATCH 08/14] refactor: small change to make code consistent --- src/a2a/client/transports/rest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index fef5a0144..a628284c5 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -466,8 +466,9 @@ async def subscribe( ) modified_kwargs.setdefault('timeout', None) - tenant = request.tenant - path = self._get_path(f'/v1/tasks/{request.id}:subscribe', tenant) + path = self._get_path( + f'/v1/tasks/{request.id}:subscribe', request.tenant + ) async with aconnect_sse( self.httpx_client, From 97469c3a112ce03d7949ba5d3583f5f036b7844d Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 4 Mar 2026 16:32:11 +0000 Subject: [PATCH 09/14] refactor: remove TenantTransportDecorator and update transport imports --- src/a2a/client/client_factory.py | 3 +- src/a2a/client/transports/__init__.py | 3 +- src/a2a/client/transports/base.py | 168 ---------------- src/a2a/client/transports/rest.py | 82 ++++---- src/a2a/client/transports/tenant_decorator.py | 188 ++++++++++++++++++ tests/client/test_base_client.py | 103 +--------- tests/client/test_client_factory.py | 3 +- tests/client/transports/test_rest_client.py | 4 +- .../transports/test_tenant_decorator.py | 140 +++++++++++++ tests/integration/test_tenant.py | 9 +- 10 files changed, 375 insertions(+), 328 deletions(-) create mode 100644 src/a2a/client/transports/tenant_decorator.py create mode 100644 tests/client/transports/test_tenant_decorator.py diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 0177cba9e..ed3395d85 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -11,9 +11,10 @@ from a2a.client.card_resolver import A2ACardResolver from a2a.client.client import Client, ClientConfig, Consumer from a2a.client.middleware import ClientCallInterceptor -from a2a.client.transports.base import ClientTransport, TenantTransportDecorator +from a2a.client.transports.base import ClientTransport from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.client.transports.rest import RestTransport +from a2a.client.transports.tenant_decorator import TenantTransportDecorator from a2a.types.a2a_pb2 import ( AgentCapabilities, AgentCard, diff --git a/src/a2a/client/transports/__init__.py b/src/a2a/client/transports/__init__.py index 8c17ac04b..af7c60f62 100644 --- a/src/a2a/client/transports/__init__.py +++ b/src/a2a/client/transports/__init__.py @@ -1,6 +1,6 @@ """A2A Client Transports.""" -from a2a.client.transports.base import ClientTransport, TenantTransportDecorator +from a2a.client.transports.base import ClientTransport from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.client.transports.rest import RestTransport @@ -16,5 +16,4 @@ 'GrpcTransport', 'JsonRpcTransport', 'RestTransport', - 'TenantTransportDecorator', ] diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py index 4e6f8bde7..2d2c29873 100644 --- a/src/a2a/client/transports/base.py +++ b/src/a2a/client/transports/base.py @@ -158,171 +158,3 @@ async def get_extended_agent_card( @abstractmethod async def close(self) -> None: """Closes the transport.""" - - -class TenantTransportDecorator(ClientTransport): - """A transport decorator that attaches a tenant to all requests.""" - - def __init__(self, base: ClientTransport, tenant: str): - self._base = base - self._tenant = tenant - - def update_tenant(self, tenant: str) -> str: - """If tenant is not provided, use the default tenant. - - Returns: - The tenant used for the request. - """ - if tenant != '': - return tenant - return self._tenant or '' - - async def send_message( - self, - request: SendMessageRequest, - *, - context: ClientCallContext | None = None, - extensions: list[str] | None = None, - ) -> SendMessageResponse: - """Sends a streaming message request to the agent and yields responses as they arrive.""" - request.tenant = self.update_tenant(request.tenant) - return await self._base.send_message( - request, context=context, extensions=extensions - ) - - async def send_message_streaming( - self, - request: SendMessageRequest, - *, - context: ClientCallContext | None = None, - extensions: list[str] | None = None, - ) -> AsyncGenerator[StreamResponse]: - """Sends a streaming message request to the agent and yields responses.""" - request.tenant = self.update_tenant(request.tenant) - async for event in self._base.send_message_streaming( - request, context=context, extensions=extensions - ): - yield event - - async def get_task( - self, - request: GetTaskRequest, - *, - context: ClientCallContext | None = None, - extensions: list[str] | None = None, - ) -> Task: - """Retrieves the current state and history of a specific task.""" - request.tenant = self.update_tenant(request.tenant) - return await self._base.get_task( - request, context=context, extensions=extensions - ) - - async def list_tasks( - self, - request: ListTasksRequest, - *, - context: ClientCallContext | None = None, - extensions: list[str] | None = None, - ) -> ListTasksResponse: - """Retrieves tasks for an agent.""" - request.tenant = self.update_tenant(request.tenant) - return await self._base.list_tasks( - request, context=context, extensions=extensions - ) - - async def cancel_task( - self, - request: CancelTaskRequest, - *, - context: ClientCallContext | None = None, - extensions: list[str] | None = None, - ) -> Task: - """Requests the agent to cancel a specific task.""" - request.tenant = self.update_tenant(request.tenant) - return await self._base.cancel_task( - request, context=context, extensions=extensions - ) - - async def create_task_push_notification_config( - self, - request: CreateTaskPushNotificationConfigRequest, - *, - context: ClientCallContext | None = None, - extensions: list[str] | None = None, - ) -> TaskPushNotificationConfig: - """Sets or updates the push notification configuration for a specific task.""" - request.tenant = self.update_tenant(request.tenant) - return await self._base.create_task_push_notification_config( - request, context=context, extensions=extensions - ) - - async def get_task_push_notification_config( - self, - request: GetTaskPushNotificationConfigRequest, - *, - context: ClientCallContext | None = None, - extensions: list[str] | None = None, - ) -> TaskPushNotificationConfig: - """Retrieves the push notification configuration for a specific task.""" - request.tenant = self.update_tenant(request.tenant) - return await self._base.get_task_push_notification_config( - request, context=context, extensions=extensions - ) - - async def list_task_push_notification_configs( - self, - request: ListTaskPushNotificationConfigsRequest, - *, - context: ClientCallContext | None = None, - extensions: list[str] | None = None, - ) -> ListTaskPushNotificationConfigsResponse: - """Lists push notification configurations for a specific task.""" - request.tenant = self.update_tenant(request.tenant) - return await self._base.list_task_push_notification_configs( - request, context=context, extensions=extensions - ) - - async def delete_task_push_notification_config( - self, - request: DeleteTaskPushNotificationConfigRequest, - *, - context: ClientCallContext | None = None, - extensions: list[str] | None = None, - ) -> None: - """Deletes the push notification configuration for a specific task.""" - request.tenant = self.update_tenant(request.tenant) - await self._base.delete_task_push_notification_config( - request, context=context, extensions=extensions - ) - - async def subscribe( - self, - request: SubscribeToTaskRequest, - *, - context: ClientCallContext | None = None, - extensions: list[str] | None = None, - ) -> AsyncGenerator[StreamResponse]: - """Reconnects to get task updates.""" - request.tenant = self.update_tenant(request.tenant) - async for event in self._base.subscribe( - request, context=context, extensions=extensions - ): - yield event - - async def get_extended_agent_card( - self, - *, - context: ClientCallContext | None = None, - extensions: list[str] | None = None, - signature_verifier: Callable[[AgentCard], None] | None = None, - ) -> AgentCard: - """Retrieves the Extended AgentCard.""" - return await self._base.get_extended_agent_card( - context=context, - extensions=extensions, - signature_verifier=signature_verifier, - ) - - async def close(self) -> None: - """Closes the transport.""" - await self._base.close() diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index bfba3876e..dd545bbb7 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -66,10 +66,6 @@ def __init__( self._needs_extended_card = agent_card.capabilities.extended_agent_card self.extensions = extensions - def _get_path(self, base_path: str, tenant: str | None) -> str: - """Returns the full path, prepending the tenant if provided.""" - return f'/{tenant}{base_path}' if tenant else base_path - async def send_message( self, request: SendMessageRequest, @@ -81,9 +77,8 @@ async def send_message( payload, modified_kwargs = await self._prepare_send_message( request, context, extensions ) - path = self._get_path('/v1/message:send', request.tenant) response_data = await self._send_post_request( - path, payload, modified_kwargs + '/v1/message:send', request.tenant, payload, modified_kwargs ) response: SendMessageResponse = ParseDict( response_data, SendMessageResponse() @@ -101,11 +96,10 @@ async def send_message_streaming( payload, modified_kwargs = await self._prepare_send_message( request, context, extensions ) - path = self._get_path('/v1/message:stream', request.tenant) - async for event in self._send_stream_request( 'POST', - path, + '/v1/message:stream', + request.tenant, http_kwargs=modified_kwargs, json=payload, ): @@ -133,10 +127,9 @@ async def get_task( if 'id' in params: del params['id'] # id is part of the URL path, not query params - path = self._get_path(f'/v1/tasks/{request.id}', request.tenant) - response_data = await self._send_get_request( - path, + f'/v1/tasks/{request.id}', + request.tenant, params, modified_kwargs, ) @@ -161,10 +154,9 @@ async def list_tasks( extensions if extensions is not None else self.extensions, ) - path = self._get_path('/v1/tasks', request.tenant) - response_data = await self._send_get_request( - path, + '/v1/tasks', + request.tenant, _model_to_query_params(request), modified_kwargs, ) @@ -192,10 +184,11 @@ async def cancel_task( context, ) - path = self._get_path(f'/v1/tasks/{request.id}:cancel', request.tenant) - response_data = await self._send_post_request( - path, payload, modified_kwargs + f'/v1/tasks/{request.id}:cancel', + request.tenant, + payload, + modified_kwargs, ) response: Task = ParseDict(response_data, Task()) return response @@ -217,13 +210,9 @@ async def create_task_push_notification_config( payload, modified_kwargs, context ) - path = self._get_path( + response_data = await self._send_post_request( f'/v1/tasks/{request.task_id}/pushNotificationConfigs', request.tenant, - ) - - response_data = await self._send_post_request( - path, payload, modified_kwargs, ) @@ -255,13 +244,9 @@ async def get_task_push_notification_config( if 'task_id' in params: del params['task_id'] - path = self._get_path( + response_data = await self._send_get_request( f'/v1/tasks/{request.task_id}/pushNotificationConfigs/{request.id}', request.tenant, - ) - - response_data = await self._send_get_request( - path, params, modified_kwargs, ) @@ -291,13 +276,9 @@ async def list_task_push_notification_configs( if 'task_id' in params: del params['task_id'] - path = self._get_path( + response_data = await self._send_get_request( f'/v1/tasks/{request.task_id}/pushNotificationConfigs', request.tenant, - ) - - response_data = await self._send_get_request( - path, params, modified_kwargs, ) @@ -329,13 +310,9 @@ async def delete_task_push_notification_config( if 'task_id' in params: del params['task_id'] - path = self._get_path( + await self._send_delete_request( f'/v1/tasks/{request.task_id}/pushNotificationConfigs/{request.id}', request.tenant, - ) - - await self._send_delete_request( - path, params, modified_kwargs, ) @@ -353,13 +330,10 @@ async def subscribe( extensions if extensions is not None else self.extensions, ) - path = self._get_path( - f'/v1/tasks/{request.id}:subscribe', request.tenant - ) - async for event in self._send_stream_request( 'GET', - f'{path}', + f'/v1/tasks/{request.id}:subscribe', + request.tenant, http_kwargs=modified_kwargs, ): yield event @@ -387,7 +361,7 @@ async def get_extended_agent_card( context, ) response_data = await self._send_get_request( - '/v1/card', {}, modified_kwargs + '/v1/card', '', {}, modified_kwargs ) response: AgentCard = ParseDict(response_data, AgentCard()) @@ -403,6 +377,10 @@ async def close(self) -> None: """Closes the httpx client.""" await self.httpx_client.aclose() + def _get_path(self, base_path: str, tenant: str) -> str: + """Returns the full path, prepending the tenant if provided.""" + return f'/{tenant}{base_path}' if tenant else base_path + async def _apply_interceptors( self, request_payload: dict[str, Any], @@ -465,16 +443,18 @@ async def _send_stream_request( self, method: str, target: str, + tenant: str, http_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamResponse]: final_kwargs = dict(http_kwargs or {}) final_kwargs.update(kwargs) + path = self._get_path(target, tenant) async for sse_data in send_http_stream_request( self.httpx_client, method, - f'{self.url}{target}', + f'{self.url}{path}', self._handle_http_error, **final_kwargs, ): @@ -489,13 +469,15 @@ async def _send_request(self, request: httpx.Request) -> dict[str, Any]: async def _send_post_request( self, target: str, + tenant: str, rpc_request_payload: dict[str, Any], http_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: + path = self._get_path(target, tenant) return await self._send_request( self.httpx_client.build_request( 'POST', - f'{self.url}{target}', + f'{self.url}{path}', json=rpc_request_payload, **(http_kwargs or {}), ) @@ -504,13 +486,15 @@ async def _send_post_request( async def _send_get_request( self, target: str, + tenant: str, query_params: dict[str, str], http_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: + path = self._get_path(target, tenant) return await self._send_request( self.httpx_client.build_request( 'GET', - f'{self.url}{target}', + f'{self.url}{path}', params=query_params, **(http_kwargs or {}), ) @@ -519,13 +503,15 @@ async def _send_get_request( async def _send_delete_request( self, target: str, + tenant: str, query_params: dict[str, Any], http_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: + path = self._get_path(target, tenant) return await self._send_request( self.httpx_client.build_request( 'DELETE', - f'{self.url}{target}', + f'{self.url}{path}', params=query_params, **(http_kwargs or {}), ) diff --git a/src/a2a/client/transports/tenant_decorator.py b/src/a2a/client/transports/tenant_decorator.py new file mode 100644 index 000000000..52ab64574 --- /dev/null +++ b/src/a2a/client/transports/tenant_decorator.py @@ -0,0 +1,188 @@ +from collections.abc import AsyncGenerator, Callable + +from a2a.client.middleware import ClientCallContext +from a2a.client.transports.base import ClientTransport +from a2a.types.a2a_pb2 import ( + AgentCard, + CancelTaskRequest, + CreateTaskPushNotificationConfigRequest, + DeleteTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTaskPushNotificationConfigsResponse, + ListTasksRequest, + ListTasksResponse, + SendMessageRequest, + SendMessageResponse, + StreamResponse, + SubscribeToTaskRequest, + Task, + TaskPushNotificationConfig, +) + + +class TenantTransportDecorator(ClientTransport): + """A transport decorator that attaches a tenant to all requests.""" + + def __init__(self, base: ClientTransport, tenant: str): + self._base = base + self._tenant = tenant + + def _resolve_tenant(self, tenant: str) -> str: + """If tenant is not provided, use the default tenant. + + Returns: + The tenant used for the request. + """ + return tenant or self._tenant + + async def send_message( + self, + request: SendMessageRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> SendMessageResponse: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.send_message( + request, context=context, extensions=extensions + ) + + async def send_message_streaming( + self, + request: SendMessageRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> AsyncGenerator[StreamResponse]: + """Sends a streaming message request to the agent and yields responses.""" + request.tenant = self._resolve_tenant(request.tenant) + async for event in self._base.send_message_streaming( + request, context=context, extensions=extensions + ): + yield event + + async def get_task( + self, + request: GetTaskRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.get_task( + request, context=context, extensions=extensions + ) + + async def list_tasks( + self, + request: ListTasksRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> ListTasksResponse: + """Retrieves tasks for an agent.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.list_tasks( + request, context=context, extensions=extensions + ) + + async def cancel_task( + self, + request: CancelTaskRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.cancel_task( + request, context=context, extensions=extensions + ) + + async def create_task_push_notification_config( + self, + request: CreateTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.create_task_push_notification_config( + request, context=context, extensions=extensions + ) + + async def get_task_push_notification_config( + self, + request: GetTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.get_task_push_notification_config( + request, context=context, extensions=extensions + ) + + async def list_task_push_notification_configs( + self, + request: ListTaskPushNotificationConfigsRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> ListTaskPushNotificationConfigsResponse: + """Lists push notification configurations for a specific task.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.list_task_push_notification_configs( + request, context=context, extensions=extensions + ) + + async def delete_task_push_notification_config( + self, + request: DeleteTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> None: + """Deletes the push notification configuration for a specific task.""" + request.tenant = self._resolve_tenant(request.tenant) + await self._base.delete_task_push_notification_config( + request, context=context, extensions=extensions + ) + + async def subscribe( + self, + request: SubscribeToTaskRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> AsyncGenerator[StreamResponse]: + """Reconnects to get task updates.""" + request.tenant = self._resolve_tenant(request.tenant) + async for event in self._base.subscribe( + request, context=context, extensions=extensions + ): + yield event + + async def get_extended_agent_card( + self, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, + ) -> AgentCard: + """Retrieves the Extended AgentCard.""" + return await self._base.get_extended_agent_card( + context=context, + extensions=extensions, + signature_verifier=signature_verifier, + ) + + async def close(self) -> None: + """Closes the transport.""" + await self._base.close() diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index 9f5f7b5c1..384b18fb0 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -4,7 +4,7 @@ from a2a.client.base_client import BaseClient from a2a.client.client import ClientConfig -from a2a.client.transports.base import ClientTransport, TenantTransportDecorator +from a2a.client.transports.base import ClientTransport from a2a.types.a2a_pb2 import ( AgentCapabilities, AgentCard, @@ -301,104 +301,3 @@ async def create_stream(*args, **kwargs): assert params.configuration.history_length == 0 assert params.configuration.blocking is True assert params.configuration.accepted_output_modes == ['text/plain'] - - -class TestTenantTransportDecorator: - @pytest.mark.asyncio - async def test_update_tenant_logic(self, mock_transport: AsyncMock) -> None: - tenant_id = 'test-tenant' - decorator = TenantTransportDecorator(mock_transport, tenant_id) - - # Case 1: Tenant already set on request - assert decorator.update_tenant('existing-tenant') == 'existing-tenant' - - # Case 2: Tenant not set (empty string) - assert decorator.update_tenant('') == tenant_id - - @pytest.mark.parametrize( - 'method_name, request_obj', - [ - ( - 'send_message', - SendMessageRequest(message=Message(parts=[Part(text='hello')])), - ), - ( - 'get_task', - GetTaskRequest(id='t1'), - ), - ( - 'list_tasks', - ListTasksRequest(), - ), - ( - 'cancel_task', - CancelTaskRequest(id='t1'), - ), - ( - 'create_task_push_notification_config', - CreateTaskPushNotificationConfigRequest(task_id='t1'), - ), - ( - 'get_task_push_notification_config', - GetTaskPushNotificationConfigRequest(task_id='t1', id='c1'), - ), - ( - 'list_task_push_notification_configs', - ListTaskPushNotificationConfigsRequest(task_id='t1'), - ), - ( - 'delete_task_push_notification_config', - DeleteTaskPushNotificationConfigRequest(task_id='t1', id='c1'), - ), - ], - ) - @pytest.mark.asyncio - async def test_methods( - self, mock_transport: AsyncMock, method_name, request_obj - ) -> None: - tenant_id = 'test-tenant' - decorator = TenantTransportDecorator(mock_transport, tenant_id) - mock_method = getattr(mock_transport, method_name) - - await getattr(decorator, method_name)(request_obj) - - mock_method.assert_called_once() - assert mock_transport.mock_calls[0][0] == method_name - assert request_obj.tenant == tenant_id - - @pytest.mark.asyncio - async def test_streaming_methods(self, mock_transport: AsyncMock) -> None: - tenant_id = 'test-tenant' - decorator = TenantTransportDecorator(mock_transport, tenant_id) - - async def mock_stream(*args, **kwargs): - yield StreamResponse() - - # Test subscribe - mock_transport.subscribe.return_value = mock_stream() - request_sub = SubscribeToTaskRequest(id='t1') - async for _ in decorator.subscribe(request_sub): - pass - assert request_sub.tenant == tenant_id - - # Test send_message_streaming - mock_transport.send_message_streaming.return_value = mock_stream() - request_msg = SendMessageRequest() - async for _ in decorator.send_message_streaming(request_msg): - pass - assert request_msg.tenant == tenant_id - - @pytest.mark.asyncio - async def test_misc(self, mock_transport: AsyncMock) -> None: - tenant_id = 'test-tenant' - decorator = TenantTransportDecorator(mock_transport, tenant_id) - - # Test get_extended_agent_card - card = MagicMock(spec=AgentCard) - mock_transport.get_extended_agent_card.return_value = card - res = await decorator.get_extended_agent_card() - assert res is card - - # Test close - await decorator.close() - mock_transport.close.assert_called_once() diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index 8ccc4e0b8..d101594e9 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -12,8 +12,9 @@ from a2a.client.transports import ( JsonRpcTransport, RestTransport, - TenantTransportDecorator, + ClientTransport, ) +from a2a.client.transports.tenant_decorator import TenantTransportDecorator from a2a.types.a2a_pb2 import ( AgentCapabilities, AgentCard, diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 75f959aaf..4a403066f 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -315,7 +315,7 @@ async def test_get_card_with_extended_card_support_with_extensions( await client.get_extended_agent_card(extensions=extensions) mock_send_get_request.assert_called_once() - _, _, mock_kwargs = mock_send_get_request.call_args[0] + _, _, _, mock_kwargs = mock_send_get_request.call_args[0] _assert_extensions_header( mock_kwargs, @@ -524,7 +524,7 @@ async def test_rest_methods_prepend_tenant( ], ) @pytest.mark.asyncio - @patch('a2a.client.transports.rest.aconnect_sse') + @patch('a2a.client.transports.http_helpers.aconnect_sse') async def test_rest_streaming_methods_prepend_tenant( self, mock_aconnect_sse, diff --git a/tests/client/transports/test_tenant_decorator.py b/tests/client/transports/test_tenant_decorator.py new file mode 100644 index 000000000..ecde7f935 --- /dev/null +++ b/tests/client/transports/test_tenant_decorator.py @@ -0,0 +1,140 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock + +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.tenant_decorator import TenantTransportDecorator +from a2a.types.a2a_pb2 import ( + AgentCard, + CancelTaskRequest, + CreateTaskPushNotificationConfigRequest, + DeleteTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTasksRequest, + Message, + Part, + SendMessageRequest, + StreamResponse, + SubscribeToTaskRequest, +) + + +@pytest.fixture +def mock_transport() -> AsyncMock: + return AsyncMock(spec=ClientTransport) + + +class TestTenantTransportDecorator: + @pytest.mark.asyncio + async def test_resolve_tenant_logic( + self, mock_transport: AsyncMock + ) -> None: + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + + # Case 1: Tenant already set on request + assert decorator._resolve_tenant('existing-tenant') == 'existing-tenant' + + # Case 2: Tenant not set (empty string) + assert decorator._resolve_tenant('') == tenant_id + + @pytest.mark.asyncio + async def test_resolve_tenant_logic_empty_tenant( + self, mock_transport: AsyncMock + ) -> None: + decorator = TenantTransportDecorator(mock_transport, '') + + # Case 1: Tenant already set on request + assert decorator._resolve_tenant('existing-tenant') == 'existing-tenant' + + # Case 2: Tenant not set (empty string) + assert decorator._resolve_tenant('') == '' + + @pytest.mark.parametrize( + 'method_name, request_obj', + [ + ( + 'send_message', + SendMessageRequest(message=Message(parts=[Part(text='hello')])), + ), + ( + 'get_task', + GetTaskRequest(id='t1'), + ), + ( + 'list_tasks', + ListTasksRequest(), + ), + ( + 'cancel_task', + CancelTaskRequest(id='t1'), + ), + ( + 'create_task_push_notification_config', + CreateTaskPushNotificationConfigRequest(task_id='t1'), + ), + ( + 'get_task_push_notification_config', + GetTaskPushNotificationConfigRequest(task_id='t1', id='c1'), + ), + ( + 'list_task_push_notification_configs', + ListTaskPushNotificationConfigsRequest(task_id='t1'), + ), + ( + 'delete_task_push_notification_config', + DeleteTaskPushNotificationConfigRequest(task_id='t1', id='c1'), + ), + ], + ) + @pytest.mark.asyncio + async def test_methods( + self, mock_transport: AsyncMock, method_name, request_obj + ) -> None: + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + mock_method = getattr(mock_transport, method_name) + + await getattr(decorator, method_name)(request_obj) + + mock_method.assert_called_once() + assert mock_transport.mock_calls[0][0] == method_name + assert request_obj.tenant == tenant_id + + @pytest.mark.asyncio + async def test_streaming_methods(self, mock_transport: AsyncMock) -> None: + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + + async def mock_stream(*args, **kwargs): + yield StreamResponse() + + # Test subscribe + mock_transport.subscribe.return_value = mock_stream() + request_sub = SubscribeToTaskRequest(id='t1') + async for _ in decorator.subscribe(request_sub): + pass + assert request_sub.tenant == tenant_id + + # Test send_message_streaming + mock_transport.send_message_streaming.return_value = mock_stream() + request_msg = SendMessageRequest() + async for _ in decorator.send_message_streaming(request_msg): + pass + assert request_msg.tenant == tenant_id + + @pytest.mark.asyncio + async def test_misc(self, mock_transport: AsyncMock) -> None: + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + + # Test get_extended_agent_card + card = MagicMock(spec=AgentCard) + mock_transport.get_extended_agent_card.return_value = card + res = await decorator.get_extended_agent_card() + assert res is card + + # Test close + await decorator.close() + mock_transport.close.assert_called_once() diff --git a/tests/integration/test_tenant.py b/tests/integration/test_tenant.py index 134110bef..b3e6a0753 100644 --- a/tests/integration/test_tenant.py +++ b/tests/integration/test_tenant.py @@ -10,7 +10,7 @@ AgentCapabilities, ) from a2a.client.transports import RestTransport, JsonRpcTransport, GrpcTransport -from a2a.client.transports.base import TenantTransportDecorator +from a2a.client.transports.tenant_decorator import TenantTransportDecorator from a2a.client import ClientConfig, ClientFactory from a2a.utils.constants import TransportProtocol @@ -79,7 +79,8 @@ async def test_tenant_decorator_rest(agent_card): @pytest.mark.asyncio async def test_tenant_decorator_jsonrpc(agent_card): mock_httpx = AsyncMock(spec=httpx.AsyncClient) - mock_httpx.post.return_value = MagicMock( + mock_httpx.build_request.return_value = MagicMock() + mock_httpx.send.return_value = MagicMock( status_code=200, json=lambda: {'result': {'message': {}}, 'id': '1', 'jsonrpc': '2.0'}, ) @@ -97,8 +98,8 @@ async def test_tenant_decorator_jsonrpc(agent_card): request = SendMessageRequest(message=Message(parts=[{'text': 'hi'}])) await client._transport.send_message(request) - mock_httpx.post.assert_called() - _, kwargs = mock_httpx.post.call_args + mock_httpx.build_request.assert_called() + _, kwargs = mock_httpx.build_request.call_args assert kwargs['json']['params']['tenant'] == 'tenant-2' From abdddb0f37ea3f2fc99b750064eb264724115aec Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 4 Mar 2026 16:42:07 +0000 Subject: [PATCH 10/14] fix: remove `v1/` from expected paths in tests --- tests/client/transports/test_rest_client.py | 20 ++++++++++---------- tests/integration/test_tenant.py | 8 ++++---- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index fc15dbc0e..42eaaf45c 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -423,50 +423,50 @@ class TestRestTransportTenant: tenant='my-tenant', message=create_text_message_object(content='hi'), ), - '/my-tenant/v1/message:send', + '/my-tenant/message:send', ), ( 'list_tasks', ListTasksRequest(tenant='my-tenant'), - '/my-tenant/v1/tasks', + '/my-tenant/tasks', ), ( 'get_task', GetTaskRequest(tenant='my-tenant', id='task-123'), - '/my-tenant/v1/tasks/task-123', + '/my-tenant/tasks/task-123', ), ( 'cancel_task', CancelTaskRequest(tenant='my-tenant', id='task-123'), - '/my-tenant/v1/tasks/task-123:cancel', + '/my-tenant/tasks/task-123:cancel', ), ( 'create_task_push_notification_config', CreateTaskPushNotificationConfigRequest( tenant='my-tenant', task_id='task-123' ), - '/my-tenant/v1/tasks/task-123/pushNotificationConfigs', + '/my-tenant/tasks/task-123/pushNotificationConfigs', ), ( 'get_task_push_notification_config', GetTaskPushNotificationConfigRequest( tenant='my-tenant', task_id='task-123', id='cfg-1' ), - '/my-tenant/v1/tasks/task-123/pushNotificationConfigs/cfg-1', + '/my-tenant/tasks/task-123/pushNotificationConfigs/cfg-1', ), ( 'list_task_push_notification_configs', ListTaskPushNotificationConfigsRequest( tenant='my-tenant', task_id='task-123' ), - '/my-tenant/v1/tasks/task-123/pushNotificationConfigs', + '/my-tenant/tasks/task-123/pushNotificationConfigs', ), ( 'delete_task_push_notification_config', DeleteTaskPushNotificationConfigRequest( tenant='my-tenant', task_id='task-123', id='cfg-1' ), - '/my-tenant/v1/tasks/task-123/pushNotificationConfigs/cfg-1', + '/my-tenant/tasks/task-123/pushNotificationConfigs/cfg-1', ), ], ) @@ -511,7 +511,7 @@ async def test_rest_methods_prepend_tenant( ( 'subscribe', SubscribeToTaskRequest(tenant='my-tenant', id='task-123'), - '/my-tenant/v1/tasks/task-123:subscribe', + '/my-tenant/tasks/task-123:subscribe', ), ( 'send_message_streaming', @@ -519,7 +519,7 @@ async def test_rest_methods_prepend_tenant( tenant='my-tenant', message=create_text_message_object(content='hi'), ), - '/my-tenant/v1/message:stream', + '/my-tenant/message:stream', ), ], ) diff --git a/tests/integration/test_tenant.py b/tests/integration/test_tenant.py index b3e6a0753..aef0289db 100644 --- a/tests/integration/test_tenant.py +++ b/tests/integration/test_tenant.py @@ -69,10 +69,10 @@ async def test_tenant_decorator_rest(agent_card): send_call = next( c for c in mock_httpx.build_request.call_args_list - if 'v1/message:send' in c.args[1] + if 'message:send' in c.args[1] ) args, kwargs = send_call - assert args[1] == 'http://example.com/rest/tenant-1/v1/message:send' + assert args[1] == 'http://example.com/rest/tenant-1/message:send' assert 'tenant' in kwargs['json'] @@ -154,7 +154,7 @@ async def test_tenant_decorator_explicit_override(agent_card): send_call = next( c for c in mock_httpx.build_request.call_args_list - if 'v1/message:send' in c.args[1] + if 'message:send' in c.args[1] ) args, _ = send_call - assert args[1] == 'http://example.com/rest/explicit-tenant/v1/message:send' + assert args[1] == 'http://example.com/rest/explicit-tenant/message:send' From 593d5bfdf6f6e08bbc053e87afebe0efd9b17b9e Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 4 Mar 2026 17:22:54 +0000 Subject: [PATCH 11/14] test: add async test for get_task with empty tenant --- tests/client/transports/test_rest_client.py | 31 +++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 42eaaf45c..ecd8918fa 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -505,6 +505,37 @@ async def test_rest_methods_prepend_tenant( args, _ = mock_httpx_client.build_request.call_args assert args[1] == f'http://agent.example.com/api{expected_path}' + @pytest.mark.asyncio + async def test_rest_get_task_prepend_empty_tenant( + self, + mock_httpx_client, + mock_agent_card, + ): + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + url='http://agent.example.com/api', + ) + + request = GetTaskRequest(tenant='', id='task-123') + + # 1. Setup mocks + mock_httpx_client.build_request.return_value = MagicMock( + spec=httpx.Request + ) + mock_httpx_client.send.return_value = AsyncMock( + spec=httpx.Response, + status_code=200, + json=MagicMock(return_value={}), + ) + + # 2. Call the method + await client.get_task(request=request) + + # 3. Verify the URL + args, _ = mock_httpx_client.build_request.call_args + assert args[1] == f'http://agent.example.com/api/tasks/task-123' + @pytest.mark.parametrize( 'method_name, request_obj, expected_path', [ From d981e0c44cce0c204e84e55a535f6ad69105a27b Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 5 Mar 2026 09:45:00 +0000 Subject: [PATCH 12/14] feat: prepend tenant to the extended agent card endpoint and add a corresponding test. --- src/a2a/client/transports/rest.py | 2 +- tests/client/transports/test_rest_client.py | 32 +++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index e52d17b4f..936603b72 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -363,7 +363,7 @@ async def get_extended_agent_card( context, ) response_data = await self._send_get_request( - '/card', '', {}, modified_kwargs + '/card', request.tenant, {}, modified_kwargs ) response: AgentCard = ParseDict(response_data, AgentCard()) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 506ed63ec..19bce5fa1 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -506,6 +506,38 @@ async def test_rest_methods_prepend_tenant( # 4. Verify the URL args, _ = mock_httpx_client.build_request.call_args assert args[1] == f'http://agent.example.com/api{expected_path}' + + @pytest.mark.asyncio + async def test_rest_get_extended_agent_card_prepend_tenant( + self, + mock_httpx_client, + mock_agent_card, + ): + mock_agent_card.capabilities.extended_agent_card = True + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + url='http://agent.example.com/api', + ) + + request = GetExtendedAgentCardRequest(tenant='my-tenant') + + # 1. Setup mocks + mock_httpx_client.build_request.return_value = MagicMock( + spec=httpx.Request + ) + mock_httpx_client.send.return_value = AsyncMock( + spec=httpx.Response, + status_code=200, + json=MagicMock(return_value={}), + ) + + # 2. Call the method + await client.get_extended_agent_card(request=request) + + # 3. Verify the URL + args, _ = mock_httpx_client.build_request.call_args + assert args[1] == 'http://agent.example.com/api/my-tenant/card' @pytest.mark.asyncio async def test_rest_get_task_prepend_empty_tenant( From c7d157e2fa9a61af05d7300c1d57fe258ffbe2c8 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 5 Mar 2026 09:48:32 +0000 Subject: [PATCH 13/14] fix: format --- tests/client/transports/test_rest_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 19bce5fa1..5d801aa91 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -506,7 +506,7 @@ async def test_rest_methods_prepend_tenant( # 4. Verify the URL args, _ = mock_httpx_client.build_request.call_args assert args[1] == f'http://agent.example.com/api{expected_path}' - + @pytest.mark.asyncio async def test_rest_get_extended_agent_card_prepend_tenant( self, From fac3d8b84581e226fd9dd1170028c0a6db4b9dda Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 5 Mar 2026 09:58:47 +0000 Subject: [PATCH 14/14] feat: Add tenant resolution to `GetExtendedAgentCardRequest` in `TenantTransportDecorator` and extend test coverage. --- src/a2a/client/transports/tenant_decorator.py | 4 ++++ .../transports/test_tenant_decorator.py | 19 ++++--------------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/src/a2a/client/transports/tenant_decorator.py b/src/a2a/client/transports/tenant_decorator.py index 52ab64574..0335bd093 100644 --- a/src/a2a/client/transports/tenant_decorator.py +++ b/src/a2a/client/transports/tenant_decorator.py @@ -7,6 +7,7 @@ CancelTaskRequest, CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest, ListTaskPushNotificationConfigsRequest, @@ -171,13 +172,16 @@ async def subscribe( async def get_extended_agent_card( self, + request: GetExtendedAgentCardRequest, *, context: ClientCallContext | None = None, extensions: list[str] | None = None, signature_verifier: Callable[[AgentCard], None] | None = None, ) -> AgentCard: """Retrieves the Extended AgentCard.""" + request.tenant = self._resolve_tenant(request.tenant) return await self._base.get_extended_agent_card( + request, context=context, extensions=extensions, signature_verifier=signature_verifier, diff --git a/tests/client/transports/test_tenant_decorator.py b/tests/client/transports/test_tenant_decorator.py index ecde7f935..f544d6762 100644 --- a/tests/client/transports/test_tenant_decorator.py +++ b/tests/client/transports/test_tenant_decorator.py @@ -8,6 +8,7 @@ CancelTaskRequest, CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, GetTaskPushNotificationConfigRequest, GetTaskRequest, ListTaskPushNotificationConfigsRequest, @@ -86,12 +87,14 @@ async def test_resolve_tenant_logic_empty_tenant( 'delete_task_push_notification_config', DeleteTaskPushNotificationConfigRequest(task_id='t1', id='c1'), ), + ('get_extended_agent_card', GetExtendedAgentCardRequest()), ], ) @pytest.mark.asyncio async def test_methods( self, mock_transport: AsyncMock, method_name, request_obj ) -> None: + """Test that tenant is set on the request for all methods.""" tenant_id = 'test-tenant' decorator = TenantTransportDecorator(mock_transport, tenant_id) mock_method = getattr(mock_transport, method_name) @@ -104,6 +107,7 @@ async def test_methods( @pytest.mark.asyncio async def test_streaming_methods(self, mock_transport: AsyncMock) -> None: + """Test that tenant is set on the request for streaming methods.""" tenant_id = 'test-tenant' decorator = TenantTransportDecorator(mock_transport, tenant_id) @@ -123,18 +127,3 @@ async def mock_stream(*args, **kwargs): async for _ in decorator.send_message_streaming(request_msg): pass assert request_msg.tenant == tenant_id - - @pytest.mark.asyncio - async def test_misc(self, mock_transport: AsyncMock) -> None: - tenant_id = 'test-tenant' - decorator = TenantTransportDecorator(mock_transport, tenant_id) - - # Test get_extended_agent_card - card = MagicMock(spec=AgentCard) - mock_transport.get_extended_agent_card.return_value = card - res = await decorator.get_extended_agent_card() - assert res is card - - # Test close - await decorator.close() - mock_transport.close.assert_called_once()