"""主 agent loop: ReAct 风格,LLM ↔ Tool 反复直到无 tool_call。""" from __future__ import annotations import json import threading import time from contextlib import contextmanager from typing import Any, Dict, Optional, Tuple from rich.console import Console from rich.markdown import Markdown from .capabilities import ModelCapabilities from .llm import LLM from .session import Session def _extract_usage(usage: Any) -> Tuple[int, int]: """从 litellm response.usage 提 (prompt_tokens, completion_tokens)。""" if not usage: return 0, 0 if hasattr(usage, "model_dump"): usage = usage.model_dump() elif hasattr(usage, "dict"): usage = usage.dict() if isinstance(usage, dict): return int(usage.get("prompt_tokens") or 0), int(usage.get("completion_tokens") or 0) return 0, 0 class AgentLoop: def __init__( self, llm: LLM, tools: Dict[str, Any], session: Session, capabilities: ModelCapabilities, console: Optional[Console] = None, max_iterations: Optional[int] = None, ) -> None: self.llm = llm self.tools = tools self.session = session self.caps = capabilities self.max_iterations = max_iterations or capabilities.max_iterations self.console = console or Console() @contextmanager def _thinking(self): """spinner 实时刷耗时 + 上下文 token 数。yield 出的 ctx 退出后填 elapsed。""" start = time.monotonic() stop = threading.Event() def fmt() -> str: elapsed = time.monotonic() - start total = self.llm.token_counter.total tail = f" ctx {total:,} tok" if total else "" return f"[dim]thinking... {elapsed:.1f}s{tail}[/dim]" class Ctx: elapsed: float = 0.0 ctx = Ctx() status = self.console.status(fmt(), spinner="dots") def tick() -> None: while not stop.wait(0.1): try: status.update(fmt()) except Exception: return with status: th = threading.Thread(target=tick, daemon=True) th.start() try: yield ctx finally: stop.set() th.join(timeout=0.5) ctx.elapsed = time.monotonic() - start def run(self, user_message: str) -> str: self.session.append({"role": "user", "content": user_message}) for _ in range(self.max_iterations): with self._thinking() as t: response = self.llm.chat( messages=self.session.messages, tools=[t.schema for t in self.tools.values()], reasoning_effort=self.caps.default_reasoning_effort or None, ) msg = response.choices[0].message self.session.append(msg) pt, ct = _extract_usage(getattr(response, "usage", None)) self.console.print( f"[dim][in {pt:,} out {ct:,} t {t.elapsed:.1f}s][/dim]" ) tool_calls = getattr(msg, "tool_calls", None) or [] content = getattr(msg, "content", None) if content: self.console.print("[cyan]assistant>[/cyan]") self.console.print(Markdown(content)) if not tool_calls: return content or "" for tc in tool_calls: result = self._execute_tool_call(tc) self.session.append( { "role": "tool", "tool_call_id": tc.id, "content": result, } ) return "[reached max iterations]" def _execute_tool_call(self, tc: Any) -> str: name = tc.function.name raw_args = tc.function.arguments or "{}" try: args = json.loads(raw_args) except json.JSONDecodeError as e: return f"[Error] invalid JSON arguments for {name}: {e}" preview = json.dumps(args, ensure_ascii=False) if len(preview) > 200: preview = preview[:200] + "..." self.console.print(f"[yellow]tool>[/yellow] {name}({preview})") tool = self.tools.get(name) if tool is None: return f"[Error] unknown tool: {name}" try: result = tool.execute(**args) except TypeError as e: return f"[Error] bad arguments to {name}: {e}" except Exception as e: return f"[Error executing {name}] {type(e).__name__}: {e}" if not isinstance(result, str): result = str(result) # 控制返回给模型的 tool 结果体量,避免炸 context MAX_LEN = 16_000 if len(result) > MAX_LEN: result = result[:MAX_LEN] + f"\n[... truncated, {len(result) - MAX_LEN} chars ...]" # 给用户预览(截短) preview = result if len(result) < 400 else result[:400] + "..." self.console.print(f"[dim]{preview}[/dim]") return result