diff --git a/cpex/framework/base.py b/cpex/framework/base.py index 0a29be2..8d38bb4 100644 --- a/cpex/framework/base.py +++ b/cpex/framework/base.py @@ -65,6 +65,11 @@ def __init__( ) -> None: """Initialize a plugin with a configuration and context. + The plugin receives the config directly. When the plugin is + registered with the Manager, the PluginRef retains the + authoritative config and gives the plugin a defensive copy, + so the Manager never trusts config read back from the plugin. + Args: config: The plugin configuration hook_payloads: optional mapping of hookpoints to payloads for the plugin. @@ -267,11 +272,18 @@ class PluginRef: ['ref', 'test'] """ - def __init__(self, plugin: Plugin): + def __init__(self, plugin: Plugin, trusted_config: PluginConfig | None = None): """Initialize a plugin reference. + Stores the authoritative config separately from the plugin. + The Manager reads policy-sensitive fields (capabilities, mode, + on_error) from the trusted config, never from the plugin. + Args: plugin: The plugin to reference. + trusted_config: The authoritative config retained by the + Manager. If not provided, falls back to plugin.config + (for backward compatibility in tests). Examples: >>> from cpex.framework import PluginConfig @@ -293,6 +305,7 @@ def __init__(self, plugin: Plugin): True """ self._plugin = plugin + self._trusted_config = trusted_config or plugin.config self._uuid = uuid.uuid4() @property @@ -304,6 +317,15 @@ def plugin(self) -> Plugin: """ return self._plugin + @property + def trusted_config(self) -> PluginConfig: + """Return the authoritative config held by the Manager. + + Returns: + The trusted PluginConfig (not the plugin's copy). + """ + return self._trusted_config + @property def uuid(self) -> str: """Return the plugin's UUID. @@ -320,7 +342,7 @@ def priority(self) -> int: Returns: Plugin's priority. """ - return self._plugin.priority + return self._trusted_config.priority @property def name(self) -> str: @@ -329,7 +351,7 @@ def name(self) -> str: Returns: Plugin's name. """ - return self._plugin.name + return self._trusted_config.name @property def hooks(self) -> list[str]: @@ -338,7 +360,7 @@ def hooks(self) -> list[str]: Returns: Plugin's configured hooks. """ - return self._plugin.hooks + return self._trusted_config.hooks @property def tags(self) -> list[str]: @@ -347,7 +369,7 @@ def tags(self) -> list[str]: Returns: Plugin's tags. """ - return self._plugin.tags + return self._trusted_config.tags @property def conditions(self) -> list[PluginCondition] | None: @@ -356,7 +378,7 @@ def conditions(self) -> list[PluginCondition] | None: Returns: Plugin's conditions for operation. """ - return self._plugin.conditions + return self._trusted_config.conditions @property def mode(self) -> PluginMode: @@ -365,7 +387,7 @@ def mode(self) -> PluginMode: Returns: Plugin's mode. """ - return self.plugin.mode + return self._trusted_config.mode @property def on_error(self) -> OnError: @@ -374,7 +396,16 @@ def on_error(self) -> OnError: Returns: Plugin's on_error behavior. """ - return self.plugin.config.on_error + return self._trusted_config.on_error + + @property + def capabilities(self) -> frozenset[str]: + """Return the plugin's declared capabilities. + + Returns: + The authoritative capability set from the trusted config. + """ + return self._trusted_config.capabilities class HookRef: @@ -439,13 +470,16 @@ def __init__(self, hook: str, plugin_ref: PluginRef): ) # Validate hook method signature (parameter count and async) - self._validate_hook_signature(hook, self._func, plugin_ref.plugin.name) + param_count = self._validate_hook_signature(hook, self._func, plugin_ref.plugin.name) + + # Store whether the plugin accepts extensions as a third argument + self._accepts_extensions = param_count == 3 - def _validate_hook_signature(self, hook: str, func: Callable, plugin_name: str) -> None: + def _validate_hook_signature(self, hook: str, func: Callable, plugin_name: str) -> int: """Validate that the hook method has the correct signature. Checks: - 1. Method accepts correct number of parameters (self, payload, context) + 1. Method accepts 2 parameters (payload, context) or 3 (payload, context, extensions) 2. Method is async (returns coroutine) Args: @@ -453,6 +487,9 @@ def _validate_hook_signature(self, hook: str, func: Callable, plugin_name: str) func: The hook method to validate plugin_name: Name of the plugin (for error messages) + Returns: + The number of parameters (2 or 3). + Raises: PluginError: If the signature is invalid """ @@ -462,14 +499,16 @@ def _validate_hook_signature(self, hook: str, func: Callable, plugin_name: str) sig = inspect.signature(func) params = list(sig.parameters.values()) - # Check parameter count (should be: payload, context) + # Check parameter count (should be: payload, context[, extensions]) # Note: 'self' is not included in bound method signatures - if len(params) != 2: + if len(params) not in (2, 3): raise PluginError( error=PluginErrorModel( message=f"Plugin '{plugin_name}' hook '{hook}' has invalid signature. " - f"Expected 2 parameters (payload, context), got {len(params)}: {list(sig.parameters.keys())}. " - f"Correct signature: async def {hook}(self, payload: PayloadType, context: PluginContext) -> ResultType", + f"Expected 2 or 3 parameters (payload, context[, extensions]), " + f"got {len(params)}: {list(sig.parameters.keys())}. " + f"Correct signature: async def {hook}(self, payload: PayloadType, " + f"context: PluginContext[, extensions: Extensions]) -> ResultType", plugin_name=plugin_name, ) ) @@ -485,6 +524,8 @@ def _validate_hook_signature(self, hook: str, func: Callable, plugin_name: str) ) ) + return len(params) + # ========== OPTIONAL: Type Hint Validation ========== # Uncomment to enable strict type checking of payload and return types. # This validates that type hints match the expected types from the hook registry. @@ -608,7 +649,16 @@ def name(self) -> str: return self._hook @property - def hook(self) -> Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] | None: + def accepts_extensions(self) -> bool: + """Whether the hook method accepts extensions as a third argument. + + Returns: + True if the hook signature has 3 parameters (payload, context, extensions). + """ + return self._accepts_extensions + + @property + def hook(self) -> Callable[..., Awaitable[PluginResult]] | None: """The hooking function that can be invoked within the reference. Returns: diff --git a/cpex/framework/cmf/__init__.py b/cpex/framework/cmf/__init__.py new file mode 100644 index 0000000..f446938 --- /dev/null +++ b/cpex/framework/cmf/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/cmf/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Common Message Format (CMF) Package. +Provides the canonical, provider-agnostic message representation +for interactions between users, agents, tools, and language models. +""" diff --git a/cpex/framework/cmf/message.py b/cpex/framework/cmf/message.py new file mode 100644 index 0000000..74e4673 --- /dev/null +++ b/cpex/framework/cmf/message.py @@ -0,0 +1,945 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/cmf/message.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Common Message Format (CMF) message models. +This module implements the canonical message representation for interactions +between users, agents, tools, and language models. All models are frozen +(immutable) and require model_copy() for modification, supporting the CMF's +copy-on-write semantics and mutability tier enforcement. + +Domain objects (ToolCall, ImageSource, etc.) are standalone frozen models +reusable across contexts. ContentPart wrappers (ToolCallContentPart, etc.) +compose them into the typed content-part hierarchy for message serialization. +""" + +# Standard +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Annotated, Any, Iterator, Literal, Union + +# Third-Party +from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, model_validator + +# First-Party +from cpex.framework.extensions.extensions import Extensions + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class Role(str, Enum): + """Closed-set enumeration of message roles. + + Identifies WHO is speaking in a conversation turn. + + Attributes: + SYSTEM: System-level instructions. + DEVELOPER: Developer-provided instructions. + USER: Human user input. + ASSISTANT: LLM/agent response. + TOOL: Tool execution result. + + Examples: + >>> Role.USER + + >>> Role.USER.value + 'user' + >>> Role("assistant") + + """ + + SYSTEM = "system" + DEVELOPER = "developer" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + + +class Channel(str, Enum): + """Closed-set enumeration of output channel types. + + Classifies the kind of output a message represents, allowing + pipelines to route or filter messages by output type without + inspecting content. + + Attributes: + ANALYSIS: Intermediate analytical output not intended as final response. + COMMENTARY: Meta-level observations about the task or process. + FINAL: Terminal response intended for delivery to the end consumer. + + Examples: + >>> Channel.FINAL + + >>> Channel("analysis") + + """ + + ANALYSIS = "analysis" + COMMENTARY = "commentary" + FINAL = "final" + + +class ContentType(str, Enum): + """Closed-set enumeration of content part types. + + Discriminator for the typed ContentPart hierarchy, identifying + the kind of content carried by each part of a multimodal message. + + Attributes: + TEXT: Plain text content. + THINKING: Chain-of-thought reasoning. + TOOL_CALL: Tool/function invocation request. + TOOL_RESULT: Result from tool execution. + RESOURCE: Embedded resource with content (MCP). + RESOURCE_REF: Lightweight resource reference without embedded content. + PROMPT_REQUEST: Prompt template invocation request (MCP). + PROMPT_RESULT: Rendered prompt template result. + IMAGE: Image content (URL or base64). + VIDEO: Video content (URL or base64). + AUDIO: Audio content (URL or base64). + DOCUMENT: Document content (PDF, Word, etc.). + + Examples: + >>> ContentType.TOOL_CALL + + >>> ContentType("text") + + """ + + TEXT = "text" + THINKING = "thinking" + TOOL_CALL = "tool_call" + TOOL_RESULT = "tool_result" + RESOURCE = "resource" + RESOURCE_REF = "resource_ref" + PROMPT_REQUEST = "prompt_request" + PROMPT_RESULT = "prompt_result" + IMAGE = "image" + VIDEO = "video" + AUDIO = "audio" + DOCUMENT = "document" + + +class ResourceType(str, Enum): + """Closed-set enumeration of resource types. + + Attributes: + FILE: File-system resource. + BLOB: Binary large object. + URI: Generic URI-addressable resource. + DATABASE: Database entity. + API: API endpoint. + MEMORY: In-memory or ephemeral resource. + ARTIFACT: Produced artifact (generated output, build result). + + Examples: + >>> ResourceType.FILE + + >>> ResourceType("database") + + """ + + FILE = "file" + BLOB = "blob" + URI = "uri" + DATABASE = "database" + API = "api" + MEMORY = "memory" + ARTIFACT = "artifact" + + +# --------------------------------------------------------------------------- +# Domain Objects (standalone, reusable across contexts) +# --------------------------------------------------------------------------- + + +class ToolCall(BaseModel): + """Normalized tool/function invocation request. + + Standalone domain object reusable outside of message content parts. + + Attributes: + tool_call_id: Unique request correlation ID. + name: Tool name. + arguments: Arguments as a JSON-serializable dict. + namespace: Optional namespace for namespaced tools. + + Examples: + >>> call = ToolCall( + ... tool_call_id="tc_001", + ... name="get_user", + ... arguments={"user_id": "123"}, + ... ) + >>> call.name + 'get_user' + >>> call.arguments + {'user_id': '123'} + """ + + model_config = ConfigDict(frozen=True) + + tool_call_id: str = Field(description="Unique request correlation ID.") + name: str = Field(description="Tool name.") + arguments: dict[str, Any] = Field(default_factory=dict, description="Arguments as a JSON-serializable dict.") + namespace: str | None = Field(default=None, description="Namespace for namespaced tools.") + + +class ToolResult(BaseModel): + """Result from tool execution. + + Standalone domain object reusable outside of message content parts. + + Attributes: + tool_call_id: Correlation ID linking to the corresponding tool call. + tool_name: Name of the tool that was executed. + content: Result content, any JSON-serializable value. + is_error: Whether the result represents an error. + + Examples: + >>> result = ToolResult( + ... tool_call_id="tc_001", + ... tool_name="get_user", + ... content={"name": "Alice"}, + ... ) + >>> result.is_error + False + >>> result.tool_name + 'get_user' + """ + + model_config = ConfigDict(frozen=True) + + tool_call_id: str = Field(description="Correlation ID linking to the corresponding tool call.") + tool_name: str = Field(description="Name of the tool that was executed.") + content: Any = Field(default=None, description="Result content, any JSON-serializable value.") + is_error: bool = Field(default=False, description="Whether the result represents an error.") + + +class Resource(BaseModel): + """Embedded resource with content (MCP). + + Standalone domain object reusable outside of message content parts. + + Attributes: + resource_request_id: Unique request correlation ID. + uri: Unique identifier in URI format. + name: Human-readable name. + description: What this resource contains. + resource_type: The kind of resource. + content: Text content if embedded. + blob: Binary content if embedded. + mime_type: MIME type of content. + size_bytes: Size information. + annotations: Metadata (classification, retention, etc.). + version: Version tracking. + + Examples: + >>> res = Resource( + ... resource_request_id="rr_001", + ... uri="file:///data/report.csv", + ... name="Q4 Report", + ... resource_type=ResourceType.FILE, + ... content="col1,col2\\n1,2", + ... mime_type="text/csv", + ... ) + >>> res.uri + 'file:///data/report.csv' + """ + + model_config = ConfigDict(frozen=True) + + resource_request_id: str = Field(description="Unique request correlation ID.") + uri: str = Field(description="Unique identifier in URI format.") + name: str | None = Field(default=None, description="Human-readable name.") + description: str | None = Field(default=None, description="What this resource contains.") + resource_type: ResourceType = Field(description="The kind of resource.") + content: str | None = Field(default=None, description="Text content if embedded.") + blob: bytes | None = Field(default=None, description="Binary content if embedded.") + + @model_validator(mode="after") + def _check_content_blob_exclusion(self) -> Resource: + """Ensure content and blob are mutually exclusive. + + Returns: + The validated Resource instance. + + Raises: + ValueError: If both content and blob are set. + """ + if self.content is not None and self.blob is not None: + raise ValueError("Resource cannot have both 'content' and 'blob' set") + return self + + mime_type: str | None = Field(default=None, description="MIME type of content.") + size_bytes: int | None = Field(default=None, description="Size information.") + annotations: dict[str, Any] = Field(default_factory=dict, description="Metadata (classification, retention, etc.).") + version: str | None = Field(default=None, description="Version tracking.") + + +class ResourceReference(BaseModel): + """Lightweight resource reference without embedded content. + + Standalone domain object reusable outside of message content parts. + + Attributes: + resource_request_id: Correlation ID linking to the originating resource request. + uri: Resource URI. + name: Human-readable name. + resource_type: Type of resource. + range_start: Line number or byte offset for partial references. + range_end: End of range. + selector: CSS/XPath/JSONPath selector. + + Examples: + >>> ref = ResourceReference( + ... resource_request_id="rr_002", + ... uri="db://users/42", + ... resource_type=ResourceType.DATABASE, + ... ) + >>> ref.uri + 'db://users/42' + """ + + model_config = ConfigDict(frozen=True) + + resource_request_id: str = Field(description="Correlation ID linking to the originating resource request.") + uri: str = Field(description="Resource URI.") + name: str | None = Field(default=None, description="Human-readable name.") + resource_type: ResourceType = Field(description="Type of resource.") + range_start: int | None = Field(default=None, description="Line number or byte offset for partial references.") + range_end: int | None = Field(default=None, description="End of range.") + selector: str | None = Field(default=None, description="CSS/XPath/JSONPath selector.") + + @model_validator(mode="after") + def _check_range_consistency(self) -> ResourceReference: + """Ensure range_end is not less than range_start. + + Returns: + The validated ResourceReference instance. + + Raises: + ValueError: If range_end < range_start. + """ + if self.range_start is not None and self.range_end is not None: + if self.range_end < self.range_start: + raise ValueError(f"range_end ({self.range_end}) must be >= range_start ({self.range_start})") + return self + + +class PromptRequest(BaseModel): + """Prompt template invocation request (MCP). + + Standalone domain object reusable outside of message content parts. + + Attributes: + prompt_request_id: Request ID for correlation. + name: Prompt template name. + arguments: Arguments to pass to the template. + server_id: Source server for multi-server scenarios. + + Examples: + >>> req = PromptRequest( + ... prompt_request_id="pr_001", + ... name="summarize", + ... arguments={"text": "Long document..."}, + ... ) + >>> req.name + 'summarize' + """ + + model_config = ConfigDict(frozen=True) + + prompt_request_id: str = Field(description="Request ID for correlation.") + name: str = Field(description="Prompt template name.") + arguments: dict[str, Any] = Field(default_factory=dict, description="Arguments to pass to the template.") + server_id: str | None = Field(default=None, description="Source server for multi-server scenarios.") + + +class PromptResult(BaseModel): + """Rendered prompt template result. + + Standalone domain object reusable outside of message content parts. + + Attributes: + prompt_request_id: ID of the corresponding prompt request. + prompt_name: Name of the prompt that was rendered. + messages: Rendered messages (prompts produce messages). + content: Single text result for simple prompts. + is_error: Whether rendering failed. + error_message: Error details if rendering failed. + + Examples: + >>> result = PromptResult( + ... prompt_request_id="pr_001", + ... prompt_name="summarize", + ... content="This document discusses...", + ... ) + >>> result.is_error + False + """ + + model_config = ConfigDict(frozen=True) + + prompt_request_id: str = Field(description="ID of the corresponding prompt request.") + prompt_name: str = Field(description="Name of the prompt that was rendered.") + messages: list[Message] = Field( + default_factory=list, + description="Rendered messages (prompts produce messages).", + ) + content: str | None = Field(default=None, description="Single text result for simple prompts.") + is_error: bool = Field(default=False, description="Whether rendering failed.") + error_message: str | None = Field(default=None, description="Error details if rendering failed.") + + +class ImageSource(BaseModel): + """Image source data. + + Standalone domain object reusable outside of message content parts. + + Attributes: + type: Source type, either URL or base64-encoded. + data: URL or base64-encoded string. + media_type: MIME type (e.g., image/jpeg). + + Examples: + >>> img = ImageSource(type="url", data="https://example.com/photo.jpg") + >>> img.type + 'url' + """ + + model_config = ConfigDict(frozen=True) + + type: Literal["url", "base64"] = Field(description="Source type: 'url' or 'base64'.") + data: str = Field(description="URL or base64-encoded string.") + media_type: str | None = Field(default=None, description="MIME type (e.g., image/jpeg).") + + +class VideoSource(BaseModel): + """Video source data. + + Standalone domain object reusable outside of message content parts. + + Attributes: + type: Source type, either URL or base64-encoded. + data: URL or base64-encoded string. + media_type: MIME type (e.g., video/mp4). + duration_ms: Duration in milliseconds. + + Examples: + >>> vid = VideoSource(type="url", data="https://example.com/clip.mp4") + >>> vid.type + 'url' + """ + + model_config = ConfigDict(frozen=True) + + type: Literal["url", "base64"] = Field(description="Source type: 'url' or 'base64'.") + data: str = Field(description="URL or base64-encoded string.") + media_type: str | None = Field(default=None, description="MIME type (e.g., video/mp4).") + duration_ms: int | None = Field(default=None, description="Duration in milliseconds.") + + +class AudioSource(BaseModel): + """Audio source data. + + Standalone domain object reusable outside of message content parts. + + Attributes: + type: Source type, either URL or base64-encoded. + data: URL or base64-encoded string. + media_type: MIME type (e.g., audio/mp3). + duration_ms: Duration in milliseconds. + + Examples: + >>> aud = AudioSource(type="url", data="https://example.com/track.mp3") + >>> aud.type + 'url' + """ + + model_config = ConfigDict(frozen=True) + + type: Literal["url", "base64"] = Field(description="Source type: 'url' or 'base64'.") + data: str = Field(description="URL or base64-encoded string.") + media_type: str | None = Field(default=None, description="MIME type (e.g., audio/mp3).") + duration_ms: int | None = Field(default=None, description="Duration in milliseconds.") + + +class DocumentSource(BaseModel): + """Document source data (PDF, Word, etc.). + + Standalone domain object reusable outside of message content parts. + + Attributes: + type: Source type, either URL or base64-encoded. + data: URL or base64-encoded string. + media_type: MIME type (e.g., application/pdf). + title: Document title. + + Examples: + >>> doc = DocumentSource( + ... type="base64", + ... data="JVBERi0xLjQ...", + ... media_type="application/pdf", + ... title="Annual Report", + ... ) + >>> doc.title + 'Annual Report' + """ + + model_config = ConfigDict(frozen=True) + + type: Literal["url", "base64"] = Field(description="Source type: 'url' or 'base64'.") + data: str = Field(description="URL or base64-encoded string.") + media_type: str | None = Field(default=None, description="MIME type (e.g., application/pdf).") + title: str | None = Field(default=None, description="Document title.") + + +# --------------------------------------------------------------------------- +# Content Parts (ContentPart base + wrappers) +# --------------------------------------------------------------------------- + + +class ContentPart(BaseModel): + """Base class for all content parts in a CMF message. + + Frozen by design — subclasses inherit immutability. Consumers must + use model_copy(update={...}) to create modified copies. + + Attributes: + content_type: Discriminator identifying the concrete content type. + + Examples: + >>> part = TextContent(text="hello") + >>> isinstance(part, ContentPart) + True + >>> part.content_type + + """ + + model_config = ConfigDict(frozen=True) + + content_type: ContentType = Field(description="Content type discriminator.") + + +class TextContent(ContentPart): + """Plain text content part. + + Attributes: + content_type: Discriminator, always ContentType.TEXT. + text: The text content. + + Examples: + >>> part = TextContent(text="Hello, world!") + >>> part.content_type + + >>> part.text + 'Hello, world!' + >>> modified = part.model_copy(update={"text": "Updated"}) + >>> (part.text, modified.text) + ('Hello, world!', 'Updated') + """ + + content_type: Literal[ContentType.TEXT] = Field(default=ContentType.TEXT, description="Content type discriminator.") + text: str = Field(description="The text content.") + + +class ThinkingContent(ContentPart): + """Chain-of-thought reasoning content part. + + Attributes: + content_type: Discriminator, always ContentType.THINKING. + text: The reasoning text. + + Examples: + >>> part = ThinkingContent(text="Let me analyze this...") + >>> part.content_type + + """ + + content_type: Literal[ContentType.THINKING] = Field( + default=ContentType.THINKING, description="Content type discriminator." + ) + text: str = Field(description="The reasoning text.") + + +class ToolCallContentPart(ContentPart): + """Content part wrapping a ToolCall domain object. + + Attributes: + content_type: Discriminator, always ContentType.TOOL_CALL. + content: The wrapped ToolCall. + + Examples: + >>> part = ToolCallContentPart( + ... content=ToolCall(tool_call_id="tc_001", name="search", arguments={"q": "test"}), + ... ) + >>> part.content.name + 'search' + """ + + content_type: Literal[ContentType.TOOL_CALL] = Field( + default=ContentType.TOOL_CALL, description="Content type discriminator." + ) + content: ToolCall = Field(description="The wrapped ToolCall.") + + +class ToolResultContentPart(ContentPart): + """Content part wrapping a ToolResult domain object. + + Attributes: + content_type: Discriminator, always ContentType.TOOL_RESULT. + content: The wrapped ToolResult. + + Examples: + >>> part = ToolResultContentPart( + ... content=ToolResult(tool_call_id="tc_001", tool_name="search", content="Found 10 results"), + ... ) + >>> part.content.tool_name + 'search' + """ + + content_type: Literal[ContentType.TOOL_RESULT] = Field( + default=ContentType.TOOL_RESULT, description="Content type discriminator." + ) + content: ToolResult = Field(description="The wrapped ToolResult.") + + +class ResourceContentPart(ContentPart): + """Content part wrapping a Resource domain object. + + Attributes: + content_type: Discriminator, always ContentType.RESOURCE. + content: The wrapped Resource. + + Examples: + >>> part = ResourceContentPart( + ... content=Resource(resource_request_id="rr_001", uri="file:///data.txt", resource_type=ResourceType.FILE), + ... ) + >>> part.content.uri + 'file:///data.txt' + """ + + content_type: Literal[ContentType.RESOURCE] = Field( + default=ContentType.RESOURCE, description="Content type discriminator." + ) + content: Resource = Field(description="The wrapped Resource.") + + +class ResourceRefContentPart(ContentPart): + """Content part wrapping a ResourceReference domain object. + + Attributes: + content_type: Discriminator, always ContentType.RESOURCE_REF. + content: The wrapped ResourceReference. + + Examples: + >>> part = ResourceRefContentPart( + ... content=ResourceReference(resource_request_id="rr_002", uri="db://users/42", resource_type=ResourceType.DATABASE), + ... ) + >>> part.content.uri + 'db://users/42' + """ + + content_type: Literal[ContentType.RESOURCE_REF] = Field( + default=ContentType.RESOURCE_REF, description="Content type discriminator." + ) + content: ResourceReference = Field(description="The wrapped ResourceReference.") + + +class PromptRequestContentPart(ContentPart): + """Content part wrapping a PromptRequest domain object. + + Attributes: + content_type: Discriminator, always ContentType.PROMPT_REQUEST. + content: The wrapped PromptRequest. + + Examples: + >>> part = PromptRequestContentPart( + ... content=PromptRequest(prompt_request_id="pr_001", name="summarize"), + ... ) + >>> part.content.name + 'summarize' + """ + + content_type: Literal[ContentType.PROMPT_REQUEST] = Field( + default=ContentType.PROMPT_REQUEST, description="Content type discriminator." + ) + content: PromptRequest = Field(description="The wrapped PromptRequest.") + + +class PromptResultContentPart(ContentPart): + """Content part wrapping a PromptResult domain object. + + Attributes: + content_type: Discriminator, always ContentType.PROMPT_RESULT. + content: The wrapped PromptResult. + + Examples: + >>> part = PromptResultContentPart( + ... content=PromptResult(prompt_request_id="pr_001", prompt_name="summarize"), + ... ) + >>> part.content.prompt_name + 'summarize' + """ + + content_type: Literal[ContentType.PROMPT_RESULT] = Field( + default=ContentType.PROMPT_RESULT, description="Content type discriminator." + ) + content: PromptResult = Field(description="The wrapped PromptResult.") + + +class ImageContentPart(ContentPart): + """Content part wrapping an ImageSource domain object. + + Attributes: + content_type: Discriminator, always ContentType.IMAGE. + content: The wrapped ImageSource. + + Examples: + >>> part = ImageContentPart( + ... content=ImageSource(type="url", data="https://example.com/photo.jpg"), + ... ) + >>> part.content.type + 'url' + """ + + content_type: Literal[ContentType.IMAGE] = Field( + default=ContentType.IMAGE, description="Content type discriminator." + ) + content: ImageSource = Field(description="The wrapped ImageSource.") + + +class VideoContentPart(ContentPart): + """Content part wrapping a VideoSource domain object. + + Attributes: + content_type: Discriminator, always ContentType.VIDEO. + content: The wrapped VideoSource. + + Examples: + >>> part = VideoContentPart( + ... content=VideoSource(type="url", data="https://example.com/clip.mp4"), + ... ) + >>> part.content.type + 'url' + """ + + content_type: Literal[ContentType.VIDEO] = Field( + default=ContentType.VIDEO, description="Content type discriminator." + ) + content: VideoSource = Field(description="The wrapped VideoSource.") + + +class AudioContentPart(ContentPart): + """Content part wrapping an AudioSource domain object. + + Attributes: + content_type: Discriminator, always ContentType.AUDIO. + content: The wrapped AudioSource. + + Examples: + >>> part = AudioContentPart( + ... content=AudioSource(type="url", data="https://example.com/track.mp3"), + ... ) + >>> part.content.type + 'url' + """ + + content_type: Literal[ContentType.AUDIO] = Field( + default=ContentType.AUDIO, description="Content type discriminator." + ) + content: AudioSource = Field(description="The wrapped AudioSource.") + + +class DocumentContentPart(ContentPart): + """Content part wrapping a DocumentSource domain object. + + Attributes: + content_type: Discriminator, always ContentType.DOCUMENT. + content: The wrapped DocumentSource. + + Examples: + >>> part = DocumentContentPart( + ... content=DocumentSource(type="base64", data="JVBERi0xLjQ...", media_type="application/pdf"), + ... ) + >>> part.content.media_type + 'application/pdf' + """ + + content_type: Literal[ContentType.DOCUMENT] = Field( + default=ContentType.DOCUMENT, description="Content type discriminator." + ) + content: DocumentSource = Field(description="The wrapped DocumentSource.") + + +# --------------------------------------------------------------------------- +# ContentPart Discriminated Union +# --------------------------------------------------------------------------- + + +def _content_type_discriminator(v: Any) -> str: + """Extract the content_type discriminator value from a content part. + + Supports both dict (during deserialization) and model instance access. + + Args: + v: A content part as a dict or model instance. + + Returns: + The content_type string value for discriminator routing. + """ + if isinstance(v, dict): + ct = v.get("content_type") + if ct is None: + raise ValueError("Missing 'content_type' discriminator in content part dict") + return ct + if not hasattr(v, "content_type"): + raise ValueError(f"Content part {type(v).__name__} missing 'content_type' attribute") + return v.content_type.value + + +ContentPartUnion = Annotated[ + Union[ + Annotated[TextContent, Tag("text")], + Annotated[ThinkingContent, Tag("thinking")], + Annotated[ToolCallContentPart, Tag("tool_call")], + Annotated[ToolResultContentPart, Tag("tool_result")], + Annotated[ResourceContentPart, Tag("resource")], + Annotated[ResourceRefContentPart, Tag("resource_ref")], + Annotated[PromptRequestContentPart, Tag("prompt_request")], + Annotated[PromptResultContentPart, Tag("prompt_result")], + Annotated[ImageContentPart, Tag("image")], + Annotated[VideoContentPart, Tag("video")], + Annotated[AudioContentPart, Tag("audio")], + Annotated[DocumentContentPart, Tag("document")], + ], + Discriminator(_content_type_discriminator), +] +"""Discriminated union of all content part types. + +Pydantic uses the content_type field to resolve the correct subclass +during validation and deserialization. +""" + + +# --------------------------------------------------------------------------- +# Message +# --------------------------------------------------------------------------- + + +class Message(BaseModel): + """Canonical CMF message representing a single turn in a conversation. + + A Message is the storage and wire format. It preserves the structure + exactly as the LLM or framework sent it. For policy evaluation and + inspection, use MessageView (via iter_views()) which decomposes the + message into individually addressable, uniformly accessible parts. + + All Message instances are frozen. To create a modified copy, use + model_copy(update={...}). + + Attributes: + schema_version: Message schema version. + role: Who is speaking. + content: List of typed content parts (multimodal). + channel: Optional output classification. + extensions: Optional contextual metadata (identity, security, governance). + + Examples: + >>> msg = Message( + ... role=Role.USER, + ... content=[TextContent(text="What is the weather?")], + ... ) + >>> msg.role + + >>> msg.content[0].text + 'What is the weather?' + >>> msg.schema_version + '2.0' + + >>> # Frozen: modifications require model_copy + >>> updated = msg.model_copy(update={"channel": Channel.FINAL}) + >>> updated.channel + + >>> msg.channel is None + True + + >>> # Multi-part assistant message + >>> assistant_msg = Message( + ... role=Role.ASSISTANT, + ... content=[ + ... ThinkingContent(text="I should check the weather API."), + ... TextContent(text="Let me look that up."), + ... ToolCallContentPart( + ... content=ToolCall( + ... tool_call_id="tc_001", + ... name="get_weather", + ... arguments={"city": "London"}, + ... ), + ... ), + ... ], + ... ) + >>> len(assistant_msg.content) + 3 + >>> assistant_msg.content[2].content.name + 'get_weather' + """ + + model_config = ConfigDict(frozen=True) + + schema_version: str = Field(default="2.0", description="Message schema version.") + role: Role = Field(description="Who is speaking.") + content: list[ContentPartUnion] = Field(default_factory=list, description="List of typed content parts.") + channel: Channel | None = Field(default=None, description="Optional output classification.") + extensions: Extensions | None = Field( + default=None, + description="Contextual metadata (identity, security, governance, etc.).", + ) + + def iter_views(self, hook: str | None = None) -> Iterator[MessageView]: + """Decompose this message into individually addressable MessageViews. + + Yields one MessageView per content part. Each view provides a + uniform interface for policy evaluation regardless of content type. + + Args: + hook: Optional hook location string (e.g., "llm_input", + "tool_post_invoke") to attach to each view. + + Returns: + An iterator of MessageView objects. + + Examples: + >>> msg = Message( + ... role=Role.ASSISTANT, + ... content=[ + ... TextContent(text="Let me check."), + ... ToolCallContentPart( + ... content=ToolCall( + ... tool_call_id="tc_001", + ... name="get_weather", + ... arguments={"city": "London"}, + ... ), + ... ), + ... ], + ... ) + >>> views = list(msg.iter_views()) + >>> len(views) + 2 + >>> views[0].kind.value + 'text' + >>> views[1].name + 'get_weather' + """ + from cpex.framework.cmf.view import iter_views # pylint: disable=import-outside-toplevel + + return iter_views(self, hook=hook) + + +if TYPE_CHECKING: + from cpex.framework.cmf.view import MessageView diff --git a/cpex/framework/cmf/view.py b/cpex/framework/cmf/view.py new file mode 100644 index 0000000..75fc95c --- /dev/null +++ b/cpex/framework/cmf/view.py @@ -0,0 +1,1219 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/cmf/view.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +MessageView — read-only projection for policy evaluation. + +Decomposes a Message into individually addressable views with a +uniform interface regardless of content type. Zero-copy design — +properties are computed on-demand by accessing the underlying +content part and message extensions directly. +""" + +# Standard +import json +import logging +import re +from enum import Enum +from types import MappingProxyType +from typing import Any, Iterator, Mapping + +# First-Party +from cpex.framework.cmf.message import ( + ContentPart, + ContentType, + Message, + Resource, + Role, +) +from cpex.framework.extensions.security import ( + DataPolicy, + ObjectSecurityProfile, + SubjectExtension, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class ViewKind(str, Enum): + """Closed-set enumeration of message view kinds. + + Maps one-to-one with ContentType, identifying the kind of + content that a view represents. + + Attributes: + TEXT: Plain text content. + THINKING: Reasoning/chain-of-thought content. + TOOL_CALL: Tool/function invocation. + TOOL_RESULT: Result from tool execution. + RESOURCE: Embedded resource with content. + RESOURCE_REF: Reference to a resource (URI only). + PROMPT_REQUEST: Prompt template request. + PROMPT_RESULT: Rendered prompt result. + IMAGE: Image content. + VIDEO: Video content. + AUDIO: Audio content. + DOCUMENT: Document content. + + Examples: + >>> ViewKind.TOOL_CALL + + >>> ViewKind.TOOL_CALL.value + 'tool_call' + """ + + TEXT = "text" + THINKING = "thinking" + TOOL_CALL = "tool_call" + TOOL_RESULT = "tool_result" + RESOURCE = "resource" + RESOURCE_REF = "resource_ref" + PROMPT_REQUEST = "prompt_request" + PROMPT_RESULT = "prompt_result" + IMAGE = "image" + VIDEO = "video" + AUDIO = "audio" + DOCUMENT = "document" + + +class ViewAction(str, Enum): + """Closed-set enumeration of semantic actions. + + Attributes: + READ: Reading/accessing data. + WRITE: Writing/modifying data. + EXECUTE: Executing a tool or command. + INVOKE: Invoking a prompt template. + SEND: Sending content outbound. + RECEIVE: Receiving content inbound. + GENERATE: Generating content (LLM output). + + Examples: + >>> ViewAction.EXECUTE + + """ + + READ = "read" + WRITE = "write" + EXECUTE = "execute" + INVOKE = "invoke" + SEND = "send" + RECEIVE = "receive" + GENERATE = "generate" + + +# --------------------------------------------------------------------------- +# ContentType -> ViewKind mapping +# --------------------------------------------------------------------------- + +_CONTENT_TYPE_TO_VIEW_KIND: dict[ContentType, ViewKind] = { + ContentType.TEXT: ViewKind.TEXT, + ContentType.THINKING: ViewKind.THINKING, + ContentType.TOOL_CALL: ViewKind.TOOL_CALL, + ContentType.TOOL_RESULT: ViewKind.TOOL_RESULT, + ContentType.RESOURCE: ViewKind.RESOURCE, + ContentType.RESOURCE_REF: ViewKind.RESOURCE_REF, + ContentType.PROMPT_REQUEST: ViewKind.PROMPT_REQUEST, + ContentType.PROMPT_RESULT: ViewKind.PROMPT_RESULT, + ContentType.IMAGE: ViewKind.IMAGE, + ContentType.VIDEO: ViewKind.VIDEO, + ContentType.AUDIO: ViewKind.AUDIO, + ContentType.DOCUMENT: ViewKind.DOCUMENT, +} + +_ACTION_MAP: dict[ViewKind, ViewAction] = { + ViewKind.TOOL_CALL: ViewAction.EXECUTE, + ViewKind.TOOL_RESULT: ViewAction.RECEIVE, + ViewKind.RESOURCE: ViewAction.READ, + ViewKind.RESOURCE_REF: ViewAction.READ, + ViewKind.PROMPT_REQUEST: ViewAction.INVOKE, + ViewKind.PROMPT_RESULT: ViewAction.RECEIVE, +} + +# Kinds whose action depends on message direction (role) +_DIRECTION_DEPENDENT_KINDS = frozenset( + { + ViewKind.TEXT, + ViewKind.THINKING, + ViewKind.IMAGE, + ViewKind.VIDEO, + ViewKind.AUDIO, + ViewKind.DOCUMENT, + } +) + +# Sensitive headers stripped during serialization +_SENSITIVE_HEADERS = frozenset({"authorization", "cookie", "x-api-key"}) + + +# --------------------------------------------------------------------------- +# MessageView +# --------------------------------------------------------------------------- + + +class MessageView: + """Read-only, zero-copy view over a single content part for policy evaluation. + + A MessageView provides a uniform interface for inspecting any content + part of a message — regardless of whether it's text, a tool call, a + resource, or media. Properties are computed on-demand from the + underlying content part and message extensions without copying data. + + For wrapped content parts (tool calls, resources, media, etc.), the + domain object is accessed via the wrapper's .content field. The _inner + property provides convenient access to the wrapped domain object. + + MessageViews are produced by Message.iter_views() or the standalone + iter_views() function. A single Message with multiple content parts + yields one view per part. + + Attributes: + kind: The type of content this view represents. + role: The role of the parent message. + raw: Direct access to the underlying content part. + + Examples: + >>> from cpex.framework.cmf.message import ( + ... Message, Role, TextContent, ToolCall, ToolCallContentPart, + ... ) + >>> msg = Message( + ... role=Role.ASSISTANT, + ... content=[ + ... TextContent(text="Let me look that up."), + ... ToolCallContentPart( + ... content=ToolCall( + ... tool_call_id="tc_001", + ... name="get_user", + ... arguments={"id": "123"}, + ... ), + ... ), + ... ], + ... ) + >>> views = list(iter_views(msg)) + >>> len(views) + 2 + >>> views[0].kind + + >>> views[1].kind + + >>> views[1].name + 'get_user' + >>> views[1].uri + 'tool://_/get_user' + >>> views[1].is_pre + True + """ + + __slots__ = ("_part", "_kind", "_message", "_hook") + + def __init__( + self, + part: ContentPart, + kind: ViewKind, + message: Message, + hook: str | None = None, + ) -> None: + """Initialize a MessageView. + + Args: + part: The underlying content part. + kind: The kind of content. + message: The parent message (for role and extensions access). + hook: The hook location where this view is being evaluated + (e.g., "llm_input", "tool_post_invoke"). None if unset. + """ + self._part = part + self._kind = kind + self._message = message + self._hook = hook + + # ========================================================================= + # Internal Helpers + # ========================================================================= + + @property + def _inner(self) -> Any: + """Get the wrapped domain object for composite content parts. + + For TextContent/ThinkingContent (which have no wrapper), returns + the part itself. For all other types, returns the .content field + which holds the domain object (ToolCall, Resource, etc.). + + Returns: + The domain object for this content part. + """ + if self._kind in (ViewKind.TEXT, ViewKind.THINKING): + return self._part + return self._part.content # type: ignore[union-attr] + + # ========================================================================= + # Core Properties + # ========================================================================= + + @property + def kind(self) -> ViewKind: + """The type of content this view represents. + + Returns: + The ViewKind for this view. + """ + return self._kind + + @property + def role(self) -> Role: + """The role of the parent message. + + Returns: + The Role (user, assistant, system, developer, tool). + """ + return self._message.role + + @property + def hook(self) -> str | None: + """The hook location where this view is being evaluated. + + Indicates where in the pipeline the evaluation is occurring + (e.g., "llm_input", "llm_output", "tool_pre_invoke", + "tool_post_invoke"). None if not set. + + Returns: + Hook location string or None. + """ + return self._hook + + @property + def raw(self) -> ContentPart: + """Direct access to the underlying content part. + + Returns: + The underlying ContentPart subclass instance. + """ + return self._part + + @property + def content(self) -> str | None: + """Scannable text content. + + For text/thinking: the text itself. For tool calls and prompt + requests: JSON-serialized arguments. For tool results: + JSON-serialized content. For resources: embedded content string. + + Returns: + Scannable text or None if no text content is available. + """ + inner = self._inner + + if self._kind in (ViewKind.TEXT, ViewKind.THINKING): + return inner.text + + if self._kind == ViewKind.RESOURCE: + return inner.content + + if self._kind == ViewKind.TOOL_CALL: + try: + return json.dumps(inner.arguments) + except (TypeError, ValueError): + return str(inner.arguments) + + if self._kind == ViewKind.TOOL_RESULT: + result_content = inner.content + if result_content is None: + return None + if isinstance(result_content, str): + return result_content + try: + return json.dumps(result_content) + except (TypeError, ValueError): + return str(result_content) + + if self._kind == ViewKind.PROMPT_REQUEST: + try: + return json.dumps(inner.arguments) + except (TypeError, ValueError): + return str(inner.arguments) + + if self._kind == ViewKind.PROMPT_RESULT: + return inner.content + + return None + + @property + def uri(self) -> str | None: + """Synthetic identity URI. + + Tools: tool://namespace/name. Tool results: tool_result://name. + Prompts: prompt://server/name. Prompt results: prompt_result://name. + Resources: the resource's own URI. + + Returns: + URI string or None if not applicable. + """ + inner = self._inner + + if self._kind in (ViewKind.RESOURCE, ViewKind.RESOURCE_REF): + return inner.uri + + if self._kind == ViewKind.TOOL_CALL: + ns = inner.namespace or "_" + return f"tool://{ns}/{inner.name}" + + if self._kind == ViewKind.TOOL_RESULT: + return f"tool_result://{inner.tool_name}" + + if self._kind == ViewKind.PROMPT_REQUEST: + server = inner.server_id or "_" + return f"prompt://{server}/{inner.name}" + + if self._kind == ViewKind.PROMPT_RESULT: + return f"prompt_result://{inner.prompt_name}" + + return None + + @property + def name(self) -> str | None: + """Human-readable name (tool name, resource name, prompt name). + + Returns: + Name string or None if not applicable. + """ + inner = self._inner + + if self._kind in (ViewKind.TOOL_CALL, ViewKind.PROMPT_REQUEST): + return inner.name + + if self._kind in (ViewKind.RESOURCE, ViewKind.RESOURCE_REF): + return inner.name + + if self._kind == ViewKind.TOOL_RESULT: + return inner.tool_name + + if self._kind == ViewKind.PROMPT_RESULT: + return inner.prompt_name + + return None + + @property + def action(self) -> ViewAction: + """The semantic action this view represents. + + For content kinds like text and media, the action depends on + the message role: SEND for user/system/developer input, + GENERATE for assistant output, RECEIVE for tool output. + + Returns: + A ViewAction value. + """ + fixed = _ACTION_MAP.get(self._kind) + if fixed is not None: + return fixed + role = self._message.role + if role == Role.ASSISTANT: + return ViewAction.GENERATE + if role == Role.TOOL: + return ViewAction.RECEIVE + return ViewAction.SEND + + @property + def args(self) -> dict[str, Any] | None: + """Arguments dict for tool calls and prompt requests. + + Returns: + Arguments dict or None for other content types. + """ + inner = self._inner + if self._kind == ViewKind.TOOL_CALL: + return inner.arguments + if self._kind == ViewKind.PROMPT_REQUEST: + return inner.arguments + return None + + @property + def mime_type(self) -> str | None: + """MIME type if applicable. + + Returns: + MIME type string or None. + """ + inner = self._inner + if self._kind == ViewKind.RESOURCE: + return inner.mime_type + if self._kind in (ViewKind.IMAGE, ViewKind.VIDEO, ViewKind.AUDIO, ViewKind.DOCUMENT): + return inner.media_type + return None + + @property + def size_bytes(self) -> int | None: + """Content size in bytes (computed from content). + + Returns: + Size in bytes or None. + """ + if self._kind == ViewKind.RESOURCE: + res: Resource = self._inner + if res.size_bytes is not None: + return res.size_bytes + if res.content: + return len(res.content.encode("utf-8")) + if res.blob: + return len(res.blob) + return None + + text = self.content + if text is not None: + return len(text.encode("utf-8")) + return None + + @property + def properties(self) -> dict[str, Any]: + """Type-specific properties as a dict. + + For single property access, prefer get_property() which + avoids allocating a dict. + + Returns: + Dict of property name to value for this view's kind. + """ + props: dict[str, Any] = {} + inner = self._inner + + if self._kind == ViewKind.RESOURCE: + props["resource_type"] = inner.resource_type.value + props["version"] = inner.version + props["annotations"] = inner.annotations + + elif self._kind == ViewKind.TOOL_CALL: + props["namespace"] = inner.namespace + props["tool_id"] = inner.tool_call_id + + elif self._kind == ViewKind.TOOL_RESULT: + props["is_error"] = inner.is_error + props["tool_name"] = inner.tool_name + + elif self._kind == ViewKind.PROMPT_REQUEST: + props["server_id"] = inner.server_id + + elif self._kind == ViewKind.PROMPT_RESULT: + props["is_error"] = inner.is_error + props["message_count"] = len(inner.messages) if inner.messages else 0 + + return props + + def get_property(self, name: str, default: Any = None) -> Any: + """Get a single type-specific property without allocating a dict. + + Args: + name: Property name to retrieve. + default: Value to return if property doesn't exist. + + Returns: + The property value or default. + """ + inner = self._inner + + if self._kind == ViewKind.RESOURCE: + if name == "resource_type": + return inner.resource_type.value + if name == "version": + return inner.version + if name == "annotations": + return inner.annotations + + elif self._kind == ViewKind.TOOL_CALL: + if name == "namespace": + return inner.namespace + if name == "tool_id": + return inner.tool_call_id + + elif self._kind == ViewKind.TOOL_RESULT: + if name == "is_error": + return inner.is_error + if name == "tool_name": + return inner.tool_name + + elif self._kind == ViewKind.PROMPT_REQUEST: + if name == "server_id": + return inner.server_id + + elif self._kind == ViewKind.PROMPT_RESULT: + if name == "is_error": + return inner.is_error + if name == "message_count": + return len(inner.messages) if inner.messages else 0 + + return default + + # ========================================================================= + # Direction + # ========================================================================= + + @property + def is_pre(self) -> bool: + """True if this represents input/request content (before processing). + + Determined by ViewKind for requests/responses, and by Role + for text, thinking, and media content. + + Returns: + True if this is pre-processing content. + """ + if self._kind in (ViewKind.TOOL_CALL, ViewKind.PROMPT_REQUEST, ViewKind.RESOURCE_REF): + return True + if self._kind in (ViewKind.TOOL_RESULT, ViewKind.PROMPT_RESULT, ViewKind.RESOURCE): + return False + return self._message.role in (Role.USER, Role.SYSTEM, Role.DEVELOPER) + + @property + def is_post(self) -> bool: + """True if this represents output/response content (after processing). + + Returns: + True if this is post-processing content. + """ + if self._kind in (ViewKind.TOOL_RESULT, ViewKind.PROMPT_RESULT, ViewKind.RESOURCE): + return True + if self._kind in (ViewKind.TOOL_CALL, ViewKind.PROMPT_REQUEST, ViewKind.RESOURCE_REF): + return False + return self._message.role in (Role.ASSISTANT, Role.TOOL) + + @property + def is_tool(self) -> bool: + """True if tool_call or tool_result. + + Returns: + True if this is a tool-related view. + """ + return self._kind in (ViewKind.TOOL_CALL, ViewKind.TOOL_RESULT) + + @property + def is_prompt(self) -> bool: + """True if prompt_request or prompt_result. + + Returns: + True if this is a prompt-related view. + """ + return self._kind in (ViewKind.PROMPT_REQUEST, ViewKind.PROMPT_RESULT) + + @property + def is_resource(self) -> bool: + """True if resource or resource_ref. + + Returns: + True if this is a resource-related view. + """ + return self._kind in (ViewKind.RESOURCE, ViewKind.RESOURCE_REF) + + @property + def is_text(self) -> bool: + """True if text or thinking. + + Returns: + True if this is text-based content. + """ + return self._kind in (ViewKind.TEXT, ViewKind.THINKING) + + @property + def is_media(self) -> bool: + """True if image, video, audio, or document. + + Returns: + True if this is media content. + """ + return self._kind in (ViewKind.IMAGE, ViewKind.VIDEO, ViewKind.AUDIO, ViewKind.DOCUMENT) + + # ========================================================================= + # Flat Accessors (capability-gated in the spec) + # ========================================================================= + + def _ext(self) -> Any: + """Get the message extensions, or None.""" + return self._message.extensions + + # --- Base tier (no capability required) --- + + @property + def environment(self) -> str | None: + """Execution environment (production, staging, dev). + + Capability: base (no requirement). + + Returns: + Environment string or None. + """ + ext = self._ext() + if ext and ext.request: + return ext.request.environment + return None + + @property + def request_id(self) -> str | None: + """Request correlation ID. + + Capability: base (no requirement). + + Returns: + Request ID string or None. + """ + ext = self._ext() + if ext and ext.request: + return ext.request.request_id + return None + + # --- read_subject --- + + @property + def subject(self) -> SubjectExtension | None: + """The authenticated entity making the request. + + Capability: read_subject. + + Returns: + SubjectExtension or None. + """ + ext = self._ext() + if ext and ext.security: + return ext.security.subject + return None + + # --- read_roles --- + + @property + def roles(self) -> frozenset[str]: + """Subject's assigned roles. + + Capability: read_roles. + + Returns: + Frozenset of role strings. + """ + s = self.subject + return s.roles if s else frozenset() + + # --- read_permissions --- + + @property + def permissions(self) -> frozenset[str]: + """Subject's granted permissions. + + Capability: read_permissions. + + Returns: + Frozenset of permission strings. + """ + s = self.subject + return s.permissions if s else frozenset() + + # --- read_teams --- + + @property + def teams(self) -> frozenset[str]: + """Subject's team memberships. + + Capability: read_teams. + + Returns: + Frozenset of team strings. + """ + s = self.subject + return s.teams if s else frozenset() + + # --- read_headers --- + + @property + def headers(self) -> Mapping[str, str]: + """HTTP headers associated with the request. + + Capability: read_headers. + + Returns: + Read-only mapping of header name to value. + """ + ext = self._ext() + if ext and ext.http: + return MappingProxyType(ext.http.headers) + return MappingProxyType({}) + + # --- read_labels --- + + @property + def labels(self) -> frozenset[str]: + """Security/data labels on this message. + + Capability: read_labels. + + Returns: + Frozenset of label strings. + """ + ext = self._ext() + if ext and ext.security: + return ext.security.labels + return frozenset() + + # --- read_agent --- + + @property + def agent_input(self) -> str | None: + """Original user intent that triggered this action. + + Capability: read_agent. + + Returns: + Input string or None. + """ + ext = self._ext() + if ext and ext.agent: + return ext.agent.input + return None + + @property + def session_id(self) -> str | None: + """Broad session identifier. + + Capability: read_agent. + + Returns: + Session ID string or None. + """ + ext = self._ext() + if ext and ext.agent: + return ext.agent.session_id + return None + + @property + def conversation_id(self) -> str | None: + """Specific dialogue/task within a session. + + Capability: read_agent. + + Returns: + Conversation ID string or None. + """ + ext = self._ext() + if ext and ext.agent: + return ext.agent.conversation_id + return None + + @property + def turn(self) -> int | None: + """Position in conversation (0-indexed). + + Capability: read_agent. + + Returns: + Turn number or None. + """ + ext = self._ext() + if ext and ext.agent: + return ext.agent.turn + return None + + @property + def agent_id(self) -> str | None: + """Identifier of the producing agent. + + Capability: read_agent. + + Returns: + Agent ID string or None. + """ + ext = self._ext() + if ext and ext.agent: + return ext.agent.agent_id + return None + + @property + def parent_agent_id(self) -> str | None: + """Spawning agent's ID (multi-agent lineage). + + Capability: read_agent. + + Returns: + Parent agent ID string or None. + """ + ext = self._ext() + if ext and ext.agent: + return ext.agent.parent_agent_id + return None + + # --- read_objects --- + + @property + def object(self) -> ObjectSecurityProfile | None: + """Access control profile for this view's entity. + + Resolved by view.name from extensions.security.objects. + + Capability: read_objects. + + Returns: + ObjectSecurityProfile or None. + """ + ext = self._ext() + view_name = self.name + if ext and ext.security and view_name: + return ext.security.objects.get(view_name) + return None + + # --- read_data --- + + @property + def data_policy(self) -> DataPolicy | None: + """Data governance policy for this view's entity. + + Resolved by view.name from extensions.security.data. + + Capability: read_data. + + Returns: + DataPolicy or None. + """ + ext = self._ext() + view_name = self.name + if ext and ext.security and view_name: + return ext.security.data.get(view_name) + return None + + # ========================================================================= + # Helper Methods + # ========================================================================= + + def has_role(self, role: str) -> bool: + """Check if subject has a specific role. + + Args: + role: The role to check for. + + Returns: + True if the subject has the role. + """ + return role in self.roles + + def has_permission(self, perm: str) -> bool: + """Check if subject has a specific permission. + + Args: + perm: The permission to check for. + + Returns: + True if the subject has the permission. + """ + return perm in self.permissions + + def has_label(self, label: str) -> bool: + """Check if a security label is present. + + Args: + label: Label to check for (e.g., "PII", "SECRET"). + + Returns: + True if the label is present. + """ + return label in self.labels + + def has_header(self, name: str) -> bool: + """Check if an HTTP header exists (case-insensitive). + + Args: + name: Header name to check. + + Returns: + True if header exists. + """ + return self.get_header(name) is not None + + def get_header(self, name: str, default: str | None = None) -> str | None: + """Get an HTTP header value (case-insensitive). + + Args: + name: Header name. + default: Default value if header not found. + + Returns: + Header value or default. + """ + lower_name = name.lower() + for key, value in self.headers.items(): + if key.lower() == lower_name: + return value + return default + + def get_arg(self, name: str, default: Any = None) -> Any: + """Get a single argument value. + + Args: + name: Argument name. + default: Value if argument doesn't exist. + + Returns: + Argument value or default. + """ + args = self.args + if args is None: + return default + return args.get(name, default) + + def has_arg(self, name: str) -> bool: + """Check if an argument exists. + + Args: + name: Argument name to check. + + Returns: + True if argument exists. + """ + args = self.args + return args is not None and name in args + + def matches_uri_pattern(self, pattern: str) -> bool: + """Check if URI matches a glob-style pattern. + + Supports * (single segment) and ** (any number of segments) + wildcards. + + Args: + pattern: Glob pattern to match against. + + Returns: + True if URI matches the pattern. + """ + view_uri = self.uri + if not view_uri: + return False + # Split on ** first, then * within each segment, escaping literals + parts = pattern.split("**") + regex_parts = [] + for part in parts: + sub_parts = part.split("*") + regex_parts.append("[^/]*".join(re.escape(s) for s in sub_parts)) + regex = f"^{'.*'.join(regex_parts)}$" + return bool(re.match(regex, view_uri)) + + def has_content(self) -> bool: + """True if scannable text content is available. + + Returns: + True if content is not None. + """ + return self.content is not None + + # ========================================================================= + # Serialization + # ========================================================================= + + def to_dict(self, include_content: bool = True, include_context: bool = True) -> dict[str, Any]: + """Serialize the view to a JSON-compatible dictionary. + + Sensitive headers (Authorization, Cookie, X-API-Key) are + automatically stripped from the serialized output. + + Args: + include_content: Include text content (may be large). + include_context: Include extensions context. + + Returns: + JSON-serializable dictionary with view properties. + """ + result: dict[str, Any] = { + "kind": self._kind.value, + "role": self._message.role.value, + "is_pre": self.is_pre, + "is_post": self.is_post, + "action": self.action.value, + } + + if self._hook is not None: + result["hook"] = self._hook + + if self.uri: + result["uri"] = self.uri + if self.name: + result["name"] = self.name + + if include_content: + text = self.content + if text is not None: + result["content"] = text + result["size_bytes"] = len(text.encode("utf-8")) + else: + size = self.size_bytes + if size is not None: + result["size_bytes"] = size + + if self.mime_type: + result["mime_type"] = self.mime_type + + args = self.args + if args is not None: + result["arguments"] = args + + props = self.properties + if props: + result["properties"] = props + + if include_context: + extensions: dict[str, Any] = {} + + # Subject + s = self.subject + if s: + extensions["subject"] = { + "id": s.id, + "type": s.type.value, + "roles": sorted(s.roles), + "permissions": sorted(s.permissions), + "teams": sorted(s.teams), + } + + # Environment + env = self.environment + if env: + extensions["environment"] = env + + # Labels + lbls = self.labels + if lbls: + extensions["labels"] = sorted(lbls) + + # Headers (strip sensitive) + hdrs = self.headers + if hdrs: + safe = {k: v for k, v in hdrs.items() if k.lower() not in _SENSITIVE_HEADERS} + if safe: + extensions["headers"] = safe + + # Object profile (for pre views) + obj = self.object + if obj: + extensions["object"] = { + "managed_by": obj.managed_by, + "permissions": obj.permissions, + "trust_domain": obj.trust_domain, + "data_scope": obj.data_scope, + } + + # Data policy (for post views) + dp = self.data_policy + if dp: + dp_dict: dict[str, Any] = { + "apply_labels": dp.apply_labels, + "denied_actions": dp.denied_actions, + } + if dp.allowed_actions is not None: + dp_dict["allowed_actions"] = dp.allowed_actions + if dp.retention: + dp_dict["retention"] = { + "max_age_seconds": dp.retention.max_age_seconds, + "policy": dp.retention.policy, + "delete_after": dp.retention.delete_after, + } + extensions["data"] = dp_dict + + # Agent context + ext = self._ext() + if ext and ext.agent: + agent_dict: dict[str, Any] = {} + if ext.agent.input: + agent_dict["input"] = ext.agent.input + if ext.agent.session_id: + agent_dict["session_id"] = ext.agent.session_id + if ext.agent.conversation_id: + agent_dict["conversation_id"] = ext.agent.conversation_id + if ext.agent.turn is not None: + agent_dict["turn"] = ext.agent.turn + if ext.agent.agent_id: + agent_dict["agent_id"] = ext.agent.agent_id + if ext.agent.parent_agent_id: + agent_dict["parent_agent_id"] = ext.agent.parent_agent_id + if agent_dict: + extensions["agent"] = agent_dict + + if extensions: + result["extensions"] = extensions + + return result + + def to_opa_input(self, include_content: bool = True) -> dict[str, Any]: + """Serialize to OPA-compatible input format. + + Wraps the view in the standard OPA input envelope: + {"input": {...view data...}}. + + Args: + include_content: Include text content in the input. + + Returns: + Dict in OPA input format. + """ + return {"input": self.to_dict(include_content=include_content)} + + def __repr__(self) -> str: + """String representation of the view. + + Returns: + Human-readable representation. + """ + role_part = f", role={self._message.role.value}" + uri_part = f", uri={self.uri}" if self.uri else "" + hook_part = f", hook={self._hook}" if self._hook else "" + direction = "pre" if self.is_pre else "post" if self.is_post else "?" + return f"MessageView(kind={self._kind.value}{role_part}, {direction}{uri_part}{hook_part})" + + +# --------------------------------------------------------------------------- +# View Iterator (standalone) +# --------------------------------------------------------------------------- + + +def iter_views(message: Message, hook: str | None = None) -> Iterator[MessageView]: + """Iterate over a message yielding one MessageView per content part. + + Memory-efficient: views are yielded one at a time and hold only + references to the underlying message and content part. + + This is the standalone version. Message.iter_views() delegates + to this function. + + Args: + message: The message to decompose into views. + hook: Optional hook location string (e.g., "llm_input", + "tool_post_invoke") to attach to each view. + + Yields: + A MessageView for each content part in the message. + + Examples: + >>> from cpex.framework.cmf.message import ( + ... Message, Role, TextContent, ToolCall, ToolCallContentPart, + ... ThinkingContent, + ... ) + >>> msg = Message( + ... role=Role.ASSISTANT, + ... content=[ + ... ThinkingContent(text="User wants admin users."), + ... TextContent(text="Let me look that up."), + ... ToolCallContentPart( + ... content=ToolCall( + ... tool_call_id="tc_001", + ... name="execute_sql", + ... arguments={"query": "SELECT * FROM users"}, + ... ), + ... ), + ... ], + ... ) + >>> views = list(iter_views(msg)) + >>> len(views) + 3 + >>> [(v.kind.value, v.is_pre) for v in views] + [('thinking', False), ('text', False), ('tool_call', True)] + """ + for part in message.content: + kind = _CONTENT_TYPE_TO_VIEW_KIND.get(part.content_type) + if kind is None: + logger.warning("Unknown content type %r in iter_views", part.content_type) + raise ValueError(f"Unknown content type: {part.content_type!r}") + yield MessageView(part, kind, message, hook=hook) diff --git a/cpex/framework/extensions/__init__.py b/cpex/framework/extensions/__init__.py new file mode 100644 index 0000000..8d8a67d --- /dev/null +++ b/cpex/framework/extensions/__init__.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/__init__.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Extensions Package. +Provides structured, typed extension models for identity, security, +governance, and execution context metadata. Extensions are designed +to be reusable across different payload types. +""" + +# First-Party +from cpex.framework.extensions.agent import AgentExtension, ConversationContext +from cpex.framework.extensions.completion import CompletionExtension, StopReason, TokenUsage +from cpex.framework.extensions.constants import SlotName +from cpex.framework.extensions.extensions import Extensions +from cpex.framework.extensions.framework import FrameworkExtension +from cpex.framework.extensions.http import HttpExtension +from cpex.framework.extensions.llm import LLMExtension +from cpex.framework.extensions.mcp import ( + MCPExtension, + PromptMetadata, + ResourceMetadata, + ToolMetadata, +) +from cpex.framework.extensions.provenance import ProvenanceExtension +from cpex.framework.extensions.request import RequestExtension +from cpex.framework.extensions.security import ( + DataPolicy, + ObjectSecurityProfile, + RetentionPolicy, + SecurityExtension, + SubjectExtension, + SubjectType, +) +from cpex.framework.extensions.tiers import ( + AccessPolicy, + Capability, + MutabilityTier, + TierViolationError, +) + +__all__ = [ + "AccessPolicy", + "SlotName", + "AgentExtension", + "Capability", + "CompletionExtension", + "ConversationContext", + "DataPolicy", + "Extensions", + "FrameworkExtension", + "HttpExtension", + "LLMExtension", + "MCPExtension", + "MutabilityTier", + "ObjectSecurityProfile", + "PromptMetadata", + "ProvenanceExtension", + "RequestExtension", + "ResourceMetadata", + "RetentionPolicy", + "SecurityExtension", + "StopReason", + "SubjectExtension", + "SubjectType", + "TierViolationError", + "TokenUsage", + "ToolMetadata", +] diff --git a/cpex/framework/extensions/agent.py b/cpex/framework/extensions/agent.py new file mode 100644 index 0000000..b8412f1 --- /dev/null +++ b/cpex/framework/extensions/agent.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/agent.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Agent extension models. +Carries agent execution context — session tracking, multi-agent lineage, +original user intent, and optional windowed conversation history. +Immutable tier — the user's intent and session identity must not be +modifiable by processing components. +""" + +# Standard +from typing import Any + +# Third-Party +from pydantic import BaseModel, ConfigDict, Field + + +class ConversationContext(BaseModel): + """Windowed conversation context for agent-aware processing. + + Provides a lightweight view of prior conversation history without + requiring access to the full message store. + + Attributes: + history: Windowed message history (recent turns). + summary: Summarized prior context. + topics: Extracted topics or intents. + + Examples: + >>> ctx = ConversationContext( + ... summary="User asked about quarterly revenue.", + ... topics=["revenue", "Q4"], + ... ) + >>> ctx.summary + 'User asked about quarterly revenue.' + >>> ctx.topics + ['revenue', 'Q4'] + """ + + model_config = ConfigDict(frozen=True) + + history: list[Any] = Field( + default_factory=list, + description="Windowed message history (recent turns).", + ) + summary: str | None = Field(default=None, description="Summarized prior context.") + topics: list[str] = Field(default_factory=list, description="Extracted topics or intents.") + + +class AgentExtension(BaseModel): + """Agent execution context. + + Tracks session identity, multi-agent lineage, and the original + user intent that triggered the current action. Immutable — the + processing pipeline rejects any modifications. + + Attributes: + input: Original user intent that triggered this action. + session_id: Broad session identifier. + conversation_id: Specific dialogue/task within a session. + turn: Position in conversation (0-indexed). + agent_id: Identifier of the producing agent. + parent_agent_id: Spawning agent's ID (multi-agent lineage). + conversation: Windowed conversation context. + + Examples: + >>> ext = AgentExtension( + ... input="What is the weather in London?", + ... session_id="sess-001", + ... conversation_id="conv-042", + ... turn=3, + ... agent_id="weather-agent", + ... ) + >>> ext.input + 'What is the weather in London?' + >>> ext.turn + 3 + + >>> # Multi-agent lineage + >>> child = AgentExtension( + ... agent_id="sub-agent-01", + ... parent_agent_id="weather-agent", + ... ) + >>> child.parent_agent_id + 'weather-agent' + """ + + model_config = ConfigDict(frozen=True) + + input: str | None = Field(default=None, description="Original user intent that triggered this action.") + session_id: str | None = Field(default=None, description="Broad session identifier.") + conversation_id: str | None = Field(default=None, description="Specific dialogue/task within a session.") + turn: int | None = Field(default=None, description="Position in conversation (0-indexed).") + agent_id: str | None = Field(default=None, description="Identifier of the producing agent.") + parent_agent_id: str | None = Field(default=None, description="Spawning agent's ID (multi-agent lineage).") + conversation: ConversationContext | None = Field(default=None, description="Windowed conversation context.") diff --git a/cpex/framework/extensions/completion.py b/cpex/framework/extensions/completion.py new file mode 100644 index 0000000..e4be152 --- /dev/null +++ b/cpex/framework/extensions/completion.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/completion.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Completion extension models. +Carries LLM completion information including stop reason, token usage, +model identifier, wire format, and latency. +Immutable tier — shared reference, no modifications allowed. +""" + +# Standard +from enum import Enum + +# Third-Party +from pydantic import BaseModel, ConfigDict, Field + + +class StopReason(str, Enum): + """Closed-set enumeration of completion stop reasons. + + Attributes: + END: Natural end of generation. + RETURN: Model returned a structured result. + CALL: Model made a tool call. + MAX_TOKENS: Generation stopped due to token limit. + STOP_SEQUENCE: Generation stopped at a stop sequence. + + Examples: + >>> StopReason.END + + >>> StopReason("max_tokens") + + """ + + END = "end" + RETURN = "return" + CALL = "call" + MAX_TOKENS = "max_tokens" + STOP_SEQUENCE = "stop_sequence" + + +class TokenUsage(BaseModel): + """Token consumption metrics for a completion. + + Attributes: + input_tokens: Tokens consumed by the input. + output_tokens: Tokens generated in the output. + total_tokens: Total tokens (input + output). + + Examples: + >>> usage = TokenUsage(input_tokens=150, output_tokens=50, total_tokens=200) + >>> usage.total_tokens + 200 + """ + + model_config = ConfigDict(frozen=True) + + input_tokens: int = Field(description="Tokens consumed by the input.") + output_tokens: int = Field(description="Tokens generated in the output.") + total_tokens: int = Field(description="Total tokens (input + output).") + + +class CompletionExtension(BaseModel): + """LLM completion information. + + Fields like model and stop_reason can drive policy decisions + (e.g., "only allow gpt-4 for financial queries", "flag max_tokens + responses for review"). Immutable — the processing pipeline rejects + any modifications. + + Attributes: + stop_reason: Why the model stopped. + tokens: Token counts. + model: Model identifier that generated this response. + raw_format: Original wire format (chatml, harmony, gemini, anthropic). + created_at: ISO timestamp when the message was created. + latency_ms: Response generation time in milliseconds. + + Examples: + >>> ext = CompletionExtension( + ... stop_reason=StopReason.END, + ... tokens=TokenUsage(input_tokens=100, output_tokens=50, total_tokens=150), + ... model="gpt-4o", + ... latency_ms=1200, + ... ) + >>> ext.stop_reason + + >>> ext.tokens.total_tokens + 150 + >>> ext.latency_ms + 1200 + """ + + model_config = ConfigDict(frozen=True) + + stop_reason: StopReason | None = Field(default=None, description="Why the model stopped.") + tokens: TokenUsage | None = Field(default=None, description="Token counts.") + model: str | None = Field(default=None, description="Model identifier that generated this response.") + raw_format: str | None = Field( + default=None, description="Original wire format (chatml, harmony, gemini, anthropic)." + ) + created_at: str | None = Field(default=None, description="ISO timestamp when the message was created.") + latency_ms: int | None = Field(default=None, description="Response generation time in milliseconds.") diff --git a/cpex/framework/extensions/constants.py b/cpex/framework/extensions/constants.py new file mode 100644 index 0000000..8200122 --- /dev/null +++ b/cpex/framework/extensions/constants.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/constants.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Extension constants. + +Single source of truth for slot names and field names used in the +slot registry, filter_extensions(), merge_extensions(), and +validate_tier_constraints(). Using these constants instead of bare +strings prevents typo-induced duplicates. +""" + +# Standard +from __future__ import annotations + +from enum import Enum + +# --------------------------------------------------------------------------- +# Slot Registry Names (dot-notation paths for nested sub-fields) +# --------------------------------------------------------------------------- + + +class SlotName(str, Enum): + """Canonical slot names for extension fields and sub-fields. + + Top-level names correspond to Extensions model attributes. + Dotted names represent nested sub-fields (e.g., security.subject.roles). + """ + + def __str__(self) -> str: + """Return the enum value as a plain string. + + Overrides the default ``StrEnum.__str__`` which renders as + ``ClassName.MEMBER`` in Python 3.11+. + + Returns: + The raw string value of the enum member. + """ + return self.value + + REQUEST = "request" + PROVENANCE = "provenance" + COMPLETION = "completion" + LLM = "llm" + FRAMEWORK = "framework" + MCP = "mcp" + AGENT = "agent" + HTTP = "http" + CUSTOM = "custom" + + # Security sub-fields + SECURITY_SUBJECT = "security.subject" + SECURITY_SUBJECT_ROLES = "security.subject.roles" + SECURITY_SUBJECT_TEAMS = "security.subject.teams" + SECURITY_SUBJECT_CLAIMS = "security.subject.claims" + SECURITY_SUBJECT_PERMISSIONS = "security.subject.permissions" + SECURITY_OBJECTS = "security.objects" + SECURITY_DATA = "security.data" + SECURITY_LABELS = "security.labels" + + +# --------------------------------------------------------------------------- +# Pydantic Field Name Constants +# --------------------------------------------------------------------------- +# Used as keys in model_copy(update={...}) dicts and Extensions(**fields) +# construction. These match the Pydantic model attribute names exactly. + +# Extensions model fields +FIELD_REQUEST: str = "request" +FIELD_PROVENANCE: str = "provenance" +FIELD_COMPLETION: str = "completion" +FIELD_LLM: str = "llm" +FIELD_FRAMEWORK: str = "framework" +FIELD_MCP: str = "mcp" +FIELD_AGENT: str = "agent" +FIELD_HTTP: str = "http" +FIELD_CUSTOM: str = "custom" +FIELD_SECURITY: str = "security" + +# SecurityExtension model fields +FIELD_LABELS: str = "labels" +FIELD_CLASSIFICATION: str = "classification" +FIELD_SUBJECT: str = "subject" +FIELD_OBJECTS: str = "objects" +FIELD_DATA: str = "data" + +# SubjectExtension model fields +FIELD_ROLES: str = "roles" +FIELD_TEAMS: str = "teams" +FIELD_CLAIMS: str = "claims" +FIELD_PERMISSIONS: str = "permissions" diff --git a/cpex/framework/extensions/extensions.py b/cpex/framework/extensions/extensions.py new file mode 100644 index 0000000..577b7db --- /dev/null +++ b/cpex/framework/extensions/extensions.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/extensions.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Extensions container model. +Aggregates all typed extension models into a single container that +attaches to a Message. Each extension slot corresponds to a specific +mutability tier enforced by the processing pipeline. +""" + +# Standard +from typing import Any + +# Third-Party +from pydantic import BaseModel, ConfigDict, Field + +# First-Party +from cpex.framework.extensions.agent import AgentExtension +from cpex.framework.extensions.completion import CompletionExtension +from cpex.framework.extensions.framework import FrameworkExtension +from cpex.framework.extensions.http import HttpExtension +from cpex.framework.extensions.llm import LLMExtension +from cpex.framework.extensions.mcp import MCPExtension +from cpex.framework.extensions.provenance import ProvenanceExtension +from cpex.framework.extensions.request import RequestExtension +from cpex.framework.extensions.security import SecurityExtension + + +class Extensions(BaseModel): + """Container for all typed message extensions. + + Each extension slot carries contextual metadata with an explicit + mutability tier enforced by the processing pipeline during + copy-on-write operations. + + Frozen by design — consumers must use model_copy(update={...}) + to create modified copies. + + Attributes: + request: Execution environment, request ID, timestamp, tracing (immutable). + agent: Session tracking, multi-agent lineage, user intent (immutable). + http: HTTP headers with capability-gated access (guarded). + security: Labels, classification, identity, access control, data policy (monotonic/immutable). + mcp: Tool, resource, or prompt metadata (immutable). + completion: Stop reason, token usage, model, latency (immutable). + provenance: Source, message ID, parent ID (immutable). + llm: Model identity and capabilities (immutable). + framework: Agentic framework context (immutable). + custom: Custom extensions (mutable). + + Examples: + >>> ext = Extensions( + ... request=RequestExtension( + ... environment="production", + ... request_id="req-001", + ... ), + ... llm=LLMExtension( + ... model_id="gpt-4o", + ... provider="openai", + ... ), + ... ) + >>> ext.request.environment + 'production' + >>> ext.llm.provider + 'openai' + >>> ext.security is None + True + + >>> # Frozen: modifications require model_copy + >>> updated = ext.model_copy(update={"custom": {"trace": True}}) + >>> updated.custom + {'trace': True} + >>> ext.custom is None + True + """ + + model_config = ConfigDict(frozen=True) + + request: RequestExtension | None = Field(default=None, description="Execution environment and tracing.") + agent: AgentExtension | None = Field(default=None, description="Agent execution context.") + http: HttpExtension | None = Field(default=None, description="HTTP request context.") + security: SecurityExtension | None = Field(default=None, description="Security labels and identity.") + mcp: MCPExtension | None = Field(default=None, description="MCP entity metadata.") + completion: CompletionExtension | None = Field(default=None, description="LLM completion information.") + provenance: ProvenanceExtension | None = Field(default=None, description="Origin and threading.") + llm: LLMExtension | None = Field(default=None, description="Model identity and capabilities.") + framework: FrameworkExtension | None = Field(default=None, description="Agentic framework context.") + custom: dict[str, Any] | None = Field(default=None, description="Custom extensions (mutable).") diff --git a/cpex/framework/extensions/framework.py b/cpex/framework/extensions/framework.py new file mode 100644 index 0000000..13d269b --- /dev/null +++ b/cpex/framework/extensions/framework.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/framework.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Framework extension model. +Captures the agentic framework execution environment for messages +originating from or passing through orchestration layers. +Immutable tier — shared reference, no modifications allowed. +""" + +# Standard +from typing import Any + +# Third-Party +from pydantic import BaseModel, ConfigDict, Field + + +class FrameworkExtension(BaseModel): + """Agentic framework execution context. + + Captures framework-level metadata for messages that originate + from or pass through agentic orchestration layers (LangGraph, + CrewAI, AutoGen, A2A, etc.). Immutable — the processing pipeline + rejects any modifications. + + Attributes: + framework: Framework identifier (e.g., langgraph, crewai, autogen, a2a). + framework_version: Framework version. + node_id: Framework-specific node or step identifier. + graph_id: Graph or workflow identifier. + metadata: Framework-specific metadata. + + Examples: + >>> ext = FrameworkExtension( + ... framework="langgraph", + ... framework_version="0.2.0", + ... node_id="weather_node", + ... graph_id="travel_planner", + ... ) + >>> ext.framework + 'langgraph' + >>> ext.node_id + 'weather_node' + """ + + model_config = ConfigDict(frozen=True) + + framework: str | None = Field(default=None, description="Framework identifier.") + framework_version: str | None = Field(default=None, description="Framework version.") + node_id: str | None = Field(default=None, description="Framework-specific node or step identifier.") + graph_id: str | None = Field(default=None, description="Graph or workflow identifier.") + metadata: dict[str, Any] = Field(default_factory=dict, description="Framework-specific metadata.") diff --git a/cpex/framework/extensions/http.py b/cpex/framework/extensions/http.py new file mode 100644 index 0000000..f7a9458 --- /dev/null +++ b/cpex/framework/extensions/http.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/http.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +HTTP extension model. +Carries HTTP request context with capability-gated access. +Guarded tier — readable with read_headers, writable with write_headers. +""" + +# Third-Party +from pydantic import BaseModel, ConfigDict, Field + + +class HttpExtension(BaseModel): + """HTTP request context. + + Readable with the read_headers capability, writable with + write_headers. Sensitive headers (Authorization, Cookie, X-API-Key) + are stripped when serialized for external policy engines. + + Guarded tier — the processing pipeline rejects modifications + unless the consumer holds the write_headers capability. + + Attributes: + headers: HTTP headers as key-value pairs. + + Examples: + >>> ext = HttpExtension( + ... headers={"Content-Type": "application/json", "X-Request-ID": "req-123"}, + ... ) + >>> ext.headers["Content-Type"] + 'application/json' + + >>> # Frozen: modifications require model_copy + >>> updated = ext.model_copy( + ... update={"headers": {**ext.headers, "X-Trace-ID": "trace-456"}}, + ... ) + >>> "X-Trace-ID" in updated.headers + True + """ + + model_config = ConfigDict(frozen=True) + + headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers as key-value pairs.") diff --git a/cpex/framework/extensions/llm.py b/cpex/framework/extensions/llm.py new file mode 100644 index 0000000..cfcddd2 --- /dev/null +++ b/cpex/framework/extensions/llm.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/llm.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +LLM extension model. +Carries model identity and capability metadata for routing, +policy evaluation, and audit. +Immutable tier — shared reference, no modifications allowed. +""" + +# Third-Party +from pydantic import BaseModel, ConfigDict, Field + + +class LLMExtension(BaseModel): + """Model identity and capability metadata. + + Used for routing, policy evaluation, and audit when the producing + model's identity matters independently of the completion itself. + Immutable — the processing pipeline rejects any modifications. + + Attributes: + model_id: Model identifier (e.g., gpt-4o, claude-sonnet-4-20250514). + provider: Provider name (e.g., openai, anthropic, google). + capabilities: Declared model capabilities (e.g., vision, tool_use, extended_thinking). + + Examples: + >>> ext = LLMExtension( + ... model_id="claude-sonnet-4-20250514", + ... provider="anthropic", + ... capabilities=["vision", "tool_use", "extended_thinking"], + ... ) + >>> ext.provider + 'anthropic' + >>> "tool_use" in ext.capabilities + True + """ + + model_config = ConfigDict(frozen=True) + + model_id: str | None = Field(default=None, description="Model identifier.") + provider: str | None = Field(default=None, description="Provider name.") + capabilities: list[str] = Field(default_factory=list, description="Declared model capabilities.") diff --git a/cpex/framework/extensions/mcp.py b/cpex/framework/extensions/mcp.py new file mode 100644 index 0000000..1146375 --- /dev/null +++ b/cpex/framework/extensions/mcp.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/mcp.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +MCP extension models. +Carries typed metadata about MCP entities (tools, resources, prompts) +being processed. Gives consumers access to schemas and annotations. +Immutable tier — shared reference, no modifications allowed. +""" + +# Standard +from typing import Any + +# Third-Party +from pydantic import BaseModel, ConfigDict, Field + + +class ToolMetadata(BaseModel): + """Typed metadata for an MCP tool. + + Attributes: + name: Unique tool identifier. + title: Human-readable display name. + description: Description of tool functionality. + input_schema: JSON Schema defining expected parameters. + output_schema: JSON Schema for structured output. + server_id: ID of the server providing this tool. + namespace: Tool namespace (server/origin). + annotations: MCP annotations (e.g., readOnlyHint, destructiveHint). + + Examples: + >>> meta = ToolMetadata( + ... name="get_user", + ... description="Retrieve user by ID", + ... input_schema={"type": "object", "properties": {"id": {"type": "string"}}}, + ... server_id="user-service", + ... ) + >>> meta.name + 'get_user' + >>> meta.server_id + 'user-service' + """ + + model_config = ConfigDict(frozen=True) + + name: str = Field(description="Unique tool identifier.") + title: str | None = Field(default=None, description="Human-readable display name.") + description: str | None = Field(default=None, description="Description of tool functionality.") + input_schema: dict[str, Any] | None = Field(default=None, description="JSON Schema defining expected parameters.") + output_schema: dict[str, Any] | None = Field(default=None, description="JSON Schema for structured output.") + server_id: str | None = Field(default=None, description="ID of the server providing this tool.") + namespace: str | None = Field(default=None, description="Tool namespace (server/origin).") + annotations: dict[str, Any] = Field(default_factory=dict, description="MCP annotations.") + + +class ResourceMetadata(BaseModel): + """Typed metadata for an MCP resource. + + Attributes: + uri: Resource URI. + name: Resource name. + description: Resource description. + mime_type: MIME type (text/csv, application/json, etc.). + server_id: ID of the server providing this resource. + annotations: MCP annotations (classification, retention, access hints). + + Examples: + >>> meta = ResourceMetadata( + ... uri="file:///data/report.csv", + ... name="Quarterly Report", + ... mime_type="text/csv", + ... ) + >>> meta.uri + 'file:///data/report.csv' + """ + + model_config = ConfigDict(frozen=True) + + uri: str = Field(description="Resource URI.") + name: str | None = Field(default=None, description="Resource name.") + description: str | None = Field(default=None, description="Resource description.") + mime_type: str | None = Field(default=None, description="MIME type.") + server_id: str | None = Field(default=None, description="ID of the server providing this resource.") + annotations: dict[str, Any] = Field(default_factory=dict, description="MCP annotations.") + + +class PromptMetadata(BaseModel): + """Typed metadata for an MCP prompt template. + + Prompts use an argument list rather than JSON Schema for input + definition, following the MCP prompt specification. There is no + output schema — prompt output is always rendered messages. + + Attributes: + name: Prompt template name. + description: Prompt description. + arguments: Argument definitions (each has name, description, required). + server_id: ID of the server providing this prompt. + annotations: MCP annotations. + + Examples: + >>> meta = PromptMetadata( + ... name="summarize", + ... description="Summarize a document", + ... arguments=[ + ... {"name": "text", "description": "Text to summarize", "required": True}, + ... ], + ... ) + >>> meta.name + 'summarize' + >>> meta.arguments[0]["name"] + 'text' + """ + + model_config = ConfigDict(frozen=True) + + name: str = Field(description="Prompt template name.") + description: str | None = Field(default=None, description="Prompt description.") + arguments: list[dict[str, Any]] | None = Field(default=None, description="Argument definitions.") + server_id: str | None = Field(default=None, description="ID of the server providing this prompt.") + annotations: dict[str, Any] = Field(default_factory=dict, description="MCP annotations.") + + +class MCPExtension(BaseModel): + """Typed metadata about the MCP entity being processed. + + Exactly one of tool, resource, or prompt is populated per message, + depending on the content type. Immutable — the processing pipeline + rejects any modifications. + + Attributes: + tool: Tool metadata (populated for tool_call / tool_result content). + resource: Resource metadata (populated for resource / resource_ref content). + prompt: Prompt metadata (populated for prompt_request / prompt_result content). + + Examples: + >>> ext = MCPExtension( + ... tool=ToolMetadata(name="get_user", description="Retrieve user by ID"), + ... ) + >>> ext.tool.name + 'get_user' + >>> ext.resource is None + True + """ + + model_config = ConfigDict(frozen=True) + + tool: ToolMetadata | None = Field(default=None, description="Tool metadata.") + resource: ResourceMetadata | None = Field(default=None, description="Resource metadata.") + prompt: PromptMetadata | None = Field(default=None, description="Prompt metadata.") diff --git a/cpex/framework/extensions/provenance.py b/cpex/framework/extensions/provenance.py new file mode 100644 index 0000000..f9a5e46 --- /dev/null +++ b/cpex/framework/extensions/provenance.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/provenance.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Provenance extension model. +Carries origin and threading information for lineage tracking +across multi-turn conversations and multi-agent systems. +Immutable tier — shared reference, no modifications allowed. +""" + +# Third-Party +from pydantic import BaseModel, ConfigDict, Field + + +class ProvenanceExtension(BaseModel): + """Origin and threading information for the message. + + Enables lineage tracking across multi-turn conversations and + multi-agent systems. Immutable — the processing pipeline rejects + any modifications. + + Attributes: + source: Source identifier (e.g., "user", "agent:xyz", "mcp-server:abc"). + message_id: Unique message identifier. + parent_id: Parent message ID (threading/replies). + + Examples: + >>> ext = ProvenanceExtension( + ... source="agent:weather-bot", + ... message_id="msg-001", + ... parent_id="msg-000", + ... ) + >>> ext.source + 'agent:weather-bot' + >>> ext.message_id + 'msg-001' + """ + + model_config = ConfigDict(frozen=True) + + source: str | None = Field(default=None, description="Source identifier.") + message_id: str | None = Field(default=None, description="Unique message identifier.") + parent_id: str | None = Field(default=None, description="Parent message ID (threading/replies).") diff --git a/cpex/framework/extensions/request.py b/cpex/framework/extensions/request.py new file mode 100644 index 0000000..a1fa48c --- /dev/null +++ b/cpex/framework/extensions/request.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/request.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Request extension model. +Carries execution environment and request-level timing/tracing metadata. +Immutable tier — shared reference, no modifications allowed. +""" + +# Third-Party +from pydantic import BaseModel, ConfigDict, Field + + +class RequestExtension(BaseModel): + """Execution environment and request-level timing/tracing. + + Available to all consumers without any capability requirement (base tier). + Immutable — the processing pipeline rejects any modifications. + + Attributes: + environment: Execution environment (production, staging, dev). + request_id: Request correlation ID. + timestamp: ISO timestamp of the request. + trace_id: Distributed tracing ID (OpenTelemetry). + span_id: Distributed tracing span ID. + + Examples: + >>> ext = RequestExtension( + ... environment="production", + ... request_id="req-abc-123", + ... timestamp="2025-01-15T10:30:00Z", + ... ) + >>> ext.environment + 'production' + >>> ext.request_id + 'req-abc-123' + + >>> # Frozen: modifications require model_copy + >>> updated = ext.model_copy(update={"span_id": "span-456"}) + >>> updated.span_id + 'span-456' + >>> ext.span_id is None + True + """ + + model_config = ConfigDict(frozen=True) + + environment: str | None = Field(default=None, description="Execution environment (production, staging, dev).") + request_id: str | None = Field(default=None, description="Request correlation ID.") + timestamp: str | None = Field(default=None, description="ISO timestamp of the request.") + trace_id: str | None = Field(default=None, description="Distributed tracing ID (OpenTelemetry).") + span_id: str | None = Field(default=None, description="Distributed tracing span ID.") diff --git a/cpex/framework/extensions/security.py b/cpex/framework/extensions/security.py new file mode 100644 index 0000000..f13e7d5 --- /dev/null +++ b/cpex/framework/extensions/security.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/security.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Security extension models. +Carries data classification, security labels, authenticated identity, +access control profiles, and data governance policies. + +The SecurityExtension itself is monotonic tier — labels can only be +added, never removed, during normal message flow. Its nested fields +(subject, objects, data) are immutable tier. +""" + +# Standard +from enum import Enum + +# Third-Party +from pydantic import BaseModel, ConfigDict, Field + +# --------------------------------------------------------------------------- +# Subject +# --------------------------------------------------------------------------- + + +class SubjectType(str, Enum): + """Closed-set enumeration of subject types. + + Attributes: + USER: Human user. + AGENT: Autonomous agent. + SERVICE: Backend service. + SYSTEM: System-level principal. + + Examples: + >>> SubjectType.USER + + >>> SubjectType("agent") + + """ + + USER = "user" + AGENT = "agent" + SERVICE = "service" + SYSTEM = "system" + + +class SubjectExtension(BaseModel): + """Authenticated entity making the request. + + Access to individual fields is controlled by declared capabilities + on the MessageView. Immutable — the processing pipeline rejects + any modifications. + + Attributes: + id: Unique subject identifier. + type: Subject kind. + roles: Assigned roles (developer, admin, viewer, etc.). + permissions: Granted permissions (tools.execute, db.read, etc.). + teams: Team memberships (for multi-tenant scoping). + claims: Raw identity claims (JWT, SAML). + + Examples: + >>> subject = SubjectExtension( + ... id="user-alice", + ... type=SubjectType.USER, + ... roles={"admin", "developer"}, + ... permissions={"tools.execute", "db.read"}, + ... ) + >>> subject.id + 'user-alice' + >>> "admin" in subject.roles + True + >>> "db.read" in subject.permissions + True + """ + + model_config = ConfigDict(frozen=True) + + id: str = Field(description="Unique subject identifier.") + type: SubjectType = Field(description="Subject kind.") + roles: frozenset[str] = Field(default_factory=frozenset, description="Assigned roles.") + permissions: frozenset[str] = Field(default_factory=frozenset, description="Granted permissions.") + teams: frozenset[str] = Field(default_factory=frozenset, description="Team memberships.") + claims: dict[str, str] = Field(default_factory=dict, description="Raw identity claims (JWT, SAML).") + + +# --------------------------------------------------------------------------- +# Object Security Profile +# --------------------------------------------------------------------------- + + +class ObjectSecurityProfile(BaseModel): + """Access control contract declared by or for an object. + + Lives on extensions.security.objects, keyed by entity name/URI. + Evaluated on pre-hook views (tool_call, resource request, prompt + request). Immutable — the processing pipeline rejects any + modifications. + + Attributes: + managed_by: Who enforces access control: host, tool, or both. + permissions: Required permissions to invoke. + trust_domain: Trust domain: internal, external, or privileged. + data_scope: Field names this entity accesses/returns. + + Examples: + >>> profile = ObjectSecurityProfile( + ... managed_by="tool", + ... permissions=["read:compensation"], + ... trust_domain="internal", + ... data_scope=["salary", "bonus"], + ... ) + >>> profile.managed_by + 'tool' + >>> "read:compensation" in profile.permissions + True + """ + + model_config = ConfigDict(frozen=True) + + managed_by: str = Field(default="host", description="Who enforces access control: host, tool, or both.") + permissions: list[str] = Field(default_factory=list, description="Required permissions to invoke.") + trust_domain: str | None = Field(default=None, description="Trust domain: internal, external, or privileged.") + data_scope: list[str] = Field(default_factory=list, description="Field names this entity accesses/returns.") + + +# --------------------------------------------------------------------------- +# Data Policy +# --------------------------------------------------------------------------- + + +class RetentionPolicy(BaseModel): + """Data retention constraints. + + Attributes: + max_age_seconds: Maximum retention duration in seconds. + policy: Retention class: session, transient, persistent, or none. + delete_after: ISO timestamp after which data must be deleted. + + Examples: + >>> ret = RetentionPolicy(policy="session", max_age_seconds=3600) + >>> ret.policy + 'session' + >>> ret.max_age_seconds + 3600 + """ + + model_config = ConfigDict(frozen=True) + + max_age_seconds: int | None = Field(default=None, description="Maximum retention duration in seconds.") + policy: str = Field(default="persistent", description="Retention class: session, transient, persistent, none.") + delete_after: str | None = Field(default=None, description="ISO timestamp after which data must be deleted.") + + +class DataPolicy(BaseModel): + """Data governance policy for data returned by an entity. + + Lives on extensions.security.data, keyed by entity name/URI. + Enforced on post-hook views (tool_result, resource response, + prompt result). Always enforced by the gateway — the tool + declares, the framework enforces. Immutable — the processing + pipeline rejects any modifications. + + Attributes: + apply_labels: Labels to stamp on output (PII, financial, etc.). + allowed_actions: What downstream can do. None means unrestricted. + denied_actions: What downstream cannot do (export, forward, log_raw). + retention: How long data can be kept. + + Examples: + >>> policy = DataPolicy( + ... apply_labels=["PII", "financial"], + ... denied_actions=["export", "forward", "log_raw"], + ... retention=RetentionPolicy(policy="session", max_age_seconds=7200), + ... ) + >>> "PII" in policy.apply_labels + True + >>> policy.retention.policy + 'session' + """ + + model_config = ConfigDict(frozen=True) + + apply_labels: list[str] = Field(default_factory=list, description="Labels to stamp on output.") + allowed_actions: list[str] | None = Field( + default=None, description="What downstream can do. None means unrestricted." + ) + denied_actions: list[str] = Field(default_factory=list, description="What downstream cannot do.") + retention: RetentionPolicy | None = Field(default=None, description="How long data can be kept.") + + +# --------------------------------------------------------------------------- +# SecurityExtension +# --------------------------------------------------------------------------- + + +class SecurityExtension(BaseModel): + """Data classification, security labels, and security-relevant context. + + Monotonic tier for labels — labels can only be added, never removed, + during normal message flow. Removal requires a privileged + declassification operation that is audited separately. The nested + fields (subject, objects, data) are immutable. + + Attributes: + labels: Security/data labels (PII, CONFIDENTIAL, SECRET, etc.). + classification: Data classification level. + subject: Authenticated identity. + objects: Access control profiles, keyed by entity identifier. + data: Data governance policies, keyed by entity identifier. + + Examples: + >>> ext = SecurityExtension( + ... labels=frozenset({"PII", "CONFIDENTIAL"}), + ... classification="confidential", + ... subject=SubjectExtension( + ... id="user-alice", + ... type=SubjectType.USER, + ... roles=frozenset({"admin"}), + ... ), + ... ) + >>> "PII" in ext.labels + True + >>> ext.subject.id + 'user-alice' + + >>> # Monotonic label addition via model_copy + >>> updated = ext.model_copy(update={"labels": ext.labels | frozenset({"financial"})}) + >>> "financial" in updated.labels + True + >>> "PII" in updated.labels + True + """ + + model_config = ConfigDict(frozen=True) + + labels: frozenset[str] = Field(default_factory=frozenset, description="Security/data labels.") + classification: str | None = Field(default=None, description="Data classification level.") + subject: SubjectExtension | None = Field(default=None, description="Authenticated identity.") + objects: dict[str, ObjectSecurityProfile] = Field( + default_factory=dict, description="Access control profiles, keyed by entity identifier." + ) + data: dict[str, DataPolicy] = Field( + default_factory=dict, description="Data governance policies, keyed by entity identifier." + ) diff --git a/cpex/framework/extensions/tiers.py b/cpex/framework/extensions/tiers.py new file mode 100644 index 0000000..dcb6c0b --- /dev/null +++ b/cpex/framework/extensions/tiers.py @@ -0,0 +1,594 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/extensions/tiers.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Extension mutability tiers and capability-gated access. + +Defines the three mutability tiers (immutable, monotonic, mutable), +the capability enum for gating extension visibility and writability, +and the slot registry that maps each extension slot to its policy. + +Provides filter_extensions() for pre-hook capability filtering and +validate_tier_constraints() for post-hook tier enforcement. +""" + +# Standard +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from types import MappingProxyType +from typing import Any, Mapping + +from cpex.framework.extensions.constants import ( + FIELD_AGENT, + FIELD_CLAIMS, + FIELD_CLASSIFICATION, + FIELD_COMPLETION, + FIELD_CUSTOM, + FIELD_DATA, + FIELD_FRAMEWORK, + FIELD_HTTP, + FIELD_LABELS, + FIELD_LLM, + FIELD_MCP, + FIELD_OBJECTS, + FIELD_PERMISSIONS, + FIELD_PROVENANCE, + FIELD_REQUEST, + FIELD_ROLES, + FIELD_SECURITY, + FIELD_SUBJECT, + FIELD_TEAMS, + SlotName, +) +from cpex.framework.extensions.extensions import Extensions +from cpex.framework.extensions.security import SecurityExtension, SubjectExtension + + +class MutabilityTier(str, Enum): + """Mutability tier for an extension slot. + + Attributes: + IMMUTABLE: Set once, never changed. Pipeline rejects any delta. + MONOTONIC: Can only grow (add elements). Pipeline validates + before <= after. + MUTABLE: Freely modifiable through COW. + """ + + IMMUTABLE = "immutable" + MONOTONIC = "monotonic" + MUTABLE = "mutable" + + +class Capability(str, Enum): + """Declared capabilities that a plugin can request. + + Controls visibility (read) and writability (write/append) of + extension slots. Write/append capabilities imply their + corresponding read capability. + + Attributes: + READ_SUBJECT: Access to subject.id and subject.type. + READ_ROLES: Access to subject.roles. + READ_TEAMS: Access to subject.teams. + READ_CLAIMS: Access to subject.claims. + READ_PERMISSIONS: Access to subject.permissions. + READ_AGENT: Access to AgentExtension. + READ_HEADERS: Read access to HTTP headers. + WRITE_HEADERS: Read + write access to HTTP headers. + READ_LABELS: Read access to security labels. + APPEND_LABELS: Read + append-only access to security labels. + """ + + READ_SUBJECT = "read_subject" + READ_ROLES = "read_roles" + READ_TEAMS = "read_teams" + READ_CLAIMS = "read_claims" + READ_PERMISSIONS = "read_permissions" + READ_AGENT = "read_agent" + READ_HEADERS = "read_headers" + WRITE_HEADERS = "write_headers" + READ_LABELS = "read_labels" + APPEND_LABELS = "append_labels" + + +# Write/append capabilities that imply their read counterpart. +_WRITE_IMPLIES_READ: dict[Capability, Capability] = { + Capability.WRITE_HEADERS: Capability.READ_HEADERS, + Capability.APPEND_LABELS: Capability.READ_LABELS, +} + +# Subject sub-field capabilities that imply read_subject. +_SUBJECT_IMPLIES_READ: frozenset[Capability] = frozenset( + { + Capability.READ_ROLES, + Capability.READ_TEAMS, + Capability.READ_CLAIMS, + Capability.READ_PERMISSIONS, + } +) + + +class AccessPolicy(str, Enum): + """Declares whether an extension slot requires capabilities for visibility. + + Attributes: + UNRESTRICTED: Visible to all plugins regardless of capabilities. + CAPABILITY_GATED: Requires a declared capability for visibility. + """ + + UNRESTRICTED = "unrestricted" + CAPABILITY_GATED = "capability_gated" + + +@dataclass(frozen=True) +class SlotPolicy: + """Policy for a single extension slot or sub-field. + + Attributes: + tier: The mutability tier. + access: Whether the slot is unrestricted or capability-gated. + read_cap: Capability required to see this slot (when capability-gated). + write_cap: Capability required to modify this slot. + None means no mutation path exists. + """ + + tier: MutabilityTier + access: AccessPolicy = AccessPolicy.UNRESTRICTED + read_cap: Capability | None = None + write_cap: Capability | None = None + + +# --------------------------------------------------------------------------- +# Slot Registry — single source of truth (internal only) +# --------------------------------------------------------------------------- + +_SLOT_REGISTRY: dict[str, SlotPolicy] = { + # Unrestricted — always visible, always immutable + SlotName.REQUEST: SlotPolicy(MutabilityTier.IMMUTABLE), + SlotName.PROVENANCE: SlotPolicy(MutabilityTier.IMMUTABLE), + SlotName.COMPLETION: SlotPolicy(MutabilityTier.IMMUTABLE), + SlotName.LLM: SlotPolicy(MutabilityTier.IMMUTABLE), + SlotName.FRAMEWORK: SlotPolicy(MutabilityTier.IMMUTABLE), + SlotName.MCP: SlotPolicy(MutabilityTier.IMMUTABLE), + # Capability-gated, immutable + SlotName.AGENT: SlotPolicy( + MutabilityTier.IMMUTABLE, + access=AccessPolicy.CAPABILITY_GATED, + read_cap=Capability.READ_AGENT, + ), + # Subject — granular sub-field gating + SlotName.SECURITY_SUBJECT: SlotPolicy( + MutabilityTier.IMMUTABLE, + access=AccessPolicy.CAPABILITY_GATED, + read_cap=Capability.READ_SUBJECT, + ), + SlotName.SECURITY_SUBJECT_ROLES: SlotPolicy( + MutabilityTier.IMMUTABLE, + access=AccessPolicy.CAPABILITY_GATED, + read_cap=Capability.READ_ROLES, + ), + SlotName.SECURITY_SUBJECT_TEAMS: SlotPolicy( + MutabilityTier.IMMUTABLE, + access=AccessPolicy.CAPABILITY_GATED, + read_cap=Capability.READ_TEAMS, + ), + SlotName.SECURITY_SUBJECT_CLAIMS: SlotPolicy( + MutabilityTier.IMMUTABLE, + access=AccessPolicy.CAPABILITY_GATED, + read_cap=Capability.READ_CLAIMS, + ), + SlotName.SECURITY_SUBJECT_PERMISSIONS: SlotPolicy( + MutabilityTier.IMMUTABLE, + access=AccessPolicy.CAPABILITY_GATED, + read_cap=Capability.READ_PERMISSIONS, + ), + # Unrestricted — always visible sub-fields + SlotName.SECURITY_OBJECTS: SlotPolicy(MutabilityTier.IMMUTABLE), + SlotName.SECURITY_DATA: SlotPolicy(MutabilityTier.IMMUTABLE), + # Security labels — monotonic, capability-gated + SlotName.SECURITY_LABELS: SlotPolicy( + MutabilityTier.MONOTONIC, + access=AccessPolicy.CAPABILITY_GATED, + read_cap=Capability.READ_LABELS, + write_cap=Capability.APPEND_LABELS, + ), + # HTTP — capability-gated, writable with write cap + SlotName.HTTP: SlotPolicy( + MutabilityTier.IMMUTABLE, + access=AccessPolicy.CAPABILITY_GATED, + read_cap=Capability.READ_HEADERS, + write_cap=Capability.WRITE_HEADERS, + ), + # Unrestricted, mutable — no capability gate + SlotName.CUSTOM: SlotPolicy(MutabilityTier.MUTABLE), +} + +# Read-only view — prevents mutation even if imported directly +_slot_registry: Mapping[str, SlotPolicy] = MappingProxyType(_SLOT_REGISTRY) + + +def _has_read_access(policy: SlotPolicy, capabilities: frozenset[str]) -> bool: + """Check if a plugin has read access to a slot. + + A plugin has read access if: + - The slot has no read_cap (base tier, always visible), OR + - The plugin holds the read_cap, OR + - The plugin holds a write_cap that implies the read_cap, OR + - For subject sub-fields: any subject sub-field cap implies + read_subject. + """ + if policy.access == AccessPolicy.UNRESTRICTED: + return True + if policy.read_cap.value in capabilities: + return True + # Check if any held write cap implies this read cap + for write_cap, implied_read in _WRITE_IMPLIES_READ.items(): + if implied_read == policy.read_cap and write_cap.value in capabilities: + return True + # Check if any subject sub-field cap implies read_subject + if policy.read_cap == Capability.READ_SUBJECT: + for sub_cap in _SUBJECT_IMPLIES_READ: + if sub_cap.value in capabilities: + return True + return False + + +def _has_subject_access(capabilities: frozenset[str]) -> bool: + """Check if a plugin has any subject-related capability.""" + if Capability.READ_SUBJECT.value in capabilities: + return True + for sub_cap in _SUBJECT_IMPLIES_READ: + if sub_cap.value in capabilities: + return True + return False + + +# --------------------------------------------------------------------------- +# Extension Filtering +# --------------------------------------------------------------------------- + + +def _build_filtered_subject( + subject: SubjectExtension, + capabilities: frozenset[str], +) -> SubjectExtension: + """Build a filtered SubjectExtension containing only accessible fields. + + Always includes id and type (base subject access). Individual + sub-fields (roles, teams, claims, permissions) are only populated + if the plugin holds the corresponding capability. + """ + return subject.model_copy( + update={ + FIELD_ROLES: subject.roles if Capability.READ_ROLES.value in capabilities else frozenset(), + FIELD_TEAMS: subject.teams if Capability.READ_TEAMS.value in capabilities else frozenset(), + FIELD_CLAIMS: subject.claims if Capability.READ_CLAIMS.value in capabilities else {}, + FIELD_PERMISSIONS: ( + subject.permissions if Capability.READ_PERMISSIONS.value in capabilities else frozenset() + ), + } + ) + + +def _build_filtered_security( + sec: SecurityExtension, + capabilities: frozenset[str], +) -> SecurityExtension: + """Build a filtered SecurityExtension containing only accessible fields. + + Unrestricted sub-fields (objects, data, classification) are always + included. Capability-gated sub-fields (labels, subject) are only + populated if the plugin holds the required capability. + """ + fields: dict[str, Any] = { + # Unrestricted — always included + FIELD_OBJECTS: sec.objects, + FIELD_DATA: sec.data, + FIELD_CLASSIFICATION: sec.classification, + } + + # Labels — capability-gated + if _has_read_access(_slot_registry[SlotName.SECURITY_LABELS], capabilities): + fields[FIELD_LABELS] = sec.labels + else: + fields[FIELD_LABELS] = frozenset() + + # Subject — granular capability-gated + if sec.subject is not None and _has_subject_access(capabilities): + fields[FIELD_SUBJECT] = _build_filtered_subject(sec.subject, capabilities) + else: + fields[FIELD_SUBJECT] = None + + return sec.model_copy(update=fields) + + +def filter_extensions( + extensions: Extensions | None, + capabilities: frozenset[str], +) -> Extensions | None: + """Build a new Extensions containing only slots the plugin can access. + + Starts from an empty Extensions and copies in only the slots the + plugin has read access to. Slots not explicitly included are left + as None (the default). This is secure by default — if a new slot + is added to Extensions but not registered here, it remains hidden. + + For the security extension, filtering is granular: unrestricted + sub-fields (objects, data) are always included, while labels and + subject sub-fields are gated by their respective capabilities. + + Args: + extensions: The source Extensions model instance (or None). + capabilities: Plugin's declared capability strings. + + Returns: + A new frozen Extensions with only accessible slots populated, + or None if input was None. + """ + if extensions is None: + return None + + fields: dict[str, Any] = {} + + # Unrestricted top-level slots — always included when present + if extensions.request is not None: + fields[FIELD_REQUEST] = extensions.request + if extensions.provenance is not None: + fields[FIELD_PROVENANCE] = extensions.provenance + if extensions.completion is not None: + fields[FIELD_COMPLETION] = extensions.completion + if extensions.llm is not None: + fields[FIELD_LLM] = extensions.llm + if extensions.framework is not None: + fields[FIELD_FRAMEWORK] = extensions.framework + if extensions.mcp is not None: + fields[FIELD_MCP] = extensions.mcp + if extensions.custom is not None: + fields[FIELD_CUSTOM] = extensions.custom + + # Capability-gated top-level slots — included only with access + if extensions.agent is not None: + if _has_read_access(_slot_registry[SlotName.AGENT], capabilities): + fields[FIELD_AGENT] = extensions.agent + + if extensions.http is not None: + if _has_read_access(_slot_registry[SlotName.HTTP], capabilities): + fields[FIELD_HTTP] = extensions.http + + # Security — granular sub-field filtering + if extensions.security is not None: + fields[FIELD_SECURITY] = _build_filtered_security(extensions.security, capabilities) + + return Extensions(**fields) + + +# --------------------------------------------------------------------------- +# Tier Validation +# --------------------------------------------------------------------------- + + +class TierViolationError(Exception): + """Raised when a plugin violates a mutability tier constraint. + + Attributes: + plugin_name: Name of the offending plugin. + slot: The extension slot that was violated. + tier: The mutability tier of the slot. + detail: Description of the violation. + """ + + def __init__( + self, + plugin_name: str, + slot: str, + tier: MutabilityTier, + detail: str, + ) -> None: + """Initialise a tier violation error. + + Args: + plugin_name: Name of the offending plugin. + slot: The extension slot that was violated. + tier: The mutability tier of the slot. + detail: Description of the violation. + """ + self.plugin_name = plugin_name + self.slot = slot + self.tier = tier + self.detail = detail + super().__init__(f"Plugin '{plugin_name}' violated {tier.value} tier on '{slot}': {detail}") + + +def _resolve_slot(ext: Extensions | None, dot_path: str) -> Any: + """Resolve a dot-notation slot path to its value.""" + if ext is None: + return None + obj: Any = ext + for part in dot_path.split("."): + if obj is None: + return None + obj = getattr(obj, part, None) + return obj + + +def _is_monotonic_superset(before: Any, after: Any) -> bool: + """Check that after is a superset of before for monotonic validation.""" + if before is None or (isinstance(before, frozenset) and len(before) == 0): + return True + if after is None: + return isinstance(before, frozenset) and len(before) == 0 + if isinstance(before, frozenset) and isinstance(after, frozenset): + return before <= after + return before == after + + +def validate_tier_constraints( + before: Extensions | None, + after: Extensions | None, + capabilities: frozenset[str], + plugin_name: str, +) -> None: + """Validate that tier constraints were respected after a plugin transform. + + Compares the original (unfiltered) extensions against the modified + extensions. Raises TierViolationError on any violation. + + Args: + before: Original Extensions before plugin execution. + after: Extensions after plugin execution. + capabilities: Plugin's declared capability strings. + plugin_name: Name of the plugin (for error messages). + + Raises: + TierViolationError: If a tier constraint was violated. + """ + if before is None and after is None: + return + + for slot_name, policy in _slot_registry.items(): + before_val = _resolve_slot(before, slot_name) + after_val = _resolve_slot(after, slot_name) + + # No change — always fine + if before_val == after_val: + continue + + # Mutable tier with no write_cap — freely modifiable + if policy.tier == MutabilityTier.MUTABLE and policy.write_cap is None: + continue + + # Something changed — check if mutation is allowed + if policy.write_cap is None: + raise TierViolationError( + plugin_name, + slot_name, + policy.tier, + "slot has no write capability and cannot be modified", + ) + + if policy.write_cap.value not in capabilities: + raise TierViolationError( + plugin_name, + slot_name, + policy.tier, + f"plugin lacks '{policy.write_cap.value}' capability", + ) + + # Plugin has write capability — check tier-specific constraints + if policy.tier == MutabilityTier.MONOTONIC: + if not _is_monotonic_superset(before_val, after_val): + raise TierViolationError( + plugin_name, + slot_name, + policy.tier, + "monotonic slot had elements removed", + ) + + +# --------------------------------------------------------------------------- +# Selective Merge +# --------------------------------------------------------------------------- + + +def _merge_security( + original: SecurityExtension, + plugin_sec: SecurityExtension | None, + capabilities: frozenset[str], + plugin_name: str, +) -> SecurityExtension | None: + """Accept writable security changes back into the original. + + - subject, objects, data, classification: immutable — ignored. + - labels: monotonic — accepted only if the plugin holds + append_labels and the result is a superset of the original. + + Returns None if nothing changed (caller should skip the update). + """ + if plugin_sec is None: + return None + + # Labels — monotonic, capability-gated + if Capability.APPEND_LABELS.value in capabilities and plugin_sec.labels != original.labels: + if not _is_monotonic_superset(original.labels, plugin_sec.labels): + raise TierViolationError( + plugin_name, + SlotName.SECURITY_LABELS, + MutabilityTier.MONOTONIC, + "monotonic slot had elements removed", + ) + return original.model_copy(update={FIELD_LABELS: plugin_sec.labels}) + + return None + + +def merge_extensions( + original: Extensions | None, + plugin_output: Extensions | None, + capabilities: frozenset[str], + plugin_name: str, +) -> Extensions | None: + """Merge accepted plugin changes back into the original Extensions. + + Only writable slots are read from the plugin's output: + + - **Immutable** slots (request, provenance, agent, etc.) are + ignored — the original values are preserved. + - **Monotonic** slots (security.labels) are accepted only when + the plugin holds the write capability and the result is a + superset of the original. + - **Mutable** slots (custom) are accepted unconditionally. + - **Guarded-writable** slots (http) are accepted only when the + plugin holds the write capability. + + If nothing changed, the original object is returned as-is. + + This is the complement of ``filter_extensions`` (which controls + what a plugin *sees*). ``merge_extensions`` controls what the + manager *accepts back*. + + Args: + original: The authoritative Extensions before plugin execution. + plugin_output: The Extensions returned by the plugin. + capabilities: Plugin's declared capability strings. + plugin_name: Name of the plugin (for error messages). + + Returns: + The original Extensions with accepted changes applied via + model_copy, or the original unchanged if nothing was accepted. + + Raises: + TierViolationError: If a monotonic slot had elements removed. + """ + if original is None: + return None + if plugin_output is None: + return original + + updates: dict[str, Any] = {} + + # HTTP — writable only with write_headers capability + if ( + Capability.WRITE_HEADERS.value in capabilities + and plugin_output.http is not None + and plugin_output.http != original.http + ): + updates[FIELD_HTTP] = plugin_output.http + + # Security — mixed tiers, delegate to helper + if original.security is not None: + merged_sec = _merge_security(original.security, plugin_output.security, capabilities, plugin_name) + if merged_sec is not None: + updates[FIELD_SECURITY] = merged_sec + + # Custom — mutable, no capability gate + if plugin_output.custom != original.custom: + updates[FIELD_CUSTOM] = plugin_output.custom + + if not updates: + return original + + return original.model_copy(update=updates) diff --git a/cpex/framework/external/mcp/client.py b/cpex/framework/external/mcp/client.py index 2060076..6242794 100644 --- a/cpex/framework/external/mcp/client.py +++ b/cpex/framework/external/mcp/client.py @@ -560,6 +560,7 @@ def __init__(self, hook: str, plugin_ref: PluginRef): # pylint: disable=super-i """ self._plugin_ref = plugin_ref self._hook = hook + self._accepts_extensions = False # External plugins use invoke_hook(), not direct method calls if hasattr(plugin_ref.plugin, INVOKE_HOOK): self._func: Callable[[PluginPayload, PluginContext], Awaitable[PluginResult]] = partial( plugin_ref.plugin.invoke_hook, hook diff --git a/cpex/framework/hooks/agents.py b/cpex/framework/hooks/agents.py index f6dc00a..4a1cac4 100644 --- a/cpex/framework/hooks/agents.py +++ b/cpex/framework/hooks/agents.py @@ -10,6 +10,7 @@ """ # Standard +import warnings from enum import Enum from typing import Any, Dict, List, Optional @@ -75,6 +76,20 @@ class AgentPreInvokePayload(PluginPayload): system_prompt: Optional[str] = None parameters: Optional[Dict[str, Any]] = Field(default_factory=dict) + @field_validator("headers", mode="before") + @classmethod + def _warn_headers_deprecated(cls, v: object) -> object: + """Emit deprecation warning for headers field.""" + if v is not None: + warnings.warn( + "AgentPreInvokePayload.headers is deprecated; " + "use extensions.http.headers instead. " + "This field will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) + return v + @field_validator("messages", mode="before") @classmethod def _coerce_messages(cls, v: Any) -> Any: diff --git a/cpex/framework/hooks/http.py b/cpex/framework/hooks/http.py index e33f763..ff719c3 100644 --- a/cpex/framework/hooks/http.py +++ b/cpex/framework/hooks/http.py @@ -8,16 +8,17 @@ """ # Standard +import warnings from enum import Enum # Third-Party -from pydantic import RootModel +from pydantic import RootModel, field_validator # First-Party from cpex.framework.models import PluginPayload, PluginResult -class HttpHeaderPayload(RootModel[dict[str, str]], PluginPayload): +class HttpHeaderPayload(RootModel[dict[str, str]]): """An HTTP dictionary of headers used in the pre/post HTTP forwarding hooks.""" def __iter__(self): # type: ignore[no-untyped-def] @@ -103,6 +104,20 @@ class HttpPreRequestPayload(PluginPayload): client_port: int | None = None headers: HttpHeaderPayload + @field_validator("headers", mode="before") + @classmethod + def _warn_headers_deprecated(cls, v: object) -> object: + """Emit deprecation warning for headers field.""" + if v is not None: + warnings.warn( + "HttpPreRequestPayload.headers is deprecated; " + "use extensions.http.headers instead. " + "This field will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) + return v + class HttpPostRequestPayload(HttpPreRequestPayload): """Payload for HTTP post-request hook (middleware layer). @@ -139,6 +154,20 @@ class HttpAuthResolveUserPayload(PluginPayload): client_host: str | None = None client_port: int | None = None + @field_validator("headers", mode="before") + @classmethod + def _warn_headers_deprecated(cls, v: object) -> object: + """Emit deprecation warning for headers field.""" + if v is not None: + warnings.warn( + "HttpAuthResolveUserPayload.headers is deprecated; " + "use extensions.http.headers instead. " + "This field will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) + return v + class HttpAuthCheckPermissionPayload(PluginPayload): """Payload for permission checking hook (RBAC layer). diff --git a/cpex/framework/hooks/message.py b/cpex/framework/hooks/message.py new file mode 100644 index 0000000..b34c706 --- /dev/null +++ b/cpex/framework/hooks/message.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +"""Location: ./cpex/framework/hooks/message.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Hook definitions for CMF Message evaluation. + +Provides a unified entry point for policy evaluation on messages +flowing through the system. Plugins receive a MessagePayload +wrapping the CMF Message and can use Message.iter_views() for +granular per-content-part inspection. +""" + +# Standard +from enum import Enum + +# Third-Party +from pydantic import Field + +# First-Party +from cpex.framework.cmf.message import Message +from cpex.framework.models import PluginPayload, PluginResult + + +class MessageHookType(str, Enum): + """Message hook points. + + The hook type indicates *where* in the pipeline the evaluation + is happening, enabling plugins to register for specific locations. + + Attributes: + EVALUATE: Generic message evaluation. + LLM_INPUT: Before model/LLM call (user messages going to LLM). + LLM_OUTPUT: After model/LLM call (LLM response). + TOOL_PRE_INVOKE: Before tool execution (tool call arguments). + TOOL_POST_INVOKE: After tool execution (tool result). + PROMPT_PRE_FETCH: Before prompt template fetch. + PROMPT_POST_FETCH: After prompt template fetch. + RESOURCE_PRE_FETCH: Before resource fetch. + RESOURCE_POST_FETCH: After resource fetch. + + Examples: + >>> MessageHookType.EVALUATE + + >>> MessageHookType.LLM_INPUT + + """ + + EVALUATE = "evaluate" + LLM_INPUT = "llm_input" + LLM_OUTPUT = "llm_output" + TOOL_PRE_INVOKE = "tool_pre_invoke" + TOOL_POST_INVOKE = "tool_post_invoke" + PROMPT_PRE_FETCH = "prompt_pre_fetch" + PROMPT_POST_FETCH = "prompt_post_fetch" + RESOURCE_PRE_FETCH = "resource_pre_fetch" + RESOURCE_POST_FETCH = "resource_post_fetch" + + +class MessagePayload(PluginPayload): + """Payload for message evaluation hooks. + + Wraps a CMF Message for processing through the plugin pipeline. + Plugins access the message and use iter_views() for per-content-part + policy evaluation. + + Attributes: + message: The CMF message to evaluate. + hook: The hook location where this evaluation is happening. + + Examples: + >>> from cpex.framework.cmf.message import Message, Role, TextContent + >>> msg = Message( + ... role=Role.USER, + ... content=[TextContent(text="Hello")], + ... ) + >>> payload = MessagePayload( + ... message=msg, hook=MessageHookType.LLM_INPUT + ... ) + >>> payload.hook + + """ + + message: Message = Field(description="The CMF message to evaluate.") + hook: MessageHookType = Field( + default=MessageHookType.EVALUATE, + description="The hook location where this evaluation is happening.", + ) + + +MessageResult = PluginResult[MessagePayload] +"""Result type for message evaluation hooks.""" + + +def _register_message_hooks() -> None: + """Register message hooks in the global registry. + + Called at module load time. Idempotent — skips registration + if the hook is already registered. + """ + # First-Party + from cpex.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel + + registry = get_hook_registry() + + if not registry.is_registered(MessageHookType.EVALUATE): + registry.register_hook(MessageHookType.EVALUATE, MessagePayload, MessageResult) + + +_register_message_hooks() diff --git a/cpex/framework/hooks/tools.py b/cpex/framework/hooks/tools.py index 8cd5aa6..7b53734 100644 --- a/cpex/framework/hooks/tools.py +++ b/cpex/framework/hooks/tools.py @@ -8,11 +8,12 @@ """ # Standard +import warnings from enum import Enum from typing import Any, Optional # Third-Party -from pydantic import Field +from pydantic import Field, field_validator # First-Party from cpex.framework.hooks.http import HttpHeaderPayload @@ -70,6 +71,20 @@ class ToolPreInvokePayload(PluginPayload): args: Optional[dict[str, Any]] = Field(default_factory=dict) headers: Optional[HttpHeaderPayload] = None + @field_validator("headers", mode="before") + @classmethod + def _warn_headers_deprecated(cls, v: object) -> object: + """Emit deprecation warning for headers field.""" + if v is not None: + warnings.warn( + "ToolPreInvokePayload.headers is deprecated; " + "use extensions.http.headers instead. " + "This field will be removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) + return v + class ToolPostInvokePayload(PluginPayload): """A tool payload for a tool post-invoke hook. diff --git a/cpex/framework/loader/plugin.py b/cpex/framework/loader/plugin.py index 48248f6..7d65d23 100644 --- a/cpex/framework/loader/plugin.py +++ b/cpex/framework/loader/plugin.py @@ -92,6 +92,9 @@ def __register_plugin_type(self, kind: str) -> None: async def load_and_instantiate_plugin(self, config: PluginConfig) -> Plugin | None: """Load and instantiate a plugin, given a configuration. + The plugin receives a defensive copy of the config so it cannot + modify the authoritative config retained by the Manager/PluginRef. + For external plugins, the transport type is determined by the presence of 'mcp', 'grpc', or 'unix_socket' configuration: - If 'grpc' is set, uses GrpcExternalPlugin for gRPC transport @@ -107,6 +110,9 @@ async def load_and_instantiate_plugin(self, config: PluginConfig) -> Plugin | No Raises: ValueError: If an external plugin has no transport configured. """ + # Defensive copy — the plugin never sees the authoritative config + plugin_config = config.model_copy() + # Handle external plugins with transport selection if config.kind == EXTERNAL_PLUGIN_TYPE: plugin: Plugin @@ -118,7 +124,7 @@ async def load_and_instantiate_plugin(self, config: PluginConfig) -> Plugin | No GrpcExternalPlugin, ) # pylint: disable=import-outside-toplevel - plugin = GrpcExternalPlugin(config) + plugin = GrpcExternalPlugin(plugin_config) logger.info("Loading external plugin '%s' with gRPC transport", config.name) elif config.unix_socket: # Use raw Unix socket transport (high-performance local IPC) @@ -127,11 +133,11 @@ async def load_and_instantiate_plugin(self, config: PluginConfig) -> Plugin | No UnixSocketExternalPlugin, ) # pylint: disable=import-outside-toplevel - plugin = UnixSocketExternalPlugin(config) + plugin = UnixSocketExternalPlugin(plugin_config) logger.info("Loading external plugin '%s' with Unix socket transport", config.name) elif config.mcp: # Use MCP transport - plugin = ExternalPlugin(config) + plugin = ExternalPlugin(plugin_config) logger.info("Loading external plugin '%s' with MCP transport", config.name) else: # Defensive fallback: PluginConfig validation should prevent this path. @@ -147,7 +153,7 @@ async def load_and_instantiate_plugin(self, config: PluginConfig) -> Plugin | No self.__register_plugin_type(config.kind) plugin_type = self._plugin_types[config.kind] if plugin_type: - plugin = plugin_type(config) + plugin = plugin_type(plugin_config) await plugin.initialize() return plugin return None diff --git a/cpex/framework/manager.py b/cpex/framework/manager.py index e480ceb..a130f7e 100644 --- a/cpex/framework/manager.py +++ b/cpex/framework/manager.py @@ -38,7 +38,10 @@ # First-Party from cpex.framework.base import HookRef, Plugin +from cpex.framework.constants import EXTERNAL_PLUGIN_TYPE from cpex.framework.errors import PluginError, PluginViolationError, convert_exception_to_error +from cpex.framework.extensions.extensions import Extensions +from cpex.framework.extensions.tiers import filter_extensions from cpex.framework.hooks.policies import DefaultHookPolicy, HookPayloadPolicy, apply_policy from cpex.framework.loader.config import ConfigLoader from cpex.framework.loader.plugin import PluginLoader @@ -137,6 +140,7 @@ async def execute( hook_type: str, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False, + extensions: Optional[Extensions] = None, ) -> tuple[PluginResult, PluginContextTable | None]: """Execute plugins in priority order with timeout protection. @@ -147,6 +151,7 @@ async def execute( hook_type: The hook type identifier (e.g., "tool_pre_invoke"). local_contexts: Optional existing contexts from previous hook executions. violations_as_exceptions: Raise violations as exceptions rather than as returns. + extensions: Optional extensions to filter and pass to plugins that accept them. Returns: A tuple containing: @@ -214,6 +219,7 @@ async def execute( allow_blocking=True, fire_and_forget_refs=fire_and_forget_refs, fire_and_forget_semaphore=fire_and_forget_semaphore, + extensions=extensions, ) if halt_result is not None: return halt_result @@ -236,6 +242,7 @@ async def execute( decision_plugin_name=decision_plugin_name, apply_modifications=True, allow_blocking=False, + extensions=extensions, ) current_payload, decision_plugin_name = self._serial_phase_state @@ -255,6 +262,7 @@ async def execute( decision_plugin_name=decision_plugin_name, apply_modifications=False, allow_blocking=False, + extensions=extensions, ) current_payload, decision_plugin_name = self._serial_phase_state @@ -275,6 +283,7 @@ async def execute( violations_as_exceptions, global_context, combined_metadata, + extensions=extensions, ) if concurrent_semaphore: coro = self._with_semaphore(concurrent_semaphore, coro) @@ -317,11 +326,17 @@ async def execute( fire_and_forget_semaphore, hook_type, decision_plugin_name, + extensions=extensions, ) # FIRE_AND_FORGET: fire-and-forget background tasks (fires last with final payload snapshot) self._fire_and_forget_tasks( - fire_and_forget_refs, payload, global_context, res_local_contexts, fire_and_forget_semaphore + fire_and_forget_refs, + payload, + global_context, + res_local_contexts, + fire_and_forget_semaphore, + extensions=extensions, ) if hook_type == HTTP_AUTH_CHECK_PERMISSION_HOOK and decision_plugin_name: @@ -412,6 +427,7 @@ async def _run_serial_phase( allow_blocking: bool, fire_and_forget_refs: Optional[list[HookRef]] = None, fire_and_forget_semaphore: Optional[asyncio.Semaphore] = None, + extensions: Optional[Extensions] = None, ) -> Optional[tuple[PluginResult, PluginContextTable | None]]: """Run a serial execution phase (SEQUENTIAL, TRANSFORM, or AUDIT). @@ -449,6 +465,7 @@ async def _run_serial_phase( violations_as_exceptions, global_context, combined_metadata, + extensions=extensions, ) if result.modified_payload is not None: @@ -494,6 +511,7 @@ async def _run_serial_phase( fire_and_forget_semaphore, hook_type, decision_plugin_name, + extensions=extensions, ) else: logger.warning( @@ -628,10 +646,16 @@ def _build_halt_result( fire_and_forget_semaphore: Optional[asyncio.Semaphore], hook_type: str, decision_plugin_name: Optional[str], + extensions: Optional[Extensions] = None, ) -> tuple[PluginResult, dict]: """Schedule fire-and-forget tasks and build a pipeline-halting result.""" self._fire_and_forget_tasks( - fire_and_forget_refs, payload, global_context, res_local_contexts, fire_and_forget_semaphore + fire_and_forget_refs, + payload, + global_context, + res_local_contexts, + fire_and_forget_semaphore, + extensions=extensions, ) if hook_type == HTTP_AUTH_CHECK_PERMISSION_HOOK and decision_plugin_name: combined_metadata[DECISION_PLUGIN_METADATA_KEY] = decision_plugin_name @@ -664,6 +688,7 @@ def _fire_and_forget_tasks( global_context: GlobalContext, res_local_contexts: dict, semaphore: Optional[asyncio.Semaphore], + extensions: Optional[Extensions] = None, ) -> None: """Schedule all FIRE_AND_FORGET plugins as fire-and-forget background tasks. @@ -688,7 +713,9 @@ def _fire_and_forget_tasks( ) local_context = PluginContext(global_context=tmp_gc) res_local_contexts[local_context_key] = local_context - asyncio.create_task(self._run_fire_and_forget_task(ref, task_input, local_context, semaphore)) + asyncio.create_task( + self._run_fire_and_forget_task(ref, task_input, local_context, semaphore, extensions=extensions) + ) async def _run_fire_and_forget_task( self, @@ -696,6 +723,7 @@ async def _run_fire_and_forget_task( payload: PluginPayload, local_context: PluginContext, semaphore: Optional[asyncio.Semaphore], + extensions: Optional[Extensions] = None, ) -> None: """Execute a plugin as a fire-and-forget background task. @@ -705,9 +733,9 @@ async def _run_fire_and_forget_task( try: if semaphore: async with semaphore: - await self._execute_with_timeout(hook_ref, payload, local_context) + await self._execute_with_timeout(hook_ref, payload, local_context, extensions=extensions) else: - await self._execute_with_timeout(hook_ref, payload, local_context) + await self._execute_with_timeout(hook_ref, payload, local_context, extensions=extensions) except Exception: logger.error("Plugin %s failed in fire-and-forget mode (ignored)", hook_ref.plugin_ref.name) if hook_ref.plugin_ref.on_error == OnError.DISABLE: @@ -722,6 +750,7 @@ async def execute_plugin( violations_as_exceptions: bool, global_context: Optional[GlobalContext] = None, combined_metadata: Optional[dict[str, Any]] = None, + extensions: Optional[Extensions] = None, ) -> PluginResult: """Execute a single plugin with timeout protection. @@ -732,6 +761,7 @@ async def execute_plugin( violations_as_exceptions: Raise violations as exceptions rather than as returns. global_context: Shared context for all plugins containing request metadata. combined_metadata: combination of the metadata of all plugins. + extensions: Optional extensions to filter and pass to plugins that accept them. Returns: A tuple containing: @@ -745,7 +775,7 @@ async def execute_plugin( """ try: # Execute plugin with timeout protection - result = await self._execute_with_timeout(hook_ref, payload, local_context) + result = await self._execute_with_timeout(hook_ref, payload, local_context, extensions=extensions) # Merge global state for modes that participate in the pipeline chain. # AUDIT and FIRE_AND_FORGET operate on isolated snapshots and should not # mutate shared state. @@ -870,7 +900,11 @@ async def execute_plugin( return PluginResult(continue_processing=True) async def _execute_with_timeout( - self, hook_ref: HookRef, payload: PluginPayload, context: PluginContext + self, + hook_ref: HookRef, + payload: PluginPayload, + context: PluginContext, + extensions: Optional[Extensions] = None, ) -> PluginResult: """Execute a plugin with timeout protection. @@ -878,6 +912,7 @@ async def _execute_with_timeout( hook_ref: Reference to the hook and plugin to execute. payload: Payload to process. context: Plugin execution context. + extensions: Optional extensions to filter and pass if the plugin accepts them. Returns: Result from plugin execution. @@ -916,7 +951,11 @@ async def _execute_with_timeout( # Execute plugin try: - result = await asyncio.wait_for(hook_ref.hook(payload, context), timeout=self.timeout) + if hook_ref.accepts_extensions: + filtered = filter_extensions(extensions, hook_ref.plugin_ref.capabilities) + result = await asyncio.wait_for(hook_ref.hook(payload, context, filtered), timeout=self.timeout) + else: + result = await asyncio.wait_for(hook_ref.hook(payload, context), timeout=self.timeout) except Exception: if span_id is not None: try: @@ -1263,7 +1302,15 @@ async def initialize(self) -> None: # Fully instantiate enabled plugins plugin = await self._loader.load_and_instantiate_plugin(plugin_config) if plugin: - self._registry.register(plugin) + # For external plugins, initialize() merges the remote + # config (mode, hooks, etc.) so the post-init config is + # authoritative. For internal plugins the original YAML + # config is already complete. + if plugin_config.kind == EXTERNAL_PLUGIN_TYPE: + trusted = plugin.config.model_copy() + else: + trusted = plugin_config + self._registry.register(plugin, trusted_config=trusted) loaded_count += 1 logger.info("Loaded plugin: %s (mode: %s)", plugin_config.name, plugin_config.mode) else: @@ -1330,6 +1377,7 @@ async def invoke_hook( global_context: GlobalContext, local_contexts: Optional[PluginContextTable] = None, violations_as_exceptions: bool = False, + extensions: Optional[Extensions] = None, ) -> tuple[PluginResult, PluginContextTable | None]: """Invoke a set of plugins configured for the hook point in priority order. @@ -1339,6 +1387,7 @@ async def invoke_hook( global_context: Shared context for all plugins with request metadata. local_contexts: Optional existing contexts from previous hook executions. violations_as_exceptions: Raise violations as exceptions rather than as returns. + extensions: Optional extensions to filter and pass to plugins that accept them. Returns: A tuple containing: @@ -1361,7 +1410,13 @@ async def invoke_hook( # Execute plugins result = await self._get_executor().execute( - hook_refs, payload, global_context, hook_type, local_contexts, violations_as_exceptions + hook_refs, + payload, + global_context, + hook_type, + local_contexts, + violations_as_exceptions, + extensions=extensions, ) return result diff --git a/cpex/framework/models.py b/cpex/framework/models.py index c0bd5db..5941cce 100644 --- a/cpex/framework/models.py +++ b/cpex/framework/models.py @@ -1192,6 +1192,8 @@ class PluginConfig(BaseModel): grpc (Optional[GRPCClientConfig]): Client-side gRPC configuration (gateway connecting to plugin). """ + model_config = ConfigDict(frozen=True) + name: str description: Optional[str] = None author: Optional[str] = None @@ -1227,11 +1229,45 @@ def _migrate_legacy_modes(cls, data: Any) -> Any: conditions: list[PluginCondition] = Field(default_factory=list) # When to apply applied_to: Optional[AppliedTo] = None # Fields to apply to. + capabilities: frozenset[str] = Field( + default_factory=frozenset, + description="Declared capabilities (e.g., 'read_headers', 'append_labels').", + ) config: Optional[dict[str, Any]] = None mcp: Optional[MCPClientConfig] = None grpc: Optional[GRPCClientConfig] = None unix_socket: Optional[UnixSocketClientConfig] = None + @field_validator("capabilities", mode="before") + @classmethod + def _validate_capabilities(cls, v: Any) -> frozenset[str]: + """Validate that all declared capabilities are known. + + Args: + v: Raw capabilities value from the config. + + Returns: + A validated frozenset of capability strings. + + Raises: + ValueError: If an unknown capability is declared. + """ + # First-Party + from cpex.framework.extensions.tiers import Capability # pylint: disable=import-outside-toplevel + + if isinstance(v, (list, set, frozenset)): + known = {c.value for c in Capability} + for cap in v: + if cap not in known: + raise ValueError(f"Unknown capability: {cap!r}. Known: {sorted(known)}") + return frozenset(v) + return frozenset() + + @field_serializer("capabilities") + def serialize_capabilities(self, value: frozenset[str]) -> list[str]: + """Serialize frozenset for JSON compatibility.""" + return sorted(value) + @model_validator(mode="after") def check_url_or_script_filled(self) -> Self: # pylint: disable=bad-classmethod-argument """Checks to see that at least one of url or script are set depending on MCP server configuration. diff --git a/cpex/framework/registry.py b/cpex/framework/registry.py index 2507208..5263272 100644 --- a/cpex/framework/registry.py +++ b/cpex/framework/registry.py @@ -16,6 +16,7 @@ # First-Party from cpex.framework.base import HookRef, Plugin, PluginRef from cpex.framework.external.mcp.client import ExternalHookRef +from cpex.framework.models import PluginConfig # Use standard logging to avoid circular imports (plugins -> services -> plugins) logger = logging.getLogger(__name__) @@ -67,11 +68,18 @@ def __init__(self) -> None: self._hooks_by_name: dict[str, dict[str, HookRef]] = {} self._priority_cache: dict[str, list[HookRef]] = {} - def register(self, plugin: Plugin) -> None: + def register( + self, + plugin: Plugin, + trusted_config: PluginConfig | None = None, + ) -> None: """Register a plugin instance. Args: plugin: plugin to be registered. + trusted_config: The authoritative config retained by the + Manager. If provided, PluginRef reads policy fields + from this copy rather than from the plugin. Raises: ValueError: if plugin is already registered. @@ -79,7 +87,7 @@ def register(self, plugin: Plugin) -> None: if plugin.name in self._plugins: raise ValueError(f"Plugin {plugin.name} already registered") - plugin_ref = PluginRef(plugin) + plugin_ref = PluginRef(plugin, trusted_config=trusted_config) self._plugins[plugin.name] = plugin_ref diff --git a/tests/unit/cpex/framework/cmf/__init__.py b/tests/unit/cpex/framework/cmf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/cpex/framework/cmf/test_message.py b/tests/unit/cpex/framework/cmf/test_message.py new file mode 100644 index 0000000..2d4f482 --- /dev/null +++ b/tests/unit/cpex/framework/cmf/test_message.py @@ -0,0 +1,725 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/cpex/framework/cmf/test_message.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for CMF message models. +""" + +# Standard +from typing import Any + +# Third-Party +import pytest + +# First-Party +from cpex.framework.cmf.message import ( + AudioContentPart, + AudioSource, + Channel, + ContentPart, + ContentType, + DocumentContentPart, + DocumentSource, + ImageContentPart, + ImageSource, + Message, + PromptRequest, + PromptRequestContentPart, + PromptResult, + PromptResultContentPart, + Resource, + ResourceContentPart, + ResourceReference, + ResourceRefContentPart, + ResourceType, + Role, + TextContent, + ThinkingContent, + ToolCall, + ToolCallContentPart, + ToolResult, + ToolResultContentPart, + VideoContentPart, + VideoSource, +) + + +# --------------------------------------------------------------------------- +# Enum Tests +# --------------------------------------------------------------------------- + + +class TestRole: + """Tests for the Role enum.""" + + def test_values(self): + assert Role.SYSTEM.value == "system" + assert Role.DEVELOPER.value == "developer" + assert Role.USER.value == "user" + assert Role.ASSISTANT.value == "assistant" + assert Role.TOOL.value == "tool" + + def test_from_string(self): + assert Role("user") == Role.USER + assert Role("assistant") == Role.ASSISTANT + + def test_invalid_value(self): + with pytest.raises(ValueError): + Role("invalid") + + def test_member_count(self): + assert len(Role) == 5 + + +class TestChannel: + """Tests for the Channel enum.""" + + def test_values(self): + assert Channel.ANALYSIS.value == "analysis" + assert Channel.COMMENTARY.value == "commentary" + assert Channel.FINAL.value == "final" + + def test_from_string(self): + assert Channel("final") == Channel.FINAL + + def test_member_count(self): + assert len(Channel) == 3 + + +class TestContentType: + """Tests for the ContentType enum.""" + + def test_all_types_present(self): + expected = { + "text", "thinking", "tool_call", "tool_result", + "resource", "resource_ref", "prompt_request", "prompt_result", + "image", "video", "audio", "document", + } + assert {ct.value for ct in ContentType} == expected + + def test_member_count(self): + assert len(ContentType) == 12 + + +class TestResourceType: + """Tests for the ResourceType enum.""" + + def test_all_types_present(self): + expected = {"file", "blob", "uri", "database", "api", "memory", "artifact"} + assert {rt.value for rt in ResourceType} == expected + + def test_member_count(self): + assert len(ResourceType) == 7 + + +# --------------------------------------------------------------------------- +# ContentPart Base Class Tests +# --------------------------------------------------------------------------- + + +class TestContentPart: + """Tests for the ContentPart base class.""" + + def test_subclass_relationship(self): + part = TextContent(text="hello") + assert isinstance(part, ContentPart) + + def test_wrapper_subclass_relationship(self): + part = ToolCallContentPart( + content=ToolCall(tool_call_id="tc1", name="test"), + ) + assert isinstance(part, ContentPart) + + def test_frozen(self): + part = TextContent(text="hello") + with pytest.raises(Exception): + part.text = "world" + + +# --------------------------------------------------------------------------- +# Domain Object Tests +# --------------------------------------------------------------------------- + + +class TestToolCallDomain: + """Tests for the ToolCall domain object.""" + + def test_creation(self): + call = ToolCall( + tool_call_id="tc_001", + name="get_user", + arguments={"user_id": "123"}, + ) + assert call.tool_call_id == "tc_001" + assert call.name == "get_user" + assert call.arguments == {"user_id": "123"} + + def test_default_arguments(self): + call = ToolCall(tool_call_id="tc_002", name="list_users") + assert call.arguments == {} + + def test_default_namespace(self): + call = ToolCall(tool_call_id="tc_003", name="test") + assert call.namespace is None + + def test_with_namespace(self): + call = ToolCall( + tool_call_id="tc_004", + name="get_user", + namespace="user-service", + ) + assert call.namespace == "user-service" + + def test_frozen(self): + call = ToolCall(tool_call_id="tc_005", name="test") + with pytest.raises(Exception): + call.name = "other" + + +class TestToolResultDomain: + """Tests for the ToolResult domain object.""" + + def test_creation(self): + result = ToolResult( + tool_call_id="tc_001", + tool_name="get_user", + content={"name": "Alice"}, + ) + assert result.tool_call_id == "tc_001" + assert result.tool_name == "get_user" + assert result.content == {"name": "Alice"} + assert result.is_error is False + + def test_error_result(self): + result = ToolResult( + tool_call_id="tc_002", + tool_name="fail_tool", + content="Something went wrong", + is_error=True, + ) + assert result.is_error is True + + def test_default_content(self): + result = ToolResult(tool_call_id="tc_003", tool_name="test") + assert result.content is None + assert result.is_error is False + + +class TestImageSourceDomain: + """Tests for the ImageSource domain object.""" + + def test_url_image(self): + img = ImageSource(type="url", data="https://example.com/photo.jpg") + assert img.type == "url" + assert img.media_type is None + + def test_base64_image(self): + img = ImageSource( + type="base64", + data="iVBORw0KGgo...", + media_type="image/png", + ) + assert img.type == "base64" + assert img.media_type == "image/png" + + +class TestVideoSourceDomain: + """Tests for the VideoSource domain object.""" + + def test_creation(self): + vid = VideoSource(type="url", data="https://example.com/clip.mp4") + assert vid.duration_ms is None + + def test_with_duration(self): + vid = VideoSource( + type="url", + data="https://example.com/clip.mp4", + duration_ms=30000, + ) + assert vid.duration_ms == 30000 + + +class TestAudioSourceDomain: + """Tests for the AudioSource domain object.""" + + def test_creation(self): + aud = AudioSource(type="url", data="https://example.com/track.mp3") + assert aud.type == "url" + + +class TestDocumentSourceDomain: + """Tests for the DocumentSource domain object.""" + + def test_creation(self): + doc = DocumentSource( + type="base64", + data="JVBERi0xLjQ...", + media_type="application/pdf", + title="Annual Report", + ) + assert doc.title == "Annual Report" + + +class TestResourceDomain: + """Tests for the Resource domain object.""" + + def test_creation(self): + res = Resource( + resource_request_id="rr_001", + uri="file:///data/report.csv", + name="Q4 Report", + resource_type=ResourceType.FILE, + content="col1,col2\n1,2", + mime_type="text/csv", + ) + assert res.uri == "file:///data/report.csv" + assert res.resource_type == ResourceType.FILE + + def test_minimal_creation(self): + res = Resource( + resource_request_id="rr_002", + uri="db://users/42", + resource_type=ResourceType.DATABASE, + ) + assert res.name is None + assert res.content is None + assert res.blob is None + + def test_blob_resource(self): + res = Resource( + resource_request_id="rr_003", + uri="blob://data", + resource_type=ResourceType.BLOB, + blob=b"\x00\x01\x02", + ) + assert res.blob == b"\x00\x01\x02" + + +class TestResourceReferenceDomain: + """Tests for the ResourceReference domain object.""" + + def test_creation(self): + ref = ResourceReference( + resource_request_id="rr_004", + uri="file:///path/to/file.txt", + resource_type=ResourceType.FILE, + ) + assert ref.uri == "file:///path/to/file.txt" + + def test_with_range(self): + ref = ResourceReference( + resource_request_id="rr_005", + uri="file:///code.py", + resource_type=ResourceType.FILE, + range_start=10, + range_end=50, + ) + assert ref.range_start == 10 + assert ref.range_end == 50 + + def test_with_selector(self): + ref = ResourceReference( + resource_request_id="rr_006", + uri="api://data", + resource_type=ResourceType.API, + selector="$.results[0]", + ) + assert ref.selector == "$.results[0]" + + +class TestPromptRequestDomain: + """Tests for the PromptRequest domain object.""" + + def test_creation(self): + req = PromptRequest( + prompt_request_id="pr_001", + name="summarize", + arguments={"text": "Long document..."}, + ) + assert req.name == "summarize" + assert req.arguments == {"text": "Long document..."} + + def test_defaults(self): + req = PromptRequest(prompt_request_id="pr_002", name="test") + assert req.arguments == {} + assert req.server_id is None + + +class TestPromptResultDomain: + """Tests for the PromptResult domain object.""" + + def test_creation(self): + result = PromptResult( + prompt_request_id="pr_001", + prompt_name="summarize", + content="This document discusses...", + ) + assert result.prompt_name == "summarize" + assert result.is_error is False + + def test_error_result(self): + result = PromptResult( + prompt_request_id="pr_002", + prompt_name="fail_prompt", + is_error=True, + error_message="Template not found", + ) + assert result.is_error is True + assert result.error_message == "Template not found" + + def test_defaults(self): + result = PromptResult(prompt_request_id="pr_003", prompt_name="test") + assert result.messages == [] + assert result.content is None + + +# --------------------------------------------------------------------------- +# ContentPart Wrapper Tests +# --------------------------------------------------------------------------- + + +class TestTextContent: + """Tests for TextContent.""" + + def test_creation(self): + part = TextContent(text="Hello, world!") + assert part.content_type == ContentType.TEXT + assert part.text == "Hello, world!" + + def test_frozen(self): + part = TextContent(text="original") + with pytest.raises(Exception): + part.text = "modified" + + def test_model_copy(self): + part = TextContent(text="original") + modified = part.model_copy(update={"text": "updated"}) + assert part.text == "original" + assert modified.text == "updated" + + +class TestThinkingContent: + """Tests for ThinkingContent.""" + + def test_creation(self): + part = ThinkingContent(text="Let me analyze this...") + assert part.content_type == ContentType.THINKING + assert part.text == "Let me analyze this..." + + +class TestToolCallContentPart: + """Tests for ToolCallContentPart wrapper.""" + + def test_creation(self): + call = ToolCall(tool_call_id="tc_001", name="get_user", arguments={"user_id": "123"}) + part = ToolCallContentPart(content=call) + assert part.content_type == ContentType.TOOL_CALL + assert part.content.name == "get_user" + assert part.content.arguments == {"user_id": "123"} + + def test_frozen(self): + part = ToolCallContentPart(content=ToolCall(tool_call_id="tc1", name="test")) + with pytest.raises(Exception): + part.content = ToolCall(tool_call_id="tc2", name="other") + + +class TestToolResultContentPart: + """Tests for ToolResultContentPart wrapper.""" + + def test_creation(self): + result = ToolResult(tool_call_id="tc_001", tool_name="get_user", content={"name": "Alice"}) + part = ToolResultContentPart(content=result) + assert part.content_type == ContentType.TOOL_RESULT + assert part.content.tool_name == "get_user" + assert part.content.is_error is False + + +class TestResourceContentPart: + """Tests for ResourceContentPart wrapper.""" + + def test_creation(self): + res = Resource( + resource_request_id="rr_001", + uri="file:///data/report.csv", + resource_type=ResourceType.FILE, + ) + part = ResourceContentPart(content=res) + assert part.content_type == ContentType.RESOURCE + assert part.content.uri == "file:///data/report.csv" + + +class TestImageContentPart: + """Tests for ImageContentPart wrapper.""" + + def test_creation(self): + img = ImageSource(type="url", data="https://example.com/photo.jpg") + part = ImageContentPart(content=img) + assert part.content_type == ContentType.IMAGE + assert part.content.type == "url" + + +class TestDocumentContentPart: + """Tests for DocumentContentPart wrapper.""" + + def test_creation(self): + doc = DocumentSource( + type="base64", data="JVBERi0xLjQ...", + media_type="application/pdf", title="Annual Report", + ) + part = DocumentContentPart(content=doc) + assert part.content_type == ContentType.DOCUMENT + assert part.content.title == "Annual Report" + + +# --------------------------------------------------------------------------- +# Message Tests +# --------------------------------------------------------------------------- + + +class TestMessage: + """Tests for the Message model.""" + + def test_simple_message(self): + msg = Message( + role=Role.USER, + content=[TextContent(text="Hello")], + ) + assert msg.role == Role.USER + assert msg.schema_version == "2.0" + assert msg.channel is None + assert msg.extensions is None + assert len(msg.content) == 1 + + def test_empty_content(self): + msg = Message(role=Role.SYSTEM) + assert msg.content == [] + + def test_multi_part_message(self): + msg = Message( + role=Role.ASSISTANT, + content=[ + ThinkingContent(text="Reasoning..."), + TextContent(text="Here is the answer."), + ToolCallContentPart( + content=ToolCall(tool_call_id="tc_001", name="search", arguments={"q": "test"}), + ), + ], + ) + assert len(msg.content) == 3 + assert msg.content[0].content_type == ContentType.THINKING + assert msg.content[1].content_type == ContentType.TEXT + assert msg.content[2].content_type == ContentType.TOOL_CALL + + def test_with_channel(self): + msg = Message( + role=Role.ASSISTANT, + content=[TextContent(text="Final answer.")], + channel=Channel.FINAL, + ) + assert msg.channel == Channel.FINAL + + def test_frozen(self): + msg = Message(role=Role.USER, content=[TextContent(text="Hi")]) + with pytest.raises(Exception): + msg.role = Role.ASSISTANT + + def test_model_copy(self): + msg = Message(role=Role.USER, content=[TextContent(text="Hi")]) + updated = msg.model_copy(update={"channel": Channel.FINAL}) + assert msg.channel is None + assert updated.channel == Channel.FINAL + assert updated.role == Role.USER + + def test_deserialization_from_dict(self): + msg = Message.model_validate({ + "role": "user", + "content": [ + {"content_type": "text", "text": "Hello"}, + {"content_type": "tool_call", "content": {"tool_call_id": "tc1", "name": "foo", "arguments": {}}}, + ], + }) + assert msg.role == Role.USER + assert len(msg.content) == 2 + assert isinstance(msg.content[0], TextContent) + assert isinstance(msg.content[1], ToolCallContentPart) + + def test_deserialization_all_content_types(self): + msg = Message.model_validate({ + "role": "assistant", + "content": [ + {"content_type": "text", "text": "hi"}, + {"content_type": "thinking", "text": "hmm"}, + {"content_type": "tool_call", "content": {"tool_call_id": "t1", "name": "x", "arguments": {}}}, + {"content_type": "tool_result", "content": {"tool_call_id": "t1", "tool_name": "x"}}, + {"content_type": "resource", "content": {"resource_request_id": "r1", "uri": "file:///a", "resource_type": "file"}}, + {"content_type": "resource_ref", "content": {"resource_request_id": "r2", "uri": "db://b", "resource_type": "database"}}, + {"content_type": "prompt_request", "content": {"prompt_request_id": "p1", "name": "s"}}, + {"content_type": "prompt_result", "content": {"prompt_request_id": "p1", "prompt_name": "s"}}, + {"content_type": "image", "content": {"type": "url", "data": "http://img"}}, + {"content_type": "video", "content": {"type": "url", "data": "http://vid"}}, + {"content_type": "audio", "content": {"type": "url", "data": "http://aud"}}, + {"content_type": "document", "content": {"type": "url", "data": "http://doc"}}, + ], + }) + assert len(msg.content) == 12 + expected_types = [ + TextContent, ThinkingContent, ToolCallContentPart, ToolResultContentPart, + ResourceContentPart, ResourceRefContentPart, PromptRequestContentPart, PromptResultContentPart, + ImageContentPart, VideoContentPart, AudioContentPart, DocumentContentPart, + ] + for part, expected in zip(msg.content, expected_types): + assert isinstance(part, expected), f"Expected {expected.__name__}, got {type(part).__name__}" + + def test_serialization_roundtrip(self): + msg = Message( + role=Role.ASSISTANT, + content=[ + TextContent(text="hello"), + ToolCallContentPart( + content=ToolCall(tool_call_id="tc1", name="test", arguments={"a": 1}), + ), + ], + ) + data = msg.model_dump() + restored = Message.model_validate(data) + assert restored.role == msg.role + assert len(restored.content) == 2 + assert restored.content[0].text == "hello" + assert restored.content[1].content.name == "test" + + def test_iter_views(self): + msg = Message( + role=Role.ASSISTANT, + content=[ + TextContent(text="hello"), + ToolCallContentPart( + content=ToolCall(tool_call_id="tc1", name="test", arguments={}), + ), + ], + ) + views = list(msg.iter_views()) + assert len(views) == 2 + + +# --------------------------------------------------------------------------- +# Validation Tests +# --------------------------------------------------------------------------- + + +class TestResourceValidation: + """Tests for Resource model validators.""" + + def test_content_only(self): + res = Resource( + resource_request_id="r1", uri="file:///a.txt", + resource_type=ResourceType.FILE, content="hello", + ) + assert res.content == "hello" + assert res.blob is None + + def test_blob_only(self): + res = Resource( + resource_request_id="r1", uri="file:///a.bin", + resource_type=ResourceType.FILE, blob=b"\x00\x01", + ) + assert res.blob == b"\x00\x01" + assert res.content is None + + def test_neither_content_nor_blob(self): + res = Resource( + resource_request_id="r1", uri="file:///a.txt", + resource_type=ResourceType.FILE, + ) + assert res.content is None + assert res.blob is None + + def test_content_and_blob_raises(self): + with pytest.raises(ValueError, match="cannot have both"): + Resource( + resource_request_id="r1", uri="file:///a.txt", + resource_type=ResourceType.FILE, + content="hello", blob=b"\x00", + ) + + +class TestResourceReferenceValidation: + """Tests for ResourceReference range validators.""" + + def test_valid_range(self): + ref = ResourceReference( + resource_request_id="r1", uri="file:///a.txt", + resource_type=ResourceType.FILE, + range_start=10, range_end=20, + ) + assert ref.range_start == 10 + assert ref.range_end == 20 + + def test_equal_range(self): + ref = ResourceReference( + resource_request_id="r1", uri="file:///a.txt", + resource_type=ResourceType.FILE, + range_start=5, range_end=5, + ) + assert ref.range_start == ref.range_end + + def test_invalid_range_raises(self): + with pytest.raises(ValueError, match="range_end.*must be >= range_start"): + ResourceReference( + resource_request_id="r1", uri="file:///a.txt", + resource_type=ResourceType.FILE, + range_start=20, range_end=10, + ) + + def test_start_only(self): + ref = ResourceReference( + resource_request_id="r1", uri="file:///a.txt", + resource_type=ResourceType.FILE, range_start=5, + ) + assert ref.range_start == 5 + assert ref.range_end is None + + def test_end_only(self): + ref = ResourceReference( + resource_request_id="r1", uri="file:///a.txt", + resource_type=ResourceType.FILE, range_end=10, + ) + assert ref.range_start is None + assert ref.range_end == 10 + + +class TestDiscriminator: + """Tests for content_type discriminator function.""" + + def test_missing_content_type_in_dict_raises(self): + with pytest.raises(Exception): + Message(role=Role.USER, content=[{"text": "hello"}]) + + def test_invalid_content_type_in_dict_raises(self): + with pytest.raises(Exception): + Message(role=Role.USER, content=[{"content_type": "bogus", "text": "hello"}]) + + +class TestMediaSourceLiteral: + """Tests that media source type fields enforce Literal['url', 'base64'].""" + + def test_image_source_valid_types(self): + assert ImageSource(type="url", data="https://x.com/a.jpg").type == "url" + assert ImageSource(type="base64", data="abc").type == "base64" + + def test_image_source_invalid_type(self): + with pytest.raises(Exception): + ImageSource(type="ftp", data="abc") + + def test_video_source_invalid_type(self): + with pytest.raises(Exception): + VideoSource(type="file", data="abc") + + def test_audio_source_invalid_type(self): + with pytest.raises(Exception): + AudioSource(type="stream", data="abc") + + def test_document_source_invalid_type(self): + with pytest.raises(Exception): + DocumentSource(type="unknown", data="abc") diff --git a/tests/unit/cpex/framework/cmf/test_view.py b/tests/unit/cpex/framework/cmf/test_view.py new file mode 100644 index 0000000..54db20a --- /dev/null +++ b/tests/unit/cpex/framework/cmf/test_view.py @@ -0,0 +1,1174 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/cpex/framework/cmf/test_view.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for MessageView. +""" + +# Standard +from typing import Any + +# Third-Party +import pytest + +# First-Party +from cpex.framework.cmf.message import ( + AudioContentPart, + AudioSource, + DocumentContentPart, + DocumentSource, + ImageContentPart, + ImageSource, + Message, + PromptRequest, + PromptRequestContentPart, + PromptResult, + PromptResultContentPart, + Resource, + ResourceContentPart, + ResourceReference, + ResourceRefContentPart, + ResourceType, + Role, + TextContent, + ThinkingContent, + ToolCall, + ToolCallContentPart, + ToolResult, + ToolResultContentPart, + VideoContentPart, + VideoSource, +) +from cpex.framework.cmf.view import ( + MessageView, + ViewAction, + ViewKind, + iter_views, +) +from cpex.framework.extensions.agent import AgentExtension +from cpex.framework.extensions.extensions import Extensions +from cpex.framework.extensions.http import HttpExtension +from cpex.framework.extensions.request import RequestExtension +from cpex.framework.extensions.security import ( + DataPolicy, + ObjectSecurityProfile, + RetentionPolicy, + SecurityExtension, + SubjectExtension, + SubjectType, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def simple_assistant_msg(): + """An assistant message with text, thinking, and a tool call.""" + return Message( + role=Role.ASSISTANT, + content=[ + ThinkingContent(text="User wants admin users."), + TextContent(text="Let me look that up."), + ToolCallContentPart( + content=ToolCall( + tool_call_id="tc_001", + name="execute_sql", + arguments={"query": "SELECT * FROM users WHERE role='admin'"}, + ), + ), + ], + ) + + +@pytest.fixture +def full_msg(): + """A message with full extensions populated.""" + return Message( + role=Role.ASSISTANT, + content=[ + ToolCallContentPart( + content=ToolCall( + tool_call_id="tc_001", + name="get_compensation", + namespace="hr-server", + arguments={"employee_id": "emp-42"}, + ), + ), + ], + extensions=Extensions( + request=RequestExtension( + environment="production", + request_id="req-001", + ), + agent=AgentExtension( + input="Show me Alice's compensation", + session_id="sess-001", + conversation_id="conv-001", + turn=2, + agent_id="main-agent", + parent_agent_id="orchestrator", + ), + http=HttpExtension( + headers={ + "Authorization": "Bearer secret-token", + "Cookie": "session=abc", + "X-Request-ID": "req-001", + "Content-Type": "application/json", + }, + ), + security=SecurityExtension( + labels=frozenset({"CONFIDENTIAL"}), + classification="confidential", + subject=SubjectExtension( + id="user-alice", + type=SubjectType.USER, + roles=frozenset({"admin", "hr-manager"}), + permissions=frozenset({"read:compensation", "tools.execute"}), + teams=frozenset({"hr-team"}), + ), + objects={ + "get_compensation": ObjectSecurityProfile( + managed_by="tool", + permissions=["read:compensation"], + trust_domain="internal", + data_scope=["salary", "bonus"], + ), + }, + data={ + "get_compensation": DataPolicy( + apply_labels=["PII", "financial"], + denied_actions=["export", "forward", "log_raw"], + retention=RetentionPolicy( + policy="session", + max_age_seconds=3600, + ), + ), + }, + ), + ), + ) + + +# --------------------------------------------------------------------------- +# Enum Tests +# --------------------------------------------------------------------------- + + +class TestViewKind: + """Tests for ViewKind enum.""" + + def test_member_count(self): + assert len(ViewKind) == 12 + + def test_values_match_content_type(self): + from cpex.framework.cmf.message import ContentType + + for ct in ContentType: + assert ct.value in [vk.value for vk in ViewKind] + + +class TestViewAction: + """Tests for ViewAction enum.""" + + def test_member_count(self): + assert len(ViewAction) == 7 + + def test_values(self): + expected = {"read", "write", "execute", "invoke", "send", "receive", "generate"} + assert {va.value for va in ViewAction} == expected + + +# --------------------------------------------------------------------------- +# View Iteration +# --------------------------------------------------------------------------- + + +class TestIterViews: + """Tests for iter_views() and Message.iter_views().""" + + def test_standalone_and_method_match(self, simple_assistant_msg): + standalone = list(iter_views(simple_assistant_msg)) + method = list(simple_assistant_msg.iter_views()) + assert len(standalone) == len(method) == 3 + + def test_view_count(self, simple_assistant_msg): + views = list(iter_views(simple_assistant_msg)) + assert len(views) == 3 + + def test_view_kinds(self, simple_assistant_msg): + views = list(iter_views(simple_assistant_msg)) + assert views[0].kind == ViewKind.THINKING + assert views[1].kind == ViewKind.TEXT + assert views[2].kind == ViewKind.TOOL_CALL + + def test_empty_message(self): + msg = Message(role=Role.USER) + views = list(iter_views(msg)) + assert len(views) == 0 + + def test_single_content_part(self): + msg = Message(role=Role.USER, content=[TextContent(text="Hi")]) + views = list(iter_views(msg)) + assert len(views) == 1 + assert views[0].kind == ViewKind.TEXT + + def test_all_content_types(self): + msg = Message( + role=Role.ASSISTANT, + content=[ + TextContent(text="hi"), + ThinkingContent(text="hmm"), + ToolCallContentPart(content=ToolCall(tool_call_id="t1", name="x", arguments={})), + ToolResultContentPart(content=ToolResult(tool_call_id="t1", tool_name="x", content="ok")), + ResourceContentPart(content=Resource(resource_request_id="r1", uri="file:///a", resource_type=ResourceType.FILE, content="data")), + ResourceRefContentPart(content=ResourceReference(resource_request_id="r2", uri="db://b", resource_type=ResourceType.DATABASE)), + PromptRequestContentPart(content=PromptRequest(prompt_request_id="p1", name="summarize")), + PromptResultContentPart(content=PromptResult(prompt_request_id="p1", prompt_name="summarize", content="summary")), + ImageContentPart(content=ImageSource(type="url", data="http://img")), + VideoContentPart(content=VideoSource(type="url", data="http://vid")), + AudioContentPart(content=AudioSource(type="url", data="http://aud")), + DocumentContentPart(content=DocumentSource(type="url", data="http://doc")), + ], + ) + views = list(iter_views(msg)) + assert len(views) == 12 + expected_kinds = [ + ViewKind.TEXT, ViewKind.THINKING, ViewKind.TOOL_CALL, ViewKind.TOOL_RESULT, + ViewKind.RESOURCE, ViewKind.RESOURCE_REF, ViewKind.PROMPT_REQUEST, ViewKind.PROMPT_RESULT, + ViewKind.IMAGE, ViewKind.VIDEO, ViewKind.AUDIO, ViewKind.DOCUMENT, + ] + for view, expected in zip(views, expected_kinds): + assert view.kind == expected + + +# --------------------------------------------------------------------------- +# Core Properties +# --------------------------------------------------------------------------- + + +class TestCoreProperties: + """Tests for MessageView core properties.""" + + def test_role(self, simple_assistant_msg): + view = list(iter_views(simple_assistant_msg))[0] + assert view.role == Role.ASSISTANT + + def test_raw_access(self, simple_assistant_msg): + views = list(iter_views(simple_assistant_msg)) + assert isinstance(views[0].raw, ThinkingContent) + assert isinstance(views[2].raw, ToolCallContentPart) + + def test_content_text(self): + msg = Message(role=Role.USER, content=[TextContent(text="Hello")]) + view = list(iter_views(msg))[0] + assert view.content == "Hello" + + def test_content_thinking(self): + msg = Message(role=Role.ASSISTANT, content=[ThinkingContent(text="Reasoning...")]) + view = list(iter_views(msg))[0] + assert view.content == "Reasoning..." + + def test_content_tool_call(self): + msg = Message( + role=Role.ASSISTANT, + content=[ToolCallContentPart(content=ToolCall(tool_call_id="tc1", name="test", arguments={"key": "value"}))], + ) + view = list(iter_views(msg))[0] + assert view.content == '{"key": "value"}' + + def test_content_tool_result(self): + msg = Message( + role=Role.TOOL, + content=[ToolResultContentPart(content=ToolResult(tool_call_id="tc1", tool_name="test", content={"result": 42}))], + ) + view = list(iter_views(msg))[0] + assert view.content == '{"result": 42}' + + def test_content_tool_result_string(self): + msg = Message( + role=Role.TOOL, + content=[ToolResultContentPart(content=ToolResult(tool_call_id="tc1", tool_name="test", content="plain text"))], + ) + view = list(iter_views(msg))[0] + assert view.content == "plain text" + + def test_content_tool_result_none(self): + msg = Message( + role=Role.TOOL, + content=[ToolResultContentPart(content=ToolResult(tool_call_id="tc1", tool_name="test"))], + ) + view = list(iter_views(msg))[0] + assert view.content is None + + def test_content_resource(self): + msg = Message( + role=Role.TOOL, + content=[ResourceContentPart(content=Resource( + resource_request_id="r1", uri="file:///a", + resource_type=ResourceType.FILE, content="file data", + ))], + ) + view = list(iter_views(msg))[0] + assert view.content == "file data" + + def test_content_prompt_request(self): + msg = Message( + role=Role.ASSISTANT, + content=[PromptRequestContentPart(content=PromptRequest(prompt_request_id="p1", name="s", arguments={"text": "hi"}))], + ) + view = list(iter_views(msg))[0] + assert view.content == '{"text": "hi"}' + + def test_content_prompt_result(self): + msg = Message( + role=Role.TOOL, + content=[PromptResultContentPart(content=PromptResult(prompt_request_id="p1", prompt_name="s", content="rendered"))], + ) + view = list(iter_views(msg))[0] + assert view.content == "rendered" + + def test_content_media_none(self): + msg = Message( + role=Role.USER, + content=[ImageContentPart(content=ImageSource(type="url", data="http://img"))], + ) + view = list(iter_views(msg))[0] + assert view.content is None + + +# --------------------------------------------------------------------------- +# URI +# --------------------------------------------------------------------------- + + +class TestURI: + """Tests for synthetic URI generation.""" + + def test_tool_call_uri(self): + msg = Message( + role=Role.ASSISTANT, + content=[ToolCallContentPart(content=ToolCall(tool_call_id="tc1", name="get_user", arguments={}))], + ) + view = list(iter_views(msg))[0] + assert view.uri == "tool://_/get_user" + + def test_tool_call_uri_with_namespace(self): + msg = Message( + role=Role.ASSISTANT, + content=[ToolCallContentPart(content=ToolCall(tool_call_id="tc1", name="get_user", namespace="user-svc", arguments={}))], + ) + view = list(iter_views(msg))[0] + assert view.uri == "tool://user-svc/get_user" + + def test_tool_result_uri(self): + msg = Message( + role=Role.TOOL, + content=[ToolResultContentPart(content=ToolResult(tool_call_id="tc1", tool_name="get_user"))], + ) + view = list(iter_views(msg))[0] + assert view.uri == "tool_result://get_user" + + def test_resource_uri(self): + msg = Message( + role=Role.TOOL, + content=[ResourceContentPart(content=Resource( + resource_request_id="r1", uri="file:///data/report.csv", + resource_type=ResourceType.FILE, + ))], + ) + view = list(iter_views(msg))[0] + assert view.uri == "file:///data/report.csv" + + def test_resource_ref_uri(self): + msg = Message( + role=Role.ASSISTANT, + content=[ResourceRefContentPart(content=ResourceReference( + resource_request_id="r1", uri="db://users/42", + resource_type=ResourceType.DATABASE, + ))], + ) + view = list(iter_views(msg))[0] + assert view.uri == "db://users/42" + + def test_prompt_request_uri(self): + msg = Message( + role=Role.ASSISTANT, + content=[PromptRequestContentPart(content=PromptRequest(prompt_request_id="p1", name="summarize", server_id="prompt-svc"))], + ) + view = list(iter_views(msg))[0] + assert view.uri == "prompt://prompt-svc/summarize" + + def test_prompt_result_uri(self): + msg = Message( + role=Role.TOOL, + content=[PromptResultContentPart(content=PromptResult(prompt_request_id="p1", prompt_name="summarize"))], + ) + view = list(iter_views(msg))[0] + assert view.uri == "prompt_result://summarize" + + def test_text_uri_is_none(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg))[0] + assert view.uri is None + + +# --------------------------------------------------------------------------- +# Name +# --------------------------------------------------------------------------- + + +class TestName: + """Tests for the name property.""" + + def test_tool_call_name(self): + msg = Message(role=Role.ASSISTANT, content=[ToolCallContentPart(content=ToolCall(tool_call_id="tc1", name="get_user", arguments={}))]) + assert list(iter_views(msg))[0].name == "get_user" + + def test_tool_result_name(self): + msg = Message(role=Role.TOOL, content=[ToolResultContentPart(content=ToolResult(tool_call_id="tc1", tool_name="get_user"))]) + assert list(iter_views(msg))[0].name == "get_user" + + def test_prompt_request_name(self): + msg = Message(role=Role.ASSISTANT, content=[PromptRequestContentPart(content=PromptRequest(prompt_request_id="p1", name="summarize"))]) + assert list(iter_views(msg))[0].name == "summarize" + + def test_prompt_result_name(self): + msg = Message(role=Role.TOOL, content=[PromptResultContentPart(content=PromptResult(prompt_request_id="p1", prompt_name="summarize"))]) + assert list(iter_views(msg))[0].name == "summarize" + + def test_text_name_is_none(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + assert list(iter_views(msg))[0].name is None + + +# --------------------------------------------------------------------------- +# Action +# --------------------------------------------------------------------------- + + +class TestAction: + """Tests for the action property.""" + + def test_action_mapping(self): + pairs = [ + (TextContent(text="hi"), Role.USER, ViewAction.SEND), + (ThinkingContent(text="hmm"), Role.ASSISTANT, ViewAction.GENERATE), + (ToolCallContentPart(content=ToolCall(tool_call_id="t", name="x", arguments={})), Role.ASSISTANT, ViewAction.EXECUTE), + (ToolResultContentPart(content=ToolResult(tool_call_id="t", tool_name="x")), Role.TOOL, ViewAction.RECEIVE), + (ResourceContentPart(content=Resource(resource_request_id="r", uri="f:///a", resource_type=ResourceType.FILE)), Role.TOOL, ViewAction.READ), + (ResourceRefContentPart(content=ResourceReference(resource_request_id="r", uri="f:///a", resource_type=ResourceType.FILE)), Role.ASSISTANT, ViewAction.READ), + (PromptRequestContentPart(content=PromptRequest(prompt_request_id="p", name="s")), Role.ASSISTANT, ViewAction.INVOKE), + (PromptResultContentPart(content=PromptResult(prompt_request_id="p", prompt_name="s")), Role.TOOL, ViewAction.RECEIVE), + (ImageContentPart(content=ImageSource(type="url", data="http://img")), Role.USER, ViewAction.SEND), + ] + for part, role, expected_action in pairs: + msg = Message(role=role, content=[part]) + view = list(iter_views(msg))[0] + assert view.action == expected_action, f"Expected {expected_action} for {part.content_type}" + + +# --------------------------------------------------------------------------- +# Direction +# --------------------------------------------------------------------------- + + +class TestDirection: + """Tests for is_pre / is_post direction logic.""" + + def test_tool_call_is_pre(self): + msg = Message(role=Role.ASSISTANT, content=[ToolCallContentPart(content=ToolCall(tool_call_id="t", name="x", arguments={}))]) + view = list(iter_views(msg))[0] + assert view.is_pre is True + assert view.is_post is False + + def test_tool_result_is_post(self): + msg = Message(role=Role.TOOL, content=[ToolResultContentPart(content=ToolResult(tool_call_id="t", tool_name="x"))]) + view = list(iter_views(msg))[0] + assert view.is_pre is False + assert view.is_post is True + + def test_prompt_request_is_pre(self): + msg = Message(role=Role.ASSISTANT, content=[PromptRequestContentPart(content=PromptRequest(prompt_request_id="p", name="s"))]) + view = list(iter_views(msg))[0] + assert view.is_pre is True + + def test_prompt_result_is_post(self): + msg = Message(role=Role.TOOL, content=[PromptResultContentPart(content=PromptResult(prompt_request_id="p", prompt_name="s"))]) + view = list(iter_views(msg))[0] + assert view.is_post is True + + def test_resource_ref_is_pre(self): + msg = Message(role=Role.ASSISTANT, content=[ + ResourceRefContentPart(content=ResourceReference(resource_request_id="r", uri="f:///a", resource_type=ResourceType.FILE)), + ]) + view = list(iter_views(msg))[0] + assert view.is_pre is True + + def test_resource_is_post(self): + msg = Message(role=Role.TOOL, content=[ + ResourceContentPart(content=Resource(resource_request_id="r", uri="f:///a", resource_type=ResourceType.FILE)), + ]) + view = list(iter_views(msg))[0] + assert view.is_post is True + + def test_user_text_is_pre(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg))[0] + assert view.is_pre is True + assert view.is_post is False + + def test_assistant_text_is_post(self): + msg = Message(role=Role.ASSISTANT, content=[TextContent(text="hello")]) + view = list(iter_views(msg))[0] + assert view.is_pre is False + assert view.is_post is True + + def test_system_text_is_pre(self): + msg = Message(role=Role.SYSTEM, content=[TextContent(text="instructions")]) + view = list(iter_views(msg))[0] + assert view.is_pre is True + + def test_developer_text_is_pre(self): + msg = Message(role=Role.DEVELOPER, content=[TextContent(text="hints")]) + view = list(iter_views(msg))[0] + assert view.is_pre is True + + def test_tool_text_is_post(self): + msg = Message(role=Role.TOOL, content=[TextContent(text="result text")]) + view = list(iter_views(msg))[0] + assert view.is_pre is False + assert view.is_post is True + + +# --------------------------------------------------------------------------- +# Entity Type Helpers +# --------------------------------------------------------------------------- + + +class TestEntityTypeHelpers: + """Tests for is_tool, is_prompt, is_resource, is_text, is_media.""" + + def test_is_tool(self): + msg = Message(role=Role.ASSISTANT, content=[ToolCallContentPart(content=ToolCall(tool_call_id="t", name="x", arguments={}))]) + assert list(iter_views(msg))[0].is_tool is True + + def test_is_tool_result(self): + msg = Message(role=Role.TOOL, content=[ToolResultContentPart(content=ToolResult(tool_call_id="t", tool_name="x"))]) + assert list(iter_views(msg))[0].is_tool is True + + def test_is_prompt(self): + msg = Message(role=Role.ASSISTANT, content=[PromptRequestContentPart(content=PromptRequest(prompt_request_id="p", name="s"))]) + assert list(iter_views(msg))[0].is_prompt is True + + def test_is_resource(self): + msg = Message(role=Role.TOOL, content=[ + ResourceContentPart(content=Resource(resource_request_id="r", uri="f:///a", resource_type=ResourceType.FILE)), + ]) + assert list(iter_views(msg))[0].is_resource is True + + def test_is_text(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + assert list(iter_views(msg))[0].is_text is True + + def test_is_text_thinking(self): + msg = Message(role=Role.ASSISTANT, content=[ThinkingContent(text="hmm")]) + assert list(iter_views(msg))[0].is_text is True + + def test_is_media(self): + for part in [ + ImageContentPart(content=ImageSource(type="url", data="http://img")), + VideoContentPart(content=VideoSource(type="url", data="http://vid")), + AudioContentPart(content=AudioSource(type="url", data="http://aud")), + DocumentContentPart(content=DocumentSource(type="url", data="http://doc")), + ]: + msg = Message(role=Role.USER, content=[part]) + assert list(iter_views(msg))[0].is_media is True + + def test_text_is_not_media(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + assert list(iter_views(msg))[0].is_media is False + + +# --------------------------------------------------------------------------- +# Flat Accessors +# --------------------------------------------------------------------------- + + +class TestFlatAccessors: + """Tests for capability-gated flat accessors.""" + + def test_base_tier(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.environment == "production" + assert view.request_id == "req-001" + + def test_subject(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.subject.id == "user-alice" + assert view.subject.type == SubjectType.USER + + def test_roles(self, full_msg): + view = list(iter_views(full_msg))[0] + assert "admin" in view.roles + assert "hr-manager" in view.roles + + def test_permissions(self, full_msg): + view = list(iter_views(full_msg))[0] + assert "read:compensation" in view.permissions + + def test_teams(self, full_msg): + view = list(iter_views(full_msg))[0] + assert "hr-team" in view.teams + + def test_headers(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.headers["Content-Type"] == "application/json" + + def test_labels(self, full_msg): + view = list(iter_views(full_msg))[0] + assert "CONFIDENTIAL" in view.labels + + def test_agent_accessors(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.agent_input == "Show me Alice's compensation" + assert view.session_id == "sess-001" + assert view.conversation_id == "conv-001" + assert view.turn == 2 + assert view.agent_id == "main-agent" + assert view.parent_agent_id == "orchestrator" + + def test_object_profile(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.object is not None + assert view.object.managed_by == "tool" + assert view.object.permissions == ["read:compensation"] + assert view.object.trust_domain == "internal" + assert view.object.data_scope == ["salary", "bonus"] + + def test_data_policy(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.data_policy is not None + assert "PII" in view.data_policy.apply_labels + assert "export" in view.data_policy.denied_actions + assert view.data_policy.retention.policy == "session" + + def test_no_extensions(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg))[0] + assert view.environment is None + assert view.request_id is None + assert view.subject is None + assert view.roles == frozenset() + assert view.permissions == frozenset() + assert view.teams == frozenset() + assert view.headers == {} + assert view.labels == frozenset() + assert view.agent_input is None + assert view.session_id is None + assert view.conversation_id is None + assert view.turn is None + assert view.agent_id is None + assert view.parent_agent_id is None + assert view.object is None + assert view.data_policy is None + + def test_object_resolves_by_name(self): + msg = Message( + role=Role.ASSISTANT, + content=[ + ToolCallContentPart(content=ToolCall(tool_call_id="tc1", name="tool_a", arguments={})), + ToolCallContentPart(content=ToolCall(tool_call_id="tc2", name="tool_b", arguments={})), + ], + extensions=Extensions( + security=SecurityExtension( + objects={"tool_a": ObjectSecurityProfile(managed_by="host")}, + ), + ), + ) + views = list(iter_views(msg)) + assert views[0].object is not None + assert views[0].object.managed_by == "host" + assert views[1].object is None + + +# --------------------------------------------------------------------------- +# Helper Methods +# --------------------------------------------------------------------------- + + +class TestHelperMethods: + """Tests for helper methods on MessageView.""" + + def test_has_role(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.has_role("admin") is True + assert view.has_role("viewer") is False + + def test_has_permission(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.has_permission("read:compensation") is True + assert view.has_permission("write:users") is False + + def test_has_label(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.has_label("CONFIDENTIAL") is True + assert view.has_label("SECRET") is False + + def test_has_header(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.has_header("Content-Type") is True + assert view.has_header("content-type") is True + assert view.has_header("X-Missing") is False + + def test_get_header_case_insensitive(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.get_header("content-type") == "application/json" + assert view.get_header("CONTENT-TYPE") == "application/json" + + def test_get_header_default(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.get_header("X-Missing") is None + assert view.get_header("X-Missing", "fallback") == "fallback" + + def test_get_arg(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.get_arg("employee_id") == "emp-42" + assert view.get_arg("missing") is None + assert view.get_arg("missing", "default") == "default" + + def test_has_arg(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.has_arg("employee_id") is True + assert view.has_arg("missing") is False + + def test_has_arg_text_view(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg))[0] + assert view.has_arg("anything") is False + + def test_matches_uri_pattern(self, full_msg): + view = list(iter_views(full_msg))[0] + assert view.matches_uri_pattern("tool://hr-server/*") is True + assert view.matches_uri_pattern("tool://hr-server/get_*") is True + assert view.matches_uri_pattern("tool://other/*") is False + assert view.matches_uri_pattern("tool://**") is True + + def test_matches_uri_pattern_no_uri(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg))[0] + assert view.matches_uri_pattern("*") is False + + def test_has_content(self, simple_assistant_msg): + views = list(iter_views(simple_assistant_msg)) + assert views[0].has_content() is True + assert views[1].has_content() is True + assert views[2].has_content() is True + + +# --------------------------------------------------------------------------- +# Type-Specific Properties +# --------------------------------------------------------------------------- + + +class TestProperties: + """Tests for type-specific properties.""" + + def test_tool_call_properties(self): + msg = Message( + role=Role.ASSISTANT, + content=[ToolCallContentPart(content=ToolCall(tool_call_id="tc1", name="test", namespace="ns", arguments={}))], + ) + view = list(iter_views(msg))[0] + assert view.get_property("namespace") == "ns" + assert view.get_property("tool_id") == "tc1" + props = view.properties + assert props["namespace"] == "ns" + assert props["tool_id"] == "tc1" + + def test_tool_result_properties(self): + msg = Message( + role=Role.TOOL, + content=[ToolResultContentPart(content=ToolResult(tool_call_id="tc1", tool_name="test", is_error=True))], + ) + view = list(iter_views(msg))[0] + assert view.get_property("is_error") is True + assert view.get_property("tool_name") == "test" + + def test_resource_properties(self): + msg = Message( + role=Role.TOOL, + content=[ResourceContentPart(content=Resource( + resource_request_id="r1", uri="f:///a", + resource_type=ResourceType.FILE, version="v2", + annotations={"key": "val"}, + ))], + ) + view = list(iter_views(msg))[0] + assert view.get_property("resource_type") == "file" + assert view.get_property("version") == "v2" + assert view.get_property("annotations") == {"key": "val"} + + def test_prompt_result_properties(self): + msg = Message( + role=Role.TOOL, + content=[PromptResultContentPart(content=PromptResult( + prompt_request_id="p1", prompt_name="s", + messages=[ + Message(role=Role.USER, content=[TextContent(text="m1")]), + Message(role=Role.ASSISTANT, content=[TextContent(text="m2")]), + ], + is_error=False, + ))], + ) + view = list(iter_views(msg))[0] + assert view.get_property("is_error") is False + assert view.get_property("message_count") == 2 + + def test_get_property_default(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg))[0] + assert view.get_property("anything") is None + assert view.get_property("anything", "fallback") == "fallback" + + def test_empty_properties_for_text(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg))[0] + assert view.properties == {} + + +# --------------------------------------------------------------------------- +# Misc Properties +# --------------------------------------------------------------------------- + + +class TestMiscProperties: + """Tests for mime_type, size_bytes, args.""" + + def test_mime_type_resource(self): + msg = Message(role=Role.TOOL, content=[ResourceContentPart(content=Resource( + resource_request_id="r1", uri="f:///a", + resource_type=ResourceType.FILE, mime_type="text/csv", + ))]) + assert list(iter_views(msg))[0].mime_type == "text/csv" + + def test_mime_type_image(self): + msg = Message(role=Role.USER, content=[ + ImageContentPart(content=ImageSource(type="url", data="http://img", media_type="image/png")), + ]) + assert list(iter_views(msg))[0].mime_type == "image/png" + + def test_mime_type_text_none(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + assert list(iter_views(msg))[0].mime_type is None + + def test_size_bytes_text(self): + msg = Message(role=Role.USER, content=[TextContent(text="hello")]) + assert list(iter_views(msg))[0].size_bytes == 5 + + def test_size_bytes_resource_explicit(self): + msg = Message(role=Role.TOOL, content=[ResourceContentPart(content=Resource( + resource_request_id="r1", uri="f:///a", + resource_type=ResourceType.FILE, size_bytes=1024, + ))]) + assert list(iter_views(msg))[0].size_bytes == 1024 + + def test_size_bytes_resource_from_content(self): + msg = Message(role=Role.TOOL, content=[ResourceContentPart(content=Resource( + resource_request_id="r1", uri="f:///a", + resource_type=ResourceType.FILE, content="hello", + ))]) + assert list(iter_views(msg))[0].size_bytes == 5 + + def test_args_tool_call(self): + msg = Message( + role=Role.ASSISTANT, + content=[ToolCallContentPart(content=ToolCall(tool_call_id="tc1", name="x", arguments={"a": 1, "b": 2}))], + ) + view = list(iter_views(msg))[0] + assert view.args == {"a": 1, "b": 2} + + def test_args_text_none(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + assert list(iter_views(msg))[0].args is None + + +# --------------------------------------------------------------------------- +# Serialization +# --------------------------------------------------------------------------- + + +class TestSerialization: + """Tests for to_dict() and to_opa_input().""" + + def test_to_dict_basic(self, simple_assistant_msg): + view = list(iter_views(simple_assistant_msg))[2] + d = view.to_dict() + assert d["kind"] == "tool_call" + assert d["role"] == "assistant" + assert d["is_pre"] is True + assert d["is_post"] is False + assert d["action"] == "execute" + assert d["name"] == "execute_sql" + assert d["uri"] == "tool://_/execute_sql" + + def test_to_dict_strips_sensitive_headers(self, full_msg): + view = list(iter_views(full_msg))[0] + d = view.to_dict() + headers = d["extensions"].get("headers", {}) + assert "Authorization" not in headers + assert "Cookie" not in headers + assert "Content-Type" in headers + + def test_to_dict_includes_extensions(self, full_msg): + view = list(iter_views(full_msg))[0] + d = view.to_dict() + ext = d["extensions"] + assert ext["environment"] == "production" + assert ext["subject"]["id"] == "user-alice" + assert "CONFIDENTIAL" in ext["labels"] + assert ext["object"]["managed_by"] == "tool" + assert "PII" in ext["data"]["apply_labels"] + assert ext["agent"]["input"] == "Show me Alice's compensation" + + def test_to_dict_exclude_content(self, simple_assistant_msg): + view = list(iter_views(simple_assistant_msg))[0] + d = view.to_dict(include_content=False) + assert "content" not in d + assert "size_bytes" not in d + + def test_to_dict_exclude_context(self, full_msg): + view = list(iter_views(full_msg))[0] + d = view.to_dict(include_context=False) + assert "extensions" not in d + + def test_to_opa_input(self, simple_assistant_msg): + view = list(iter_views(simple_assistant_msg))[2] + opa = view.to_opa_input() + assert "input" in opa + assert opa["input"]["kind"] == "tool_call" + + def test_to_dict_no_extensions(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg))[0] + d = view.to_dict() + assert "extensions" not in d + + +# --------------------------------------------------------------------------- +# Repr +# --------------------------------------------------------------------------- + + +class TestRepr: + """Tests for __repr__.""" + + def test_repr_text(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg))[0] + r = repr(view) + assert "kind=text" in r + assert "role=user" in r + assert "pre" in r + + def test_repr_tool_call(self): + msg = Message( + role=Role.ASSISTANT, + content=[ToolCallContentPart(content=ToolCall(tool_call_id="tc1", name="test", arguments={}))], + ) + view = list(iter_views(msg))[0] + r = repr(view) + assert "tool_call" in r + assert "tool://_/test" in r + + +# --------------------------------------------------------------------------- +# Hook property +# --------------------------------------------------------------------------- + + +class TestHookProperty: + """Tests for the hook property on MessageView.""" + + def test_hook_none_by_default(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg))[0] + assert view.hook is None + + def test_hook_passed_through(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg, hook="llm_input"))[0] + assert view.hook == "llm_input" + + def test_hook_in_to_dict(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg, hook="tool_pre_invoke"))[0] + d = view.to_dict() + assert d["hook"] == "tool_pre_invoke" + + def test_hook_absent_from_to_dict_when_none(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg))[0] + d = view.to_dict() + assert "hook" not in d + + +# --------------------------------------------------------------------------- +# Content edge cases +# --------------------------------------------------------------------------- + + +class TestContentEdgeCases: + """Tests for content property edge cases (json fallbacks).""" + + def test_tool_call_non_serializable_args(self): + """Tool call with non-JSON-serializable arguments falls back to str().""" + tc = ToolCall(tool_call_id="tc1", name="test", arguments={"key": "val"}) + msg = Message(role=Role.ASSISTANT, content=[ToolCallContentPart(content=tc)]) + view = list(iter_views(msg))[0] + assert view.content is not None + + def test_prompt_request_content(self): + pr = PromptRequest( + prompt_request_id="pr1", name="test", + arguments={"key": "val"}, + ) + msg = Message(role=Role.USER, content=[PromptRequestContentPart(content=pr)]) + view = list(iter_views(msg))[0] + assert '"key"' in view.content + + def test_prompt_result_content(self): + pr = PromptResult( + prompt_request_id="pr1", prompt_name="test", + content="rendered text", + ) + msg = Message(role=Role.TOOL, content=[PromptResultContentPart(content=pr)]) + view = list(iter_views(msg))[0] + assert view.content == "rendered text" + + def test_resource_blob_size(self): + """Resource with blob but no content still reports size_bytes.""" + res = Resource( + resource_request_id="r1", uri="file:///a.bin", + resource_type=ResourceType.FILE, blob=b"\x00\x01\x02", + ) + msg = Message(role=Role.TOOL, content=[ResourceContentPart(content=res)]) + view = list(iter_views(msg))[0] + assert view.content is None + assert view.size_bytes == 3 + + def test_resource_explicit_size(self): + """Resource with explicit size_bytes uses that value.""" + res = Resource( + resource_request_id="r1", uri="file:///a.txt", + resource_type=ResourceType.FILE, content="hello", + size_bytes=999, + ) + msg = Message(role=Role.TOOL, content=[ResourceContentPart(content=res)]) + view = list(iter_views(msg))[0] + assert view.size_bytes == 999 + + def test_to_dict_no_content_with_blob_size(self): + """to_dict includes size_bytes even when content is None (blob path).""" + res = Resource( + resource_request_id="r1", uri="file:///a.bin", + resource_type=ResourceType.FILE, blob=b"\x00\x01", + ) + msg = Message(role=Role.TOOL, content=[ResourceContentPart(content=res)]) + view = list(iter_views(msg))[0] + d = view.to_dict() + assert "content" not in d + assert d["size_bytes"] == 2 + + +# --------------------------------------------------------------------------- +# Properties edge cases +# --------------------------------------------------------------------------- + + +class TestPropertiesEdgeCases: + """Tests for properties on various view kinds.""" + + def test_resource_properties(self): + res = Resource( + resource_request_id="r1", uri="file:///a.txt", + resource_type=ResourceType.FILE, content="hi", + version="v1", annotations={"label": "pii"}, + ) + msg = Message(role=Role.TOOL, content=[ResourceContentPart(content=res)]) + view = list(iter_views(msg))[0] + props = view.properties + assert props["resource_type"] == "file" + assert props["version"] == "v1" + assert props["annotations"] == {"label": "pii"} + + def test_tool_call_properties(self): + tc = ToolCall( + tool_call_id="tc1", name="test", + namespace="ns", arguments={}, + ) + msg = Message(role=Role.ASSISTANT, content=[ToolCallContentPart(content=tc)]) + view = list(iter_views(msg))[0] + props = view.properties + assert props["namespace"] == "ns" + assert props["tool_id"] == "tc1" + + def test_tool_result_properties(self): + tr = ToolResult( + tool_call_id="tc1", tool_name="test", + content="result", is_error=True, + ) + msg = Message(role=Role.TOOL, content=[ToolResultContentPart(content=tr)]) + view = list(iter_views(msg))[0] + props = view.properties + assert props["is_error"] is True + assert props["tool_name"] == "test" + + def test_prompt_request_properties(self): + pr = PromptRequest( + prompt_request_id="pr1", name="test", + server_id="srv1", + ) + msg = Message(role=Role.USER, content=[PromptRequestContentPart(content=pr)]) + view = list(iter_views(msg))[0] + props = view.properties + assert props["server_id"] == "srv1" + + +# --------------------------------------------------------------------------- +# Headers immutability +# --------------------------------------------------------------------------- + + +class TestHeadersImmutability: + """Tests that headers returns a read-only mapping.""" + + def test_headers_not_mutable(self): + ext = Extensions(http=HttpExtension(headers={"Authorization": "Bearer tok"})) + msg = Message( + role=Role.USER, + content=[TextContent(text="hi")], + extensions=ext, + ) + view = list(iter_views(msg))[0] + with pytest.raises(TypeError): + view.headers["new_key"] = "val" + + +# --------------------------------------------------------------------------- +# get_arg / has_arg +# --------------------------------------------------------------------------- + + +class TestArgHelpers: + """Tests for get_arg and has_arg.""" + + def test_get_arg_on_non_tool(self): + msg = Message(role=Role.USER, content=[TextContent(text="hi")]) + view = list(iter_views(msg))[0] + assert view.get_arg("anything") is None + assert view.get_arg("anything", "fallback") == "fallback" diff --git a/tests/unit/cpex/framework/extensions/__init__.py b/tests/unit/cpex/framework/extensions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/cpex/framework/extensions/test_extensions.py b/tests/unit/cpex/framework/extensions/test_extensions.py new file mode 100644 index 0000000..7c7e4fe --- /dev/null +++ b/tests/unit/cpex/framework/extensions/test_extensions.py @@ -0,0 +1,624 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/cpex/framework/extensions/test_extensions.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for extension models. +""" + +# Standard +from typing import Any + +# Third-Party +import pytest + +# First-Party +from cpex.framework.extensions.agent import AgentExtension, ConversationContext +from cpex.framework.extensions.completion import ( + CompletionExtension, + StopReason, + TokenUsage, +) +from cpex.framework.extensions.extensions import Extensions +from cpex.framework.extensions.framework import FrameworkExtension +from cpex.framework.extensions.http import HttpExtension +from cpex.framework.extensions.llm import LLMExtension +from cpex.framework.extensions.mcp import ( + MCPExtension, + PromptMetadata, + ResourceMetadata, + ToolMetadata, +) +from cpex.framework.extensions.provenance import ProvenanceExtension +from cpex.framework.extensions.request import RequestExtension +from cpex.framework.extensions.security import ( + DataPolicy, + ObjectSecurityProfile, + RetentionPolicy, + SecurityExtension, + SubjectExtension, + SubjectType, +) + + +# --------------------------------------------------------------------------- +# RequestExtension +# --------------------------------------------------------------------------- + + +class TestRequestExtension: + """Tests for RequestExtension.""" + + def test_creation(self): + ext = RequestExtension( + environment="production", + request_id="req-001", + timestamp="2025-01-15T10:30:00Z", + ) + assert ext.environment == "production" + assert ext.request_id == "req-001" + assert ext.timestamp == "2025-01-15T10:30:00Z" + + def test_defaults(self): + ext = RequestExtension() + assert ext.environment is None + assert ext.request_id is None + assert ext.timestamp is None + assert ext.trace_id is None + assert ext.span_id is None + + def test_frozen(self): + ext = RequestExtension(environment="dev") + with pytest.raises(Exception): + ext.environment = "production" + + def test_model_copy(self): + ext = RequestExtension(environment="dev") + updated = ext.model_copy(update={"environment": "production"}) + assert ext.environment == "dev" + assert updated.environment == "production" + + def test_tracing_fields(self): + ext = RequestExtension( + trace_id="trace-abc", + span_id="span-123", + ) + assert ext.trace_id == "trace-abc" + assert ext.span_id == "span-123" + + +# --------------------------------------------------------------------------- +# AgentExtension +# --------------------------------------------------------------------------- + + +class TestConversationContext: + """Tests for ConversationContext.""" + + def test_creation(self): + ctx = ConversationContext( + summary="User asked about revenue.", + topics=["revenue", "Q4"], + ) + assert ctx.summary == "User asked about revenue." + assert ctx.topics == ["revenue", "Q4"] + + def test_defaults(self): + ctx = ConversationContext() + assert ctx.history == [] + assert ctx.summary is None + assert ctx.topics == [] + + def test_frozen(self): + ctx = ConversationContext(summary="test") + with pytest.raises(Exception): + ctx.summary = "modified" + + +class TestAgentExtension: + """Tests for AgentExtension.""" + + def test_creation(self): + ext = AgentExtension( + input="What is the weather?", + session_id="sess-001", + conversation_id="conv-042", + turn=3, + agent_id="weather-agent", + ) + assert ext.input == "What is the weather?" + assert ext.session_id == "sess-001" + assert ext.turn == 3 + + def test_defaults(self): + ext = AgentExtension() + assert ext.input is None + assert ext.session_id is None + assert ext.conversation_id is None + assert ext.turn is None + assert ext.agent_id is None + assert ext.parent_agent_id is None + assert ext.conversation is None + + def test_multi_agent_lineage(self): + ext = AgentExtension( + agent_id="sub-agent-01", + parent_agent_id="main-agent", + ) + assert ext.parent_agent_id == "main-agent" + + def test_with_conversation(self): + conv = ConversationContext(summary="Prior context", topics=["weather"]) + ext = AgentExtension(conversation=conv) + assert ext.conversation.summary == "Prior context" + + +# --------------------------------------------------------------------------- +# HttpExtension +# --------------------------------------------------------------------------- + + +class TestHttpExtension: + """Tests for HttpExtension.""" + + def test_creation(self): + ext = HttpExtension( + headers={"Content-Type": "application/json", "X-Request-ID": "req-001"}, + ) + assert ext.headers["Content-Type"] == "application/json" + + def test_defaults(self): + ext = HttpExtension() + assert ext.headers == {} + + def test_frozen(self): + ext = HttpExtension(headers={"X-Test": "value"}) + with pytest.raises(Exception): + ext.headers = {} + + def test_model_copy_add_header(self): + ext = HttpExtension(headers={"X-Test": "value"}) + updated = ext.model_copy( + update={"headers": {**ext.headers, "X-New": "added"}}, + ) + assert "X-New" in updated.headers + assert "X-New" not in ext.headers + + +# --------------------------------------------------------------------------- +# SecurityExtension +# --------------------------------------------------------------------------- + + +class TestSubjectType: + """Tests for SubjectType enum.""" + + def test_values(self): + assert SubjectType.USER.value == "user" + assert SubjectType.AGENT.value == "agent" + assert SubjectType.SERVICE.value == "service" + assert SubjectType.SYSTEM.value == "system" + + def test_member_count(self): + assert len(SubjectType) == 4 + + +class TestSubjectExtension: + """Tests for SubjectExtension.""" + + def test_creation(self): + subject = SubjectExtension( + id="user-alice", + type=SubjectType.USER, + roles=frozenset({"admin", "developer"}), + permissions=frozenset({"db.read", "tools.execute"}), + ) + assert subject.id == "user-alice" + assert subject.type == SubjectType.USER + assert "admin" in subject.roles + assert "db.read" in subject.permissions + + def test_defaults(self): + subject = SubjectExtension(id="svc-1", type=SubjectType.SERVICE) + assert subject.roles == frozenset() + assert subject.permissions == frozenset() + assert subject.teams == frozenset() + assert subject.claims == {} + + def test_frozen_sets(self): + subject = SubjectExtension( + id="test", + type=SubjectType.USER, + roles=frozenset({"admin"}), + ) + assert isinstance(subject.roles, frozenset) + assert isinstance(subject.permissions, frozenset) + assert isinstance(subject.teams, frozenset) + + +class TestObjectSecurityProfile: + """Tests for ObjectSecurityProfile.""" + + def test_creation(self): + profile = ObjectSecurityProfile( + managed_by="tool", + permissions=["read:compensation"], + trust_domain="internal", + data_scope=["salary", "bonus"], + ) + assert profile.managed_by == "tool" + assert "read:compensation" in profile.permissions + assert profile.trust_domain == "internal" + + def test_defaults(self): + profile = ObjectSecurityProfile() + assert profile.managed_by == "host" + assert profile.permissions == [] + assert profile.trust_domain is None + assert profile.data_scope == [] + + +class TestRetentionPolicy: + """Tests for RetentionPolicy.""" + + def test_creation(self): + ret = RetentionPolicy( + max_age_seconds=3600, + policy="session", + ) + assert ret.max_age_seconds == 3600 + assert ret.policy == "session" + + def test_defaults(self): + ret = RetentionPolicy() + assert ret.max_age_seconds is None + assert ret.policy == "persistent" + assert ret.delete_after is None + + +class TestDataPolicy: + """Tests for DataPolicy.""" + + def test_creation(self): + policy = DataPolicy( + apply_labels=["PII", "financial"], + denied_actions=["export", "forward"], + retention=RetentionPolicy(policy="session", max_age_seconds=7200), + ) + assert "PII" in policy.apply_labels + assert "export" in policy.denied_actions + assert policy.retention.policy == "session" + + def test_defaults(self): + policy = DataPolicy() + assert policy.apply_labels == [] + assert policy.allowed_actions is None + assert policy.denied_actions == [] + assert policy.retention is None + + def test_unrestricted_vs_restricted(self): + unrestricted = DataPolicy() + restricted = DataPolicy(allowed_actions=["view", "summarize"]) + assert unrestricted.allowed_actions is None + assert restricted.allowed_actions == ["view", "summarize"] + + +class TestSecurityExtension: + """Tests for SecurityExtension.""" + + def test_creation(self): + ext = SecurityExtension( + labels=frozenset({"PII", "CONFIDENTIAL"}), + classification="confidential", + subject=SubjectExtension(id="user-1", type=SubjectType.USER), + ) + assert "PII" in ext.labels + assert ext.classification == "confidential" + assert ext.subject.id == "user-1" + + def test_defaults(self): + ext = SecurityExtension() + assert ext.labels == frozenset() + assert ext.classification is None + assert ext.subject is None + assert ext.objects == {} + assert ext.data == {} + + def test_monotonic_label_addition(self): + ext = SecurityExtension(labels=frozenset({"PII"})) + updated = ext.model_copy( + update={"labels": ext.labels | frozenset({"CONFIDENTIAL"})}, + ) + assert "PII" in updated.labels + assert "CONFIDENTIAL" in updated.labels + assert ext.labels == frozenset({"PII"}) + + def test_labels_are_frozenset(self): + ext = SecurityExtension(labels=frozenset({"PII"})) + assert isinstance(ext.labels, frozenset) + + def test_with_objects_and_data(self): + ext = SecurityExtension( + objects={ + "get_user": ObjectSecurityProfile( + managed_by="host", + permissions=["users.read"], + ), + }, + data={ + "get_user": DataPolicy( + apply_labels=["PII"], + denied_actions=["export"], + ), + }, + ) + assert ext.objects["get_user"].permissions == ["users.read"] + assert ext.data["get_user"].apply_labels == ["PII"] + + +# --------------------------------------------------------------------------- +# MCPExtension +# --------------------------------------------------------------------------- + + +class TestToolMetadata: + """Tests for ToolMetadata.""" + + def test_creation(self): + meta = ToolMetadata( + name="get_user", + description="Retrieve user by ID", + input_schema={"type": "object", "properties": {"id": {"type": "string"}}}, + ) + assert meta.name == "get_user" + assert meta.input_schema is not None + + def test_defaults(self): + meta = ToolMetadata(name="test") + assert meta.title is None + assert meta.description is None + assert meta.input_schema is None + assert meta.output_schema is None + assert meta.server_id is None + assert meta.namespace is None + assert meta.annotations == {} + + +class TestResourceMetadata: + """Tests for ResourceMetadata.""" + + def test_creation(self): + meta = ResourceMetadata( + uri="file:///data/report.csv", + name="Quarterly Report", + mime_type="text/csv", + ) + assert meta.uri == "file:///data/report.csv" + + +class TestPromptMetadata: + """Tests for PromptMetadata.""" + + def test_creation(self): + meta = PromptMetadata( + name="summarize", + arguments=[{"name": "text", "description": "Text to summarize", "required": True}], + ) + assert meta.name == "summarize" + assert meta.arguments[0]["name"] == "text" + + +class TestMCPExtension: + """Tests for MCPExtension.""" + + def test_with_tool(self): + ext = MCPExtension(tool=ToolMetadata(name="get_user")) + assert ext.tool.name == "get_user" + assert ext.resource is None + assert ext.prompt is None + + def test_with_resource(self): + ext = MCPExtension(resource=ResourceMetadata(uri="file:///test")) + assert ext.resource.uri == "file:///test" + assert ext.tool is None + + def test_with_prompt(self): + ext = MCPExtension(prompt=PromptMetadata(name="summarize")) + assert ext.prompt.name == "summarize" + + def test_defaults(self): + ext = MCPExtension() + assert ext.tool is None + assert ext.resource is None + assert ext.prompt is None + + +# --------------------------------------------------------------------------- +# CompletionExtension +# --------------------------------------------------------------------------- + + +class TestStopReason: + """Tests for StopReason enum.""" + + def test_values(self): + assert StopReason.END.value == "end" + assert StopReason.MAX_TOKENS.value == "max_tokens" + + def test_member_count(self): + assert len(StopReason) == 5 + + +class TestTokenUsage: + """Tests for TokenUsage.""" + + def test_creation(self): + usage = TokenUsage(input_tokens=100, output_tokens=50, total_tokens=150) + assert usage.total_tokens == 150 + + +class TestCompletionExtension: + """Tests for CompletionExtension.""" + + def test_creation(self): + ext = CompletionExtension( + stop_reason=StopReason.END, + tokens=TokenUsage(input_tokens=100, output_tokens=50, total_tokens=150), + model="gpt-4o", + latency_ms=1200, + ) + assert ext.stop_reason == StopReason.END + assert ext.tokens.total_tokens == 150 + assert ext.model == "gpt-4o" + assert ext.latency_ms == 1200 + + def test_defaults(self): + ext = CompletionExtension() + assert ext.stop_reason is None + assert ext.tokens is None + assert ext.model is None + assert ext.raw_format is None + assert ext.created_at is None + assert ext.latency_ms is None + + +# --------------------------------------------------------------------------- +# ProvenanceExtension +# --------------------------------------------------------------------------- + + +class TestProvenanceExtension: + """Tests for ProvenanceExtension.""" + + def test_creation(self): + ext = ProvenanceExtension( + source="agent:weather-bot", + message_id="msg-001", + parent_id="msg-000", + ) + assert ext.source == "agent:weather-bot" + assert ext.message_id == "msg-001" + + def test_defaults(self): + ext = ProvenanceExtension() + assert ext.source is None + assert ext.message_id is None + assert ext.parent_id is None + + +# --------------------------------------------------------------------------- +# LLMExtension +# --------------------------------------------------------------------------- + + +class TestLLMExtension: + """Tests for LLMExtension.""" + + def test_creation(self): + ext = LLMExtension( + model_id="claude-sonnet-4-20250514", + provider="anthropic", + capabilities=["vision", "tool_use"], + ) + assert ext.provider == "anthropic" + assert "tool_use" in ext.capabilities + + def test_defaults(self): + ext = LLMExtension() + assert ext.model_id is None + assert ext.provider is None + assert ext.capabilities == [] + + +# --------------------------------------------------------------------------- +# FrameworkExtension +# --------------------------------------------------------------------------- + + +class TestFrameworkExtension: + """Tests for FrameworkExtension.""" + + def test_creation(self): + ext = FrameworkExtension( + framework="langgraph", + framework_version="0.2.0", + node_id="weather_node", + graph_id="travel_planner", + ) + assert ext.framework == "langgraph" + assert ext.node_id == "weather_node" + + def test_defaults(self): + ext = FrameworkExtension() + assert ext.framework is None + assert ext.framework_version is None + assert ext.node_id is None + assert ext.graph_id is None + assert ext.metadata == {} + + +# --------------------------------------------------------------------------- +# Extensions Container +# --------------------------------------------------------------------------- + + +class TestExtensions: + """Tests for the Extensions container.""" + + def test_all_none_by_default(self): + ext = Extensions() + assert ext.request is None + assert ext.agent is None + assert ext.http is None + assert ext.security is None + assert ext.mcp is None + assert ext.completion is None + assert ext.provenance is None + assert ext.llm is None + assert ext.framework is None + assert ext.custom is None + + def test_frozen(self): + ext = Extensions() + with pytest.raises(Exception): + ext.request = RequestExtension() + + def test_model_copy(self): + ext = Extensions( + request=RequestExtension(environment="dev"), + ) + updated = ext.model_copy( + update={"request": RequestExtension(environment="production")}, + ) + assert ext.request.environment == "dev" + assert updated.request.environment == "production" + + def test_full_construction(self): + ext = Extensions( + request=RequestExtension(environment="production", request_id="req-001"), + agent=AgentExtension(input="Hello", session_id="sess-001"), + http=HttpExtension(headers={"X-Test": "value"}), + security=SecurityExtension(labels=frozenset({"PII"})), + mcp=MCPExtension(tool=ToolMetadata(name="get_user")), + completion=CompletionExtension(stop_reason=StopReason.END), + provenance=ProvenanceExtension(source="user"), + llm=LLMExtension(model_id="gpt-4o", provider="openai"), + framework=FrameworkExtension(framework="langgraph"), + custom={"debug": True}, + ) + assert ext.request.environment == "production" + assert ext.agent.input == "Hello" + assert ext.http.headers["X-Test"] == "value" + assert "PII" in ext.security.labels + assert ext.mcp.tool.name == "get_user" + assert ext.completion.stop_reason == StopReason.END + assert ext.provenance.source == "user" + assert ext.llm.provider == "openai" + assert ext.framework.framework == "langgraph" + assert ext.custom["debug"] is True + + def test_custom_extension(self): + ext = Extensions(custom={"key": "value", "nested": {"a": 1}}) + assert ext.custom["key"] == "value" + assert ext.custom["nested"]["a"] == 1 diff --git a/tests/unit/cpex/framework/extensions/test_tiers.py b/tests/unit/cpex/framework/extensions/test_tiers.py new file mode 100644 index 0000000..b640002 --- /dev/null +++ b/tests/unit/cpex/framework/extensions/test_tiers.py @@ -0,0 +1,748 @@ +# -*- coding: utf-8 -*- +"""Tests for cpex.framework.extensions.tiers module. + +Covers mutability tiers, capability gating, extension filtering, +tier constraint validation, and lockdown (private registry, frozen config). +""" + +# Standard +from __future__ import annotations + +# Third-Party +import pytest + +# First-Party +from cpex.framework.extensions.agent import AgentExtension +from cpex.framework.extensions.extensions import Extensions +from cpex.framework.extensions.http import HttpExtension +from cpex.framework.extensions.request import RequestExtension +from cpex.framework.extensions.security import ( + SecurityExtension, + SubjectExtension, + SubjectType, +) +from cpex.framework.extensions.constants import SlotName +from cpex.framework.extensions.tiers import ( + AccessPolicy, + Capability, + MutabilityTier, + SlotPolicy, + TierViolationError, + _slot_registry, + filter_extensions, + merge_extensions, + validate_tier_constraints, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def subject(): + return SubjectExtension( + id="user-alice", + type=SubjectType.USER, + roles=frozenset({"admin", "developer"}), + teams=frozenset({"platform"}), + claims={"sub": "alice", "iss": "idp"}, + permissions=frozenset({"tools.execute", "db.read"}), + ) + + +@pytest.fixture() +def security_ext(subject): + return SecurityExtension( + labels=frozenset({"PII", "CONFIDENTIAL"}), + classification="confidential", + subject=subject, + ) + + +@pytest.fixture() +def full_extensions(security_ext): + return Extensions( + request=RequestExtension(environment="test", request_id="req-1"), + agent=AgentExtension(session_id="sess-1"), + http=HttpExtension(headers={"Authorization": "Bearer tok"}), + security=security_ext, + custom={"key": "value"}, + ) + + +# --------------------------------------------------------------------------- +# Enum / SlotPolicy basics +# --------------------------------------------------------------------------- + + +class TestMutabilityTier: + def test_values(self): + assert MutabilityTier.IMMUTABLE.value == "immutable" + assert MutabilityTier.MONOTONIC.value == "monotonic" + assert MutabilityTier.MUTABLE.value == "mutable" + + def test_string_enum(self): + assert MutabilityTier("immutable") == MutabilityTier.IMMUTABLE + + +class TestAccessPolicy: + def test_values(self): + assert AccessPolicy.UNRESTRICTED.value == "unrestricted" + assert AccessPolicy.CAPABILITY_GATED.value == "capability_gated" + + def test_string_enum(self): + assert AccessPolicy("unrestricted") == AccessPolicy.UNRESTRICTED + + +class TestCapability: + def test_all_values(self): + assert len(Capability) == 10 # noqa: PLR2004 + assert Capability.READ_SUBJECT.value == "read_subject" + assert Capability.APPEND_LABELS.value == "append_labels" + assert Capability.WRITE_HEADERS.value == "write_headers" + + def test_string_enum(self): + assert Capability("read_agent") == Capability.READ_AGENT + + +class TestSlotPolicy: + def test_frozen(self): + policy = SlotPolicy(MutabilityTier.IMMUTABLE) + with pytest.raises(AttributeError): + policy.tier = MutabilityTier.MUTABLE # type: ignore[misc] + + def test_defaults(self): + policy = SlotPolicy(MutabilityTier.IMMUTABLE) + assert policy.access == AccessPolicy.UNRESTRICTED + assert policy.read_cap is None + assert policy.write_cap is None + + def test_capability_gated(self): + policy = SlotPolicy( + MutabilityTier.IMMUTABLE, + access=AccessPolicy.CAPABILITY_GATED, + read_cap=Capability.READ_AGENT, + ) + assert policy.access == AccessPolicy.CAPABILITY_GATED + assert policy.read_cap == Capability.READ_AGENT + + +class TestSlotRegistry: + def test_base_tier_slots_unrestricted(self): + for slot in ( + SlotName.REQUEST, + SlotName.PROVENANCE, + SlotName.COMPLETION, + SlotName.LLM, + SlotName.FRAMEWORK, + SlotName.MCP, + ): + policy = _slot_registry[slot] + assert policy.tier == MutabilityTier.IMMUTABLE + assert policy.access == AccessPolicy.UNRESTRICTED + assert policy.read_cap is None + assert policy.write_cap is None + + def test_agent_capability_gated(self): + policy = _slot_registry[SlotName.AGENT] + assert policy.access == AccessPolicy.CAPABILITY_GATED + assert policy.read_cap == Capability.READ_AGENT + assert policy.write_cap is None + + def test_http_capability_gated(self): + policy = _slot_registry[SlotName.HTTP] + assert policy.access == AccessPolicy.CAPABILITY_GATED + assert policy.read_cap == Capability.READ_HEADERS + assert policy.write_cap == Capability.WRITE_HEADERS + + def test_labels_monotonic_capability_gated(self): + policy = _slot_registry[SlotName.SECURITY_LABELS] + assert policy.tier == MutabilityTier.MONOTONIC + assert policy.access == AccessPolicy.CAPABILITY_GATED + assert policy.read_cap == Capability.READ_LABELS + assert policy.write_cap == Capability.APPEND_LABELS + + def test_custom_mutable_unrestricted(self): + policy = _slot_registry[SlotName.CUSTOM] + assert policy.tier == MutabilityTier.MUTABLE + assert policy.access == AccessPolicy.UNRESTRICTED + + def test_security_objects_unrestricted(self): + policy = _slot_registry[SlotName.SECURITY_OBJECTS] + assert policy.access == AccessPolicy.UNRESTRICTED + + def test_security_data_unrestricted(self): + policy = _slot_registry[SlotName.SECURITY_DATA] + assert policy.access == AccessPolicy.UNRESTRICTED + + def test_subject_subfields_capability_gated(self): + for slot in ( + SlotName.SECURITY_SUBJECT, + SlotName.SECURITY_SUBJECT_ROLES, + SlotName.SECURITY_SUBJECT_TEAMS, + SlotName.SECURITY_SUBJECT_CLAIMS, + SlotName.SECURITY_SUBJECT_PERMISSIONS, + ): + policy = _slot_registry[slot] + assert policy.access == AccessPolicy.CAPABILITY_GATED, f"{slot} should be capability-gated" + + def test_registry_is_read_only(self): + with pytest.raises(TypeError): + _slot_registry[SlotName.CUSTOM] = SlotPolicy(MutabilityTier.IMMUTABLE) # type: ignore[index] + + def test_registry_not_in_public_exports(self): + import cpex.framework.extensions as ext_pkg + + assert "SLOT_REGISTRY" not in ext_pkg.__all__ + assert "filter_extensions" not in ext_pkg.__all__ + assert "validate_tier_constraints" not in ext_pkg.__all__ + assert "SlotPolicy" not in ext_pkg.__all__ + + +# --------------------------------------------------------------------------- +# filter_extensions +# --------------------------------------------------------------------------- + + +class TestFilterExtensions: + def test_none_input(self): + assert filter_extensions(None, frozenset()) is None + + def test_no_capabilities_hides_gated_slots(self, full_extensions): + filtered = filter_extensions(full_extensions, frozenset()) + # Unrestricted slots pass through + assert filtered.request is not None + assert filtered.custom is not None + # Capability-gated slots hidden + assert filtered.agent is None + assert filtered.http is None + # Security sub-fields: subject hidden, labels hidden + assert filtered.security is not None + assert filtered.security.subject is None + assert filtered.security.labels == frozenset() + + def test_read_agent_makes_agent_visible(self, full_extensions): + caps = frozenset({"read_agent"}) + filtered = filter_extensions(full_extensions, caps) + assert filtered.agent is not None + assert filtered.agent.session_id == "sess-1" + + def test_read_headers_makes_http_visible(self, full_extensions): + caps = frozenset({"read_headers"}) + filtered = filter_extensions(full_extensions, caps) + assert filtered.http is not None + assert filtered.http.headers["Authorization"] == "Bearer tok" + + def test_write_headers_implies_read(self, full_extensions): + caps = frozenset({"write_headers"}) + filtered = filter_extensions(full_extensions, caps) + assert filtered.http is not None + + def test_append_labels_implies_read(self, full_extensions): + caps = frozenset({"append_labels"}) + filtered = filter_extensions(full_extensions, caps) + assert filtered.security.labels == frozenset({"PII", "CONFIDENTIAL"}) + + def test_read_labels_makes_labels_visible(self, full_extensions): + caps = frozenset({"read_labels"}) + filtered = filter_extensions(full_extensions, caps) + assert filtered.security.labels == frozenset({"PII", "CONFIDENTIAL"}) + + def test_no_filtering_returns_equal_object(self): + ext = Extensions(request=RequestExtension(environment="test", request_id="r1")) + result = filter_extensions(ext, frozenset()) + assert result == ext # Build-up always creates a new frozen instance + assert result is not ext + + def test_ungated_security_subfields_pass_through(self, full_extensions): + """security.objects and security.data are always visible.""" + filtered = filter_extensions(full_extensions, frozenset()) + assert filtered.security.objects == full_extensions.security.objects + assert filtered.security.data == full_extensions.security.data + + +class TestFilterSubjectGranular: + """Subject sub-field filtering: roles, teams, claims, permissions gated independently.""" + + def test_read_subject_only_hides_subfields(self, full_extensions): + caps = frozenset({"read_subject"}) + filtered = filter_extensions(full_extensions, caps) + subj = filtered.security.subject + assert subj is not None + assert subj.id == "user-alice" + assert subj.type == SubjectType.USER + # Sub-fields hidden + assert subj.roles == frozenset() + assert subj.teams == frozenset() + assert subj.claims == {} + assert subj.permissions == frozenset() + + def test_read_roles_implies_read_subject(self, full_extensions): + caps = frozenset({"read_roles"}) + filtered = filter_extensions(full_extensions, caps) + subj = filtered.security.subject + assert subj is not None + assert subj.id == "user-alice" + assert "admin" in subj.roles + # Other sub-fields still hidden + assert subj.teams == frozenset() + assert subj.claims == {} + assert subj.permissions == frozenset() + + def test_read_teams_implies_read_subject(self, full_extensions): + caps = frozenset({"read_teams"}) + filtered = filter_extensions(full_extensions, caps) + subj = filtered.security.subject + assert subj is not None + assert "platform" in subj.teams + assert subj.roles == frozenset() + + def test_read_claims_implies_read_subject(self, full_extensions): + caps = frozenset({"read_claims"}) + filtered = filter_extensions(full_extensions, caps) + subj = filtered.security.subject + assert subj is not None + assert subj.claims == {"sub": "alice", "iss": "idp"} + assert subj.roles == frozenset() + + def test_read_permissions_implies_read_subject(self, full_extensions): + caps = frozenset({"read_permissions"}) + filtered = filter_extensions(full_extensions, caps) + subj = filtered.security.subject + assert subj is not None + assert "tools.execute" in subj.permissions + assert subj.roles == frozenset() + + def test_multiple_subject_caps(self, full_extensions): + caps = frozenset({"read_roles", "read_permissions"}) + filtered = filter_extensions(full_extensions, caps) + subj = filtered.security.subject + assert "admin" in subj.roles + assert "tools.execute" in subj.permissions + assert subj.teams == frozenset() + assert subj.claims == {} + + def test_no_subject_extension_no_error(self): + ext = Extensions( + security=SecurityExtension(labels=frozenset({"PII"})), + ) + filtered = filter_extensions(ext, frozenset({"read_labels"})) + assert filtered.security.labels == frozenset({"PII"}) + assert filtered.security.subject is None + + +# --------------------------------------------------------------------------- +# validate_tier_constraints +# --------------------------------------------------------------------------- + + +class TestValidateTierConstraints: + def test_both_none(self): + validate_tier_constraints(None, None, frozenset(), "test-plugin") + + def test_no_change_passes(self, full_extensions): + validate_tier_constraints( + full_extensions, full_extensions, frozenset(), "test-plugin" + ) + + def test_immutable_no_write_cap_rejects_change(self): + before = Extensions( + request=RequestExtension(environment="prod", request_id="r1"), + ) + after = Extensions( + request=RequestExtension(environment="staging", request_id="r1"), + ) + with pytest.raises(TierViolationError) as exc_info: + validate_tier_constraints(before, after, frozenset(), "bad-plugin") + assert exc_info.value.plugin_name == "bad-plugin" + assert exc_info.value.slot == SlotName.REQUEST + assert exc_info.value.tier == MutabilityTier.IMMUTABLE + + def test_immutable_gated_rejects_without_cap(self): + before = Extensions( + http=HttpExtension(headers={"X-Foo": "bar"}), + ) + after = Extensions( + http=HttpExtension(headers={"X-Foo": "baz"}), + ) + with pytest.raises(TierViolationError) as exc_info: + validate_tier_constraints(before, after, frozenset(), "bad-plugin") + assert "write_headers" in exc_info.value.detail + + def test_immutable_gated_allows_with_write_cap(self): + before = Extensions( + http=HttpExtension(headers={"X-Foo": "bar"}), + ) + after = Extensions( + http=HttpExtension(headers={"X-Foo": "baz"}), + ) + caps = frozenset({"write_headers"}) + # Should not raise + validate_tier_constraints(before, after, caps, "good-plugin") + + def test_monotonic_superset_passes(self): + before = Extensions( + security=SecurityExtension(labels=frozenset({"PII"})), + ) + after = Extensions( + security=SecurityExtension(labels=frozenset({"PII", "CONFIDENTIAL"})), + ) + caps = frozenset({"append_labels"}) + validate_tier_constraints(before, after, caps, "good-plugin") + + def test_monotonic_removal_rejects(self): + before = Extensions( + security=SecurityExtension(labels=frozenset({"PII", "CONFIDENTIAL"})), + ) + after = Extensions( + security=SecurityExtension(labels=frozenset({"PII"})), + ) + caps = frozenset({"append_labels"}) + with pytest.raises(TierViolationError) as exc_info: + validate_tier_constraints(before, after, caps, "bad-plugin") + assert "monotonic" in str(exc_info.value) + assert exc_info.value.tier == MutabilityTier.MONOTONIC + + def test_monotonic_without_cap_rejects(self): + before = Extensions( + security=SecurityExtension(labels=frozenset({"PII"})), + ) + after = Extensions( + security=SecurityExtension(labels=frozenset({"PII", "SECRET"})), + ) + with pytest.raises(TierViolationError) as exc_info: + validate_tier_constraints(before, after, frozenset(), "bad-plugin") + assert "append_labels" in exc_info.value.detail + + def test_mutable_allows_any_change(self): + before = Extensions(custom={"key": "value"}) + after = Extensions(custom={"key": "changed", "new": "stuff"}) + validate_tier_constraints(before, after, frozenset(), "plugin") + + def test_mutable_allows_deletion(self): + before = Extensions(custom={"key": "value"}) + after = Extensions(custom=None) + validate_tier_constraints(before, after, frozenset(), "plugin") + + +class TestTierViolationError: + def test_attributes(self): + err = TierViolationError("my-plugin", SlotName.REQUEST, MutabilityTier.IMMUTABLE, "changed") + assert err.plugin_name == "my-plugin" + assert err.slot == SlotName.REQUEST + assert err.tier == MutabilityTier.IMMUTABLE + assert err.detail == "changed" + + def test_message(self): + err = TierViolationError("p", SlotName.REQUEST, MutabilityTier.IMMUTABLE, "nope") + assert "p" in str(err) + assert "immutable" in str(err) + assert "request" in str(err) + assert "nope" in str(err) + + +# --------------------------------------------------------------------------- +# merge_extensions +# --------------------------------------------------------------------------- + + +class TestMergeExtensions: + def test_none_original_returns_none(self): + output = Extensions(custom={"key": "val"}) + assert merge_extensions(None, output, frozenset(), "p") is None + + def test_none_output_returns_original(self): + original = Extensions(custom={"key": "val"}) + assert merge_extensions(original, None, frozenset(), "p") is original + + def test_no_changes_returns_original(self): + original = Extensions( + request=RequestExtension(environment="prod", request_id="r1"), + custom={"key": "val"}, + ) + output = original.model_copy() + result = merge_extensions(original, output, frozenset(), "p") + assert result is original + + def test_immutable_slots_ignored(self): + original = Extensions( + request=RequestExtension(environment="prod", request_id="r1"), + ) + output = Extensions( + request=RequestExtension(environment="staging", request_id="r1"), + ) + result = merge_extensions(original, output, frozenset(), "p") + assert result is original + assert result.request.environment == "prod" + + def test_immutable_agent_ignored(self): + original = Extensions( + agent=AgentExtension(agent_id="a1", session_id="s1"), + ) + output = Extensions( + agent=AgentExtension(agent_id="hijacked", session_id="s1"), + ) + result = merge_extensions(original, output, frozenset({"read_agent"}), "p") + assert result is original + assert result.agent.agent_id == "a1" + + def test_custom_accepted_without_cap(self): + original = Extensions(custom={"key": "val"}) + output = Extensions(custom={"key": "changed", "new": "stuff"}) + result = merge_extensions(original, output, frozenset(), "p") + assert result.custom == {"key": "changed", "new": "stuff"} + # Immutable slots unchanged + assert result.request is None + + def test_custom_deletion_accepted(self): + original = Extensions(custom={"key": "val"}) + output = Extensions(custom=None) + result = merge_extensions(original, output, frozenset(), "p") + assert result.custom is None + + def test_http_accepted_with_write_cap(self): + original = Extensions( + http=HttpExtension(headers={"X-Foo": "bar"}), + ) + output = Extensions( + http=HttpExtension(headers={"X-Foo": "baz"}), + ) + caps = frozenset({"write_headers"}) + result = merge_extensions(original, output, caps, "p") + assert result.http.headers == {"X-Foo": "baz"} + + def test_http_ignored_without_write_cap(self): + original = Extensions( + http=HttpExtension(headers={"X-Foo": "bar"}), + ) + output = Extensions( + http=HttpExtension(headers={"X-Foo": "baz"}), + ) + result = merge_extensions(original, output, frozenset({"read_headers"}), "p") + assert result is original + assert result.http.headers == {"X-Foo": "bar"} + + def test_labels_accepted_with_append_cap(self): + original = Extensions( + security=SecurityExtension(labels=frozenset({"PII"})), + ) + output = Extensions( + security=SecurityExtension(labels=frozenset({"PII", "CONFIDENTIAL"})), + ) + caps = frozenset({"append_labels"}) + result = merge_extensions(original, output, caps, "p") + assert result.security.labels == frozenset({"PII", "CONFIDENTIAL"}) + + def test_labels_ignored_without_cap(self): + original = Extensions( + security=SecurityExtension(labels=frozenset({"PII"})), + ) + output = Extensions( + security=SecurityExtension(labels=frozenset({"PII", "CONFIDENTIAL"})), + ) + result = merge_extensions(original, output, frozenset(), "p") + assert result is original + assert result.security.labels == frozenset({"PII"}) + + def test_labels_removal_rejected(self): + original = Extensions( + security=SecurityExtension(labels=frozenset({"PII", "CONFIDENTIAL"})), + ) + output = Extensions( + security=SecurityExtension(labels=frozenset({"PII"})), + ) + caps = frozenset({"append_labels"}) + with pytest.raises(TierViolationError) as exc_info: + merge_extensions(original, output, caps, "bad-plugin") + assert exc_info.value.tier == MutabilityTier.MONOTONIC + + def test_security_subject_ignored(self): + """Subject is immutable — plugin changes are discarded.""" + original = Extensions( + security=SecurityExtension( + subject=SubjectExtension( + id="alice", type=SubjectType.USER, roles=frozenset({"admin"}), + ), + ), + ) + output = Extensions( + security=SecurityExtension( + subject=SubjectExtension( + id="eve", type=SubjectType.USER, roles=frozenset({"root"}), + ), + ), + ) + caps = frozenset({"read_subject", "read_roles"}) + result = merge_extensions(original, output, caps, "p") + assert result is original + assert result.security.subject.id == "alice" + + def test_mixed_changes(self, full_extensions): + """Only writable slots are accepted in a single merge.""" + output = full_extensions.model_copy(update={ + # Immutable — should be ignored + "request": RequestExtension(environment="hijacked", request_id="r1"), + # Mutable — should be accepted + "custom": {"injected": True}, + }) + result = merge_extensions(full_extensions, output, frozenset(), "p") + assert result.request.environment == full_extensions.request.environment + assert result.custom == {"injected": True} + + +# --------------------------------------------------------------------------- +# PluginConfig capabilities and frozen lockdown +# --------------------------------------------------------------------------- + + +class TestPluginConfigCapabilities: + _PLUGIN_BASE = {"name": "test-plugin", "kind": "test.Plugin"} + + def test_valid_capabilities(self): + from cpex.framework.models import PluginConfig + + conf = PluginConfig( + **self._PLUGIN_BASE, + capabilities=["read_headers", "append_labels"], + ) + assert conf.capabilities == frozenset({"read_headers", "append_labels"}) + + def test_unknown_capability_rejected(self): + from cpex.framework.models import PluginConfig + + with pytest.raises(ValueError, match="Unknown capability"): + PluginConfig(**self._PLUGIN_BASE, capabilities=["bogus_cap"]) + + def test_empty_capabilities_default(self): + from cpex.framework.models import PluginConfig + + conf = PluginConfig(**self._PLUGIN_BASE) + assert conf.capabilities == frozenset() + + def test_capabilities_serialization(self): + import orjson + + from cpex.framework.models import PluginConfig + + conf = PluginConfig( + **self._PLUGIN_BASE, + capabilities=["read_agent", "read_headers"], + ) + data = orjson.loads(orjson.dumps(conf.model_dump())) + assert sorted(data["capabilities"]) == ["read_agent", "read_headers"] + + def test_frozen_config_prevents_capability_escalation(self): + from pydantic import ValidationError + + from cpex.framework.models import PluginConfig + + conf = PluginConfig(**self._PLUGIN_BASE) + with pytest.raises(ValidationError): + conf.capabilities = frozenset({"write_headers"}) # type: ignore[misc] + + def test_frozen_config_prevents_field_mutation(self): + from pydantic import ValidationError + + from cpex.framework.models import PluginConfig + + conf = PluginConfig(**self._PLUGIN_BASE) + with pytest.raises(ValidationError): + conf.name = "hijacked" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# Defensive config copy and PluginRef trusted config +# --------------------------------------------------------------------------- + + +class TestDefensiveConfigCopy: + """Verify the trust boundary between Manager and plugins.""" + + _PLUGIN_BASE = {"name": "copy-test", "kind": "test.Plugin"} + + def test_plugin_ref_trusted_config_is_separate_from_plugin(self): + """PluginRef's trusted_config should be a different object than the plugin's config.""" + from cpex.framework.base import Plugin, PluginRef + from cpex.framework.models import PluginConfig + + original = PluginConfig(**self._PLUGIN_BASE, capabilities=["read_headers"]) + copy = original.model_copy() + plugin = Plugin(copy) + ref = PluginRef(plugin, trusted_config=original) + + # The plugin holds the copy, the ref holds the original + assert ref.trusted_config is original + assert plugin.config is copy + assert ref.trusted_config is not plugin.config + + def test_plugin_ref_reads_from_trusted_config(self): + """PluginRef properties should read from trusted_config, not from the plugin.""" + from cpex.framework.base import Plugin, PluginRef + from cpex.framework.models import PluginConfig, PluginMode + + original = PluginConfig( + **self._PLUGIN_BASE, + mode=PluginMode.CONCURRENT, + priority=42, + tags=["trusted"], + capabilities=["read_headers"], + ) + # Give the plugin a different copy with different values + plugin_copy = PluginConfig( + name="copy-test", + kind="test.Plugin", + mode=PluginMode.SEQUENTIAL, + priority=99, + tags=["untrusted"], + ) + plugin = Plugin(plugin_copy) + ref = PluginRef(plugin, trusted_config=original) + + # All properties come from trusted_config + assert ref.mode == PluginMode.CONCURRENT + assert ref.priority == 42 + assert ref.tags == ["trusted"] + assert ref.capabilities == frozenset({"read_headers"}) + + def test_plugin_ref_fallback_without_trusted_config(self): + """Without trusted_config, PluginRef falls back to plugin.config.""" + from cpex.framework.base import Plugin, PluginRef + from cpex.framework.models import PluginConfig + + config = PluginConfig(**self._PLUGIN_BASE) + plugin = Plugin(config) + ref = PluginRef(plugin) + + assert ref.trusted_config is plugin.config + + def test_model_copy_produces_equal_but_distinct_config(self): + """model_copy() should produce an equal but distinct PluginConfig.""" + from cpex.framework.models import PluginConfig + + original = PluginConfig(**self._PLUGIN_BASE, capabilities=["append_labels"]) + copy = original.model_copy() + + assert copy == original + assert copy is not original + assert copy.capabilities == original.capabilities + + def test_registry_passes_trusted_config_to_ref(self): + """PluginInstanceRegistry.register() should pass trusted_config to PluginRef.""" + from cpex.framework.base import Plugin + from cpex.framework.models import PluginConfig + from cpex.framework.registry import PluginInstanceRegistry + + original = PluginConfig(**self._PLUGIN_BASE, capabilities=["read_headers"]) + copy = original.model_copy() + plugin = Plugin(copy) + + registry = PluginInstanceRegistry() + registry.register(plugin, trusted_config=original) + + ref = registry.get_plugin("copy-test") + assert ref is not None + assert ref.trusted_config is original + assert ref.plugin.config is copy + assert ref.trusted_config is not ref.plugin.config diff --git a/tests/unit/cpex/framework/external/grpc/test_client.py b/tests/unit/cpex/framework/external/grpc/test_client.py index 4bc8c4f..ffe038c 100644 --- a/tests/unit/cpex/framework/external/grpc/test_client.py +++ b/tests/unit/cpex/framework/external/grpc/test_client.py @@ -13,6 +13,7 @@ # Third-Party import pytest +from pydantic import ValidationError # First-Party from cpex.framework import PluginError, ToolPreInvokePayload @@ -84,20 +85,14 @@ def test_init_stores_config(self, mock_plugin_config): class TestGrpcExternalPluginInitialize: """Tests for GrpcExternalPlugin.initialize().""" - @pytest.mark.asyncio - async def test_initialize_missing_grpc_config(self): - """Test initialize raises PluginError when grpc config is missing.""" - config = PluginConfig( - name="TestPlugin", - kind="external", - hooks=["tool_pre_invoke"], - grpc=GRPCClientConfig(target="localhost:50051"), - ) - plugin = GrpcExternalPlugin(config) - plugin._config.grpc = None # Remove grpc config - - with pytest.raises(PluginError, match="grpc section must be defined"): - await plugin.initialize() + def test_initialize_missing_grpc_config(self): + """Test PluginConfig validation rejects external plugin without transport config.""" + with pytest.raises(ValidationError, match="External plugin.*must have"): + PluginConfig( + name="TestPlugin", + kind="external", + hooks=["tool_pre_invoke"], + ) @pytest.mark.asyncio async def test_initialize_creates_channel(self, mock_plugin_config): diff --git a/tests/unit/cpex/framework/external/mcp/test_client_config.py b/tests/unit/cpex/framework/external/mcp/test_client_config.py index 45053d6..5b05a35 100644 --- a/tests/unit/cpex/framework/external/mcp/test_client_config.py +++ b/tests/unit/cpex/framework/external/mcp/test_client_config.py @@ -54,9 +54,9 @@ async def test_initialize_missing_mcp_config(): config = ConfigLoader.load_config("tests/unit/cpex/fixtures/configs/valid_stdio_external_plugin.yaml") plugin_config = config.plugins[0] - # Create plugin and temporarily set mcp to None - plugin = ExternalPlugin(plugin_config) - plugin._config.mcp = None + # Create plugin with mcp removed via frozen-safe copy + no_mcp_config = plugin_config.model_copy(update={"mcp": None}) + plugin = ExternalPlugin(no_mcp_config) with pytest.raises(PluginError, match="The mcp section must be defined for external plugin"): await plugin.initialize() @@ -67,10 +67,11 @@ async def test_initialize_stdio_missing_script(): """Test initialize raises ValueError for missing stdio script.""" config = ConfigLoader.load_config("tests/unit/cpex/fixtures/configs/valid_stdio_external_plugin.yaml") plugin_config = config.plugins[0] - plugin = ExternalPlugin(plugin_config) - # Mock the script path to be missing - plugin._config.mcp.script = "/path/to/missing.sh" + # Create plugin with missing script via frozen-safe copy + bad_mcp = plugin_config.mcp.model_copy(update={"script": "/path/to/missing.sh"}) + bad_config = plugin_config.model_copy(update={"mcp": bad_mcp}) + plugin = ExternalPlugin(bad_config) # Cross-platform: Windows uses backslashes, Unix uses forward slashes with pytest.raises(PluginError, match=r"Server script .+[/\\]missing\.sh does not exist\."): diff --git a/tests/unit/cpex/framework/external/unix/test_client.py b/tests/unit/cpex/framework/external/unix/test_client.py index 0c8370f..474c0e2 100644 --- a/tests/unit/cpex/framework/external/unix/test_client.py +++ b/tests/unit/cpex/framework/external/unix/test_client.py @@ -14,6 +14,7 @@ # Third-Party import pytest +from pydantic import ValidationError # First-Party from cpex.framework import ToolPreInvokePayload @@ -79,19 +80,13 @@ def test_init_stores_socket_config(self, mock_plugin_config): assert plugin._reconnect_delay == 0.1 def test_init_missing_unix_socket_config(self): - """Test init raises PluginError when unix_socket config is missing.""" - config = PluginConfig( - name="TestPlugin", - kind="external", - hooks=["tool_pre_invoke"], - unix_socket=UnixSocketClientConfig(path="/tmp/test.sock"), - ) - plugin = UnixSocketExternalPlugin(config) - plugin._config.unix_socket = None - - with pytest.raises(PluginError, match="unix_socket section must be defined"): - # Re-initialize to trigger the check - UnixSocketExternalPlugin.__init__(plugin, config) + """Test PluginConfig validation rejects external plugin without transport config.""" + with pytest.raises(ValidationError, match="External plugin.*must have"): + PluginConfig( + name="TestPlugin", + kind="external", + hooks=["tool_pre_invoke"], + ) class TestUnixSocketExternalPluginConnected: diff --git a/tests/unit/cpex/framework/hooks/test_message.py b/tests/unit/cpex/framework/hooks/test_message.py new file mode 100644 index 0000000..a365221 --- /dev/null +++ b/tests/unit/cpex/framework/hooks/test_message.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +"""Location: ./tests/unit/cpex/framework/hooks/test_message.py +Copyright 2025 +SPDX-License-Identifier: Apache-2.0 +Authors: Teryl Taylor + +Unit tests for message evaluation hook definitions. +""" + +# Third-Party +import pytest + +# First-Party +from cpex.framework.cmf.message import Message, Role, TextContent +from cpex.framework.hooks.message import ( + MessageHookType, + MessagePayload, + MessageResult, +) +from cpex.framework.hooks.registry import get_hook_registry +from cpex.framework.models import PluginPayload, PluginResult + + +# --------------------------------------------------------------------------- +# MessageHookType Tests +# --------------------------------------------------------------------------- + + +class TestMessageHookType: + """Tests for the MessageHookType enum.""" + + def test_evaluate_value(self): + assert MessageHookType.EVALUATE.value == "evaluate" + + def test_from_string(self): + assert MessageHookType("evaluate") == MessageHookType.EVALUATE + + def test_invalid_value(self): + with pytest.raises(ValueError): + MessageHookType("invalid") + + def test_member_count(self): + assert len(MessageHookType) == 9 + + def test_is_str_enum(self): + assert isinstance(MessageHookType.EVALUATE, str) + assert MessageHookType.EVALUATE == "evaluate" + + +# --------------------------------------------------------------------------- +# MessagePayload Tests +# --------------------------------------------------------------------------- + + +class TestMessagePayload: + """Tests for the MessagePayload model.""" + + def test_subclass_of_plugin_payload(self): + assert issubclass(MessagePayload, PluginPayload) + + def test_creation(self): + msg = Message( + role=Role.USER, + content=[TextContent(text="Hello")], + ) + payload = MessagePayload(message=msg) + assert payload.message is msg + assert payload.message.role == Role.USER + assert payload.message.content[0].text == "Hello" + + def test_message_field_required(self): + with pytest.raises(Exception): + MessagePayload() + + def test_with_multi_part_message(self): + msg = Message( + role=Role.ASSISTANT, + content=[ + TextContent(text="Part one"), + TextContent(text="Part two"), + ], + ) + payload = MessagePayload(message=msg) + assert len(payload.message.content) == 2 + + def test_iter_views_through_payload(self): + msg = Message( + role=Role.USER, + content=[ + TextContent(text="First"), + TextContent(text="Second"), + ], + ) + payload = MessagePayload(message=msg) + views = list(payload.message.iter_views()) + assert len(views) == 2 + + +# --------------------------------------------------------------------------- +# MessageResult Tests +# --------------------------------------------------------------------------- + + +class TestMessageResult: + """Tests for the MessageResult type alias.""" + + def test_is_plugin_result_subclass(self): + assert issubclass(MessageResult, PluginResult) + + +# --------------------------------------------------------------------------- +# Hook Registration Tests +# --------------------------------------------------------------------------- + + +class TestMessageHookRegistration: + """Tests for message hook registration in the global registry.""" + + def test_evaluate_hook_registered(self): + registry = get_hook_registry() + assert registry.is_registered(MessageHookType.EVALUATE) + + def test_payload_type(self): + registry = get_hook_registry() + assert registry.get_payload_type(MessageHookType.EVALUATE) is MessagePayload + + def test_result_type(self): + registry = get_hook_registry() + assert registry.get_result_type(MessageHookType.EVALUATE) is MessageResult + + def test_idempotent_registration(self): + """Re-importing or re-calling _register should not raise.""" + # First-Party + from cpex.framework.hooks.message import _register_message_hooks + + _register_message_hooks() + registry = get_hook_registry() + assert registry.is_registered(MessageHookType.EVALUATE) diff --git a/tests/unit/cpex/framework/test_errors.py b/tests/unit/cpex/framework/test_errors.py index 8cc8728..242f782 100644 --- a/tests/unit/cpex/framework/test_errors.py +++ b/tests/unit/cpex/framework/test_errors.py @@ -61,8 +61,9 @@ async def test_error_plugin_raise_error_false(): # assert not result.modified_payload await plugin_manager.shutdown() - plugin_manager.config.plugins[0].mode = PluginMode.CONCURRENT - plugin_manager.config.plugins[0].on_error = OnError.IGNORE + plugin_manager.config.plugins[0] = plugin_manager.config.plugins[0].model_copy( + update={"mode": PluginMode.CONCURRENT, "on_error": OnError.IGNORE} + ) await plugin_manager.initialize() result, _ = await plugin_manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, payload, global_context) assert result.continue_processing diff --git a/tests/unit/cpex/framework/test_manager_extended.py b/tests/unit/cpex/framework/test_manager_extended.py index e61569d..2c0a231 100644 --- a/tests/unit/cpex/framework/test_manager_extended.py +++ b/tests/unit/cpex/framework/test_manager_extended.py @@ -116,10 +116,12 @@ async def prompt_pre_fetch(self, payload, context): # assert "timeout" in result.violation.description.lower() # Test with audit mode + on_error=IGNORE (errors are logged and ignored) - plugin_config.mode = PluginMode.AUDIT - plugin_config.on_error = OnError.IGNORE + audit_config = plugin_config.model_copy( + update={"mode": PluginMode.AUDIT, "on_error": OnError.IGNORE} + ) + audit_plugin = TimeoutPlugin(audit_config) with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(timeout_plugin)) + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(audit_plugin)) mock_get.return_value = [hook_ref] result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) @@ -177,10 +179,12 @@ async def prompt_pre_fetch(self, payload, context): # assert "error" in result.violation.description.lower() # Test with audit mode + on_error=IGNORE (errors are logged and ignored) - plugin_config.mode = PluginMode.AUDIT - plugin_config.on_error = OnError.IGNORE + audit_config = plugin_config.model_copy( + update={"mode": PluginMode.AUDIT, "on_error": OnError.IGNORE} + ) + audit_plugin = ErrorPlugin(audit_config) with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(audit_plugin)) mock_get.return_value = [hook_ref] result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) @@ -189,41 +193,21 @@ async def prompt_pre_fetch(self, payload, context): assert result.continue_processing assert result.violation is None - plugin_config.mode = PluginMode.CONCURRENT - plugin_config.on_error = OnError.IGNORE - with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) - mock_get.return_value = [hook_ref] - - result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) - - # Should continue with on_error=ignore - assert result.continue_processing - assert result.violation is None - - plugin_config.mode = PluginMode.CONCURRENT - plugin_config.on_error = OnError.IGNORE - with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) - mock_get.return_value = [hook_ref] - - result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) - - # Should continue with on_error=ignore - assert result.continue_processing - assert result.violation is None - - plugin_config.mode = PluginMode.CONCURRENT - plugin_config.on_error = OnError.IGNORE - with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: - hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(error_plugin)) - mock_get.return_value = [hook_ref] + # Test with concurrent mode + on_error=IGNORE (repeated to verify consistency) + ignore_config = plugin_config.model_copy( + update={"mode": PluginMode.CONCURRENT, "on_error": OnError.IGNORE} + ) + ignore_plugin = ErrorPlugin(ignore_config) + for _ in range(3): + with patch.object(manager._registry, "get_hook_refs_for_hook") as mock_get: + hook_ref = HookRef(PromptHookType.PROMPT_PRE_FETCH, PluginRef(ignore_plugin)) + mock_get.return_value = [hook_ref] - result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, _ = await manager.invoke_hook(PromptHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) - # Should continue with on_error=ignore - assert result.continue_processing - assert result.violation is None + # Should continue with on_error=ignore + assert result.continue_processing + assert result.violation is None await manager.shutdown() diff --git a/tests/unit/cpex/framework/test_plugin_base_coverage.py b/tests/unit/cpex/framework/test_plugin_base_coverage.py index a268539..090ea4b 100644 --- a/tests/unit/cpex/framework/test_plugin_base_coverage.py +++ b/tests/unit/cpex/framework/test_plugin_base_coverage.py @@ -64,6 +64,13 @@ async def tool_pre_invoke(self, payload: PluginPayload) -> PluginResult: return PluginResult(continue_processing=True) +class ThreeParamPlugin(Plugin): + """Plugin with 3 parameters (accepts extensions).""" + + async def tool_pre_invoke(self, payload: PluginPayload, context: PluginContext, extensions) -> PluginResult: + return PluginResult(continue_processing=True) + + class NoHookPlugin(Plugin): """Plugin with no method matching the hook.""" @@ -155,6 +162,19 @@ def test_wrong_param_count_raises(self): with pytest.raises(PluginError, match="invalid signature"): HookRef("tool_pre_invoke", ref) + def test_three_param_plugin_accepted(self): + plugin = ThreeParamPlugin(_make_config()) + ref = PluginRef(plugin) + hook_ref = HookRef("tool_pre_invoke", ref) + assert hook_ref.hook is not None + assert hook_ref.accepts_extensions is True + + def test_two_param_plugin_no_extensions(self): + plugin = ConcretePlugin(_make_config()) + ref = PluginRef(plugin) + hook_ref = HookRef("tool_pre_invoke", ref) + assert hook_ref.accepts_extensions is False + def test_sync_method_raises(self): plugin = SyncPlugin(_make_config()) ref = PluginRef(plugin) diff --git a/tests/unit/cpex/framework/test_resource_hooks.py b/tests/unit/cpex/framework/test_resource_hooks.py index f903a1b..0b0fa4c 100644 --- a/tests/unit/cpex/framework/test_resource_hooks.py +++ b/tests/unit/cpex/framework/test_resource_hooks.py @@ -419,9 +419,12 @@ async def resource_pre_fetch(self, payload, context): assert result.continue_processing is True # Continues despite error # Test with concurrent mode + on_error=FAIL (default) - should raise PluginError - config.mode = PluginMode.CONCURRENT - config.on_error = OnError.FAIL - with patch.object(manager._registry, "get_hook_refs_for_hook", return_value=[hook_ref]): + fail_config = config.model_copy( + update={"mode": PluginMode.CONCURRENT, "on_error": OnError.FAIL} + ) + fail_plugin = ErrorPlugin(fail_config) + fail_ref = HookRef(ResourceHookType.RESOURCE_PRE_FETCH, PluginRef(fail_plugin)) + with patch.object(manager._registry, "get_hook_refs_for_hook", return_value=[fail_ref]): with pytest.raises(PluginError): result, contexts = await manager.invoke_hook( ResourceHookType.RESOURCE_PRE_FETCH, payload, global_context