"""FastAPI app 工厂。G1 脚手架 → G2 task list → G3 chat 只读 → G4+ 渐进上。 设计: - 单 FastAPI 进程,模板走 Jinja2,静态走 StaticFiles - 模板里 path 显示一律 `replace('\\', '/')`,Win / Linux 看到统一形态 (`Path.as_posix()` 在 Linux 读 Windows backslash 串时不归一,所以直接 replace) - Markdown 渲染走 markdown-it-py(gfm-like)+ pygments syntax highlight - SSE 在 G4 加,响应头会带 `X-Accel-Buffering: no`(nginx 反代友好) - 本地形态 sentinel user 固定;Phase D 加 OIDC 之后才有真正 user 态 """ from __future__ import annotations import json from pathlib import Path from typing import Any, Optional from uuid import UUID from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from sqlalchemy import func, select from core.storage import session_scope from core.storage.models import Message, Task WEB_ROOT = Path(__file__).resolve().parent TEMPLATES_DIR = WEB_ROOT / "templates" STATIC_DIR = WEB_ROOT / "static" STATUS_FILTERS = ("active", "completed", "abandoned") def _norm_path(p: str) -> str: """跨 OS 显示归一:backslash → forward slash。Win 存 `\\`、Linux 存 `/`,显示统一 `/`。""" return (p or "").replace("\\", "/") # --------------------------- Markdown 渲染 --------------------------- _md_instance = None def _pygments_highlight(code: str, lang: str, attrs: str) -> str: """markdown-it highlight 回调。lang 未识别 / pygments 异常时返回 '' 让 md 走默认
。"""
    if not lang:
        return ""
    try:
        from pygments import highlight
        from pygments.formatters import HtmlFormatter
        from pygments.lexers import get_lexer_by_name
        from pygments.util import ClassNotFound
    except ImportError:
        return ""
    try:
        lexer = get_lexer_by_name(lang, stripall=False)
    except ClassNotFound:
        return ""
    formatter = HtmlFormatter(nowrap=False, cssclass="codehilite")
    return highlight(code, lexer, formatter)


def _get_md():
    """单例 MarkdownIt:gfm-like(表/strikethrough/linkify),禁 html(防 XSS),break=True。"""
    global _md_instance
    if _md_instance is None:
        from markdown_it import MarkdownIt
        _md_instance = MarkdownIt(
            "gfm-like",
            {
                "linkify": True,
                "html": False,
                "breaks": True,
                "highlight": _pygments_highlight,
            },
        )
    return _md_instance


def _render_md(text: str) -> str:
    """渲染 markdown → HTML。空串返空。"""
    if not text:
        return ""
    return _get_md().render(text)


# --------------------------- 消息块聚合 ---------------------------

def _args_preview(args: str, max_len: int = 60) -> str:
    s = (args or "").replace("\n", " ").strip()
    return s if len(s) <= max_len else s[: max_len - 3] + "..."


def _pretty_json(s: str) -> str:
    """JSON 串美化输出。解析失败返回原串。"""
    try:
        return json.dumps(json.loads(s), indent=2, ensure_ascii=False)
    except Exception:
        return s or ""


def load_chat_messages(task_id: UUID) -> list[dict]:
    """读 task 全部 messages(idx asc)。空 task 返空列表。"""
    with session_scope() as s:
        rows = s.execute(
            select(Message.payload).where(Message.task_id == task_id).order_by(Message.idx)
        ).scalars().all()
    return [dict(p) for p in rows]


def build_chat_blocks(messages: list[dict]) -> list[dict]:
    """把 LiteLLM 消息序列聚合成显示块。

    - system / tool 不进 blocks(system 不入 DB;tool result 跟随 assistant 的 tool_call 内嵌)
    - user → {type=user, html}
    - assistant → {type=assistant, html, tool_calls=[{name,args_preview,args_pretty,result}]}
    """
    tool_results: dict[str, str] = {}
    for m in messages:
        if m.get("role") == "tool":
            tcid = m.get("tool_call_id")
            if tcid:
                tool_results[tcid] = m.get("content") or ""

    blocks: list[dict] = []
    for m in messages:
        role = m.get("role")
        if role in ("system", "tool"):
            continue
        if role == "user":
            blocks.append({
                "type": "user",
                "html": _render_md(m.get("content") or ""),
            })
        elif role == "assistant":
            content = m.get("content") or ""
            tool_calls = m.get("tool_calls") or []
            tc_blocks = []
            for tc in tool_calls:
                fn = tc.get("function", {}) or {}
                args_raw = fn.get("arguments", "") or ""
                tc_blocks.append({
                    "name": fn.get("name", "?"),
                    "args_preview": _args_preview(args_raw),
                    "args_pretty": _pretty_json(args_raw),
                    "result": tool_results.get(tc.get("id", ""), "[no result]"),
                })
            blocks.append({
                "type": "assistant",
                "html": _render_md(content),
                "tool_calls": tc_blocks,
            })
    return blocks


# --------------------------- Task list 查询 ---------------------------

def list_tasks(limit: int = 50, status: Optional[str] = None) -> list[dict[str, Any]]:
    """Tasks 列表(updated_at 降序),含 messages 计数。"""
    if status and status not in STATUS_FILTERS:
        status = None
    with session_scope() as s:
        q = (
            select(
                Task.task_id, Task.updated_at, Task.created_at, Task.status,
                Task.mode, Task.model, Task.model_profile,
                Task.tokens_prompt, Task.tokens_completion, Task.description,
                Task.task_dir,
            )
            .order_by(Task.updated_at.desc())
        )
        if status:
            q = q.where(Task.status == status)
        rows_db = s.execute(q.limit(limit)).all()
        msg_counts = dict(s.execute(
            select(Message.task_id, func.count()).group_by(Message.task_id)
        ).all())

    result = []
    for r in rows_db:
        tid = r.task_id
        result.append({
            "task_id": str(tid),
            "task_id_short": str(tid)[:8],
            "updated_at": r.updated_at,
            "created_at": r.created_at,
            "status": r.status,
            "mode": r.mode or "",
            "model_label": r.model_profile or r.model or "",
            "tokens": (r.tokens_prompt or 0) + (r.tokens_completion or 0),
            "n_messages": msg_counts.get(tid, 0),
            "description": r.description or "",
            "task_dir": _norm_path(r.task_dir or ""),
        })
    return result


# --------------------------- App 工厂 ---------------------------

def create_app() -> FastAPI:
    """FastAPI 工厂。uvicorn --reload 模式需要工厂签名(factory=True)。"""
    app = FastAPI(title="zcbot web", version="0.3")
    templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
    app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")

    @app.get("/", response_class=HTMLResponse)
    def home(request: Request, status: Optional[str] = None, limit: int = 50):
        tasks = list_tasks(limit=limit, status=status)
        return templates.TemplateResponse(
            request, "home.html",
            {
                "version": app.version,
                "tasks": tasks,
                "status": status or "",
                "limit": limit,
                "filters": STATUS_FILTERS,
            },
        )

    @app.get("/tasks/{task_id}", response_class=HTMLResponse)
    def task_detail(request: Request, task_id: str):
        """G3:UUID 校验 + 读 task 元数据 + 读 messages + 聚合成显示块 + 渲染。"""
        try:
            tid = UUID(task_id)
        except ValueError:
            return HTMLResponse(f"invalid task id: {task_id!r}", status_code=404)
        with session_scope() as s:
            row = s.execute(
                select(
                    Task.task_id, Task.description, Task.task_dir, Task.status,
                    Task.mode, Task.model, Task.model_profile,
                    Task.tokens_prompt, Task.tokens_completion,
                    Task.created_at, Task.updated_at,
                ).where(Task.task_id == tid)
            ).first()
        if row is None:
            return HTMLResponse(f"task not found: {tid}", status_code=404)

        messages = load_chat_messages(tid)
        blocks = build_chat_blocks(messages)

        return templates.TemplateResponse(
            request, "chat.html",
            {
                "task_id": str(tid),
                "task_id_short": str(tid)[:8],
                "description": row.description or "",
                "task_dir": _norm_path(row.task_dir or ""),
                "status": row.status,
                "mode": row.mode or "",
                "model_label": row.model_profile or row.model or "",
                "tokens": (row.tokens_prompt or 0) + (row.tokens_completion or 0),
                "n_messages": len(messages),
                "created_at": row.created_at,
                "updated_at": row.updated_at,
                "blocks": blocks,
            },
        )

    @app.get("/healthz", response_class=HTMLResponse)
    def healthz():
        return HTMLResponse("ok")

    return app