675 lines
25 KiB
Python
675 lines
25 KiB
Python
"""FastAPI app: 纯 /v1 JSON API(2026-05-15 切换 — 详见 DESIGN §7.9)。
|
|
|
|
设计要点:
|
|
- 所有路由 `/v1/*` 前缀,响应 JSON;模板 / HTMX / 服务端 markdown 渲染全删
|
|
- SSE 事件 payload 是 JSON dict 而非 HTML 片段(`event: <type>` + `data: <json>`)
|
|
- Auth: PLATFORM_KEY → JWT 兑换(§7 D' 过渡形态,见 web/auth.py);OIDC 替换时只动 /v1/auth/login 内部
|
|
- 所有 /v1/tasks* 路由 Depends(require_user),按 user_id 隔离数据
|
|
- 豁免:/healthz、/docs、/openapi.json、/、/v1/auth/login、/static/*
|
|
- CORS allow_origins=["*"] 本地宽松;真发布按 platform 域名收紧
|
|
- `GET /` 302 → /static/dev.html(本地 dev SPA)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import tempfile
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime as _dt
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
from uuid import UUID, uuid4
|
|
|
|
from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import func, select, update
|
|
from starlette.background import BackgroundTask
|
|
|
|
from core.paths import from_db_path, to_db_path
|
|
from core.storage import (
|
|
NoSubtaskError,
|
|
check_no_subtask,
|
|
session_scope,
|
|
)
|
|
from core.storage.models import Message, Run, Task
|
|
from core.storage.utils import ensure_local_task_row
|
|
|
|
from .auth import AuthConfig, ensure_user_row, make_require_user, mint_token
|
|
from .broker import broker
|
|
from .sinks import WebEventSink
|
|
|
|
|
|
STATUS_FILTERS = ("active", "completed", "abandoned")
|
|
STATUS_WRITABLE = ("completed", "abandoned") # web 不让从 web 端切回 active(走 CLI)
|
|
|
|
|
|
# ─────────────────────────── helpers ───────────────────────────
|
|
|
|
def _norm_path(p: str) -> str:
|
|
"""跨 OS 显示归一:backslash → forward slash。"""
|
|
return (p or "").replace("\\", "/")
|
|
|
|
|
|
def _iso(dt: Optional[Any]) -> Optional[str]:
|
|
return dt.isoformat() if dt else None
|
|
|
|
|
|
def _task_dict(row: Any, *, n_messages: Optional[int] = None) -> dict:
|
|
"""Task ORM row → API JSON dict。"""
|
|
d = {
|
|
"task_id": str(row.task_id),
|
|
"description": row.description or "",
|
|
"task_dir": _norm_path(row.task_dir or ""),
|
|
"status": row.status,
|
|
"mode": row.mode or "",
|
|
"model": row.model or "",
|
|
"model_profile": row.model_profile or "",
|
|
"tokens_prompt": row.tokens_prompt or 0,
|
|
"tokens_completion": row.tokens_completion or 0,
|
|
"tokens": (row.tokens_prompt or 0) + (row.tokens_completion or 0),
|
|
"created_at": _iso(getattr(row, "created_at", None)),
|
|
"updated_at": _iso(getattr(row, "updated_at", None)),
|
|
}
|
|
if n_messages is not None:
|
|
d["n_messages"] = n_messages
|
|
return d
|
|
|
|
|
|
# ─────────────────────── files helpers ───────────────────────
|
|
|
|
def _load_task_dir(task_id: str, user_id: UUID) -> tuple[UUID, Path]:
|
|
"""task_id 解析 + 查 PG 拿 task_dir db form + 还原 absolute Path。
|
|
404 / 400 if 非 UUID / task 不存在 / 不属于 user / task_dir 空。
|
|
跨 user 视为 not found(不暴露 task 存在性)。
|
|
"""
|
|
try:
|
|
tid = UUID(task_id)
|
|
except ValueError:
|
|
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
|
with session_scope() as s:
|
|
row = s.execute(
|
|
select(Task.task_dir).where(
|
|
Task.task_id == tid, Task.user_id == user_id
|
|
)
|
|
).first()
|
|
if row is None:
|
|
raise HTTPException(404, f"task not found: {tid}")
|
|
td = row[0] or ""
|
|
if not td:
|
|
raise HTTPException(400, f"task {tid} has no task_dir, files browsing unavailable")
|
|
return tid, from_db_path(td)
|
|
|
|
|
|
def _safe_join(root: Path, rel: str) -> Path:
|
|
"""归一用户路径到 absolute,并校验仍在 root 内。防 `../` / 绝对 path / symlink 越界。"""
|
|
rel = (rel or "").strip()
|
|
if not rel:
|
|
return root.resolve()
|
|
if rel[0] in ("/", "\\"):
|
|
raise HTTPException(400, f"absolute-style path not allowed: {rel!r}")
|
|
if Path(rel).is_absolute():
|
|
raise HTTPException(400, f"absolute path not allowed: {rel!r}")
|
|
target = (root / rel).resolve()
|
|
try:
|
|
target.relative_to(root.resolve())
|
|
except ValueError:
|
|
raise HTTPException(400, f"path escapes task_dir: {rel!r}")
|
|
return target
|
|
|
|
|
|
def _rel_to(root: Path, target: Path) -> str:
|
|
try:
|
|
return target.resolve().relative_to(root.resolve()).as_posix()
|
|
except ValueError:
|
|
return ""
|
|
|
|
|
|
def _enumerate_files(root: Path, current: Path) -> tuple[list[dict], list[dict], bool]:
|
|
"""枚举 current 下条目 + 拼面包屑。size raw bytes,mtime ISO 串(前端 humanize)。"""
|
|
entries: list[dict] = []
|
|
exists = current.exists()
|
|
if exists and current.is_dir():
|
|
try:
|
|
raw = sorted(current.iterdir(), key=lambda p: (p.is_file(), p.name.lower()))
|
|
except OSError:
|
|
raw = []
|
|
for p in raw:
|
|
try:
|
|
st = p.stat()
|
|
except OSError:
|
|
continue
|
|
entries.append({
|
|
"name": p.name,
|
|
"is_dir": p.is_dir(),
|
|
"size": st.st_size if p.is_file() else None,
|
|
"mtime": _dt.fromtimestamp(st.st_mtime).isoformat(timespec="seconds"),
|
|
"rel": _rel_to(root, p),
|
|
})
|
|
cur_rel = _rel_to(root, current)
|
|
crumbs = [{"label": "/", "rel": ""}]
|
|
if cur_rel:
|
|
acc = ""
|
|
for part in cur_rel.split("/"):
|
|
acc = f"{acc}/{part}" if acc else part
|
|
crumbs.append({"label": part, "rel": acc})
|
|
return entries, crumbs, exists
|
|
|
|
|
|
# ─────────────────── Run 启动 + SSE 帧格式 ───────────────────
|
|
|
|
def _run_agent_bg(task_id: UUID, run_id: UUID, user_message: str) -> None:
|
|
"""工作线程:`build_agent(resume=True)` → 装 WebEventSink → `agent.run` → 写 runs 状态。
|
|
|
|
sink 通过 broker.emit 桥事件回 asyncio loop;agent.run 是 sync,所以在 to_thread 跑。
|
|
"""
|
|
from main import build_agent, sync_task_tokens
|
|
try:
|
|
broker.emit(run_id, {"type": "run_start"})
|
|
agent, session, sid, task_state, task_dir = build_agent(
|
|
session_id=str(task_id), resume=True,
|
|
)
|
|
agent.sink = WebEventSink(broker, run_id)
|
|
agent.run(user_message)
|
|
sync_task_tokens(task_state, agent.llm)
|
|
with session_scope() as s:
|
|
s.execute(
|
|
update(Run).where(Run.run_id == run_id).values(
|
|
status="ok",
|
|
finished_at=func.now(),
|
|
tokens_p=agent.llm.token_counter.prompt_tokens,
|
|
tokens_c=agent.llm.token_counter.completion_tokens,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
err = f"{type(e).__name__}: {e}"
|
|
broker.emit(run_id, {"type": "error", "msg": err})
|
|
try:
|
|
with session_scope() as s:
|
|
s.execute(
|
|
update(Run).where(Run.run_id == run_id).values(
|
|
status="error", error=err, finished_at=func.now()
|
|
)
|
|
)
|
|
except Exception:
|
|
pass # 已 emit error 给前端,DB 写失败不放大噪声
|
|
finally:
|
|
broker.close(run_id)
|
|
|
|
|
|
def _sse_event(event_type: str, payload: dict) -> bytes:
|
|
"""格式化 SSE 一帧:`event: <type>` + `data: <json single-line>`。"""
|
|
body = json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
|
return f"event: {event_type}\ndata: {body}\n\n".encode("utf-8")
|
|
|
|
|
|
# ────────────────────── Pydantic 请求体 ──────────────────────
|
|
|
|
class TaskCreateRequest(BaseModel):
|
|
description: str = ""
|
|
mode: str = ""
|
|
task_dir: str = ""
|
|
|
|
|
|
class TaskPatchRequest(BaseModel):
|
|
status: Optional[str] = None
|
|
description: Optional[str] = None
|
|
mode: Optional[str] = None
|
|
|
|
|
|
class MessageRequest(BaseModel):
|
|
content: str
|
|
|
|
|
|
class FileDeleteRequest(BaseModel):
|
|
path: str
|
|
|
|
|
|
class LoginRequest(BaseModel):
|
|
user_id: str
|
|
platform_key: str
|
|
|
|
|
|
# ────────────────────── App 工厂 ──────────────────────
|
|
|
|
# web/static 目录路径 — /static 静态挂载用,dev.html 也放这
|
|
_STATIC_DIR = Path(__file__).parent / "static"
|
|
|
|
|
|
def create_app() -> FastAPI:
|
|
# fail-fast:env 缺失直接抛,不裸跑无密
|
|
auth_cfg = AuthConfig.from_env()
|
|
require_user = make_require_user(auth_cfg)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
broker.bind_loop(asyncio.get_running_loop())
|
|
yield
|
|
|
|
app = FastAPI(
|
|
title="zcbot api",
|
|
version="0.8",
|
|
description=(
|
|
"zcbot 后端 — /v1 JSON API + SSE。Auth: PLATFORM_KEY → JWT(§7 D' 过渡)。"
|
|
"本地 dev SPA: /static/dev.html。"
|
|
),
|
|
lifespan=lifespan,
|
|
)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # 本地宽松,部署 platform 时按域名收紧
|
|
allow_credentials=False,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
if _STATIC_DIR.is_dir():
|
|
app.mount("/static", StaticFiles(directory=str(_STATIC_DIR)), name="static")
|
|
|
|
# ───────────── Misc ─────────────
|
|
|
|
@app.get("/", include_in_schema=False)
|
|
def root():
|
|
# 本地 dev SPA;Swagger UI 仍在 /docs
|
|
return RedirectResponse(url="/static/dev.html", status_code=302)
|
|
|
|
@app.get("/healthz", tags=["misc"])
|
|
def healthz():
|
|
return {"status": "ok"}
|
|
|
|
# ───────────── Auth ─────────────
|
|
|
|
@app.post("/v1/auth/login", tags=["auth"])
|
|
def login(body: LoginRequest):
|
|
"""platform_key 校验通过 → 签 JWT(user_id 作为 sub)。
|
|
|
|
platform_key 错 → 403;user_id 非 UUID → 400。
|
|
user_id 未存在则幂等创建 users 行(避免下游 FK 失败)。
|
|
"""
|
|
if body.platform_key != auth_cfg.platform_key:
|
|
raise HTTPException(403, "invalid platform_key")
|
|
try:
|
|
uid = UUID(body.user_id)
|
|
except (ValueError, TypeError):
|
|
raise HTTPException(400, f"invalid user_id (must be UUID): {body.user_id!r}")
|
|
ensure_user_row(uid)
|
|
token, exp = mint_token(auth_cfg, uid)
|
|
return {
|
|
"token": token,
|
|
"expires_at": _dt.fromtimestamp(exp).isoformat(),
|
|
"user_id": str(uid),
|
|
"ttl_seconds": auth_cfg.ttl_seconds,
|
|
}
|
|
|
|
# ───────────── Tasks CRUD ─────────────
|
|
|
|
@app.post("/v1/tasks", status_code=201, tags=["tasks"])
|
|
def create_task(body: TaskCreateRequest, user_id: UUID = Depends(require_user)):
|
|
"""新建 task。`task_dir` 留空 → 默认派生 `workspace/tasks/<uuid>/`。
|
|
`description` 与 `task_dir` 至少给一个否则 400。
|
|
前缀嵌套(no-subtask,同 user 内)→ 409。
|
|
"""
|
|
description = body.description.strip()
|
|
mode = body.mode.strip()
|
|
task_dir_raw = body.task_dir.strip()
|
|
if not description and not task_dir_raw:
|
|
raise HTTPException(400, "either description or task_dir must be provided")
|
|
|
|
tid = uuid4()
|
|
from main import _default_task_dir, resolve_workspace
|
|
ws = resolve_workspace(None)
|
|
if task_dir_raw:
|
|
fs_dir = Path(task_dir_raw).expanduser().resolve()
|
|
else:
|
|
fs_dir = _default_task_dir(ws, tid)
|
|
fs_dir_db = to_db_path(fs_dir)
|
|
|
|
try:
|
|
check_no_subtask(fs_dir_db, user_id=user_id)
|
|
except NoSubtaskError as e:
|
|
raise HTTPException(409, str(e))
|
|
|
|
ensure_local_task_row(
|
|
task_id=tid, task_dir=fs_dir_db, mode=mode,
|
|
description=description, user_id=user_id,
|
|
)
|
|
with session_scope() as s:
|
|
row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
|
|
return _task_dict(row, n_messages=0)
|
|
|
|
@app.get("/v1/tasks", tags=["tasks"])
|
|
def list_tasks_route(
|
|
status: Optional[str] = None,
|
|
limit: int = 50,
|
|
user_id: UUID = Depends(require_user),
|
|
):
|
|
"""列出当前 user 的 task,`updated_at` 降序。"""
|
|
if status and status not in STATUS_FILTERS:
|
|
status = None
|
|
with session_scope() as s:
|
|
q = select(Task).where(Task.user_id == user_id).order_by(Task.updated_at.desc())
|
|
if status:
|
|
q = q.where(Task.status == status)
|
|
rows = s.execute(q.limit(limit)).scalars().all()
|
|
tids = [r.task_id for r in rows]
|
|
counts = (
|
|
dict(s.execute(
|
|
select(Message.task_id, func.count())
|
|
.where(Message.task_id.in_(tids))
|
|
.group_by(Message.task_id)
|
|
).all())
|
|
if tids else {}
|
|
)
|
|
return {
|
|
"tasks": [
|
|
_task_dict(r, n_messages=counts.get(r.task_id, 0))
|
|
for r in rows
|
|
]
|
|
}
|
|
|
|
@app.get("/v1/tasks/{task_id}", tags=["tasks"])
|
|
def get_task(task_id: str, user_id: UUID = Depends(require_user)):
|
|
"""单 task meta(不含 messages;走 /messages 拿)。跨 user → 404。"""
|
|
try:
|
|
tid = UUID(task_id)
|
|
except ValueError:
|
|
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
|
with session_scope() as s:
|
|
row = s.execute(
|
|
select(Task).where(Task.task_id == tid, Task.user_id == user_id)
|
|
).scalar_one_or_none()
|
|
if row is None:
|
|
raise HTTPException(404, f"task not found: {tid}")
|
|
n = s.execute(
|
|
select(func.count()).select_from(Message).where(Message.task_id == tid)
|
|
).scalar_one()
|
|
return _task_dict(row, n_messages=n)
|
|
|
|
@app.patch("/v1/tasks/{task_id}", tags=["tasks"])
|
|
def patch_task(
|
|
task_id: str,
|
|
body: TaskPatchRequest,
|
|
user_id: UUID = Depends(require_user),
|
|
):
|
|
"""更新 task 字段。`status` 仅允许 completed/abandoned(active 走 CLI 切回)。"""
|
|
try:
|
|
tid = UUID(task_id)
|
|
except ValueError:
|
|
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
|
updates: dict[str, Any] = {}
|
|
if body.status is not None:
|
|
if body.status not in STATUS_WRITABLE:
|
|
raise HTTPException(
|
|
400, f"invalid status {body.status!r}; allowed: {STATUS_WRITABLE}"
|
|
)
|
|
updates["status"] = body.status
|
|
if body.description is not None:
|
|
updates["description"] = body.description
|
|
if body.mode is not None:
|
|
updates["mode"] = body.mode
|
|
if not updates:
|
|
raise HTTPException(400, "no fields to update")
|
|
with session_scope() as s:
|
|
result = s.execute(
|
|
update(Task)
|
|
.where(Task.task_id == tid, Task.user_id == user_id)
|
|
.values(**updates)
|
|
)
|
|
if result.rowcount == 0:
|
|
raise HTTPException(404, f"task not found: {tid}")
|
|
row = s.execute(select(Task).where(Task.task_id == tid)).scalar_one()
|
|
n = s.execute(
|
|
select(func.count()).select_from(Message).where(Message.task_id == tid)
|
|
).scalar_one()
|
|
return _task_dict(row, n_messages=n)
|
|
|
|
# ───────────── Messages ─────────────
|
|
|
|
def _assert_owns_task(s, tid: UUID, user_id: UUID) -> None:
|
|
ok = s.execute(
|
|
select(Task.task_id).where(Task.task_id == tid, Task.user_id == user_id)
|
|
).first()
|
|
if ok is None:
|
|
raise HTTPException(404, f"task not found: {tid}")
|
|
|
|
@app.get("/v1/tasks/{task_id}/messages", tags=["messages"])
|
|
def list_messages(task_id: str, user_id: UUID = Depends(require_user)):
|
|
"""task 历史消息(idx 升序);LiteLLM 原 payload 透传给前端,自行渲染。"""
|
|
try:
|
|
tid = UUID(task_id)
|
|
except ValueError:
|
|
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
|
with session_scope() as s:
|
|
_assert_owns_task(s, tid, user_id)
|
|
rows = s.execute(
|
|
select(
|
|
Message.idx, Message.payload, Message.tokens_in,
|
|
Message.tokens_out, Message.created_at,
|
|
).where(Message.task_id == tid).order_by(Message.idx)
|
|
).all()
|
|
return {
|
|
"messages": [
|
|
{
|
|
"idx": r.idx,
|
|
"payload": dict(r.payload),
|
|
"tokens_in": r.tokens_in,
|
|
"tokens_out": r.tokens_out,
|
|
"created_at": _iso(r.created_at),
|
|
}
|
|
for r in rows
|
|
]
|
|
}
|
|
|
|
@app.post("/v1/tasks/{task_id}/messages", status_code=202, tags=["messages"])
|
|
async def post_message(
|
|
task_id: str,
|
|
body: MessageRequest,
|
|
user_id: UUID = Depends(require_user),
|
|
):
|
|
"""发消息 + 起 BG run。返 `{run_id, events_url}`,客户端立刻订阅 SSE 拿流式。"""
|
|
try:
|
|
tid = UUID(task_id)
|
|
except ValueError:
|
|
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
|
content = (body.content or "").strip()
|
|
if not content:
|
|
raise HTTPException(400, "empty content")
|
|
with session_scope() as s:
|
|
_assert_owns_task(s, tid, user_id)
|
|
|
|
run_id = uuid4()
|
|
with session_scope() as s:
|
|
s.add(Run(run_id=run_id, task_id=tid, status="running", started_at=func.now()))
|
|
# to_thread 跑 sync agent.run;sink 通过 broker 把 event 桥回 asyncio
|
|
asyncio.create_task(asyncio.to_thread(_run_agent_bg, tid, run_id, content))
|
|
return {
|
|
"run_id": str(run_id),
|
|
"events_url": f"/v1/tasks/{tid}/runs/{run_id}/events",
|
|
}
|
|
|
|
# ───────────── SSE events ─────────────
|
|
|
|
@app.get("/v1/tasks/{task_id}/runs/{run_id}/events", tags=["runs"])
|
|
async def stream_events(
|
|
task_id: str,
|
|
run_id: str,
|
|
user_id: UUID = Depends(require_user),
|
|
):
|
|
"""SSE 流。事件类型:run_start / llm_start / text / tool_call / tool_result /
|
|
llm_end / error / done。data 是 JSON dict(已剔除 `type` 字段,移到 event 名)。
|
|
"""
|
|
try:
|
|
tid = UUID(task_id)
|
|
rid = UUID(run_id)
|
|
except ValueError:
|
|
raise HTTPException(404, "invalid id")
|
|
with session_scope() as s:
|
|
_assert_owns_task(s, tid, user_id)
|
|
|
|
async def gen():
|
|
q = broker.subscribe(rid)
|
|
try:
|
|
# 第一帧 retry 注释 + 心跳:让 EventSource 立即建立(不被 buffer 卡)
|
|
yield b": connected\nretry: 3000\n\n"
|
|
while True:
|
|
try:
|
|
ev = await asyncio.wait_for(q.get(), timeout=30.0)
|
|
except asyncio.TimeoutError:
|
|
yield b": ping\n\n"
|
|
continue
|
|
ev_type = ev.get("type", "msg")
|
|
payload = {k: v for k, v in ev.items() if k != "type"}
|
|
yield _sse_event(ev_type, payload)
|
|
if ev_type in ("done", "error"):
|
|
break
|
|
except asyncio.CancelledError:
|
|
pass # 客户端断开,静默退
|
|
finally:
|
|
broker.unsubscribe(rid, q)
|
|
|
|
return StreamingResponse(
|
|
gen(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
# ───────────── Files ─────────────
|
|
|
|
@app.get("/v1/tasks/{task_id}/files", tags=["files"])
|
|
def list_files(
|
|
task_id: str,
|
|
path: str = "",
|
|
user_id: UUID = Depends(require_user),
|
|
):
|
|
"""列子目录条目 + 面包屑。`path` 留空 → root;`../` / 绝对 → 400。"""
|
|
tid, root = _load_task_dir(task_id, user_id)
|
|
current = _safe_join(root, path)
|
|
entries, crumbs, exists = _enumerate_files(root, current)
|
|
return {
|
|
"task_id": str(tid),
|
|
"root": _norm_path(str(root)),
|
|
"current": _rel_to(root, current),
|
|
"exists": exists,
|
|
"crumbs": crumbs,
|
|
"entries": entries,
|
|
}
|
|
|
|
@app.get("/v1/tasks/{task_id}/files/download", tags=["files"])
|
|
def download_file(
|
|
task_id: str,
|
|
path: str,
|
|
user_id: UUID = Depends(require_user),
|
|
):
|
|
"""下载单个 regular file(目录 → 400 / 不存在 → 404)。"""
|
|
tid, root = _load_task_dir(task_id, user_id)
|
|
target = _safe_join(root, path)
|
|
if not target.exists():
|
|
raise HTTPException(404, f"file not found: {path}")
|
|
if not target.is_file():
|
|
raise HTTPException(400, f"not a file: {path}")
|
|
return FileResponse(path=str(target), filename=target.name)
|
|
|
|
@app.post("/v1/tasks/{task_id}/files/upload", tags=["files"])
|
|
async def upload_files(
|
|
task_id: str,
|
|
path: str = Form(""),
|
|
files: list[UploadFile] = File(...),
|
|
user_id: UUID = Depends(require_user),
|
|
):
|
|
"""multipart 多文件上传到 `<task_dir>/<path>/`。
|
|
路径不存在自动 mkdir(parents=True);重名直接覆盖。
|
|
文件名严格校验(含 `/ \\ ..` 或为空 → 400)。
|
|
"""
|
|
tid, root = _load_task_dir(task_id, user_id)
|
|
dest_dir = _safe_join(root, path)
|
|
if dest_dir.exists() and not dest_dir.is_dir():
|
|
raise HTTPException(400, f"upload target is a file, not a directory: {path}")
|
|
dest_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
saved: list[dict] = []
|
|
for up in files or []:
|
|
raw_name = up.filename or ""
|
|
if (
|
|
not raw_name
|
|
or raw_name in (".", "..")
|
|
or "/" in raw_name or "\\" in raw_name
|
|
or any(part in (".", "..") for part in Path(raw_name).parts)
|
|
):
|
|
raise HTTPException(400, f"invalid filename: {raw_name!r}")
|
|
dest = dest_dir / raw_name
|
|
try:
|
|
dest.resolve().relative_to(root.resolve())
|
|
except ValueError:
|
|
raise HTTPException(400, f"path escapes task_dir: {raw_name!r}")
|
|
data = await up.read()
|
|
dest.write_bytes(data)
|
|
saved.append({"name": raw_name, "size": len(data), "rel": _rel_to(root, dest)})
|
|
if not saved:
|
|
raise HTTPException(400, "no files uploaded")
|
|
return {"count": len(saved), "saved": saved}
|
|
|
|
@app.post("/v1/tasks/{task_id}/files/delete", tags=["files"])
|
|
def delete_file(
|
|
task_id: str,
|
|
body: FileDeleteRequest,
|
|
user_id: UUID = Depends(require_user),
|
|
):
|
|
"""删 task_dir 下文件或**空**目录。非空目录 → 400(避免误操);root → 400。"""
|
|
tid, root = _load_task_dir(task_id, user_id)
|
|
target = _safe_join(root, body.path)
|
|
if target.resolve() == root.resolve():
|
|
raise HTTPException(400, "cannot delete task_dir root")
|
|
if not target.exists():
|
|
raise HTTPException(404, f"path not found: {body.path}")
|
|
try:
|
|
if target.is_dir():
|
|
target.rmdir() # 非空目录会触发 OSError
|
|
else:
|
|
target.unlink()
|
|
except OSError as e:
|
|
raise HTTPException(400, f"delete failed: {e}")
|
|
return {"ok": True, "path": body.path}
|
|
|
|
# ───────────── Export ─────────────
|
|
|
|
@app.get("/v1/tasks/{task_id}/export", tags=["export"])
|
|
def export_task(task_id: str, user_id: UUID = Depends(require_user)):
|
|
"""导出对话为 .docx,临时文件下载完后 BackgroundTask 删 tmp。"""
|
|
try:
|
|
tid = UUID(task_id)
|
|
except ValueError:
|
|
raise HTTPException(404, f"invalid task id: {task_id!r}")
|
|
with session_scope() as s:
|
|
_assert_owns_task(s, tid, user_id)
|
|
has_msg = s.execute(
|
|
select(Message.message_id).where(Message.task_id == tid).limit(1)
|
|
).first()
|
|
if not has_msg:
|
|
raise HTTPException(400, "no messages to export")
|
|
|
|
fd, tmp_str = tempfile.mkstemp(suffix=".docx", prefix="zcbot-export-")
|
|
os.close(fd)
|
|
tmp_path = Path(tmp_str)
|
|
try:
|
|
from core.export_docx import export_chat_to_docx
|
|
export_chat_to_docx(tid, out_path=tmp_path)
|
|
except Exception as e:
|
|
tmp_path.unlink(missing_ok=True)
|
|
raise HTTPException(500, f"export failed: {type(e).__name__}: {e}")
|
|
|
|
return FileResponse(
|
|
path=str(tmp_path),
|
|
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
filename=f"chat_{str(tid)[:8]}.docx",
|
|
background=BackgroundTask(tmp_path.unlink, missing_ok=True),
|
|
)
|
|
|
|
return app
|