186 lines
6.8 KiB
Python
186 lines
6.8 KiB
Python
"""装配入口: 读 config → 加载 capabilities/skills → 构造 LLM/tools/session/loop。
|
|
|
|
存储布局:
|
|
workspace/tasks/<task_id>/state.json ← TaskState
|
|
workspace/tasks/<task_id>/messages.json ← Session 消息
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Optional, Tuple
|
|
|
|
import yaml
|
|
from rich.console import Console
|
|
|
|
from core.capabilities import ModelCapabilities
|
|
from core.llm import LLM
|
|
from core.loop import AgentLoop
|
|
from core.session import Session
|
|
from core.skills import SkillRegistry
|
|
from core.task import TaskState
|
|
from tools.fs import EditTool, GlobTool, GrepTool, ReadTool, WriteTool
|
|
from tools.run_python import RunPythonTool
|
|
from tools.shell import ShellTool
|
|
from tools.skill_tool import LoadSkillTool
|
|
|
|
ROOT = Path(__file__).resolve().parent
|
|
|
|
|
|
def load_config() -> dict:
|
|
return yaml.safe_load((ROOT / "config" / "agent.yaml").read_text(encoding="utf-8")) or {}
|
|
|
|
|
|
def resolve_workspace(workspace: Optional[str], cfg: Optional[dict] = None) -> Path:
|
|
cfg = cfg or load_config()
|
|
p = Path(workspace) if workspace else ROOT / cfg.get("workspace_dir", "workspace")
|
|
p.mkdir(parents=True, exist_ok=True)
|
|
return p
|
|
|
|
|
|
def tasks_dir(workspace_dir: Path) -> Path:
|
|
d = workspace_dir / "tasks"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
return d
|
|
|
|
|
|
def resolve_task_messages_path(
|
|
workspace_dir: Path, task_id: Optional[str], resume: bool
|
|
) -> Tuple[Path, str]:
|
|
"""返回 (messages_file_path, task_id)。
|
|
新建:tasks/<id>/messages.json;Resume:tasks/<id>/messages.json,'last' 取最新。
|
|
"""
|
|
tdir = tasks_dir(workspace_dir)
|
|
if resume:
|
|
if task_id in (None, "", "last"):
|
|
candidates = []
|
|
for d in tdir.iterdir():
|
|
mf = d / "messages.json"
|
|
if mf.is_file():
|
|
candidates.append((mf.stat().st_mtime, mf, d.name))
|
|
if not candidates:
|
|
raise FileNotFoundError(f"无可恢复的 task: {tdir} 下无 task")
|
|
candidates.sort(key=lambda x: x[0], reverse=True)
|
|
_, path, sid = candidates[0]
|
|
return path, sid
|
|
task_msg = tdir / task_id / "messages.json"
|
|
if not task_msg.exists():
|
|
raise FileNotFoundError(f"task 不存在: {task_msg}")
|
|
return task_msg, task_id
|
|
|
|
sid = task_id or datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
return tdir / sid / "messages.json", sid
|
|
|
|
|
|
def build_agent(
|
|
model_name: Optional[str] = None,
|
|
workspace: Optional[str] = None,
|
|
console: Optional[Console] = None,
|
|
session_id: Optional[str] = None,
|
|
resume: bool = False,
|
|
tool_base: Optional[Path] = None,
|
|
mode: str = "",
|
|
description: str = "",
|
|
) -> Tuple[AgentLoop, Session, str, TaskState, Path]:
|
|
"""返回 (agent, session, task_id, task_state, task_dir)。"""
|
|
cfg = load_config()
|
|
model = model_name or cfg["default_model"]
|
|
|
|
caps = ModelCapabilities.load(model, ROOT / cfg["models_dir"])
|
|
llm = LLM(caps)
|
|
|
|
workspace_dir = resolve_workspace(workspace, cfg)
|
|
session_path, sid = resolve_task_messages_path(workspace_dir, session_id, resume)
|
|
|
|
tool_base = Path(tool_base) if tool_base else Path.cwd()
|
|
|
|
skills = SkillRegistry(ROOT / cfg.get("skills_dir", "skills"))
|
|
|
|
task_dir = session_path.parent
|
|
|
|
if resume:
|
|
session = Session.load(session_path)
|
|
saved_cwd = session.meta.get("cwd")
|
|
if saved_cwd and console is not None and saved_cwd != str(tool_base):
|
|
console.print(
|
|
f"[yellow]提示:[/yellow] 当前 cwd 与 task 记录不同 —— "
|
|
f"工具基于 current cwd,不会自动切回。\n"
|
|
f" task cwd: [dim]{saved_cwd}[/dim]\n"
|
|
f" current cwd: [dim]{tool_base}[/dim]"
|
|
)
|
|
task_state = TaskState.load(task_dir)
|
|
if task_state is None:
|
|
# messages.json 存在但 state.json 缺失:用 session.meta 兜底重建
|
|
task_state = TaskState(
|
|
task_id=sid,
|
|
mode=mode,
|
|
description=description,
|
|
status="active",
|
|
model=session.meta.get("model", caps.model_id),
|
|
model_profile=session.meta.get("model_profile", model),
|
|
cwd=session.meta.get("cwd", str(tool_base)),
|
|
created_at=session.meta.get("created_at", datetime.now().isoformat(timespec="seconds")),
|
|
)
|
|
task_state.save(task_dir)
|
|
else:
|
|
system_prompt = (ROOT / cfg["system_prompt"]).read_text(encoding="utf-8")
|
|
if skills.skills:
|
|
system_prompt += f"\n\n## 可用 skill (用 load_skill 加载完整指引)\n{skills.discovery_block()}"
|
|
task_dir_abs = task_dir.resolve()
|
|
system_prompt += (
|
|
f"\n\n## 工作目录\n"
|
|
f"- cwd(用户启动时所在目录,只读用): `{tool_base}`\n"
|
|
f"- **task_dir(所有产物写到这里)**: `{task_dir_abs}`\n\n"
|
|
f"SKILL 文档里出现的 `<task_dir>` 占位符,一律指上面这个绝对路径。"
|
|
f"产物示例: `{task_dir_abs}/spec_lock.md`、"
|
|
f"`{task_dir_abs}/sections/01_summary.md`、"
|
|
f"`{task_dir_abs}/slides/`、最终 .docx/.pptx。\n"
|
|
f"⛔ 不要把产物写到 cwd / `skills/` / repo 根 —— 只写到 task_dir。"
|
|
)
|
|
now_iso = datetime.now().isoformat(timespec="seconds")
|
|
meta = {
|
|
"id": sid,
|
|
"created_at": now_iso,
|
|
"cwd": str(tool_base),
|
|
"model": caps.model_id,
|
|
"model_profile": model,
|
|
}
|
|
session = Session(system_prompt=system_prompt, path=session_path, meta=meta)
|
|
session.save() # 占住文件名
|
|
task_state = TaskState(
|
|
task_id=sid,
|
|
mode=mode,
|
|
description=description,
|
|
status="active",
|
|
model=caps.model_id,
|
|
model_profile=model,
|
|
reasoning_effort=caps.default_reasoning_effort or "",
|
|
cwd=str(tool_base),
|
|
created_at=now_iso,
|
|
)
|
|
task_state.save(task_dir)
|
|
|
|
tools = {}
|
|
for cls in (ReadTool, WriteTool, EditTool, GlobTool, GrepTool, ShellTool):
|
|
t = cls(base_dir=tool_base)
|
|
tools[t.name] = t
|
|
|
|
if skills.skills:
|
|
ls = LoadSkillTool(registry=skills, base_dir=tool_base)
|
|
tools[ls.name] = ls
|
|
|
|
if caps.enable_run_python:
|
|
rp = RunPythonTool(base_dir=tool_base)
|
|
tools[rp.name] = rp
|
|
|
|
agent = AgentLoop(llm, tools, session, caps, console=console)
|
|
return agent, session, sid, task_state, task_dir
|
|
|
|
|
|
def sync_task_tokens(task_state: TaskState, task_dir: Path, llm: LLM) -> None:
|
|
"""每轮 agent.run 后调,把 LLM 累计 tokens 写回 state.json。"""
|
|
tc = llm.token_counter
|
|
task_state.tokens_prompt = tc.prompt_tokens
|
|
task_state.tokens_completion = tc.completion_tokens
|
|
task_state.save(task_dir)
|