From 504ab0a941ab47e97a645d6e0960a37a537bd9bd Mon Sep 17 00:00:00 2001 From: Om Gate Date: Mon, 17 Mar 2025 18:58:32 +0530 Subject: [PATCH 1/7] feat: support for MCP Servers inside VideoDB Director --- backend/director/core/mcp_client.py | 118 ++++++++++++++++++++++++++++ backend/director/core/reasoning.py | 74 +++++++++++++++-- 2 files changed, 185 insertions(+), 7 deletions(-) create mode 100644 backend/director/core/mcp_client.py diff --git a/backend/director/core/mcp_client.py b/backend/director/core/mcp_client.py new file mode 100644 index 00000000..9f4470a0 --- /dev/null +++ b/backend/director/core/mcp_client.py @@ -0,0 +1,118 @@ +import json +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client +from director.agents.base import AgentResponse, AgentStatus +import asyncio +from contextlib import AsyncExitStack +import logging +import shutil + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger() + + +class MCPClient: + def __init__(self): + self.config_path = 'mcp_servers.json' + self.servers = self.load_servers() + self.mcp_tools = [] + self.exit_stack = AsyncExitStack() + + def load_servers(self): + with open(self.config_path, 'r') as file: + return json.load(file).get('mcpServers', {}) + + async def create_session(self, config): + """Creates a new session for a given server config.""" + try: + exec_path = shutil.which(config['command']) + server_params = StdioServerParameters( + command=exec_path if exec_path else config['command'], + args=config['args'], + env=config.get('env'), + stderr="pipe", + ) + logger.info(f"Initializing server with params: {server_params}") + + stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) + stdio, write = stdio_transport + session = await self.exit_stack.enter_async_context(ClientSession(stdio, write)) + await session.initialize() + + return session + except Exception as e: + logger.error(f"Failed to create session. Error: {e}") + return None + + async def connect_to_server(self, name, config): + """Connects to an MCP server and retrieves tools.""" + session = await self.create_session(config) + if not session: + logger.error(f"Failed to connect to server: {name}") + return [] + + response = await session.list_tools() + tools = response.tools + for tool in tools: + tool.server_name = name + self.mcp_tools.extend(tools) + + logger.info(f"Connected to {name} server with {len(tools)} tools.") + return tools + + def initialize_all_servers_sync(self): + """Initialize all servers synchronously.""" + asyncio.run(self.initialize_all_servers()) + + async def initialize_all_servers(self): + """Initialize all servers asynchronously.""" + all_tools = [] + for name, config in self.servers.items(): + tools = await self.connect_to_server(name, config) + all_tools.extend(tools) + logger.info(f"Loaded {len(all_tools)} tools from all servers.") + return all_tools + + def mcp_tools_to_llm_format(self): + """Converts MCP tools into an LLM-compatible format.""" + return [{ + "name": tool.name, + "description": tool.description, + "parameters": tool.inputSchema + } for tool in self.mcp_tools] + + def is_mcp_tool_call(self, name): + """Checks if a given tool name exists in the registered MCP tools.""" + return any(tool.name == name for tool in self.mcp_tools) + + async def call_tool(self, tool_name, tool_args): + """Calls an MCP tool by name with provided arguments, creating a new session each time.""" + try: + tool = next((t for t in self.mcp_tools if t.name == tool_name), None) + if not tool: + raise ValueError(f"Tool '{tool_name}' not found in MCP tools.") + + config = self.servers.get(tool.server_name) + if not config: + raise ValueError(f"Server '{tool.server_name}' not found in config.") + + session = await self.create_session(config) + if not session: + raise ValueError(f"Failed to create session for server '{tool.server_name}'.") + + logger.info(f"Calling {tool_name} with args {tool_args}") + result = await session.call_tool(tool_name, tool_args) + + return AgentResponse( + status=AgentStatus.SUCCESS, + message=f"Tool call successful: {tool_name}", + data={"content": result.content} + ) + + except Exception as e: + logger.error(f"Error calling tool '{tool_name}': {e}") + return AgentResponse( + status=AgentStatus.ERROR, + message=f"Error calling tool '{tool_name}': {e}", + data={} + ) diff --git a/backend/director/core/reasoning.py b/backend/director/core/reasoning.py index 647705ee..b4826464 100644 --- a/backend/director/core/reasoning.py +++ b/backend/director/core/reasoning.py @@ -14,7 +14,8 @@ ) from director.llm.base import LLMResponse from director.llm import get_default_llm - +from director.core.mcp_client import MCPClient +import asyncio logger = logging.getLogger(__name__) @@ -109,7 +110,49 @@ def __init__( self.output_message: OutputMessage = self.session.output_message self.summary_content = None self.failed_agents = [] - + self.mcp_tools = [] + self.mcp_client = None + self.mcp_client = MCPClient() + self.setup_mcp_servers() + + def _set_mcp_tools(self, tools): + self.mcp_tools = tools + + def setup_mcp_servers(self): + """Initialize all MCP servers and make them available to agents.""" + try: + logger.info("Setting up MCP Servers") + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + task = loop.create_task(self.mcp_client.initialize_all_servers()) + task.add_done_callback(lambda t: self._set_mcp_tools(t.result())) + else: + tools = asyncio.run(self.mcp_client.initialize_all_servers()) + self._set_mcp_tools(tools) + except Exception as e: + logger.error(f"Failed to initialize MCP servers: {e}") + + def call_mcp_tool_sync(self, tool_name, tool_args): + try: + logger.info(f"Calling MCP tool: {tool_name} with args: {tool_args}") + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + task = loop.create_task(self.mcp_client.call_tool(tool_name, tool_args)) + return task.result() + else: + return asyncio.run(self.mcp_client.call_tool(tool_name, tool_args)) + except Exception as e: + logger.error(f"Failed to call MCP tool '{tool_name}': {e}") + return None + def register_agents(self, agents: List[BaseAgent]): """Register an agents. @@ -218,12 +261,15 @@ def step(self): [message.to_llm_msg() for message in self.session.reasoning_context], "\n\n", ) + mcp_tools = self.mcp_client.mcp_tools_to_llm_format() + agent_tools = [agent.to_llm_format() for agent in self.agents] + all_tools = mcp_tools + agent_tools llm_response: LLMResponse = self.llm.chat_completions( messages=[ message.to_llm_msg() for message in self.session.reasoning_context ] + temp_messages, - tools=[agent.to_llm_format() for agent in self.agents], + tools=all_tools, ) logger.info(f"LLM Response: {llm_response}") @@ -254,10 +300,24 @@ def step(self): ) ) for tool_call in llm_response.tool_calls: - agent_response: AgentResponse = self.run_agent( - tool_call["tool"]["name"], - **tool_call["tool"]["arguments"], - ) + tool_name = tool_call["tool"]["name"] + tool_args = tool_call["tool"]["arguments"] + if self.mcp_client.is_mcp_tool_call(tool_name): + self.output_message.actions.append(f"Running MCP Tool @{tool_name}") + self.output_message.agents.append(tool_name) + self.output_message.push_update() + logger.info(f"Detected MCP tool call for: {tool_name}, executing synchronously.") + agent_response_content = self.call_mcp_tool_sync(tool_name, tool_args) + if agent_response_content: + agent_response = agent_response_content + else: + agent_response = AgentResponse( + status=AgentStatus.ERROR, + message="Method returned null" + ) + + else: + agent_response: AgentResponse = self.run_agent(tool_name, **tool_args) if agent_response.status == AgentStatus.ERROR: self.failed_agents.append(tool_call["tool"]["name"]) self.session.reasoning_context.append( From 8eac93a094690ae3399988f4082d47b5a0cb5d5a Mon Sep 17 00:00:00 2001 From: Om Gate Date: Mon, 17 Mar 2025 18:58:54 +0530 Subject: [PATCH 2/7] feat: sample mcp_servers.json --- backend/mcp_servers.json | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 backend/mcp_servers.json diff --git a/backend/mcp_servers.json b/backend/mcp_servers.json new file mode 100644 index 00000000..f8be2d76 --- /dev/null +++ b/backend/mcp_servers.json @@ -0,0 +1,18 @@ +{ + "mcpServers": { + "github": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "-e", + "GITHUB_PERSONAL_ACCESS_TOKEN", + "mcp/github" + ], + "env": { + "GITHUB_PERSONAL_ACCESS_TOKEN": "your-github-personal-access-token" + } + } + } +} From ed91d7e1ff07c798887db5cd3009896ad7903c35 Mon Sep 17 00:00:00 2001 From: Om Gate Date: Mon, 17 Mar 2025 19:05:46 +0530 Subject: [PATCH 3/7] refactor --- backend/director/core/mcp_client.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/backend/director/core/mcp_client.py b/backend/director/core/mcp_client.py index 9f4470a0..6a804ec5 100644 --- a/backend/director/core/mcp_client.py +++ b/backend/director/core/mcp_client.py @@ -60,10 +60,6 @@ async def connect_to_server(self, name, config): logger.info(f"Connected to {name} server with {len(tools)} tools.") return tools - def initialize_all_servers_sync(self): - """Initialize all servers synchronously.""" - asyncio.run(self.initialize_all_servers()) - async def initialize_all_servers(self): """Initialize all servers asynchronously.""" all_tools = [] From 37027ba0aa1c67a08e2fd96ce1e07d6caa776e74 Mon Sep 17 00:00:00 2001 From: Om Gate Date: Mon, 17 Mar 2025 19:41:36 +0530 Subject: [PATCH 4/7] feat: requirments.txt --- backend/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/requirements.txt b/backend/requirements.txt index 4d747666..ce7e7d06 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -17,3 +17,4 @@ replicate==1.0.1 yt-dlp==2024.10.7 videodb==0.2.10 slack_sdk==3.33.2 +mcp==1.4.1 From fadba5f6317faecee50e18af2fa2d5e474ee92b3 Mon Sep 17 00:00:00 2001 From: Om Gate Date: Tue, 18 Mar 2025 17:03:00 +0530 Subject: [PATCH 5/7] refactor: code review --- .gitignore | 3 ++- backend/director/constants.py | 2 ++ backend/director/core/mcp_client.py | 13 ++++++------- backend/director/core/reasoning.py | 27 ++++----------------------- backend/mcp_servers.json | 18 ------------------ 5 files changed, 14 insertions(+), 49 deletions(-) delete mode 100644 backend/mcp_servers.json diff --git a/.gitignore b/.gitignore index a717b5b5..57912b16 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ venv package-lock.json *.mjs site/* -backend/director/downloads \ No newline at end of file +backend/director/downloads +mcp_servers.json diff --git a/backend/director/constants.py b/backend/director/constants.py index f370aac9..17c7b0ad 100644 --- a/backend/director/constants.py +++ b/backend/director/constants.py @@ -29,3 +29,5 @@ class EnvPrefix(str, Enum): ANTHROPIC_ = "ANTHROPIC_" DOWNLOADS_PATH="director/downloads" + +MCP_SERVER_CONFIG_PATH="mcp_servers.json" diff --git a/backend/director/core/mcp_client.py b/backend/director/core/mcp_client.py index 6a804ec5..1755dd52 100644 --- a/backend/director/core/mcp_client.py +++ b/backend/director/core/mcp_client.py @@ -1,11 +1,11 @@ import json +import logging +import shutil +from contextlib import AsyncExitStack from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client from director.agents.base import AgentResponse, AgentStatus -import asyncio -from contextlib import AsyncExitStack -import logging -import shutil +from director.constants import MCP_SERVER_CONFIG_PATH logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() @@ -13,7 +13,7 @@ class MCPClient: def __init__(self): - self.config_path = 'mcp_servers.json' + self.config_path = MCP_SERVER_CONFIG_PATH self.servers = self.load_servers() self.mcp_tools = [] self.exit_stack = AsyncExitStack() @@ -29,8 +29,7 @@ async def create_session(self, config): server_params = StdioServerParameters( command=exec_path if exec_path else config['command'], args=config['args'], - env=config.get('env'), - stderr="pipe", + env=config.get('env') ) logger.info(f"Initializing server with params: {server_params}") diff --git a/backend/director/core/reasoning.py b/backend/director/core/reasoning.py index b4826464..1d543dfd 100644 --- a/backend/director/core/reasoning.py +++ b/backend/director/core/reasoning.py @@ -1,4 +1,5 @@ import logging +import asyncio from typing import List @@ -15,7 +16,6 @@ from director.llm.base import LLMResponse from director.llm import get_default_llm from director.core.mcp_client import MCPClient -import asyncio logger = logging.getLogger(__name__) @@ -111,7 +111,6 @@ def __init__( self.summary_content = None self.failed_agents = [] self.mcp_tools = [] - self.mcp_client = None self.mcp_client = MCPClient() self.setup_mcp_servers() @@ -122,33 +121,15 @@ def setup_mcp_servers(self): """Initialize all MCP servers and make them available to agents.""" try: logger.info("Setting up MCP Servers") - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop and loop.is_running(): - task = loop.create_task(self.mcp_client.initialize_all_servers()) - task.add_done_callback(lambda t: self._set_mcp_tools(t.result())) - else: - tools = asyncio.run(self.mcp_client.initialize_all_servers()) - self._set_mcp_tools(tools) + tools = asyncio.run(self.mcp_client.initialize_all_servers()) + self._set_mcp_tools(tools) except Exception as e: logger.error(f"Failed to initialize MCP servers: {e}") def call_mcp_tool_sync(self, tool_name, tool_args): try: logger.info(f"Calling MCP tool: {tool_name} with args: {tool_args}") - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop and loop.is_running(): - task = loop.create_task(self.mcp_client.call_tool(tool_name, tool_args)) - return task.result() - else: - return asyncio.run(self.mcp_client.call_tool(tool_name, tool_args)) + return asyncio.run(self.mcp_client.call_tool(tool_name, tool_args)) except Exception as e: logger.error(f"Failed to call MCP tool '{tool_name}': {e}") return None diff --git a/backend/mcp_servers.json b/backend/mcp_servers.json deleted file mode 100644 index f8be2d76..00000000 --- a/backend/mcp_servers.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "mcpServers": { - "github": { - "command": "docker", - "args": [ - "run", - "-i", - "--rm", - "-e", - "GITHUB_PERSONAL_ACCESS_TOKEN", - "mcp/github" - ], - "env": { - "GITHUB_PERSONAL_ACCESS_TOKEN": "your-github-personal-access-token" - } - } - } -} From 7aeb0c46282ca8971828cae9be0ce928159d7943 Mon Sep 17 00:00:00 2001 From: Om Gate Date: Wed, 19 Mar 2025 14:44:16 +0530 Subject: [PATCH 6/7] feat: handling sse and stdio --- backend/director/core/mcp_client.py | 126 ++++++++++++++++++---------- backend/director/core/reasoning.py | 28 ++++++- 2 files changed, 109 insertions(+), 45 deletions(-) diff --git a/backend/director/core/mcp_client.py b/backend/director/core/mcp_client.py index 1755dd52..a70a7784 100644 --- a/backend/director/core/mcp_client.py +++ b/backend/director/core/mcp_client.py @@ -1,9 +1,11 @@ import json import logging import shutil -from contextlib import AsyncExitStack +from typing import AsyncGenerator +from contextlib import AsyncExitStack, asynccontextmanager from mcp import ClientSession, StdioServerParameters -from mcp.client.stdio import stdio_client +from mcp.client.stdio import stdio_client, get_default_environment +from mcp.client.sse import sse_client from director.agents.base import AgentResponse, AgentStatus from director.constants import MCP_SERVER_CONFIG_PATH @@ -22,42 +24,82 @@ def load_servers(self): with open(self.config_path, 'r') as file: return json.load(file).get('mcpServers', {}) - async def create_session(self, config): - """Creates a new session for a given server config.""" - try: - exec_path = shutil.which(config['command']) + @asynccontextmanager + async def create_session( + self, + server_name, + config + ) -> AsyncGenerator[ClientSession, None]: + if server_name not in self.servers: + raise ValueError(f"Server '{server_name}' not found in configuration.") + + if config.get("transport") == "stdio": + if not config.get("command") or not config.get("args"): + raise ValueError( + f"Command and args are required for stdio transport: {server_name}" + ) + server_params = StdioServerParameters( - command=exec_path if exec_path else config['command'], - args=config['args'], - env=config.get('env') + command=config["command"], + args=config["args"], + env={**get_default_environment(), **config.get("env", {})}, ) - logger.info(f"Initializing server with params: {server_params}") + + logger.info(f"{server_name}: Initializing stdio transport with {server_params}") stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) - stdio, write = stdio_transport - session = await self.exit_stack.enter_async_context(ClientSession(stdio, write)) + read_stream, write_stream = stdio_transport + + session = await self.exit_stack.enter_async_context( + ClientSession(read_stream, write_stream) + ) + await session.initialize() - - return session - except Exception as e: - logger.error(f"Failed to create session. Error: {e}") - return None + logger.info(f"{server_name}: Connected to server using stdio transport.") + + try: + yield session + finally: + logger.debug(f"{server_name}: Closing session.") + + elif config.get("transport") == "sse": + if not config.get("url"): + raise ValueError(f"URL is required for SSE transport: {server_name}") + + async with sse_client(config["url"]) as (read_stream, write_stream): + session = await self.exit_stack.enter_async_context( + ClientSession(read_stream, write_stream) + ) + + logger.info(f"{server_name}: Connected to server using SSE transport.") + + try: + yield session + finally: + logger.debug(f"{server_name}: Closing session.") + + else: + raise ValueError(f"Unsupported transport: {config.get('transport')}") + + async def close(self) -> None: + """Closes all managed sessions and releases resources.""" + await self.exit_stack.aclose() + logger.info("MCPClient closed all sessions.") async def connect_to_server(self, name, config): - """Connects to an MCP server and retrieves tools.""" - session = await self.create_session(config) - if not session: - logger.error(f"Failed to connect to server: {name}") - return [] - - response = await session.list_tools() - tools = response.tools - for tool in tools: - tool.server_name = name - self.mcp_tools.extend(tools) - - logger.info(f"Connected to {name} server with {len(tools)} tools.") - return tools + async with self.create_session(name, config) as session: + if not session: + logger.error(f"Failed to connect to server: {name}") + return [] + + response = await session.list_tools() + tools = response.tools + for tool in tools: + tool.server_name = name + + logger.info(f"Connected to {name} server with {len(tools)} tools.") + self.mcp_tools = tools + return tools async def initialize_all_servers(self): """Initialize all servers asynchronously.""" @@ -81,7 +123,6 @@ def is_mcp_tool_call(self, name): return any(tool.name == name for tool in self.mcp_tools) async def call_tool(self, tool_name, tool_args): - """Calls an MCP tool by name with provided arguments, creating a new session each time.""" try: tool = next((t for t in self.mcp_tools if t.name == tool_name), None) if not tool: @@ -91,18 +132,18 @@ async def call_tool(self, tool_name, tool_args): if not config: raise ValueError(f"Server '{tool.server_name}' not found in config.") - session = await self.create_session(config) - if not session: - raise ValueError(f"Failed to create session for server '{tool.server_name}'.") + async with self.create_session(tool.server_name) as session: + if not session: + raise ValueError(f"Failed to create session for server '{tool.server_name}'.") - logger.info(f"Calling {tool_name} with args {tool_args}") - result = await session.call_tool(tool_name, tool_args) + logger.info(f"Calling {tool_name} with args {tool_args}") + result = await session.call_tool(tool_name, tool_args) - return AgentResponse( - status=AgentStatus.SUCCESS, - message=f"Tool call successful: {tool_name}", - data={"content": result.content} - ) + return AgentResponse( + status=AgentStatus.SUCCESS, + message=f"Tool call successful: {tool_name}", + data={"content": result.content} + ) except Exception as e: logger.error(f"Error calling tool '{tool_name}': {e}") @@ -111,3 +152,4 @@ async def call_tool(self, tool_name, tool_args): message=f"Error calling tool '{tool_name}': {e}", data={} ) + diff --git a/backend/director/core/reasoning.py b/backend/director/core/reasoning.py index 1d543dfd..657b4e16 100644 --- a/backend/director/core/reasoning.py +++ b/backend/director/core/reasoning.py @@ -118,18 +118,39 @@ def _set_mcp_tools(self, tools): self.mcp_tools = tools def setup_mcp_servers(self): - """Initialize all MCP servers and make them available to agents.""" try: logger.info("Setting up MCP Servers") - tools = asyncio.run(self.mcp_client.initialize_all_servers()) + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + tools = loop.run_until_complete(self.mcp_client.initialize_all_servers()) + else: + tools = asyncio.run(self.mcp_client.initialize_all_servers()) + self._set_mcp_tools(tools) + except Exception as e: logger.error(f"Failed to initialize MCP servers: {e}") def call_mcp_tool_sync(self, tool_name, tool_args): + """Call an MCP tool synchronously.""" try: logger.info(f"Calling MCP tool: {tool_name} with args: {tool_args}") - return asyncio.run(self.mcp_client.call_tool(tool_name, tool_args)) + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + return loop.run_until_complete(self.mcp_client.call_tool(tool_name, tool_args)) + else: + return asyncio.run(self.mcp_client.call_tool(tool_name, tool_args)) + except Exception as e: logger.error(f"Failed to call MCP tool '{tool_name}': {e}") return None @@ -245,6 +266,7 @@ def step(self): mcp_tools = self.mcp_client.mcp_tools_to_llm_format() agent_tools = [agent.to_llm_format() for agent in self.agents] all_tools = mcp_tools + agent_tools + logger.info(f"MCP_TOOLS -> {mcp_tools}") llm_response: LLMResponse = self.llm.chat_completions( messages=[ message.to_llm_msg() for message in self.session.reasoning_context From 3a71a702c2d829406e141c359b9fe4ae99e43383 Mon Sep 17 00:00:00 2001 From: Om Gate Date: Wed, 19 Mar 2025 17:06:03 +0530 Subject: [PATCH 7/7] feat: support SSE MCP Servers --- backend/director/core/mcp_client.py | 66 +++++++++++++++++------------ backend/director/core/reasoning.py | 6 --- 2 files changed, 39 insertions(+), 33 deletions(-) diff --git a/backend/director/core/mcp_client.py b/backend/director/core/mcp_client.py index a70a7784..03f71ec7 100644 --- a/backend/director/core/mcp_client.py +++ b/backend/director/core/mcp_client.py @@ -1,6 +1,6 @@ import json import logging -import shutil +import os from typing import AsyncGenerator from contextlib import AsyncExitStack, asynccontextmanager from mcp import ClientSession, StdioServerParameters @@ -9,8 +9,7 @@ from director.agents.base import AgentResponse, AgentStatus from director.constants import MCP_SERVER_CONFIG_PATH -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger() +logger = logging.getLogger(__name__) class MCPClient: @@ -21,6 +20,9 @@ def __init__(self): self.exit_stack = AsyncExitStack() def load_servers(self): + if not os.path.exists(self.config_path): + return {} + with open(self.config_path, 'r') as file: return json.load(file).get('mcpServers', {}) @@ -66,17 +68,18 @@ async def create_session( if not config.get("url"): raise ValueError(f"URL is required for SSE transport: {server_name}") - async with sse_client(config["url"]) as (read_stream, write_stream): - session = await self.exit_stack.enter_async_context( - ClientSession(read_stream, write_stream) - ) - - logger.info(f"{server_name}: Connected to server using SSE transport.") + read_stream, write_stream = await self.exit_stack.enter_async_context(sse_client(config["url"])) + + session = await self.exit_stack.enter_async_context( + ClientSession(read_stream, write_stream) + ) - try: - yield session - finally: - logger.debug(f"{server_name}: Closing session.") + await session.initialize() + logger.info(f"{server_name}: Connected to server using sse transport.") + try: + yield session + finally: + logger.debug(f"{server_name}: Closing session.") else: raise ValueError(f"Unsupported transport: {config.get('transport')}") @@ -87,26 +90,32 @@ async def close(self) -> None: logger.info("MCPClient closed all sessions.") async def connect_to_server(self, name, config): - async with self.create_session(name, config) as session: - if not session: - logger.error(f"Failed to connect to server: {name}") - return [] + try: + async with self.create_session(name, config) as session: + if not session: + logger.error(f"Failed to connect to server: {name}") + return [] - response = await session.list_tools() - tools = response.tools - for tool in tools: - tool.server_name = name + response = await session.list_tools() + tools = response.tools + for tool in tools: + tool.server_name = name - logger.info(f"Connected to {name} server with {len(tools)} tools.") - self.mcp_tools = tools - return tools + logger.info(f"Connected to {name} server with {len(tools)} tools.") + self.mcp_tools = tools + return tools + finally: + await self.close() async def initialize_all_servers(self): """Initialize all servers asynchronously.""" all_tools = [] for name, config in self.servers.items(): - tools = await self.connect_to_server(name, config) - all_tools.extend(tools) + try: + tools = await self.connect_to_server(name, config) + all_tools.extend(tools) + except Exception as e: + logger.info(f"Could not connect to {name}: {e} \n\n config: {config}") logger.info(f"Loaded {len(all_tools)} tools from all servers.") return all_tools @@ -132,12 +141,13 @@ async def call_tool(self, tool_name, tool_args): if not config: raise ValueError(f"Server '{tool.server_name}' not found in config.") - async with self.create_session(tool.server_name) as session: + async with self.create_session(tool.server_name, config) as session: if not session: raise ValueError(f"Failed to create session for server '{tool.server_name}'.") logger.info(f"Calling {tool_name} with args {tool_args}") result = await session.call_tool(tool_name, tool_args) + logger.info(f"Tool call result: {result}") return AgentResponse( status=AgentStatus.SUCCESS, @@ -152,4 +162,6 @@ async def call_tool(self, tool_name, tool_args): message=f"Error calling tool '{tool_name}': {e}", data={} ) + finally: + await self.close() diff --git a/backend/director/core/reasoning.py b/backend/director/core/reasoning.py index 657b4e16..cf7d6329 100644 --- a/backend/director/core/reasoning.py +++ b/backend/director/core/reasoning.py @@ -119,8 +119,6 @@ def _set_mcp_tools(self, tools): def setup_mcp_servers(self): try: - logger.info("Setting up MCP Servers") - try: loop = asyncio.get_running_loop() except RuntimeError: @@ -139,8 +137,6 @@ def setup_mcp_servers(self): def call_mcp_tool_sync(self, tool_name, tool_args): """Call an MCP tool synchronously.""" try: - logger.info(f"Calling MCP tool: {tool_name} with args: {tool_args}") - try: loop = asyncio.get_running_loop() except RuntimeError: @@ -266,7 +262,6 @@ def step(self): mcp_tools = self.mcp_client.mcp_tools_to_llm_format() agent_tools = [agent.to_llm_format() for agent in self.agents] all_tools = mcp_tools + agent_tools - logger.info(f"MCP_TOOLS -> {mcp_tools}") llm_response: LLMResponse = self.llm.chat_completions( messages=[ message.to_llm_msg() for message in self.session.reasoning_context @@ -309,7 +304,6 @@ def step(self): self.output_message.actions.append(f"Running MCP Tool @{tool_name}") self.output_message.agents.append(tool_name) self.output_message.push_update() - logger.info(f"Detected MCP tool call for: {tool_name}, executing synchronously.") agent_response_content = self.call_mcp_tool_sync(tool_name, tool_args) if agent_response_content: agent_response = agent_response_content