3042 lines
136 KiB
Python
3042 lines
136 KiB
Python
"""FastAPI app: 纯 /v1 JSON API(2026-05-15 切换 — 详见 DESIGN §7.9)。
|
||
|
||
设计要点:
|
||
- 所有路由 `/v1/*` 前缀,响应 JSON;模板 / HTMX / 服务端 markdown 渲染全删
|
||
- SSE 事件 payload 是 JSON dict 而非 HTML 片段(`event: <type>` + `data: <json>`)
|
||
- Auth: PLATFORM_KEY → JWT 兑换(§7 D' 过渡形态,见 web/auth.py);OIDC 替换时只动 /v1/auth/login 内部
|
||
- 所有 /v1/tasks* 路由 Depends(require_user),按 user_id 隔离数据
|
||
- 豁免:/healthz、/docs、/openapi.json、/、/v1/auth/login、/static/*
|
||
- CORS allow_origins=["*"] 本地宽松;真发布按 platform 域名收紧
|
||
- `GET /` 302 → /static/dev.html(本地 dev SPA)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import mimetypes
|
||
import os
|
||
import tempfile
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime as _dt
|
||
from pathlib import Path
|
||
from typing import Any, Optional
|
||
from uuid import UUID, uuid4
|
||
|
||
try:
|
||
import resource # Unix only;Windows dev 无此模块,RSS 监控自动降级跳过
|
||
except ImportError: # pragma: no cover - Windows
|
||
resource = None
|
||
|
||
from fastapi import Depends, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse
|
||
from pydantic import BaseModel
|
||
from sqlalchemy import BigInteger, cast, func, select, update
|
||
from starlette.background import BackgroundTask
|
||
|
||
from core import __version__
|
||
from core.paths import from_db_path, to_db_path
|
||
from core.storage import (
|
||
NoSubtaskError,
|
||
check_no_subtask,
|
||
session_scope,
|
||
)
|
||
from core.storage.models import Message, ScheduledJob, Task, UsageEvent
|
||
from core.storage.utils import ensure_local_task_row
|
||
|
||
from .auth import (
|
||
AuthConfig,
|
||
UserCreateError,
|
||
change_password,
|
||
create_user,
|
||
ensure_user_row,
|
||
get_user_profile,
|
||
make_require_admin,
|
||
make_require_user,
|
||
mint_token,
|
||
resolve_user_by_email,
|
||
)
|
||
from .admin import register_admin_routes
|
||
from .broker import broker
|
||
from .sinks import WebEventSink
|
||
from .static_files import NoCacheStaticFiles
|
||
|
||
|
||
STATUS_FILTERS = ("active", "completed", "abandoned")
|
||
STATUS_WRITABLE = ("completed", "abandoned") # web 不让从 web 端切回 active(走 CLI)
|
||
# 渠道镜像 task 的 channel 取值(每用户每渠道一条常驻只读对话):从普通任务列表排除,
|
||
# 改由 /v1/channel_tasks 单独取、前端做成固定卡片。新增渠道在此追加即可。
|
||
CHANNEL_MIRROR_KINDS = ("wechat", "wecom")
|
||
ORDER_FIELDS = ("created_at", "updated_at", "name", "status")
|
||
ORDER_DEFAULT = "-created_at"
|
||
|
||
# pptx→PDF 预览:按解析后的 pptx 绝对路径加锁,防同一文件并发重复转换(DESIGN §8.3)。
|
||
_pptx_preview_locks: dict[str, asyncio.Lock] = {}
|
||
|
||
|
||
def _pptx_lock_for(abs_path: str) -> asyncio.Lock:
|
||
lock = _pptx_preview_locks.get(abs_path)
|
||
if lock is None:
|
||
lock = _pptx_preview_locks[abs_path] = asyncio.Lock()
|
||
return lock
|
||
|
||
|
||
# ─────────────────────────── helpers ───────────────────────────
|
||
|
||
def _norm_path(p: str) -> str:
|
||
"""跨 OS 显示归一:backslash → forward slash。"""
|
||
return (p or "").replace("\\", "/")
|
||
|
||
|
||
def _iso(dt: Optional[Any]) -> Optional[str]:
|
||
return dt.isoformat() if dt else None
|
||
|
||
|
||
def _outline_snippet(text: Optional[str], maxlen: int = 48) -> str:
|
||
"""user 消息正文 → 目录条目片段:取首个非空行,压平首尾空白,截断到 maxlen。"""
|
||
if not text:
|
||
return ""
|
||
for line in text.splitlines():
|
||
line = line.strip()
|
||
if line:
|
||
return line[:maxlen]
|
||
return text.strip()[:maxlen]
|
||
|
||
|
||
def _parse_ordering(s: Optional[str]) -> list:
|
||
"""DRF 风格 `ordering` 解析:逗号分隔多字段,`-` 前缀代表 desc。
|
||
|
||
allowlist 见 `ORDER_FIELDS`;非法字段静默丢弃。全部非法或空串 → `ORDER_DEFAULT`(`-created_at`)。
|
||
返回 sqlalchemy `order_by` 列表(可直接 `*expand`)。
|
||
"""
|
||
spec = (s or "").strip() or ORDER_DEFAULT
|
||
cols = []
|
||
for part in spec.split(","):
|
||
p = part.strip()
|
||
if not p:
|
||
continue
|
||
asc = True
|
||
if p.startswith("-"):
|
||
asc = False
|
||
p = p[1:]
|
||
if p in ORDER_FIELDS:
|
||
col = getattr(Task, p)
|
||
cols.append(col.asc() if asc else col.desc())
|
||
if not cols:
|
||
# 用户传了全无效字段 → fallback 默认
|
||
cols = [Task.created_at.desc()]
|
||
return cols
|
||
|
||
|
||
def _usage_aggregates(s: Any, tids: list) -> dict:
|
||
"""按 task_id 批量聚合 usage_events:真实成本 + chat token + 缓存命中。
|
||
|
||
单查询 GROUP BY(复用列表接口 msg_counts 同款批量范式,无 N+1)。on-the-fly 现算,
|
||
不落 tasks 列 —— 对所有历史 task 即时准确,免回填。
|
||
- cost_cny:全 kind(chat+image+video)合计 = task 真实花费
|
||
- tokens_in/out + cache_hit:仅 chat。**三者同源 usage_events**,故缓存命中率
|
||
`cache_hit / tokens_in` 恒 ≤ 100%;不能拿 `tasks.tokens_prompt` 当分母 ——
|
||
那列会被「清空对话」重置而 usage_events 不重置,跨源相除会算出 >100% 的怪值。
|
||
返回 {task_id: {"cost_cny": float, "tokens_in": int, "tokens_out": int,
|
||
"tokens_cache_hit": int}}。
|
||
"""
|
||
if not tids:
|
||
return {}
|
||
chat = UsageEvent.kind == "chat"
|
||
tin_col = cast(UsageEvent.units["tokens_in"].astext, BigInteger)
|
||
tout_col = cast(UsageEvent.units["tokens_out"].astext, BigInteger)
|
||
hit_col = cast(UsageEvent.units["cache_hit_tokens"].astext, BigInteger)
|
||
rows = s.execute(
|
||
select(
|
||
UsageEvent.task_id,
|
||
func.coalesce(func.sum(UsageEvent.cost_cny), 0),
|
||
func.coalesce(func.sum(tin_col).filter(chat), 0),
|
||
func.coalesce(func.sum(tout_col).filter(chat), 0),
|
||
func.coalesce(func.sum(hit_col).filter(chat), 0),
|
||
)
|
||
.where(UsageEvent.task_id.in_(tids))
|
||
.group_by(UsageEvent.task_id)
|
||
).all()
|
||
return {
|
||
tid: {
|
||
"cost_cny": float(cost or 0),
|
||
"tokens_in": int(tin or 0),
|
||
"tokens_out": int(tout or 0),
|
||
"tokens_cache_hit": int(hit or 0),
|
||
}
|
||
for tid, cost, tin, tout, hit in rows
|
||
}
|
||
|
||
|
||
def _task_dict(
|
||
row: Any,
|
||
*,
|
||
n_messages: Optional[int] = None,
|
||
usage: Optional[dict] = None,
|
||
) -> dict:
|
||
"""Task ORM row → API JSON dict。
|
||
|
||
`usage`(可选)= `_usage_aggregates` 算出的本 task 概要,带真实成本与缓存命中;
|
||
缺省回退到 tasks.cost_cny 列(多为 0)与 0 命中,前端据此显 ¥ / 缓存命中率。
|
||
"""
|
||
u = usage or {}
|
||
# token 总量优先取 usage_events 聚合(用量 source-of-truth,且与 cache_hit 同源 →
|
||
# 命中率分母一致、恒 ≤100%);无 usage 时回退 tasks 概览列。tasks.tokens_prompt 会被
|
||
# 「清空对话」重置,不能与 usage_events 的 cache_hit 跨源相除。
|
||
tokens_prompt = int(u["tokens_in"]) if "tokens_in" in u else (row.tokens_prompt or 0)
|
||
tokens_completion = int(u["tokens_out"]) if "tokens_out" in u else (row.tokens_completion or 0)
|
||
d = {
|
||
"task_id": str(row.task_id),
|
||
"name": row.name or "",
|
||
"description": row.description or "",
|
||
"working_dir": _norm_path(row.working_dir or ""),
|
||
"status": row.status,
|
||
"skill": row.skill or "",
|
||
"channel": getattr(row, "channel", None) or "web",
|
||
"model": row.model or "",
|
||
"model_profile": row.model_profile or "",
|
||
"tokens_prompt": tokens_prompt,
|
||
"tokens_completion": tokens_completion,
|
||
"tokens": tokens_prompt + tokens_completion,
|
||
# 缓存命中 token(chat 前缀缓存)+ 真实成本(已按缓存折价,见 usage.py)。
|
||
# on-the-fly 聚合;未传 usage 时回退列/0。
|
||
"tokens_cache_hit": int(u.get("tokens_cache_hit", 0)),
|
||
"cost_cny": float(u["cost_cny"]) if "cost_cny" in u else float(row.cost_cny or 0),
|
||
# 当前 run 状态(0004 schema 简化:原 runs 表合并入 task)
|
||
"run_status": row.run_status or "idle",
|
||
"run_error": row.run_error or None,
|
||
"created_at": _iso(getattr(row, "created_at", None)),
|
||
"updated_at": _iso(getattr(row, "updated_at", None)),
|
||
}
|
||
if n_messages is not None:
|
||
d["n_messages"] = n_messages
|
||
return d
|
||
|
||
|
||
# ─────────────────────── files helpers ───────────────────────
|
||
|
||
def _load_user_root(user_id: UUID) -> Path:
|
||
"""user_root = `<workspace>/users/<user_id>/`,所有 files API 的边界。
|
||
若目录尚未存在自动 mkdir(空 user 首次访问也能拿到根)。
|
||
"""
|
||
from core.agent_builder import resolve_workspace, user_root
|
||
ws = resolve_workspace(None)
|
||
return user_root(ws, user_id)
|
||
|
||
|
||
def _safe_join(root: Path, rel: str) -> Path:
|
||
"""归一用户路径到 absolute,并校验仍在 root 内。防 `../` / 绝对 path / symlink 越界。"""
|
||
rel = (rel or "").strip()
|
||
if not rel:
|
||
return root.resolve()
|
||
if rel[0] in ("/", "\\"):
|
||
raise HTTPException(400, f"absolute-style path not allowed: {rel!r}")
|
||
if Path(rel).is_absolute():
|
||
raise HTTPException(400, f"absolute path not allowed: {rel!r}")
|
||
target = (root / rel).resolve()
|
||
try:
|
||
target.relative_to(root.resolve())
|
||
except ValueError:
|
||
raise HTTPException(400, f"path escapes user_root: {rel!r}")
|
||
return target
|
||
|
||
|
||
def _rel_to(root: Path, target: Path) -> str:
|
||
try:
|
||
rel = target.resolve().relative_to(root.resolve()).as_posix()
|
||
except ValueError:
|
||
return ""
|
||
return "" if rel == "." else rel
|
||
|
||
|
||
def _enumerate_files(root: Path, current: Path) -> tuple[list[dict], list[dict], bool]:
|
||
"""枚举 current 下条目 + 拼面包屑。size raw bytes,mtime ISO 串(前端 humanize)。
|
||
Dotfile 一律隐藏(`.memory/` 等系统区不暴露给 UI,同 `/v1/folders` 约定)。
|
||
"""
|
||
entries: list[dict] = []
|
||
exists = current.exists()
|
||
if exists and current.is_dir():
|
||
try:
|
||
raw = sorted(current.iterdir(), key=lambda p: (p.is_file(), p.name.lower()))
|
||
except OSError:
|
||
raw = []
|
||
for p in raw:
|
||
if p.name.startswith("."):
|
||
continue
|
||
try:
|
||
st = p.stat()
|
||
except OSError:
|
||
continue
|
||
entries.append({
|
||
"name": p.name,
|
||
"is_dir": p.is_dir(),
|
||
"size": st.st_size if p.is_file() else None,
|
||
"mtime": _dt.fromtimestamp(st.st_mtime).isoformat(timespec="seconds"),
|
||
"rel": _rel_to(root, p),
|
||
})
|
||
cur_rel = _rel_to(root, current)
|
||
crumbs = [{"label": "/", "rel": ""}]
|
||
# cur_rel == "." 表示当前就在 root(target.relative_to(root) 返 Path(".")),
|
||
# 不该再追加一个无意义的 "." crumb
|
||
if cur_rel and cur_rel != ".":
|
||
acc = ""
|
||
for part in cur_rel.split("/"):
|
||
acc = f"{acc}/{part}" if acc else part
|
||
crumbs.append({"label": part, "rel": acc})
|
||
return entries, crumbs, exists
|
||
|
||
|
||
def _validate_transfer(
|
||
root: Path, paths: list[str], dest_dir: str,
|
||
) -> tuple[list[Path], Path]:
|
||
"""预检批量 transfer:解析所有源 + 目标,任意一项不合法即整批 abort(无 FS 副作用)。
|
||
|
||
返回 (sources, dest_dir_path)。不区分 copy / move(顶层 working_dir 闸由路由各自加)。
|
||
校验项:
|
||
- paths 非空;每个源在 user_root 内 + 存在;不能是 user_root 本身
|
||
- dest_dir 存在 + 是目录(可以是 user_root)
|
||
- 源不能与 dest_dir 相同(自移动)
|
||
- dest_dir 不能在源的子树内(不能把 a/ 搬进 a/b/)
|
||
- 源不能已是 dest_dir 直接子项(原地移动,no-op)
|
||
- 同批次源 leaf 名不能重复(俩 a.txt 会撞 dest/a.txt)
|
||
- dest_dir/<name> 不能已存在(整批 409,不静默覆盖)
|
||
"""
|
||
if not paths:
|
||
raise HTTPException(400, "paths is empty")
|
||
dest = _safe_join(root, dest_dir)
|
||
if not dest.exists():
|
||
raise HTTPException(404, f"dest_dir not found: {dest_dir!r}")
|
||
if not dest.is_dir():
|
||
raise HTTPException(400, f"dest_dir is not a directory: {dest_dir!r}")
|
||
dest_r = dest.resolve()
|
||
|
||
sources: list[Path] = []
|
||
seen_names: set[str] = set()
|
||
for p in paths:
|
||
src = _safe_join(root, p)
|
||
if not src.exists():
|
||
raise HTTPException(404, f"source not found: {p!r}")
|
||
src_r = src.resolve()
|
||
if src_r == root.resolve():
|
||
raise HTTPException(400, "cannot transfer user_root")
|
||
if src_r == dest_r:
|
||
raise HTTPException(400, f"source equals dest_dir: {p!r}")
|
||
# dest 在 src 子树内 → 自嵌套
|
||
try:
|
||
dest_r.relative_to(src_r)
|
||
raise HTTPException(
|
||
400, f"cannot transfer {p!r} into its own subtree"
|
||
)
|
||
except ValueError:
|
||
pass
|
||
# 已是 dest 直接子项 → no-op
|
||
if src.parent.resolve() == dest_r:
|
||
raise HTTPException(
|
||
400, f"{p!r} already directly under dest_dir"
|
||
)
|
||
name = src.name
|
||
if name in seen_names:
|
||
raise HTTPException(400, f"duplicate source leaf name in batch: {name!r}")
|
||
seen_names.add(name)
|
||
target = dest / name
|
||
if target.exists():
|
||
raise HTTPException(
|
||
409, f"target already exists: {_rel_to(root, target)!r}"
|
||
)
|
||
sources.append(src)
|
||
return sources, dest
|
||
|
||
|
||
# ─────────────────── BG run + SSE 帧格式 ───────────────────
|
||
|
||
def _run_agent_bg(
|
||
task_id: UUID, user_id: UUID, user_message: str,
|
||
image_variant: str = "", video_variant: str = "",
|
||
scheduled: bool = False,
|
||
) -> None:
|
||
"""工作线程:`build_agent(resume=True)` → 装 WebEventSink + cancel_check → `agent.run` → 写 tasks.run_status。
|
||
|
||
sink 通过 broker.emit 桥事件回 asyncio loop;agent.run 是 sync,所以在 to_thread 跑。
|
||
user_id 必须从 JWT 那侧透传过来 —— 决定 memory_block 读哪个 per-user 子树。
|
||
cancel_check 桥 broker.is_cancelled,loop 在 stream chunk 间 + 工具调用之间 poll;
|
||
cancel 延迟 ~ 单 chunk 间隔(100ms 级);seedance 轮询间也读这个 cancel_check 用于
|
||
用户停止按钮(必须在 build_agent 阶段就传进去,因为 SeedanceTool ctor 持有它,
|
||
不能像以前那样 build_agent 返回后再赋 agent.cancel_check)。
|
||
`ok / cancelled` 收尾直接回 `idle`(不留持久标记);只有 error 是持久终态。
|
||
|
||
image_variant / video_variant:本 run 用哪个 image/video variant 装 tool(空 → yaml 第一个)。
|
||
随消息 POST 传进来,不入 DB —— UI 下拉的选择就跟在这一条消息上生效。
|
||
"""
|
||
from core.agent_builder import build_agent, sync_task_tokens
|
||
cancel_check = lambda tid=task_id: broker.is_cancelled(tid)
|
||
try:
|
||
broker.emit(task_id, {"type": "run_start"})
|
||
agent, session, sid, task_state, task_dir = build_agent(
|
||
session_id=str(task_id), resume=True, user_id=user_id,
|
||
image_variant=image_variant,
|
||
video_variant=video_variant,
|
||
cancel_check=cancel_check,
|
||
scheduled_run=scheduled,
|
||
)
|
||
agent.sink = WebEventSink(broker, task_id)
|
||
agent.run(user_message)
|
||
sync_task_tokens(task_state)
|
||
# cancel 命中或正常完成 → run_status 回 idle(error 才持久)
|
||
with session_scope() as s:
|
||
s.execute(
|
||
update(Task).where(Task.task_id == task_id).values(
|
||
run_status="idle", run_error=None,
|
||
)
|
||
)
|
||
except Exception as e:
|
||
err = f"{type(e).__name__}: {e}"
|
||
broker.emit(task_id, {"type": "error", "msg": err})
|
||
try:
|
||
with session_scope() as s:
|
||
s.execute(
|
||
update(Task).where(Task.task_id == task_id).values(
|
||
run_status="error", run_error=err,
|
||
)
|
||
)
|
||
except Exception:
|
||
pass # 已 emit error 给前端,DB 写失败不放大噪声
|
||
finally:
|
||
broker.clear_cancel(task_id)
|
||
broker.close(task_id)
|
||
|
||
|
||
def _sse_event(event_type: str, payload: dict) -> bytes:
|
||
"""格式化 SSE 一帧:`event: <type>` + `data: <json single-line>`。"""
|
||
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
||
return f"event: {event_type}\ndata: {body}\n\n".encode("utf-8")
|
||
|
||
|
||
def _resolve_model_profile(profile: str) -> tuple[str, str]:
|
||
"""校验 model_profile 并返回 (profile, model_id)。
|
||
|
||
传空 → cfg["default_model"]。profile 走 ModelCapabilities.load:
|
||
格式或文件错误一律 400。返 (profile_str, caps.model_id) —— 调 ensure_local_task_row
|
||
时 model_profile / model 两列一起填,保持现有 schema 双列约定。
|
||
"""
|
||
from core.agent_builder import load_config
|
||
from core.capabilities import ModelCapabilities
|
||
from core.paths import ROOT
|
||
|
||
cfg = load_config()
|
||
name = (profile or "").strip() or cfg["default_model"]
|
||
try:
|
||
caps = ModelCapabilities.load(name, ROOT / cfg["models_dir"])
|
||
except (FileNotFoundError, ValueError) as e:
|
||
raise HTTPException(400, f"invalid model_profile {name!r}: {e}")
|
||
return name, caps.model_id
|
||
|
||
|
||
async def _run_channel_conversation(app, uid, text, attachments, *, channel):
|
||
"""渠道无关的入站对话核心(§8.7):解析/建该用户该渠道常驻 task → 落盘附件 → 抢 run 锁
|
||
→ _run_agent_bg → 取 assistant 回复文本。两渠道各一张会话 task,互不串扰。
|
||
|
||
channel:'wechat'(个人微信 ClawBot,绑定快照取 chat_task_id)| 'wecom'(企业微信,
|
||
wecom 绑定行取 chat_task_id)。attachments:已下载解密的入站附件(可空,wecom 暂只收文本)。
|
||
返回回复文本(供 ClawBot 回流 / wecom 主动推回)。
|
||
"""
|
||
from core.wechat import service as _wx
|
||
from core.wechat.ilink import attachment_basename
|
||
from core.wechat.inbound import extract_last_assistant_text
|
||
|
||
# 解析/建该渠道常驻 chat task(不存在自动建)—— 与 push 记录(send_to_user)共用
|
||
# ensure_channel_chat_task,避免两条建 task 路径漂移。wechat 无 binding → 返回 None。
|
||
tid = await asyncio.to_thread(_wx.ensure_channel_chat_task, uid, channel)
|
||
if tid is None:
|
||
return ""
|
||
|
||
# 落盘入站附件到 <wd>/inbound/,拼 [用户上传的...] 行进 text(复用 web 端粘贴图约定)
|
||
if attachments:
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
with session_scope() as s:
|
||
wd_db = s.execute(
|
||
select(Task.working_dir).where(Task.task_id == tid)
|
||
).scalar_one()
|
||
inbound_dir = from_db_path(wd_db) / "inbound"
|
||
inbound_dir.mkdir(parents=True, exist_ok=True)
|
||
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||
lines: list[str] = []
|
||
for i, att in enumerate(attachments):
|
||
if not att.data:
|
||
continue
|
||
base = attachment_basename(att)
|
||
name = f"{ts}-{i}-{base}"
|
||
(inbound_dir / name).write_bytes(att.data)
|
||
rel = f"inbound/{name}"
|
||
tag = "[用户上传的参考图]" if att.kind == "image" else "[用户上传的文件]"
|
||
lines.append(f"{tag} {rel}")
|
||
if lines:
|
||
extra = "\n".join(lines)
|
||
text = f"{text}\n\n{extra}" if text.strip() else extra
|
||
|
||
# 抢 run 锁:正忙 → 提示稍候(同用户串行;ClawBot loop 本就串行,wecom 回调靠此挡并发)
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.run_status).where(Task.task_id == tid).with_for_update()
|
||
).first()
|
||
if row is None:
|
||
return "[出错] 对话 task 不存在"
|
||
if row.run_status in ("running", "cancelling"):
|
||
return "上一条还在处理中,请稍候再发。"
|
||
s.execute(update(Task).where(Task.task_id == tid).values(
|
||
run_status="running", run_error=None))
|
||
|
||
broker.start(tid)
|
||
runner = asyncio.create_task(asyncio.to_thread(
|
||
_run_agent_bg, tid, uid, text, "", "", False,
|
||
))
|
||
app.state.inflight[runner] = tid
|
||
runner.add_done_callback(lambda t: app.state.inflight.pop(t, None))
|
||
await runner
|
||
|
||
with session_scope() as s:
|
||
st = s.execute(
|
||
select(Task.run_status, Task.run_error).where(Task.task_id == tid)
|
||
).first()
|
||
if st is not None and st.run_status == "error":
|
||
return f"[出错] {st.run_error}"
|
||
reply = await asyncio.to_thread(extract_last_assistant_text, tid)
|
||
return reply or "(本轮无文本回复)"
|
||
|
||
|
||
def _list_image_variants() -> list[tuple[str, dict]]:
|
||
"""扫 config/media/doubao.yaml image 段 → [(variant_key, variant_cfg), ...]。
|
||
|
||
yaml 不存在或 image 段空 / 仅注释 → 返 []。不要求 ARK_API_KEY 已设 —— 仅纯
|
||
元数据列举,UI 拉这个画下拉。真正调用 seedream 时 agent_builder 那边再过
|
||
`ArkConfig.load()`(没 key → tool 不注册)。
|
||
"""
|
||
from core.paths import ROOT
|
||
import yaml as _yaml
|
||
|
||
p = ROOT / "config" / "media" / "doubao.yaml"
|
||
if not p.exists():
|
||
return []
|
||
try:
|
||
data = _yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||
except Exception:
|
||
return []
|
||
image_cfg = data.get("image") or {}
|
||
return [(k, v) for k, v in image_cfg.items() if isinstance(v, dict)]
|
||
|
||
|
||
def _resolve_image_model(variant: str) -> str:
|
||
"""校验 image_model variant key。
|
||
|
||
传空 → 返空(agent_builder fallback 到第一个 variant);传非空 → 必须存在
|
||
于 config/media/doubao.yaml image 段,否则 400。
|
||
"""
|
||
name = (variant or "").strip()
|
||
if not name:
|
||
return ""
|
||
variants = {k for k, _ in _list_image_variants()}
|
||
if name not in variants:
|
||
raise HTTPException(400, f"invalid image_model {name!r}; available: {sorted(variants)}")
|
||
return name
|
||
|
||
|
||
def _list_video_variants() -> list[tuple[str, dict]]:
|
||
"""扫 config/media/doubao.yaml video 段 → [(variant_key, variant_cfg), ...]。
|
||
|
||
与 _list_image_variants 同范式;空 video 段(未上线 / 注释掉)→ 返 [],UI 隐藏下拉。
|
||
"""
|
||
from core.paths import ROOT
|
||
import yaml as _yaml
|
||
|
||
p = ROOT / "config" / "media" / "doubao.yaml"
|
||
if not p.exists():
|
||
return []
|
||
try:
|
||
data = _yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
||
except Exception:
|
||
return []
|
||
video_cfg = data.get("video") or {}
|
||
return [(k, v) for k, v in video_cfg.items() if isinstance(v, dict)]
|
||
|
||
|
||
def _resolve_video_model(variant: str) -> str:
|
||
"""校验 video_model variant key(同 _resolve_image_model 范式)。"""
|
||
name = (variant or "").strip()
|
||
if not name:
|
||
return ""
|
||
variants = {k for k, _ in _list_video_variants()}
|
||
if name not in variants:
|
||
raise HTTPException(400, f"invalid video_model {name!r}; available: {sorted(variants)}")
|
||
return name
|
||
|
||
|
||
# ────────────────────── Pydantic 请求体 ──────────────────────
|
||
|
||
class TaskCreateRequest(BaseModel):
|
||
name: str # 任务显示名(必填,DB 列 NOT NULL)
|
||
working_dir: str = "" # 工作目录名(可选,留空 → 用 name 作目录名)
|
||
description: str = ""
|
||
skill: str = ""
|
||
model_profile: str = "" # `family.variant`,留空 → cfg["default_model"];必须存在于 config/models/
|
||
|
||
|
||
class TaskPatchRequest(BaseModel):
|
||
status: Optional[str] = None
|
||
description: Optional[str] = None
|
||
name: Optional[str] = None
|
||
skill: Optional[str] = None
|
||
model_profile: Optional[str] = None # 切模型(c 模式 task 层 / A 粒度 — 下条 send 生效)
|
||
|
||
|
||
class SchedulePatchRequest(BaseModel):
|
||
# 前端只读视图仅用 enabled(停用/启用);其余字段留着供"对话改不了时"的兜底直改,
|
||
# 但前端不暴露编辑表单(建/改走对话,§8.5)。
|
||
enabled: Optional[bool] = None
|
||
|
||
|
||
class MessageRequest(BaseModel):
|
||
content: str
|
||
# 该条消息触发的生图 / 生视频模型 variant key(config/media/doubao.yaml image/video 段)。
|
||
# 空 → 对应 tool 走 yaml 第一个 variant;非空 → 本次 run 装配指定 variant。
|
||
# 仅作用于本 run,不入 DB,UI 下拉的选择跟在消息 POST body 上。
|
||
image_model: str = ""
|
||
video_model: str = ""
|
||
|
||
|
||
class OptimizePromptRequest(BaseModel):
|
||
text: str
|
||
# 选择性传当前 UI 选中的 variant key,润色 meta-prompt 会把对应模型特性塞进去
|
||
# (让 LLM 知道下游 tool 偏好,润色出更贴合 seedream / seedance 等的 prompt)。
|
||
image_model: str = ""
|
||
video_model: str = ""
|
||
|
||
|
||
class FileDeleteRequest(BaseModel):
|
||
path: str
|
||
recursive: bool = False
|
||
|
||
|
||
class FileRenameRequest(BaseModel):
|
||
path: str # 被重命名的目录 / 文件,相对 user_root
|
||
new_name: str # 新的 leaf 名(不是路径),不含 / \ ..
|
||
|
||
|
||
class FileTransferRequest(BaseModel):
|
||
paths: list[str] # 多源,均相对 user_root
|
||
dest_dir: str = "" # 目标目录,相对 user_root,空 → user_root
|
||
|
||
|
||
class LoginRequest(BaseModel):
|
||
user_id: str
|
||
platform_key: str
|
||
# 0016:平台可选注入的用户档案,缺省即旧行为(只填 user_id),向后兼容老调用方。
|
||
name: Optional[str] = None # 显示名 / 姓名
|
||
user_name: Optional[str] = None # 平台账号名
|
||
|
||
|
||
class PasswordLoginRequest(BaseModel):
|
||
email: str
|
||
password: str
|
||
|
||
|
||
class AdminCreateUserRequest(BaseModel):
|
||
email: str
|
||
password: str
|
||
admin_token: str
|
||
role: str = "user" # 'user' / 'admin';admin 可访问 /static/admin.html 管理后台
|
||
|
||
|
||
class ChangePasswordRequest(BaseModel):
|
||
old_password: str
|
||
new_password: str
|
||
|
||
|
||
# ────────────────────── App 工厂 ──────────────────────
|
||
|
||
# web/static 目录路径 — /static 静态挂载用,dev.html 也放这
|
||
_STATIC_DIR = Path(__file__).parent / "static"
|
||
|
||
|
||
def create_app() -> FastAPI:
|
||
# fail-fast:env 缺失直接抛,不裸跑无密
|
||
auth_cfg = AuthConfig.from_env()
|
||
require_user = make_require_user(auth_cfg)
|
||
require_admin = make_require_admin(auth_cfg)
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
loop = asyncio.get_running_loop()
|
||
broker.bind_loop(loop)
|
||
|
||
# ── 接管默认线程池 executor(§8.4)──────────────────────────────
|
||
# run 走 asyncio.to_thread(用 loop 默认 executor);默认是匿名的,读不到大小、
|
||
# 不可调。显式建一个同尺寸(复刻 Python 默认 min(32, cpu+4))接管,好处:① 监控
|
||
# 能读 max_workers 判断有没有排队 ② 并发不够时改 ZCBOT_RUN_MAX_WORKERS 调大不改码。
|
||
# 注:run 与 disk scan / pptx 转换 / reaper 共享此池(同原默认行为);真要隔离
|
||
# 长任务再另开 run 专用池,那是后话。
|
||
run_max_workers = int(
|
||
os.getenv("ZCBOT_RUN_MAX_WORKERS") or min(32, (os.cpu_count() or 1) + 4)
|
||
)
|
||
run_executor = ThreadPoolExecutor(
|
||
max_workers=run_max_workers, thread_name_prefix="run"
|
||
)
|
||
loop.set_default_executor(run_executor)
|
||
app.state.run_executor = run_executor
|
||
app.state.run_max_workers = run_max_workers
|
||
print(f"[startup] run executor: max_workers={run_max_workers} "
|
||
f"(override via ZCBOT_RUN_MAX_WORKERS)")
|
||
|
||
from core.agent_builder import load_config, resolve_workspace
|
||
_cfg = load_config()
|
||
|
||
# 优雅 drain 状态(SIGTERM / systemctl restart 兜底,见下方 finally):
|
||
# draining 置位后 POST /messages 返 503;inflight 登记在跑的 BG run task,
|
||
# 关停时 await 它们收尾。inflight 同时给 create_task 持强引用,防被 GC 中途回收。
|
||
app.state.draining = asyncio.Event()
|
||
app.state.inflight = {} # dict[asyncio.Task, UUID(task_id)]
|
||
_shutdown_cfg = _cfg.get("shutdown") or {}
|
||
drain_timeout = int(_shutdown_cfg.get("drain_timeout_seconds") or 90)
|
||
cancel_grace = int(_shutdown_cfg.get("cancel_grace_seconds") or 15)
|
||
|
||
# Stale-run reaper:上次进程 crash 留下的 "running" / "cancelling" 已无 BG 线程
|
||
# 继续,启动时标 error,让对应 task 重新可发消息(否则 gate 永挂)。
|
||
# TODO 真生产 multi-worker:换 heartbeat / lease,只 reap 自家 worker 的孤儿。
|
||
with session_scope() as s:
|
||
result = s.execute(
|
||
update(Task)
|
||
.where(Task.run_status.in_(("running", "cancelling")))
|
||
.values(
|
||
run_status="error",
|
||
run_error="server restarted before run finished",
|
||
)
|
||
)
|
||
if result.rowcount:
|
||
print(f"[startup] reaped {result.rowcount} stale active run(s)")
|
||
|
||
# 磁盘配额后台扫描(§7.5 #4 应用层 gate)── 不依赖 docker backend,host
|
||
# backend 也跑(/v1/files/upload 也走配额 gate)。yaml `quotas.disk_scan_interval_seconds`
|
||
# 默 900s = 15min;limit_bytes ≤ 0 视为不限,scan 仍跑(用量统计有用),check 短路放行。
|
||
from core.agent_builder import resolve_workspace
|
||
from core.storage.disk_quota import parse_bytes, scan_all_users
|
||
workspace = resolve_workspace(None, _cfg)
|
||
disk_user_root = workspace / "users"
|
||
quotas_cfg = _cfg.get("quotas") or {}
|
||
disk_scan_interval = int(quotas_cfg.get("disk_scan_interval_seconds") or 900)
|
||
|
||
async def _disk_scanner() -> None:
|
||
loop = asyncio.get_running_loop()
|
||
# 启动时跑一次,后续按 interval。首次扫完 check 才能命中。
|
||
try:
|
||
n = await loop.run_in_executor(None, scan_all_users, disk_user_root)
|
||
if n:
|
||
print(f"[disk_scanner] initial scan: {n} user(s)")
|
||
except Exception as e:
|
||
print(f"[disk_scanner] initial scan error: {type(e).__name__}: {e}")
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(disk_scan_interval)
|
||
n = await loop.run_in_executor(None, scan_all_users, disk_user_root)
|
||
if n:
|
||
print(f"[disk_scanner] scanned {n} user(s)")
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except Exception as e:
|
||
print(f"[disk_scanner] error: {type(e).__name__}: {e}")
|
||
|
||
disk_scanner_task = asyncio.create_task(_disk_scanner(), name="disk-scanner")
|
||
|
||
# ── 并发/线程池监控(§8.4):周期采样,只在有负载/刷新峰值时打,空闲不刷屏 ──
|
||
# active_runs 来自 inflight(已提交未完成的 run,含排队中);逼近 max_workers 即
|
||
# 线程池排队,新 run 的 SSE 会卡着不吐 token。查看:journalctl -u zcbot | grep '\[stats\]'
|
||
def _rss_peak_mb() -> Optional[float]:
|
||
if resource is None:
|
||
return None # Windows dev:降级,不打 rss
|
||
# Linux ru_maxrss 单位 KB,是峰值/high-water(单调不降 —— 看内存涨势够用)
|
||
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
|
||
|
||
async def _stats_logger() -> None:
|
||
peak = 0
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(60)
|
||
active = len(app.state.inflight)
|
||
if active > peak:
|
||
peak = active
|
||
warn = " [WARN >= max_workers,已在排队]" if active >= run_max_workers else ""
|
||
print(f"[stats] new peak active_runs={active} "
|
||
f"max_workers={run_max_workers}{warn}")
|
||
if active > 0:
|
||
rss = _rss_peak_mb()
|
||
rss_s = f" rss_peak={rss:.0f}MB" if rss is not None else ""
|
||
print(f"[stats] active_runs={active} "
|
||
f"max_workers={run_max_workers} "
|
||
f"sse_subs={broker.total_subscribers()}{rss_s}")
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except Exception as e:
|
||
print(f"[stats] error: {type(e).__name__}: {e}")
|
||
|
||
stats_logger_task = asyncio.create_task(_stats_logger(), name="stats-logger")
|
||
|
||
# ── 定时任务守护循环(§8.5)── 仿 _disk_scanner 的 plain-asyncio 范式,不引
|
||
# APScheduler/Celery。每 ~10s 认领到点 job(claim+advance next_run 防重复触发),
|
||
# 复用 _run_agent_bg 起 run,跑完确定性兜底投递 + 回写 last_*。间隔只决定最坏延迟
|
||
# (≤1 tick),不决定会不会漏(claim 取 next_run<=now 的全部)。ZCBOT_DISABLE_SCHEDULER=1
|
||
# 整体关掉(对照 Claude Code CLAUDE_CODE_DISABLE_CRON)。
|
||
scheduler_enabled = os.getenv("ZCBOT_DISABLE_SCHEDULER", "").strip() not in ("1", "true", "yes")
|
||
sched_tick = int(os.getenv("ZCBOT_SCHEDULER_TICK_SECONDS", "10") or "10")
|
||
sched_sema = asyncio.Semaphore(int(os.getenv("ZCBOT_SCHEDULER_CONCURRENCY", "4") or "4"))
|
||
|
||
async def _execute_scheduled_job(snap: dict) -> None:
|
||
"""认领后跑一个 job:解析目标 task → 抢 run 锁 → _run_agent_bg → 投递 + 记账。"""
|
||
from core.agent_builder import (
|
||
resolve_workspace, working_dir_from_name, validate_task_name, InvalidTaskName,
|
||
)
|
||
from core.scheduler import build_run_message, deliver_notify, record_result
|
||
from core.storage.utils import ensure_local_task_row
|
||
|
||
job_id = snap["job_id"]
|
||
uid = snap["user_id"]
|
||
async with sched_sema:
|
||
try:
|
||
profile, model_id = _resolve_model_profile(snap.get("model_profile") or "")
|
||
ws = resolve_workspace(None, _cfg)
|
||
# 目标 task:persistent 用绑定 task(缺则新建并回填);isolated 用稳定 per-job 目录
|
||
tid: Optional[UUID] = None
|
||
if snap["mode"] == "persistent" and snap.get("bound_task_id"):
|
||
tid = snap["bound_task_id"]
|
||
# 绑定 task 可能已被删(SET NULL 已处理 None;这里再查实在性)
|
||
with session_scope() as s:
|
||
exists = s.execute(
|
||
select(Task.task_id).where(
|
||
Task.task_id == tid, Task.deleted_at.is_(None)
|
||
)
|
||
).first()
|
||
if exists is None:
|
||
tid = None
|
||
if tid is None:
|
||
tid = uuid4()
|
||
wd_name = f"scheduled-{str(job_id)[:8]}"
|
||
fs_dir = working_dir_from_name(ws, uid, wd_name)
|
||
fs_dir.mkdir(parents=True, exist_ok=True)
|
||
disp = f"{snap['name']}"
|
||
try:
|
||
disp = validate_task_name(disp)
|
||
except InvalidTaskName:
|
||
disp = wd_name # 名字含非法字符 → 退到安全名
|
||
ensure_local_task_row(
|
||
task_id=tid, name=disp, working_dir=to_db_path(fs_dir),
|
||
skill=snap.get("skill") or "", user_id=uid,
|
||
model=model_id, model_profile=profile,
|
||
description="(定时任务自动创建)",
|
||
scheduled_job_id=job_id,
|
||
)
|
||
if snap["mode"] == "persistent":
|
||
with session_scope() as s:
|
||
s.execute(update(ScheduledJob).where(
|
||
ScheduledJob.job_id == job_id
|
||
).values(bound_task_id=tid))
|
||
|
||
# 抢 run 锁(同 post_message):busy → 本次跳过,下个 cron 点再来
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.run_status).where(Task.task_id == tid).with_for_update()
|
||
).first()
|
||
if row is None:
|
||
record_result(job_id, status="error", task_id=tid, error="目标 task 不存在")
|
||
return
|
||
if row.run_status in ("running", "cancelling"):
|
||
record_result(job_id, status="skipped", task_id=tid,
|
||
error="目标 task 正忙,本次跳过")
|
||
print(f"[scheduler] job {str(job_id)[:8]} skipped (task busy)")
|
||
return
|
||
s.execute(update(Task).where(Task.task_id == tid).values(
|
||
run_status="running", run_error=None))
|
||
|
||
message = build_run_message(snap)
|
||
broker.start(tid)
|
||
runner = asyncio.create_task(asyncio.to_thread(
|
||
_run_agent_bg, tid, uid, message, "", "", True,
|
||
))
|
||
app.state.inflight[runner] = tid
|
||
runner.add_done_callback(lambda t: app.state.inflight.pop(t, None))
|
||
|
||
timeout = int(snap.get("timeout_seconds") or 0)
|
||
if timeout > 0:
|
||
done, _pending = await asyncio.wait({runner}, timeout=timeout)
|
||
if not done:
|
||
broker.request_cancel(tid) # 协作式停;loop 在 chunk 间 poll 到即退
|
||
print(f"[scheduler] job {str(job_id)[:8]} timed out ({timeout}s), cancelling")
|
||
await runner
|
||
else:
|
||
await runner
|
||
|
||
# run 终态:_run_agent_bg 收尾把 run_status 写回 idle(ok)/error
|
||
with session_scope() as s:
|
||
st = s.execute(
|
||
select(Task.run_status, Task.run_error).where(Task.task_id == tid)
|
||
).first()
|
||
if st is not None and st.run_status == "error":
|
||
record_result(job_id, status="error", task_id=tid, error=st.run_error)
|
||
print(f"[scheduler] job {str(job_id)[:8]} run error: {st.run_error}")
|
||
return
|
||
|
||
# 第 3 层确定性兜底投递(notify);失败不影响 run 已成功这一事实
|
||
if snap.get("notify"):
|
||
try:
|
||
with session_scope() as s:
|
||
wd_db = s.execute(
|
||
select(Task.working_dir).where(Task.task_id == tid)
|
||
).scalar_one_or_none()
|
||
fs_dir = from_db_path(wd_db) if wd_db else ws
|
||
await asyncio.get_running_loop().run_in_executor(
|
||
None, lambda: deliver_notify(
|
||
snap["notify"], job_name=snap["name"],
|
||
working_dir=fs_dir, tz=snap["tz"],
|
||
user_id=snap["user_id"],
|
||
)
|
||
)
|
||
except Exception as e:
|
||
print(f"[scheduler] job {str(job_id)[:8]} notify failed: {type(e).__name__}: {e}")
|
||
record_result(job_id, status="ok", task_id=tid)
|
||
print(f"[scheduler] job {str(job_id)[:8]} '{snap['name']}' done")
|
||
except Exception as e:
|
||
print(f"[scheduler] job {str(job_id)[:8]} crashed: {type(e).__name__}: {e}")
|
||
try:
|
||
record_result(job_id, status="error", task_id=None, error=f"{type(e).__name__}: {e}")
|
||
except Exception:
|
||
pass
|
||
|
||
async def _scheduler_loop() -> None:
|
||
from core.scheduler import claim_due_jobs
|
||
loop = asyncio.get_running_loop()
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(sched_tick)
|
||
if getattr(app.state, "draining", None) is not None and app.state.draining.is_set():
|
||
continue # 关停 drain 期不起新 job
|
||
due = await loop.run_in_executor(None, claim_due_jobs)
|
||
for snap in due:
|
||
asyncio.create_task(_execute_scheduled_job(snap))
|
||
if due:
|
||
print(f"[scheduler] fired {len(due)} job(s)")
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except Exception as e:
|
||
print(f"[scheduler] loop error: {type(e).__name__}: {e}")
|
||
|
||
scheduler_task = asyncio.create_task(_scheduler_loop(), name="scheduler") if scheduler_enabled else None
|
||
if scheduler_enabled:
|
||
print(f"[scheduler] enabled (tick={sched_tick}s)")
|
||
|
||
# ── 微信(ClawBot)入站长轮询管理器(§8.7)── 仅当 ZCBOT_WECHAT_BOT_ENABLED 在。
|
||
# 每个 active 绑定一条 getupdates 长轮询;收到消息 → 跑用户常驻「微信」task → 回复发回。
|
||
from core.wechat.service import clawbot_enabled
|
||
wechat_stop = asyncio.Event()
|
||
wechat_task = None
|
||
|
||
async def _run_wechat_message(uid: UUID, text: str, attachments=None) -> str:
|
||
"""微信(ClawBot)入站一条消息 → 跑用户常驻「微信」task → 取回复。
|
||
|
||
attachments:已下载解密的入站附件(core.wechat.ilink.InboundAttachment,att.data 已回填)。
|
||
建/复用 task、落盘附件、抢 run 锁、跑 agent 全在渠道无关核心
|
||
`_run_channel_conversation` 里(企业微信回调走同一核心,channel='wecom')。
|
||
"""
|
||
return await _run_channel_conversation(app, uid, text, attachments, channel="wechat")
|
||
|
||
if clawbot_enabled():
|
||
from core.wechat.inbound import run_inbound_manager
|
||
wechat_task = asyncio.create_task(
|
||
run_inbound_manager(_run_wechat_message, wechat_stop),
|
||
name="wechat-inbound",
|
||
)
|
||
print("[wechat] ClawBot inbound manager enabled")
|
||
|
||
# Sandbox pool(§7.5):仅当 ZCBOT_SANDBOX_BACKEND=docker 时启用。
|
||
# 启动钩子:① init_pool(创建 docker network + pool 实例)② shutdown_all 清
|
||
# 前驱孤儿(上次进程留下的 zcbot-sandbox-* 容器,内存 _last_active 为空,
|
||
# 全清重启)③ 后台 reaper task,每 60s 跑 reap_idle。
|
||
sandbox_backend = os.getenv("ZCBOT_SANDBOX_BACKEND", "host").lower()
|
||
sandbox_reaper_task = None
|
||
if sandbox_backend == "docker":
|
||
from core.paths import ROOT
|
||
from core.sandbox import init_pool
|
||
from core.sandbox.check import detect_fs_quota
|
||
workspace = resolve_workspace(None, _cfg)
|
||
user_root_base = workspace / "users"
|
||
# §7.5 #4 fs quota 探测:不阻塞启动(应用层周期扫描已有),仅打 WARN
|
||
# 提醒外部用户开放前必须升级到 xfs prjquota / ext4 project / zfs。
|
||
try:
|
||
level, msg = detect_fs_quota(user_root_base.resolve())
|
||
print(f"[startup] {'[ok]' if level == 'ok' else '[warn]'} {msg}")
|
||
except Exception as e:
|
||
print(f"[startup] [warn] fs quota detect failed: {type(e).__name__}: {e}")
|
||
try:
|
||
# repo_root=ROOT 让 SandboxPool 把 <repo>/skills 只读 mount 进容器
|
||
# (fs 工具进容器后 read SKILL references 需要)
|
||
# sandbox_cfg=yaml `sandbox` 段(memory/cpus/pids_limit 可调)
|
||
pool = init_pool(
|
||
user_root_base, repo_root=ROOT,
|
||
sandbox_cfg=_cfg.get("sandbox") or {},
|
||
)
|
||
removed = pool.shutdown_all()
|
||
if removed:
|
||
print(f"[startup] swept {len(removed)} stale sandbox container(s)")
|
||
|
||
async def _reaper() -> None:
|
||
loop = asyncio.get_running_loop()
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(60)
|
||
removed = await loop.run_in_executor(None, pool.reap_idle)
|
||
if removed:
|
||
print(f"[reaper] reaped {len(removed)} idle sandbox container(s)")
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except Exception as e:
|
||
print(f"[reaper] error: {type(e).__name__}: {e}")
|
||
|
||
sandbox_reaper_task = asyncio.create_task(_reaper(), name="sandbox-reaper")
|
||
app.state.sandbox_pool = pool
|
||
except Exception as e:
|
||
# ensure_network / docker CLI 不可用 → fail-fast。Stage C 协议:任一
|
||
# hardening 缺失视为部署未完成,不退化到 host(否则误以为有沙盒实则在裸跑)。
|
||
raise RuntimeError(
|
||
f"ZCBOT_SANDBOX_BACKEND=docker but sandbox init failed: {e}"
|
||
)
|
||
try:
|
||
yield
|
||
finally:
|
||
# ── 优雅 drain:先拒新 run,等在跑的 run 收尾,超时转协作式 cancel ──
|
||
# 单实例形态下消除"restart 误杀 in-flight run 标 error"。新 POST /messages
|
||
# 期间返 503(客户端退避重试覆盖)。drain_timeout 内自然跑完 → idle 零 error;
|
||
# 超时的 broker.request_cancel → 下个 chunk 间隙退(标 idle);cancel_grace 后仍
|
||
# 没退的留给 systemd SIGKILL,下次启动 reaper 标 error(最坏退化 = 改前行为)。
|
||
# ★ systemd TimeoutStopSec 必须 > drain_timeout + cancel_grace + 余量(见 RUN.md)。
|
||
app.state.draining.set()
|
||
inflight = app.state.inflight
|
||
if inflight:
|
||
print(f"[shutdown] draining {len(inflight)} in-flight run(s), "
|
||
f"timeout={drain_timeout}s")
|
||
_, pending = await asyncio.wait(
|
||
list(inflight.keys()), timeout=drain_timeout
|
||
)
|
||
if pending:
|
||
print(f"[shutdown] {len(pending)} run(s) over drain timeout; "
|
||
f"signalling cooperative cancel")
|
||
for t in pending:
|
||
cid = inflight.get(t)
|
||
if cid is not None:
|
||
broker.request_cancel(cid)
|
||
_, still = await asyncio.wait(pending, timeout=cancel_grace)
|
||
if still:
|
||
print(f"[shutdown] {len(still)} run(s) still active after "
|
||
f"cancel grace; SIGKILL takes over, next start reaps them")
|
||
|
||
disk_scanner_task.cancel()
|
||
try:
|
||
await disk_scanner_task
|
||
except (asyncio.CancelledError, Exception):
|
||
pass
|
||
stats_logger_task.cancel()
|
||
try:
|
||
await stats_logger_task
|
||
except (asyncio.CancelledError, Exception):
|
||
pass
|
||
if scheduler_task is not None:
|
||
scheduler_task.cancel()
|
||
try:
|
||
await scheduler_task
|
||
except (asyncio.CancelledError, Exception):
|
||
pass
|
||
if wechat_task is not None:
|
||
wechat_stop.set()
|
||
wechat_task.cancel()
|
||
try:
|
||
await wechat_task
|
||
except (asyncio.CancelledError, Exception):
|
||
pass
|
||
if sandbox_reaper_task is not None:
|
||
sandbox_reaper_task.cancel()
|
||
try:
|
||
await sandbox_reaper_task
|
||
except (asyncio.CancelledError, Exception):
|
||
pass
|
||
if sandbox_backend == "docker":
|
||
pool = getattr(app.state, "sandbox_pool", None)
|
||
if pool is not None:
|
||
try:
|
||
pool.shutdown_all()
|
||
except Exception as e:
|
||
print(f"[shutdown] sandbox shutdown_all error: {type(e).__name__}: {e}")
|
||
|
||
# drain 已 await inflight 收尾、run 线程退完;非阻塞关池(进程在退出,保守清理)
|
||
run_executor.shutdown(wait=False)
|
||
|
||
app = FastAPI(
|
||
title="zcbot api",
|
||
version=__version__,
|
||
description=(
|
||
"zcbot 后端 — /v1 JSON API + SSE。Auth: PLATFORM_KEY → JWT(§7 D' 过渡)。"
|
||
"本地 dev SPA: /static/dev.html。"
|
||
),
|
||
lifespan=lifespan,
|
||
)
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 本地宽松,部署 platform 时按域名收紧
|
||
allow_credentials=False,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
if _STATIC_DIR.is_dir():
|
||
# Windows 上 mimetypes 偶尔把 .js 判成 text/plain,会令 <script type="module"> 被浏览器拒执行;
|
||
# 显式兜底,保证静态 ES module 以正确 MIME 下发。
|
||
mimetypes.add_type("text/javascript", ".js")
|
||
app.mount("/static", NoCacheStaticFiles(directory=str(_STATIC_DIR)), name="static")
|
||
|
||
# ───────────── Misc ─────────────
|
||
|
||
@app.get("/", include_in_schema=False)
|
||
def root():
|
||
# 本地 dev SPA;Swagger UI 仍在 /docs
|
||
return RedirectResponse(url="/static/dev.html", status_code=302)
|
||
|
||
@app.get("/healthz", tags=["misc"])
|
||
def healthz():
|
||
return {"status": "ok", "version": __version__}
|
||
|
||
@app.get("/WW_verify_{token}.txt", include_in_schema=False)
|
||
def wecom_domain_verify(token: str):
|
||
"""企业微信「网页授权可信域名」归属校验文件(公开,无需登录)。
|
||
把企业微信下载的 WW_verify_<token>.txt 放到 ZCBOT_WECOM_VERIFY_DIR
|
||
(默认 repo 根)下,本路由按文件名在域名根 serve。"""
|
||
from fastapi.responses import PlainTextResponse
|
||
from core.paths import ROOT
|
||
if not token.isalnum(): # 防路径穿越
|
||
raise HTTPException(404, "not found")
|
||
vdir = Path(os.getenv("ZCBOT_WECOM_VERIFY_DIR", "").strip() or str(ROOT))
|
||
fpath = vdir / f"WW_verify_{token}.txt"
|
||
if not fpath.is_file():
|
||
raise HTTPException(404, "verify file not found")
|
||
return PlainTextResponse(fpath.read_text(encoding="utf-8"))
|
||
|
||
@app.get("/v1/me", tags=["misc"])
|
||
def me(user_id: UUID = Depends(require_user)):
|
||
"""当前登录用户身份(JWT → user_id → DB role)。
|
||
|
||
前端 dev SPA 用 localStorage 里的 token 恢复会话时调一次,据 role=='admin'
|
||
决定显不显"管理"入口(/static/admin.html)。role 走 DB 查,改完即时生效。
|
||
name / user_name 是平台登录注入的档案(0016),可能为 null(邮箱密码 / 历史行)。
|
||
"""
|
||
prof = get_user_profile(user_id) or {}
|
||
return {
|
||
"user_id": str(user_id),
|
||
"role": prof.get("role", "user"),
|
||
"name": prof.get("name"),
|
||
"user_name": prof.get("user_name"),
|
||
"email": prof.get("email"),
|
||
}
|
||
|
||
# ───────────── 微信接入(ClawBot,§8.7)─────────────
|
||
|
||
@app.post("/v1/wechat/bind/qrcode", tags=["wechat"])
|
||
async def wechat_bind_qrcode(user_id: UUID = Depends(require_user)):
|
||
"""起一张 ClawBot 绑定二维码(渲成 PNG data-uri)。前端展示,用户手机微信扫;
|
||
二维码 TTL ~1min,前端轮询到 expired 后重调本端点换码。"""
|
||
import base64 as _b64
|
||
import io as _io
|
||
|
||
import segno
|
||
|
||
from core.wechat import ilink
|
||
qr = await asyncio.to_thread(ilink.get_bot_qrcode)
|
||
buf = _io.BytesIO()
|
||
segno.make(qr.deeplink, error="m").save(buf, kind="png", scale=6, border=3)
|
||
data_uri = "data:image/png;base64," + _b64.b64encode(buf.getvalue()).decode()
|
||
return {"qrcode_id": qr.qrcode_id, "qr_png": data_uri}
|
||
|
||
@app.get("/v1/wechat/bind/status", tags=["wechat"])
|
||
async def wechat_bind_status(qrcode_id: str, user_id: UUID = Depends(require_user)):
|
||
"""轮询扫码状态(服务端长轮询,hold 数十秒)。confirmed → 写绑定。
|
||
返回 {status: wait|confirmed|expired};expired 时前端重起二维码。"""
|
||
from core.wechat import ilink
|
||
from core.wechat import service as _wx
|
||
res = await asyncio.to_thread(ilink.poll_qrcode_status, qrcode_id)
|
||
if res.status == "confirmed" and res.bot_token:
|
||
await asyncio.to_thread(
|
||
_wx.upsert_clawbot_binding, user_id, res.bot_token,
|
||
res.base_url or ilink.DEFAULT_BASE,
|
||
)
|
||
return {"status": res.status}
|
||
|
||
@app.get("/v1/wechat/bind", tags=["wechat"])
|
||
def wechat_bind_get(user_id: UUID = Depends(require_user)):
|
||
"""当前用户的微信绑定状态(不泄露 token)。"""
|
||
from core.wechat import service as _wx
|
||
snap = _wx.get_binding(user_id)
|
||
if snap is None or snap.status != "active":
|
||
return {"bound": False}
|
||
return {
|
||
"bound": True,
|
||
"user_im_id": snap.user_im_id,
|
||
"can_push": bool(_wx._token_fresh(snap)), # 24h 窗口内可主动推
|
||
"last_active": _iso(snap.context_token_at),
|
||
}
|
||
|
||
@app.delete("/v1/wechat/bind", status_code=204, tags=["wechat"])
|
||
def wechat_unbind(user_id: UUID = Depends(require_user)):
|
||
from core.wechat import service as _wx
|
||
_wx.unbind(user_id)
|
||
return
|
||
|
||
@app.post("/v1/wechat/test", tags=["wechat"])
|
||
async def wechat_test(user_id: UUID = Depends(require_user)):
|
||
"""自检:给已绑用户推一条测试消息(需用户近 24h 在微信开口过)。"""
|
||
from core.wechat import service as _wx
|
||
res = await asyncio.to_thread(
|
||
_wx.push_clawbot, user_id, "zcbot 测试消息:绑定成功,这条来自你的 zcbot。"
|
||
)
|
||
return {"ok": res.ok, "reason": res.reason}
|
||
|
||
# ───────────── 企业微信接入(渠道 B,纯推送,§8.7)─────────────
|
||
|
||
@app.get("/v1/wecom/oauth/url", tags=["wecom"])
|
||
def wecom_oauth_url(request: Request, user_id: UUID = Depends(require_user)):
|
||
"""生成企业微信 OAuth 网页授权链接(前端打开 → 扫码授权)。回调用 state 带回身份。
|
||
redirect 主机须在应用「网页授权可信域名」内;默认取 ZCBOT_PUBLIC_BASE_URL 或请求 base。"""
|
||
from core.wechat import wecom
|
||
if not wecom.wecom_configured():
|
||
raise HTTPException(400, "企业微信未配置(需 WECOM_CORPID/AGENTID/SECRET)")
|
||
base = (os.getenv("ZCBOT_PUBLIC_BASE_URL", "").strip()
|
||
or str(request.base_url).rstrip("/"))
|
||
redirect_uri = f"{base}/v1/wecom/oauth/callback"
|
||
state = wecom.sign_state(str(user_id))
|
||
return {"authorize_url": wecom.oauth_authorize_url(redirect_uri, state)}
|
||
|
||
@app.get("/v1/wecom/oauth/callback", include_in_schema=False)
|
||
async def wecom_oauth_callback(code: str = "", state: str = ""):
|
||
"""企业微信授权后浏览器回跳到这(无 JWT;身份从 state 验)。换 userid → 写绑定 → 回提示页。"""
|
||
from fastapi.responses import HTMLResponse
|
||
from core.wechat import service as _wx
|
||
from core.wechat import wecom
|
||
|
||
def _page(msg: str, ok: bool) -> HTMLResponse:
|
||
color = "#1a7f37" if ok else "#cf222e"
|
||
return HTMLResponse(
|
||
f"<!doctype html><meta charset=utf-8><meta name=viewport content='width=device-width,initial-scale=1'>"
|
||
f"<div style='font:16px system-ui;max-width:420px;margin:80px auto;text-align:center;color:{color}'>"
|
||
f"{msg}</div><div style='text-align:center;color:#888;font:13px system-ui'>可关闭本页返回 zcbot</div>"
|
||
)
|
||
|
||
uid = wecom.verify_state(state)
|
||
if not uid:
|
||
return _page("绑定失败:授权已过期或无效,请回 zcbot 重试。", False)
|
||
if not code:
|
||
return _page("绑定失败:未拿到授权码。", False)
|
||
try:
|
||
wecom_userid = await asyncio.to_thread(wecom.get_user_id, code)
|
||
except Exception as e:
|
||
return _page(f"绑定失败:{type(e).__name__}", False)
|
||
if not wecom_userid:
|
||
return _page("绑定失败:你不是该企业成员(只支持企业内成员)。", False)
|
||
await asyncio.to_thread(_wx.upsert_wecom_binding, UUID(uid), wecom_userid)
|
||
return _page("✅ 企业微信绑定成功!以后简报 / 结果会推到你的企业微信。", True)
|
||
|
||
@app.get("/v1/wecom/bind", tags=["wecom"])
|
||
def wecom_bind_get(user_id: UUID = Depends(require_user)):
|
||
"""当前用户企业微信绑定状态。"""
|
||
from core.wechat import service as _wx
|
||
from core.wechat import wecom
|
||
wuid = _wx.get_wecom_userid(user_id)
|
||
return {"configured": wecom.wecom_configured(), "bound": bool(wuid), "wecom_userid": wuid}
|
||
|
||
@app.put("/v1/wecom/bind/userid", tags=["wecom"])
|
||
def wecom_bind_userid(payload: dict, user_id: UUID = Depends(require_user)):
|
||
"""手填企业微信成员 userid 绑定(无 HTTPS 域名 / 不走 OAuth 时用)。
|
||
userid 见管理后台 → 通讯录 → 点成员 → 「账号」。"""
|
||
from core.wechat import service as _wx
|
||
from core.wechat import wecom
|
||
if not wecom.wecom_configured():
|
||
raise HTTPException(400, "企业微信未配置(需 WECOM_CORPID/AGENTID/SECRET)")
|
||
wuid = (payload.get("wecom_userid") or "").strip()
|
||
if not wuid:
|
||
raise HTTPException(400, "wecom_userid 不能为空")
|
||
_wx.upsert_wecom_binding(user_id, wuid)
|
||
return {"bound": True, "wecom_userid": wuid}
|
||
|
||
@app.delete("/v1/wecom/bind", status_code=204, tags=["wecom"])
|
||
def wecom_unbind(user_id: UUID = Depends(require_user)):
|
||
from core.wechat import service as _wx
|
||
_wx.unbind_wecom(user_id)
|
||
return
|
||
|
||
@app.post("/v1/wecom/test", tags=["wecom"])
|
||
async def wecom_test(user_id: UUID = Depends(require_user)):
|
||
"""自检:给已绑用户推一条企业微信测试消息(无 24h 窗口约束)。"""
|
||
from core.wechat import service as _wx
|
||
res = await asyncio.to_thread(
|
||
_wx.push_wecom, user_id, "zcbot 测试消息:企业微信绑定成功,这条来自你的 zcbot。"
|
||
)
|
||
return {"ok": res.ok, "reason": res.reason}
|
||
|
||
# ── 企业微信「接收消息」回调(入站对话,§8.7)── 无 JWT;身份从加密 XML 的 FromUserName 反查。
|
||
# 配置:企业微信后台「应用 → 接收消息 → 设置 API 接收」填本 URL + Token + EncodingAESKey,
|
||
# 对应 env WECOM_CALLBACK_TOKEN / WECOM_CALLBACK_AESKEY。回调 URL = <公网 base>/v1/wecom/callback。
|
||
@app.get("/v1/wecom/callback", include_in_schema=False)
|
||
def wecom_callback_verify(
|
||
msg_signature: str = "", timestamp: str = "", nonce: str = "", echostr: str = ""
|
||
):
|
||
"""企业微信保存回调配置时 GET 验有效性:验签 + 解密 echostr,原样回明文。"""
|
||
from fastapi.responses import PlainTextResponse
|
||
from core.wechat import wecom, wecom_crypto
|
||
if not wecom_crypto.callback_configured():
|
||
raise HTTPException(404, "wecom callback 未配置(需 WECOM_CALLBACK_TOKEN/AESKEY)")
|
||
try:
|
||
plain = wecom_crypto.verify_url(
|
||
msg_signature, timestamp, nonce, echostr, corpid=wecom._corpid()
|
||
)
|
||
except Exception as e: # noqa: BLE001
|
||
raise HTTPException(400, f"verify failed: {type(e).__name__}: {e}")
|
||
return PlainTextResponse(plain)
|
||
|
||
@app.post("/v1/wecom/callback", include_in_schema=False)
|
||
async def wecom_callback(
|
||
request: Request, msg_signature: str = "", timestamp: str = "", nonce: str = ""
|
||
):
|
||
"""企业微信推入站消息(加密 XML POST)。解密 → 反查身份 → 后台跑 agent → 主动推回。
|
||
|
||
agent 跑 >5s,远超被动回复(同步返回密文 XML)5s 窗口 → 异步:立刻回 'success' 防重试,
|
||
agent 结果走 wecom.send_text 主动推回(message/send,无 24h 窗口约束)。同一用户的并发/
|
||
重复投递由对话 task 的 run 锁挡(第二条会收到「上一条还在处理中」)。
|
||
"""
|
||
from fastapi.responses import PlainTextResponse
|
||
from core.wechat import service as _wx
|
||
from core.wechat import wecom, wecom_crypto
|
||
from core.wechat.ilink import InboundAttachment
|
||
if not wecom_crypto.callback_configured():
|
||
raise HTTPException(404, "wecom callback 未配置(需 WECOM_CALLBACK_TOKEN/AESKEY)")
|
||
body = (await request.body()).decode("utf-8")
|
||
try:
|
||
msg = wecom_crypto.decrypt_message(
|
||
msg_signature, timestamp, nonce, body, corpid=wecom._corpid()
|
||
)
|
||
except Exception as e: # noqa: BLE001
|
||
raise HTTPException(400, f"decrypt failed: {type(e).__name__}: {e}")
|
||
msgtype = msg.get("MsgType") or ""
|
||
wuid = msg.get("FromUserName") or ""
|
||
uid = await asyncio.to_thread(_wx.get_user_by_wecom_userid, wuid)
|
||
if uid is None:
|
||
return PlainTextResponse("success") # 未绑定 → 静默
|
||
|
||
# 文本取 Content;图片/文件走 media/get 下载,构造 InboundAttachment(与个人微信同结构,
|
||
# 仅 kind/file_name/data 三字段被 _run_channel_conversation 用到)。其余类型(语音/视频/
|
||
# 位置/链接/事件)暂不处理,回 success 防重试。
|
||
content = ""
|
||
attachments: list = []
|
||
if msgtype == "text":
|
||
content = (msg.get("Content") or "").strip()
|
||
elif msgtype in ("image", "file"):
|
||
media_id = msg.get("MediaId") or ""
|
||
if media_id:
|
||
try:
|
||
data, fname = await asyncio.to_thread(wecom.download_media, media_id)
|
||
attachments.append(InboundAttachment(
|
||
kind=("image" if msgtype == "image" else "file"),
|
||
media={},
|
||
file_name=(msg.get("FileName") or fname or ""),
|
||
data=data,
|
||
))
|
||
except Exception as e: # noqa: BLE001
|
||
print(f"[wecom] {wuid} download {msgtype} err: {type(e).__name__}: {e}")
|
||
else:
|
||
return PlainTextResponse("success")
|
||
if not content and not attachments:
|
||
return PlainTextResponse("success") # 空消息 / 附件下载全失败 → 静默
|
||
|
||
async def _bg(uid=uid, content=content, attachments=attachments):
|
||
try:
|
||
reply = await _run_channel_conversation(
|
||
app, uid, content, attachments, channel="wecom")
|
||
except Exception as e: # noqa: BLE001
|
||
reply = f"[出错] {type(e).__name__}: {e}"
|
||
if reply and reply.strip():
|
||
await asyncio.to_thread(_wx.push_wecom, uid, reply)
|
||
|
||
# 登记到 inflight:持强引用防 task 被 GC 中途回收 + 关停时 drain(value=None → 不参与
|
||
# broker cancel;内层 _run_agent_bg runner 另有自己的 inflight 项负责取消)。
|
||
bg = asyncio.create_task(_bg(), name=f"wecom-msg-{str(uid)[:8]}")
|
||
app.state.inflight[bg] = None
|
||
bg.add_done_callback(lambda t: app.state.inflight.pop(t, None))
|
||
return PlainTextResponse("success")
|
||
|
||
@app.get("/v1/models", tags=["misc"])
|
||
def list_models(user_id: UUID = Depends(require_user)):
|
||
"""列出所有可用 LLM 模型(扫 config/models/*.yaml)。
|
||
|
||
前端顶栏 / 新建对话框的模型下拉拉这个。is_default 标记 cfg["default_model"]
|
||
命中项。开发期不缓存,每次扫一遍(几个文件 IO);改 YAML 立即生效。
|
||
"""
|
||
from core.agent_builder import load_config
|
||
from core.capabilities import ModelCapabilities
|
||
from core.paths import ROOT
|
||
import yaml as _yaml
|
||
cfg = load_config()
|
||
default = cfg["default_model"]
|
||
models_dir = ROOT / cfg["models_dir"]
|
||
|
||
out: list[dict] = []
|
||
if models_dir.is_dir():
|
||
for path in sorted(models_dir.glob("*.yaml")):
|
||
try:
|
||
data = _yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||
except Exception:
|
||
continue
|
||
family = data.get("family") or path.stem
|
||
for variant in (data.get("variants") or {}).keys():
|
||
profile = f"{family}.{variant}"
|
||
try:
|
||
caps = ModelCapabilities.load(profile, models_dir)
|
||
except (ValueError, FileNotFoundError):
|
||
continue
|
||
out.append({
|
||
"profile": profile,
|
||
"display_name": caps.display_name or profile,
|
||
"family": caps.family,
|
||
"variant": caps.variant,
|
||
"thinking_mode": caps.thinking_mode,
|
||
"is_default": profile == default,
|
||
})
|
||
return {"models": out}
|
||
|
||
@app.get("/v1/image_models", tags=["misc"])
|
||
def list_image_models(user_id: UUID = Depends(require_user)):
|
||
"""图像生成模型清单(扫 config/media/doubao.yaml image 段)。
|
||
|
||
前端顶栏第二个下拉拉这个;空列表 → 没配 image variant,UI 隐藏下拉。
|
||
`is_default` 标第一个 variant(=agent_builder fallback 目标)。开发期不缓存,
|
||
改 YAML 加新 variant 立即生效。
|
||
"""
|
||
variants = _list_image_variants()
|
||
out: list[dict] = []
|
||
for i, (key, cfg) in enumerate(variants):
|
||
out.append({
|
||
"variant": key,
|
||
"display_name": cfg.get("display_name") or key,
|
||
"model_id": cfg.get("model_id") or "",
|
||
"price_cny_per_image": cfg.get("price_cny_per_image"),
|
||
"is_default": i == 0,
|
||
})
|
||
return {"models": out}
|
||
|
||
@app.get("/v1/video_models", tags=["misc"])
|
||
def list_video_models(user_id: UUID = Depends(require_user)):
|
||
"""视频生成模型清单(扫 config/media/doubao.yaml video 段)。
|
||
|
||
与 /v1/image_models 同范式;空列表 → UI 隐藏第三下拉。展示信息包括默认分辨率
|
||
与 token 单价(¥/Mtok 文生视频路径),方便用户在下拉选项里直接看到 cost 量级。
|
||
"""
|
||
variants = _list_video_variants()
|
||
out: list[dict] = []
|
||
for i, (key, cfg) in enumerate(variants):
|
||
out.append({
|
||
"variant": key,
|
||
"display_name": cfg.get("display_name") or key,
|
||
"model_id": cfg.get("model_id") or "",
|
||
"default_resolution": cfg.get("default_resolution"),
|
||
"default_duration": cfg.get("default_duration"),
|
||
"default_ratio": cfg.get("default_ratio"),
|
||
"price_cny_per_mtoken_text2video": cfg.get("price_cny_per_mtoken_text2video"),
|
||
"price_cny_per_mtoken_video2video": cfg.get("price_cny_per_mtoken_video2video"),
|
||
"is_default": i == 0,
|
||
})
|
||
return {"models": out}
|
||
|
||
# ───────────── Auth ─────────────
|
||
|
||
@app.post("/v1/auth/login", tags=["auth"])
|
||
def login(body: LoginRequest):
|
||
"""platform_key 校验通过 → 签 JWT(user_id 作为 sub)。
|
||
|
||
platform_key 错 → 403;user_id 非 UUID → 400。
|
||
user_id 未存在则幂等创建 users 行(避免下游 FK 失败);body 带 name / user_name
|
||
时一并 upsert 落库(平台侧改名每次登录自动同步,见 ensure_user_row)。
|
||
platform 服务端用此入口注入指定 user_id;dev SPA 走 /login_password。
|
||
"""
|
||
if body.platform_key != auth_cfg.platform_key:
|
||
raise HTTPException(403, "invalid platform_key")
|
||
try:
|
||
uid = UUID(body.user_id)
|
||
except (ValueError, TypeError):
|
||
raise HTTPException(400, f"invalid user_id (must be UUID): {body.user_id!r}")
|
||
ensure_user_row(uid, name=body.name, user_name=body.user_name)
|
||
token, exp = mint_token(auth_cfg, uid)
|
||
prof = get_user_profile(uid) or {}
|
||
return {
|
||
"token": token,
|
||
"expires_at": _dt.fromtimestamp(exp).isoformat(),
|
||
"user_id": str(uid),
|
||
"name": prof.get("name"),
|
||
"user_name": prof.get("user_name"),
|
||
"role": prof.get("role", "user"),
|
||
"ttl_seconds": auth_cfg.ttl_seconds,
|
||
}
|
||
|
||
@app.post("/v1/auth/admin/create_user", tags=["auth"])
|
||
def admin_create_user(body: AdminCreateUserRequest):
|
||
"""管理员发用户(dev SPA 登录页右下角入口)。
|
||
|
||
- `ZCBOT_ADMIN_TOKEN` env 未设 → 503,功能关闭
|
||
- `admin_token` 不匹配 → 403(不细分 "未设" / "错了",防探测)
|
||
- email 不合法 / password 太短 → 400
|
||
- email 已存在 → 409
|
||
- 成功 → `{"user_id": ..., "email": ...}`,前端提示 "已创建,请登录"
|
||
|
||
不签 token、不自动登录 —— 管理员发完用户用户自己登,逻辑清晰。
|
||
"""
|
||
if auth_cfg.admin_token is None:
|
||
raise HTTPException(503, "admin create_user disabled (ZCBOT_ADMIN_TOKEN not set)")
|
||
if body.admin_token != auth_cfg.admin_token:
|
||
raise HTTPException(403, "invalid admin_token")
|
||
try:
|
||
uid, email = create_user(
|
||
email=body.email, password=body.password, role=body.role
|
||
)
|
||
except UserCreateError as ex:
|
||
if ex.code in ("invalid_email", "weak_password", "invalid_role"):
|
||
raise HTTPException(400, ex.message)
|
||
if ex.code == "email_taken":
|
||
raise HTTPException(409, "email already exists")
|
||
raise HTTPException(500, f"create_user failed: {ex.message}")
|
||
return {"user_id": str(uid), "email": email, "role": body.role}
|
||
|
||
@app.post("/v1/auth/login_password", tags=["auth"])
|
||
def login_password(body: PasswordLoginRequest):
|
||
"""邮箱密码登录(dev SPA 给同事 / 自己试用)。
|
||
|
||
- users.email 未命中 / password_hash 为空 / bcrypt 校验失败 → 一律 403
|
||
(不细分错因,防探测用户存在性)
|
||
- 命中 → 直接用 DB 里现成 user_id 签 JWT(不 ensure_user_row,行已在 `user add` 时建)
|
||
- 发用户:`.venv/Scripts/python.exe main.py user add --email X --password Y`;
|
||
撤用户:`DELETE FROM users WHERE email=...`(先 DELETE 该 user 的 tasks)
|
||
"""
|
||
hit = resolve_user_by_email(body.email, body.password)
|
||
if hit is None:
|
||
raise HTTPException(403, "账号或密码错误")
|
||
uid, email = hit
|
||
token, exp = mint_token(auth_cfg, uid)
|
||
prof = get_user_profile(uid) or {}
|
||
return {
|
||
"token": token,
|
||
"expires_at": _dt.fromtimestamp(exp).isoformat(),
|
||
"user_id": str(uid),
|
||
"email": email,
|
||
"name": prof.get("name"),
|
||
"user_name": prof.get("user_name"),
|
||
"role": prof.get("role", "user"),
|
||
"ttl_seconds": auth_cfg.ttl_seconds,
|
||
}
|
||
|
||
@app.post("/v1/auth/change_password", tags=["auth"])
|
||
def change_password_route(
|
||
body: ChangePasswordRequest, user_id: UUID = Depends(require_user)
|
||
):
|
||
"""改密码(dev SPA 顶栏入口)。user_id 取自 JWT,不信任前端传值。
|
||
|
||
- 新密码 < 6 → 400
|
||
- 旧密码错 / 该账号无密码(platform_key 建的)→ 403(不细分,防探测)
|
||
- 用户不存在(JWT 有效但行没了)→ 401
|
||
成功 → `{"ok": true}`,前端提示并清空表单。
|
||
"""
|
||
try:
|
||
change_password(user_id, body.old_password, body.new_password)
|
||
except UserCreateError as ex:
|
||
if ex.code == "weak_password":
|
||
raise HTTPException(400, ex.message)
|
||
if ex.code in ("wrong_password", "no_password"):
|
||
raise HTTPException(403, ex.message)
|
||
if ex.code == "user_not_found":
|
||
raise HTTPException(401, ex.message)
|
||
raise HTTPException(500, f"change_password failed: {ex.message}")
|
||
return {"ok": True}
|
||
|
||
# ───────────── Tasks CRUD ─────────────
|
||
|
||
@app.post("/v1/tasks", status_code=201, tags=["tasks"])
|
||
def create_task(body: TaskCreateRequest, user_id: UUID = Depends(require_user)):
|
||
"""新建 task。
|
||
|
||
- `name` 必填(任务显示名,DB 列 NOT NULL,UI 列表 / 标题用)
|
||
- `working_dir` 可选(留空 → 用 name 作目录名);同 working_dir 多 task 共享同目录(§7.1)
|
||
- name / working_dir 都过 validate_task_name(简单名,无 `/\\..`,非 `.` 起头,≤255)
|
||
- 前缀嵌套(no-subtask,同 user 内)→ 409
|
||
"""
|
||
from core.agent_builder import InvalidTaskName, resolve_workspace, validate_task_name, working_dir_from_name
|
||
try:
|
||
name = validate_task_name(body.name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"name 不合法: {e}")
|
||
# working_dir 留空 → fallback 用 name
|
||
wd_raw = (body.working_dir or "").strip()
|
||
wd_name = wd_raw if wd_raw else name
|
||
try:
|
||
wd_name = validate_task_name(wd_name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"working_dir 不合法: {e}")
|
||
description = body.description.strip()
|
||
skill = body.skill.strip()
|
||
|
||
tid = uuid4()
|
||
ws = resolve_workspace(None)
|
||
fs_dir = working_dir_from_name(ws, user_id, wd_name)
|
||
fs_dir_db = to_db_path(fs_dir)
|
||
|
||
try:
|
||
check_no_subtask(fs_dir_db, user_id=user_id)
|
||
except NoSubtaskError as e:
|
||
raise HTTPException(409, str(e))
|
||
|
||
# 工作目录立刻建出(同 working_dir 多 task 共享,exist_ok=True)
|
||
fs_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
profile, model_id = _resolve_model_profile(body.model_profile)
|
||
ensure_local_task_row(
|
||
task_id=tid, name=name, working_dir=fs_dir_db, skill=skill,
|
||
description=description, user_id=user_id,
|
||
model=model_id, model_profile=profile,
|
||
)
|
||
with session_scope() as s:
|
||
row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
|
||
return _task_dict(row, n_messages=0)
|
||
|
||
@app.get("/v1/tasks", tags=["tasks"])
|
||
def list_tasks_route(
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
status: Optional[str] = None,
|
||
skill: Optional[str] = None,
|
||
working_dir: Optional[str] = None,
|
||
q: Optional[str] = None,
|
||
ordering: Optional[str] = None,
|
||
run_status: Optional[str] = None,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""列出当前 user 的 task,分页 + 多维筛选 + 排序。
|
||
|
||
- `page` ≥ 1(1-based);`page_size` 1–100(超界 clamp)
|
||
- `status` 在 active/completed/abandoned;非法值静默忽略
|
||
- `skill` 精确匹配(空忽略)
|
||
- `working_dir` 末段目录名(如 `水泥申报`);后端自动拼 `workspace/users/<uid>/<name>` 比对
|
||
- `q` 模糊搜索 name + description(ILIKE,大小写不敏感)
|
||
- `run_status` 逗号分隔,allowlist `idle/running/cancelling/error`;非法值静默忽略
|
||
(dev SPA 拉同 wd 活跃 task 用,通常 `running,cancelling`)
|
||
- `ordering` DRF 风格,逗号分隔,`-field` 倒序;allowlist `created_at/updated_at/name/status`;
|
||
非法字段静默忽略;**默认 `-created_at`**(创建时间倒序)
|
||
返回标准分页壳 `{page, page_size, count, results}` —— count 供前端算总页数。
|
||
"""
|
||
# clamp + sanitize
|
||
page = max(1, page)
|
||
page_size = max(1, min(page_size, 100))
|
||
status = status if status in STATUS_FILTERS else None
|
||
skill = (skill or "").strip() or None
|
||
wd_name = (working_dir or "").strip() or None
|
||
q_text = (q or "").strip() or None
|
||
rs_allowed = ("idle", "running", "cancelling", "error")
|
||
run_status_set = {
|
||
s.strip() for s in (run_status or "").split(",") if s.strip() in rs_allowed
|
||
} or None
|
||
|
||
# 组装 WHERE(软删除的 task 永不出现在列表;恢复见 /restore)
|
||
# 渠道镜像 task(wechat/wecom 常驻对话)不进普通列表 —— 它们在左栏「新建任务」下
|
||
# 做成固定卡片(GET /v1/channel_tasks),从列表排除避免重复。coalesce 兜 NULL(老 web task)。
|
||
conditions = [
|
||
Task.user_id == user_id,
|
||
Task.deleted_at.is_(None),
|
||
func.coalesce(Task.channel, "web").notin_(CHANNEL_MIRROR_KINDS),
|
||
# 定时任务执行 task(scheduled_job_id 归属)不进普通列表;兜底 working_dir
|
||
# LIKE 防 backfill 漏网的孤行(job 已物理删的 isolated task)
|
||
Task.scheduled_job_id.is_(None),
|
||
~Task.working_dir.like("%/scheduled-%"),
|
||
]
|
||
if status:
|
||
conditions.append(Task.status == status)
|
||
if skill:
|
||
conditions.append(Task.skill == skill)
|
||
if wd_name:
|
||
# 末段 → 完整 db form。同 working_dir 多 task 共享时,这是命中入口。
|
||
wd_db = f"workspace/users/{user_id}/{wd_name}"
|
||
conditions.append(Task.working_dir == wd_db)
|
||
if run_status_set:
|
||
conditions.append(Task.run_status.in_(run_status_set))
|
||
if q_text:
|
||
pat = f"%{q_text}%"
|
||
conditions.append(Task.name.ilike(pat) | Task.description.ilike(pat))
|
||
|
||
offset = (page - 1) * page_size
|
||
|
||
with session_scope() as s:
|
||
cnt = s.execute(
|
||
select(func.count()).select_from(Task).where(*conditions)
|
||
).scalar_one() or 0
|
||
|
||
rows = s.execute(
|
||
select(Task).where(*conditions)
|
||
.order_by(*_parse_ordering(ordering))
|
||
.limit(page_size).offset(offset)
|
||
).scalars().all()
|
||
|
||
tids = [r.task_id for r in rows]
|
||
msg_counts = (
|
||
dict(s.execute(
|
||
select(Message.task_id, func.count())
|
||
.where(Message.task_id.in_(tids))
|
||
.group_by(Message.task_id)
|
||
).all())
|
||
if tids else {}
|
||
)
|
||
usage = _usage_aggregates(s, tids)
|
||
|
||
return {
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"count": int(cnt),
|
||
"results": [
|
||
_task_dict(
|
||
r,
|
||
n_messages=msg_counts.get(r.task_id, 0),
|
||
usage=usage.get(r.task_id),
|
||
)
|
||
for r in rows
|
||
],
|
||
}
|
||
|
||
@app.get("/v1/channel_tasks", tags=["tasks"])
|
||
def list_channel_tasks(user_id: UUID = Depends(require_user)):
|
||
"""渠道镜像任务(微信 / 企业微信)的绑定状态 + 常驻对话摘要 —— 前端在左栏「新建任务」
|
||
下做成固定卡片。返回 `{ wechat: { bound: bool, task: <task_dict>|null },
|
||
wecom: { bound: bool, task: <task_dict>|null } }`,
|
||
bound 状态由 `get_binding` / `get_wecom_userid` 判定;task 同 /v1/tasks 列表项(复用 _task_dict),
|
||
有对话则给摘要,无则 null。前端据此渲染三种卡片:未绑定(点绑定)、已绑定无对话(占位)、
|
||
已绑定有对话(点进 + ⚙ 管理)。
|
||
"""
|
||
from core.wechat import service as _wx
|
||
|
||
snap = _wx.get_binding(user_id)
|
||
wuid = _wx.get_wecom_userid(user_id)
|
||
bound: dict[str, bool] = {
|
||
"wechat": bool(snap and snap.status == "active"),
|
||
"wecom": bool(wuid),
|
||
}
|
||
tids: dict[str, Optional[UUID]] = {
|
||
"wechat": snap.chat_task_id if snap and snap.status == "active" else None,
|
||
"wecom": _wx.get_wecom_chat_task(user_id),
|
||
}
|
||
wanted = [t for t in tids.values() if t is not None]
|
||
tasks: dict[str, Optional[dict]] = {"wechat": None, "wecom": None}
|
||
if wanted:
|
||
with session_scope() as s:
|
||
rows = {
|
||
r.task_id: r
|
||
for r in s.execute(
|
||
select(Task).where(
|
||
Task.task_id.in_(wanted),
|
||
Task.user_id == user_id,
|
||
Task.deleted_at.is_(None),
|
||
)
|
||
).scalars().all()
|
||
}
|
||
msg_counts = dict(
|
||
s.execute(
|
||
select(Message.task_id, func.count())
|
||
.where(Message.task_id.in_(list(rows.keys())))
|
||
.group_by(Message.task_id)
|
||
).all()
|
||
) if rows else {}
|
||
usage = _usage_aggregates(s, list(rows.keys()))
|
||
for kind, tid in tids.items():
|
||
row = rows.get(tid) if tid else None
|
||
if row is not None:
|
||
tasks[kind] = _task_dict(
|
||
row,
|
||
n_messages=msg_counts.get(row.task_id, 0),
|
||
usage=usage.get(row.task_id),
|
||
)
|
||
# 按渠道返回 { bound, task }
|
||
return {
|
||
"wechat": {"bound": bound["wechat"], "task": tasks["wechat"]},
|
||
"wecom": {"bound": bound["wecom"], "task": tasks["wecom"]},
|
||
}
|
||
|
||
@app.get("/v1/tasks/{task_id}", tags=["tasks"])
|
||
def get_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""单 task meta(不含 messages;走 /messages 拿)。跨 user → 404。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task).where(Task.task_id == tid, Task.user_id == user_id)
|
||
).scalar_one_or_none()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
n = s.execute(
|
||
select(func.count()).select_from(Message).where(Message.task_id == tid)
|
||
).scalar_one()
|
||
usage = _usage_aggregates(s, [tid])
|
||
return _task_dict(row, n_messages=n, usage=usage.get(tid))
|
||
|
||
@app.get("/v1/folders", tags=["folders"])
|
||
def list_folders(user_id: UUID = Depends(require_user)):
|
||
"""列出当前 user 的工作目录(`workspace/users/<uid>/` 下非 dotfile 子目录)。
|
||
供新建 task 时自动补全 / 选已有目录用。FS 是 source of truth(也含手动创建
|
||
但还无关联 task 的目录)。每项带 n_tasks(关联 task 数)+ last_used(最近使用 ISO)。
|
||
排序:有 last_used 的按降序,无 last_used 的排最后,同列 by name asc。
|
||
"""
|
||
from core.agent_builder import resolve_workspace, user_root
|
||
ws = resolve_workspace(None)
|
||
root = user_root(ws, user_id)
|
||
|
||
folder_names: list[str] = []
|
||
if root.is_dir():
|
||
for p in sorted(root.iterdir(), key=lambda x: x.name.lower()):
|
||
if p.is_dir() and not p.name.startswith("."):
|
||
folder_names.append(p.name)
|
||
|
||
folders: list[dict] = []
|
||
if folder_names:
|
||
with session_scope() as s:
|
||
for name in folder_names:
|
||
db_form = f"workspace/users/{user_id}/{name}"
|
||
stat = s.execute(
|
||
select(func.count(), func.max(Task.updated_at))
|
||
.where(
|
||
Task.user_id == user_id,
|
||
Task.working_dir == db_form,
|
||
Task.deleted_at.is_(None),
|
||
)
|
||
).first()
|
||
n = int((stat[0] if stat else 0) or 0)
|
||
lu = stat[1] if stat else None
|
||
folders.append({
|
||
"name": name,
|
||
"n_tasks": n,
|
||
"last_used": _iso(lu),
|
||
})
|
||
|
||
folders.sort(key=lambda f: f["name"])
|
||
folders.sort(key=lambda f: f["last_used"] or "", reverse=True)
|
||
return {"folders": folders}
|
||
|
||
# ───────────── Skills ─────────────
|
||
|
||
@app.get("/v1/skills", tags=["skills"])
|
||
def list_skills(user_id: UUID = Depends(require_user)):
|
||
"""列出当前用户可用的 skill(内置 + 自己的),供新建 task 时下拉选择。
|
||
|
||
每次请求现扫(内置 `skills/<name>/SKILL.md` + 用户 `.skills/<name>/SKILL.md`,
|
||
稳态 ~3ms),加 / 改 / 删 skill 目录后无需重启即可在前端看到。
|
||
`core/agent_builder.py::build_agent` 同样每次新建 SkillRegistry,所以 agent 内部
|
||
`load_skill` 与 system prompt discovery 也是热的。源标 `source`(builtin/user)+
|
||
`overrides_builtin`(用户 skill 覆盖了同名内置)。`load_errors` 列出用户 skill
|
||
因 frontmatter 问题未加载的,供前端提示。
|
||
"""
|
||
from core.agent_builder import build_skill_registry, load_config, resolve_workspace
|
||
cfg = load_config()
|
||
ws = resolve_workspace(None, cfg)
|
||
reg = build_skill_registry(cfg, ws, user_id, docker=False)
|
||
return {
|
||
"skills": [
|
||
{
|
||
"name": s.name,
|
||
"description": s.description,
|
||
"source": s.source,
|
||
"overrides_builtin": s.name in reg.user_overrides,
|
||
}
|
||
for s in reg.skills.values()
|
||
],
|
||
"load_errors": [{"name": n, "reason": r} for n, r in reg.load_errors],
|
||
}
|
||
|
||
@app.get("/v1/skills/{name}", tags=["skills"])
|
||
def get_skill(name: str, user_id: UUID = Depends(require_user)):
|
||
"""返回某 skill 的完整 SKILL.md 正文(供前端 modal 展开查看)。
|
||
|
||
内置 + 用户两来源都可查;同名时按 user wins 取用户那份(与 agent 看到的一致)。
|
||
"""
|
||
from core.agent_builder import build_skill_registry, load_config, resolve_workspace
|
||
cfg = load_config()
|
||
ws = resolve_workspace(None, cfg)
|
||
reg = build_skill_registry(cfg, ws, user_id, docker=False)
|
||
skill = reg.get(name)
|
||
if skill is None:
|
||
raise HTTPException(404, f"skill not found: {name!r}")
|
||
try:
|
||
content = skill.full_content()
|
||
except OSError as e:
|
||
raise HTTPException(500, f"读取 SKILL.md 失败: {e}")
|
||
return {
|
||
"name": skill.name,
|
||
"source": skill.source,
|
||
"overrides_builtin": skill.name in reg.user_overrides,
|
||
"content": content,
|
||
}
|
||
|
||
@app.delete("/v1/skills/{name}", status_code=204, tags=["skills"])
|
||
def delete_skill(name: str, user_id: UUID = Depends(require_user)):
|
||
"""删除当前用户的私有 skill(`.skills/<name>/` 整目录)。
|
||
|
||
只能删 user 源 —— 内置 skill 不可删(404,等同"用户那里没有这个可删的")。
|
||
`.skills` 在文件面板隐藏,这是 UI 上删除自己 skill 的唯一入口。
|
||
"""
|
||
import shutil
|
||
from core.agent_builder import build_skill_registry, load_config, resolve_workspace, user_root
|
||
cfg = load_config()
|
||
ws = resolve_workspace(None, cfg)
|
||
reg = build_skill_registry(cfg, ws, user_id, docker=False)
|
||
skill = reg.get(name)
|
||
if skill is None or skill.source != "user":
|
||
raise HTTPException(404, f"no user skill to delete: {name!r}")
|
||
# 防穿越:目标必须落在该用户的 .skills 子树内(skill_dir 来自扫描,理应如此,仍兜一层)
|
||
user_skills_dir = (user_root(ws, user_id) / ".skills").resolve()
|
||
target = skill.skill_dir.resolve()
|
||
if user_skills_dir not in target.parents:
|
||
raise HTTPException(400, "拒绝删除 .skills 之外的路径")
|
||
shutil.rmtree(target)
|
||
|
||
@app.get("/v1/memory", tags=["memory"])
|
||
def get_memory(user_id: UUID = Depends(require_user)):
|
||
"""记忆全貌(只读):core.md 原文 + extended 列表(filename + description)。
|
||
|
||
前端「记忆」弹框一次拉满。**只读** —— 改记忆全走对话(agent 自管,见
|
||
`core/memory.py::_CONTRACT`),GUI 当"眼睛"不当"手"(DESIGN §3.7)。
|
||
每次现读 FS(同 skills 端点),agent 刚写的即时可见。
|
||
"""
|
||
from core.agent_builder import resolve_workspace
|
||
from core.memory import memory_view
|
||
ws = resolve_workspace(None)
|
||
return memory_view(ws, user_id)
|
||
|
||
@app.get("/v1/memory/extended/{filename}", tags=["memory"])
|
||
def get_memory_extended(filename: str, user_id: UUID = Depends(require_user)):
|
||
"""读单篇 extended 原文(点开列表项时拉)。文件名非法 / 不存在 → 404。
|
||
|
||
路径穿越校验收口在 `core/memory.py::read_extended_file`(只许 `.memory/extended/`
|
||
下扁平 `.md`,再 resolve 兜一层子树)。
|
||
"""
|
||
from core.agent_builder import resolve_workspace
|
||
from core.memory import read_extended_file
|
||
ws = resolve_workspace(None)
|
||
content = read_extended_file(ws, user_id, filename)
|
||
if content is None:
|
||
raise HTTPException(404, f"memory file not found: {filename!r}")
|
||
return {"filename": filename, "content": content}
|
||
|
||
# ───────────── 定时任务(DESIGN §8.5)─────────────
|
||
# 前端只读展示 + 停用/删除两个便捷动作;建/改全走对话(schedule_* 工具)。
|
||
# 与对话工具共用 core.scheduler 服务层,两条路径不漂移。
|
||
@app.get("/v1/schedules", tags=["schedules"])
|
||
def list_schedules(user_id: UUID = Depends(require_user)):
|
||
"""列当前用户的定时任务(只读)。前端「定时」面板一次拉满。"""
|
||
from core import scheduler
|
||
return {"results": scheduler.list_jobs(user_id)}
|
||
|
||
@app.patch("/v1/schedules/{job_id}", tags=["schedules"])
|
||
def patch_schedule(
|
||
job_id: str, body: SchedulePatchRequest, user_id: UUID = Depends(require_user),
|
||
):
|
||
"""改定时任务 —— 前端只用来停用/启用(enabled)。其余编辑走对话。"""
|
||
from core import scheduler
|
||
if body.enabled is None:
|
||
raise HTTPException(400, "no fields to update")
|
||
try:
|
||
return scheduler.set_enabled(user_id, job_id, body.enabled)
|
||
except scheduler.JobError as e:
|
||
raise HTTPException(404, str(e))
|
||
|
||
@app.delete("/v1/schedules/{job_id}", status_code=204, tags=["schedules"])
|
||
def delete_schedule(job_id: str, user_id: UUID = Depends(require_user)):
|
||
"""删定时任务(软删,立即停止触发)。"""
|
||
from core import scheduler
|
||
try:
|
||
scheduler.cancel_job(user_id, job_id)
|
||
except scheduler.JobError as e:
|
||
raise HTTPException(404, str(e))
|
||
|
||
@app.get("/v1/schedules/{job_id}/tasks", tags=["schedules"])
|
||
def list_schedule_tasks(
|
||
job_id: str,
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""列某定时任务的历史执行 task(归属 scheduled_job_id),按 created_at 倒序分页。
|
||
|
||
isolated 模式每次触发新建一个 task,这些 task 不进普通 /v1/tasks 列表
|
||
(scheduled_job_id 过滤),只能经此端点回看;persistent 模式始终只有绑定的那一条。
|
||
返回标准分页壳 `{page, page_size, count, results}`(results 同 /v1/tasks,复用 _task_dict)。
|
||
user_id 过滤天然隔离他人 job —— 非本人 / 非法 job_id 一律返回空列表。
|
||
"""
|
||
page = max(1, page)
|
||
page_size = max(1, min(page_size, 100))
|
||
try:
|
||
jid = UUID(str(job_id).strip())
|
||
except (ValueError, AttributeError):
|
||
return {"page": page, "page_size": page_size, "count": 0, "results": []}
|
||
|
||
conditions = [
|
||
Task.user_id == user_id,
|
||
Task.deleted_at.is_(None),
|
||
Task.scheduled_job_id == jid,
|
||
]
|
||
offset = (page - 1) * page_size
|
||
with session_scope() as s:
|
||
cnt = s.execute(
|
||
select(func.count()).select_from(Task).where(*conditions)
|
||
).scalar_one() or 0
|
||
rows = s.execute(
|
||
select(Task).where(*conditions)
|
||
.order_by(Task.created_at.desc())
|
||
.limit(page_size).offset(offset)
|
||
).scalars().all()
|
||
tids = [r.task_id for r in rows]
|
||
msg_counts = (
|
||
dict(s.execute(
|
||
select(Message.task_id, func.count())
|
||
.where(Message.task_id.in_(tids))
|
||
.group_by(Message.task_id)
|
||
).all())
|
||
if tids else {}
|
||
)
|
||
usage = _usage_aggregates(s, tids)
|
||
|
||
return {
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"count": int(cnt),
|
||
"results": [
|
||
_task_dict(r, n_messages=msg_counts.get(r.task_id, 0), usage=usage.get(r.task_id))
|
||
for r in rows
|
||
],
|
||
}
|
||
|
||
@app.delete("/v1/tasks/{task_id}", status_code=204, tags=["tasks"])
|
||
def delete_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""软删除:置 deleted_at=now(),从任务列表隐藏。
|
||
|
||
DB 行 / messages / usage_events(原 CASCADE 不再触发)及工作目录文件全部保留
|
||
—— 留作训练语料,且可经 POST /v1/tasks/{id}/restore 恢复。不动任何磁盘文件。
|
||
已软删的再次调用幂等返回 204。跨 user / 不存在 → 404。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.deleted_at).where(
|
||
Task.task_id == tid, Task.user_id == user_id,
|
||
)
|
||
).first()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
if row.deleted_at is None:
|
||
s.execute(
|
||
update(Task)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.values(deleted_at=func.now())
|
||
)
|
||
return None # 204
|
||
|
||
@app.post("/v1/tasks/{task_id}/restore", tags=["tasks"])
|
||
def restore_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""恢复软删除的 task(置 deleted_at=NULL),重新出现在列表。
|
||
|
||
未软删的幂等成功。跨 user / 不存在 → 404。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task).where(Task.task_id == tid, Task.user_id == user_id)
|
||
).scalar_one_or_none()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
row.deleted_at = None # ORM 脏标记,session_scope 提交时落库
|
||
n = s.execute(
|
||
select(func.count()).select_from(Message).where(Message.task_id == tid)
|
||
).scalar_one()
|
||
usage = _usage_aggregates(s, [tid])
|
||
return _task_dict(row, n_messages=n, usage=usage.get(tid))
|
||
|
||
@app.patch("/v1/tasks/{task_id}", tags=["tasks"])
|
||
def patch_task(
|
||
task_id: str,
|
||
body: TaskPatchRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""更新 task 字段。`status` 仅允许 completed/abandoned(active 走 CLI 切回)。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
updates: dict[str, Any] = {}
|
||
if body.status is not None:
|
||
if body.status not in STATUS_WRITABLE:
|
||
raise HTTPException(
|
||
400, f"invalid status {body.status!r}; allowed: {STATUS_WRITABLE}"
|
||
)
|
||
updates["status"] = body.status
|
||
if body.description is not None:
|
||
updates["description"] = body.description
|
||
if body.skill is not None:
|
||
updates["skill"] = body.skill
|
||
if body.name is not None:
|
||
from core.agent_builder import InvalidTaskName, validate_task_name
|
||
try:
|
||
updates["name"] = validate_task_name(body.name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"name 不合法: {e}")
|
||
if body.model_profile is not None:
|
||
# 切模型:校验后双列同更(profile + model_id)。下条 send 才生效 — 当前
|
||
# in-flight run 不受影响(build_agent resume 时下次重读)。
|
||
profile, model_id = _resolve_model_profile(body.model_profile)
|
||
updates["model_profile"] = profile
|
||
updates["model"] = model_id
|
||
if not updates:
|
||
raise HTTPException(400, "no fields to update")
|
||
with session_scope() as s:
|
||
result = s.execute(
|
||
update(Task)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.values(**updates)
|
||
)
|
||
if result.rowcount == 0:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
|
||
n = s.execute(
|
||
select(func.count()).select_from(Message).where(Message.task_id == tid)
|
||
).scalar_one()
|
||
usage = _usage_aggregates(s, [tid])
|
||
return _task_dict(row, n_messages=n, usage=usage.get(tid))
|
||
|
||
# ───────────── Messages ─────────────
|
||
|
||
def _assert_owns_task(s, tid: UUID, user_id: UUID) -> None:
|
||
ok = s.execute(
|
||
select(Task.task_id).where(Task.task_id == tid, Task.user_id == user_id)
|
||
).first()
|
||
if ok is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
|
||
@app.get("/v1/tasks/{task_id}/messages", tags=["messages"])
|
||
def list_messages(
|
||
task_id: str,
|
||
limit: int = None,
|
||
before_idx: int = None,
|
||
after_idx: int = None,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""task 历史消息(idx 升序);LiteLLM 原 payload 透传给前端,自行渲染。
|
||
|
||
分页(双向窗口):
|
||
- 不传 limit → 升序全量返回(向后兼容旧前端),两个 has_more 都 false。
|
||
- 传 limit(默认)→ 取**尾部**最近 limit 条(idx desc + limit 再 reverse 回升序)。
|
||
- 传 before_idx → 只取 idx < before_idx 的更早部分(向上翻页)。
|
||
- 传 after_idx → 只取 idx > after_idx 的更新部分(向下翻页;从目录跳到旧消息后用)。
|
||
响应恒含 has_more(窗口之前是否还有更早)+ has_more_after(窗口之后是否还有更新)。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
cols = (
|
||
Message.idx, Message.payload, Message.tokens_in,
|
||
Message.tokens_out, Message.model_profile, Message.created_at,
|
||
)
|
||
if limit is None:
|
||
# 旧行为:升序全量
|
||
rows = s.execute(
|
||
select(*cols).where(Message.task_id == tid).order_by(Message.idx)
|
||
).all()
|
||
elif after_idx is not None:
|
||
# 向下窗口:正序取 idx > after_idx 的最早 limit 条
|
||
rows = list(s.execute(
|
||
select(*cols)
|
||
.where(Message.task_id == tid, Message.idx > after_idx)
|
||
.order_by(Message.idx).limit(limit)
|
||
).all())
|
||
else:
|
||
# 尾部 / 向上窗口:倒序取 limit 条,再翻回升序
|
||
q = select(*cols).where(Message.task_id == tid)
|
||
if before_idx is not None:
|
||
q = q.where(Message.idx < before_idx)
|
||
rows = list(s.execute(q.order_by(Message.idx.desc()).limit(limit)).all())
|
||
rows.reverse()
|
||
# 窗口两端外是否还有(供前端顶/底 sentinel 决定要不要继续补)
|
||
has_more = False
|
||
has_more_after = False
|
||
if rows:
|
||
first_idx = rows[0].idx
|
||
last_idx = rows[-1].idx
|
||
has_more = s.execute(
|
||
select(Message.idx)
|
||
.where(Message.task_id == tid, Message.idx < first_idx)
|
||
.limit(1)
|
||
).first() is not None
|
||
has_more_after = s.execute(
|
||
select(Message.idx)
|
||
.where(Message.task_id == tid, Message.idx > last_idx)
|
||
.limit(1)
|
||
).first() is not None
|
||
return {
|
||
"has_more": has_more,
|
||
"has_more_after": has_more_after,
|
||
"messages": [
|
||
{
|
||
"idx": r.idx,
|
||
"payload": dict(r.payload),
|
||
"tokens_in": r.tokens_in,
|
||
"tokens_out": r.tokens_out,
|
||
"model_profile": r.model_profile, # 0006:assistant 行非空,标产生该 msg 的模型
|
||
"created_at": _iso(r.created_at),
|
||
}
|
||
for r in rows
|
||
]
|
||
}
|
||
|
||
@app.get("/v1/tasks/{task_id}/outline", tags=["messages"])
|
||
def task_outline(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""消息目录:全部 user 轮次的 {idx, snippet}(idx 升序),供右侧圆点轨道导航。
|
||
|
||
只取 role=user 的 idx + content 首行片段,不回传整 payload(轻量,长任务也快);
|
||
走 (task_id, idx) 索引按 task 收窄,role 过滤为残余条件。前端点圆点 → 已加载则
|
||
scrollIntoView,未加载则用 before_idx 拉居中窗口再定位。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
rows = s.execute(
|
||
select(Message.idx, Message.payload["content"].astext)
|
||
.where(
|
||
Message.task_id == tid,
|
||
Message.payload["role"].astext == "user",
|
||
)
|
||
.order_by(Message.idx)
|
||
).all()
|
||
return {
|
||
"items": [
|
||
{"idx": r[0], "snippet": _outline_snippet(r[1])} for r in rows
|
||
]
|
||
}
|
||
|
||
@app.post("/v1/tasks/{task_id}/messages", status_code=202, tags=["messages"])
|
||
async def post_message(
|
||
task_id: str,
|
||
body: MessageRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""发消息 + 起 BG run。返 `{events_url}`,客户端立刻订阅 SSE 拿流式。
|
||
|
||
单活 run:`SELECT … FOR UPDATE` 锁 task 行 + 活跃状态检查 + 标 running,
|
||
全收进一个事务挡住"用户连点 send 两条消息"导致两个 BG 线程争 `messages.idx`。
|
||
tasks.run_status in ('running','cancelling') → 409;'error' 走起新 run 时清掉
|
||
(跟 ok / cancelled 一样视为可重启)。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
# 关停 drain 期:拒新 run,带 Retry-After 让客户端退避重试(部署窗口背压)。
|
||
if getattr(app.state, "draining", None) is not None and app.state.draining.is_set():
|
||
raise HTTPException(
|
||
503, "server is restarting; retry shortly",
|
||
headers={"Retry-After": "3"},
|
||
)
|
||
content = (body.content or "").strip()
|
||
if not content:
|
||
raise HTTPException(400, "empty content")
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.run_status)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.with_for_update()
|
||
).first()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
if row.run_status in ("running", "cancelling"):
|
||
raise HTTPException(
|
||
409,
|
||
f"task already has an active run (status={row.run_status}); "
|
||
f"wait for it to finish or cancel",
|
||
)
|
||
s.execute(
|
||
update(Task).where(Task.task_id == tid).values(
|
||
run_status="running", run_error=None,
|
||
)
|
||
)
|
||
# image_model / video_model 在 POST 时校验,避免 BG 线程里抛在 sink 之外难追;空串透传不查 yaml。
|
||
image_variant = _resolve_image_model(body.image_model)
|
||
video_variant = _resolve_video_model(body.video_model)
|
||
broker.start(tid) # 清上一轮 done 标记,新订阅者才能看到流式
|
||
# commit 后 lock 释放;BG 线程接管(sink 通过 broker 把 event 桥回 asyncio loop)。
|
||
# 登记到 app.state.inflight:① 关停 drain 时 await 它收尾 ② 持强引用防 task 被 GC
|
||
# 中途回收(asyncio.create_task 不留引用是已知坑)。done 回调自摘除。
|
||
run_task = asyncio.create_task(asyncio.to_thread(
|
||
_run_agent_bg, tid, user_id, content, image_variant, video_variant,
|
||
))
|
||
app.state.inflight[run_task] = tid
|
||
run_task.add_done_callback(lambda t: app.state.inflight.pop(t, None))
|
||
return {"events_url": f"/v1/tasks/{tid}/events"}
|
||
|
||
# ───────────── Cancel current run ─────────────
|
||
|
||
@app.post("/v1/tasks/{task_id}/cancel", status_code=202, tags=["tasks"])
|
||
def cancel_task(
|
||
task_id: str,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""向当前 task 的活跃 run 发协作式 cancel 信号。
|
||
- 单活 run 形态下"取消当前活动"语义无歧义;客户端只需 task_id
|
||
- 校验 task 归属 user;否则 404
|
||
- tasks.run_status 不是 `running` → 409(idle / cancelling / error 都不能 cancel)
|
||
- 标 `cancelling`(过渡态),BG 线程 loop 在 stream chunk 间 + 工具调用之间 poll 看见即退;
|
||
退出后 finally 写终态(正常→idle,异常→error)
|
||
- LLM 走 streaming,cancel 延迟 ~ 单 chunk 间隔(100ms 级)
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.run_status)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.with_for_update()
|
||
).first()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
if row.run_status != "running":
|
||
raise HTTPException(
|
||
409,
|
||
f"task not running (run_status={row.run_status}); cannot cancel",
|
||
)
|
||
s.execute(
|
||
update(Task).where(Task.task_id == tid).values(run_status="cancelling")
|
||
)
|
||
broker.request_cancel(tid)
|
||
return {"ok": True, "task_id": str(tid), "run_status": "cancelling"}
|
||
|
||
# ───────────── Clear conversation ─────────────
|
||
|
||
@app.post("/v1/tasks/{task_id}/clear", tags=["messages"])
|
||
def clear_messages(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""清空当前 task 全部 messages,token 累计 / cost / run_error 归零。
|
||
|
||
同 working_dir 下的 FS 文件不动(沿用 task delete 的"FS 视图可重生"心智 —
|
||
中间产物保留,模型重起对话时可继续基于已有素材推进)。
|
||
usage_events 不动:那是用户级账户级用量记账,不该被对话清理影响。
|
||
|
||
- 活跃 run(running / cancelling)期间拒绝:409(先 cancel)
|
||
- error 状态可清:顺手 run_status='idle' + run_error=None
|
||
- 跨 user → 404
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
from sqlalchemy import delete as _delete
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.run_status)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
.with_for_update()
|
||
).first()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
if row.run_status in ("running", "cancelling"):
|
||
raise HTTPException(
|
||
409,
|
||
f"task has an active run (status={row.run_status}); "
|
||
f"cancel it first",
|
||
)
|
||
s.execute(_delete(Message).where(Message.task_id == tid))
|
||
s.execute(
|
||
update(Task).where(Task.task_id == tid).values(
|
||
tokens_prompt=0,
|
||
tokens_completion=0,
|
||
cost_cny=0,
|
||
run_status="idle",
|
||
run_error=None,
|
||
)
|
||
)
|
||
task_row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
|
||
d = _task_dict(task_row, n_messages=0)
|
||
return d
|
||
|
||
# ───────────── Prompt optimize(辅助 LLM 调用,不入 messages)─────────────
|
||
|
||
@app.post("/v1/tasks/{task_id}/optimize_prompt", tags=["messages"])
|
||
def optimize_prompt(
|
||
task_id: str,
|
||
body: OptimizePromptRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""用 task 当前 model 润色用户草稿 prompt;返回优化后的文本。
|
||
|
||
- 同步调用(短文本,3-5s),非 stream
|
||
- 不入 messages 表;**不**累计到 tasks.tokens_prompt/completion(顶栏数字保持
|
||
只反映主对话)。usage_events 单独写一行 kind="prompt_optimize",方便对账
|
||
+ 按 kind GROUP BY 评估"这个按钮值不值"
|
||
- 不与主对话 run 互斥(它不写 messages,无 idx 竞争)— 用户在 LLM 流式
|
||
回复期间也可润色下一条草稿
|
||
- image_model 影响 meta-prompt 里给 LLM 的下游 tool 提示;不动 DB
|
||
"""
|
||
from decimal import Decimal
|
||
from core.agent_builder import load_config
|
||
from core.capabilities import ModelCapabilities
|
||
from core.llm import LLM
|
||
from core.paths import ROOT
|
||
from core.storage.models import UsageEvent
|
||
from core.storage.usage import USD_TO_CNY
|
||
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
text = (body.text or "").strip()
|
||
if not text:
|
||
raise HTTPException(400, "empty text")
|
||
if len(text) > 4000:
|
||
raise HTTPException(400, "text too long (>4000 chars)")
|
||
|
||
with session_scope() as s:
|
||
row = s.execute(
|
||
select(Task.model_profile)
|
||
.where(Task.task_id == tid, Task.user_id == user_id)
|
||
).first()
|
||
if row is None:
|
||
raise HTTPException(404, f"task not found: {tid}")
|
||
task_model_profile = row.model_profile or ""
|
||
|
||
cfg = load_config()
|
||
chosen_profile = task_model_profile or cfg["default_model"]
|
||
try:
|
||
caps = ModelCapabilities.load(chosen_profile, ROOT / cfg["models_dir"])
|
||
except (FileNotFoundError, ValueError) as e:
|
||
raise HTTPException(500, f"invalid task model_profile {chosen_profile!r}: {e}")
|
||
|
||
# 收集下游 tool 上下文:对话模型 display_name + 当前选中 image/video variant 元数据
|
||
chat_model_display = caps.display_name or chosen_profile
|
||
image_variant_hint = ""
|
||
img_variant = (body.image_model or "").strip()
|
||
if img_variant:
|
||
for k, v in _list_image_variants():
|
||
if k == img_variant:
|
||
name = v.get("display_name") or k
|
||
sz = v.get("default_size") or "2048x2048"
|
||
image_variant_hint = (
|
||
f"\n下游生图工具:{name}(默认尺寸 {sz},支持中英文 prompt,"
|
||
f"擅长写实/插画/构图描述)。若用户意图涉及画面/封面/插图,"
|
||
f"润色后的文本要给出适合该模型的画面细节(主体/风格/光线/构图)。"
|
||
)
|
||
video_variant_hint = ""
|
||
vid_variant = (body.video_model or "").strip()
|
||
if vid_variant:
|
||
for k, v in _list_video_variants():
|
||
if k == vid_variant:
|
||
name = v.get("display_name") or k
|
||
res = v.get("default_resolution") or "720p"
|
||
dur = v.get("default_duration") or 5
|
||
video_variant_hint = (
|
||
f"\n下游生视频工具:{name}(默认 {res} / {dur}s / 16:9)。"
|
||
f"若用户意图涉及视频/动画/动起来,润色后的文本要补全 "
|
||
f"主体在做什么(运动)+ 镜头怎么动 + 场景 + 风格,而非静态画面描述。"
|
||
)
|
||
|
||
meta_prompt = (
|
||
f"你的任务是润色用户输入的草稿,使之成为一个清晰、完整、可执行的 prompt。\n"
|
||
f"当前对话模型:{chat_model_display}。{image_variant_hint}{video_variant_hint}\n\n"
|
||
f"规则:\n"
|
||
f"1. 只输出润色后的文本本身,不要任何解释、前后缀、引号、markdown 代码块包裹\n"
|
||
f"2. 保留用户原始语言(中文/英文)\n"
|
||
f"3. 补全模糊点(主体、目标、风格、约束),但不要无中生有改变用户意图\n"
|
||
f"4. 长度合理 — 简短诉求润色后也应当简洁,不要堆砌\n\n"
|
||
f"用户草稿:\n{text}"
|
||
)
|
||
|
||
llm = LLM(caps)
|
||
try:
|
||
response = llm.chat(
|
||
messages=[{"role": "user", "content": meta_prompt}],
|
||
tools=None,
|
||
)
|
||
except Exception as e:
|
||
raise HTTPException(502, f"llm call failed: {type(e).__name__}: {e}")
|
||
|
||
try:
|
||
optimized = (response.choices[0].message.content or "").strip()
|
||
except Exception:
|
||
raise HTTPException(502, "llm response missing content")
|
||
if not optimized:
|
||
raise HTTPException(502, "llm returned empty optimization")
|
||
|
||
usage = getattr(response, "usage", None)
|
||
prompt_tokens = int(getattr(usage, "prompt_tokens", 0) or 0)
|
||
completion_tokens = int(getattr(usage, "completion_tokens", 0) or 0)
|
||
try:
|
||
from litellm import completion_cost
|
||
cost_usd_raw = completion_cost(completion_response=response)
|
||
cost_usd = Decimal(str(cost_usd_raw)) if cost_usd_raw else Decimal("0")
|
||
except Exception:
|
||
cost_usd = Decimal("0")
|
||
cost_cny = (cost_usd * USD_TO_CNY).quantize(Decimal("0.000001"))
|
||
|
||
try:
|
||
with session_scope() as s:
|
||
s.add(UsageEvent(
|
||
user_id=user_id,
|
||
task_id=tid,
|
||
message_id=None,
|
||
kind="prompt_optimize",
|
||
model_profile=chosen_profile,
|
||
units={
|
||
"tokens_in": prompt_tokens,
|
||
"tokens_out": completion_tokens,
|
||
"usd_to_cny": float(USD_TO_CNY),
|
||
"image_model_hint": img_variant or "",
|
||
"video_model_hint": vid_variant or "",
|
||
},
|
||
cost_cny=cost_cny,
|
||
))
|
||
except Exception as e:
|
||
# 记账失败不阻塞返结果 — 用户拿到润色文本要紧,事后人工补
|
||
print(f"[optimize_prompt] usage record failed: {type(e).__name__}: {e}", flush=True)
|
||
|
||
return {
|
||
"optimized": optimized,
|
||
"model_profile": chosen_profile,
|
||
"tokens_in": prompt_tokens,
|
||
"tokens_out": completion_tokens,
|
||
"cost_cny": float(cost_cny),
|
||
}
|
||
|
||
# ───────────── SSE events ─────────────
|
||
|
||
@app.get("/v1/tasks/{task_id}/events", tags=["tasks"])
|
||
async def stream_events(
|
||
task_id: str,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""SSE 流。订阅当前 task 的活动 event(单活 run 形态下无歧义)。
|
||
事件类型:run_start / llm_start / text / tool_call / tool_result /
|
||
llm_end / cancelled / error / done。data 是 JSON dict(已剔除 `type` 字段,
|
||
移到 event 名)。
|
||
"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
run_status = s.execute(
|
||
select(Task.run_status).where(Task.task_id == tid)
|
||
).scalar_one()
|
||
|
||
# 重连保护:若 task 不在活跃态(进程重启 / reaper 已收尾 / 自然结束),
|
||
# 直接吐 done 关流。否则 broker 进程内队列空,客户端会无限挂在 ping 上。
|
||
is_active = run_status in ("running", "cancelling")
|
||
|
||
async def gen():
|
||
yield b": connected\nretry: 3000\n\n"
|
||
if not is_active:
|
||
yield _sse_event("done", {})
|
||
return
|
||
q = broker.subscribe(tid)
|
||
try:
|
||
while True:
|
||
try:
|
||
ev = await asyncio.wait_for(q.get(), timeout=30.0)
|
||
except asyncio.TimeoutError:
|
||
yield b": ping\n\n"
|
||
continue
|
||
ev_type = ev.get("type", "msg")
|
||
payload = {k: v for k, v in ev.items() if k != "type"}
|
||
yield _sse_event(ev_type, payload)
|
||
if ev_type in ("done", "error"):
|
||
break
|
||
except asyncio.CancelledError:
|
||
pass # 客户端断开,静默退
|
||
finally:
|
||
broker.unsubscribe(tid, q)
|
||
|
||
return StreamingResponse(
|
||
gen(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|
||
|
||
# ───────────── Files(user-rooted,不绑 task) ─────────────
|
||
|
||
@app.get("/v1/user/storage", tags=["user"])
|
||
def user_storage(user_id: UUID = Depends(require_user)):
|
||
"""当前用户磁盘用量 + 配额。
|
||
|
||
数据来自后台扫描落库的 user_disk_usage(默 15min 一次),非实时;
|
||
无扫描记录(新用户 / 首次扫描前)bytes_used/file_count=0、scanned_at=None。
|
||
limit_bytes<=0 或 None → 不限(前端不画进度条)。
|
||
"""
|
||
from core.agent_builder import load_config as _load_cfg
|
||
from core.storage.disk_quota import get_user_usage, parse_bytes
|
||
_quotas_cfg = (_load_cfg().get("quotas") or {})
|
||
_limit = parse_bytes(_quotas_cfg.get("disk_bytes_per_user"))
|
||
usage = get_user_usage(user_id)
|
||
if usage is None:
|
||
bytes_used, file_count, scanned_at = 0, 0, None
|
||
else:
|
||
bytes_used, file_count, scanned_at = usage
|
||
return {
|
||
"bytes_used": bytes_used,
|
||
"file_count": file_count,
|
||
"limit_bytes": (_limit if (_limit and _limit > 0) else None),
|
||
"scanned_at": scanned_at.isoformat() if scanned_at else None,
|
||
}
|
||
|
||
@app.get("/v1/files", tags=["files"])
|
||
def list_files(
|
||
path: str = "",
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""列 user_root 下子目录条目 + 面包屑。`path` 留空 → user_root;
|
||
`../` / 绝对 → 400。dotfile(`.memory/` 等)一律隐藏。
|
||
"""
|
||
root = _load_user_root(user_id)
|
||
current = _safe_join(root, path)
|
||
entries, crumbs, exists = _enumerate_files(root, current)
|
||
return {
|
||
"root": _norm_path(str(root)),
|
||
"current": _rel_to(root, current),
|
||
"exists": exists,
|
||
"crumbs": crumbs,
|
||
"entries": entries,
|
||
}
|
||
|
||
@app.get("/v1/files/download", tags=["files"])
|
||
def download_file(
|
||
path: str,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""下载 user_root 下单个 regular file(目录 → 400 / 不存在 → 404)。"""
|
||
root = _load_user_root(user_id)
|
||
target = _safe_join(root, path)
|
||
if not target.exists():
|
||
raise HTTPException(404, f"file not found: {path}")
|
||
if not target.is_file():
|
||
raise HTTPException(400, f"not a file: {path}")
|
||
# workspace 文件可变, 禁浏览器启发式缓存 (RFC 7234 默认能缓数小时)
|
||
# 否则文件改了 SPA 预览还是旧内容
|
||
# (Starlette FileResponse 不实现 304, 总是 200 全量; workspace 文件小, 可接受)
|
||
return FileResponse(
|
||
path=str(target),
|
||
filename=target.name,
|
||
headers={"Cache-Control": "no-cache"},
|
||
)
|
||
|
||
@app.get("/v1/files/preview_pdf", tags=["files"])
|
||
async def preview_pdf(
|
||
path: str,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""把 user_root 下的 .pptx 转成 PDF 返回,供前端复用 PDF iframe 在线预览。
|
||
|
||
转换跑在 backend host(不进沙盒),按需触发 + 缓存到 `.preview/`(DESIGN §8.3)。
|
||
soffice 缺失 → 501;转换失败/超时 → 500;前端据此回退到下载。
|
||
"""
|
||
from .pptx_render import (
|
||
PptxConvertError,
|
||
SofficeNotFoundError,
|
||
pptx_to_pdf,
|
||
)
|
||
|
||
root = _load_user_root(user_id)
|
||
target = _safe_join(root, path)
|
||
if not target.exists():
|
||
raise HTTPException(404, f"file not found: {path}")
|
||
if not target.is_file():
|
||
raise HTTPException(400, f"not a file: {path}")
|
||
if target.suffix.lower() not in (".pptx", ".ppt"):
|
||
raise HTTPException(400, f"not a pptx: {path}")
|
||
|
||
abs_path = str(target.resolve())
|
||
loop = asyncio.get_event_loop()
|
||
async with _pptx_lock_for(abs_path):
|
||
try:
|
||
pdf_path = await loop.run_in_executor(None, pptx_to_pdf, target)
|
||
except SofficeNotFoundError as e:
|
||
raise HTTPException(501, str(e))
|
||
except PptxConvertError as e:
|
||
raise HTTPException(500, str(e))
|
||
return FileResponse(
|
||
path=str(pdf_path),
|
||
media_type="application/pdf",
|
||
headers={"Cache-Control": "no-cache"},
|
||
)
|
||
|
||
@app.post("/v1/files/upload", tags=["files"])
|
||
async def upload_files(
|
||
path: str = Form(""),
|
||
files: list[UploadFile] = File(...),
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""multipart 多文件上传到 `<user_root>/<path>/`。
|
||
路径不存在自动 mkdir(parents=True);重名直接覆盖。
|
||
文件名严格校验(含 `/ \\ ..` 或为空 → 400)。
|
||
"""
|
||
# 磁盘配额 gate(§7.5 #4):超额 413 阻止上传,提示 user 清旧产物
|
||
from core.agent_builder import load_config as _load_cfg
|
||
from core.storage.disk_quota import check_disk_quota, parse_bytes
|
||
_quotas_cfg = (_load_cfg().get("quotas") or {})
|
||
_limit = parse_bytes(_quotas_cfg.get("disk_bytes_per_user"))
|
||
if _limit is not None and _limit > 0:
|
||
_err = check_disk_quota(user_id, _limit)
|
||
if _err is not None:
|
||
raise HTTPException(413, _err)
|
||
|
||
root = _load_user_root(user_id)
|
||
dest_dir = _safe_join(root, path)
|
||
if dest_dir.exists() and not dest_dir.is_dir():
|
||
raise HTTPException(400, f"upload target is a file, not a directory: {path}")
|
||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
saved: list[dict] = []
|
||
for up in files or []:
|
||
raw_name = up.filename or ""
|
||
if (
|
||
not raw_name
|
||
or raw_name in (".", "..")
|
||
or "/" in raw_name or "\\" in raw_name
|
||
or any(part in (".", "..") for part in Path(raw_name).parts)
|
||
):
|
||
raise HTTPException(400, f"invalid filename: {raw_name!r}")
|
||
dest = dest_dir / raw_name
|
||
try:
|
||
dest.resolve().relative_to(root.resolve())
|
||
except ValueError:
|
||
raise HTTPException(400, f"path escapes user_root: {raw_name!r}")
|
||
data = await up.read()
|
||
dest.write_bytes(data)
|
||
saved.append({"name": raw_name, "size": len(data), "rel": _rel_to(root, dest)})
|
||
if not saved:
|
||
raise HTTPException(400, "no files uploaded")
|
||
return {"count": len(saved), "saved": saved}
|
||
|
||
@app.post("/v1/files/delete", tags=["files"])
|
||
def delete_file(
|
||
body: FileDeleteRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""删 user_root 下文件或目录。
|
||
|
||
- `recursive=False`(默认):目录必须为空(`rmdir`),非空 → 400
|
||
- `recursive=True`:`shutil.rmtree`;若目标是顶层目录且被某 task.working_dir
|
||
引用 → 409,提示先 DELETE task(避免 DB 还引用、FS artifacts 已清的错位)
|
||
- 顶层空目录 / 子级空目录无论 recursive 与否都可删:task.working_dir 字段不动,
|
||
下次 build_agent 按需 mkdir 重建,FS 目录视为可重生
|
||
- root → 400;不存在 → 404
|
||
"""
|
||
root = _load_user_root(user_id)
|
||
target = _safe_join(root, body.path)
|
||
if target.resolve() == root.resolve():
|
||
raise HTTPException(400, "cannot delete user_root")
|
||
if not target.exists():
|
||
raise HTTPException(404, f"path not found: {body.path}")
|
||
|
||
if target.is_dir() and body.recursive:
|
||
is_top_level = target.parent.resolve() == root.resolve()
|
||
if is_top_level:
|
||
db_form = to_db_path(target)
|
||
with session_scope() as s:
|
||
n = s.execute(
|
||
select(func.count()).select_from(Task).where(
|
||
Task.user_id == user_id,
|
||
Task.working_dir == db_form,
|
||
Task.deleted_at.is_(None), # 软删 task 不再算引用
|
||
)
|
||
).scalar_one() or 0
|
||
if n:
|
||
raise HTTPException(
|
||
409,
|
||
f"该顶层目录正被 {n} 个 task 引用,不能递归删除;"
|
||
f"请先 DELETE task,再清残留文件",
|
||
)
|
||
|
||
try:
|
||
if target.is_dir():
|
||
if body.recursive:
|
||
import shutil
|
||
shutil.rmtree(target)
|
||
else:
|
||
target.rmdir() # 非空目录会触发 OSError
|
||
else:
|
||
target.unlink()
|
||
except OSError as e:
|
||
raise HTTPException(400, f"delete failed: {e}")
|
||
return {"ok": True, "path": body.path}
|
||
|
||
@app.post("/v1/files/rename", tags=["files"])
|
||
def rename_path(
|
||
body: FileRenameRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""重命名 user_root 下文件或目录(任意深度)。
|
||
|
||
- `path` 必填,指被重命名对象;不能为 user_root
|
||
- `new_name` 是新 leaf 名(`validate_task_name`:非空 / 不含 `/\\..` / 非 dotfile / ≤255);
|
||
不是路径,parent 自动取自原 path
|
||
- 目标 sibling `<parent>/<new_name>` 不能已存在(防覆盖)
|
||
- **path 是顶层目录**(user_root 直接子项,且为目录)→ DB-aware:
|
||
* 同事务内 `SELECT ... FOR UPDATE` 锁该目录对应 task;任一 run_status 在
|
||
running/cancelling → 409(避免 BG 线程握旧路径而 DB 已指新路径)
|
||
* `check_no_subtask(new_db, exclude=被改名 tids)` 防止改名后跟其它 task 形成嵌套
|
||
* `UPDATE tasks SET working_dir=new_db WHERE task_id IN (...)` 先写 DB
|
||
* 再 `os.rename` FS;失败 → 抛错 → session_scope 回滚 DB
|
||
* 唯一不一致窗口是 "FS 已改名 + commit 阶段失败"(PG 单事务 commit 极少失败)
|
||
- 非顶层(子目录 / 文件)→ 纯 FS rename,不动 DB
|
||
"""
|
||
from core.agent_builder import InvalidTaskName, validate_task_name
|
||
|
||
root = _load_user_root(user_id)
|
||
target = _safe_join(root, body.path)
|
||
if target.resolve() == root.resolve():
|
||
raise HTTPException(400, "cannot rename user_root")
|
||
if not target.exists():
|
||
raise HTTPException(404, f"path not found: {body.path}")
|
||
|
||
try:
|
||
new_name = validate_task_name(body.new_name)
|
||
except InvalidTaskName as e:
|
||
raise HTTPException(400, f"new_name 不合法: {e}")
|
||
|
||
if new_name == target.name:
|
||
raise HTTPException(400, f"new_name 与原名相同: {new_name!r}")
|
||
|
||
new_target = target.parent / new_name
|
||
if new_target.exists():
|
||
raise HTTPException(
|
||
409, f"target already exists: {_rel_to(root, new_target)!r}"
|
||
)
|
||
|
||
is_top_level_dir = (
|
||
target.is_dir() and target.parent.resolve() == root.resolve()
|
||
)
|
||
|
||
if not is_top_level_dir:
|
||
try:
|
||
target.rename(new_target)
|
||
except OSError as e:
|
||
raise HTTPException(400, f"rename failed: {e}")
|
||
return {
|
||
"ok": True,
|
||
"old": body.path,
|
||
"new": _rel_to(root, new_target),
|
||
"tasks_updated": 0,
|
||
}
|
||
|
||
# 顶层目录:DB-aware
|
||
old_db = to_db_path(target)
|
||
new_db = to_db_path(new_target)
|
||
with session_scope() as s:
|
||
rows = s.execute(
|
||
select(Task.task_id, Task.run_status)
|
||
.where(Task.user_id == user_id, Task.working_dir == old_db)
|
||
.with_for_update()
|
||
).all()
|
||
tids = [r.task_id for r in rows]
|
||
active = [
|
||
str(r.task_id)[:8] for r in rows
|
||
if r.run_status in ("running", "cancelling")
|
||
]
|
||
if active:
|
||
raise HTTPException(
|
||
409,
|
||
f"folder has active run(s) on task(s) {active}; "
|
||
f"cancel before renaming",
|
||
)
|
||
try:
|
||
check_no_subtask(new_db, user_id=user_id, exclude_task_ids=tids)
|
||
except NoSubtaskError as e:
|
||
raise HTTPException(409, str(e))
|
||
if tids:
|
||
s.execute(
|
||
update(Task)
|
||
.where(Task.task_id.in_(tids))
|
||
.values(working_dir=new_db)
|
||
)
|
||
try:
|
||
target.rename(new_target)
|
||
except OSError as e:
|
||
# 抛 HTTPException 也会让 session_scope 走 except 分支回滚 UPDATE
|
||
raise HTTPException(400, f"FS rename failed: {e}")
|
||
|
||
return {
|
||
"ok": True,
|
||
"old": body.path,
|
||
"new": _rel_to(root, new_target),
|
||
"tasks_updated": len(tids),
|
||
}
|
||
|
||
@app.post("/v1/files/copy", tags=["files"])
|
||
def copy_files(
|
||
body: FileTransferRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""批量拷贝 paths → dest_dir/<name>(目录递归)。
|
||
|
||
- 不覆盖(任一目标已存在 → 409)
|
||
- 不能拷到自己 / 自身子树
|
||
- 顶层目录(可能是某 task 的 working_dir)可以拷:新副本无 task 关联,不动 DB
|
||
- 部分失败语义:任一 FS 拷贝抛错 → 抛 HTTPException,**前面已成功的拷贝保留**
|
||
(无 FS 事务可回滚;预检通过后通常不会失败,失败也是磁盘满 / 权限这类不能恢复的)
|
||
"""
|
||
import shutil
|
||
root = _load_user_root(user_id)
|
||
sources, dest = _validate_transfer(root, body.paths, body.dest_dir)
|
||
transferred: list[dict] = []
|
||
for src in sources:
|
||
target = dest / src.name
|
||
try:
|
||
if src.is_dir():
|
||
shutil.copytree(src, target)
|
||
else:
|
||
shutil.copy2(src, target)
|
||
except OSError as e:
|
||
raise HTTPException(
|
||
500,
|
||
f"copy failed at {src.name!r}: {e} "
|
||
f"(已成功 {len(transferred)} 项,剩余未处理)",
|
||
)
|
||
transferred.append({
|
||
"old": _rel_to(root, src),
|
||
"new": _rel_to(root, target),
|
||
})
|
||
return {"ok": True, "count": len(transferred), "transferred": transferred}
|
||
|
||
@app.post("/v1/files/move", tags=["files"])
|
||
def move_files(
|
||
body: FileTransferRequest,
|
||
user_id: UUID = Depends(require_user),
|
||
):
|
||
"""批量移动 paths → dest_dir/<name>。
|
||
|
||
- 不覆盖、不自嵌套(同 /copy)
|
||
- **顶层目录是某 task 的 working_dir → 409**,维持 "working_dir = 顶层目录" invariant
|
||
(允许的话 task working_dir 沉到子目录会让 rename 顶层的 DB-aware 逻辑失效;
|
||
用户想归档:先 DELETE task)
|
||
- 拷贝(`/copy`)无此限制,因为新副本无 task 关联
|
||
- 部分失败:同 /copy,前面成功的不回滚(`shutil.move` 失败几乎只发生在
|
||
跨卷拷贝中断,workspace 都在同一磁盘下罕见)
|
||
"""
|
||
import shutil
|
||
root = _load_user_root(user_id)
|
||
sources, dest = _validate_transfer(root, body.paths, body.dest_dir)
|
||
|
||
# 顶层目录-是-某 task.working_dir → 闸
|
||
top_level_dir_srcs = [
|
||
s for s in sources
|
||
if s.is_dir() and s.parent.resolve() == root.resolve()
|
||
]
|
||
if top_level_dir_srcs:
|
||
db_forms = [to_db_path(s) for s in top_level_dir_srcs]
|
||
with session_scope() as s:
|
||
rows = s.execute(
|
||
select(Task.working_dir, func.count())
|
||
.where(
|
||
Task.user_id == user_id,
|
||
Task.working_dir.in_(db_forms),
|
||
)
|
||
.group_by(Task.working_dir)
|
||
).all()
|
||
occupied = {wd: n for wd, n in rows}
|
||
if occupied:
|
||
# 反查 db_form → src.name 给报错文案
|
||
form2name = {to_db_path(s): s.name for s in top_level_dir_srcs}
|
||
names = ", ".join(
|
||
f"{form2name[wd]!r}({n} 个 task)"
|
||
for wd, n in occupied.items()
|
||
)
|
||
raise HTTPException(
|
||
409,
|
||
f"以下顶层目录正被 task 引用,不能移动:{names};"
|
||
f"请先删 task,或改用复制",
|
||
)
|
||
|
||
transferred: list[dict] = []
|
||
for src in sources:
|
||
target = dest / src.name
|
||
try:
|
||
shutil.move(str(src), str(target))
|
||
except OSError as e:
|
||
raise HTTPException(
|
||
500,
|
||
f"move failed at {src.name!r}: {e} "
|
||
f"(已成功 {len(transferred)} 项,剩余未处理)",
|
||
)
|
||
transferred.append({
|
||
"old": _rel_to(root, src),
|
||
"new": _rel_to(root, target),
|
||
})
|
||
return {"ok": True, "count": len(transferred), "transferred": transferred}
|
||
|
||
# ───────────── Export ─────────────
|
||
|
||
@app.get("/v1/tasks/{task_id}/export", tags=["export"])
|
||
def export_task(task_id: str, user_id: UUID = Depends(require_user)):
|
||
"""导出对话为 .docx,临时文件下载完后 BackgroundTask 删 tmp。"""
|
||
try:
|
||
tid = UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
||
with session_scope() as s:
|
||
_assert_owns_task(s, tid, user_id)
|
||
has_msg = s.execute(
|
||
select(Message.message_id).where(Message.task_id == tid).limit(1)
|
||
).first()
|
||
if not has_msg:
|
||
raise HTTPException(400, "no messages to export")
|
||
|
||
fd, tmp_str = tempfile.mkstemp(suffix=".docx", prefix="zcbot-export-")
|
||
os.close(fd)
|
||
tmp_path = Path(tmp_str)
|
||
try:
|
||
from core.export_docx import export_chat_to_docx
|
||
export_chat_to_docx(tid, out_path=tmp_path)
|
||
except Exception as e:
|
||
tmp_path.unlink(missing_ok=True)
|
||
raise HTTPException(500, f"export failed: {type(e).__name__}: {e}")
|
||
|
||
return FileResponse(
|
||
path=str(tmp_path),
|
||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||
filename=f"chat_{str(tid)[:8]}.docx",
|
||
background=BackgroundTask(tmp_path.unlink, missing_ok=True),
|
||
)
|
||
|
||
# ───────────── 管理后台(admin-only)─────────────
|
||
register_admin_routes(app, require_admin)
|
||
|
||
return app
|