zcbot/core/llm.py

90 lines
2.9 KiB
Python

"""LiteLLM 封装: capabilities 决定调用参数,自动重试。"""
from __future__ import annotations
import os
import time
from typing import Any, List, Optional
import litellm
from litellm.exceptions import (
APIConnectionError,
APIError,
RateLimitError,
ServiceUnavailableError,
Timeout,
)
from .capabilities import ModelCapabilities
class TokenCounter:
def __init__(self) -> None:
self.prompt_tokens = 0
self.completion_tokens = 0
def add(self, usage: Any) -> None:
if not usage:
return
if hasattr(usage, "model_dump"):
usage = usage.model_dump()
elif hasattr(usage, "dict"):
usage = usage.dict()
if isinstance(usage, dict):
self.prompt_tokens += int(usage.get("prompt_tokens") or 0)
self.completion_tokens += int(usage.get("completion_tokens") or 0)
@property
def total(self) -> int:
return self.prompt_tokens + self.completion_tokens
class LLM:
def __init__(self, capabilities: ModelCapabilities) -> None:
self.caps = capabilities
env_name = capabilities.api_key_env or "DEEPSEEK_API_KEY"
self.api_key = os.environ.get(env_name)
self.api_base = capabilities.api_base or None
self.token_counter = TokenCounter()
if not self.api_key:
raise RuntimeError(
f"环境变量 {env_name} 未设置,无法调用 {capabilities.model_id}"
)
def chat(
self,
messages: List[dict],
tools: Optional[list] = None,
parallel_tool_calls: Optional[bool] = None,
reasoning_effort: Optional[str] = None,
max_retries: int = 3,
) -> Any:
kwargs: dict = {
"model": self.caps.model_id,
"messages": messages,
"temperature": self.caps.optimal_temperature,
"api_key": self.api_key,
}
if self.api_base:
kwargs["api_base"] = self.api_base
if tools:
kwargs["tools"] = tools
if self.caps.parallel_tools and parallel_tool_calls is not False:
kwargs["parallel_tool_calls"] = True
if self.caps.thinking_mode and reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
if self.caps.prompt_caching:
kwargs["extra_headers"] = {"anthropic-beta": "prompt-caching-2024-07-31"}
last_err: Optional[Exception] = None
for attempt in range(max_retries):
try:
response = litellm.completion(**kwargs)
self.token_counter.add(getattr(response, "usage", None))
return response
except (RateLimitError, APIConnectionError, ServiceUnavailableError, Timeout, APIError) as e:
last_err = e
if attempt == max_retries - 1:
break
time.sleep(2 ** attempt)
raise last_err # type: ignore[misc]