diff --git a/sdks/python/src/agent_control/__init__.py b/sdks/python/src/agent_control/__init__.py index 33658fb4..c03a3f66 100644 --- a/sdks/python/src/agent_control/__init__.py +++ b/sdks/python/src/agent_control/__init__.py @@ -78,11 +78,7 @@ async def handle_input(user_message: str) -> str: from ._control_registry import ( clear as clear_step_registry, ) - -# Import client and operations modules from .client import AgentControlClient - -# Import control decorator from .control_decorators import ControlSteerError, ControlViolationError, control from .evaluation import check_evaluation_with_local, evaluate_controls from .observability import ( @@ -98,8 +94,14 @@ async def handle_input(user_message: str) -> str: shutdown_observability, sync_shutdown_observability, ) - -# Import tracing and observability +from .telemetry import ( + clear_control_event_sink, + clear_trace_context_provider, + emit_control_events, + get_trace_context_from_provider, + set_control_event_sink, + set_trace_context_provider, +) from .tracing import ( get_current_span_id, get_current_trace_id, @@ -1305,6 +1307,12 @@ async def main(): "get_current_span_id", "with_trace", "is_otel_available", + "set_trace_context_provider", + "get_trace_context_from_provider", + "clear_trace_context_provider", + "set_control_event_sink", + "emit_control_events", + "clear_control_event_sink", # Observability "init_observability", "add_event", diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index 55f5efc1..d76af177 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -20,6 +20,7 @@ from ._state import state from .client import AgentControlClient from .observability import add_event, get_logger, is_observability_enabled +from .tracing import get_trace_and_span_ids from .validation import ensure_agent_name _logger = get_logger(__name__) @@ -291,6 +292,13 @@ async def check_evaluation_with_local( httpx.HTTPError: If server request fails """ normalized_name = ensure_agent_name(agent_name) + resolved_trace_id = trace_id + resolved_span_id = span_id + if trace_id is None or span_id is None: + current_trace_id, current_span_id = get_trace_and_span_ids() + resolved_trace_id = trace_id or current_trace_id + resolved_span_id = span_id or current_span_id + # Partition controls by local flag local_controls: list[_ControlAdapter] = [] parse_errors: list[ControlMatch] = [] @@ -389,8 +397,8 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: local_result, request, applicable_local_controls, - trace_id, - span_id, + resolved_trace_id, + resolved_span_id, agent_name=event_agent_name, ) @@ -409,10 +417,10 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: if _has_applicable_prefiltered_server_controls(server_control_payloads, request): request_payload = request.model_dump(mode="json", exclude_none=True) headers: dict[str, str] = {} - if trace_id: - headers["X-Trace-Id"] = trace_id - if span_id: - headers["X-Span-Id"] = span_id + if resolved_trace_id: + headers["X-Trace-Id"] = resolved_trace_id + if resolved_span_id: + headers["X-Span-Id"] = resolved_span_id response = await client.http_client.post( "/api/v1/evaluation", diff --git a/sdks/python/src/agent_control/telemetry/__init__.py b/sdks/python/src/agent_control/telemetry/__init__.py new file mode 100644 index 00000000..8933553d --- /dev/null +++ b/sdks/python/src/agent_control/telemetry/__init__.py @@ -0,0 +1,27 @@ +"""Telemetry interfaces for provider-agnostic tracing and event emission.""" + +from .event_sink import ( + ControlEventSink, + clear_control_event_sink, + emit_control_events, + set_control_event_sink, +) +from .trace_context import ( + TraceContext, + TraceContextProvider, + clear_trace_context_provider, + get_trace_context_from_provider, + set_trace_context_provider, +) + +__all__ = [ + "ControlEventSink", + "TraceContext", + "TraceContextProvider", + "clear_control_event_sink", + "clear_trace_context_provider", + "emit_control_events", + "get_trace_context_from_provider", + "set_control_event_sink", + "set_trace_context_provider", +] diff --git a/sdks/python/src/agent_control/telemetry/event_sink.py b/sdks/python/src/agent_control/telemetry/event_sink.py new file mode 100644 index 00000000..b36e9c13 --- /dev/null +++ b/sdks/python/src/agent_control/telemetry/event_sink.py @@ -0,0 +1,33 @@ +"""Provider-agnostic sink for merged control execution events.""" + +from collections.abc import Callable + +from agent_control_models import ControlExecutionEvent + +ControlEventSink = Callable[[list[ControlExecutionEvent]], None] + +_control_event_sink: ControlEventSink | None = None + + +def set_control_event_sink(sink: ControlEventSink | None) -> None: + """Register a sink for merged control execution events.""" + global _control_event_sink + _control_event_sink = sink + + +def emit_control_events(events: list[ControlExecutionEvent]) -> None: + """Emit merged control execution events to the registered sink.""" + if not events or _control_event_sink is None: + return + + try: + _control_event_sink(events) + except Exception: + # Sink failures should not break control evaluation. + pass + + +def clear_control_event_sink() -> None: + """Clear the registered control event sink.""" + global _control_event_sink + _control_event_sink = None diff --git a/sdks/python/src/agent_control/telemetry/trace_context.py b/sdks/python/src/agent_control/telemetry/trace_context.py new file mode 100644 index 00000000..a871fb29 --- /dev/null +++ b/sdks/python/src/agent_control/telemetry/trace_context.py @@ -0,0 +1,55 @@ +"""Provider-agnostic trace context interface for external tracing systems.""" + +from collections.abc import Callable +from typing import TypedDict + + +class TraceContext(TypedDict): + """Resolved trace context for a control evaluation.""" + + trace_id: str + span_id: str + + +TraceContextProvider = Callable[[], TraceContext | None] + +_trace_context_provider: TraceContextProvider | None = None + + +def set_trace_context_provider(provider: TraceContextProvider | None) -> None: + """Register a provider that returns the current trace context.""" + global _trace_context_provider + _trace_context_provider = provider + + +def get_trace_context_from_provider() -> TraceContext | None: + """Return trace context from the registered provider, if any.""" + if _trace_context_provider is None: + return None + + try: + trace_context = _trace_context_provider() + except Exception: + # Provider failures should not break control evaluation. + return None + + if trace_context is None: + return None + + trace_id = trace_context.get("trace_id") + span_id = trace_context.get("span_id") + if not isinstance(trace_id, str) or not isinstance(span_id, str): + return None + if not trace_id or not span_id: + return None + + return { + "trace_id": trace_id, + "span_id": span_id, + } + + +def clear_trace_context_provider() -> None: + """Clear the registered trace context provider.""" + global _trace_context_provider + _trace_context_provider = None diff --git a/sdks/python/src/agent_control/tracing.py b/sdks/python/src/agent_control/tracing.py index 473b5633..47696b15 100644 --- a/sdks/python/src/agent_control/tracing.py +++ b/sdks/python/src/agent_control/tracing.py @@ -31,6 +31,8 @@ from contextlib import contextmanager from contextvars import ContextVar, Token +from .telemetry.trace_context import get_trace_context_from_provider + # Context variables for trace/span propagation _trace_id_var: ContextVar[str | None] = ContextVar("trace_id", default=None) _span_id_var: ContextVar[str | None] = ContextVar("span_id", default=None) @@ -94,8 +96,9 @@ def get_trace_and_span_ids() -> tuple[str, str]: Priority: 1. Context variable (set by with_trace or explicitly) - 2. OpenTelemetry context (if OTEL is installed and active) - 3. Generate new OTEL-compatible IDs + 2. External provider + 3. OpenTelemetry context (if OTEL is installed and active) + 4. Generate new OTEL-compatible IDs Returns: Tuple of (trace_id, span_id) - both are hex strings @@ -114,6 +117,11 @@ def get_trace_and_span_ids() -> tuple[str, str]: if trace_id is not None and span_id is not None: return trace_id, span_id + # Try external provider + trace_context = get_trace_context_from_provider() + if trace_context: + return trace_context["trace_id"], trace_context["span_id"] + # Try OpenTelemetry context otel_trace_id, otel_span_id = _get_otel_ids() @@ -136,6 +144,11 @@ def get_current_trace_id() -> str | None: if trace_id is not None: return trace_id + # Try external provider + trace_context = get_trace_context_from_provider() + if trace_context: + return trace_context["trace_id"] + # Try OpenTelemetry otel_trace_id, _ = _get_otel_ids() return otel_trace_id @@ -153,6 +166,11 @@ def get_current_span_id() -> str | None: if span_id is not None: return span_id + # Try external provider + trace_context = get_trace_context_from_provider() + if trace_context: + return trace_context["span_id"] + # Try OpenTelemetry _, otel_span_id = _get_otel_ids() return otel_span_id diff --git a/sdks/python/tests/test_event_sink.py b/sdks/python/tests/test_event_sink.py new file mode 100644 index 00000000..8013f4d6 --- /dev/null +++ b/sdks/python/tests/test_event_sink.py @@ -0,0 +1,59 @@ +"""Tests for the telemetry merged control event sink interface.""" + +from datetime import UTC, datetime + +from agent_control.telemetry.event_sink import ( + clear_control_event_sink, + emit_control_events, + set_control_event_sink, +) +from agent_control_models import ControlExecutionEvent + + +def _event() -> ControlExecutionEvent: + return ControlExecutionEvent( + control_execution_id="ce-1", + trace_id="a" * 32, + span_id="b" * 16, + agent_name="test-agent", + control_id=1, + control_name="pii_check", + check_stage="pre", + applies_to="llm_call", + action="allow", + matched=False, + confidence=0.95, + timestamp=datetime.now(UTC), + metadata={}, + ) + + +def teardown_function() -> None: + clear_control_event_sink() + + +def test_emit_control_events_calls_registered_sink() -> None: + seen: list[list[ControlExecutionEvent]] = [] + + def _sink(events: list[ControlExecutionEvent]) -> None: + seen.append(events) + + event = _event() + set_control_event_sink(_sink) + + emit_control_events([event]) + + assert seen == [[event]] + + +def test_emit_control_events_noops_without_sink() -> None: + emit_control_events([_event()]) + + +def test_emit_control_events_swallows_sink_failures() -> None: + def _sink(_events: list[ControlExecutionEvent]) -> None: + raise RuntimeError("boom") + + set_control_event_sink(_sink) + + emit_control_events([_event()]) diff --git a/sdks/python/tests/test_observability_updates.py b/sdks/python/tests/test_observability_updates.py index cdaaa6ce..bb11a5ae 100644 --- a/sdks/python/tests/test_observability_updates.py +++ b/sdks/python/tests/test_observability_updates.py @@ -10,6 +10,10 @@ _map_applies_to, _merge_results, ) +from agent_control.telemetry.trace_context import ( + clear_trace_context_provider, + set_trace_context_provider, +) from agent_control_models import ControlDefinition # ============================================================================= @@ -326,6 +330,9 @@ def test_fallback_warning_logged_only_once(self): class TestCheckEvaluationWithLocal: """Tests for check_evaluation_with_local event emission and non_matches.""" + def teardown_method(self) -> None: + clear_trace_context_provider() + @pytest.mark.asyncio async def test_emits_events_when_trace_context_provided(self): """Should emit observability events when trace_id and span_id are passed.""" @@ -398,7 +405,7 @@ async def test_emits_events_when_trace_context_provided(self): @pytest.mark.asyncio async def test_emits_events_without_trace_context(self): - """Should still emit events when trace_id/span_id not provided (fallback IDs).""" + """Should resolve trace context from the provider when IDs are omitted.""" from agent_control_models import EvaluationResponse, Step mock_response = EvaluationResponse( @@ -424,6 +431,12 @@ async def test_emits_events_without_trace_context(self): client = MagicMock() client.http_client = AsyncMock() step = Step(type="llm", name="test-step", input="hello") + set_trace_context_provider( + lambda: { + "trace_id": "a" * 32, + "span_id": "b" * 16, + } + ) with patch("agent_control.evaluation.ControlEngine", return_value=mock_engine), \ patch("agent_control.evaluation.list_evaluators", return_value=["regex"]), \ @@ -438,8 +451,8 @@ async def test_emits_events_without_trace_context(self): ) mock_emit.assert_called_once() call_args = mock_emit.call_args - assert call_args[0][3] is None # trace_id passed as None - assert call_args[0][4] is None # span_id passed as None + assert call_args[0][3] == "a" * 32 + assert call_args[0][4] == "b" * 16 @pytest.mark.asyncio async def test_forwards_trace_headers_to_server(self): @@ -492,6 +505,59 @@ async def test_forwards_trace_headers_to_server(self): assert headers["X-Trace-Id"] == "aaaa1111bbbb2222cccc3333dddd4444" assert headers["X-Span-Id"] == "eeee5555ffff6666" + @pytest.mark.asyncio + async def test_forwards_provider_trace_headers_to_server_when_ids_omitted(self): + """Server POST should resolve trace headers from the provider when omitted.""" + from agent_control_models import Step + + controls = [{ + "id": 1, + "name": "server-ctrl", + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "deny"}, + "execution": "server", + }, + }] + + mock_http_response = MagicMock() + mock_http_response.json.return_value = { + "is_safe": True, + "confidence": 1.0, + "matches": None, + "errors": None, + "non_matches": None, + } + mock_http_response.raise_for_status = MagicMock() + + client = MagicMock() + client.http_client = AsyncMock() + client.http_client.post = AsyncMock(return_value=mock_http_response) + step = Step(type="llm", name="test-step", input="hello") + set_trace_context_provider( + lambda: { + "trace_id": "c" * 32, + "span_id": "d" * 16, + } + ) + + with patch("agent_control.evaluation.list_evaluators", return_value=["regex"]): + await evaluation.check_evaluation_with_local( + client=client, + agent_name="agent-000000000001", + step=step, + stage="pre", + controls=controls, + ) + + call_kwargs = client.http_client.post.call_args + headers = call_kwargs.kwargs.get("headers", {}) + assert headers["X-Trace-Id"] == "c" * 32 + assert headers["X-Span-Id"] == "d" * 16 + # ============================================================================= # control_decorators non_matches dict conversion diff --git a/sdks/python/tests/test_trace_context.py b/sdks/python/tests/test_trace_context.py new file mode 100644 index 00000000..f08306e0 --- /dev/null +++ b/sdks/python/tests/test_trace_context.py @@ -0,0 +1,65 @@ +"""Tests for the telemetry trace context provider interface.""" + +from agent_control.telemetry.trace_context import ( + clear_trace_context_provider, + get_trace_context_from_provider, + set_trace_context_provider, +) + + +def teardown_function() -> None: + clear_trace_context_provider() + + +def test_get_trace_context_from_provider_returns_registered_context() -> None: + set_trace_context_provider( + lambda: { + "trace_id": "a" * 32, + "span_id": "b" * 16, + } + ) + + assert get_trace_context_from_provider() == { + "trace_id": "a" * 32, + "span_id": "b" * 16, + } + + +def test_get_trace_context_from_provider_returns_none_when_unset() -> None: + assert get_trace_context_from_provider() is None + + +def test_get_trace_context_from_provider_returns_none_when_provider_returns_none() -> None: + set_trace_context_provider(lambda: None) + + assert get_trace_context_from_provider() is None + + +def test_get_trace_context_from_provider_swallows_provider_failures() -> None: + def _raising_provider(): + raise RuntimeError("boom") + + set_trace_context_provider(_raising_provider) + + assert get_trace_context_from_provider() is None + + +def test_get_trace_context_from_provider_returns_none_for_invalid_shape() -> None: + set_trace_context_provider( # type: ignore[arg-type] + lambda: { + "trace_id": "a" * 32, + } + ) + + assert get_trace_context_from_provider() is None + + +def test_get_trace_context_from_provider_returns_none_for_empty_ids() -> None: + set_trace_context_provider( + lambda: { + "trace_id": "", + "span_id": "", + } + ) + + assert get_trace_context_from_provider() is None diff --git a/sdks/python/tests/test_tracing.py b/sdks/python/tests/test_tracing.py index 175cb7c4..97397b8d 100644 --- a/sdks/python/tests/test_tracing.py +++ b/sdks/python/tests/test_tracing.py @@ -2,6 +2,7 @@ import pytest +from agent_control.telemetry.trace_context import clear_trace_context_provider, set_trace_context_provider from agent_control.tracing import ( _generate_span_id, _generate_trace_id, @@ -17,6 +18,10 @@ ) +def teardown_function() -> None: + clear_trace_context_provider() + + class TestIdGeneration: """Tests for trace and span ID generation.""" @@ -132,6 +137,30 @@ def test_get_current_ids_without_context(self): assert trace_id is None or isinstance(trace_id, str) assert span_id is None or isinstance(span_id, str) + def test_get_current_trace_id_uses_provider(self): + """Test that get_current_trace_id uses external provider before OTEL fallback.""" + expected_trace = "a" * 32 + set_trace_context_provider( + lambda: { + "trace_id": expected_trace, + "span_id": "b" * 16, + } + ) + + assert get_current_trace_id() == expected_trace + + def test_get_current_span_id_uses_provider(self): + """Test that get_current_span_id uses external provider before OTEL fallback.""" + expected_span = "b" * 16 + set_trace_context_provider( + lambda: { + "trace_id": "a" * 32, + "span_id": expected_span, + } + ) + + assert get_current_span_id() == expected_span + class TestWithTraceContextManager: """Tests for the with_trace context manager.""" @@ -237,6 +266,23 @@ def test_get_trace_and_span_ids_uses_context(self): assert trace_id == expected_trace assert span_id == expected_span + def test_get_trace_and_span_ids_uses_provider_before_otel(self): + """Test that an external provider is checked before OTEL fallback.""" + expected_trace = "c" * 32 + expected_span = "d" * 16 + + set_trace_context_provider( + lambda: { + "trace_id": expected_trace, + "span_id": expected_span, + } + ) + + trace_id, span_id = get_trace_and_span_ids() + + assert trace_id == expected_trace + assert span_id == expected_span + class TestOtelAvailability: """Tests for OpenTelemetry availability detection."""