Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/a2a/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging

from collections.abc import Callable
from typing import Any
from typing import Any, cast

import httpx

Expand All @@ -20,6 +20,8 @@
AgentInterface,
)
from a2a.utils.constants import (
PROTOCOL_VERSION_CURRENT,
VERSION_HEADER,
TransportProtocol,
)

Expand Down Expand Up @@ -65,18 +67,24 @@ def __init__(
):
if consumers is None:
consumers = []

client = config.httpx_client or httpx.AsyncClient()
client.headers.setdefault(VERSION_HEADER, PROTOCOL_VERSION_CURRENT)
config.httpx_client = client

self._config = config
self._consumers = consumers
self._registry: dict[str, TransportProducer] = {}
self._register_defaults(config.supported_protocol_bindings)

def _register_defaults(self, supported: list[str]) -> None:
# Empty support list implies JSON-RPC only.

if TransportProtocol.JSONRPC in supported or not supported:
self.register(
TransportProtocol.JSONRPC,
lambda card, url, config, interceptors: JsonRpcTransport(
config.httpx_client or httpx.AsyncClient(),
cast('httpx.AsyncClient', config.httpx_client),
card,
url,
interceptors,
Expand All @@ -87,7 +95,7 @@ def _register_defaults(self, supported: list[str]) -> None:
self.register(
TransportProtocol.HTTP_JSON,
lambda card, url, config, interceptors: RestTransport(
config.httpx_client or httpx.AsyncClient(),
cast('httpx.AsyncClient', config.httpx_client),
card,
url,
interceptors,
Expand Down
12 changes: 8 additions & 4 deletions src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Task,
TaskPushNotificationConfig,
)
from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER
from a2a.utils.telemetry import SpanKind, trace_class


Expand Down Expand Up @@ -303,11 +304,14 @@ async def close(self) -> None:
def _get_grpc_metadata(
self,
extensions: list[str] | None = None,
) -> list[tuple[str, str]] | None:
) -> list[tuple[str, str]]:
"""Creates gRPC metadata for extensions."""
metadata = [(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT)]

extensions_to_use = extensions or self.extensions
if extensions_to_use:
return [
metadata.append(
(HTTP_EXTENSION_HEADER.lower(), ','.join(extensions_to_use))
]
return None
)

return metadata
4 changes: 4 additions & 0 deletions src/a2a/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ class TransportProtocol(str, Enum):

DEFAULT_MAX_CONTENT_LENGTH = 10 * 1024 * 1024 # 10MB
JSONRPC_PARSE_ERROR_CODE = -32700
VERSION_HEADER = 'A2A-Version'

PROTOCOL_VERSION_1_0 = '1.0'
PROTOCOL_VERSION_CURRENT = PROTOCOL_VERSION_1_0
61 changes: 44 additions & 17 deletions tests/client/transports/test_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from a2a.client.transports.grpc import GrpcTransport
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.utils.constants import VERSION_HEADER, PROTOCOL_VERSION_CURRENT
from a2a.types import a2a_pb2
from a2a.types.a2a_pb2 import (
AgentCapabilities,
Expand Down Expand Up @@ -217,10 +218,11 @@ async def test_send_message_task_response(
mock_grpc_stub.SendMessage.assert_awaited_once()
_, kwargs = mock_grpc_stub.SendMessage.call_args
assert kwargs['metadata'] == [
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v3',
)
),
]
assert response.HasField('task')
assert response.task.id == sample_task.id
Expand Down Expand Up @@ -266,10 +268,11 @@ async def test_send_message_message_response(
mock_grpc_stub.SendMessage.assert_awaited_once()
_, kwargs = mock_grpc_stub.SendMessage.call_args
assert kwargs['metadata'] == [
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
),
]
assert response.HasField('message')
assert response.message.message_id == sample_message.message_id
Expand Down Expand Up @@ -315,10 +318,11 @@ async def test_send_message_streaming( # noqa: PLR0913
mock_grpc_stub.SendStreamingMessage.assert_called_once()
_, kwargs = mock_grpc_stub.SendStreamingMessage.call_args
assert kwargs['metadata'] == [
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
),
]
# Responses are StreamResponse proto objects
assert responses[0].HasField('message')
Expand Down Expand Up @@ -350,10 +354,11 @@ async def test_get_task(
mock_grpc_stub.GetTask.assert_awaited_once_with(
a2a_pb2.GetTaskRequest(id=f'{sample_task.id}', history_length=None),
metadata=[
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
),
],
)
assert response.id == sample_task.id
Expand All @@ -378,10 +383,11 @@ async def test_list_tasks(
mock_grpc_stub.ListTasks.assert_awaited_once_with(
params,
metadata=[
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
),
],
)
assert result.total_size == 2
Expand All @@ -405,10 +411,11 @@ async def test_get_task_with_history(
id=f'{sample_task.id}', history_length=history_len
),
metadata=[
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
),
],
)

