104 lines
3.5 KiB
Python
104 lines
3.5 KiB
Python
import json
|
|
from typing import Generator
|
|
|
|
from mistralai.client import Mistral
|
|
from . import config
|
|
|
|
_client = Mistral(api_key=config.MISTRAL_API_KEY)
|
|
|
|
_history: list[dict] = []
|
|
|
|
|
|
def reset_history() -> None:
|
|
"""Efface l'historique de conversation."""
|
|
_history.clear()
|
|
|
|
|
|
def chat_stream(user_message: str) -> Generator[str, None, None]:
|
|
"""Génère la réponse du LLM token par token via streaming.
|
|
|
|
Les appels d'outils MCP sont exécutés internement (sans streaming).
|
|
Seule la réponse textuelle finale est streamée sous forme de chunks.
|
|
"""
|
|
from . import mcp_client
|
|
|
|
_history.append({"role": "user", "content": user_message})
|
|
|
|
manager = mcp_client.get_manager()
|
|
tools = manager.get_mistral_tools() if manager.has_tools else None
|
|
|
|
while True:
|
|
messages = [{"role": "system", "content": config.SYSTEM_PROMPT}] + _history
|
|
kwargs: dict = {"model": config.LLM_MODEL, "messages": messages}
|
|
if tools:
|
|
kwargs["tools"] = tools
|
|
|
|
accumulated_content = ""
|
|
tool_calls_received = None
|
|
|
|
for event in _client.chat.stream(**kwargs):
|
|
ch = event.data.choices[0]
|
|
delta = ch.delta
|
|
|
|
# Yield text chunks (isinstance check guards against Unset sentinel)
|
|
if isinstance(delta.content, str) and delta.content:
|
|
accumulated_content += delta.content
|
|
yield delta.content
|
|
|
|
if delta.tool_calls:
|
|
tool_calls_received = delta.tool_calls
|
|
|
|
if tool_calls_received:
|
|
# Append assistant turn with tool calls to history
|
|
_history.append({
|
|
"role": "assistant",
|
|
"content": accumulated_content or "",
|
|
"tool_calls": [
|
|
{
|
|
"id": tc.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": tc.function.name,
|
|
"arguments": tc.function.arguments,
|
|
},
|
|
}
|
|
for tc in tool_calls_received
|
|
],
|
|
})
|
|
|
|
# Execute each tool and append results
|
|
for tc in tool_calls_received:
|
|
tool_name = tc.function.name
|
|
try:
|
|
args = (
|
|
json.loads(tc.function.arguments)
|
|
if isinstance(tc.function.arguments, str)
|
|
else tc.function.arguments
|
|
)
|
|
print(f"\n 🔧 [MCP] {tool_name}({_short_args(args)})", flush=True)
|
|
result = manager.call_tool(tool_name, args)
|
|
except Exception as e:
|
|
result = f"Erreur lors de l'appel à {tool_name} : {e}"
|
|
|
|
_history.append({
|
|
"role": "tool",
|
|
"content": result,
|
|
"tool_call_id": tc.id,
|
|
})
|
|
# Loop to get the next (final) response
|
|
|
|
else:
|
|
# Pure text response — already yielded chunk by chunk; save to history
|
|
_history.append({"role": "assistant", "content": accumulated_content})
|
|
break
|
|
|
|
|
|
def chat(user_message: str) -> str:
|
|
"""Envoie un message au LLM et retourne la réponse complète (non-streaming)."""
|
|
return "".join(chat_stream(user_message))
|
|
|
|
|
|
def _short_args(args: dict) -> str:
|
|
text = json.dumps(args, ensure_ascii=False)
|
|
return text[:80] + "…" if len(text) > 80 else text
|