zcbot/tools/seedream.py

337 lines
15 KiB
Python

"""seedream: 调豆包 Seedream 图像生成 API,产物落 working_dir/figures/。
模型 ID + 单价 + 默认参数全在 `config/media/doubao.yaml`,本 tool 只装配。
完成后:
- 图片落到 `<working_dir>/figures/<YYYYMMDD-HHMMSS>-<rand6>.png`
- 同名 `.meta.json` 写 prompt / model / size / search / cost_cny / response_id / ts
- usage_events 写 kind="image" 一行(单价 snapshot 进 units → 跨调价对账)
"""
from __future__ import annotations
import json
import secrets
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Optional
from uuid import UUID
from core.ark_client import ArkClient, ArkConfig, ArkError
from core.storage.usage import check_daily_quota, record_image_usage
from .base import Tool
from .image_ref import load_image_as_data_url
class SeedreamTool(Tool):
name = "seedream"
description = (
"Generate (text-to-image) OR edit (image-to-image) an image with Doubao Seedream 5.0, "
"saved to working_dir/figures/. Text-to-image: describe what to draw. "
"Image-to-image (改图): pass `reference_images` with an existing image path to modify it "
"at pixel level — use this when the user wants to tweak an already-generated/uploaded image "
"(e.g. '把刚才那张图的天空改成黄昏'), NOT a fresh text-to-image (which would lose the original). "
"Use when the user explicitly asks for / to change an image. "
"Each call costs ¥0.22 (¥0.05 extra if search=true). Don't generate decoratively — "
"only when the user actually wants an image. Returns the saved relative path."
)
parameters = {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "中文或英文都行,详尽描述画面(主体/风格/光线/构图)。改图(reference_images)时只描述「要改成什么」即可。",
},
"reference_images": {
"type": "array",
"items": {"type": "string"},
"description": (
"改图(image-to-image):传 1 张已存在图片的相对路径(task_dir 内,如 "
"'figures/xxx.png',或 seedream 上次返回的 saved 路径)做像素级修改。"
"不传 = 从零文生图。**v1 只支持 1 张**(传多张会报错)。"
"基于刚生成/用户上传的图做局部修改,务必走这里指那张图,不要重新文生图。"
),
},
"size": {
"type": "string",
"description": "Image size like '2048x2048' / '1024x1024' / '3072x3072'. Defaults to config (2048x2048). 改图时建议保持 ≥1920²(ARK i2i 最小输出约束)。",
},
"watermark": {
"type": "boolean",
"description": "是否打豆包水印。默认 false(申报/PPT 场景不需要)。",
},
"search": {
"type": "boolean",
"description": "是否启用联网搜索辅助生成(适合时事/特定品牌等)。默认 false,启用会加价约 ¥0.05/张。",
},
},
"required": ["prompt"],
}
def __init__(
self,
*,
ark_cfg: ArkConfig,
image_variant_cfg: dict,
variant_key: str,
working_dir: Path,
task_id: UUID,
user_id: UUID,
base_dir: Optional[Path] = None,
user_root: Optional[Path] = None,
daily_limit: int = 0,
) -> None:
super().__init__(base_dir, user_root=user_root)
self.ark_cfg = ark_cfg
self.cfg = image_variant_cfg
self.variant_key = variant_key # 'seedream_5' → usage_events.model_profile = "doubao.seedream_5"
self.working_dir = Path(working_dir)
self.task_id = task_id
self.user_id = user_id
self.daily_limit = int(daily_limit) # 0 / 负 = 不限;由 agent_builder 从 yaml quotas 透传
def execute(
self,
prompt: str,
reference_images: Optional[list] = None,
size: Optional[str] = None,
watermark: Optional[bool] = None,
search: Optional[bool] = None,
) -> str:
if not (prompt or "").strip():
return "[Error] prompt 不能为空"
# 改图(i2i)分支:把参考图读成 base64 data URL → ARK body image_urls。
# 不传 / 空 → 走文生图(t2i),与历史行为完全一致(向后兼容)。
refs = [str(r).strip() for r in (reference_images or []) if str(r).strip()]
ref_data_urls: list[str] = []
ref_disp: list[str] = []
if len(refs) > 1:
return (
"[Error] reference_images v1 仅支持单张参考图(传了 "
f"{len(refs)} 张)。多图合成/角色定义留 v2,当前请只传 1 张。"
)
if refs:
data_url, disp, err = load_image_as_data_url(
refs[0],
working_dir=self.working_dir,
user_root=self.user_root,
display_fn=self._display,
)
if err:
return err
ref_data_urls.append(data_url)
ref_disp.append(disp)
is_i2i = bool(ref_data_urls)
# 每账号每日配额(yaml quotas.images_per_day)。失败 retry 不计,因为
# record_image_usage 只在成功+下载完才落库。tool 返串会进 LLM 上下文,
# 模型据此向用户解释,所以**只暴露用户该看的部分**(已用/上限 + 重置时间),
# 内部 yaml 路径不进对话(管理员要改的地方读代码/yaml 自己找)。
if self.daily_limit > 0:
used, over = check_daily_quota(user_id=self.user_id, kind="image", limit=self.daily_limit)
if over:
return (
f"[Error] 已达每日图片生成上限({used}/{self.daily_limit} 张),"
f"次日 00:00 重置。"
)
cfg = self.cfg
model_id = cfg["model_id"]
requested_size = size or cfg.get("default_size", "2048x2048")
# ARK 硬门:输出面积必须落在 [min_pixels, max_pixels],否则 400 InvalidParameter。
# 模型自选 16:9 之类小尺寸(1920x1080=2.07M < 3.69M)会被打回,这里等比钳到合法区间,
# 静默纠错省一轮往返;已合规的尺寸原样透传。归一化时给用户一行提示。
chosen_size, size_note = self._normalize_size(
requested_size,
min_pixels=int(cfg.get("min_pixels", 0)),
max_pixels=int(cfg.get("max_pixels", 0)),
)
chosen_watermark = bool(cfg.get("default_watermark", False)) if watermark is None else bool(watermark)
chosen_search = bool(cfg.get("default_search", False)) if search is None else bool(search)
timeout_s = float(cfg.get("request_timeout_s", 60))
price = float(cfg.get("price_cny_per_image", 0))
body: dict[str, Any] = {
"model": model_id,
"prompt": prompt,
"size": chosen_size,
"response_format": "url",
"watermark": chosen_watermark,
}
if is_i2i:
# ARK /images/generations 接受 base64 data URL 作 image_urls(probe 2026-05-29 实测通)
body["image_urls"] = ref_data_urls
if chosen_search:
# 豆包 search 参数透传(YAML 注释里说明加价 ~¥0.05/张)
body["search"] = True
endpoint = cfg.get("endpoint", "/images/generations")
t0 = time.monotonic()
try:
with ArkClient(self.ark_cfg, timeout_s=timeout_s) as client:
resp = client.post_json(endpoint, body, timeout_s=timeout_s)
image_url, response_id = self._extract_url(resp)
if not image_url:
return f"[Error] seedream response 缺 image url: {json.dumps(resp, ensure_ascii=False)[:300]}"
# 落盘 figures/<ts>-<rand>.png + .meta.json
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
short = secrets.token_hex(3)
figures_dir = self.working_dir / "figures"
dest_png = figures_dir / f"{ts}-{short}.png"
client.download(image_url, dest_png, timeout_s=120.0)
except ArkError as e:
return f"[Error] seedream API: {e}"
elapsed = time.monotonic() - t0
# 估算成本(单价 snapshot 在 record_image_usage 里同步落库)
extra_cny = 0.05 if chosen_search else 0.0 # 搜索加价的粗略值,仅供 user 提示
cost_cny = float(price) + extra_cny
meta = {
"prompt": prompt,
"model_id": model_id,
"size": chosen_size,
"requested_size": requested_size, # 归一化前模型/用户请求的原始尺寸(=chosen_size 表示未钳)
"watermark": chosen_watermark,
"search": chosen_search,
"mode": "i2i" if is_i2i else "t2i",
"reference_images": ref_disp, # 改图时记录参考图(可追溯派生链),t2i 为空
"cost_cny": cost_cny,
"elapsed_s": round(elapsed, 2),
"response_id": response_id,
"ts": datetime.now().isoformat(timespec="seconds"),
}
meta_path = dest_png.with_suffix(".meta.json")
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
# usage_events 记账;失败不阻塞 tool 返回,但 emit 一条 warn 给 sink 的事走不到这里
# (tool 层没 sink 引用),先 print 兜底;后续可改成 sink 注入。
try:
record_image_usage(
task_id=self.task_id,
user_id=self.user_id,
model_profile=f"doubao.{self.variant_key}",
n_images=1,
size=chosen_size,
price_cny_per_image=float(price),
search=chosen_search,
extra_units={"search_extra_cny": extra_cny} if chosen_search else None,
)
except Exception as e:
print(f"[seedream] record_image_usage failed: {type(e).__name__}: {e}", flush=True)
disp = self._display(dest_png)
# 第一行 banner:前端 SPA 把这行(name===seedream 时)单独提到 details summary
# 旁边显示,用户不展开就能看到 model / size / cost / 耗时 —— 透明性的关键。
# 格式严格 key=value · 分隔,parse 用正则 `key=([^·\n]+)` 抓。
mode_seg = " · mode=i2i" if is_i2i else ""
ref_line = f"\nreference={ref_disp[0]}" if is_i2i else ""
note_line = f"\n{size_note}" if size_note else ""
return (
f"[seedream] model={model_id} · size={chosen_size} · "
f"cost=¥{cost_cny:.2f} · elapsed={elapsed:.1f}s{mode_seg}\n"
f"saved: {disp}{ref_line}\n"
f"prompt={prompt!r}\n"
f"watermark={chosen_watermark} search={chosen_search}{note_line}"
)
@staticmethod
def _normalize_size(
requested: str, *, min_pixels: int = 0, max_pixels: int = 0
) -> tuple[str, str]:
"""把请求尺寸钳进 ARK 面积约束 [min_pixels, max_pixels],保持宽高比。
返回 (chosen_size, note):note 非空表示发生了钳制(用于提示用户 + 记账用真实尺寸)。
- 无法解析成 "WxH" / 任一边 <= 0 → 原样返回,不阻塞(交给 API 自己报错,行为不回退)。
- min/max 传 0 → 视为不设该侧约束(向后兼容:旧 yaml 无这两个键时不改变行为)。
- 面积 < min:按 s=sqrt(min/area) 等比放大,两边向上取整到 8 的倍数,复核达标(不够再 +8)。
- 面积 > max:按 s=sqrt(max/area) 等比缩小,两边向下取整到 8 的倍数,复核达标(超了再 -8)。
- 已在区间内 → 原样透传,note 为空。
"""
raw = (requested or "").strip().lower().replace(" ", "")
parts = raw.split("x")
if len(parts) != 2:
return requested, ""
try:
w, h = int(parts[0]), int(parts[1])
except ValueError:
return requested, ""
if w <= 0 or h <= 0:
return requested, ""
import math
def _round8(v: float, *, up: bool) -> int:
n = math.ceil(v / 8) if up else math.floor(v / 8)
return max(8, n * 8)
area = w * h
if min_pixels > 0 and area < min_pixels:
s = math.sqrt(min_pixels / area)
nw, nh = _round8(w * s, up=True), _round8(h * s, up=True)
# 取整可能把面积压回下限之下,补到达标为止(沿较长边加 8,尽量不破坏比例)
while nw * nh < min_pixels:
if nw >= nh:
nh += 8
else:
nw += 8
chosen = f"{nw}x{nh}"
return chosen, (
f"[note] 请求尺寸 {w}x{h}({area:,}px)低于模型最小面积 {min_pixels:,}px,"
f"已等比放大到 {chosen} 出图。"
)
if max_pixels > 0 and area > max_pixels:
s = math.sqrt(max_pixels / area)
nw, nh = _round8(w * s, up=False), _round8(h * s, up=False)
while nw * nh > max_pixels:
if nw >= nh:
nw -= 8
else:
nh -= 8
chosen = f"{nw}x{nh}"
return chosen, (
f"[note] 请求尺寸 {w}x{h}({area:,}px)超过模型最大面积 {max_pixels:,}px,"
f"已等比缩小到 {chosen} 出图。"
)
return requested, ""
@staticmethod
def _extract_url(resp: dict) -> tuple[str, str]:
"""ark images/generations 响应解析,容忍几种已知 shape:
- OpenAI 兼容: {"data":[{"url":"..."}], "id":"..."}
- 豆包自有: {"data":{"images":[{"url":"..."}]}}
- 兜底: 任意位置出现的第一个 .url 字符串
"""
rid = str(resp.get("id") or resp.get("request_id") or "")
data = resp.get("data")
if isinstance(data, list) and data:
first = data[0]
if isinstance(first, dict):
u = first.get("url") or first.get("image_url")
if isinstance(u, str):
return u, rid
if isinstance(data, dict):
imgs = data.get("images")
if isinstance(imgs, list) and imgs:
u = imgs[0].get("url") if isinstance(imgs[0], dict) else None
if isinstance(u, str):
return u, rid
# 兜底:递归搜
def _find_url(o: Any) -> Optional[str]:
if isinstance(o, dict):
for k, v in o.items():
if k in ("url", "image_url") and isinstance(v, str) and v.startswith("http"):
return v
r = _find_url(v)
if r:
return r
elif isinstance(o, list):
for x in o:
r = _find_url(x)
if r:
return r
return None
return (_find_url(resp) or ""), rid