Expand All @@ -433,7 +440,8 @@ async def test_cancel_task(
mock_grpc_stub.CancelTask.assert_awaited_once_with(
a2a_pb2.CancelTaskRequest(id=f'{sample_task.id}'),
metadata=[
(HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3')
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3'),
],
)
assert response.status.state == TaskState.TASK_STATE_CANCELED
Expand Down Expand Up @@ -462,10 +470,11 @@ async def test_create_task_push_notification_config_with_valid_task(
mock_grpc_stub.CreateTaskPushNotificationConfig.assert_awaited_once_with(
request,
metadata=[
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
),
],
)
assert response.task_id == sample_task_push_notification_config.task_id
Expand Down Expand Up @@ -524,10 +533,11 @@ async def test_get_task_push_notification_config_with_valid_task(
id=config_id,
),
metadata=[
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
),
],
)
assert response.task_id == sample_task_push_notification_config.task_id
Expand Down Expand Up @@ -577,10 +587,11 @@ async def test_list_task_push_notification_configs(
mock_grpc_stub.ListTaskPushNotificationConfigs.assert_awaited_once_with(
a2a_pb2.ListTaskPushNotificationConfigsRequest(task_id='task-1'),
metadata=[
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
),
],
)
assert len(response.configs) == 1
Expand Down Expand Up @@ -609,10 +620,11 @@ async def test_delete_task_push_notification_config(
id='config-1',
),
metadata=[
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
),
],
)

Expand All @@ -623,32 +635,47 @@ async def test_delete_task_push_notification_config(
(
None,
None,
None,
[(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT)],
), # Case 1: No initial, No input
(
['ext1'],
None,
[(HTTP_EXTENSION_HEADER.lower(), 'ext1')],
[
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(HTTP_EXTENSION_HEADER.lower(), 'ext1'),
],
), # Case 2: Initial, No input
(
None,
['ext2'],
[(HTTP_EXTENSION_HEADER.lower(), 'ext2')],
[
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(HTTP_EXTENSION_HEADER.lower(), 'ext2'),
],
), # Case 3: No initial, Input
(
['ext1'],
['ext2'],
[(HTTP_EXTENSION_HEADER.lower(), 'ext2')],
[
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(HTTP_EXTENSION_HEADER.lower(), 'ext2'),
],
), # Case 4: Initial, Input (override)
(
['ext1'],
['ext2', 'ext3'],
[(HTTP_EXTENSION_HEADER.lower(), 'ext2,ext3')],
[
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(HTTP_EXTENSION_HEADER.lower(), 'ext2,ext3'),
],
), # Case 5: Initial, Multiple inputs (override)
(
['ext1', 'ext2'],
['ext3'],
[(HTTP_EXTENSION_HEADER.lower(), 'ext3')],
[
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
(HTTP_EXTENSION_HEADER.lower(), 'ext3'),
],
), # Case 6: Multiple initial, Single input (override)
],
)
Expand Down
11 changes: 11 additions & 0 deletions tests/utils/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,14 @@ def test_agent_card_constants():
def test_default_rpc_url():
"""Test default RPC URL constant."""
assert constants.DEFAULT_RPC_URL == '/'


def test_version_header():
"""Test version header constant."""
assert constants.VERSION_HEADER == 'A2A-Version'


def test_protocol_versions():
"""Test protocol version constants."""
assert constants.PROTOCOL_VERSION_1_0 == '1.0'
assert constants.PROTOCOL_VERSION_CURRENT == '1.0'
Loading