from fastapi import APIRouter, WebSocket, WebSocketDisconnect
import asyncio
import json

from agent.llm.llm import AsyncQwen3Agent
from agent.llm import main


class SessionManager:
    def __init__(self, session_id: str):
        self.session_id = session_id
        self.state = "idle"
        self.queue = asyncio.Queue()
        self.main_agent: AsyncQwen3Agent = main.create_agent(session_id)
        self.history = []
        self.connections: set[WebSocket] = set()  # 多个连接
    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.connections.add(websocket)

    def disconnect(self, websocket: WebSocket):
        self.connections.discard(websocket)

    async def broadcast(self, message: dict):
        text = json.dumps(message)
        for conn in self.connections.copy():
            try:
                await conn.send_text(text)
            except Exception as e:
                print(f"移除连接: {e}")
                self.disconnect(conn)


    async def _handle_stream(self, content: str = None, tool_calls: dict = None):
        stream = (
            self.main_agent.stream(user_input=content)
            if content else
            self.main_agent.stream(tool_calls=tool_calls)
        )

        async for event in stream:
            event["session_id"] = self.session_id

            # 推送当前事件
            await self.queue.put(event)
            await self.broadcast(event)

            # 如果触发了工具调用，则递归继续处理
            if event.get("type") == "tool_calls":
                tool_info = event["info"]
                print(f"\n【触发工具调用】函数：{tool_info}")

                # 递归处理工具调用流
                await self._handle_stream(tool_calls=tool_info)
    async def start(self, content: str = None, tool_calls: dict = None):
        self.state = "running"
        try:
            await self._handle_stream(content=content, tool_calls=tool_calls)
        except Exception as e:
            print(f"[错误] 处理流时出错: {e}")
            self.state = "completed"
        else:
            self.state = "completed"
    async def send_history_and_current(self, websocket: WebSocket):
        for msg in self.history:
            await websocket.send_text(json.dumps(msg))


