diff --git a/.gitignore b/.gitignore index 6c815c1..2ff2911 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,6 @@ CLAUDE.local.md .scannerwork llms-full.txt aider* -.aider* todo/ *.sarif devskim-results.sarif @@ -35,7 +34,6 @@ token.txt mcpgateway.sbom.xml gateway_service_leader.lock docs/docs/test/ -tmp *.tgz *.gz *.bz @@ -74,9 +72,7 @@ dictionary.dic pdm.lock .pdm-python temp/ -public/ *history.md -htmlcov test_commands.md cover.md build/ @@ -99,28 +95,22 @@ scribeflow.log coverage_re bin/flagged flagged/ -certs/ # VENV .python37/ .python39/ # Byte-compiled / optimized / DLL files __pycache__/ -**/__pycache__/ *.py[cod] *$py.class mcpgateway-wrapper/src/mcp_gateway_wrapper/__pycache__/ -# Bak -*.bak - # C extensions *.so # Distribution / packaging .wily/ .Python -build/ develop-eggs/ dist/ downloads/ @@ -165,7 +155,6 @@ coverage.xml *.pot # Django stuff: -*.log local_settings.py db.sqlite3 @@ -199,8 +188,6 @@ celerybeat-schedule *.sage.py # Environments -.env -.venv env/ venv/ ENV/ @@ -231,9 +218,6 @@ dmypy.json .idea/ -# Sonar -.scannerwork - # vim *.swp *,cover @@ -244,9 +228,6 @@ logging/ .ai* -# downloads -downloads/ - # db_path db_path/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b3a534d..160abd9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -530,5 +530,5 @@ repos: rev: 1.7.0 # or master if you're bold hooks: - id: interrogate - args: [--quiet, --fail-under=100] + args: [--quiet, --fail-under=100, --exclude, cforge/_version.py] files: ^cforge/ diff --git a/README.md b/README.md index 2205345..2b6a6fb 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,8 @@ Here are some examples: cforge tools list [--mcp-server-id ID] [--json] cforge tools get cforge tools create [file.json] +cforge tools execute # Interactive schema prompt +cforge tools execute --data args.json # Use JSON args file cforge tools toggle # Resources diff --git a/cforge/commands/deploy/deploy.py b/cforge/commands/deploy/deploy.py index 7fc5483..4d1f578 100644 --- a/cforge/commands/deploy/deploy.py +++ b/cforge/commands/deploy/deploy.py @@ -11,7 +11,7 @@ import typer # First-Party -from cforge.common import get_console +from cforge.common.console import get_console def deploy() -> None: diff --git a/cforge/commands/metrics/metrics.py b/cforge/commands/metrics/metrics.py index fd2197e..770b599 100644 --- a/cforge/commands/metrics/metrics.py +++ b/cforge/commands/metrics/metrics.py @@ -11,11 +11,9 @@ import typer # First-Party -from cforge.common import ( - get_console, - make_authenticated_request, - print_json, -) +from cforge.common.console import get_console +from cforge.common.http import make_authenticated_request +from cforge.common.render import print_json def metrics_get( diff --git a/cforge/commands/resources/a2a.py b/cforge/commands/resources/a2a.py index 6307b45..1dde748 100644 --- a/cforge/commands/resources/a2a.py +++ b/cforge/commands/resources/a2a.py @@ -16,14 +16,11 @@ import typer # First-Party -from cforge.common import ( - get_console, - handle_exception, - make_authenticated_request, - print_json, - print_table, - prompt_for_schema, -) +from cforge.common.console import get_console +from cforge.common.errors import handle_exception +from cforge.common.http import make_authenticated_request +from cforge.common.prompting import prompt_for_schema +from cforge.common.render import print_json, print_table from mcpgateway.schemas import A2AAgentCreate, A2AAgentUpdate diff --git a/cforge/commands/resources/mcp_servers.py b/cforge/commands/resources/mcp_servers.py index 6c461e7..8973e26 100644 --- a/cforge/commands/resources/mcp_servers.py +++ b/cforge/commands/resources/mcp_servers.py @@ -16,14 +16,11 @@ import typer # First-Party -from cforge.common import ( - get_console, - handle_exception, - make_authenticated_request, - print_json, - print_table, - prompt_for_schema, -) +from cforge.common.console import get_console +from cforge.common.errors import handle_exception +from cforge.common.http import make_authenticated_request +from cforge.common.prompting import prompt_for_schema +from cforge.common.render import print_json, print_table from mcpgateway.schemas import GatewayCreate, GatewayUpdate diff --git a/cforge/commands/resources/plugins.py b/cforge/commands/resources/plugins.py index afd2483..5baa698 100644 --- a/cforge/commands/resources/plugins.py +++ b/cforge/commands/resources/plugins.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- -"""Location: ./cforge/commands/resources/plugins.py -Copyright 2025 +""" SPDX-License-Identifier: Apache-2.0 -Authors: Matthew Grigsby CLI command group: plugins @@ -16,45 +14,19 @@ """ # Standard -from enum import Enum from typing import Any, Dict, Optional # Third-Party import typer # First-Party -from cforge.common import ( - AuthenticationError, - CLIError, - get_console, - handle_exception, - make_authenticated_request, - print_json, - print_table, -) - - -class _CaseInsensitiveEnum(str, Enum): - """Enum that supports case-insensitive parsing for CLI options.""" - - @classmethod - def _missing_(cls, value: object) -> Optional["_CaseInsensitiveEnum"]: - """Resolve unknown values by matching enum values case-insensitively. - - Typer converts CLI strings into Enum members. Implementing `_missing_` - allows `--mode EnFoRcE` to resolve to `PluginMode.ENFORCE`, while still - rejecting unknown values. - """ - if not isinstance(value, str): - return None - value_folded = value.casefold() - for member in cls: - if member.value.casefold() == value_folded: - return member - return None +from cforge.common.console import get_console +from cforge.common.errors import AuthenticationError, CaseInsensitiveEnum, CLIError, handle_exception +from cforge.common.http import make_authenticated_request +from cforge.common.render import print_json, print_table -class PluginMode(_CaseInsensitiveEnum): +class PluginMode(CaseInsensitiveEnum): """Valid plugin mode filters supported by the gateway admin API.""" ENFORCE = "enforce" @@ -62,21 +34,39 @@ class PluginMode(_CaseInsensitiveEnum): DISABLED = "disabled" -def _handle_plugins_exception(exception: Exception) -> None: +def _parse_plugin_mode(mode: Optional[str]) -> Optional[PluginMode]: + """Parse plugin mode with case-insensitive enum matching.""" + if mode is None: + return None + try: + return PluginMode(mode) + except ValueError as exc: + choices = ", ".join(member.value for member in PluginMode) + raise CLIError(f"Invalid value for '--mode': {mode!r}. Must be one of: {choices}.") from exc + + +def _handle_plugins_exception(exception: Exception, operation: str, plugin_name: Optional[str] = None) -> None: """Provide plugin-specific hints and raise a CLI error.""" console = get_console() if isinstance(exception, AuthenticationError): console.print("[yellow]Access denied. Requires admin.plugins permission.[/yellow]") - elif isinstance(exception, CLIError) and "(404)" in str(exception): - console.print("[yellow]Admin plugin API unavailable. Ensure MCPGATEWAY_ADMIN_API_ENABLED=true and gateway version supports /admin/plugins.[/yellow]") + elif isinstance(exception, CLIError): + error_str = str(exception) + if "(404)" in error_str: + error_str_folded = error_str.casefold() + if operation == "get" and "plugin" in error_str_folded and "not found" in error_str_folded: + plugin_label = plugin_name or "requested plugin" + console.print(f"[yellow]Plugin not found: {plugin_label}[/yellow]") + else: + console.print("[yellow]Admin plugin API unavailable. Ensure MCPGATEWAY_ADMIN_API_ENABLED=true and gateway version supports /admin/plugins.[/yellow]") handle_exception(exception) def plugins_list( search: Optional[str] = typer.Option(None, "--search", help="Search by plugin name, description, or author"), - mode: Optional[PluginMode] = typer.Option(None, "--mode", help="Filter by mode"), + mode: Optional[str] = typer.Option(None, "--mode", help="Filter by mode"), hook: Optional[str] = typer.Option(None, "--hook", help="Filter by hook type"), tag: Optional[str] = typer.Option(None, "--tag", help="Filter by plugin tag"), json_output: bool = typer.Option(False, "--json", help="Output as JSON"), @@ -88,8 +78,9 @@ def plugins_list( params: Dict[str, Any] = {} if search: params["search"] = search - if mode: - params["mode"] = mode.value + parsed_mode = _parse_plugin_mode(mode) + if parsed_mode: + params["mode"] = parsed_mode.value if hook: params["hook"] = hook if tag: @@ -108,7 +99,7 @@ def plugins_list( console.print("[yellow]No plugins found[/yellow]") except Exception as e: - _handle_plugins_exception(e) + _handle_plugins_exception(e, operation="list") def plugins_get( @@ -120,7 +111,7 @@ def plugins_get( print_json(result, f"Plugin {name}") except Exception as e: - _handle_plugins_exception(e) + _handle_plugins_exception(e, operation="get", plugin_name=name) def plugins_stats() -> None: @@ -130,4 +121,4 @@ def plugins_stats() -> None: print_json(result, "Plugin Statistics") except Exception as e: - _handle_plugins_exception(e) + _handle_plugins_exception(e, operation="stats") diff --git a/cforge/commands/resources/prompts.py b/cforge/commands/resources/prompts.py index ecc62ff..b3b0d83 100644 --- a/cforge/commands/resources/prompts.py +++ b/cforge/commands/resources/prompts.py @@ -16,14 +16,11 @@ import typer # First-Party -from cforge.common import ( - get_console, - handle_exception, - make_authenticated_request, - print_json, - print_table, - prompt_for_schema, -) +from cforge.common.console import get_console +from cforge.common.errors import handle_exception +from cforge.common.http import make_authenticated_request +from cforge.common.prompting import prompt_for_schema +from cforge.common.render import print_json, print_table from mcpgateway.schemas import PromptCreate, PromptUpdate diff --git a/cforge/commands/resources/resources.py b/cforge/commands/resources/resources.py index 6f84eff..05137cd 100644 --- a/cforge/commands/resources/resources.py +++ b/cforge/commands/resources/resources.py @@ -16,14 +16,11 @@ import typer # First-Party -from cforge.common import ( - get_console, - handle_exception, - make_authenticated_request, - print_json, - print_table, - prompt_for_schema, -) +from cforge.common.console import get_console +from cforge.common.errors import handle_exception +from cforge.common.http import make_authenticated_request +from cforge.common.prompting import prompt_for_schema +from cforge.common.render import print_json, print_table from mcpgateway.schemas import ResourceCreate, ResourceUpdate diff --git a/cforge/commands/resources/tools.py b/cforge/commands/resources/tools.py index b12b277..0c88bb4 100644 --- a/cforge/commands/resources/tools.py +++ b/cforge/commands/resources/tools.py @@ -16,14 +16,11 @@ import typer # First-Party -from cforge.common import ( - get_console, - handle_exception, - make_authenticated_request, - print_json, - print_table, - prompt_for_schema, -) +from cforge.common.console import get_console +from cforge.common.errors import CLIError, handle_exception +from cforge.common.http import make_authenticated_request +from cforge.common.prompting import prompt_for_json_schema, prompt_for_schema +from cforge.common.render import print_json, print_table from mcpgateway.schemas import ToolCreate, ToolUpdate @@ -179,3 +176,78 @@ def tools_toggle( except Exception as e: handle_exception(e) + + +def tools_execute( + tool_id: str = typer.Argument(..., help="Tool ID"), + data_file: Optional[Path] = typer.Option(None, "--data", help="JSON file containing tool arguments"), +) -> None: + """Execute a tool by ID using optional dynamic schema prompting.""" + console = get_console() + + try: + if not tool_id.strip(): + raise CLIError("Tool ID must be a non-empty string") + + prefilled_data: Optional[Dict[str, Any]] = None + if data_file: + if not data_file.exists(): + console.print(f"[red]File not found: {data_file}[/red]") + raise typer.Exit(1) + file_data = json.loads(data_file.read_text()) + if not isinstance(file_data, dict): + raise CLIError("Data file must contain a JSON object") + prefilled_data = file_data + prompt_optional = prefilled_data is None + + tool_result = make_authenticated_request("GET", f"/tools/{tool_id}") + assert isinstance(tool_result, dict) + + tool_name = tool_result.get("name") + if not isinstance(tool_name, str) or not tool_name: + raise CLIError(f"Tool '{tool_id}' does not have a valid name") + + raw_schema = tool_result.get("inputSchema") + if raw_schema is None: + raw_schema = tool_result.get("input_schema") + if raw_schema is None: + input_schema = {"type": "object", "properties": {}} + elif isinstance(raw_schema, dict): + input_schema = raw_schema + elif isinstance(raw_schema, str): + try: + parsed_schema = json.loads(raw_schema) + except json.JSONDecodeError as exc: + raise CLIError("Tool input schema must be a JSON object") from exc + if not isinstance(parsed_schema, dict): + raise CLIError("Tool input schema must be a JSON object") + input_schema = parsed_schema + else: + raise CLIError("Tool input schema must be a JSON object") + + if not input_schema: + input_schema = {"type": "object", "properties": {}} + + data = prompt_for_json_schema(input_schema, prefilled=prefilled_data, prompt_optional=prompt_optional) + + rpc_payload: Dict[str, Any] = {"jsonrpc": "2.0", "id": f"cforge-tools-{tool_id}", "method": "tools/call", "params": {"name": tool_name, "arguments": data}} + rpc_result = make_authenticated_request("POST", "/rpc", json_data=rpc_payload) + + if isinstance(rpc_result, dict) and "error" in rpc_result: + error = rpc_result["error"] + if isinstance(error, dict): + err_message = error.get("message", "Unknown error") + err_code = error.get("code") + if err_code is not None: + raise CLIError(f"Tool execution failed ({err_code}): {err_message}") + raise CLIError(f"Tool execution failed: {err_message}") + raise CLIError(f"Tool execution failed: {error}") + + console.print("[green]✓ Tool executed successfully![/green]") + if isinstance(rpc_result, dict) and "result" in rpc_result: + print_json(rpc_result["result"], "Tool Result") + else: + print_json(rpc_result, "Tool Result") + + except Exception as e: + handle_exception(e) diff --git a/cforge/commands/resources/virtual_servers.py b/cforge/commands/resources/virtual_servers.py index 22fa0d4..39ecf88 100644 --- a/cforge/commands/resources/virtual_servers.py +++ b/cforge/commands/resources/virtual_servers.py @@ -16,14 +16,11 @@ import typer # First-Party -from cforge.common import ( - get_console, - handle_exception, - make_authenticated_request, - print_json, - print_table, - prompt_for_schema, -) +from cforge.common.console import get_console +from cforge.common.errors import handle_exception +from cforge.common.http import make_authenticated_request +from cforge.common.prompting import prompt_for_schema +from cforge.common.render import print_json, print_table from mcpgateway.schemas import ServerCreate, ServerUpdate diff --git a/cforge/commands/server/run.py b/cforge/commands/server/run.py index 8aab2d8..23f1060 100644 --- a/cforge/commands/server/run.py +++ b/cforge/commands/server/run.py @@ -23,7 +23,8 @@ import typer # First-Party -from cforge.common import get_console, make_authenticated_request +from cforge.common.console import get_console +from cforge.common.http import make_authenticated_request def run( diff --git a/cforge/commands/settings/config_schema.py b/cforge/commands/settings/config_schema.py index 2deda5d..1d95797 100644 --- a/cforge/commands/settings/config_schema.py +++ b/cforge/commands/settings/config_schema.py @@ -16,7 +16,9 @@ import typer # First-Party -from cforge.common import get_console, get_settings, print_json +from cforge.common.console import get_console +from cforge.common.render import print_json +from cforge.config import get_settings def config_schema( diff --git a/cforge/commands/settings/export.py b/cforge/commands/settings/export.py index 6db85e1..3937c1e 100644 --- a/cforge/commands/settings/export.py +++ b/cforge/commands/settings/export.py @@ -17,7 +17,8 @@ import typer # First-Party -from cforge.common import get_base_url, get_console, make_authenticated_request +from cforge.common.console import get_console +from cforge.common.http import get_base_url, make_authenticated_request def export( diff --git a/cforge/commands/settings/import_cmd.py b/cforge/commands/settings/import_cmd.py index 3a7aa9b..d02b3e6 100644 --- a/cforge/commands/settings/import_cmd.py +++ b/cforge/commands/settings/import_cmd.py @@ -16,7 +16,8 @@ import typer # First-Party -from cforge.common import get_console, make_authenticated_request +from cforge.common.console import get_console +from cforge.common.http import make_authenticated_request def import_cmd( diff --git a/cforge/commands/settings/login.py b/cforge/commands/settings/login.py index a107d3d..10294e1 100644 --- a/cforge/commands/settings/login.py +++ b/cforge/commands/settings/login.py @@ -12,7 +12,8 @@ import typer # First-Party -from cforge.common import get_base_url, get_console, get_token_file, save_token +from cforge.common.console import get_console +from cforge.common.http import get_base_url, get_token_file, save_token def login( diff --git a/cforge/commands/settings/logout.py b/cforge/commands/settings/logout.py index df0ce75..c1d86da 100644 --- a/cforge/commands/settings/logout.py +++ b/cforge/commands/settings/logout.py @@ -8,7 +8,8 @@ """ # First-Party -from cforge.common import get_console, get_token_file +from cforge.common.console import get_console +from cforge.common.http import get_token_file def logout() -> None: diff --git a/cforge/commands/settings/profiles.py b/cforge/commands/settings/profiles.py index 91fd0af..d39f5ae 100644 --- a/cforge/commands/settings/profiles.py +++ b/cforge/commands/settings/profiles.py @@ -9,27 +9,29 @@ # Standard from datetime import datetime -from pathlib import Path -from typing import Optional import json +from pathlib import Path import secrets import string +from typing import Optional # Third-Party import typer # First-Party -from cforge.common import get_console, print_table, print_json, prompt_for_schema +from cforge.common.console import get_console +from cforge.common.prompting import prompt_for_schema +from cforge.common.render import print_json, print_table from cforge.config import get_settings from cforge.profile_utils import ( AuthProfile, - ProfileStore, + get_active_profile, get_all_profiles, get_profile, - get_active_profile, - set_active_profile, load_profile_store, + ProfileStore, save_profile_store, + set_active_profile, ) diff --git a/cforge/commands/settings/support_bundle.py b/cforge/commands/settings/support_bundle.py index 8230e54..da9f930 100644 --- a/cforge/commands/settings/support_bundle.py +++ b/cforge/commands/settings/support_bundle.py @@ -15,7 +15,7 @@ import typer # First-Party -from cforge.common import get_console +from cforge.common.console import get_console def support_bundle( diff --git a/cforge/commands/settings/version.py b/cforge/commands/settings/version.py index 4662dce..ae63438 100644 --- a/cforge/commands/settings/version.py +++ b/cforge/commands/settings/version.py @@ -8,7 +8,8 @@ """ # First-Party -from cforge.common import get_console, make_authenticated_request +from cforge.common.console import get_console +from cforge.common.http import make_authenticated_request from mcpgateway import __version__ diff --git a/cforge/commands/settings/whoami.py b/cforge/commands/settings/whoami.py index 94702d4..ed91743 100644 --- a/cforge/commands/settings/whoami.py +++ b/cforge/commands/settings/whoami.py @@ -8,7 +8,9 @@ """ # First-Party -from cforge.common import get_console, get_settings, get_token_file, load_token +from cforge.common.console import get_console +from cforge.common.http import get_token_file, load_token +from cforge.config import get_settings from cforge.profile_utils import get_active_profile diff --git a/cforge/common.py b/cforge/common.py deleted file mode 100644 index 30882c0..0000000 --- a/cforge/common.py +++ /dev/null @@ -1,513 +0,0 @@ -# -*- coding: utf-8 -*- -"""Location: ./cforge/common.py -Copyright 2025 -SPDX-License-Identifier: Apache-2.0 -Authors: Gabe Goodhart - -Common utilities for Context Forge CLI. -""" - -# Standard -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union -import json - -# Third-Party -from pydantic import BaseModel -from pydantic_core import PydanticUndefined -from rich.console import Console, ConsoleOptions, RenderResult, RenderableType -from rich.segment import Segment -from rich.measure import Measurement -from rich.table import Table -from rich.panel import Panel -from rich.syntax import Syntax -import requests -import typer - -# First-Party -from cforge.profile_utils import DEFAULT_PROFILE_ID -from cforge.config import get_settings -from cforge.credential_store import load_profile_credentials -from cforge.profile_utils import get_active_profile - -# ------------------------------------------------------------------------------ -# Singletons -# ------------------------------------------------------------------------------ - - -@lru_cache -def get_console() -> Console: - """Get the console singleton. - Returns: - Console singleton - """ - return Console() - - -@lru_cache -def get_app() -> typer.Typer: - """Get the typer singleton. - Returns: - typer singleton - """ - return typer.Typer( - name="mcpgateway", - help="MCP Gateway - Production-grade MCP Gateway & Proxy CLI", - add_completion=True, - rich_markup_mode="rich", - ) - - -# ------------------------------------------------------------------------------ -# Error handling -# ------------------------------------------------------------------------------ - - -class CLIError(Exception): - """Base class for CLI-related errors.""" - - -class AuthenticationError(CLIError): - """Raised when authentication fails.""" - - -def split_exception_details(exception: Exception) -> Tuple[str, Any]: - """Try to get parsed details from the exception""" - exc_str = str(exception) - splits = exc_str.split(":", 1) - if len(splits) == 2: - try: - parsed_details = json.loads(splits[1]) - return splits[0], parsed_details - except json.JSONDecodeError: - pass - return exc_str, None - - -def handle_exception(exception: Exception) -> None: - """Handle an exception and print a friendly error message.""" - e_str, e_detail = split_exception_details(exception) - get_console().print(f"[red]Error: {e_str}[/red]") - if e_detail: - print_json(e_detail, "Error details") - raise typer.Exit(1) - - -# ------------------------------------------------------------------------------ -# Auth -# ------------------------------------------------------------------------------ - - -def get_base_url() -> str: - """Get the full base URL for the current profile's server - - TODO: This will need to support https in the future! - - Returns: - The string URL base - """ - return get_active_profile().api_url - - -def get_token_file() -> Path: - """Get the path to the token file in contextforge_home. - - Uses the active profile if available, otherwise returns the default token file. - For the virtual default profile, uses the unsuffixed token file. - - Returns: - Path to the token file (profile-specific or default) - """ - profile = get_active_profile() - suffix = "" if profile.id == DEFAULT_PROFILE_ID else f".{profile.id}" - return get_settings().contextforge_home / f"token{suffix}" - - -def save_token(token: str) -> None: - """Save authentication token to contextforge_home/token file. - - Args: - token: The JWT token to save - """ - token_file = get_token_file() - token_file.parent.mkdir(parents=True, exist_ok=True) - token_file.write_text(token, encoding="utf-8") - # Set restrictive permissions (readable only by owner) - token_file.chmod(0o600) - - -def load_token() -> Optional[str]: - """Load authentication token from contextforge_home/token file. - - Returns: - Token string if found, None otherwise - """ - token_file = get_token_file() - if token_file.exists(): - return token_file.read_text(encoding="utf-8").strip() - return None - - -def attempt_auto_login() -> Optional[str]: - """Attempt to automatically login using stored credentials. - - This function tries to login using credentials stored by the desktop app - in the encrypted credential store. If successful, it saves the token - and returns it. - - Returns: - Authentication token if auto-login succeeds, None otherwise - """ - # Try to load credentials from the encrypted store - profile = get_active_profile() - credentials = load_profile_credentials(profile.id) - if not credentials or not credentials.get("email") or not credentials.get("password"): - return None - - # Attempt login - try: - gateway_url = get_base_url() - response = requests.post( - f"{gateway_url}/auth/email/login", - json={"email": credentials["email"], "password": credentials["password"]}, - headers={"Content-Type": "application/json"}, - ) - - if response.status_code == 200: - data = response.json() - token = data.get("access_token") - if token: - # Save the token for future use - save_token(token) - return token - except Exception: - # Silently fail - auto-login is best-effort - pass - - return None - - -def get_auth_token() -> Optional[str]: - """Get authentication token from multiple sources in priority order. - - Priority: - 1. MCPGATEWAY_BEARER_TOKEN environment variable - 2. Stored token in contextforge_home/token file - 3. Auto-login using stored credentials (if available) - - Returns: - Authentication token string or None if not configured - """ - # Try environment variable first (highest priority) - token: Optional[str] = get_settings().mcpgateway_bearer_token - if token: - return token - - # Try stored token file - token = load_token() - if token: - return token - - # Try auto-login with stored credentials - token = attempt_auto_login() - if token: - return token - - return None - - -def make_authenticated_request( - method: str, - url: str, - json_data: Optional[Dict[str, Any]] = None, - params: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - """Make an authenticated HTTP request to the gateway API. - - Supports both authenticated and unauthenticated servers. Will attempt - the request without authentication if no token is configured, and only - fail if the server requires authentication. - - Args: - method: HTTP method (GET, POST, etc.) - url: URL path for the request - json_data: Optional JSON data for request body - params: Optional query parameters - - Returns: - JSON response from the API - - Raises: - AuthenticationError: If the server requires authentication but none is configured - CLIError: If the API request fails - """ - token = get_auth_token() - - headers = {"Content-Type": "application/json"} - # Only add Authorization header if a token is available - if token: - if token.startswith("Basic "): - headers["Authorization"] = token - else: - headers["Authorization"] = f"Bearer {token}" - - gateway_url = get_base_url() - full_url = f"{gateway_url}{url}" - - try: - response = requests.request(method=method, url=full_url, json=json_data, params=params, headers=headers) - - # Handle authentication errors specifically - if response.status_code in (401, 403): - raise AuthenticationError("Authentication required but not configured. " "Set MCPGATEWAY_BEARER_TOKEN environment variable or run 'cforge login'.") - - if response.status_code >= 400: - raise CLIError(f"API request failed ({response.status_code}): {response.text}") - - return response.json() - - except requests.RequestException as e: - raise CLIError(f"Failed to connect to gateway at {gateway_url}: {str(e)}") - - -# ------------------------------------------------------------------------------ -# Pretty Printing -# ------------------------------------------------------------------------------ - - -class LineLimit: - """A renderable that limits the number of lines after rich's wrapping.""" - - def __init__(self, renderable: RenderableType, max_lines: int): - """Implement with the wrapped renderable and the max lines to render""" - self.renderable = renderable - self.max_lines = max_lines - - def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: - """Hook the actual rendering to perform the per-line truncation""" - - # Let rich render the content with proper wrapping - lines = console.render_lines(self.renderable, options, pad=False) - - # Limit to max_lines - for i, line in enumerate(lines): - if i >= self.max_lines: - # Optionally add an ellipsis indicator - yield Segment("...") - break - yield from line - yield Segment.line() - - def __rich_measure__(self, console: Console, options: ConsoleOptions) -> Measurement: - """Hook the measurement of this entry to pass through to the wrapped - renderable - """ - - return Measurement.get(console, options, self.renderable) - - -def print_json(data: Any, title: Optional[str] = None) -> None: - """Pretty print JSON data with Rich. - - Args: - data: Data to print - title: Optional title for the output - """ - console = get_console() - json_str = json.dumps(data, indent=2, ensure_ascii=False) - syntax = Syntax(json_str, "json", theme="monokai", line_numbers=True) - if title: - console.print(Panel(syntax, title=title, border_style="green")) - else: - console.print(syntax) - - -def print_table( - data: List[Dict], - title: str, - columns: List[str], - col_name_map: Optional[Dict[str, str]] = None, -) -> None: - """Print data as a Rich table. - - Args: - data: List of dictionaries to display - title: Title for the table - columns: List of column names to display - col_name_map: Optional mapping of column names to display - """ - console = get_console() - table = Table(title=title, show_header=True, header_style="bold magenta") - col_name_map = col_name_map or {} - max_lines = get_settings().table_max_lines - - for column in columns: - table.add_column(col_name_map.get(column, column), style="cyan") - - for item in data: - row = [str(item.get(col, "")) for col in columns] - if max_lines > 0: - row = [LineLimit(cell, max_lines=max_lines) for cell in row] - table.add_row(*row) - - console.print(table) - - -# ------------------------------------------------------------------------------ -# Structure Guidance -# ------------------------------------------------------------------------------ - -# Very unlikely number for any valid int param -_INT_SENTINEL_DEFAULT = -4231415 - - -def prompt_for_schema(schema_class: type, prefilled: Optional[Dict[str, Any]] = None, indent: str = "") -> Dict[str, Any]: - """Interactively prompt user for fields based on a Pydantic schema. - - Args: - schema_class: The Pydantic model class to use for prompting - prefilled: Optional dictionary of pre-filled values to skip prompting for - indent: Indentation string for nested fields - - Returns: - Dictionary with the user's input data (includes prefilled values) - """ - from typing import get_args, get_origin - - def _format_indent(indt: str) -> str: - """Format the indentation as dim""" - return f"[dim]{indt}[/dim]" if indt else indt - - formatted_indent = _format_indent(indent) - next_indent = indent - if not next_indent: - next_indent = "|" - next_indent += "-" - formatted_next_indent = _format_indent(next_indent) - - console = get_console() - console.print(f"\n{formatted_indent}[bold cyan]Creating {schema_class.__name__}[/bold cyan]") - console.print(f"{formatted_indent}[dim]Press Enter to skip optional fields[/dim]\n{formatted_indent}") - - data = prefilled.copy() if prefilled else {} - model_fields = schema_class.model_fields - - for field_name, field_info in model_fields.items(): - # Skip if already provided - if field_name in data: - console.print(f"{formatted_indent}[dim]{field_name}: {data[field_name]} (pre-filled)[/dim]") - continue - - # Skip internal fields - if field_name in ["model_config", "auth_value"]: - continue - - # Get field metadata - annotation = field_info.annotation - description = field_info.description or field_name - is_required = field_info.is_required() - default = field_info.default if field_info.default is not PydanticUndefined else None - - # Get the actual type (handle Optional, Union, etc.) - origin = get_origin(annotation) - args = get_args(annotation) - - # Determine the base type - if origin is Union: - # Handle Optional[T] which is Union[T, None] - actual_type = args[0] if len(args) > 0 and type(None) in args else annotation - else: - actual_type = annotation - - # Create prompt text - prompt_text = f"{field_name}" - if description and description != field_name: - prompt_text += f" ({description})" - if default and default != "": - prompt_text += f" [default: {default}]" - if not is_required: - prompt_text += " [optional]" - - # Handle different types - if actual_type is bool or str(actual_type) == "bool": - if not is_required: - console.print(f"{formatted_indent}Include {field_name}?", end="") - if is_required or typer.confirm("", default=False): - console.print(f"{formatted_indent}{prompt_text}", end="") - data[field_name] = typer.prompt("", default=bool(default) if default else False, type=bool) - - elif actual_type is int or str(actual_type) == "int": - default_val = default - if default is None: - default_val = "" if is_required else _INT_SENTINEL_DEFAULT - console.print(f"{formatted_indent}{prompt_text}", end="") - value = typer.prompt("", type=int, default=default_val, show_default=default_val not in ["", _INT_SENTINEL_DEFAULT]) - if value != _INT_SENTINEL_DEFAULT: - data[field_name] = value - - elif isinstance(actual_type, type) and issubclass(actual_type, BaseModel): - console.print(f"{formatted_indent}[yellow]{prompt_text}[/yellow]") - data[field_name] = prompt_for_schema(actual_type, indent=next_indent) - - elif get_origin(actual_type) is list or str(actual_type).startswith("list"): - list_type = get_args(actual_type)[0] - console.print(f"{formatted_indent}[yellow]{prompt_text}[/yellow]") - if isinstance(list_type, type) and issubclass(list_type, BaseModel): - # Loop collecting more arguments until the user wants to stop - entries = [] - while True: - console.print(f"{formatted_indent}[dim]Add an entry?[/dim] ", end="") - if not typer.confirm("", default=False): - break - if not indent: - indent = "|" - indent += "-" - entries.append(prompt_for_schema(list_type, indent=indent)) - data[field_name] = entries - - # Assume string - else: - console.print(f"{formatted_indent}[dim]Enter comma-separated values, or press Enter to skip[/dim] ", end="") - value = typer.prompt("", default="", show_default=False) - if value: - # Parse comma-separated values - data[field_name] = [v.strip() for v in value.split(",") if v.strip()] - - elif get_origin(actual_type) is dict: - dict_key_type, dict_value_type = get_args(actual_type) - console.print(f"{formatted_indent}[yellow]{prompt_text}[/yellow]") - assert dict_key_type is str, "Only string keys are supported" - data[field_name] = {} - while True: - console.print(f"{formatted_indent}[dim]Add an entry?[/dim] ", end="") - if not typer.confirm("", default=False): - break - console.print(f"{formatted_next_indent}Enter key", end="") - key = typer.prompt("") - if isinstance(dict_value_type, type) and issubclass(dict_value_type, BaseModel): - val = prompt_for_schema(dict_value_type, indent=next_indent) - else: - console.print(f"{formatted_next_indent}Enter value", end="") - parse_type = dict_value_type if dict_value_type is not Any else str - val = typer.prompt("", type=parse_type) - - # If the value type is Any, try to parse it as JSON - if dict_value_type is Any: - val = json.loads(val) - data[field_name][key] = val - - else: # Treat as string - console.print(f"{formatted_indent}{prompt_text}", end="") - value = typer.prompt( - "", - type=str, - default=default if default is not None else "", - show_default=default is not None and default != "", - ) - if value and value != "": - data[field_name] = value - if is_required and not value: - raise CLIError(f"Field '{field_name}' is required") - - return data diff --git a/cforge/common/__init__.py b/cforge/common/__init__.py new file mode 100644 index 0000000..afcfd08 --- /dev/null +++ b/cforge/common/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- +"""Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +Shared CLI helper modules. + +This package groups cross-cutting utilities used by command modules: +console/app construction, error handling, HTTP/auth access, Rich rendering, +and schema-driven prompting. Consumers should import from concrete submodules +instead of importing from `cforge.common` directly. +""" diff --git a/cforge/common/console.py b/cforge/common/console.py new file mode 100644 index 0000000..5e3ab60 --- /dev/null +++ b/cforge/common/console.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +""" +SPDX-License-Identifier: Apache-2.0 + +Console and CLI application factories. + +This module centralizes creation of the shared Rich console and Typer app. +Both are cached so commands across the process use consistent output and app +configuration without repeated construction. +""" + +from functools import lru_cache + +from rich.console import Console +import typer + + +@lru_cache +def get_console() -> Console: + """Get the console singleton.""" + return Console() + + +@lru_cache +def get_app() -> typer.Typer: + """Get the typer singleton.""" + return typer.Typer( + name="mcpgateway", + help="MCP Gateway - Production-grade MCP Gateway & Proxy CLI", + add_completion=True, + rich_markup_mode="rich", + ) diff --git a/cforge/common/errors.py b/cforge/common/errors.py new file mode 100644 index 0000000..58e96e7 --- /dev/null +++ b/cforge/common/errors.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +""" +SPDX-License-Identifier: Apache-2.0 + +Error domain for CLI execution. + +This module defines CLI-facing exception types and helpers that normalize +exception payloads into user-friendly terminal output. It is the shared +boundary between internal failures and surfaced command errors. +""" + +from enum import Enum +import json +from typing import Any, Optional, Tuple + +import typer + +from cforge.common.console import get_console + + +class CLIError(Exception): + """Base class for CLI-related errors.""" + + +class AuthenticationError(CLIError): + """Raised when authentication fails.""" + + +class CaseInsensitiveEnum(str, Enum): + """Enum that supports case-insensitive parsing for CLI options.""" + + @classmethod + def _missing_(cls, value: object) -> Optional["CaseInsensitiveEnum"]: + """Resolve unknown values by matching enum values case-insensitively.""" + if not isinstance(value, str): + return None + value_folded = value.casefold() + for member in cls: + if member.value.casefold() == value_folded: + return member + return None + + +def split_exception_details(exception: Exception) -> Tuple[str, Any]: + """Try to parse JSON details from an exception string.""" + exc_str = str(exception) + splits = exc_str.split(":", 1) + if len(splits) == 2: + try: + parsed_details = json.loads(splits[1]) + return splits[0], parsed_details + except json.JSONDecodeError: + pass + return exc_str, None + + +def handle_exception(exception: Exception) -> None: + """Handle an exception and print a friendly error message.""" + from cforge.common.render import print_json + + e_str, e_detail = split_exception_details(exception) + get_console().print(f"[red]Error: {e_str}[/red]") + if e_detail: + print_json(e_detail, "Error details") + raise typer.Exit(1) diff --git a/cforge/common/http.py b/cforge/common/http.py new file mode 100644 index 0000000..0de7755 --- /dev/null +++ b/cforge/common/http.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +""" +SPDX-License-Identifier: Apache-2.0 + +HTTP and authentication primitives for gateway access. + +This module owns token persistence, profile-aware auth lookup, optional +auto-login, and authenticated request dispatch. Command modules call these +helpers instead of handling auth headers and base URL resolution themselves. +""" + +from pathlib import Path +from typing import Any, Dict, Optional + +import requests + +from cforge.common.errors import AuthenticationError, CLIError +from cforge.config import get_settings +from cforge.credential_store import load_profile_credentials +from cforge.profile_utils import DEFAULT_PROFILE_ID, get_active_profile + + +def get_base_url() -> str: + """Get the full base URL for the current profile's server.""" + return get_active_profile().api_url + + +def get_token_file() -> Path: + """Get the path to the token file in contextforge_home.""" + profile = get_active_profile() + suffix = "" if profile.id == DEFAULT_PROFILE_ID else f".{profile.id}" + return get_settings().contextforge_home / f"token{suffix}" + + +def save_token(token: str) -> None: + """Save authentication token to contextforge_home/token file.""" + token_file = get_token_file() + token_file.parent.mkdir(parents=True, exist_ok=True) + token_file.write_text(token, encoding="utf-8") + token_file.chmod(0o600) + + +def load_token() -> Optional[str]: + """Load authentication token from contextforge_home/token file.""" + token_file = get_token_file() + if token_file.exists(): + return token_file.read_text(encoding="utf-8").strip() + return None + + +def attempt_auto_login() -> Optional[str]: + """Attempt to automatically login using stored credentials.""" + profile = get_active_profile() + credentials = load_profile_credentials(profile.id) + if not credentials or not credentials.get("email") or not credentials.get("password"): + return None + + try: + gateway_url = get_base_url() + response = requests.post( + f"{gateway_url}/auth/email/login", + json={"email": credentials["email"], "password": credentials["password"]}, + headers={"Content-Type": "application/json"}, + ) + + if response.status_code == 200: + data = response.json() + token = data.get("access_token") + if token: + save_token(token) + return token + except Exception: + pass + + return None + + +def get_auth_token() -> Optional[str]: + """Get authentication token from environment, token file, or auto-login.""" + token: Optional[str] = get_settings().mcpgateway_bearer_token + if token: + return token + + token = load_token() + if token: + return token + + token = attempt_auto_login() + if token: + return token + + return None + + +def make_authenticated_request( + method: str, + url: str, + json_data: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Make an authenticated HTTP request to the gateway API.""" + token = get_auth_token() + + headers = {"Content-Type": "application/json"} + if token: + if token.startswith("Basic "): + headers["Authorization"] = token + else: + headers["Authorization"] = f"Bearer {token}" + + gateway_url = get_base_url() + full_url = f"{gateway_url}{url}" + + try: + response = requests.request(method=method, url=full_url, json=json_data, params=params, headers=headers) + + if response.status_code in (401, 403): + raise AuthenticationError("Authentication required but not configured. Set MCPGATEWAY_BEARER_TOKEN environment variable or run 'cforge login'.") + + if response.status_code >= 400: + raise CLIError(f"API request failed ({response.status_code}): {response.text}") + + return response.json() + + except requests.RequestException as exception: + raise CLIError(f"Failed to connect to gateway at {gateway_url}: {str(exception)}") diff --git a/cforge/common/prompting.py b/cforge/common/prompting.py new file mode 100644 index 0000000..0e36455 --- /dev/null +++ b/cforge/common/prompting.py @@ -0,0 +1,783 @@ +# -*- coding: utf-8 -*- +""" +SPDX-License-Identifier: Apache-2.0 + +Schema-driven interactive prompting utilities. + +This module converts Pydantic annotations and JSON Schema definitions into +interactive CLI prompts, including nested objects, arrays, enums, optional +fields, and local `$ref` resolution. It is the shared input pipeline used by +commands that need structured request payloads. +""" + +import json +from typing import Annotated, Any, Callable, Dict, get_args, get_origin, get_type_hints, List, Optional, Tuple, Union + +from pydantic import BaseModel +from rich.console import Console +import typer + +from cforge.common.console import get_console +from cforge.common.errors import CLIError +from cforge.common.schema_validation import validate_instance, validate_instance_against_subschema, validate_schema + +_INT_SENTINEL_DEFAULT = -4231415 + + +def _format_prompt_indent(indt: str) -> str: + """Render indentation in a dim style for readability.""" + return f"[dim]{indt}[/dim]" if indt else indt + + +def _next_prompt_indent(indt: str) -> str: + """Return the next indentation level.""" + if not indt: + return "|-" + return f"{indt}-" + + +def _build_prompt_text( + field_name: str, + description: Optional[str], + default: Any, + default_is_set: bool, + is_required: bool, + include_falsy_default: bool = True, +) -> str: + """Build prompt text for a field.""" + prompt_text = field_name + if description and description != field_name: + prompt_text += f" ({description})" + + has_default = default_is_set and default is not None and (include_falsy_default or bool(default)) + if has_default: + try: + default_text = json.dumps(default) + except TypeError: + default_text = str(default) + prompt_text += f" [default: {default_text}]" + + if not is_required: + prompt_text += " [optional]" + return prompt_text + + +def _prompt_include_field(console: Console, field_name: str, field_indent: str) -> bool: + """Prompt whether to include an optional field.""" + formatted_field_indent = _format_prompt_indent(field_indent) + console.print(f"{formatted_field_indent}[dim]Include {field_name}?[/dim] ", end="") + return typer.confirm("", default=False) + + +def _prompt_boolean_value( + console: Console, + field_indent: str, + prompt_text: str, + default: bool, + show_default: bool, +) -> bool: + """Prompt for a boolean field value.""" + formatted_field_indent = _format_prompt_indent(field_indent) + console.print(f"{formatted_field_indent}{prompt_text}", end="") + return typer.prompt("", default=default, type=bool, show_default=show_default) + + +def _prompt_integer_value( + console: Console, + field_indent: str, + prompt_text: str, + default: Any, + show_default: bool, +) -> int: + """Prompt for an integer field value.""" + formatted_field_indent = _format_prompt_indent(field_indent) + console.print(f"{formatted_field_indent}{prompt_text}", end="") + return typer.prompt("", type=int, default=default, show_default=show_default) + + +def _prompt_string_value( + console: Console, + field_indent: str, + prompt_text: str, + default: str, + show_default: bool, +) -> str: + """Prompt for a string field value.""" + formatted_field_indent = _format_prompt_indent(field_indent) + console.print(f"{formatted_field_indent}{prompt_text}", end="") + return typer.prompt("", type=str, default=default, show_default=show_default) + + +def _validate_pydantic_schema_dict_key_types(schema_class: type[BaseModel]) -> None: + """Reject dict fields with non-string keys (JSON object keys are always strings).""" + visited_models: set[type[BaseModel]] = set() + + def _unwrap_annotated(annotation: Any) -> Any: + """Return the underlying type for Annotated[T, ...] annotations.""" + origin = get_origin(annotation) + if origin is Annotated: + args = get_args(annotation) + if not args: # pragma: no cover - defensive for patched typing helpers + return annotation + return args[0] + return annotation + + def _resolve_type_hints(model_class: type[BaseModel]) -> Dict[str, Any]: + """Resolve type hints for a Pydantic model, keeping Annotated extras.""" + try: + return get_type_hints(model_class, include_extras=True) + except Exception: # pragma: no cover - defensive fallback for complex forward refs + return {} + + def _visit_annotation(annotation: Any) -> None: + """Walk annotation trees to find dict key types and nested models.""" + annotation = _unwrap_annotated(annotation) + origin = get_origin(annotation) + + if origin is Union: + for arg in get_args(annotation): + if arg is type(None): + continue + _visit_annotation(arg) + return + + if origin in {list, set, frozenset, tuple}: + for arg in get_args(annotation): + _visit_annotation(arg) + return + + if origin is dict: + args = get_args(annotation) + dict_key_type = _unwrap_annotated(args[0]) if len(args) > 0 else str + if dict_key_type is not str: + raise CLIError("Only string keys are supported") + if len(args) > 1: + _visit_annotation(args[1]) + return + + if isinstance(annotation, type) and issubclass(annotation, BaseModel): + _visit_model(annotation) + return + + def _visit_model(model_class: type[BaseModel]) -> None: + """Visit model field annotations recursively, avoiding cycles.""" + if model_class in visited_models: + return + visited_models.add(model_class) + + resolved_hints = _resolve_type_hints(model_class) + for field_name, field_info in model_class.model_fields.items(): + annotation = resolved_hints.get(field_name, field_info.annotation) + _visit_annotation(annotation) + + _visit_model(schema_class) + + +def _as_schema_dict(value: Any) -> Dict[str, Any]: + """Return a schema dictionary, coercing unknown values to an empty schema.""" + return value if isinstance(value, dict) else {} + + +def _resolve_json_pointer(root_schema: Dict[str, Any], ref: str) -> Dict[str, Any]: + """Resolve a local JSON Pointer reference against the root schema.""" + if not ref.startswith("#/"): + raise CLIError(f"Only local schema references are supported: {ref}") + + pointer_tokens = ref[2:].split("/") + current: Any = root_schema + + for raw_token in pointer_tokens: + token = raw_token.replace("~1", "/").replace("~0", "~") + + if isinstance(current, dict): + if token not in current: + raise CLIError(f"Schema reference not found: {ref}") + current = current[token] + continue + + if isinstance(current, list): + try: + index = int(token) + except ValueError: + raise CLIError(f"Invalid array index in schema reference: {ref}") + if index < 0 or index >= len(current): + raise CLIError(f"Schema reference index out of bounds: {ref}") + current = current[index] + continue + + raise CLIError(f"Schema reference path is invalid: {ref}") + + if not isinstance(current, dict): + raise CLIError(f"Schema reference does not resolve to an object schema: {ref}") + + return current + + +def _resolve_ref_schema(root_schema: Dict[str, Any], field_schema: Dict[str, Any], resolving: Tuple[str, ...] = ()) -> Dict[str, Any]: + """Resolve $ref in schema and merge sibling keys as overrides.""" + if not isinstance(field_schema, dict): + return {} + + ref_value = field_schema.get("$ref") + if not isinstance(ref_value, str): + return field_schema + + if ref_value in resolving: + ref_chain = " -> ".join([*resolving, ref_value]) + raise CLIError(f"Cyclic schema reference detected: {ref_chain}") + + resolved_schema = _resolve_ref_schema(root_schema, _resolve_json_pointer(root_schema, ref_value), resolving + (ref_value,)) + merged_schema = resolved_schema.copy() + merged_schema.update({key: value for key, value in field_schema.items() if key != "$ref"}) + return merged_schema + + +def _infer_schema_type(field_schema: Dict[str, Any]) -> Optional[str]: + """Infer schema type from direct JSON Schema keywords.""" + raw_type = field_schema.get("type") + if isinstance(raw_type, str): + return raw_type + if isinstance(raw_type, list): + valid_types = [item for item in raw_type if isinstance(item, str)] + non_null_types = [item for item in valid_types if item != "null"] + if non_null_types: + first_non_null_type = non_null_types[0] + if all(item == first_non_null_type for item in non_null_types): + return first_non_null_type + return "union" + if "null" in valid_types: + return "null" + if isinstance(field_schema.get("properties"), dict): + return "object" + if "required" in field_schema: + return "object" + if isinstance(field_schema.get("items"), dict): + return "array" + if isinstance(field_schema.get("enum"), list): + return "string" + return None + + +def _schema_contains_ref(field_schema: Dict[str, Any], target_ref: str, visited: Optional[set[int]] = None) -> bool: + """Return True when the schema tree contains the target local $ref.""" + if not isinstance(field_schema, dict): + return False + + if visited is None: + visited = set() + schema_id = id(field_schema) + if schema_id in visited: + return False + visited.add(schema_id) + + ref_value = field_schema.get("$ref") + if isinstance(ref_value, str) and ref_value == target_ref: + return True + + for value in field_schema.values(): + if isinstance(value, dict) and _schema_contains_ref(value, target_ref, visited): + return True + if isinstance(value, list): + for item in value: + if isinstance(item, dict) and _schema_contains_ref(item, target_ref, visited): + return True + return False + + +def _resolve_effective_schema( + root_schema: Dict[str, Any], + field_schema: Dict[str, Any], + resolving_refs: Tuple[str, ...] = (), + collapse_nullable: bool = True, +) -> Dict[str, Any]: + """Resolve refs and collapse only nullable anyOf/oneOf schemas.""" + if not isinstance(field_schema, dict): + return {} + + ref_value = field_schema.get("$ref") + if isinstance(ref_value, str): + if ref_value in resolving_refs: + return field_schema + resolving_refs = (*resolving_refs, ref_value) + + resolved_schema = _resolve_ref_schema(root_schema, field_schema) + if not collapse_nullable: + return resolved_schema + + for combinator_key in ("anyOf", "oneOf"): + options = resolved_schema.get(combinator_key) + if not isinstance(options, list) or not options: + continue + + resolved_options: List[Tuple[Dict[str, Any], Dict[str, Any]]] = [] + for option in options: + option_schema = _as_schema_dict(option) + resolved_option = _resolve_effective_schema(root_schema, option_schema, resolving_refs, collapse_nullable) + resolved_options.append((option_schema, resolved_option)) + + non_null_options: List[Tuple[Dict[str, Any], Dict[str, Any]]] = [] + has_null_option = False + for option_schema, resolved_option in resolved_options: + option_type = _resolve_schema_type(root_schema, resolved_option, resolving_refs) + if option_type == "null": + has_null_option = True + continue + non_null_options.append((option_schema, resolved_option)) + + if len(non_null_options) == 1 and (has_null_option or len(resolved_options) == 1): + selected_option, selected_resolved_option = non_null_options[0] + option_ref = selected_option.get("$ref") + if isinstance(option_ref, str) and _schema_contains_ref(selected_resolved_option, option_ref): + # Keep nullable wrapper so recursive schemas retain a terminating null path. + return resolved_schema + + merged_schema = selected_resolved_option.copy() + merged_schema.update({key: value for key, value in resolved_schema.items() if key not in {"anyOf", "oneOf"}}) + return merged_schema + + return resolved_schema + return resolved_schema + + +def _resolve_schema_type( + root_schema: Dict[str, Any], + field_schema: Dict[str, Any], + resolving_refs: Tuple[str, ...] = (), +) -> str: + """Resolve schema type from JSON Schema type fields and structural hints.""" + if not isinstance(field_schema, dict): + return "string" + + ref_value = field_schema.get("$ref") + if isinstance(ref_value, str): + if ref_value in resolving_refs: + return "union" + resolving_refs = (*resolving_refs, ref_value) + + field_schema = _resolve_ref_schema(root_schema, field_schema) + direct_type = _infer_schema_type(field_schema) + if direct_type is not None: + return direct_type + + for combinator_key in ("anyOf", "oneOf"): + options = field_schema.get(combinator_key) + if not isinstance(options, list) or not options: + continue + + option_results: List[Tuple[Dict[str, Any], str]] = [] + for option in options: + option_schema = _as_schema_dict(option) + option_type = _resolve_schema_type(root_schema, option_schema, resolving_refs) + option_results.append((option_schema, option_type)) + + non_null_results = [(option_schema, option_type) for option_schema, option_type in option_results if option_type != "null"] + has_null_option = any(option_type == "null" for _, option_type in option_results) + + if not non_null_results: + return "null" + + if len(non_null_results) == 1: + non_null_option_schema, non_null_type = non_null_results[0] + option_ref = non_null_option_schema.get("$ref") + if has_null_option and isinstance(option_ref, str): + resolved_non_null_option = _resolve_ref_schema(root_schema, non_null_option_schema) + if _schema_contains_ref(resolved_non_null_option, option_ref): + # Keep recursive nullable refs as union prompts to avoid auto-generating empty objects. + return "union" + return non_null_type + return "union" + + return "string" + + +def _prompt_from_json_schema( + schema: Dict[str, Any], + prefilled: Optional[Dict[str, Any]] = None, + indent: str = "", + prompt_optional: bool = True, + default_display_name: str = "Tool Arguments", + validate_payload: bool = True, +) -> Dict[str, Any]: + """Prompt recursively from schema dictionaries shared by both public APIs.""" + if not isinstance(schema, dict): + raise CLIError("Input schema must be a JSON object") + if prefilled is not None and not isinstance(prefilled, dict): + raise CLIError("Prefilled input must be a JSON object") + + if validate_payload: + schema_error = validate_schema(schema) + if schema_error is not None: + raise CLIError(schema_error) + + resolved_root_schema = _resolve_effective_schema(schema, schema) + root_type = _resolve_schema_type(schema, resolved_root_schema) + if root_type != "object": + raise CLIError("Input schema must be an object schema") + + _MISSING = object() + console = get_console() + formatted_indent = _format_prompt_indent(indent) + display_name = resolved_root_schema.get("title", default_display_name) + if not isinstance(display_name, str) or not display_name: + display_name = default_display_name + + console.print(f"\n{formatted_indent}[bold cyan]Creating {display_name}[/bold cyan]") + if prompt_optional: + console.print(f"{formatted_indent}[dim]Press Enter to skip optional fields[/dim]\n{formatted_indent}") + else: + console.print(f"{formatted_indent}[dim]Prompting for missing required fields only[/dim]\n{formatted_indent}") + + def _prompt_field_value( + field_name: str, + field_schema: Dict[str, Any], + is_required: bool, + field_indent: str, + prefilled_value: Any = _MISSING, + ) -> Tuple[Any, bool]: + """Prompt for a single field value. + + Returns: + tuple[value, included] where included indicates if field should be set. + """ + field_schema = _resolve_effective_schema(schema, field_schema) + schema_type = _resolve_schema_type(schema, field_schema) + formatted_field_indent = _format_prompt_indent(field_indent) + include_falsy_default = True + prompt_text = _build_prompt_text( + field_name=field_name, + description=field_schema.get("description") if isinstance(field_schema.get("description"), str) else None, + default=field_schema.get("default"), + default_is_set="default" in field_schema, + is_required=is_required, + include_falsy_default=include_falsy_default, + ) + + def _format_string_default(default_value: Any) -> Tuple[str, bool]: + """Return a default prompt value (as text) and whether to show it.""" + if default_value is None: + return "", False + if isinstance(default_value, str): + return default_value, True + return json.dumps(default_value), True + + def _prompt_string_with_default() -> Tuple[Optional[str], bool]: + """Prompt for a string value while honoring schema defaults and required-ness.""" + default_text, show_default = _format_string_default(field_schema.get("default")) + value = _prompt_string_value( + console=console, + field_indent=field_indent, + prompt_text=prompt_text, + default=default_text, + show_default=show_default, + ) + if value == "": + if is_required: + raise CLIError(f"Field '{field_name}' is required") + return None, False + return value, True + + if prefilled_value is not _MISSING: + if schema_type == "object" and isinstance(prefilled_value, dict): + console.print(f"{formatted_field_indent}[dim]{field_name}: (pre-filled object)[/dim]") + return _prompt_object(field_schema, prefilled=prefilled_value, object_indent=_next_prompt_indent(field_indent)), True + if schema_type == "array" and isinstance(prefilled_value, list): + console.print(f"{formatted_field_indent}[dim]{field_name}: (pre-filled array)[/dim]") + array_values, include_array = _prompt_array(field_name, field_schema, is_required, field_indent, prefilled=prefilled_value) + return array_values, include_array + console.print(f"{formatted_field_indent}[dim]{field_name}: {prefilled_value} (pre-filled)[/dim]") + return prefilled_value, True + + if not is_required and not prompt_optional: + return None, False + + enum_values = field_schema.get("enum") + if isinstance(enum_values, list) and enum_values: + enum_text = ", ".join(json.dumps(value) for value in enum_values) + console.print(f"{formatted_field_indent}{prompt_text} [choices: {enum_text}]", end="") + + default_value = field_schema.get("default") + default_text = "" + show_default = False + if default_value is not None: + default_text = default_value if isinstance(default_value, str) else json.dumps(default_value) + show_default = True + + raw_value = typer.prompt("", type=str, default=default_text, show_default=show_default) + if raw_value == "": + if is_required: + raise CLIError(f"Field '{field_name}' is required") + return None, False + + parsed_value: Any + try: + parsed_value = json.loads(raw_value) + except json.JSONDecodeError: + parsed_value = raw_value + + if parsed_value in enum_values: + return parsed_value, True + if raw_value in enum_values: + return raw_value, True + + raise CLIError(f"Field '{field_name}' must be one of: {enum_text}") + + if schema_type == "object": + if not is_required and prompt_optional and not _prompt_include_field(console, field_name, field_indent): + return None, False + console.print(f"{formatted_field_indent}[yellow]{prompt_text}[/yellow]") + return _prompt_object(field_schema, prefilled=None, object_indent=_next_prompt_indent(field_indent)), True + + if schema_type == "array": + return _prompt_array(field_name, field_schema, is_required, field_indent) + + if schema_type == "boolean": + if not is_required and prompt_optional and not _prompt_include_field(console, field_name, field_indent): + return None, False + default_val = field_schema.get("default") + bool_default = default_val if isinstance(default_val, bool) else False + return ( + _prompt_boolean_value( + console=console, + field_indent=field_indent, + prompt_text=prompt_text, + default=bool_default, + show_default=True, + ), + True, + ) + + if schema_type == "integer": + default_val = field_schema.get("default") + default_prompt: Any = "" if is_required else _INT_SENTINEL_DEFAULT + show_default = False + if isinstance(default_val, int): + default_prompt = default_val + show_default = True + value = _prompt_integer_value( + console=console, + field_indent=field_indent, + prompt_text=prompt_text, + default=default_prompt, + show_default=show_default, + ) + if value == _INT_SENTINEL_DEFAULT: + if is_required: + raise CLIError(f"Field '{field_name}' is required") + return None, False + return value, True + + if schema_type == "number": + default_val = field_schema.get("default") + default_text = "" + show_default = False + if isinstance(default_val, (int, float)): + default_text = str(float(default_val)) + show_default = True + raw_value = _prompt_string_value( + console=console, + field_indent=field_indent, + prompt_text=prompt_text, + default=default_text, + show_default=show_default, + ) + if raw_value == "": + if is_required: + raise CLIError(f"Field '{field_name}' is required") + return None, False + try: + return float(raw_value), True + except ValueError: + raise CLIError(f"Field '{field_name}' must be a number") + + if schema_type == "null": + return None, True + + if schema_type == "union": + raw_value, include_raw = _prompt_string_with_default() + if not include_raw: + return None, False + + parsed_value: Any + try: + parsed_value = json.loads(raw_value) + except json.JSONDecodeError: + parsed_value = raw_value + + validation_error = validate_instance_against_subschema(schema, field_schema, parsed_value) + if validation_error is None: + return parsed_value, True + + raise CLIError(f"Field '{field_name}' is invalid: {validation_error}") + + value, include_value = _prompt_string_with_default() + if not include_value: + return None, False + return value, True + + def _prompt_array( + field_name: str, + field_schema: Dict[str, Any], + is_required: bool, + field_indent: str, + prefilled: Optional[List[Any]] = None, + ) -> Tuple[List[Any], bool]: + """Prompt for array values.""" + field_schema = _resolve_effective_schema(schema, field_schema) + formatted_field_indent = _format_prompt_indent(field_indent) + item_schema = _resolve_effective_schema(schema, _as_schema_dict(field_schema.get("items"))) + item_type = _resolve_schema_type(schema, item_schema) + + if prefilled is None and item_type != "string" and not is_required and prompt_optional: + if not _prompt_include_field(console, field_name, field_indent): + return [], False + + values: List[Any] = [] + + if prefilled is not None: + nested_indent = _next_prompt_indent(field_indent) + for idx, entry in enumerate(prefilled): + if item_type == "object" and isinstance(entry, dict): + values.append(_prompt_object(item_schema, prefilled=entry, object_indent=nested_indent)) + elif item_type == "array" and isinstance(entry, list): + nested_values, _ = _prompt_array(f"{field_name}[{idx}]", item_schema, is_required=True, field_indent=nested_indent, prefilled=entry) + values.append(nested_values) + else: + values.append(entry) + return values, True + + if item_type == "string": + console.print(f"{formatted_field_indent}[dim]Enter comma-separated values, or press Enter to skip[/dim] ", end="") + csv_value = typer.prompt("", default="", show_default=False) + if csv_value: + return [value.strip() for value in csv_value.split(",") if value.strip()], True + if is_required: + return [], True + return [], False + + nested_indent = _next_prompt_indent(field_indent) + while True: + console.print(f"{formatted_field_indent}[dim]Add an entry to {field_name}?[/dim] ", end="") + if not typer.confirm("", default=False): + break + entry, _ = _prompt_field_value("item", item_schema, is_required=True, field_indent=nested_indent) + values.append(entry) + + if values: + return values, True + return [], True + + def _prompt_additional_properties( + object_indent: str, + assign_value: Callable[[str, str], None], + ) -> None: + """Prompt for additional object properties using a supplied value handler.""" + while True: + formatted_object_indent = _format_prompt_indent(object_indent) + next_indent = _next_prompt_indent(object_indent) + formatted_next_indent = _format_prompt_indent(next_indent) + console.print(f"{formatted_object_indent}[dim]Add an extra field?[/dim] ", end="") + if not typer.confirm("", default=False): + break + console.print(f"{formatted_next_indent}Enter key", end="") + key = typer.prompt("", type=str) + assign_value(key, next_indent) + + def _prompt_object(field_schema: Dict[str, Any], prefilled: Optional[Dict[str, Any]], object_indent: str) -> Dict[str, Any]: + """Prompt for object fields recursively.""" + field_schema = _resolve_effective_schema(schema, field_schema) + data = prefilled.copy() if prefilled else {} + properties = field_schema.get("properties") + required = field_schema.get("required") + properties_dict = properties if isinstance(properties, dict) else {} + required_fields = {field for field in required if isinstance(field, str)} if isinstance(required, list) else set() + + for field_name, field_details in properties_dict.items(): + field_def = _resolve_effective_schema(schema, _as_schema_dict(field_details)) + has_prefilled = field_name in data + current_value = data[field_name] if has_prefilled else _MISSING + value, include_value = _prompt_field_value( + field_name=field_name, + field_schema=field_def, + is_required=field_name in required_fields, + field_indent=object_indent, + prefilled_value=current_value, + ) + if include_value: + data[field_name] = value + + additional_properties = field_schema.get("additionalProperties") + if isinstance(additional_properties, dict) and prompt_optional: + additional_properties_schema = _resolve_effective_schema(schema, additional_properties) + + def _assign_typed_value(key: str, next_indent: str) -> None: + """Prompt for and assign a typed additional property value.""" + value, _ = _prompt_field_value(key, additional_properties_schema, is_required=True, field_indent=next_indent) + data[key] = value + + _prompt_additional_properties(object_indent, _assign_typed_value) + elif additional_properties is True and prompt_optional: + + def _assign_json_value(key: str, next_indent: str) -> None: + """Prompt for and assign a JSON additional property value.""" + formatted_next_indent = _format_prompt_indent(next_indent) + console.print(f"{formatted_next_indent}Enter JSON value", end="") + raw_value = typer.prompt("", type=str) + try: + data[key] = json.loads(raw_value) + except json.JSONDecodeError: + data[key] = raw_value + + _prompt_additional_properties(object_indent, _assign_json_value) + + return data + + prompted_payload = _prompt_object(resolved_root_schema, prefilled=prefilled, object_indent=indent) + if validate_payload: + payload_validation_error = validate_instance(schema, prompted_payload) + if payload_validation_error is not None: + raise CLIError(f"Prompted payload is invalid: {payload_validation_error}") + return prompted_payload + + +def _strip_schema_internal_properties(schema: Dict[str, Any], skip_fields: set[str]) -> Dict[str, Any]: + """Return a shallow copy of a root object schema without internal prompt-only fields.""" + schema_copy = schema.copy() + properties = schema_copy.get("properties") + if isinstance(properties, dict): + properties_copy = properties.copy() + for field in skip_fields: + properties_copy.pop(field, None) + schema_copy["properties"] = properties_copy + + required = schema_copy.get("required") + if isinstance(required, list): + schema_copy["required"] = [field for field in required if isinstance(field, str) and field not in skip_fields] + + return schema_copy + + +def prompt_for_schema(schema_class: type[BaseModel], prefilled: Optional[Dict[str, Any]] = None, indent: str = "") -> Dict[str, Any]: + """Interactively prompt user for fields based on a Pydantic schema.""" + _validate_pydantic_schema_dict_key_types(schema_class) + schema = _strip_schema_internal_properties(schema_class.model_json_schema(), {"auth_value", "model_config"}) + # The prompt schema is intentionally lossy (e.g., datetime fields become strings), + # so validating the prompted payload against it can reject values that would be + # valid after normal Pydantic validation/coercion. + return _prompt_from_json_schema( + schema, + prefilled=prefilled, + indent=indent, + prompt_optional=True, + default_display_name=schema_class.__name__, + validate_payload=False, + ) + + +def prompt_for_json_schema( + schema: Dict[str, Any], + prefilled: Optional[Dict[str, Any]] = None, + indent: str = "", + prompt_optional: bool = True, +) -> Dict[str, Any]: + """Interactively prompt user for fields based on a JSON Schema object.""" + return _prompt_from_json_schema(schema, prefilled=prefilled, indent=indent, prompt_optional=prompt_optional, default_display_name="Tool Arguments") diff --git a/cforge/common/render.py b/cforge/common/render.py new file mode 100644 index 0000000..c007a0d --- /dev/null +++ b/cforge/common/render.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +""" +SPDX-License-Identifier: Apache-2.0 + +Rich rendering helpers for structured CLI output. + +This module contains reusable output primitives for JSON and tabular data. +It keeps formatting decisions in one place so resource commands can focus on +data retrieval while sharing a consistent terminal presentation. +""" + +import json +from typing import Any, Dict, List, Optional + +from rich.console import Console, ConsoleOptions, RenderableType, RenderResult +from rich.measure import Measurement +from rich.panel import Panel +from rich.segment import Segment +from rich.syntax import Syntax +from rich.table import Table + +from cforge.common.console import get_console +from cforge.config import get_settings + + +class LineLimit: + """A renderable that limits the number of lines after rich wrapping.""" + + def __init__(self, renderable: RenderableType, max_lines: int): + """Initialize with the wrapped renderable and max lines to render.""" + self.renderable = renderable + self.max_lines = max_lines + + def __rich_console__(self, console: Console, options: ConsoleOptions) -> RenderResult: + """Render with line truncation applied after wrapping.""" + lines = console.render_lines(self.renderable, options, pad=False) + for index, line in enumerate(lines): + if index >= self.max_lines: + yield Segment("...") + break + yield from line + yield Segment.line() + + def __rich_measure__(self, console: Console, options: ConsoleOptions) -> Measurement: + """Measure by delegating to the wrapped renderable.""" + return Measurement.get(console, options, self.renderable) + + +def print_json(data: Any, title: Optional[str] = None) -> None: + """Pretty print JSON data with Rich.""" + console = get_console() + json_str = json.dumps(data, indent=2, ensure_ascii=False) + syntax = Syntax(json_str, "json", theme="monokai", line_numbers=True) + if title: + console.print(Panel(syntax, title=title, border_style="green")) + else: + console.print(syntax) + + +def print_table( + data: List[Dict], + title: str, + columns: List[str], + col_name_map: Optional[Dict[str, str]] = None, +) -> None: + """Print data as a Rich table.""" + console = get_console() + table = Table(title=title, show_header=True, header_style="bold magenta") + col_name_map = col_name_map or {} + max_lines = get_settings().table_max_lines + + for column in columns: + table.add_column(col_name_map.get(column, column), style="cyan") + + for item in data: + row = [str(item.get(col, "")) for col in columns] + if max_lines > 0: + row = [LineLimit(cell, max_lines=max_lines) for cell in row] + table.add_row(*row) + + console.print(table) diff --git a/cforge/common/schema_validation.py b/cforge/common/schema_validation.py new file mode 100644 index 0000000..3928c92 --- /dev/null +++ b/cforge/common/schema_validation.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +"""Copyright 2025 +SPDX-License-Identifier: Apache-2.0 + +JSON Schema validation helpers. + +This module centralizes jsonschema-based validation so callers can validate +full payloads or individual field values against sub-schemas while preserving +root-schema `$ref` resolution. +""" + +# Standard +from typing import Any, Dict, Optional, Tuple + +# Third-Party +from jsonschema import SchemaError +from jsonschema.exceptions import ValidationError +from jsonschema.validators import validator_for + + +def _error_sort_key(error: ValidationError) -> Tuple[Tuple[str, ...], Tuple[str, ...]]: + """Sort errors deterministically by instance path then schema path.""" + return tuple(str(part) for part in error.path), tuple(str(part) for part in error.schema_path) + + +def _format_error_path(error: ValidationError) -> str: + """Format an instance path in a compact jsonpath-like style.""" + if not error.path: + return "$" + + segments = ["$"] + for part in error.path: + if isinstance(part, int): + segments.append(f"[{part}]") + else: + segments.append(f".{part}") + return "".join(segments) + + +def _format_validation_error(error: ValidationError) -> str: + """Format a validation error for user-facing CLI messages.""" + location = _format_error_path(error) + if location == "$": + return error.message + return f"{location}: {error.message}" + + +def _first_validation_error_message(errors: list[ValidationError]) -> Optional[str]: + """Return a formatted first error message when validation fails.""" + if not errors: + return None + return _format_validation_error(errors[0]) + + +def _build_root_validator(schema: Dict[str, Any]) -> Any: + """Build a jsonschema validator from a root schema.""" + validator_cls = validator_for(schema) + validator_cls.check_schema(schema) + return validator_cls(schema) + + +def validate_schema(schema: Dict[str, Any]) -> Optional[str]: + """Validate a JSON Schema without validating a specific instance. + + Returns: + A user-facing error message when invalid, otherwise ``None``. + """ + if not isinstance(schema, dict): + return "Input schema must be a JSON object" + + try: + _build_root_validator(schema) + return None + except SchemaError as exc: + return f"Invalid JSON Schema: {exc}" + except Exception as exc: # pragma: no cover - defensive fallback for validator internals + return f"Schema validation failed: {exc}" + + +def validate_instance(schema: Dict[str, Any], instance: Any) -> Optional[str]: + """Validate an instance against a full schema. + + Returns: + A user-facing error message when invalid, otherwise ``None``. + """ + if not isinstance(schema, dict): + return "Input schema must be a JSON object" + + try: + validator = _build_root_validator(schema) + errors = sorted(validator.iter_errors(instance), key=_error_sort_key) + return _first_validation_error_message(errors) + except SchemaError as exc: + return f"Invalid JSON Schema: {exc}" + except Exception as exc: # pragma: no cover - defensive fallback for validator internals + return f"Schema validation failed: {exc}" + + +def validate_instance_against_subschema(root_schema: Dict[str, Any], subschema: Dict[str, Any], instance: Any) -> Optional[str]: + """Validate an instance against a subschema with root `$ref` context. + + Returns: + A user-facing error message when invalid, otherwise ``None``. + """ + if not isinstance(root_schema, dict): + return "Input schema must be a JSON object" + if not isinstance(subschema, dict): + return "Input schema must be a JSON object" + + try: + root_validator = _build_root_validator(root_schema) + root_validator.__class__.check_schema(subschema) + subschema_validator = root_validator.evolve(schema=subschema) + errors = sorted(subschema_validator.iter_errors(instance), key=_error_sort_key) + return _first_validation_error_message(errors) + except SchemaError as exc: + return f"Invalid JSON Schema: {exc}" + except Exception as exc: # pragma: no cover - defensive fallback for validator internals + return f"Schema validation failed: {exc}" diff --git a/cforge/config.py b/cforge/config.py index 921fc7d..e006c1a 100644 --- a/cforge/config.py +++ b/cforge/config.py @@ -21,7 +21,6 @@ from mcpgateway.config import Settings from mcpgateway.config import get_settings as cf_get_settings - HOME_DIR_NAME = ".contextforge" DEFAULT_HOME = Path.home() / HOME_DIR_NAME diff --git a/cforge/main.py b/cforge/main.py index 0d7bcaa..2d75eb8 100644 --- a/cforge/main.py +++ b/cforge/main.py @@ -28,80 +28,82 @@ # Third-Party import typer -# First-Party -from cforge.common import get_app from cforge.commands.deploy.deploy import deploy -from cforge.commands.server.serve import serve -from cforge.commands.server.run import run -from cforge.commands.settings import profiles -from cforge.commands.settings.login import login -from cforge.commands.settings.logout import logout -from cforge.commands.settings.whoami import whoami -from cforge.commands.settings.export import export -from cforge.commands.settings.import_cmd import import_cmd -from cforge.commands.settings.config_schema import config_schema -from cforge.commands.settings.support_bundle import support_bundle -from cforge.commands.settings.version import version from cforge.commands.metrics.metrics import metrics_get, metrics_reset -from cforge.commands.resources.tools import ( - tools_list, - tools_get, - tools_create, - tools_update, - tools_delete, - tools_toggle, +from cforge.commands.resources.a2a import ( + a2a_create, + a2a_delete, + a2a_get, + a2a_invoke, + a2a_list, + a2a_toggle, + a2a_update, ) -from cforge.commands.resources.resources import ( - resources_list, - resources_get, - resources_create, - resources_update, - resources_delete, - resources_toggle, - resources_templates, +from cforge.commands.resources.mcp_servers import ( + mcp_servers_create, + mcp_servers_delete, + mcp_servers_get, + mcp_servers_list, + mcp_servers_toggle, + mcp_servers_update, +) +from cforge.commands.resources.plugins import ( + plugins_get, + plugins_list, + plugins_stats, ) from cforge.commands.resources.prompts import ( - prompts_list, - prompts_get, prompts_create, - prompts_update, prompts_delete, - prompts_toggle, prompts_execute, + prompts_get, + prompts_list, + prompts_toggle, + prompts_update, ) -from cforge.commands.resources.mcp_servers import ( - mcp_servers_list, - mcp_servers_get, - mcp_servers_create, - mcp_servers_update, - mcp_servers_delete, - mcp_servers_toggle, +from cforge.commands.resources.resources import ( + resources_create, + resources_delete, + resources_get, + resources_list, + resources_templates, + resources_toggle, + resources_update, +) +from cforge.commands.resources.tools import ( + tools_create, + tools_delete, + tools_execute, + tools_get, + tools_list, + tools_toggle, + tools_update, ) from cforge.commands.resources.virtual_servers import ( - virtual_servers_list, - virtual_servers_get, virtual_servers_create, - virtual_servers_update, virtual_servers_delete, + virtual_servers_get, + virtual_servers_list, + virtual_servers_prompts, + virtual_servers_resources, virtual_servers_toggle, virtual_servers_tools, - virtual_servers_resources, - virtual_servers_prompts, -) -from cforge.commands.resources.a2a import ( - a2a_list, - a2a_get, - a2a_create, - a2a_update, - a2a_delete, - a2a_toggle, - a2a_invoke, -) -from cforge.commands.resources.plugins import ( - plugins_get, - plugins_list, - plugins_stats, + virtual_servers_update, ) +from cforge.commands.server.run import run +from cforge.commands.server.serve import serve +from cforge.commands.settings import profiles +from cforge.commands.settings.config_schema import config_schema +from cforge.commands.settings.export import export +from cforge.commands.settings.import_cmd import import_cmd +from cforge.commands.settings.login import login +from cforge.commands.settings.logout import logout +from cforge.commands.settings.support_bundle import support_bundle +from cforge.commands.settings.version import version +from cforge.commands.settings.whoami import whoami + +# First-Party +from cforge.common.console import get_app # Get the main app singleton app = get_app() @@ -160,6 +162,7 @@ tools_app.command("update")(tools_update) tools_app.command("delete")(tools_delete) tools_app.command("toggle")(tools_toggle) +tools_app.command("execute")(tools_execute) # --------------------------------------------------------------------------- # Resources command group diff --git a/pyproject.toml b/pyproject.toml index 58422d4..dc0c328 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "typer>=0.20.0", "mcp-contextforge-gateway==1.0.0b2", "cryptography>=44.0.0", + "jsonschema>=4.23.0", ] # ---------------------------------------------------------------- @@ -462,7 +463,7 @@ ignore-nested-functions = false ignore-nested-classes = true ignore-setters = false fail-under = 100 -exclude = ["setup.py", "docs", "build", "tests"] +exclude = ["setup.py", "docs", "build", "tests", "cforge/_version.py", "cforge.egg-info"] ignore-regex = ["^get_", "^post_"] verbose = 0 quiet = false diff --git a/tests/commands/resources/test_plugins.py b/tests/commands/resources/test_plugins.py index dd8e1fc..ad0af73 100644 --- a/tests/commands/resources/test_plugins.py +++ b/tests/commands/resources/test_plugins.py @@ -12,8 +12,9 @@ import typer # First-Party -from cforge.commands.resources.plugins import PluginMode, plugins_get, plugins_list, plugins_stats -from cforge.common import AuthenticationError, CLIError +from cforge.commands.resources.plugins import _parse_plugin_mode, PluginMode, plugins_get, plugins_list, plugins_stats +from cforge.common.errors import AuthenticationError, CLIError +from cforge.main import app from tests.conftest import invoke_typer_command, patch_functions @@ -32,6 +33,11 @@ def test_plugin_mode_enum_missing_unknown_value(self) -> None: """Unknown strings should not be coerced into Enum members.""" assert PluginMode._missing_("nope") is None + def test_parse_plugin_mode_invalid_value_raises(self) -> None: + """Invalid mode values should raise a clear CLIError.""" + with pytest.raises(CLIError, match="Invalid value for '--mode'"): + _parse_plugin_mode("invalid") + def test_plugins_list_success(self, mock_console) -> None: """Test plugins list command with table output.""" mock_response = { @@ -87,6 +93,19 @@ def test_plugins_list_with_filters(self, mock_console) -> None: assert call_args[0][1] == "/admin/plugins" assert call_args[1]["params"] == {"search": "pii", "mode": "enforce", "hook": "tool_pre_invoke", "tag": "security"} + def test_plugins_list_mode_case_insensitive_via_cli(self, cli_runner, mock_console) -> None: + """Test mixed-case --mode values work through actual CLI parsing.""" + with patch_functions( + "cforge.commands.resources.plugins", + get_console=mock_console, + make_authenticated_request={"return_value": {"plugins": [], "total": 0, "enabled_count": 0, "disabled_count": 0}}, + print_table=None, + ) as mocks: + result = cli_runner.invoke(app, ["plugins", "list", "--mode", "EnFoRcE"]) + assert result.exit_code == 0 + call_args = mocks.make_authenticated_request.call_args + assert call_args[1]["params"]["mode"] == "enforce" + def test_plugins_list_error(self, mock_console) -> None: """Test plugins list error handling.""" with patch_functions("cforge.commands.resources.plugins", get_console=mock_console, make_authenticated_request={"side_effect": Exception("API error")}): @@ -111,6 +130,32 @@ def test_plugins_get_error(self, mock_console) -> None: with pytest.raises(typer.Exit): invoke_typer_command(plugins_get, name="pii_filter") + def test_plugins_get_not_found_shows_plugin_hint(self, mock_console) -> None: + """Test plugins get shows a plugin-not-found hint on plugin-specific 404s.""" + with patch_functions( + "cforge.commands.resources.plugins", + get_console=mock_console, + make_authenticated_request={"side_effect": CLIError("API request failed (404): Plugin 'missing_plugin' not found")}, + ): + with pytest.raises(typer.Exit): + invoke_typer_command(plugins_get, name="missing_plugin") + + assert any("Plugin not found: missing_plugin" in str(call) for call in mock_console.print.call_args_list) + assert not any("Admin plugin API unavailable" in str(call) for call in mock_console.print.call_args_list) + + def test_plugins_get_generic_not_found_shows_admin_api_hint(self, mock_console) -> None: + """Test plugins get shows admin-api hint for generic 404 errors.""" + with patch_functions( + "cforge.commands.resources.plugins", + get_console=mock_console, + make_authenticated_request={"side_effect": CLIError("API request failed (404): Not Found")}, + ): + with pytest.raises(typer.Exit): + invoke_typer_command(plugins_get, name="missing_plugin") + + assert any("Admin plugin API unavailable" in str(call) for call in mock_console.print.call_args_list) + assert not any("Plugin not found: missing_plugin" in str(call) for call in mock_console.print.call_args_list) + def test_plugins_stats_success(self, mock_console) -> None: """Test plugins stats command.""" mock_stats = {"total_plugins": 4, "enabled_plugins": 3, "disabled_plugins": 1, "plugins_by_hook": {"tool_pre_invoke": 3}, "plugins_by_mode": {"enforce": 3, "disabled": 1}} @@ -152,3 +197,15 @@ def test_plugins_list_not_found_shows_admin_api_hint(self, mock_console) -> None invoke_typer_command(plugins_list) assert any("Admin plugin API unavailable" in str(call) for call in mock_console.print.call_args_list) + + def test_plugins_list_clierror_without_404_does_not_show_admin_api_hint(self, mock_console) -> None: + """Test non-404 CLI errors do not show an admin-api availability hint.""" + with patch_functions( + "cforge.commands.resources.plugins", + get_console=mock_console, + make_authenticated_request={"side_effect": CLIError("API request failed (500): Internal Server Error")}, + ): + with pytest.raises(typer.Exit): + invoke_typer_command(plugins_list) + + assert not any("Admin plugin API unavailable" in str(call) for call in mock_console.print.call_args_list) diff --git a/tests/commands/resources/test_tools.py b/tests/commands/resources/test_tools.py index 457f944..7a6c3c8 100644 --- a/tests/commands/resources/test_tools.py +++ b/tests/commands/resources/test_tools.py @@ -22,6 +22,7 @@ from cforge.commands.resources.tools import ( tools_create, tools_delete, + tools_execute, tools_get, tools_list, tools_toggle, @@ -151,6 +152,21 @@ def test_tools_update_file_not_found(self, mock_console) -> None: with pytest.raises(typer.Exit): tools_update(tool_id="tool-1", data_file=Path("/nonexistent.json")) + def test_tools_update_interactive(self, mock_console) -> None: + """Test tools update interactive mode.""" + mock_result = {"id": "tool-1", "name": "updated"} + + with patch_functions( + "cforge.commands.resources.tools", + get_console=mock_console, + prompt_for_schema={"return_value": {"description": "updated"}}, + make_authenticated_request={"return_value": mock_result}, + print_json=None, + ) as mocks: + tools_update(tool_id="tool-1", data_file=None) + mocks.make_authenticated_request.assert_called_once_with("PUT", "/tools/tool-1", json_data={"description": "updated"}) + mocks.print_json.assert_called_once() + def test_tools_delete_with_confirmation(self, mock_console) -> None: """Test tools delete with confirmation.""" with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): @@ -238,6 +254,246 @@ def test_tools_toggle_detects_current_status(self, mock_console) -> None: assert calls[0][0][0] == "GET" # First call is GET assert calls[1][0][0] == "POST" # Second call is POST + def test_tools_execute_interactive_success(self, mock_console) -> None: + """Test tools execute with interactive schema prompting.""" + tool_response = { + "id": "tool-1", + "name": "search_tool", + "input_schema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + } + rpc_response = {"jsonrpc": "2.0", "id": "1", "result": {"content": [{"type": "text", "text": "ok"}]}} + + with patch_functions( + "cforge.commands.resources.tools", + get_console=mock_console, + prompt_for_json_schema={"return_value": {"query": "hello"}}, + make_authenticated_request={"side_effect": [tool_response, rpc_response]}, + print_json={}, + ) as mocks: + tools_execute(tool_id="tool-1", data_file=None) + + mocks.prompt_for_json_schema.assert_called_once_with(tool_response["input_schema"], prefilled=None, prompt_optional=True) + mocks.print_json.assert_called_once_with(rpc_response["result"], "Tool Result") + assert mocks.make_authenticated_request.call_count == 2 + rpc_call = mocks.make_authenticated_request.call_args_list[1] + assert rpc_call[0][0] == "POST" + assert rpc_call[0][1] == "/rpc" + payload = rpc_call[1]["json_data"] + assert payload["method"] == "tools/call" + assert payload["params"]["name"] == "search_tool" + assert payload["params"]["arguments"] == {"query": "hello"} + + def test_tools_execute_with_data_file_prompts_missing_required(self, mock_console) -> None: + """Test tools execute merges data file values and prompts only required fields.""" + tool_response = { + "id": "tool-1", + "name": "search_tool", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string"}, + "limit": {"type": "integer"}, + }, + "required": ["query"], + }, + } + rpc_response = {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + + with tempfile.TemporaryDirectory() as temp_dir: + data_file = Path(temp_dir) / "args.json" + data_file.write_text(json.dumps({"limit": 5})) + + with patch_functions( + "cforge.commands.resources.tools", + get_console=mock_console, + prompt_for_json_schema={"return_value": {"query": "hello", "limit": 5}}, + make_authenticated_request={"side_effect": [tool_response, rpc_response]}, + print_json={}, + ) as mocks: + tools_execute(tool_id="tool-1", data_file=data_file) + + mocks.prompt_for_json_schema.assert_called_once_with(tool_response["inputSchema"], prefilled={"limit": 5}, prompt_optional=False) + rpc_call = mocks.make_authenticated_request.call_args_list[1] + payload = rpc_call[1]["json_data"] + assert payload["params"]["arguments"] == {"query": "hello", "limit": 5} + + def test_tools_execute_falls_back_to_input_schema_when_inputschema_null(self, mock_console) -> None: + """Test tools execute uses input_schema when inputSchema exists but is null.""" + tool_response = { + "id": "tool-1", + "name": "search_tool", + "inputSchema": None, + "input_schema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + } + rpc_response = {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.prompt_for_json_schema", return_value={"query": "hello"}) as mock_prompt: + with patch("cforge.commands.resources.tools.make_authenticated_request", side_effect=[tool_response, rpc_response]): + with patch("cforge.commands.resources.tools.print_json"): + tools_execute(tool_id="tool-1", data_file=None) + + mock_prompt.assert_called_once_with(tool_response["input_schema"], prefilled=None, prompt_optional=True) + + def test_tools_execute_data_file_must_be_object(self, mock_console) -> None: + """Test tools execute rejects non-object JSON data files.""" + tool_response = {"id": "tool-1", "name": "search_tool", "input_schema": {"type": "object", "properties": {}}} + + with tempfile.TemporaryDirectory() as temp_dir: + data_file = Path(temp_dir) / "args.json" + data_file.write_text(json.dumps(["not-an-object"])) + + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.make_authenticated_request", return_value=tool_response): + with pytest.raises(typer.Exit): + tools_execute(tool_id="tool-1", data_file=data_file) + + def test_tools_execute_data_file_not_found(self, mock_console) -> None: + """Test tools execute with missing data file.""" + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.make_authenticated_request") as mock_req: + with pytest.raises(typer.Exit): + tools_execute(tool_id="tool-1", data_file=Path("/nonexistent.json")) + mock_req.assert_not_called() + + def test_tools_execute_requires_non_empty_tool_id(self, mock_console) -> None: + """Test tools execute fails fast for empty tool IDs.""" + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.make_authenticated_request") as mock_req: + with pytest.raises(typer.Exit): + tools_execute(tool_id=" ", data_file=None) + mock_req.assert_not_called() + + def test_tools_execute_requires_valid_tool_name(self, mock_console) -> None: + """Test tools execute validates the fetched tool has a usable name.""" + tool_response = {"id": "tool-1", "name": None, "input_schema": {"type": "object", "properties": {}}} + + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.make_authenticated_request", return_value=tool_response): + with pytest.raises(typer.Exit): + tools_execute(tool_id="tool-1", data_file=None) + + def test_tools_execute_parses_string_schema(self, mock_console) -> None: + """Test tools execute accepts string-encoded schemas when they parse to an object.""" + schema_str = json.dumps({"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}) + tool_response = {"id": "tool-1", "name": "search_tool", "input_schema": schema_str} + rpc_response = {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + + with patch_functions( + "cforge.commands.resources.tools", + get_console=mock_console, + prompt_for_json_schema={"return_value": {"query": "hello"}}, + make_authenticated_request={"side_effect": [tool_response, rpc_response]}, + print_json={}, + ) as mocks: + tools_execute(tool_id="tool-1", data_file=None) + mocks.prompt_for_json_schema.assert_called_once_with(json.loads(schema_str), prefilled=None, prompt_optional=True) + + def test_tools_execute_rejects_string_schema_invalid_json(self, mock_console) -> None: + """Test tools execute rejects string-encoded schemas that are not valid JSON.""" + tool_response = {"id": "tool-1", "name": "search_tool", "input_schema": "{invalid json"} + + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.make_authenticated_request", return_value=tool_response) as mock_req: + with pytest.raises(typer.Exit): + tools_execute(tool_id="tool-1", data_file=None) + mock_req.assert_called_once() + + def test_tools_execute_rejects_invalid_schema_type(self, mock_console) -> None: + """Test tools execute rejects non-object schema containers.""" + tool_response = {"id": "tool-1", "name": "search_tool", "input_schema": ["bad"]} + + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.make_authenticated_request", return_value=tool_response): + with pytest.raises(typer.Exit): + tools_execute(tool_id="tool-1", data_file=None) + + def test_tools_execute_rejects_string_schema_not_object(self, mock_console) -> None: + """Test tools execute rejects schema strings that decode to non-object JSON.""" + tool_response = {"id": "tool-1", "name": "search_tool", "input_schema": json.dumps(["bad"])} + + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.make_authenticated_request", return_value=tool_response): + with pytest.raises(typer.Exit): + tools_execute(tool_id="tool-1", data_file=None) + + def test_tools_execute_defaults_schema_when_missing(self, mock_console) -> None: + """Test tools execute falls back to an empty object schema when schema is absent.""" + tool_response = {"id": "tool-1", "name": "search_tool"} + rpc_response = {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.prompt_for_json_schema", return_value={}) as mock_prompt: + with patch("cforge.commands.resources.tools.make_authenticated_request", side_effect=[tool_response, rpc_response]): + with patch("cforge.commands.resources.tools.print_json"): + tools_execute(tool_id="tool-1", data_file=None) + mock_prompt.assert_called_once_with({"type": "object", "properties": {}}, prefilled=None, prompt_optional=True) + + def test_tools_execute_defaults_empty_schema_to_object(self, mock_console) -> None: + """Test tools execute normalizes empty schema dictionaries to an object schema.""" + tool_response = {"id": "tool-1", "name": "search_tool", "input_schema": {}} + rpc_response = {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.prompt_for_json_schema", return_value={}) as mock_prompt: + with patch("cforge.commands.resources.tools.make_authenticated_request", side_effect=[tool_response, rpc_response]): + with patch("cforge.commands.resources.tools.print_json"): + tools_execute(tool_id="tool-1", data_file=None) + mock_prompt.assert_called_once_with({"type": "object", "properties": {}}, prefilled=None, prompt_optional=True) + + def test_tools_execute_jsonrpc_error_response(self, mock_console) -> None: + """Test tools execute surfaces JSON-RPC errors returned from /rpc.""" + tool_response = {"id": "tool-1", "name": "search_tool", "input_schema": {"type": "object", "properties": {}}} + rpc_error = {"jsonrpc": "2.0", "id": "1", "error": {"code": -32601, "message": "Tool not found"}} + + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.prompt_for_json_schema", return_value={}): + with patch("cforge.commands.resources.tools.make_authenticated_request", side_effect=[tool_response, rpc_error]): + with pytest.raises(typer.Exit): + tools_execute(tool_id="tool-1", data_file=None) + + def test_tools_execute_jsonrpc_error_without_code(self, mock_console) -> None: + """Test tools execute handles JSON-RPC errors that do not include a code.""" + tool_response = {"id": "tool-1", "name": "search_tool", "input_schema": {"type": "object", "properties": {}}} + rpc_error = {"jsonrpc": "2.0", "id": "1", "error": {"message": "Bad input"}} + + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.prompt_for_json_schema", return_value={}): + with patch("cforge.commands.resources.tools.make_authenticated_request", side_effect=[tool_response, rpc_error]): + with pytest.raises(typer.Exit): + tools_execute(tool_id="tool-1", data_file=None) + + def test_tools_execute_jsonrpc_error_non_dict(self, mock_console) -> None: + """Test tools execute handles JSON-RPC errors returned as non-dict values.""" + tool_response = {"id": "tool-1", "name": "search_tool", "input_schema": {"type": "object", "properties": {}}} + rpc_error = {"jsonrpc": "2.0", "id": "1", "error": "Something went wrong"} + + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.prompt_for_json_schema", return_value={}): + with patch("cforge.commands.resources.tools.make_authenticated_request", side_effect=[tool_response, rpc_error]): + with pytest.raises(typer.Exit): + tools_execute(tool_id="tool-1", data_file=None) + + def test_tools_execute_prints_raw_rpc_when_result_missing(self, mock_console) -> None: + """Test tools execute prints raw RPC payload when no result field is present.""" + tool_response = {"id": "tool-1", "name": "search_tool", "input_schema": {"type": "object", "properties": {}}} + rpc_response = {"jsonrpc": "2.0", "id": "1", "ok": True} + + with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): + with patch("cforge.commands.resources.tools.prompt_for_json_schema", return_value={}): + with patch("cforge.commands.resources.tools.make_authenticated_request", side_effect=[tool_response, rpc_response]): + with patch("cforge.commands.resources.tools.print_json") as mock_print: + tools_execute(tool_id="tool-1", data_file=None) + mock_print.assert_called_once_with(rpc_response, "Tool Result") + def test_tools_get_error(self, mock_console) -> None: """Test tools get error handling.""" with patch("cforge.commands.resources.tools.get_console", return_value=mock_console): diff --git a/tests/commands/settings/test_login.py b/tests/commands/settings/test_login.py index aac6966..1aa83cb 100644 --- a/tests/commands/settings/test_login.py +++ b/tests/commands/settings/test_login.py @@ -8,8 +8,8 @@ """ # Standard -import tempfile from pathlib import Path +import tempfile from unittest.mock import Mock, patch # Third-Party @@ -19,7 +19,8 @@ # First-Party from cforge.commands.settings.login import login -from cforge.common import AuthenticationError, make_authenticated_request +from cforge.common.errors import AuthenticationError +from cforge.common.http import make_authenticated_request class TestLoginCommand: @@ -143,9 +144,10 @@ class TestLoginWithProfiles: def test_login_saves_to_profile_specific_token_file(self, mock_base_url, mock_console, mock_settings) -> None: """Test that login saves token to profile-specific file when profile is active.""" - from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store from datetime import datetime + from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store + mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"access_token": "profile_token_123"} @@ -179,9 +181,10 @@ def test_login_saves_to_profile_specific_token_file(self, mock_base_url, mock_co def test_login_with_multiple_profiles(self, mock_base_url, mock_console, mock_settings) -> None: """Test that different profiles can have different tokens.""" - from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store from datetime import datetime + from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store + profile_id1 = "profile-1" profile_id2 = "profile-2" diff --git a/tests/commands/settings/test_logout.py b/tests/commands/settings/test_logout.py index 1a11b2a..b38e860 100644 --- a/tests/commands/settings/test_logout.py +++ b/tests/commands/settings/test_logout.py @@ -8,8 +8,8 @@ """ # Standard -import tempfile from pathlib import Path +import tempfile from unittest.mock import patch # Third-Party @@ -18,7 +18,8 @@ # First-Party from cforge.commands.settings.login import login from cforge.commands.settings.logout import logout -from cforge.common import AuthenticationError, make_authenticated_request +from cforge.common.errors import AuthenticationError +from cforge.common.http import make_authenticated_request class TestLogoutCommand: @@ -92,9 +93,10 @@ class TestLogoutWithProfiles: def test_logout_removes_profile_specific_token(self, mock_console, mock_settings) -> None: """Test that logout removes profile-specific token file.""" - from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store from datetime import datetime + from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store + # Create and save an active profile profile_id = "test-profile-logout" profile = AuthProfile( @@ -128,9 +130,10 @@ def test_logout_removes_profile_specific_token(self, mock_console, mock_settings def test_logout_only_removes_active_profile_token(self, mock_console, mock_settings) -> None: """Test that logout only removes the active profile's token, not others.""" - from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store from datetime import datetime + from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store + profile_id1 = "profile-1" profile_id2 = "profile-2" diff --git a/tests/common/test_console.py b/tests/common/test_console.py new file mode 100644 index 0000000..eb5fe43 --- /dev/null +++ b/tests/common/test_console.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +"""Tests for cforge.common.console.""" + +# First-Party +from cforge.common.console import get_app, get_console + + +class TestSingletons: + """Tests for singleton getter functions.""" + + def test_get_console_returns_console(self) -> None: + """Test that get_console returns a Console instance.""" + console = get_console() + assert console is not None + # Should return same instance + assert get_console() is console + + def test_get_app_returns_typer_app(self) -> None: + """Test that get_app returns a Typer instance.""" + app = get_app() + assert app is not None + # Should return same instance + assert get_app() is app diff --git a/tests/common/test_errors.py b/tests/common/test_errors.py new file mode 100644 index 0000000..4884300 --- /dev/null +++ b/tests/common/test_errors.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +"""Tests for cforge.common.errors.""" + +# First-Party +from cforge.common.errors import AuthenticationError, CLIError + + +class TestErrors: + """Tests for custom error classes.""" + + def test_cli_error(self) -> None: + """Test CLIError exception.""" + error = CLIError("Test error") + assert str(error) == "Test error" + + def test_authentication_error(self) -> None: + """Test AuthenticationError exception.""" + error = AuthenticationError("Auth failed") + assert str(error) == "Auth failed" + assert isinstance(error, CLIError) diff --git a/tests/common/test_http.py b/tests/common/test_http.py new file mode 100644 index 0000000..41c4e91 --- /dev/null +++ b/tests/common/test_http.py @@ -0,0 +1,630 @@ +# -*- coding: utf-8 -*- +"""Tests for cforge.common.http.""" + +# Standard +from pathlib import Path +import stat +import tempfile +from unittest.mock import Mock, patch + +# Third-Party +import pytest +import requests + +# First-Party +from cforge.common.errors import AuthenticationError, CLIError +from cforge.common.http import get_auth_token, get_token_file, load_token, make_authenticated_request, save_token +from tests.conftest import mock_client_login + + +class TestTokenManagement: + """Tests for token management functions.""" + + def test_get_token_file(self, mock_settings) -> None: + """Test getting the token file path.""" + token_file = get_token_file() + assert isinstance(token_file, Path) + assert str(token_file).endswith("token") + assert token_file.parent == mock_settings.contextforge_home + + def test_get_token_file_with_active_profile(self, mock_settings) -> None: + """Test getting the token file path uses active profile when available.""" + from datetime import datetime + + from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store + + # Create and save an active profile + profile_id = "active-profile-456" + profile = AuthProfile( + id=profile_id, + name="Active Profile", + email="active@example.com", + apiUrl="https://api.example.com", + isActive=True, + createdAt=datetime.now(), + ) + store = ProfileStore( + profiles={profile_id: profile}, + activeProfileId=profile_id, + ) + save_profile_store(store) + + # get_token_file should use the active profile + token_file = get_token_file() + assert str(token_file).endswith(f"token.{profile_id}") + + def test_save_and_load_token(self) -> None: + """Test saving and loading a token.""" + test_token = "test_token_123" + + with tempfile.NamedTemporaryFile() as temp_token_file: + with patch("cforge.common.http.get_token_file", return_value=Path(temp_token_file.name)): + save_token(test_token) + loaded_token = load_token() + + assert loaded_token == test_token + + def test_save_and_load_token_with_active_profile(self, mock_settings) -> None: + """Test saving and loading a token with an active profile.""" + from datetime import datetime + + from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store + + test_token = "profile_token_456" + profile_id = "test-profile-789" + + # Create and save an active profile + profile = AuthProfile( + id=profile_id, + name="Test Profile", + email="test@example.com", + apiUrl="https://api.example.com", + isActive=True, + createdAt=datetime.now(), + ) + store = ProfileStore( + profiles={profile_id: profile}, + activeProfileId=profile_id, + ) + save_profile_store(store) + + # Save and load token - should use profile-specific file + save_token(test_token) + loaded_token = load_token() + + assert loaded_token == test_token + + # Verify it was saved to profile-specific file + token_file = mock_settings.contextforge_home / f"token.{profile_id}" + assert token_file.exists() + + def test_save_token_different_profiles(self, mock_settings) -> None: + """Test that different profiles have separate token files.""" + from datetime import datetime + + from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store + + token1 = "token_for_profile_1" + token2 = "token_for_profile_2" + profile_id1 = "profile-1" + profile_id2 = "profile-2" + + # Save token for profile 1 + profile1 = AuthProfile( + id=profile_id1, + name="Profile 1", + email="user1@example.com", + apiUrl="https://api1.example.com", + isActive=True, + createdAt=datetime.now(), + ) + store1 = ProfileStore( + profiles={profile_id1: profile1}, + activeProfileId=profile_id1, + ) + save_profile_store(store1) + save_token(token1) + + # Save token for profile 2 + profile2 = AuthProfile( + id=profile_id2, + name="Profile 2", + email="user2@example.com", + apiUrl="https://api2.example.com", + isActive=True, + createdAt=datetime.now(), + ) + store2 = ProfileStore( + profiles={profile_id2: profile2}, + activeProfileId=profile_id2, + ) + save_profile_store(store2) + save_token(token2) + + # Verify both tokens exist in separate files + token_file1 = mock_settings.contextforge_home / f"token.{profile_id1}" + token_file2 = mock_settings.contextforge_home / f"token.{profile_id2}" + + assert token_file1.exists() + assert token_file2.exists() + assert token_file1.read_text() == token1 + assert token_file2.read_text() == token2 + assert token1 != token2 + + def test_load_token_nonexistent(self, tmp_path: Path) -> None: + """Test loading a token when file doesn't exist.""" + nonexistent_file = tmp_path / "nonexistent" / "token" + + with patch("cforge.common.http.get_token_file", return_value=nonexistent_file): + token = load_token() + + assert token is None + + def test_load_token_nonexistent_profile(self, mock_settings) -> None: + """Test loading a token for a profile that doesn't have a token file.""" + from datetime import datetime + + from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store + + profile_id = "nonexistent-profile" + + # Create an active profile but don't create a token file + profile = AuthProfile( + id=profile_id, + name="Test Profile", + email="test@example.com", + apiUrl="https://api.example.com", + isActive=True, + createdAt=datetime.now(), + ) + store = ProfileStore( + profiles={profile_id: profile}, + activeProfileId=profile_id, + ) + save_profile_store(store) + + # Try to load token - should return None since file doesn't exist + token = load_token() + + assert token is None + + +class TestBaseUrl: + """Tests for get_base_url function.""" + + def test_get_base_url_with_active_profile(self, mock_settings) -> None: + """Test get_base_url returns profile's API URL when active profile exists.""" + from datetime import datetime + + from cforge.common.http import get_base_url + from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store + + # Create and save a profile + profile = AuthProfile( + id="profile-1", + name="Test Profile", + email="test@example.com", + apiUrl="https://custom-api.example.com", + isActive=True, + createdAt=datetime.now(), + ) + store = ProfileStore( + profiles={"profile-1": profile}, + activeProfileId="profile-1", + ) + save_profile_store(store) + + # Get base URL should return the profile's API URL + base_url = get_base_url() + assert base_url == "https://custom-api.example.com" + + def test_get_base_url_without_active_profile(self, mock_settings) -> None: + """Test get_base_url returns default URL when no active profile.""" + from cforge.common.http import get_base_url + + # No profile saved, should use settings + base_url = get_base_url() + assert base_url == f"http://{mock_settings.host}:{mock_settings.port}" + + +class TestAuthentication: + """Tests for authentication functions.""" + + def test_get_auth_token_from_env(self, mock_settings) -> None: + """Test getting auth token from environment variable.""" + # Create a new settings instance with token + mock_settings.mcpgateway_bearer_token = "env_token" + with patch("cforge.common.http.load_token", return_value=None): + token = get_auth_token() + + assert token == "env_token" + + def test_get_auth_token_from_file(self, mock_settings) -> None: + """Test getting auth token from file when env var not set.""" + # mock_settings already has mcpgateway_bearer_token=None + with patch("cforge.common.http.load_token", return_value="file_token"): + token = get_auth_token() + + assert token == "file_token" + + def test_get_auth_token_none(self, mock_settings) -> None: + """Test getting auth token when none available.""" + # mock_settings already has mcpgateway_bearer_token=None + with patch("cforge.common.http.load_token", return_value=None): + token = get_auth_token() + + assert token is None + + +class TestAutoLogin: + """Tests for automatic login functionality.""" + + def test_attempt_auto_login_no_profile(self, mock_settings): + """Test auto-login when no profile is active.""" + from cforge.common.http import attempt_auto_login + + token = attempt_auto_login() + assert token is None + + def test_attempt_auto_login_no_credentials(self, mock_settings): + """Test auto-login when credentials are not available.""" + from datetime import datetime + + from cforge.common.http import attempt_auto_login + from cforge.profile_utils import AuthProfile + + mock_profile = AuthProfile( + id="test-profile", + name="Test", + email="test@example.com", + apiUrl="http://localhost:4444", + isActive=True, + createdAt=datetime.now(), + ) + + with patch("cforge.common.http.get_active_profile", return_value=mock_profile): + with patch("cforge.common.http.load_profile_credentials", return_value=None): + token = attempt_auto_login() + assert token is None + + def test_attempt_auto_login_missing_email(self, mock_settings): + """Test auto-login when email is missing from credentials.""" + from datetime import datetime + + from cforge.common.http import attempt_auto_login + from cforge.profile_utils import AuthProfile + + mock_profile = AuthProfile( + id="test-profile", + name="Test", + email="test@example.com", + apiUrl="http://localhost:4444", + isActive=True, + createdAt=datetime.now(), + ) + + with patch("cforge.common.http.get_active_profile", return_value=mock_profile): + with patch("cforge.common.http.load_profile_credentials", return_value={"password": "test"}): + token = attempt_auto_login() + assert token is None + + def test_attempt_auto_login_missing_password(self, mock_settings): + """Test auto-login when password is missing from credentials.""" + from datetime import datetime + + from cforge.common.http import attempt_auto_login + from cforge.profile_utils import AuthProfile + + mock_profile = AuthProfile( + id="test-profile", + name="Test", + email="test@example.com", + apiUrl="http://localhost:4444", + isActive=True, + createdAt=datetime.now(), + ) + + with patch("cforge.common.http.get_active_profile", return_value=mock_profile): + with patch("cforge.common.http.load_profile_credentials", return_value={"email": "test@example.com"}): + token = attempt_auto_login() + assert token is None + + @patch("cforge.common.http.requests.post") + def test_attempt_auto_login_success(self, mock_post, mock_settings): + """Test successful auto-login.""" + from datetime import datetime + + from cforge.common.http import attempt_auto_login, load_token + from cforge.profile_utils import AuthProfile + + mock_profile = AuthProfile( + id="test-profile", + name="Test", + email="test@example.com", + apiUrl="http://localhost:4444", + isActive=True, + createdAt=datetime.now(), + ) + + # Mock successful login response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"access_token": "auto-login-token"} + mock_post.return_value = mock_response + + with patch("cforge.common.http.get_active_profile", return_value=mock_profile): + with patch("cforge.common.http.load_profile_credentials", return_value={"email": "test@example.com", "password": "test-pass"}): + token = attempt_auto_login() + assert token == "auto-login-token" + + # Verify token was saved + saved_token = load_token() + assert saved_token == "auto-login-token" + + @patch("cforge.common.http.requests.post") + def test_attempt_auto_login_failed_login(self, mock_post, mock_settings): + """Test auto-login when login fails.""" + from datetime import datetime + + from cforge.common.http import attempt_auto_login + from cforge.profile_utils import AuthProfile + + mock_profile = AuthProfile( + id="test-profile", + name="Test", + email="test@example.com", + apiUrl="http://localhost:4444", + isActive=True, + createdAt=datetime.now(), + ) + + # Mock failed login response + mock_response = Mock() + mock_response.status_code = 401 + mock_post.return_value = mock_response + + with patch("cforge.common.http.get_active_profile", return_value=mock_profile): + with patch("cforge.common.http.load_profile_credentials", return_value={"email": "test@example.com", "password": "wrong-pass"}): + token = attempt_auto_login() + assert token is None + + @patch("cforge.common.http.requests.post") + def test_attempt_auto_login_no_token_in_response(self, mock_post, mock_settings): + """Test auto-login when response doesn't contain token.""" + from datetime import datetime + + from cforge.common.http import attempt_auto_login + from cforge.profile_utils import AuthProfile + + mock_profile = AuthProfile( + id="test-profile", + name="Test", + email="test@example.com", + apiUrl="http://localhost:4444", + isActive=True, + createdAt=datetime.now(), + ) + + # Mock response without token + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {} + mock_post.return_value = mock_response + + with patch("cforge.common.http.get_active_profile", return_value=mock_profile): + with patch("cforge.common.http.load_profile_credentials", return_value={"email": "test@example.com", "password": "test-pass"}): + token = attempt_auto_login() + assert token is None + + @patch("cforge.common.http.requests.post") + def test_attempt_auto_login_request_exception(self, mock_post, mock_settings): + """Test auto-login when request raises exception.""" + from datetime import datetime + + from cforge.common.http import attempt_auto_login + from cforge.profile_utils import AuthProfile + + mock_profile = AuthProfile( + id="test-profile", + name="Test", + email="test@example.com", + apiUrl="http://localhost:4444", + isActive=True, + createdAt=datetime.now(), + ) + + # Mock request exception + mock_post.side_effect = Exception("Connection error") + + with patch("cforge.common.http.get_active_profile", return_value=mock_profile): + with patch("cforge.common.http.load_profile_credentials", return_value={"email": "test@example.com", "password": "test-pass"}): + token = attempt_auto_login() + assert token is None + + def test_get_auth_token_with_auto_login(self, mock_settings): + """Test that get_auth_token attempts auto-login when no token is available.""" + from cforge.common.http import get_auth_token + + # Mock no env token and no file token, but successful auto-login + with patch("cforge.common.http.load_token", return_value=None): + with patch("cforge.common.http.attempt_auto_login", return_value="auto-token"): + token = get_auth_token() + assert token == "auto-token" + + +class TestMakeAuthenticatedRequest: + """Tests for make_authenticated_request function using a server mock.""" + + def test_request_no_auth_raises_error_when_server_requires_it(self, mock_settings) -> None: + """Test that request without auth raises AuthenticationError when server requires it.""" + # Ensure no token is available + with patch("cforge.common.http.load_token", return_value=None): + # Mock a 401 response from server (authentication required) + mock_response = Mock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + + with patch("cforge.common.http.requests.request", return_value=mock_response): + with pytest.raises(AuthenticationError) as exc_info: + make_authenticated_request("GET", "/test") + + assert "Authentication required but not configured" in str(exc_info.value) + + def test_request_without_auth_succeeds_on_unauthenticated_server(self, mock_settings) -> None: + """Test that request without auth succeeds when server doesn't require it.""" + # Ensure no token is available + with patch("cforge.common.http.load_token", return_value=None): + # Mock a successful response from server (no auth required) + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "success"} + + with patch("cforge.common.http.requests.request", return_value=mock_response) as mock_req: + result = make_authenticated_request("GET", "/test") + + # Verify the request was made without Authorization header + call_args = mock_req.call_args + assert "Authorization" not in call_args[1]["headers"] + assert result == {"result": "success"} + + def test_request_with_bearer_token(self, mock_client, mock_settings) -> None: + """Test successful request with Bearer token.""" + mock_client.reset_mock() + with mock_client_login(mock_client): + mock_req = mock_client.request + result = make_authenticated_request("GET", "/tools") + + # Verify request was made correctly + mock_req.assert_called_once() + call_args = mock_req.call_args + assert call_args[1]["method"] == "GET" + assert call_args[1]["url"] == f"http://{mock_client.settings.host}:{mock_client.settings.port}/tools" + assert call_args[1]["headers"]["Authorization"] == f"Bearer {mock_client.settings.mcpgateway_bearer_token}" + assert call_args[1]["headers"]["Content-Type"] == "application/json" + assert isinstance(result, list) + + def test_request_with_basic_auth(self, mock_settings) -> None: + """Test request with Basic auth token.""" + # Set up settings with Basic auth token + mock_settings.mcpgateway_bearer_token = "Basic dGVzdDp0ZXN0" + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "success"} + + with patch("cforge.common.http.requests.request", return_value=mock_response) as mock_req: + make_authenticated_request("POST", "/api/test", json_data={"data": "value"}) + + # Verify Basic auth is passed as-is + call_args = mock_req.call_args + assert call_args[1]["headers"]["Authorization"] == "Basic dGVzdDp0ZXN0" + + def test_request_api_error(self, mock_settings) -> None: + """Test that API errors are properly raised.""" + mock_settings.mcpgateway_bearer_token = "test_token" + + mock_response = Mock() + mock_response.status_code = 404 + mock_response.text = "Not found" + + with patch("cforge.common.http.requests.request", return_value=mock_response): + with pytest.raises(CLIError) as exc_info: + make_authenticated_request("GET", "/api/missing") + + assert "API request failed (404)" in str(exc_info.value) + assert "Not found" in str(exc_info.value) + + def test_request_connection_error(self, mock_settings) -> None: + """Test that connection errors are properly raised.""" + mock_settings.mcpgateway_bearer_token = "test_token" + + with patch("cforge.common.http.requests.request", side_effect=requests.ConnectionError("Connection refused")): + with pytest.raises(CLIError) as exc_info: + make_authenticated_request("GET", "/api/test") + + assert "Failed to connect to gateway" in str(exc_info.value) + assert "Connection refused" in str(exc_info.value) + + +class TestTokenFilePermissions: + """Tests for token file permission handling.""" + + def test_save_token_creates_parent_dirs(self) -> None: + """Test that save_token creates parent directories.""" + with tempfile.TemporaryDirectory() as temp_dir: + token_path = Path(temp_dir) / "nested" / "dirs" / "token" + + with patch("cforge.common.http.get_token_file", return_value=token_path): + save_token("test_token") + + assert token_path.exists() + assert token_path.read_text() == "test_token" + + def test_save_token_sets_permissions(self) -> None: + """Test that save_token sets restrictive permissions.""" + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + token_path = Path(temp_file.name) + + try: + with patch("cforge.common.http.get_token_file", return_value=token_path): + save_token("test_token") + + # Check permissions are 0o600 (read/write for owner only) + file_stat = token_path.stat() + file_mode = stat.S_IMODE(file_stat.st_mode) + assert file_mode == 0o600 + finally: + token_path.unlink(missing_ok=True) + + +class TestMakeAuthenticatedRequestIntegration: + """Integration tests for make_authenticated_request with real server. + + These tests use the session_settings fixture which provides a real + running mcpgateway server and properly configured settings. This validates + that the client code actually works with the server, not just that it + makes the right calls. + """ + + def test_request_with_bearer_token_to_health_endpoint(self, mock_client) -> None: + """Test successful authenticated request to /health endpoint.""" + + # Make a request to the health endpoint (no auth required) + make_authenticated_request("GET", "/health") + + # Make a request to an authorized endpoint before login + with pytest.raises(CLIError): + make_authenticated_request("GET", "/tools") + + # Log in and try again + with mock_client_login(mock_client): + + # Make a real HTTP request to the session server's health endpoint + result = make_authenticated_request("GET", "/tools") + + # The tools endpoint should return a successful response + assert result is not None + assert isinstance(result, list) + + def test_request_to_nonexistent_endpoint_raises_error(self, authorized_mock_client) -> None: + """Test that requesting a nonexistent endpoint raises CLIError.""" + # Try to request an endpoint that doesn't exist + with pytest.raises(CLIError) as exc_info: + make_authenticated_request("GET", "/api/this/endpoint/does/not/exist") + + # Should get a 404 error + assert "404" in str(exc_info.value) or "not found" in str(exc_info.value).lower() + + def test_request_with_params_and_json_data(self, authorized_mock_client) -> None: + """Test request with query parameters. + + This test verifies that parameters are correctly passed through + to the server in a real HTTP request. + """ + # Test that we can make requests with params + # The health endpoint may not use params, but we can verify the request succeeds + result = make_authenticated_request("GET", "/health", params={"test": "value"}) + + # Should still get a valid response even with unused params + assert result is not None + assert isinstance(result, dict) diff --git a/tests/common/test_prompting.py b/tests/common/test_prompting.py new file mode 100644 index 0000000..216fdb8 --- /dev/null +++ b/tests/common/test_prompting.py @@ -0,0 +1,1776 @@ +# -*- coding: utf-8 -*- +"""Tests for cforge.common.prompting.""" + +# Standard +from datetime import datetime +from typing import Annotated, Any, Dict, List, Optional +from unittest.mock import patch + +# Third-Party +from pydantic import BaseModel, Field +import pytest + +# First-Party +from cforge.common.errors import CLIError +from cforge.common.prompting import ( + _build_prompt_text, + _infer_schema_type, + _INT_SENTINEL_DEFAULT, + prompt_for_json_schema, + prompt_for_schema, + _resolve_effective_schema, + _resolve_ref_schema, + _resolve_schema_type, + _schema_contains_ref, + _strip_schema_internal_properties, +) + + +class TestBuildPromptText: + """Tests for internal prompt text formatting helpers.""" + + def test_build_prompt_text_falls_back_for_unserializable_default(self) -> None: + """Unserializable defaults should still render via string conversion.""" + + class UnserializableDefault: + def __str__(self) -> str: # noqa: D105 - local test helper + return "" + + result = _build_prompt_text( + field_name="field", + description=None, + default=UnserializableDefault(), + default_is_set=True, + is_required=True, + include_falsy_default=True, + ) + assert "" in result + + +class TestPromptForSchema: + """Tests for prompt_for_schema function.""" + + def test_prompt_with_prefilled_values(self, mock_console) -> None: + """Test that prefilled values are used and not prompted.""" + + class TestSchema(BaseModel): + name: str + description: str + + prefilled = {"name": "test_name", "description": "test_desc"} + + result = prompt_for_schema(TestSchema, prefilled=prefilled) + + # Should return prefilled values without prompting + assert result == prefilled + # Console should show the prefilled values + assert mock_console.print.call_count >= 3 # Header + 2 fields + + def test_prompt_with_prefilled_datetime_and_none(self, mock_console) -> None: + """Prefilled non-string values should not be rejected by prompt schema validation.""" + + class TestSchema(BaseModel): + name: str + created_at: datetime + last_used: Optional[datetime] = None + + prefilled = {"name": "test", "created_at": datetime.now(), "last_used": None} + + result = prompt_for_schema(TestSchema, prefilled=prefilled) + + assert result == prefilled + + def test_prompt_skips_internal_fields(self, mock_console) -> None: + """Test that internal fields are skipped.""" + + class TestSchema(BaseModel): + name: str + model_config: dict = {} # Should be skipped + auth_value: str = "" # Should be skipped + + prefilled = {"name": "test"} + + result = prompt_for_schema(TestSchema, prefilled=prefilled) + + # Should only have the name field + assert "name" in result + assert "model_config" not in result + assert "auth_value" not in result + + def test_prompt_with_string_field(self, mock_console) -> None: + """Test prompting for string fields.""" + + class TestSchema(BaseModel): + name: str = Field(description="The name") + + with patch("typer.prompt", return_value="user_input"): + result = prompt_for_schema(TestSchema) + + assert result["name"] == "user_input" + + def test_prompt_with_optional_field(self, mock_console) -> None: + """Test prompting for optional fields.""" + + class TestSchema(BaseModel): + required_field: str + optional_field: Optional[str] = None + + with patch("typer.prompt", side_effect=["required_value", ""]): + result = prompt_for_schema(TestSchema) + + assert result["required_field"] == "required_value" + # Optional field with empty input should not be in result + assert "optional_field" not in result or result["optional_field"] == "" + + def test_prompt_with_bool_field(self, mock_console) -> None: + """Test prompting for boolean fields.""" + + class TestSchema(BaseModel): + enabled: bool + + with patch("typer.confirm", return_value=True): + with patch("typer.prompt", return_value=True): + result = prompt_for_schema(TestSchema) + + assert result["enabled"] is True + + def test_prompt_with_optional_bool_field_declined(self, mock_console) -> None: + """Test prompting for optional boolean field that is declined.""" + + class TestSchema(BaseModel): + enabled: Optional[bool] = None + + # First confirm returns False (don't include field) + with patch("typer.confirm", return_value=False): + result = prompt_for_schema(TestSchema) + + # Field should not be in result when declined + assert "enabled" not in result + + def test_prompt_with_int_field(self, mock_console) -> None: + """Test prompting for integer fields.""" + + class TestSchema(BaseModel): + count: int + + with patch("typer.prompt", return_value=42): + result = prompt_for_schema(TestSchema) + + assert result["count"] == 42 + + def test_prompt_with_int_field_empty_input(self, mock_console) -> None: + """Test prompting for optional integer field with empty input.""" + + class TestSchema(BaseModel): + count: Optional[int] = None + + # Return sentinel to simulate skipping optional field + with patch("typer.prompt", return_value=_INT_SENTINEL_DEFAULT): + result = prompt_for_schema(TestSchema) + + # Field should not be in result when empty + assert "count" not in result + + def test_prompt_with_list_field(self, mock_console) -> None: + """Test prompting for list fields.""" + + class TestSchema(BaseModel): + tags: List[str] + + with patch("typer.prompt", return_value="tag1, tag2, tag3"): + result = prompt_for_schema(TestSchema) + + assert result["tags"] == ["tag1", "tag2", "tag3"] + + def test_prompt_with_list_field_empty(self, mock_console) -> None: + """Test prompting for list fields with empty input.""" + + class TestSchema(BaseModel): + tags: Optional[List[str]] = None + + with patch("typer.prompt", return_value=""): + result = prompt_for_schema(TestSchema) + + # Empty input for list should not add the field + assert "tags" not in result or result.get("tags") is None + + def test_prompt_with_optional_nested_object_declined(self, mock_console) -> None: + """Optional nested objects should be omitted when declined.""" + + class SubSchema(BaseModel): + url: str + + class TestSchema(BaseModel): + config: Optional[SubSchema] = None + + with patch("typer.confirm", return_value=False): + result = prompt_for_schema(TestSchema) + + assert "config" not in result + + def test_prompt_with_optional_nested_object_accepted(self, mock_console) -> None: + """Optional nested objects should prompt when accepted.""" + + class SubSchema(BaseModel): + url: str + + class TestSchema(BaseModel): + config: Optional[SubSchema] = None + + with patch("typer.confirm", return_value=True), patch("typer.prompt", return_value="https://example.com"): + result = prompt_for_schema(TestSchema) + + assert result == {"config": {"url": "https://example.com"}} + + def test_prompt_dict_str_str(self, mock_console) -> None: + """Test prompting for a string to string dict""" + + class TestSchema(BaseModel): + key: Dict[str, str] + + with patch("typer.confirm", side_effect=["y", "y", ""]), patch("typer.prompt", side_effect=["k1", "v1", "k2", "v2"]): + result = prompt_for_schema(TestSchema) + + # Empty input for list should not add the field + assert result == { + "key": {"k1": "v1", "k2": "v2"}, + } + + def test_prompt_with_nested_dicts(self, mock_console) -> None: + """Test prompting for a nested dict with dict values""" + + class SubSchema(BaseModel): + num: int + + class TestSchema(BaseModel): + key: Dict[str, Any] + sub: SubSchema + sub_dict: Dict[str, SubSchema] + + with patch("typer.confirm", side_effect=["y", "y", "", "y", ""]), patch("typer.prompt", side_effect=["k1", '{"foo": 1}', "k2", "[1, 2, 3]", 42, "a-num", 123]): + result = prompt_for_schema(TestSchema) + + # Empty input for list should not add the field + assert result == { + "key": {"k1": {"foo": 1}, "k2": [1, 2, 3]}, + "sub": {"num": 42}, + "sub_dict": {"a-num": {"num": 123}}, + } + + def test_prompt_list_of_sub_models(self, mock_console) -> None: + """Test prompting for a list of sub pydantic models""" + + class SubSchema(BaseModel): + num: int + + class TestSchema(BaseModel): + nums: List[SubSchema] + + with patch("typer.confirm", side_effect=["y", "y", ""]), patch("typer.prompt", side_effect=[1, 2]): + result = prompt_for_schema(TestSchema) + + # Empty input for list should not add the field + assert result == {"nums": [{"num": 1}, {"num": 2}]} + + def test_prompt_with_default(self, mock_console) -> None: + """Test prompting with defaults and make sure prompt string added.""" + + class TestSchema(BaseModel): + name: str = "foobar" + some_val: int = 42 + + with patch("typer.prompt", side_effect=["", 42]) as prompt_mock: + prompt_for_schema(TestSchema) + assert prompt_mock.call_count == 2 + assert prompt_mock.call_args_list[0][1]["default"] == "foobar" + assert prompt_mock.call_args_list[1][1]["default"] == 42 + assert any("foobar" in call[0][0] for call in mock_console.print.call_args_list) + assert any("42" in call[0][0] for call in mock_console.print.call_args_list) + + def test_prompt_missing_required_string(self, mock_console) -> None: + """Test that an exception is raised if a required string is unset.""" + + class TestSchema(BaseModel): + foo: str + + with patch("typer.prompt", return_value=""): + with pytest.raises(CLIError): + prompt_for_schema(TestSchema) + + +class TestPromptForJsonSchema: + """Tests for prompt_for_json_schema function.""" + + def test_prompt_for_json_schema_required_only_with_prefilled(self, mock_console) -> None: + """Test prompting only missing required fields when prefilled data exists.""" + schema = { + "type": "object", + "properties": { + "query": {"type": "string"}, + "limit": {"type": "integer"}, + }, + "required": ["query"], + } + prefilled = {"limit": 10} + + with patch("typer.prompt", return_value="search term") as mock_prompt: + result = prompt_for_json_schema(schema, prefilled=prefilled, prompt_optional=False) + + assert result["query"] == "search term" + assert result["limit"] == 10 + assert mock_prompt.call_count == 1 + + def test_prompt_for_json_schema_skips_optional_fields_when_required_only(self, mock_console) -> None: + """Test optional fields are skipped entirely in required-only mode.""" + schema = { + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + } + + with patch("typer.prompt") as mock_prompt: + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {} + mock_prompt.assert_not_called() + + def test_prompt_for_json_schema_prompts_optional_fields(self, mock_console) -> None: + """Test optional fields are prompted in full interactive mode.""" + schema = { + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + } + + with patch("typer.prompt", return_value="search term"): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result["query"] == "search term" + + def test_prompt_for_json_schema_prefilled_nested_object_prompts_missing_required(self, mock_console) -> None: + """Test nested required fields are prompted when parent object is prefilled.""" + schema = { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "timeout": {"type": "integer"}, + }, + "required": ["name"], + } + }, + "required": ["config"], + } + prefilled = {"config": {"timeout": 30}} + + with patch("typer.prompt", return_value="tool-name") as mock_prompt: + result = prompt_for_json_schema(schema, prefilled=prefilled, prompt_optional=False) + + assert result == {"config": {"timeout": 30, "name": "tool-name"}} + assert mock_prompt.call_count == 1 + + def test_prompt_for_json_schema_requires_object_schema(self, mock_console) -> None: + """Test non-object root schema raises a CLIError.""" + schema = {"type": "string"} + + with pytest.raises(CLIError): + prompt_for_json_schema(schema) + + def test_prompt_for_json_schema_resolves_ref_object(self, mock_console) -> None: + """Test object fields referenced via $ref are prompted as objects.""" + schema = { + "type": "object", + "$defs": { + "NestedArgs": { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + "required": ["name"], + } + }, + "properties": { + "config": {"$ref": "#/$defs/NestedArgs"}, + }, + "required": ["config"], + } + + with patch("typer.prompt", return_value="nested-name") as mock_prompt: + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"config": {"name": "nested-name"}} + assert mock_prompt.call_count == 1 + + def test_prompt_for_json_schema_resolves_ref_array(self, mock_console) -> None: + """Test array fields referenced via $ref are prompted as arrays.""" + schema = { + "type": "object", + "$defs": { + "TagList": { + "type": "array", + "items": {"type": "string"}, + } + }, + "properties": { + "tags": {"$ref": "#/$defs/TagList"}, + }, + "required": ["tags"], + } + + with patch("typer.prompt", return_value="one, two"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"tags": ["one", "two"]} + + def test_prompt_for_json_schema_missing_ref_raises(self, mock_console) -> None: + """Test missing local $ref path raises a CLIError.""" + schema = { + "type": "object", + "$defs": {}, + "properties": { + "config": {"$ref": "#/$defs/DoesNotExist"}, + }, + "required": ["config"], + } + + with pytest.raises(CLIError, match="Schema reference not found"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_external_ref_raises(self, mock_console) -> None: + """Test non-local $ref values are rejected.""" + schema = { + "type": "object", + "properties": { + "config": {"$ref": "https://example.com/schema.json#/$defs/Config"}, + }, + "required": ["config"], + } + + with pytest.raises(CLIError, match="Only local schema references are supported"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_cyclic_ref_raises(self, mock_console) -> None: + """Test cyclic $ref graphs raise a CLIError.""" + schema = { + "type": "object", + "$defs": { + "Node": {"$ref": "#/$defs/Node"}, + }, + "properties": { + "node": {"$ref": "#/$defs/Node"}, + }, + "required": ["node"], + } + + with pytest.raises(CLIError, match="Cyclic schema reference detected"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_non_object_input_schema_raises(self, mock_console) -> None: + """Test non-dictionary schemas are rejected.""" + with pytest.raises(CLIError, match="Input schema must be a JSON object"): + prompt_for_json_schema(["not", "an", "object"]) # type: ignore[arg-type] + + def test_prompt_for_json_schema_non_object_prefilled_raises(self, mock_console) -> None: + """Test non-dictionary prefilled payloads are rejected.""" + schema = { + "type": "object", + "properties": {"query": {"type": "string"}}, + } + + with pytest.raises(CLIError, match="Prefilled input must be a JSON object"): + prompt_for_json_schema(schema, prefilled=["bad"]) # type: ignore[arg-type] + + def test_prompt_for_json_schema_ref_resolving_to_scalar_raises(self, mock_console) -> None: + """Test $ref pointers resolving to scalar nodes are rejected.""" + schema = { + "type": "object", + "$defs": { + "Config": { + "type": "object", + "properties": {"kind": {"type": "string"}}, + } + }, + "properties": { + "config": {"$ref": "#/$defs/Config/type"}, + }, + "required": ["config"], + } + + with pytest.raises(CLIError, match="does not resolve to an object schema"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_ref_invalid_array_index_raises(self, mock_console) -> None: + """Test invalid array indexes in local $ref pointers are rejected.""" + schema = { + "type": "object", + "x-variants": [{"type": "string"}], + "properties": { + "value": {"$ref": "#/x-variants/not-an-index"}, + }, + "required": ["value"], + } + + with pytest.raises(CLIError, match="Invalid array index in schema reference"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_handles_nullable_integer_type_lists(self, mock_console) -> None: + """Test `type` lists like [null, integer] prompt using the concrete type.""" + schema = { + "type": "object", + "properties": { + "limit": {"type": ["null", "integer"]}, + }, + "required": ["limit"], + } + + with patch("typer.prompt", return_value=7): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"limit": 7} + + def test_prompt_for_json_schema_handles_nullable_integer_any_of(self, mock_console) -> None: + """Test anyOf nullable integers prompt using integer type instead of string fallback.""" + schema = { + "type": "object", + "properties": { + "limit": { + "anyOf": [{"type": "null"}, {"type": "integer"}], + }, + }, + "required": ["limit"], + } + + with patch("typer.prompt", return_value=7) as mock_prompt: + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"limit": 7} + assert mock_prompt.call_args.kwargs.get("type") is int + + def test_prompt_for_json_schema_handles_one_of_object_with_null(self, mock_console) -> None: + """Test oneOf object/null schemas prompt nested object fields.""" + schema = { + "type": "object", + "properties": { + "config": { + "oneOf": [ + { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + {"type": "null"}, + ] + }, + }, + "required": ["config"], + } + + with patch("typer.prompt", return_value="alpha"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"config": {"name": "alpha"}} + + def test_prompt_for_json_schema_handles_any_of_array_with_null(self, mock_console) -> None: + """Test anyOf array/null schemas are prompted as arrays.""" + schema = { + "type": "object", + "properties": { + "tags": { + "anyOf": [ + {"type": "null"}, + {"type": "array", "items": {"type": "string"}}, + ] + }, + }, + "required": ["tags"], + } + + with patch("typer.confirm", side_effect=[True, False]), patch("typer.prompt", return_value="tag-one"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"tags": ["tag-one"]} + + def test_prompt_for_json_schema_handles_one_of_integer_or_string_string_value(self, mock_console) -> None: + """Test oneOf with multiple non-null variants accepts valid string input.""" + schema = { + "type": "object", + "properties": { + "value": { + "oneOf": [{"type": "integer"}, {"type": "string"}], + }, + }, + "required": ["value"], + } + + with patch("typer.prompt", return_value="alpha"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"value": "alpha"} + + def test_prompt_for_json_schema_handles_one_of_integer_or_string_integer_value(self, mock_console) -> None: + """Test oneOf with multiple non-null variants accepts parsed integer input.""" + schema = { + "type": "object", + "properties": { + "value": { + "oneOf": [{"type": "integer"}, {"type": "string"}], + }, + }, + "required": ["value"], + } + + with patch("typer.prompt", return_value="7"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"value": 7} + + def test_prompt_for_json_schema_handles_type_array_integer_or_string_string_value(self, mock_console) -> None: + """Test type-array unions accept valid string input.""" + schema = { + "type": "object", + "properties": { + "value": { + "type": ["integer", "string"], + }, + }, + "required": ["value"], + } + + with patch("typer.prompt", return_value="alpha"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"value": "alpha"} + + def test_prompt_for_json_schema_handles_type_array_integer_or_string_integer_value(self, mock_console) -> None: + """Test type-array unions accept parsed integer input.""" + schema = { + "type": "object", + "properties": { + "value": { + "type": ["integer", "string"], + }, + }, + "required": ["value"], + } + + with patch("typer.prompt", return_value="12"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"value": 12} + + def test_prompt_for_json_schema_multi_variant_union_rejects_unmatched_type(self, mock_console) -> None: + """Test union prompts reject values that do not match any variant types.""" + schema = { + "type": "object", + "properties": { + "value": { + "oneOf": [{"type": "integer"}, {"type": "boolean"}], + }, + }, + "required": ["value"], + } + + with patch("typer.prompt", return_value="alpha"): + with pytest.raises(CLIError, match="Field 'value' is invalid"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_object_union_requires_branch_fields(self, mock_console) -> None: + """Test object unions reject payloads that satisfy no branch required fields.""" + schema = { + "type": "object", + "properties": { + "value": { + "oneOf": [ + {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, + {"type": "object", "properties": {"id": {"type": "integer"}}, "required": ["id"]}, + ] + }, + }, + "required": ["value"], + } + + with patch("typer.prompt", return_value="{}"): + with pytest.raises(CLIError, match="Field 'value' is invalid"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_object_union_accepts_matching_branch(self, mock_console) -> None: + """Test object unions accept a payload matching one branch schema.""" + schema = { + "type": "object", + "properties": { + "value": { + "oneOf": [ + {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, + {"type": "object", "properties": {"id": {"type": "integer"}}, "required": ["id"]}, + ] + }, + }, + "required": ["value"], + } + + with patch("typer.prompt", return_value='{"name":"alpha"}'): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"value": {"name": "alpha"}} + + def test_prompt_for_json_schema_discriminated_union_const_accepts_matching_branch(self, mock_console) -> None: + """Test oneOf discriminator branches using const validate correctly.""" + schema = { + "type": "object", + "properties": { + "value": { + "oneOf": [ + { + "type": "object", + "properties": { + "kind": {"const": "a"}, + "count": {"type": "integer"}, + }, + "required": ["kind", "count"], + }, + { + "type": "object", + "properties": { + "kind": {"const": "b"}, + "name": {"type": "string"}, + }, + "required": ["kind", "name"], + }, + ] + }, + }, + "required": ["value"], + } + + with patch("typer.prompt", return_value='{"kind":"a","count":1}'): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"value": {"kind": "a", "count": 1}} + + def test_prompt_for_json_schema_discriminated_union_const_rejects_invalid_discriminator(self, mock_console) -> None: + """Test oneOf discriminator branches reject unmatched const values.""" + schema = { + "type": "object", + "properties": { + "value": { + "oneOf": [ + { + "type": "object", + "properties": { + "kind": {"const": "a"}, + "count": {"type": "integer"}, + }, + "required": ["kind", "count"], + }, + { + "type": "object", + "properties": { + "kind": {"const": "b"}, + "name": {"type": "string"}, + }, + "required": ["kind", "name"], + }, + ] + }, + }, + "required": ["value"], + } + + with patch("typer.prompt", return_value='{"kind":"c","count":1}'): + with pytest.raises(CLIError, match="Field 'value' is invalid"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_union_optional_blank_skips_field(self, mock_console) -> None: + """Test optional union fields are skipped when blank input is provided.""" + schema = { + "type": "object", + "properties": { + "value": { + "oneOf": [{"type": "integer"}, {"type": "string"}], + "default": {"nested": 1}, + }, + }, + } + + with patch("typer.prompt", return_value=""): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {} + + def test_prompt_for_json_schema_union_required_blank_raises(self, mock_console) -> None: + """Test required union fields reject blank input.""" + schema = { + "type": "object", + "properties": { + "value": { + "oneOf": [{"type": "integer"}, {"type": "string"}], + "default": {"nested": 1}, + }, + }, + "required": ["value"], + } + + with patch("typer.prompt", return_value=""): + with pytest.raises(CLIError, match="Field 'value' is required"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_resolve_effective_schema_preserves_multi_variant_union(self) -> None: + """Test effective schema resolver does not collapse multi-variant unions.""" + schema = { + "oneOf": [{"type": "integer"}, {"type": "string"}], + "description": "multi-type", + } + + result = _resolve_effective_schema({}, schema) + + assert "oneOf" in result + assert result["description"] == "multi-type" + + def test_resolve_schema_type_type_list_multiple_non_null_returns_union(self) -> None: + """Test type lists with multiple non-null entries resolve to union.""" + schema = {"type": ["integer", "string", "null"]} + assert _resolve_schema_type({}, schema) == "union" + + def test_resolve_effective_schema_recursive_any_of_ref_does_not_raise(self) -> None: + """Test recursive anyOf refs are handled without recursion errors.""" + schema = { + "type": "object", + "$defs": { + "Node": { + "type": "object", + "properties": { + "next": { + "anyOf": [{"type": "null"}, {"$ref": "#/$defs/Node"}], + } + }, + } + }, + "properties": {"node": {"$ref": "#/$defs/Node"}}, + "required": ["node"], + } + next_schema = schema["$defs"]["Node"]["properties"]["next"] + + result = _resolve_effective_schema(schema, next_schema) + assert isinstance(result, dict) + + def test_resolve_effective_schema_preserves_nullable_recursive_ref_branch(self) -> None: + """Test nullable combinator is not collapsed when non-null branch is recursive.""" + schema = { + "type": "object", + "$defs": { + "Node": { + "type": "object", + "properties": { + "next": { + "anyOf": [{"type": "null"}, {"$ref": "#/$defs/Node"}], + } + }, + "required": ["next"], + } + }, + "properties": { + "next": { + "anyOf": [{"type": "null"}, {"$ref": "#/$defs/Node"}], + } + }, + "required": ["next"], + } + + next_schema = schema["properties"]["next"] + result = _resolve_effective_schema(schema, next_schema) + + assert "anyOf" in result + assert isinstance(result["anyOf"], list) + + def test_resolve_schema_type_recursive_nullable_ref_returns_union(self) -> None: + """Test recursive nullable refs resolve to union type for prompting.""" + schema = { + "type": "object", + "$defs": { + "Node": { + "type": "object", + "properties": { + "next": { + "anyOf": [{"type": "null"}, {"$ref": "#/$defs/Node"}], + } + }, + "required": ["next"], + } + }, + "properties": { + "next": { + "anyOf": [{"type": "null"}, {"$ref": "#/$defs/Node"}], + } + }, + "required": ["next"], + } + + next_schema = schema["properties"]["next"] + assert _resolve_schema_type(schema, next_schema) == "union" + + def test_resolve_schema_type_non_recursive_nullable_ref_returns_inner_type(self) -> None: + """Test nullable non-recursive refs resolve to their concrete type.""" + schema = { + "type": "object", + "$defs": { + "Value": {"type": "integer"}, + }, + "properties": { + "value": {"anyOf": [{"type": "null"}, {"$ref": "#/$defs/Value"}]}, + }, + "required": ["value"], + } + + value_schema = schema["properties"]["value"] + assert _resolve_schema_type(schema, value_schema) == "integer" + + def test_prompt_for_json_schema_recursive_nullable_required_field_accepts_null(self, mock_console) -> None: + """Test recursive nullable required fields prompt as union and accept explicit null.""" + schema = { + "type": "object", + "$defs": { + "Node": { + "type": "object", + "properties": { + "next": { + "anyOf": [{"type": "null"}, {"$ref": "#/$defs/Node"}], + } + }, + "required": ["next"], + } + }, + "properties": { + "next": { + "anyOf": [{"type": "null"}, {"$ref": "#/$defs/Node"}], + } + }, + "required": ["next"], + } + + with patch("typer.prompt", return_value="null"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"next": None} + + def test_resolve_effective_schema_non_dict_input_returns_empty_schema(self) -> None: + """Test effective schema resolver returns empty schema for non-dict input.""" + assert _resolve_effective_schema({}, "not-a-schema") == {} # type: ignore[arg-type] + + def test_resolve_schema_type_non_dict_input_defaults_to_string(self) -> None: + """Test schema type resolver defaults to string for non-dict inputs.""" + assert _resolve_schema_type({}, "not-a-schema") == "string" # type: ignore[arg-type] + + def test_resolve_schema_type_ref_updates_resolution_stack(self) -> None: + """Test schema type resolver handles non-cyclic refs via resolving stack.""" + schema = {"$defs": {"Value": {"type": "integer"}}} + ref_schema = {"$ref": "#/$defs/Value"} + assert _resolve_schema_type(schema, ref_schema) == "integer" + + def test_resolve_effective_schema_ref_cycle_guard_returns_original_ref_schema(self) -> None: + """Test effective schema resolver short-circuits on repeated ref in stack.""" + schema = {"$defs": {"Loop": {"type": "string"}}} + ref_schema = {"$ref": "#/$defs/Loop"} + assert _resolve_effective_schema(schema, ref_schema, resolving_refs=("#/$defs/Loop",)) == ref_schema + + def test_resolve_schema_type_ref_cycle_guard_returns_union(self) -> None: + """Test schema type resolver short-circuits repeated refs as union.""" + schema = {"$defs": {"Loop": {"type": "string"}}} + ref_schema = {"$ref": "#/$defs/Loop"} + assert _resolve_schema_type(schema, ref_schema, resolving_refs=("#/$defs/Loop",)) == "union" + + def test_infer_schema_type_required_without_properties_returns_object(self) -> None: + """Test direct required-key inference resolves to object.""" + assert _infer_schema_type({"required": ["id"]}) == "object" + + def test_infer_schema_type_type_list_null_only_returns_null(self) -> None: + """Test nullable-only `type` lists resolve to the null type.""" + assert _infer_schema_type({"type": ["null"]}) == "null" + + def test_infer_schema_type_type_list_with_no_valid_entries_falls_back(self) -> None: + """Type lists without valid string entries should fall back to structural hints.""" + assert _infer_schema_type({"type": [1], "items": {"type": "string"}}) == "array" + + def test_strip_schema_internal_properties_without_properties_is_noop(self) -> None: + """Schema stripping should handle missing `properties`/`required` keys.""" + schema = {"type": "object"} + assert _strip_schema_internal_properties(schema, {"auth_value"}) == schema + + def test_resolve_effective_schema_can_skip_nullable_collapse(self) -> None: + """Test caller can preserve nullable combinator wrappers.""" + schema = { + "anyOf": [{"type": "null"}, {"type": "integer"}], + "description": "nullable value", + } + + result = _resolve_effective_schema({}, schema, collapse_nullable=False) + + assert "anyOf" in result + assert result["description"] == "nullable value" + + def test_schema_contains_ref_non_dict_returns_false(self) -> None: + """Test ref search helper returns false for non-dict input.""" + assert not _schema_contains_ref("not-a-schema", "#/$defs/Node") # type: ignore[arg-type] + + def test_schema_contains_ref_handles_self_referential_schema(self) -> None: + """Test ref search helper avoids infinite loops on self-referential dicts.""" + recursive_schema: Dict[str, Any] = {} + recursive_schema["self"] = recursive_schema + assert not _schema_contains_ref(recursive_schema, "#/$defs/Node") + + def test_schema_contains_ref_list_without_dicts_continues_search(self) -> None: + """Test ref search helper continues after list entries without dict values.""" + schema = { + "items": [1, "x", True], + "next": {"$ref": "#/$defs/Node"}, + } + assert _schema_contains_ref(schema, "#/$defs/Node") + + def test_prompt_for_json_schema_ref_index_out_of_bounds_raises(self, mock_console) -> None: + """Test out-of-bounds array indexes in local $ref pointers are rejected.""" + schema = { + "type": "object", + "x-variants": [{"type": "string"}], + "properties": { + "value": {"$ref": "#/x-variants/1"}, + }, + "required": ["value"], + } + + with pytest.raises(CLIError, match="Schema reference index out of bounds"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_ref_valid_array_index_resolves(self, mock_console) -> None: + """Test valid array indexes in local $ref pointers resolve correctly.""" + schema = { + "type": "object", + "x-variants": [ + { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + ], + "properties": { + "value": {"$ref": "#/x-variants/0"}, + }, + "required": ["value"], + } + + with patch("typer.prompt", return_value="selected"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"value": {"name": "selected"}} + + def test_prompt_for_json_schema_ref_path_invalid_raises(self, mock_console) -> None: + """Test local $ref traversal fails on invalid scalar path traversal.""" + schema = { + "type": "object", + "x-scalar": 1, + "properties": { + "value": {"$ref": "#/x-scalar/child"}, + }, + "required": ["value"], + } + + with pytest.raises(CLIError, match="Schema reference path is invalid"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_nested_with_non_empty_indent(self, mock_console) -> None: + """Test nested prompting works with a non-empty root indent.""" + schema = { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + "required": ["name"], + } + }, + "required": ["config"], + } + + with patch("typer.prompt", return_value="nested-name"): + result = prompt_for_json_schema(schema, indent="|", prompt_optional=False) + + assert result == {"config": {"name": "nested-name"}} + + def test_prompt_for_json_schema_fallback_title_and_prompt_text_metadata(self, mock_console) -> None: + """Test fallback display title and prompt metadata formatting branches.""" + schema = { + "type": "object", + "title": "", + "properties": { + "query": { + "type": "string", + "description": "Search query", + "default": "seed", + } + }, + "required": ["query"], + } + + with patch("typer.prompt", return_value="manual"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"query": "manual"} + + def test_prompt_for_json_schema_optional_object_decline(self, mock_console) -> None: + """Test optional object fields can be skipped.""" + schema = { + "type": "object", + "properties": { + "config": { + "properties": { + "name": {"type": "string"}, + }, + } + }, + } + + with patch("typer.confirm", return_value=False): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {} + + def test_prompt_for_json_schema_optional_object_included(self, mock_console) -> None: + """Test optional object fields can be included and prompted.""" + schema = { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + "required": ["name"], + } + }, + } + + with patch("typer.confirm", return_value=True), patch("typer.prompt", return_value="chosen"): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {"config": {"name": "chosen"}} + + def test_prompt_for_json_schema_object_inferred_from_required_keyword(self, mock_console) -> None: + """Test object type inference from `required` when `type` is missing.""" + schema = { + "type": "object", + "properties": { + "meta": { + "properties": {"id": {"type": "string"}}, + "required": ["id"], + } + }, + "required": ["meta"], + } + + with patch("typer.prompt", return_value="abc"): + result = prompt_for_json_schema(schema, prompt_optional=False) + assert result == {"meta": {"id": "abc"}} + + def test_prompt_for_json_schema_optional_array_inferred_and_declined(self, mock_console) -> None: + """Test optional arrays inferred from `items` can be skipped.""" + schema = { + "type": "object", + "properties": { + "tags": { + "items": {"type": "string"}, + } + }, + } + + with patch("typer.prompt", return_value=""): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {} + + def test_prompt_for_json_schema_optional_integer_array_declined_is_skipped(self, mock_console) -> None: + """Optional non-string arrays should be skipped when declined via include gate.""" + schema = { + "type": "object", + "properties": { + "tags": { + "type": "array", + "items": {"type": "integer"}, + } + }, + } + + with patch("typer.confirm", return_value=False): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {} + + def test_prompt_for_json_schema_type_list_non_string_falls_back_to_items(self, mock_console) -> None: + """Test invalid type-list schemas are rejected by JSON Schema validation.""" + schema = { + "type": "object", + "properties": { + "tags": { + "type": [1], + "items": {"type": "string"}, + } + }, + } + + with patch("typer.confirm", return_value=False): + with pytest.raises(CLIError, match="Invalid JSON Schema"): + prompt_for_json_schema(schema, prompt_optional=True) + + def test_prompt_for_json_schema_optional_array_include_no_entries(self, mock_console) -> None: + """Test optional arrays can be included with no entries.""" + schema = { + "type": "object", + "properties": { + "tags": { + "type": "array", + "items": {"type": "integer"}, + } + }, + } + + with patch("typer.confirm", side_effect=[True, False]): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {"tags": []} + + def test_prompt_for_json_schema_optional_array_collects_multiple_entries(self, mock_console) -> None: + """Test array entry prompting loops correctly for multiple values.""" + schema = { + "type": "object", + "properties": { + "tags": { + "type": "array", + "items": {"type": "integer"}, + } + }, + } + + with patch("typer.confirm", side_effect=[True, True, True, False]), patch("typer.prompt", side_effect=[1, 2]): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {"tags": [1, 2]} + + def test_prompt_for_json_schema_prefilled_array_of_objects_validates_entries(self, mock_console) -> None: + """Test prefilled object-array entries recurse and remain schema-valid.""" + schema = { + "type": "object", + "properties": { + "rows": { + "type": "array", + "items": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + } + }, + "required": ["rows"], + } + prefilled = {"rows": [{"name": "row-one"}, {"name": "row-two"}]} + + result = prompt_for_json_schema(schema, prefilled=prefilled, prompt_optional=False) + assert result == prefilled + + def test_prompt_for_json_schema_prefilled_array_of_objects_invalid_entry_raises(self, mock_console) -> None: + """Test invalid prefilled object-array entries fail final schema validation.""" + schema = { + "type": "object", + "properties": { + "rows": { + "type": "array", + "items": { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + }, + } + }, + "required": ["rows"], + } + prefilled = {"rows": [{"name": "row-one"}, "raw-entry"]} + + with pytest.raises(CLIError, match="Prompted payload is invalid"): + prompt_for_json_schema(schema, prefilled=prefilled, prompt_optional=False) + + def test_prompt_for_json_schema_prefilled_array_of_arrays_validates_entries(self, mock_console) -> None: + """Test prefilled nested arrays recurse through nested item schemas.""" + schema = { + "type": "object", + "properties": { + "matrix": { + "type": "array", + "items": { + "type": "array", + "items": {"type": "integer"}, + }, + } + }, + "required": ["matrix"], + } + prefilled = {"matrix": [[1, 2], [3]]} + + result = prompt_for_json_schema(schema, prefilled=prefilled, prompt_optional=False) + assert result == prefilled + + def test_prompt_for_json_schema_prefilled_array_of_arrays_invalid_entry_raises(self, mock_console) -> None: + """Test invalid prefilled nested arrays fail final schema validation.""" + schema = { + "type": "object", + "properties": { + "matrix": { + "type": "array", + "items": { + "type": "array", + "items": {"type": "integer"}, + }, + } + }, + "required": ["matrix"], + } + prefilled = {"matrix": [[1, 2], 3]} + + with pytest.raises(CLIError, match="Prompted payload is invalid"): + prompt_for_json_schema(schema, prefilled=prefilled, prompt_optional=False) + + def test_prompt_for_json_schema_enum_uses_raw_string_match(self, mock_console) -> None: + """Test enum prompts accept raw values when JSON parsing changes type.""" + schema = { + "type": "object", + "properties": { + "mode": { + "enum": ["1"], + "default": "1", + } + }, + "required": ["mode"], + } + + with patch("typer.prompt", return_value="1"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"mode": "1"} + + def test_prompt_for_json_schema_enum_uses_json_parsed_match(self, mock_console) -> None: + """Test enum prompts accept values that match after JSON parsing.""" + schema = { + "type": "object", + "properties": { + "priority": { + "enum": [1, 2], + } + }, + "required": ["priority"], + } + + with patch("typer.prompt", return_value="1"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"priority": 1} + + def test_prompt_for_json_schema_optional_enum_blank_is_skipped(self, mock_console) -> None: + """Test optional enum fields are skipped when left blank.""" + schema = { + "type": "object", + "properties": { + "mode": {"enum": ["enforce", "disabled"]}, + }, + } + + with patch("typer.prompt", return_value=""): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {} + + def test_prompt_for_json_schema_required_enum_blank_raises(self, mock_console) -> None: + """Test required enum fields reject blank values.""" + schema = { + "type": "object", + "properties": { + "mode": {"enum": ["enforce", "disabled"]}, + }, + "required": ["mode"], + } + + with patch("typer.prompt", return_value=""): + with pytest.raises(CLIError, match="Field 'mode' is required"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_invalid_enum_value_raises(self, mock_console) -> None: + """Test invalid enum values raise a clear error.""" + schema = { + "type": "object", + "properties": { + "mode": {"enum": ["enforce", "disabled"]}, + }, + "required": ["mode"], + } + + with patch("typer.prompt", return_value="invalid"): + with pytest.raises(CLIError, match="must be one of"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_optional_boolean_include_and_prompt(self, mock_console) -> None: + """Test optional booleans prompt when included.""" + schema = { + "type": "object", + "properties": { + "enabled": { + "type": "boolean", + "default": True, + } + }, + } + + with patch("typer.confirm", return_value=True), patch("typer.prompt", return_value=False): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {"enabled": False} + + def test_prompt_for_json_schema_optional_boolean_decline(self, mock_console) -> None: + """Test optional booleans can be skipped by declining inclusion.""" + schema = { + "type": "object", + "properties": { + "enabled": { + "type": "boolean", + } + }, + } + + with patch("typer.confirm", return_value=False): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {} + + def test_prompt_for_json_schema_required_boolean_prompts_directly(self, mock_console) -> None: + """Test required booleans prompt without inclusion confirmation.""" + schema = { + "type": "object", + "properties": { + "enabled": { + "type": "boolean", + } + }, + "required": ["enabled"], + } + + with patch("typer.prompt", return_value=True): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"enabled": True} + + def test_prompt_for_json_schema_integer_with_default(self, mock_console) -> None: + """Test integer prompts with integer defaults.""" + schema = { + "type": "object", + "properties": { + "count": { + "type": "integer", + "default": 3, + } + }, + "required": ["count"], + } + + with patch("typer.prompt", return_value=7): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"count": 7} + + def test_prompt_for_json_schema_required_integer_sentinel_raises(self, mock_console) -> None: + """Test required integers reject sentinel-equivalent empty values.""" + schema = { + "type": "object", + "properties": { + "count": {"type": "integer"}, + }, + "required": ["count"], + } + + with patch("typer.prompt", return_value=_INT_SENTINEL_DEFAULT): + with pytest.raises(CLIError, match="Field 'count' is required"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_optional_integer_sentinel_is_skipped(self, mock_console) -> None: + """Test optional integers are skipped when sentinel-equivalent value is returned.""" + schema = { + "type": "object", + "properties": { + "count": {"type": "integer"}, + }, + } + + with patch("typer.prompt", return_value=_INT_SENTINEL_DEFAULT): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {} + + def test_prompt_for_json_schema_number_with_default(self, mock_console) -> None: + """Test number prompts parse float values and respect defaults.""" + schema = { + "type": "object", + "properties": { + "score": { + "type": "number", + "default": 1.5, + } + }, + "required": ["score"], + } + + with patch("typer.prompt", return_value="2.25"): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"score": 2.25} + + def test_prompt_for_json_schema_required_number_blank_raises(self, mock_console) -> None: + """Test required numbers reject blank input.""" + schema = { + "type": "object", + "properties": { + "score": {"type": "number"}, + }, + "required": ["score"], + } + + with patch("typer.prompt", return_value=""): + with pytest.raises(CLIError, match="Field 'score' is required"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_optional_number_blank_is_skipped(self, mock_console) -> None: + """Test optional numbers are skipped when blank input is provided.""" + schema = { + "type": "object", + "properties": { + "score": {"type": "number"}, + }, + } + + with patch("typer.prompt", return_value=""): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {} + + def test_prompt_for_json_schema_invalid_number_raises(self, mock_console) -> None: + """Test invalid numeric input raises a number-specific error.""" + schema = { + "type": "object", + "properties": { + "score": {"type": "number"}, + }, + "required": ["score"], + } + + with patch("typer.prompt", return_value="not-a-number"): + with pytest.raises(CLIError, match="Field 'score' must be a number"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_nullable_type_list_with_only_null(self, mock_console) -> None: + """Test union-like type lists containing only null resolve to None.""" + schema = { + "type": "object", + "properties": { + "empty_value": {"type": ["null"]}, + }, + "required": ["empty_value"], + } + + result = prompt_for_json_schema(schema, prompt_optional=False) + assert result == {"empty_value": None} + + def test_prompt_for_json_schema_string_default_non_string_value(self, mock_console) -> None: + """Test string-like fallback prompts render non-string defaults.""" + schema = { + "type": "object", + "properties": { + "payload": { + "default": {"kind": "map"}, + } + }, + "required": ["payload"], + } + + with patch("typer.prompt", return_value='{"ok":true}'): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"payload": '{"ok":true}'} + + def test_prompt_for_json_schema_required_fallback_string_blank_raises(self, mock_console) -> None: + """Test required fallback string fields reject blank input.""" + schema = { + "type": "object", + "properties": { + "name": {}, + }, + "required": ["name"], + } + + with patch("typer.prompt", return_value=""): + with pytest.raises(CLIError, match="Field 'name' is required"): + prompt_for_json_schema(schema, prompt_optional=False) + + def test_prompt_for_json_schema_optional_fallback_string_blank_is_skipped(self, mock_console) -> None: + """Test optional fallback string fields are skipped when blank.""" + schema = { + "type": "object", + "properties": { + "name": {}, + }, + } + + with patch("typer.prompt", return_value=""): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {} + + def test_prompt_for_json_schema_optional_object_declined_is_omitted(self, mock_console) -> None: + """Test optional object fields can be omitted via the include gate.""" + schema = { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": {"url": {"type": "string"}}, + "required": ["url"], + } + }, + } + + with patch("typer.confirm", return_value=False): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {} + + def test_prompt_for_json_schema_optional_object_accepted_prompts_nested(self, mock_console) -> None: + """Test optional object fields prompt nested values when accepted.""" + schema = { + "type": "object", + "properties": { + "config": { + "type": "object", + "properties": {"url": {"type": "string"}}, + "required": ["url"], + } + }, + } + + with patch("typer.confirm", return_value=True), patch("typer.prompt", return_value="https://example.com"): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {"config": {"url": "https://example.com"}} + + def test_prompt_for_json_schema_additional_properties_schema(self, mock_console) -> None: + """Test `additionalProperties` schema prompts for typed extra fields.""" + schema = { + "type": "object", + "additionalProperties": {"type": "integer"}, + } + + with patch("typer.confirm", side_effect=[True, False]), patch("typer.prompt", side_effect=["max_items", 10]): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {"max_items": 10} + + def test_prompt_for_json_schema_additional_properties_true_json_and_raw(self, mock_console) -> None: + """Test `additionalProperties: true` parses JSON and falls back to raw strings.""" + schema = { + "type": "object", + "additionalProperties": True, + } + + with patch("typer.confirm", side_effect=[True, True, False]), patch("typer.prompt", side_effect=["alpha", '{"nested": 1}', "beta", "raw-text"]): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {"alpha": {"nested": 1}, "beta": "raw-text"} + + def test_prompt_for_json_schema_required_csv_array_blank_returns_empty_list(self, mock_console) -> None: + """Test required CSV-style arrays return an empty list on blank input.""" + schema = { + "type": "object", + "properties": { + "tags": { + "type": "array", + "items": {"type": "string"}, + } + }, + "required": ["tags"], + } + + with patch("typer.prompt", return_value=""): + result = prompt_for_json_schema(schema, prompt_optional=False) + + assert result == {"tags": []} + + def test_prompt_for_json_schema_optional_string_array_blank_is_skipped(self, mock_console) -> None: + """Test optional string arrays skip cleanly when blank.""" + schema = { + "type": "object", + "properties": { + "tags": { + "type": "array", + "items": {"type": "string"}, + } + }, + } + + with patch("typer.prompt", return_value=""): + result = prompt_for_json_schema(schema, prompt_optional=True) + + assert result == {} + + def test_prompt_for_schema_rejects_non_string_dict_keys(self, mock_console) -> None: + """Test dict fields with non-string keys are rejected for prompting.""" + + class TestSchema(BaseModel): + values: dict[int, str] + + with pytest.raises(CLIError, match="Only string keys are supported"): + prompt_for_schema(TestSchema) + + def test_prompt_for_schema_rejects_nested_non_string_dict_keys(self, mock_console) -> None: + """Nested Pydantic models should reject dict fields with non-string keys.""" + + class SubSchema(BaseModel): + values: dict[int, str] + + class OuterSchema(BaseModel): + sub: SubSchema + + with pytest.raises(CLIError, match="Only string keys are supported"): + prompt_for_schema(OuterSchema) + + def test_prompt_for_schema_traverses_annotated_optional_and_dict_values(self, mock_console) -> None: + """Validator should traverse Annotated, Optional[...], and dict value types.""" + + class TestSchema(BaseModel): + annotated_scalar: Annotated[int, "meta"] + optional_scalar: Optional[int] + mapping: Dict[str, str] + untyped_mapping: Dict + bad: dict[int, str] + + with pytest.raises(CLIError, match="Only string keys are supported"): + prompt_for_schema(TestSchema) + + def test_resolve_ref_schema_non_dict_input_returns_empty_schema(self) -> None: + """Test resolve helper defensively returns empty schema for non-dict input.""" + result = _resolve_ref_schema({}, "not-a-dict") # type: ignore[arg-type] + assert result == {} + + def test_resolve_schema_type_any_of_all_null_returns_null(self) -> None: + """Test schema type resolver returns null when all anyOf variants are null.""" + schema = {"anyOf": [{"type": "null"}, {"type": "null"}]} + assert _resolve_schema_type({}, schema) == "null" + + def test_resolve_schema_type_one_of_prefers_non_null_variant(self) -> None: + """Test schema type resolver returns first non-null type from oneOf variants.""" + schema = {"oneOf": [{"type": "null"}, {"type": "integer"}]} + assert _resolve_schema_type({}, schema) == "integer" diff --git a/tests/common/test_render.py b/tests/common/test_render.py new file mode 100644 index 0000000..44b1133 --- /dev/null +++ b/tests/common/test_render.py @@ -0,0 +1,338 @@ +# -*- coding: utf-8 -*- +"""Tests for cforge.common.render.""" + +# Third-Party +from rich.panel import Panel +from rich.syntax import Syntax +from rich.table import Table + +# First-Party +from cforge.common.render import LineLimit, print_json, print_table + + +class TestLineLimit: + """Tests for LineLimit class that truncates rendered content.""" + + def test_line_limit_basic_truncation(self) -> None: + """Test that LineLimit truncates content to max_lines.""" + from rich.console import Console + from rich.text import Text + + console = Console() + # Create text with 5 lines + text = Text("Line 1\nLine 2\nLine 3\nLine 4\nLine 5") + limited = LineLimit(text, max_lines=3) + + # Render to string and verify truncation + with console.capture() as capture: + console.print(limited) + + output = capture.get() + # Should contain first 3 lines + assert "Line 1" in output + assert "Line 2" in output + assert "Line 3" in output + # Should NOT contain lines 4 and 5 + assert "Line 4" not in output + assert "Line 5" not in output + # Should have ellipsis + assert "..." in output + + def test_line_limit_no_truncation_needed(self) -> None: + """Test that LineLimit doesn't truncate when content is within limit.""" + from rich.console import Console + from rich.text import Text + + console = Console() + # Create text with 2 lines, limit to 5 + text = Text("Line 1\nLine 2") + limited = LineLimit(text, max_lines=5) + + with console.capture() as capture: + console.print(limited) + + output = capture.get() + # Should contain both lines + assert "Line 1" in output + assert "Line 2" in output + # Should NOT have ellipsis since no truncation + assert "..." not in output + + def test_line_limit_exact_match(self) -> None: + """Test LineLimit when content exactly matches max_lines.""" + from rich.console import Console + from rich.text import Text + + console = Console() + # Create text with exactly 3 lines + text = Text("Line 1\nLine 2\nLine 3") + limited = LineLimit(text, max_lines=3) + + with console.capture() as capture: + console.print(limited) + + output = capture.get() + # Should contain all 3 lines + assert "Line 1" in output + assert "Line 2" in output + assert "Line 3" in output + # Should NOT have ellipsis since content fits exactly + assert "..." not in output + + def test_line_limit_zero_lines(self) -> None: + """Test LineLimit with max_lines=0 shows only ellipsis.""" + from rich.console import Console + from rich.text import Text + + console = Console() + text = Text("Line 1\nLine 2") + limited = LineLimit(text, max_lines=0) + + with console.capture() as capture: + console.print(limited) + + output = capture.get() + # Should only show ellipsis, no content + assert "..." in output + assert "Line 1" not in output + assert "Line 2" not in output + + def test_line_limit_one_line(self) -> None: + """Test LineLimit with max_lines=1.""" + from rich.console import Console + from rich.text import Text + + console = Console() + text = Text("Line 1\nLine 2\nLine 3") + limited = LineLimit(text, max_lines=1) + + with console.capture() as capture: + console.print(limited) + + output = capture.get() + # Should show only first line and ellipsis + assert "Line 1" in output + assert "..." in output + assert "Line 2" not in output + assert "Line 3" not in output + + def test_line_limit_with_long_single_line(self) -> None: + """Test LineLimit with a single long line that wraps.""" + from rich.console import Console + from rich.text import Text + + console = Console(width=80) # Set fixed width for predictable wrapping + # Create a very long line that will wrap + long_text = "A" * 200 + text = Text(long_text) + limited = LineLimit(text, max_lines=2) + + with console.capture() as capture: + console.print(limited) + + output = capture.get() + # Should contain some A's but be truncated + assert "A" in output + # Should have ellipsis since it wraps to more than 2 lines + assert "..." in output + + def test_line_limit_measurement_passthrough(self) -> None: + """Test that LineLimit passes through measurement to wrapped renderable.""" + from rich.console import Console + from rich.text import Text + + console = Console() + text = Text("Test content") + limited = LineLimit(text, max_lines=3) + + # Get measurement using console's options + measurement = console.measure(limited) + + # Should return a valid Measurement + assert measurement is not None + assert hasattr(measurement, "minimum") + assert hasattr(measurement, "maximum") + + def test_line_limit_with_empty_content(self) -> None: + """Test LineLimit with empty content.""" + from rich.console import Console + from rich.text import Text + + console = Console() + text = Text("") + limited = LineLimit(text, max_lines=3) + + with console.capture() as capture: + console.print(limited) + + output = capture.get() + # Empty content should produce minimal output + # Should not have ellipsis since there's nothing to truncate + assert "..." not in output + + def test_line_limit_preserves_styling(self) -> None: + """Test that LineLimit preserves rich styling in truncated content.""" + from rich.console import Console + from rich.text import Text + + console = Console() + # Create styled text + text = Text() + text.append("Line 1\n", style="bold red") + text.append("Line 2\n", style="italic blue") + text.append("Line 3\n", style="underline green") + text.append("Line 4", style="bold yellow") + + limited = LineLimit(text, max_lines=2) + + with console.capture() as capture: + console.print(limited) + + output = capture.get() + # Should contain first 2 lines + assert "Line 1" in output + assert "Line 2" in output + # Should NOT contain lines 3 and 4 + assert "Line 3" not in output + assert "Line 4" not in output + # Should have ellipsis + assert "..." in output + + +class TestPrettyPrinting: + """Tests for pretty printing functions.""" + + def test_print_json_with_title(self, mock_console) -> None: + """Test print_json with a title.""" + test_data = {"key": "value", "number": 42} + + print_json(test_data, "Test Title") + + # Verify console.print was called + mock_console.print.assert_called_once() + call_args = mock_console.print.call_args[0] + + # Should be wrapped in a Panel + assert isinstance(call_args[0], Panel) + + def test_print_json_without_title(self, mock_console) -> None: + """Test print_json without a title.""" + test_data = {"key": "value"} + + print_json(test_data) + + # Verify console.print was called + mock_console.print.assert_called_once() + call_args = mock_console.print.call_args[0] + + # Should be Syntax object, not Panel + assert isinstance(call_args[0], Syntax) + + def test_print_table(self, mock_console) -> None: + """Test print_table with data.""" + test_data = [ + {"id": 1, "name": "Item 1", "value": "A"}, + {"id": 2, "name": "Item 2", "value": "B"}, + ] + columns = ["id", "name", "value"] + + print_table(test_data, "Test Table", columns) + + # Verify console.print was called + mock_console.print.assert_called_once() + call_args = mock_console.print.call_args[0] + + # Should be a Table + assert isinstance(call_args[0], Table) + + def test_print_table_missing_columns(self, mock_console) -> None: + """Test print_table handles missing columns gracefully.""" + test_data = [ + {"id": 1, "name": "Item 1"}, # Missing 'value' column + ] + columns = ["id", "name", "value"] + + # Should not raise an error + print_table(test_data, "Test Table", columns) + mock_console.print.assert_called_once() + + def test_print_table_wraps_all_cells_with_line_limit(self) -> None: + """Test that print_table wraps all cell values with LineLimit for truncation.""" + from unittest.mock import patch + + # Create test data with various types + test_data = [ + {"id": 1, "name": "Item 1", "description": "Short text"}, + {"id": 2, "name": "Item 2", "description": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"}, + ] + columns = ["id", "name", "description"] + + # Mock Table.add_row to capture what's passed to it + with patch.object(Table, "add_row") as mock_add_row: + print_table(test_data, "Test Table", columns) + + # Verify add_row was called for each data row + assert mock_add_row.call_count == 2 + + # Check that all arguments to add_row are LineLimit instances + for call in mock_add_row.call_args_list: + args = call[0] # Get positional arguments + for arg in args: + assert isinstance(arg, LineLimit), f"Expected LineLimit but got {type(arg)}" + # Verify max_lines is set to 4 + assert arg.max_lines == 4 + + def test_print_table_with_custom_max_lines(self, mock_settings) -> None: + """Test that print_table respects custom table_max_lines configuration.""" + from unittest.mock import patch + + # Configure mock_settings with custom max_lines value + mock_settings.table_max_lines = 2 + + test_data = [ + {"id": 1, "name": "Item 1", "description": "Short text"}, + {"id": 2, "name": "Item 2", "description": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"}, + ] + columns = ["id", "name", "description"] + + # Mock Table.add_row to capture what's passed to it + with patch.object(Table, "add_row") as mock_add_row: + print_table(test_data, "Test Table", columns) + + # Verify add_row was called for each data row + assert mock_add_row.call_count == 2 + + # Check that all arguments to add_row are LineLimit instances with custom max_lines + for call in mock_add_row.call_args_list: + args = call[0] # Get positional arguments + for arg in args: + assert isinstance(arg, LineLimit), f"Expected LineLimit but got {type(arg)}" + # Verify max_lines is set to custom value of 2 + assert arg.max_lines == 2 + + def test_print_table_with_disabled_line_limit(self, mock_settings) -> None: + """Test that print_table skips LineLimit wrapping when table_max_lines is 0 or negative.""" + from unittest.mock import patch + + # Configure mock_settings with disabled max_lines value (0) + mock_settings.table_max_lines = 0 + + test_data = [ + {"id": 1, "name": "Item 1", "description": "Short text"}, + {"id": 2, "name": "Item 2", "description": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"}, + ] + columns = ["id", "name", "description"] + + # Mock Table.add_row to capture what's passed to it + with patch.object(Table, "add_row") as mock_add_row: + print_table(test_data, "Test Table", columns) + + # Verify add_row was called for each data row + assert mock_add_row.call_count == 2 + + # Check that arguments to add_row are plain strings, NOT LineLimit instances + for call in mock_add_row.call_args_list: + args = call[0] # Get positional arguments + for arg in args: + assert isinstance(arg, str), f"Expected str but got {type(arg)}" + assert not isinstance(arg, LineLimit), "Should not wrap with LineLimit when disabled" diff --git a/tests/common/test_schema_validation.py b/tests/common/test_schema_validation.py new file mode 100644 index 0000000..4d68e3a --- /dev/null +++ b/tests/common/test_schema_validation.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +"""Tests for cforge.common.schema_validation.""" + +# First-Party +from cforge.common.schema_validation import validate_instance, validate_instance_against_subschema, validate_schema + + +class TestSchemaValidation: + """Tests for JSON Schema validation helpers.""" + + def test_validate_schema_valid_returns_none(self) -> None: + """Valid schemas return no error message.""" + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + assert validate_schema(schema) is None + + def test_validate_schema_requires_object_schema(self) -> None: + """Non-object schemas are rejected by guard clause.""" + assert validate_schema([]) == "Input schema must be a JSON object" # type: ignore[arg-type] + + def test_validate_schema_invalid_schema_returns_error(self) -> None: + """Invalid JSON Schemas return a schema-error message.""" + schema = {"type": 1} + message = validate_schema(schema) + assert isinstance(message, str) + assert message.startswith("Invalid JSON Schema:") + + def test_validate_instance_valid_payload(self) -> None: + """Valid payloads return no error message.""" + schema = { + "type": "object", + "properties": {"age": {"type": "integer"}}, + "required": ["age"], + } + + assert validate_instance(schema, {"age": 2}) is None + + def test_validate_instance_reports_nested_path(self) -> None: + """Invalid nested payloads include a JSON path in the message.""" + schema = { + "type": "object", + "properties": {"age": {"type": "integer"}}, + "required": ["age"], + } + + message = validate_instance(schema, {"age": "two"}) + assert isinstance(message, str) + assert "$.age" in message + assert "integer" in message + + def test_validate_instance_reports_root_error(self) -> None: + """Root-level validation failures report message without a path prefix.""" + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + + message = validate_instance(schema, {}) + assert isinstance(message, str) + assert "required property" in message + assert "$." not in message + + def test_validate_instance_requires_object_schema(self) -> None: + """Non-object schemas are rejected by guard clause.""" + assert validate_instance([], {}) == "Input schema must be a JSON object" # type: ignore[arg-type] + + def test_validate_instance_invalid_schema_returns_error(self) -> None: + """Invalid JSON Schemas return a schema-error message.""" + schema = {"type": 1} + + message = validate_instance(schema, "x") + assert isinstance(message, str) + assert message.startswith("Invalid JSON Schema:") + + def test_validate_instance_against_subschema_with_ref_context(self) -> None: + """Subschema validation resolves local refs from the provided root schema.""" + root_schema = { + "type": "object", + "$defs": {"Value": {"type": "integer"}}, + "properties": {"value": {"$ref": "#/$defs/Value"}}, + } + subschema = {"$ref": "#/$defs/Value"} + + assert validate_instance_against_subschema(root_schema, subschema, 9) is None + + message = validate_instance_against_subschema(root_schema, subschema, "nine") + assert isinstance(message, str) + assert "integer" in message + + def test_validate_instance_against_subschema_requires_object_inputs(self) -> None: + """Both root schema and subschema must be JSON objects.""" + assert validate_instance_against_subschema([], {}, 1) == "Input schema must be a JSON object" # type: ignore[arg-type] + assert validate_instance_against_subschema({}, [], 1) == "Input schema must be a JSON object" # type: ignore[arg-type] + + def test_validate_instance_against_subschema_invalid_schema_returns_error(self) -> None: + """Invalid subschemas produce a schema-error message.""" + root_schema = {"type": "object"} + subschema = {"type": 1} + + message = validate_instance_against_subschema(root_schema, subschema, "x") + assert isinstance(message, str) + assert message.startswith("Invalid JSON Schema:") + + def test_validate_instance_against_subschema_formats_array_index_path(self) -> None: + """Validation messages include numeric indices for array paths.""" + root_schema = {"type": "object"} + subschema = {"type": "array", "items": {"type": "integer"}} + + message = validate_instance_against_subschema(root_schema, subschema, [1, "bad"]) + assert isinstance(message, str) + assert "$[1]" in message diff --git a/tests/conftest.py b/tests/conftest.py index ed8f8c8..3c21f61 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,30 +8,30 @@ """ # Standard +from contextlib import contextmanager import inspect import logging import os +from pathlib import Path import socket import sys import tempfile -import time import threading -import urllib3 -from contextlib import contextmanager -from pathlib import Path +import time from types import SimpleNamespace from typing import Any, Callable, Generator, List, Union from unittest.mock import Mock, patch -# Third-Party -import pytest -import uvicorn from fastapi.testclient import TestClient from mcp.server.fastmcp import FastMCP from pydantic import SecretStr + +# Third-Party +import pytest from typer.models import OptionInfo from typer.testing import CliRunner - +import urllib3 +import uvicorn # Before importing anything from the core, force the database to use a temp dir # NOTE: In memory results in missing table errors @@ -42,7 +42,6 @@ # First-Party from cforge.config import CLISettings, get_settings # noqa: E402 - # Suppress urllib3 retry warnings during tests logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) @@ -269,7 +268,7 @@ def test_endpoint(mock_client): client = TestClient(app) mock_client = Mock(wraps=client) - with patch("cforge.common.requests.request", mock_client.request): + with patch("cforge.common.http.requests.request", mock_client.request): yield mock_client diff --git a/tests/test_common.py b/tests/test_common.py deleted file mode 100644 index 3a9f34e..0000000 --- a/tests/test_common.py +++ /dev/null @@ -1,1213 +0,0 @@ -# -*- coding: utf-8 -*- -"""Location: ./tests/test_common.py -Copyright 2025 -SPDX-License-Identifier: Apache-2.0 -Authors: Gabe Goodhart - -Tests for common utility functions. -""" - -# Standard -from pathlib import Path -from typing import Any, Dict, List, Optional -from unittest.mock import Mock, patch -import stat -import tempfile - -# Third-Party -from pydantic import BaseModel, Field -from rich.panel import Panel -from rich.syntax import Syntax -from rich.table import Table -import pytest -import requests - -# First-Party -from cforge.common import ( - _INT_SENTINEL_DEFAULT, - AuthenticationError, - CLIError, - LineLimit, - get_app, - get_auth_token, - get_console, - get_token_file, - load_token, - make_authenticated_request, - print_json, - print_table, - prompt_for_schema, - save_token, -) -from tests.conftest import mock_client_login - - -class TestSingletons: - """Tests for singleton getter functions.""" - - def test_get_console_returns_console(self) -> None: - """Test that get_console returns a Console instance.""" - console = get_console() - assert console is not None - # Should return same instance - assert get_console() is console - - def test_get_app_returns_typer_app(self) -> None: - """Test that get_app returns a Typer instance.""" - app = get_app() - assert app is not None - # Should return same instance - assert get_app() is app - - -class TestTokenManagement: - """Tests for token management functions.""" - - def test_get_token_file(self, mock_settings) -> None: - """Test getting the token file path.""" - token_file = get_token_file() - assert isinstance(token_file, Path) - assert str(token_file).endswith("token") - assert token_file.parent == mock_settings.contextforge_home - - def test_get_token_file_with_active_profile(self, mock_settings) -> None: - """Test getting the token file path uses active profile when available.""" - from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store - from datetime import datetime - - # Create and save an active profile - profile_id = "active-profile-456" - profile = AuthProfile( - id=profile_id, - name="Active Profile", - email="active@example.com", - apiUrl="https://api.example.com", - isActive=True, - createdAt=datetime.now(), - ) - store = ProfileStore( - profiles={profile_id: profile}, - activeProfileId=profile_id, - ) - save_profile_store(store) - - # get_token_file should use the active profile - token_file = get_token_file() - assert str(token_file).endswith(f"token.{profile_id}") - - def test_save_and_load_token(self) -> None: - """Test saving and loading a token.""" - test_token = "test_token_123" - - with tempfile.NamedTemporaryFile() as temp_token_file: - with patch("cforge.common.get_token_file", return_value=Path(temp_token_file.name)): - save_token(test_token) - loaded_token = load_token() - - assert loaded_token == test_token - - def test_save_and_load_token_with_active_profile(self, mock_settings) -> None: - """Test saving and loading a token with an active profile.""" - from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store - from datetime import datetime - - test_token = "profile_token_456" - profile_id = "test-profile-789" - - # Create and save an active profile - profile = AuthProfile( - id=profile_id, - name="Test Profile", - email="test@example.com", - apiUrl="https://api.example.com", - isActive=True, - createdAt=datetime.now(), - ) - store = ProfileStore( - profiles={profile_id: profile}, - activeProfileId=profile_id, - ) - save_profile_store(store) - - # Save and load token - should use profile-specific file - save_token(test_token) - loaded_token = load_token() - - assert loaded_token == test_token - - # Verify it was saved to profile-specific file - token_file = mock_settings.contextforge_home / f"token.{profile_id}" - assert token_file.exists() - - def test_save_token_different_profiles(self, mock_settings) -> None: - """Test that different profiles have separate token files.""" - from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store - from datetime import datetime - - token1 = "token_for_profile_1" - token2 = "token_for_profile_2" - profile_id1 = "profile-1" - profile_id2 = "profile-2" - - # Save token for profile 1 - profile1 = AuthProfile( - id=profile_id1, - name="Profile 1", - email="user1@example.com", - apiUrl="https://api1.example.com", - isActive=True, - createdAt=datetime.now(), - ) - store1 = ProfileStore( - profiles={profile_id1: profile1}, - activeProfileId=profile_id1, - ) - save_profile_store(store1) - save_token(token1) - - # Save token for profile 2 - profile2 = AuthProfile( - id=profile_id2, - name="Profile 2", - email="user2@example.com", - apiUrl="https://api2.example.com", - isActive=True, - createdAt=datetime.now(), - ) - store2 = ProfileStore( - profiles={profile_id2: profile2}, - activeProfileId=profile_id2, - ) - save_profile_store(store2) - save_token(token2) - - # Verify both tokens exist in separate files - token_file1 = mock_settings.contextforge_home / f"token.{profile_id1}" - token_file2 = mock_settings.contextforge_home / f"token.{profile_id2}" - - assert token_file1.exists() - assert token_file2.exists() - assert token_file1.read_text() == token1 - assert token_file2.read_text() == token2 - assert token1 != token2 - - def test_load_token_nonexistent(self, tmp_path: Path) -> None: - """Test loading a token when file doesn't exist.""" - nonexistent_file = tmp_path / "nonexistent" / "token" - - with patch("cforge.common.get_token_file", return_value=nonexistent_file): - token = load_token() - - assert token is None - - def test_load_token_nonexistent_profile(self, mock_settings) -> None: - """Test loading a token for a profile that doesn't have a token file.""" - from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store - from datetime import datetime - - profile_id = "nonexistent-profile" - - # Create an active profile but don't create a token file - profile = AuthProfile( - id=profile_id, - name="Test Profile", - email="test@example.com", - apiUrl="https://api.example.com", - isActive=True, - createdAt=datetime.now(), - ) - store = ProfileStore( - profiles={profile_id: profile}, - activeProfileId=profile_id, - ) - save_profile_store(store) - - # Try to load token - should return None since file doesn't exist - token = load_token() - - assert token is None - - -class TestBaseUrl: - """Tests for get_base_url function.""" - - def test_get_base_url_with_active_profile(self, mock_settings) -> None: - """Test get_base_url returns profile's API URL when active profile exists.""" - from cforge.common import get_base_url - from cforge.profile_utils import AuthProfile, ProfileStore, save_profile_store - from datetime import datetime - - # Create and save a profile - profile = AuthProfile( - id="profile-1", - name="Test Profile", - email="test@example.com", - apiUrl="https://custom-api.example.com", - isActive=True, - createdAt=datetime.now(), - ) - store = ProfileStore( - profiles={"profile-1": profile}, - activeProfileId="profile-1", - ) - save_profile_store(store) - - # Get base URL should return the profile's API URL - base_url = get_base_url() - assert base_url == "https://custom-api.example.com" - - def test_get_base_url_without_active_profile(self, mock_settings) -> None: - """Test get_base_url returns default URL when no active profile.""" - from cforge.common import get_base_url - - # No profile saved, should use settings - base_url = get_base_url() - assert base_url == f"http://{mock_settings.host}:{mock_settings.port}" - - -class TestAuthentication: - """Tests for authentication functions.""" - - def test_get_auth_token_from_env(self, mock_settings) -> None: - """Test getting auth token from environment variable.""" - # Create a new settings instance with token - mock_settings.mcpgateway_bearer_token = "env_token" - with patch("cforge.common.load_token", return_value=None): - token = get_auth_token() - - assert token == "env_token" - - def test_get_auth_token_from_file(self, mock_settings) -> None: - """Test getting auth token from file when env var not set.""" - # mock_settings already has mcpgateway_bearer_token=None - with patch("cforge.common.load_token", return_value="file_token"): - token = get_auth_token() - - assert token == "file_token" - - def test_get_auth_token_none(self, mock_settings) -> None: - """Test getting auth token when none available.""" - # mock_settings already has mcpgateway_bearer_token=None - with patch("cforge.common.load_token", return_value=None): - token = get_auth_token() - - assert token is None - - -class TestAutoLogin: - """Tests for automatic login functionality.""" - - def test_attempt_auto_login_no_profile(self, mock_settings): - """Test auto-login when no profile is active.""" - from cforge.common import attempt_auto_login - - token = attempt_auto_login() - assert token is None - - def test_attempt_auto_login_no_credentials(self, mock_settings): - """Test auto-login when credentials are not available.""" - from cforge.common import attempt_auto_login - from cforge.profile_utils import AuthProfile - from datetime import datetime - - mock_profile = AuthProfile( - id="test-profile", - name="Test", - email="test@example.com", - apiUrl="http://localhost:4444", - isActive=True, - createdAt=datetime.now(), - ) - - with patch("cforge.common.get_active_profile", return_value=mock_profile): - with patch("cforge.common.load_profile_credentials", return_value=None): - token = attempt_auto_login() - assert token is None - - def test_attempt_auto_login_missing_email(self, mock_settings): - """Test auto-login when email is missing from credentials.""" - from cforge.common import attempt_auto_login - from cforge.profile_utils import AuthProfile - from datetime import datetime - - mock_profile = AuthProfile( - id="test-profile", - name="Test", - email="test@example.com", - apiUrl="http://localhost:4444", - isActive=True, - createdAt=datetime.now(), - ) - - with patch("cforge.common.get_active_profile", return_value=mock_profile): - with patch("cforge.common.load_profile_credentials", return_value={"password": "test"}): - token = attempt_auto_login() - assert token is None - - def test_attempt_auto_login_missing_password(self, mock_settings): - """Test auto-login when password is missing from credentials.""" - from cforge.common import attempt_auto_login - from cforge.profile_utils import AuthProfile - from datetime import datetime - - mock_profile = AuthProfile( - id="test-profile", - name="Test", - email="test@example.com", - apiUrl="http://localhost:4444", - isActive=True, - createdAt=datetime.now(), - ) - - with patch("cforge.common.get_active_profile", return_value=mock_profile): - with patch("cforge.common.load_profile_credentials", return_value={"email": "test@example.com"}): - token = attempt_auto_login() - assert token is None - - @patch("cforge.common.requests.post") - def test_attempt_auto_login_success(self, mock_post, mock_settings): - """Test successful auto-login.""" - from cforge.common import attempt_auto_login, load_token - from cforge.profile_utils import AuthProfile - from datetime import datetime - - mock_profile = AuthProfile( - id="test-profile", - name="Test", - email="test@example.com", - apiUrl="http://localhost:4444", - isActive=True, - createdAt=datetime.now(), - ) - - # Mock successful login response - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = {"access_token": "auto-login-token"} - mock_post.return_value = mock_response - - with patch("cforge.common.get_active_profile", return_value=mock_profile): - with patch("cforge.common.load_profile_credentials", return_value={"email": "test@example.com", "password": "test-pass"}): - token = attempt_auto_login() - assert token == "auto-login-token" - - # Verify token was saved - saved_token = load_token() - assert saved_token == "auto-login-token" - - @patch("cforge.common.requests.post") - def test_attempt_auto_login_failed_login(self, mock_post, mock_settings): - """Test auto-login when login fails.""" - from cforge.common import attempt_auto_login - from cforge.profile_utils import AuthProfile - from datetime import datetime - - mock_profile = AuthProfile( - id="test-profile", - name="Test", - email="test@example.com", - apiUrl="http://localhost:4444", - isActive=True, - createdAt=datetime.now(), - ) - - # Mock failed login response - mock_response = Mock() - mock_response.status_code = 401 - mock_post.return_value = mock_response - - with patch("cforge.common.get_active_profile", return_value=mock_profile): - with patch("cforge.common.load_profile_credentials", return_value={"email": "test@example.com", "password": "wrong-pass"}): - token = attempt_auto_login() - assert token is None - - @patch("cforge.common.requests.post") - def test_attempt_auto_login_no_token_in_response(self, mock_post, mock_settings): - """Test auto-login when response doesn't contain token.""" - from cforge.common import attempt_auto_login - from cforge.profile_utils import AuthProfile - from datetime import datetime - - mock_profile = AuthProfile( - id="test-profile", - name="Test", - email="test@example.com", - apiUrl="http://localhost:4444", - isActive=True, - createdAt=datetime.now(), - ) - - # Mock response without token - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = {} - mock_post.return_value = mock_response - - with patch("cforge.common.get_active_profile", return_value=mock_profile): - with patch("cforge.common.load_profile_credentials", return_value={"email": "test@example.com", "password": "test-pass"}): - token = attempt_auto_login() - assert token is None - - @patch("cforge.common.requests.post") - def test_attempt_auto_login_request_exception(self, mock_post, mock_settings): - """Test auto-login when request raises exception.""" - from cforge.common import attempt_auto_login - from cforge.profile_utils import AuthProfile - from datetime import datetime - - mock_profile = AuthProfile( - id="test-profile", - name="Test", - email="test@example.com", - apiUrl="http://localhost:4444", - isActive=True, - createdAt=datetime.now(), - ) - - # Mock request exception - mock_post.side_effect = Exception("Connection error") - - with patch("cforge.common.get_active_profile", return_value=mock_profile): - with patch("cforge.common.load_profile_credentials", return_value={"email": "test@example.com", "password": "test-pass"}): - token = attempt_auto_login() - assert token is None - - def test_get_auth_token_with_auto_login(self, mock_settings): - """Test that get_auth_token attempts auto-login when no token is available.""" - from cforge.common import get_auth_token - - # Mock no env token and no file token, but successful auto-login - with patch("cforge.common.load_token", return_value=None): - with patch("cforge.common.attempt_auto_login", return_value="auto-token"): - token = get_auth_token() - assert token == "auto-token" - - -class TestErrors: - """Tests for custom error classes.""" - - def test_cli_error(self) -> None: - """Test CLIError exception.""" - error = CLIError("Test error") - assert str(error) == "Test error" - - def test_authentication_error(self) -> None: - """Test AuthenticationError exception.""" - error = AuthenticationError("Auth failed") - assert str(error) == "Auth failed" - assert isinstance(error, CLIError) - - -class TestLineLimit: - """Tests for LineLimit class that truncates rendered content.""" - - def test_line_limit_basic_truncation(self) -> None: - """Test that LineLimit truncates content to max_lines.""" - from rich.text import Text - from rich.console import Console - - console = Console() - # Create text with 5 lines - text = Text("Line 1\nLine 2\nLine 3\nLine 4\nLine 5") - limited = LineLimit(text, max_lines=3) - - # Render to string and verify truncation - with console.capture() as capture: - console.print(limited) - - output = capture.get() - # Should contain first 3 lines - assert "Line 1" in output - assert "Line 2" in output - assert "Line 3" in output - # Should NOT contain lines 4 and 5 - assert "Line 4" not in output - assert "Line 5" not in output - # Should have ellipsis - assert "..." in output - - def test_line_limit_no_truncation_needed(self) -> None: - """Test that LineLimit doesn't truncate when content is within limit.""" - from rich.text import Text - from rich.console import Console - - console = Console() - # Create text with 2 lines, limit to 5 - text = Text("Line 1\nLine 2") - limited = LineLimit(text, max_lines=5) - - with console.capture() as capture: - console.print(limited) - - output = capture.get() - # Should contain both lines - assert "Line 1" in output - assert "Line 2" in output - # Should NOT have ellipsis since no truncation - assert "..." not in output - - def test_line_limit_exact_match(self) -> None: - """Test LineLimit when content exactly matches max_lines.""" - from rich.text import Text - from rich.console import Console - - console = Console() - # Create text with exactly 3 lines - text = Text("Line 1\nLine 2\nLine 3") - limited = LineLimit(text, max_lines=3) - - with console.capture() as capture: - console.print(limited) - - output = capture.get() - # Should contain all 3 lines - assert "Line 1" in output - assert "Line 2" in output - assert "Line 3" in output - # Should NOT have ellipsis since content fits exactly - assert "..." not in output - - def test_line_limit_zero_lines(self) -> None: - """Test LineLimit with max_lines=0 shows only ellipsis.""" - from rich.text import Text - from rich.console import Console - - console = Console() - text = Text("Line 1\nLine 2") - limited = LineLimit(text, max_lines=0) - - with console.capture() as capture: - console.print(limited) - - output = capture.get() - # Should only show ellipsis, no content - assert "..." in output - assert "Line 1" not in output - assert "Line 2" not in output - - def test_line_limit_one_line(self) -> None: - """Test LineLimit with max_lines=1.""" - from rich.text import Text - from rich.console import Console - - console = Console() - text = Text("Line 1\nLine 2\nLine 3") - limited = LineLimit(text, max_lines=1) - - with console.capture() as capture: - console.print(limited) - - output = capture.get() - # Should show only first line and ellipsis - assert "Line 1" in output - assert "..." in output - assert "Line 2" not in output - assert "Line 3" not in output - - def test_line_limit_with_long_single_line(self) -> None: - """Test LineLimit with a single long line that wraps.""" - from rich.text import Text - from rich.console import Console - - console = Console(width=80) # Set fixed width for predictable wrapping - # Create a very long line that will wrap - long_text = "A" * 200 - text = Text(long_text) - limited = LineLimit(text, max_lines=2) - - with console.capture() as capture: - console.print(limited) - - output = capture.get() - # Should contain some A's but be truncated - assert "A" in output - # Should have ellipsis since it wraps to more than 2 lines - assert "..." in output - - def test_line_limit_measurement_passthrough(self) -> None: - """Test that LineLimit passes through measurement to wrapped renderable.""" - from rich.text import Text - from rich.console import Console - - console = Console() - text = Text("Test content") - limited = LineLimit(text, max_lines=3) - - # Get measurement using console's options - measurement = console.measure(limited) - - # Should return a valid Measurement - assert measurement is not None - assert hasattr(measurement, "minimum") - assert hasattr(measurement, "maximum") - - def test_line_limit_with_empty_content(self) -> None: - """Test LineLimit with empty content.""" - from rich.text import Text - from rich.console import Console - - console = Console() - text = Text("") - limited = LineLimit(text, max_lines=3) - - with console.capture() as capture: - console.print(limited) - - output = capture.get() - # Empty content should produce minimal output - # Should not have ellipsis since there's nothing to truncate - assert "..." not in output - - def test_line_limit_preserves_styling(self) -> None: - """Test that LineLimit preserves rich styling in truncated content.""" - from rich.text import Text - from rich.console import Console - - console = Console() - # Create styled text - text = Text() - text.append("Line 1\n", style="bold red") - text.append("Line 2\n", style="italic blue") - text.append("Line 3\n", style="underline green") - text.append("Line 4", style="bold yellow") - - limited = LineLimit(text, max_lines=2) - - with console.capture() as capture: - console.print(limited) - - output = capture.get() - # Should contain first 2 lines - assert "Line 1" in output - assert "Line 2" in output - # Should NOT contain lines 3 and 4 - assert "Line 3" not in output - assert "Line 4" not in output - # Should have ellipsis - assert "..." in output - - -class TestMakeAuthenticatedRequest: - """Tests for make_authenticated_request function using a server mock.""" - - def test_request_no_auth_raises_error_when_server_requires_it(self, mock_settings) -> None: - """Test that request without auth raises AuthenticationError when server requires it.""" - # Ensure no token is available - with patch("cforge.common.load_token", return_value=None): - # Mock a 401 response from server (authentication required) - mock_response = Mock() - mock_response.status_code = 401 - mock_response.text = "Unauthorized" - - with patch("cforge.common.requests.request", return_value=mock_response): - with pytest.raises(AuthenticationError) as exc_info: - make_authenticated_request("GET", "/test") - - assert "Authentication required but not configured" in str(exc_info.value) - - def test_request_without_auth_succeeds_on_unauthenticated_server(self, mock_settings) -> None: - """Test that request without auth succeeds when server doesn't require it.""" - # Ensure no token is available - with patch("cforge.common.load_token", return_value=None): - # Mock a successful response from server (no auth required) - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = {"result": "success"} - - with patch("cforge.common.requests.request", return_value=mock_response) as mock_req: - result = make_authenticated_request("GET", "/test") - - # Verify the request was made without Authorization header - call_args = mock_req.call_args - assert "Authorization" not in call_args[1]["headers"] - assert result == {"result": "success"} - - def test_request_with_bearer_token(self, mock_client, mock_settings) -> None: - """Test successful request with Bearer token.""" - mock_client.reset_mock() - with mock_client_login(mock_client): - mock_req = mock_client.request - result = make_authenticated_request("GET", "/tools") - - # Verify request was made correctly - mock_req.assert_called_once() - call_args = mock_req.call_args - assert call_args[1]["method"] == "GET" - assert call_args[1]["url"] == f"http://{mock_client.settings.host}:{mock_client.settings.port}/tools" - assert call_args[1]["headers"]["Authorization"] == f"Bearer {mock_client.settings.mcpgateway_bearer_token}" - assert call_args[1]["headers"]["Content-Type"] == "application/json" - assert isinstance(result, list) - - def test_request_with_basic_auth(self, mock_settings) -> None: - """Test request with Basic auth token.""" - # Set up settings with Basic auth token - mock_settings.mcpgateway_bearer_token = "Basic dGVzdDp0ZXN0" - - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = {"result": "success"} - - with patch("cforge.common.requests.request", return_value=mock_response) as mock_req: - make_authenticated_request("POST", "/api/test", json_data={"data": "value"}) - - # Verify Basic auth is passed as-is - call_args = mock_req.call_args - assert call_args[1]["headers"]["Authorization"] == "Basic dGVzdDp0ZXN0" - - def test_request_api_error(self, mock_settings) -> None: - """Test that API errors are properly raised.""" - mock_settings.mcpgateway_bearer_token = "test_token" - - mock_response = Mock() - mock_response.status_code = 404 - mock_response.text = "Not found" - - with patch("cforge.common.requests.request", return_value=mock_response): - with pytest.raises(CLIError) as exc_info: - make_authenticated_request("GET", "/api/missing") - - assert "API request failed (404)" in str(exc_info.value) - assert "Not found" in str(exc_info.value) - - def test_request_connection_error(self, mock_settings) -> None: - """Test that connection errors are properly raised.""" - mock_settings.mcpgateway_bearer_token = "test_token" - - with patch("cforge.common.requests.request", side_effect=requests.ConnectionError("Connection refused")): - with pytest.raises(CLIError) as exc_info: - make_authenticated_request("GET", "/api/test") - - assert "Failed to connect to gateway" in str(exc_info.value) - assert "Connection refused" in str(exc_info.value) - - -class TestPrettyPrinting: - """Tests for pretty printing functions.""" - - def test_print_json_with_title(self, mock_console) -> None: - """Test print_json with a title.""" - test_data = {"key": "value", "number": 42} - - print_json(test_data, "Test Title") - - # Verify console.print was called - mock_console.print.assert_called_once() - call_args = mock_console.print.call_args[0] - - # Should be wrapped in a Panel - assert isinstance(call_args[0], Panel) - - def test_print_json_without_title(self, mock_console) -> None: - """Test print_json without a title.""" - test_data = {"key": "value"} - - print_json(test_data) - - # Verify console.print was called - mock_console.print.assert_called_once() - call_args = mock_console.print.call_args[0] - - # Should be Syntax object, not Panel - assert isinstance(call_args[0], Syntax) - - def test_print_table(self, mock_console) -> None: - """Test print_table with data.""" - test_data = [ - {"id": 1, "name": "Item 1", "value": "A"}, - {"id": 2, "name": "Item 2", "value": "B"}, - ] - columns = ["id", "name", "value"] - - print_table(test_data, "Test Table", columns) - - # Verify console.print was called - mock_console.print.assert_called_once() - call_args = mock_console.print.call_args[0] - - # Should be a Table - assert isinstance(call_args[0], Table) - - def test_print_table_missing_columns(self, mock_console) -> None: - """Test print_table handles missing columns gracefully.""" - test_data = [ - {"id": 1, "name": "Item 1"}, # Missing 'value' column - ] - columns = ["id", "name", "value"] - - # Should not raise an error - print_table(test_data, "Test Table", columns) - mock_console.print.assert_called_once() - - def test_print_table_wraps_all_cells_with_line_limit(self) -> None: - """Test that print_table wraps all cell values with LineLimit for truncation.""" - from unittest.mock import patch - - # Create test data with various types - test_data = [ - {"id": 1, "name": "Item 1", "description": "Short text"}, - {"id": 2, "name": "Item 2", "description": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"}, - ] - columns = ["id", "name", "description"] - - # Mock Table.add_row to capture what's passed to it - with patch.object(Table, "add_row") as mock_add_row: - print_table(test_data, "Test Table", columns) - - # Verify add_row was called for each data row - assert mock_add_row.call_count == 2 - - # Check that all arguments to add_row are LineLimit instances - for call in mock_add_row.call_args_list: - args = call[0] # Get positional arguments - for arg in args: - assert isinstance(arg, LineLimit), f"Expected LineLimit but got {type(arg)}" - # Verify max_lines is set to 4 - assert arg.max_lines == 4 - - def test_print_table_with_custom_max_lines(self, mock_settings) -> None: - """Test that print_table respects custom table_max_lines configuration.""" - from unittest.mock import patch - - # Configure mock_settings with custom max_lines value - mock_settings.table_max_lines = 2 - - test_data = [ - {"id": 1, "name": "Item 1", "description": "Short text"}, - {"id": 2, "name": "Item 2", "description": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"}, - ] - columns = ["id", "name", "description"] - - # Mock Table.add_row to capture what's passed to it - with patch.object(Table, "add_row") as mock_add_row: - print_table(test_data, "Test Table", columns) - - # Verify add_row was called for each data row - assert mock_add_row.call_count == 2 - - # Check that all arguments to add_row are LineLimit instances with custom max_lines - for call in mock_add_row.call_args_list: - args = call[0] # Get positional arguments - for arg in args: - assert isinstance(arg, LineLimit), f"Expected LineLimit but got {type(arg)}" - # Verify max_lines is set to custom value of 2 - assert arg.max_lines == 2 - - def test_print_table_with_disabled_line_limit(self, mock_settings) -> None: - """Test that print_table skips LineLimit wrapping when table_max_lines is 0 or negative.""" - from unittest.mock import patch - - # Configure mock_settings with disabled max_lines value (0) - mock_settings.table_max_lines = 0 - - test_data = [ - {"id": 1, "name": "Item 1", "description": "Short text"}, - {"id": 2, "name": "Item 2", "description": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"}, - ] - columns = ["id", "name", "description"] - - # Mock Table.add_row to capture what's passed to it - with patch.object(Table, "add_row") as mock_add_row: - print_table(test_data, "Test Table", columns) - - # Verify add_row was called for each data row - assert mock_add_row.call_count == 2 - - # Check that arguments to add_row are plain strings, NOT LineLimit instances - for call in mock_add_row.call_args_list: - args = call[0] # Get positional arguments - for arg in args: - assert isinstance(arg, str), f"Expected str but got {type(arg)}" - assert not isinstance(arg, LineLimit), "Should not wrap with LineLimit when disabled" - - -class TestPromptForSchema: - """Tests for prompt_for_schema function.""" - - def test_prompt_with_prefilled_values(self, mock_console) -> None: - """Test that prefilled values are used and not prompted.""" - - class TestSchema(BaseModel): - name: str - description: str - - prefilled = {"name": "test_name", "description": "test_desc"} - - result = prompt_for_schema(TestSchema, prefilled=prefilled) - - # Should return prefilled values without prompting - assert result == prefilled - # Console should show the prefilled values - assert mock_console.print.call_count >= 3 # Header + 2 fields - - def test_prompt_skips_internal_fields(self, mock_console) -> None: - """Test that internal fields are skipped.""" - - class TestSchema(BaseModel): - name: str - model_config: dict = {} # Should be skipped - auth_value: str = "" # Should be skipped - - prefilled = {"name": "test"} - - result = prompt_for_schema(TestSchema, prefilled=prefilled) - - # Should only have the name field - assert "name" in result - assert "model_config" not in result - assert "auth_value" not in result - - def test_prompt_with_string_field(self, mock_console) -> None: - """Test prompting for string fields.""" - - class TestSchema(BaseModel): - name: str = Field(description="The name") - - with patch("typer.prompt", return_value="user_input"): - result = prompt_for_schema(TestSchema) - - assert result["name"] == "user_input" - - def test_prompt_with_optional_field(self, mock_console) -> None: - """Test prompting for optional fields.""" - - class TestSchema(BaseModel): - required_field: str - optional_field: Optional[str] = None - - with patch("typer.prompt", side_effect=["required_value", ""]): - result = prompt_for_schema(TestSchema) - - assert result["required_field"] == "required_value" - # Optional field with empty input should not be in result - assert "optional_field" not in result or result["optional_field"] == "" - - def test_prompt_with_bool_field(self, mock_console) -> None: - """Test prompting for boolean fields.""" - - class TestSchema(BaseModel): - enabled: bool - - with patch("typer.confirm", return_value=True): - with patch("typer.prompt", return_value=True): - result = prompt_for_schema(TestSchema) - - assert result["enabled"] is True - - def test_prompt_with_optional_bool_field_declined(self, mock_console) -> None: - """Test prompting for optional boolean field that is declined.""" - - class TestSchema(BaseModel): - enabled: Optional[bool] = None - - # First confirm returns False (don't include field) - with patch("typer.confirm", return_value=False): - result = prompt_for_schema(TestSchema) - - # Field should not be in result when declined - assert "enabled" not in result - - def test_prompt_with_int_field(self, mock_console) -> None: - """Test prompting for integer fields.""" - - class TestSchema(BaseModel): - count: int - - with patch("typer.prompt", return_value=42): - result = prompt_for_schema(TestSchema) - - assert result["count"] == 42 - - def test_prompt_with_int_field_empty_input(self, mock_console) -> None: - """Test prompting for optional integer field with empty input.""" - - class TestSchema(BaseModel): - count: Optional[int] = None - - # Return sentinel to simulate skipping optional field - with patch("typer.prompt", return_value=_INT_SENTINEL_DEFAULT): - result = prompt_for_schema(TestSchema) - - # Field should not be in result when empty - assert "count" not in result - - def test_prompt_with_list_field(self, mock_console) -> None: - """Test prompting for list fields.""" - - class TestSchema(BaseModel): - tags: List[str] - - with patch("typer.prompt", return_value="tag1, tag2, tag3"): - result = prompt_for_schema(TestSchema) - - assert result["tags"] == ["tag1", "tag2", "tag3"] - - def test_prompt_with_list_field_empty(self, mock_console) -> None: - """Test prompting for list fields with empty input.""" - - class TestSchema(BaseModel): - tags: Optional[List[str]] = None - - with patch("typer.prompt", return_value=""): - result = prompt_for_schema(TestSchema) - - # Empty input for list should not add the field - assert "tags" not in result or result.get("tags") is None - - def test_prompt_dict_str_str(self, mock_console) -> None: - """Test prompting for a string to string dict""" - - class TestSchema(BaseModel): - key: Dict[str, str] - - with patch("typer.confirm", side_effect=["y", "y", ""]), patch("typer.prompt", side_effect=["k1", "v1", "k2", "v2"]): - result = prompt_for_schema(TestSchema) - - # Empty input for list should not add the field - assert result == { - "key": {"k1": "v1", "k2": "v2"}, - } - - def test_prompt_with_nested_dicts(self, mock_console) -> None: - """Test prompting for a nested dict with dict values""" - - class SubSchema(BaseModel): - num: int - - class TestSchema(BaseModel): - key: Dict[str, Any] - sub: SubSchema - sub_dict: Dict[str, SubSchema] - - with patch("typer.confirm", side_effect=["y", "y", "", "y", ""]), patch("typer.prompt", side_effect=["k1", '{"foo": 1}', "k2", "[1, 2, 3]", 42, "a-num", 123]): - result = prompt_for_schema(TestSchema) - - # Empty input for list should not add the field - assert result == { - "key": {"k1": {"foo": 1}, "k2": [1, 2, 3]}, - "sub": {"num": 42}, - "sub_dict": {"a-num": {"num": 123}}, - } - - def test_prompt_list_of_sub_models(self, mock_console) -> None: - """Test prompting for a list of sub pydantic models""" - - class SubSchema(BaseModel): - num: int - - class TestSchema(BaseModel): - nums: List[SubSchema] - - with patch("typer.confirm", side_effect=["y", "y", ""]), patch("typer.prompt", side_effect=[1, 2]): - result = prompt_for_schema(TestSchema) - - # Empty input for list should not add the field - assert result == {"nums": [{"num": 1}, {"num": 2}]} - - def test_prompt_with_default(self, mock_console) -> None: - """Test prompting with defaults and make sure prompt string added.""" - - class TestSchema(BaseModel): - name: str = "foobar" - some_val: int = 42 - - with patch("typer.prompt", side_effect=["", ""]) as prompt_mock: - prompt_for_schema(TestSchema) - assert prompt_mock.call_count == 2 - assert prompt_mock.call_args_list[0][1]["default"] == "foobar" - assert prompt_mock.call_args_list[1][1]["default"] == 42 - assert any("foobar" in call[0][0] for call in mock_console.print.call_args_list) - assert any("42" in call[0][0] for call in mock_console.print.call_args_list) - - def test_prompt_missing_required_string(self, mock_console) -> None: - """Test that an exception is raised if a required string is unset.""" - - class TestSchema(BaseModel): - foo: str - - with patch("typer.prompt", return_value=""): - with pytest.raises(CLIError): - prompt_for_schema(TestSchema) - - -class TestTokenFilePermissions: - """Tests for token file permission handling.""" - - def test_save_token_creates_parent_dirs(self) -> None: - """Test that save_token creates parent directories.""" - with tempfile.TemporaryDirectory() as temp_dir: - token_path = Path(temp_dir) / "nested" / "dirs" / "token" - - with patch("cforge.common.get_token_file", return_value=token_path): - save_token("test_token") - - assert token_path.exists() - assert token_path.read_text() == "test_token" - - def test_save_token_sets_permissions(self) -> None: - """Test that save_token sets restrictive permissions.""" - with tempfile.NamedTemporaryFile(delete=False) as temp_file: - token_path = Path(temp_file.name) - - try: - with patch("cforge.common.get_token_file", return_value=token_path): - save_token("test_token") - - # Check permissions are 0o600 (read/write for owner only) - file_stat = token_path.stat() - file_mode = stat.S_IMODE(file_stat.st_mode) - assert file_mode == 0o600 - finally: - token_path.unlink(missing_ok=True) - - -class TestMakeAuthenticatedRequestIntegration: - """Integration tests for make_authenticated_request with real server. - - These tests use the session_settings fixture which provides a real - running mcpgateway server and properly configured settings. This validates - that the client code actually works with the server, not just that it - makes the right calls. - """ - - def test_request_with_bearer_token_to_health_endpoint(self, mock_client) -> None: - """Test successful authenticated request to /health endpoint.""" - - # Make a request to the health endpoint (no auth required) - make_authenticated_request("GET", "/health") - - # Make a request to an authorized endpoint before login - with pytest.raises(CLIError): - make_authenticated_request("GET", "/tools") - - # Log in and try again - with mock_client_login(mock_client): - - # Make a real HTTP request to the session server's health endpoint - result = make_authenticated_request("GET", "/tools") - - # The tools endpoint should return a successful response - assert result is not None - assert isinstance(result, list) - - def test_request_to_nonexistent_endpoint_raises_error(self, authorized_mock_client) -> None: - """Test that requesting a nonexistent endpoint raises CLIError.""" - # Try to request an endpoint that doesn't exist - with pytest.raises(CLIError) as exc_info: - make_authenticated_request("GET", "/api/this/endpoint/does/not/exist") - - # Should get a 404 error - assert "404" in str(exc_info.value) or "not found" in str(exc_info.value).lower() - - def test_request_with_params_and_json_data(self, authorized_mock_client) -> None: - """Test request with query parameters. - - This test verifies that parameters are correctly passed through - to the server in a real HTTP request. - """ - # Test that we can make requests with params - # The health endpoint may not use params, but we can verify the request succeeds - result = make_authenticated_request("GET", "/health", params={"test": "value"}) - - # Should still get a valid response even with unused params - assert result is not None - assert isinstance(result, dict) diff --git a/tests/test_config.py b/tests/test_config.py index fae4f80..6ac9f39 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,6 +6,7 @@ Tests for configuration management. """ + # Standard from pathlib import Path from unittest import mock