158 lines
5.1 KiB
Python
158 lines
5.1 KiB
Python
"""主 agent loop: ReAct 风格,LLM ↔ Tool 反复直到无 tool_call。"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import threading
|
|
import time
|
|
from contextlib import contextmanager
|
|
from typing import Any, Dict, Optional, Tuple
|
|
|
|
from rich.console import Console
|
|
from rich.markdown import Markdown
|
|
|
|
from .capabilities import ModelCapabilities
|
|
from .llm import LLM
|
|
from .session import Session
|
|
|
|
|
|
def _extract_usage(usage: Any) -> Tuple[int, int]:
|
|
"""从 litellm response.usage 提 (prompt_tokens, completion_tokens)。"""
|
|
if not usage:
|
|
return 0, 0
|
|
if hasattr(usage, "model_dump"):
|
|
usage = usage.model_dump()
|
|
elif hasattr(usage, "dict"):
|
|
usage = usage.dict()
|
|
if isinstance(usage, dict):
|
|
return int(usage.get("prompt_tokens") or 0), int(usage.get("completion_tokens") or 0)
|
|
return 0, 0
|
|
|
|
|
|
class AgentLoop:
|
|
def __init__(
|
|
self,
|
|
llm: LLM,
|
|
tools: Dict[str, Any],
|
|
session: Session,
|
|
capabilities: ModelCapabilities,
|
|
console: Optional[Console] = None,
|
|
max_iterations: Optional[int] = None,
|
|
) -> None:
|
|
self.llm = llm
|
|
self.tools = tools
|
|
self.session = session
|
|
self.caps = capabilities
|
|
self.max_iterations = max_iterations or capabilities.max_iterations
|
|
self.console = console or Console()
|
|
|
|
@contextmanager
|
|
def _thinking(self):
|
|
"""spinner 实时刷耗时 + 上下文 token 数。yield 出的 ctx 退出后填 elapsed。"""
|
|
start = time.monotonic()
|
|
stop = threading.Event()
|
|
|
|
def fmt() -> str:
|
|
elapsed = time.monotonic() - start
|
|
total = self.llm.token_counter.total
|
|
tail = f" ctx {total:,} tok" if total else ""
|
|
return f"[dim]thinking... {elapsed:.1f}s{tail}[/dim]"
|
|
|
|
class Ctx:
|
|
elapsed: float = 0.0
|
|
|
|
ctx = Ctx()
|
|
status = self.console.status(fmt(), spinner="dots")
|
|
|
|
def tick() -> None:
|
|
while not stop.wait(0.1):
|
|
try:
|
|
status.update(fmt())
|
|
except Exception:
|
|
return
|
|
|
|
with status:
|
|
th = threading.Thread(target=tick, daemon=True)
|
|
th.start()
|
|
try:
|
|
yield ctx
|
|
finally:
|
|
stop.set()
|
|
th.join(timeout=0.5)
|
|
ctx.elapsed = time.monotonic() - start
|
|
|
|
def run(self, user_message: str) -> str:
|
|
self.session.append({"role": "user", "content": user_message})
|
|
|
|
for _ in range(self.max_iterations):
|
|
with self._thinking() as t:
|
|
response = self.llm.chat(
|
|
messages=self.session.messages,
|
|
tools=[t.schema for t in self.tools.values()],
|
|
reasoning_effort=self.caps.default_reasoning_effort or None,
|
|
)
|
|
msg = response.choices[0].message
|
|
self.session.append(msg)
|
|
|
|
pt, ct = _extract_usage(getattr(response, "usage", None))
|
|
self.console.print(
|
|
f"[dim][in {pt:,} out {ct:,} t {t.elapsed:.1f}s][/dim]"
|
|
)
|
|
|
|
tool_calls = getattr(msg, "tool_calls", None) or []
|
|
content = getattr(msg, "content", None)
|
|
if content:
|
|
self.console.print("[cyan]assistant>[/cyan]")
|
|
self.console.print(Markdown(content))
|
|
|
|
if not tool_calls:
|
|
return content or ""
|
|
|
|
for tc in tool_calls:
|
|
result = self._execute_tool_call(tc)
|
|
self.session.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tc.id,
|
|
"content": result,
|
|
}
|
|
)
|
|
|
|
return "[reached max iterations]"
|
|
|
|
def _execute_tool_call(self, tc: Any) -> str:
|
|
name = tc.function.name
|
|
raw_args = tc.function.arguments or "{}"
|
|
try:
|
|
args = json.loads(raw_args)
|
|
except json.JSONDecodeError as e:
|
|
return f"[Error] invalid JSON arguments for {name}: {e}"
|
|
|
|
preview = json.dumps(args, ensure_ascii=False)
|
|
if len(preview) > 200:
|
|
preview = preview[:200] + "..."
|
|
self.console.print(f"[yellow]tool>[/yellow] {name}({preview})")
|
|
|
|
tool = self.tools.get(name)
|
|
if tool is None:
|
|
return f"[Error] unknown tool: {name}"
|
|
|
|
try:
|
|
result = tool.execute(**args)
|
|
except TypeError as e:
|
|
return f"[Error] bad arguments to {name}: {e}"
|
|
except Exception as e:
|
|
return f"[Error executing {name}] {type(e).__name__}: {e}"
|
|
|
|
if not isinstance(result, str):
|
|
result = str(result)
|
|
|
|
# 控制返回给模型的 tool 结果体量,避免炸 context
|
|
MAX_LEN = 16_000
|
|
if len(result) > MAX_LEN:
|
|
result = result[:MAX_LEN] + f"\n[... truncated, {len(result) - MAX_LEN} chars ...]"
|
|
|
|
# 给用户预览(截短)
|
|
preview = result if len(result) < 400 else result[:400] + "..."
|
|
self.console.print(f"[dim]{preview}[/dim]")
|
|
return result
|