import json
from openai import OpenAI
from redis import Redis

from agent.memory.memory import RedisMemory


class BaseAgent:
    def __init__(self, api_key, base_url, model_name, tools=None, memory=None, system_prompt=None):
        self.client = OpenAI(api_key=api_key, base_url=base_url)
        self.model = model_name
        self.tools = tools or []
        self.memory = memory
        self.messages = []

        # 注入 system prompt 到对话历史最前
        if system_prompt:
            self.messages.append({"role": "system", "content": system_prompt})

        # 加载历史记录（如果有）
        if self.memory:
            self.messages += self.memory.load()

    def add_user_message(self, content):
        self.messages.append({"role": "user", "content": content})

    def add_assistant_message(self, content=None, tool_call=None):
        msg = {"role": "assistant"}
        if content:
            msg["content"] = content
        if tool_call:
            msg["tool_calls"] = [tool_call]
        self.messages.append(msg)

    def add_tool_message(self, tool_call_id, name, content):
        self.messages.append({
            "role": "tools",
            "tool_call_id": tool_call_id,
            "name": name,
            "content": content
        })

    def process_stream_chunk(self, delta):
        """普通模型流式解析，返回 (文本, 是否触发工具调用, tool_call信息dict)"""
        text = ""
        tool_call_triggered = False
        tool_call_info = None

        if hasattr(delta, "content") and delta.content:
            text = delta.content

        if hasattr(delta, "tool_calls") and delta.tool_calls:
            tool_call_triggered = True
            tc = delta.tool_calls[0]
            tool_call_info = {
                "id": getattr(tc, "id", None),
                "type": getattr(tc, "type", None),
                "function_name": getattr(tc.function, "name", None) if tc.function else None,
                "arguments": getattr(tc.function, "arguments", "") if tc.function else ""
            }
        return text, tool_call_triggered, tool_call_info

    def call_tool(self, tool_call_info):
        """根据工具调用信息执行本地工具，示例只支持get_weather"""
        if not tool_call_info:
            return None

        fname = tool_call_info.get("function_name")
        args_str = tool_call_info.get("arguments", "")
        try:
            args = json.loads(args_str)
        except Exception:
            args = {}

        if fname == "get_weather":
            city = args.get("city", "")
            return self.get_weather(city)
        # 这里可扩展更多工具调用
        return None

    def get_weather(self, city):
        fake_data = {
            "上海": "晴，27℃，无降水",
            "北京": "多云，22℃，轻微降水",
            "广州": "阴，29℃，无降水",
        }
        return fake_data.get(city, "无法获取该城市的天气信息")

    def chat_stream(self, user_input):
        self.add_user_message(user_input)
        tool_call_info = None
        tool_call_triggered = False

        print("【模型回应】")

        response = self.client.chat.completions.create(
            model=self.model,
            messages=self.messages,
            tools=self.tools,
            tool_choice="auto",
            temperature=0.7,
            stream=True,
        )

        for chunk in response:
            delta = chunk.choices[0].delta
            text, triggered, tcall_info, chunk_type = self.process_stream_chunk(delta)

            if chunk_type == "think_start":
                print("\n【模型思考阶段开始】")
            elif chunk_type == "think_continue":
                print(text, end="", flush=True)  # ✅ 追加打印思考内容
            elif chunk_type == "think_end":
                print(f"【模型思考阶段结束】：{text}\n")
            elif chunk_type == "content":
                print(f"【模型回答】：{text}", end="", flush=True)
            elif chunk_type == "tool_call":
                tool_call_triggered = True
                tool_call_info = tcall_info
                print("\n【触发工具调用】")
                print(f"函数名：{tool_call_info.get('function_name')}")
                print(f"参数：{tool_call_info.get('arguments')}")



