1798 lines
75 KiB
Python
1798 lines
75 KiB
Python
"""FastAPI app: 纯 /v1 JSON API(2026-05-15 切换 — 详见 DESIGN §7.9)。
|
||
|
||
设计要点:
|
||
- 所有路由 `/v1/*` 前缀,响应 JSON;模板 / HTMX / 服务端 markdown 渲染全删
|
||
- SSE 事件 payload 是 JSON dict 而非 HTML 片段(`event: <type>` + `data: <json>`)
|
||
- Auth: PLATFORM_KEY → JWT 兑换(§7 D' 过渡形态,见 web/auth.py);OIDC 替换时只动 /v1/auth/login 内部
|
||
- 所有 /v1/tasks* 路由 Depends(require_user),按 user_id 隔离数据
|
||
- 豁免:/healthz、/docs、/openapi.json、/、/v1/auth/login、/static/*
|
||
- CORS allow_origins=["*"] 本地宽松;真发布按 platform 域名收紧
|
||
- `GET /` 302 → /static/dev.html(本地 dev SPA)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import tempfile
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime as _dt
|
||
from pathlib import Path
|
||
from typing import Any, Optional
|
||
from uuid import UUID, uuid4
|
||
|
||
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from pydantic import BaseModel
|
||
from sqlalchemy import func, select, update
|
||
from starlette.background import BackgroundTask
|
||
|
||
from core.paths import to_db_path
|
||
from core.storage import (
|
||
NoSubtaskError,
|
||
check_no_subtask,
|
||
session_scope,
|
||
)
|
||
from core.storage.models import Message, Task
|
||
from core.storage.utils import ensure_local_task_row
|
||
|
||
from .auth import (
|
||
AuthConfig,
|
||
UserCreateError,
|
||
create_user,
|
||
ensure_user_row,
|
||
make_require_user,
|
||
mint_token,
|
||
resolve_user_by_email,
|
||
)
|
||
from .broker import broker
|
||
from .sinks import WebEventSink
|
||
|
||
|
||
STATUS_FILTERS = ("active", "completed", "abandoned")
|
||
STATUS_WRITABLE = ("completed", "abandoned") # web 不让从 web 端切回 active(走 CLI)
|
||
ORDER_FIELDS = ("created_at", "updated_at", "name", "status")
|
||
ORDER_DEFAULT = "-created_at"
|
||
|
||
|
||
# ─────────────────────────── helpers ───────────────────────────
|
||
|
||
def _norm_path(p: str) -> str:
|
||
"""跨 OS 显示归一:backslash → forward slash。"""
|
||
return (p or "").replace("\\", "/")
|
||
|
||
|
||
def _iso(dt: Optional[Any]) -> Optional[str]:
|
||
return dt.isoformat() if dt else None
|
||
|
||
|
||
def _parse_ordering(s: Optional[str]) -> list:
|
||
"""DRF 风格 `ordering` 解析:逗号分隔多字段,`-` 前缀代表 desc。
|
||
|
||
allowlist 见 `ORDER_FIELDS`;非法字段静默丢弃。全部非法或空串 → `ORDER_DEFAULT`(`-created_at`)。
|
||
返回 sqlalchemy `order_by` 列表(可直接 `*expand`)。
|
||
"""
|
||
spec = (s or "").strip() or ORDER_DEFAULT
|
||
cols = []
|
||
for part in spec.split(","):
|
||
p = part.strip()
|
||
if not p:
|
||
continue
|
||
asc = True
|
||
if p.startswith("-"):
|
||
asc = False
|
||
p = p[1:]
|
||
if p in ORDER_FIELDS:
|
||
col = getattr(Task, p)
|
||
cols.append(col.asc() if asc else col.desc())
|
||
if not cols:
|
||
# 用户传了全无效字段 → fallback 默认
|
||
cols = [Task.created_at.desc()]
|
||
return cols
|
||
|
||
|
||
def _task_dict(row: Any, *, n_messages: Optional[int] = None) -> dict:
|
||
"""Task ORM row → API JSON dict。"""
|
||
d = {
|
||
"task_id": str(row.task_id),
|
||
"name": row.name or "",
|
||
"description": row.description or "",
|
||
"working_dir": _norm_path(row.working_dir or ""),
|
||
"status": row.status,
|
||
"skill": row.skill or "",
|
||
"model": row.model or "",
|
||
"model_profile": row.model_profile or "",
|
||
"tokens_prompt": row.tokens_prompt or 0,
|
||
"tokens_completion": row.tokens_completion or 0,
|
||
"tokens": (row.tokens_prompt or 0) + (row.tokens_completion or 0),
|
||
# 当前 run 状态(0004 schema 简化:原 runs 表合并入 task)
|
||
"run_status": row.run_status or "idle",
|
||
"run_error": row.run_error or None,
|
||
"created_at": _iso(getattr(row, "created_at", None)),
|
||
"updated_at": _iso(getattr(row, "updated_at", None)),
|
||
}
|
||
if n_messages is not None:
|
||
d["n_messages"] = n_messages
|
||
return d
|
||
|
||
|
||
# ─────────────────────── files helpers ───────────────────────
|
||
|
||
def _load_user_root(user_id: UUID) -> Path:
|
||
"""user_root = `<workspace>/users/<user_id>/`,所有 files API 的边界。
|
||
若目录尚未存在自动 mkdir(空 user 首次访问也能拿到根)。
|
||
"""
|
||
from core.agent_builder import resolve_workspace, user_root
|
||
ws = resolve_workspace(None)
|
||
return user_root(ws, user_id)
|
||
|
||
|
||
def _safe_join(root: Path, rel: str) -> Path:
|
||
"""归一用户路径到 absolute,并校验仍在 root 内。防 `../` / 绝对 path / symlink 越界。"""
|
||
rel = (rel or "").strip()
|
||
if not rel:
|
||
return root.resolve()
|
||
if rel[0] in ("/", "\\"):
|
||
raise HTTPException(400, f"absolute-style path not allowed: {rel!r}")
|
||
if Path(rel).is_absolute():
|
||
raise HTTPException(400, f"absolute path not allowed: {rel!r}")
|
||
target = (root / rel).resolve()
|
||
try:
|
||
target.relative_to(root.resolve())
|
||
except ValueError:
|
||
raise HTTPException(400, f"path escapes user_root: {rel!r}")
|
||
return target
|
||
|
||
|
||
def _rel_to(root: Path, target: Path) -> str:
|
||
try:
|
||
rel = target.resolve().relative_to(root.resolve()).as_posix()
|
||
except ValueError:
|
||
return ""
|
||
return "" if rel == "." else rel
|
||
|
||
|
||
def _enumerate_files(root: Path, current: Path) -> tuple[list[dict], list[dict], bool]:
|
||
"""枚举 current 下条目 + 拼面包屑。size raw bytes,mtime ISO 串(前端 humanize)。
|
||
Dotfile 一律隐藏(`.memory/` 等系统区不暴露给 UI,同 `/v1/folders` 约定)。
|
||
"""
|
||
entries: list[dict] = []
|
||
exists = current.exists()
|
||
if exists and current.is_dir():
|
||
try:
|
||
raw = sorted(current.iterdir(), key=lambda p: (p.is_file(), p.name.lower()))
|
||
except OSError:
|
||
raw = []
|
||
for p in raw:
|
||
if p.name.startswith("."):
|
||
continue
|
||
try:
|
||
st = p.stat()
|
||
except OSError:
|
||
continue
|
||
entries.append({
|
||
"name": p.name,
|
||
"is_dir": p.is_dir(),
|
||
"size": st.st_size if p.is_file() else None,
|
||
"mtime": _dt.fromtimestamp(st.st_mtime).isoformat(timespec="seconds"),
|
||
"rel": _rel_to(root, p),
|
||
})
|
||
cur_rel = _rel_to(root, current)
|
||
crumbs = [{"label": "/", "rel": ""}]
|
||
# cur_rel == "." 表示当前就在 root(target.relative_to(root) 返 Path(".")),
|
||
# 不该再追加一个无意义的 "." crumb
|
||
if cur_rel and cur_rel != ".":
|
||
acc = ""
|
||
for part in cur_rel.split("/"):
|
||
acc = f"{acc}/{part}" if acc else part
|
||
crumbs.append({"label": part, "rel": acc})
|
||
return entries, crumbs, exists
|
||
|
||
|
||
def _validate_transfer(
|
||
root: Path, paths: list[str], dest_dir: str,
|
||
) -> tuple[list[Path], Path]:
|
||
"""预检批量 transfer:解析所有源 + 目标,任意一项不合法即整批 abort(无 FS 副作用)。
|
||
|
||
返回 (sources, dest_dir_path)。不区分 copy / move(顶层 working_dir 闸由路由各自加)。
|
||
校验项:
|
||
- paths 非空;每个源在 user_root 内 + 存在;不能是 user_root 本身
|
||
- dest_dir 存在 + 是目录(可以是 user_root)
|
||
- 源不能与 dest_dir 相同(自移动)
|
||
- dest_dir 不能在源的子树内(不能把 a/ 搬进 a/b/)
|
||
- 源不能已是 dest_dir 直接子项(原地移动,no-op)
|
||
- 同批次源 leaf 名不能重复(俩 a.txt 会撞 dest/a.txt)
|
||
- dest_dir/<name> 不能已存在(整批 409,不静默覆盖)
|
||
"""
|
||
if not paths:
|
||
raise HTTPException(400, "paths is empty")
|
||
dest = _safe_join(root, dest_dir)
|
||
if not dest.exists():
|
||
raise HTTPException(404, f"dest_dir not found: {dest_dir!r}")
|
||
if not dest.is_dir():
|
||
raise HTTPException(400, f"dest_dir is not a directory: {dest_dir!r}")
|
||
dest_r = dest.resolve()
|
||
|
||
sources: list[Path] = []
|
||
seen_names: set[str] = set()
|
||
for p in paths:
|
||
src = _safe_join(root, p)
|
||
if not src.exists():
|
||
raise HTTPException(404, f"source not found: {p!r}")
|
||
src_r = src.resolve()
|
||
if src_r == root.resolve():
|
||
raise HTTPException(400, "cannot transfer user_root")
|
||
if src_r == dest_r:
|
||
raise HTTPException(400, f"source equals dest_dir: {p!r}")
|
||
# dest 在 src 子树内 → 自嵌套
|
||
try:
|
||
dest_r.relative_to(src_r)
|
||
raise HTTPException(
|
||
400, f"cannot transfer {p!r} into its own subtree"
|
||
)
|
||
except ValueError:
|
||
pass
|
||
# 已是 dest 直接子项 → no-op
|
||
if src.parent.resolve() == dest_r:
|
||
raise HTTPException(
|
||
400, f"{p!r} already directly under dest_dir"
|
||
)
|
||
name = src.name
|
||
if name in seen_names:
|
||
raise HTTPException(400, f"duplicate source leaf name in batch: {name!r}")
|
||
seen_names.add(name)
|
||
target = dest / name
|
||
if target.exists():
|
||
raise HTTPException(
|
||
409, f"target already exists: {_rel_to(root, target)!r}"
|
||
)
|
||
sources.append(src)
|
||
return sources, dest
|
||
|
||
|
||
# ─────────────────── BG run + SSE 帧格式 ───────────────────
|
||
|
||
def _run_agent_bg(
|
||
task_id: UUID, user_id: UUID, user_message: str,
|
||
image_variant: str = "", video_variant: str = "",
|
||
) -> None:
|
||
"""工作线程:`build_agent(resume=True)` → 装 WebEventSink + cancel_check → `agent.run` → 写 tasks.run_status。
|
||
|
||
sink 通过 broker.emit 桥事件回 asyncio loop;agent.run 是 sync,所以在 to_thread 跑。
|
||
user_id 必须从 JWT 那侧透传过来 —— 决定 memory_block 读哪个 per-user 子树。
|
||
cancel_check 桥 broker.is_cancelled,loop 在 stream chunk 间 + 工具调用之间 poll;
|
||
cancel 延迟 ~ 单 chunk 间隔(100ms 级);seedance 轮询间也读这个 cancel_check 用于
|
||
用户停止按钮(必须在 build_agent 阶段就传进去,因为 SeedanceTool ctor 持有它,
|
||
不能像以前那样 build_agent 返回后再赋 agent.cancel_check)。
|
||
`ok / cancelled` 收尾直接回 `idle`(不留持久标记);只有 error 是持久终态。
|
||
|
||
image_variant / video_variant:本 run 用哪个 image/video variant 装 tool(空 → yaml 第一个)。
|
||
随消息 POST 传进来,不入 DB —— UI 下拉的选择就跟在这一条消息上生效。
|
||
"""
|
||
from core.agent_builder import build_agent, sync_task_tokens
|
||
cancel_check = lambda tid=task_id: broker.is_cancelled(tid)
|
||
try:
|
||
broker.emit(task_id, {"type": "run_start"})
|
||
agent, session, sid, task_state, task_dir = build_agent(
|
||
session_id=str(task_id), resume=True, user_id=user_id,
|
||
image_variant=image_variant,
|
||
video_variant=video_variant,
|
||
cancel_check=cancel_check,
|
||
)
|
||
agent.sink = WebEventSink(broker, task_id)
|
||
agent.run(user_message)
|
||
sync_task_tokens(task_state)
|
||
# cancel 命中或正常完成 → run_status 回 idle(error 才持久)
|
||
with session_scope() as s:
|
||
s.execute(
|
||
update(Task).where(Task.task_id == task_id).values(
|
||
run_status="idle", run_error=None,
|
||
)
|
||
)
|
||
except Exception as e:
|
||
err = f"{type(e).__name__}: {e}"
|
||
broker.emit(task_id, {"type": "error", "msg": err})
|
||
try:
|
||
with session_scope() as s:
|
||
s.execute(
|
||
update(Task).where(Task.task_id == task_id).values(
|
||
run_status="error", run_error=err,
|
||
)
|
||
)
|
||
except Exception:
|
||
pass # 已 emit error 给前端,DB 写失败不放大噪声
|
||
finally:
|
||
broker.clear_cancel(task_id)
|
||
broker.close(task_id)
|
||
|
||
|
||
def _sse_event(event_type: str, payload: dict) -> bytes:
|
||
"""格式化 SSE 一帧:`event: <type>` + `data: <json single-line>`。"""
|
||
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||
return f"event: {event_type}\ndata: {body}\n\n".encode("utf-8")
|
||
|
||
|
||
def _resolve_model_profile(profile: str) -> tuple[str, str]:
|
||
"""校验 model_profile 并返回 (profile, model_id)。
|
||
|
||
传空 → cfg["default_model"]。profile 走 ModelCapabilities.load:
|
||
格式或文件错误一律 400。返 (profile_str, caps.model_id) —— 调 ensure_local_task_row
|
||
时 model_profile / model 两列一起填,保持现有 schema 双列约定。
|
||
"""
|
||
from core.agent_builder import load_config
|
||
from core.capabilities import ModelCapabilities
|
||
from core.paths import ROOT
|
||
|
||
cfg = load_config()
|
||
name = (profile or "").strip() or cfg["default_model"]
|
||
try:
|
||
caps = ModelCapabilities.load(name, ROOT / cfg["models_dir"])
|
||
except (FileNotFoundError, ValueError) as e:
|
||
raise HTTPException(400, f"invalid model_profile {name!r}: {e}")
|
||
return name, caps.model_id
|
||
|
||
|
||
def _list_image_variants() -> list[tuple[str, dict]]:
|
||
"""扫 config/media/doubao.yaml image 段 → [(variant_key, variant_cfg), ...]。
|
||
|
||
yaml 不存在或 image 段空 / 仅注释 → 返 []。不要求 ARK_API_KEY 已设 —— 仅纯
|
||
元数据列举,UI 拉这个画下拉。真正调用 seedream 时 agent_builder 那边再过
|
||
`ArkConfig.load()`(没 key → tool 不注册)。
|
||
"""
|
||
from core.paths import ROOT
|
||
import yaml as _yaml
|
||
|
||
p = ROOT / "config" / "media" / "doubao.yaml"
|
||
if not p.exists():
|
||
return []
|
||
try:
|
||
data = _yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||
except Exception:
|
||
return []
|
||
image_cfg = data.get("image") or {}
|
||
return [(k, v) for k, v in image_cfg.items() if isinstance(v, dict)]
|
||
|
||
|
||
def _resolve_image_model(variant: str) -> str:
|
||
"""校验 image_model variant key。
|
||
|
||
传空 → 返空(agent_builder fallback 到第一个 variant);传非空 → 必须存在
|
||
于 config/media/doubao.yaml image 段,否则 400。
|
||
"""
|
||
name = (variant or "").strip()
|
||
if not name:
|
||
return ""
|
||
variants = {k for k, _ in _list_image_variants()}
|
||
if name not in variants:
|
||
raise HTTPException(400, f"invalid image_model {name!r}; available: {sorted(variants)}")
|
||
return name
|
||
|
||
|
||
def _list_video_variants() -> list[tuple[str, dict]]:
|
||
"""扫 config/media/doubao.yaml video 段 → [(variant_key, variant_cfg), ...]。
|
||
|
||
与 _list_image_variants 同范式;空 video 段(未上线 / 注释掉)→ 返 [],UI 隐藏下拉。
|
||
"""
|
||
from core.paths import ROOT
|
||
import yaml as _yaml
|
||
|
||
p = ROOT / "config" / "media" / "doubao.yaml"
|
||
if not p.exists():
|
||
return []
|
||
try:
|
||
data = _yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||
except Exception:
|
||
return []
|
||
video_cfg = data.get("video") or {}
|
||
return [(k, v) for k, v in video_cfg.items() if isinstance(v, dict)]
|
||
|
||
|
||
def _resolve_video_model(variant: str) -> str:
|
||
"""校验 video_model variant key(同 _resolve_image_model 范式)。"""
|
||
name = (variant or "").strip()
|
||
if not name:
|
||
return ""
|
||
variants = {k for k, _ in _list_video_variants()}
|
||
if name not in variants:
|
||
raise HTTPException(400, f"invalid video_model {name!r}; available: {sorted(variants)}")
|
||
return name
|
||
|
||
|
||
# ────────────────────── Pydantic 请求体 ──────────────────────
|
||
|
||
class TaskCreateRequest(BaseModel):
|
||
name: str # 任务显示名(必填,DB 列 NOT NULL)
|
||
working_dir: str = "" # 工作目录名(可选,留空 → 用 name 作目录名)
|
||
description: str = ""
|
||
skill: str = ""
|
||
model_profile: str = "" # `family.variant`,留空 → cfg["default_model"];必须存在于 config/models/
|
||
|
||
|
||
class TaskPatchRequest(BaseModel):
|
||
status: Optional[str] = None
|
||
description: Optional[str] = None
|
||
name: Optional[str] = None
|
||
skill: Optional[str] = None
|
||
model_profile: Optional[str] = None # 切模型(c 模式 task 层 / A 粒度 — 下条 send 生效)
|
||
|
||
|
||
class MessageRequest(BaseModel):
|
||
content: str
|
||
# 该条消息触发的生图 / 生视频模型 variant key(config/media/doubao.yaml image/video 段)。
|
||
# 空 → 对应 tool 走 yaml 第一个 variant;非空 → 本次 run 装配指定 variant。
|
||
# 仅作用于本 run,不入 DB,UI 下拉的选择跟在消息 POST body 上。
|
||
image_model: str = ""
|
||
video_model: str = ""
|
||
|
||
|
||
class OptimizePromptRequest(BaseModel):
|
||
text: str
|
||
# 选择性传当前 UI 选中的 variant key,润色 meta-prompt 会把对应模型特性塞进去
|
||
# (让 LLM 知道下游 tool 偏好,润色出更贴合 seedream / seedance 等的 prompt)。
|
||
image_model: str = ""
|
||
video_model: str = ""
|
||
|
||
|
||
class FileDeleteRequest(BaseModel):
|
||
path: str
|
||
recursive: bool = False
|
||
|
||
|
||
class FileRenameRequest(BaseModel):
|
||
path: str # 被重命名的目录 / 文件,相对 user_root
|
||
new_name: str # 新的 leaf 名(不是路径),不含 / \ ..
|
||
|
||
|
||
class FileTransferRequest(BaseModel):
|
||
paths: list[str] # 多源,均相对 user_root
|
||
dest_dir: str = "" # 目标目录,相对 user_root,空 → user_root
|
||
|
||
|
||
class LoginRequest(BaseModel):
|
||
user_id: str
|
||
platform_key: str
|
||
|
||
|
||
class PasswordLoginRequest(BaseModel):
|
||
email: str
|
||
password: str
|
||
|
||
|
||
class AdminCreateUserRequest(BaseModel):
|
||
email: str
|
||
password: str
|
||
admin_token: str
|
||
|
||
|
||
# ────────────────────── App 工厂 ──────────────────────
|
||
|
||
# web/static 目录路径 — /static 静态挂载用,dev.html 也放这
|
||
_STATIC_DIR = Path(__file__).parent / "static"
|
||
|
||
|
||
def create_app() -> FastAPI:
|
||
# fail-fast:env 缺失直接抛,不裸跑无密
|
||
auth_cfg = AuthConfig.from_env()
|
||
require_user = make_require_user(auth_cfg)
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
broker.bind_loop(asyncio.get_running_loop())
|
||
# Skill 注册表启动时扫一次 — 文件系统静态,运行中不变;/v1/skills 直接读
|
||
from core.agent_builder import load_config, resolve_workspace
|
||
from core.paths import ROOT
|
||
from core.skills import SkillRegistry
|
||
_cfg = load_config()
|
||
app.state.skill_registry = SkillRegistry(ROOT / _cfg.get("skills_dir", "skills"))
|
||
# Stale-run reaper:上次进程 crash 留下的 "running" / "cancelling" 已无 BG 线程
|
||
# 继续,启动时标 error,让对应 task 重新可发消息(否则 gate 永挂)。
|
||
# TODO 真生产 multi-worker:换 heartbeat / lease,只 reap 自家 worker 的孤儿。
|
||
with session_scope() as s:
|
||
result = s.execute(
|
||
update(Task)
|
||
.where(Task.run_status.in_(("running", "cancelling")))
|
||
.values(
|
||
run_status="error",
|
||
run_error="server restarted before run finished",
|
||
)
|
||
)
|
||
if result.rowcount:
|
||
print(f"[startup] reaped {result.rowcount} stale active run(s)")
|
||
|
||
# Sandbox pool(§7.5):仅当 ZCBOT_SANDBOX_BACKEND=docker 时启用。
|
||
# 启动钩子:① init_pool(创建 docker network + pool 实例)② shutdown_all 清
|
||
# 前驱孤儿(上次进程留下的 zcbot-sandbox-* 容器,内存 _last_active 为空,
|
||
# 全清重启)③ 后台 reaper task,每 60s 跑 reap_idle。
|
||
sandbox_backend = os.getenv("ZCBOT_SANDBOX_BACKEND", "host").lower()
|
||
sandbox_reaper_task = None
|
||
if sandbox_backend == "docker":
|
||
from core.sandbox import init_pool
|
||
from core.sandbox.check import detect_fs_quota
|
||
workspace = resolve_workspace(None, _cfg)
|
||
user_root_base = workspace / "users"
|
||
# §7.5 #4 fs quota 探测:不阻塞启动(应用层周期扫描已有),仅打 WARN
|
||
# 提醒外部用户开放前必须升级到 xfs prjquota / ext4 project / zfs。
|
||
try:
|
||
level, msg = detect_fs_quota(user_root_base.resolve())
|
||
print(f"[startup] {'[ok]' if level == 'ok' else '[warn]'} {msg}")
|
||
except Exception as e:
|
||
print(f"[startup] [warn] fs quota detect failed: {type(e).__name__}: {e}")
|
||
try:
|
||
pool = init_pool(user_root_base)
|
||
removed = pool.shutdown_all()
|
||
if removed:
|
||
print(f"[startup] swept {len(removed)} stale sandbox container(s)")
|
||
|
||
async def _reaper() -> None:
|
||
loop = asyncio.get_running_loop()
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(60)
|
||
removed = await loop.run_in_executor(None, pool.reap_idle)
|
||
if removed:
|
||
print(f"[reaper] reaped {len(removed)} idle sandbox container(s)")
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except Exception as e:
|
||
print(f"[reaper] error: {type(e).__name__}: {e}")
|
||
|
||
sandbox_reaper_task = asyncio.create_task(_reaper(), name="sandbox-reaper")
|
||
app.state.sandbox_pool = pool
|
||
except Exception as e:
|
||
# ensure_network / docker CLI 不可用 → fail-fast。Stage C 协议:任一
|
||
# hardening 缺失视为部署未完成,不退化到 host(否则误以为有沙盒实则在裸跑)。
|
||
raise RuntimeError(
|
||
f"ZCBOT_SANDBOX_BACKEND=docker but sandbox init failed: {e}"
|
||
)
|
||
try:
|
||
yield
|
||
finally:
|
||
if sandbox_reaper_task is not None:
|
||
sandbox_reaper_task.cancel()
|
||
try:
|
||
await sandbox_reaper_task
|
||
except (asyncio.CancelledError, Exception):
|
||
pass
|
||
if sandbox_backend == "docker":
|
||
pool = getattr(app.state, "sandbox_pool", None)
|
||
if pool is not None:
|
||
try:
|
||
pool.shutdown_all()
|
||
except Exception as e:
|
||
print(f"[shutdown] sandbox shutdown_all error: {type(e).__name__}: {e}")
|
||
|
||
app = FastAPI(
|
||
title="zcbot api",
|
||
version="0.8",
|
||
description=(
|
||
"zcbot 后端 — /v1 JSON API + SSE。Auth: PLATFORM_KEY → JWT(§7 D' 过渡)。"
|
||
"本地 dev SPA: /static/dev.html。"
|
||
),
|
||
lifespan=lifespan,
|
||
)
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 本地宽松,部署 platform 时按域名收紧
|
||
allow_credentials=False,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
if _STATIC_DIR.is_dir():
|
||
app.mount("/static", StaticFiles(directory=str(_STATIC_DIR)), name="static")
|
||
|
||
# ───────────── Misc ─────────────
|
||
|
||
@app.get("/", include_in_schema=False)
|
||
def root():
|
||
# 本地 dev SPA;Swagger UI 仍在 /docs
|
||
return RedirectResponse(url="/static/dev.html", status_code=302)
|
||
|
||
@app.get("/healthz", tags=["misc"])
|
||
def healthz():
|
||
return {"status": "ok"}
|
||
|
||
@app.get("/v1/models", tags=["misc"])
|
||
def list_models(user_id: UUID = Depends(require_user)):
|
||
"""列出所有可用 LLM 模型(扫 config/models/*.yaml)。
|
||
|
||
前端顶栏 / 新建对话框的模型下拉拉这个。is_default 标记 cfg["default_model"]
|
||
命中项。开发期不缓存,每次扫一遍(几个文件 IO);改 YAML 立即生效。
|
||
"""
|
||
from core.agent_builder import load_config
|
||
from core.capabilities import ModelCapabilities
|
||
from core.paths import ROOT
|
||
import yaml as _yaml
|
||
cfg = load_config()
|
||
default = cfg["default_model"]
|
||
models_dir = ROOT / cfg["models_dir"]
|
||
|
||
out: list[dict] = []
|
||
if models_dir.is_dir():
|
||
for path in sorted(models_dir.glob("*.yaml")):
|
||
try:
|
||
data = _yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||
except Exception:
|
||
continue
|
||
family = data.get("family") or path.stem
|
||
for variant in (data.get("variants") or {}).keys():
|
||
profile = f"{family}.{variant}"
|
||
try:
|
||
caps = ModelCapabilities.load(profile, models_dir)
|
||
except (ValueError, FileNotFoundError):
|
||
continue
|
||
out.append({
|
||
"profile": profile,
|
||
"display_name": caps.display_name or profile,
|
||
"family": caps.family,
|
||
"variant": caps.variant,
|
||
"thinking_mode": caps.thinking_mode,
|
||
"is_default": profile == default,
|
||
})
|
||
return {"models": out}
|
||
|
||
@app.get("/v1/image_models", tags=["misc"])
|
||
def list_image_models(user_id: UUID = Depends(require_user)):
|
||
"""图像生成模型清单(扫 config/media/doubao.yaml image 段)。
|
||
|
||
前端顶栏第二个下拉拉这个;空列表 → 没配 image variant,UI 隐藏下拉。
|
||
`is_default` 标第一个 variant(=agent_builder fallback 目标)。开发期不缓存,
|
||
改 YAML 加新 variant 立即生效。
|
||
"""
|
||
variants = _list_image_variants()
|
||
out: list[dict] = []
|
||
for i, (key, cfg) in enumerate(variants):
|
||
out.append({
|
||
"variant": key,
|
||
"display_name": cfg.get("display_name") or key,
|
||
"model_id": cfg.get("model_id") or "",
|
||
"price_cny_per_image": cfg.get("price_cny_per_image"),
|
||
"is_default": i == 0,
|
||
})
|
||
return {"models": out}
|
||
|
||
@app.get("/v1/video_models", tags=["misc"])
|
||
def list_video_models(user_id: UUID = Depends(require_user)):
|
||
"""视频生成模型清单(扫 config/media/doubao.yaml video 段)。
|
||
|
||
与 /v1/image_models 同范式;空列表 → UI 隐藏第三下拉。展示信息包括默认分辨率
|
||
与 token 单价(¥/Mtok 文生视频路径),方便用户在下拉选项里直接看到 cost 量级。
|
||
"""
|
||
variants = _list_video_variants()
|
||
out: list[dict] = []
|
||
for i, (key, cfg) in enumerate(variants):
|
||
out.append({
|
||
"variant": key,
|
||
"display_name": cfg.get("display_name") or key,
|
||
"model_id": cfg.get("model_id") or "",
|
||
"default_resolution": cfg.get("default_resolution"),
|
||
"default_duration": cfg.get("default_duration"),
|
||
"default_ratio": cfg.get("default_ratio"),
|
||
"price_cny_per_mtoken_text2video": cfg.get("price_cny_per_mtoken_text2video"),
|
||
"price_cny_per_mtoken_video2video": cfg.get("price_cny_per_mtoken_video2video"),
|
||
"is_default": i == 0,
|
||
})
|
||
return {"models": out}
|
||
|
||
# ───────────── Auth ─────────────
|
||
|
||
@app.post("/v1/auth/login", tags=["auth"])
|
||
def login(body: LoginRequest):
|
||
"""platform_key 校验通过 → 签 JWT(user_id 作为 sub)。
|
||
|
||
platform_key 错 → 403;user_id 非 UUID → 400。
|
||
user_id 未存在则幂等创建 users 行(避免下游 FK 失败)。
|
||
platform 服务端用此入口注入指定 user_id;dev SPA 走 /login_password。
|
||
"""
|
||
if body.platform_key != auth_cfg.platform_key:
|
||
raise HTTPException(403, "invalid platform_key")
|
||
try:
|
||
uid = UUID(body.user_id)
|
||
except (ValueError, TypeError):
|
||
raise HTTPException(400, f"invalid user_id (must be UUID): {body.user_id!r}")
|
||
ensure_user_row(uid)
|
||
token, exp = mint_token(auth_cfg, uid)
|
||
return {
|
||
"token": token,
|
||
"expires_at": _dt.fromtimestamp(exp).isoformat(),
|
||
"user_id": str(uid),
|
||
"ttl_seconds": auth_cfg.ttl_seconds,
|
||
}
|
||
|
||
@app.post("/v1/auth/admin/create_user", tags=["auth"])
|
||
def admin_create_user(body: AdminCreateUserRequest):
|
||
"""管理员发用户(dev SPA 登录页右下角入口)。
|
||
|
||
- `ZCBOT_ADMIN_TOKEN` env 未设 → 503,功能关闭
|
||
- `admin_token` 不匹配 → 403(不细分 "未设" / "错了",防探测)
|
||
- email 不合法 / password 太短 → 400
|
||
- email 已存在 → 409
|
||
- 成功 → `{"user_id": ..., "email": ...}`,前端提示 "已创建,请登录"
|
||
|
||
不签 token、不自动登录 —— 管理员发完用户用户自己登,逻辑清晰。
|
||
"""
|
||
if auth_cfg.admin_token is None:
|
||
raise HTTPException(503, "admin create_user disabled (ZCBOT_ADMIN_TOKEN not set)")
|
||
if body.admin_token != auth_cfg.admin_token:
|
||
raise HTTPException(403, "invalid admin_token")
|
||
try:
|
||
uid, email = create_user(email=body.email, password=body.password)
|
||
except UserCreateError as ex:
|
||
if ex.code in ("invalid_email", "weak_password"):
|
||
raise HTTPException(400, ex.message)
|
||
if ex.code == "email_taken":
|
||
raise HTTPException(409, "email already exists")
|
||
raise HTTPException(500, f"create_user failed: {ex.message}")
|
||
return {"user_id": str(uid), "email": email}
|
||
|
||
@app.post("/v1/auth/login_password", tags=["auth"])
|
||
def login_password(body: PasswordLoginRequest):
|
||
"""邮箱密码登录(dev SPA 给同事 / 自己试用)。
|
||
|
||
- users.email 未命中 / password_hash 为空 / bcrypt 校验失败 → 一律 403
|
||
(不细分错因,防探测用户存在性)
|
||
- 命中 → 直接用 DB 里现成 user_id 签 JWT(不 ensure_user_row,行已在 `user add` 时建)
|
||
- 发用户:`.venv/Scripts/python.exe main.py user add --email X --password Y`;
|
||
撤用户:`DELETE FROM users WHERE email=...`(先 DELETE 该 user 的 tasks)
|
||
"""
|
||
hit = resolve_user_by_email(body.email, body.password)
|
||
if hit is None:
|
||
raise HTTPException(403, "invalid email or password")
|
||
uid, email = hit
|
||
token, exp = mint_token(auth_cfg, uid)
|
||
return {
|
||
"token": token,
|
||
"expires_at": _dt.fromtimestamp(exp).isoformat(),
|
||
"user_id": str(uid),
|
||
"email": email,
|
||
"ttl_seconds": auth_cfg.ttl_seconds,
|
||
}
|
||
|
||
# ───────────── Tasks CRUD ─────────────
|
||
|
||
@app.post("/v1/tasks", status_code=201, tags=["tasks"])
|
||
def create_task(body: TaskCreateRequest, user_id: UUID = Depends(require_user)):
|
||
"""新建 task。
|
||
|
||
- `name` 必填(任务显示名,DB 列 NOT NULL,UI 列表 / 标题用)
|
||
- `working_dir` 可选(留空 → 用 name 作目录名);同 working_dir 多 task 共享同目录(§7.1)
|
||
- name / working_dir 都过 validate_task_name(简单名,无 `/\\..`,非 `.` 起头,≤255)
|
||
- 前缀嵌套(no-subtask,同 user 内)→ 409
|
||
"""
|
||
from core.agent_builder import InvalidTaskName, resolve_workspace, validate_task_name, working_dir_from_name
|
||
try:
|
||
name = validate_task_name(body.name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"name 不合法: {e}")
|
||
# working_dir 留空 → fallback 用 name
|
||
wd_raw = (body.working_dir or "").strip()
|
||
wd_name = wd_raw if wd_raw else name
|
||
try:
|
||
wd_name = validate_task_name(wd_name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"working_dir 不合法: {e}")
|
||
description = body.description.strip()
|
||
skill = body.skill.strip()
|
||
|
||
tid = uuid4()
|
||
ws = resolve_workspace(None)
|
||
fs_dir = working_dir_from_name(ws, user_id, wd_name)
|
||
fs_dir_db = to_db_path(fs_dir)
|
||
|
||
try:
|
||
check_no_subtask(fs_dir_db, user_id=user_id)
|
||
except NoSubtaskError as e:
|
||
raise HTTPException(409, str(e))
|
||
|
||
# 工作目录立刻建出(同 working_dir 多 task 共享,exist_ok=True)
|
||
fs_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
profile, model_id = _resolve_model_profile(body.model_profile)
|
||
ensure_local_task_row(
|
||
task_id=tid, name=name, working_dir=fs_dir_db, skill=skill,
|
||
description=description, user_id=user_id,
|
||
model=model_id, model_profile=profile,
|
||
)
|
||
with session_scope() as s:
|
||
row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
|
||
return _task_dict(row, n_messages=0)
|
||
|
||
@app.get("/v1/tasks", tags=["tasks"])
|
||
def list_tasks_route(
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
status: Optional[str] = None,
|
||
skill: Optional[str] = None,
|
||
working_dir: Optional[str] = None,
|
||
q: Optional[str] = None,
|
||
ordering: Optional[str] = None,
|
||
run_status: Optional[str] = None,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""列出当前 user 的 task,分页 + 多维筛选 + 排序。
|
||
|
||
- `page` ≥ 1(1-based);`page_size` 1–100(超界 clamp)
|
||
- `status` 在 active/completed/abandoned;非法值静默忽略
|
||
- `skill` 精确匹配(空忽略)
|
||
- `working_dir` 末段目录名(如 `水泥申报`);后端自动拼 `workspace/users/<uid>/<name>` 比对
|
||
- `q` 模糊搜索 name + description(ILIKE,大小写不敏感)
|
||
- `run_status` 逗号分隔,allowlist `idle/running/cancelling/error`;非法值静默忽略
|
||
(dev SPA 拉同 wd 活跃 task 用,通常 `running,cancelling`)
|
||
- `ordering` DRF 风格,逗号分隔,`-field` 倒序;allowlist `created_at/updated_at/name/status`;
|
||
非法字段静默忽略;**默认 `-created_at`**(创建时间倒序)
|
||
返回标准分页壳 `{page, page_size, count, results}` —— count 供前端算总页数。
|
||
"""
|
||
# clamp + sanitize
|
||
page = max(1, page)
|
||
page_size = max(1, min(page_size, 100))
|
||
status = status if status in STATUS_FILTERS else None
|
||
skill = (skill or "").strip() or None
|
||
wd_name = (working_dir or "").strip() or None
|
||
q_text = (q or "").strip() or None
|
||
rs_allowed = ("idle", "running", "cancelling", "error")
|
||
run_status_set = {
|
||
s.strip() for s in (run_status or "").split(",") if s.strip() in rs_allowed
|
||
} or None
|
||
|
||
# 组装 WHERE
|
||
conditions = [Task.user_id == user_id]
|
||
if status:
|
||
conditions.append(Task.status == status)
|
||
if skill:
|
||
conditions.append(Task.skill == skill)
|
||
if wd_name:
|
||
# 末段 → 完整 db form。同 working_dir 多 task 共享时,这是命中入口。
|
||
wd_db = f"workspace/users/{user_id}/{wd_name}"
|
||
conditions.append(Task.working_dir == wd_db)
|
||
if run_status_set:
|
||
conditions.append(Task.run_status.in_(run_status_set))
|
||
if q_text:
|
||
pat = f"%{q_text}%"
|
||
conditions.append(Task.name.ilike(pat) | Task.description.ilike(pat))
|
||
|
||
offset = (page - 1) * page_size
|
||
|
||
with session_scope() as s:
|
||
cnt = s.execute(
|
||
select(func.count()).select_from(Task).where(*conditions)
|
||
).scalar_one() or 0
|
||
|
||
rows = s.execute(
|
||
select(Task).where(*conditions)
|
||
.order_by(*_parse_ordering(ordering))
|
||
.limit(page_size).offset(offset)
|
||
).scalars().all()
|
||
|
||
tids = [r.task_id for r in rows]
|
||
msg_counts = (
|
||
dict(s.execute(
|
||
select(Message.task_id, func.count())
|
||
.where(Message.task_id.in_(tids))
|
||
.group_by(Message.task_id)
|
||
).all())
|
||
if tids else {}
|
||
)
|
||
|
||
return {
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"count": int(cnt),
|
||
"results": [
|
||
_task_dict(r, n_messages=msg_counts.get(r.task_id, 0))
|
||
for r in rows
|
||
],
|
||
}
|
||
|
||
@app.get("/v1/tasks/{task_id}", tags=["tasks"])
|
||
def get_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""单 task meta(不含 messages;走 /messages 拿)。跨 user → 404。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task).where(Task.task_id == tid, Task.user_id == user_id)
|
||
).scalar_one_or_none()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
n = s.execute(
|
||
select(func.count()).select_from(Message).where(Message.task_id == tid)
|
||
).scalar_one()
|
||
return _task_dict(row, n_messages=n)
|
||
|
||
@app.get("/v1/folders", tags=["folders"])
|
||
def list_folders(user_id: UUID = Depends(require_user)):
|
||
"""列出当前 user 的工作目录(`workspace/users/<uid>/` 下非 dotfile 子目录)。
|
||
供新建 task 时自动补全 / 选已有目录用。FS 是 source of truth(也含手动创建
|
||
但还无关联 task 的目录)。每项带 n_tasks(关联 task 数)+ last_used(最近使用 ISO)。
|
||
排序:有 last_used 的按降序,无 last_used 的排最后,同列 by name asc。
|
||
"""
|
||
from core.agent_builder import resolve_workspace, user_root
|
||
ws = resolve_workspace(None)
|
||
root = user_root(ws, user_id)
|
||
|
||
folder_names: list[str] = []
|
||
if root.is_dir():
|
||
for p in sorted(root.iterdir(), key=lambda x: x.name.lower()):
|
||
if p.is_dir() and not p.name.startswith("."):
|
||
folder_names.append(p.name)
|
||
|
||
folders: list[dict] = []
|
||
if folder_names:
|
||
with session_scope() as s:
|
||
for name in folder_names:
|
||
db_form = f"workspace/users/{user_id}/{name}"
|
||
stat = s.execute(
|
||
select(func.count(), func.max(Task.updated_at))
|
||
.where(Task.user_id == user_id, Task.working_dir == db_form)
|
||
).first()
|
||
n = int((stat[0] if stat else 0) or 0)
|
||
lu = stat[1] if stat else None
|
||
folders.append({
|
||
"name": name,
|
||
"n_tasks": n,
|
||
"last_used": _iso(lu),
|
||
})
|
||
|
||
folders.sort(key=lambda f: f["name"])
|
||
folders.sort(key=lambda f: f["last_used"] or "", reverse=True)
|
||
return {"folders": folders}
|
||
|
||
# ───────────── Skills ─────────────
|
||
|
||
@app.get("/v1/skills", tags=["skills"])
|
||
def list_skills(user_id: UUID = Depends(require_user)):
|
||
"""列出当前可用的 skill(智能体类型),供新建 task 时下拉选择。
|
||
|
||
skill registry 在 lifespan 启动时扫一次(`skills/<name>/SKILL.md` frontmatter),
|
||
运行中不变;返 name + description 给 UI 渲染。排序按 name 升序(registry 已 sorted)。
|
||
"""
|
||
reg = getattr(app.state, "skill_registry", None)
|
||
skills = list(reg.skills.values()) if reg else []
|
||
return {
|
||
"skills": [
|
||
{"name": s.name, "description": s.description}
|
||
for s in skills
|
||
]
|
||
}
|
||
|
||
@app.delete("/v1/tasks/{task_id}", status_code=204, tags=["tasks"])
|
||
def delete_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""硬删除:DELETE DB 行(messages / usage_events CASCADE)。
|
||
|
||
若 working_dir 已无任何 task 引用且 FS 目录为空 → best-effort rmdir
|
||
清孤儿(非空 / 不存在 / 没权限 都静默跳过 —— working_dir 视为可重生视图)。
|
||
外部 --working-dir(DB 串绝对)不动,只清 ROOT 内相对路径。跨 user → 404。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
from sqlalchemy import delete as _delete
|
||
from core.paths import from_db_path
|
||
with session_scope() as s:
|
||
wd_db = s.execute(
|
||
select(Task.working_dir).where(
|
||
Task.task_id == tid, Task.user_id == user_id,
|
||
)
|
||
).scalar_one_or_none()
|
||
if wd_db is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
s.execute(
|
||
_delete(Task).where(Task.task_id == tid, Task.user_id == user_id)
|
||
)
|
||
remaining = s.execute(
|
||
select(func.count()).select_from(Task).where(
|
||
Task.user_id == user_id, Task.working_dir == wd_db,
|
||
)
|
||
).scalar_one() or 0
|
||
if wd_db and not remaining and not Path(wd_db).is_absolute():
|
||
try:
|
||
from_db_path(wd_db).rmdir()
|
||
except OSError:
|
||
pass
|
||
return None # 204
|
||
|
||
@app.patch("/v1/tasks/{task_id}", tags=["tasks"])
|
||
def patch_task(
|
||
task_id: str,
|
||
body: TaskPatchRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""更新 task 字段。`status` 仅允许 completed/abandoned(active 走 CLI 切回)。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
updates: dict[str, Any] = {}
|
||
if body.status is not None:
|
||
if body.status not in STATUS_WRITABLE:
|
||
raise HTTPException(
|
||
400, f"invalid status {body.status!r}; allowed: {STATUS_WRITABLE}"
|
||
)
|
||
updates["status"] = body.status
|
||
if body.description is not None:
|
||
updates["description"] = body.description
|
||
if body.skill is not None:
|
||
updates["skill"] = body.skill
|
||
if body.name is not None:
|
||
from core.agent_builder import InvalidTaskName, validate_task_name
|
||
try:
|
||
updates["name"] = validate_task_name(body.name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"name 不合法: {e}")
|
||
if body.model_profile is not None:
|
||
# 切模型:校验后双列同更(profile + model_id)。下条 send 才生效 — 当前
|
||
# in-flight run 不受影响(build_agent resume 时下次重读)。
|
||
profile, model_id = _resolve_model_profile(body.model_profile)
|
||
updates["model_profile"] = profile
|
||
updates["model"] = model_id
|
||
if not updates:
|
||
raise HTTPException(400, "no fields to update")
|
||
with session_scope() as s:
|
||
result = s.execute(
|
||
update(Task)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.values(**updates)
|
||
)
|
||
if result.rowcount == 0:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
|
||
n = s.execute(
|
||
select(func.count()).select_from(Message).where(Message.task_id == tid)
|
||
).scalar_one()
|
||
return _task_dict(row, n_messages=n)
|
||
|
||
# ───────────── Messages ─────────────
|
||
|
||
def _assert_owns_task(s, tid: UUID, user_id: UUID) -> None:
|
||
ok = s.execute(
|
||
select(Task.task_id).where(Task.task_id == tid, Task.user_id == user_id)
|
||
).first()
|
||
if ok is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
|
||
@app.get("/v1/tasks/{task_id}/messages", tags=["messages"])
|
||
def list_messages(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""task 历史消息(idx 升序);LiteLLM 原 payload 透传给前端,自行渲染。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
rows = s.execute(
|
||
select(
|
||
Message.idx, Message.payload, Message.tokens_in,
|
||
Message.tokens_out, Message.model_profile, Message.created_at,
|
||
).where(Message.task_id == tid).order_by(Message.idx)
|
||
).all()
|
||
return {
|
||
"messages": [
|
||
{
|
||
"idx": r.idx,
|
||
"payload": dict(r.payload),
|
||
"tokens_in": r.tokens_in,
|
||
"tokens_out": r.tokens_out,
|
||
"model_profile": r.model_profile, # 0006:assistant 行非空,标产生该 msg 的模型
|
||
"created_at": _iso(r.created_at),
|
||
}
|
||
for r in rows
|
||
]
|
||
}
|
||
|
||
@app.post("/v1/tasks/{task_id}/messages", status_code=202, tags=["messages"])
|
||
async def post_message(
|
||
task_id: str,
|
||
body: MessageRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""发消息 + 起 BG run。返 `{events_url}`,客户端立刻订阅 SSE 拿流式。
|
||
|
||
单活 run:`SELECT … FOR UPDATE` 锁 task 行 + 活跃状态检查 + 标 running,
|
||
全收进一个事务挡住"用户连点 send 两条消息"导致两个 BG 线程争 `messages.idx`。
|
||
tasks.run_status in ('running','cancelling') → 409;'error' 走起新 run 时清掉
|
||
(跟 ok / cancelled 一样视为可重启)。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
content = (body.content or "").strip()
|
||
if not content:
|
||
raise HTTPException(400, "empty content")
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.run_status)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.with_for_update()
|
||
).first()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
if row.run_status in ("running", "cancelling"):
|
||
raise HTTPException(
|
||
409,
|
||
f"task already has an active run (status={row.run_status}); "
|
||
f"wait for it to finish or cancel",
|
||
)
|
||
s.execute(
|
||
update(Task).where(Task.task_id == tid).values(
|
||
run_status="running", run_error=None,
|
||
)
|
||
)
|
||
# image_model / video_model 在 POST 时校验,避免 BG 线程里抛在 sink 之外难追;空串透传不查 yaml。
|
||
image_variant = _resolve_image_model(body.image_model)
|
||
video_variant = _resolve_video_model(body.video_model)
|
||
broker.start(tid) # 清上一轮 done 标记,新订阅者才能看到流式
|
||
# commit 后 lock 释放;BG 线程接管(sink 通过 broker 把 event 桥回 asyncio loop)
|
||
asyncio.create_task(asyncio.to_thread(
|
||
_run_agent_bg, tid, user_id, content, image_variant, video_variant,
|
||
))
|
||
return {"events_url": f"/v1/tasks/{tid}/events"}
|
||
|
||
# ───────────── Cancel current run ─────────────
|
||
|
||
@app.post("/v1/tasks/{task_id}/cancel", status_code=202, tags=["tasks"])
|
||
def cancel_task(
|
||
task_id: str,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""向当前 task 的活跃 run 发协作式 cancel 信号。
|
||
- 单活 run 形态下"取消当前活动"语义无歧义;客户端只需 task_id
|
||
- 校验 task 归属 user;否则 404
|
||
- tasks.run_status 不是 `running` → 409(idle / cancelling / error 都不能 cancel)
|
||
- 标 `cancelling`(过渡态),BG 线程 loop 在 stream chunk 间 + 工具调用之间 poll 看见即退;
|
||
退出后 finally 写终态(正常→idle,异常→error)
|
||
- LLM 走 streaming,cancel 延迟 ~ 单 chunk 间隔(100ms 级)
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.run_status)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.with_for_update()
|
||
).first()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
if row.run_status != "running":
|
||
raise HTTPException(
|
||
409,
|
||
f"task not running (run_status={row.run_status}); cannot cancel",
|
||
)
|
||
s.execute(
|
||
update(Task).where(Task.task_id == tid).values(run_status="cancelling")
|
||
)
|
||
broker.request_cancel(tid)
|
||
return {"ok": True, "task_id": str(tid), "run_status": "cancelling"}
|
||
|
||
# ───────────── Clear conversation ─────────────
|
||
|
||
@app.post("/v1/tasks/{task_id}/clear", tags=["messages"])
|
||
def clear_messages(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""清空当前 task 全部 messages,token 累计 / cost / run_error 归零。
|
||
|
||
同 working_dir 下的 FS 文件不动(沿用 task delete 的"FS 视图可重生"心智 —
|
||
中间产物保留,模型重起对话时可继续基于已有素材推进)。
|
||
usage_events 不动:那是用户级账户级用量记账,不该被对话清理影响。
|
||
|
||
- 活跃 run(running / cancelling)期间拒绝:409(先 cancel)
|
||
- error 状态可清:顺手 run_status='idle' + run_error=None
|
||
- 跨 user → 404
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
from sqlalchemy import delete as _delete
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.run_status)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.with_for_update()
|
||
).first()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
if row.run_status in ("running", "cancelling"):
|
||
raise HTTPException(
|
||
409,
|
||
f"task has an active run (status={row.run_status}); "
|
||
f"cancel it first",
|
||
)
|
||
s.execute(_delete(Message).where(Message.task_id == tid))
|
||
s.execute(
|
||
update(Task).where(Task.task_id == tid).values(
|
||
tokens_prompt=0,
|
||
tokens_completion=0,
|
||
cost_cny=0,
|
||
run_status="idle",
|
||
run_error=None,
|
||
)
|
||
)
|
||
task_row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
|
||
d = _task_dict(task_row, n_messages=0)
|
||
return d
|
||
|
||
# ───────────── Prompt optimize(辅助 LLM 调用,不入 messages)─────────────
|
||
|
||
@app.post("/v1/tasks/{task_id}/optimize_prompt", tags=["messages"])
|
||
def optimize_prompt(
|
||
task_id: str,
|
||
body: OptimizePromptRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""用 task 当前 model 润色用户草稿 prompt;返回优化后的文本。
|
||
|
||
- 同步调用(短文本,3-5s),非 stream
|
||
- 不入 messages 表;**不**累计到 tasks.tokens_prompt/completion(顶栏数字保持
|
||
只反映主对话)。usage_events 单独写一行 kind="prompt_optimize",方便对账
|
||
+ 按 kind GROUP BY 评估"这个按钮值不值"
|
||
- 不与主对话 run 互斥(它不写 messages,无 idx 竞争)— 用户在 LLM 流式
|
||
回复期间也可润色下一条草稿
|
||
- image_model 影响 meta-prompt 里给 LLM 的下游 tool 提示;不动 DB
|
||
"""
|
||
from decimal import Decimal
|
||
from core.agent_builder import load_config
|
||
from core.capabilities import ModelCapabilities
|
||
from core.llm import LLM
|
||
from core.paths import ROOT
|
||
from core.storage.models import UsageEvent
|
||
from core.storage.usage import USD_TO_CNY
|
||
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
text = (body.text or "").strip()
|
||
if not text:
|
||
raise HTTPException(400, "empty text")
|
||
if len(text) > 4000:
|
||
raise HTTPException(400, "text too long (>4000 chars)")
|
||
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.model_profile)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
).first()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
task_model_profile = row.model_profile or ""
|
||
|
||
cfg = load_config()
|
||
chosen_profile = task_model_profile or cfg["default_model"]
|
||
try:
|
||
caps = ModelCapabilities.load(chosen_profile, ROOT / cfg["models_dir"])
|
||
except (FileNotFoundError, ValueError) as e:
|
||
raise HTTPException(500, f"invalid task model_profile {chosen_profile!r}: {e}")
|
||
|
||
# 收集下游 tool 上下文:对话模型 display_name + 当前选中 image/video variant 元数据
|
||
chat_model_display = caps.display_name or chosen_profile
|
||
image_variant_hint = ""
|
||
img_variant = (body.image_model or "").strip()
|
||
if img_variant:
|
||
for k, v in _list_image_variants():
|
||
if k == img_variant:
|
||
name = v.get("display_name") or k
|
||
sz = v.get("default_size") or "2048x2048"
|
||
image_variant_hint = (
|
||
f"\n下游生图工具:{name}(默认尺寸 {sz},支持中英文 prompt,"
|
||
f"擅长写实/插画/构图描述)。若用户意图涉及画面/封面/插图,"
|
||
f"润色后的文本要给出适合该模型的画面细节(主体/风格/光线/构图)。"
|
||
)
|
||
video_variant_hint = ""
|
||
vid_variant = (body.video_model or "").strip()
|
||
if vid_variant:
|
||
for k, v in _list_video_variants():
|
||
if k == vid_variant:
|
||
name = v.get("display_name") or k
|
||
res = v.get("default_resolution") or "720p"
|
||
dur = v.get("default_duration") or 5
|
||
video_variant_hint = (
|
||
f"\n下游生视频工具:{name}(默认 {res} / {dur}s / 16:9)。"
|
||
f"若用户意图涉及视频/动画/动起来,润色后的文本要补全 "
|
||
f"主体在做什么(运动)+ 镜头怎么动 + 场景 + 风格,而非静态画面描述。"
|
||
)
|
||
|
||
meta_prompt = (
|
||
f"你的任务是润色用户输入的草稿,使之成为一个清晰、完整、可执行的 prompt。\n"
|
||
f"当前对话模型:{chat_model_display}。{image_variant_hint}{video_variant_hint}\n\n"
|
||
f"规则:\n"
|
||
f"1. 只输出润色后的文本本身,不要任何解释、前后缀、引号、markdown 代码块包裹\n"
|
||
f"2. 保留用户原始语言(中文/英文)\n"
|
||
f"3. 补全模糊点(主体、目标、风格、约束),但不要无中生有改变用户意图\n"
|
||
f"4. 长度合理 — 简短诉求润色后也应当简洁,不要堆砌\n\n"
|
||
f"用户草稿:\n{text}"
|
||
)
|
||
|
||
llm = LLM(caps)
|
||
try:
|
||
response = llm.chat(
|
||
messages=[{"role": "user", "content": meta_prompt}],
|
||
tools=None,
|
||
)
|
||
except Exception as e:
|
||
raise HTTPException(502, f"llm call failed: {type(e).__name__}: {e}")
|
||
|
||
try:
|
||
optimized = (response.choices[0].message.content or "").strip()
|
||
except Exception:
|
||
raise HTTPException(502, "llm response missing content")
|
||
if not optimized:
|
||
raise HTTPException(502, "llm returned empty optimization")
|
||
|
||
usage = getattr(response, "usage", None)
|
||
prompt_tokens = int(getattr(usage, "prompt_tokens", 0) or 0)
|
||
completion_tokens = int(getattr(usage, "completion_tokens", 0) or 0)
|
||
try:
|
||
from litellm import completion_cost
|
||
cost_usd_raw = completion_cost(completion_response=response)
|
||
cost_usd = Decimal(str(cost_usd_raw)) if cost_usd_raw else Decimal("0")
|
||
except Exception:
|
||
cost_usd = Decimal("0")
|
||
cost_cny = (cost_usd * USD_TO_CNY).quantize(Decimal("0.000001"))
|
||
|
||
try:
|
||
with session_scope() as s:
|
||
s.add(UsageEvent(
|
||
user_id=user_id,
|
||
task_id=tid,
|
||
message_id=None,
|
||
kind="prompt_optimize",
|
||
model_profile=chosen_profile,
|
||
units={
|
||
"tokens_in": prompt_tokens,
|
||
"tokens_out": completion_tokens,
|
||
"usd_to_cny": float(USD_TO_CNY),
|
||
"image_model_hint": img_variant or "",
|
||
"video_model_hint": vid_variant or "",
|
||
},
|
||
cost_cny=cost_cny,
|
||
))
|
||
except Exception as e:
|
||
# 记账失败不阻塞返结果 — 用户拿到润色文本要紧,事后人工补
|
||
print(f"[optimize_prompt] usage record failed: {type(e).__name__}: {e}", flush=True)
|
||
|
||
return {
|
||
"optimized": optimized,
|
||
"model_profile": chosen_profile,
|
||
"tokens_in": prompt_tokens,
|
||
"tokens_out": completion_tokens,
|
||
"cost_cny": float(cost_cny),
|
||
}
|
||
|
||
# ───────────── SSE events ─────────────
|
||
|
||
@app.get("/v1/tasks/{task_id}/events", tags=["tasks"])
|
||
async def stream_events(
|
||
task_id: str,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""SSE 流。订阅当前 task 的活动 event(单活 run 形态下无歧义)。
|
||
事件类型:run_start / llm_start / text / tool_call / tool_result /
|
||
llm_end / cancelled / error / done。data 是 JSON dict(已剔除 `type` 字段,
|
||
移到 event 名)。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
run_status = s.execute(
|
||
select(Task.run_status).where(Task.task_id == tid)
|
||
).scalar_one()
|
||
|
||
# 重连保护:若 task 不在活跃态(进程重启 / reaper 已收尾 / 自然结束),
|
||
# 直接吐 done 关流。否则 broker 进程内队列空,客户端会无限挂在 ping 上。
|
||
is_active = run_status in ("running", "cancelling")
|
||
|
||
async def gen():
|
||
yield b": connected\nretry: 3000\n\n"
|
||
if not is_active:
|
||
yield _sse_event("done", {})
|
||
return
|
||
q = broker.subscribe(tid)
|
||
try:
|
||
while True:
|
||
try:
|
||
ev = await asyncio.wait_for(q.get(), timeout=30.0)
|
||
except asyncio.TimeoutError:
|
||
yield b": ping\n\n"
|
||
continue
|
||
ev_type = ev.get("type", "msg")
|
||
payload = {k: v for k, v in ev.items() if k != "type"}
|
||
yield _sse_event(ev_type, payload)
|
||
if ev_type in ("done", "error"):
|
||
break
|
||
except asyncio.CancelledError:
|
||
pass # 客户端断开,静默退
|
||
finally:
|
||
broker.unsubscribe(tid, q)
|
||
|
||
return StreamingResponse(
|
||
gen(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|
||
|
||
# ───────────── Files(user-rooted,不绑 task) ─────────────
|
||
|
||
@app.get("/v1/files", tags=["files"])
|
||
def list_files(
|
||
path: str = "",
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""列 user_root 下子目录条目 + 面包屑。`path` 留空 → user_root;
|
||
`../` / 绝对 → 400。dotfile(`.memory/` 等)一律隐藏。
|
||
"""
|
||
root = _load_user_root(user_id)
|
||
current = _safe_join(root, path)
|
||
entries, crumbs, exists = _enumerate_files(root, current)
|
||
return {
|
||
"root": _norm_path(str(root)),
|
||
"current": _rel_to(root, current),
|
||
"exists": exists,
|
||
"crumbs": crumbs,
|
||
"entries": entries,
|
||
}
|
||
|
||
@app.get("/v1/files/download", tags=["files"])
|
||
def download_file(
|
||
path: str,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""下载 user_root 下单个 regular file(目录 → 400 / 不存在 → 404)。"""
|
||
root = _load_user_root(user_id)
|
||
target = _safe_join(root, path)
|
||
if not target.exists():
|
||
raise HTTPException(404, f"file not found: {path}")
|
||
if not target.is_file():
|
||
raise HTTPException(400, f"not a file: {path}")
|
||
# workspace 文件可变, 禁浏览器启发式缓存 (RFC 7234 默认能缓数小时)
|
||
# 否则文件改了 SPA 预览还是旧内容
|
||
# (Starlette FileResponse 不实现 304, 总是 200 全量; workspace 文件小, 可接受)
|
||
return FileResponse(
|
||
path=str(target),
|
||
filename=target.name,
|
||
headers={"Cache-Control": "no-cache"},
|
||
)
|
||
|
||
@app.post("/v1/files/upload", tags=["files"])
|
||
async def upload_files(
|
||
path: str = Form(""),
|
||
files: list[UploadFile] = File(...),
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""multipart 多文件上传到 `<user_root>/<path>/`。
|
||
路径不存在自动 mkdir(parents=True);重名直接覆盖。
|
||
文件名严格校验(含 `/ \\ ..` 或为空 → 400)。
|
||
"""
|
||
root = _load_user_root(user_id)
|
||
dest_dir = _safe_join(root, path)
|
||
if dest_dir.exists() and not dest_dir.is_dir():
|
||
raise HTTPException(400, f"upload target is a file, not a directory: {path}")
|
||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
saved: list[dict] = []
|
||
for up in files or []:
|
||
raw_name = up.filename or ""
|
||
if (
|
||
not raw_name
|
||
or raw_name in (".", "..")
|
||
or "/" in raw_name or "\\" in raw_name
|
||
or any(part in (".", "..") for part in Path(raw_name).parts)
|
||
):
|
||
raise HTTPException(400, f"invalid filename: {raw_name!r}")
|
||
dest = dest_dir / raw_name
|
||
try:
|
||
dest.resolve().relative_to(root.resolve())
|
||
except ValueError:
|
||
raise HTTPException(400, f"path escapes user_root: {raw_name!r}")
|
||
data = await up.read()
|
||
dest.write_bytes(data)
|
||
saved.append({"name": raw_name, "size": len(data), "rel": _rel_to(root, dest)})
|
||
if not saved:
|
||
raise HTTPException(400, "no files uploaded")
|
||
return {"count": len(saved), "saved": saved}
|
||
|
||
@app.post("/v1/files/delete", tags=["files"])
|
||
def delete_file(
|
||
body: FileDeleteRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""删 user_root 下文件或目录。
|
||
|
||
- `recursive=False`(默认):目录必须为空(`rmdir`),非空 → 400
|
||
- `recursive=True`:`shutil.rmtree`;若目标是顶层目录且被某 task.working_dir
|
||
引用 → 409,提示先 DELETE task(避免 DB 还引用、FS artifacts 已清的错位)
|
||
- 顶层空目录 / 子级空目录无论 recursive 与否都可删:task.working_dir 字段不动,
|
||
下次 build_agent 按需 mkdir 重建,FS 目录视为可重生
|
||
- root → 400;不存在 → 404
|
||
"""
|
||
root = _load_user_root(user_id)
|
||
target = _safe_join(root, body.path)
|
||
if target.resolve() == root.resolve():
|
||
raise HTTPException(400, "cannot delete user_root")
|
||
if not target.exists():
|
||
raise HTTPException(404, f"path not found: {body.path}")
|
||
|
||
if target.is_dir() and body.recursive:
|
||
is_top_level = target.parent.resolve() == root.resolve()
|
||
if is_top_level:
|
||
db_form = to_db_path(target)
|
||
with session_scope() as s:
|
||
n = s.execute(
|
||
select(func.count()).select_from(Task).where(
|
||
Task.user_id == user_id,
|
||
Task.working_dir == db_form,
|
||
)
|
||
).scalar_one() or 0
|
||
if n:
|
||
raise HTTPException(
|
||
409,
|
||
f"该顶层目录正被 {n} 个 task 引用,不能递归删除;"
|
||
f"请先 DELETE task,再清残留文件",
|
||
)
|
||
|
||
try:
|
||
if target.is_dir():
|
||
if body.recursive:
|
||
import shutil
|
||
shutil.rmtree(target)
|
||
else:
|
||
target.rmdir() # 非空目录会触发 OSError
|
||
else:
|
||
target.unlink()
|
||
except OSError as e:
|
||
raise HTTPException(400, f"delete failed: {e}")
|
||
return {"ok": True, "path": body.path}
|
||
|
||
@app.post("/v1/files/rename", tags=["files"])
|
||
def rename_path(
|
||
body: FileRenameRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""重命名 user_root 下文件或目录(任意深度)。
|
||
|
||
- `path` 必填,指被重命名对象;不能为 user_root
|
||
- `new_name` 是新 leaf 名(`validate_task_name`:非空 / 不含 `/\\..` / 非 dotfile / ≤255);
|
||
不是路径,parent 自动取自原 path
|
||
- 目标 sibling `<parent>/<new_name>` 不能已存在(防覆盖)
|
||
- **path 是顶层目录**(user_root 直接子项,且为目录)→ DB-aware:
|
||
* 同事务内 `SELECT ... FOR UPDATE` 锁该目录对应 task;任一 run_status 在
|
||
running/cancelling → 409(避免 BG 线程握旧路径而 DB 已指新路径)
|
||
* `check_no_subtask(new_db, exclude=被改名 tids)` 防止改名后跟其它 task 形成嵌套
|
||
* `UPDATE tasks SET working_dir=new_db WHERE task_id IN (...)` 先写 DB
|
||
* 再 `os.rename` FS;失败 → 抛错 → session_scope 回滚 DB
|
||
* 唯一不一致窗口是 "FS 已改名 + commit 阶段失败"(PG 单事务 commit 极少失败)
|
||
- 非顶层(子目录 / 文件)→ 纯 FS rename,不动 DB
|
||
"""
|
||
from core.agent_builder import InvalidTaskName, validate_task_name
|
||
|
||
root = _load_user_root(user_id)
|
||
target = _safe_join(root, body.path)
|
||
if target.resolve() == root.resolve():
|
||
raise HTTPException(400, "cannot rename user_root")
|
||
if not target.exists():
|
||
raise HTTPException(404, f"path not found: {body.path}")
|
||
|
||
try:
|
||
new_name = validate_task_name(body.new_name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"new_name 不合法: {e}")
|
||
|
||
if new_name == target.name:
|
||
raise HTTPException(400, f"new_name 与原名相同: {new_name!r}")
|
||
|
||
new_target = target.parent / new_name
|
||
if new_target.exists():
|
||
raise HTTPException(
|
||
409, f"target already exists: {_rel_to(root, new_target)!r}"
|
||
)
|
||
|
||
is_top_level_dir = (
|
||
target.is_dir() and target.parent.resolve() == root.resolve()
|
||
)
|
||
|
||
if not is_top_level_dir:
|
||
try:
|
||
target.rename(new_target)
|
||
except OSError as e:
|
||
raise HTTPException(400, f"rename failed: {e}")
|
||
return {
|
||
"ok": True,
|
||
"old": body.path,
|
||
"new": _rel_to(root, new_target),
|
||
"tasks_updated": 0,
|
||
}
|
||
|
||
# 顶层目录:DB-aware
|
||
old_db = to_db_path(target)
|
||
new_db = to_db_path(new_target)
|
||
with session_scope() as s:
|
||
rows = s.execute(
|
||
select(Task.task_id, Task.run_status)
|
||
.where(Task.user_id == user_id, Task.working_dir == old_db)
|
||
.with_for_update()
|
||
).all()
|
||
tids = [r.task_id for r in rows]
|
||
active = [
|
||
str(r.task_id)[:8] for r in rows
|
||
if r.run_status in ("running", "cancelling")
|
||
]
|
||
if active:
|
||
raise HTTPException(
|
||
409,
|
||
f"folder has active run(s) on task(s) {active}; "
|
||
f"cancel before renaming",
|
||
)
|
||
try:
|
||
check_no_subtask(new_db, user_id=user_id, exclude_task_ids=tids)
|
||
except NoSubtaskError as e:
|
||
raise HTTPException(409, str(e))
|
||
if tids:
|
||
s.execute(
|
||
update(Task)
|
||
.where(Task.task_id.in_(tids))
|
||
.values(working_dir=new_db)
|
||
)
|
||
try:
|
||
target.rename(new_target)
|
||
except OSError as e:
|
||
# 抛 HTTPException 也会让 session_scope 走 except 分支回滚 UPDATE
|
||
raise HTTPException(400, f"FS rename failed: {e}")
|
||
|
||
return {
|
||
"ok": True,
|
||
"old": body.path,
|
||
"new": _rel_to(root, new_target),
|
||
"tasks_updated": len(tids),
|
||
}
|
||
|
||
@app.post("/v1/files/copy", tags=["files"])
|
||
def copy_files(
|
||
body: FileTransferRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""批量拷贝 paths → dest_dir/<name>(目录递归)。
|
||
|
||
- 不覆盖(任一目标已存在 → 409)
|
||
- 不能拷到自己 / 自身子树
|
||
- 顶层目录(可能是某 task 的 working_dir)可以拷:新副本无 task 关联,不动 DB
|
||
- 部分失败语义:任一 FS 拷贝抛错 → 抛 HTTPException,**前面已成功的拷贝保留**
|
||
(无 FS 事务可回滚;预检通过后通常不会失败,失败也是磁盘满 / 权限这类不能恢复的)
|
||
"""
|
||
import shutil
|
||
root = _load_user_root(user_id)
|
||
sources, dest = _validate_transfer(root, body.paths, body.dest_dir)
|
||
transferred: list[dict] = []
|
||
for src in sources:
|
||
target = dest / src.name
|
||
try:
|
||
if src.is_dir():
|
||
shutil.copytree(src, target)
|
||
else:
|
||
shutil.copy2(src, target)
|
||
except OSError as e:
|
||
raise HTTPException(
|
||
500,
|
||
f"copy failed at {src.name!r}: {e} "
|
||
f"(已成功 {len(transferred)} 项,剩余未处理)",
|
||
)
|
||
transferred.append({
|
||
"old": _rel_to(root, src),
|
||
"new": _rel_to(root, target),
|
||
})
|
||
return {"ok": True, "count": len(transferred), "transferred": transferred}
|
||
|
||
@app.post("/v1/files/move", tags=["files"])
|
||
def move_files(
|
||
body: FileTransferRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""批量移动 paths → dest_dir/<name>。
|
||
|
||
- 不覆盖、不自嵌套(同 /copy)
|
||
- **顶层目录是某 task 的 working_dir → 409**,维持 "working_dir = 顶层目录" invariant
|
||
(允许的话 task working_dir 沉到子目录会让 rename 顶层的 DB-aware 逻辑失效;
|
||
用户想归档:先 DELETE task)
|
||
- 拷贝(`/copy`)无此限制,因为新副本无 task 关联
|
||
- 部分失败:同 /copy,前面成功的不回滚(`shutil.move` 失败几乎只发生在
|
||
跨卷拷贝中断,workspace 都在同一磁盘下罕见)
|
||
"""
|
||
import shutil
|
||
root = _load_user_root(user_id)
|
||
sources, dest = _validate_transfer(root, body.paths, body.dest_dir)
|
||
|
||
# 顶层目录-是-某 task.working_dir → 闸
|
||
top_level_dir_srcs = [
|
||
s for s in sources
|
||
if s.is_dir() and s.parent.resolve() == root.resolve()
|
||
]
|
||
if top_level_dir_srcs:
|
||
db_forms = [to_db_path(s) for s in top_level_dir_srcs]
|
||
with session_scope() as s:
|
||
rows = s.execute(
|
||
select(Task.working_dir, func.count())
|
||
.where(
|
||
Task.user_id == user_id,
|
||
Task.working_dir.in_(db_forms),
|
||
)
|
||
.group_by(Task.working_dir)
|
||
).all()
|
||
occupied = {wd: n for wd, n in rows}
|
||
if occupied:
|
||
# 反查 db_form → src.name 给报错文案
|
||
form2name = {to_db_path(s): s.name for s in top_level_dir_srcs}
|
||
names = ", ".join(
|
||
f"{form2name[wd]!r}({n} 个 task)"
|
||
for wd, n in occupied.items()
|
||
)
|
||
raise HTTPException(
|
||
409,
|
||
f"以下顶层目录正被 task 引用,不能移动:{names};"
|
||
f"请先删 task,或改用复制",
|
||
)
|
||
|
||
transferred: list[dict] = []
|
||
for src in sources:
|
||
target = dest / src.name
|
||
try:
|
||
shutil.move(str(src), str(target))
|
||
except OSError as e:
|
||
raise HTTPException(
|
||
500,
|
||
f"move failed at {src.name!r}: {e} "
|
||
f"(已成功 {len(transferred)} 项,剩余未处理)",
|
||
)
|
||
transferred.append({
|
||
"old": _rel_to(root, src),
|
||
"new": _rel_to(root, target),
|
||
})
|
||
return {"ok": True, "count": len(transferred), "transferred": transferred}
|
||
|
||
# ───────────── Export ─────────────
|
||
|
||
@app.get("/v1/tasks/{task_id}/export", tags=["export"])
|
||
def export_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""导出对话为 .docx,临时文件下载完后 BackgroundTask 删 tmp。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
has_msg = s.execute(
|
||
select(Message.message_id).where(Message.task_id == tid).limit(1)
|
||
).first()
|
||
if not has_msg:
|
||
raise HTTPException(400, "no messages to export")
|
||
|
||
fd, tmp_str = tempfile.mkstemp(suffix=".docx", prefix="zcbot-export-")
|
||
os.close(fd)
|
||
tmp_path = Path(tmp_str)
|
||
try:
|
||
from core.export_docx import export_chat_to_docx
|
||
export_chat_to_docx(tid, out_path=tmp_path)
|
||
except Exception as e:
|
||
tmp_path.unlink(missing_ok=True)
|
||
raise HTTPException(500, f"export failed: {type(e).__name__}: {e}")
|
||
|
||
return FileResponse(
|
||
path=str(tmp_path),
|
||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
filename=f"chat_{str(tid)[:8]}.docx",
|
||
background=BackgroundTask(tmp_path.unlink, missing_ok=True),
|
||
)
|
||
|
||
return app
|