diff --git a/pyproject.toml b/pyproject.toml index dffb43a7..0814a70e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,6 +176,7 @@ omit = [ "*/__init__.py", "src/a2a/types/a2a_pb2.py", "src/a2a/types/a2a_pb2_grpc.py", + "src/a2a/compat/*/*_pb2*.py", ] [tool.coverage.report] diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index ff5387ef..30006568 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -14,6 +14,7 @@ from a2a.client.transports.base import ClientTransport from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.client.transports.rest import RestTransport +from a2a.client.transports.tenant_decorator import TenantTransportDecorator from a2a.types.a2a_pb2 import ( AgentCapabilities, AgentCard, @@ -216,10 +217,10 @@ def create( TransportProtocol.JSONRPC ] transport_protocol = None - transport_url = None + selected_interface = None if self._config.use_client_preference: for protocol_binding in client_set: - supported_interface = next( + selected_interface = next( ( si for si in card.supported_interfaces @@ -227,17 +228,16 @@ def create( ), None, ) - if supported_interface: + if selected_interface: transport_protocol = protocol_binding - transport_url = supported_interface.url break else: for supported_interface in card.supported_interfaces: if supported_interface.protocol_binding in client_set: transport_protocol = supported_interface.protocol_binding - transport_url = supported_interface.url + selected_interface = supported_interface break - if not transport_protocol or not transport_url: + if not transport_protocol or not selected_interface: raise ValueError('no compatible transports found.') if transport_protocol not in self._registry: raise ValueError(f'no client available for {transport_protocol}') @@ -252,9 +252,14 @@ def create( self._config.extensions = all_extensions transport = self._registry[transport_protocol]( - card, transport_url, self._config, interceptors or [] + card, selected_interface.url, self._config, interceptors or [] ) + if selected_interface.tenant: + transport = TenantTransportDecorator( + transport, selected_interface.tenant + ) + return BaseClient( card, self._config, diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 0ebdfcb7..0c51a266 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -79,7 +79,7 @@ async def send_message( request, context, extensions ) response_data = await self._send_post_request( - '/message:send', payload, modified_kwargs + '/message:send', request.tenant, payload, modified_kwargs ) response: SendMessageResponse = ParseDict( response_data, SendMessageResponse() @@ -97,10 +97,10 @@ async def send_message_streaming( payload, modified_kwargs = await self._prepare_send_message( request, context, extensions ) - async for event in self._send_stream_request( 'POST', '/message:stream', + request.tenant, http_kwargs=modified_kwargs, json=payload, ): @@ -130,6 +130,7 @@ async def get_task( response_data = await self._send_get_request( f'/tasks/{request.id}', + request.tenant, params, modified_kwargs, ) @@ -153,8 +154,10 @@ async def list_tasks( modified_kwargs, extensions if extensions is not None else self.extensions, ) + response_data = await self._send_get_request( '/tasks', + request.tenant, _model_to_query_params(request), modified_kwargs, ) @@ -181,8 +184,12 @@ async def cancel_task( modified_kwargs, context, ) + response_data = await self._send_post_request( - f'/tasks/{request.id}:cancel', payload, modified_kwargs + f'/tasks/{request.id}:cancel', + request.tenant, + payload, + modified_kwargs, ) response: Task = ParseDict(response_data, Task()) return response @@ -203,8 +210,10 @@ async def create_task_push_notification_config( payload, modified_kwargs = await self._apply_interceptors( payload, modified_kwargs, context ) + response_data = await self._send_post_request( f'/tasks/{request.task_id}/pushNotificationConfigs', + request.tenant, payload, modified_kwargs, ) @@ -235,8 +244,10 @@ async def get_task_push_notification_config( del params['id'] if 'task_id' in params: del params['task_id'] + response_data = await self._send_get_request( f'/tasks/{request.task_id}/pushNotificationConfigs/{request.id}', + request.tenant, params, modified_kwargs, ) @@ -265,8 +276,10 @@ async def list_task_push_notification_configs( ) if 'task_id' in params: del params['task_id'] + response_data = await self._send_get_request( f'/tasks/{request.task_id}/pushNotificationConfigs', + request.tenant, params, modified_kwargs, ) @@ -297,8 +310,10 @@ async def delete_task_push_notification_config( del params['id'] if 'task_id' in params: del params['task_id'] + await self._send_delete_request( f'/tasks/{request.task_id}/pushNotificationConfigs/{request.id}', + request.tenant, params, modified_kwargs, ) @@ -319,6 +334,7 @@ async def subscribe( async for event in self._send_stream_request( 'GET', f'/tasks/{request.id}:subscribe', + request.tenant, http_kwargs=modified_kwargs, ): yield event @@ -347,7 +363,7 @@ async def get_extended_agent_card( context, ) response_data = await self._send_get_request( - '/extendedAgentCard', {}, modified_kwargs + '/extendedAgentCard', request.tenant, {}, modified_kwargs ) response: AgentCard = ParseDict(response_data, AgentCard()) @@ -363,6 +379,10 @@ async def close(self) -> None: """Closes the httpx client.""" await self.httpx_client.aclose() + def _get_path(self, base_path: str, tenant: str) -> str: + """Returns the full path, prepending the tenant if provided.""" + return f'/{tenant}{base_path}' if tenant else base_path + async def _apply_interceptors( self, request_payload: dict[str, Any], @@ -425,16 +445,18 @@ async def _send_stream_request( self, method: str, target: str, + tenant: str, http_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamResponse]: final_kwargs = dict(http_kwargs or {}) final_kwargs.update(kwargs) + path = self._get_path(target, tenant) async for sse_data in send_http_stream_request( self.httpx_client, method, - f'{self.url}{target}', + f'{self.url}{path}', self._handle_http_error, **final_kwargs, ): @@ -449,13 +471,15 @@ async def _send_request(self, request: httpx.Request) -> dict[str, Any]: async def _send_post_request( self, target: str, + tenant: str, rpc_request_payload: dict[str, Any], http_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: + path = self._get_path(target, tenant) return await self._send_request( self.httpx_client.build_request( 'POST', - f'{self.url}{target}', + f'{self.url}{path}', json=rpc_request_payload, **(http_kwargs or {}), ) @@ -464,13 +488,15 @@ async def _send_post_request( async def _send_get_request( self, target: str, + tenant: str, query_params: dict[str, str], http_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: + path = self._get_path(target, tenant) return await self._send_request( self.httpx_client.build_request( 'GET', - f'{self.url}{target}', + f'{self.url}{path}', params=query_params, **(http_kwargs or {}), ) @@ -479,13 +505,15 @@ async def _send_get_request( async def _send_delete_request( self, target: str, + tenant: str, query_params: dict[str, Any], http_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: + path = self._get_path(target, tenant) return await self._send_request( self.httpx_client.build_request( 'DELETE', - f'{self.url}{target}', + f'{self.url}{path}', params=query_params, **(http_kwargs or {}), ) diff --git a/src/a2a/client/transports/tenant_decorator.py b/src/a2a/client/transports/tenant_decorator.py new file mode 100644 index 00000000..0335bd09 --- /dev/null +++ b/src/a2a/client/transports/tenant_decorator.py @@ -0,0 +1,192 @@ +from collections.abc import AsyncGenerator, Callable + +from a2a.client.middleware import ClientCallContext +from a2a.client.transports.base import ClientTransport +from a2a.types.a2a_pb2 import ( + AgentCard, + CancelTaskRequest, + CreateTaskPushNotificationConfigRequest, + DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTaskPushNotificationConfigsResponse, + ListTasksRequest, + ListTasksResponse, + SendMessageRequest, + SendMessageResponse, + StreamResponse, + SubscribeToTaskRequest, + Task, + TaskPushNotificationConfig, +) + + +class TenantTransportDecorator(ClientTransport): + """A transport decorator that attaches a tenant to all requests.""" + + def __init__(self, base: ClientTransport, tenant: str): + self._base = base + self._tenant = tenant + + def _resolve_tenant(self, tenant: str) -> str: + """If tenant is not provided, use the default tenant. + + Returns: + The tenant used for the request. + """ + return tenant or self._tenant + + async def send_message( + self, + request: SendMessageRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> SendMessageResponse: + """Sends a streaming message request to the agent and yields responses as they arrive.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.send_message( + request, context=context, extensions=extensions + ) + + async def send_message_streaming( + self, + request: SendMessageRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> AsyncGenerator[StreamResponse]: + """Sends a streaming message request to the agent and yields responses.""" + request.tenant = self._resolve_tenant(request.tenant) + async for event in self._base.send_message_streaming( + request, context=context, extensions=extensions + ): + yield event + + async def get_task( + self, + request: GetTaskRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.get_task( + request, context=context, extensions=extensions + ) + + async def list_tasks( + self, + request: ListTasksRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> ListTasksResponse: + """Retrieves tasks for an agent.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.list_tasks( + request, context=context, extensions=extensions + ) + + async def cancel_task( + self, + request: CancelTaskRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> Task: + """Requests the agent to cancel a specific task.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.cancel_task( + request, context=context, extensions=extensions + ) + + async def create_task_push_notification_config( + self, + request: CreateTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.create_task_push_notification_config( + request, context=context, extensions=extensions + ) + + async def get_task_push_notification_config( + self, + request: GetTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.get_task_push_notification_config( + request, context=context, extensions=extensions + ) + + async def list_task_push_notification_configs( + self, + request: ListTaskPushNotificationConfigsRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> ListTaskPushNotificationConfigsResponse: + """Lists push notification configurations for a specific task.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.list_task_push_notification_configs( + request, context=context, extensions=extensions + ) + + async def delete_task_push_notification_config( + self, + request: DeleteTaskPushNotificationConfigRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> None: + """Deletes the push notification configuration for a specific task.""" + request.tenant = self._resolve_tenant(request.tenant) + await self._base.delete_task_push_notification_config( + request, context=context, extensions=extensions + ) + + async def subscribe( + self, + request: SubscribeToTaskRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + ) -> AsyncGenerator[StreamResponse]: + """Reconnects to get task updates.""" + request.tenant = self._resolve_tenant(request.tenant) + async for event in self._base.subscribe( + request, context=context, extensions=extensions + ): + yield event + + async def get_extended_agent_card( + self, + request: GetExtendedAgentCardRequest, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, + ) -> AgentCard: + """Retrieves the Extended AgentCard.""" + request.tenant = self._resolve_tenant(request.tenant) + return await self._base.get_extended_agent_card( + request, + context=context, + extensions=extensions, + signature_verifier=signature_verifier, + ) + + async def close(self) -> None: + """Closes the transport.""" + await self._base.close() diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index ce47b7ac..384b18fb 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -7,15 +7,27 @@ from a2a.client.transports.base import ClientTransport from a2a.types.a2a_pb2 import ( AgentCapabilities, - AgentInterface, AgentCard, + AgentInterface, + CancelTaskRequest, + CreateTaskPushNotificationConfigRequest, + DeleteTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTaskPushNotificationConfigsResponse, + ListTasksRequest, + ListTasksResponse, Message, Part, Role, SendMessageConfiguration, + SendMessageRequest, SendMessageResponse, StreamResponse, + SubscribeToTaskRequest, Task, + TaskPushNotificationConfig, TaskState, TaskStatus, ) @@ -65,214 +77,227 @@ def base_client( ) -@pytest.mark.asyncio -async def test_transport_async_context_manager() -> None: - with ( - patch.object(ClientTransport, '__abstractmethods__', set()), - patch.object(ClientTransport, 'close', new_callable=AsyncMock), - ): - transport = ClientTransport() - async with transport as t: - assert t is transport - transport.close.assert_not_awaited() - transport.close.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_transport_async_context_manager_on_exception() -> None: - with ( - patch.object(ClientTransport, '__abstractmethods__', set()), - patch.object(ClientTransport, 'close', new_callable=AsyncMock), - ): - transport = ClientTransport() +class TestClientTransport: + @pytest.mark.asyncio + async def test_transport_async_context_manager(self) -> None: + with ( + patch.object(ClientTransport, '__abstractmethods__', set()), + patch.object(ClientTransport, 'close', new_callable=AsyncMock), + ): + transport = ClientTransport() + async with transport as t: + assert t is transport + transport.close.assert_not_awaited() + transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_transport_async_context_manager_on_exception(self) -> None: + with ( + patch.object(ClientTransport, '__abstractmethods__', set()), + patch.object(ClientTransport, 'close', new_callable=AsyncMock), + ): + transport = ClientTransport() + with pytest.raises(RuntimeError, match='boom'): + async with transport: + raise RuntimeError('boom') + transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_base_client_async_context_manager( + self, base_client: BaseClient, mock_transport: AsyncMock + ) -> None: + async with base_client as client: + assert client is base_client + mock_transport.close.assert_not_awaited() + mock_transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_base_client_async_context_manager_on_exception( + self, base_client: BaseClient, mock_transport: AsyncMock + ) -> None: with pytest.raises(RuntimeError, match='boom'): - async with transport: + async with base_client: raise RuntimeError('boom') - transport.close.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_base_client_async_context_manager( - base_client: BaseClient, mock_transport: AsyncMock -) -> None: - async with base_client as client: - assert client is base_client - mock_transport.close.assert_not_awaited() - mock_transport.close.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_base_client_async_context_manager_on_exception( - base_client: BaseClient, mock_transport: AsyncMock -) -> None: - with pytest.raises(RuntimeError, match='boom'): - async with base_client: - raise RuntimeError('boom') - mock_transport.close.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_send_message_streaming( - base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -) -> None: - async def create_stream(*args, **kwargs): + mock_transport.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_send_message_streaming( + self, + base_client: BaseClient, + mock_transport: MagicMock, + sample_message: Message, + ) -> None: + async def create_stream(*args, **kwargs): + task = Task( + id='task-123', + context_id='ctx-456', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + stream_response = StreamResponse() + stream_response.task.CopyFrom(task) + yield stream_response + + mock_transport.send_message_streaming.return_value = create_stream() + + meta = {'test': 1} + stream = base_client.send_message(sample_message, request_metadata=meta) + events = [event async for event in stream] + + mock_transport.send_message_streaming.assert_called_once() + assert ( + mock_transport.send_message_streaming.call_args[0][0].metadata + == meta + ) + assert not mock_transport.send_message.called + assert len(events) == 1 + # events[0] is (StreamResponse, Task) tuple + stream_response, tracked_task = events[0] + assert stream_response.task.id == 'task-123' + assert tracked_task is not None + assert tracked_task.id == 'task-123' + + @pytest.mark.asyncio + async def test_send_message_non_streaming( + self, + base_client: BaseClient, + mock_transport: MagicMock, + sample_message: Message, + ) -> None: + base_client._config.streaming = False task = Task( - id='task-123', - context_id='ctx-456', + id='task-456', + context_id='ctx-789', status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), ) - stream_response = StreamResponse() - stream_response.task.CopyFrom(task) - yield stream_response - - mock_transport.send_message_streaming.return_value = create_stream() - - meta = {'test': 1} - stream = base_client.send_message(sample_message, request_metadata=meta) - events = [event async for event in stream] - - mock_transport.send_message_streaming.assert_called_once() - assert ( - mock_transport.send_message_streaming.call_args[0][0].metadata == meta - ) - assert not mock_transport.send_message.called - assert len(events) == 1 - # events[0] is (StreamResponse, Task) tuple - stream_response, tracked_task = events[0] - assert stream_response.task.id == 'task-123' - assert tracked_task is not None - assert tracked_task.id == 'task-123' - - -@pytest.mark.asyncio -async def test_send_message_non_streaming( - base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -) -> None: - base_client._config.streaming = False - task = Task( - id='task-456', - context_id='ctx-789', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - response = SendMessageResponse() - response.task.CopyFrom(task) - mock_transport.send_message.return_value = response - - meta = {'test': 1} - stream = base_client.send_message(sample_message, request_metadata=meta) - events = [event async for event in stream] - - mock_transport.send_message.assert_called_once() - assert mock_transport.send_message.call_args[0][0].metadata == meta - assert not mock_transport.send_message_streaming.called - assert len(events) == 1 - stream_response, tracked_task = events[0] - assert stream_response.task.id == 'task-456' - assert tracked_task is not None - assert tracked_task.id == 'task-456' - - -@pytest.mark.asyncio -async def test_send_message_non_streaming_agent_capability_false( - base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -) -> None: - base_client._card.capabilities.streaming = False - task = Task( - id='task-789', - context_id='ctx-101', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - response = SendMessageResponse() - response.task.CopyFrom(task) - mock_transport.send_message.return_value = response - - events = [event async for event in base_client.send_message(sample_message)] - - mock_transport.send_message.assert_called_once() - assert not mock_transport.send_message_streaming.called - assert len(events) == 1 - stream_response, tracked_task = events[0] - assert stream_response is not None - assert tracked_task is not None - assert tracked_task.id == 'task-789' - - -@pytest.mark.asyncio -async def test_send_message_callsite_config_overrides_non_streaming( - base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -): - base_client._config.streaming = False - task = Task( - id='task-cfg-ns-1', - context_id='ctx-cfg-ns-1', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - response = SendMessageResponse() - response.task.CopyFrom(task) - mock_transport.send_message.return_value = response - - cfg = SendMessageConfiguration( - history_length=2, - blocking=False, - accepted_output_modes=['application/json'], - ) - events = [ - event - async for event in base_client.send_message( - sample_message, configuration=cfg + response = SendMessageResponse() + response.task.CopyFrom(task) + mock_transport.send_message.return_value = response + + meta = {'test': 1} + stream = base_client.send_message(sample_message, request_metadata=meta) + events = [event async for event in stream] + + mock_transport.send_message.assert_called_once() + assert mock_transport.send_message.call_args[0][0].metadata == meta + assert not mock_transport.send_message_streaming.called + assert len(events) == 1 + stream_response, tracked_task = events[0] + assert stream_response.task.id == 'task-456' + assert tracked_task is not None + assert tracked_task.id == 'task-456' + + @pytest.mark.asyncio + async def test_send_message_non_streaming_agent_capability_false( + self, + base_client: BaseClient, + mock_transport: MagicMock, + sample_message: Message, + ) -> None: + base_client._card.capabilities.streaming = False + task = Task( + id='task-789', + context_id='ctx-101', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), ) - ] - - mock_transport.send_message.assert_called_once() - assert not mock_transport.send_message_streaming.called - assert len(events) == 1 - stream_response, _ = events[0] - assert stream_response.task.id == 'task-cfg-ns-1' - - params = mock_transport.send_message.call_args[0][0] - assert params.configuration.history_length == 2 - assert params.configuration.blocking is False - assert params.configuration.accepted_output_modes == ['application/json'] - - -@pytest.mark.asyncio -async def test_send_message_callsite_config_overrides_streaming( - base_client: BaseClient, mock_transport: MagicMock, sample_message: Message -): - base_client._config.streaming = True - base_client._card.capabilities.streaming = True - - async def create_stream(*args, **kwargs): + response = SendMessageResponse() + response.task.CopyFrom(task) + mock_transport.send_message.return_value = response + + events = [ + event async for event in base_client.send_message(sample_message) + ] + + mock_transport.send_message.assert_called_once() + assert not mock_transport.send_message_streaming.called + assert len(events) == 1 + stream_response, tracked_task = events[0] + assert stream_response is not None + assert tracked_task is not None + assert tracked_task.id == 'task-789' + + @pytest.mark.asyncio + async def test_send_message_callsite_config_overrides_non_streaming( + self, + base_client: BaseClient, + mock_transport: MagicMock, + sample_message: Message, + ): + base_client._config.streaming = False task = Task( - id='task-cfg-s-1', - context_id='ctx-cfg-s-1', + id='task-cfg-ns-1', + context_id='ctx-cfg-ns-1', status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), ) - stream_response = StreamResponse() - stream_response.task.CopyFrom(task) - yield stream_response - - mock_transport.send_message_streaming.return_value = create_stream() - - cfg = SendMessageConfiguration( - history_length=0, - blocking=True, - accepted_output_modes=['text/plain'], - ) - events = [ - event - async for event in base_client.send_message( - sample_message, configuration=cfg + response = SendMessageResponse() + response.task.CopyFrom(task) + mock_transport.send_message.return_value = response + + cfg = SendMessageConfiguration( + history_length=2, + blocking=False, + accepted_output_modes=['application/json'], + ) + events = [ + event + async for event in base_client.send_message( + sample_message, configuration=cfg + ) + ] + + mock_transport.send_message.assert_called_once() + assert not mock_transport.send_message_streaming.called + assert len(events) == 1 + stream_response, _ = events[0] + assert stream_response.task.id == 'task-cfg-ns-1' + + params = mock_transport.send_message.call_args[0][0] + assert params.configuration.history_length == 2 + assert params.configuration.blocking is False + assert params.configuration.accepted_output_modes == [ + 'application/json' + ] + + @pytest.mark.asyncio + async def test_send_message_callsite_config_overrides_streaming( + self, + base_client: BaseClient, + mock_transport: MagicMock, + sample_message: Message, + ): + base_client._config.streaming = True + base_client._card.capabilities.streaming = True + + async def create_stream(*args, **kwargs): + task = Task( + id='task-cfg-s-1', + context_id='ctx-cfg-s-1', + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + stream_response = StreamResponse() + stream_response.task.CopyFrom(task) + yield stream_response + + mock_transport.send_message_streaming.return_value = create_stream() + + cfg = SendMessageConfiguration( + history_length=0, + blocking=True, + accepted_output_modes=['text/plain'], ) - ] - - mock_transport.send_message_streaming.assert_called_once() - assert not mock_transport.send_message.called - assert len(events) == 1 - stream_response, _ = events[0] - assert stream_response.task.id == 'task-cfg-s-1' - - params = mock_transport.send_message_streaming.call_args[0][0] - assert params.configuration.history_length == 0 - assert params.configuration.blocking is True - assert params.configuration.accepted_output_modes == ['text/plain'] + events = [ + event + async for event in base_client.send_message( + sample_message, configuration=cfg + ) + ] + + mock_transport.send_message_streaming.assert_called_once() + assert not mock_transport.send_message.called + assert len(events) == 1 + stream_response, _ = events[0] + assert stream_response.task.id == 'task-cfg-s-1' + + params = mock_transport.send_message_streaming.call_args[0][0] + assert params.configuration.history_length == 0 + assert params.configuration.blocking is True + assert params.configuration.accepted_output_modes == ['text/plain'] diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index a29fa38f..dbfa7cf7 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -1,5 +1,6 @@ """Tests for the ClientFactory.""" +from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, MagicMock, patch import typing @@ -8,7 +9,12 @@ from a2a.client import ClientConfig, ClientFactory from a2a.client.client_factory import TransportProducer -from a2a.client.transports import JsonRpcTransport, RestTransport +from a2a.client.transports import ( + JsonRpcTransport, + RestTransport, + ClientTransport, +) +from a2a.client.transports.tenant_decorator import TenantTransportDecorator from a2a.types.a2a_pb2 import ( AgentCapabilities, AgentCard, @@ -284,3 +290,18 @@ async def test_client_factory_connect_with_consumers_and_interceptors( call_args = mock_base_client.call_args[0] assert call_args[3] == [consumer1] assert call_args[4] == [interceptor1] + + +def test_client_factory_applies_tenant_decorator(base_agent_card: AgentCard): + """Verify that the factory applies TenantTransportDecorator when tenant is present.""" + base_agent_card.supported_interfaces[0].tenant = 'my-tenant' + config = ClientConfig( + httpx_client=httpx.AsyncClient(), + supported_protocol_bindings=[TransportProtocol.JSONRPC], + ) + factory = ClientFactory(config) + client = factory.create(base_agent_card) + + assert isinstance(client._transport, TenantTransportDecorator) # type: ignore[attr-defined] + assert client._transport._tenant == 'my-tenant' # type: ignore[attr-defined] + assert isinstance(client._transport._base, JsonRpcTransport) # type: ignore[attr-defined] diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 8d395457..fd6899e6 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -15,10 +15,17 @@ AgentCapabilities, AgentCard, AgentInterface, + CancelTaskRequest, + CreateTaskPushNotificationConfigRequest, DeleteTaskPushNotificationConfigRequest, GetExtendedAgentCardRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, ListTaskPushNotificationConfigsRequest, + ListTasksRequest, + Message, SendMessageRequest, + SubscribeToTaskRequest, ) from a2a.utils.constants import TransportProtocol from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP @@ -310,7 +317,7 @@ async def test_get_card_with_extended_card_support_with_extensions( await client.get_extended_agent_card(request, extensions=extensions) mock_send_get_request.assert_called_once() - _, _, mock_kwargs = mock_send_get_request.call_args[0] + _, _, _, mock_kwargs = mock_send_get_request.call_args[0] _assert_extensions_header( mock_kwargs, @@ -404,3 +411,226 @@ async def test_delete_task_push_notification_config_success( f'/tasks/{task_id}/pushNotificationConfigs/config-1' in call_args[0][1] ) + + +class TestRestTransportTenant: + """Tests for tenant path prepending in RestTransport.""" + + @pytest.mark.parametrize( + 'method_name, request_obj, expected_path', + [ + ( + 'send_message', + SendMessageRequest( + tenant='my-tenant', + message=create_text_message_object(content='hi'), + ), + '/my-tenant/message:send', + ), + ( + 'list_tasks', + ListTasksRequest(tenant='my-tenant'), + '/my-tenant/tasks', + ), + ( + 'get_task', + GetTaskRequest(tenant='my-tenant', id='task-123'), + '/my-tenant/tasks/task-123', + ), + ( + 'cancel_task', + CancelTaskRequest(tenant='my-tenant', id='task-123'), + '/my-tenant/tasks/task-123:cancel', + ), + ( + 'create_task_push_notification_config', + CreateTaskPushNotificationConfigRequest( + tenant='my-tenant', task_id='task-123' + ), + '/my-tenant/tasks/task-123/pushNotificationConfigs', + ), + ( + 'get_task_push_notification_config', + GetTaskPushNotificationConfigRequest( + tenant='my-tenant', task_id='task-123', id='cfg-1' + ), + '/my-tenant/tasks/task-123/pushNotificationConfigs/cfg-1', + ), + ( + 'list_task_push_notification_configs', + ListTaskPushNotificationConfigsRequest( + tenant='my-tenant', task_id='task-123' + ), + '/my-tenant/tasks/task-123/pushNotificationConfigs', + ), + ( + 'delete_task_push_notification_config', + DeleteTaskPushNotificationConfigRequest( + tenant='my-tenant', task_id='task-123', id='cfg-1' + ), + '/my-tenant/tasks/task-123/pushNotificationConfigs/cfg-1', + ), + ], + ) + @pytest.mark.asyncio + async def test_rest_methods_prepend_tenant( + self, + method_name, + request_obj, + expected_path, + mock_httpx_client, + mock_agent_card, + ): + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + url='http://agent.example.com/api', + ) + + # 1. Get the method dynamically + method = getattr(client, method_name) + + # 2. Setup mocks + mock_httpx_client.build_request.return_value = MagicMock( + spec=httpx.Request + ) + mock_httpx_client.send.return_value = AsyncMock( + spec=httpx.Response, + status_code=200, + json=MagicMock(return_value={}), + ) + + # 3. Call the method + await method(request=request_obj) + + # 4. Verify the URL + args, _ = mock_httpx_client.build_request.call_args + assert args[1] == f'http://agent.example.com/api{expected_path}' + + @pytest.mark.asyncio + async def test_rest_get_extended_agent_card_prepend_tenant( + self, + mock_httpx_client, + mock_agent_card, + ): + mock_agent_card.capabilities.extended_agent_card = True + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + url='http://agent.example.com/api', + ) + + request = GetExtendedAgentCardRequest(tenant='my-tenant') + + # 1. Setup mocks + mock_httpx_client.build_request.return_value = MagicMock( + spec=httpx.Request + ) + mock_httpx_client.send.return_value = AsyncMock( + spec=httpx.Response, + status_code=200, + json=MagicMock(return_value={}), + ) + + # 2. Call the method + await client.get_extended_agent_card(request=request) + + # 3. Verify the URL + args, _ = mock_httpx_client.build_request.call_args + assert ( + args[1] + == 'http://agent.example.com/api/my-tenant/extendedAgentCard' + ) + + @pytest.mark.asyncio + async def test_rest_get_task_prepend_empty_tenant( + self, + mock_httpx_client, + mock_agent_card, + ): + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + url='http://agent.example.com/api', + ) + + request = GetTaskRequest(tenant='', id='task-123') + + # 1. Setup mocks + mock_httpx_client.build_request.return_value = MagicMock( + spec=httpx.Request + ) + mock_httpx_client.send.return_value = AsyncMock( + spec=httpx.Response, + status_code=200, + json=MagicMock(return_value={}), + ) + + # 2. Call the method + await client.get_task(request=request) + + # 3. Verify the URL + args, _ = mock_httpx_client.build_request.call_args + assert args[1] == f'http://agent.example.com/api/tasks/task-123' + + @pytest.mark.parametrize( + 'method_name, request_obj, expected_path', + [ + ( + 'subscribe', + SubscribeToTaskRequest(tenant='my-tenant', id='task-123'), + '/my-tenant/tasks/task-123:subscribe', + ), + ( + 'send_message_streaming', + SendMessageRequest( + tenant='my-tenant', + message=create_text_message_object(content='hi'), + ), + '/my-tenant/message:stream', + ), + ], + ) + @pytest.mark.asyncio + @patch('a2a.client.transports.http_helpers.aconnect_sse') + async def test_rest_streaming_methods_prepend_tenant( + self, + mock_aconnect_sse, + method_name, + request_obj, + expected_path, + mock_httpx_client, + mock_agent_card, + ): + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + url='http://agent.example.com/api', + ) + + # 1. Get the method dynamically + method = getattr(client, method_name) + + # 2. Setup mocks + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.response = MagicMock(spec=httpx.Response) + mock_event_source.response.raise_for_status.return_value = None + + async def empty_aiter(): + if False: + yield + + mock_event_source.aiter_sse.return_value = empty_aiter() + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + # 3. Call the method + async for _ in method(request=request_obj): + pass + + # 4. Verify the URL + mock_aconnect_sse.assert_called_once() + args, _ = mock_aconnect_sse.call_args + # url is 3rd positional argument in aconnect_sse(client, method, url, ...) + assert args[2] == f'http://agent.example.com/api{expected_path}' diff --git a/tests/client/transports/test_tenant_decorator.py b/tests/client/transports/test_tenant_decorator.py new file mode 100644 index 00000000..f544d676 --- /dev/null +++ b/tests/client/transports/test_tenant_decorator.py @@ -0,0 +1,129 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock + +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.tenant_decorator import TenantTransportDecorator +from a2a.types.a2a_pb2 import ( + AgentCard, + CancelTaskRequest, + CreateTaskPushNotificationConfigRequest, + DeleteTaskPushNotificationConfigRequest, + GetExtendedAgentCardRequest, + GetTaskPushNotificationConfigRequest, + GetTaskRequest, + ListTaskPushNotificationConfigsRequest, + ListTasksRequest, + Message, + Part, + SendMessageRequest, + StreamResponse, + SubscribeToTaskRequest, +) + + +@pytest.fixture +def mock_transport() -> AsyncMock: + return AsyncMock(spec=ClientTransport) + + +class TestTenantTransportDecorator: + @pytest.mark.asyncio + async def test_resolve_tenant_logic( + self, mock_transport: AsyncMock + ) -> None: + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + + # Case 1: Tenant already set on request + assert decorator._resolve_tenant('existing-tenant') == 'existing-tenant' + + # Case 2: Tenant not set (empty string) + assert decorator._resolve_tenant('') == tenant_id + + @pytest.mark.asyncio + async def test_resolve_tenant_logic_empty_tenant( + self, mock_transport: AsyncMock + ) -> None: + decorator = TenantTransportDecorator(mock_transport, '') + + # Case 1: Tenant already set on request + assert decorator._resolve_tenant('existing-tenant') == 'existing-tenant' + + # Case 2: Tenant not set (empty string) + assert decorator._resolve_tenant('') == '' + + @pytest.mark.parametrize( + 'method_name, request_obj', + [ + ( + 'send_message', + SendMessageRequest(message=Message(parts=[Part(text='hello')])), + ), + ( + 'get_task', + GetTaskRequest(id='t1'), + ), + ( + 'list_tasks', + ListTasksRequest(), + ), + ( + 'cancel_task', + CancelTaskRequest(id='t1'), + ), + ( + 'create_task_push_notification_config', + CreateTaskPushNotificationConfigRequest(task_id='t1'), + ), + ( + 'get_task_push_notification_config', + GetTaskPushNotificationConfigRequest(task_id='t1', id='c1'), + ), + ( + 'list_task_push_notification_configs', + ListTaskPushNotificationConfigsRequest(task_id='t1'), + ), + ( + 'delete_task_push_notification_config', + DeleteTaskPushNotificationConfigRequest(task_id='t1', id='c1'), + ), + ('get_extended_agent_card', GetExtendedAgentCardRequest()), + ], + ) + @pytest.mark.asyncio + async def test_methods( + self, mock_transport: AsyncMock, method_name, request_obj + ) -> None: + """Test that tenant is set on the request for all methods.""" + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + mock_method = getattr(mock_transport, method_name) + + await getattr(decorator, method_name)(request_obj) + + mock_method.assert_called_once() + assert mock_transport.mock_calls[0][0] == method_name + assert request_obj.tenant == tenant_id + + @pytest.mark.asyncio + async def test_streaming_methods(self, mock_transport: AsyncMock) -> None: + """Test that tenant is set on the request for streaming methods.""" + tenant_id = 'test-tenant' + decorator = TenantTransportDecorator(mock_transport, tenant_id) + + async def mock_stream(*args, **kwargs): + yield StreamResponse() + + # Test subscribe + mock_transport.subscribe.return_value = mock_stream() + request_sub = SubscribeToTaskRequest(id='t1') + async for _ in decorator.subscribe(request_sub): + pass + assert request_sub.tenant == tenant_id + + # Test send_message_streaming + mock_transport.send_message_streaming.return_value = mock_stream() + request_msg = SendMessageRequest() + async for _ in decorator.send_message_streaming(request_msg): + pass + assert request_msg.tenant == tenant_id diff --git a/tests/integration/test_tenant.py b/tests/integration/test_tenant.py new file mode 100644 index 00000000..aef0289d --- /dev/null +++ b/tests/integration/test_tenant.py @@ -0,0 +1,160 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +import httpx +from a2a.types.a2a_pb2 import ( + AgentCard, + AgentInterface, + SendMessageRequest, + Message, + GetTaskRequest, + AgentCapabilities, +) +from a2a.client.transports import RestTransport, JsonRpcTransport, GrpcTransport +from a2a.client.transports.tenant_decorator import TenantTransportDecorator +from a2a.client import ClientConfig, ClientFactory +from a2a.utils.constants import TransportProtocol + + +@pytest.fixture +def agent_card(): + return AgentCard( + supported_interfaces=[ + AgentInterface( + url='http://example.com/rest', + protocol_binding=TransportProtocol.HTTP_JSON, + tenant='tenant-1', + ), + AgentInterface( + url='http://example.com/jsonrpc', + protocol_binding=TransportProtocol.JSONRPC, + tenant='tenant-2', + ), + AgentInterface( + url='http://example.com/grpc', + protocol_binding=TransportProtocol.GRPC, + tenant='tenant-3', + ), + ], + capabilities=AgentCapabilities(streaming=True), + ) + + +@pytest.mark.asyncio +async def test_tenant_decorator_rest(agent_card): + mock_httpx = AsyncMock(spec=httpx.AsyncClient) + mock_httpx.build_request.return_value = MagicMock() + mock_httpx.send.return_value = MagicMock( + status_code=200, json=lambda: {'message': {}} + ) + + config = ClientConfig( + httpx_client=mock_httpx, + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + factory = ClientFactory(config) + client = factory.create(agent_card) + + assert isinstance(client._transport, TenantTransportDecorator) + assert client._transport._tenant == 'tenant-1' + + # Test SendMessage (POST) - Use transport directly to avoid streaming complexity in mock + request = SendMessageRequest(message=Message(parts=[{'text': 'hi'}])) + await client._transport.send_message(request) + + # Check that tenant was populated in request + assert request.tenant == 'tenant-1' + + # Check that path was prepended in the underlying transport + mock_httpx.build_request.assert_called() + send_call = next( + c + for c in mock_httpx.build_request.call_args_list + if 'message:send' in c.args[1] + ) + args, kwargs = send_call + assert args[1] == 'http://example.com/rest/tenant-1/message:send' + assert 'tenant' in kwargs['json'] + + +@pytest.mark.asyncio +async def test_tenant_decorator_jsonrpc(agent_card): + mock_httpx = AsyncMock(spec=httpx.AsyncClient) + mock_httpx.build_request.return_value = MagicMock() + mock_httpx.send.return_value = MagicMock( + status_code=200, + json=lambda: {'result': {'message': {}}, 'id': '1', 'jsonrpc': '2.0'}, + ) + + config = ClientConfig( + httpx_client=mock_httpx, + supported_protocol_bindings=[TransportProtocol.JSONRPC], + ) + factory = ClientFactory(config) + client = factory.create(agent_card) + + assert isinstance(client._transport, TenantTransportDecorator) + assert client._transport._tenant == 'tenant-2' + + request = SendMessageRequest(message=Message(parts=[{'text': 'hi'}])) + await client._transport.send_message(request) + + mock_httpx.build_request.assert_called() + _, kwargs = mock_httpx.build_request.call_args + assert kwargs['json']['params']['tenant'] == 'tenant-2' + + +@pytest.mark.asyncio +async def test_tenant_decorator_grpc(agent_card): + mock_channel = MagicMock() + config = ClientConfig( + grpc_channel_factory=lambda url: mock_channel, + supported_protocol_bindings=[TransportProtocol.GRPC], + ) + + with patch('a2a.types.a2a_pb2_grpc.A2AServiceStub') as mock_stub_class: + mock_stub = mock_stub_class.return_value + mock_stub.SendMessage = AsyncMock(return_value={'message': {}}) + + factory = ClientFactory(config) + client = factory.create(agent_card) + + assert isinstance(client._transport, TenantTransportDecorator) + assert client._transport._tenant == 'tenant-3' + + await client._transport.send_message( + SendMessageRequest(message=Message(parts=[{'text': 'hi'}])) + ) + + call_args = mock_stub.SendMessage.call_args + assert call_args[0][0].tenant == 'tenant-3' + + +@pytest.mark.asyncio +async def test_tenant_decorator_explicit_override(agent_card): + mock_httpx = AsyncMock(spec=httpx.AsyncClient) + mock_httpx.build_request.return_value = MagicMock() + mock_httpx.send.return_value = MagicMock( + status_code=200, json=lambda: {'message': {}} + ) + + config = ClientConfig( + httpx_client=mock_httpx, + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + factory = ClientFactory(config) + client = factory.create(agent_card) + + request = SendMessageRequest( + message=Message(parts=[{'text': 'hi'}]), tenant='explicit-tenant' + ) + await client._transport.send_message(request) + + assert request.tenant == 'explicit-tenant' + + send_call = next( + c + for c in mock_httpx.build_request.call_args_list + if 'message:send' in c.args[1] + ) + args, _ = send_call + assert args[1] == 'http://example.com/rest/explicit-tenant/message:send'