152 lines
5.5 KiB
Python
152 lines
5.5 KiB
Python
"""LiteLLM 封装: capabilities 决定调用参数,自动重试。
|
|
|
|
`chat()`:同步阻塞,一次性返回完整 response。给 probe / 离线探测用。
|
|
`chat_stream()`:流式 generator,yield chunk;调用方累积 + 用 litellm.stream_chunk_builder
|
|
拼回完整 response。loop 走这条以便 chunk 之间 poll cancel(同步 LLM call 不可中断;
|
|
流式下 cancel 延迟 ~ chunk 间隔 100ms 级,而非整轮 generation 时长几十秒)。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
from typing import Any, Iterator, List, Optional
|
|
|
|
# 跳过启动时从 GitHub 拉 model_prices 的网络请求,直接用 litellm 打包的本地副本。
|
|
# 必须在 `import litellm` 之前设置,否则 get_model_cost_map() 已经跑过了。
|
|
os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True")
|
|
|
|
import litellm # noqa: E402
|
|
from litellm.exceptions import (
|
|
APIConnectionError,
|
|
APIError,
|
|
RateLimitError,
|
|
ServiceUnavailableError,
|
|
Timeout,
|
|
)
|
|
|
|
from .capabilities import ModelCapabilities
|
|
|
|
|
|
class TokenCounter:
|
|
def __init__(self) -> None:
|
|
self.prompt_tokens = 0
|
|
self.completion_tokens = 0
|
|
|
|
def add(self, usage: Any) -> None:
|
|
if not usage:
|
|
return
|
|
if hasattr(usage, "model_dump"):
|
|
usage = usage.model_dump()
|
|
elif hasattr(usage, "dict"):
|
|
usage = usage.dict()
|
|
if isinstance(usage, dict):
|
|
self.prompt_tokens += int(usage.get("prompt_tokens") or 0)
|
|
self.completion_tokens += int(usage.get("completion_tokens") or 0)
|
|
|
|
@property
|
|
def total(self) -> int:
|
|
return self.prompt_tokens + self.completion_tokens
|
|
|
|
|
|
class LLM:
|
|
def __init__(self, capabilities: ModelCapabilities) -> None:
|
|
self.caps = capabilities
|
|
env_name = capabilities.api_key_env or "DEEPSEEK_API_KEY"
|
|
self.api_key = os.environ.get(env_name)
|
|
self.api_base = capabilities.api_base or None
|
|
self.token_counter = TokenCounter()
|
|
if not self.api_key:
|
|
raise RuntimeError(
|
|
f"环境变量 {env_name} 未设置,无法调用 {capabilities.model_id}"
|
|
)
|
|
|
|
def _build_kwargs(
|
|
self,
|
|
messages: List[dict],
|
|
tools: Optional[list],
|
|
parallel_tool_calls: Optional[bool],
|
|
reasoning_effort: Optional[str],
|
|
) -> dict:
|
|
kwargs: dict = {
|
|
"model": self.caps.model_id,
|
|
"messages": messages,
|
|
"temperature": self.caps.optimal_temperature,
|
|
"api_key": self.api_key,
|
|
}
|
|
if self.api_base:
|
|
kwargs["api_base"] = self.api_base
|
|
if tools:
|
|
kwargs["tools"] = tools
|
|
if self.caps.parallel_tools and parallel_tool_calls is not False:
|
|
kwargs["parallel_tool_calls"] = True
|
|
if self.caps.thinking_mode and reasoning_effort:
|
|
kwargs["reasoning_effort"] = reasoning_effort
|
|
if self.caps.prompt_caching:
|
|
kwargs["extra_headers"] = {"anthropic-beta": "prompt-caching-2024-07-31"}
|
|
return kwargs
|
|
|
|
def chat(
|
|
self,
|
|
messages: List[dict],
|
|
tools: Optional[list] = None,
|
|
parallel_tool_calls: Optional[bool] = None,
|
|
reasoning_effort: Optional[str] = None,
|
|
max_retries: int = 3,
|
|
) -> Any:
|
|
kwargs = self._build_kwargs(messages, tools, parallel_tool_calls, reasoning_effort)
|
|
last_err: Optional[Exception] = None
|
|
for attempt in range(max_retries):
|
|
try:
|
|
response = litellm.completion(**kwargs)
|
|
self.token_counter.add(getattr(response, "usage", None))
|
|
return response
|
|
except (RateLimitError, APIConnectionError, ServiceUnavailableError, Timeout, APIError) as e:
|
|
last_err = e
|
|
if attempt == max_retries - 1:
|
|
break
|
|
time.sleep(2 ** attempt)
|
|
raise last_err # type: ignore[misc]
|
|
|
|
def chat_stream(
|
|
self,
|
|
messages: List[dict],
|
|
tools: Optional[list] = None,
|
|
parallel_tool_calls: Optional[bool] = None,
|
|
reasoning_effort: Optional[str] = None,
|
|
max_retries: int = 3,
|
|
) -> Iterator[Any]:
|
|
"""流式 chat:yield 每个 chunk。调用方累积 + 用 litellm.stream_chunk_builder 拼回完整 response。
|
|
|
|
重试语义:连接建立阶段错误(还没拿到第一个 chunk)按 max_retries 退避重试;
|
|
开始流之后失败直接抛(半截 partial 没法续)。usage 通过 stream_options.include_usage
|
|
让最后一个 chunk 带 usage。
|
|
"""
|
|
kwargs = self._build_kwargs(messages, tools, parallel_tool_calls, reasoning_effort)
|
|
kwargs["stream"] = True
|
|
kwargs["stream_options"] = {"include_usage": True}
|
|
|
|
last_err: Optional[Exception] = None
|
|
for attempt in range(max_retries):
|
|
try:
|
|
stream = litellm.completion(**kwargs)
|
|
break
|
|
except (RateLimitError, APIConnectionError, ServiceUnavailableError, Timeout, APIError) as e:
|
|
last_err = e
|
|
if attempt == max_retries - 1:
|
|
raise
|
|
time.sleep(2 ** attempt)
|
|
else:
|
|
raise last_err # type: ignore[misc]
|
|
|
|
try:
|
|
for chunk in stream:
|
|
yield chunk
|
|
finally:
|
|
# 调用方提前 break(cancel) → generator close → 这里关掉底层 httpx 连接
|
|
close = getattr(stream, "close", None)
|
|
if callable(close):
|
|
try:
|
|
close()
|
|
except Exception:
|
|
pass
|