221 lines
9.2 KiB
Python
221 lines
9.2 KiB
Python
"""seedream: 调豆包 Seedream 图像生成 API,产物落 working_dir/figures/。
|
|
|
|
模型 ID + 单价 + 默认参数全在 `config/media/doubao.yaml`,本 tool 只装配。
|
|
完成后:
|
|
- 图片落到 `<working_dir>/figures/<YYYYMMDD-HHMMSS>-<rand6>.png`
|
|
- 同名 `.meta.json` 写 prompt / model / size / search / cost_cny / response_id / ts
|
|
- usage_events 写 kind="image" 一行(单价 snapshot 进 units → 跨调价对账)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import secrets
|
|
import time
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
from uuid import UUID
|
|
|
|
from core.ark_client import ArkClient, ArkConfig, ArkError
|
|
from core.storage.usage import check_daily_quota, record_image_usage
|
|
|
|
from .base import Tool
|
|
|
|
|
|
class SeedreamTool(Tool):
|
|
name = "seedream"
|
|
description = (
|
|
"Generate an image with Doubao Seedream 5.0 and save to working_dir/figures/. "
|
|
"Use when the user explicitly asks for an image / illustration / cover. "
|
|
"Each call costs ¥0.22 (¥0.05 extra if search=true). Don't generate decoratively — "
|
|
"only when the user actually wants an image. Returns the saved relative path."
|
|
)
|
|
parameters = {
|
|
"type": "object",
|
|
"properties": {
|
|
"prompt": {
|
|
"type": "string",
|
|
"description": "中文或英文都行,详尽描述画面(主体/风格/光线/构图)。直接传用户意图即可,模型自己理解。",
|
|
},
|
|
"size": {
|
|
"type": "string",
|
|
"description": "Image size like '2048x2048' / '1024x1024' / '3072x3072'. Defaults to config (2048x2048).",
|
|
},
|
|
"watermark": {
|
|
"type": "boolean",
|
|
"description": "是否打豆包水印。默认 false(申报/PPT 场景不需要)。",
|
|
},
|
|
"search": {
|
|
"type": "boolean",
|
|
"description": "是否启用联网搜索辅助生成(适合时事/特定品牌等)。默认 false,启用会加价约 ¥0.05/张。",
|
|
},
|
|
},
|
|
"required": ["prompt"],
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
ark_cfg: ArkConfig,
|
|
image_variant_cfg: dict,
|
|
variant_key: str,
|
|
working_dir: Path,
|
|
task_id: UUID,
|
|
user_id: UUID,
|
|
base_dir: Optional[Path] = None,
|
|
user_root: Optional[Path] = None,
|
|
daily_limit: int = 0,
|
|
) -> None:
|
|
super().__init__(base_dir, user_root=user_root)
|
|
self.ark_cfg = ark_cfg
|
|
self.cfg = image_variant_cfg
|
|
self.variant_key = variant_key # 'seedream_5' → usage_events.model_profile = "doubao.seedream_5"
|
|
self.working_dir = Path(working_dir)
|
|
self.task_id = task_id
|
|
self.user_id = user_id
|
|
self.daily_limit = int(daily_limit) # 0 / 负 = 不限;由 agent_builder 从 yaml quotas 透传
|
|
|
|
def execute(
|
|
self,
|
|
prompt: str,
|
|
size: Optional[str] = None,
|
|
watermark: Optional[bool] = None,
|
|
search: Optional[bool] = None,
|
|
) -> str:
|
|
if not (prompt or "").strip():
|
|
return "[Error] prompt 不能为空"
|
|
|
|
# 每账号每日配额(yaml quotas.images_per_day)。失败 retry 不计,因为
|
|
# record_image_usage 只在成功+下载完才落库。tool 返串会进 LLM 上下文,
|
|
# 模型据此向用户解释,所以**只暴露用户该看的部分**(已用/上限 + 重置时间),
|
|
# 内部 yaml 路径不进对话(管理员要改的地方读代码/yaml 自己找)。
|
|
if self.daily_limit > 0:
|
|
used, over = check_daily_quota(user_id=self.user_id, kind="image", limit=self.daily_limit)
|
|
if over:
|
|
return (
|
|
f"[Error] 已达每日图片生成上限({used}/{self.daily_limit} 张),"
|
|
f"次日 00:00 重置。"
|
|
)
|
|
|
|
cfg = self.cfg
|
|
model_id = cfg["model_id"]
|
|
chosen_size = size or cfg.get("default_size", "2048x2048")
|
|
chosen_watermark = bool(cfg.get("default_watermark", False)) if watermark is None else bool(watermark)
|
|
chosen_search = bool(cfg.get("default_search", False)) if search is None else bool(search)
|
|
timeout_s = float(cfg.get("request_timeout_s", 60))
|
|
price = float(cfg.get("price_cny_per_image", 0))
|
|
|
|
body: dict[str, Any] = {
|
|
"model": model_id,
|
|
"prompt": prompt,
|
|
"size": chosen_size,
|
|
"response_format": "url",
|
|
"watermark": chosen_watermark,
|
|
}
|
|
if chosen_search:
|
|
# 豆包 search 参数透传(YAML 注释里说明加价 ~¥0.05/张)
|
|
body["search"] = True
|
|
|
|
endpoint = cfg.get("endpoint", "/images/generations")
|
|
t0 = time.monotonic()
|
|
try:
|
|
with ArkClient(self.ark_cfg, timeout_s=timeout_s) as client:
|
|
resp = client.post_json(endpoint, body, timeout_s=timeout_s)
|
|
image_url, response_id = self._extract_url(resp)
|
|
if not image_url:
|
|
return f"[Error] seedream response 缺 image url: {json.dumps(resp, ensure_ascii=False)[:300]}"
|
|
|
|
# 落盘 figures/<ts>-<rand>.png + .meta.json
|
|
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
short = secrets.token_hex(3)
|
|
figures_dir = self.working_dir / "figures"
|
|
dest_png = figures_dir / f"{ts}-{short}.png"
|
|
client.download(image_url, dest_png, timeout_s=120.0)
|
|
except ArkError as e:
|
|
return f"[Error] seedream API: {e}"
|
|
|
|
elapsed = time.monotonic() - t0
|
|
# 估算成本(单价 snapshot 在 record_image_usage 里同步落库)
|
|
extra_cny = 0.05 if chosen_search else 0.0 # 搜索加价的粗略值,仅供 user 提示
|
|
cost_cny = float(price) + extra_cny
|
|
|
|
meta = {
|
|
"prompt": prompt,
|
|
"model_id": model_id,
|
|
"size": chosen_size,
|
|
"watermark": chosen_watermark,
|
|
"search": chosen_search,
|
|
"cost_cny": cost_cny,
|
|
"elapsed_s": round(elapsed, 2),
|
|
"response_id": response_id,
|
|
"ts": datetime.now().isoformat(timespec="seconds"),
|
|
}
|
|
meta_path = dest_png.with_suffix(".meta.json")
|
|
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
|
|
# usage_events 记账;失败不阻塞 tool 返回,但 emit 一条 warn 给 sink 的事走不到这里
|
|
# (tool 层没 sink 引用),先 print 兜底;后续可改成 sink 注入。
|
|
try:
|
|
record_image_usage(
|
|
task_id=self.task_id,
|
|
user_id=self.user_id,
|
|
model_profile=f"doubao.{self.variant_key}",
|
|
n_images=1,
|
|
size=chosen_size,
|
|
price_cny_per_image=float(price),
|
|
search=chosen_search,
|
|
extra_units={"search_extra_cny": extra_cny} if chosen_search else None,
|
|
)
|
|
except Exception as e:
|
|
print(f"[seedream] record_image_usage failed: {type(e).__name__}: {e}", flush=True)
|
|
|
|
disp = self._display(dest_png)
|
|
# 第一行 banner:前端 SPA 把这行(name===seedream 时)单独提到 details summary
|
|
# 旁边显示,用户不展开就能看到 model / size / cost / 耗时 —— 透明性的关键。
|
|
# 格式严格 key=value · 分隔,parse 用正则 `key=([^·\n]+)` 抓。
|
|
return (
|
|
f"[seedream] model={model_id} · size={chosen_size} · "
|
|
f"cost=¥{cost_cny:.2f} · elapsed={elapsed:.1f}s\n"
|
|
f"saved: {disp}\n"
|
|
f"prompt={prompt!r}\n"
|
|
f"watermark={chosen_watermark} search={chosen_search}"
|
|
)
|
|
|
|
@staticmethod
|
|
def _extract_url(resp: dict) -> tuple[str, str]:
|
|
"""ark images/generations 响应解析,容忍几种已知 shape:
|
|
- OpenAI 兼容: {"data":[{"url":"..."}], "id":"..."}
|
|
- 豆包自有: {"data":{"images":[{"url":"..."}]}}
|
|
- 兜底: 任意位置出现的第一个 .url 字符串
|
|
"""
|
|
rid = str(resp.get("id") or resp.get("request_id") or "")
|
|
data = resp.get("data")
|
|
if isinstance(data, list) and data:
|
|
first = data[0]
|
|
if isinstance(first, dict):
|
|
u = first.get("url") or first.get("image_url")
|
|
if isinstance(u, str):
|
|
return u, rid
|
|
if isinstance(data, dict):
|
|
imgs = data.get("images")
|
|
if isinstance(imgs, list) and imgs:
|
|
u = imgs[0].get("url") if isinstance(imgs[0], dict) else None
|
|
if isinstance(u, str):
|
|
return u, rid
|
|
# 兜底:递归搜
|
|
def _find_url(o: Any) -> Optional[str]:
|
|
if isinstance(o, dict):
|
|
for k, v in o.items():
|
|
if k in ("url", "image_url") and isinstance(v, str) and v.startswith("http"):
|
|
return v
|
|
r = _find_url(v)
|
|
if r:
|
|
return r
|
|
elif isinstance(o, list):
|
|
for x in o:
|
|
r = _find_url(x)
|
|
if r:
|
|
return r
|
|
return None
|
|
return (_find_url(resp) or ""), rid
|