zcbot/web/app.py

2947 lines
132 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""FastAPI app: 纯 /v1 JSON API(2026-05-15 切换 — 详见 DESIGN §7.9)。
设计要点:
- 所有路由 `/v1/*` 前缀,响应 JSON;模板 / HTMX / 服务端 markdown 渲染全删
- SSE 事件 payload 是 JSON dict 而非 HTML 片段(`event: <type>` + `data: <json>`)
- Auth: PLATFORM_KEY → JWT 兑换(§7 D' 过渡形态,见 web/auth.py);OIDC 替换时只动 /v1/auth/login 内部
- 所有 /v1/tasks* 路由 Depends(require_user),按 user_id 隔离数据
- 豁免:/healthz、/docs、/openapi.json、/、/v1/auth/login、/static/*
- CORS allow_origins=["*"] 本地宽松;真发布按 platform 域名收紧
- `GET /` 302 → /static/dev.html(本地 dev SPA)
"""
from __future__ import annotations
import asyncio
import json
import mimetypes
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from datetime import datetime as _dt
from pathlib import Path
from typing import Any, Optional
from uuid import UUID, uuid4
try:
import resource # Unix only;Windows dev 无此模块,RSS 监控自动降级跳过
except ImportError: # pragma: no cover - Windows
resource = None
from fastapi import Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse
from pydantic import BaseModel
from sqlalchemy import BigInteger, case, cast, func, select, update
from starlette.background import BackgroundTask
from core import __version__
from core.paths import from_db_path, to_db_path
from core.storage import (
NoSubtaskError,
check_no_subtask,
session_scope,
)
from core.storage.models import Message, ScheduledJob, Task, UsageEvent
from core.storage.utils import ensure_local_task_row
from .auth import (
AuthConfig,
UserCreateError,
change_password,
create_user,
ensure_user_row,
get_user_profile,
make_require_admin,
make_require_user,
mint_token,
resolve_user_by_email,
)
from .admin import register_admin_routes
from .broker import broker
from .sinks import WebEventSink
from .static_files import NoCacheStaticFiles
STATUS_FILTERS = ("active", "completed", "abandoned")
STATUS_WRITABLE = ("completed", "abandoned") # web 不让从 web 端切回 active(走 CLI)
ORDER_FIELDS = ("created_at", "updated_at", "name", "status")
ORDER_DEFAULT = "-created_at"
# pptx→PDF 预览:按解析后的 pptx 绝对路径加锁,防同一文件并发重复转换(DESIGN §8.3)。
_pptx_preview_locks: dict[str, asyncio.Lock] = {}
def _pptx_lock_for(abs_path: str) -> asyncio.Lock:
lock = _pptx_preview_locks.get(abs_path)
if lock is None:
lock = _pptx_preview_locks[abs_path] = asyncio.Lock()
return lock
# ─────────────────────────── helpers ───────────────────────────
def _norm_path(p: str) -> str:
"""跨 OS 显示归一:backslash → forward slash。"""
return (p or "").replace("\\", "/")
def _iso(dt: Optional[Any]) -> Optional[str]:
return dt.isoformat() if dt else None
def _outline_snippet(text: Optional[str], maxlen: int = 48) -> str:
"""user 消息正文 → 目录条目片段:取首个非空行,压平首尾空白,截断到 maxlen。"""
if not text:
return ""
for line in text.splitlines():
line = line.strip()
if line:
return line[:maxlen]
return text.strip()[:maxlen]
def _parse_ordering(s: Optional[str]) -> list:
"""DRF 风格 `ordering` 解析:逗号分隔多字段,`-` 前缀代表 desc。
allowlist 见 `ORDER_FIELDS`;非法字段静默丢弃。全部非法或空串 → `ORDER_DEFAULT`(`-created_at`)。
返回 sqlalchemy `order_by` 列表(可直接 `*expand`)。
"""
spec = (s or "").strip() or ORDER_DEFAULT
cols = []
for part in spec.split(","):
p = part.strip()
if not p:
continue
asc = True
if p.startswith("-"):
asc = False
p = p[1:]
if p in ORDER_FIELDS:
col = getattr(Task, p)
cols.append(col.asc() if asc else col.desc())
if not cols:
# 用户传了全无效字段 → fallback 默认
cols = [Task.created_at.desc()]
return cols
def _usage_aggregates(s: Any, tids: list) -> dict:
"""按 task_id 批量聚合 usage_events:真实成本 + chat token + 缓存命中。
单查询 GROUP BY(复用列表接口 msg_counts 同款批量范式,无 N+1)。on-the-fly 现算,
不落 tasks 列 —— 对所有历史 task 即时准确,免回填。
- cost_cny:全 kind(chat+image+video)合计 = task 真实花费
- tokens_in/out + cache_hit:仅 chat。**三者同源 usage_events**,故缓存命中率
`cache_hit / tokens_in` 恒 ≤ 100%;不能拿 `tasks.tokens_prompt` 当分母 ——
那列会被「清空对话」重置而 usage_events 不重置,跨源相除会算出 >100% 的怪值。
返回 {task_id: {"cost_cny": float, "tokens_in": int, "tokens_out": int,
"tokens_cache_hit": int}}。
"""
if not tids:
return {}
chat = UsageEvent.kind == "chat"
tin_col = cast(UsageEvent.units["tokens_in"].astext, BigInteger)
tout_col = cast(UsageEvent.units["tokens_out"].astext, BigInteger)
hit_col = cast(UsageEvent.units["cache_hit_tokens"].astext, BigInteger)
rows = s.execute(
select(
UsageEvent.task_id,
func.coalesce(func.sum(UsageEvent.cost_cny), 0),
func.coalesce(func.sum(tin_col).filter(chat), 0),
func.coalesce(func.sum(tout_col).filter(chat), 0),
func.coalesce(func.sum(hit_col).filter(chat), 0),
)
.where(UsageEvent.task_id.in_(tids))
.group_by(UsageEvent.task_id)
).all()
return {
tid: {
"cost_cny": float(cost or 0),
"tokens_in": int(tin or 0),
"tokens_out": int(tout or 0),
"tokens_cache_hit": int(hit or 0),
}
for tid, cost, tin, tout, hit in rows
}
def _task_dict(
row: Any,
*,
n_messages: Optional[int] = None,
usage: Optional[dict] = None,
) -> dict:
"""Task ORM row → API JSON dict。
`usage`(可选)= `_usage_aggregates` 算出的本 task 概要,带真实成本与缓存命中;
缺省回退到 tasks.cost_cny 列(多为 0)与 0 命中,前端据此显 ¥ / 缓存命中率。
"""
u = usage or {}
# token 总量优先取 usage_events 聚合(用量 source-of-truth,且与 cache_hit 同源 →
# 命中率分母一致、恒 ≤100%);无 usage 时回退 tasks 概览列。tasks.tokens_prompt 会被
# 「清空对话」重置,不能与 usage_events 的 cache_hit 跨源相除。
tokens_prompt = int(u["tokens_in"]) if "tokens_in" in u else (row.tokens_prompt or 0)
tokens_completion = int(u["tokens_out"]) if "tokens_out" in u else (row.tokens_completion or 0)
d = {
"task_id": str(row.task_id),
"name": row.name or "",
"description": row.description or "",
"working_dir": _norm_path(row.working_dir or ""),
"status": row.status,
"skill": row.skill or "",
"channel": getattr(row, "channel", None) or "web",
"model": row.model or "",
"model_profile": row.model_profile or "",
"tokens_prompt": tokens_prompt,
"tokens_completion": tokens_completion,
"tokens": tokens_prompt + tokens_completion,
# 缓存命中 token(chat 前缀缓存)+ 真实成本(已按缓存折价,见 usage.py)。
# on-the-fly 聚合;未传 usage 时回退列/0。
"tokens_cache_hit": int(u.get("tokens_cache_hit", 0)),
"cost_cny": float(u["cost_cny"]) if "cost_cny" in u else float(row.cost_cny or 0),
# 当前 run 状态(0004 schema 简化:原 runs 表合并入 task)
"run_status": row.run_status or "idle",
"run_error": row.run_error or None,
"created_at": _iso(getattr(row, "created_at", None)),
"updated_at": _iso(getattr(row, "updated_at", None)),
}
if n_messages is not None:
d["n_messages"] = n_messages
return d
# ─────────────────────── files helpers ───────────────────────
def _load_user_root(user_id: UUID) -> Path:
"""user_root = `<workspace>/users/<user_id>/`,所有 files API 的边界。
若目录尚未存在自动 mkdir(空 user 首次访问也能拿到根)。
"""
from core.agent_builder import resolve_workspace, user_root
ws = resolve_workspace(None)
return user_root(ws, user_id)
def _safe_join(root: Path, rel: str) -> Path:
"""归一用户路径到 absolute,并校验仍在 root 内。防 `../` / 绝对 path / symlink 越界。"""
rel = (rel or "").strip()
if not rel:
return root.resolve()
if rel[0] in ("/", "\\"):
raise HTTPException(400, f"absolute-style path not allowed: {rel!r}")
if Path(rel).is_absolute():
raise HTTPException(400, f"absolute path not allowed: {rel!r}")
target = (root / rel).resolve()
try:
target.relative_to(root.resolve())
except ValueError:
raise HTTPException(400, f"path escapes user_root: {rel!r}")
return target
def _rel_to(root: Path, target: Path) -> str:
try:
rel = target.resolve().relative_to(root.resolve()).as_posix()
except ValueError:
return ""
return "" if rel == "." else rel
def _enumerate_files(root: Path, current: Path) -> tuple[list[dict], list[dict], bool]:
"""枚举 current 下条目 + 拼面包屑。size raw bytes,mtime ISO 串(前端 humanize)。
Dotfile 一律隐藏(`.memory/` 等系统区不暴露给 UI,同 `/v1/folders` 约定)。
"""
entries: list[dict] = []
exists = current.exists()
if exists and current.is_dir():
try:
raw = sorted(current.iterdir(), key=lambda p: (p.is_file(), p.name.lower()))
except OSError:
raw = []
for p in raw:
if p.name.startswith("."):
continue
try:
st = p.stat()
except OSError:
continue
entries.append({
"name": p.name,
"is_dir": p.is_dir(),
"size": st.st_size if p.is_file() else None,
"mtime": _dt.fromtimestamp(st.st_mtime).isoformat(timespec="seconds"),
"rel": _rel_to(root, p),
})
cur_rel = _rel_to(root, current)
crumbs = [{"label": "/", "rel": ""}]
# cur_rel == "." 表示当前就在 root(target.relative_to(root) 返 Path(".")),
# 不该再追加一个无意义的 "." crumb
if cur_rel and cur_rel != ".":
acc = ""
for part in cur_rel.split("/"):
acc = f"{acc}/{part}" if acc else part
crumbs.append({"label": part, "rel": acc})
return entries, crumbs, exists
def _validate_transfer(
root: Path, paths: list[str], dest_dir: str,
) -> tuple[list[Path], Path]:
"""预检批量 transfer:解析所有源 + 目标,任意一项不合法即整批 abort(无 FS 副作用)。
返回 (sources, dest_dir_path)。不区分 copy / move(顶层 working_dir 闸由路由各自加)。
校验项:
- paths 非空;每个源在 user_root 内 + 存在;不能是 user_root 本身
- dest_dir 存在 + 是目录(可以是 user_root)
- 源不能与 dest_dir 相同(自移动)
- dest_dir 不能在源的子树内(不能把 a/ 搬进 a/b/)
- 源不能已是 dest_dir 直接子项(原地移动,no-op)
- 同批次源 leaf 名不能重复(俩 a.txt 会撞 dest/a.txt)
- dest_dir/<name> 不能已存在(整批 409,不静默覆盖)
"""
if not paths:
raise HTTPException(400, "paths is empty")
dest = _safe_join(root, dest_dir)
if not dest.exists():
raise HTTPException(404, f"dest_dir not found: {dest_dir!r}")
if not dest.is_dir():
raise HTTPException(400, f"dest_dir is not a directory: {dest_dir!r}")
dest_r = dest.resolve()
sources: list[Path] = []
seen_names: set[str] = set()
for p in paths:
src = _safe_join(root, p)
if not src.exists():
raise HTTPException(404, f"source not found: {p!r}")
src_r = src.resolve()
if src_r == root.resolve():
raise HTTPException(400, "cannot transfer user_root")
if src_r == dest_r:
raise HTTPException(400, f"source equals dest_dir: {p!r}")
# dest 在 src 子树内 → 自嵌套
try:
dest_r.relative_to(src_r)
raise HTTPException(
400, f"cannot transfer {p!r} into its own subtree"
)
except ValueError:
pass
# 已是 dest 直接子项 → no-op
if src.parent.resolve() == dest_r:
raise HTTPException(
400, f"{p!r} already directly under dest_dir"
)
name = src.name
if name in seen_names:
raise HTTPException(400, f"duplicate source leaf name in batch: {name!r}")
seen_names.add(name)
target = dest / name
if target.exists():
raise HTTPException(
409, f"target already exists: {_rel_to(root, target)!r}"
)
sources.append(src)
return sources, dest
# ─────────────────── BG run + SSE 帧格式 ───────────────────
def _run_agent_bg(
task_id: UUID, user_id: UUID, user_message: str,
image_variant: str = "", video_variant: str = "",
scheduled: bool = False,
) -> None:
"""工作线程:`build_agent(resume=True)` → 装 WebEventSink + cancel_check → `agent.run` → 写 tasks.run_status。
sink 通过 broker.emit 桥事件回 asyncio loop;agent.run 是 sync,所以在 to_thread 跑。
user_id 必须从 JWT 那侧透传过来 —— 决定 memory_block 读哪个 per-user 子树。
cancel_check 桥 broker.is_cancelled,loop 在 stream chunk 间 + 工具调用之间 poll;
cancel 延迟 ~ 单 chunk 间隔(100ms 级);seedance 轮询间也读这个 cancel_check 用于
用户停止按钮(必须在 build_agent 阶段就传进去,因为 SeedanceTool ctor 持有它,
不能像以前那样 build_agent 返回后再赋 agent.cancel_check)。
`ok / cancelled` 收尾直接回 `idle`(不留持久标记);只有 error 是持久终态。
image_variant / video_variant:本 run 用哪个 image/video variant 装 tool(空 → yaml 第一个)。
随消息 POST 传进来,不入 DB —— UI 下拉的选择就跟在这一条消息上生效。
"""
from core.agent_builder import build_agent, sync_task_tokens
cancel_check = lambda tid=task_id: broker.is_cancelled(tid)
try:
broker.emit(task_id, {"type": "run_start"})
agent, session, sid, task_state, task_dir = build_agent(
session_id=str(task_id), resume=True, user_id=user_id,
image_variant=image_variant,
video_variant=video_variant,
cancel_check=cancel_check,
scheduled_run=scheduled,
)
agent.sink = WebEventSink(broker, task_id)
agent.run(user_message)
sync_task_tokens(task_state)
# cancel 命中或正常完成 → run_status 回 idle(error 才持久)
with session_scope() as s:
s.execute(
update(Task).where(Task.task_id == task_id).values(
run_status="idle", run_error=None,
)
)
except Exception as e:
err = f"{type(e).__name__}: {e}"
broker.emit(task_id, {"type": "error", "msg": err})
try:
with session_scope() as s:
s.execute(
update(Task).where(Task.task_id == task_id).values(
run_status="error", run_error=err,
)
)
except Exception:
pass # 已 emit error 给前端,DB 写失败不放大噪声
finally:
broker.clear_cancel(task_id)
broker.close(task_id)
def _sse_event(event_type: str, payload: dict) -> bytes:
"""格式化 SSE 一帧:`event: <type>` + `data: <json single-line>`。"""
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
return f"event: {event_type}\ndata: {body}\n\n".encode("utf-8")
def _resolve_model_profile(profile: str) -> tuple[str, str]:
"""校验 model_profile 并返回 (profile, model_id)。
传空 → cfg["default_model"]。profile 走 ModelCapabilities.load:
格式或文件错误一律 400。返 (profile_str, caps.model_id) —— 调 ensure_local_task_row
时 model_profile / model 两列一起填,保持现有 schema 双列约定。
"""
from core.agent_builder import load_config
from core.capabilities import ModelCapabilities
from core.paths import ROOT
cfg = load_config()
name = (profile or "").strip() or cfg["default_model"]
try:
caps = ModelCapabilities.load(name, ROOT / cfg["models_dir"])
except (FileNotFoundError, ValueError) as e:
raise HTTPException(400, f"invalid model_profile {name!r}: {e}")
return name, caps.model_id
async def _run_channel_conversation(app, uid, text, attachments, *, channel):
"""渠道无关的入站对话核心(§8.7):解析/建该用户该渠道常驻 task → 落盘附件 → 抢 run 锁
→ _run_agent_bg → 取 assistant 回复文本。两渠道各一张会话 task,互不串扰。
channel:'wechat'(个人微信 ClawBot,绑定快照取 chat_task_id)| 'wecom'(企业微信,
wecom 绑定行取 chat_task_id)。attachments:已下载解密的入站附件(可空,wecom 暂只收文本)。
返回回复文本(供 ClawBot 回流 / wecom 主动推回)。
"""
from core.agent_builder import resolve_workspace, working_dir_from_name, load_config
from core.storage.utils import ensure_local_task_row
from core.wechat import service as _wx
from core.wechat.ilink import attachment_basename
from core.wechat.inbound import extract_last_assistant_text
if channel == "wecom":
existing_tid = await asyncio.to_thread(_wx.get_wecom_chat_task, uid)
task_name, slug, desc = "企业微信对话", f"wecom-{str(uid)[:8]}", "(企业微信对话)"
set_task = _wx.set_wecom_chat_task
else:
snap = await asyncio.to_thread(_wx.get_binding, uid)
if snap is None:
return ""
existing_tid = snap.chat_task_id
task_name, slug, desc = "微信对话", f"wechat-{str(uid)[:8]}", "(微信 ClawBot 对话)"
set_task = _wx.set_chat_task
profile, model_id = _resolve_model_profile("")
ws = resolve_workspace(None, load_config())
tid = existing_tid
need_create = tid is None
if not need_create:
with session_scope() as s:
exists = s.execute(
select(Task.task_id).where(Task.task_id == tid, Task.deleted_at.is_(None))
).first()
if exists is None:
need_create = True
if need_create:
tid = uuid4()
fs_dir = working_dir_from_name(ws, uid, slug)
fs_dir.mkdir(parents=True, exist_ok=True)
ensure_local_task_row(
task_id=tid, name=task_name, working_dir=to_db_path(fs_dir),
skill="", user_id=uid, model=model_id, model_profile=profile,
description=desc, channel=channel,
)
await asyncio.to_thread(set_task, uid, tid)
# 落盘入站附件到 <wd>/inbound/,拼 [用户上传的...] 行进 text(复用 web 端粘贴图约定)
if attachments:
from datetime import datetime
from pathlib import Path
with session_scope() as s:
wd_db = s.execute(
select(Task.working_dir).where(Task.task_id == tid)
).scalar_one()
inbound_dir = from_db_path(wd_db) / "inbound"
inbound_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
lines: list[str] = []
for i, att in enumerate(attachments):
if not att.data:
continue
base = attachment_basename(att)
name = f"{ts}-{i}-{base}"
(inbound_dir / name).write_bytes(att.data)
rel = f"inbound/{name}"
tag = "[用户上传的参考图]" if att.kind == "image" else "[用户上传的文件]"
lines.append(f"{tag} {rel}")
if lines:
extra = "\n".join(lines)
text = f"{text}\n\n{extra}" if text.strip() else extra
# 抢 run 锁:正忙 → 提示稍候(同用户串行;ClawBot loop 本就串行,wecom 回调靠此挡并发)
with session_scope() as s:
row = s.execute(
select(Task.run_status).where(Task.task_id == tid).with_for_update()
).first()
if row is None:
return "[出错] 对话 task 不存在"
if row.run_status in ("running", "cancelling"):
return "上一条还在处理中,请稍候再发。"
s.execute(update(Task).where(Task.task_id == tid).values(
run_status="running", run_error=None))
broker.start(tid)
runner = asyncio.create_task(asyncio.to_thread(
_run_agent_bg, tid, uid, text, "", "", False,
))
app.state.inflight[runner] = tid
runner.add_done_callback(lambda t: app.state.inflight.pop(t, None))
await runner
with session_scope() as s:
st = s.execute(
select(Task.run_status, Task.run_error).where(Task.task_id == tid)
).first()
if st is not None and st.run_status == "error":
return f"[出错] {st.run_error}"
reply = await asyncio.to_thread(extract_last_assistant_text, tid)
return reply or "(本轮无文本回复)"
def _list_image_variants() -> list[tuple[str, dict]]:
"""扫 config/media/doubao.yaml image 段 → [(variant_key, variant_cfg), ...]。
yaml 不存在或 image 段空 / 仅注释 → 返 []。不要求 ARK_API_KEY 已设 —— 仅纯
元数据列举,UI 拉这个画下拉。真正调用 seedream 时 agent_builder 那边再过
`ArkConfig.load()`(没 key → tool 不注册)。
"""
from core.paths import ROOT
import yaml as _yaml
p = ROOT / "config" / "media" / "doubao.yaml"
if not p.exists():
return []
try:
data = _yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except Exception:
return []
image_cfg = data.get("image") or {}
return [(k, v) for k, v in image_cfg.items() if isinstance(v, dict)]
def _resolve_image_model(variant: str) -> str:
"""校验 image_model variant key。
传空 → 返空(agent_builder fallback 到第一个 variant);传非空 → 必须存在
于 config/media/doubao.yaml image 段,否则 400。
"""
name = (variant or "").strip()
if not name:
return ""
variants = {k for k, _ in _list_image_variants()}
if name not in variants:
raise HTTPException(400, f"invalid image_model {name!r}; available: {sorted(variants)}")
return name
def _list_video_variants() -> list[tuple[str, dict]]:
"""扫 config/media/doubao.yaml video 段 → [(variant_key, variant_cfg), ...]。
与 _list_image_variants 同范式;空 video 段(未上线 / 注释掉)→ 返 [],UI 隐藏下拉。
"""
from core.paths import ROOT
import yaml as _yaml
p = ROOT / "config" / "media" / "doubao.yaml"
if not p.exists():
return []
try:
data = _yaml.safe_load(p.read_text(encoding="utf-8")) or {}
except Exception:
return []
video_cfg = data.get("video") or {}
return [(k, v) for k, v in video_cfg.items() if isinstance(v, dict)]
def _resolve_video_model(variant: str) -> str:
"""校验 video_model variant key(同 _resolve_image_model 范式)。"""
name = (variant or "").strip()
if not name:
return ""
variants = {k for k, _ in _list_video_variants()}
if name not in variants:
raise HTTPException(400, f"invalid video_model {name!r}; available: {sorted(variants)}")
return name
# ────────────────────── Pydantic 请求体 ──────────────────────
class TaskCreateRequest(BaseModel):
name: str # 任务显示名(必填,DB 列 NOT NULL)
working_dir: str = "" # 工作目录名(可选,留空 → 用 name 作目录名)
description: str = ""
skill: str = ""
model_profile: str = "" # `family.variant`,留空 → cfg["default_model"];必须存在于 config/models/
class TaskPatchRequest(BaseModel):
status: Optional[str] = None
description: Optional[str] = None
name: Optional[str] = None
skill: Optional[str] = None
model_profile: Optional[str] = None # 切模型(c 模式 task 层 / A 粒度 — 下条 send 生效)
class SchedulePatchRequest(BaseModel):
# 前端只读视图仅用 enabled(停用/启用);其余字段留着供"对话改不了时"的兜底直改,
# 但前端不暴露编辑表单(建/改走对话,§8.5)。
enabled: Optional[bool] = None
class MessageRequest(BaseModel):
content: str
# 该条消息触发的生图 / 生视频模型 variant key(config/media/doubao.yaml image/video 段)。
# 空 → 对应 tool 走 yaml 第一个 variant;非空 → 本次 run 装配指定 variant。
# 仅作用于本 run,不入 DB,UI 下拉的选择跟在消息 POST body 上。
image_model: str = ""
video_model: str = ""
class OptimizePromptRequest(BaseModel):
text: str
# 选择性传当前 UI 选中的 variant key,润色 meta-prompt 会把对应模型特性塞进去
# (让 LLM 知道下游 tool 偏好,润色出更贴合 seedream / seedance 等的 prompt)。
image_model: str = ""
video_model: str = ""
class FileDeleteRequest(BaseModel):
path: str
recursive: bool = False
class FileRenameRequest(BaseModel):
path: str # 被重命名的目录 / 文件,相对 user_root
new_name: str # 新的 leaf 名(不是路径),不含 / \ ..
class FileTransferRequest(BaseModel):
paths: list[str] # 多源,均相对 user_root
dest_dir: str = "" # 目标目录,相对 user_root,空 → user_root
class LoginRequest(BaseModel):
user_id: str
platform_key: str
# 0016:平台可选注入的用户档案,缺省即旧行为(只填 user_id),向后兼容老调用方。
name: Optional[str] = None # 显示名 / 姓名
user_name: Optional[str] = None # 平台账号名
class PasswordLoginRequest(BaseModel):
email: str
password: str
class AdminCreateUserRequest(BaseModel):
email: str
password: str
admin_token: str
role: str = "user" # 'user' / 'admin';admin 可访问 /static/admin.html 管理后台
class ChangePasswordRequest(BaseModel):
old_password: str
new_password: str
# ────────────────────── App 工厂 ──────────────────────
# web/static 目录路径 — /static 静态挂载用,dev.html 也放这
_STATIC_DIR = Path(__file__).parent / "static"
def create_app() -> FastAPI:
# fail-fast:env 缺失直接抛,不裸跑无密
auth_cfg = AuthConfig.from_env()
require_user = make_require_user(auth_cfg)
require_admin = make_require_admin(auth_cfg)
@asynccontextmanager
async def lifespan(app: FastAPI):
loop = asyncio.get_running_loop()
broker.bind_loop(loop)
# ── 接管默认线程池 executor(§8.4)──────────────────────────────
# run 走 asyncio.to_thread(用 loop 默认 executor);默认是匿名的,读不到大小、
# 不可调。显式建一个同尺寸(复刻 Python 默认 min(32, cpu+4))接管,好处:① 监控
# 能读 max_workers 判断有没有排队 ② 并发不够时改 ZCBOT_RUN_MAX_WORKERS 调大不改码。
# 注:run 与 disk scan / pptx 转换 / reaper 共享此池(同原默认行为);真要隔离
# 长任务再另开 run 专用池,那是后话。
run_max_workers = int(
os.getenv("ZCBOT_RUN_MAX_WORKERS") or min(32, (os.cpu_count() or 1) + 4)
)
run_executor = ThreadPoolExecutor(
max_workers=run_max_workers, thread_name_prefix="run"
)
loop.set_default_executor(run_executor)
app.state.run_executor = run_executor
app.state.run_max_workers = run_max_workers
print(f"[startup] run executor: max_workers={run_max_workers} "
f"(override via ZCBOT_RUN_MAX_WORKERS)")
from core.agent_builder import load_config, resolve_workspace
_cfg = load_config()
# 优雅 drain 状态(SIGTERM / systemctl restart 兜底,见下方 finally):
# draining 置位后 POST /messages 返 503;inflight 登记在跑的 BG run task,
# 关停时 await 它们收尾。inflight 同时给 create_task 持强引用,防被 GC 中途回收。
app.state.draining = asyncio.Event()
app.state.inflight = {} # dict[asyncio.Task, UUID(task_id)]
_shutdown_cfg = _cfg.get("shutdown") or {}
drain_timeout = int(_shutdown_cfg.get("drain_timeout_seconds") or 90)
cancel_grace = int(_shutdown_cfg.get("cancel_grace_seconds") or 15)
# Stale-run reaper:上次进程 crash 留下的 "running" / "cancelling" 已无 BG 线程
# 继续,启动时标 error,让对应 task 重新可发消息(否则 gate 永挂)。
# TODO 真生产 multi-worker:换 heartbeat / lease,只 reap 自家 worker 的孤儿。
with session_scope() as s:
result = s.execute(
update(Task)
.where(Task.run_status.in_(("running", "cancelling")))
.values(
run_status="error",
run_error="server restarted before run finished",
)
)
if result.rowcount:
print(f"[startup] reaped {result.rowcount} stale active run(s)")
# 磁盘配额后台扫描(§7.5 #4 应用层 gate)── 不依赖 docker backend,host
# backend 也跑(/v1/files/upload 也走配额 gate)。yaml `quotas.disk_scan_interval_seconds`
# 默 900s = 15min;limit_bytes ≤ 0 视为不限,scan 仍跑(用量统计有用),check 短路放行。
from core.agent_builder import resolve_workspace
from core.storage.disk_quota import parse_bytes, scan_all_users
workspace = resolve_workspace(None, _cfg)
disk_user_root = workspace / "users"
quotas_cfg = _cfg.get("quotas") or {}
disk_scan_interval = int(quotas_cfg.get("disk_scan_interval_seconds") or 900)
async def _disk_scanner() -> None:
loop = asyncio.get_running_loop()
# 启动时跑一次,后续按 interval。首次扫完 check 才能命中。
try:
n = await loop.run_in_executor(None, scan_all_users, disk_user_root)
if n:
print(f"[disk_scanner] initial scan: {n} user(s)")
except Exception as e:
print(f"[disk_scanner] initial scan error: {type(e).__name__}: {e}")
while True:
try:
await asyncio.sleep(disk_scan_interval)
n = await loop.run_in_executor(None, scan_all_users, disk_user_root)
if n:
print(f"[disk_scanner] scanned {n} user(s)")
except asyncio.CancelledError:
raise
except Exception as e:
print(f"[disk_scanner] error: {type(e).__name__}: {e}")
disk_scanner_task = asyncio.create_task(_disk_scanner(), name="disk-scanner")
# ── 并发/线程池监控(§8.4):周期采样,只在有负载/刷新峰值时打,空闲不刷屏 ──
# active_runs 来自 inflight(已提交未完成的 run,含排队中);逼近 max_workers 即
# 线程池排队,新 run 的 SSE 会卡着不吐 token。查看:journalctl -u zcbot | grep '\[stats\]'
def _rss_peak_mb() -> Optional[float]:
if resource is None:
return None # Windows dev:降级,不打 rss
# Linux ru_maxrss 单位 KB,是峰值/high-water(单调不降 —— 看内存涨势够用)
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
async def _stats_logger() -> None:
peak = 0
while True:
try:
await asyncio.sleep(60)
active = len(app.state.inflight)
if active > peak:
peak = active
warn = " [WARN >= max_workers,已在排队]" if active >= run_max_workers else ""
print(f"[stats] new peak active_runs={active} "
f"max_workers={run_max_workers}{warn}")
if active > 0:
rss = _rss_peak_mb()
rss_s = f" rss_peak={rss:.0f}MB" if rss is not None else ""
print(f"[stats] active_runs={active} "
f"max_workers={run_max_workers} "
f"sse_subs={broker.total_subscribers()}{rss_s}")
except asyncio.CancelledError:
raise
except Exception as e:
print(f"[stats] error: {type(e).__name__}: {e}")
stats_logger_task = asyncio.create_task(_stats_logger(), name="stats-logger")
# ── 定时任务守护循环(§8.5)── 仿 _disk_scanner 的 plain-asyncio 范式,不引
# APScheduler/Celery。每 ~10s 认领到点 job(claim+advance next_run 防重复触发),
# 复用 _run_agent_bg 起 run,跑完确定性兜底投递 + 回写 last_*。间隔只决定最坏延迟
# (≤1 tick),不决定会不会漏(claim 取 next_run<=now 的全部)。ZCBOT_DISABLE_SCHEDULER=1
# 整体关掉(对照 Claude Code CLAUDE_CODE_DISABLE_CRON)。
scheduler_enabled = os.getenv("ZCBOT_DISABLE_SCHEDULER", "").strip() not in ("1", "true", "yes")
sched_tick = int(os.getenv("ZCBOT_SCHEDULER_TICK_SECONDS", "10") or "10")
sched_sema = asyncio.Semaphore(int(os.getenv("ZCBOT_SCHEDULER_CONCURRENCY", "4") or "4"))
async def _execute_scheduled_job(snap: dict) -> None:
"""认领后跑一个 job:解析目标 task → 抢 run 锁 → _run_agent_bg → 投递 + 记账。"""
from core.agent_builder import (
resolve_workspace, working_dir_from_name, validate_task_name, InvalidTaskName,
)
from core.scheduler import build_run_message, deliver_notify, record_result
from core.storage.utils import ensure_local_task_row
job_id = snap["job_id"]
uid = snap["user_id"]
async with sched_sema:
try:
profile, model_id = _resolve_model_profile(snap.get("model_profile") or "")
ws = resolve_workspace(None, _cfg)
# 目标 task:persistent 用绑定 task(缺则新建并回填);isolated 用稳定 per-job 目录
tid: Optional[UUID] = None
if snap["mode"] == "persistent" and snap.get("bound_task_id"):
tid = snap["bound_task_id"]
# 绑定 task 可能已被删(SET NULL 已处理 None;这里再查实在性)
with session_scope() as s:
exists = s.execute(
select(Task.task_id).where(
Task.task_id == tid, Task.deleted_at.is_(None)
)
).first()
if exists is None:
tid = None
if tid is None:
tid = uuid4()
wd_name = f"scheduled-{str(job_id)[:8]}"
fs_dir = working_dir_from_name(ws, uid, wd_name)
fs_dir.mkdir(parents=True, exist_ok=True)
disp = f"{snap['name']}"
try:
disp = validate_task_name(disp)
except InvalidTaskName:
disp = wd_name # 名字含非法字符 → 退到安全名
ensure_local_task_row(
task_id=tid, name=disp, working_dir=to_db_path(fs_dir),
skill=snap.get("skill") or "", user_id=uid,
model=model_id, model_profile=profile,
description="(定时任务自动创建)",
)
if snap["mode"] == "persistent":
with session_scope() as s:
s.execute(update(ScheduledJob).where(
ScheduledJob.job_id == job_id
).values(bound_task_id=tid))
# 抢 run 锁(同 post_message):busy → 本次跳过,下个 cron 点再来
with session_scope() as s:
row = s.execute(
select(Task.run_status).where(Task.task_id == tid).with_for_update()
).first()
if row is None:
record_result(job_id, status="error", task_id=tid, error="目标 task 不存在")
return
if row.run_status in ("running", "cancelling"):
record_result(job_id, status="skipped", task_id=tid,
error="目标 task 正忙,本次跳过")
print(f"[scheduler] job {str(job_id)[:8]} skipped (task busy)")
return
s.execute(update(Task).where(Task.task_id == tid).values(
run_status="running", run_error=None))
message = build_run_message(snap)
broker.start(tid)
runner = asyncio.create_task(asyncio.to_thread(
_run_agent_bg, tid, uid, message, "", "", True,
))
app.state.inflight[runner] = tid
runner.add_done_callback(lambda t: app.state.inflight.pop(t, None))
timeout = int(snap.get("timeout_seconds") or 0)
if timeout > 0:
done, _pending = await asyncio.wait({runner}, timeout=timeout)
if not done:
broker.request_cancel(tid) # 协作式停;loop 在 chunk 间 poll 到即退
print(f"[scheduler] job {str(job_id)[:8]} timed out ({timeout}s), cancelling")
await runner
else:
await runner
# run 终态:_run_agent_bg 收尾把 run_status 写回 idle(ok)/error
with session_scope() as s:
st = s.execute(
select(Task.run_status, Task.run_error).where(Task.task_id == tid)
).first()
if st is not None and st.run_status == "error":
record_result(job_id, status="error", task_id=tid, error=st.run_error)
print(f"[scheduler] job {str(job_id)[:8]} run error: {st.run_error}")
return
# 第 3 层确定性兜底投递(notify);失败不影响 run 已成功这一事实
if snap.get("notify"):
try:
with session_scope() as s:
wd_db = s.execute(
select(Task.working_dir).where(Task.task_id == tid)
).scalar_one_or_none()
fs_dir = from_db_path(wd_db) if wd_db else ws
await asyncio.get_running_loop().run_in_executor(
None, lambda: deliver_notify(
snap["notify"], job_name=snap["name"],
working_dir=fs_dir, tz=snap["tz"],
user_id=snap["user_id"],
)
)
except Exception as e:
print(f"[scheduler] job {str(job_id)[:8]} notify failed: {type(e).__name__}: {e}")
record_result(job_id, status="ok", task_id=tid)
print(f"[scheduler] job {str(job_id)[:8]} '{snap['name']}' done")
except Exception as e:
print(f"[scheduler] job {str(job_id)[:8]} crashed: {type(e).__name__}: {e}")
try:
record_result(job_id, status="error", task_id=None, error=f"{type(e).__name__}: {e}")
except Exception:
pass
async def _scheduler_loop() -> None:
from core.scheduler import claim_due_jobs
loop = asyncio.get_running_loop()
while True:
try:
await asyncio.sleep(sched_tick)
if getattr(app.state, "draining", None) is not None and app.state.draining.is_set():
continue # 关停 drain 期不起新 job
due = await loop.run_in_executor(None, claim_due_jobs)
for snap in due:
asyncio.create_task(_execute_scheduled_job(snap))
if due:
print(f"[scheduler] fired {len(due)} job(s)")
except asyncio.CancelledError:
raise
except Exception as e:
print(f"[scheduler] loop error: {type(e).__name__}: {e}")
scheduler_task = asyncio.create_task(_scheduler_loop(), name="scheduler") if scheduler_enabled else None
if scheduler_enabled:
print(f"[scheduler] enabled (tick={sched_tick}s)")
# ── 微信(ClawBot)入站长轮询管理器(§8.7)── 仅当 ZCBOT_WECHAT_BOT_ENABLED 在。
# 每个 active 绑定一条 getupdates 长轮询;收到消息 → 跑用户常驻「微信」task → 回复发回。
from core.wechat.service import clawbot_enabled
wechat_stop = asyncio.Event()
wechat_task = None
async def _run_wechat_message(uid: UUID, text: str, attachments=None) -> str:
"""微信(ClawBot)入站一条消息 → 跑用户常驻「微信」task → 取回复。
attachments:已下载解密的入站附件(core.wechat.ilink.InboundAttachment,att.data 已回填)。
建/复用 task、落盘附件、抢 run 锁、跑 agent 全在渠道无关核心
`_run_channel_conversation` 里(企业微信回调走同一核心,channel='wecom')。
"""
return await _run_channel_conversation(app, uid, text, attachments, channel="wechat")
if clawbot_enabled():
from core.wechat.inbound import run_inbound_manager
wechat_task = asyncio.create_task(
run_inbound_manager(_run_wechat_message, wechat_stop),
name="wechat-inbound",
)
print("[wechat] ClawBot inbound manager enabled")
# Sandbox pool(§7.5):仅当 ZCBOT_SANDBOX_BACKEND=docker 时启用。
# 启动钩子:① init_pool(创建 docker network + pool 实例)② shutdown_all 清
# 前驱孤儿(上次进程留下的 zcbot-sandbox-* 容器,内存 _last_active 为空,
# 全清重启)③ 后台 reaper task,每 60s 跑 reap_idle。
sandbox_backend = os.getenv("ZCBOT_SANDBOX_BACKEND", "host").lower()
sandbox_reaper_task = None
if sandbox_backend == "docker":
from core.paths import ROOT
from core.sandbox import init_pool
from core.sandbox.check import detect_fs_quota
workspace = resolve_workspace(None, _cfg)
user_root_base = workspace / "users"
# §7.5 #4 fs quota 探测:不阻塞启动(应用层周期扫描已有),仅打 WARN
# 提醒外部用户开放前必须升级到 xfs prjquota / ext4 project / zfs。
try:
level, msg = detect_fs_quota(user_root_base.resolve())
print(f"[startup] {'[ok]' if level == 'ok' else '[warn]'} {msg}")
except Exception as e:
print(f"[startup] [warn] fs quota detect failed: {type(e).__name__}: {e}")
try:
# repo_root=ROOT 让 SandboxPool 把 <repo>/skills 只读 mount 进容器
# (fs 工具进容器后 read SKILL references 需要)
# sandbox_cfg=yaml `sandbox` 段(memory/cpus/pids_limit 可调)
pool = init_pool(
user_root_base, repo_root=ROOT,
sandbox_cfg=_cfg.get("sandbox") or {},
)
removed = pool.shutdown_all()
if removed:
print(f"[startup] swept {len(removed)} stale sandbox container(s)")
async def _reaper() -> None:
loop = asyncio.get_running_loop()
while True:
try:
await asyncio.sleep(60)
removed = await loop.run_in_executor(None, pool.reap_idle)
if removed:
print(f"[reaper] reaped {len(removed)} idle sandbox container(s)")
except asyncio.CancelledError:
raise
except Exception as e:
print(f"[reaper] error: {type(e).__name__}: {e}")
sandbox_reaper_task = asyncio.create_task(_reaper(), name="sandbox-reaper")
app.state.sandbox_pool = pool
except Exception as e:
# ensure_network / docker CLI 不可用 → fail-fast。Stage C 协议:任一
# hardening 缺失视为部署未完成,不退化到 host(否则误以为有沙盒实则在裸跑)。
raise RuntimeError(
f"ZCBOT_SANDBOX_BACKEND=docker but sandbox init failed: {e}"
)
try:
yield
finally:
# ── 优雅 drain:先拒新 run,等在跑的 run 收尾,超时转协作式 cancel ──
# 单实例形态下消除"restart 误杀 in-flight run 标 error"。新 POST /messages
# 期间返 503(客户端退避重试覆盖)。drain_timeout 内自然跑完 → idle 零 error;
# 超时的 broker.request_cancel → 下个 chunk 间隙退(标 idle);cancel_grace 后仍
# 没退的留给 systemd SIGKILL,下次启动 reaper 标 error(最坏退化 = 改前行为)。
# ★ systemd TimeoutStopSec 必须 > drain_timeout + cancel_grace + 余量(见 RUN.md)。
app.state.draining.set()
inflight = app.state.inflight
if inflight:
print(f"[shutdown] draining {len(inflight)} in-flight run(s), "
f"timeout={drain_timeout}s")
_, pending = await asyncio.wait(
list(inflight.keys()), timeout=drain_timeout
)
if pending:
print(f"[shutdown] {len(pending)} run(s) over drain timeout; "
f"signalling cooperative cancel")
for t in pending:
cid = inflight.get(t)
if cid is not None:
broker.request_cancel(cid)
_, still = await asyncio.wait(pending, timeout=cancel_grace)
if still:
print(f"[shutdown] {len(still)} run(s) still active after "
f"cancel grace; SIGKILL takes over, next start reaps them")
disk_scanner_task.cancel()
try:
await disk_scanner_task
except (asyncio.CancelledError, Exception):
pass
stats_logger_task.cancel()
try:
await stats_logger_task
except (asyncio.CancelledError, Exception):
pass
if scheduler_task is not None:
scheduler_task.cancel()
try:
await scheduler_task
except (asyncio.CancelledError, Exception):
pass
if wechat_task is not None:
wechat_stop.set()
wechat_task.cancel()
try:
await wechat_task
except (asyncio.CancelledError, Exception):
pass
if sandbox_reaper_task is not None:
sandbox_reaper_task.cancel()
try:
await sandbox_reaper_task
except (asyncio.CancelledError, Exception):
pass
if sandbox_backend == "docker":
pool = getattr(app.state, "sandbox_pool", None)
if pool is not None:
try:
pool.shutdown_all()
except Exception as e:
print(f"[shutdown] sandbox shutdown_all error: {type(e).__name__}: {e}")
# drain 已 await inflight 收尾、run 线程退完;非阻塞关池(进程在退出,保守清理)
run_executor.shutdown(wait=False)
app = FastAPI(
title="zcbot api",
version=__version__,
description=(
"zcbot 后端 — /v1 JSON API + SSE。Auth: PLATFORM_KEY → JWT(§7 D' 过渡)。"
"本地 dev SPA: /static/dev.html。"
),
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 本地宽松,部署 platform 时按域名收紧
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
if _STATIC_DIR.is_dir():
# Windows 上 mimetypes 偶尔把 .js 判成 text/plain,会令 <script type="module"> 被浏览器拒执行;
# 显式兜底,保证静态 ES module 以正确 MIME 下发。
mimetypes.add_type("text/javascript", ".js")
app.mount("/static", NoCacheStaticFiles(directory=str(_STATIC_DIR)), name="static")
# ───────────── Misc ─────────────
@app.get("/", include_in_schema=False)
def root():
# 本地 dev SPA;Swagger UI 仍在 /docs
return RedirectResponse(url="/static/dev.html", status_code=302)
@app.get("/healthz", tags=["misc"])
def healthz():
return {"status": "ok", "version": __version__}
@app.get("/WW_verify_{token}.txt", include_in_schema=False)
def wecom_domain_verify(token: str):
"""企业微信「网页授权可信域名」归属校验文件(公开,无需登录)。
把企业微信下载的 WW_verify_<token>.txt 放到 ZCBOT_WECOM_VERIFY_DIR
(默认 repo 根)下,本路由按文件名在域名根 serve。"""
from fastapi.responses import PlainTextResponse
from core.paths import ROOT
if not token.isalnum(): # 防路径穿越
raise HTTPException(404, "not found")
vdir = Path(os.getenv("ZCBOT_WECOM_VERIFY_DIR", "").strip() or str(ROOT))
fpath = vdir / f"WW_verify_{token}.txt"
if not fpath.is_file():
raise HTTPException(404, "verify file not found")
return PlainTextResponse(fpath.read_text(encoding="utf-8"))
@app.get("/v1/me", tags=["misc"])
def me(user_id: UUID = Depends(require_user)):
"""当前登录用户身份(JWT → user_id → DB role)。
前端 dev SPA 用 localStorage 里的 token 恢复会话时调一次,据 role=='admin'
决定显不显"管理"入口(/static/admin.html)。role 走 DB 查,改完即时生效。
name / user_name 是平台登录注入的档案(0016),可能为 null(邮箱密码 / 历史行)。
"""
prof = get_user_profile(user_id) or {}
return {
"user_id": str(user_id),
"role": prof.get("role", "user"),
"name": prof.get("name"),
"user_name": prof.get("user_name"),
"email": prof.get("email"),
}
# ───────────── 微信接入(ClawBot,§8.7)─────────────
@app.post("/v1/wechat/bind/qrcode", tags=["wechat"])
async def wechat_bind_qrcode(user_id: UUID = Depends(require_user)):
"""起一张 ClawBot 绑定二维码(渲成 PNG data-uri)。前端展示,用户手机微信扫;
二维码 TTL ~1min,前端轮询到 expired 后重调本端点换码。"""
import base64 as _b64
import io as _io
import segno
from core.wechat import ilink
qr = await asyncio.to_thread(ilink.get_bot_qrcode)
buf = _io.BytesIO()
segno.make(qr.deeplink, error="m").save(buf, kind="png", scale=6, border=3)
data_uri = "data:image/png;base64," + _b64.b64encode(buf.getvalue()).decode()
return {"qrcode_id": qr.qrcode_id, "qr_png": data_uri}
@app.get("/v1/wechat/bind/status", tags=["wechat"])
async def wechat_bind_status(qrcode_id: str, user_id: UUID = Depends(require_user)):
"""轮询扫码状态(服务端长轮询,hold 数十秒)。confirmed → 写绑定。
返回 {status: wait|confirmed|expired};expired 时前端重起二维码。"""
from core.wechat import ilink
from core.wechat import service as _wx
res = await asyncio.to_thread(ilink.poll_qrcode_status, qrcode_id)
if res.status == "confirmed" and res.bot_token:
await asyncio.to_thread(
_wx.upsert_clawbot_binding, user_id, res.bot_token,
res.base_url or ilink.DEFAULT_BASE,
)
return {"status": res.status}
@app.get("/v1/wechat/bind", tags=["wechat"])
def wechat_bind_get(user_id: UUID = Depends(require_user)):
"""当前用户的微信绑定状态(不泄露 token)。"""
from core.wechat import service as _wx
snap = _wx.get_binding(user_id)
if snap is None or snap.status != "active":
return {"bound": False}
return {
"bound": True,
"user_im_id": snap.user_im_id,
"can_push": bool(_wx._token_fresh(snap)), # 24h 窗口内可主动推
"last_active": _iso(snap.context_token_at),
}
@app.delete("/v1/wechat/bind", status_code=204, tags=["wechat"])
def wechat_unbind(user_id: UUID = Depends(require_user)):
from core.wechat import service as _wx
_wx.unbind(user_id)
return
@app.post("/v1/wechat/test", tags=["wechat"])
async def wechat_test(user_id: UUID = Depends(require_user)):
"""自检:给已绑用户推一条测试消息(需用户近 24h 在微信开口过)。"""
from core.wechat import service as _wx
res = await asyncio.to_thread(
_wx.push_clawbot, user_id, "zcbot 测试消息:绑定成功,这条来自你的 zcbot。"
)
return {"ok": res.ok, "reason": res.reason}
# ───────────── 企业微信接入(渠道 B,纯推送,§8.7)─────────────
@app.get("/v1/wecom/oauth/url", tags=["wecom"])
def wecom_oauth_url(request: Request, user_id: UUID = Depends(require_user)):
"""生成企业微信 OAuth 网页授权链接(前端打开 → 扫码授权)。回调用 state 带回身份。
redirect 主机须在应用「网页授权可信域名」内;默认取 ZCBOT_PUBLIC_BASE_URL 或请求 base。"""
from core.wechat import wecom
if not wecom.wecom_configured():
raise HTTPException(400, "企业微信未配置(需 WECOM_CORPID/AGENTID/SECRET)")
base = (os.getenv("ZCBOT_PUBLIC_BASE_URL", "").strip()
or str(request.base_url).rstrip("/"))
redirect_uri = f"{base}/v1/wecom/oauth/callback"
state = wecom.sign_state(str(user_id))
return {"authorize_url": wecom.oauth_authorize_url(redirect_uri, state)}
@app.get("/v1/wecom/oauth/callback", include_in_schema=False)
async def wecom_oauth_callback(code: str = "", state: str = ""):
"""企业微信授权后浏览器回跳到这(无 JWT;身份从 state 验)。换 userid → 写绑定 → 回提示页。"""
from fastapi.responses import HTMLResponse
from core.wechat import service as _wx
from core.wechat import wecom
def _page(msg: str, ok: bool) -> HTMLResponse:
color = "#1a7f37" if ok else "#cf222e"
return HTMLResponse(
f"<!doctype html><meta charset=utf-8><meta name=viewport content='width=device-width,initial-scale=1'>"
f"<div style='font:16px system-ui;max-width:420px;margin:80px auto;text-align:center;color:{color}'>"
f"{msg}</div><div style='text-align:center;color:#888;font:13px system-ui'>可关闭本页返回 zcbot</div>"
)
uid = wecom.verify_state(state)
if not uid:
return _page("绑定失败:授权已过期或无效,请回 zcbot 重试。", False)
if not code:
return _page("绑定失败:未拿到授权码。", False)
try:
wecom_userid = await asyncio.to_thread(wecom.get_user_id, code)
except Exception as e:
return _page(f"绑定失败:{type(e).__name__}", False)
if not wecom_userid:
return _page("绑定失败:你不是该企业成员(只支持企业内成员)。", False)
await asyncio.to_thread(_wx.upsert_wecom_binding, UUID(uid), wecom_userid)
return _page("✅ 企业微信绑定成功!以后简报 / 结果会推到你的企业微信。", True)
@app.get("/v1/wecom/bind", tags=["wecom"])
def wecom_bind_get(user_id: UUID = Depends(require_user)):
"""当前用户企业微信绑定状态。"""
from core.wechat import service as _wx
from core.wechat import wecom
wuid = _wx.get_wecom_userid(user_id)
return {"configured": wecom.wecom_configured(), "bound": bool(wuid), "wecom_userid": wuid}
@app.put("/v1/wecom/bind/userid", tags=["wecom"])
def wecom_bind_userid(payload: dict, user_id: UUID = Depends(require_user)):
"""手填企业微信成员 userid 绑定(无 HTTPS 域名 / 不走 OAuth 时用)。
userid 见管理后台 → 通讯录 → 点成员 → 「账号」。"""
from core.wechat import service as _wx
from core.wechat import wecom
if not wecom.wecom_configured():
raise HTTPException(400, "企业微信未配置(需 WECOM_CORPID/AGENTID/SECRET)")
wuid = (payload.get("wecom_userid") or "").strip()
if not wuid:
raise HTTPException(400, "wecom_userid 不能为空")
_wx.upsert_wecom_binding(user_id, wuid)
return {"bound": True, "wecom_userid": wuid}
@app.delete("/v1/wecom/bind", status_code=204, tags=["wecom"])
def wecom_unbind(user_id: UUID = Depends(require_user)):
from core.wechat import service as _wx
_wx.unbind_wecom(user_id)
return
@app.post("/v1/wecom/test", tags=["wecom"])
async def wecom_test(user_id: UUID = Depends(require_user)):
"""自检:给已绑用户推一条企业微信测试消息(无 24h 窗口约束)。"""
from core.wechat import service as _wx
res = await asyncio.to_thread(
_wx.push_wecom, user_id, "zcbot 测试消息:企业微信绑定成功,这条来自你的 zcbot。"
)
return {"ok": res.ok, "reason": res.reason}
# ── 企业微信「接收消息」回调(入站对话,§8.7)── 无 JWT;身份从加密 XML 的 FromUserName 反查。
# 配置:企业微信后台「应用 → 接收消息 → 设置 API 接收」填本 URL + Token + EncodingAESKey,
# 对应 env WECOM_CALLBACK_TOKEN / WECOM_CALLBACK_AESKEY。回调 URL = <公网 base>/v1/wecom/callback。
@app.get("/v1/wecom/callback", include_in_schema=False)
def wecom_callback_verify(
msg_signature: str = "", timestamp: str = "", nonce: str = "", echostr: str = ""
):
"""企业微信保存回调配置时 GET 验有效性:验签 + 解密 echostr,原样回明文。"""
from fastapi.responses import PlainTextResponse
from core.wechat import wecom, wecom_crypto
if not wecom_crypto.callback_configured():
raise HTTPException(404, "wecom callback 未配置(需 WECOM_CALLBACK_TOKEN/AESKEY)")
try:
plain = wecom_crypto.verify_url(
msg_signature, timestamp, nonce, echostr, corpid=wecom._corpid()
)
except Exception as e: # noqa: BLE001
raise HTTPException(400, f"verify failed: {type(e).__name__}: {e}")
return PlainTextResponse(plain)
@app.post("/v1/wecom/callback", include_in_schema=False)
async def wecom_callback(
request: Request, msg_signature: str = "", timestamp: str = "", nonce: str = ""
):
"""企业微信推入站消息(加密 XML POST)。解密 → 反查身份 → 后台跑 agent → 主动推回。
agent 跑 >5s,远超被动回复(同步返回密文 XML)5s 窗口 → 异步:立刻回 'success' 防重试,
agent 结果走 wecom.send_text 主动推回(message/send,无 24h 窗口约束)。同一用户的并发/
重复投递由对话 task 的 run 锁挡(第二条会收到「上一条还在处理中」)。
"""
from fastapi.responses import PlainTextResponse
from core.wechat import service as _wx
from core.wechat import wecom, wecom_crypto
from core.wechat.ilink import InboundAttachment
if not wecom_crypto.callback_configured():
raise HTTPException(404, "wecom callback 未配置(需 WECOM_CALLBACK_TOKEN/AESKEY)")
body = (await request.body()).decode("utf-8")
try:
msg = wecom_crypto.decrypt_message(
msg_signature, timestamp, nonce, body, corpid=wecom._corpid()
)
except Exception as e: # noqa: BLE001
raise HTTPException(400, f"decrypt failed: {type(e).__name__}: {e}")
msgtype = msg.get("MsgType") or ""
wuid = msg.get("FromUserName") or ""
uid = await asyncio.to_thread(_wx.get_user_by_wecom_userid, wuid)
if uid is None:
return PlainTextResponse("success") # 未绑定 → 静默
# 文本取 Content;图片/文件走 media/get 下载,构造 InboundAttachment(与个人微信同结构,
# 仅 kind/file_name/data 三字段被 _run_channel_conversation 用到)。其余类型(语音/视频/
# 位置/链接/事件)暂不处理,回 success 防重试。
content = ""
attachments: list = []
if msgtype == "text":
content = (msg.get("Content") or "").strip()
elif msgtype in ("image", "file"):
media_id = msg.get("MediaId") or ""
if media_id:
try:
data, fname = await asyncio.to_thread(wecom.download_media, media_id)
attachments.append(InboundAttachment(
kind=("image" if msgtype == "image" else "file"),
media={},
file_name=(msg.get("FileName") or fname or ""),
data=data,
))
except Exception as e: # noqa: BLE001
print(f"[wecom] {wuid} download {msgtype} err: {type(e).__name__}: {e}")
else:
return PlainTextResponse("success")
if not content and not attachments:
return PlainTextResponse("success") # 空消息 / 附件下载全失败 → 静默
async def _bg(uid=uid, content=content, attachments=attachments):
try:
reply = await _run_channel_conversation(
app, uid, content, attachments, channel="wecom")
except Exception as e: # noqa: BLE001
reply = f"[出错] {type(e).__name__}: {e}"
if reply and reply.strip():
await asyncio.to_thread(_wx.push_wecom, uid, reply)
# 登记到 inflight:持强引用防 task 被 GC 中途回收 + 关停时 drain(value=None → 不参与
# broker cancel;内层 _run_agent_bg runner 另有自己的 inflight 项负责取消)。
bg = asyncio.create_task(_bg(), name=f"wecom-msg-{str(uid)[:8]}")
app.state.inflight[bg] = None
bg.add_done_callback(lambda t: app.state.inflight.pop(t, None))
return PlainTextResponse("success")
@app.get("/v1/models", tags=["misc"])
def list_models(user_id: UUID = Depends(require_user)):
"""列出所有可用 LLM 模型(扫 config/models/*.yaml)。
前端顶栏 / 新建对话框的模型下拉拉这个。is_default 标记 cfg["default_model"]
命中项。开发期不缓存,每次扫一遍(几个文件 IO);改 YAML 立即生效。
"""
from core.agent_builder import load_config
from core.capabilities import ModelCapabilities
from core.paths import ROOT
import yaml as _yaml
cfg = load_config()
default = cfg["default_model"]
models_dir = ROOT / cfg["models_dir"]
out: list[dict] = []
if models_dir.is_dir():
for path in sorted(models_dir.glob("*.yaml")):
try:
data = _yaml.safe_load(path.read_text(encoding="utf-8")) or {}
except Exception:
continue
family = data.get("family") or path.stem
for variant in (data.get("variants") or {}).keys():
profile = f"{family}.{variant}"
try:
caps = ModelCapabilities.load(profile, models_dir)
except (ValueError, FileNotFoundError):
continue
out.append({
"profile": profile,
"display_name": caps.display_name or profile,
"family": caps.family,
"variant": caps.variant,
"thinking_mode": caps.thinking_mode,
"is_default": profile == default,
})
return {"models": out}
@app.get("/v1/image_models", tags=["misc"])
def list_image_models(user_id: UUID = Depends(require_user)):
"""图像生成模型清单(扫 config/media/doubao.yaml image 段)。
前端顶栏第二个下拉拉这个;空列表 → 没配 image variant,UI 隐藏下拉。
`is_default` 标第一个 variant(=agent_builder fallback 目标)。开发期不缓存,
改 YAML 加新 variant 立即生效。
"""
variants = _list_image_variants()
out: list[dict] = []
for i, (key, cfg) in enumerate(variants):
out.append({
"variant": key,
"display_name": cfg.get("display_name") or key,
"model_id": cfg.get("model_id") or "",
"price_cny_per_image": cfg.get("price_cny_per_image"),
"is_default": i == 0,
})
return {"models": out}
@app.get("/v1/video_models", tags=["misc"])
def list_video_models(user_id: UUID = Depends(require_user)):
"""视频生成模型清单(扫 config/media/doubao.yaml video 段)。
与 /v1/image_models 同范式;空列表 → UI 隐藏第三下拉。展示信息包括默认分辨率
与 token 单价(¥/Mtok 文生视频路径),方便用户在下拉选项里直接看到 cost 量级。
"""
variants = _list_video_variants()
out: list[dict] = []
for i, (key, cfg) in enumerate(variants):
out.append({
"variant": key,
"display_name": cfg.get("display_name") or key,
"model_id": cfg.get("model_id") or "",
"default_resolution": cfg.get("default_resolution"),
"default_duration": cfg.get("default_duration"),
"default_ratio": cfg.get("default_ratio"),
"price_cny_per_mtoken_text2video": cfg.get("price_cny_per_mtoken_text2video"),
"price_cny_per_mtoken_video2video": cfg.get("price_cny_per_mtoken_video2video"),
"is_default": i == 0,
})
return {"models": out}
# ───────────── Auth ─────────────
@app.post("/v1/auth/login", tags=["auth"])
def login(body: LoginRequest):
"""platform_key 校验通过 → 签 JWT(user_id 作为 sub)。
platform_key 错 → 403;user_id 非 UUID → 400。
user_id 未存在则幂等创建 users 行(避免下游 FK 失败);body 带 name / user_name
时一并 upsert 落库(平台侧改名每次登录自动同步,见 ensure_user_row)。
platform 服务端用此入口注入指定 user_id;dev SPA 走 /login_password。
"""
if body.platform_key != auth_cfg.platform_key:
raise HTTPException(403, "invalid platform_key")
try:
uid = UUID(body.user_id)
except (ValueError, TypeError):
raise HTTPException(400, f"invalid user_id (must be UUID): {body.user_id!r}")
ensure_user_row(uid, name=body.name, user_name=body.user_name)
token, exp = mint_token(auth_cfg, uid)
prof = get_user_profile(uid) or {}
return {
"token": token,
"expires_at": _dt.fromtimestamp(exp).isoformat(),
"user_id": str(uid),
"name": prof.get("name"),
"user_name": prof.get("user_name"),
"role": prof.get("role", "user"),
"ttl_seconds": auth_cfg.ttl_seconds,
}
@app.post("/v1/auth/admin/create_user", tags=["auth"])
def admin_create_user(body: AdminCreateUserRequest):
"""管理员发用户(dev SPA 登录页右下角入口)。
- `ZCBOT_ADMIN_TOKEN` env 未设 → 503,功能关闭
- `admin_token` 不匹配 → 403(不细分 "未设" / "错了",防探测)
- email 不合法 / password 太短 → 400
- email 已存在 → 409
- 成功 → `{"user_id": ..., "email": ...}`,前端提示 "已创建,请登录"
不签 token、不自动登录 —— 管理员发完用户用户自己登,逻辑清晰。
"""
if auth_cfg.admin_token is None:
raise HTTPException(503, "admin create_user disabled (ZCBOT_ADMIN_TOKEN not set)")
if body.admin_token != auth_cfg.admin_token:
raise HTTPException(403, "invalid admin_token")
try:
uid, email = create_user(
email=body.email, password=body.password, role=body.role
)
except UserCreateError as ex:
if ex.code in ("invalid_email", "weak_password", "invalid_role"):
raise HTTPException(400, ex.message)
if ex.code == "email_taken":
raise HTTPException(409, "email already exists")
raise HTTPException(500, f"create_user failed: {ex.message}")
return {"user_id": str(uid), "email": email, "role": body.role}
@app.post("/v1/auth/login_password", tags=["auth"])
def login_password(body: PasswordLoginRequest):
"""邮箱密码登录(dev SPA 给同事 / 自己试用)。
- users.email 未命中 / password_hash 为空 / bcrypt 校验失败 → 一律 403
(不细分错因,防探测用户存在性)
- 命中 → 直接用 DB 里现成 user_id 签 JWT(不 ensure_user_row,行已在 `user add` 时建)
- 发用户:`.venv/Scripts/python.exe main.py user add --email X --password Y`;
撤用户:`DELETE FROM users WHERE email=...`(先 DELETE 该 user 的 tasks)
"""
hit = resolve_user_by_email(body.email, body.password)
if hit is None:
raise HTTPException(403, "账号或密码错误")
uid, email = hit
token, exp = mint_token(auth_cfg, uid)
prof = get_user_profile(uid) or {}
return {
"token": token,
"expires_at": _dt.fromtimestamp(exp).isoformat(),
"user_id": str(uid),
"email": email,
"name": prof.get("name"),
"user_name": prof.get("user_name"),
"role": prof.get("role", "user"),
"ttl_seconds": auth_cfg.ttl_seconds,
}
@app.post("/v1/auth/change_password", tags=["auth"])
def change_password_route(
body: ChangePasswordRequest, user_id: UUID = Depends(require_user)
):
"""改密码(dev SPA 顶栏入口)。user_id 取自 JWT,不信任前端传值。
- 新密码 < 6 → 400
- 旧密码错 / 该账号无密码(platform_key 建的)→ 403(不细分,防探测)
- 用户不存在(JWT 有效但行没了)→ 401
成功 → `{"ok": true}`,前端提示并清空表单。
"""
try:
change_password(user_id, body.old_password, body.new_password)
except UserCreateError as ex:
if ex.code == "weak_password":
raise HTTPException(400, ex.message)
if ex.code in ("wrong_password", "no_password"):
raise HTTPException(403, ex.message)
if ex.code == "user_not_found":
raise HTTPException(401, ex.message)
raise HTTPException(500, f"change_password failed: {ex.message}")
return {"ok": True}
# ───────────── Tasks CRUD ─────────────
@app.post("/v1/tasks", status_code=201, tags=["tasks"])
def create_task(body: TaskCreateRequest, user_id: UUID = Depends(require_user)):
"""新建 task。
- `name` 必填(任务显示名,DB 列 NOT NULL,UI 列表 / 标题用)
- `working_dir` 可选(留空 → 用 name 作目录名);同 working_dir 多 task 共享同目录(§7.1)
- name / working_dir 都过 validate_task_name(简单名,无 `/\\..`,非 `.` 起头,≤255)
- 前缀嵌套(no-subtask,同 user 内)→ 409
"""
from core.agent_builder import InvalidTaskName, resolve_workspace, validate_task_name, working_dir_from_name
try:
name = validate_task_name(body.name)
except InvalidTaskName as e:
raise HTTPException(400, f"name 不合法: {e}")
# working_dir 留空 → fallback 用 name
wd_raw = (body.working_dir or "").strip()
wd_name = wd_raw if wd_raw else name
try:
wd_name = validate_task_name(wd_name)
except InvalidTaskName as e:
raise HTTPException(400, f"working_dir 不合法: {e}")
description = body.description.strip()
skill = body.skill.strip()
tid = uuid4()
ws = resolve_workspace(None)
fs_dir = working_dir_from_name(ws, user_id, wd_name)
fs_dir_db = to_db_path(fs_dir)
try:
check_no_subtask(fs_dir_db, user_id=user_id)
except NoSubtaskError as e:
raise HTTPException(409, str(e))
# 工作目录立刻建出(同 working_dir 多 task 共享,exist_ok=True)
fs_dir.mkdir(parents=True, exist_ok=True)
profile, model_id = _resolve_model_profile(body.model_profile)
ensure_local_task_row(
task_id=tid, name=name, working_dir=fs_dir_db, skill=skill,
description=description, user_id=user_id,
model=model_id, model_profile=profile,
)
with session_scope() as s:
row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
return _task_dict(row, n_messages=0)
@app.get("/v1/tasks", tags=["tasks"])
def list_tasks_route(
page: int = 1,
page_size: int = 20,
status: Optional[str] = None,
skill: Optional[str] = None,
working_dir: Optional[str] = None,
q: Optional[str] = None,
ordering: Optional[str] = None,
run_status: Optional[str] = None,
user_id: UUID = Depends(require_user),
):
"""列出当前 user 的 task,分页 + 多维筛选 + 排序。
- `page` ≥ 1(1-based);`page_size` 1100(超界 clamp)
- `status` 在 active/completed/abandoned;非法值静默忽略
- `skill` 精确匹配(空忽略)
- `working_dir` 末段目录名(如 `水泥申报`);后端自动拼 `workspace/users/<uid>/<name>` 比对
- `q` 模糊搜索 name + description(ILIKE,大小写不敏感)
- `run_status` 逗号分隔,allowlist `idle/running/cancelling/error`;非法值静默忽略
(dev SPA 拉同 wd 活跃 task 用,通常 `running,cancelling`)
- `ordering` DRF 风格,逗号分隔,`-field` 倒序;allowlist `created_at/updated_at/name/status`;
非法字段静默忽略;**默认 `-created_at`**(创建时间倒序)
返回标准分页壳 `{page, page_size, count, results}` —— count 供前端算总页数。
"""
# clamp + sanitize
page = max(1, page)
page_size = max(1, min(page_size, 100))
status = status if status in STATUS_FILTERS else None
skill = (skill or "").strip() or None
wd_name = (working_dir or "").strip() or None
q_text = (q or "").strip() or None
rs_allowed = ("idle", "running", "cancelling", "error")
run_status_set = {
s.strip() for s in (run_status or "").split(",") if s.strip() in rs_allowed
} or None
# 组装 WHERE(软删除的 task 永不出现在列表;恢复见 /restore)
conditions = [Task.user_id == user_id, Task.deleted_at.is_(None)]
if status:
conditions.append(Task.status == status)
if skill:
conditions.append(Task.skill == skill)
if wd_name:
# 末段 → 完整 db form。同 working_dir 多 task 共享时,这是命中入口。
wd_db = f"workspace/users/{user_id}/{wd_name}"
conditions.append(Task.working_dir == wd_db)
if run_status_set:
conditions.append(Task.run_status.in_(run_status_set))
if q_text:
pat = f"%{q_text}%"
conditions.append(Task.name.ilike(pat) | Task.description.ilike(pat))
offset = (page - 1) * page_size
with session_scope() as s:
cnt = s.execute(
select(func.count()).select_from(Task).where(*conditions)
).scalar_one() or 0
# 微信渠道常驻 task 后端强制置顶(恒排全局最上,跨分页稳定),
# 用户选的排序对其余 task 照常生效。
pin = case((Task.channel == "wechat", 0), else_=1).asc()
rows = s.execute(
select(Task).where(*conditions)
.order_by(pin, *_parse_ordering(ordering))
.limit(page_size).offset(offset)
).scalars().all()
tids = [r.task_id for r in rows]
msg_counts = (
dict(s.execute(
select(Message.task_id, func.count())
.where(Message.task_id.in_(tids))
.group_by(Message.task_id)
).all())
if tids else {}
)
usage = _usage_aggregates(s, tids)
return {
"page": page,
"page_size": page_size,
"count": int(cnt),
"results": [
_task_dict(
r,
n_messages=msg_counts.get(r.task_id, 0),
usage=usage.get(r.task_id),
)
for r in rows
],
}
@app.get("/v1/tasks/{task_id}", tags=["tasks"])
def get_task(task_id: str, user_id: UUID = Depends(require_user)):
"""单 task meta(不含 messages;走 /messages 拿)。跨 user → 404。"""
try:
tid = UUID(task_id)
except ValueError:
raise HTTPException(404, f"invalid task id: {task_id!r}")
with session_scope() as s:
row = s.execute(
select(Task).where(Task.task_id == tid, Task.user_id == user_id)
).scalar_one_or_none()
if row is None:
raise HTTPException(404, f"task not found: {tid}")
n = s.execute(
select(func.count()).select_from(Message).where(Message.task_id == tid)
).scalar_one()
usage = _usage_aggregates(s, [tid])
return _task_dict(row, n_messages=n, usage=usage.get(tid))
@app.get("/v1/folders", tags=["folders"])
def list_folders(user_id: UUID = Depends(require_user)):
"""列出当前 user 的工作目录(`workspace/users/<uid>/` 下非 dotfile 子目录)。
供新建 task 时自动补全 / 选已有目录用。FS 是 source of truth(也含手动创建
但还无关联 task 的目录)。每项带 n_tasks(关联 task 数)+ last_used(最近使用 ISO)。
排序:有 last_used 的按降序,无 last_used 的排最后,同列 by name asc。
"""
from core.agent_builder import resolve_workspace, user_root
ws = resolve_workspace(None)
root = user_root(ws, user_id)
folder_names: list[str] = []
if root.is_dir():
for p in sorted(root.iterdir(), key=lambda x: x.name.lower()):
if p.is_dir() and not p.name.startswith("."):
folder_names.append(p.name)
folders: list[dict] = []
if folder_names:
with session_scope() as s:
for name in folder_names:
db_form = f"workspace/users/{user_id}/{name}"
stat = s.execute(
select(func.count(), func.max(Task.updated_at))
.where(
Task.user_id == user_id,
Task.working_dir == db_form,
Task.deleted_at.is_(None),
)
).first()
n = int((stat[0] if stat else 0) or 0)
lu = stat[1] if stat else None
folders.append({
"name": name,
"n_tasks": n,
"last_used": _iso(lu),
})
folders.sort(key=lambda f: f["name"])
folders.sort(key=lambda f: f["last_used"] or "", reverse=True)
return {"folders": folders}
# ───────────── Skills ─────────────
@app.get("/v1/skills", tags=["skills"])
def list_skills(user_id: UUID = Depends(require_user)):
"""列出当前用户可用的 skill(内置 + 自己的),供新建 task 时下拉选择。
每次请求现扫(内置 `skills/<name>/SKILL.md` + 用户 `.skills/<name>/SKILL.md`,
稳态 ~3ms),加 / 改 / 删 skill 目录后无需重启即可在前端看到。
`core/agent_builder.py::build_agent` 同样每次新建 SkillRegistry,所以 agent 内部
`load_skill` 与 system prompt discovery 也是热的。源标 `source`(builtin/user)+
`overrides_builtin`(用户 skill 覆盖了同名内置)。`load_errors` 列出用户 skill
因 frontmatter 问题未加载的,供前端提示。
"""
from core.agent_builder import build_skill_registry, load_config, resolve_workspace
cfg = load_config()
ws = resolve_workspace(None, cfg)
reg = build_skill_registry(cfg, ws, user_id, docker=False)
return {
"skills": [
{
"name": s.name,
"description": s.description,
"source": s.source,
"overrides_builtin": s.name in reg.user_overrides,
}
for s in reg.skills.values()
],
"load_errors": [{"name": n, "reason": r} for n, r in reg.load_errors],
}
@app.get("/v1/skills/{name}", tags=["skills"])
def get_skill(name: str, user_id: UUID = Depends(require_user)):
"""返回某 skill 的完整 SKILL.md 正文(供前端 modal 展开查看)。
内置 + 用户两来源都可查;同名时按 user wins 取用户那份(与 agent 看到的一致)。
"""
from core.agent_builder import build_skill_registry, load_config, resolve_workspace
cfg = load_config()
ws = resolve_workspace(None, cfg)
reg = build_skill_registry(cfg, ws, user_id, docker=False)
skill = reg.get(name)
if skill is None:
raise HTTPException(404, f"skill not found: {name!r}")
try:
content = skill.full_content()
except OSError as e:
raise HTTPException(500, f"读取 SKILL.md 失败: {e}")
return {
"name": skill.name,
"source": skill.source,
"overrides_builtin": skill.name in reg.user_overrides,
"content": content,
}
@app.delete("/v1/skills/{name}", status_code=204, tags=["skills"])
def delete_skill(name: str, user_id: UUID = Depends(require_user)):
"""删除当前用户的私有 skill(`.skills/<name>/` 整目录)。
只能删 user 源 —— 内置 skill 不可删(404,等同"用户那里没有这个可删的")。
`.skills` 在文件面板隐藏,这是 UI 上删除自己 skill 的唯一入口。
"""
import shutil
from core.agent_builder import build_skill_registry, load_config, resolve_workspace, user_root
cfg = load_config()
ws = resolve_workspace(None, cfg)
reg = build_skill_registry(cfg, ws, user_id, docker=False)
skill = reg.get(name)
if skill is None or skill.source != "user":
raise HTTPException(404, f"no user skill to delete: {name!r}")
# 防穿越:目标必须落在该用户的 .skills 子树内(skill_dir 来自扫描,理应如此,仍兜一层)
user_skills_dir = (user_root(ws, user_id) / ".skills").resolve()
target = skill.skill_dir.resolve()
if user_skills_dir not in target.parents:
raise HTTPException(400, "拒绝删除 .skills 之外的路径")
shutil.rmtree(target)
@app.get("/v1/memory", tags=["memory"])
def get_memory(user_id: UUID = Depends(require_user)):
"""记忆全貌(只读):core.md 原文 + extended 列表(filename + description)。
前端「记忆」弹框一次拉满。**只读** —— 改记忆全走对话(agent 自管,见
`core/memory.py::_CONTRACT`),GUI 当"眼睛"不当""(DESIGN §3.7)。
每次现读 FS(同 skills 端点),agent 刚写的即时可见。
"""
from core.agent_builder import resolve_workspace
from core.memory import memory_view
ws = resolve_workspace(None)
return memory_view(ws, user_id)
@app.get("/v1/memory/extended/{filename}", tags=["memory"])
def get_memory_extended(filename: str, user_id: UUID = Depends(require_user)):
"""读单篇 extended 原文(点开列表项时拉)。文件名非法 / 不存在 → 404。
路径穿越校验收口在 `core/memory.py::read_extended_file`(只许 `.memory/extended/`
下扁平 `.md`,再 resolve 兜一层子树)。
"""
from core.agent_builder import resolve_workspace
from core.memory import read_extended_file
ws = resolve_workspace(None)
content = read_extended_file(ws, user_id, filename)
if content is None:
raise HTTPException(404, f"memory file not found: {filename!r}")
return {"filename": filename, "content": content}
# ───────────── 定时任务(DESIGN §8.5)─────────────
# 前端只读展示 + 停用/删除两个便捷动作;建/改全走对话(schedule_* 工具)。
# 与对话工具共用 core.scheduler 服务层,两条路径不漂移。
@app.get("/v1/schedules", tags=["schedules"])
def list_schedules(user_id: UUID = Depends(require_user)):
"""列当前用户的定时任务(只读)。前端「定时」面板一次拉满。"""
from core import scheduler
return {"results": scheduler.list_jobs(user_id)}
@app.patch("/v1/schedules/{job_id}", tags=["schedules"])
def patch_schedule(
job_id: str, body: SchedulePatchRequest, user_id: UUID = Depends(require_user),
):
"""改定时任务 —— 前端只用来停用/启用(enabled)。其余编辑走对话。"""
from core import scheduler
if body.enabled is None:
raise HTTPException(400, "no fields to update")
try:
return scheduler.set_enabled(user_id, job_id, body.enabled)
except scheduler.JobError as e:
raise HTTPException(404, str(e))
@app.delete("/v1/schedules/{job_id}", status_code=204, tags=["schedules"])
def delete_schedule(job_id: str, user_id: UUID = Depends(require_user)):
"""删定时任务(软删,立即停止触发)。"""
from core import scheduler
try:
scheduler.cancel_job(user_id, job_id)
except scheduler.JobError as e:
raise HTTPException(404, str(e))
@app.delete("/v1/tasks/{task_id}", status_code=204, tags=["tasks"])
def delete_task(task_id: str, user_id: UUID = Depends(require_user)):
"""软删除:置 deleted_at=now(),从任务列表隐藏。
DB 行 / messages / usage_events(原 CASCADE 不再触发)及工作目录文件全部保留
—— 留作训练语料,且可经 POST /v1/tasks/{id}/restore 恢复。不动任何磁盘文件。
已软删的再次调用幂等返回 204。跨 user / 不存在 → 404。
"""
try:
tid = UUID(task_id)
except ValueError:
raise HTTPException(404, f"invalid task id: {task_id!r}")
with session_scope() as s:
row = s.execute(
select(Task.deleted_at).where(
Task.task_id == tid, Task.user_id == user_id,
)
).first()
if row is None:
raise HTTPException(404, f"task not found: {tid}")
if row.deleted_at is None:
s.execute(
update(Task)
.where(Task.task_id == tid, Task.user_id == user_id)
.values(deleted_at=func.now())
)
return None # 204
@app.post("/v1/tasks/{task_id}/restore", tags=["tasks"])
def restore_task(task_id: str, user_id: UUID = Depends(require_user)):
"""恢复软删除的 task(置 deleted_at=NULL),重新出现在列表。
未软删的幂等成功。跨 user / 不存在 → 404。
"""
try:
tid = UUID(task_id)
except ValueError:
raise HTTPException(404, f"invalid task id: {task_id!r}")
with session_scope() as s:
row = s.execute(
select(Task).where(Task.task_id == tid, Task.user_id == user_id)
).scalar_one_or_none()
if row is None:
raise HTTPException(404, f"task not found: {tid}")
row.deleted_at = None # ORM 脏标记,session_scope 提交时落库
n = s.execute(
select(func.count()).select_from(Message).where(Message.task_id == tid)
).scalar_one()
usage = _usage_aggregates(s, [tid])
return _task_dict(row, n_messages=n, usage=usage.get(tid))
@app.patch("/v1/tasks/{task_id}", tags=["tasks"])
def patch_task(
task_id: str,
body: TaskPatchRequest,
user_id: UUID = Depends(require_user),
):
"""更新 task 字段。`status` 仅允许 completed/abandoned(active 走 CLI 切回)。"""
try:
tid = UUID(task_id)
except ValueError:
raise HTTPException(404, f"invalid task id: {task_id!r}")
updates: dict[str, Any] = {}
if body.status is not None:
if body.status not in STATUS_WRITABLE:
raise HTTPException(
400, f"invalid status {body.status!r}; allowed: {STATUS_WRITABLE}"
)
updates["status"] = body.status
if body.description is not None:
updates["description"] = body.description
if body.skill is not None:
updates["skill"] = body.skill
if body.name is not None:
from core.agent_builder import InvalidTaskName, validate_task_name
try:
updates["name"] = validate_task_name(body.name)
except InvalidTaskName as e:
raise HTTPException(400, f"name 不合法: {e}")
if body.model_profile is not None:
# 切模型:校验后双列同更(profile + model_id)。下条 send 才生效 — 当前
# in-flight run 不受影响(build_agent resume 时下次重读)。
profile, model_id = _resolve_model_profile(body.model_profile)
updates["model_profile"] = profile
updates["model"] = model_id
if not updates:
raise HTTPException(400, "no fields to update")
with session_scope() as s:
result = s.execute(
update(Task)
.where(Task.task_id == tid, Task.user_id == user_id)
.values(**updates)
)
if result.rowcount == 0:
raise HTTPException(404, f"task not found: {tid}")
row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
n = s.execute(
select(func.count()).select_from(Message).where(Message.task_id == tid)
).scalar_one()
usage = _usage_aggregates(s, [tid])
return _task_dict(row, n_messages=n, usage=usage.get(tid))
# ───────────── Messages ─────────────
def _assert_owns_task(s, tid: UUID, user_id: UUID) -> None:
ok = s.execute(
select(Task.task_id).where(Task.task_id == tid, Task.user_id == user_id)
).first()
if ok is None:
raise HTTPException(404, f"task not found: {tid}")
@app.get("/v1/tasks/{task_id}/messages", tags=["messages"])
def list_messages(
task_id: str,
limit: int = None,
before_idx: int = None,
after_idx: int = None,
user_id: UUID = Depends(require_user),
):
"""task 历史消息(idx 升序);LiteLLM 原 payload 透传给前端,自行渲染。
分页(双向窗口):
- 不传 limit → 升序全量返回(向后兼容旧前端),两个 has_more 都 false。
- 传 limit(默认)→ 取**尾部**最近 limit 条(idx desc + limit 再 reverse 回升序)。
- 传 before_idx → 只取 idx < before_idx 的更早部分(向上翻页)。
- 传 after_idx → 只取 idx > after_idx 的更新部分(向下翻页;从目录跳到旧消息后用)。
响应恒含 has_more(窗口之前是否还有更早)+ has_more_after(窗口之后是否还有更新)。
"""
try:
tid = UUID(task_id)
except ValueError:
raise HTTPException(404, f"invalid task id: {task_id!r}")
with session_scope() as s:
_assert_owns_task(s, tid, user_id)
cols = (
Message.idx, Message.payload, Message.tokens_in,
Message.tokens_out, Message.model_profile, Message.created_at,
)
if limit is None:
# 旧行为:升序全量
rows = s.execute(
select(*cols).where(Message.task_id == tid).order_by(Message.idx)
).all()
elif after_idx is not None:
# 向下窗口:正序取 idx > after_idx 的最早 limit 条
rows = list(s.execute(
select(*cols)
.where(Message.task_id == tid, Message.idx > after_idx)
.order_by(Message.idx).limit(limit)
).all())
else:
# 尾部 / 向上窗口:倒序取 limit 条,再翻回升序
q = select(*cols).where(Message.task_id == tid)
if before_idx is not None:
q = q.where(Message.idx < before_idx)
rows = list(s.execute(q.order_by(Message.idx.desc()).limit(limit)).all())
rows.reverse()
# 窗口两端外是否还有(供前端顶/底 sentinel 决定要不要继续补)
has_more = False
has_more_after = False
if rows:
first_idx = rows[0].idx
last_idx = rows[-1].idx
has_more = s.execute(
select(Message.idx)
.where(Message.task_id == tid, Message.idx < first_idx)
.limit(1)
).first() is not None
has_more_after = s.execute(
select(Message.idx)
.where(Message.task_id == tid, Message.idx > last_idx)
.limit(1)
).first() is not None
return {
"has_more": has_more,
"has_more_after": has_more_after,
"messages": [
{
"idx": r.idx,
"payload": dict(r.payload),
"tokens_in": r.tokens_in,
"tokens_out": r.tokens_out,
"model_profile": r.model_profile, # 0006:assistant 行非空,标产生该 msg 的模型
"created_at": _iso(r.created_at),
}
for r in rows
]
}
@app.get("/v1/tasks/{task_id}/outline", tags=["messages"])
def task_outline(task_id: str, user_id: UUID = Depends(require_user)):
"""消息目录:全部 user 轮次的 {idx, snippet}(idx 升序),供右侧圆点轨道导航。
只取 role=user 的 idx + content 首行片段,不回传整 payload(轻量,长任务也快);
走 (task_id, idx) 索引按 task 收窄,role 过滤为残余条件。前端点圆点 → 已加载则
scrollIntoView,未加载则用 before_idx 拉居中窗口再定位。
"""
try:
tid = UUID(task_id)
except ValueError:
raise HTTPException(404, f"invalid task id: {task_id!r}")
with session_scope() as s:
_assert_owns_task(s, tid, user_id)
rows = s.execute(
select(Message.idx, Message.payload["content"].astext)
.where(
Message.task_id == tid,
Message.payload["role"].astext == "user",
)
.order_by(Message.idx)
).all()
return {
"items": [
{"idx": r[0], "snippet": _outline_snippet(r[1])} for r in rows
]
}
@app.post("/v1/tasks/{task_id}/messages", status_code=202, tags=["messages"])
async def post_message(
task_id: str,
body: MessageRequest,
user_id: UUID = Depends(require_user),
):
"""发消息 + 起 BG run。返 `{events_url}`,客户端立刻订阅 SSE 拿流式。
单活 run:`SELECT … FOR UPDATE` 锁 task 行 + 活跃状态检查 + 标 running,
全收进一个事务挡住"用户连点 send 两条消息"导致两个 BG 线程争 `messages.idx`。
tasks.run_status in ('running','cancelling') → 409;'error' 走起新 run 时清掉
(跟 ok / cancelled 一样视为可重启)。
"""
try:
tid = UUID(task_id)
except ValueError:
raise HTTPException(404, f"invalid task id: {task_id!r}")
# 关停 drain 期:拒新 run,带 Retry-After 让客户端退避重试(部署窗口背压)。
if getattr(app.state, "draining", None) is not None and app.state.draining.is_set():
raise HTTPException(
503, "server is restarting; retry shortly",
headers={"Retry-After": "3"},
)
content = (body.content or "").strip()
if not content:
raise HTTPException(400, "empty content")
with session_scope() as s:
row = s.execute(
select(Task.run_status)
.where(Task.task_id == tid, Task.user_id == user_id)
.with_for_update()
).first()
if row is None:
raise HTTPException(404, f"task not found: {tid}")
if row.run_status in ("running", "cancelling"):
raise HTTPException(
409,
f"task already has an active run (status={row.run_status}); "
f"wait for it to finish or cancel",
)
s.execute(
update(Task).where(Task.task_id == tid).values(
run_status="running", run_error=None,
)
)
# image_model / video_model 在 POST 时校验,避免 BG 线程里抛在 sink 之外难追;空串透传不查 yaml。
image_variant = _resolve_image_model(body.image_model)
video_variant = _resolve_video_model(body.video_model)
broker.start(tid) # 清上一轮 done 标记,新订阅者才能看到流式
# commit 后 lock 释放;BG 线程接管(sink 通过 broker 把 event 桥回 asyncio loop)。
# 登记到 app.state.inflight:① 关停 drain 时 await 它收尾 ② 持强引用防 task 被 GC
# 中途回收(asyncio.create_task 不留引用是已知坑)。done 回调自摘除。
run_task = asyncio.create_task(asyncio.to_thread(
_run_agent_bg, tid, user_id, content, image_variant, video_variant,
))
app.state.inflight[run_task] = tid
run_task.add_done_callback(lambda t: app.state.inflight.pop(t, None))
return {"events_url": f"/v1/tasks/{tid}/events"}
# ───────────── Cancel current run ─────────────
@app.post("/v1/tasks/{task_id}/cancel", status_code=202, tags=["tasks"])
def cancel_task(
task_id: str,
user_id: UUID = Depends(require_user),
):
"""向当前 task 的活跃 run 发协作式 cancel 信号。
- 单活 run 形态下"取消当前活动"语义无歧义;客户端只需 task_id
- 校验 task 归属 user;否则 404
- tasks.run_status 不是 `running` → 409(idle / cancelling / error 都不能 cancel)
- 标 `cancelling`(过渡态),BG 线程 loop 在 stream chunk 间 + 工具调用之间 poll 看见即退;
退出后 finally 写终态(正常→idle,异常→error)
- LLM 走 streaming,cancel 延迟 ~ 单 chunk 间隔(100ms 级)
"""
try:
tid = UUID(task_id)
except ValueError:
raise HTTPException(404, f"invalid task id: {task_id!r}")
with session_scope() as s:
row = s.execute(
select(Task.run_status)
.where(Task.task_id == tid, Task.user_id == user_id)
.with_for_update()
).first()
if row is None:
raise HTTPException(404, f"task not found: {tid}")
if row.run_status != "running":
raise HTTPException(
409,
f"task not running (run_status={row.run_status}); cannot cancel",
)
s.execute(
update(Task).where(Task.task_id == tid).values(run_status="cancelling")
)
broker.request_cancel(tid)
return {"ok": True, "task_id": str(tid), "run_status": "cancelling"}
# ───────────── Clear conversation ─────────────
@app.post("/v1/tasks/{task_id}/clear", tags=["messages"])
def clear_messages(task_id: str, user_id: UUID = Depends(require_user)):
"""清空当前 task 全部 messages,token 累计 / cost / run_error 归零。
同 working_dir 下的 FS 文件不动(沿用 task delete 的"FS 视图可重生"心智 —
中间产物保留,模型重起对话时可继续基于已有素材推进)。
usage_events 不动:那是用户级账户级用量记账,不该被对话清理影响。
- 活跃 run(running / cancelling)期间拒绝:409(先 cancel)
- error 状态可清:顺手 run_status='idle' + run_error=None
- 跨 user → 404
"""
try:
tid = UUID(task_id)
except ValueError:
raise HTTPException(404, f"invalid task id: {task_id!r}")
from sqlalchemy import delete as _delete
with session_scope() as s:
row = s.execute(
select(Task.run_status)
.where(Task.task_id == tid, Task.user_id == user_id)
.with_for_update()
).first()
if row is None:
raise HTTPException(404, f"task not found: {tid}")
if row.run_status in ("running", "cancelling"):
raise HTTPException(
409,
f"task has an active run (status={row.run_status}); "
f"cancel it first",
)
s.execute(_delete(Message).where(Message.task_id == tid))
s.execute(
update(Task).where(Task.task_id == tid).values(
tokens_prompt=0,
tokens_completion=0,
cost_cny=0,
run_status="idle",
run_error=None,
)
)
task_row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
d = _task_dict(task_row, n_messages=0)
return d
# ───────────── Prompt optimize(辅助 LLM 调用,不入 messages)─────────────
@app.post("/v1/tasks/{task_id}/optimize_prompt", tags=["messages"])
def optimize_prompt(
task_id: str,
body: OptimizePromptRequest,
user_id: UUID = Depends(require_user),
):
"""用 task 当前 model 润色用户草稿 prompt;返回优化后的文本。
- 同步调用(短文本,3-5s),非 stream
- 不入 messages 表;**不**累计到 tasks.tokens_prompt/completion(顶栏数字保持
只反映主对话)。usage_events 单独写一行 kind="prompt_optimize",方便对账
+ 按 kind GROUP BY 评估"这个按钮值不值"
- 不与主对话 run 互斥(它不写 messages,无 idx 竞争)— 用户在 LLM 流式
回复期间也可润色下一条草稿
- image_model 影响 meta-prompt 里给 LLM 的下游 tool 提示;不动 DB
"""
from decimal import Decimal
from core.agent_builder import load_config
from core.capabilities import ModelCapabilities
from core.llm import LLM
from core.paths import ROOT
from core.storage.models import UsageEvent
from core.storage.usage import USD_TO_CNY
try:
tid = UUID(task_id)
except ValueError:
raise HTTPException(404, f"invalid task id: {task_id!r}")
text = (body.text or "").strip()
if not text:
raise HTTPException(400, "empty text")
if len(text) > 4000:
raise HTTPException(400, "text too long (>4000 chars)")
with session_scope() as s:
row = s.execute(
select(Task.model_profile)
.where(Task.task_id == tid, Task.user_id == user_id)
).first()
if row is None:
raise HTTPException(404, f"task not found: {tid}")
task_model_profile = row.model_profile or ""
cfg = load_config()
chosen_profile = task_model_profile or cfg["default_model"]
try:
caps = ModelCapabilities.load(chosen_profile, ROOT / cfg["models_dir"])
except (FileNotFoundError, ValueError) as e:
raise HTTPException(500, f"invalid task model_profile {chosen_profile!r}: {e}")
# 收集下游 tool 上下文:对话模型 display_name + 当前选中 image/video variant 元数据
chat_model_display = caps.display_name or chosen_profile
image_variant_hint = ""
img_variant = (body.image_model or "").strip()
if img_variant:
for k, v in _list_image_variants():
if k == img_variant:
name = v.get("display_name") or k
sz = v.get("default_size") or "2048x2048"
image_variant_hint = (
f"\n下游生图工具:{name}(默认尺寸 {sz},支持中英文 prompt,"
f"擅长写实/插画/构图描述)。若用户意图涉及画面/封面/插图,"
f"润色后的文本要给出适合该模型的画面细节(主体/风格/光线/构图)。"
)
video_variant_hint = ""
vid_variant = (body.video_model or "").strip()
if vid_variant:
for k, v in _list_video_variants():
if k == vid_variant:
name = v.get("display_name") or k
res = v.get("default_resolution") or "720p"
dur = v.get("default_duration") or 5
video_variant_hint = (
f"\n下游生视频工具:{name}(默认 {res} / {dur}s / 16:9)。"
f"若用户意图涉及视频/动画/动起来,润色后的文本要补全 "
f"主体在做什么(运动)+ 镜头怎么动 + 场景 + 风格,而非静态画面描述。"
)
meta_prompt = (
f"你的任务是润色用户输入的草稿,使之成为一个清晰、完整、可执行的 prompt。\n"
f"当前对话模型:{chat_model_display}{image_variant_hint}{video_variant_hint}\n\n"
f"规则:\n"
f"1. 只输出润色后的文本本身,不要任何解释、前后缀、引号、markdown 代码块包裹\n"
f"2. 保留用户原始语言(中文/英文)\n"
f"3. 补全模糊点(主体、目标、风格、约束),但不要无中生有改变用户意图\n"
f"4. 长度合理 — 简短诉求润色后也应当简洁,不要堆砌\n\n"
f"用户草稿:\n{text}"
)
llm = LLM(caps)
try:
response = llm.chat(
messages=[{"role": "user", "content": meta_prompt}],
tools=None,
)
except Exception as e:
raise HTTPException(502, f"llm call failed: {type(e).__name__}: {e}")
try:
optimized = (response.choices[0].message.content or "").strip()
except Exception:
raise HTTPException(502, "llm response missing content")
if not optimized:
raise HTTPException(502, "llm returned empty optimization")
usage = getattr(response, "usage", None)
prompt_tokens = int(getattr(usage, "prompt_tokens", 0) or 0)
completion_tokens = int(getattr(usage, "completion_tokens", 0) or 0)
try:
from litellm import completion_cost
cost_usd_raw = completion_cost(completion_response=response)
cost_usd = Decimal(str(cost_usd_raw)) if cost_usd_raw else Decimal("0")
except Exception:
cost_usd = Decimal("0")
cost_cny = (cost_usd * USD_TO_CNY).quantize(Decimal("0.000001"))
try:
with session_scope() as s:
s.add(UsageEvent(
user_id=user_id,
task_id=tid,
message_id=None,
kind="prompt_optimize",
model_profile=chosen_profile,
units={
"tokens_in": prompt_tokens,
"tokens_out": completion_tokens,
"usd_to_cny": float(USD_TO_CNY),
"image_model_hint": img_variant or "",
"video_model_hint": vid_variant or "",
},
cost_cny=cost_cny,
))
except Exception as e:
# 记账失败不阻塞返结果 — 用户拿到润色文本要紧,事后人工补
print(f"[optimize_prompt] usage record failed: {type(e).__name__}: {e}", flush=True)
return {
"optimized": optimized,
"model_profile": chosen_profile,
"tokens_in": prompt_tokens,
"tokens_out": completion_tokens,
"cost_cny": float(cost_cny),
}
# ───────────── SSE events ─────────────
@app.get("/v1/tasks/{task_id}/events", tags=["tasks"])
async def stream_events(
task_id: str,
user_id: UUID = Depends(require_user),
):
"""SSE 流。订阅当前 task 的活动 event(单活 run 形态下无歧义)。
事件类型:run_start / llm_start / text / tool_call / tool_result /
llm_end / cancelled / error / done。data 是 JSON dict(已剔除 `type` 字段,
移到 event 名)。
"""
try:
tid = UUID(task_id)
except ValueError:
raise HTTPException(404, f"invalid task id: {task_id!r}")
with session_scope() as s:
_assert_owns_task(s, tid, user_id)
run_status = s.execute(
select(Task.run_status).where(Task.task_id == tid)
).scalar_one()
# 重连保护:若 task 不在活跃态(进程重启 / reaper 已收尾 / 自然结束),
# 直接吐 done 关流。否则 broker 进程内队列空,客户端会无限挂在 ping 上。
is_active = run_status in ("running", "cancelling")
async def gen():
yield b": connected\nretry: 3000\n\n"
if not is_active:
yield _sse_event("done", {})
return
q = broker.subscribe(tid)
try:
while True:
try:
ev = await asyncio.wait_for(q.get(), timeout=30.0)
except asyncio.TimeoutError:
yield b": ping\n\n"
continue
ev_type = ev.get("type", "msg")
payload = {k: v for k, v in ev.items() if k != "type"}
yield _sse_event(ev_type, payload)
if ev_type in ("done", "error"):
break
except asyncio.CancelledError:
pass # 客户端断开,静默退
finally:
broker.unsubscribe(tid, q)
return StreamingResponse(
gen(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# ───────────── Files(user-rooted,不绑 task) ─────────────
@app.get("/v1/user/storage", tags=["user"])
def user_storage(user_id: UUID = Depends(require_user)):
"""当前用户磁盘用量 + 配额。
数据来自后台扫描落库的 user_disk_usage(默 15min 一次),非实时;
无扫描记录(新用户 / 首次扫描前)bytes_used/file_count=0、scanned_at=None。
limit_bytes<=0 或 None → 不限(前端不画进度条)。
"""
from core.agent_builder import load_config as _load_cfg
from core.storage.disk_quota import get_user_usage, parse_bytes
_quotas_cfg = (_load_cfg().get("quotas") or {})
_limit = parse_bytes(_quotas_cfg.get("disk_bytes_per_user"))
usage = get_user_usage(user_id)
if usage is None:
bytes_used, file_count, scanned_at = 0, 0, None
else:
bytes_used, file_count, scanned_at = usage
return {
"bytes_used": bytes_used,
"file_count": file_count,
"limit_bytes": (_limit if (_limit and _limit > 0) else None),
"scanned_at": scanned_at.isoformat() if scanned_at else None,
}
@app.get("/v1/files", tags=["files"])
def list_files(
path: str = "",
user_id: UUID = Depends(require_user),
):
"""列 user_root 下子目录条目 + 面包屑。`path` 留空 → user_root;
`../` / 绝对 → 400。dotfile(`.memory/` 等)一律隐藏。
"""
root = _load_user_root(user_id)
current = _safe_join(root, path)
entries, crumbs, exists = _enumerate_files(root, current)
return {
"root": _norm_path(str(root)),
"current": _rel_to(root, current),
"exists": exists,
"crumbs": crumbs,
"entries": entries,
}
@app.get("/v1/files/download", tags=["files"])
def download_file(
path: str,
user_id: UUID = Depends(require_user),
):
"""下载 user_root 下单个 regular file(目录 → 400 / 不存在 → 404)。"""
root = _load_user_root(user_id)
target = _safe_join(root, path)
if not target.exists():
raise HTTPException(404, f"file not found: {path}")
if not target.is_file():
raise HTTPException(400, f"not a file: {path}")
# workspace 文件可变, 禁浏览器启发式缓存 (RFC 7234 默认能缓数小时)
# 否则文件改了 SPA 预览还是旧内容
# (Starlette FileResponse 不实现 304, 总是 200 全量; workspace 文件小, 可接受)
return FileResponse(
path=str(target),
filename=target.name,
headers={"Cache-Control": "no-cache"},
)
@app.get("/v1/files/preview_pdf", tags=["files"])
async def preview_pdf(
path: str,
user_id: UUID = Depends(require_user),
):
"""把 user_root 下的 .pptx 转成 PDF 返回,供前端复用 PDF iframe 在线预览。
转换跑在 backend host(不进沙盒),按需触发 + 缓存到 `.preview/`(DESIGN §8.3)。
soffice 缺失 → 501;转换失败/超时 → 500;前端据此回退到下载。
"""
from .pptx_render import (
PptxConvertError,
SofficeNotFoundError,
pptx_to_pdf,
)
root = _load_user_root(user_id)
target = _safe_join(root, path)
if not target.exists():
raise HTTPException(404, f"file not found: {path}")
if not target.is_file():
raise HTTPException(400, f"not a file: {path}")
if target.suffix.lower() not in (".pptx", ".ppt"):
raise HTTPException(400, f"not a pptx: {path}")
abs_path = str(target.resolve())
loop = asyncio.get_event_loop()
async with _pptx_lock_for(abs_path):
try:
pdf_path = await loop.run_in_executor(None, pptx_to_pdf, target)
except SofficeNotFoundError as e:
raise HTTPException(501, str(e))
except PptxConvertError as e:
raise HTTPException(500, str(e))
return FileResponse(
path=str(pdf_path),
media_type="application/pdf",
headers={"Cache-Control": "no-cache"},
)
@app.post("/v1/files/upload", tags=["files"])
async def upload_files(
path: str = Form(""),
files: list[UploadFile] = File(...),
user_id: UUID = Depends(require_user),
):
"""multipart 多文件上传到 `<user_root>/<path>/`。
路径不存在自动 mkdir(parents=True);重名直接覆盖。
文件名严格校验(含 `/ \\ ..` 或为空 → 400)。
"""
# 磁盘配额 gate(§7.5 #4):超额 413 阻止上传,提示 user 清旧产物
from core.agent_builder import load_config as _load_cfg
from core.storage.disk_quota import check_disk_quota, parse_bytes
_quotas_cfg = (_load_cfg().get("quotas") or {})
_limit = parse_bytes(_quotas_cfg.get("disk_bytes_per_user"))
if _limit is not None and _limit > 0:
_err = check_disk_quota(user_id, _limit)
if _err is not None:
raise HTTPException(413, _err)
root = _load_user_root(user_id)
dest_dir = _safe_join(root, path)
if dest_dir.exists() and not dest_dir.is_dir():
raise HTTPException(400, f"upload target is a file, not a directory: {path}")
dest_dir.mkdir(parents=True, exist_ok=True)
saved: list[dict] = []
for up in files or []:
raw_name = up.filename or ""
if (
not raw_name
or raw_name in (".", "..")
or "/" in raw_name or "\\" in raw_name
or any(part in (".", "..") for part in Path(raw_name).parts)
):
raise HTTPException(400, f"invalid filename: {raw_name!r}")
dest = dest_dir / raw_name
try:
dest.resolve().relative_to(root.resolve())
except ValueError:
raise HTTPException(400, f"path escapes user_root: {raw_name!r}")
data = await up.read()
dest.write_bytes(data)
saved.append({"name": raw_name, "size": len(data), "rel": _rel_to(root, dest)})
if not saved:
raise HTTPException(400, "no files uploaded")
return {"count": len(saved), "saved": saved}
@app.post("/v1/files/delete", tags=["files"])
def delete_file(
body: FileDeleteRequest,
user_id: UUID = Depends(require_user),
):
"""删 user_root 下文件或目录。
- `recursive=False`(默认):目录必须为空(`rmdir`),非空 → 400
- `recursive=True`:`shutil.rmtree`;若目标是顶层目录且被某 task.working_dir
引用 → 409,提示先 DELETE task(避免 DB 还引用、FS artifacts 已清的错位)
- 顶层空目录 / 子级空目录无论 recursive 与否都可删:task.working_dir 字段不动,
下次 build_agent 按需 mkdir 重建,FS 目录视为可重生
- root → 400;不存在 → 404
"""
root = _load_user_root(user_id)
target = _safe_join(root, body.path)
if target.resolve() == root.resolve():
raise HTTPException(400, "cannot delete user_root")
if not target.exists():
raise HTTPException(404, f"path not found: {body.path}")
if target.is_dir() and body.recursive:
is_top_level = target.parent.resolve() == root.resolve()
if is_top_level:
db_form = to_db_path(target)
with session_scope() as s:
n = s.execute(
select(func.count()).select_from(Task).where(
Task.user_id == user_id,
Task.working_dir == db_form,
Task.deleted_at.is_(None), # 软删 task 不再算引用
)
).scalar_one() or 0
if n:
raise HTTPException(
409,
f"该顶层目录正被 {n} 个 task 引用,不能递归删除;"
f"请先 DELETE task,再清残留文件",
)
try:
if target.is_dir():
if body.recursive:
import shutil
shutil.rmtree(target)
else:
target.rmdir() # 非空目录会触发 OSError
else:
target.unlink()
except OSError as e:
raise HTTPException(400, f"delete failed: {e}")
return {"ok": True, "path": body.path}
@app.post("/v1/files/rename", tags=["files"])
def rename_path(
body: FileRenameRequest,
user_id: UUID = Depends(require_user),
):
"""重命名 user_root 下文件或目录(任意深度)。
- `path` 必填,指被重命名对象;不能为 user_root
- `new_name` 是新 leaf 名(`validate_task_name`:非空 / 不含 `/\\..` / 非 dotfile / ≤255);
不是路径,parent 自动取自原 path
- 目标 sibling `<parent>/<new_name>` 不能已存在(防覆盖)
- **path 是顶层目录**(user_root 直接子项,且为目录)→ DB-aware:
* 同事务内 `SELECT ... FOR UPDATE` 锁该目录对应 task;任一 run_status 在
running/cancelling → 409(避免 BG 线程握旧路径而 DB 已指新路径)
* `check_no_subtask(new_db, exclude=被改名 tids)` 防止改名后跟其它 task 形成嵌套
* `UPDATE tasks SET working_dir=new_db WHERE task_id IN (...)` 先写 DB
* 再 `os.rename` FS;失败 → 抛错 → session_scope 回滚 DB
* 唯一不一致窗口是 "FS 已改名 + commit 阶段失败"(PG 单事务 commit 极少失败)
- 非顶层(子目录 / 文件)→ 纯 FS rename,不动 DB
"""
from core.agent_builder import InvalidTaskName, validate_task_name
root = _load_user_root(user_id)
target = _safe_join(root, body.path)
if target.resolve() == root.resolve():
raise HTTPException(400, "cannot rename user_root")
if not target.exists():
raise HTTPException(404, f"path not found: {body.path}")
try:
new_name = validate_task_name(body.new_name)
except InvalidTaskName as e:
raise HTTPException(400, f"new_name 不合法: {e}")
if new_name == target.name:
raise HTTPException(400, f"new_name 与原名相同: {new_name!r}")
new_target = target.parent / new_name
if new_target.exists():
raise HTTPException(
409, f"target already exists: {_rel_to(root, new_target)!r}"
)
is_top_level_dir = (
target.is_dir() and target.parent.resolve() == root.resolve()
)
if not is_top_level_dir:
try:
target.rename(new_target)
except OSError as e:
raise HTTPException(400, f"rename failed: {e}")
return {
"ok": True,
"old": body.path,
"new": _rel_to(root, new_target),
"tasks_updated": 0,
}
# 顶层目录:DB-aware
old_db = to_db_path(target)
new_db = to_db_path(new_target)
with session_scope() as s:
rows = s.execute(
select(Task.task_id, Task.run_status)
.where(Task.user_id == user_id, Task.working_dir == old_db)
.with_for_update()
).all()
tids = [r.task_id for r in rows]
active = [
str(r.task_id)[:8] for r in rows
if r.run_status in ("running", "cancelling")
]
if active:
raise HTTPException(
409,
f"folder has active run(s) on task(s) {active}; "
f"cancel before renaming",
)
try:
check_no_subtask(new_db, user_id=user_id, exclude_task_ids=tids)
except NoSubtaskError as e:
raise HTTPException(409, str(e))
if tids:
s.execute(
update(Task)
.where(Task.task_id.in_(tids))
.values(working_dir=new_db)
)
try:
target.rename(new_target)
except OSError as e:
# 抛 HTTPException 也会让 session_scope 走 except 分支回滚 UPDATE
raise HTTPException(400, f"FS rename failed: {e}")
return {
"ok": True,
"old": body.path,
"new": _rel_to(root, new_target),
"tasks_updated": len(tids),
}
@app.post("/v1/files/copy", tags=["files"])
def copy_files(
body: FileTransferRequest,
user_id: UUID = Depends(require_user),
):
"""批量拷贝 paths → dest_dir/<name>(目录递归)。
- 不覆盖(任一目标已存在 → 409)
- 不能拷到自己 / 自身子树
- 顶层目录(可能是某 task 的 working_dir)可以拷:新副本无 task 关联,不动 DB
- 部分失败语义:任一 FS 拷贝抛错 → 抛 HTTPException,**前面已成功的拷贝保留**
(无 FS 事务可回滚;预检通过后通常不会失败,失败也是磁盘满 / 权限这类不能恢复的)
"""
import shutil
root = _load_user_root(user_id)
sources, dest = _validate_transfer(root, body.paths, body.dest_dir)
transferred: list[dict] = []
for src in sources:
target = dest / src.name
try:
if src.is_dir():
shutil.copytree(src, target)
else:
shutil.copy2(src, target)
except OSError as e:
raise HTTPException(
500,
f"copy failed at {src.name!r}: {e} "
f"(已成功 {len(transferred)} 项,剩余未处理)",
)
transferred.append({
"old": _rel_to(root, src),
"new": _rel_to(root, target),
})
return {"ok": True, "count": len(transferred), "transferred": transferred}
@app.post("/v1/files/move", tags=["files"])
def move_files(
body: FileTransferRequest,
user_id: UUID = Depends(require_user),
):
"""批量移动 paths → dest_dir/<name>。
- 不覆盖、不自嵌套(同 /copy)
- **顶层目录是某 task 的 working_dir → 409**,维持 "working_dir = 顶层目录" invariant
(允许的话 task working_dir 沉到子目录会让 rename 顶层的 DB-aware 逻辑失效;
用户想归档:先 DELETE task)
- 拷贝(`/copy`)无此限制,因为新副本无 task 关联
- 部分失败:同 /copy,前面成功的不回滚(`shutil.move` 失败几乎只发生在
跨卷拷贝中断,workspace 都在同一磁盘下罕见)
"""
import shutil
root = _load_user_root(user_id)
sources, dest = _validate_transfer(root, body.paths, body.dest_dir)
# 顶层目录-是-某 task.working_dir → 闸
top_level_dir_srcs = [
s for s in sources
if s.is_dir() and s.parent.resolve() == root.resolve()
]
if top_level_dir_srcs:
db_forms = [to_db_path(s) for s in top_level_dir_srcs]
with session_scope() as s:
rows = s.execute(
select(Task.working_dir, func.count())
.where(
Task.user_id == user_id,
Task.working_dir.in_(db_forms),
)
.group_by(Task.working_dir)
).all()
occupied = {wd: n for wd, n in rows}
if occupied:
# 反查 db_form → src.name 给报错文案
form2name = {to_db_path(s): s.name for s in top_level_dir_srcs}
names = ", ".join(
f"{form2name[wd]!r}({n} 个 task)"
for wd, n in occupied.items()
)
raise HTTPException(
409,
f"以下顶层目录正被 task 引用,不能移动:{names};"
f"请先删 task,或改用复制",
)
transferred: list[dict] = []
for src in sources:
target = dest / src.name
try:
shutil.move(str(src), str(target))
except OSError as e:
raise HTTPException(
500,
f"move failed at {src.name!r}: {e} "
f"(已成功 {len(transferred)} 项,剩余未处理)",
)
transferred.append({
"old": _rel_to(root, src),
"new": _rel_to(root, target),
})
return {"ok": True, "count": len(transferred), "transferred": transferred}
# ───────────── Export ─────────────
@app.get("/v1/tasks/{task_id}/export", tags=["export"])
def export_task(task_id: str, user_id: UUID = Depends(require_user)):
"""导出对话为 .docx,临时文件下载完后 BackgroundTask 删 tmp。"""
try:
tid = UUID(task_id)
except ValueError:
raise HTTPException(404, f"invalid task id: {task_id!r}")
with session_scope() as s:
_assert_owns_task(s, tid, user_id)
has_msg = s.execute(
select(Message.message_id).where(Message.task_id == tid).limit(1)
).first()
if not has_msg:
raise HTTPException(400, "no messages to export")
fd, tmp_str = tempfile.mkstemp(suffix=".docx", prefix="zcbot-export-")
os.close(fd)
tmp_path = Path(tmp_str)
try:
from core.export_docx import export_chat_to_docx
export_chat_to_docx(tid, out_path=tmp_path)
except Exception as e:
tmp_path.unlink(missing_ok=True)
raise HTTPException(500, f"export failed: {type(e).__name__}: {e}")
return FileResponse(
path=str(tmp_path),
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
filename=f"chat_{str(tid)[:8]}.docx",
background=BackgroundTask(tmp_path.unlink, missing_ok=True),
)
# ───────────── 管理后台(admin-only)─────────────
register_admin_routes(app, require_admin)
return app