134 lines
4.9 KiB
Python
134 lines
4.9 KiB
Python
"""FastAPI app 工厂。G1 脚手架 → G2 task list 接 PG → G3+ 渐进上。
|
|
|
|
设计:
|
|
- 单 FastAPI 进程,模板走 Jinja2,静态走 StaticFiles
|
|
- 模板里 path 显示一律 `replace('\\', '/')`,Win / Linux 看到统一形态
|
|
(`Path.as_posix()` 在 Linux 读 Windows backslash 串时不归一,所以直接 replace)
|
|
- SSE 在 G4 加,响应头会带 `X-Accel-Buffering: no`(nginx 反代友好)
|
|
- 本地形态 sentinel user 固定;Phase D 加 OIDC 之后才有真正 user 态
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
from uuid import UUID
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import HTMLResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from sqlalchemy import func, select
|
|
|
|
from core.storage import session_scope
|
|
from core.storage.models import Message, Task
|
|
|
|
WEB_ROOT = Path(__file__).resolve().parent
|
|
TEMPLATES_DIR = WEB_ROOT / "templates"
|
|
STATIC_DIR = WEB_ROOT / "static"
|
|
|
|
STATUS_FILTERS = ("active", "completed", "abandoned")
|
|
|
|
|
|
def _norm_path(p: str) -> str:
|
|
"""跨 OS 显示归一:backslash → forward slash。Win 存 `\\`、Linux 存 `/`,显示统一 `/`。"""
|
|
return (p or "").replace("\\", "/")
|
|
|
|
|
|
def list_tasks(limit: int = 50, status: Optional[str] = None) -> list[dict[str, Any]]:
|
|
"""Tasks 列表(updated_at 降序),含 messages 计数。
|
|
|
|
返回 dict 列表,模板友好;cli.py 的 `_list_task_rows` 自留 tuple 形态不变
|
|
(CLI / Web 数据形状不一定一致 — 等真有 schema 变更同步成本时再抽。)
|
|
"""
|
|
if status and status not in STATUS_FILTERS:
|
|
status = None # 无效 status 静默忽略,等价 all
|
|
with session_scope() as s:
|
|
q = (
|
|
select(
|
|
Task.task_id, Task.updated_at, Task.created_at, Task.status,
|
|
Task.mode, Task.model, Task.model_profile,
|
|
Task.tokens_prompt, Task.tokens_completion, Task.description,
|
|
Task.task_dir,
|
|
)
|
|
.order_by(Task.updated_at.desc())
|
|
)
|
|
if status:
|
|
q = q.where(Task.status == status)
|
|
rows_db = s.execute(q.limit(limit)).all()
|
|
msg_counts = dict(s.execute(
|
|
select(Message.task_id, func.count()).group_by(Message.task_id)
|
|
).all())
|
|
|
|
result = []
|
|
for r in rows_db:
|
|
tid = r.task_id
|
|
result.append({
|
|
"task_id": str(tid),
|
|
"task_id_short": str(tid)[:8],
|
|
"updated_at": r.updated_at,
|
|
"created_at": r.created_at,
|
|
"status": r.status,
|
|
"mode": r.mode or "",
|
|
"model_label": r.model_profile or r.model or "",
|
|
"tokens": (r.tokens_prompt or 0) + (r.tokens_completion or 0),
|
|
"n_messages": msg_counts.get(tid, 0),
|
|
"description": r.description or "",
|
|
"task_dir": _norm_path(r.task_dir or ""),
|
|
})
|
|
return result
|
|
|
|
|
|
def create_app() -> FastAPI:
|
|
"""FastAPI 工厂。uvicorn --reload 模式需要工厂签名(factory=True)。"""
|
|
app = FastAPI(title="zcbot web", version="0.2")
|
|
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
|
|
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
def home(request: Request, status: Optional[str] = None, limit: int = 50):
|
|
tasks = list_tasks(limit=limit, status=status)
|
|
return templates.TemplateResponse(
|
|
request, "home.html",
|
|
{
|
|
"version": app.version,
|
|
"tasks": tasks,
|
|
"status": status or "",
|
|
"limit": limit,
|
|
"filters": STATUS_FILTERS,
|
|
},
|
|
)
|
|
|
|
@app.get("/tasks/{task_id}", response_class=HTMLResponse)
|
|
def task_detail(request: Request, task_id: str):
|
|
"""G2 占位:UUID 校验 + "G3 进行中" 提示。G3 落地后替换为消息渲染。"""
|
|
try:
|
|
tid = UUID(task_id)
|
|
except ValueError:
|
|
return HTMLResponse(
|
|
f"invalid task id: {task_id!r}", status_code=404
|
|
)
|
|
# 顺便 DB 里查一下 task 是否存在 — 不存在直接 404,避免 G3 上之前给假象
|
|
with session_scope() as s:
|
|
row = s.execute(
|
|
select(Task.task_id, Task.description, Task.task_dir, Task.status)
|
|
.where(Task.task_id == tid)
|
|
).first()
|
|
if row is None:
|
|
return HTMLResponse(f"task not found: {tid}", status_code=404)
|
|
return templates.TemplateResponse(
|
|
request, "task_placeholder.html",
|
|
{
|
|
"task_id": str(tid),
|
|
"task_id_short": str(tid)[:8],
|
|
"description": row.description or "",
|
|
"task_dir": _norm_path(row.task_dir or ""),
|
|
"status": row.status,
|
|
},
|
|
)
|
|
|
|
@app.get("/healthz", response_class=HTMLResponse)
|
|
def healthz():
|
|
return HTMLResponse("ok")
|
|
|
|
return app
|