zcbot/core/llm.py

129 lines
4.8 KiB
Python

"""LiteLLM 封装: capabilities 决定调用参数,自动重试。
`chat()`:同步阻塞,一次性返回完整 response。给 probe / 离线探测用。
`chat_stream()`:流式 generator,yield chunk;调用方累积 + 用 litellm.stream_chunk_builder
拼回完整 response。loop 走这条以便 chunk 之间 poll cancel(同步 LLM call 不可中断;
流式下 cancel 延迟 ~ chunk 间隔 100ms 级,而非整轮 generation 时长几十秒)。
"""
from __future__ import annotations
import os
import time
from typing import Any, Iterator, List, Optional
# 跳过启动时从 GitHub 拉 model_prices 的网络请求,直接用 litellm 打包的本地副本。
# 必须在 `import litellm` 之前设置,否则 get_model_cost_map() 已经跑过了。
os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True")
import litellm # noqa: E402
from litellm.exceptions import (
APIConnectionError,
APIError,
RateLimitError,
ServiceUnavailableError,
Timeout,
)
from .capabilities import ModelCapabilities
class LLM:
def __init__(self, capabilities: ModelCapabilities) -> None:
self.caps = capabilities
env_name = capabilities.api_key_env or "DEEPSEEK_API_KEY"
self.api_key = os.environ.get(env_name)
self.api_base = capabilities.api_base or None
if not self.api_key:
raise RuntimeError(
f"环境变量 {env_name} 未设置,无法调用 {capabilities.model_id}"
)
def _build_kwargs(
self,
messages: List[dict],
tools: Optional[list],
parallel_tool_calls: Optional[bool],
reasoning_effort: Optional[str],
) -> dict:
kwargs: dict = {
"model": self.caps.model_id,
"messages": messages,
"temperature": self.caps.optimal_temperature,
"api_key": self.api_key,
}
if self.api_base:
kwargs["api_base"] = self.api_base
if tools:
kwargs["tools"] = tools
if self.caps.parallel_tools and parallel_tool_calls is not False:
kwargs["parallel_tool_calls"] = True
if self.caps.thinking_mode and reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
if self.caps.prompt_caching:
kwargs["extra_headers"] = {"anthropic-beta": "prompt-caching-2024-07-31"}
return kwargs
def chat(
self,
messages: List[dict],
tools: Optional[list] = None,
parallel_tool_calls: Optional[bool] = None,
reasoning_effort: Optional[str] = None,
max_retries: int = 3,
) -> Any:
kwargs = self._build_kwargs(messages, tools, parallel_tool_calls, reasoning_effort)
last_err: Optional[Exception] = None
for attempt in range(max_retries):
try:
response = litellm.completion(**kwargs)
return response
except (RateLimitError, APIConnectionError, ServiceUnavailableError, Timeout, APIError) as e:
last_err = e
if attempt == max_retries - 1:
break
time.sleep(2 ** attempt)
raise last_err # type: ignore[misc]
def chat_stream(
self,
messages: List[dict],
tools: Optional[list] = None,
parallel_tool_calls: Optional[bool] = None,
reasoning_effort: Optional[str] = None,
max_retries: int = 3,
) -> Iterator[Any]:
"""流式 chat:yield 每个 chunk。调用方累积 + 用 litellm.stream_chunk_builder 拼回完整 response。
重试语义:连接建立阶段错误(还没拿到第一个 chunk)按 max_retries 退避重试;
开始流之后失败直接抛(半截 partial 没法续)。usage 通过 stream_options.include_usage
让最后一个 chunk 带 usage。
"""
kwargs = self._build_kwargs(messages, tools, parallel_tool_calls, reasoning_effort)
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
last_err: Optional[Exception] = None
for attempt in range(max_retries):
try:
stream = litellm.completion(**kwargs)
break
except (RateLimitError, APIConnectionError, ServiceUnavailableError, Timeout, APIError) as e:
last_err = e
if attempt == max_retries - 1:
raise
time.sleep(2 ** attempt)
else:
raise last_err # type: ignore[misc]
try:
for chunk in stream:
yield chunk
finally:
# 调用方提前 break(cancel) → generator close → 这里关掉底层 httpx 连接
close = getattr(stream, "close", None)
if callable(close):
try:
close()
except Exception:
pass