Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ venv
package-lock.json
*.mjs
site/*
backend/director/downloads
backend/director/downloads
mcp_servers.json
2 changes: 2 additions & 0 deletions backend/director/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ class EnvPrefix(str, Enum):
ANTHROPIC_ = "ANTHROPIC_"

DOWNLOADS_PATH="director/downloads"

MCP_SERVER_CONFIG_PATH="mcp_servers.json"
167 changes: 167 additions & 0 deletions backend/director/core/mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import json
import logging
import os
from typing import AsyncGenerator
from contextlib import AsyncExitStack, asynccontextmanager
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client, get_default_environment
from mcp.client.sse import sse_client
from director.agents.base import AgentResponse, AgentStatus
from director.constants import MCP_SERVER_CONFIG_PATH

logger = logging.getLogger(__name__)


class MCPClient:
def __init__(self):
self.config_path = MCP_SERVER_CONFIG_PATH
self.servers = self.load_servers()
self.mcp_tools = []
self.exit_stack = AsyncExitStack()

def load_servers(self):
if not os.path.exists(self.config_path):
return {}

with open(self.config_path, 'r') as file:
return json.load(file).get('mcpServers', {})

@asynccontextmanager
async def create_session(
self,
server_name,
config
) -> AsyncGenerator[ClientSession, None]:
if server_name not in self.servers:
raise ValueError(f"Server '{server_name}' not found in configuration.")

if config.get("transport") == "stdio":
if not config.get("command") or not config.get("args"):
raise ValueError(
f"Command and args are required for stdio transport: {server_name}"
)

server_params = StdioServerParameters(
command=config["command"],
args=config["args"],
env={**get_default_environment(), **config.get("env", {})},
)

logger.info(f"{server_name}: Initializing stdio transport with {server_params}")

stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
read_stream, write_stream = stdio_transport

session = await self.exit_stack.enter_async_context(
ClientSession(read_stream, write_stream)
)

await session.initialize()
logger.info(f"{server_name}: Connected to server using stdio transport.")

try:
yield session
finally:
logger.debug(f"{server_name}: Closing session.")

elif config.get("transport") == "sse":
if not config.get("url"):
raise ValueError(f"URL is required for SSE transport: {server_name}")

read_stream, write_stream = await self.exit_stack.enter_async_context(sse_client(config["url"]))

session = await self.exit_stack.enter_async_context(
ClientSession(read_stream, write_stream)
)

await session.initialize()
logger.info(f"{server_name}: Connected to server using sse transport.")
try:
yield session
finally:
logger.debug(f"{server_name}: Closing session.")

else:
raise ValueError(f"Unsupported transport: {config.get('transport')}")

async def close(self) -> None:
"""Closes all managed sessions and releases resources."""
await self.exit_stack.aclose()
logger.info("MCPClient closed all sessions.")

async def connect_to_server(self, name, config):
try:
async with self.create_session(name, config) as session:
if not session:
logger.error(f"Failed to connect to server: {name}")
return []

response = await session.list_tools()
tools = response.tools
for tool in tools:
tool.server_name = name

logger.info(f"Connected to {name} server with {len(tools)} tools.")
self.mcp_tools = tools
return tools
finally:
await self.close()

async def initialize_all_servers(self):
"""Initialize all servers asynchronously."""
all_tools = []
for name, config in self.servers.items():
try:
tools = await self.connect_to_server(name, config)
all_tools.extend(tools)
except Exception as e:
logger.info(f"Could not connect to {name}: {e} \n\n config: {config}")
logger.info(f"Loaded {len(all_tools)} tools from all servers.")
return all_tools

def mcp_tools_to_llm_format(self):
"""Converts MCP tools into an LLM-compatible format."""
return [{
"name": tool.name,
"description": tool.description,
"parameters": tool.inputSchema
} for tool in self.mcp_tools]

def is_mcp_tool_call(self, name):
"""Checks if a given tool name exists in the registered MCP tools."""
return any(tool.name == name for tool in self.mcp_tools)

async def call_tool(self, tool_name, tool_args):
try:
tool = next((t for t in self.mcp_tools if t.name == tool_name), None)
if not tool:
raise ValueError(f"Tool '{tool_name}' not found in MCP tools.")

config = self.servers.get(tool.server_name)
if not config:
raise ValueError(f"Server '{tool.server_name}' not found in config.")

async with self.create_session(tool.server_name, config) as session:
if not session:
raise ValueError(f"Failed to create session for server '{tool.server_name}'.")

logger.info(f"Calling {tool_name} with args {tool_args}")
result = await session.call_tool(tool_name, tool_args)
logger.info(f"Tool call result: {result}")

return AgentResponse(
status=AgentStatus.SUCCESS,
message=f"Tool call successful: {tool_name}",
data={"content": result.content}
)

except Exception as e:
logger.error(f"Error calling tool '{tool_name}': {e}")
return AgentResponse(
status=AgentStatus.ERROR,
message=f"Error calling tool '{tool_name}': {e}",
data={}
)
finally:
await self.close()

69 changes: 63 additions & 6 deletions backend/director/core/reasoning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import asyncio
from typing import List


Expand All @@ -14,7 +15,7 @@
)
from director.llm.base import LLMResponse
from director.llm import get_default_llm

from director.core.mcp_client import MCPClient

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,7 +110,47 @@ def __init__(
self.output_message: OutputMessage = self.session.output_message
self.summary_content = None
self.failed_agents = []
self.mcp_tools = []
self.mcp_client = MCPClient()
self.setup_mcp_servers()

def _set_mcp_tools(self, tools):
self.mcp_tools = tools

def setup_mcp_servers(self):
try:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

if loop and loop.is_running():
tools = loop.run_until_complete(self.mcp_client.initialize_all_servers())
else:
tools = asyncio.run(self.mcp_client.initialize_all_servers())

self._set_mcp_tools(tools)

except Exception as e:
logger.error(f"Failed to initialize MCP servers: {e}")

def call_mcp_tool_sync(self, tool_name, tool_args):
"""Call an MCP tool synchronously."""
try:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

if loop and loop.is_running():
return loop.run_until_complete(self.mcp_client.call_tool(tool_name, tool_args))
else:
return asyncio.run(self.mcp_client.call_tool(tool_name, tool_args))

except Exception as e:
logger.error(f"Failed to call MCP tool '{tool_name}': {e}")
return None

def register_agents(self, agents: List[BaseAgent]):
"""Register an agents.

Expand Down Expand Up @@ -218,12 +259,15 @@ def step(self):
[message.to_llm_msg() for message in self.session.reasoning_context],
"\n\n",
)
mcp_tools = self.mcp_client.mcp_tools_to_llm_format()
agent_tools = [agent.to_llm_format() for agent in self.agents]
all_tools = mcp_tools + agent_tools
llm_response: LLMResponse = self.llm.chat_completions(
messages=[
message.to_llm_msg() for message in self.session.reasoning_context
]
+ temp_messages,
tools=[agent.to_llm_format() for agent in self.agents],
tools=all_tools,
)
logger.info(f"LLM Response: {llm_response}")

Expand Down Expand Up @@ -254,10 +298,23 @@ def step(self):
)
)
for tool_call in llm_response.tool_calls:
agent_response: AgentResponse = self.run_agent(
tool_call["tool"]["name"],
**tool_call["tool"]["arguments"],
)
tool_name = tool_call["tool"]["name"]
tool_args = tool_call["tool"]["arguments"]
if self.mcp_client.is_mcp_tool_call(tool_name):
self.output_message.actions.append(f"Running MCP Tool @{tool_name}")
self.output_message.agents.append(tool_name)
self.output_message.push_update()
agent_response_content = self.call_mcp_tool_sync(tool_name, tool_args)
if agent_response_content:
agent_response = agent_response_content
else:
agent_response = AgentResponse(
status=AgentStatus.ERROR,
message="Method returned null"
)

else:
agent_response: AgentResponse = self.run_agent(tool_name, **tool_args)
if agent_response.status == AgentStatus.ERROR:
self.failed_agents.append(tool_call["tool"]["name"])
self.session.reasoning_context.append(
Expand Down
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ replicate==1.0.1
yt-dlp==2024.10.7
videodb==0.2.10
slack_sdk==3.33.2
mcp==1.4.1