Add gitignor
This commit is contained in:
188
assistant/mcp_client.py
Normal file
188
assistant/mcp_client.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Support des serveurs MCP (Model Context Protocol) pour l'assistant.
|
||||
|
||||
Configure les serveurs dans un profil YAML sous la clé `mcp_servers` :
|
||||
|
||||
mcp_servers:
|
||||
- name: filesystem
|
||||
command: npx
|
||||
args: ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
|
||||
- name: my_sse_server
|
||||
url: "http://localhost:3000/sse"
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
from contextlib import AsyncExitStack
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPServerConfig:
|
||||
name: str
|
||||
command: str | None = None
|
||||
args: list[str] = field(default_factory=list)
|
||||
env: dict[str, str] | None = None
|
||||
url: str | None = None
|
||||
|
||||
|
||||
def _sanitize_name(name: str) -> str:
|
||||
"""Transforme un nom en identifiant valide pour l'API Mistral (^[a-zA-Z0-9_-]{1,64}$)."""
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]
|
||||
|
||||
|
||||
class MCPManager:
|
||||
"""Gère les connexions aux serveurs MCP et l'exécution des outils."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._sessions: dict[str, Any] = {}
|
||||
self._raw_tools: dict[str, list] = {}
|
||||
# mistral_name -> (server_name, original_tool_name)
|
||||
self._tool_map: dict[str, tuple[str, str]] = {}
|
||||
self._exit_stacks: dict[str, AsyncExitStack] = {}
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
def _run(self, coro, timeout: int = 30) -> Any:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
return future.result(timeout=timeout)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Public API #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def load_servers(self, configs: list[MCPServerConfig]) -> None:
|
||||
self._run(self._load_servers_async(configs), timeout=60)
|
||||
|
||||
def get_mistral_tools(self) -> list[dict]:
|
||||
tools = []
|
||||
for mistral_name, (server_name, tool_name) in self._tool_map.items():
|
||||
raw = {t.name: t for t in self._raw_tools.get(server_name, [])}
|
||||
tool = raw.get(tool_name)
|
||||
if tool is None:
|
||||
continue
|
||||
tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": mistral_name,
|
||||
"description": tool.description or "",
|
||||
"parameters": tool.inputSchema or {"type": "object", "properties": {}},
|
||||
},
|
||||
})
|
||||
return tools
|
||||
|
||||
def call_tool(self, mistral_name: str, arguments: dict) -> str:
|
||||
if mistral_name not in self._tool_map:
|
||||
return f"Outil inconnu : {mistral_name}"
|
||||
server_name, tool_name = self._tool_map[mistral_name]
|
||||
if server_name not in self._sessions:
|
||||
return f"Serveur MCP '{server_name}' non disponible"
|
||||
return self._run(self._call_tool_async(server_name, tool_name, arguments))
|
||||
|
||||
def shutdown(self) -> None:
|
||||
try:
|
||||
self._run(self._shutdown_async(), timeout=10)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def summary(self) -> list[tuple[str, int]]:
|
||||
"""Retourne [(server_name, tool_count), ...] pour les serveurs connectés."""
|
||||
return [(name, len(self._raw_tools.get(name, []))) for name in self._sessions]
|
||||
|
||||
@property
|
||||
def has_tools(self) -> bool:
|
||||
return bool(self._tool_map)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Async internals #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
async def _load_servers_async(self, configs: list[MCPServerConfig]) -> None:
|
||||
for cfg in configs:
|
||||
try:
|
||||
await self._connect_server(cfg)
|
||||
except Exception as e:
|
||||
print(f"[MCP] ❌ Connexion {cfg.name} impossible : {e}")
|
||||
|
||||
async def _connect_server(self, cfg: MCPServerConfig) -> None:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
stack = AsyncExitStack()
|
||||
|
||||
if cfg.command:
|
||||
params = StdioServerParameters(
|
||||
command=cfg.command,
|
||||
args=cfg.args or [],
|
||||
env=cfg.env,
|
||||
)
|
||||
read, write = await stack.enter_async_context(stdio_client(params))
|
||||
else:
|
||||
from mcp.client.sse import sse_client
|
||||
read, write = await stack.enter_async_context(sse_client(cfg.url))
|
||||
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
|
||||
self._sessions[cfg.name] = session
|
||||
self._exit_stacks[cfg.name] = stack
|
||||
|
||||
tools_resp = await session.list_tools()
|
||||
self._raw_tools[cfg.name] = tools_resp.tools
|
||||
|
||||
server_safe = _sanitize_name(cfg.name)
|
||||
for tool in tools_resp.tools:
|
||||
tool_safe = _sanitize_name(tool.name)
|
||||
mistral_name = f"{server_safe}__{tool_safe}"
|
||||
self._tool_map[mistral_name] = (cfg.name, tool.name)
|
||||
|
||||
print(f"[MCP] ✅ {cfg.name} — {len(tools_resp.tools)} outil(s) disponible(s)")
|
||||
|
||||
async def _call_tool_async(self, server_name: str, tool_name: str, arguments: dict) -> str:
|
||||
session = self._sessions[server_name]
|
||||
result = await session.call_tool(tool_name, arguments)
|
||||
parts = []
|
||||
for item in result.content:
|
||||
if hasattr(item, "text"):
|
||||
parts.append(item.text)
|
||||
else:
|
||||
parts.append(str(item))
|
||||
return "\n".join(parts) if parts else "(aucun résultat)"
|
||||
|
||||
async def _shutdown_async(self) -> None:
|
||||
for stack in list(self._exit_stacks.values()):
|
||||
try:
|
||||
await stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._sessions.clear()
|
||||
self._raw_tools.clear()
|
||||
self._tool_map.clear()
|
||||
self._exit_stacks.clear()
|
||||
|
||||
|
||||
_manager: MCPManager | None = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_manager() -> MCPManager:
|
||||
global _manager
|
||||
with _lock:
|
||||
if _manager is None:
|
||||
_manager = MCPManager()
|
||||
return _manager
|
||||
|
||||
|
||||
def reset_manager() -> None:
|
||||
global _manager
|
||||
with _lock:
|
||||
if _manager is not None:
|
||||
_manager.shutdown()
|
||||
_manager = MCPManager()
|
||||
Reference in New Issue
Block a user