from fastapi import FastAPI, Request, Security, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.gzip import GZipMiddleware
from collections import OrderedDict
from dataclasses import dataclass, field
from pydantic_settings import BaseSettings
import httpx, json, time, logging, hashlib, asyncio, uuid, pickle, pathlib
from datetime import datetime, timedelta

_start_time = time.time()  # 记录启动时间

# ── 配置外部化 ─────────────────────────────────────────────
class Settings(BaseSettings):
    target: str = "https://llm.test.com/v1"
    api_key: str = "xxx"
    proxy_api_key: str = "proxy-secret-key"
    cache_ttl: int = 10
    max_cache_size: int = 100
    max_request_size: int = 10 * 1024 * 1024  # 10MB
    max_retries: int = 2
    retry_delay: float = 1.0
    tool_desc_max_len: int = 100
    connect_timeout: float = 5.0
    read_timeout: float = 60.0
    write_timeout: float = 10.0
    log_file: str = "proxy.log"

    # 模型映射：Claude Desktop 模型名 → 实际私有模型名
    model_sonnet: str = "pri-kimi-26"
    model_haiku: str = "pri-glm-51"
    model_opus: str = "pri-deepseek-v4"

    # 工具过滤：白名单优先，两者都空则不过滤
    # 格式：逗号分隔的工具名，如 read_file,write_file
    tool_whitelist: str = ""  # 只保留列出的工具
    tool_blacklist: str = ""  # 过滤掉列出的工具

    class Config:
        env_file = ".env"

settings = Settings()

MODEL_MAP = {
    "claude-sonnet-4-5": settings.model_sonnet,
    "claude-haiku-4-5":  settings.model_haiku,
    "claude-opus-4-5":   settings.model_opus,
}

# ── 结构化日志（终端 + 文件）─────────────────────────────
class JSONFormatter(logging.Formatter):
    def format(self, record):
        return json.dumps({
            "time": datetime.now().isoformat(),
            "level": record.levelname,
            "message": record.getMessage(),
        }, ensure_ascii=False)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
_console = logging.StreamHandler()
_console.setFormatter(JSONFormatter())
logger.addHandler(_console)
_file = logging.FileHandler(settings.log_file, encoding="utf-8")
_file.setFormatter(JSONFormatter())
logger.addHandler(_file)

app = FastAPI()

# ── 响应压缩 ───────────────────────────────────────────────
app.add_middleware(GZipMiddleware, minimum_size=1000)

# ── API Key 鉴权 ───────────────────────────────────────────
security = HTTPBearer(auto_error=False)

def verify_key(credentials: HTTPAuthorizationCredentials = Security(security)):
    if not credentials or credentials.credentials != settings.proxy_api_key:
        raise HTTPException(status_code=401, detail="Invalid API Key")
    return credentials

# ── LRU 缓存 ──────────────────────────────────────────────
CACHE_FILE = pathlib.Path("cache.pkl")
STATS_FILE = pathlib.Path("stats.json")

class LRUCache:
    def __init__(self, max_size: int, ttl: int):
        self.cache = OrderedDict()
        self.max_size = max_size
        self.ttl = ttl

    def get(self, key: str):
        if key not in self.cache:
            return None
        cached_time, value = self.cache[key]
        if datetime.now() - cached_time > timedelta(seconds=self.ttl):
            del self.cache[key]
            return None
        self.cache.move_to_end(key)
        return value

    def set(self, key: str, value):
        if key in self.cache:
            self.cache.move_to_end(key)
        self.cache[key] = (datetime.now(), value)
        if len(self.cache) > self.max_size:
            self.cache.popitem(last=False)

cache = LRUCache(max_size=settings.max_cache_size, ttl=settings.cache_ttl)

# ── 并发去重 ───────────────────────────────────────────────
pending_requests: dict[str, asyncio.Future] = {}

# ── 统计（含并发锁）──────────────────────────────────────
@dataclass
class Stats:
    total_requests: int = 0
    cache_hits: int = 0
    gateway_errors: int = 0
    retries: int = 0
    total_input_tokens: int = 0
    total_output_tokens: int = 0

    @property
    def hit_rate(self) -> str:
        if self.total_requests == 0:
            return "0%"
        return f"{self.cache_hits / self.total_requests * 100:.1f}%"

    @property
    def total_tokens(self) -> int:
        return self.total_input_tokens + self.total_output_tokens

stats = Stats()
stats_lock = asyncio.Lock()

# 最近一次请求携带的工具列表（用于 stats 页面展示）
_last_tools: list[str] = []
_tools_lock = asyncio.Lock()

# 运行时工具过滤状态（从 .env 初始化，可通过页面动态修改）
_tool_filter_lock = asyncio.Lock()
_tool_whitelist: set[str] = set(t.strip() for t in settings.tool_whitelist.split(",") if t.strip())
_tool_blacklist: set[str] = set(t.strip() for t in settings.tool_blacklist.split(",") if t.strip())

