265 lines
9.2 KiB
Python
265 lines
9.2 KiB
Python
"""FastAPI app 工厂。G1 脚手架 → G2 task list → G3 chat 只读 → G4+ 渐进上。
|
|
|
|
设计:
|
|
- 单 FastAPI 进程,模板走 Jinja2,静态走 StaticFiles
|
|
- 模板里 path 显示一律 `replace('\\', '/')`,Win / Linux 看到统一形态
|
|
(`Path.as_posix()` 在 Linux 读 Windows backslash 串时不归一,所以直接 replace)
|
|
- Markdown 渲染走 markdown-it-py(gfm-like)+ pygments syntax highlight
|
|
- SSE 在 G4 加,响应头会带 `X-Accel-Buffering: no`(nginx 反代友好)
|
|
- 本地形态 sentinel user 固定;Phase D 加 OIDC 之后才有真正 user 态
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
from uuid import UUID
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import HTMLResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from sqlalchemy import func, select
|
|
|
|
from core.storage import session_scope
|
|
from core.storage.models import Message, Task
|
|
|
|
WEB_ROOT = Path(__file__).resolve().parent
|
|
TEMPLATES_DIR = WEB_ROOT / "templates"
|
|
STATIC_DIR = WEB_ROOT / "static"
|
|
|
|
STATUS_FILTERS = ("active", "completed", "abandoned")
|
|
|
|
|
|
def _norm_path(p: str) -> str:
|
|
"""跨 OS 显示归一:backslash → forward slash。Win 存 `\\`、Linux 存 `/`,显示统一 `/`。"""
|
|
return (p or "").replace("\\", "/")
|
|
|
|
|
|
# --------------------------- Markdown 渲染 ---------------------------
|
|
|
|
_md_instance = None
|
|
|
|
|
|
def _pygments_highlight(code: str, lang: str, attrs: str) -> str:
|
|
"""markdown-it highlight 回调。lang 未识别 / pygments 异常时返回 '' 让 md 走默认 <pre><code>。"""
|
|
if not lang:
|
|
return ""
|
|
try:
|
|
from pygments import highlight
|
|
from pygments.formatters import HtmlFormatter
|
|
from pygments.lexers import get_lexer_by_name
|
|
from pygments.util import ClassNotFound
|
|
except ImportError:
|
|
return ""
|
|
try:
|
|
lexer = get_lexer_by_name(lang, stripall=False)
|
|
except ClassNotFound:
|
|
return ""
|
|
formatter = HtmlFormatter(nowrap=False, cssclass="codehilite")
|
|
return highlight(code, lexer, formatter)
|
|
|
|
|
|
def _get_md():
|
|
"""单例 MarkdownIt:gfm-like(表/strikethrough/linkify),禁 html(防 XSS),break=True。"""
|
|
global _md_instance
|
|
if _md_instance is None:
|
|
from markdown_it import MarkdownIt
|
|
_md_instance = MarkdownIt(
|
|
"gfm-like",
|
|
{
|
|
"linkify": True,
|
|
"html": False,
|
|
"breaks": True,
|
|
"highlight": _pygments_highlight,
|
|
},
|
|
)
|
|
return _md_instance
|
|
|
|
|
|
def _render_md(text: str) -> str:
|
|
"""渲染 markdown → HTML。空串返空。"""
|
|
if not text:
|
|
return ""
|
|
return _get_md().render(text)
|
|
|
|
|
|
# --------------------------- 消息块聚合 ---------------------------
|
|
|
|
def _args_preview(args: str, max_len: int = 60) -> str:
|
|
s = (args or "").replace("\n", " ").strip()
|
|
return s if len(s) <= max_len else s[: max_len - 3] + "..."
|
|
|
|
|
|
def _pretty_json(s: str) -> str:
|
|
"""JSON 串美化输出。解析失败返回原串。"""
|
|
try:
|
|
return json.dumps(json.loads(s), indent=2, ensure_ascii=False)
|
|
except Exception:
|
|
return s or ""
|
|
|
|
|
|
def load_chat_messages(task_id: UUID) -> list[dict]:
|
|
"""读 task 全部 messages(idx asc)。空 task 返空列表。"""
|
|
with session_scope() as s:
|
|
rows = s.execute(
|
|
select(Message.payload).where(Message.task_id == task_id).order_by(Message.idx)
|
|
).scalars().all()
|
|
return [dict(p) for p in rows]
|
|
|
|
|
|
def build_chat_blocks(messages: list[dict]) -> list[dict]:
|
|
"""把 LiteLLM 消息序列聚合成显示块。
|
|
|
|
- system / tool 不进 blocks(system 不入 DB;tool result 跟随 assistant 的 tool_call 内嵌)
|
|
- user → {type=user, html}
|
|
- assistant → {type=assistant, html, tool_calls=[{name,args_preview,args_pretty,result}]}
|
|
"""
|
|
tool_results: dict[str, str] = {}
|
|
for m in messages:
|
|
if m.get("role") == "tool":
|
|
tcid = m.get("tool_call_id")
|
|
if tcid:
|
|
tool_results[tcid] = m.get("content") or ""
|
|
|
|
blocks: list[dict] = []
|
|
for m in messages:
|
|
role = m.get("role")
|
|
if role in ("system", "tool"):
|
|
continue
|
|
if role == "user":
|
|
blocks.append({
|
|
"type": "user",
|
|
"html": _render_md(m.get("content") or ""),
|
|
})
|
|
elif role == "assistant":
|
|
content = m.get("content") or ""
|
|
tool_calls = m.get("tool_calls") or []
|
|
tc_blocks = []
|
|
for tc in tool_calls:
|
|
fn = tc.get("function", {}) or {}
|
|
args_raw = fn.get("arguments", "") or ""
|
|
tc_blocks.append({
|
|
"name": fn.get("name", "?"),
|
|
"args_preview": _args_preview(args_raw),
|
|
"args_pretty": _pretty_json(args_raw),
|
|
"result": tool_results.get(tc.get("id", ""), "[no result]"),
|
|
})
|
|
blocks.append({
|
|
"type": "assistant",
|
|
"html": _render_md(content),
|
|
"tool_calls": tc_blocks,
|
|
})
|
|
return blocks
|
|
|
|
|
|
# --------------------------- Task list 查询 ---------------------------
|
|
|
|
def list_tasks(limit: int = 50, status: Optional[str] = None) -> list[dict[str, Any]]:
|
|
"""Tasks 列表(updated_at 降序),含 messages 计数。"""
|
|
if status and status not in STATUS_FILTERS:
|
|
status = None
|
|
with session_scope() as s:
|
|
q = (
|
|
select(
|
|
Task.task_id, Task.updated_at, Task.created_at, Task.status,
|
|
Task.mode, Task.model, Task.model_profile,
|
|
Task.tokens_prompt, Task.tokens_completion, Task.description,
|
|
Task.task_dir,
|
|
)
|
|
.order_by(Task.updated_at.desc())
|
|
)
|
|
if status:
|
|
q = q.where(Task.status == status)
|
|
rows_db = s.execute(q.limit(limit)).all()
|
|
msg_counts = dict(s.execute(
|
|
select(Message.task_id, func.count()).group_by(Message.task_id)
|
|
).all())
|
|
|
|
result = []
|
|
for r in rows_db:
|
|
tid = r.task_id
|
|
result.append({
|
|
"task_id": str(tid),
|
|
"task_id_short": str(tid)[:8],
|
|
"updated_at": r.updated_at,
|
|
"created_at": r.created_at,
|
|
"status": r.status,
|
|
"mode": r.mode or "",
|
|
"model_label": r.model_profile or r.model or "",
|
|
"tokens": (r.tokens_prompt or 0) + (r.tokens_completion or 0),
|
|
"n_messages": msg_counts.get(tid, 0),
|
|
"description": r.description or "",
|
|
"task_dir": _norm_path(r.task_dir or ""),
|
|
})
|
|
return result
|
|
|
|
|
|
# --------------------------- App 工厂 ---------------------------
|
|
|
|
def create_app() -> FastAPI:
|
|
"""FastAPI 工厂。uvicorn --reload 模式需要工厂签名(factory=True)。"""
|
|
app = FastAPI(title="zcbot web", version="0.3")
|
|
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
|
|
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
def home(request: Request, status: Optional[str] = None, limit: int = 50):
|
|
tasks = list_tasks(limit=limit, status=status)
|
|
return templates.TemplateResponse(
|
|
request, "home.html",
|
|
{
|
|
"version": app.version,
|
|
"tasks": tasks,
|
|
"status": status or "",
|
|
"limit": limit,
|
|
"filters": STATUS_FILTERS,
|
|
},
|
|
)
|
|
|
|
@app.get("/tasks/{task_id}", response_class=HTMLResponse)
|
|
def task_detail(request: Request, task_id: str):
|
|
"""G3:UUID 校验 + 读 task 元数据 + 读 messages + 聚合成显示块 + 渲染。"""
|
|
try:
|
|
tid = UUID(task_id)
|
|
except ValueError:
|
|
return HTMLResponse(f"invalid task id: {task_id!r}", status_code=404)
|
|
with session_scope() as s:
|
|
row = s.execute(
|
|
select(
|
|
Task.task_id, Task.description, Task.task_dir, Task.status,
|
|
Task.mode, Task.model, Task.model_profile,
|
|
Task.tokens_prompt, Task.tokens_completion,
|
|
Task.created_at, Task.updated_at,
|
|
).where(Task.task_id == tid)
|
|
).first()
|
|
if row is None:
|
|
return HTMLResponse(f"task not found: {tid}", status_code=404)
|
|
|
|
messages = load_chat_messages(tid)
|
|
blocks = build_chat_blocks(messages)
|
|
|
|
return templates.TemplateResponse(
|
|
request, "chat.html",
|
|
{
|
|
"task_id": str(tid),
|
|
"task_id_short": str(tid)[:8],
|
|
"description": row.description or "",
|
|
"task_dir": _norm_path(row.task_dir or ""),
|
|
"status": row.status,
|
|
"mode": row.mode or "",
|
|
"model_label": row.model_profile or row.model or "",
|
|
"tokens": (row.tokens_prompt or 0) + (row.tokens_completion or 0),
|
|
"n_messages": len(messages),
|
|
"created_at": row.created_at,
|
|
"updated_at": row.updated_at,
|
|
"blocks": blocks,
|
|
},
|
|
)
|
|
|
|
@app.get("/healthz", response_class=HTMLResponse)
|
|
def healthz():
|
|
return HTMLResponse("ok")
|
|
|
|
return app
|