127 lines
4.6 KiB
Python
127 lines
4.6 KiB
Python
"""火山方舟 (Ark) 通用 HTTP 客户端,共享给 seedream / 未来 seedance 等媒体工具。
|
|
|
|
litellm 不覆盖豆包的图像/视频生成端点,这里自己用 httpx 直调 OpenAI 兼容路径
|
|
/images/generations 与异步任务 /contents/generations/tasks。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
import httpx
|
|
import yaml
|
|
|
|
from core.paths import ROOT
|
|
|
|
|
|
_DOUBAO_YAML = ROOT / "config" / "media" / "doubao.yaml"
|
|
|
|
|
|
class ArkError(RuntimeError):
|
|
"""ark API 调用失败的统一异常。"""
|
|
|
|
|
|
@dataclass
|
|
class ArkConfig:
|
|
api_key: str
|
|
base_url: str
|
|
raw: dict # 完整 yaml 内容(便于 caller 按 image/video 子键再取)
|
|
|
|
@classmethod
|
|
def load(cls, path: Optional[Path] = None) -> Optional["ArkConfig"]:
|
|
"""读 doubao.yaml + 解析 env 拿 api_key。
|
|
|
|
api_key env 未设 → 返 None(caller 据此决定是否注册 tool;无 key 用户无感知)。
|
|
yaml 不存在 → 返 None。
|
|
"""
|
|
p = path or _DOUBAO_YAML
|
|
if not p.exists():
|
|
return None
|
|
data = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
|
|
env = data.get("ark_api_key_env") or "ARK_API_KEY"
|
|
key = os.environ.get(env, "").strip()
|
|
if not key:
|
|
return None
|
|
return cls(
|
|
api_key=key,
|
|
base_url=str(data.get("ark_base_url") or "https://ark.cn-beijing.volces.com/api/v3").rstrip("/"),
|
|
raw=data,
|
|
)
|
|
|
|
|
|
class ArkClient:
|
|
"""轻量 httpx 封装:统一 base_url + bearer auth + 异常翻译。
|
|
|
|
成功返 dict(JSON 已解析);非 2xx / 网络异常 / JSON 不可解析都抛 ArkError。
|
|
"""
|
|
|
|
def __init__(self, cfg: ArkConfig, timeout_s: float = 60.0) -> None:
|
|
self.cfg = cfg
|
|
self.timeout_s = timeout_s
|
|
self._client = httpx.Client(
|
|
base_url=cfg.base_url,
|
|
headers={
|
|
"Authorization": f"Bearer {cfg.api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
timeout=timeout_s,
|
|
)
|
|
|
|
def post_json(self, path: str, body: dict, *, timeout_s: Optional[float] = None) -> dict:
|
|
try:
|
|
resp = self._client.post(path, json=body, timeout=timeout_s or self.timeout_s)
|
|
except httpx.TimeoutException as e:
|
|
raise ArkError(f"timeout calling POST {path}: {e}") from e
|
|
except httpx.HTTPError as e:
|
|
raise ArkError(f"network error calling POST {path}: {e}") from e
|
|
return self._parse(resp, f"POST {path}")
|
|
|
|
def get_json(self, path: str, *, timeout_s: Optional[float] = None) -> dict:
|
|
try:
|
|
resp = self._client.get(path, timeout=timeout_s or self.timeout_s)
|
|
except httpx.TimeoutException as e:
|
|
raise ArkError(f"timeout calling GET {path}: {e}") from e
|
|
except httpx.HTTPError as e:
|
|
raise ArkError(f"network error calling GET {path}: {e}") from e
|
|
return self._parse(resp, f"GET {path}")
|
|
|
|
@staticmethod
|
|
def _parse(resp: httpx.Response, label: str) -> dict:
|
|
if resp.status_code >= 400:
|
|
# ark 错误 body 一般是 {"error": {"code": ..., "message": ...}};能解就解
|
|
try:
|
|
err = resp.json().get("error") or {}
|
|
msg = err.get("message") or resp.text[:300]
|
|
code = err.get("code") or resp.status_code
|
|
raise ArkError(f"{label} → HTTP {resp.status_code} ({code}): {msg}")
|
|
except ValueError:
|
|
raise ArkError(f"{label} → HTTP {resp.status_code}: {resp.text[:300]}")
|
|
try:
|
|
return resp.json()
|
|
except ValueError as e:
|
|
raise ArkError(f"{label} → invalid JSON response: {e}") from e
|
|
|
|
def download(self, url: str, dest: Path, *, timeout_s: float = 120.0) -> None:
|
|
"""跨域下载产物(image/video URL 是火山 CDN,不带 ark auth)。"""
|
|
try:
|
|
with httpx.stream("GET", url, timeout=timeout_s) as r:
|
|
if r.status_code >= 400:
|
|
raise ArkError(f"download {url} → HTTP {r.status_code}")
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(dest, "wb") as f:
|
|
for chunk in r.iter_bytes(chunk_size=64 * 1024):
|
|
f.write(chunk)
|
|
except httpx.HTTPError as e:
|
|
raise ArkError(f"download {url} failed: {e}") from e
|
|
|
|
def close(self) -> None:
|
|
self._client.close()
|
|
|
|
def __enter__(self) -> "ArkClient":
|
|
return self
|
|
|
|
def __exit__(self, *_exc: Any) -> None:
|
|
self.close()
|