229 lines
8.3 KiB
Python
229 lines
8.3 KiB
Python
"""装配入口: 读 config → 加载 capabilities/skills → 构造 LLM/tools/session/loop。
|
|
|
|
存储布局(§7 B Step 2 后):
|
|
workspace/tasks/<task_id>/state.json ← TaskState(Step 3 前还在,Step 3 删)
|
|
PG messages ← Session 消息(Step 2 切换)
|
|
task_id 用 UUID,task_dir = workspace/tasks/<task_id>/。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Optional, Tuple
|
|
from uuid import UUID, uuid4
|
|
|
|
import yaml
|
|
from rich.console import Console
|
|
|
|
from core.capabilities import ModelCapabilities
|
|
from core.llm import LLM
|
|
from core.loop import AgentLoop
|
|
from core.memory import memory_block
|
|
from core.session import Session
|
|
from core.sinks import ConsoleEventSink
|
|
from core.skills import SkillRegistry
|
|
from core.storage import ensure_local_sentinel
|
|
from core.task import TaskState
|
|
from tools.fs import EditTool, GlobTool, GrepTool, ReadTool, WriteTool
|
|
from tools.run_python import RunPythonTool
|
|
from tools.shell import ShellTool
|
|
from tools.skill_tool import LoadSkillTool
|
|
|
|
ROOT = Path(__file__).resolve().parent
|
|
|
|
|
|
def load_config() -> dict:
|
|
return yaml.safe_load((ROOT / "config" / "agent.yaml").read_text(encoding="utf-8")) or {}
|
|
|
|
|
|
def resolve_workspace(workspace: Optional[str], cfg: Optional[dict] = None) -> Path:
|
|
cfg = cfg or load_config()
|
|
p = Path(workspace) if workspace else ROOT / cfg.get("workspace_dir", "workspace")
|
|
p.mkdir(parents=True, exist_ok=True)
|
|
return p
|
|
|
|
|
|
def tasks_dir(workspace_dir: Path) -> Path:
|
|
d = workspace_dir / "tasks"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
return d
|
|
|
|
|
|
def resolve_task_id(
|
|
workspace_dir: Path, task_id_arg: Optional[str], resume: bool
|
|
) -> Tuple[UUID, Path]:
|
|
"""返回 (task_id, task_dir)。
|
|
|
|
新建:UUID + workspace/tasks/<uuid>/(懒创建,目录不预占)
|
|
Resume:解析 task_id_arg 为 UUID(支持前缀匹配);'last' 取最近(按 PG tasks.updated_at)
|
|
"""
|
|
tdir = tasks_dir(workspace_dir)
|
|
if resume:
|
|
from sqlalchemy import select
|
|
from core.storage import session_scope
|
|
from core.storage.models import Task
|
|
|
|
if task_id_arg in (None, "", "last"):
|
|
with session_scope() as s:
|
|
row = s.execute(
|
|
select(Task.task_id).order_by(Task.updated_at.desc()).limit(1)
|
|
).scalar_one_or_none()
|
|
if row is None:
|
|
raise FileNotFoundError("no recoverable task: PG tasks 表为空")
|
|
return row, tdir / str(row)
|
|
|
|
# 接受完整 UUID 或前缀(8 字符够辨识本机量级)
|
|
tid = _resolve_uuid_or_prefix(task_id_arg)
|
|
return tid, tdir / str(tid)
|
|
|
|
tid = uuid4()
|
|
return tid, tdir / str(tid)
|
|
|
|
|
|
def _resolve_uuid_or_prefix(s: str) -> UUID:
|
|
"""完整 UUID 字符串直接解析;否则当前缀,从 tasks 表精确匹配一个。"""
|
|
try:
|
|
return UUID(s)
|
|
except ValueError:
|
|
pass
|
|
from sqlalchemy import cast, String, select
|
|
from core.storage import session_scope
|
|
from core.storage.models import Task
|
|
|
|
with session_scope() as sess:
|
|
matches = sess.execute(
|
|
select(Task.task_id).where(cast(Task.task_id, String).like(f"{s}%"))
|
|
).scalars().all()
|
|
if not matches:
|
|
raise FileNotFoundError(f"no task matching prefix: {s}")
|
|
if len(matches) > 1:
|
|
raise ValueError(f"ambiguous prefix {s!r}, matched {len(matches)} tasks")
|
|
return matches[0]
|
|
|
|
|
|
def _build_system_prompt(
|
|
cfg: dict,
|
|
skills: SkillRegistry,
|
|
workspace_dir: Path,
|
|
tool_base: Path,
|
|
task_dir: Path,
|
|
) -> str:
|
|
"""拼 system prompt: 模板 + skill 列表 + memory + 工作目录段。
|
|
|
|
new task 和 resume task 都走这里,memory 演化即时生效。
|
|
"""
|
|
prompt = (ROOT / cfg["system_prompt"]).read_text(encoding="utf-8")
|
|
if skills.skills:
|
|
prompt += f"\n\n## 可用 skill (用 load_skill 加载完整指引)\n{skills.discovery_block()}"
|
|
prompt += memory_block(workspace_dir)
|
|
task_dir_abs = task_dir.resolve()
|
|
prompt += (
|
|
f"\n\n## 工作目录\n"
|
|
f"- cwd(用户启动时所在目录,只读用): `{tool_base}`\n"
|
|
f"- **task_dir(所有产物写到这里)**: `{task_dir_abs}`\n\n"
|
|
f"SKILL 文档里出现的 `<task_dir>` 占位符,一律指上面这个绝对路径。"
|
|
f"产物示例: `{task_dir_abs}/spec_lock.md`、"
|
|
f"`{task_dir_abs}/sections/01_summary.md`、"
|
|
f"`{task_dir_abs}/slides/`、最终 .docx/.pptx。\n"
|
|
f"⛔ 不要把产物写到 cwd / `skills/` / repo 根 —— 只写到 task_dir。"
|
|
)
|
|
return prompt
|
|
|
|
|
|
def build_agent(
|
|
model_name: Optional[str] = None,
|
|
workspace: Optional[str] = None,
|
|
console: Optional[Console] = None,
|
|
session_id: Optional[str] = None,
|
|
resume: bool = False,
|
|
tool_base: Optional[Path] = None,
|
|
mode: str = "",
|
|
description: str = "",
|
|
) -> Tuple[AgentLoop, Session, str, TaskState, Path]:
|
|
"""返回 (agent, session, task_id_str, task_state, task_dir)。"""
|
|
cfg = load_config()
|
|
model = model_name or cfg["default_model"]
|
|
|
|
# 本地 sentinel user 入库(idempotent);build_agent 是所有 task 操作的入口
|
|
ensure_local_sentinel()
|
|
|
|
caps = ModelCapabilities.load(model, ROOT / cfg["models_dir"])
|
|
llm = LLM(caps)
|
|
|
|
workspace_dir = resolve_workspace(workspace, cfg)
|
|
task_id, task_dir = resolve_task_id(workspace_dir, session_id, resume)
|
|
sid = str(task_id)
|
|
|
|
tool_base = Path(tool_base) if tool_base else Path.cwd()
|
|
|
|
skills = SkillRegistry(ROOT / cfg.get("skills_dir", "skills"))
|
|
|
|
system_prompt = _build_system_prompt(cfg, skills, workspace_dir, tool_base, task_dir)
|
|
|
|
now_iso = datetime.now().isoformat(timespec="seconds")
|
|
meta = {
|
|
"id": sid,
|
|
"created_at": now_iso,
|
|
"cwd": str(tool_base),
|
|
"task_dir": str(task_dir),
|
|
"model": caps.model_id,
|
|
"model_profile": model,
|
|
}
|
|
|
|
if resume:
|
|
session = Session.load(task_id, system_prompt=system_prompt, meta=meta)
|
|
task_state = TaskState.load(task_dir)
|
|
if task_state is None:
|
|
# tasks 行存在但 state.json 缺失:兜底重建(Step 3 后该分支会消失)
|
|
task_state = TaskState(
|
|
task_id=sid, mode=mode, description=description, status="active",
|
|
model=caps.model_id, model_profile=model,
|
|
cwd=str(tool_base), created_at=now_iso,
|
|
)
|
|
task_state.save(task_dir)
|
|
else:
|
|
# 提示 cwd 漂移(老 state.json 保留过启动时 cwd)
|
|
saved_cwd = task_state.cwd
|
|
if saved_cwd and console is not None and saved_cwd != str(tool_base):
|
|
console.print(
|
|
f"[warn]提示:[/warn] 当前 cwd 与 task 记录不同 —— "
|
|
f"工具基于 current cwd,不会自动切回。\n"
|
|
f" task cwd: [info]{saved_cwd}[/info]\n"
|
|
f" current cwd: [info]{tool_base}[/info]"
|
|
)
|
|
else:
|
|
session = Session(task_id=task_id, system_prompt=system_prompt, meta=meta)
|
|
# 懒创建:Session 不触发 DB 写,Task 行在首条 user 消息 append 时由
|
|
# ensure_local_task_row 插入;state.json 在 task_state.save 第一次调时落地。
|
|
task_state = TaskState(
|
|
task_id=sid, mode=mode, description=description, status="active",
|
|
model=caps.model_id, model_profile=model,
|
|
reasoning_effort=caps.default_reasoning_effort or "",
|
|
cwd=str(tool_base), created_at=now_iso,
|
|
)
|
|
|
|
tools = {}
|
|
for cls in (ReadTool, WriteTool, EditTool, GlobTool, GrepTool, ShellTool):
|
|
t = cls(base_dir=tool_base)
|
|
tools[t.name] = t
|
|
|
|
if skills.skills:
|
|
ls = LoadSkillTool(registry=skills, base_dir=tool_base)
|
|
tools[ls.name] = ls
|
|
|
|
if caps.enable_run_python:
|
|
rp = RunPythonTool(base_dir=tool_base)
|
|
tools[rp.name] = rp
|
|
|
|
sink = ConsoleEventSink(console, token_counter=lambda: llm.token_counter.total) if console else None
|
|
agent = AgentLoop(llm, tools, session, caps, sink=sink)
|
|
return agent, session, sid, task_state, task_dir
|
|
|
|
|
|
def sync_task_tokens(task_state: TaskState, task_dir: Path, llm: LLM) -> None:
|
|
"""每轮 agent.run 后调,把 LLM 累计 tokens 写回 state.json。"""
|
|
tc = llm.token_counter
|
|
task_state.tokens_prompt = tc.prompt_tokens
|
|
task_state.tokens_completion = tc.completion_tokens
|
|
task_state.save(task_dir)
|