class Qwen3Agent(BaseAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.in_think_phase = False
        self.tool_call_acc = {
            "id": None,
            "type": None,
            "function": {"name": None, "arguments": ""}
        }

    def process_stream_chunk(self, delta):
        text = ""
        chunk_type = None

        # 处理文本内容和think阶段
        if hasattr(delta, "content") and delta.content:
            c = delta.content

            if "<think>" in c:
                self.in_think_phase = True
                c = c.replace("<think>", "")
                chunk_type = "think_start"

            if "</think>" in c:
                self.in_think_phase = False
                c = c.replace("</think>", "")
                chunk_type = "think_end"

            if self.in_think_phase:
                if chunk_type is None:
                    chunk_type = "think_continue"
                text = c
            else:
                if chunk_type is None:
                    chunk_type = "content"
                text = c

        # 处理工具调用，累积参数字符串
        if hasattr(delta, "tool_calls") and delta.tool_calls:
            tool_call = delta.tool_calls[0]
            if tool_call.id:
                self.tool_call_acc["id"] = tool_call.id
            if tool_call.type:
                self.tool_call_acc["type"] = tool_call.type
            if tool_call.function and tool_call.function.name:
                self.tool_call_acc["function"]["name"] = tool_call.function.name
            if tool_call.function and tool_call.function.arguments:
                self.tool_call_acc["function"]["arguments"] += tool_call.function.arguments
            chunk_type = "tool_call"

        return text, chunk_type

    def chat_stream(self, user_input):
        self.add_user_message(user_input)
        tool_call_triggered = False
        tool_call_info = None

        print("\n【模型回应】")

        response = self.client.chat.completions.create(
            model=self.model,
            messages=self.messages,
            tools=self.tools,
            tool_choice="auto",
            temperature=0.7,
            stream=True,
        )

        for chunk in response:
            delta = chunk.choices[0].delta
            text, chunk_type = self.process_stream_chunk(delta)

            if chunk_type == "think_start":
                print("\n【模型思考阶段开始】", end="", flush=True)
            elif chunk_type == "think_continue":
                print(text, end="", flush=True)
            elif chunk_type == "think_end":
                print(text, flush=True)
                print("【模型思考阶段结束】\n")
            elif chunk_type == "content":
                print(text, end="", flush=True)
            elif chunk_type == "tool_call":
                tool_call_triggered = True
                # 工具调用信息在 self.tool_call_acc 中累积
                tool_call_info = {
                    "id": self.tool_call_acc["id"],
                    "type": self.tool_call_acc["type"],
                    "function_name": self.tool_call_acc["function"]["name"],
                    "arguments": self.tool_call_acc["function"]["arguments"]
                }

        if tool_call_triggered and tool_call_info:
            print("\n【触发工具调用】")
            print(f"函数名：{tool_call_info['function_name']}")
            print(f"参数：{tool_call_info['arguments']}")
            # 调用工具
            # result = self.call_tool(tool_call_info)
            # 添加对话消息等后续逻辑...

import logging
# 先关闭全部低于 WARNING 级别的日志
logging.basicConfig(level=logging.WARNING)

# 对某些第三方库单独降级日志，比如 openai、httpx、urllib3
logging.getLogger("openai").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)

if __name__ == "__main__":
    # 你根据实际API KEY和地址替换
    API_KEY = "your_api_key"
    BASE_URL = "http://180.163.119.106:16031/v1"
    MODEL_NAME = "Fusion2-chat-v2.0"  # 或者 qwen3模型名称
    #
    # API_KEY = "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"
    # BASE_URL = "https://api.openai.com/v1"
    # MODEL_NAME = "gpt-4"  # 或者 qwen3模型名称

    # 定义工具
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_weather",
                "description": "获取指定城市的天气情况",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "city": {
                            "type": "string",
                            "description": "城市名称，用于查询天气"
                        }
                    },
                    "required": ["city"]
                }
            }
        }
    ]



    # 假设 session_id 和 agent_id 由上层控制
    session_id = "session001"
    agent_id = "weather_agent"

    redis_client = Redis(host="localhost", port=6379,password='21221', decode_responses=True)
    memory = RedisMemory(session_id="session001", agent_id="weather", redis_client=redis_client)

    agent = Qwen3Agent(
        api_key=API_KEY,
        base_url=BASE_URL,
        model_name=MODEL_NAME,
        tools=tools,
        memory=memory
    )

    user_question = "请帮我查询上海的天气"
    # user_question = "写一篇200字的关于人工智能的文章"
    agent.chat_stream(user_question)

