zcbot/main.py

286 lines
10 KiB
Python

"""装配入口: 读 config → 加载 capabilities/skills → 构造 LLM/tools/session/loop。
存储布局(§7 B Step 3 后):
PG tasks / messages ← Task 元数据 + Session 消息
workspace/tasks/<task_id>/ ← task_dir,只承担 skill 产物
task_id 用 UUID,state.json 已删除(元数据全在 PG)。
"""
from __future__ import annotations
from datetime import datetime
from pathlib import Path
from typing import Optional, Tuple
from uuid import UUID, uuid4
import yaml
from rich.console import Console
from core.capabilities import ModelCapabilities
from core.llm import LLM
from core.loop import AgentLoop
from core.memory import memory_block
from core.paths import ROOT, from_db_path, to_db_path
from core.session import Session
from core.sinks import ConsoleEventSink
from core.skills import SkillRegistry
from core.storage import check_no_subtask, ensure_local_sentinel
from core.task import TaskState
from tools.fs import EditTool, GlobTool, GrepTool, ReadTool, WriteTool
from tools.run_python import RunPythonTool
from tools.shell import ShellTool
from tools.skill_tool import LoadSkillTool
def load_config() -> dict:
return yaml.safe_load((ROOT / "config" / "agent.yaml").read_text(encoding="utf-8")) or {}
def resolve_workspace(workspace: Optional[str], cfg: Optional[dict] = None) -> Path:
cfg = cfg or load_config()
p = Path(workspace) if workspace else ROOT / cfg.get("workspace_dir", "workspace")
p.mkdir(parents=True, exist_ok=True)
return p
def tasks_dir(workspace_dir: Path) -> Path:
d = workspace_dir / "tasks"
d.mkdir(parents=True, exist_ok=True)
return d
def _default_task_dir(workspace_dir: Path, task_id: UUID) -> Path:
return tasks_dir(workspace_dir) / str(task_id)
def is_managed_task_dir(task_dir: Path, workspace_dir: Path) -> bool:
"""task_dir 是否在 workspace/tasks/<uuid>/ 默认派生模板下。
用作 _cleanup_if_empty 的保护开关 —— 用户自指定的项目目录绝不 rmtree。
"""
try:
rel = task_dir.resolve().relative_to(tasks_dir(workspace_dir).resolve())
except (ValueError, OSError):
return False
parts = rel.parts
if len(parts) != 1:
return False
try:
UUID(parts[0])
except ValueError:
return False
return True
def resolve_task_id(
workspace_dir: Path,
task_id_arg: Optional[str],
resume: bool,
task_dir_arg: Optional[str] = None,
) -> Tuple[UUID, Path]:
"""返回 (task_id, task_dir 绝对路径)。
新建:
- UUID + (task_dir_arg 显式 → 用户路径绝对化;否则默认派生 workspace/tasks/<uuid>/)
Resume:
- task_id 从前缀/UUID/'last' 解析;task_dir 从 PG tasks.task_dir 读
- DB task_dir 为空表示"该 task 创建时未显式指定" → 仍用默认派生(老数据 / Step 3 前)
- task_dir_arg 在 resume 时若传入 → 覆盖 DB 值(允许用户改绑路径,但调用方需自行 UPSERT)
"""
if resume:
from sqlalchemy import select
from core.storage import session_scope
from core.storage.models import Task
if task_id_arg in (None, "", "last"):
with session_scope() as s:
row = s.execute(
select(Task.task_id, Task.task_dir)
.order_by(Task.updated_at.desc()).limit(1)
).first()
if row is None:
raise FileNotFoundError("no recoverable task: PG tasks 表为空")
tid, db_dir = row
else:
tid = _resolve_uuid_or_prefix(task_id_arg)
with session_scope() as s:
db_dir = s.execute(
select(Task.task_dir).where(Task.task_id == tid)
).scalar_one_or_none() or ""
if task_dir_arg and task_dir_arg.strip():
# 用户显式覆盖(允许 resume 时改绑路径,调用方需自行 UPSERT 持久化)
fs_dir = Path(task_dir_arg).expanduser().resolve()
elif db_dir:
# DB 存的是 db 形态(相对 ROOT 或绝对),走 from_db_path 还原绝对
fs_dir = from_db_path(db_dir)
else:
fs_dir = _default_task_dir(workspace_dir, tid)
return tid, fs_dir
tid = uuid4()
if task_dir_arg and task_dir_arg.strip():
fs_dir = Path(task_dir_arg).expanduser().resolve()
else:
fs_dir = _default_task_dir(workspace_dir, tid)
return tid, fs_dir
def _resolve_uuid_or_prefix(s: str) -> UUID:
"""完整 UUID 字符串直接解析;否则当前缀,从 tasks 表精确匹配一个。"""
try:
return UUID(s)
except ValueError:
pass
from sqlalchemy import cast, String, select
from core.storage import session_scope
from core.storage.models import Task
with session_scope() as sess:
matches = sess.execute(
select(Task.task_id).where(cast(Task.task_id, String).like(f"{s}%"))
).scalars().all()
if not matches:
raise FileNotFoundError(f"no task matching prefix: {s}")
if len(matches) > 1:
raise ValueError(f"ambiguous prefix {s!r}, matched {len(matches)} tasks")
return matches[0]
def _build_system_prompt(
cfg: dict,
skills: SkillRegistry,
workspace_dir: Path,
tool_base: Path,
task_dir: Path,
) -> str:
"""拼 system prompt: 模板 + skill 列表 + memory + 工作目录段。
new task 和 resume task 都走这里,memory 演化即时生效。
"""
prompt = (ROOT / cfg["system_prompt"]).read_text(encoding="utf-8")
if skills.skills:
prompt += f"\n\n## 可用 skill (用 load_skill 加载完整指引)\n{skills.discovery_block()}"
prompt += memory_block(workspace_dir)
task_dir_abs = task_dir.resolve()
prompt += (
f"\n\n## 工作目录\n"
f"- cwd(用户启动时所在目录,只读用): `{tool_base}`\n"
f"- **task_dir(所有产物写到这里)**: `{task_dir_abs}`\n\n"
f"SKILL 文档里出现的 `<task_dir>` 占位符,一律指上面这个绝对路径。"
f"产物示例: `{task_dir_abs}/spec_lock.md`、"
f"`{task_dir_abs}/sections/01_summary.md`、"
f"`{task_dir_abs}/slides/`、最终 .docx/.pptx。\n"
f"⛔ 不要把产物写到 cwd / `skills/` / repo 根 —— 只写到 task_dir。"
)
return prompt
def build_agent(
model_name: Optional[str] = None,
workspace: Optional[str] = None,
console: Optional[Console] = None,
session_id: Optional[str] = None,
resume: bool = False,
tool_base: Optional[Path] = None,
mode: str = "",
description: str = "",
task_dir_arg: Optional[str] = None,
) -> Tuple[AgentLoop, Session, str, TaskState, Path]:
"""返回 (agent, session, task_id_str, task_state, task_dir)。"""
cfg = load_config()
model = model_name or cfg["default_model"]
# 本地 sentinel user 入库(idempotent);build_agent 是所有 task 操作的入口
ensure_local_sentinel()
caps = ModelCapabilities.load(model, ROOT / cfg["models_dir"])
llm = LLM(caps)
workspace_dir = resolve_workspace(workspace, cfg)
task_id, task_dir = resolve_task_id(workspace_dir, session_id, resume, task_dir_arg)
sid = str(task_id)
# §7.4 no-subtask:新建 task 时校验 task_dir 不与同 user 已有 task 形成前缀嵌套
# (resume 跳过 —— 该 task 已落库,改名走 Folder API 的 cascade)
if not resume:
check_no_subtask(str(task_dir))
tool_base = Path(tool_base) if tool_base else Path.cwd()
skills = SkillRegistry(ROOT / cfg.get("skills_dir", "skills"))
system_prompt = _build_system_prompt(cfg, skills, workspace_dir, tool_base, task_dir)
now_iso = datetime.now().isoformat(timespec="seconds")
# meta["task_dir"] 是 db 形态(相对 ROOT 或绝对);Session.append → ensure_local_task_row
# 把它直接落 PG tasks.task_dir,所以这里就转好。文件系统操作仍用上面的 task_dir(absolute)。
task_dir_db = to_db_path(task_dir)
meta = {
"id": sid,
"created_at": now_iso,
"cwd": str(tool_base),
"task_dir": task_dir_db,
"model": caps.model_id,
"model_profile": model,
"mode": mode,
"description": description,
"reasoning_effort": caps.default_reasoning_effort or "",
}
if resume:
session = Session.load(task_id, system_prompt=system_prompt, meta=meta)
task_state = TaskState.load(task_id)
if task_state is None:
# tasks 行不存在 —— 理论上 resolve_task_id 已经定位到 DB 行了,走到这里
# 说明被并发删了,兜底构造空 state(不主动 save,等下条 append / 命令)
task_state = TaskState(
task_id=sid, task_dir=task_dir_db,
mode=mode, description=description, status="active",
model=caps.model_id, model_profile=model,
)
else:
session = Session(task_id=task_id, system_prompt=system_prompt, meta=meta)
# 懒创建:TaskState 仅内存。tasks 行在首条 user 消息 append 时由
# ensure_local_task_row 占位 INSERT;首次 sync_task_tokens 或 /done /desc 走 upsert 覆盖。
task_state = TaskState(
task_id=sid, task_dir=task_dir_db,
mode=mode, description=description, status="active",
model=caps.model_id, model_profile=model,
reasoning_effort=caps.default_reasoning_effort or "",
)
tools = {}
for cls in (ReadTool, WriteTool, EditTool, GlobTool, GrepTool, ShellTool):
t = cls(base_dir=tool_base)
tools[t.name] = t
if skills.skills:
ls = LoadSkillTool(registry=skills, base_dir=tool_base)
tools[ls.name] = ls
if caps.enable_run_python:
rp = RunPythonTool(base_dir=tool_base)
tools[rp.name] = rp
sink = ConsoleEventSink(console, token_counter=lambda: llm.token_counter.total) if console else None
agent = AgentLoop(llm, tools, session, caps, sink=sink)
return agent, session, sid, task_state, task_dir
def sync_task_tokens(task_state: TaskState, llm: LLM) -> None:
"""每轮 agent.run 后调,把 LLM 累计 tokens UPDATE 到 PG tasks 表。
走 update_task 而非 task_state.save() —— 只更 tokens 两列,避免无谓全字段 UPSERT
且 ORM-level update 自动刷 updated_at。
"""
from uuid import UUID
from core.storage import update_task
tc = llm.token_counter
task_state.tokens_prompt = tc.prompt_tokens
task_state.tokens_completion = tc.completion_tokens
update_task(
UUID(task_state.task_id),
tokens_prompt=tc.prompt_tokens,
tokens_completion=tc.completion_tokens,
)