ENV_FILE_PATH = pathlib.Path(".env")

def _save_tool_filter_to_env():
    """将当前工具过滤状态写回 .env 文件"""
    try:
        if ENV_FILE_PATH.exists():
            lines = ENV_FILE_PATH.read_text(encoding="utf-8").splitlines()
            new_lines = []
            whitelist_written = blacklist_written = False
            for line in lines:
                if line.startswith("TOOL_WHITELIST="):
                    new_lines.append("TOOL_WHITELIST=" + ",".join(sorted(_tool_whitelist)))
                    whitelist_written = True
                elif line.startswith("TOOL_BLACKLIST="):
                    new_lines.append("TOOL_BLACKLIST=" + ",".join(sorted(_tool_blacklist)))
                    blacklist_written = True
                else:
                    new_lines.append(line)
            if not whitelist_written:
                new_lines.append("TOOL_WHITELIST=" + ",".join(sorted(_tool_whitelist)))
            if not blacklist_written:
                new_lines.append("TOOL_BLACKLIST=" + ",".join(sorted(_tool_blacklist)))
            ENV_FILE_PATH.write_text("\n".join(new_lines) + "\n", encoding="utf-8")
    except Exception as e:
        logger.warning(json.dumps({"message": f"工具过滤配置写入失败: {str(e)}"}))


def _apply_tool_filter(tools: list) -> list:
    """根据黑白名单过滤工具列表"""
    if _tool_whitelist:
        return [t for t in tools if t["function"]["name"] in _tool_whitelist]
    if _tool_blacklist:
        return [t for t in tools if t["function"]["name"] not in _tool_blacklist]
    return tools

async def inc_stat(field: str, delta: int = 1):
    """线程安全的统计自增"""
    async with stats_lock:
        setattr(stats, field, getattr(stats, field) + delta)

# ── 工具调用格式转换 ───────────────────────────────────────
def slim_schema(schema) -> dict:
    """精简参数 schema，去掉非必要字段节省 token"""
    remove_keys = {"title", "examples", "default", "$schema"}
    if isinstance(schema, dict):
        return {k: slim_schema(v) for k, v in schema.items() if k not in remove_keys}
    if isinstance(schema, list):
        return [slim_schema(i) for i in schema]
    return schema


def convert_tools_to_openai(tools: list, max_desc_len: int = 100) -> list:
    """Anthropic 工具定义 → OpenAI 格式（含描述压缩和 schema 精简）"""
    result = []
    for tool in tools:
        desc = tool.get("description", "")
        if len(desc) > max_desc_len:
            desc = desc[:max_desc_len] + "..."
        result.append({
            "type": "function",
            "function": {
                "name": tool.get("name", ""),
                "description": desc,
                "parameters": slim_schema(tool.get("input_schema", {}))
            }
        })
    return result


def convert_messages_to_openai(messages: list) -> list:
    """Anthropic messages → OpenAI 格式（含多轮工具调用）"""
    result = []
    for msg in messages:
        role = msg.get("role", "user")
        content = msg.get("content", "")

        if isinstance(content, str):
            result.append({"role": role, "content": content})
            continue

        if isinstance(content, list):
            tool_calls, tool_results, text_parts = [], [], []
            for block in content:
                btype = block.get("type", "")
                if btype == "tool_use":
                    tool_calls.append({
                        "id": block.get("id", ""),
                        "type": "function",
                        "function": {
                            "name": block.get("name", ""),
                            "arguments": json.dumps(block.get("input", {}), ensure_ascii=False)
                        }
                    })
                elif btype == "tool_result":
                    tool_content = block.get("content", "")
                    if isinstance(tool_content, list):
                        tool_content = " ".join(
                            b.get("text", "") for b in tool_content if b.get("type") == "text"
                        )
                    tool_results.append({
                        "role": "tool",
                        "tool_call_id": block.get("tool_use_id", ""),
                        "content": tool_content
                    })
                elif btype == "text":
                    text_parts.append(block.get("text", ""))

            if tool_calls:
                msg_out = {"role": "assistant", "tool_calls": tool_calls}
                if text_parts:
                    msg_out["content"] = " ".join(text_parts)
                result.append(msg_out)
            if tool_results:
                result.extend(tool_results)
            if text_parts and not tool_calls and not tool_results:
                result.append({"role": role, "content": " ".join(text_parts)})

    return result


def has_pending_tool_calls(messages: list) -> bool:
    """检测是否有未完成的多轮工具调用"""
    for msg in reversed(messages):
        if msg.get("role") == "assistant":
            content = msg.get("content", [])
            if isinstance(content, list):
                return any(b.get("type") == "tool_use" for b in content)
    return False


def normalize_usage(usage: dict) -> dict:
    """标准化 usage 字段，兼容不同模型返回格式"""
    return {
        "input_tokens":  usage.get("prompt_tokens") or usage.get("input_tokens") or 0,
        "output_tokens": usage.get("completion_tokens") or usage.get("output_tokens") or 0,
    }


