import asyncio
import json
import re
import time
import uuid

from redis import Redis
from openai import AsyncOpenAI

from agent.memory.memory import RedisMemory
from agent.tools.tools import tool_manager, ToolManager
from agent.memory.memory import BaseMemory

from abc import ABC, abstractmethod
import json


class AsyncBaseAgent(ABC):
    def __init__(self, api_key, base_url, model_name, tool_manager: ToolManager, memory: BaseMemory = None,
                 sys_prompt=None):
        self.system_prompt = sys_prompt
        self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
        self.model = model_name
        self.tool_manager = tool_manager
        self.memory = memory
        self.tool_calls_acc = {}
        self.think = False
        self.cur_stream_id = None

    def build_messages(self, user_input: str = None, system_prompt: str = None, tool_msgs: list[dict] = None):
        messages = []
        think_prompt = '' if self.think else '/nothink'
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt + think_prompt})

        if self.memory:
            messages.extend(self.memory.load())

        # 需要保存的新消息列表
        new_msgs_to_save = []

        if user_input:  # 仅当 user_input 有内容时才添加
            user_msg = {"role": "user", "content": user_input}
            messages.append(user_msg)
            new_msgs_to_save.append(user_msg)

        if tool_msgs:
            messages.extend(tool_msgs)
            new_msgs_to_save.extend(tool_msgs)

        # 追加新消息到内存
        if self.memory and new_msgs_to_save:
            # 如果你的内存实现有 extend 方法，就用它，否则用循环 append
            if hasattr(self.memory, "extend"):
                self.memory.extend(new_msgs_to_save)
            else:
                for msg in new_msgs_to_save:
                    self.memory.append(msg)

        return messages

    def reload_messages(self):
        messages = []
        think_prompt = '' if self.think else '/nothink'
        system_prompt = self.system_prompt
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt + think_prompt})

        if self.memory:
            messages.extend(self.memory.load())
        return messages
    @abstractmethod
    def process_stream_chunk(self, delta):
        """
        留给子类实现的接口，用于处理流式返回 delta 内容。
        应返回: text, chunk_type
        """
        pass

    async def stream(self, user_input: str = None, tool_calls: list[dict] = None):
        if user_input is not None:
            async for event in self.chat_stream(user_input):
                yield event
        elif tool_calls is not None:
            async for event in self.continue_tool_response(tool_calls):
                yield event
        else:
            raise ValueError("必须传入 user_input 或 tool_calls 之一")

    # 首次对话
    async def chat_stream(self, user_input: str):
        def message_factory ():
            return  self.build_messages(user_input=user_input, system_prompt=self.system_prompt)
        async for event in self.stream_response_from_messages(message_factory):
            yield event

    # 工具返回接力
    async def continue_tool_response(self, tool_calls: list[dict], stream_id=None):
        tasks = []
        stream_id = str(uuid.uuid4())
        for tool_info in tool_calls:
            raw_args = tool_info["function"].get("arguments", "")
            # 如果是空字符串，替换为 "{}"
            if not raw_args.strip():
                raw_args = "{}"
                tool_info["function"]["arguments"] = raw_args  # ✅ 同步更新回去
            try:
                args = json.loads(tool_info["function"]["arguments"])
            except json.JSONDecodeError as e:
                print(f"工具参数解析失败: {e}")
                args = {}

            async def timed_call(tool_id, name, args, ):
                start_time = time.perf_counter()
                try:
                    result = await self.tool_manager.call_tool(name, args)
                except Exception as e:
                    result = {"error": str(e)}
                end_time = time.perf_counter()
                duration = end_time - start_time
                print(f"🛠️ 工具 `{name}`（ID: {tool_id}）执行完成，用时：{duration:.2f} 秒")
                return tool_id, result

            task = asyncio.create_task(
                timed_call(tool_info["id"], tool_info["function"]["name"], args)
            )
            tasks.append(task)

        # 并发执行所有工具调用
        results = await asyncio.gather(*tasks)

        # 构造 tool 消息
        tool_msgs = []
        for tool_id, result in results:
            d = {
                "role": "tool",
                "tool_call_id": tool_id,
                "content": json.dumps(result, ensure_ascii=False),
            }
            tool_msgs.append(d)
            yield {
                "stream_id":  str(uuid.uuid4()),
                "type": "tool_calls_result",
                "stream_group_id": self.cur_stream_id,
                "info": d

            }


        # 构建新的消息上下文
        def message_factory():

            return self.build_messages(user_input=None, system_prompt=self.system_prompt, tool_msgs=tool_msgs)

        # 继续流式响应
        async for event in self.stream_response_from_messages(message_factory):
            yield event

    async def append_system_reflection_message(self, messages: list[dict]):
        """
        在messages末尾添加一段系统消息，引导模型用自然语言对当前上下文进行分析和反思，
        并规划工具调用的步骤顺序，确保在正式操作前有清晰的思考和计划。
        """
        new_messages = messages + [{
            "role": "system",
            "content": (
                           "🧠【你的任务】：你正在进行**思考阶段**，请根据已有对话和工具调用结果，**用自然语言描述你下一步打算做什么**。\n\n"

                           "📌【规则】：\n"
                           "- ❌ 禁止直接回答用户问题；\n"
                           "- ❌ 禁止介绍或列出你有哪些工具；\n"
                           "- ❌ 禁止输出代码、JSON 或结构体；\n\n"

                           "📉【失败处理】：\n"
                           "- 工具失败时，优先分析失败原因；\n"
                           "- 参数问题请说明如何修正；\n"
                           "- 信息不足请主动向用户提问；\n\n"

                           "🚀【执行策略】：\n"
                           "- 多个无依赖任务应并行执行；\n"
                           "- 有依赖的请说明执行顺序；\n"
                           "- 可合并、去重、优化的，请合理规划。\n\n"

                           "✅【输出内容】（仅自然语言）：\n"
                           "- 你对当前任务的理解；\n"
                           "- 你的下一步计划（调用什么、顺序、是否并行）；\n"
                           "- 如果信息不足，你打算问什么。\n\n"

                           "⚠️ 请**只输出自然语言描述的计划内容**，禁止包含任何调用结构或代码格式。"
                       ) + ('' if self.think else '\n/nothink')
        }]

        new_messages = new_messages[1:]
        async for result in self._handle_streaming_response(new_messages):
            if result["type"] in ("done", "tool_calls"):
                msg = result.get("msg")
                msg['role']='system'
                self._append_to_memory(msg)
            result['think'] = True
            yield result

    async def stream_response_from_messages(self, messages_factory):

        messages = messages_factory()
        # # #对当前上下文进行系统反思，并将思考的结果作为一条系统消息
        async for result in self.append_system_reflection_message(messages):
            yield result
        messages = self.reload_messages()
        # 调用提取后的处理方法
        async for result in self._handle_streaming_response(messages):
            if result["type"] in ("done", "tool_calls"):
                msg = result.get("msg") or {
                    "role": "assistant",
                    "tool_calls": result.get("info")
                }
                self._append_to_memory(msg)

            yield result

    def _append_to_memory(self, message: dict):
        """
        统一 memory 存储方法，便于扩展管理。
        """
        self.memory.append(message)

    async def _handle_streaming_response(self, messages: list[dict]):
        stream_id = str(uuid.uuid4())
        tools = await self.tool_manager.list_tools()
        tool_call_active = False

        response = await self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            tools=tools,
            tool_choice="auto",
            temperature=0.7,
            stream=True,
        )

        full_text = ""
        async for chunk in response:
            delta = chunk.choices[0].delta
            text, chunk_type = self.process_stream_chunk(delta)

            if chunk_type == "content":
                full_text += text
                yield {
                    "stream_id": stream_id,
                    "type": "content",
                    "text": text,
                    "stream_group_id": self.cur_stream_id
                }

            elif chunk_type == "tool_call":
                tool_call_active = True

                # ✅ 每片都返回完整 tool_calls_acc（实时更新版）
                yield {
                    "stream_id": stream_id,
                    "type": "tool_call_stream",
                    "info": list(self.tool_calls_acc.values()),
                    "stream_group_id": self.cur_stream_id
                }

        # ✅ 全部 tool_call 结束后返回最终结构（带校验）
        if tool_call_active and self.tool_calls_acc:
            tool_calls = list(self.tool_calls_acc.values())
            self.tool_calls_acc = {}

            for tool_call in tool_calls:
                if (
                        tool_call.get("type") == "function"
                        and isinstance(tool_call.get("function"), dict)
                ):
                    raw_args = tool_call["function"].get("arguments", "")
                    try:
                        if not raw_args or not raw_args.strip():
                            raise ValueError("空字符串")
                        json.loads(raw_args)
                    except (json.JSONDecodeError, ValueError) as e:
                        print(f"arguments 非法，已替换为 '{{}}'：{e}，tool_call id: {tool_call.get('id')}")
                        tool_call["function"]["arguments"] = "{}"

            assistant_tool_msg = {
                "role": "assistant",
                "tool_calls": tool_calls
            }

            yield {
                "stream_id": stream_id,
                "type": "tool_calls",
                "info": tool_calls,
                "stream_group_id": self.cur_stream_id,
                "msg": assistant_tool_msg
            }
            return

        # ✅ 没有工具调用时的普通 content 回复
        final_msg = {"role": "assistant", "content": full_text}
        yield {
            "stream_id": stream_id,
            "type": "done",
            "stream_group_id": self.cur_stream_id,
            "msg": final_msg
        }

    def extract_tool_call_info(self, delta):
        tc = delta.tool_calls[0]
        return {
            "id": getattr(tc, "id", None),
            "type": getattr(tc, "type", None),
            "function_name": getattr(tc.function, "name", None) if tc.function else None,
            "arguments": getattr(tc.function, "arguments", "") if tc.function else ""
        }

    async def _handle_tool_call(self, tool_call_info):
        print("\n【触发工具调用】")
        print(f"函数名：{tool_call_info['function_name']}")
        print(f"参数：{tool_call_info['arguments']}")
        try:
            args = json.loads(tool_call_info["arguments"])
        except json.JSONDecodeError as e:
            print(f"参数解析失败: {e}")
            args = {}

        result = await self.tool_manager.call_tool(tool_call_info["function_name"], args)
        print(f"\n【工具调用结果】：{result}")

    def handle_special_chunk(self, chunk_type, text):
        """
        子类可以重写这个方法处理特殊类型的输出（如 <think>）
        """
        pass


