478 lines
21 KiB
Python
478 lines
21 KiB
Python
"""装配入口: 读 config → 加载 capabilities/skills → 构造 LLM/tools/session/loop。
|
|
|
|
存储布局(§7.0 / §7.4):本地 + SaaS 共用 `workspace/` 根,只差 user_id:
|
|
|
|
PG tasks / messages ← 元数据 + 消息
|
|
workspace/users/<user_id>/<working_dir>/ ← 工作目录(用户起名,可多 task 共享)
|
|
workspace/users/<user_id>/.memory/{core.md, extended/} ← per-user 记忆(dotfile 隔离)
|
|
|
|
所有入口都走 web `/v1` + JWT(user_id = sub);dev SPA 走邮箱密码登录
|
|
(`users.email/password_hash`,bcrypt)、platform 服务端走 platform_key 登录。task_id / user_id 全 UUID;
|
|
state.json 已删除(元数据全在 PG)。
|
|
|
|
**新建 task 必须给 `name`**(任务显示名,DB 列 NOT NULL);**`working_dir` 可选**
|
|
(留空 → 用 name 作目录名;同 working_dir 多 task 自动共享 §7.1)。name 和 working_dir
|
|
都过同一份 `validate_task_name` 校验(简单名,不含 `/\\..`、不以 `.` 起头)。
|
|
`_cleanup_if_empty` 不 rmtree FS —— 同 working_dir 跨 task 复用,空 task 只删 DB 行。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Callable, Optional, Tuple
|
|
from uuid import UUID, uuid4
|
|
|
|
import yaml
|
|
from rich.console import Console
|
|
|
|
from core.capabilities import ModelCapabilities
|
|
from core.executor_host import HostExecutor
|
|
from core.llm import LLM
|
|
from core.loop import AgentLoop
|
|
from core.memory import memory_block
|
|
from core.paths import ROOT, from_db_path, to_db_path
|
|
from core.session import Session
|
|
from core.sinks import ConsoleEventSink
|
|
from core.skills import SkillRegistry
|
|
from core.storage import check_no_subtask
|
|
from core.task import TaskState
|
|
from tools.fs import EditTool, GlobTool, GrepTool, ReadTool, WriteTool
|
|
from tools.run_python import RunPythonTool
|
|
from tools.seedance import SeedanceTool
|
|
from tools.seedream import SeedreamTool
|
|
from tools.shell import ShellTool
|
|
from tools.skill_tool import LoadSkillTool
|
|
from tools.web_fetch import WebFetchTool
|
|
from tools.web_search import WebSearchTool
|
|
|
|
from core.ark_client import ArkConfig
|
|
from core.bocha_client import BochaConfig
|
|
|
|
|
|
def load_config() -> dict:
|
|
return yaml.safe_load((ROOT / "config" / "agent.yaml").read_text(encoding="utf-8")) or {}
|
|
|
|
|
|
def resolve_workspace(workspace: Optional[str], cfg: Optional[dict] = None) -> Path:
|
|
cfg = cfg or load_config()
|
|
p = Path(workspace) if workspace else ROOT / cfg.get("workspace_dir", "workspace")
|
|
p.mkdir(parents=True, exist_ok=True)
|
|
return p
|
|
|
|
|
|
def user_root(workspace_dir: Path, user_id: UUID) -> Path:
|
|
"""per-user 子树根:`<workspace>/users/<user_id>/`。working_dir / `.memory/` 都在下面。"""
|
|
d = workspace_dir / "users" / str(user_id)
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
return d
|
|
|
|
|
|
class InvalidTaskName(ValueError):
|
|
"""task name / working_dir 不合法(空 / 含分隔符 / dotfile 起头 / 超长)。"""
|
|
|
|
|
|
def validate_task_name(name: str) -> str:
|
|
"""返回 stripped name;非法抛 InvalidTaskName。
|
|
|
|
name 和 working_dir 共用一份规则:非空 / 不含 `/\\` 和 NUL / 不以 `.` 起头
|
|
(挡 `.memory` 等系统区)/ ≤ 255 字符。允许 CJK 与其他 Unicode 字符。
|
|
"""
|
|
n = (name or "").strip()
|
|
if not n:
|
|
raise InvalidTaskName("name 不能为空")
|
|
if len(n) > 255:
|
|
raise InvalidTaskName(f"name 超长(>255 字符): {n[:40]!r}...")
|
|
if any(c in n for c in ("/", "\\", "\x00")):
|
|
raise InvalidTaskName(f"name 不能含 `/` `\\` 或 NUL: {n!r}")
|
|
if n.startswith("."):
|
|
raise InvalidTaskName(
|
|
f"name 不能以 `.` 起头(保留给 .memory 等系统区): {n!r}"
|
|
)
|
|
return n
|
|
|
|
|
|
def working_dir_from_name(workspace_dir: Path, user_id: UUID, dir_name: str) -> Path:
|
|
"""`<workspace>/users/<user_id>/<dir_name>` 绝对路径。
|
|
|
|
入参 dir_name 由 `validate_task_name` 在入口校验过;本函数只拼路径,不 mkdir
|
|
(目录创建放在 task 创建入口 build_agent / web `/v1/tasks`,函数保持纯)。
|
|
"""
|
|
return user_root(workspace_dir, user_id) / dir_name
|
|
|
|
|
|
def resolve_task_id(
|
|
workspace_dir: Path,
|
|
task_id_arg: Optional[str],
|
|
resume: bool,
|
|
user_id: UUID,
|
|
working_dir_name: Optional[str] = None,
|
|
) -> Tuple[UUID, Path]:
|
|
"""返回 (task_id, working_dir 绝对路径)。
|
|
|
|
新建:`working_dir_name` 必填(调用方应已 fallback 到 name + 校验过),
|
|
工作目录 = `<workspace>/users/<uid>/<working_dir_name>/`。
|
|
Resume:`task_id_arg` 是完整 UUID 字符串(web 路由进来的总是 UUID),
|
|
working_dir 从 PG `tasks.working_dir` 读还原;`working_dir_name` 在 resume 时被忽略。
|
|
"""
|
|
if resume:
|
|
from sqlalchemy import select
|
|
from core.storage import session_scope
|
|
from core.storage.models import Task
|
|
|
|
tid = UUID(task_id_arg) if task_id_arg else None
|
|
if tid is None:
|
|
raise ValueError("resume 必须指定 task_id")
|
|
with session_scope() as s:
|
|
db_dir = s.execute(
|
|
select(Task.working_dir).where(Task.task_id == tid)
|
|
).scalar_one_or_none() or ""
|
|
if not db_dir:
|
|
raise ValueError(
|
|
f"task {tid} has empty working_dir in DB — should not happen "
|
|
"(new tasks require name + working_dir; legacy empty data was wiped)"
|
|
)
|
|
# DB 存的是 db 形态(相对 ROOT 或绝对),走 from_db_path 还原绝对
|
|
return tid, from_db_path(db_dir)
|
|
|
|
if not working_dir_name:
|
|
raise InvalidTaskName("new task 必须指定 working_dir(或留空 fallback 用 name)")
|
|
safe = validate_task_name(working_dir_name)
|
|
return uuid4(), working_dir_from_name(workspace_dir, user_id, safe)
|
|
|
|
|
|
def _build_system_prompt(
|
|
cfg: dict,
|
|
skills: SkillRegistry,
|
|
workspace_dir: Path,
|
|
tool_base: Path,
|
|
working_dir: Path,
|
|
user_id: UUID,
|
|
task_id: UUID,
|
|
task_name: str,
|
|
task_skill: str = "",
|
|
) -> str:
|
|
"""拼 system prompt: 模板 + skill 列表 + memory + 工作目录段 + task 上下文 + 命名约定。
|
|
|
|
new task 和 resume task 都走这里,memory 演化即时生效。memory 按 user_id 隔离。
|
|
task_short_id (task_id.hex 前 8 位) 作「宪法」文件主锚 —— task.name 可改,
|
|
task_id 永不变,glob 按 short_id 找文件,免 cascade rename。
|
|
task_name 仍写进文件名作"建时元数据 / 人类可读说明",改名后文件名里的旧 name
|
|
不强求同步(由 short_id 兜底定位)。
|
|
today 当场算,落 prompt 给 LLM 直接拼路径(避免 LLM 不知道当前日期)。
|
|
"""
|
|
prompt = (ROOT / cfg["system_prompt"]).read_text(encoding="utf-8")
|
|
if skills.skills:
|
|
prompt += f"\n\n## 可用 skill (用 load_skill 加载完整指引)\n{skills.discovery_block()}"
|
|
prompt += memory_block(workspace_dir, user_id)
|
|
wd_abs = working_dir.resolve()
|
|
today = datetime.now().strftime("%Y-%m-%d")
|
|
tname = task_name or "<未指定>"
|
|
short_id = task_id.hex[:8]
|
|
skill_line = (
|
|
f"- **task 预选 skill**: `{task_skill}` — 用户创建时声明的主 skill\n"
|
|
if task_skill else ""
|
|
)
|
|
prompt += (
|
|
f"\n\n## 工作目录与 task 上下文\n"
|
|
f"- cwd(用户启动时所在目录,只读用): `{tool_base}`\n"
|
|
f"- **task_dir(所有产物写到这里)**: `{wd_abs}`\n"
|
|
f"- **task_short_id**(永不变,「宪法」文件主锚): `{short_id}`\n"
|
|
f"- **task_name**(可变,写进文件名作人类可读说明): `{tname}`\n"
|
|
f"- **today**(当前日期,用于「宪法」文件命名): `{today}`\n"
|
|
f"{skill_line}"
|
|
f"\n"
|
|
f"SKILL 文档里出现的 `<task_dir>` 占位符,一律指上面这个绝对路径。"
|
|
f"普通产物(sections / slides / 终稿 .docx/.pptx)按 SKILL 文档落路径;"
|
|
f"「宪法」性文件(spec 等)按下面《task 级「宪法」文件命名约定》拼路径。\n"
|
|
f"⛔ 不要把产物写到 cwd / `skills/` / repo 根 —— 只写到 task_dir。\n"
|
|
f"\n## task 级「宪法」文件命名约定(跨 skill 通用)\n"
|
|
f"任何 skill 产物中,跟 task 1:1 强绑定、阶段二/后续步骤会**反复 read**"
|
|
f"的「宪法」性文件(如 proposal/ppt 的 spec、outline 等),**统一按下面格式命名**,"
|
|
f"落在 task_dir 根下:\n\n"
|
|
f" <YYYY-MM-DD>-<task_short_id>-<task_name>.<base>.md\n\n"
|
|
f"其中 `<YYYY-MM-DD>` = 本会话 today=`{today}`;"
|
|
f"`<task_short_id>` = `{short_id}`(永不变,主锚);"
|
|
f"`<task_name>` = `{tname}`(可变,人类可读说明,原样用 含 CJK / 空格);"
|
|
f"`<base>` 由 skill 定义(如 proposal/ppt 的 `spec`)。\n\n"
|
|
f"**取 current 版本规则**:read 时 **按 task_short_id 锚定** glob "
|
|
f"`{wd_abs}/*-{short_id}-*.<base>.md` → 按文件名字典序排 → 取最大者"
|
|
f"(= 最新日期)。这样即使用户改了 task_name,旧文件仍能定位(`<task_name>` "
|
|
f"那段视为「建时快照」,不强求同步)。这是「current 指针」的纯文件名实现,"
|
|
f"agent 自己拼即可。\n\n"
|
|
f"**重定调场景**:用户阶段一已确认过的「宪法」文件,后续要推翻重写时,"
|
|
f"以 today=`{today}` 为前缀写一份新的,**旧版自然保留为历史快照**(不要 edit "
|
|
f"覆盖旧文件)。同日多次重定调可在文件名末尾加 `-v2` / `-v3` 等递增后缀。\n\n"
|
|
f"**隔离逻辑**:同 working_dir 多 task → 由 `<task_short_id>` 严格隔离"
|
|
f"(8 位 hex,撞概率近 0);同 task 多版本 → 由 `<YYYY-MM-DD>` 隔离。两层隔离"
|
|
f"都靠文件名,**无目录嵌套、无 DB 字段、无 cascade rename**。其余产物"
|
|
f"(`sections/` / `figures/` / `slides/` / 终稿 .docx/.pptx 等)按 SKILL "
|
|
f"文档保留扁平共享,LLM 自行通过 task_short_id / 命名前缀判断归属。"
|
|
)
|
|
return prompt
|
|
|
|
|
|
def build_agent(
|
|
*,
|
|
user_id: UUID,
|
|
model_name: Optional[str] = None,
|
|
workspace: Optional[str] = None,
|
|
console: Optional[Console] = None,
|
|
session_id: Optional[str] = None,
|
|
resume: bool = False,
|
|
tool_base: Optional[Path] = None,
|
|
skill: str = "",
|
|
description: str = "",
|
|
name: Optional[str] = None,
|
|
working_dir: Optional[str] = None,
|
|
image_variant: str = "",
|
|
video_variant: str = "",
|
|
cancel_check: Optional[Callable[[], bool]] = None,
|
|
) -> Tuple[AgentLoop, Session, str, TaskState, Path]:
|
|
"""返回 (agent, session, task_id_str, task_state, working_dir_path)。
|
|
|
|
新建 task:
|
|
- `name` 必填(任务显示名,DB 列 NOT NULL,走 validate_task_name)
|
|
- `working_dir` 可选(留空 → fallback 用 name 作目录名;非空也走 validate_task_name)
|
|
Resume:name / working_dir 都忽略(从 DB 读)。
|
|
|
|
`user_id` 必填,决定 working_dir 根、memory 子树、no-subtask 校验作用域。
|
|
web 入口从 JWT 拿到后透传;不允许无 user 的调用路径。
|
|
"""
|
|
cfg = load_config()
|
|
uid = user_id
|
|
|
|
# model 选择优先级:caller 传参 > resume 时 task.model_profile > cfg["default_model"]。
|
|
# caller 传参为新建 task 时 web POST /v1/tasks 接收的 model_profile 字段;resume
|
|
# 不传时读 tasks 表(由顶栏下拉切换 PATCH 维护)。整体满足 grill A 粒度:下条 send 生效。
|
|
model = model_name
|
|
if model is None and resume and session_id:
|
|
from sqlalchemy import select as _select
|
|
from core.storage import session_scope as _scope
|
|
from core.storage.models import Task as _Task
|
|
with _scope() as _s:
|
|
model = _s.execute(
|
|
_select(_Task.model_profile).where(_Task.task_id == UUID(session_id))
|
|
).scalar_one_or_none() or None
|
|
if not model:
|
|
model = cfg["default_model"]
|
|
|
|
caps = ModelCapabilities.load(model, ROOT / cfg["models_dir"])
|
|
llm = LLM(caps)
|
|
|
|
workspace_dir = resolve_workspace(workspace, cfg)
|
|
|
|
# 新建时校验 name + 解析 working_dir(留空 fallback 用 name);resume 跳过
|
|
task_name_safe = ""
|
|
wd_name_for_resolve: Optional[str] = None
|
|
if not resume:
|
|
if not name:
|
|
raise InvalidTaskName("new task 必须指定 name(任务显示名)")
|
|
task_name_safe = validate_task_name(name)
|
|
wd_raw = (working_dir or "").strip()
|
|
wd_name = wd_raw if wd_raw else task_name_safe
|
|
wd_name_for_resolve = validate_task_name(wd_name)
|
|
|
|
task_id, working_dir_path = resolve_task_id(
|
|
workspace_dir, session_id, resume, uid, wd_name_for_resolve
|
|
)
|
|
sid = str(task_id)
|
|
|
|
# §7.4 no-subtask:新建 task 时校验 working_dir 不与同 user 已有 task 形成前缀嵌套
|
|
# (resume 跳过 —— 该 task 已落库,改名走 Folder API 的 cascade)
|
|
if not resume:
|
|
check_no_subtask(str(working_dir_path), user_id=uid)
|
|
# working_dir 立刻建出 —— DB 是 source of truth,FS 目录视为可重生的视图。
|
|
# resume 时也兜底 mkdir(用户可能经 /v1/files/delete 删过空目录),
|
|
# 同 working_dir 多 task 共享,exist_ok=True 不冲突。
|
|
working_dir_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
tool_base = Path(tool_base) if tool_base else Path.cwd()
|
|
|
|
skills = SkillRegistry(ROOT / cfg.get("skills_dir", "skills"))
|
|
|
|
now_iso = datetime.now().isoformat(timespec="seconds")
|
|
# meta["working_dir"] 是 db 形态(相对 ROOT 或绝对);Session.append → ensure_local_task_row
|
|
# 把它直接落 PG tasks.working_dir,所以这里就转好。文件系统操作仍用 working_dir_path(absolute)。
|
|
wd_db = to_db_path(working_dir_path)
|
|
|
|
# task_state 先就位:resume 从 DB 拿真 name,new 直接用 task_name_safe。
|
|
# system_prompt 拼接需要 task.name 注入(「宪法」文件命名约定),所以拼 prompt
|
|
# 必须在 task_state 之后。
|
|
if resume:
|
|
task_state = TaskState.load(task_id)
|
|
if task_state is None:
|
|
# tasks 行不存在 —— 理论上 resolve_task_id 已经定位到 DB 行了,走到这里
|
|
# 说明被并发删了,兜底构造空 state(不主动 save,等下条 append / 命令)
|
|
task_state = TaskState(
|
|
task_id=sid, user_id=uid, name="", working_dir=wd_db,
|
|
skill=skill, description=description, status="active",
|
|
model=caps.model_id, model_profile=model,
|
|
)
|
|
else:
|
|
# 懒创建:TaskState 仅内存。tasks 行在首条 user 消息 append 时由
|
|
# ensure_local_task_row 占位 INSERT(name 已就位);首次 sync_task_tokens
|
|
# 或 /done /desc 走 upsert 覆盖完整字段。
|
|
task_state = TaskState(
|
|
task_id=sid, user_id=uid, name=task_name_safe, working_dir=wd_db,
|
|
skill=skill, description=description, status="active",
|
|
model=caps.model_id, model_profile=model,
|
|
reasoning_effort=caps.default_reasoning_effort or "",
|
|
)
|
|
|
|
system_prompt = _build_system_prompt(
|
|
cfg, skills, workspace_dir, tool_base, working_dir_path, uid,
|
|
task_id, task_state.name, task_state.skill,
|
|
)
|
|
|
|
meta = {
|
|
"id": sid,
|
|
"created_at": now_iso,
|
|
"cwd": str(tool_base),
|
|
"name": task_state.name, # resume / new 都拿到真 name(空字符串只在并发删兜底分支)
|
|
"working_dir": wd_db,
|
|
"model": caps.model_id,
|
|
"model_profile": model,
|
|
"skill": skill,
|
|
"description": description,
|
|
"reasoning_effort": caps.default_reasoning_effort or "",
|
|
}
|
|
|
|
if resume:
|
|
session = Session.load(task_id, system_prompt=system_prompt, meta=meta)
|
|
else:
|
|
session = Session(task_id=task_id, system_prompt=system_prompt, meta=meta)
|
|
|
|
# user_root 传给 tool 让 fs 输出渲染成相对路径(不泄漏 user_id / 部署根,
|
|
# 同时让 web SPA artifact chip 抽取稳定锚定 <wd>/ 前缀)
|
|
ur_path = user_root(workspace_dir, uid)
|
|
|
|
tools = {}
|
|
for cls in (ReadTool, WriteTool, EditTool, GlobTool, GrepTool, ShellTool):
|
|
t = cls(base_dir=tool_base, user_root=ur_path)
|
|
tools[t.name] = t
|
|
|
|
# web_fetch 无需 API key,始终可用
|
|
wf = WebFetchTool(base_dir=tool_base, user_root=ur_path)
|
|
tools[wf.name] = wf
|
|
|
|
if skills.skills:
|
|
ls = LoadSkillTool(registry=skills, base_dir=tool_base, user_root=ur_path)
|
|
tools[ls.name] = ls
|
|
|
|
if caps.enable_run_python:
|
|
rp = RunPythonTool(base_dir=tool_base, user_root=ur_path)
|
|
tools[rp.name] = rp
|
|
|
|
# 每账号每日配额(yaml `quotas` 段,跨 task 跨 variant 全口径合计;
|
|
# 0 / 缺失 = 不限)。tool 起手 check_daily_quota,超额返 [Error] 不调远端。
|
|
quotas = cfg.get("quotas") or {}
|
|
images_per_day = int(quotas.get("images_per_day", 0))
|
|
videos_per_day = int(quotas.get("videos_per_day", 0))
|
|
|
|
# 媒体生成 tool(豆包 seedream / 后续 seedance):仅当 ARK_API_KEY 设了才挂 ——
|
|
# 没 key 的用户无感知,不至于看到 schema 里突然多个永远报错的工具。
|
|
# image_variant 由 caller 传(web 入口随消息 POST 带);空 → 取 yaml 第一个 variant
|
|
# (fallback,沿用原行为)。本次 run 装的 SeedreamTool 锁定该 variant,本 run 内的
|
|
# 多次 tool call 全用同一个;下一条消息可以重选。
|
|
ark_cfg = ArkConfig.load()
|
|
if ark_cfg is not None:
|
|
image_cfg = (ark_cfg.raw.get("image") or {})
|
|
chosen_key, chosen_cfg = "", None
|
|
if image_variant:
|
|
v = image_cfg.get(image_variant)
|
|
if isinstance(v, dict):
|
|
chosen_key, chosen_cfg = image_variant, v
|
|
# 不认的 variant 静默退到 fallback —— web 入口已校验过;留兜底防 yaml 改动
|
|
if chosen_cfg is None:
|
|
for variant_key, variant_cfg in image_cfg.items():
|
|
if isinstance(variant_cfg, dict):
|
|
chosen_key, chosen_cfg = variant_key, variant_cfg
|
|
break
|
|
if chosen_cfg is not None:
|
|
seedream_tool = SeedreamTool(
|
|
ark_cfg=ark_cfg,
|
|
image_variant_cfg=chosen_cfg,
|
|
variant_key=chosen_key,
|
|
working_dir=working_dir_path,
|
|
task_id=task_id,
|
|
user_id=uid,
|
|
base_dir=tool_base,
|
|
user_root=ur_path,
|
|
daily_limit=images_per_day,
|
|
)
|
|
tools[seedream_tool.name] = seedream_tool
|
|
|
|
# 视频 variant 选择(同上 image_variant 范式):video_variant 由 caller 传,
|
|
# 空 → 取 yaml 第一个 video variant。本 run 的 SeedanceTool 锁定该 variant。
|
|
# cancel_check 是 web 入口构造的 `lambda: broker.is_cancelled(task_id)` —— 轮询
|
|
# 期间(典型 30-90s)拿来响应用户停止按钮;远端 cgt 任务无 cancel API,best-effort 不动远端
|
|
video_cfg = (ark_cfg.raw.get("video") or {})
|
|
v_chosen_key, v_chosen_cfg = "", None
|
|
if video_variant:
|
|
v = video_cfg.get(video_variant)
|
|
if isinstance(v, dict):
|
|
v_chosen_key, v_chosen_cfg = video_variant, v
|
|
if v_chosen_cfg is None:
|
|
for variant_key, variant_cfg in video_cfg.items():
|
|
if isinstance(variant_cfg, dict):
|
|
v_chosen_key, v_chosen_cfg = variant_key, variant_cfg
|
|
break
|
|
if v_chosen_cfg is not None:
|
|
seedance_tool = SeedanceTool(
|
|
ark_cfg=ark_cfg,
|
|
video_variant_cfg=v_chosen_cfg,
|
|
variant_key=v_chosen_key,
|
|
working_dir=working_dir_path,
|
|
task_id=task_id,
|
|
user_id=uid,
|
|
base_dir=tool_base,
|
|
user_root=ur_path,
|
|
cancel_check=cancel_check,
|
|
daily_limit=videos_per_day,
|
|
)
|
|
tools[seedance_tool.name] = seedance_tool
|
|
|
|
# 博查联网搜索:仅当 BOCHA_API_KEY 设了才挂
|
|
bocha_cfg = BochaConfig.load()
|
|
if bocha_cfg is not None:
|
|
ws = WebSearchTool(cfg=bocha_cfg)
|
|
tools[ws.name] = ws
|
|
|
|
sink = ConsoleEventSink(console) if console else None
|
|
# §7.5 #5 Executor 抽象:本步全 host backend(in-process),Step 3 docker backend
|
|
# 引入后切 `ZCBOT_SANDBOX_BACKEND=docker` 把 shell/run_python dispatch 到容器。
|
|
executor = HostExecutor(tools)
|
|
agent = AgentLoop(
|
|
llm, executor, session, caps,
|
|
user_id=uid, working_dir=working_dir_path, sink=sink,
|
|
)
|
|
if cancel_check is not None:
|
|
agent.cancel_check = cancel_check
|
|
return agent, session, sid, task_state, working_dir_path
|
|
|
|
|
|
def sync_task_tokens(task_state: TaskState) -> None:
|
|
"""每轮 agent.run 后调,把累计 tokens UPDATE 到 PG tasks 表。
|
|
|
|
从 `messages.tokens_in/out` SUM 现算 —— `record_chat_usage` 写每条 assistant
|
|
message 时已落库,这里聚合写入 tasks 概览列。query 走 (task_id) 索引,行数
|
|
顶天几百,亚毫秒级,在刚跑完几秒 LLM 后的 round-trip 噪声里。
|
|
"""
|
|
from uuid import UUID
|
|
from sqlalchemy import func, select
|
|
from core.storage import update_task
|
|
from core.storage.engine import session_scope
|
|
from core.storage.models import Message
|
|
tid = UUID(task_state.task_id)
|
|
with session_scope() as s:
|
|
row = s.execute(
|
|
select(
|
|
func.coalesce(func.sum(Message.tokens_in), 0),
|
|
func.coalesce(func.sum(Message.tokens_out), 0),
|
|
).where(Message.task_id == tid)
|
|
).one()
|
|
tp, tc = int(row[0]), int(row[1])
|
|
task_state.tokens_prompt = tp
|
|
task_state.tokens_completion = tc
|
|
update_task(tid, tokens_prompt=tp, tokens_completion=tc)
|