Files
arioch-assistant/assistant/mcp_client.py
2026-04-07 22:06:36 +02:00

249 lines
8.8 KiB
Python

"""Support des serveurs MCP (Model Context Protocol) pour l'assistant.
Configure les serveurs dans un profil YAML sous la clé `mcp_servers` :
mcp_servers:
- name: filesystem
command: npx
args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
- name: my_sse_server
url: "http://localhost:3000/sse"
"""
from __future__ import annotations
import asyncio
import os
import re
import threading
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
# Racine du projet (le dossier qui contient main.py)
_PROJECT_ROOT = Path(__file__).parent.parent.resolve()
@dataclass
class MCPServerConfig:
name: str
command: str | None = None
args: list[str] = field(default_factory=list)
env: dict[str, str] | None = None
url: str | None = None
def resolved_args(self) -> list[str]:
"""Résout les chemins relatifs dans args par rapport à la racine du projet."""
result = []
for arg in self.args:
p = Path(arg)
if not p.is_absolute() and p.suffix in (".js", ".py", ".ts"):
resolved = (_PROJECT_ROOT / p).resolve()
result.append(str(resolved))
else:
result.append(arg)
return result
def _sanitize_name(name: str) -> str:
"""Transforme un nom en identifiant valide pour l'API Mistral (^[a-zA-Z0-9_-]{1,64}$)."""
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
class MCPManager:
"""Gère les connexions aux serveurs MCP et l'exécution des outils.
Chaque connexion tourne dans une "keeper task" qui possède toute la durée
de vie du context manager stdio_client / ClientSession. Cela évite l'erreur
anyio "Attempted to exit cancel scope in a different task".
"""
def __init__(self) -> None:
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._sessions: dict[str, Any] = {}
self._raw_tools: dict[str, list] = {}
self._tool_map: dict[str, tuple[str, str]] = {}
# stop event per server, signalled at shutdown
self._stop_events: dict[str, asyncio.Event] = {}
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
def _run(self, coro, timeout: int = 30) -> Any:
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result(timeout=timeout)
# ------------------------------------------------------------------ #
# Public API #
# ------------------------------------------------------------------ #
def load_servers(self, configs: list[MCPServerConfig]) -> None:
self._run(self._load_servers_async(configs), timeout=60)
def get_mistral_tools(self) -> list[dict]:
tools = []
for mistral_name, (server_name, tool_name) in self._tool_map.items():
raw = {t.name: t for t in self._raw_tools.get(server_name, [])}
tool = raw.get(tool_name)
if tool is None:
continue
tools.append({
"type": "function",
"function": {
"name": mistral_name,
"description": tool.description or "",
"parameters": tool.inputSchema or {"type": "object", "properties": {}},
},
})
return tools
def call_tool(self, mistral_name: str, arguments: dict) -> str:
if mistral_name not in self._tool_map:
return f"Outil inconnu : {mistral_name}"
server_name, tool_name = self._tool_map[mistral_name]
if server_name not in self._sessions:
return f"Serveur MCP '{server_name}' non disponible"
return self._run(self._call_tool_async(server_name, tool_name, arguments))
def shutdown(self) -> None:
"""Signale l'arrêt à toutes les keeper tasks et attend brièvement."""
async def _signal_all() -> None:
for ev in self._stop_events.values():
ev.set()
# Laisser une courte fenêtre pour que les tâches se terminent proprement
await asyncio.sleep(0.3)
try:
self._run(_signal_all(), timeout=5)
except Exception:
pass
self._stop_events.clear()
def summary(self) -> list[tuple[str, int]]:
"""Retourne [(server_name, tool_count), ...] pour les serveurs connectés."""
return [(name, len(self._raw_tools.get(name, []))) for name in self._sessions]
@property
def has_tools(self) -> bool:
return bool(self._tool_map)
# ------------------------------------------------------------------ #
# Async internals #
# ------------------------------------------------------------------ #
async def _load_servers_async(self, configs: list[MCPServerConfig]) -> None:
for cfg in configs:
try:
await self._connect_server(cfg)
except Exception as e:
print(f"[MCP] ❌ Connexion {cfg.name} impossible : {e}")
async def _connect_server(self, cfg: MCPServerConfig) -> None:
"""Lance la keeper task et attend que la connexion soit établie."""
stop_event = asyncio.Event()
ready_event = asyncio.Event()
error_holder: list[Exception] = []
self._stop_events[cfg.name] = stop_event
asyncio.create_task(
self._run_server(cfg, ready_event, stop_event, error_holder),
name=f"mcp-keeper-{cfg.name}",
)
# Attendre que la connexion soit prête (ou échoue)
try:
await asyncio.wait_for(ready_event.wait(), timeout=30)
except asyncio.TimeoutError:
stop_event.set()
raise TimeoutError(f"Timeout lors de la connexion à {cfg.name}")
if error_holder:
raise error_holder[0]
async def _run_server(
self,
cfg: MCPServerConfig,
ready_event: asyncio.Event,
stop_event: asyncio.Event,
error_holder: list[Exception],
) -> None:
"""Keeper task : possède le context manager de bout en bout."""
try:
from mcp import ClientSession, StdioServerParameters
if cfg.command:
from mcp.client.stdio import stdio_client
transport = stdio_client(StdioServerParameters(
command=cfg.command,
args=cfg.resolved_args(),
env=cfg.env,
))
else:
from mcp.client.sse import sse_client
transport = sse_client(cfg.url)
async with transport as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
tools_resp = await session.list_tools()
self._sessions[cfg.name] = session
self._raw_tools[cfg.name] = tools_resp.tools
server_safe = _sanitize_name(cfg.name)
for tool in tools_resp.tools:
tool_safe = _sanitize_name(tool.name)
self._tool_map[f"{server_safe}__{tool_safe}"] = (cfg.name, tool.name)
print(f"[MCP] ✅ {cfg.name}{len(tools_resp.tools)} outil(s) disponible(s)")
ready_event.set()
# Maintenir la connexion jusqu'au signal d'arrêt
await stop_event.wait()
except Exception as e:
error_holder.append(e)
ready_event.set() # débloquer _connect_server même en cas d'erreur
finally:
self._sessions.pop(cfg.name, None)
self._raw_tools.pop(cfg.name, None)
to_remove = [k for k, v in self._tool_map.items() if v[0] == cfg.name]
for k in to_remove:
del self._tool_map[k]
async def _call_tool_async(self, server_name: str, tool_name: str, arguments: dict) -> str:
session = self._sessions[server_name]
result = await session.call_tool(tool_name, arguments)
parts = []
for item in result.content:
if hasattr(item, "text"):
parts.append(item.text)
else:
parts.append(str(item))
return "\n".join(parts) if parts else "(aucun résultat)"
_manager: MCPManager | None = None
_lock = threading.Lock()
def get_manager() -> MCPManager:
global _manager
with _lock:
if _manager is None:
_manager = MCPManager()
return _manager
def reset_manager() -> None:
global _manager
with _lock:
if _manager is not None:
_manager.shutdown()
_manager = MCPManager()