469 lines
17 KiB
Python
469 lines
17 KiB
Python
"""FastAPI app 工厂。G1 脚手架 → G2 task list → G3 chat 只读 → G4+ 渐进上。
|
||
|
||
设计:
|
||
- 单 FastAPI 进程,模板走 Jinja2,静态走 StaticFiles
|
||
- 模板里 path 显示一律 `replace('\\', '/')`,Win / Linux 看到统一形态
|
||
(`Path.as_posix()` 在 Linux 读 Windows backslash 串时不归一,所以直接 replace)
|
||
- Markdown 渲染走 markdown-it-py(gfm-like)+ pygments syntax highlight
|
||
- SSE 在 G4 加,响应头会带 `X-Accel-Buffering: no`(nginx 反代友好)
|
||
- 本地形态 sentinel user 固定;Phase D 加 OIDC 之后才有真正 user 态
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
from contextlib import asynccontextmanager
|
||
from pathlib import Path
|
||
from typing import Any, Optional
|
||
from uuid import UUID, uuid4
|
||
|
||
from fastapi import FastAPI, Form, HTTPException, Request
|
||
from fastapi.responses import HTMLResponse, StreamingResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.templating import Jinja2Templates
|
||
from sqlalchemy import func, select, update
|
||
|
||
from core.storage import session_scope
|
||
from core.storage.models import Message, Run, Task
|
||
|
||
from .broker import broker
|
||
from .sinks import WebEventSink
|
||
|
||
WEB_ROOT = Path(__file__).resolve().parent
|
||
TEMPLATES_DIR = WEB_ROOT / "templates"
|
||
STATIC_DIR = WEB_ROOT / "static"
|
||
|
||
STATUS_FILTERS = ("active", "completed", "abandoned")
|
||
|
||
|
||
def _norm_path(p: str) -> str:
|
||
"""跨 OS 显示归一:backslash → forward slash。Win 存 `\\`、Linux 存 `/`,显示统一 `/`。"""
|
||
return (p or "").replace("\\", "/")
|
||
|
||
|
||
# --------------------------- Markdown 渲染 ---------------------------
|
||
|
||
_md_instance = None
|
||
|
||
|
||
def _pygments_highlight(code: str, lang: str, attrs: str) -> str:
|
||
"""markdown-it highlight 回调。lang 未识别 / pygments 异常时返回 '' 让 md 走默认 <pre><code>。"""
|
||
if not lang:
|
||
return ""
|
||
try:
|
||
from pygments import highlight
|
||
from pygments.formatters import HtmlFormatter
|
||
from pygments.lexers import get_lexer_by_name
|
||
from pygments.util import ClassNotFound
|
||
except ImportError:
|
||
return ""
|
||
try:
|
||
lexer = get_lexer_by_name(lang, stripall=False)
|
||
except ClassNotFound:
|
||
return ""
|
||
formatter = HtmlFormatter(nowrap=False, cssclass="codehilite")
|
||
return highlight(code, lexer, formatter)
|
||
|
||
|
||
def _get_md():
|
||
"""单例 MarkdownIt:gfm-like(表/strikethrough/linkify),禁 html(防 XSS),break=True。"""
|
||
global _md_instance
|
||
if _md_instance is None:
|
||
from markdown_it import MarkdownIt
|
||
_md_instance = MarkdownIt(
|
||
"gfm-like",
|
||
{
|
||
"linkify": True,
|
||
"html": False,
|
||
"breaks": True,
|
||
"highlight": _pygments_highlight,
|
||
},
|
||
)
|
||
return _md_instance
|
||
|
||
|
||
def _render_md(text: str) -> str:
|
||
"""渲染 markdown → HTML。空串返空。"""
|
||
if not text:
|
||
return ""
|
||
return _get_md().render(text)
|
||
|
||
|
||
# --------------------------- 消息块聚合 ---------------------------
|
||
|
||
def _args_preview(args: str, max_len: int = 60) -> str:
|
||
s = (args or "").replace("\n", " ").strip()
|
||
return s if len(s) <= max_len else s[: max_len - 3] + "..."
|
||
|
||
|
||
def _pretty_json(s: str) -> str:
|
||
"""JSON 串美化输出。解析失败返回原串。"""
|
||
try:
|
||
return json.dumps(json.loads(s), indent=2, ensure_ascii=False)
|
||
except Exception:
|
||
return s or ""
|
||
|
||
|
||
def load_chat_messages(task_id: UUID) -> list[dict]:
|
||
"""读 task 全部 messages(idx asc)。空 task 返空列表。"""
|
||
with session_scope() as s:
|
||
rows = s.execute(
|
||
select(Message.payload).where(Message.task_id == task_id).order_by(Message.idx)
|
||
).scalars().all()
|
||
return [dict(p) for p in rows]
|
||
|
||
|
||
def build_chat_blocks(messages: list[dict]) -> list[dict]:
|
||
"""把 LiteLLM 消息序列聚合成显示块。
|
||
|
||
- system / tool 不进 blocks(system 不入 DB;tool result 跟随 assistant 的 tool_call 内嵌)
|
||
- user → {type=user, html}
|
||
- assistant → {type=assistant, html, tool_calls=[{name,args_preview,args_pretty,result}]}
|
||
"""
|
||
tool_results: dict[str, str] = {}
|
||
for m in messages:
|
||
if m.get("role") == "tool":
|
||
tcid = m.get("tool_call_id")
|
||
if tcid:
|
||
tool_results[tcid] = m.get("content") or ""
|
||
|
||
blocks: list[dict] = []
|
||
for m in messages:
|
||
role = m.get("role")
|
||
if role in ("system", "tool"):
|
||
continue
|
||
if role == "user":
|
||
blocks.append({
|
||
"type": "user",
|
||
"html": _render_md(m.get("content") or ""),
|
||
})
|
||
elif role == "assistant":
|
||
content = m.get("content") or ""
|
||
tool_calls = m.get("tool_calls") or []
|
||
tc_blocks = []
|
||
for tc in tool_calls:
|
||
fn = tc.get("function", {}) or {}
|
||
args_raw = fn.get("arguments", "") or ""
|
||
tc_blocks.append({
|
||
"name": fn.get("name", "?"),
|
||
"args_preview": _args_preview(args_raw),
|
||
"args_pretty": _pretty_json(args_raw),
|
||
"result": tool_results.get(tc.get("id", ""), "[no result]"),
|
||
})
|
||
blocks.append({
|
||
"type": "assistant",
|
||
"html": _render_md(content),
|
||
"tool_calls": tc_blocks,
|
||
})
|
||
return blocks
|
||
|
||
|
||
# --------------------------- Task list 查询 ---------------------------
|
||
|
||
def list_tasks(limit: int = 50, status: Optional[str] = None) -> list[dict[str, Any]]:
|
||
"""Tasks 列表(updated_at 降序),含 messages 计数。"""
|
||
if status and status not in STATUS_FILTERS:
|
||
status = None
|
||
with session_scope() as s:
|
||
q = (
|
||
select(
|
||
Task.task_id, Task.updated_at, Task.created_at, Task.status,
|
||
Task.mode, Task.model, Task.model_profile,
|
||
Task.tokens_prompt, Task.tokens_completion, Task.description,
|
||
Task.task_dir,
|
||
)
|
||
.order_by(Task.updated_at.desc())
|
||
)
|
||
if status:
|
||
q = q.where(Task.status == status)
|
||
rows_db = s.execute(q.limit(limit)).all()
|
||
msg_counts = dict(s.execute(
|
||
select(Message.task_id, func.count()).group_by(Message.task_id)
|
||
).all())
|
||
|
||
result = []
|
||
for r in rows_db:
|
||
tid = r.task_id
|
||
result.append({
|
||
"task_id": str(tid),
|
||
"task_id_short": str(tid)[:8],
|
||
"updated_at": r.updated_at,
|
||
"created_at": r.created_at,
|
||
"status": r.status,
|
||
"mode": r.mode or "",
|
||
"model_label": r.model_profile or r.model or "",
|
||
"tokens": (r.tokens_prompt or 0) + (r.tokens_completion or 0),
|
||
"n_messages": msg_counts.get(tid, 0),
|
||
"description": r.description or "",
|
||
"task_dir": _norm_path(r.task_dir or ""),
|
||
})
|
||
return result
|
||
|
||
|
||
# --------------------------- Run 启动 / SSE event 渲染 ---------------------------
|
||
|
||
def _run_agent_bg(task_id: UUID, run_id: UUID, user_message: str) -> None:
|
||
"""工作线程入口。这里**不能** await asyncio —— 在 to_thread 跑。
|
||
|
||
流程:build_agent(resume=True) → 装 WebEventSink → agent.run → 写 runs 状态。
|
||
"""
|
||
from main import build_agent, sync_task_tokens
|
||
|
||
# build_agent 会调 ensure_local_sentinel / LLM init / Session.load 等。
|
||
# 单次 POST 每次都走一遍 — 不便宜但简单;未来按需缓存 agent。
|
||
try:
|
||
broker.emit(run_id, {"type": "run_start"})
|
||
agent, session, sid, task_state, task_dir = build_agent(
|
||
session_id=str(task_id), resume=True,
|
||
)
|
||
agent.sink = WebEventSink(broker, run_id)
|
||
agent.run(user_message)
|
||
sync_task_tokens(task_state, agent.llm)
|
||
with session_scope() as s:
|
||
s.execute(
|
||
update(Run)
|
||
.where(Run.run_id == run_id)
|
||
.values(
|
||
status="ok",
|
||
finished_at=func.now(),
|
||
tokens_p=agent.llm.token_counter.prompt_tokens,
|
||
tokens_c=agent.llm.token_counter.completion_tokens,
|
||
)
|
||
)
|
||
except Exception as e:
|
||
err = f"{type(e).__name__}: {e}"
|
||
broker.emit(run_id, {"type": "error", "msg": err})
|
||
try:
|
||
with session_scope() as s:
|
||
s.execute(
|
||
update(Run)
|
||
.where(Run.run_id == run_id)
|
||
.values(status="error", error=err, finished_at=func.now())
|
||
)
|
||
except Exception:
|
||
pass # 已 emit 给前端,DB 写失败不再放大噪声
|
||
finally:
|
||
broker.close(run_id)
|
||
|
||
|
||
def _render_event_fragment(templates: Jinja2Templates, ev: dict, request: Request) -> str:
|
||
"""把一条 event 渲染成 HTML 片段(供 SSE data 推送)。
|
||
|
||
片段类型与 chat.html 静态 block 视觉一致,append 模式追加到 #chat-stream 容器尾。
|
||
text / tool_call / tool_result / error 各有专用块;run_start / llm_start / llm_end /
|
||
done 不出 HTML(用空串当 keep-alive,客户端依然能识别 event type 控制状态)。
|
||
"""
|
||
t = ev.get("type")
|
||
if t == "text":
|
||
content = ev.get("content") or ""
|
||
if not content:
|
||
return ""
|
||
# assistant text 片段:跟 chat.html 静态 assistant body 同形态
|
||
return templates.get_template("_frag_text.html").render(
|
||
request=request, html=_render_md(content)
|
||
)
|
||
if t == "tool_call":
|
||
return templates.get_template("_frag_tool_call.html").render(
|
||
request=request,
|
||
name=ev.get("name", "?"),
|
||
args_preview=_args_preview(ev.get("args_preview", "")),
|
||
args_pretty=_pretty_json(json.dumps(ev.get("args", {}), ensure_ascii=False))
|
||
if ev.get("args") is not None else _pretty_json(ev.get("args_preview", "")),
|
||
)
|
||
if t == "tool_result":
|
||
return templates.get_template("_frag_tool_result.html").render(
|
||
request=request,
|
||
name=ev.get("name", "?"),
|
||
preview=ev.get("preview", ""),
|
||
truncated=ev.get("truncated", False),
|
||
)
|
||
if t == "error":
|
||
return templates.get_template("_frag_error.html").render(
|
||
request=request, msg=ev.get("msg", "")
|
||
)
|
||
# llm_start / llm_end / run_start / done:发空 data,htmx-ext-sse 也会触发 event,
|
||
# 客户端只读 type 控制状态(spinner / close);data 内容不需要 swap。
|
||
return ""
|
||
|
||
|
||
def _sse_format(event_type: str, payload: str) -> bytes:
|
||
"""格式化一帧 SSE。data 多行要每行 `data: ` 前缀(SSE spec)。
|
||
|
||
EventSource API 会自动把 multi-line data 用 \n 拼接还原 — htmx-ext-sse 直接拿来当 HTML swap。
|
||
"""
|
||
parts = [f"event: {event_type}"]
|
||
if payload:
|
||
for line in payload.splitlines() or [""]:
|
||
parts.append(f"data: {line}")
|
||
else:
|
||
parts.append("data: ") # 空 data 也要有,EventSource 才认这帧
|
||
parts.append("") # 终结空行
|
||
parts.append("")
|
||
return ("\n".join(parts)).encode("utf-8")
|
||
|
||
|
||
# --------------------------- App 工厂 ---------------------------
|
||
|
||
def create_app() -> FastAPI:
|
||
"""FastAPI 工厂。uvicorn --reload 模式需要工厂签名(factory=True)。"""
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
# 把当前 asyncio loop 绑给 broker — emit() 从工作线程会 call_soon_threadsafe 桥回
|
||
broker.bind_loop(asyncio.get_running_loop())
|
||
yield
|
||
|
||
app = FastAPI(title="zcbot web", version="0.4", lifespan=lifespan)
|
||
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
|
||
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
||
|
||
@app.get("/", response_class=HTMLResponse)
|
||
def home(request: Request, status: Optional[str] = None, limit: int = 50):
|
||
tasks = list_tasks(limit=limit, status=status)
|
||
return templates.TemplateResponse(
|
||
request, "home.html",
|
||
{
|
||
"version": app.version,
|
||
"tasks": tasks,
|
||
"status": status or "",
|
||
"limit": limit,
|
||
"filters": STATUS_FILTERS,
|
||
},
|
||
)
|
||
|
||
@app.get("/tasks/{task_id}", response_class=HTMLResponse)
|
||
def task_detail(request: Request, task_id: str):
|
||
"""G3:UUID 校验 + 读 task 元数据 + 读 messages + 聚合成显示块 + 渲染。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
return HTMLResponse(f"invalid task id: {task_id!r}", status_code=404)
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(
|
||
Task.task_id, Task.description, Task.task_dir, Task.status,
|
||
Task.mode, Task.model, Task.model_profile,
|
||
Task.tokens_prompt, Task.tokens_completion,
|
||
Task.created_at, Task.updated_at,
|
||
).where(Task.task_id == tid)
|
||
).first()
|
||
if row is None:
|
||
return HTMLResponse(f"task not found: {tid}", status_code=404)
|
||
|
||
messages = load_chat_messages(tid)
|
||
blocks = build_chat_blocks(messages)
|
||
|
||
return templates.TemplateResponse(
|
||
request, "chat.html",
|
||
{
|
||
"task_id": str(tid),
|
||
"task_id_short": str(tid)[:8],
|
||
"description": row.description or "",
|
||
"task_dir": _norm_path(row.task_dir or ""),
|
||
"status": row.status,
|
||
"mode": row.mode or "",
|
||
"model_label": row.model_profile or row.model or "",
|
||
"tokens": (row.tokens_prompt or 0) + (row.tokens_completion or 0),
|
||
"n_messages": len(messages),
|
||
"created_at": row.created_at,
|
||
"updated_at": row.updated_at,
|
||
"blocks": blocks,
|
||
},
|
||
)
|
||
|
||
@app.post("/tasks/{task_id}/messages", response_class=HTMLResponse)
|
||
async def post_message(request: Request, task_id: str, content: str = Form(...)):
|
||
"""G4:用户提交消息 → 启 BG run → 返回 user msg 卡 + assistant 占位 + SSE 容器。
|
||
|
||
客户端 HTMX hx-post 这条,响应 swap 到 #chat-stream beforeend;响应 HTML 内含
|
||
sse-connect=/tasks/{id}/runs/{rid}/events,htmx-ext-sse 自动开 EventSource。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
content = (content or "").strip()
|
||
if not content:
|
||
raise HTTPException(400, "empty message")
|
||
# 校验 task 存在
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.task_id).where(Task.task_id == tid)
|
||
).first()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
|
||
run_id = uuid4()
|
||
with session_scope() as s:
|
||
s.add(Run(run_id=run_id, task_id=tid, status="running", started_at=func.now()))
|
||
|
||
# 启 BG agent — to_thread 跑 sync agent.run,sink 通过 broker 把 event 桥回 asyncio
|
||
asyncio.create_task(asyncio.to_thread(_run_agent_bg, tid, run_id, content))
|
||
|
||
return templates.TemplateResponse(
|
||
request, "_send_response.html",
|
||
{
|
||
"task_id": str(tid),
|
||
"run_id": str(run_id),
|
||
"user_html": _render_md(content),
|
||
},
|
||
)
|
||
|
||
@app.get("/tasks/{task_id}/runs/{run_id}/events")
|
||
async def stream_events(request: Request, task_id: str, run_id: str):
|
||
"""G4:SSE 流。订阅 broker[run_id] → 渲染 HTML 片段 → 推。
|
||
|
||
客户端断开(close tab / navigate)→ asyncio 在下次 yield 抛 CancelledError →
|
||
finally 清理。同 run 多订阅者(刷新页面 / 多 tab)各自独立 queue。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
rid = UUID(run_id)
|
||
except ValueError:
|
||
raise HTTPException(404, "invalid id")
|
||
# task 存在性校验(防探测 / 错链)
|
||
with session_scope() as s:
|
||
ok = s.execute(
|
||
select(Task.task_id).where(Task.task_id == tid)
|
||
).first()
|
||
if ok is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
|
||
async def gen():
|
||
q = broker.subscribe(rid)
|
||
try:
|
||
# 第一帧 retry 注释 + 心跳:让 EventSource 立即建立(不被 buffer 卡住)
|
||
yield b": connected\nretry: 3000\n\n"
|
||
while True:
|
||
try:
|
||
ev = await asyncio.wait_for(q.get(), timeout=30.0)
|
||
except asyncio.TimeoutError:
|
||
yield b": ping\n\n"
|
||
continue
|
||
ev_type = ev.get("type", "msg")
|
||
frag = _render_event_fragment(templates, ev, request)
|
||
yield _sse_format(ev_type, frag)
|
||
if ev_type in ("done", "error"):
|
||
break
|
||
except asyncio.CancelledError:
|
||
# 客户端断开 — 静默退,不向上抛
|
||
pass
|
||
finally:
|
||
broker.unsubscribe(rid, q)
|
||
|
||
return StreamingResponse(
|
||
gen(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no", # nginx 反代:别 buffer 这条流
|
||
},
|
||
)
|
||
|
||
@app.get("/healthz", response_class=HTMLResponse)
|
||
def healthz():
|
||
return HTMLResponse("ok")
|
||
|
||
return app
|