385 lines
16 KiB
Python
385 lines
16 KiB
Python
"""主 agent loop: ReAct 风格,LLM ↔ Tool 反复直到无 tool_call。
|
|
|
|
loop 不直接 print —— 进度通过 sink.emit(event) 上抛。Sink 决定怎么呈现
|
|
(本地 console / SSE / 日志)。事件类型见 core/sinks.py 头部说明。
|
|
|
|
LLM 调用走 `chat_stream`(流式),chunk 之间 poll cancel_check 实现快速中断。
|
|
content delta 即时 emit `text` 事件让前端打字机渲染;chunks 攒齐后用
|
|
`litellm.stream_chunk_builder` 拼回完整 response 给 tool_calls 解析 + usage 记账。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Callable, List, Optional, Tuple
|
|
|
|
from uuid import UUID
|
|
|
|
import litellm
|
|
|
|
from .capabilities import ModelCapabilities
|
|
from .context import prepare_messages_with_stats
|
|
from .executor import ExecCtx, Executor
|
|
from .llm import LLM
|
|
from .session import Session
|
|
from .storage import record_chat_usage
|
|
|
|
|
|
_CANCELLED_TOOL_PLACEHOLDER = "[cancelled by user]"
|
|
|
|
|
|
def _extract_delta_content(chunk: Any) -> Optional[str]:
|
|
"""从 stream chunk 提 delta.content(文本片段)。chunk 形态 litellm ModelResponseStream:
|
|
choices[0].delta.content。usage-only 收尾 chunk(没 choices / delta)返 None。
|
|
"""
|
|
try:
|
|
choices = getattr(chunk, "choices", None)
|
|
if not choices:
|
|
return None
|
|
delta = getattr(choices[0], "delta", None)
|
|
if delta is None:
|
|
return None
|
|
content = getattr(delta, "content", None)
|
|
return content if content else None
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _malformed_tool_calls(response: Any) -> List[str]:
|
|
"""检出 arguments 损坏(JSON 解析不了)的 tool_call,返回 [name(len=N), ...]。
|
|
|
|
背景:deepseek-v4-flash 大参数工具调用偶发畸形 —— 流式 delta 错位把别处的内容
|
|
碎片粘到 arguments 开头(如 `].cells[1].merge(...{"path":...}`),拼回来后 JSON
|
|
解析直接失败。这种是上游瞬时抖动,不该入库污染上下文,调用方据此丢弃整轮重 roll。
|
|
|
|
只看「解析失败」;空字符串 / 合法空对象不算畸形(交给 executor 按缺参数处理)。
|
|
"""
|
|
try:
|
|
msg = response.choices[0].message
|
|
except Exception:
|
|
return []
|
|
bad: List[str] = []
|
|
for tc in (getattr(msg, "tool_calls", None) or []):
|
|
raw = (getattr(tc.function, "arguments", None) or "").strip()
|
|
if not raw:
|
|
continue
|
|
try:
|
|
json.loads(raw)
|
|
except (json.JSONDecodeError, ValueError):
|
|
bad.append(f"{tc.function.name}(len={len(raw)})")
|
|
return bad
|
|
|
|
|
|
def _usage_to_dict(usage: Any) -> dict:
|
|
if not usage:
|
|
return {}
|
|
if hasattr(usage, "model_dump"):
|
|
usage = usage.model_dump()
|
|
elif hasattr(usage, "dict"):
|
|
usage = usage.dict()
|
|
if isinstance(usage, dict):
|
|
return usage
|
|
return {}
|
|
|
|
|
|
def _extract_usage_details(usage: Any) -> dict:
|
|
"""从 provider usage 提取统一 token 明细。
|
|
|
|
DeepSeek 直接给 prompt_cache_hit_tokens / prompt_cache_miss_tokens;
|
|
OpenAI 风格把 cached tokens 放在 prompt_tokens_details.cached_tokens。
|
|
"""
|
|
data = _usage_to_dict(usage)
|
|
prompt_details = data.get("prompt_tokens_details") or {}
|
|
completion_details = data.get("completion_tokens_details") or {}
|
|
if not isinstance(prompt_details, dict):
|
|
prompt_details = {}
|
|
if not isinstance(completion_details, dict):
|
|
completion_details = {}
|
|
|
|
cache_hit = (
|
|
data.get("prompt_cache_hit_tokens")
|
|
or prompt_details.get("cached_tokens")
|
|
or 0
|
|
)
|
|
cache_miss = data.get("prompt_cache_miss_tokens") or 0
|
|
return {
|
|
"tokens_in": int(data.get("prompt_tokens") or 0),
|
|
"tokens_out": int(data.get("completion_tokens") or 0),
|
|
"cache_hit_tokens": int(cache_hit or 0),
|
|
"cache_miss_tokens": int(cache_miss or 0),
|
|
"reasoning_tokens": int(completion_details.get("reasoning_tokens") or 0),
|
|
}
|
|
|
|
|
|
def _extract_usage(usage: Any) -> Tuple[int, int]:
|
|
"""从 litellm response.usage 提 (prompt_tokens, completion_tokens)。"""
|
|
details = _extract_usage_details(usage)
|
|
return details["tokens_in"], details["tokens_out"]
|
|
|
|
|
|
class AgentLoop:
|
|
def __init__(
|
|
self,
|
|
llm: LLM,
|
|
executor: Executor,
|
|
session: Session,
|
|
capabilities: ModelCapabilities,
|
|
user_id: UUID,
|
|
working_dir: Path,
|
|
sink: Optional[Any] = None,
|
|
max_iterations: Optional[int] = None,
|
|
cancel_check: Optional[Callable[[], bool]] = None,
|
|
) -> None:
|
|
self.llm = llm
|
|
self.executor = executor
|
|
self.session = session
|
|
self.caps = capabilities
|
|
self.user_id = user_id # usage_events 写入时按 user 维度聚合
|
|
# ExecCtx 字段:user_id / task_id 已在,working_dir 单独传 —— 供 docker backend
|
|
# (Step 3)拼 `--workdir /workspace/<wd_name>` 与临时文件命名空间使用。
|
|
self.working_dir = working_dir
|
|
self.max_iterations = max_iterations or capabilities.max_iterations
|
|
self.sink = sink
|
|
# 协作式 cancel:web 层注入 `lambda: broker.is_cancelled(task_id)`;
|
|
# CLI 路径不设(None → 永不 cancel)。check 点在 ① 每轮 LLM 前 ② stream chunk 间
|
|
# ③ tool_calls 之间。chunk 间 poll 让 cancel 延迟从「整轮 generation 时长」
|
|
# (几十秒)降到「单 chunk 间隔」(~100ms)。
|
|
self.cancel_check = cancel_check
|
|
|
|
def _emit(self, event: dict) -> None:
|
|
if self.sink is not None:
|
|
self.sink.emit(event)
|
|
|
|
def _is_cancelled(self) -> bool:
|
|
return bool(self.cancel_check and self.cancel_check())
|
|
|
|
def _fill_cancelled_tool_results(self, remaining: list) -> None:
|
|
"""给未执行的 tool_call 补 cancelled tool result,保 LiteLLM 协议完整。
|
|
每个 assistant tool_call 必须有对应的 tool message,否则 resume 时 LLM 报错。
|
|
"""
|
|
for tc in remaining:
|
|
self.session.append({
|
|
"role": "tool",
|
|
"tool_call_id": tc.id,
|
|
"content": _CANCELLED_TOOL_PLACEHOLDER,
|
|
})
|
|
|
|
def run(self, user_message: str) -> str:
|
|
self.session.append({"role": "user", "content": user_message})
|
|
|
|
for _ in range(self.max_iterations):
|
|
if self._is_cancelled():
|
|
self._emit({"type": "cancelled"})
|
|
return "[cancelled]"
|
|
|
|
start = time.monotonic()
|
|
response, cancelled_mid_stream = self._stream_llm()
|
|
elapsed = time.monotonic() - start
|
|
|
|
if cancelled_mid_stream:
|
|
# 流中途收到 cancel:已接收的 chunk 丢弃,不入库不记账(部分 assistant
|
|
# 内容也不持久化,下次 resume 上下文干净)。response 可能是 None。
|
|
self._emit({"type": "cancelled"})
|
|
return "[cancelled]"
|
|
|
|
msg = response.choices[0].message
|
|
asst_msg_id = self.session.append(msg)
|
|
|
|
usage_details = _extract_usage_details(getattr(response, "usage", None))
|
|
pt, ct = usage_details["tokens_in"], usage_details["tokens_out"]
|
|
# 记账(0006):一行 usage_event + 回填 messages.tokens_in/out + model_profile。
|
|
# 任何失败都吞掉(litellm cost map miss / DB 异常),不阻塞主 loop;
|
|
# message 仍在 session/DB 里,后续重启不影响。
|
|
model_profile = f"{self.caps.family}.{self.caps.variant}"
|
|
try:
|
|
record_chat_usage(
|
|
task_id=self.session.task_id,
|
|
user_id=self.user_id,
|
|
message_id=asst_msg_id,
|
|
model_profile=model_profile,
|
|
prompt_tokens=pt,
|
|
completion_tokens=ct,
|
|
input_cny_per_mtoken=self.caps.input_cny_per_mtoken,
|
|
output_cny_per_mtoken=self.caps.output_cny_per_mtoken,
|
|
cache_hit_tokens=usage_details["cache_hit_tokens"],
|
|
cache_hit_cny_per_mtoken=self.caps.cache_hit_cny_per_mtoken,
|
|
extra_units={
|
|
k: v for k, v in usage_details.items()
|
|
if k not in ("tokens_in", "tokens_out") and v
|
|
},
|
|
response=response,
|
|
)
|
|
except Exception as e:
|
|
self._emit({"type": "warn", "msg": f"record_usage failed: {type(e).__name__}: {e}"})
|
|
self._emit({
|
|
"type": "llm_end",
|
|
"prompt_tokens": pt,
|
|
"completion_tokens": ct,
|
|
"cache_hit_tokens": usage_details["cache_hit_tokens"],
|
|
"cache_miss_tokens": usage_details["cache_miss_tokens"],
|
|
"elapsed": elapsed,
|
|
})
|
|
|
|
tool_calls = getattr(msg, "tool_calls", None) or []
|
|
# content 已通过 stream 流式 emit 过 delta,这里不再 emit 整段 text 事件。
|
|
|
|
if not tool_calls:
|
|
self._emit({"type": "done"})
|
|
return getattr(msg, "content", None) or ""
|
|
|
|
for i, tc in enumerate(tool_calls):
|
|
if self._is_cancelled():
|
|
self._fill_cancelled_tool_results(tool_calls[i:])
|
|
self._emit({"type": "cancelled"})
|
|
return "[cancelled]"
|
|
result = self._execute_tool_call(tc)
|
|
self.session.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tc.id,
|
|
"name": tc.function.name,
|
|
"content": result,
|
|
}
|
|
)
|
|
|
|
self._emit({"type": "done"})
|
|
return "[reached max iterations]"
|
|
|
|
# 工具参数畸形时,丢弃整轮重 roll 的最大次数;第 _MAX_MALFORMED_RETRIES 次(即最后
|
|
# 一次)降级走非流式(provider 服务端拼 tool_calls,绕开流式 delta 错位)。实测大参数
|
|
# 工具调用偶发连续两次畸形,故留够重试余量。
|
|
_MAX_MALFORMED_RETRIES = 3
|
|
|
|
def _stream_llm(self) -> Tuple[Optional[Any], bool]:
|
|
"""拉一轮 LLM 并保证返回的 tool_call arguments 可解析。
|
|
|
|
返回 (response, cancelled_mid_stream):
|
|
- 正常完结 → (response, False);response shape 与非流式 completion() 等价
|
|
(choices[0].message + usage)
|
|
- 中途 cancel → (None, True);已收 chunk 丢弃,内层 generator 在 finally 关闭底层连接
|
|
|
|
畸形重试:deepseek-v4-flash 大参数工具调用偶发把内容碎片错位粘进 arguments,拼回
|
|
后 JSON 解析失败。这种损坏一旦入库会被每轮重发、诱导模型继续学坏(投毒级联)。
|
|
故拼回后先校验 tool_call arguments 能否解析:不能 → 丢弃整轮(不 append/不记账)重
|
|
roll;连续失败到最后一次降级非流式兜底。重试消耗的 token 不单独记账(罕见路径)。
|
|
"""
|
|
llm_messages, context_stats = prepare_messages_with_stats(self.session.messages)
|
|
self._emit({
|
|
"type": "llm_start",
|
|
**{f"context_{k}": v for k, v in context_stats.items()},
|
|
})
|
|
for attempt in range(self._MAX_MALFORMED_RETRIES + 1):
|
|
use_nonstream = attempt == self._MAX_MALFORMED_RETRIES
|
|
if use_nonstream:
|
|
response = self._nonstream_once(llm_messages)
|
|
else:
|
|
response, cancelled = self._collect_stream_once(llm_messages)
|
|
if cancelled:
|
|
return None, True
|
|
|
|
bad = _malformed_tool_calls(response)
|
|
if not bad:
|
|
return response, False
|
|
self._emit({
|
|
"type": "warn",
|
|
"msg": f"工具调用参数损坏 {bad},丢弃本轮重试 ({attempt + 1}/{self._MAX_MALFORMED_RETRIES})",
|
|
})
|
|
# 非流式兜底仍畸形(理论极罕见):交还给 _execute_tool_call 的 invalid-JSON 分支
|
|
# 优雅返错给模型,而非在此死循环。
|
|
return response, False
|
|
|
|
def _collect_stream_once(self, llm_messages: List[dict]) -> Tuple[Optional[Any], bool]:
|
|
"""跑一次流式:攒 chunk + content delta 即时 emit,拼回完整 response。
|
|
返回 (response, cancelled_mid_stream)。"""
|
|
chunks: List[Any] = []
|
|
stream = self.llm.chat_stream(
|
|
messages=llm_messages,
|
|
tools=self.executor.schemas(),
|
|
reasoning_effort=self.caps.default_reasoning_effort or None,
|
|
)
|
|
cancelled = False
|
|
try:
|
|
for chunk in stream:
|
|
if self._is_cancelled():
|
|
cancelled = True
|
|
break
|
|
chunks.append(chunk)
|
|
# delta.content 即时 emit 给前端打字机渲染;tool_call delta 不实时发
|
|
# (拼接散在多 chunk 跨 frame 难看,等拼回后整条 tool_call 事件由
|
|
# _execute_tool_call 时机发更直观)。
|
|
delta_text = _extract_delta_content(chunk)
|
|
if delta_text:
|
|
self._emit({"type": "text", "delta": delta_text})
|
|
finally:
|
|
# generator 提前 break 时 GeneratorExit 触发 chat_stream finally → close 底层连接
|
|
stream.close()
|
|
|
|
if cancelled:
|
|
return None, True
|
|
|
|
# 用 litellm 官方 helper 拼回完整 response(包括 tool_calls 拼接 + usage)。
|
|
# messages 参数仅用于失败时回填 prompt token 估算,正常路径 stream_options.include_usage
|
|
# 已让最后一个 chunk 带准确 usage。
|
|
response = litellm.stream_chunk_builder(chunks, messages=llm_messages)
|
|
return response, False
|
|
|
|
def _nonstream_once(self, llm_messages: List[dict]) -> Any:
|
|
"""非流式兜底:provider 服务端一次性拼好 tool_calls,绕开流式 delta 错位。
|
|
没有 chunk 级 cancel,content 也拿不到 delta —— 整段 text 一次性补 emit。"""
|
|
response = self.llm.chat(
|
|
messages=llm_messages,
|
|
tools=self.executor.schemas(),
|
|
reasoning_effort=self.caps.default_reasoning_effort or None,
|
|
)
|
|
try:
|
|
text = getattr(response.choices[0].message, "content", None)
|
|
if text:
|
|
self._emit({"type": "text", "delta": text})
|
|
except Exception:
|
|
pass
|
|
return response
|
|
|
|
def _execute_tool_call(self, tc: Any) -> str:
|
|
name = tc.function.name
|
|
raw_args = tc.function.arguments or "{}"
|
|
try:
|
|
args = json.loads(raw_args)
|
|
except json.JSONDecodeError as e:
|
|
return f"[Error] invalid JSON arguments for {name}: {e}"
|
|
|
|
args_preview = json.dumps(args, ensure_ascii=False)
|
|
if len(args_preview) > 200:
|
|
args_preview = args_preview[:200] + "..."
|
|
self._emit({
|
|
"type": "tool_call",
|
|
"name": name,
|
|
"args": args,
|
|
"args_preview": args_preview,
|
|
})
|
|
|
|
ctx = ExecCtx(
|
|
user_id=self.user_id,
|
|
task_id=self.session.task_id,
|
|
working_dir=self.working_dir,
|
|
cancel_check=self.cancel_check,
|
|
)
|
|
result = self.executor.call_tool(name, args, ctx).content
|
|
|
|
# 控制返回给模型的 tool 结果体量,避免炸 context
|
|
MAX_LEN = 16_000
|
|
truncated = False
|
|
if len(result) > MAX_LEN:
|
|
result = result[:MAX_LEN] + f"\n[... truncated, {len(result) - MAX_LEN} chars ...]"
|
|
truncated = True
|
|
|
|
preview = result if len(result) < 400 else result[:400] + "..."
|
|
self._emit({
|
|
"type": "tool_result",
|
|
"name": name,
|
|
"result": result,
|
|
"preview": preview,
|
|
"truncated": truncated,
|
|
})
|
|
return result
|