816 lines
32 KiB
Python
816 lines
32 KiB
Python
"""FastAPI app: 纯 /v1 JSON API(2026-05-15 切换 — 详见 DESIGN §7.9)。
|
||
|
||
设计要点:
|
||
- 所有路由 `/v1/*` 前缀,响应 JSON;模板 / HTMX / 服务端 markdown 渲染全删
|
||
- SSE 事件 payload 是 JSON dict 而非 HTML 片段(`event: <type>` + `data: <json>`)
|
||
- Auth: PLATFORM_KEY → JWT 兑换(§7 D' 过渡形态,见 web/auth.py);OIDC 替换时只动 /v1/auth/login 内部
|
||
- 所有 /v1/tasks* 路由 Depends(require_user),按 user_id 隔离数据
|
||
- 豁免:/healthz、/docs、/openapi.json、/、/v1/auth/login、/static/*
|
||
- CORS allow_origins=["*"] 本地宽松;真发布按 platform 域名收紧
|
||
- `GET /` 302 → /static/dev.html(本地 dev SPA)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import tempfile
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime as _dt
|
||
from pathlib import Path
|
||
from typing import Any, Optional
|
||
from uuid import UUID, uuid4
|
||
|
||
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from pydantic import BaseModel
|
||
from sqlalchemy import func, select, update
|
||
from starlette.background import BackgroundTask
|
||
|
||
from core.paths import to_db_path
|
||
from core.storage import (
|
||
NoSubtaskError,
|
||
check_no_subtask,
|
||
session_scope,
|
||
)
|
||
from core.storage.models import Message, Run, Task
|
||
from core.storage.utils import ensure_local_task_row
|
||
|
||
from .auth import AuthConfig, ensure_user_row, make_require_user, mint_token
|
||
from .broker import broker
|
||
from .sinks import WebEventSink
|
||
|
||
|
||
STATUS_FILTERS = ("active", "completed", "abandoned")
|
||
STATUS_WRITABLE = ("completed", "abandoned") # web 不让从 web 端切回 active(走 CLI)
|
||
ORDER_FIELDS = ("created_at", "updated_at", "name", "status")
|
||
ORDER_DEFAULT = "-created_at"
|
||
|
||
|
||
# ─────────────────────────── helpers ───────────────────────────
|
||
|
||
def _norm_path(p: str) -> str:
|
||
"""跨 OS 显示归一:backslash → forward slash。"""
|
||
return (p or "").replace("\\", "/")
|
||
|
||
|
||
def _iso(dt: Optional[Any]) -> Optional[str]:
|
||
return dt.isoformat() if dt else None
|
||
|
||
|
||
def _parse_ordering(s: Optional[str]) -> list:
|
||
"""DRF 风格 `ordering` 解析:逗号分隔多字段,`-` 前缀代表 desc。
|
||
|
||
allowlist 见 `ORDER_FIELDS`;非法字段静默丢弃。全部非法或空串 → `ORDER_DEFAULT`(`-created_at`)。
|
||
返回 sqlalchemy `order_by` 列表(可直接 `*expand`)。
|
||
"""
|
||
spec = (s or "").strip() or ORDER_DEFAULT
|
||
cols = []
|
||
for part in spec.split(","):
|
||
p = part.strip()
|
||
if not p:
|
||
continue
|
||
asc = True
|
||
if p.startswith("-"):
|
||
asc = False
|
||
p = p[1:]
|
||
if p in ORDER_FIELDS:
|
||
col = getattr(Task, p)
|
||
cols.append(col.asc() if asc else col.desc())
|
||
if not cols:
|
||
# 用户传了全无效字段 → fallback 默认
|
||
cols = [Task.created_at.desc()]
|
||
return cols
|
||
|
||
|
||
def _task_dict(row: Any, *, n_messages: Optional[int] = None) -> dict:
|
||
"""Task ORM row → API JSON dict。"""
|
||
d = {
|
||
"task_id": str(row.task_id),
|
||
"name": row.name or "",
|
||
"description": row.description or "",
|
||
"working_dir": _norm_path(row.working_dir or ""),
|
||
"status": row.status,
|
||
"skill": row.skill or "",
|
||
"model": row.model or "",
|
||
"model_profile": row.model_profile or "",
|
||
"tokens_prompt": row.tokens_prompt or 0,
|
||
"tokens_completion": row.tokens_completion or 0,
|
||
"tokens": (row.tokens_prompt or 0) + (row.tokens_completion or 0),
|
||
"created_at": _iso(getattr(row, "created_at", None)),
|
||
"updated_at": _iso(getattr(row, "updated_at", None)),
|
||
}
|
||
if n_messages is not None:
|
||
d["n_messages"] = n_messages
|
||
return d
|
||
|
||
|
||
# ─────────────────────── files helpers ───────────────────────
|
||
|
||
def _load_user_root(user_id: UUID) -> Path:
|
||
"""user_root = `<workspace>/users/<user_id>/`,所有 files API 的边界。
|
||
若目录尚未存在自动 mkdir(空 user 首次访问也能拿到根)。
|
||
"""
|
||
from main import resolve_workspace, user_root
|
||
ws = resolve_workspace(None)
|
||
return user_root(ws, user_id)
|
||
|
||
|
||
def _safe_join(root: Path, rel: str) -> Path:
|
||
"""归一用户路径到 absolute,并校验仍在 root 内。防 `../` / 绝对 path / symlink 越界。"""
|
||
rel = (rel or "").strip()
|
||
if not rel:
|
||
return root.resolve()
|
||
if rel[0] in ("/", "\\"):
|
||
raise HTTPException(400, f"absolute-style path not allowed: {rel!r}")
|
||
if Path(rel).is_absolute():
|
||
raise HTTPException(400, f"absolute path not allowed: {rel!r}")
|
||
target = (root / rel).resolve()
|
||
try:
|
||
target.relative_to(root.resolve())
|
||
except ValueError:
|
||
raise HTTPException(400, f"path escapes user_root: {rel!r}")
|
||
return target
|
||
|
||
|
||
def _rel_to(root: Path, target: Path) -> str:
|
||
try:
|
||
rel = target.resolve().relative_to(root.resolve()).as_posix()
|
||
except ValueError:
|
||
return ""
|
||
return "" if rel == "." else rel
|
||
|
||
|
||
def _enumerate_files(root: Path, current: Path) -> tuple[list[dict], list[dict], bool]:
|
||
"""枚举 current 下条目 + 拼面包屑。size raw bytes,mtime ISO 串(前端 humanize)。
|
||
Dotfile 一律隐藏(`.memory/` 等系统区不暴露给 UI,同 `/v1/folders` 约定)。
|
||
"""
|
||
entries: list[dict] = []
|
||
exists = current.exists()
|
||
if exists and current.is_dir():
|
||
try:
|
||
raw = sorted(current.iterdir(), key=lambda p: (p.is_file(), p.name.lower()))
|
||
except OSError:
|
||
raw = []
|
||
for p in raw:
|
||
if p.name.startswith("."):
|
||
continue
|
||
try:
|
||
st = p.stat()
|
||
except OSError:
|
||
continue
|
||
entries.append({
|
||
"name": p.name,
|
||
"is_dir": p.is_dir(),
|
||
"size": st.st_size if p.is_file() else None,
|
||
"mtime": _dt.fromtimestamp(st.st_mtime).isoformat(timespec="seconds"),
|
||
"rel": _rel_to(root, p),
|
||
})
|
||
cur_rel = _rel_to(root, current)
|
||
crumbs = [{"label": "/", "rel": ""}]
|
||
# cur_rel == "." 表示当前就在 root(target.relative_to(root) 返 Path(".")),
|
||
# 不该再追加一个无意义的 "." crumb
|
||
if cur_rel and cur_rel != ".":
|
||
acc = ""
|
||
for part in cur_rel.split("/"):
|
||
acc = f"{acc}/{part}" if acc else part
|
||
crumbs.append({"label": part, "rel": acc})
|
||
return entries, crumbs, exists
|
||
|
||
|
||
# ─────────────────── Run 启动 + SSE 帧格式 ───────────────────
|
||
|
||
def _run_agent_bg(task_id: UUID, run_id: UUID, user_id: UUID, user_message: str) -> None:
|
||
"""工作线程:`build_agent(resume=True)` → 装 WebEventSink → `agent.run` → 写 runs 状态。
|
||
|
||
sink 通过 broker.emit 桥事件回 asyncio loop;agent.run 是 sync,所以在 to_thread 跑。
|
||
user_id 必须从 JWT 那侧透传过来 —— 决定 memory_block 读哪个 per-user 子树。
|
||
"""
|
||
from main import build_agent, sync_task_tokens
|
||
try:
|
||
broker.emit(run_id, {"type": "run_start"})
|
||
agent, session, sid, task_state, task_dir = build_agent(
|
||
session_id=str(task_id), resume=True, user_id=user_id,
|
||
)
|
||
agent.sink = WebEventSink(broker, run_id)
|
||
agent.run(user_message)
|
||
sync_task_tokens(task_state, agent.llm)
|
||
with session_scope() as s:
|
||
s.execute(
|
||
update(Run).where(Run.run_id == run_id).values(
|
||
status="ok",
|
||
finished_at=func.now(),
|
||
tokens_p=agent.llm.token_counter.prompt_tokens,
|
||
tokens_c=agent.llm.token_counter.completion_tokens,
|
||
)
|
||
)
|
||
except Exception as e:
|
||
err = f"{type(e).__name__}: {e}"
|
||
broker.emit(run_id, {"type": "error", "msg": err})
|
||
try:
|
||
with session_scope() as s:
|
||
s.execute(
|
||
update(Run).where(Run.run_id == run_id).values(
|
||
status="error", error=err, finished_at=func.now()
|
||
)
|
||
)
|
||
except Exception:
|
||
pass # 已 emit error 给前端,DB 写失败不放大噪声
|
||
finally:
|
||
broker.close(run_id)
|
||
|
||
|
||
def _sse_event(event_type: str, payload: dict) -> bytes:
|
||
"""格式化 SSE 一帧:`event: <type>` + `data: <json single-line>`。"""
|
||
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||
return f"event: {event_type}\ndata: {body}\n\n".encode("utf-8")
|
||
|
||
|
||
# ────────────────────── Pydantic 请求体 ──────────────────────
|
||
|
||
class TaskCreateRequest(BaseModel):
|
||
name: str # 任务显示名(必填,DB 列 NOT NULL)
|
||
working_dir: str = "" # 工作目录名(可选,留空 → 用 name 作目录名)
|
||
description: str = ""
|
||
skill: str = ""
|
||
|
||
|
||
class TaskPatchRequest(BaseModel):
|
||
status: Optional[str] = None
|
||
description: Optional[str] = None
|
||
name: Optional[str] = None
|
||
skill: Optional[str] = None
|
||
|
||
|
||
class MessageRequest(BaseModel):
|
||
content: str
|
||
|
||
|
||
class FileDeleteRequest(BaseModel):
|
||
path: str
|
||
|
||
|
||
class LoginRequest(BaseModel):
|
||
user_id: str
|
||
platform_key: str
|
||
|
||
|
||
# ────────────────────── App 工厂 ──────────────────────
|
||
|
||
# web/static 目录路径 — /static 静态挂载用,dev.html 也放这
|
||
_STATIC_DIR = Path(__file__).parent / "static"
|
||
|
||
|
||
def create_app() -> FastAPI:
|
||
# fail-fast:env 缺失直接抛,不裸跑无密
|
||
auth_cfg = AuthConfig.from_env()
|
||
require_user = make_require_user(auth_cfg)
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
broker.bind_loop(asyncio.get_running_loop())
|
||
yield
|
||
|
||
app = FastAPI(
|
||
title="zcbot api",
|
||
version="0.8",
|
||
description=(
|
||
"zcbot 后端 — /v1 JSON API + SSE。Auth: PLATFORM_KEY → JWT(§7 D' 过渡)。"
|
||
"本地 dev SPA: /static/dev.html。"
|
||
),
|
||
lifespan=lifespan,
|
||
)
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 本地宽松,部署 platform 时按域名收紧
|
||
allow_credentials=False,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
if _STATIC_DIR.is_dir():
|
||
app.mount("/static", StaticFiles(directory=str(_STATIC_DIR)), name="static")
|
||
|
||
# ───────────── Misc ─────────────
|
||
|
||
@app.get("/", include_in_schema=False)
|
||
def root():
|
||
# 本地 dev SPA;Swagger UI 仍在 /docs
|
||
return RedirectResponse(url="/static/dev.html", status_code=302)
|
||
|
||
@app.get("/healthz", tags=["misc"])
|
||
def healthz():
|
||
return {"status": "ok"}
|
||
|
||
# ───────────── Auth ─────────────
|
||
|
||
@app.post("/v1/auth/login", tags=["auth"])
|
||
def login(body: LoginRequest):
|
||
"""platform_key 校验通过 → 签 JWT(user_id 作为 sub)。
|
||
|
||
platform_key 错 → 403;user_id 非 UUID → 400。
|
||
user_id 未存在则幂等创建 users 行(避免下游 FK 失败)。
|
||
"""
|
||
if body.platform_key != auth_cfg.platform_key:
|
||
raise HTTPException(403, "invalid platform_key")
|
||
try:
|
||
uid = UUID(body.user_id)
|
||
except (ValueError, TypeError):
|
||
raise HTTPException(400, f"invalid user_id (must be UUID): {body.user_id!r}")
|
||
ensure_user_row(uid)
|
||
token, exp = mint_token(auth_cfg, uid)
|
||
return {
|
||
"token": token,
|
||
"expires_at": _dt.fromtimestamp(exp).isoformat(),
|
||
"user_id": str(uid),
|
||
"ttl_seconds": auth_cfg.ttl_seconds,
|
||
}
|
||
|
||
# ───────────── Tasks CRUD ─────────────
|
||
|
||
@app.post("/v1/tasks", status_code=201, tags=["tasks"])
|
||
def create_task(body: TaskCreateRequest, user_id: UUID = Depends(require_user)):
|
||
"""新建 task。
|
||
|
||
- `name` 必填(任务显示名,DB 列 NOT NULL,UI 列表 / 标题用)
|
||
- `working_dir` 可选(留空 → 用 name 作目录名);同 working_dir 多 task 共享同目录(§7.1)
|
||
- name / working_dir 都过 validate_task_name(简单名,无 `/\\..`,非 `.` 起头,≤255)
|
||
- 前缀嵌套(no-subtask,同 user 内)→ 409
|
||
"""
|
||
from main import InvalidTaskName, resolve_workspace, validate_task_name, working_dir_from_name
|
||
try:
|
||
name = validate_task_name(body.name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"name 不合法: {e}")
|
||
# working_dir 留空 → fallback 用 name
|
||
wd_raw = (body.working_dir or "").strip()
|
||
wd_name = wd_raw if wd_raw else name
|
||
try:
|
||
wd_name = validate_task_name(wd_name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"working_dir 不合法: {e}")
|
||
description = body.description.strip()
|
||
skill = body.skill.strip()
|
||
|
||
tid = uuid4()
|
||
ws = resolve_workspace(None)
|
||
fs_dir = working_dir_from_name(ws, user_id, wd_name)
|
||
fs_dir_db = to_db_path(fs_dir)
|
||
|
||
try:
|
||
check_no_subtask(fs_dir_db, user_id=user_id)
|
||
except NoSubtaskError as e:
|
||
raise HTTPException(409, str(e))
|
||
|
||
# 工作目录立刻建出(同 working_dir 多 task 共享,exist_ok=True)
|
||
fs_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
ensure_local_task_row(
|
||
task_id=tid, name=name, working_dir=fs_dir_db, skill=skill,
|
||
description=description, user_id=user_id,
|
||
)
|
||
with session_scope() as s:
|
||
row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
|
||
return _task_dict(row, n_messages=0)
|
||
|
||
@app.get("/v1/tasks", tags=["tasks"])
|
||
def list_tasks_route(
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
status: Optional[str] = None,
|
||
skill: Optional[str] = None,
|
||
working_dir: Optional[str] = None,
|
||
q: Optional[str] = None,
|
||
ordering: Optional[str] = None,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""列出当前 user 的 task,分页 + 多维筛选 + 排序。
|
||
|
||
- `page` ≥ 1(1-based);`page_size` 1–100(超界 clamp)
|
||
- `status` 在 active/completed/abandoned;非法值静默忽略
|
||
- `skill` 精确匹配(空忽略)
|
||
- `working_dir` 末段目录名(如 `水泥申报`);后端自动拼 `workspace/users/<uid>/<name>` 比对
|
||
- `q` 模糊搜索 name + description(ILIKE,大小写不敏感)
|
||
- `ordering` DRF 风格,逗号分隔,`-field` 倒序;allowlist `created_at/updated_at/name/status`;
|
||
非法字段静默忽略;**默认 `-created_at`**(创建时间倒序)
|
||
返回标准分页壳 `{page, page_size, count, results}` —— count 供前端算总页数。
|
||
"""
|
||
# clamp + sanitize
|
||
page = max(1, page)
|
||
page_size = max(1, min(page_size, 100))
|
||
status = status if status in STATUS_FILTERS else None
|
||
skill = (skill or "").strip() or None
|
||
wd_name = (working_dir or "").strip() or None
|
||
q_text = (q or "").strip() or None
|
||
|
||
# 组装 WHERE
|
||
conditions = [Task.user_id == user_id]
|
||
if status:
|
||
conditions.append(Task.status == status)
|
||
if skill:
|
||
conditions.append(Task.skill == skill)
|
||
if wd_name:
|
||
# 末段 → 完整 db form。同 working_dir 多 task 共享时,这是命中入口。
|
||
wd_db = f"workspace/users/{user_id}/{wd_name}"
|
||
conditions.append(Task.working_dir == wd_db)
|
||
if q_text:
|
||
pat = f"%{q_text}%"
|
||
conditions.append(Task.name.ilike(pat) | Task.description.ilike(pat))
|
||
|
||
offset = (page - 1) * page_size
|
||
|
||
with session_scope() as s:
|
||
cnt = s.execute(
|
||
select(func.count()).select_from(Task).where(*conditions)
|
||
).scalar_one() or 0
|
||
|
||
rows = s.execute(
|
||
select(Task).where(*conditions)
|
||
.order_by(*_parse_ordering(ordering))
|
||
.limit(page_size).offset(offset)
|
||
).scalars().all()
|
||
|
||
tids = [r.task_id for r in rows]
|
||
msg_counts = (
|
||
dict(s.execute(
|
||
select(Message.task_id, func.count())
|
||
.where(Message.task_id.in_(tids))
|
||
.group_by(Message.task_id)
|
||
).all())
|
||
if tids else {}
|
||
)
|
||
|
||
return {
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"count": int(cnt),
|
||
"results": [
|
||
_task_dict(r, n_messages=msg_counts.get(r.task_id, 0))
|
||
for r in rows
|
||
],
|
||
}
|
||
|
||
@app.get("/v1/tasks/{task_id}", tags=["tasks"])
|
||
def get_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""单 task meta(不含 messages;走 /messages 拿)。跨 user → 404。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task).where(Task.task_id == tid, Task.user_id == user_id)
|
||
).scalar_one_or_none()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
n = s.execute(
|
||
select(func.count()).select_from(Message).where(Message.task_id == tid)
|
||
).scalar_one()
|
||
return _task_dict(row, n_messages=n)
|
||
|
||
@app.get("/v1/folders", tags=["folders"])
|
||
def list_folders(user_id: UUID = Depends(require_user)):
|
||
"""列出当前 user 的工作目录(`workspace/users/<uid>/` 下非 dotfile 子目录)。
|
||
供新建 task 时自动补全 / 选已有目录用。FS 是 source of truth(也含手动创建
|
||
但还无关联 task 的目录)。每项带 n_tasks(关联 task 数)+ last_used(最近使用 ISO)。
|
||
排序:有 last_used 的按降序,无 last_used 的排最后,同列 by name asc。
|
||
"""
|
||
from main import resolve_workspace, user_root
|
||
ws = resolve_workspace(None)
|
||
root = user_root(ws, user_id)
|
||
|
||
folder_names: list[str] = []
|
||
if root.is_dir():
|
||
for p in sorted(root.iterdir(), key=lambda x: x.name.lower()):
|
||
if p.is_dir() and not p.name.startswith("."):
|
||
folder_names.append(p.name)
|
||
|
||
folders: list[dict] = []
|
||
if folder_names:
|
||
with session_scope() as s:
|
||
for name in folder_names:
|
||
db_form = f"workspace/users/{user_id}/{name}"
|
||
stat = s.execute(
|
||
select(func.count(), func.max(Task.updated_at))
|
||
.where(Task.user_id == user_id, Task.working_dir == db_form)
|
||
).first()
|
||
n = int((stat[0] if stat else 0) or 0)
|
||
lu = stat[1] if stat else None
|
||
folders.append({
|
||
"name": name,
|
||
"n_tasks": n,
|
||
"last_used": _iso(lu),
|
||
})
|
||
|
||
folders.sort(key=lambda f: f["name"])
|
||
folders.sort(key=lambda f: f["last_used"] or "", reverse=True)
|
||
return {"folders": folders}
|
||
|
||
@app.delete("/v1/tasks/{task_id}", status_code=204, tags=["tasks"])
|
||
def delete_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""硬删除:DELETE DB 行(messages / runs CASCADE)。**FS task_dir 不动**
|
||
(同 name 多 task 共享,文件由用户经 /files/delete 单独清)。跨 user → 404。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
from sqlalchemy import delete as _delete
|
||
with session_scope() as s:
|
||
result = s.execute(
|
||
_delete(Task).where(Task.task_id == tid, Task.user_id == user_id)
|
||
)
|
||
if result.rowcount == 0:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
return None # 204
|
||
|
||
@app.patch("/v1/tasks/{task_id}", tags=["tasks"])
|
||
def patch_task(
|
||
task_id: str,
|
||
body: TaskPatchRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""更新 task 字段。`status` 仅允许 completed/abandoned(active 走 CLI 切回)。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
updates: dict[str, Any] = {}
|
||
if body.status is not None:
|
||
if body.status not in STATUS_WRITABLE:
|
||
raise HTTPException(
|
||
400, f"invalid status {body.status!r}; allowed: {STATUS_WRITABLE}"
|
||
)
|
||
updates["status"] = body.status
|
||
if body.description is not None:
|
||
updates["description"] = body.description
|
||
if body.skill is not None:
|
||
updates["skill"] = body.skill
|
||
if body.name is not None:
|
||
from main import InvalidTaskName, validate_task_name
|
||
try:
|
||
updates["name"] = validate_task_name(body.name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"name 不合法: {e}")
|
||
if not updates:
|
||
raise HTTPException(400, "no fields to update")
|
||
with session_scope() as s:
|
||
result = s.execute(
|
||
update(Task)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.values(**updates)
|
||
)
|
||
if result.rowcount == 0:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
|
||
n = s.execute(
|
||
select(func.count()).select_from(Message).where(Message.task_id == tid)
|
||
).scalar_one()
|
||
return _task_dict(row, n_messages=n)
|
||
|
||
# ───────────── Messages ─────────────
|
||
|
||
def _assert_owns_task(s, tid: UUID, user_id: UUID) -> None:
|
||
ok = s.execute(
|
||
select(Task.task_id).where(Task.task_id == tid, Task.user_id == user_id)
|
||
).first()
|
||
if ok is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
|
||
@app.get("/v1/tasks/{task_id}/messages", tags=["messages"])
|
||
def list_messages(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""task 历史消息(idx 升序);LiteLLM 原 payload 透传给前端,自行渲染。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
rows = s.execute(
|
||
select(
|
||
Message.idx, Message.payload, Message.tokens_in,
|
||
Message.tokens_out, Message.created_at,
|
||
).where(Message.task_id == tid).order_by(Message.idx)
|
||
).all()
|
||
return {
|
||
"messages": [
|
||
{
|
||
"idx": r.idx,
|
||
"payload": dict(r.payload),
|
||
"tokens_in": r.tokens_in,
|
||
"tokens_out": r.tokens_out,
|
||
"created_at": _iso(r.created_at),
|
||
}
|
||
for r in rows
|
||
]
|
||
}
|
||
|
||
@app.post("/v1/tasks/{task_id}/messages", status_code=202, tags=["messages"])
|
||
async def post_message(
|
||
task_id: str,
|
||
body: MessageRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""发消息 + 起 BG run。返 `{run_id, events_url}`,客户端立刻订阅 SSE 拿流式。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
content = (body.content or "").strip()
|
||
if not content:
|
||
raise HTTPException(400, "empty content")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
|
||
run_id = uuid4()
|
||
with session_scope() as s:
|
||
s.add(Run(run_id=run_id, task_id=tid, status="running", started_at=func.now()))
|
||
# to_thread 跑 sync agent.run;sink 通过 broker 把 event 桥回 asyncio
|
||
asyncio.create_task(asyncio.to_thread(_run_agent_bg, tid, run_id, user_id, content))
|
||
return {
|
||
"run_id": str(run_id),
|
||
"events_url": f"/v1/tasks/{tid}/runs/{run_id}/events",
|
||
}
|
||
|
||
# ───────────── SSE events ─────────────
|
||
|
||
@app.get("/v1/tasks/{task_id}/runs/{run_id}/events", tags=["runs"])
|
||
async def stream_events(
|
||
task_id: str,
|
||
run_id: str,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""SSE 流。事件类型:run_start / llm_start / text / tool_call / tool_result /
|
||
llm_end / error / done。data 是 JSON dict(已剔除 `type` 字段,移到 event 名)。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
rid = UUID(run_id)
|
||
except ValueError:
|
||
raise HTTPException(404, "invalid id")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
|
||
async def gen():
|
||
q = broker.subscribe(rid)
|
||
try:
|
||
# 第一帧 retry 注释 + 心跳:让 EventSource 立即建立(不被 buffer 卡)
|
||
yield b": connected\nretry: 3000\n\n"
|
||
while True:
|
||
try:
|
||
ev = await asyncio.wait_for(q.get(), timeout=30.0)
|
||
except asyncio.TimeoutError:
|
||
yield b": ping\n\n"
|
||
continue
|
||
ev_type = ev.get("type", "msg")
|
||
payload = {k: v for k, v in ev.items() if k != "type"}
|
||
yield _sse_event(ev_type, payload)
|
||
if ev_type in ("done", "error"):
|
||
break
|
||
except asyncio.CancelledError:
|
||
pass # 客户端断开,静默退
|
||
finally:
|
||
broker.unsubscribe(rid, q)
|
||
|
||
return StreamingResponse(
|
||
gen(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|
||
|
||
# ───────────── Files(user-rooted,不绑 task) ─────────────
|
||
|
||
@app.get("/v1/files", tags=["files"])
|
||
def list_files(
|
||
path: str = "",
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""列 user_root 下子目录条目 + 面包屑。`path` 留空 → user_root;
|
||
`../` / 绝对 → 400。dotfile(`.memory/` 等)一律隐藏。
|
||
"""
|
||
root = _load_user_root(user_id)
|
||
current = _safe_join(root, path)
|
||
entries, crumbs, exists = _enumerate_files(root, current)
|
||
return {
|
||
"root": _norm_path(str(root)),
|
||
"current": _rel_to(root, current),
|
||
"exists": exists,
|
||
"crumbs": crumbs,
|
||
"entries": entries,
|
||
}
|
||
|
||
@app.get("/v1/files/download", tags=["files"])
|
||
def download_file(
|
||
path: str,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""下载 user_root 下单个 regular file(目录 → 400 / 不存在 → 404)。"""
|
||
root = _load_user_root(user_id)
|
||
target = _safe_join(root, path)
|
||
if not target.exists():
|
||
raise HTTPException(404, f"file not found: {path}")
|
||
if not target.is_file():
|
||
raise HTTPException(400, f"not a file: {path}")
|
||
return FileResponse(path=str(target), filename=target.name)
|
||
|
||
@app.post("/v1/files/upload", tags=["files"])
|
||
async def upload_files(
|
||
path: str = Form(""),
|
||
files: list[UploadFile] = File(...),
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""multipart 多文件上传到 `<user_root>/<path>/`。
|
||
路径不存在自动 mkdir(parents=True);重名直接覆盖。
|
||
文件名严格校验(含 `/ \\ ..` 或为空 → 400)。
|
||
"""
|
||
root = _load_user_root(user_id)
|
||
dest_dir = _safe_join(root, path)
|
||
if dest_dir.exists() and not dest_dir.is_dir():
|
||
raise HTTPException(400, f"upload target is a file, not a directory: {path}")
|
||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
saved: list[dict] = []
|
||
for up in files or []:
|
||
raw_name = up.filename or ""
|
||
if (
|
||
not raw_name
|
||
or raw_name in (".", "..")
|
||
or "/" in raw_name or "\\" in raw_name
|
||
or any(part in (".", "..") for part in Path(raw_name).parts)
|
||
):
|
||
raise HTTPException(400, f"invalid filename: {raw_name!r}")
|
||
dest = dest_dir / raw_name
|
||
try:
|
||
dest.resolve().relative_to(root.resolve())
|
||
except ValueError:
|
||
raise HTTPException(400, f"path escapes user_root: {raw_name!r}")
|
||
data = await up.read()
|
||
dest.write_bytes(data)
|
||
saved.append({"name": raw_name, "size": len(data), "rel": _rel_to(root, dest)})
|
||
if not saved:
|
||
raise HTTPException(400, "no files uploaded")
|
||
return {"count": len(saved), "saved": saved}
|
||
|
||
@app.post("/v1/files/delete", tags=["files"])
|
||
def delete_file(
|
||
body: FileDeleteRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""删 user_root 下文件或**空**目录。非空目录 → 400(避免误操);root → 400。"""
|
||
root = _load_user_root(user_id)
|
||
target = _safe_join(root, body.path)
|
||
if target.resolve() == root.resolve():
|
||
raise HTTPException(400, "cannot delete user_root")
|
||
if not target.exists():
|
||
raise HTTPException(404, f"path not found: {body.path}")
|
||
try:
|
||
if target.is_dir():
|
||
target.rmdir() # 非空目录会触发 OSError
|
||
else:
|
||
target.unlink()
|
||
except OSError as e:
|
||
raise HTTPException(400, f"delete failed: {e}")
|
||
return {"ok": True, "path": body.path}
|
||
|
||
# ───────────── Export ─────────────
|
||
|
||
@app.get("/v1/tasks/{task_id}/export", tags=["export"])
|
||
def export_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""导出对话为 .docx,临时文件下载完后 BackgroundTask 删 tmp。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
has_msg = s.execute(
|
||
select(Message.message_id).where(Message.task_id == tid).limit(1)
|
||
).first()
|
||
if not has_msg:
|
||
raise HTTPException(400, "no messages to export")
|
||
|
||
fd, tmp_str = tempfile.mkstemp(suffix=".docx", prefix="zcbot-export-")
|
||
os.close(fd)
|
||
tmp_path = Path(tmp_str)
|
||
try:
|
||
from core.export_docx import export_chat_to_docx
|
||
export_chat_to_docx(tid, out_path=tmp_path)
|
||
except Exception as e:
|
||
tmp_path.unlink(missing_ok=True)
|
||
raise HTTPException(500, f"export failed: {type(e).__name__}: {e}")
|
||
|
||
return FileResponse(
|
||
path=str(tmp_path),
|
||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
filename=f"chat_{str(tid)[:8]}.docx",
|
||
background=BackgroundTask(tmp_path.unlink, missing_ok=True),
|
||
)
|
||
|
||
return app
|