"""FastAPI app: 纯 /v1 JSON API(2026-05-15 切换 — 详见 DESIGN §7.9)。 设计要点: - 所有路由 `/v1/*` 前缀,响应 JSON;模板 / HTMX / 服务端 markdown 渲染全删 - SSE 事件 payload 是 JSON dict 而非 HTML 片段(`event: ` + `data: `) - Auth: PLATFORM_KEY → JWT 兑换(§7 D' 过渡形态,见 web/auth.py);OIDC 替换时只动 /v1/auth/login 内部 - 所有 /v1/tasks* 路由 Depends(require_user),按 user_id 隔离数据 - 豁免:/healthz、/docs、/openapi.json、/、/v1/auth/login、/static/* - CORS allow_origins=["*"] 本地宽松;真发布按 platform 域名收紧 - `GET /` 302 → /static/dev.html(本地 dev SPA) """ from __future__ import annotations import asyncio import json import os import tempfile from contextlib import asynccontextmanager from datetime import datetime as _dt from pathlib import Path from typing import Any, Optional from uuid import UUID, uuid4 from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from sqlalchemy import func, select, update from starlette.background import BackgroundTask from core.paths import to_db_path from core.storage import ( NoSubtaskError, check_no_subtask, session_scope, ) from core.storage.models import Message, Run, Task from core.storage.utils import ensure_local_task_row from .auth import AuthConfig, ensure_user_row, make_require_user, mint_token from .broker import broker from .sinks import WebEventSink STATUS_FILTERS = ("active", "completed", "abandoned") STATUS_WRITABLE = ("completed", "abandoned") # web 不让从 web 端切回 active(走 CLI) ORDER_FIELDS = ("created_at", "updated_at", "name", "status") ORDER_DEFAULT = "-created_at" # ─────────────────────────── helpers ─────────────────────────── def _norm_path(p: str) -> str: """跨 OS 显示归一:backslash → forward slash。""" return (p or "").replace("\\", "/") def _iso(dt: Optional[Any]) -> Optional[str]: return dt.isoformat() if dt else None def _parse_ordering(s: Optional[str]) -> list: """DRF 风格 `ordering` 解析:逗号分隔多字段,`-` 前缀代表 desc。 allowlist 见 `ORDER_FIELDS`;非法字段静默丢弃。全部非法或空串 → `ORDER_DEFAULT`(`-created_at`)。 返回 sqlalchemy `order_by` 列表(可直接 `*expand`)。 """ spec = (s or "").strip() or ORDER_DEFAULT cols = [] for part in spec.split(","): p = part.strip() if not p: continue asc = True if p.startswith("-"): asc = False p = p[1:] if p in ORDER_FIELDS: col = getattr(Task, p) cols.append(col.asc() if asc else col.desc()) if not cols: # 用户传了全无效字段 → fallback 默认 cols = [Task.created_at.desc()] return cols def _task_dict(row: Any, *, n_messages: Optional[int] = None) -> dict: """Task ORM row → API JSON dict。""" d = { "task_id": str(row.task_id), "name": row.name or "", "description": row.description or "", "working_dir": _norm_path(row.working_dir or ""), "status": row.status, "skill": row.skill or "", "model": row.model or "", "model_profile": row.model_profile or "", "tokens_prompt": row.tokens_prompt or 0, "tokens_completion": row.tokens_completion or 0, "tokens": (row.tokens_prompt or 0) + (row.tokens_completion or 0), "created_at": _iso(getattr(row, "created_at", None)), "updated_at": _iso(getattr(row, "updated_at", None)), } if n_messages is not None: d["n_messages"] = n_messages return d # ─────────────────────── files helpers ─────────────────────── def _load_user_root(user_id: UUID) -> Path: """user_root = `/users//`,所有 files API 的边界。 若目录尚未存在自动 mkdir(空 user 首次访问也能拿到根)。 """ from main import resolve_workspace, user_root ws = resolve_workspace(None) return user_root(ws, user_id) def _safe_join(root: Path, rel: str) -> Path: """归一用户路径到 absolute,并校验仍在 root 内。防 `../` / 绝对 path / symlink 越界。""" rel = (rel or "").strip() if not rel: return root.resolve() if rel[0] in ("/", "\\"): raise HTTPException(400, f"absolute-style path not allowed: {rel!r}") if Path(rel).is_absolute(): raise HTTPException(400, f"absolute path not allowed: {rel!r}") target = (root / rel).resolve() try: target.relative_to(root.resolve()) except ValueError: raise HTTPException(400, f"path escapes user_root: {rel!r}") return target def _rel_to(root: Path, target: Path) -> str: try: rel = target.resolve().relative_to(root.resolve()).as_posix() except ValueError: return "" return "" if rel == "." else rel def _enumerate_files(root: Path, current: Path) -> tuple[list[dict], list[dict], bool]: """枚举 current 下条目 + 拼面包屑。size raw bytes,mtime ISO 串(前端 humanize)。 Dotfile 一律隐藏(`.memory/` 等系统区不暴露给 UI,同 `/v1/folders` 约定)。 """ entries: list[dict] = [] exists = current.exists() if exists and current.is_dir(): try: raw = sorted(current.iterdir(), key=lambda p: (p.is_file(), p.name.lower())) except OSError: raw = [] for p in raw: if p.name.startswith("."): continue try: st = p.stat() except OSError: continue entries.append({ "name": p.name, "is_dir": p.is_dir(), "size": st.st_size if p.is_file() else None, "mtime": _dt.fromtimestamp(st.st_mtime).isoformat(timespec="seconds"), "rel": _rel_to(root, p), }) cur_rel = _rel_to(root, current) crumbs = [{"label": "/", "rel": ""}] # cur_rel == "." 表示当前就在 root(target.relative_to(root) 返 Path(".")), # 不该再追加一个无意义的 "." crumb if cur_rel and cur_rel != ".": acc = "" for part in cur_rel.split("/"): acc = f"{acc}/{part}" if acc else part crumbs.append({"label": part, "rel": acc}) return entries, crumbs, exists # ─────────────────── Run 启动 + SSE 帧格式 ─────────────────── def _run_agent_bg(task_id: UUID, run_id: UUID, user_id: UUID, user_message: str) -> None: """工作线程:`build_agent(resume=True)` → 装 WebEventSink → `agent.run` → 写 runs 状态。 sink 通过 broker.emit 桥事件回 asyncio loop;agent.run 是 sync,所以在 to_thread 跑。 user_id 必须从 JWT 那侧透传过来 —— 决定 memory_block 读哪个 per-user 子树。 """ from main import build_agent, sync_task_tokens try: broker.emit(run_id, {"type": "run_start"}) agent, session, sid, task_state, task_dir = build_agent( session_id=str(task_id), resume=True, user_id=user_id, ) agent.sink = WebEventSink(broker, run_id) agent.run(user_message) sync_task_tokens(task_state, agent.llm) with session_scope() as s: s.execute( update(Run).where(Run.run_id == run_id).values( status="ok", finished_at=func.now(), tokens_p=agent.llm.token_counter.prompt_tokens, tokens_c=agent.llm.token_counter.completion_tokens, ) ) except Exception as e: err = f"{type(e).__name__}: {e}" broker.emit(run_id, {"type": "error", "msg": err}) try: with session_scope() as s: s.execute( update(Run).where(Run.run_id == run_id).values( status="error", error=err, finished_at=func.now() ) ) except Exception: pass # 已 emit error 给前端,DB 写失败不放大噪声 finally: broker.close(run_id) def _sse_event(event_type: str, payload: dict) -> bytes: """格式化 SSE 一帧:`event: ` + `data: `。""" body = json.dumps(payload, ensure_ascii=False, separators=(",", ":")) return f"event: {event_type}\ndata: {body}\n\n".encode("utf-8") # ────────────────────── Pydantic 请求体 ────────────────────── class TaskCreateRequest(BaseModel): name: str # 任务显示名(必填,DB 列 NOT NULL) working_dir: str = "" # 工作目录名(可选,留空 → 用 name 作目录名) description: str = "" skill: str = "" class TaskPatchRequest(BaseModel): status: Optional[str] = None description: Optional[str] = None name: Optional[str] = None skill: Optional[str] = None class MessageRequest(BaseModel): content: str class FileDeleteRequest(BaseModel): path: str class LoginRequest(BaseModel): user_id: str platform_key: str # ────────────────────── App 工厂 ────────────────────── # web/static 目录路径 — /static 静态挂载用,dev.html 也放这 _STATIC_DIR = Path(__file__).parent / "static" def create_app() -> FastAPI: # fail-fast:env 缺失直接抛,不裸跑无密 auth_cfg = AuthConfig.from_env() require_user = make_require_user(auth_cfg) @asynccontextmanager async def lifespan(app: FastAPI): broker.bind_loop(asyncio.get_running_loop()) # Stale-run reaper:上次进程 crash 留下的 "running" 行已无 BG 线程继续, # 启动时标 error,让对应 task 重新可发消息(否则 409 gate 永挂)。 # TODO 真生产 multi-worker:换 heartbeat / lease,只 reap 自家 worker 的孤儿。 with session_scope() as s: result = s.execute( update(Run) .where(Run.status == "running") .values( status="error", error="server restarted before run finished", finished_at=func.now(), ) ) if result.rowcount: print(f"[startup] reaped {result.rowcount} stale running run(s)") yield app = FastAPI( title="zcbot api", version="0.8", description=( "zcbot 后端 — /v1 JSON API + SSE。Auth: PLATFORM_KEY → JWT(§7 D' 过渡)。" "本地 dev SPA: /static/dev.html。" ), lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], # 本地宽松,部署 platform 时按域名收紧 allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) if _STATIC_DIR.is_dir(): app.mount("/static", StaticFiles(directory=str(_STATIC_DIR)), name="static") # ───────────── Misc ───────────── @app.get("/", include_in_schema=False) def root(): # 本地 dev SPA;Swagger UI 仍在 /docs return RedirectResponse(url="/static/dev.html", status_code=302) @app.get("/healthz", tags=["misc"]) def healthz(): return {"status": "ok"} # ───────────── Auth ───────────── @app.post("/v1/auth/login", tags=["auth"]) def login(body: LoginRequest): """platform_key 校验通过 → 签 JWT(user_id 作为 sub)。 platform_key 错 → 403;user_id 非 UUID → 400。 user_id 未存在则幂等创建 users 行(避免下游 FK 失败)。 """ if body.platform_key != auth_cfg.platform_key: raise HTTPException(403, "invalid platform_key") try: uid = UUID(body.user_id) except (ValueError, TypeError): raise HTTPException(400, f"invalid user_id (must be UUID): {body.user_id!r}") ensure_user_row(uid) token, exp = mint_token(auth_cfg, uid) return { "token": token, "expires_at": _dt.fromtimestamp(exp).isoformat(), "user_id": str(uid), "ttl_seconds": auth_cfg.ttl_seconds, } # ───────────── Tasks CRUD ───────────── @app.post("/v1/tasks", status_code=201, tags=["tasks"]) def create_task(body: TaskCreateRequest, user_id: UUID = Depends(require_user)): """新建 task。 - `name` 必填(任务显示名,DB 列 NOT NULL,UI 列表 / 标题用) - `working_dir` 可选(留空 → 用 name 作目录名);同 working_dir 多 task 共享同目录(§7.1) - name / working_dir 都过 validate_task_name(简单名,无 `/\\..`,非 `.` 起头,≤255) - 前缀嵌套(no-subtask,同 user 内)→ 409 """ from main import InvalidTaskName, resolve_workspace, validate_task_name, working_dir_from_name try: name = validate_task_name(body.name) except InvalidTaskName as e: raise HTTPException(400, f"name 不合法: {e}") # working_dir 留空 → fallback 用 name wd_raw = (body.working_dir or "").strip() wd_name = wd_raw if wd_raw else name try: wd_name = validate_task_name(wd_name) except InvalidTaskName as e: raise HTTPException(400, f"working_dir 不合法: {e}") description = body.description.strip() skill = body.skill.strip() tid = uuid4() ws = resolve_workspace(None) fs_dir = working_dir_from_name(ws, user_id, wd_name) fs_dir_db = to_db_path(fs_dir) try: check_no_subtask(fs_dir_db, user_id=user_id) except NoSubtaskError as e: raise HTTPException(409, str(e)) # 工作目录立刻建出(同 working_dir 多 task 共享,exist_ok=True) fs_dir.mkdir(parents=True, exist_ok=True) ensure_local_task_row( task_id=tid, name=name, working_dir=fs_dir_db, skill=skill, description=description, user_id=user_id, ) with session_scope() as s: row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one() return _task_dict(row, n_messages=0) @app.get("/v1/tasks", tags=["tasks"]) def list_tasks_route( page: int = 1, page_size: int = 20, status: Optional[str] = None, skill: Optional[str] = None, working_dir: Optional[str] = None, q: Optional[str] = None, ordering: Optional[str] = None, user_id: UUID = Depends(require_user), ): """列出当前 user 的 task,分页 + 多维筛选 + 排序。 - `page` ≥ 1(1-based);`page_size` 1–100(超界 clamp) - `status` 在 active/completed/abandoned;非法值静默忽略 - `skill` 精确匹配(空忽略) - `working_dir` 末段目录名(如 `水泥申报`);后端自动拼 `workspace/users//` 比对 - `q` 模糊搜索 name + description(ILIKE,大小写不敏感) - `ordering` DRF 风格,逗号分隔,`-field` 倒序;allowlist `created_at/updated_at/name/status`; 非法字段静默忽略;**默认 `-created_at`**(创建时间倒序) 返回标准分页壳 `{page, page_size, count, results}` —— count 供前端算总页数。 """ # clamp + sanitize page = max(1, page) page_size = max(1, min(page_size, 100)) status = status if status in STATUS_FILTERS else None skill = (skill or "").strip() or None wd_name = (working_dir or "").strip() or None q_text = (q or "").strip() or None # 组装 WHERE conditions = [Task.user_id == user_id] if status: conditions.append(Task.status == status) if skill: conditions.append(Task.skill == skill) if wd_name: # 末段 → 完整 db form。同 working_dir 多 task 共享时,这是命中入口。 wd_db = f"workspace/users/{user_id}/{wd_name}" conditions.append(Task.working_dir == wd_db) if q_text: pat = f"%{q_text}%" conditions.append(Task.name.ilike(pat) | Task.description.ilike(pat)) offset = (page - 1) * page_size with session_scope() as s: cnt = s.execute( select(func.count()).select_from(Task).where(*conditions) ).scalar_one() or 0 rows = s.execute( select(Task).where(*conditions) .order_by(*_parse_ordering(ordering)) .limit(page_size).offset(offset) ).scalars().all() tids = [r.task_id for r in rows] msg_counts = ( dict(s.execute( select(Message.task_id, func.count()) .where(Message.task_id.in_(tids)) .group_by(Message.task_id) ).all()) if tids else {} ) return { "page": page, "page_size": page_size, "count": int(cnt), "results": [ _task_dict(r, n_messages=msg_counts.get(r.task_id, 0)) for r in rows ], } @app.get("/v1/tasks/{task_id}", tags=["tasks"]) def get_task(task_id: str, user_id: UUID = Depends(require_user)): """单 task meta(不含 messages;走 /messages 拿)。跨 user → 404。""" try: tid = UUID(task_id) except ValueError: raise HTTPException(404, f"invalid task id: {task_id!r}") with session_scope() as s: row = s.execute( select(Task).where(Task.task_id == tid, Task.user_id == user_id) ).scalar_one_or_none() if row is None: raise HTTPException(404, f"task not found: {tid}") n = s.execute( select(func.count()).select_from(Message).where(Message.task_id == tid) ).scalar_one() return _task_dict(row, n_messages=n) @app.get("/v1/folders", tags=["folders"]) def list_folders(user_id: UUID = Depends(require_user)): """列出当前 user 的工作目录(`workspace/users//` 下非 dotfile 子目录)。 供新建 task 时自动补全 / 选已有目录用。FS 是 source of truth(也含手动创建 但还无关联 task 的目录)。每项带 n_tasks(关联 task 数)+ last_used(最近使用 ISO)。 排序:有 last_used 的按降序,无 last_used 的排最后,同列 by name asc。 """ from main import resolve_workspace, user_root ws = resolve_workspace(None) root = user_root(ws, user_id) folder_names: list[str] = [] if root.is_dir(): for p in sorted(root.iterdir(), key=lambda x: x.name.lower()): if p.is_dir() and not p.name.startswith("."): folder_names.append(p.name) folders: list[dict] = [] if folder_names: with session_scope() as s: for name in folder_names: db_form = f"workspace/users/{user_id}/{name}" stat = s.execute( select(func.count(), func.max(Task.updated_at)) .where(Task.user_id == user_id, Task.working_dir == db_form) ).first() n = int((stat[0] if stat else 0) or 0) lu = stat[1] if stat else None folders.append({ "name": name, "n_tasks": n, "last_used": _iso(lu), }) folders.sort(key=lambda f: f["name"]) folders.sort(key=lambda f: f["last_used"] or "", reverse=True) return {"folders": folders} @app.delete("/v1/tasks/{task_id}", status_code=204, tags=["tasks"]) def delete_task(task_id: str, user_id: UUID = Depends(require_user)): """硬删除:DELETE DB 行(messages / runs CASCADE)。**FS task_dir 不动** (同 name 多 task 共享,文件由用户经 /files/delete 单独清)。跨 user → 404。 """ try: tid = UUID(task_id) except ValueError: raise HTTPException(404, f"invalid task id: {task_id!r}") from sqlalchemy import delete as _delete with session_scope() as s: result = s.execute( _delete(Task).where(Task.task_id == tid, Task.user_id == user_id) ) if result.rowcount == 0: raise HTTPException(404, f"task not found: {tid}") return None # 204 @app.patch("/v1/tasks/{task_id}", tags=["tasks"]) def patch_task( task_id: str, body: TaskPatchRequest, user_id: UUID = Depends(require_user), ): """更新 task 字段。`status` 仅允许 completed/abandoned(active 走 CLI 切回)。""" try: tid = UUID(task_id) except ValueError: raise HTTPException(404, f"invalid task id: {task_id!r}") updates: dict[str, Any] = {} if body.status is not None: if body.status not in STATUS_WRITABLE: raise HTTPException( 400, f"invalid status {body.status!r}; allowed: {STATUS_WRITABLE}" ) updates["status"] = body.status if body.description is not None: updates["description"] = body.description if body.skill is not None: updates["skill"] = body.skill if body.name is not None: from main import InvalidTaskName, validate_task_name try: updates["name"] = validate_task_name(body.name) except InvalidTaskName as e: raise HTTPException(400, f"name 不合法: {e}") if not updates: raise HTTPException(400, "no fields to update") with session_scope() as s: result = s.execute( update(Task) .where(Task.task_id == tid, Task.user_id == user_id) .values(**updates) ) if result.rowcount == 0: raise HTTPException(404, f"task not found: {tid}") row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one() n = s.execute( select(func.count()).select_from(Message).where(Message.task_id == tid) ).scalar_one() return _task_dict(row, n_messages=n) # ───────────── Messages ───────────── def _assert_owns_task(s, tid: UUID, user_id: UUID) -> None: ok = s.execute( select(Task.task_id).where(Task.task_id == tid, Task.user_id == user_id) ).first() if ok is None: raise HTTPException(404, f"task not found: {tid}") @app.get("/v1/tasks/{task_id}/messages", tags=["messages"]) def list_messages(task_id: str, user_id: UUID = Depends(require_user)): """task 历史消息(idx 升序);LiteLLM 原 payload 透传给前端,自行渲染。""" try: tid = UUID(task_id) except ValueError: raise HTTPException(404, f"invalid task id: {task_id!r}") with session_scope() as s: _assert_owns_task(s, tid, user_id) rows = s.execute( select( Message.idx, Message.payload, Message.tokens_in, Message.tokens_out, Message.created_at, ).where(Message.task_id == tid).order_by(Message.idx) ).all() return { "messages": [ { "idx": r.idx, "payload": dict(r.payload), "tokens_in": r.tokens_in, "tokens_out": r.tokens_out, "created_at": _iso(r.created_at), } for r in rows ] } @app.post("/v1/tasks/{task_id}/messages", status_code=202, tags=["messages"]) async def post_message( task_id: str, body: MessageRequest, user_id: UUID = Depends(require_user), ): """发消息 + 起 BG run。返 `{run_id, events_url}`,客户端立刻订阅 SSE 拿流式。 同 task 单活 run:`SELECT … FOR UPDATE` 锁 task 行 + 活跃 Run 检查,把所有权 / 活跃 / 插新 run 收进一个事务,挡住"用户连点 send 两条消息"导致两个 BG 线程 争 `messages.idx`(UniqueConstraint 会 race-crash)。已有 running run → 409。 """ try: tid = UUID(task_id) except ValueError: raise HTTPException(404, f"invalid task id: {task_id!r}") content = (body.content or "").strip() if not content: raise HTTPException(400, "empty content") run_id = uuid4() with session_scope() as s: owned = s.execute( select(Task.task_id) .where(Task.task_id == tid, Task.user_id == user_id) .with_for_update() ).first() if owned is None: raise HTTPException(404, f"task not found: {tid}") active = s.execute( select(Run.run_id) .where(Run.task_id == tid, Run.status == "running") .limit(1) ).scalar_one_or_none() if active is not None: raise HTTPException( 409, f"task already has a running run ({active}); wait for it to finish", ) s.add(Run(run_id=run_id, task_id=tid, status="running", started_at=func.now())) # commit 后 lock 释放;BG 线程接管(sink 通过 broker 把 event 桥回 asyncio loop) asyncio.create_task(asyncio.to_thread(_run_agent_bg, tid, run_id, user_id, content)) return { "run_id": str(run_id), "events_url": f"/v1/tasks/{tid}/runs/{run_id}/events", } # ───────────── SSE events ───────────── @app.get("/v1/tasks/{task_id}/runs/{run_id}/events", tags=["runs"]) async def stream_events( task_id: str, run_id: str, user_id: UUID = Depends(require_user), ): """SSE 流。事件类型:run_start / llm_start / text / tool_call / tool_result / llm_end / error / done。data 是 JSON dict(已剔除 `type` 字段,移到 event 名)。 """ try: tid = UUID(task_id) rid = UUID(run_id) except ValueError: raise HTTPException(404, "invalid id") with session_scope() as s: _assert_owns_task(s, tid, user_id) async def gen(): q = broker.subscribe(rid) try: # 第一帧 retry 注释 + 心跳:让 EventSource 立即建立(不被 buffer 卡) yield b": connected\nretry: 3000\n\n" while True: try: ev = await asyncio.wait_for(q.get(), timeout=30.0) except asyncio.TimeoutError: yield b": ping\n\n" continue ev_type = ev.get("type", "msg") payload = {k: v for k, v in ev.items() if k != "type"} yield _sse_event(ev_type, payload) if ev_type in ("done", "error"): break except asyncio.CancelledError: pass # 客户端断开,静默退 finally: broker.unsubscribe(rid, q) return StreamingResponse( gen(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # ───────────── Files(user-rooted,不绑 task) ───────────── @app.get("/v1/files", tags=["files"]) def list_files( path: str = "", user_id: UUID = Depends(require_user), ): """列 user_root 下子目录条目 + 面包屑。`path` 留空 → user_root; `../` / 绝对 → 400。dotfile(`.memory/` 等)一律隐藏。 """ root = _load_user_root(user_id) current = _safe_join(root, path) entries, crumbs, exists = _enumerate_files(root, current) return { "root": _norm_path(str(root)), "current": _rel_to(root, current), "exists": exists, "crumbs": crumbs, "entries": entries, } @app.get("/v1/files/download", tags=["files"]) def download_file( path: str, user_id: UUID = Depends(require_user), ): """下载 user_root 下单个 regular file(目录 → 400 / 不存在 → 404)。""" root = _load_user_root(user_id) target = _safe_join(root, path) if not target.exists(): raise HTTPException(404, f"file not found: {path}") if not target.is_file(): raise HTTPException(400, f"not a file: {path}") return FileResponse(path=str(target), filename=target.name) @app.post("/v1/files/upload", tags=["files"]) async def upload_files( path: str = Form(""), files: list[UploadFile] = File(...), user_id: UUID = Depends(require_user), ): """multipart 多文件上传到 `//`。 路径不存在自动 mkdir(parents=True);重名直接覆盖。 文件名严格校验(含 `/ \\ ..` 或为空 → 400)。 """ root = _load_user_root(user_id) dest_dir = _safe_join(root, path) if dest_dir.exists() and not dest_dir.is_dir(): raise HTTPException(400, f"upload target is a file, not a directory: {path}") dest_dir.mkdir(parents=True, exist_ok=True) saved: list[dict] = [] for up in files or []: raw_name = up.filename or "" if ( not raw_name or raw_name in (".", "..") or "/" in raw_name or "\\" in raw_name or any(part in (".", "..") for part in Path(raw_name).parts) ): raise HTTPException(400, f"invalid filename: {raw_name!r}") dest = dest_dir / raw_name try: dest.resolve().relative_to(root.resolve()) except ValueError: raise HTTPException(400, f"path escapes user_root: {raw_name!r}") data = await up.read() dest.write_bytes(data) saved.append({"name": raw_name, "size": len(data), "rel": _rel_to(root, dest)}) if not saved: raise HTTPException(400, "no files uploaded") return {"count": len(saved), "saved": saved} @app.post("/v1/files/delete", tags=["files"]) def delete_file( body: FileDeleteRequest, user_id: UUID = Depends(require_user), ): """删 user_root 下文件或**空**目录。非空目录 → 400(避免误操);root → 400。""" root = _load_user_root(user_id) target = _safe_join(root, body.path) if target.resolve() == root.resolve(): raise HTTPException(400, "cannot delete user_root") if not target.exists(): raise HTTPException(404, f"path not found: {body.path}") try: if target.is_dir(): target.rmdir() # 非空目录会触发 OSError else: target.unlink() except OSError as e: raise HTTPException(400, f"delete failed: {e}") return {"ok": True, "path": body.path} # ───────────── Export ───────────── @app.get("/v1/tasks/{task_id}/export", tags=["export"]) def export_task(task_id: str, user_id: UUID = Depends(require_user)): """导出对话为 .docx,临时文件下载完后 BackgroundTask 删 tmp。""" try: tid = UUID(task_id) except ValueError: raise HTTPException(404, f"invalid task id: {task_id!r}") with session_scope() as s: _assert_owns_task(s, tid, user_id) has_msg = s.execute( select(Message.message_id).where(Message.task_id == tid).limit(1) ).first() if not has_msg: raise HTTPException(400, "no messages to export") fd, tmp_str = tempfile.mkstemp(suffix=".docx", prefix="zcbot-export-") os.close(fd) tmp_path = Path(tmp_str) try: from core.export_docx import export_chat_to_docx export_chat_to_docx(tid, out_path=tmp_path) except Exception as e: tmp_path.unlink(missing_ok=True) raise HTTPException(500, f"export failed: {type(e).__name__}: {e}") return FileResponse( path=str(tmp_path), media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", filename=f"chat_{str(tid)[:8]}.docx", background=BackgroundTask(tmp_path.unlink, missing_ok=True), ) return app