zcbot/core/loop.py

204 lines
7.4 KiB
Python

"""主 agent loop: ReAct 风格,LLM ↔ Tool 反复直到无 tool_call。
loop 不直接 print —— 进度通过 sink.emit(event) 上抛。Sink 决定怎么呈现
(本地 console / SSE / 日志)。事件类型见 core/sinks.py 头部说明。
"""
from __future__ import annotations
import json
import time
from typing import Any, Callable, Dict, Optional, Tuple
from uuid import UUID
from .capabilities import ModelCapabilities
from .llm import LLM
from .session import Session
from .storage import record_chat_usage
_CANCELLED_TOOL_PLACEHOLDER = "[cancelled by user]"
def _extract_usage(usage: Any) -> Tuple[int, int]:
"""从 litellm response.usage 提 (prompt_tokens, completion_tokens)。"""
if not usage:
return 0, 0
if hasattr(usage, "model_dump"):
usage = usage.model_dump()
elif hasattr(usage, "dict"):
usage = usage.dict()
if isinstance(usage, dict):
return int(usage.get("prompt_tokens") or 0), int(usage.get("completion_tokens") or 0)
return 0, 0
class AgentLoop:
def __init__(
self,
llm: LLM,
tools: Dict[str, Any],
session: Session,
capabilities: ModelCapabilities,
user_id: UUID,
sink: Optional[Any] = None,
max_iterations: Optional[int] = None,
cancel_check: Optional[Callable[[], bool]] = None,
) -> None:
self.llm = llm
self.tools = tools
self.session = session
self.caps = capabilities
self.user_id = user_id # usage_events 写入时按 user 维度聚合
self.max_iterations = max_iterations or capabilities.max_iterations
self.sink = sink
# 协作式 cancel:web 层注入 `lambda: broker.is_cancelled(run_id)`;
# CLI 路径不设(None → 永不 cancel)。LLM 调用本身是 litellm 同步阻塞、不可中断,
# check 点放在每轮 LLM 前、tool_calls 之间;一次 LLM call 最坏卡几十秒。
self.cancel_check = cancel_check
def _emit(self, event: dict) -> None:
if self.sink is not None:
self.sink.emit(event)
def _is_cancelled(self) -> bool:
return bool(self.cancel_check and self.cancel_check())
def _fill_cancelled_tool_results(self, remaining: list) -> None:
"""给未执行的 tool_call 补 cancelled tool result,保 LiteLLM 协议完整。
每个 assistant tool_call 必须有对应的 tool message,否则 resume 时 LLM 报错。
"""
for tc in remaining:
self.session.append({
"role": "tool",
"tool_call_id": tc.id,
"content": _CANCELLED_TOOL_PLACEHOLDER,
})
def run(self, user_message: str) -> str:
self.session.append({"role": "user", "content": user_message})
for _ in range(self.max_iterations):
if self._is_cancelled():
self._emit({"type": "cancelled"})
return "[cancelled]"
self._emit({"type": "llm_start"})
start = time.monotonic()
response = self.llm.chat(
messages=self.session.messages,
tools=[t.schema for t in self.tools.values()],
reasoning_effort=self.caps.default_reasoning_effort or None,
)
elapsed = time.monotonic() - start
msg = response.choices[0].message
asst_msg_id = self.session.append(msg)
pt, ct = _extract_usage(getattr(response, "usage", None))
# 记账(0006):一行 usage_event + 回填 messages.tokens_in/out + model_profile。
# 任何失败都吞掉(litellm cost map miss / DB 异常),不阻塞主 loop;
# message 仍在 session/DB 里,后续重启不影响。
model_profile = f"{self.caps.family}.{self.caps.variant}"
try:
record_chat_usage(
task_id=self.session.task_id,
user_id=self.user_id,
message_id=asst_msg_id,
model_profile=model_profile,
prompt_tokens=pt,
completion_tokens=ct,
response=response,
)
except Exception as e:
self._emit({"type": "warn", "msg": f"record_usage failed: {type(e).__name__}: {e}"})
self._emit({
"type": "llm_end",
"prompt_tokens": pt,
"completion_tokens": ct,
"elapsed": elapsed,
})
tool_calls = getattr(msg, "tool_calls", None) or []
content = getattr(msg, "content", None)
if content:
self._emit({"type": "text", "content": content})
if not tool_calls:
self._emit({"type": "done"})
return content or ""
for i, tc in enumerate(tool_calls):
if self._is_cancelled():
self._fill_cancelled_tool_results(tool_calls[i:])
self._emit({"type": "cancelled"})
return "[cancelled]"
result = self._execute_tool_call(tc)
self.session.append(
{
"role": "tool",
"tool_call_id": tc.id,
"content": result,
}
)
self._emit({"type": "done"})
return "[reached max iterations]"
def _execute_tool_call(self, tc: Any) -> str:
name = tc.function.name
raw_args = tc.function.arguments or "{}"
try:
args = json.loads(raw_args)
except json.JSONDecodeError as e:
return f"[Error] invalid JSON arguments for {name}: {e}"
args_preview = json.dumps(args, ensure_ascii=False)
if len(args_preview) > 200:
args_preview = args_preview[:200] + "..."
self._emit({
"type": "tool_call",
"name": name,
"args": args,
"args_preview": args_preview,
})
tool = self.tools.get(name)
if tool is None:
err = f"[Error] unknown tool: {name}"
self._emit({"type": "tool_result", "name": name, "result": err,
"preview": err, "truncated": False})
return err
try:
result = tool.execute(**args)
except TypeError as e:
err = f"[Error] bad arguments to {name}: {e}"
self._emit({"type": "tool_result", "name": name, "result": err,
"preview": err, "truncated": False})
return err
except Exception as e:
err = f"[Error executing {name}] {type(e).__name__}: {e}"
self._emit({"type": "tool_result", "name": name, "result": err,
"preview": err, "truncated": False})
return err
if not isinstance(result, str):
result = str(result)
# 控制返回给模型的 tool 结果体量,避免炸 context
MAX_LEN = 16_000
truncated = False
if len(result) > MAX_LEN:
result = result[:MAX_LEN] + f"\n[... truncated, {len(result) - MAX_LEN} chars ...]"
truncated = True
preview = result if len(result) < 400 else result[:400] + "..."
self._emit({
"type": "tool_result",
"name": name,
"result": result,
"preview": preview,
"truncated": truncated,
})
return result