1045 lines
42 KiB
Python
1045 lines
42 KiB
Python
"""FastAPI app: 纯 /v1 JSON API(2026-05-15 切换 — 详见 DESIGN §7.9)。
|
||
|
||
设计要点:
|
||
- 所有路由 `/v1/*` 前缀,响应 JSON;模板 / HTMX / 服务端 markdown 渲染全删
|
||
- SSE 事件 payload 是 JSON dict 而非 HTML 片段(`event: <type>` + `data: <json>`)
|
||
- Auth: PLATFORM_KEY → JWT 兑换(§7 D' 过渡形态,见 web/auth.py);OIDC 替换时只动 /v1/auth/login 内部
|
||
- 所有 /v1/tasks* 路由 Depends(require_user),按 user_id 隔离数据
|
||
- 豁免:/healthz、/docs、/openapi.json、/、/v1/auth/login、/static/*
|
||
- CORS allow_origins=["*"] 本地宽松;真发布按 platform 域名收紧
|
||
- `GET /` 302 → /static/dev.html(本地 dev SPA)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import tempfile
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime as _dt
|
||
from pathlib import Path
|
||
from typing import Any, Optional
|
||
from uuid import UUID, uuid4
|
||
|
||
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from pydantic import BaseModel
|
||
from sqlalchemy import func, select, update
|
||
from starlette.background import BackgroundTask
|
||
|
||
from core.paths import to_db_path
|
||
from core.storage import (
|
||
NoSubtaskError,
|
||
check_no_subtask,
|
||
session_scope,
|
||
)
|
||
from core.storage.models import Message, Task
|
||
from core.storage.utils import ensure_local_task_row
|
||
|
||
from .auth import AuthConfig, ensure_user_row, make_require_user, mint_token
|
||
from .broker import broker
|
||
from .sinks import WebEventSink
|
||
|
||
|
||
STATUS_FILTERS = ("active", "completed", "abandoned")
|
||
STATUS_WRITABLE = ("completed", "abandoned") # web 不让从 web 端切回 active(走 CLI)
|
||
ORDER_FIELDS = ("created_at", "updated_at", "name", "status")
|
||
ORDER_DEFAULT = "-created_at"
|
||
|
||
|
||
# ─────────────────────────── helpers ───────────────────────────
|
||
|
||
def _norm_path(p: str) -> str:
|
||
"""跨 OS 显示归一:backslash → forward slash。"""
|
||
return (p or "").replace("\\", "/")
|
||
|
||
|
||
def _iso(dt: Optional[Any]) -> Optional[str]:
|
||
return dt.isoformat() if dt else None
|
||
|
||
|
||
def _parse_ordering(s: Optional[str]) -> list:
|
||
"""DRF 风格 `ordering` 解析:逗号分隔多字段,`-` 前缀代表 desc。
|
||
|
||
allowlist 见 `ORDER_FIELDS`;非法字段静默丢弃。全部非法或空串 → `ORDER_DEFAULT`(`-created_at`)。
|
||
返回 sqlalchemy `order_by` 列表(可直接 `*expand`)。
|
||
"""
|
||
spec = (s or "").strip() or ORDER_DEFAULT
|
||
cols = []
|
||
for part in spec.split(","):
|
||
p = part.strip()
|
||
if not p:
|
||
continue
|
||
asc = True
|
||
if p.startswith("-"):
|
||
asc = False
|
||
p = p[1:]
|
||
if p in ORDER_FIELDS:
|
||
col = getattr(Task, p)
|
||
cols.append(col.asc() if asc else col.desc())
|
||
if not cols:
|
||
# 用户传了全无效字段 → fallback 默认
|
||
cols = [Task.created_at.desc()]
|
||
return cols
|
||
|
||
|
||
def _task_dict(row: Any, *, n_messages: Optional[int] = None) -> dict:
|
||
"""Task ORM row → API JSON dict。"""
|
||
d = {
|
||
"task_id": str(row.task_id),
|
||
"name": row.name or "",
|
||
"description": row.description or "",
|
||
"working_dir": _norm_path(row.working_dir or ""),
|
||
"status": row.status,
|
||
"skill": row.skill or "",
|
||
"model": row.model or "",
|
||
"model_profile": row.model_profile or "",
|
||
"tokens_prompt": row.tokens_prompt or 0,
|
||
"tokens_completion": row.tokens_completion or 0,
|
||
"tokens": (row.tokens_prompt or 0) + (row.tokens_completion or 0),
|
||
# 当前 run 状态(0004 schema 简化:原 runs 表合并入 task)
|
||
"run_status": row.run_status or "idle",
|
||
"run_error": row.run_error or None,
|
||
"created_at": _iso(getattr(row, "created_at", None)),
|
||
"updated_at": _iso(getattr(row, "updated_at", None)),
|
||
}
|
||
if n_messages is not None:
|
||
d["n_messages"] = n_messages
|
||
return d
|
||
|
||
|
||
# ─────────────────────── files helpers ───────────────────────
|
||
|
||
def _load_user_root(user_id: UUID) -> Path:
|
||
"""user_root = `<workspace>/users/<user_id>/`,所有 files API 的边界。
|
||
若目录尚未存在自动 mkdir(空 user 首次访问也能拿到根)。
|
||
"""
|
||
from core.agent_builder import resolve_workspace, user_root
|
||
ws = resolve_workspace(None)
|
||
return user_root(ws, user_id)
|
||
|
||
|
||
def _safe_join(root: Path, rel: str) -> Path:
|
||
"""归一用户路径到 absolute,并校验仍在 root 内。防 `../` / 绝对 path / symlink 越界。"""
|
||
rel = (rel or "").strip()
|
||
if not rel:
|
||
return root.resolve()
|
||
if rel[0] in ("/", "\\"):
|
||
raise HTTPException(400, f"absolute-style path not allowed: {rel!r}")
|
||
if Path(rel).is_absolute():
|
||
raise HTTPException(400, f"absolute path not allowed: {rel!r}")
|
||
target = (root / rel).resolve()
|
||
try:
|
||
target.relative_to(root.resolve())
|
||
except ValueError:
|
||
raise HTTPException(400, f"path escapes user_root: {rel!r}")
|
||
return target
|
||
|
||
|
||
def _rel_to(root: Path, target: Path) -> str:
|
||
try:
|
||
rel = target.resolve().relative_to(root.resolve()).as_posix()
|
||
except ValueError:
|
||
return ""
|
||
return "" if rel == "." else rel
|
||
|
||
|
||
def _enumerate_files(root: Path, current: Path) -> tuple[list[dict], list[dict], bool]:
|
||
"""枚举 current 下条目 + 拼面包屑。size raw bytes,mtime ISO 串(前端 humanize)。
|
||
Dotfile 一律隐藏(`.memory/` 等系统区不暴露给 UI,同 `/v1/folders` 约定)。
|
||
"""
|
||
entries: list[dict] = []
|
||
exists = current.exists()
|
||
if exists and current.is_dir():
|
||
try:
|
||
raw = sorted(current.iterdir(), key=lambda p: (p.is_file(), p.name.lower()))
|
||
except OSError:
|
||
raw = []
|
||
for p in raw:
|
||
if p.name.startswith("."):
|
||
continue
|
||
try:
|
||
st = p.stat()
|
||
except OSError:
|
||
continue
|
||
entries.append({
|
||
"name": p.name,
|
||
"is_dir": p.is_dir(),
|
||
"size": st.st_size if p.is_file() else None,
|
||
"mtime": _dt.fromtimestamp(st.st_mtime).isoformat(timespec="seconds"),
|
||
"rel": _rel_to(root, p),
|
||
})
|
||
cur_rel = _rel_to(root, current)
|
||
crumbs = [{"label": "/", "rel": ""}]
|
||
# cur_rel == "." 表示当前就在 root(target.relative_to(root) 返 Path(".")),
|
||
# 不该再追加一个无意义的 "." crumb
|
||
if cur_rel and cur_rel != ".":
|
||
acc = ""
|
||
for part in cur_rel.split("/"):
|
||
acc = f"{acc}/{part}" if acc else part
|
||
crumbs.append({"label": part, "rel": acc})
|
||
return entries, crumbs, exists
|
||
|
||
|
||
# ─────────────────── BG run + SSE 帧格式 ───────────────────
|
||
|
||
def _run_agent_bg(task_id: UUID, user_id: UUID, user_message: str) -> None:
|
||
"""工作线程:`build_agent(resume=True)` → 装 WebEventSink + cancel_check → `agent.run` → 写 tasks.run_status。
|
||
|
||
sink 通过 broker.emit 桥事件回 asyncio loop;agent.run 是 sync,所以在 to_thread 跑。
|
||
user_id 必须从 JWT 那侧透传过来 —— 决定 memory_block 读哪个 per-user 子树。
|
||
cancel_check 桥 broker.is_cancelled,loop 在工具调用之间 poll(LLM 同步调用本身不可中断)。
|
||
`ok / cancelled` 收尾直接回 `idle`(不留持久标记);只有 error 是持久终态。
|
||
"""
|
||
from core.agent_builder import build_agent, sync_task_tokens
|
||
try:
|
||
broker.emit(task_id, {"type": "run_start"})
|
||
agent, session, sid, task_state, task_dir = build_agent(
|
||
session_id=str(task_id), resume=True, user_id=user_id,
|
||
)
|
||
agent.sink = WebEventSink(broker, task_id)
|
||
agent.cancel_check = lambda tid=task_id: broker.is_cancelled(tid)
|
||
agent.run(user_message)
|
||
sync_task_tokens(task_state, agent.llm)
|
||
# cancel 命中或正常完成 → run_status 回 idle(error 才持久)
|
||
with session_scope() as s:
|
||
s.execute(
|
||
update(Task).where(Task.task_id == task_id).values(
|
||
run_status="idle", run_error=None,
|
||
)
|
||
)
|
||
except Exception as e:
|
||
err = f"{type(e).__name__}: {e}"
|
||
broker.emit(task_id, {"type": "error", "msg": err})
|
||
try:
|
||
with session_scope() as s:
|
||
s.execute(
|
||
update(Task).where(Task.task_id == task_id).values(
|
||
run_status="error", run_error=err,
|
||
)
|
||
)
|
||
except Exception:
|
||
pass # 已 emit error 给前端,DB 写失败不放大噪声
|
||
finally:
|
||
broker.clear_cancel(task_id)
|
||
broker.close(task_id)
|
||
|
||
|
||
def _sse_event(event_type: str, payload: dict) -> bytes:
|
||
"""格式化 SSE 一帧:`event: <type>` + `data: <json single-line>`。"""
|
||
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||
return f"event: {event_type}\ndata: {body}\n\n".encode("utf-8")
|
||
|
||
|
||
# ────────────────────── Pydantic 请求体 ──────────────────────
|
||
|
||
class TaskCreateRequest(BaseModel):
|
||
name: str # 任务显示名(必填,DB 列 NOT NULL)
|
||
working_dir: str = "" # 工作目录名(可选,留空 → 用 name 作目录名)
|
||
description: str = ""
|
||
skill: str = ""
|
||
|
||
|
||
class TaskPatchRequest(BaseModel):
|
||
status: Optional[str] = None
|
||
description: Optional[str] = None
|
||
name: Optional[str] = None
|
||
skill: Optional[str] = None
|
||
|
||
|
||
class MessageRequest(BaseModel):
|
||
content: str
|
||
|
||
|
||
class FileDeleteRequest(BaseModel):
|
||
path: str
|
||
|
||
|
||
class FileRenameRequest(BaseModel):
|
||
path: str # 被重命名的目录 / 文件,相对 user_root
|
||
new_name: str # 新的 leaf 名(不是路径),不含 / \ ..
|
||
|
||
|
||
class LoginRequest(BaseModel):
|
||
user_id: str
|
||
platform_key: str
|
||
|
||
|
||
# ────────────────────── App 工厂 ──────────────────────
|
||
|
||
# web/static 目录路径 — /static 静态挂载用,dev.html 也放这
|
||
_STATIC_DIR = Path(__file__).parent / "static"
|
||
|
||
|
||
def create_app() -> FastAPI:
|
||
# fail-fast:env 缺失直接抛,不裸跑无密
|
||
auth_cfg = AuthConfig.from_env()
|
||
require_user = make_require_user(auth_cfg)
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
broker.bind_loop(asyncio.get_running_loop())
|
||
# Skill 注册表启动时扫一次 — 文件系统静态,运行中不变;/v1/skills 直接读
|
||
from core.agent_builder import load_config
|
||
from core.paths import ROOT
|
||
from core.skills import SkillRegistry
|
||
_cfg = load_config()
|
||
app.state.skill_registry = SkillRegistry(ROOT / _cfg.get("skills_dir", "skills"))
|
||
# Stale-run reaper:上次进程 crash 留下的 "running" / "cancelling" 已无 BG 线程
|
||
# 继续,启动时标 error,让对应 task 重新可发消息(否则 gate 永挂)。
|
||
# TODO 真生产 multi-worker:换 heartbeat / lease,只 reap 自家 worker 的孤儿。
|
||
with session_scope() as s:
|
||
result = s.execute(
|
||
update(Task)
|
||
.where(Task.run_status.in_(("running", "cancelling")))
|
||
.values(
|
||
run_status="error",
|
||
run_error="server restarted before run finished",
|
||
)
|
||
)
|
||
if result.rowcount:
|
||
print(f"[startup] reaped {result.rowcount} stale active run(s)")
|
||
yield
|
||
|
||
app = FastAPI(
|
||
title="zcbot api",
|
||
version="0.8",
|
||
description=(
|
||
"zcbot 后端 — /v1 JSON API + SSE。Auth: PLATFORM_KEY → JWT(§7 D' 过渡)。"
|
||
"本地 dev SPA: /static/dev.html。"
|
||
),
|
||
lifespan=lifespan,
|
||
)
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 本地宽松,部署 platform 时按域名收紧
|
||
allow_credentials=False,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
if _STATIC_DIR.is_dir():
|
||
app.mount("/static", StaticFiles(directory=str(_STATIC_DIR)), name="static")
|
||
|
||
# ───────────── Misc ─────────────
|
||
|
||
@app.get("/", include_in_schema=False)
|
||
def root():
|
||
# 本地 dev SPA;Swagger UI 仍在 /docs
|
||
return RedirectResponse(url="/static/dev.html", status_code=302)
|
||
|
||
@app.get("/healthz", tags=["misc"])
|
||
def healthz():
|
||
return {"status": "ok"}
|
||
|
||
# ───────────── Auth ─────────────
|
||
|
||
@app.post("/v1/auth/login", tags=["auth"])
|
||
def login(body: LoginRequest):
|
||
"""platform_key 校验通过 → 签 JWT(user_id 作为 sub)。
|
||
|
||
platform_key 错 → 403;user_id 非 UUID → 400。
|
||
user_id 未存在则幂等创建 users 行(避免下游 FK 失败)。
|
||
"""
|
||
if body.platform_key != auth_cfg.platform_key:
|
||
raise HTTPException(403, "invalid platform_key")
|
||
try:
|
||
uid = UUID(body.user_id)
|
||
except (ValueError, TypeError):
|
||
raise HTTPException(400, f"invalid user_id (must be UUID): {body.user_id!r}")
|
||
ensure_user_row(uid)
|
||
token, exp = mint_token(auth_cfg, uid)
|
||
return {
|
||
"token": token,
|
||
"expires_at": _dt.fromtimestamp(exp).isoformat(),
|
||
"user_id": str(uid),
|
||
"ttl_seconds": auth_cfg.ttl_seconds,
|
||
}
|
||
|
||
# ───────────── Tasks CRUD ─────────────
|
||
|
||
@app.post("/v1/tasks", status_code=201, tags=["tasks"])
|
||
def create_task(body: TaskCreateRequest, user_id: UUID = Depends(require_user)):
|
||
"""新建 task。
|
||
|
||
- `name` 必填(任务显示名,DB 列 NOT NULL,UI 列表 / 标题用)
|
||
- `working_dir` 可选(留空 → 用 name 作目录名);同 working_dir 多 task 共享同目录(§7.1)
|
||
- name / working_dir 都过 validate_task_name(简单名,无 `/\\..`,非 `.` 起头,≤255)
|
||
- 前缀嵌套(no-subtask,同 user 内)→ 409
|
||
"""
|
||
from core.agent_builder import InvalidTaskName, resolve_workspace, validate_task_name, working_dir_from_name
|
||
try:
|
||
name = validate_task_name(body.name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"name 不合法: {e}")
|
||
# working_dir 留空 → fallback 用 name
|
||
wd_raw = (body.working_dir or "").strip()
|
||
wd_name = wd_raw if wd_raw else name
|
||
try:
|
||
wd_name = validate_task_name(wd_name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"working_dir 不合法: {e}")
|
||
description = body.description.strip()
|
||
skill = body.skill.strip()
|
||
|
||
tid = uuid4()
|
||
ws = resolve_workspace(None)
|
||
fs_dir = working_dir_from_name(ws, user_id, wd_name)
|
||
fs_dir_db = to_db_path(fs_dir)
|
||
|
||
try:
|
||
check_no_subtask(fs_dir_db, user_id=user_id)
|
||
except NoSubtaskError as e:
|
||
raise HTTPException(409, str(e))
|
||
|
||
# 工作目录立刻建出(同 working_dir 多 task 共享,exist_ok=True)
|
||
fs_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
ensure_local_task_row(
|
||
task_id=tid, name=name, working_dir=fs_dir_db, skill=skill,
|
||
description=description, user_id=user_id,
|
||
)
|
||
with session_scope() as s:
|
||
row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
|
||
return _task_dict(row, n_messages=0)
|
||
|
||
@app.get("/v1/tasks", tags=["tasks"])
|
||
def list_tasks_route(
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
status: Optional[str] = None,
|
||
skill: Optional[str] = None,
|
||
working_dir: Optional[str] = None,
|
||
q: Optional[str] = None,
|
||
ordering: Optional[str] = None,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""列出当前 user 的 task,分页 + 多维筛选 + 排序。
|
||
|
||
- `page` ≥ 1(1-based);`page_size` 1–100(超界 clamp)
|
||
- `status` 在 active/completed/abandoned;非法值静默忽略
|
||
- `skill` 精确匹配(空忽略)
|
||
- `working_dir` 末段目录名(如 `水泥申报`);后端自动拼 `workspace/users/<uid>/<name>` 比对
|
||
- `q` 模糊搜索 name + description(ILIKE,大小写不敏感)
|
||
- `ordering` DRF 风格,逗号分隔,`-field` 倒序;allowlist `created_at/updated_at/name/status`;
|
||
非法字段静默忽略;**默认 `-created_at`**(创建时间倒序)
|
||
返回标准分页壳 `{page, page_size, count, results}` —— count 供前端算总页数。
|
||
"""
|
||
# clamp + sanitize
|
||
page = max(1, page)
|
||
page_size = max(1, min(page_size, 100))
|
||
status = status if status in STATUS_FILTERS else None
|
||
skill = (skill or "").strip() or None
|
||
wd_name = (working_dir or "").strip() or None
|
||
q_text = (q or "").strip() or None
|
||
|
||
# 组装 WHERE
|
||
conditions = [Task.user_id == user_id]
|
||
if status:
|
||
conditions.append(Task.status == status)
|
||
if skill:
|
||
conditions.append(Task.skill == skill)
|
||
if wd_name:
|
||
# 末段 → 完整 db form。同 working_dir 多 task 共享时,这是命中入口。
|
||
wd_db = f"workspace/users/{user_id}/{wd_name}"
|
||
conditions.append(Task.working_dir == wd_db)
|
||
if q_text:
|
||
pat = f"%{q_text}%"
|
||
conditions.append(Task.name.ilike(pat) | Task.description.ilike(pat))
|
||
|
||
offset = (page - 1) * page_size
|
||
|
||
with session_scope() as s:
|
||
cnt = s.execute(
|
||
select(func.count()).select_from(Task).where(*conditions)
|
||
).scalar_one() or 0
|
||
|
||
rows = s.execute(
|
||
select(Task).where(*conditions)
|
||
.order_by(*_parse_ordering(ordering))
|
||
.limit(page_size).offset(offset)
|
||
).scalars().all()
|
||
|
||
tids = [r.task_id for r in rows]
|
||
msg_counts = (
|
||
dict(s.execute(
|
||
select(Message.task_id, func.count())
|
||
.where(Message.task_id.in_(tids))
|
||
.group_by(Message.task_id)
|
||
).all())
|
||
if tids else {}
|
||
)
|
||
|
||
return {
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"count": int(cnt),
|
||
"results": [
|
||
_task_dict(r, n_messages=msg_counts.get(r.task_id, 0))
|
||
for r in rows
|
||
],
|
||
}
|
||
|
||
@app.get("/v1/tasks/{task_id}", tags=["tasks"])
|
||
def get_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""单 task meta(不含 messages;走 /messages 拿)。跨 user → 404。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task).where(Task.task_id == tid, Task.user_id == user_id)
|
||
).scalar_one_or_none()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
n = s.execute(
|
||
select(func.count()).select_from(Message).where(Message.task_id == tid)
|
||
).scalar_one()
|
||
return _task_dict(row, n_messages=n)
|
||
|
||
@app.get("/v1/folders", tags=["folders"])
|
||
def list_folders(user_id: UUID = Depends(require_user)):
|
||
"""列出当前 user 的工作目录(`workspace/users/<uid>/` 下非 dotfile 子目录)。
|
||
供新建 task 时自动补全 / 选已有目录用。FS 是 source of truth(也含手动创建
|
||
但还无关联 task 的目录)。每项带 n_tasks(关联 task 数)+ last_used(最近使用 ISO)。
|
||
排序:有 last_used 的按降序,无 last_used 的排最后,同列 by name asc。
|
||
"""
|
||
from core.agent_builder import resolve_workspace, user_root
|
||
ws = resolve_workspace(None)
|
||
root = user_root(ws, user_id)
|
||
|
||
folder_names: list[str] = []
|
||
if root.is_dir():
|
||
for p in sorted(root.iterdir(), key=lambda x: x.name.lower()):
|
||
if p.is_dir() and not p.name.startswith("."):
|
||
folder_names.append(p.name)
|
||
|
||
folders: list[dict] = []
|
||
if folder_names:
|
||
with session_scope() as s:
|
||
for name in folder_names:
|
||
db_form = f"workspace/users/{user_id}/{name}"
|
||
stat = s.execute(
|
||
select(func.count(), func.max(Task.updated_at))
|
||
.where(Task.user_id == user_id, Task.working_dir == db_form)
|
||
).first()
|
||
n = int((stat[0] if stat else 0) or 0)
|
||
lu = stat[1] if stat else None
|
||
folders.append({
|
||
"name": name,
|
||
"n_tasks": n,
|
||
"last_used": _iso(lu),
|
||
})
|
||
|
||
folders.sort(key=lambda f: f["name"])
|
||
folders.sort(key=lambda f: f["last_used"] or "", reverse=True)
|
||
return {"folders": folders}
|
||
|
||
# ───────────── Skills ─────────────
|
||
|
||
@app.get("/v1/skills", tags=["skills"])
|
||
def list_skills(user_id: UUID = Depends(require_user)):
|
||
"""列出当前可用的 skill(智能体类型),供新建 task 时下拉选择。
|
||
|
||
skill registry 在 lifespan 启动时扫一次(`skills/<name>/SKILL.md` frontmatter),
|
||
运行中不变;返 name + description 给 UI 渲染。排序按 name 升序(registry 已 sorted)。
|
||
"""
|
||
reg = getattr(app.state, "skill_registry", None)
|
||
skills = list(reg.skills.values()) if reg else []
|
||
return {
|
||
"skills": [
|
||
{"name": s.name, "description": s.description}
|
||
for s in skills
|
||
]
|
||
}
|
||
|
||
@app.delete("/v1/tasks/{task_id}", status_code=204, tags=["tasks"])
|
||
def delete_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""硬删除:DELETE DB 行(messages / runs CASCADE)。**FS task_dir 不动**
|
||
(同 name 多 task 共享,文件由用户经 /files/delete 单独清)。跨 user → 404。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
from sqlalchemy import delete as _delete
|
||
with session_scope() as s:
|
||
result = s.execute(
|
||
_delete(Task).where(Task.task_id == tid, Task.user_id == user_id)
|
||
)
|
||
if result.rowcount == 0:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
return None # 204
|
||
|
||
@app.patch("/v1/tasks/{task_id}", tags=["tasks"])
|
||
def patch_task(
|
||
task_id: str,
|
||
body: TaskPatchRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""更新 task 字段。`status` 仅允许 completed/abandoned(active 走 CLI 切回)。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
updates: dict[str, Any] = {}
|
||
if body.status is not None:
|
||
if body.status not in STATUS_WRITABLE:
|
||
raise HTTPException(
|
||
400, f"invalid status {body.status!r}; allowed: {STATUS_WRITABLE}"
|
||
)
|
||
updates["status"] = body.status
|
||
if body.description is not None:
|
||
updates["description"] = body.description
|
||
if body.skill is not None:
|
||
updates["skill"] = body.skill
|
||
if body.name is not None:
|
||
from core.agent_builder import InvalidTaskName, validate_task_name
|
||
try:
|
||
updates["name"] = validate_task_name(body.name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"name 不合法: {e}")
|
||
if not updates:
|
||
raise HTTPException(400, "no fields to update")
|
||
with session_scope() as s:
|
||
result = s.execute(
|
||
update(Task)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.values(**updates)
|
||
)
|
||
if result.rowcount == 0:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
|
||
n = s.execute(
|
||
select(func.count()).select_from(Message).where(Message.task_id == tid)
|
||
).scalar_one()
|
||
return _task_dict(row, n_messages=n)
|
||
|
||
# ───────────── Messages ─────────────
|
||
|
||
def _assert_owns_task(s, tid: UUID, user_id: UUID) -> None:
|
||
ok = s.execute(
|
||
select(Task.task_id).where(Task.task_id == tid, Task.user_id == user_id)
|
||
).first()
|
||
if ok is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
|
||
@app.get("/v1/tasks/{task_id}/messages", tags=["messages"])
|
||
def list_messages(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""task 历史消息(idx 升序);LiteLLM 原 payload 透传给前端,自行渲染。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
rows = s.execute(
|
||
select(
|
||
Message.idx, Message.payload, Message.tokens_in,
|
||
Message.tokens_out, Message.created_at,
|
||
).where(Message.task_id == tid).order_by(Message.idx)
|
||
).all()
|
||
return {
|
||
"messages": [
|
||
{
|
||
"idx": r.idx,
|
||
"payload": dict(r.payload),
|
||
"tokens_in": r.tokens_in,
|
||
"tokens_out": r.tokens_out,
|
||
"created_at": _iso(r.created_at),
|
||
}
|
||
for r in rows
|
||
]
|
||
}
|
||
|
||
@app.post("/v1/tasks/{task_id}/messages", status_code=202, tags=["messages"])
|
||
async def post_message(
|
||
task_id: str,
|
||
body: MessageRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""发消息 + 起 BG run。返 `{events_url}`,客户端立刻订阅 SSE 拿流式。
|
||
|
||
单活 run:`SELECT … FOR UPDATE` 锁 task 行 + 活跃状态检查 + 标 running,
|
||
全收进一个事务挡住"用户连点 send 两条消息"导致两个 BG 线程争 `messages.idx`。
|
||
tasks.run_status in ('running','cancelling') → 409;'error' 走起新 run 时清掉
|
||
(跟 ok / cancelled 一样视为可重启)。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
content = (body.content or "").strip()
|
||
if not content:
|
||
raise HTTPException(400, "empty content")
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.run_status)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.with_for_update()
|
||
).first()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
if row.run_status in ("running", "cancelling"):
|
||
raise HTTPException(
|
||
409,
|
||
f"task already has an active run (status={row.run_status}); "
|
||
f"wait for it to finish or cancel",
|
||
)
|
||
s.execute(
|
||
update(Task).where(Task.task_id == tid).values(
|
||
run_status="running", run_error=None,
|
||
)
|
||
)
|
||
broker.start(tid) # 清上一轮 done 标记,新订阅者才能看到流式
|
||
# commit 后 lock 释放;BG 线程接管(sink 通过 broker 把 event 桥回 asyncio loop)
|
||
asyncio.create_task(asyncio.to_thread(_run_agent_bg, tid, user_id, content))
|
||
return {"events_url": f"/v1/tasks/{tid}/events"}
|
||
|
||
# ───────────── Cancel current run ─────────────
|
||
|
||
@app.post("/v1/tasks/{task_id}/cancel", status_code=202, tags=["tasks"])
|
||
def cancel_task(
|
||
task_id: str,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""向当前 task 的活跃 run 发协作式 cancel 信号。
|
||
- 单活 run 形态下"取消当前活动"语义无歧义;客户端只需 task_id
|
||
- 校验 task 归属 user;否则 404
|
||
- tasks.run_status 不是 `running` → 409(idle / cancelling / error 都不能 cancel)
|
||
- 标 `cancelling`(过渡态),BG 线程 loop 在工具调用之间 poll 看见即退;
|
||
退出后 finally 写终态(正常→idle,异常→error)
|
||
- LLM 同步调用本身不可中断,最坏要等当前 LLM call 跑完(通常几十秒内)
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.run_status)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.with_for_update()
|
||
).first()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
if row.run_status != "running":
|
||
raise HTTPException(
|
||
409,
|
||
f"task not running (run_status={row.run_status}); cannot cancel",
|
||
)
|
||
s.execute(
|
||
update(Task).where(Task.task_id == tid).values(run_status="cancelling")
|
||
)
|
||
broker.request_cancel(tid)
|
||
return {"ok": True, "task_id": str(tid), "run_status": "cancelling"}
|
||
|
||
# ───────────── SSE events ─────────────
|
||
|
||
@app.get("/v1/tasks/{task_id}/events", tags=["tasks"])
|
||
async def stream_events(
|
||
task_id: str,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""SSE 流。订阅当前 task 的活动 event(单活 run 形态下无歧义)。
|
||
事件类型:run_start / llm_start / text / tool_call / tool_result /
|
||
llm_end / cancelled / error / done。data 是 JSON dict(已剔除 `type` 字段,
|
||
移到 event 名)。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
|
||
async def gen():
|
||
q = broker.subscribe(tid)
|
||
try:
|
||
# 第一帧 retry 注释 + 心跳:让 EventSource 立即建立(不被 buffer 卡)
|
||
yield b": connected\nretry: 3000\n\n"
|
||
while True:
|
||
try:
|
||
ev = await asyncio.wait_for(q.get(), timeout=30.0)
|
||
except asyncio.TimeoutError:
|
||
yield b": ping\n\n"
|
||
continue
|
||
ev_type = ev.get("type", "msg")
|
||
payload = {k: v for k, v in ev.items() if k != "type"}
|
||
yield _sse_event(ev_type, payload)
|
||
if ev_type in ("done", "error"):
|
||
break
|
||
except asyncio.CancelledError:
|
||
pass # 客户端断开,静默退
|
||
finally:
|
||
broker.unsubscribe(tid, q)
|
||
|
||
return StreamingResponse(
|
||
gen(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|
||
|
||
# ───────────── Files(user-rooted,不绑 task) ─────────────
|
||
|
||
@app.get("/v1/files", tags=["files"])
|
||
def list_files(
|
||
path: str = "",
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""列 user_root 下子目录条目 + 面包屑。`path` 留空 → user_root;
|
||
`../` / 绝对 → 400。dotfile(`.memory/` 等)一律隐藏。
|
||
"""
|
||
root = _load_user_root(user_id)
|
||
current = _safe_join(root, path)
|
||
entries, crumbs, exists = _enumerate_files(root, current)
|
||
return {
|
||
"root": _norm_path(str(root)),
|
||
"current": _rel_to(root, current),
|
||
"exists": exists,
|
||
"crumbs": crumbs,
|
||
"entries": entries,
|
||
}
|
||
|
||
@app.get("/v1/files/download", tags=["files"])
|
||
def download_file(
|
||
path: str,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""下载 user_root 下单个 regular file(目录 → 400 / 不存在 → 404)。"""
|
||
root = _load_user_root(user_id)
|
||
target = _safe_join(root, path)
|
||
if not target.exists():
|
||
raise HTTPException(404, f"file not found: {path}")
|
||
if not target.is_file():
|
||
raise HTTPException(400, f"not a file: {path}")
|
||
return FileResponse(path=str(target), filename=target.name)
|
||
|
||
@app.post("/v1/files/upload", tags=["files"])
|
||
async def upload_files(
|
||
path: str = Form(""),
|
||
files: list[UploadFile] = File(...),
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""multipart 多文件上传到 `<user_root>/<path>/`。
|
||
路径不存在自动 mkdir(parents=True);重名直接覆盖。
|
||
文件名严格校验(含 `/ \\ ..` 或为空 → 400)。
|
||
"""
|
||
root = _load_user_root(user_id)
|
||
dest_dir = _safe_join(root, path)
|
||
if dest_dir.exists() and not dest_dir.is_dir():
|
||
raise HTTPException(400, f"upload target is a file, not a directory: {path}")
|
||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
saved: list[dict] = []
|
||
for up in files or []:
|
||
raw_name = up.filename or ""
|
||
if (
|
||
not raw_name
|
||
or raw_name in (".", "..")
|
||
or "/" in raw_name or "\\" in raw_name
|
||
or any(part in (".", "..") for part in Path(raw_name).parts)
|
||
):
|
||
raise HTTPException(400, f"invalid filename: {raw_name!r}")
|
||
dest = dest_dir / raw_name
|
||
try:
|
||
dest.resolve().relative_to(root.resolve())
|
||
except ValueError:
|
||
raise HTTPException(400, f"path escapes user_root: {raw_name!r}")
|
||
data = await up.read()
|
||
dest.write_bytes(data)
|
||
saved.append({"name": raw_name, "size": len(data), "rel": _rel_to(root, dest)})
|
||
if not saved:
|
||
raise HTTPException(400, "no files uploaded")
|
||
return {"count": len(saved), "saved": saved}
|
||
|
||
@app.post("/v1/files/delete", tags=["files"])
|
||
def delete_file(
|
||
body: FileDeleteRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""删 user_root 下文件或**空**目录。非空目录 → 400;root → 400。
|
||
|
||
若 path 是**顶层目录**(user_root 直接子项,且为目录),还会查 tasks 表:
|
||
有任意 task 的 working_dir 指向此目录 → 409,要求先 DELETE 关联 task。
|
||
这是为了避免悬空引用(task.working_dir 是 DB 字符串,删 FS 不会自动 unset)。
|
||
"""
|
||
root = _load_user_root(user_id)
|
||
target = _safe_join(root, body.path)
|
||
if target.resolve() == root.resolve():
|
||
raise HTTPException(400, "cannot delete user_root")
|
||
if not target.exists():
|
||
raise HTTPException(404, f"path not found: {body.path}")
|
||
|
||
is_top_level_dir = (
|
||
target.is_dir() and target.parent.resolve() == root.resolve()
|
||
)
|
||
if is_top_level_dir:
|
||
db_form = to_db_path(target)
|
||
with session_scope() as s:
|
||
cnt = s.execute(
|
||
select(func.count()).select_from(Task).where(
|
||
Task.user_id == user_id, Task.working_dir == db_form,
|
||
)
|
||
).scalar_one() or 0
|
||
if cnt:
|
||
raise HTTPException(
|
||
409,
|
||
f"folder {body.path!r} 仍被 {cnt} 个 task 引用(working_dir);"
|
||
f"请先 DELETE 关联 task 再删目录",
|
||
)
|
||
|
||
try:
|
||
if target.is_dir():
|
||
target.rmdir() # 非空目录会触发 OSError
|
||
else:
|
||
target.unlink()
|
||
except OSError as e:
|
||
raise HTTPException(400, f"delete failed: {e}")
|
||
return {"ok": True, "path": body.path}
|
||
|
||
@app.post("/v1/files/rename", tags=["files"])
|
||
def rename_path(
|
||
body: FileRenameRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""重命名 user_root 下文件或目录(任意深度)。
|
||
|
||
- `path` 必填,指被重命名对象;不能为 user_root
|
||
- `new_name` 是新 leaf 名(`validate_task_name`:非空 / 不含 `/\\..` / 非 dotfile / ≤255);
|
||
不是路径,parent 自动取自原 path
|
||
- 目标 sibling `<parent>/<new_name>` 不能已存在(防覆盖)
|
||
- **path 是顶层目录**(user_root 直接子项,且为目录)→ DB-aware:
|
||
* 同事务内 `SELECT ... FOR UPDATE` 锁该目录对应 task;任一 run_status 在
|
||
running/cancelling → 409(避免 BG 线程握旧路径而 DB 已指新路径)
|
||
* `check_no_subtask(new_db, exclude=被改名 tids)` 防止改名后跟其它 task 形成嵌套
|
||
* `UPDATE tasks SET working_dir=new_db WHERE task_id IN (...)` 先写 DB
|
||
* 再 `os.rename` FS;失败 → 抛错 → session_scope 回滚 DB
|
||
* 唯一不一致窗口是 "FS 已改名 + commit 阶段失败"(PG 单事务 commit 极少失败)
|
||
- 非顶层(子目录 / 文件)→ 纯 FS rename,不动 DB
|
||
"""
|
||
from core.agent_builder import InvalidTaskName, validate_task_name
|
||
|
||
root = _load_user_root(user_id)
|
||
target = _safe_join(root, body.path)
|
||
if target.resolve() == root.resolve():
|
||
raise HTTPException(400, "cannot rename user_root")
|
||
if not target.exists():
|
||
raise HTTPException(404, f"path not found: {body.path}")
|
||
|
||
try:
|
||
new_name = validate_task_name(body.new_name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"new_name 不合法: {e}")
|
||
|
||
if new_name == target.name:
|
||
raise HTTPException(400, f"new_name 与原名相同: {new_name!r}")
|
||
|
||
new_target = target.parent / new_name
|
||
if new_target.exists():
|
||
raise HTTPException(
|
||
409, f"target already exists: {_rel_to(root, new_target)!r}"
|
||
)
|
||
|
||
is_top_level_dir = (
|
||
target.is_dir() and target.parent.resolve() == root.resolve()
|
||
)
|
||
|
||
if not is_top_level_dir:
|
||
try:
|
||
target.rename(new_target)
|
||
except OSError as e:
|
||
raise HTTPException(400, f"rename failed: {e}")
|
||
return {
|
||
"ok": True,
|
||
"old": body.path,
|
||
"new": _rel_to(root, new_target),
|
||
"tasks_updated": 0,
|
||
}
|
||
|
||
# 顶层目录:DB-aware
|
||
old_db = to_db_path(target)
|
||
new_db = to_db_path(new_target)
|
||
with session_scope() as s:
|
||
rows = s.execute(
|
||
select(Task.task_id, Task.run_status)
|
||
.where(Task.user_id == user_id, Task.working_dir == old_db)
|
||
.with_for_update()
|
||
).all()
|
||
tids = [r.task_id for r in rows]
|
||
active = [
|
||
str(r.task_id)[:8] for r in rows
|
||
if r.run_status in ("running", "cancelling")
|
||
]
|
||
if active:
|
||
raise HTTPException(
|
||
409,
|
||
f"folder has active run(s) on task(s) {active}; "
|
||
f"cancel before renaming",
|
||
)
|
||
try:
|
||
check_no_subtask(new_db, user_id=user_id, exclude_task_ids=tids)
|
||
except NoSubtaskError as e:
|
||
raise HTTPException(409, str(e))
|
||
if tids:
|
||
s.execute(
|
||
update(Task)
|
||
.where(Task.task_id.in_(tids))
|
||
.values(working_dir=new_db)
|
||
)
|
||
try:
|
||
target.rename(new_target)
|
||
except OSError as e:
|
||
# 抛 HTTPException 也会让 session_scope 走 except 分支回滚 UPDATE
|
||
raise HTTPException(400, f"FS rename failed: {e}")
|
||
|
||
return {
|
||
"ok": True,
|
||
"old": body.path,
|
||
"new": _rel_to(root, new_target),
|
||
"tasks_updated": len(tids),
|
||
}
|
||
|
||
# ───────────── Export ─────────────
|
||
|
||
@app.get("/v1/tasks/{task_id}/export", tags=["export"])
|
||
def export_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""导出对话为 .docx,临时文件下载完后 BackgroundTask 删 tmp。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
has_msg = s.execute(
|
||
select(Message.message_id).where(Message.task_id == tid).limit(1)
|
||
).first()
|
||
if not has_msg:
|
||
raise HTTPException(400, "no messages to export")
|
||
|
||
fd, tmp_str = tempfile.mkstemp(suffix=".docx", prefix="zcbot-export-")
|
||
os.close(fd)
|
||
tmp_path = Path(tmp_str)
|
||
try:
|
||
from core.export_docx import export_chat_to_docx
|
||
export_chat_to_docx(tid, out_path=tmp_path)
|
||
except Exception as e:
|
||
tmp_path.unlink(missing_ok=True)
|
||
raise HTTPException(500, f"export failed: {type(e).__name__}: {e}")
|
||
|
||
return FileResponse(
|
||
path=str(tmp_path),
|
||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
filename=f"chat_{str(tid)[:8]}.docx",
|
||
background=BackgroundTask(tmp_path.unlink, missing_ok=True),
|
||
)
|
||
|
||
return app
|