# 需要过滤的异常 token 列表
_ABNORMAL_TOKENS = [
    "<|begin_of_sentence|>", "<|end_of_sentence|>",
    "<|im_start|>", "<|im_end|>",
    "<|endoftext|>", "<s>", "</s>",
]

def clean_content(text: str) -> str:
    """过滤模型输出中的异常特殊 token"""
    if not text:
        return ""
    for token in _ABNORMAL_TOKENS:
        text = text.replace(token, "")
    return text.strip()


def convert_response_to_anthropic(data: dict, requested_model: str) -> dict:
    """OpenAI 响应 → Anthropic 格式（含工具调用）"""
    choice = data["choices"][0]
    message = choice.get("message", {})
    finish_reason = choice.get("finish_reason", "stop")
    usage = normalize_usage(data.get("usage", {}))

    if finish_reason == "tool_calls" and message.get("tool_calls"):
        content_blocks = []
        if message.get("content"):
            content_blocks.append({"type": "text", "text": clean_content(message["content"])})
        for tool_call in message["tool_calls"]:
            fn = tool_call.get("function", {})
            try:
                input_data = json.loads(fn.get("arguments", "{}"))
            except Exception:
                input_data = {}
            content_blocks.append({
                "type": "tool_use",
                "id": tool_call.get("id", ""),
                "name": fn.get("name", ""),
                "input": input_data
            })
        return {
            "id": data.get("id", "msg_001"),
            "type": "message", "role": "assistant",
            "content": content_blocks,
            "model": requested_model,
            "stop_reason": "tool_use",
            "usage": usage
        }

    return {
        "id": data.get("id", "msg_001"),
        "type": "message", "role": "assistant",
        "content": [{"type": "text", "text": clean_content(message.get("content", ""))}],
        "model": requested_model,
        "stop_reason": "end_turn",
        "usage": usage
    }


# ── 重试请求 ───────────────────────────────────────────────
async def request_with_retry(client, url, payload, headers, request_id, actual_model):
    last_error = None
    for attempt in range(settings.max_retries + 1):
        try:
            if attempt > 0:
                await inc_stat("retries")
                logger.warning(json.dumps({
                    "request_id": request_id,
                    "message": f"第 {attempt} 次重试",
                    "model": actual_model
                }))
                await asyncio.sleep(settings.retry_delay * attempt)
            resp = await client.post(url, json=payload, headers=headers)
            resp.raise_for_status()
            return resp.json()
        except (httpx.ReadTimeout, httpx.ConnectTimeout) as e:
            last_error = e
            continue
        except httpx.HTTPStatusError as e:
            if e.response.status_code >= 500:
                last_error = e
                continue
            raise
    raise last_error


def get_cache_key(model: str, messages: list, tools: list) -> str:
    content = json.dumps({"messages": messages, "tools": tools}, ensure_ascii=False, sort_keys=True)
    return hashlib.md5(f"{model}:{content}".encode()).hexdigest()


# ── 启动：恢复缓存和统计 ──────────────────────────────────
@app.on_event("startup")
async def startup_check():
    # 恢复缓存
    if CACHE_FILE.exists():
        try:
            with open(CACHE_FILE, "rb") as f:
                cache.cache = pickle.load(f)
            logger.info(json.dumps({"message": f"已恢复缓存 {len(cache.cache)} 条"}))
        except Exception as e:
            logger.warning(json.dumps({"message": f"缓存恢复失败: {str(e)}"}))

    # 恢复统计
    if STATS_FILE.exists():
        try:
            data = json.loads(STATS_FILE.read_text())
            stats.total_requests = data.get("total_requests", 0)
            stats.cache_hits = data.get("cache_hits", 0)
            stats.gateway_errors = data.get("gateway_errors", 0)
            stats.retries = data.get("retries", 0)
            stats.total_input_tokens = data.get("total_input_tokens", 0)
            stats.total_output_tokens = data.get("total_output_tokens", 0)
            logger.info(json.dumps({"message": "已恢复历史统计数据"}))
        except Exception as e:
            logger.warning(json.dumps({"message": f"统计恢复失败: {str(e)}"}))

    # 打印当前配置摘要
    logger.info(json.dumps({
        "message": "代理启动",
        "target": settings.target,
        "models": MODEL_MAP,
        "cache_ttl": settings.cache_ttl,
        "tool_desc_max_len": settings.tool_desc_max_len,
        "read_timeout": settings.read_timeout,
    }))

    # 网关连通性检查
    try:
        async with httpx.AsyncClient(timeout=5.0) as client:
            resp = await client.get(
                f"{settings.target}/models",
                headers={"Authorization": f"Bearer {settings.api_key}"}
            )
            resp.raise_for_status()
            logger.info(json.dumps({"message": f"网关连通性检查通过: {settings.target}"}))
    except Exception as e:
        logger.warning(json.dumps({"message": f"网关连通性检查失败: {str(e)}"}))


