zcbot/core/loop.py

158 lines
5.1 KiB
Python

"""主 agent loop: ReAct 风格,LLM ↔ Tool 反复直到无 tool_call。"""
from __future__ import annotations
import json
import threading
import time
from contextlib import contextmanager
from typing import Any, Dict, Optional, Tuple
from rich.console import Console
from rich.markdown import Markdown
from .capabilities import ModelCapabilities
from .llm import LLM
from .session import Session
def _extract_usage(usage: Any) -> Tuple[int, int]:
"""从 litellm response.usage 提 (prompt_tokens, completion_tokens)。"""
if not usage:
return 0, 0
if hasattr(usage, "model_dump"):
usage = usage.model_dump()
elif hasattr(usage, "dict"):
usage = usage.dict()
if isinstance(usage, dict):
return int(usage.get("prompt_tokens") or 0), int(usage.get("completion_tokens") or 0)
return 0, 0
class AgentLoop:
def __init__(
self,
llm: LLM,
tools: Dict[str, Any],
session: Session,
capabilities: ModelCapabilities,
console: Optional[Console] = None,
max_iterations: Optional[int] = None,
) -> None:
self.llm = llm
self.tools = tools
self.session = session
self.caps = capabilities
self.max_iterations = max_iterations or capabilities.max_iterations
self.console = console or Console()
@contextmanager
def _thinking(self):
"""spinner 实时刷耗时 + 上下文 token 数。yield 出的 ctx 退出后填 elapsed。"""
start = time.monotonic()
stop = threading.Event()
def fmt() -> str:
elapsed = time.monotonic() - start
total = self.llm.token_counter.total
tail = f" ctx {total:,} tok" if total else ""
return f"[dim]thinking... {elapsed:.1f}s{tail}[/dim]"
class Ctx:
elapsed: float = 0.0
ctx = Ctx()
status = self.console.status(fmt(), spinner="dots")
def tick() -> None:
while not stop.wait(0.1):
try:
status.update(fmt())
except Exception:
return
with status:
th = threading.Thread(target=tick, daemon=True)
th.start()
try:
yield ctx
finally:
stop.set()
th.join(timeout=0.5)
ctx.elapsed = time.monotonic() - start
def run(self, user_message: str) -> str:
self.session.append({"role": "user", "content": user_message})
for _ in range(self.max_iterations):
with self._thinking() as t:
response = self.llm.chat(
messages=self.session.messages,
tools=[t.schema for t in self.tools.values()],
reasoning_effort=self.caps.default_reasoning_effort or None,
)
msg = response.choices[0].message
self.session.append(msg)
pt, ct = _extract_usage(getattr(response, "usage", None))
self.console.print(
f"[dim][in {pt:,} out {ct:,} t {t.elapsed:.1f}s][/dim]"
)
tool_calls = getattr(msg, "tool_calls", None) or []
content = getattr(msg, "content", None)
if content:
self.console.print("[cyan]assistant>[/cyan]")
self.console.print(Markdown(content))
if not tool_calls:
return content or ""
for tc in tool_calls:
result = self._execute_tool_call(tc)
self.session.append(
{
"role": "tool",
"tool_call_id": tc.id,
"content": result,
}
)
return "[reached max iterations]"
def _execute_tool_call(self, tc: Any) -> str:
name = tc.function.name
raw_args = tc.function.arguments or "{}"
try:
args = json.loads(raw_args)
except json.JSONDecodeError as e:
return f"[Error] invalid JSON arguments for {name}: {e}"
preview = json.dumps(args, ensure_ascii=False)
if len(preview) > 200:
preview = preview[:200] + "..."
self.console.print(f"[yellow]tool>[/yellow] {name}({preview})")
tool = self.tools.get(name)
if tool is None:
return f"[Error] unknown tool: {name}"
try:
result = tool.execute(**args)
except TypeError as e:
return f"[Error] bad arguments to {name}: {e}"
except Exception as e:
return f"[Error executing {name}] {type(e).__name__}: {e}"
if not isinstance(result, str):
result = str(result)
# 控制返回给模型的 tool 结果体量,避免炸 context
MAX_LEN = 16_000
if len(result) > MAX_LEN:
result = result[:MAX_LEN] + f"\n[... truncated, {len(result) - MAX_LEN} chars ...]"
# 给用户预览(截短)
preview = result if len(result) < 400 else result[:400] + "..."
self.console.print(f"[dim]{preview}[/dim]")
return result