diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 7a3ce119..a97051fd 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -60,7 +60,8 @@ "WebFetch(domain:crates.io)", "Bash(npx -y @modelcontextprotocol/conformance list)", "Bash(target/debug/hello-world:*)", - "Bash(./target/debug/resources-demo:*)" + "Bash(./target/debug/resources-demo:*)", + "Bash(ls:*)" ], "deny": [] } diff --git a/Cargo.lock b/Cargo.lock index ffb7f9ee..c0cbac24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -526,7 +526,7 @@ dependencies = [ [[package]] name = "conformance-tests" -version = "0.16.0" +version = "0.17.0" dependencies = [ "anyhow", "chrono", @@ -1051,7 +1051,7 @@ dependencies = [ [[package]] name = "hello-world-with-auth" -version = "0.16.0" +version = "0.17.0" dependencies = [ "anyhow", "async-trait", @@ -2107,7 +2107,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-auth" -version = "0.16.0" +version = "0.17.0" dependencies = [ "aes-gcm", "anyhow", @@ -2145,9 +2145,23 @@ dependencies = [ "zeroize", ] +[[package]] +name = "pulseengine-mcp-client" +version = "0.17.0" +dependencies = [ + "async-trait", + "futures", + "pulseengine-mcp-protocol", + "serde", + "serde_json", + "thiserror 2.0.12", + "tokio", + "tracing", +] + [[package]] name = "pulseengine-mcp-external-validation" -version = "0.16.0" +version = "0.17.0" dependencies = [ "anyhow", "arbitrary", @@ -2185,7 +2199,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-integration-tests" -version = "0.16.0" +version = "0.17.0" dependencies = [ "anyhow", "assert_matches", @@ -2211,7 +2225,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-logging" -version = "0.16.0" +version = "0.17.0" dependencies = [ "chrono", "hex", @@ -2229,7 +2243,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-macros" -version = "0.16.0" +version = "0.17.0" dependencies = [ "anyhow", "async-trait", @@ -2255,7 +2269,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-protocol" -version = "0.16.0" +version = "0.17.0" dependencies = [ "async-trait", "chrono", @@ -2272,7 +2286,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-security" -version = "0.16.0" +version = "0.17.0" dependencies = [ "anyhow", "async-trait", @@ -2294,7 +2308,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-security-middleware" -version = "0.16.0" +version = "0.17.0" dependencies = [ "anyhow", "assert_matches", @@ -2327,7 +2341,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-server" -version = "0.16.0" +version = "0.17.0" dependencies = [ "anyhow", "async-trait", @@ -2355,7 +2369,7 @@ dependencies = [ [[package]] name = "pulseengine-mcp-transport" -version = "0.16.0" +version = "0.17.0" dependencies = [ "anyhow", "async-stream", diff --git a/Cargo.toml b/Cargo.toml index 42df8421..1e5e8506 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "mcp-security-middleware", "mcp-transport", "mcp-server", + "mcp-client", "mcp-macros", "mcp-external-validation", "integration-tests", @@ -23,7 +24,7 @@ members = [ resolver = "2" [workspace.package] -version = "0.16.0" +version = "0.17.0" rust-version = "1.88" edition = "2024" license = "MIT OR Apache-2.0" @@ -101,15 +102,16 @@ assert_matches = "1.5" serde_yaml = "0.9" # Framework internal dependencies (published versions) -pulseengine-mcp-protocol = { version = "0.16.0", path = "mcp-protocol" } -pulseengine-mcp-logging = { version = "0.16.0", path = "mcp-logging" } -pulseengine-mcp-auth = { version = "0.16.0", path = "mcp-auth" } -pulseengine-mcp-security = { version = "0.16.0", path = "mcp-security" } -pulseengine-mcp-security-middleware = { version = "0.16.0", path = "mcp-security-middleware" } -pulseengine-mcp-transport = { version = "0.16.0", path = "mcp-transport" } -pulseengine-mcp-server = { version = "0.16.0", path = "mcp-server" } -pulseengine-mcp-macros = { version = "0.16.0", path = "mcp-macros" } -pulseengine-mcp-external-validation = { version = "0.16.0", path = "mcp-external-validation" } +pulseengine-mcp-protocol = { version = "0.17.0", path = "mcp-protocol" } +pulseengine-mcp-logging = { version = "0.17.0", path = "mcp-logging" } +pulseengine-mcp-auth = { version = "0.17.0", path = "mcp-auth" } +pulseengine-mcp-security = { version = "0.17.0", path = "mcp-security" } +pulseengine-mcp-security-middleware = { version = "0.17.0", path = "mcp-security-middleware" } +pulseengine-mcp-transport = { version = "0.17.0", path = "mcp-transport" } +pulseengine-mcp-server = { version = "0.17.0", path = "mcp-server" } +pulseengine-mcp-macros = { version = "0.17.0", path = "mcp-macros" } +pulseengine-mcp-external-validation = { version = "0.17.0", path = "mcp-external-validation" } +pulseengine-mcp-client = { version = "0.17.0", path = "mcp-client" } [profile.release] opt-level = "s" @@ -152,3 +154,4 @@ pulseengine-mcp-transport = { path = "mcp-transport" } pulseengine-mcp-server = { path = "mcp-server" } pulseengine-mcp-macros = { path = "mcp-macros" } pulseengine-mcp-external-validation = { path = "mcp-external-validation" } +pulseengine-mcp-client = { path = "mcp-client" } diff --git a/Dockerfile.validation b/Dockerfile.validation index db53d5c3..db5373fa 100644 --- a/Dockerfile.validation +++ b/Dockerfile.validation @@ -26,6 +26,7 @@ COPY mcp-security-middleware ./mcp-security-middleware/ COPY mcp-transport ./mcp-transport/ COPY mcp-macros ./mcp-macros/ COPY mcp-server ./mcp-server/ +COPY mcp-client ./mcp-client/ COPY mcp-external-validation ./mcp-external-validation/ COPY conformance-tests ./conformance-tests/ COPY examples ./examples/ diff --git a/README.md b/README.md index 46136560..7ade5627 100644 --- a/README.md +++ b/README.md @@ -1,294 +1,70 @@ -# PulseEngine MCP Framework for Rust +# PulseEngine MCP -**Build production-ready Model Context Protocol servers with confidence** +Rust framework for building [Model Context Protocol](https://modelcontextprotocol.io/) servers and clients. -[![License](https://img.shields.io/badge/license-MIT%20OR%20Apache--2.0-blue.svg)](LICENSE) +[![Crates.io](https://img.shields.io/crates/v/pulseengine-mcp-protocol.svg)](https://crates.io/crates/pulseengine-mcp-protocol) [![Documentation](https://docs.rs/pulseengine-mcp-protocol/badge.svg)](https://docs.rs/pulseengine-mcp-protocol) -[![codecov](https://codecov.io/gh/pulseengine/mcp/graph/badge.svg?token=ZGAL6V3SQR)](https://codecov.io/gh/pulseengine/mcp) [![CI](https://github.com/pulseengine/mcp/actions/workflows/pr-validation.yml/badge.svg)](https://github.com/pulseengine/mcp/actions/workflows/pr-validation.yml) +[![codecov](https://codecov.io/gh/pulseengine/mcp/graph/badge.svg?token=ZGAL6V3SQR)](https://codecov.io/gh/pulseengine/mcp) -This framework provides everything you need to build production-ready MCP servers in Rust. It's been developed and proven through a real-world home automation server with 30+ tools that successfully integrates with MCP Inspector, Claude Desktop, and HTTP clients. - -**πŸŽ‰ MCP 2025-11-25 Support** - Full implementation of the latest MCP specification including Tasks, Tool Calling in Sampling, and Enhanced Elicitation! - -**πŸ–ΌοΈ MCP Apps Extension Support** - First production Rust framework supporting [SEP-1865](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/1865) for interactive HTML user interfaces! - -## What is MCP? - -The [Model Context Protocol](https://modelcontextprotocol.io/) enables AI assistants to securely connect to and interact with external systems through tools, resources, and prompts. Instead of AI models having static knowledge, they can dynamically access live data and perform actions through MCP servers. - -## Why This Framework? - -**πŸ—οΈ Production-Proven:** This framework powers a working Loxone home automation server that handles real-world complexity - device control, sensor monitoring, authentication, and concurrent operations. - -**πŸ”§ Complete Infrastructure:** You focus on your domain logic (databases, APIs, file systems) while the framework handles protocol compliance, transport layers, security, and monitoring. - -**πŸ“‘ Multiple Transport Support:** Works with Claude Desktop (stdio), web applications (HTTP), real-time apps (WebSocket), and tools like MCP Inspector. - -## Quick Start - -Add to your `Cargo.toml`: - -```toml -[dependencies] -pulseengine-mcp-server = "0.15" -pulseengine-mcp-protocol = "0.15" -tokio = { version = "1.0", features = ["full"] } -async-trait = "0.1" -``` - -Create your first MCP server: +## Example ```rust -use pulseengine_mcp_server::{McpServer, McpBackend, ServerConfig}; -use pulseengine_mcp_protocol::*; -use async_trait::async_trait; - -#[derive(Clone)] -struct MyBackend; - -#[async_trait] -impl McpBackend for MyBackend { - type Error = Box; - type Config = (); - - async fn initialize(_: Self::Config) -> Result { - Ok(MyBackend) - } - - fn get_server_info(&self) -> ServerInfo { - ServerInfo { - protocol_version: ProtocolVersion::default(), // MCP 2025-11-25 - capabilities: ServerCapabilities::builder() - .enable_tools() - .build(), - server_info: Implementation::with_description( - "My MCP Server", - "1.0.0", - "A simple example server", - ), - instructions: Some("Use 'hello' tool to greet someone".to_string()), - } - } - - async fn list_tools(&self, _: PaginatedRequestParam) -> Result { - Ok(ListToolsResult { - tools: vec![ - Tool { - name: "hello".to_string(), - description: "Say hello to someone".to_string(), - input_schema: serde_json::json!({ - "type": "object", - "properties": { - "name": {"type": "string", "description": "Name to greet"} - }, - "required": ["name"] - }), - output_schema: None, - title: None, - annotations: None, - icons: None, - execution: None, - _meta: None, - } - ], - next_cursor: None, - }) - } +use pulseengine_mcp_macros::{mcp_server, mcp_tools}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; - async fn call_tool(&self, request: CallToolRequestParam) -> Result { - match request.name.as_str() { - "hello" => { - let name = request.arguments - .and_then(|args| args.get("name")) - .and_then(|v| v.as_str()) - .unwrap_or("World"); +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct GreetParams { + pub name: Option, +} - Ok(CallToolResult::text(format!("Hello, {}!", name))) - } - _ => Err("Unknown tool".into()), - } - } +#[mcp_server(name = "My Server")] +#[derive(Default, Clone)] +pub struct MyServer; - // Simple implementations for unused features - async fn list_resources(&self, _: PaginatedRequestParam) -> Result { - Ok(ListResourcesResult { resources: vec![], next_cursor: None }) - } - async fn read_resource(&self, _: ReadResourceRequestParam) -> Result { - Err("No resources".into()) - } - async fn list_prompts(&self, _: PaginatedRequestParam) -> Result { - Ok(ListPromptsResult { prompts: vec![], next_cursor: None }) - } - async fn get_prompt(&self, _: GetPromptRequestParam) -> Result { - Err("No prompts".into()) +#[mcp_tools] +impl MyServer { + /// Greet someone by name + pub async fn greet(&self, params: GreetParams) -> anyhow::Result { + let name = params.name.unwrap_or_else(|| "World".to_string()); + Ok(format!("Hello, {name}!")) } } #[tokio::main] async fn main() -> Result<(), Box> { - let backend = MyBackend::initialize(()).await?; - let config = ServerConfig::default(); - let mut server = McpServer::new(backend, config).await?; - server.run().await?; - Ok(()) + MyServer::configure_stdio_logging(); + MyServer::with_defaults().serve_stdio().await?.run().await } ``` -## Framework Components - -### πŸ”§ [mcp-protocol](mcp-protocol/) - Core Protocol Types - -- MCP request/response types with validation -- JSON-RPC 2.0 support and error handling -- Schema validation for tool parameters -- **MCP Apps Extension support** - `ui://` resources, tool metadata, `text/html+mcp` - -### πŸ—οΈ [mcp-server](mcp-server/) - Server Infrastructure - -- Pluggable backend system via `McpBackend` trait -- Request routing and protocol compliance -- Middleware integration for auth, security, monitoring - -### πŸ“‘ [mcp-transport](mcp-transport/) - Multiple Transports - -- stdio (Claude Desktop), HTTP (web apps), WebSocket (real-time) -- MCP Inspector compatibility with content negotiation -- Session management and CORS support - -### πŸ”‘ [mcp-auth](mcp-auth/) - Authentication Framework - -- API key management with role-based access control -- Rate limiting and IP whitelisting -- Audit logging and security features +The `#[mcp_server]` and `#[mcp_tools]` macros generate the protocol implementation. Tool schemas are derived from your Rust types via `JsonSchema`. -### πŸ›‘οΈ [mcp-security](mcp-security/) - Security Middleware +## Crates -- Input validation and XSS/injection prevention -- Request size limits and parameter validation -- CORS policies and security headers - -### πŸ“ [mcp-logging](mcp-logging/) - Structured Logging - -- JSON logging with correlation IDs -- Automatic credential sanitization -- MCP `logging/setLevel` conformance -- Security audit trails - -### βš™οΈ [mcp-macros](mcp-macros/) - Procedural Macros - -- `#[mcp_server]` - Generate server boilerplate -- `#[mcp_tool]` - Define tools with schema generation -- `#[mcp_resource]` - Define parameterized resources -- `#[mcp_backend]` - Derive backend implementations +| Crate | Description | +| ------------------------------- | -------------------------------------------------- | +| [mcp-protocol](mcp-protocol/) | MCP types, JSON-RPC, schema validation | +| [mcp-server](mcp-server/) | Server infrastructure with `McpBackend` trait | +| [mcp-client](mcp-client/) | Client for connecting to MCP servers | +| [mcp-transport](mcp-transport/) | stdio, HTTP, WebSocket transports | +| [mcp-auth](mcp-auth/) | Authentication, API keys, OAuth 2.1 | +| [mcp-security](mcp-security/) | Input validation, rate limiting | +| [mcp-logging](mcp-logging/) | Structured logging with credential sanitization | +| [mcp-macros](mcp-macros/) | `#[mcp_server]`, `#[mcp_tools]`, `#[mcp_resource]` | ## Examples -### 🌍 [Hello World](examples/hello-world/) - -Complete minimal MCP server demonstrating basic concepts. - -### πŸ” [Hello World with Auth](examples/hello-world-with-auth/) - -MCP server with full authentication and authorization. - -### 🎨 [UI-Enabled Server](examples/ui-enabled-server/) - -**MCP Apps Extension demonstration** with interactive HTML interfaces: - -- Tool with UI resource link -- `ui://` URI scheme usage -- `text/html+mcp` MIME type -- Complete testing guide - -### πŸ“ [Resources Demo](examples/resources-demo/) - -Demonstrates `#[mcp_resource]` macro for parameterized resources with URI templates. - -### ⚑ [Ultra Simple](examples/ultra-simple/) - -The absolute minimum MCP server implementation. - -### 🏠 Real-World Reference: Loxone MCP Server - -The framework was extracted from a production Loxone home automation server that provides: - -- **30+ Tools** - Complete home automation control (lighting, climate, security, energy) -- **Multiple Transports** - Works with Claude Desktop, MCP Inspector, n8n workflows -- **Production Security** - API keys, rate limiting, input validation, audit logging -- **Real-Time Integration** - WebSocket support for live device status updates -- **Proven Reliability** - Handles concurrent operations and error conditions - -## Development Workflow +- [hello-world](examples/hello-world/) - Minimal server +- [hello-world-with-auth](examples/hello-world-with-auth/) - With authentication +- [resources-demo](examples/resources-demo/) - Resource templates with `#[mcp_resource]` +- [ui-enabled-server](examples/ui-enabled-server/) - MCP Apps extension (SEP-1865) -### Building the Framework +## MCP Spec -```bash -# Build all framework crates -cargo build --workspace - -# Test all crates -cargo test --workspace - -# Run examples -cargo run --bin hello-world-server -``` - -### Creating Your MCP Server - -1. **Choose Your Domain** - What system do you want to make accessible via MCP? -2. **Implement McpBackend** - Define your tools, resources, and prompts -3. **Configure Transport** - stdio for Claude Desktop, HTTP for web clients -4. **Add Security** - Authentication, validation, and monitoring as needed -5. **Deploy** - Native binary, Docker container, or WebAssembly - -## Architecture - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ MCP Clients β”‚ β”‚ Your Backend β”‚ β”‚ External Systemsβ”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β€’ Claude Desktopβ”‚ β”‚ β€’ Tools β”‚ β”‚ β€’ Databases β”‚ -β”‚ β€’ MCP Inspector │◄──►│ β€’ Resources │◄──►│ β€’ APIs β”‚ -β”‚ β€’ Web Apps β”‚ β”‚ β€’ Prompts β”‚ β”‚ β€’ File Systems β”‚ -β”‚ β€’ Custom Clientsβ”‚ β”‚ β”‚ β”‚ β€’ Hardware β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ - β”‚ β”‚ - β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” - β”‚ β”‚ MCP Framework β”‚ - β”‚ β”‚ β”‚ - └──────────────►│ β€’ Protocol β”‚ - β”‚ β€’ Transport β”‚ - β”‚ β€’ Security β”‚ - β”‚ β€’ Monitoring β”‚ - β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - -## Contributing - -This framework grows from real-world usage. The most valuable contributions come from: - -1. **New Backend Examples** - Show how to integrate different types of systems -2. **Production Patterns** - Share patterns from your own MCP server deployments -3. **Client Compatibility** - Test with different MCP clients and report issues -4. **Performance Improvements** - Optimizations based on real usage patterns -5. **Security Enhancements** - Better validation, authentication, or audit capabilities - -## Community - -- **Documentation** - [docs.rs/pulseengine-mcp-protocol](https://docs.rs/pulseengine-mcp-protocol) -- **Issues** - [GitHub Issues](https://github.com/pulseengine/mcp/issues) -- **Discussions** - [GitHub Discussions](https://github.com/pulseengine/mcp/discussions) +Implements MCP 2025-11-25: tools, resources, prompts, completions, sampling, roots, logging, progress, cancellation, tasks, and elicitation. ## License -Licensed under either of: - -- Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE)) -- MIT license ([LICENSE-MIT](LICENSE-MIT)) - -at your option. - ---- - -**Built by developers who needed a robust MCP framework for real production use.** Start building your MCP server today with confidence that the foundation has been proven in demanding real-world scenarios. +MIT OR Apache-2.0 diff --git a/mcp-client/Cargo.toml b/mcp-client/Cargo.toml new file mode 100644 index 00000000..e6002390 --- /dev/null +++ b/mcp-client/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "pulseengine-mcp-client" +description = "MCP client implementation for connecting to MCP servers" +version.workspace = true +rust-version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +repository.workspace = true +homepage.workspace = true +documentation = "https://docs.rs/pulseengine-mcp-client" +keywords = ["mcp", "client", "protocol", "ai", "rust"] +categories = ["api-bindings", "development-tools", "asynchronous"] + +[dependencies] +# Core dependencies from workspace +tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +async-trait = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } +futures = { workspace = true } + +# Internal dependencies +pulseengine-mcp-protocol = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["test-util", "macros", "rt-multi-thread"] } + +[lints] +workspace = true diff --git a/mcp-client/src/client.rs b/mcp-client/src/client.rs new file mode 100644 index 00000000..1e02459a --- /dev/null +++ b/mcp-client/src/client.rs @@ -0,0 +1,466 @@ +//! MCP Client implementation +//! +//! The main client struct for interacting with MCP servers. + +use crate::error::{ClientError, ClientResult}; +use crate::transport::{ClientTransport, JsonRpcMessage, next_request_id}; +use pulseengine_mcp_protocol::{ + CallToolRequestParam, CallToolResult, CompleteRequestParam, CompleteResult, + GetPromptRequestParam, GetPromptResult, Implementation, InitializeRequestParam, + InitializeResult, ListPromptsResult, ListResourceTemplatesResult, ListResourcesResult, + ListToolsResult, NumberOrString, PaginatedRequestParam, ReadResourceRequestParam, + ReadResourceResult, Request, Response, +}; +use serde_json::json; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{Mutex, oneshot}; +use tracing::{debug, info, warn}; + +/// Default timeout for requests +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); + +/// MCP Client for connecting to MCP servers +/// +/// Provides a high-level API for interacting with MCP servers, +/// handling request/response correlation and protocol details. +pub struct McpClient { + transport: Arc, + /// Pending requests waiting for responses + pending: Arc>>>, + /// Server info after initialization + server_info: Option, + /// Default request timeout + timeout: Duration, + /// Client info sent during initialization + client_info: Implementation, +} + +impl McpClient { + /// Create a new MCP client with the given transport + pub fn new(transport: T) -> Self { + Self { + transport: Arc::new(transport), + pending: Arc::new(Mutex::new(HashMap::new())), + server_info: None, + timeout: DEFAULT_TIMEOUT, + client_info: Implementation::new("pulseengine-mcp-client", env!("CARGO_PKG_VERSION")), + } + } + + /// Set the default request timeout + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + /// Set the client info for initialization + pub fn with_client_info(mut self, name: &str, version: &str) -> Self { + self.client_info = Implementation::new(name, version); + self + } + + /// Get the server info (available after initialization) + pub fn server_info(&self) -> Option<&InitializeResult> { + self.server_info.as_ref() + } + + /// Check if the client has been initialized + pub fn is_initialized(&self) -> bool { + self.server_info.is_some() + } + + /// Initialize the connection with the server + /// + /// This must be called before any other methods. + pub async fn initialize( + &mut self, + client_name: &str, + client_version: &str, + ) -> ClientResult { + self.client_info = Implementation::new(client_name, client_version); + + let params = InitializeRequestParam { + protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(), + capabilities: json!({}), // Empty capabilities - server will respond with its capabilities + client_info: self.client_info.clone(), + }; + + let result: InitializeResult = self.request("initialize", params).await?; + + info!( + "Initialized with server: {} v{}", + result.server_info.name, result.server_info.version + ); + + self.server_info = Some(result.clone()); + + // Send initialized notification + self.notify("notifications/initialized", json!({})).await?; + + Ok(result) + } + + // ========================================================================= + // Tools API + // ========================================================================= + + /// List available tools from the server + pub async fn list_tools(&self) -> ClientResult { + self.ensure_initialized()?; + self.request("tools/list", PaginatedRequestParam { cursor: None }) + .await + } + + /// List all tools, automatically handling pagination + pub async fn list_all_tools(&self) -> ClientResult> { + self.ensure_initialized()?; + let mut all_tools = Vec::new(); + let mut cursor = None; + + loop { + let result: ListToolsResult = self + .request("tools/list", PaginatedRequestParam { cursor }) + .await?; + + all_tools.extend(result.tools); + + match result.next_cursor { + Some(next) => cursor = Some(next), + None => break, + } + } + + Ok(all_tools) + } + + /// Call a tool on the server + pub async fn call_tool( + &self, + name: &str, + arguments: serde_json::Value, + ) -> ClientResult { + self.ensure_initialized()?; + self.request( + "tools/call", + CallToolRequestParam { + name: name.to_string(), + arguments: Some(arguments), + }, + ) + .await + } + + // ========================================================================= + // Resources API + // ========================================================================= + + /// List available resources from the server + pub async fn list_resources(&self) -> ClientResult { + self.ensure_initialized()?; + self.request("resources/list", PaginatedRequestParam { cursor: None }) + .await + } + + /// List all resources, automatically handling pagination + pub async fn list_all_resources( + &self, + ) -> ClientResult> { + self.ensure_initialized()?; + let mut all_resources = Vec::new(); + let mut cursor = None; + + loop { + let result: ListResourcesResult = self + .request("resources/list", PaginatedRequestParam { cursor }) + .await?; + + all_resources.extend(result.resources); + + match result.next_cursor { + Some(next) => cursor = Some(next), + None => break, + } + } + + Ok(all_resources) + } + + /// Read a resource from the server + pub async fn read_resource(&self, uri: &str) -> ClientResult { + self.ensure_initialized()?; + self.request( + "resources/read", + ReadResourceRequestParam { + uri: uri.to_string(), + }, + ) + .await + } + + /// List resource templates from the server + pub async fn list_resource_templates(&self) -> ClientResult { + self.ensure_initialized()?; + self.request( + "resources/templates/list", + PaginatedRequestParam { cursor: None }, + ) + .await + } + + // ========================================================================= + // Prompts API + // ========================================================================= + + /// List available prompts from the server + pub async fn list_prompts(&self) -> ClientResult { + self.ensure_initialized()?; + self.request("prompts/list", PaginatedRequestParam { cursor: None }) + .await + } + + /// List all prompts, automatically handling pagination + pub async fn list_all_prompts(&self) -> ClientResult> { + self.ensure_initialized()?; + let mut all_prompts = Vec::new(); + let mut cursor = None; + + loop { + let result: ListPromptsResult = self + .request("prompts/list", PaginatedRequestParam { cursor }) + .await?; + + all_prompts.extend(result.prompts); + + match result.next_cursor { + Some(next) => cursor = Some(next), + None => break, + } + } + + Ok(all_prompts) + } + + /// Get a prompt by name + pub async fn get_prompt( + &self, + name: &str, + arguments: Option>, + ) -> ClientResult { + self.ensure_initialized()?; + self.request( + "prompts/get", + GetPromptRequestParam { + name: name.to_string(), + arguments, + }, + ) + .await + } + + // ========================================================================= + // Completion API + // ========================================================================= + + /// Request completion suggestions + pub async fn complete(&self, params: CompleteRequestParam) -> ClientResult { + self.ensure_initialized()?; + self.request("completion/complete", params).await + } + + // ========================================================================= + // Utility Methods + // ========================================================================= + + /// Send a ping to the server + pub async fn ping(&self) -> ClientResult<()> { + self.ensure_initialized()?; + let _: serde_json::Value = self.request("ping", json!({})).await?; + Ok(()) + } + + /// Close the client connection + pub async fn close(&self) -> ClientResult<()> { + self.transport.close().await + } + + // ========================================================================= + // Notification Methods + // ========================================================================= + + /// Send a progress notification + pub async fn notify_progress( + &self, + progress_token: &str, + progress: f64, + total: Option, + ) -> ClientResult<()> { + self.notify( + "notifications/progress", + json!({ + "progressToken": progress_token, + "progress": progress, + "total": total, + }), + ) + .await + } + + /// Send a cancellation notification + pub async fn notify_cancelled( + &self, + request_id: &str, + reason: Option<&str>, + ) -> ClientResult<()> { + self.notify( + "notifications/cancelled", + json!({ + "requestId": request_id, + "reason": reason, + }), + ) + .await + } + + /// Send a roots list changed notification + pub async fn notify_roots_list_changed(&self) -> ClientResult<()> { + self.notify("notifications/roots/list_changed", json!({})) + .await + } + + // ========================================================================= + // Internal Methods + // ========================================================================= + + /// Ensure the client has been initialized + fn ensure_initialized(&self) -> ClientResult<()> { + if self.server_info.is_none() { + return Err(ClientError::NotInitialized); + } + Ok(()) + } + + /// Send a request and wait for the response + async fn request(&self, method: &str, params: P) -> ClientResult + where + P: serde::Serialize, + R: serde::de::DeserializeOwned, + { + let id = next_request_id(); + let id_str = match &id { + NumberOrString::Number(n) => n.to_string(), + NumberOrString::String(s) => s.to_string(), + }; + + let request = Request { + jsonrpc: "2.0".to_string(), + method: method.to_string(), + params: serde_json::to_value(params)?, + id: Some(id), + }; + + // Create channel for response + let (tx, rx) = oneshot::channel(); + + // Register pending request + { + let mut pending = self.pending.lock().await; + pending.insert(id_str.clone(), tx); + } + + // Send request + self.transport.send(&request).await?; + + debug!("Sent request: method={}, id={}", method, id_str); + + // Wait for response with timeout + let response = tokio::select! { + result = self.wait_for_response(rx) => result?, + _ = tokio::time::sleep(self.timeout) => { + // Remove from pending on timeout + let mut pending = self.pending.lock().await; + pending.remove(&id_str); + return Err(ClientError::Timeout(self.timeout)); + } + }; + + // Check for error response + if let Some(error) = response.error { + return Err(ClientError::from_protocol_error(error)); + } + + // Parse result + let result = response + .result + .ok_or_else(|| ClientError::protocol("Response has no result or error"))?; + + serde_json::from_value(result).map_err(ClientError::from) + } + + /// Wait for a response and handle incoming messages + async fn wait_for_response( + &self, + mut rx: oneshot::Receiver, + ) -> ClientResult { + // In a simple implementation, we just read messages until we get our response + // A more sophisticated implementation would use a background task + loop { + tokio::select! { + biased; + + // Check if response arrived via channel (priority) + result = &mut rx => { + return result.map_err(|_| ClientError::ChannelClosed("Response channel closed".into())); + } + // Read next message from transport + msg = self.transport.recv() => { + match msg? { + JsonRpcMessage::Response(response) => { + // Route response to waiting request + let id_str = response.id.as_ref().map(|id| match id { + NumberOrString::Number(n) => n.to_string(), + NumberOrString::String(s) => s.to_string(), + }); + + if let Some(id) = id_str { + let mut pending = self.pending.lock().await; + if let Some(tx) = pending.remove(&id) { + let _ = tx.send(response); + } else { + warn!("Received response for unknown request: {}", id); + } + } + } + JsonRpcMessage::Request(request) => { + // Handle server-initiated request (sampling, etc.) + // For now, log and continue - could add a handler callback + warn!("Received server request (not yet handled): {}", request.method); + } + JsonRpcMessage::Notification { method, params: _ } => { + // Handle notification from server + debug!("Received notification: {}", method); + } + } + } + } + } + } + + /// Send a notification (no response expected) + async fn notify

(&self, method: &str, params: P) -> ClientResult<()> + where + P: serde::Serialize, + { + let request = Request { + jsonrpc: "2.0".to_string(), + method: method.to_string(), + params: serde_json::to_value(params)?, + id: None, // No ID for notifications + }; + + self.transport.send(&request).await?; + debug!("Sent notification: method={}", method); + Ok(()) + } +} diff --git a/mcp-client/src/client_tests.rs b/mcp-client/src/client_tests.rs new file mode 100644 index 00000000..82aa7d7a --- /dev/null +++ b/mcp-client/src/client_tests.rs @@ -0,0 +1,141 @@ +//! Tests for MCP client + +use crate::client::McpClient; +use crate::error::ClientError; +use crate::transport::{JsonRpcMessage, StdioClientTransport}; +use std::time::Duration; +use tokio::io::{DuplexStream, duplex}; + +/// Create a mock transport for testing +fn create_mock_transport() -> ( + StdioClientTransport, + DuplexStream, + DuplexStream, +) { + let (client_read, server_write) = duplex(1024); + let (server_read, client_write) = duplex(1024); + + let transport = StdioClientTransport::new(client_read, client_write); + + (transport, server_read, server_write) +} + +#[tokio::test] +async fn test_client_creation() { + let (transport, _server_read, _server_write) = create_mock_transport(); + let client = McpClient::new(transport); + + assert!(!client.is_initialized()); + assert!(client.server_info().is_none()); +} + +#[tokio::test] +async fn test_client_not_initialized_error() { + let (transport, _server_read, _server_write) = create_mock_transport(); + let client = McpClient::new(transport); + + // Trying to list tools without initialization should fail + let result = client.list_tools().await; + assert!(matches!(result, Err(ClientError::NotInitialized))); +} + +#[tokio::test] +async fn test_client_with_timeout() { + let (transport, _server_read, _server_write) = create_mock_transport(); + let client = McpClient::new(transport).with_timeout(Duration::from_secs(60)); + + // Timeout is set internally, verify client was created + assert!(!client.is_initialized()); +} + +#[tokio::test] +async fn test_client_with_client_info() { + let (transport, _server_read, _server_write) = create_mock_transport(); + let client = McpClient::new(transport).with_client_info("test-client", "1.0.0"); + + // Client info is set internally, we verify it works by checking the client was created + assert!(!client.is_initialized()); +} + +#[test] +fn test_json_rpc_message_parse_response() { + let json = r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#; + let msg = JsonRpcMessage::parse(json).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Response(_))); + + if let JsonRpcMessage::Response(resp) = msg { + assert!(resp.result.is_some()); + assert!(resp.error.is_none()); + } +} + +#[test] +fn test_json_rpc_message_parse_error_response() { + let json = r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid Request"}}"#; + let msg = JsonRpcMessage::parse(json).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Response(_))); + + if let JsonRpcMessage::Response(resp) = msg { + assert!(resp.result.is_none()); + assert!(resp.error.is_some()); + } +} + +#[test] +fn test_json_rpc_message_parse_request() { + let json = r#"{"jsonrpc":"2.0","method":"sampling/createMessage","params":{},"id":"req-1"}"#; + let msg = JsonRpcMessage::parse(json).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Request(_))); + + if let JsonRpcMessage::Request(req) = msg { + assert_eq!(req.method, "sampling/createMessage"); + } +} + +#[test] +fn test_json_rpc_message_parse_notification() { + let json = r#"{"jsonrpc":"2.0","method":"notifications/progress","params":{"progress":50}}"#; + let msg = JsonRpcMessage::parse(json).unwrap(); + + if let JsonRpcMessage::Notification { method, params } = msg { + assert_eq!(method, "notifications/progress"); + assert_eq!(params["progress"], 50); + } else { + panic!("Expected notification"); + } +} + +#[test] +fn test_json_rpc_message_parse_invalid() { + let json = r#"{"jsonrpc":"2.0"}"#; + let result = JsonRpcMessage::parse(json); + assert!(result.is_err()); +} + +#[test] +fn test_client_error_display() { + let err = ClientError::NotInitialized; + assert_eq!( + err.to_string(), + "Client not initialized - call initialize() first" + ); + + let err = ClientError::Timeout(Duration::from_secs(30)); + assert!(err.to_string().contains("30")); + + let err = ClientError::ServerError { + code: -32600, + message: "Invalid Request".to_string(), + data: None, + }; + assert!(err.to_string().contains("-32600")); + assert!(err.to_string().contains("Invalid Request")); +} + +#[test] +fn test_client_error_is_retryable() { + assert!(ClientError::Timeout(Duration::from_secs(1)).is_retryable()); + assert!(ClientError::Transport("connection lost".to_string()).is_retryable()); + assert!(!ClientError::NotInitialized.is_retryable()); + assert!(!ClientError::Protocol("invalid".to_string()).is_retryable()); +} diff --git a/mcp-client/src/error.rs b/mcp-client/src/error.rs new file mode 100644 index 00000000..a4dd0c0c --- /dev/null +++ b/mcp-client/src/error.rs @@ -0,0 +1,87 @@ +//! Error types for MCP client operations + +use pulseengine_mcp_protocol::Error as ProtocolError; +use thiserror::Error; + +/// Result type alias for client operations +pub type ClientResult = std::result::Result; + +/// Errors that can occur during MCP client operations +#[derive(Debug, Error)] +pub enum ClientError { + /// Transport-level errors (I/O, connection) + #[error("Transport error: {0}")] + Transport(String), + + /// Protocol errors (invalid JSON-RPC, parse errors) + #[error("Protocol error: {0}")] + Protocol(String), + + /// Server returned an error response + #[error("Server error: {message} (code: {code})")] + ServerError { + /// Error code from server + code: i32, + /// Error message from server + message: String, + /// Optional additional data + data: Option, + }, + + /// Request timed out + #[error("Request timed out after {0:?}")] + Timeout(std::time::Duration), + + /// Client not initialized (must call initialize first) + #[error("Client not initialized - call initialize() first")] + NotInitialized, + + /// Response ID mismatch + #[error("Response ID mismatch: expected {expected}, got {actual}")] + IdMismatch { + /// Expected request ID + expected: String, + /// Actual response ID + actual: String, + }, + + /// Channel closed unexpectedly + #[error("Channel closed: {0}")] + ChannelClosed(String), + + /// Serialization/deserialization error + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), +} + +impl ClientError { + /// Create a transport error + pub fn transport(msg: impl Into) -> Self { + Self::Transport(msg.into()) + } + + /// Create a protocol error + pub fn protocol(msg: impl Into) -> Self { + Self::Protocol(msg.into()) + } + + /// Create from a protocol error response + pub fn from_protocol_error(err: ProtocolError) -> Self { + Self::ServerError { + code: err.code as i32, + message: err.message, + data: err.data, + } + } + + /// Check if this is a retryable error + pub fn is_retryable(&self) -> bool { + matches!(self, Self::Timeout(_) | Self::Transport(_)) + } +} + +impl From for ClientError { + fn from(err: std::io::Error) -> Self { + Self::Transport(err.to_string()) + } +} diff --git a/mcp-client/src/lib.rs b/mcp-client/src/lib.rs new file mode 100644 index 00000000..dd618977 --- /dev/null +++ b/mcp-client/src/lib.rs @@ -0,0 +1,81 @@ +//! MCP Client Implementation +//! +//! This crate provides a client for connecting to MCP (Model Context Protocol) servers. +//! It enables programmatic interaction with MCP servers for testing, proxying, and +//! building multi-hop MCP architectures. +//! +//! # Quick Start +//! +//! ```rust,ignore +//! use pulseengine_mcp_client::{McpClient, StdioClientTransport}; +//! use tokio::process::Command; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box> { +//! // Spawn an MCP server as a child process +//! let mut child = Command::new("my-mcp-server") +//! .stdin(std::process::Stdio::piped()) +//! .stdout(std::process::Stdio::piped()) +//! .spawn()?; +//! +//! // Create transport from child process streams +//! let stdin = child.stdin.take().unwrap(); +//! let stdout = child.stdout.take().unwrap(); +//! let transport = StdioClientTransport::new(stdin, stdout); +//! +//! // Create and initialize client +//! let mut client = McpClient::new(transport); +//! let server_info = client.initialize("my-client", "1.0.0").await?; +//! println!("Connected to: {}", server_info.server_info.name); +//! +//! // Use the server +//! let tools = client.list_tools().await?; +//! for tool in tools.tools { +//! println!("Tool: {}", tool.name); +//! } +//! +//! Ok(()) +//! } +//! ``` + +mod client; +mod error; +mod transport; + +#[cfg(test)] +mod client_tests; +#[cfg(test)] +mod transport_tests; + +pub use client::McpClient; +pub use error::{ClientError, ClientResult}; +pub use transport::{ClientTransport, StdioClientTransport}; + +// Re-export protocol types for convenience +pub use pulseengine_mcp_protocol::{ + // Tools + CallToolRequestParam, + CallToolResult, + // Completions + CompleteRequestParam, + CompleteResult, + // Prompts + GetPromptRequestParam, + GetPromptResult, + // Core types + Implementation, + InitializeResult, + ListPromptsResult, + // Resources + ListResourceTemplatesResult, + ListResourcesResult, + ListToolsResult, + Prompt, + ReadResourceRequestParam, + ReadResourceResult, + Resource, + ResourceTemplate, + // Capabilities + ServerCapabilities, + Tool, +}; diff --git a/mcp-client/src/transport.rs b/mcp-client/src/transport.rs new file mode 100644 index 00000000..1ad6ed2a --- /dev/null +++ b/mcp-client/src/transport.rs @@ -0,0 +1,240 @@ +//! Transport layer for MCP client +//! +//! Provides abstractions for bidirectional communication with MCP servers. + +use crate::error::{ClientError, ClientResult}; +use async_trait::async_trait; +use pulseengine_mcp_protocol::{NumberOrString, Request, Response}; +use std::sync::Arc; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::sync::Mutex; +use tracing::{debug, trace}; + +/// Trait for client-side MCP transport +/// +/// This trait abstracts the underlying communication mechanism (stdio, WebSocket, etc.) +/// and provides a simple interface for sending requests and receiving responses. +#[async_trait] +pub trait ClientTransport: Send + Sync { + /// Send a JSON-RPC request to the server + async fn send(&self, request: &Request) -> ClientResult<()>; + + /// Receive the next message from the server + /// + /// This may be a response to a previous request or a server-initiated request. + async fn recv(&self) -> ClientResult; + + /// Close the transport + async fn close(&self) -> ClientResult<()>; +} + +/// A JSON-RPC message that can be either a request or response +#[derive(Debug, Clone)] +pub enum JsonRpcMessage { + /// A response to a previous request + Response(Response), + /// A request from the server (for sampling, roots/list, etc.) + Request(Request), + /// A notification (no response expected) + Notification { + /// The notification method + method: String, + /// The notification parameters + params: serde_json::Value, + }, +} + +impl JsonRpcMessage { + /// Parse a JSON string into a JsonRpcMessage + pub fn parse(json: &str) -> ClientResult { + let value: serde_json::Value = serde_json::from_str(json)?; + + // Check if it's a response (has result or error, no method) + if value.get("result").is_some() || value.get("error").is_some() { + let response: Response = serde_json::from_value(value)?; + return Ok(Self::Response(response)); + } + + // Check if it has a method (request or notification) + if let Some(method) = value.get("method").and_then(|m| m.as_str()) { + // If it has an id, it's a request; otherwise notification + if value.get("id").is_some() && !value.get("id").unwrap().is_null() { + let request: Request = serde_json::from_value(value)?; + return Ok(Self::Request(request)); + } else { + let params = value + .get("params") + .cloned() + .unwrap_or(serde_json::Value::Null); + return Ok(Self::Notification { + method: method.to_string(), + params, + }); + } + } + + Err(ClientError::protocol( + "Invalid JSON-RPC message: no method, result, or error", + )) + } +} + +/// Standard I/O transport for MCP client +/// +/// Communicates with an MCP server via stdin/stdout streams. +/// Typically used with child process spawning. +pub struct StdioClientTransport +where + R: tokio::io::AsyncRead + Unpin + Send, + W: tokio::io::AsyncWrite + Unpin + Send, +{ + reader: Arc>>, + writer: Arc>, +} + +impl StdioClientTransport +where + R: tokio::io::AsyncRead + Unpin + Send, + W: tokio::io::AsyncWrite + Unpin + Send, +{ + /// Create a new stdio transport from read and write streams + /// + /// # Arguments + /// * `reader` - The input stream (typically child process stdout) + /// * `writer` - The output stream (typically child process stdin) + pub fn new(reader: R, writer: W) -> Self { + Self { + reader: Arc::new(Mutex::new(BufReader::new(reader))), + writer: Arc::new(Mutex::new(writer)), + } + } +} + +#[async_trait] +impl ClientTransport for StdioClientTransport +where + R: tokio::io::AsyncRead + Unpin + Send + 'static, + W: tokio::io::AsyncWrite + Unpin + Send + 'static, +{ + async fn send(&self, request: &Request) -> ClientResult<()> { + let json = serde_json::to_string(request)?; + + // Validate: no embedded newlines (MCP spec) + if json.contains('\n') || json.contains('\r') { + return Err(ClientError::protocol( + "Request contains embedded newlines, which is not allowed by MCP spec", + )); + } + + trace!("Sending request: {}", json); + + let mut writer = self.writer.lock().await; + writer + .write_all(json.as_bytes()) + .await + .map_err(|e| ClientError::transport(format!("Failed to write: {e}")))?; + writer + .write_all(b"\n") + .await + .map_err(|e| ClientError::transport(format!("Failed to write newline: {e}")))?; + writer + .flush() + .await + .map_err(|e| ClientError::transport(format!("Failed to flush: {e}")))?; + + debug!( + "Sent request: method={}, id={:?}", + request.method, request.id + ); + Ok(()) + } + + async fn recv(&self) -> ClientResult { + let mut reader = self.reader.lock().await; + let mut line = String::new(); + + loop { + line.clear(); + let bytes_read = reader + .read_line(&mut line) + .await + .map_err(|e| ClientError::transport(format!("Failed to read: {e}")))?; + + if bytes_read == 0 { + return Err(ClientError::transport("EOF: server closed connection")); + } + + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; // Skip empty lines + } + + trace!("Received message: {}", trimmed); + return JsonRpcMessage::parse(trimmed); + } + } + + async fn close(&self) -> ClientResult<()> { + // For stdio, we just flush and let the streams drop + let mut writer = self.writer.lock().await; + writer + .flush() + .await + .map_err(|e| ClientError::transport(format!("Failed to flush on close: {e}")))?; + Ok(()) + } +} + +/// Create a request ID for tracking +pub fn next_request_id() -> NumberOrString { + use std::sync::atomic::{AtomicU64, Ordering}; + static COUNTER: AtomicU64 = AtomicU64::new(1); + NumberOrString::Number(COUNTER.fetch_add(1, Ordering::Relaxed) as i64) +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_parse_response() { + let json = r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#; + let msg = JsonRpcMessage::parse(json).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Response(_))); + } + + #[test] + fn test_parse_error_response() { + let json = r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid"}}"#; + let msg = JsonRpcMessage::parse(json).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Response(_))); + } + + #[test] + fn test_parse_request() { + let json = + r#"{"jsonrpc":"2.0","method":"sampling/createMessage","params":{},"id":"req-1"}"#; + let msg = JsonRpcMessage::parse(json).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Request(_))); + } + + #[test] + fn test_parse_notification() { + let json = + r#"{"jsonrpc":"2.0","method":"notifications/progress","params":{"progress":50}}"#; + let msg = JsonRpcMessage::parse(json).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Notification { .. })); + } + + #[test] + fn test_next_request_id() { + let id1 = next_request_id(); + let id2 = next_request_id(); + + // IDs should be sequential + if let (NumberOrString::Number(n1), NumberOrString::Number(n2)) = (id1, id2) { + assert_eq!(n2, n1 + 1); + } else { + panic!("Expected numeric IDs"); + } + } +} diff --git a/mcp-client/src/transport_tests.rs b/mcp-client/src/transport_tests.rs new file mode 100644 index 00000000..a3b8b284 --- /dev/null +++ b/mcp-client/src/transport_tests.rs @@ -0,0 +1,211 @@ +//! Tests for MCP client transport + +use super::transport::*; +use pulseengine_mcp_protocol::{NumberOrString, Request}; +use serde_json::json; +use tokio::io::{AsyncWriteExt, duplex}; + +#[tokio::test] +async fn test_stdio_transport_send() { + let (client_read, _server_write) = duplex(1024); + let (server_read, client_write) = duplex(1024); + + let transport = StdioClientTransport::new(client_read, client_write); + + let request = Request { + jsonrpc: "2.0".to_string(), + method: "test".to_string(), + params: json!({}), + id: Some(NumberOrString::Number(1)), + }; + + // Send request + transport.send(&request).await.unwrap(); + + // Read from "server" side + let mut reader = tokio::io::BufReader::new(server_read); + use tokio::io::AsyncBufReadExt; + let mut line = String::new(); + reader.read_line(&mut line).await.unwrap(); + + // Verify the message + assert!(line.contains("\"method\":\"test\"")); + assert!(line.contains("\"id\":1")); +} + +#[tokio::test] +async fn test_stdio_transport_recv_response() { + let (client_read, mut server_write) = duplex(1024); + let (_server_read, client_write) = duplex(1024); + + let transport = StdioClientTransport::new(client_read, client_write); + + // Server sends a response + let response_json = r#"{"jsonrpc":"2.0","id":1,"result":{"status":"ok"}}"#; + server_write + .write_all(format!("{response_json}\n").as_bytes()) + .await + .unwrap(); + server_write.flush().await.unwrap(); + + // Client receives it + let msg = transport.recv().await.unwrap(); + + match msg { + JsonRpcMessage::Response(resp) => { + assert_eq!(resp.id, Some(NumberOrString::Number(1))); + assert!(resp.result.is_some()); + } + _ => panic!("Expected Response"), + } +} + +#[tokio::test] +async fn test_stdio_transport_recv_notification() { + let (client_read, mut server_write) = duplex(1024); + let (_server_read, client_write) = duplex(1024); + + let transport = StdioClientTransport::new(client_read, client_write); + + // Server sends a notification + let notification_json = + r#"{"jsonrpc":"2.0","method":"notifications/progress","params":{"progress":50}}"#; + server_write + .write_all(format!("{notification_json}\n").as_bytes()) + .await + .unwrap(); + server_write.flush().await.unwrap(); + + // Client receives it + let msg = transport.recv().await.unwrap(); + + match msg { + JsonRpcMessage::Notification { method, params } => { + assert_eq!(method, "notifications/progress"); + assert_eq!(params["progress"], 50); + } + _ => panic!("Expected Notification"), + } +} + +#[tokio::test] +async fn test_stdio_transport_recv_request() { + let (client_read, mut server_write) = duplex(1024); + let (_server_read, client_write) = duplex(1024); + + let transport = StdioClientTransport::new(client_read, client_write); + + // Server sends a request (e.g., sampling/createMessage) + let request_json = + r#"{"jsonrpc":"2.0","method":"sampling/createMessage","params":{},"id":"srv-1"}"#; + server_write + .write_all(format!("{request_json}\n").as_bytes()) + .await + .unwrap(); + server_write.flush().await.unwrap(); + + // Client receives it + let msg = transport.recv().await.unwrap(); + + match msg { + JsonRpcMessage::Request(req) => { + assert_eq!(req.method, "sampling/createMessage"); + assert!(req.id.is_some()); + } + _ => panic!("Expected Request"), + } +} + +#[tokio::test] +async fn test_stdio_transport_close() { + let (client_read, _server_write) = duplex(1024); + let (_server_read, client_write) = duplex(1024); + + let transport = StdioClientTransport::new(client_read, client_write); + + // Close should succeed + transport.close().await.unwrap(); +} + +#[tokio::test] +async fn test_stdio_transport_skip_empty_lines() { + let (client_read, mut server_write) = duplex(1024); + let (_server_read, client_write) = duplex(1024); + + let transport = StdioClientTransport::new(client_read, client_write); + + // Server sends empty lines then a response + server_write.write_all(b"\n\n").await.unwrap(); + let response_json = r#"{"jsonrpc":"2.0","id":1,"result":{}}"#; + server_write + .write_all(format!("{response_json}\n").as_bytes()) + .await + .unwrap(); + server_write.flush().await.unwrap(); + + // Client should skip empty lines and receive the response + let msg = transport.recv().await.unwrap(); + assert!(matches!(msg, JsonRpcMessage::Response(_))); +} + +#[tokio::test] +async fn test_request_id_generation() { + let id1 = next_request_id(); + let id2 = next_request_id(); + let id3 = next_request_id(); + + // All IDs should be different + let ids: Vec = vec![id1, id2, id3] + .into_iter() + .map(|id| match id { + NumberOrString::Number(n) => n.to_string(), + NumberOrString::String(s) => s.to_string(), + }) + .collect(); + + assert_ne!(ids[0], ids[1]); + assert_ne!(ids[1], ids[2]); + assert_ne!(ids[0], ids[2]); +} + +#[test] +fn test_json_rpc_message_parse_variants() { + // Test all message types + + // Response with result + let msg = JsonRpcMessage::parse(r#"{"jsonrpc":"2.0","id":1,"result":{"ok":true}}"#).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Response(_))); + + // Response with error (use a valid error code) + let msg = JsonRpcMessage::parse( + r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"err"}}"#, + ) + .unwrap(); + assert!(matches!(msg, JsonRpcMessage::Response(_))); + + // Request (has method and id) + let msg = + JsonRpcMessage::parse(r#"{"jsonrpc":"2.0","method":"test","params":{},"id":"x"}"#).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Request(_))); + + // Notification (has method but no id) + let msg = JsonRpcMessage::parse(r#"{"jsonrpc":"2.0","method":"notify","params":{}}"#).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Notification { .. })); + + // Notification with null id (treated as notification) + let msg = JsonRpcMessage::parse(r#"{"jsonrpc":"2.0","method":"notify","params":{},"id":null}"#) + .unwrap(); + assert!(matches!(msg, JsonRpcMessage::Notification { .. })); +} + +#[test] +fn test_json_rpc_message_parse_errors() { + // Invalid JSON + assert!(JsonRpcMessage::parse("not json").is_err()); + + // Missing required fields + assert!(JsonRpcMessage::parse(r#"{"jsonrpc":"2.0"}"#).is_err()); + + // Empty object + assert!(JsonRpcMessage::parse(r#"{}"#).is_err()); +}