# ── 关闭：持久化缓存和统计 ───────────────────────────────
@app.on_event("shutdown")
async def shutdown():
    logger.info(json.dumps({"message": "代理正在关闭，等待进行中的请求完成..."}))
    if pending_requests:
        await asyncio.gather(*pending_requests.values(), return_exceptions=True)

    # 持久化缓存
    try:
        with open(CACHE_FILE, "wb") as f:
            pickle.dump(cache.cache, f)
        logger.info(json.dumps({"message": f"缓存已持久化 {len(cache.cache)} 条"}))
    except Exception as e:
        logger.warning(json.dumps({"message": f"缓存持久化失败: {str(e)}"}))

    # 持久化统计
    try:
        STATS_FILE.write_text(json.dumps({
            "total_requests": stats.total_requests,
            "cache_hits": stats.cache_hits,
            "gateway_errors": stats.gateway_errors,
            "retries": stats.retries,
            "total_input_tokens": stats.total_input_tokens,
            "total_output_tokens": stats.total_output_tokens,
        }))
        logger.info(json.dumps({"message": "统计数据已持久化"}))
    except Exception as e:
        logger.warning(json.dumps({"message": f"统计持久化失败: {str(e)}"}))

    logger.info(json.dumps({"message": "代理已安全关闭"}))


# ── 请求日志中间件 ─────────────────────────────────────────
@app.middleware("http")
async def log_requests(request: Request, call_next):
    start = time.time()
    response = await call_next(request)
    duration = time.time() - start
    logger.info(json.dumps({
        "method": request.method,
        "path": request.url.path,
        "status": response.status_code,
        "duration": f"{duration:.2f}s"
    }))
    return response


# ── 健康检查 ───────────────────────────────────────────────
@app.get("/")
@app.head("/")
async def health():
    return JSONResponse({"status": "ok"})


