"""LiteLLM 封装: capabilities 决定调用参数,自动重试。 `chat()`:同步阻塞,一次性返回完整 response。给 probe / 离线探测用。 `chat_stream()`:流式 generator,yield chunk;调用方累积 + 用 litellm.stream_chunk_builder 拼回完整 response。loop 走这条以便 chunk 之间 poll cancel(同步 LLM call 不可中断; 流式下 cancel 延迟 ~ chunk 间隔 100ms 级,而非整轮 generation 时长几十秒)。 """ from __future__ import annotations import os import time from typing import Any, Iterator, List, Optional # 跳过启动时从 GitHub 拉 model_prices 的网络请求,直接用 litellm 打包的本地副本。 # 必须在 `import litellm` 之前设置,否则 get_model_cost_map() 已经跑过了。 os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") import litellm # noqa: E402 from litellm.exceptions import ( APIConnectionError, APIError, RateLimitError, ServiceUnavailableError, Timeout, ) from .capabilities import ModelCapabilities class TokenCounter: def __init__(self) -> None: self.prompt_tokens = 0 self.completion_tokens = 0 def add(self, usage: Any) -> None: if not usage: return if hasattr(usage, "model_dump"): usage = usage.model_dump() elif hasattr(usage, "dict"): usage = usage.dict() if isinstance(usage, dict): self.prompt_tokens += int(usage.get("prompt_tokens") or 0) self.completion_tokens += int(usage.get("completion_tokens") or 0) @property def total(self) -> int: return self.prompt_tokens + self.completion_tokens class LLM: def __init__(self, capabilities: ModelCapabilities) -> None: self.caps = capabilities env_name = capabilities.api_key_env or "DEEPSEEK_API_KEY" self.api_key = os.environ.get(env_name) self.api_base = capabilities.api_base or None self.token_counter = TokenCounter() if not self.api_key: raise RuntimeError( f"环境变量 {env_name} 未设置,无法调用 {capabilities.model_id}" ) def _build_kwargs( self, messages: List[dict], tools: Optional[list], parallel_tool_calls: Optional[bool], reasoning_effort: Optional[str], ) -> dict: kwargs: dict = { "model": self.caps.model_id, "messages": messages, "temperature": self.caps.optimal_temperature, "api_key": self.api_key, } if self.api_base: kwargs["api_base"] = self.api_base if tools: kwargs["tools"] = tools if self.caps.parallel_tools and parallel_tool_calls is not False: kwargs["parallel_tool_calls"] = True if self.caps.thinking_mode and reasoning_effort: kwargs["reasoning_effort"] = reasoning_effort if self.caps.prompt_caching: kwargs["extra_headers"] = {"anthropic-beta": "prompt-caching-2024-07-31"} return kwargs def chat( self, messages: List[dict], tools: Optional[list] = None, parallel_tool_calls: Optional[bool] = None, reasoning_effort: Optional[str] = None, max_retries: int = 3, ) -> Any: kwargs = self._build_kwargs(messages, tools, parallel_tool_calls, reasoning_effort) last_err: Optional[Exception] = None for attempt in range(max_retries): try: response = litellm.completion(**kwargs) self.token_counter.add(getattr(response, "usage", None)) return response except (RateLimitError, APIConnectionError, ServiceUnavailableError, Timeout, APIError) as e: last_err = e if attempt == max_retries - 1: break time.sleep(2 ** attempt) raise last_err # type: ignore[misc] def chat_stream( self, messages: List[dict], tools: Optional[list] = None, parallel_tool_calls: Optional[bool] = None, reasoning_effort: Optional[str] = None, max_retries: int = 3, ) -> Iterator[Any]: """流式 chat:yield 每个 chunk。调用方累积 + 用 litellm.stream_chunk_builder 拼回完整 response。 重试语义:连接建立阶段错误(还没拿到第一个 chunk)按 max_retries 退避重试; 开始流之后失败直接抛(半截 partial 没法续)。usage 通过 stream_options.include_usage 让最后一个 chunk 带 usage。 """ kwargs = self._build_kwargs(messages, tools, parallel_tool_calls, reasoning_effort) kwargs["stream"] = True kwargs["stream_options"] = {"include_usage": True} last_err: Optional[Exception] = None for attempt in range(max_retries): try: stream = litellm.completion(**kwargs) break except (RateLimitError, APIConnectionError, ServiceUnavailableError, Timeout, APIError) as e: last_err = e if attempt == max_retries - 1: raise time.sleep(2 ** attempt) else: raise last_err # type: ignore[misc] try: for chunk in stream: yield chunk finally: # 调用方提前 break(cancel) → generator close → 这里关掉底层 httpx 连接 close = getattr(stream, "close", None) if callable(close): try: close() except Exception: pass