zcbot/core/loop.py

261 lines
10 KiB
Python

"""主 agent loop: ReAct 风格,LLM ↔ Tool 反复直到无 tool_call。
loop 不直接 print —— 进度通过 sink.emit(event) 上抛。Sink 决定怎么呈现
(本地 console / SSE / 日志)。事件类型见 core/sinks.py 头部说明。
LLM 调用走 `chat_stream`(流式),chunk 之间 poll cancel_check 实现快速中断。
content delta 即时 emit `text` 事件让前端打字机渲染;chunks 攒齐后用
`litellm.stream_chunk_builder` 拼回完整 response 给 tool_calls 解析 + usage 记账。
"""
from __future__ import annotations
import json
import time
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple
from uuid import UUID
import litellm
from .capabilities import ModelCapabilities
from .executor import ExecCtx, Executor
from .llm import LLM
from .session import Session
from .storage import record_chat_usage
_CANCELLED_TOOL_PLACEHOLDER = "[cancelled by user]"
def _extract_delta_content(chunk: Any) -> Optional[str]:
"""从 stream chunk 提 delta.content(文本片段)。chunk 形态 litellm ModelResponseStream:
choices[0].delta.content。usage-only 收尾 chunk(没 choices / delta)返 None。
"""
try:
choices = getattr(chunk, "choices", None)
if not choices:
return None
delta = getattr(choices[0], "delta", None)
if delta is None:
return None
content = getattr(delta, "content", None)
return content if content else None
except Exception:
return None
def _extract_usage(usage: Any) -> Tuple[int, int]:
"""从 litellm response.usage 提 (prompt_tokens, completion_tokens)。"""
if not usage:
return 0, 0
if hasattr(usage, "model_dump"):
usage = usage.model_dump()
elif hasattr(usage, "dict"):
usage = usage.dict()
if isinstance(usage, dict):
return int(usage.get("prompt_tokens") or 0), int(usage.get("completion_tokens") or 0)
return 0, 0
class AgentLoop:
def __init__(
self,
llm: LLM,
executor: Executor,
session: Session,
capabilities: ModelCapabilities,
user_id: UUID,
working_dir: Path,
sink: Optional[Any] = None,
max_iterations: Optional[int] = None,
cancel_check: Optional[Callable[[], bool]] = None,
) -> None:
self.llm = llm
self.executor = executor
self.session = session
self.caps = capabilities
self.user_id = user_id # usage_events 写入时按 user 维度聚合
# ExecCtx 字段:user_id / task_id 已在,working_dir 单独传 —— 供 docker backend
# (Step 3)拼 `--workdir /workspace/<wd_name>` 与临时文件命名空间使用。
self.working_dir = working_dir
self.max_iterations = max_iterations or capabilities.max_iterations
self.sink = sink
# 协作式 cancel:web 层注入 `lambda: broker.is_cancelled(task_id)`;
# CLI 路径不设(None → 永不 cancel)。check 点在 ① 每轮 LLM 前 ② stream chunk 间
# ③ tool_calls 之间。chunk 间 poll 让 cancel 延迟从「整轮 generation 时长」
# (几十秒)降到「单 chunk 间隔」(~100ms)。
self.cancel_check = cancel_check
def _emit(self, event: dict) -> None:
if self.sink is not None:
self.sink.emit(event)
def _is_cancelled(self) -> bool:
return bool(self.cancel_check and self.cancel_check())
def _fill_cancelled_tool_results(self, remaining: list) -> None:
"""给未执行的 tool_call 补 cancelled tool result,保 LiteLLM 协议完整。
每个 assistant tool_call 必须有对应的 tool message,否则 resume 时 LLM 报错。
"""
for tc in remaining:
self.session.append({
"role": "tool",
"tool_call_id": tc.id,
"content": _CANCELLED_TOOL_PLACEHOLDER,
})
def run(self, user_message: str) -> str:
self.session.append({"role": "user", "content": user_message})
for _ in range(self.max_iterations):
if self._is_cancelled():
self._emit({"type": "cancelled"})
return "[cancelled]"
self._emit({"type": "llm_start"})
start = time.monotonic()
response, cancelled_mid_stream = self._stream_llm()
elapsed = time.monotonic() - start
if cancelled_mid_stream:
# 流中途收到 cancel:已接收的 chunk 丢弃,不入库不记账(部分 assistant
# 内容也不持久化,下次 resume 上下文干净)。response 可能是 None。
self._emit({"type": "cancelled"})
return "[cancelled]"
msg = response.choices[0].message
asst_msg_id = self.session.append(msg)
pt, ct = _extract_usage(getattr(response, "usage", None))
# 记账(0006):一行 usage_event + 回填 messages.tokens_in/out + model_profile。
# 任何失败都吞掉(litellm cost map miss / DB 异常),不阻塞主 loop;
# message 仍在 session/DB 里,后续重启不影响。
model_profile = f"{self.caps.family}.{self.caps.variant}"
try:
record_chat_usage(
task_id=self.session.task_id,
user_id=self.user_id,
message_id=asst_msg_id,
model_profile=model_profile,
prompt_tokens=pt,
completion_tokens=ct,
response=response,
)
except Exception as e:
self._emit({"type": "warn", "msg": f"record_usage failed: {type(e).__name__}: {e}"})
self._emit({
"type": "llm_end",
"prompt_tokens": pt,
"completion_tokens": ct,
"elapsed": elapsed,
})
tool_calls = getattr(msg, "tool_calls", None) or []
# content 已通过 stream 流式 emit 过 delta,这里不再 emit 整段 text 事件。
if not tool_calls:
self._emit({"type": "done"})
return getattr(msg, "content", None) or ""
for i, tc in enumerate(tool_calls):
if self._is_cancelled():
self._fill_cancelled_tool_results(tool_calls[i:])
self._emit({"type": "cancelled"})
return "[cancelled]"
result = self._execute_tool_call(tc)
self.session.append(
{
"role": "tool",
"tool_call_id": tc.id,
"name": tc.function.name,
"content": result,
}
)
self._emit({"type": "done"})
return "[reached max iterations]"
def _stream_llm(self) -> Tuple[Optional[Any], bool]:
"""流式拉一轮 LLM,chunk 间 poll cancel,content delta 即时 emit。
返回 (response, cancelled_mid_stream):
- 正常完结 → (response, False);response 由 litellm.stream_chunk_builder 拼回,
shape 与非流式 completion() 等价(choices[0].message + usage)
- 中途 cancel → (None, True);已收 chunk 丢弃,内层 generator 在 finally 关闭底层连接
"""
chunks: List[Any] = []
stream = self.llm.chat_stream(
messages=self.session.messages,
tools=self.executor.schemas(),
reasoning_effort=self.caps.default_reasoning_effort or None,
)
cancelled = False
try:
for chunk in stream:
if self._is_cancelled():
cancelled = True
break
chunks.append(chunk)
# delta.content 即时 emit 给前端打字机渲染;tool_call delta 不实时发
# (拼接散在多 chunk 跨 frame 难看,等拼回后整条 tool_call 事件由
# _execute_tool_call 时机发更直观)。
delta_text = _extract_delta_content(chunk)
if delta_text:
self._emit({"type": "text", "delta": delta_text})
finally:
# generator 提前 break 时 GeneratorExit 触发 chat_stream finally → close 底层连接
stream.close()
if cancelled:
return None, True
# 用 litellm 官方 helper 拼回完整 response(包括 tool_calls 拼接 + usage)。
# messages 参数仅用于失败时回填 prompt token 估算,正常路径 stream_options.include_usage
# 已让最后一个 chunk 带准确 usage。
response = litellm.stream_chunk_builder(chunks, messages=self.session.messages)
return response, False
def _execute_tool_call(self, tc: Any) -> str:
name = tc.function.name
raw_args = tc.function.arguments or "{}"
try:
args = json.loads(raw_args)
except json.JSONDecodeError as e:
return f"[Error] invalid JSON arguments for {name}: {e}"
args_preview = json.dumps(args, ensure_ascii=False)
if len(args_preview) > 200:
args_preview = args_preview[:200] + "..."
self._emit({
"type": "tool_call",
"name": name,
"args": args,
"args_preview": args_preview,
})
ctx = ExecCtx(
user_id=self.user_id,
task_id=self.session.task_id,
working_dir=self.working_dir,
cancel_check=self.cancel_check,
)
result = self.executor.call_tool(name, args, ctx).content
# 控制返回给模型的 tool 结果体量,避免炸 context
MAX_LEN = 16_000
truncated = False
if len(result) > MAX_LEN:
result = result[:MAX_LEN] + f"\n[... truncated, {len(result) - MAX_LEN} chars ...]"
truncated = True
preview = result if len(result) < 400 else result[:400] + "..."
self._emit({
"type": "tool_result",
"name": name,
"result": result,
"preview": preview,
"truncated": truncated,
})
return result