# ── 统计接口（HTML 页面）──────────────────────────────────
@app.get("/stats")
async def get_stats():
    from fastapi.responses import HTMLResponse
    uptime = int(time.time() - _start_time)
    h, m, s = uptime // 3600, (uptime % 3600) // 60, uptime % 60
    model_rows = "".join(
        f"<tr><td>{k}</td><td>{v}</td></tr>" for k, v in MODEL_MAP.items()
    )

    # 获取工具列表和当前过滤状态
    async with _tools_lock:
        last_tools_snapshot = list(_last_tools)
    async with _tool_filter_lock:
        cur_whitelist = set(_tool_whitelist)
        cur_blacklist = set(_tool_blacklist)

    if cur_whitelist:
        cur_mode = "whitelist"
        cur_filtered = cur_whitelist
    elif cur_blacklist:
        cur_mode = "blacklist"
        cur_filtered = cur_blacklist
    else:
        cur_mode = "none"
        cur_filtered = set()

    # 构建工具勾选列表
    if last_tools_snapshot:
        tool_items = ""
        for name in last_tools_snapshot:
            if cur_mode == "whitelist":
                checked = "checked" if name in cur_filtered else ""
                status_color = "#080" if name in cur_filtered else "#aaa"
                status_text = "发送" if name in cur_filtered else "过滤"
            elif cur_mode == "blacklist":
                checked = "checked" if name not in cur_filtered else ""
                status_color = "#080" if name not in cur_filtered else "#c00"
                status_text = "发送" if name not in cur_filtered else "过滤"
            else:
                checked = "checked"
                status_color = "#080"
                status_text = "发送"
            tool_items += (
                f"<div class='tool-item'>"
                f"<input type='checkbox' class='tool-cb' value='{name}' {checked}>"
                f"<span class='tool-name'>{name}</span>"
                f"<span class='tool-status' style='color:{status_color}'>{status_text}</span>"
                f"</div>"
            )
        tool_count = len(last_tools_snapshot)
        tool_section = f"""
<div class='tool-bar'>
  <span style='color:#666;font-size:13px'>共 {tool_count} 个工具 &nbsp;|&nbsp; 过滤模式：</span>
  <select id='filter-mode' onchange='onModeChange()'>
    <option value='none' {'selected' if cur_mode=='none' else ''}>不过滤（全部发送）</option>
    <option value='whitelist' {'selected' if cur_mode=='whitelist' else ''}>白名单（只发送勾选）</option>
    <option value='blacklist' {'selected' if cur_mode=='blacklist' else ''}>黑名单（过滤勾选）</option>
  </select>
  &nbsp;
  <input type='button' onclick='selectAll(true)' value='全选'>
  <input type='button' onclick='selectAll(false)' value='全不选'>
  <input type='button' class='save-btn' onclick='saveFilter()' value='保存设置'>
  <span id='save-msg' style='color:#080;margin-left:10px;display:none'>已保存</span>
</div>
<div class='tool-list'>{tool_items}</div>"""
    else:
        tool_section = "<p style='color:#aaa;padding:10px'>暂无数据，发送一条消息后刷新页面</p>"

    html = f"""<!DOCTYPE html>
<html><head>
<meta charset="utf-8"><title>Claude Proxy Stats</title>
<style>
  body{{font-family:sans-serif;padding:30px;background:#f5f5f5;color:#333;margin:0}}
  h2{{margin:24px 0 12px;color:#222}}
  table{{border-collapse:collapse;width:460px;background:#fff;border-radius:8px;
        overflow:hidden;box-shadow:0 2px 8px rgba(0,0,0,.08);margin-bottom:20px}}
  td{{padding:11px 20px;border-bottom:1px solid #eee;font-size:14px}}
  td:first-child{{color:#888;width:180px}}
  td:last-child{{font-weight:600}}
  .green{{color:#080}} .red{{color:#c00}}
  .tool-bar{{margin-bottom:10px;display:flex;align-items:center;flex-wrap:wrap;gap:6px}}
  .tool-bar select{{padding:4px 8px;border-radius:4px;border:1px solid #ccc;font-size:13px}}
  .tool-bar button{{padding:4px 10px;border:1px solid #ccc;border-radius:4px;
                    background:#fff;cursor:pointer;font-size:13px}}
  .save-btn{{background:#1a73e8;color:#fff;border-color:#1a73e8}}
  .save-btn:hover{{background:#1557b0}}
  .tool-list{{background:#fff;border-radius:8px;box-shadow:0 2px 8px rgba(0,0,0,.08);
              padding:8px 0;max-height:400px;overflow-y:auto;width:460px;margin-bottom:20px}}
  .tool-item{{display:flex;align-items:center;padding:7px 16px;border-bottom:1px solid #f0f0f0}}
  .tool-item:last-child{{border-bottom:none}}
  .tool-item input{{margin-right:10px;cursor:pointer;width:15px;height:15px}}
  .tool-name{{flex:1;font-size:13px;font-family:monospace}}
  .tool-status{{font-size:12px;width:36px;text-align:right}}
  p{{color:#aaa;font-size:12px;margin:4px 0}}
</style></head><body>
<h2>Claude Proxy 运行状态</h2>
<table>
  <tr><td>运行时长</td><td>{h:02d}:{m:02d}:{s:02d}</td></tr>
  <tr><td>总请求数</td><td>{stats.total_requests}</td></tr>
  <tr><td>缓存命中</td><td class="green">{stats.cache_hits}</td></tr>
  <tr><td>缓存命中率</td><td class="green">{stats.hit_rate}</td></tr>
  <tr><td>当前缓存条数</td><td>{len(cache.cache)}</td></tr>
  <tr><td>网关错误</td><td class="{'red' if stats.gateway_errors else 'green'}">{stats.gateway_errors}</td></tr>
  <tr><td>重试次数</td><td>{stats.retries}</td></tr>
</table>
<h2>Token 使用统计</h2>
<table>
  <tr><td>输入 Token 合计</td><td>{stats.total_input_tokens:,}</td></tr>
  <tr><td>输出 Token 合计</td><td>{stats.total_output_tokens:,}</td></tr>
  <tr><td>总 Token 合计</td><td><b>{stats.total_tokens:,}</b></td></tr>
  <tr><td>平均每次请求</td><td>{stats.total_tokens // max(stats.total_requests, 1):,}</td></tr>
</table>
<h2>模型映射</h2>
<table>
  <tr><td><b>Claude Desktop</b></td><td><b>私有模型</b></td></tr>
  {model_rows}
</table>
<h2>MCP 工具管理</h2>
{tool_section}
<p>* 过滤设置即时生效并自动保存到 .env，重启后依然有效</p>
<script>
function onModeChange() {{
  var mode = document.getElementById('filter-mode').value;
  var items = document.querySelectorAll('.tool-item .tool-status');
  // 切换模式时重新渲染状态标签
  updateStatusLabels();
}}
function updateStatusLabels() {{
  var mode = document.getElementById('filter-mode').value;
  document.querySelectorAll('.tool-item').forEach(function(item) {{
    var cb = item.querySelector('.tool-cb');
    var status = item.querySelector('.tool-status');
    var checked = cb.checked;
    if (mode === 'none') {{
      status.textContent = '发送'; status.style.color = '#080';
    }} else if (mode === 'whitelist') {{
      status.textContent = checked ? '发送' : '过滤';
      status.style.color = checked ? '#080' : '#aaa';
    }} else {{
      status.textContent = checked ? '过滤' : '发送';
      status.style.color = checked ? '#c00' : '#080';
    }}
  }});
}}
function selectAll(v) {{
  document.querySelectorAll('.tool-cb').forEach(function(cb){{ cb.checked = v; }});
  updateStatusLabels();
}}
document.querySelectorAll('.tool-cb').forEach(function(cb) {{
  cb.addEventListener('change', updateStatusLabels);
}});
function saveFilter() {{
  var mode = document.getElementById('filter-mode').value;
  var tools = [];
  document.querySelectorAll('.tool-cb:checked').forEach(function(cb){{ tools.push(cb.value); }});
  fetch('/admin/tools/filter', {{
    method: 'POST',
    headers: {{'Content-Type': 'application/json'}},
    body: JSON.stringify({{mode: mode, tools: tools}})
  }}).then(function(r) {{
    if (r.ok) {{
      var msg = document.getElementById('save-msg');
      msg.style.display = 'inline';
      setTimeout(function(){{ msg.style.display='none'; }}, 2000);
      updateStatusLabels();
    }}
  }});
}}
</script>
</body></html>"""
    return HTMLResponse(html)