class AsyncQwen3Agent(AsyncBaseAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.in_think_phase = False
        self.start_tag_buffer = ""
        self.end_tag_buffer = ""

    def process_stream_chunk(self, delta):
        text = ""
        chunk_type = None

        # ========= 普通文本处理（content 或 <think>） =========
        if hasattr(delta, "content") and delta.content:
            c = delta.content
            if "<think>" in c:
                self.in_think_phase = True
                chunk_type = "think_start"
            elif "</think>" in c:
                self.in_think_phase = False
                chunk_type = "think_end"
            elif self.in_think_phase:
                chunk_type = "think_continue"
            else:
                chunk_type = "content"
            text = c
            return text, chunk_type

        # ========= 工具调用流式处理 =========
        if hasattr(delta, "tool_calls") and delta.tool_calls:
            if not hasattr(self, "tool_calls_acc"):
                self.tool_calls_acc = {}

            for tool_call in delta.tool_calls:
                index = tool_call.index
                is_new = index not in self.tool_calls_acc

                if is_new:
                    self.tool_calls_acc[index] = {
                        "id": tool_call.id or "",
                        "type": tool_call.type or "function",
                        "function": {
                            "name": "",
                            "arguments": ""
                        }
                    }

                acc = self.tool_calls_acc[index]

                if tool_call.id:
                    acc["id"] = tool_call.id
                if tool_call.type:
                    acc["type"] = tool_call.type
                if tool_call.function:
                    if tool_call.function.name:
                        acc["function"]["name"] += tool_call.function.name
                    if tool_call.function.arguments:
                        # print(tool_call.function.arguments)
                        acc["function"]["arguments"] += tool_call.function.arguments

                # 判断 stream_status
                stream_status = "streaming"
                if is_new:
                    stream_status = "start"
                elif acc["function"]["arguments"].strip().endswith("}") or \
                        acc["function"]["arguments"].strip().endswith(")") or \
                        acc["function"]["arguments"].strip().endswith("\""):
                    stream_status = "complete"

                # 每个 tool_call 片段都返回
                text = {
                    "index": index,
                    "id": acc["id"],
                    "type": acc["type"],
                    "function": {
                        "name": acc["function"]["name"],
                        "arguments": acc["function"]["arguments"]
                    },
                    "stream_status": stream_status  # <== 加的字段
                }
                chunk_type = "tool_call"

                return text, chunk_type

        return text, chunk_type

    def extract_tool_call_info(self, delta):
        return {
            "id": self.tool_call_acc["id"],
            "type": self.tool_call_acc["type"],
            "function_name": self.tool_call_acc["function"]["name"],
            "arguments": self.tool_call_acc["function"]["arguments"]
        }

    def handle_special_chunk(self, chunk_type, text):
        if chunk_type == "think_start":
            print("\n【模型思考阶段开始】", end="", flush=True)
        elif chunk_type == "think_continue":
            print(text, end="", flush=True)
        elif chunk_type == "think_end":
            print(text, flush=True)
            print("【模型思考阶段结束】\n")
