Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CancelTaskRequest,
CreateTaskPushNotificationConfigRequest,
DeleteTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
ListTaskPushNotificationConfigsRequest,
Expand Down Expand Up @@ -309,27 +310,30 @@
async for client_event in self._process_stream(stream):
yield client_event

async def get_extended_agent_card(
self,
request: GetExtendedAgentCardRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card.

This will fetch the authenticated card if necessary and update the
client's internal state with the new card.

Args:
request: The `GetExtendedAgentCardRequest` object specifying the request.
context: The client call context.
extensions: List of extensions to be activated.
signature_verifier: A callable used to verify the agent card's signatures.

Returns:
The `AgentCard` for the agent.
"""

Check notice on line 334 in src/a2a/client/base_client.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (326-334)
card = await self._transport.get_extended_agent_card(
request,
context=context,
extensions=extensions,
signature_verifier=signature_verifier,
Expand Down
2 changes: 2 additions & 0 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
CancelTaskRequest,
CreateTaskPushNotificationConfigRequest,
DeleteTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
ListTaskPushNotificationConfigsRequest,
Expand Down Expand Up @@ -225,20 +226,21 @@
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Resubscribes to a task's event stream."""
return
yield

@abstractmethod
async def get_extended_agent_card(
self,
request: GetExtendedAgentCardRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""

async def add_event_consumer(self, consumer: Consumer) -> None:

Check notice on line 243 in src/a2a/client/client.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/base.py (146-388)
"""Attaches additional consumers to the `Client`."""
self._consumers.append(consumer)

Expand Down
2 changes: 2 additions & 0 deletions src/a2a/client/transports/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CancelTaskRequest,
CreateTaskPushNotificationConfigRequest,
DeleteTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
ListTaskPushNotificationConfigsRequest,
Expand Down Expand Up @@ -142,19 +143,20 @@
extensions: list[str] | None = None,
) -> AsyncGenerator[StreamResponse]:
"""Reconnects to get task updates."""
return
yield

@abstractmethod
async def get_extended_agent_card(
self,
request: GetExtendedAgentCardRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the Extended AgentCard."""

@abstractmethod

Check notice on line 160 in src/a2a/client/transports/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (326-335)
async def close(self) -> None:
"""Closes the transport."""
6 changes: 4 additions & 2 deletions src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
from a2a.client.optionals import Channel
from a2a.client.transports.base import ClientTransport
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.types import a2a_pb2, a2a_pb2_grpc
from a2a.types import a2a_pb2_grpc
from a2a.types.a2a_pb2 import (
AgentCard,
CancelTaskRequest,
CreateTaskPushNotificationConfigRequest,
DeleteTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
ListTaskPushNotificationConfigsRequest,
Expand Down Expand Up @@ -274,16 +275,17 @@
)

@_handle_grpc_exception
async def get_extended_agent_card(
self,
request: GetExtendedAgentCardRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
card = await self.stub.GetExtendedAgentCard(

Check notice on line 287 in src/a2a/client/transports/grpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (326-388)
a2a_pb2.GetExtendedAgentCardRequest(),
request,
metadata=self._get_grpc_metadata(extensions),
)

Expand Down
2 changes: 1 addition & 1 deletion src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,19 +371,20 @@
context,
)
async for event in self._send_stream_request(
payload,
http_kwargs=modified_kwargs,
):
yield event

async def get_extended_agent_card(
self,
request: GetExtendedAgentCardRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""

Check notice on line 387 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (321-334)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
Expand All @@ -394,7 +395,6 @@
if not card.capabilities.extended_agent_card:
return card

request = GetExtendedAgentCardRequest()
rpc_request = JSONRPC20Request(
method='GetExtendedAgentCard',
params=json_format.MessageToDict(request),
Expand Down
4 changes: 3 additions & 1 deletion src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,34 @@
from google.protobuf.json_format import MessageToDict, Parse, ParseDict
from google.protobuf.message import Message

from a2a.client.errors import A2AClientError
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.client.transports.http_helpers import (
send_http_request,
send_http_stream_request,
)
from a2a.extensions.common import update_extension_header
from a2a.types.a2a_pb2 import (
AgentCard,
CancelTaskRequest,
CreateTaskPushNotificationConfigRequest,
DeleteTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
ListTaskPushNotificationConfigsRequest,
ListTaskPushNotificationConfigsResponse,
ListTasksRequest,
ListTasksResponse,
SendMessageRequest,
SendMessageResponse,
StreamResponse,
SubscribeToTaskRequest,
Task,
TaskPushNotificationConfig,
)
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP, MethodNotFoundError

Check notice on line 39 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (12-40)
from a2a.utils.telemetry import SpanKind, trace_class


Expand Down Expand Up @@ -317,73 +318,74 @@

async for event in self._send_stream_request(
'GET',
f'/tasks/{request.id}:subscribe',
http_kwargs=modified_kwargs,
):
yield event

async def get_extended_agent_card(
self,
request: GetExtendedAgentCardRequest,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the Extended AgentCard."""

Check notice on line 334 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (374-387)
modified_kwargs = update_extension_header(

Check notice on line 335 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/base.py (150-160)
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)

card = self.agent_card

if not card.capabilities.extended_agent_card:
return card
_, modified_kwargs = await self._apply_interceptors(
{},
MessageToDict(request, preserving_proto_field_name=True),
modified_kwargs,
context,
)
response_data = await self._send_get_request(
'/card', {}, modified_kwargs
)
response: AgentCard = ParseDict(response_data, AgentCard())

if signature_verifier:
signature_verifier(response)

# Update the transport's agent_card
self.agent_card = response
self._needs_extended_card = False
return response

async def close(self) -> None:
"""Closes the httpx client."""
await self.httpx_client.aclose()

async def _apply_interceptors(
self,
request_payload: dict[str, Any],
http_kwargs: dict[str, Any] | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
final_http_kwargs = http_kwargs or {}
final_request_payload = request_payload
# TODO: Implement interceptors for other transports
return final_request_payload, final_http_kwargs

def _get_http_args(
self, context: ClientCallContext | None
) -> dict[str, Any] | None:
return context.state.get('http_kwargs') if context else None

async def _prepare_send_message(
self,
request: SendMessageRequest,
context: ClientCallContext | None,
extensions: list[str] | None = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
payload = MessageToDict(request)

Check notice on line 388 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/grpc.py (278-287)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
Expand Down
4 changes: 3 additions & 1 deletion tests/client/transports/test_jsonrpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AgentInterface,
CancelTaskRequest,
DeleteTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
ListTaskPushNotificationConfigsRequest,
Expand Down Expand Up @@ -647,6 +648,7 @@ async def test_get_card_with_extended_card_support_with_extensions(
extended_card.CopyFrom(agent_card)
extended_card.name = 'Extended'

request = GetExtendedAgentCardRequest()
rpc_response = {
'id': '123',
'jsonrpc': '2.0',
Expand All @@ -656,7 +658,7 @@ async def test_get_card_with_extended_card_support_with_extensions(
client, '_send_request', new_callable=AsyncMock
) as mock_send_request:
mock_send_request.return_value = rpc_response
await client.get_extended_agent_card(extensions=extensions)
await client.get_extended_agent_card(request, extensions=extensions)

mock_send_request.assert_called_once()
_, mock_kwargs = mock_send_request.call_args[0]
Expand Down
4 changes: 3 additions & 1 deletion tests/client/transports/test_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
AgentCard,
AgentInterface,
DeleteTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
ListTaskPushNotificationConfigsRequest,
SendMessageRequest,
)
Expand Down Expand Up @@ -299,13 +300,14 @@ async def test_get_card_with_extended_card_support_with_extensions(
) # Extended card same for mock
mock_httpx_client.send.return_value = mock_response

request = GetExtendedAgentCardRequest()
with patch.object(
client, '_send_get_request', new_callable=AsyncMock
) as mock_send_get_request:
mock_send_get_request.return_value = json_format.MessageToDict(
agent_card
)
await client.get_extended_agent_card(extensions=extensions)
await client.get_extended_agent_card(request, extensions=extensions)

mock_send_get_request.assert_called_once()
_, _, mock_kwargs = mock_send_get_request.call_args[0]
Expand Down
17 changes: 11 additions & 6 deletions tests/integration/test_client_server_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AgentCard,
AgentInterface,
CancelTaskRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
Message,
Expand Down Expand Up @@ -950,7 +951,9 @@ async def test_http_transport_get_authenticated_card(
agent_card=agent_card,
url=agent_card.supported_interfaces[0].url,
)
result = await transport.get_extended_agent_card()
result = await transport.get_extended_agent_card(
GetExtendedAgentCardRequest()
)
assert result.name == extended_agent_card.name
assert transport.agent_card is not None
assert transport.agent_card.name == extended_agent_card.name
Expand All @@ -976,7 +979,9 @@ def channel_factory(address: str) -> Channel:
# The transport starts with a minimal card, get_extended_agent_card() fetches the full one
assert transport.agent_card is not None
transport.agent_card.capabilities.extended_agent_card = True
result = await transport.get_extended_agent_card()
result = await transport.get_extended_agent_card(
GetExtendedAgentCardRequest()
)

assert result.name == agent_card.name
assert transport.agent_card.name == agent_card.name
Expand Down Expand Up @@ -1160,7 +1165,7 @@ async def test_json_transport_get_signed_extended_card(
create_key_provider(public_key), ['HS384', 'ES256']
)
result = await transport.get_extended_agent_card(
signature_verifier=signature_verifier
GetExtendedAgentCardRequest(), signature_verifier=signature_verifier
)
assert result.name == extended_agent_card.name
assert result.signatures is not None
Expand Down Expand Up @@ -1239,7 +1244,7 @@ async def test_json_transport_get_signed_base_and_extended_cards(

# 3. Fetch extended card via transport
result = await transport.get_extended_agent_card(
signature_verifier=signature_verifier
GetExtendedAgentCardRequest(), signature_verifier=signature_verifier
)
assert result.name == extended_agent_card.name
assert len(result.signatures) == 1
Expand Down Expand Up @@ -1316,7 +1321,7 @@ async def test_rest_transport_get_signed_card(

# 3. Fetch extended card
result = await transport.get_extended_agent_card(
signature_verifier=signature_verifier
GetExtendedAgentCardRequest(), signature_verifier=signature_verifier
)
assert result.name == extended_agent_card.name
assert result.signatures is not None
Expand Down Expand Up @@ -1378,7 +1383,7 @@ def channel_factory(address: str) -> Channel:
create_key_provider(public_key), ['HS384', 'ES256', 'RS256']
)
result = await transport.get_extended_agent_card(
signature_verifier=signature_verifier
GetExtendedAgentCardRequest(), signature_verifier=signature_verifier
)
assert result.signatures is not None
assert len(result.signatures) == 1
Expand Down
Loading