diff --git a/.gitignore b/.gitignore index a717b5b5..57912b16 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ venv package-lock.json *.mjs site/* -backend/director/downloads \ No newline at end of file +backend/director/downloads +mcp_servers.json diff --git a/backend/director/constants.py b/backend/director/constants.py index f370aac9..17c7b0ad 100644 --- a/backend/director/constants.py +++ b/backend/director/constants.py @@ -29,3 +29,5 @@ class EnvPrefix(str, Enum): ANTHROPIC_ = "ANTHROPIC_" DOWNLOADS_PATH="director/downloads" + +MCP_SERVER_CONFIG_PATH="mcp_servers.json" diff --git a/backend/director/core/mcp_client.py b/backend/director/core/mcp_client.py new file mode 100644 index 00000000..03f71ec7 --- /dev/null +++ b/backend/director/core/mcp_client.py @@ -0,0 +1,167 @@ +import json +import logging +import os +from typing import AsyncGenerator +from contextlib import AsyncExitStack, asynccontextmanager +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client, get_default_environment +from mcp.client.sse import sse_client +from director.agents.base import AgentResponse, AgentStatus +from director.constants import MCP_SERVER_CONFIG_PATH + +logger = logging.getLogger(__name__) + + +class MCPClient: + def __init__(self): + self.config_path = MCP_SERVER_CONFIG_PATH + self.servers = self.load_servers() + self.mcp_tools = [] + self.exit_stack = AsyncExitStack() + + def load_servers(self): + if not os.path.exists(self.config_path): + return {} + + with open(self.config_path, 'r') as file: + return json.load(file).get('mcpServers', {}) + + @asynccontextmanager + async def create_session( + self, + server_name, + config + ) -> AsyncGenerator[ClientSession, None]: + if server_name not in self.servers: + raise ValueError(f"Server '{server_name}' not found in configuration.") + + if config.get("transport") == "stdio": + if not config.get("command") or not config.get("args"): + raise ValueError( + f"Command and args are required for stdio transport: {server_name}" + ) + + server_params = StdioServerParameters( + command=config["command"], + args=config["args"], + env={**get_default_environment(), **config.get("env", {})}, + ) + + logger.info(f"{server_name}: Initializing stdio transport with {server_params}") + + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) + read_stream, write_stream = stdio_transport + + session = await self.exit_stack.enter_async_context( + ClientSession(read_stream, write_stream) + ) + + await session.initialize() + logger.info(f"{server_name}: Connected to server using stdio transport.") + + try: + yield session + finally: + logger.debug(f"{server_name}: Closing session.") + + elif config.get("transport") == "sse": + if not config.get("url"): + raise ValueError(f"URL is required for SSE transport: {server_name}") + + read_stream, write_stream = await self.exit_stack.enter_async_context(sse_client(config["url"])) + + session = await self.exit_stack.enter_async_context( + ClientSession(read_stream, write_stream) + ) + + await session.initialize() + logger.info(f"{server_name}: Connected to server using sse transport.") + try: + yield session + finally: + logger.debug(f"{server_name}: Closing session.") + + else: + raise ValueError(f"Unsupported transport: {config.get('transport')}") + + async def close(self) -> None: + """Closes all managed sessions and releases resources.""" + await self.exit_stack.aclose() + logger.info("MCPClient closed all sessions.") + + async def connect_to_server(self, name, config): + try: + async with self.create_session(name, config) as session: + if not session: + logger.error(f"Failed to connect to server: {name}") + return [] + + response = await session.list_tools() + tools = response.tools + for tool in tools: + tool.server_name = name + + logger.info(f"Connected to {name} server with {len(tools)} tools.") + self.mcp_tools = tools + return tools + finally: + await self.close() + + async def initialize_all_servers(self): + """Initialize all servers asynchronously.""" + all_tools = [] + for name, config in self.servers.items(): + try: + tools = await self.connect_to_server(name, config) + all_tools.extend(tools) + except Exception as e: + logger.info(f"Could not connect to {name}: {e} \n\n config: {config}") + logger.info(f"Loaded {len(all_tools)} tools from all servers.") + return all_tools + + def mcp_tools_to_llm_format(self): + """Converts MCP tools into an LLM-compatible format.""" + return [{ + "name": tool.name, + "description": tool.description, + "parameters": tool.inputSchema + } for tool in self.mcp_tools] + + def is_mcp_tool_call(self, name): + """Checks if a given tool name exists in the registered MCP tools.""" + return any(tool.name == name for tool in self.mcp_tools) + + async def call_tool(self, tool_name, tool_args): + try: + tool = next((t for t in self.mcp_tools if t.name == tool_name), None) + if not tool: + raise ValueError(f"Tool '{tool_name}' not found in MCP tools.") + + config = self.servers.get(tool.server_name) + if not config: + raise ValueError(f"Server '{tool.server_name}' not found in config.") + + async with self.create_session(tool.server_name, config) as session: + if not session: + raise ValueError(f"Failed to create session for server '{tool.server_name}'.") + + logger.info(f"Calling {tool_name} with args {tool_args}") + result = await session.call_tool(tool_name, tool_args) + logger.info(f"Tool call result: {result}") + + return AgentResponse( + status=AgentStatus.SUCCESS, + message=f"Tool call successful: {tool_name}", + data={"content": result.content} + ) + + except Exception as e: + logger.error(f"Error calling tool '{tool_name}': {e}") + return AgentResponse( + status=AgentStatus.ERROR, + message=f"Error calling tool '{tool_name}': {e}", + data={} + ) + finally: + await self.close() + diff --git a/backend/director/core/reasoning.py b/backend/director/core/reasoning.py index 647705ee..cf7d6329 100644 --- a/backend/director/core/reasoning.py +++ b/backend/director/core/reasoning.py @@ -1,4 +1,5 @@ import logging +import asyncio from typing import List @@ -14,7 +15,7 @@ ) from director.llm.base import LLMResponse from director.llm import get_default_llm - +from director.core.mcp_client import MCPClient logger = logging.getLogger(__name__) @@ -109,7 +110,47 @@ def __init__( self.output_message: OutputMessage = self.session.output_message self.summary_content = None self.failed_agents = [] + self.mcp_tools = [] + self.mcp_client = MCPClient() + self.setup_mcp_servers() + + def _set_mcp_tools(self, tools): + self.mcp_tools = tools + + def setup_mcp_servers(self): + try: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + tools = loop.run_until_complete(self.mcp_client.initialize_all_servers()) + else: + tools = asyncio.run(self.mcp_client.initialize_all_servers()) + + self._set_mcp_tools(tools) + + except Exception as e: + logger.error(f"Failed to initialize MCP servers: {e}") + + def call_mcp_tool_sync(self, tool_name, tool_args): + """Call an MCP tool synchronously.""" + try: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + return loop.run_until_complete(self.mcp_client.call_tool(tool_name, tool_args)) + else: + return asyncio.run(self.mcp_client.call_tool(tool_name, tool_args)) + + except Exception as e: + logger.error(f"Failed to call MCP tool '{tool_name}': {e}") + return None + def register_agents(self, agents: List[BaseAgent]): """Register an agents. @@ -218,12 +259,15 @@ def step(self): [message.to_llm_msg() for message in self.session.reasoning_context], "\n\n", ) + mcp_tools = self.mcp_client.mcp_tools_to_llm_format() + agent_tools = [agent.to_llm_format() for agent in self.agents] + all_tools = mcp_tools + agent_tools llm_response: LLMResponse = self.llm.chat_completions( messages=[ message.to_llm_msg() for message in self.session.reasoning_context ] + temp_messages, - tools=[agent.to_llm_format() for agent in self.agents], + tools=all_tools, ) logger.info(f"LLM Response: {llm_response}") @@ -254,10 +298,23 @@ def step(self): ) ) for tool_call in llm_response.tool_calls: - agent_response: AgentResponse = self.run_agent( - tool_call["tool"]["name"], - **tool_call["tool"]["arguments"], - ) + tool_name = tool_call["tool"]["name"] + tool_args = tool_call["tool"]["arguments"] + if self.mcp_client.is_mcp_tool_call(tool_name): + self.output_message.actions.append(f"Running MCP Tool @{tool_name}") + self.output_message.agents.append(tool_name) + self.output_message.push_update() + agent_response_content = self.call_mcp_tool_sync(tool_name, tool_args) + if agent_response_content: + agent_response = agent_response_content + else: + agent_response = AgentResponse( + status=AgentStatus.ERROR, + message="Method returned null" + ) + + else: + agent_response: AgentResponse = self.run_agent(tool_name, **tool_args) if agent_response.status == AgentStatus.ERROR: self.failed_agents.append(tool_call["tool"]["name"]) self.session.reasoning_context.append( diff --git a/backend/requirements.txt b/backend/requirements.txt index 4d747666..ce7e7d06 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -17,3 +17,4 @@ replicate==1.0.1 yt-dlp==2024.10.7 videodb==0.2.10 slack_sdk==3.33.2 +mcp==1.4.1