zcbot/main.py

232 lines
8.3 KiB
Python

"""装配入口: 读 config → 加载 capabilities/skills → 构造 LLM/tools/session/loop。
存储布局(§7 B Step 3 后):
PG tasks / messages ← Task 元数据 + Session 消息
workspace/tasks/<task_id>/ ← task_dir,只承担 skill 产物
task_id 用 UUID,state.json 已删除(元数据全在 PG)。
"""
from __future__ import annotations
from datetime import datetime
from pathlib import Path
from typing import Optional, Tuple
from uuid import UUID, uuid4
import yaml
from rich.console import Console
from core.capabilities import ModelCapabilities
from core.llm import LLM
from core.loop import AgentLoop
from core.memory import memory_block
from core.session import Session
from core.sinks import ConsoleEventSink
from core.skills import SkillRegistry
from core.storage import ensure_local_sentinel
from core.task import TaskState
from tools.fs import EditTool, GlobTool, GrepTool, ReadTool, WriteTool
from tools.run_python import RunPythonTool
from tools.shell import ShellTool
from tools.skill_tool import LoadSkillTool
ROOT = Path(__file__).resolve().parent
def load_config() -> dict:
return yaml.safe_load((ROOT / "config" / "agent.yaml").read_text(encoding="utf-8")) or {}
def resolve_workspace(workspace: Optional[str], cfg: Optional[dict] = None) -> Path:
cfg = cfg or load_config()
p = Path(workspace) if workspace else ROOT / cfg.get("workspace_dir", "workspace")
p.mkdir(parents=True, exist_ok=True)
return p
def tasks_dir(workspace_dir: Path) -> Path:
d = workspace_dir / "tasks"
d.mkdir(parents=True, exist_ok=True)
return d
def resolve_task_id(
workspace_dir: Path, task_id_arg: Optional[str], resume: bool
) -> Tuple[UUID, Path]:
"""返回 (task_id, task_dir)。
新建:UUID + workspace/tasks/<uuid>/(懒创建,目录不预占)
Resume:解析 task_id_arg 为 UUID(支持前缀匹配);'last' 取最近(按 PG tasks.updated_at)
"""
tdir = tasks_dir(workspace_dir)
if resume:
from sqlalchemy import select
from core.storage import session_scope
from core.storage.models import Task
if task_id_arg in (None, "", "last"):
with session_scope() as s:
row = s.execute(
select(Task.task_id).order_by(Task.updated_at.desc()).limit(1)
).scalar_one_or_none()
if row is None:
raise FileNotFoundError("no recoverable task: PG tasks 表为空")
return row, tdir / str(row)
# 接受完整 UUID 或前缀(8 字符够辨识本机量级)
tid = _resolve_uuid_or_prefix(task_id_arg)
return tid, tdir / str(tid)
tid = uuid4()
return tid, tdir / str(tid)
def _resolve_uuid_or_prefix(s: str) -> UUID:
"""完整 UUID 字符串直接解析;否则当前缀,从 tasks 表精确匹配一个。"""
try:
return UUID(s)
except ValueError:
pass
from sqlalchemy import cast, String, select
from core.storage import session_scope
from core.storage.models import Task
with session_scope() as sess:
matches = sess.execute(
select(Task.task_id).where(cast(Task.task_id, String).like(f"{s}%"))
).scalars().all()
if not matches:
raise FileNotFoundError(f"no task matching prefix: {s}")
if len(matches) > 1:
raise ValueError(f"ambiguous prefix {s!r}, matched {len(matches)} tasks")
return matches[0]
def _build_system_prompt(
cfg: dict,
skills: SkillRegistry,
workspace_dir: Path,
tool_base: Path,
task_dir: Path,
) -> str:
"""拼 system prompt: 模板 + skill 列表 + memory + 工作目录段。
new task 和 resume task 都走这里,memory 演化即时生效。
"""
prompt = (ROOT / cfg["system_prompt"]).read_text(encoding="utf-8")
if skills.skills:
prompt += f"\n\n## 可用 skill (用 load_skill 加载完整指引)\n{skills.discovery_block()}"
prompt += memory_block(workspace_dir)
task_dir_abs = task_dir.resolve()
prompt += (
f"\n\n## 工作目录\n"
f"- cwd(用户启动时所在目录,只读用): `{tool_base}`\n"
f"- **task_dir(所有产物写到这里)**: `{task_dir_abs}`\n\n"
f"SKILL 文档里出现的 `<task_dir>` 占位符,一律指上面这个绝对路径。"
f"产物示例: `{task_dir_abs}/spec_lock.md`、"
f"`{task_dir_abs}/sections/01_summary.md`、"
f"`{task_dir_abs}/slides/`、最终 .docx/.pptx。\n"
f"⛔ 不要把产物写到 cwd / `skills/` / repo 根 —— 只写到 task_dir。"
)
return prompt
def build_agent(
model_name: Optional[str] = None,
workspace: Optional[str] = None,
console: Optional[Console] = None,
session_id: Optional[str] = None,
resume: bool = False,
tool_base: Optional[Path] = None,
mode: str = "",
description: str = "",
) -> Tuple[AgentLoop, Session, str, TaskState, Path]:
"""返回 (agent, session, task_id_str, task_state, task_dir)。"""
cfg = load_config()
model = model_name or cfg["default_model"]
# 本地 sentinel user 入库(idempotent);build_agent 是所有 task 操作的入口
ensure_local_sentinel()
caps = ModelCapabilities.load(model, ROOT / cfg["models_dir"])
llm = LLM(caps)
workspace_dir = resolve_workspace(workspace, cfg)
task_id, task_dir = resolve_task_id(workspace_dir, session_id, resume)
sid = str(task_id)
tool_base = Path(tool_base) if tool_base else Path.cwd()
skills = SkillRegistry(ROOT / cfg.get("skills_dir", "skills"))
system_prompt = _build_system_prompt(cfg, skills, workspace_dir, tool_base, task_dir)
now_iso = datetime.now().isoformat(timespec="seconds")
meta = {
"id": sid,
"created_at": now_iso,
"cwd": str(tool_base),
"task_dir": str(task_dir),
"model": caps.model_id,
"model_profile": model,
"mode": mode,
"description": description,
"reasoning_effort": caps.default_reasoning_effort or "",
}
if resume:
session = Session.load(task_id, system_prompt=system_prompt, meta=meta)
task_state = TaskState.load(task_id)
if task_state is None:
# tasks 行不存在 —— 理论上 resolve_task_id 已经定位到 DB 行了,走到这里
# 说明被并发删了,兜底构造空 state(不主动 save,等下条 append / 命令)
task_state = TaskState(
task_id=sid, task_dir=str(task_dir),
mode=mode, description=description, status="active",
model=caps.model_id, model_profile=model,
)
else:
session = Session(task_id=task_id, system_prompt=system_prompt, meta=meta)
# 懒创建:TaskState 仅内存。tasks 行在首条 user 消息 append 时由
# ensure_local_task_row 占位 INSERT;首次 sync_task_tokens 或 /done /desc 走 upsert 覆盖。
task_state = TaskState(
task_id=sid, task_dir=str(task_dir),
mode=mode, description=description, status="active",
model=caps.model_id, model_profile=model,
reasoning_effort=caps.default_reasoning_effort or "",
)
tools = {}
for cls in (ReadTool, WriteTool, EditTool, GlobTool, GrepTool, ShellTool):
t = cls(base_dir=tool_base)
tools[t.name] = t
if skills.skills:
ls = LoadSkillTool(registry=skills, base_dir=tool_base)
tools[ls.name] = ls
if caps.enable_run_python:
rp = RunPythonTool(base_dir=tool_base)
tools[rp.name] = rp
sink = ConsoleEventSink(console, token_counter=lambda: llm.token_counter.total) if console else None
agent = AgentLoop(llm, tools, session, caps, sink=sink)
return agent, session, sid, task_state, task_dir
def sync_task_tokens(task_state: TaskState, llm: LLM) -> None:
"""每轮 agent.run 后调,把 LLM 累计 tokens UPDATE 到 PG tasks 表。
走 update_task 而非 task_state.save() —— 只更 tokens 两列,避免无谓全字段 UPSERT
且 ORM-level update 自动刷 updated_at。
"""
from uuid import UUID
from core.storage import update_task
tc = llm.token_counter
task_state.tokens_prompt = tc.prompt_tokens
task_state.tokens_completion = tc.completion_tokens
update_task(
UUID(task_state.task_id),
tokens_prompt=tc.prompt_tokens,
tokens_completion=tc.completion_tokens,
)