# ── 模型映射管理接口 ───────────────────────────────────────
@app.get("/admin/models")
async def get_model_map(_=Security(verify_key)):
    return JSONResponse(MODEL_MAP)

@app.put("/admin/models")
async def update_model_map(request: Request, _=Security(verify_key)):
    new_map = await request.json()
    MODEL_MAP.update(new_map)
    logger.info(json.dumps({"message": "模型映射已更新", "model_map": MODEL_MAP}))
    return JSONResponse({"status": "ok", "model_map": MODEL_MAP})


# ── 工具过滤管理接口 ───────────────────────────────────────
@app.post("/admin/tools/filter")
async def update_tool_filter(request: Request):
    """更新工具黑白名单，无需鉴权（stats 页面直接调用）"""
    global _tool_whitelist, _tool_blacklist
    body = await request.json()
    mode = body.get("mode", "blacklist")   # whitelist / blacklist / none
    tools = set(t.strip() for t in body.get("tools", []) if t.strip())

    async with _tool_filter_lock:
        if mode == "whitelist":
            _tool_whitelist = tools
            _tool_blacklist = set()
        elif mode == "blacklist":
            _tool_whitelist = set()
            _tool_blacklist = tools
        else:  # none - 清除所有过滤
            _tool_whitelist = set()
            _tool_blacklist = set()
        _save_tool_filter_to_env()

    logger.info(json.dumps({
        "message": "工具过滤已更新",
        "mode": mode,
        "tools": list(tools)
    }))
    return JSONResponse({"status": "ok", "mode": mode, "tools": list(tools)})


@app.get("/admin/tools/filter")
async def get_tool_filter():
    """获取当前工具过滤状态"""
    async with _tool_filter_lock:
        if _tool_whitelist:
            return JSONResponse({"mode": "whitelist", "tools": list(_tool_whitelist)})
        elif _tool_blacklist:
            return JSONResponse({"mode": "blacklist", "tools": list(_tool_blacklist)})
        else:
            return JSONResponse({"mode": "none", "tools": []})


