94 lines
3.2 KiB
Python
94 lines
3.2 KiB
Python
"""LiteLLM 封装: capabilities 决定调用参数,自动重试。"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
from typing import Any, List, Optional
|
|
|
|
# 跳过启动时从 GitHub 拉 model_prices 的网络请求,直接用 litellm 打包的本地副本。
|
|
# 必须在 `import litellm` 之前设置,否则 get_model_cost_map() 已经跑过了。
|
|
os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True")
|
|
|
|
import litellm # noqa: E402
|
|
from litellm.exceptions import (
|
|
APIConnectionError,
|
|
APIError,
|
|
RateLimitError,
|
|
ServiceUnavailableError,
|
|
Timeout,
|
|
)
|
|
|
|
from .capabilities import ModelCapabilities
|
|
|
|
|
|
class TokenCounter:
|
|
def __init__(self) -> None:
|
|
self.prompt_tokens = 0
|
|
self.completion_tokens = 0
|
|
|
|
def add(self, usage: Any) -> None:
|
|
if not usage:
|
|
return
|
|
if hasattr(usage, "model_dump"):
|
|
usage = usage.model_dump()
|
|
elif hasattr(usage, "dict"):
|
|
usage = usage.dict()
|
|
if isinstance(usage, dict):
|
|
self.prompt_tokens += int(usage.get("prompt_tokens") or 0)
|
|
self.completion_tokens += int(usage.get("completion_tokens") or 0)
|
|
|
|
@property
|
|
def total(self) -> int:
|
|
return self.prompt_tokens + self.completion_tokens
|
|
|
|
|
|
class LLM:
|
|
def __init__(self, capabilities: ModelCapabilities) -> None:
|
|
self.caps = capabilities
|
|
env_name = capabilities.api_key_env or "DEEPSEEK_API_KEY"
|
|
self.api_key = os.environ.get(env_name)
|
|
self.api_base = capabilities.api_base or None
|
|
self.token_counter = TokenCounter()
|
|
if not self.api_key:
|
|
raise RuntimeError(
|
|
f"环境变量 {env_name} 未设置,无法调用 {capabilities.model_id}"
|
|
)
|
|
|
|
def chat(
|
|
self,
|
|
messages: List[dict],
|
|
tools: Optional[list] = None,
|
|
parallel_tool_calls: Optional[bool] = None,
|
|
reasoning_effort: Optional[str] = None,
|
|
max_retries: int = 3,
|
|
) -> Any:
|
|
kwargs: dict = {
|
|
"model": self.caps.model_id,
|
|
"messages": messages,
|
|
"temperature": self.caps.optimal_temperature,
|
|
"api_key": self.api_key,
|
|
}
|
|
if self.api_base:
|
|
kwargs["api_base"] = self.api_base
|
|
if tools:
|
|
kwargs["tools"] = tools
|
|
if self.caps.parallel_tools and parallel_tool_calls is not False:
|
|
kwargs["parallel_tool_calls"] = True
|
|
if self.caps.thinking_mode and reasoning_effort:
|
|
kwargs["reasoning_effort"] = reasoning_effort
|
|
if self.caps.prompt_caching:
|
|
kwargs["extra_headers"] = {"anthropic-beta": "prompt-caching-2024-07-31"}
|
|
|
|
last_err: Optional[Exception] = None
|
|
for attempt in range(max_retries):
|
|
try:
|
|
response = litellm.completion(**kwargs)
|
|
self.token_counter.add(getattr(response, "usage", None))
|
|
return response
|
|
except (RateLimitError, APIConnectionError, ServiceUnavailableError, Timeout, APIError) as e:
|
|
last_err = e
|
|
if attempt == max_retries - 1:
|
|
break
|
|
time.sleep(2 ** attempt)
|
|
raise last_err # type: ignore[misc]
|