import asyncio
import logging
from contextlib import AsyncExitStack
from typing import Any

from mcp import ClientSession, StdioServerParameters, types
from mcp.client.stdio import stdio_client
from .base import BaseMcpClient


class StdioMcpClient(BaseMcpClient):
    def __init__(self, name: str, config: dict[str, Any]) -> None:
        super().__init__(name, config)
        self.session: ClientSession | None = None
        self.exit_stack: AsyncExitStack = AsyncExitStack()
        self._cleanup_lock: asyncio.Lock = asyncio.Lock()

    async def sampling_callback(self, message: types.CreateMessageRequestParams) -> types.CreateMessageResult:
        return types.CreateMessageResult(
            role="assistant",
            content=types.TextContent(type="text", text="模拟模型响应"),
            model="gpt-3.5-turbo",
            stopReason="endTurn",
        )

    async def initialize(self) -> None:
        command = self.config.get("command")
        args = self.config.get("args")
        env = self.config.get("env", None)

        if not command or not args:
            raise ValueError(f"Invalid stdio config for client '{self.name}'")

        try:
            server_params = StdioServerParameters(command=command, args=args, env=env)
            read_stream, write_stream, *_ = await self.exit_stack.enter_async_context(
                stdio_client(server_params)
            )
            self.session = await self.exit_stack.enter_async_context(
                ClientSession(read_stream, write_stream, sampling_callback=self.sampling_callback)
            )
            await self.session.initialize()

        except Exception as e:
            logging.error(f"Error initializing Stdio client {self.name}: {e}")
            await self.cleanup()
            raise

    async def list_tools(self) -> list[Any]:
        if not self.session:
            raise RuntimeError(f"Client {self.name} not initialized")

        tools_response = await self.session.list_tools()
        tools = []

        for item in tools_response:
            if isinstance(item, tuple) and item[0] == "tools":
                tools.extend(item[1])

        return tools

    async def call_tool(
        self,
        tool_name: str,
        arguments: dict[str, Any],
        retries: int = 2,
        delay: float = 1.0,
    ) -> Any:
        if not self.session:
            raise RuntimeError(f"Client {self.name} not initialized")

        attempt = 0
        while attempt < retries:
            try:
                res = await self.session.call_tool(tool_name, arguments)
                return res.content[0].text
            except Exception as e:
                attempt += 1
                logging.warning(f"[{self.name}] Error calling tool '{tool_name}': {e} (attempt {attempt})")
                if attempt < retries:
                    await asyncio.sleep(delay)
                else:
                    raise

    async def cleanup(self) -> None:
        async with self._cleanup_lock:
            try:
                # 关闭会话相关资源
                if self.session:
                    # await self.session.close()
                    self.session = None

                # 关闭 exit_stack 管理的资源
                await self.exit_stack.aclose()

            except asyncio.CancelledError:
                self.session = None
            except Exception as e:
                logging.error(f"Error during cleanup of client {self.name}: {e}")
                self.session = None

    # 新增异步上下文管理器协议
    async def __aenter__(self):
        await self.initialize()
        return self

    async def __aexit__(self, exc_type, exc, tb):
        await self.cleanup()