# ── 主代理路由 ─────────────────────────────────────────────
@app.post("/v1/responses")
@app.post("/v1/messages")
async def proxy(request: Request, _=Security(verify_key)):
    # 优先使用 Claude Desktop 发来的原始请求 ID，方便日志关联
    request_id = request.headers.get("x-request-id") or str(uuid.uuid4())[:8]

    # 请求体大小检查
    content_length = request.headers.get("content-length")
    if content_length and int(content_length) > settings.max_request_size:
        logger.warning(json.dumps({"request_id": request_id, "message": "请求体过大", "size": content_length}))
        return JSONResponse({"error": "请求体过大"}, status_code=413)

    body = await request.json()

    requested_model = body.get("model", "claude-sonnet-4-5")

    # 模型不存在时记录警告并回退到默认模型
    if requested_model not in MODEL_MAP:
        logger.warning(json.dumps({
            "request_id": request_id,
            "message": f"未知模型 {requested_model}，回退到默认模型 {settings.model_sonnet}",
            "available_models": list(MODEL_MAP.keys())
        }))
    actual_model = MODEL_MAP.get(requested_model, settings.model_sonnet)
    is_stream = body.get("stream", False)

    await inc_stat("total_requests")
    logger.info(json.dumps({
        "request_id": request_id,
        "message": "收到请求",
        "requested_model": requested_model,
        "actual_model": actual_model,
        "stream": is_stream
    }))

    # 处理 messages
    messages = body.get("messages", [])
    if not messages:
        input_text = body.get("input", "")
        messages = [{"role": "user", "content": input_text}]

    is_tool_chain = has_pending_tool_calls(messages)
    messages = convert_messages_to_openai(messages)

    # 处理 system prompt
    system = body.get("system", "")
    if system and not any(m.get("role") == "system" for m in messages):
        messages = [{"role": "system", "content": system}] + messages

    # 处理工具定义
    raw_tools = body.get("tools", [])
    openai_tools = convert_tools_to_openai(raw_tools, max_desc_len=settings.tool_desc_max_len) if raw_tools else []

    # 更新全局工具列表（过滤前，用于页面展示完整列表）
    if openai_tools:
        all_tool_names = [t["function"]["name"] for t in openai_tools]
        async with _tools_lock:
            global _last_tools
            _last_tools = all_tool_names

    # 应用黑白名单过滤
    async with _tool_filter_lock:
        openai_tools = _apply_tool_filter(openai_tools)

    has_tools = bool(openai_tools)
    if has_tools:
        filtered_names = [t["function"]["name"] for t in openai_tools]
        logger.info(json.dumps({
            "request_id": request_id,
            "message": "工具调用请求",
            "total_tools": len(_last_tools),
            "filtered_tools": len(filtered_names),
            "tools": filtered_names,
            "desc_max_len": settings.tool_desc_max_len,
            "is_tool_chain": is_tool_chain
        }))

    # 含工具或工具链不走缓存
    use_cache = not is_stream and not has_tools and not is_tool_chain
    if use_cache:
        cache_key = get_cache_key(actual_model, messages, openai_tools)
        cached = cache.get(cache_key)
        if cached:
            await inc_stat("cache_hits")
            logger.info(json.dumps({"request_id": request_id, "message": "命中缓存", "model": actual_model}))
            return JSONResponse(cached)

        if cache_key in pending_requests:
            logger.info(json.dumps({"request_id": request_id, "message": "等待并发请求完成", "model": actual_model}))
            try:
                result = await pending_requests[cache_key]
                return JSONResponse(result)
            except Exception:
                pass

        loop = asyncio.get_event_loop()
        future: asyncio.Future = loop.create_future()
        pending_requests[cache_key] = future
    else:
        cache_key = None
        future = None

    # 构造 payload
    payload = {
        "model": actual_model,
        "messages": messages,
        "max_tokens": body.get("max_output_tokens") or body.get("max_tokens") or 4096,
        "stream": is_stream,
    }
    if openai_tools:
        payload["tools"] = openai_tools
        tool_choice = body.get("tool_choice")
        if tool_choice:
            if isinstance(tool_choice, dict) and tool_choice.get("type") == "tool":
                payload["tool_choice"] = {"type": "function", "function": {"name": tool_choice.get("name", "")}}
            elif tool_choice == "any":
                payload["tool_choice"] = "required"
            else:
                payload["tool_choice"] = "auto"

    gateway_headers = {"Authorization": f"Bearer {settings.api_key}"}
    timeout = httpx.Timeout(
        connect=settings.connect_timeout,
        read=settings.read_timeout,
        write=settings.write_timeout,
        pool=5.0
    )

    # ── 含工具的流式：强制非流式处理，再模拟 SSE 返回 ──────
    if is_stream and has_tools:
        payload["stream"] = False
        try:
            async with httpx.AsyncClient(timeout=timeout) as client:
                data = await request_with_retry(client, f"{settings.target}/chat/completions",
                                                payload, gateway_headers, request_id, actual_model)
        except Exception as e:
            await inc_stat("gateway_errors")
            err_msg = str(e)
            logger.error(json.dumps({"request_id": request_id, "message": f"工具流式请求失败: {err_msg}"}))
            async def err_gen():
                yield f"data: {json.dumps({'error': err_msg})}\n\n"
            return StreamingResponse(err_gen(), media_type="text/event-stream")

        if "choices" not in data or len(data["choices"]) == 0:
            await inc_stat("gateway_errors")
            async def fmt_err_gen():
                yield f"data: {json.dumps({'error': '网关返回格式异常'})}\n\n"
            return StreamingResponse(fmt_err_gen(), media_type="text/event-stream")

        usage = data.get("usage", {})
        finish_reason = data["choices"][0].get("finish_reason", "stop")

        # 累计 token 统计
        normalized_tool = normalize_usage(usage)
        async with stats_lock:
            stats.total_input_tokens += normalized_tool.get("input_tokens", 0)
            stats.total_output_tokens += normalized_tool.get("output_tokens", 0)

        logger.info(json.dumps({
            "request_id": request_id, "model": actual_model,
            "finish_reason": finish_reason,
            "input_tokens": normalized_tool.get("input_tokens", 0),
            "output_tokens": normalized_tool.get("output_tokens", 0),
            "total_tokens": normalized_tool.get("input_tokens", 0) + normalized_tool.get("output_tokens", 0),
            "mode": "tool_stream_forced_sync"
        }))
        result = convert_response_to_anthropic(data, requested_model)

        async def tool_stream_gen():
            yield f"data: {json.dumps({'type': 'message_start', 'message': {'id': result['id'], 'type': 'message', 'role': 'assistant', 'model': result['model'], 'content': [], 'stop_reason': None, 'usage': result['usage']}})}\n\n"
            for i, block in enumerate(result.get("content", [])):
                yield f"data: {json.dumps({'type': 'content_block_start', 'index': i, 'content_block': block})}\n\n"
                if block.get("type") == "text":
                    yield f"data: {json.dumps({'type': 'content_block_delta', 'index': i, 'delta': {'type': 'text_delta', 'text': block['text']}})}\n\n"
                yield f"data: {json.dumps({'type': 'content_block_stop', 'index': i})}\n\n"
            yield f"data: {json.dumps({'type': 'message_delta', 'delta': {'stop_reason': result['stop_reason']}, 'usage': result['usage']})}\n\n"
            yield "data: {\"type\": \"message_stop\"}\n\n"

        return StreamingResponse(tool_stream_gen(), media_type="text/event-stream")

    # ── 普通流式响应（无工具，含超时保护）─────────────────
    if is_stream:
        async def generate():
            try:
                async with asyncio.timeout(settings.read_timeout):
                    async with httpx.AsyncClient(timeout=timeout) as client:
                        async with client.stream(
                            "POST", f"{settings.target}/chat/completions",
                            json=payload, headers=gateway_headers,
                        ) as resp:
                            async for line in resp.aiter_lines():
                                if not line.startswith("data:"):
                                    continue
                                raw = line[5:].strip()
                                if raw == "[DONE]":
                                    yield "data: {\"type\": \"message_stop\"}\n\n"
                                    break
                                try:
                                    chunk = json.loads(raw)
                                    delta = chunk["choices"][0]["delta"].get("content", "")
                                    if delta:
                                        delta = clean_content(delta)
                                    if delta:
                                        yield f"data: {json.dumps({'type': 'content_block_delta', 'delta': {'type': 'text_delta', 'text': delta}})}\n\n"
                                except Exception:
                                    continue
            except asyncio.TimeoutError:
                logger.error(json.dumps({"request_id": request_id, "message": "流式响应超时"}))
                yield f"data: {json.dumps({'error': '流式响应超时'})}\n\n"
            except Exception as e:
                stream_err = str(e)
                logger.error(json.dumps({"request_id": request_id, "message": f"流式请求异常: {stream_err}"}))
                yield f"data: {json.dumps({'error': stream_err})}\n\n"

        return StreamingResponse(generate(), media_type="text/event-stream")

    # ── 非流式响应（含重试）────────────────────────────────
    try:
        async with httpx.AsyncClient(timeout=timeout) as client:
            data = await request_with_retry(
                client, f"{settings.target}/chat/completions",
                payload, gateway_headers, request_id, actual_model
            )
    except httpx.ConnectTimeout:
        await inc_stat("gateway_errors")
        logger.error(json.dumps({"request_id": request_id, "message": "连接网关超时", "model": actual_model}))
        if future: future.cancel()
        pending_requests.pop(cache_key, None)
        return JSONResponse({"error": "连接网关超时"}, status_code=504)
    except httpx.ReadTimeout:
        await inc_stat("gateway_errors")
        logger.error(json.dumps({"request_id": request_id, "message": "读取响应超时", "model": actual_model}))
        if future: future.cancel()
        pending_requests.pop(cache_key, None)
        return JSONResponse({"error": "网关响应超时"}, status_code=504)
    except httpx.HTTPStatusError as e:
        await inc_stat("gateway_errors")
        logger.error(json.dumps({"request_id": request_id, "message": f"网关返回错误: {e.response.status_code}", "model": actual_model}))
        if future: future.cancel()
        pending_requests.pop(cache_key, None)
        return JSONResponse({"error": f"网关返回错误: {e.response.status_code}"}, status_code=502)
    except Exception as e:
        await inc_stat("gateway_errors")
        logger.error(json.dumps({"request_id": request_id, "message": f"未知错误: {str(e)}"}))
        if future: future.cancel()
        pending_requests.pop(cache_key, None)
        return JSONResponse({"error": f"未知错误: {str(e)}"}, status_code=500)

    # ── 解析响应 ─────────────────────────────────────────────
    if "choices" in data and len(data["choices"]) > 0:
        usage = data.get("usage", {})
        finish_reason = data["choices"][0].get("finish_reason", "stop")

        # 累计 token 统计
        normalized = normalize_usage(usage)
        async with stats_lock:
            stats.total_input_tokens += normalized.get("input_tokens", 0)
            stats.total_output_tokens += normalized.get("output_tokens", 0)

        logger.info(json.dumps({
            "request_id": request_id,
            "model": actual_model,
            "finish_reason": finish_reason,
            "input_tokens": normalized.get("input_tokens", 0),
            "output_tokens": normalized.get("output_tokens", 0),
            "total_tokens": normalized.get("input_tokens", 0) + normalized.get("output_tokens", 0),
        }))

        result = convert_response_to_anthropic(data, requested_model)

        if use_cache:
            cache.set(cache_key, result)
            if future and not future.done():
                future.set_result(result)
            pending_requests.pop(cache_key, None)

        return JSONResponse(result)
    else:
        await inc_stat("gateway_errors")
        logger.error(json.dumps({"request_id": request_id, "message": "网关返回格式异常", "data": data}))
        if future: future.cancel()
        pending_requests.pop(cache_key, None)
        return JSONResponse({"error": "网关返回格式异常"}, status_code=500)
