232 lines
8.3 KiB
Python
232 lines
8.3 KiB
Python
"""装配入口: 读 config → 加载 capabilities/skills → 构造 LLM/tools/session/loop。
|
|
|
|
存储布局(§7 B Step 3 后):
|
|
PG tasks / messages ← Task 元数据 + Session 消息
|
|
workspace/tasks/<task_id>/ ← task_dir,只承担 skill 产物
|
|
task_id 用 UUID,state.json 已删除(元数据全在 PG)。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Optional, Tuple
|
|
from uuid import UUID, uuid4
|
|
|
|
import yaml
|
|
from rich.console import Console
|
|
|
|
from core.capabilities import ModelCapabilities
|
|
from core.llm import LLM
|
|
from core.loop import AgentLoop
|
|
from core.memory import memory_block
|
|
from core.session import Session
|
|
from core.sinks import ConsoleEventSink
|
|
from core.skills import SkillRegistry
|
|
from core.storage import ensure_local_sentinel
|
|
from core.task import TaskState
|
|
from tools.fs import EditTool, GlobTool, GrepTool, ReadTool, WriteTool
|
|
from tools.run_python import RunPythonTool
|
|
from tools.shell import ShellTool
|
|
from tools.skill_tool import LoadSkillTool
|
|
|
|
ROOT = Path(__file__).resolve().parent
|
|
|
|
|
|
def load_config() -> dict:
|
|
return yaml.safe_load((ROOT / "config" / "agent.yaml").read_text(encoding="utf-8")) or {}
|
|
|
|
|
|
def resolve_workspace(workspace: Optional[str], cfg: Optional[dict] = None) -> Path:
|
|
cfg = cfg or load_config()
|
|
p = Path(workspace) if workspace else ROOT / cfg.get("workspace_dir", "workspace")
|
|
p.mkdir(parents=True, exist_ok=True)
|
|
return p
|
|
|
|
|
|
def tasks_dir(workspace_dir: Path) -> Path:
|
|
d = workspace_dir / "tasks"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
return d
|
|
|
|
|
|
def resolve_task_id(
|
|
workspace_dir: Path, task_id_arg: Optional[str], resume: bool
|
|
) -> Tuple[UUID, Path]:
|
|
"""返回 (task_id, task_dir)。
|
|
|
|
新建:UUID + workspace/tasks/<uuid>/(懒创建,目录不预占)
|
|
Resume:解析 task_id_arg 为 UUID(支持前缀匹配);'last' 取最近(按 PG tasks.updated_at)
|
|
"""
|
|
tdir = tasks_dir(workspace_dir)
|
|
if resume:
|
|
from sqlalchemy import select
|
|
from core.storage import session_scope
|
|
from core.storage.models import Task
|
|
|
|
if task_id_arg in (None, "", "last"):
|
|
with session_scope() as s:
|
|
row = s.execute(
|
|
select(Task.task_id).order_by(Task.updated_at.desc()).limit(1)
|
|
).scalar_one_or_none()
|
|
if row is None:
|
|
raise FileNotFoundError("no recoverable task: PG tasks 表为空")
|
|
return row, tdir / str(row)
|
|
|
|
# 接受完整 UUID 或前缀(8 字符够辨识本机量级)
|
|
tid = _resolve_uuid_or_prefix(task_id_arg)
|
|
return tid, tdir / str(tid)
|
|
|
|
tid = uuid4()
|
|
return tid, tdir / str(tid)
|
|
|
|
|
|
def _resolve_uuid_or_prefix(s: str) -> UUID:
|
|
"""完整 UUID 字符串直接解析;否则当前缀,从 tasks 表精确匹配一个。"""
|
|
try:
|
|
return UUID(s)
|
|
except ValueError:
|
|
pass
|
|
from sqlalchemy import cast, String, select
|
|
from core.storage import session_scope
|
|
from core.storage.models import Task
|
|
|
|
with session_scope() as sess:
|
|
matches = sess.execute(
|
|
select(Task.task_id).where(cast(Task.task_id, String).like(f"{s}%"))
|
|
).scalars().all()
|
|
if not matches:
|
|
raise FileNotFoundError(f"no task matching prefix: {s}")
|
|
if len(matches) > 1:
|
|
raise ValueError(f"ambiguous prefix {s!r}, matched {len(matches)} tasks")
|
|
return matches[0]
|
|
|
|
|
|
def _build_system_prompt(
|
|
cfg: dict,
|
|
skills: SkillRegistry,
|
|
workspace_dir: Path,
|
|
tool_base: Path,
|
|
task_dir: Path,
|
|
) -> str:
|
|
"""拼 system prompt: 模板 + skill 列表 + memory + 工作目录段。
|
|
|
|
new task 和 resume task 都走这里,memory 演化即时生效。
|
|
"""
|
|
prompt = (ROOT / cfg["system_prompt"]).read_text(encoding="utf-8")
|
|
if skills.skills:
|
|
prompt += f"\n\n## 可用 skill (用 load_skill 加载完整指引)\n{skills.discovery_block()}"
|
|
prompt += memory_block(workspace_dir)
|
|
task_dir_abs = task_dir.resolve()
|
|
prompt += (
|
|
f"\n\n## 工作目录\n"
|
|
f"- cwd(用户启动时所在目录,只读用): `{tool_base}`\n"
|
|
f"- **task_dir(所有产物写到这里)**: `{task_dir_abs}`\n\n"
|
|
f"SKILL 文档里出现的 `<task_dir>` 占位符,一律指上面这个绝对路径。"
|
|
f"产物示例: `{task_dir_abs}/spec_lock.md`、"
|
|
f"`{task_dir_abs}/sections/01_summary.md`、"
|
|
f"`{task_dir_abs}/slides/`、最终 .docx/.pptx。\n"
|
|
f"⛔ 不要把产物写到 cwd / `skills/` / repo 根 —— 只写到 task_dir。"
|
|
)
|
|
return prompt
|
|
|
|
|
|
def build_agent(
|
|
model_name: Optional[str] = None,
|
|
workspace: Optional[str] = None,
|
|
console: Optional[Console] = None,
|
|
session_id: Optional[str] = None,
|
|
resume: bool = False,
|
|
tool_base: Optional[Path] = None,
|
|
mode: str = "",
|
|
description: str = "",
|
|
) -> Tuple[AgentLoop, Session, str, TaskState, Path]:
|
|
"""返回 (agent, session, task_id_str, task_state, task_dir)。"""
|
|
cfg = load_config()
|
|
model = model_name or cfg["default_model"]
|
|
|
|
# 本地 sentinel user 入库(idempotent);build_agent 是所有 task 操作的入口
|
|
ensure_local_sentinel()
|
|
|
|
caps = ModelCapabilities.load(model, ROOT / cfg["models_dir"])
|
|
llm = LLM(caps)
|
|
|
|
workspace_dir = resolve_workspace(workspace, cfg)
|
|
task_id, task_dir = resolve_task_id(workspace_dir, session_id, resume)
|
|
sid = str(task_id)
|
|
|
|
tool_base = Path(tool_base) if tool_base else Path.cwd()
|
|
|
|
skills = SkillRegistry(ROOT / cfg.get("skills_dir", "skills"))
|
|
|
|
system_prompt = _build_system_prompt(cfg, skills, workspace_dir, tool_base, task_dir)
|
|
|
|
now_iso = datetime.now().isoformat(timespec="seconds")
|
|
meta = {
|
|
"id": sid,
|
|
"created_at": now_iso,
|
|
"cwd": str(tool_base),
|
|
"task_dir": str(task_dir),
|
|
"model": caps.model_id,
|
|
"model_profile": model,
|
|
"mode": mode,
|
|
"description": description,
|
|
"reasoning_effort": caps.default_reasoning_effort or "",
|
|
}
|
|
|
|
if resume:
|
|
session = Session.load(task_id, system_prompt=system_prompt, meta=meta)
|
|
task_state = TaskState.load(task_id)
|
|
if task_state is None:
|
|
# tasks 行不存在 —— 理论上 resolve_task_id 已经定位到 DB 行了,走到这里
|
|
# 说明被并发删了,兜底构造空 state(不主动 save,等下条 append / 命令)
|
|
task_state = TaskState(
|
|
task_id=sid, task_dir=str(task_dir),
|
|
mode=mode, description=description, status="active",
|
|
model=caps.model_id, model_profile=model,
|
|
)
|
|
else:
|
|
session = Session(task_id=task_id, system_prompt=system_prompt, meta=meta)
|
|
# 懒创建:TaskState 仅内存。tasks 行在首条 user 消息 append 时由
|
|
# ensure_local_task_row 占位 INSERT;首次 sync_task_tokens 或 /done /desc 走 upsert 覆盖。
|
|
task_state = TaskState(
|
|
task_id=sid, task_dir=str(task_dir),
|
|
mode=mode, description=description, status="active",
|
|
model=caps.model_id, model_profile=model,
|
|
reasoning_effort=caps.default_reasoning_effort or "",
|
|
)
|
|
|
|
tools = {}
|
|
for cls in (ReadTool, WriteTool, EditTool, GlobTool, GrepTool, ShellTool):
|
|
t = cls(base_dir=tool_base)
|
|
tools[t.name] = t
|
|
|
|
if skills.skills:
|
|
ls = LoadSkillTool(registry=skills, base_dir=tool_base)
|
|
tools[ls.name] = ls
|
|
|
|
if caps.enable_run_python:
|
|
rp = RunPythonTool(base_dir=tool_base)
|
|
tools[rp.name] = rp
|
|
|
|
sink = ConsoleEventSink(console, token_counter=lambda: llm.token_counter.total) if console else None
|
|
agent = AgentLoop(llm, tools, session, caps, sink=sink)
|
|
return agent, session, sid, task_state, task_dir
|
|
|
|
|
|
def sync_task_tokens(task_state: TaskState, llm: LLM) -> None:
|
|
"""每轮 agent.run 后调,把 LLM 累计 tokens UPDATE 到 PG tasks 表。
|
|
|
|
走 update_task 而非 task_state.save() —— 只更 tokens 两列,避免无谓全字段 UPSERT
|
|
且 ORM-level update 自动刷 updated_at。
|
|
"""
|
|
from uuid import UUID
|
|
from core.storage import update_task
|
|
tc = llm.token_counter
|
|
task_state.tokens_prompt = tc.prompt_tokens
|
|
task_state.tokens_completion = tc.completion_tokens
|
|
update_task(
|
|
UUID(task_state.task_id),
|
|
tokens_prompt=tc.prompt_tokens,
|
|
tokens_completion=tc.completion_tokens,
|
|
)
|