zcbot/tools/seedance.py

343 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""seedance: 调豆包 Seedance 2.0 Fast 视频生成 API,产物落 working_dir/videos/。
异步任务:
1. POST /contents/generations/tasks → 返 `{"id": "cgt-..."}`
2. 轮询 GET /contents/generations/tasks/<cgt_id>(默认 5s 间隔,最长 10min)
直到 status ∈ {succeeded, failed, expired, cancelled}
3. succeeded → 取 content.video_url → download 到本地 + 写 meta + 计 usage_events
模型 ID + 单价 + 默认参数全在 `config/media/doubao.yaml`,本 tool 只装配。
计费按 token 公式 `(in_dur+out_dur) × W × H × fps / 1024`,文生视频 in_dur=0;
W×H 由 resolution + ratio 推算(横版 height=resolution_num,竖版 width=resolution_num)。
完成后:
- 视频落到 `<wd>/videos/<YYYYMMDD-HHMMSS>-<rand6>.mp4`
- 同名 `.meta.json` 写 prompt / model / 参数 / cost_cny / tokens / cgt_id / ts
- usage_events 写 kind="video" 一行(单价 + 分辨率 + 时长 snapshot 进 units)
"""
from __future__ import annotations
import json
import secrets
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Optional
from uuid import UUID
from core.ark_client import ArkClient, ArkConfig, ArkError
from core.storage.usage import record_video_usage
from .base import Tool
# resolution → 短边像素;W/H 实际由 ratio 决定(横版短边=H,竖版短边=W)
_RESOLUTION_TO_SHORT_EDGE: dict[str, int] = {
"480p": 480,
"720p": 720,
"1080p": 1080, # fast 不支持,留给 pro;tool 不在这层挡,让豆包侧拒绝并返 [Error]
}
_VALID_RATIOS: set[str] = {"16:9", "9:16", "1:1", "4:3", "3:4", "21:9", "adaptive"}
def _resolve_dimensions(resolution: str, ratio: str) -> tuple[int, int]:
"""resolution(短边)+ ratio → (width, height) 像素值。
adaptive 比例下无法预知 W/H,退回正方形按短边算(仅供 token 估算,实际 W/H 由豆包决定)。
未知 resolution → 默认 720p。
"""
short = _RESOLUTION_TO_SHORT_EDGE.get(resolution, 720)
if ratio == "adaptive" or ratio not in _VALID_RATIOS:
return short, short
num_str, den_str = ratio.split(":", 1)
num, den = int(num_str), int(den_str)
if num >= den: # 横版或方形:height 是短边
return round(short * num / den), short
# 竖版:width 是短边
return short, round(short * den / num)
def _estimate_tokens(width: int, height: int, duration_s: int, fps: int, in_dur_s: int = 0) -> int:
"""火山方舟 token 估算公式:`(in_dur + out_dur) × W × H × fps / 1024`。
实测校验(fast, 720p 16:9, 5s, 24fps, 文生视频):
(0+5) × 1280 × 720 × 24 / 1024 = 108,000 tokens × ¥37/Mtok = ¥3.996 ≈ ¥4.00 ✓
"""
return int((in_dur_s + duration_s) * width * height * fps / 1024)
class SeedanceTool(Tool):
name = "seedance"
description = (
"Generate a short video with Doubao Seedance 2.0 Fast and save to working_dir/videos/. "
"Use only when the user explicitly asks for a video / 视频 / 动画 / 动起来. "
"Async: takes 30-90s to render. Costs ~¥1.86 (480p, 5s) ~ ¥4.00 (720p, 5s); "
"longer / higher-resolution scales up. Returns the saved relative path."
)
parameters = {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "中文或英文都行,详尽描述画面 + 运动 + 镜头(主体在做什么 / 镜头怎么动 / 场景 / 风格)。",
},
"resolution": {
"type": "string",
"description": "Video resolution. fast 版仅支持 '480p' / '720p'(默认 720p);1080p+ 仅 pro 可用。",
"enum": ["480p", "720p", "1080p"],
},
"ratio": {
"type": "string",
"description": "Aspect ratio. 默认 '16:9'(ppt/横屏配视频)。竖版海报/短视频用 '9:16'",
"enum": ["16:9", "9:16", "1:1", "4:3", "3:4", "21:9", "adaptive"],
},
"duration": {
"type": "integer",
"description": "视频时长(秒),范围 4-15。短=便宜,默认 5。每加 1s 成本约线性上升。",
"minimum": 4,
"maximum": 15,
},
"watermark": {
"type": "boolean",
"description": "是否打豆包水印。默认 false(申报/PPT 场景反需求)。",
},
},
"required": ["prompt"],
}
def __init__(
self,
*,
ark_cfg: ArkConfig,
video_variant_cfg: dict,
variant_key: str,
working_dir: Path,
task_id: UUID,
user_id: UUID,
base_dir: Optional[Path] = None,
user_root: Optional[Path] = None,
cancel_check: Optional[Callable[[], bool]] = None,
) -> None:
super().__init__(base_dir, user_root=user_root)
self.ark_cfg = ark_cfg
self.cfg = video_variant_cfg
self.variant_key = variant_key # 'seedance_2_fast' → usage_events.model_profile = "doubao.seedance_2_fast"
self.working_dir = Path(working_dir)
self.task_id = task_id
self.user_id = user_id
self.cancel_check = cancel_check # 轮询期间检查是否被 cancel
def execute(
self,
prompt: str,
resolution: Optional[str] = None,
ratio: Optional[str] = None,
duration: Optional[int] = None,
watermark: Optional[bool] = None,
) -> str:
if not (prompt or "").strip():
return "[Error] prompt 不能为空"
cfg = self.cfg
model_id = cfg["model_id"]
chosen_resolution = resolution or cfg.get("default_resolution", "720p")
chosen_ratio = ratio or cfg.get("default_ratio", "16:9")
chosen_duration = int(duration) if duration is not None else int(cfg.get("default_duration", 5))
chosen_watermark = bool(cfg.get("default_watermark", False)) if watermark is None else bool(watermark)
fps = int(cfg.get("fps", 24))
submit_timeout = float(cfg.get("request_timeout_s", 60))
poll_interval = float(cfg.get("poll_interval_s", 5))
poll_timeout = float(cfg.get("poll_timeout_s", 600))
price_t2v = float(cfg.get("price_cny_per_mtoken_text2video", 37.0))
submit_endpoint = cfg.get("endpoint_submit", "/contents/generations/tasks")
poll_endpoint_base = cfg.get("endpoint_poll", "/contents/generations/tasks")
body: dict[str, Any] = {
"model": model_id,
"content": [{"type": "text", "text": prompt}],
"ratio": chosen_ratio,
"resolution": chosen_resolution,
"duration": chosen_duration,
"watermark": chosen_watermark,
}
t0 = time.monotonic()
try:
with ArkClient(self.ark_cfg, timeout_s=submit_timeout) as client:
# 1. submit
submit_resp = client.post_json(submit_endpoint, body, timeout_s=submit_timeout)
cgt_id = self._extract_task_id(submit_resp)
if not cgt_id:
return f"[Error] seedance submit 响应缺 task id: {json.dumps(submit_resp, ensure_ascii=False)[:300]}"
# 2. poll
poll_url = f"{poll_endpoint_base}/{cgt_id}"
deadline = time.monotonic() + poll_timeout
last_status = ""
final_resp: dict = {}
while True:
if self.cancel_check is not None and self.cancel_check():
return (
f"[Cancelled] seedance task {cgt_id} 用户取消(远端任务可能仍在跑;"
f"Volcengine 失败/成功才计费,若仍出片可能产生 ~¥{self._rough_cost(chosen_resolution, chosen_ratio, chosen_duration, fps, price_t2v):.2f})"
)
if time.monotonic() > deadline:
return (
f"[Error] seedance 轮询超时(>{poll_timeout:.0f}s),最后 status={last_status!r},"
f"cgt_id={cgt_id}(24h 内可手工 GET {poll_url} 拉结果)"
)
time.sleep(poll_interval)
poll_resp = client.get_json(poll_url, timeout_s=submit_timeout)
last_status = str(poll_resp.get("status") or "").lower()
final_resp = poll_resp
if last_status in ("succeeded", "failed", "expired", "cancelled"):
break
if last_status != "succeeded":
err = (final_resp.get("error") or {}) if isinstance(final_resp.get("error"), dict) else {}
msg = err.get("message") or final_resp.get("message") or "(无错误描述)"
return f"[Error] seedance task {cgt_id} 终态 status={last_status},msg={msg}"
video_url = self._extract_video_url(final_resp)
if not video_url:
return f"[Error] seedance succeeded 但响应缺 video url: {json.dumps(final_resp, ensure_ascii=False)[:400]}"
# 3. download
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
short = secrets.token_hex(3)
videos_dir = self.working_dir / "videos"
dest_mp4 = videos_dir / f"{ts}-{short}.mp4"
client.download(video_url, dest_mp4, timeout_s=300.0)
except ArkError as e:
return f"[Error] seedance API: {e}"
elapsed = time.monotonic() - t0
# tokens / cost:优先用响应里 usage 字段(若豆包返了),否则按公式估算
width, height = _resolve_dimensions(chosen_resolution, chosen_ratio)
tokens_estimated = _estimate_tokens(width, height, chosen_duration, fps, in_dur_s=0)
tokens_actual = self._extract_tokens(final_resp) or tokens_estimated
cost_cny = tokens_actual * price_t2v / 1_000_000.0
meta = {
"prompt": prompt,
"model_id": model_id,
"resolution": chosen_resolution,
"ratio": chosen_ratio,
"width": width,
"height": height,
"duration_s": chosen_duration,
"fps": fps,
"watermark": chosen_watermark,
"tokens": tokens_actual,
"tokens_estimated": tokens_estimated,
"price_cny_per_mtoken": price_t2v,
"cost_cny": round(cost_cny, 4),
"elapsed_s": round(elapsed, 1),
"cgt_id": cgt_id,
"ts": datetime.now().isoformat(timespec="seconds"),
}
meta_path = dest_mp4.with_suffix(".meta.json")
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
try:
record_video_usage(
task_id=self.task_id,
user_id=self.user_id,
model_profile=f"doubao.{self.variant_key}",
resolution=chosen_resolution,
ratio=chosen_ratio,
duration_s=chosen_duration,
fps=fps,
width=width,
height=height,
tokens=tokens_actual,
price_cny_per_mtoken=price_t2v,
has_video_input=False, # phase 1 仅 t2v;i2v 接入后这里读 body 判断
watermark=chosen_watermark,
extra_units={"cgt_id": cgt_id, "elapsed_s": round(elapsed, 1)},
)
except Exception as e:
print(f"[seedance] record_video_usage failed: {type(e).__name__}: {e}", flush=True)
disp = self._display(dest_mp4)
# banner 协议与 seedream 一致:首行 `[tool] key=value · key=value ...`
# 前端 extractMediaBanner 已 whitelist seedance,正则抓 key=value 挂徽章
return (
f"[seedance] model={model_id} · resolution={chosen_resolution} · ratio={chosen_ratio} · "
f"duration={chosen_duration}s · cost=¥{cost_cny:.2f} · elapsed={elapsed:.1f}s\n"
f"saved: {disp}\n"
f"prompt={prompt!r}\n"
f"watermark={chosen_watermark} cgt_id={cgt_id}"
)
@staticmethod
def _rough_cost(resolution: str, ratio: str, duration_s: int, fps: int, price_per_mtoken: float) -> float:
w, h = _resolve_dimensions(resolution, ratio)
tokens = _estimate_tokens(w, h, duration_s, fps)
return tokens * price_per_mtoken / 1_000_000.0
@staticmethod
def _extract_task_id(resp: dict) -> str:
"""submit 响应抽 cgt-xxx task id。容忍几种已知 shape。"""
for k in ("id", "task_id", "request_id"):
v = resp.get(k)
if isinstance(v, str) and v:
return v
data = resp.get("data")
if isinstance(data, dict):
for k in ("id", "task_id"):
v = data.get(k)
if isinstance(v, str) and v:
return v
return ""
@staticmethod
def _extract_video_url(resp: dict) -> str:
"""succeeded 响应抽 video url。已知路径:
- {"content": {"video_url": "..."}}(火山方舟标准)
- {"data": {"video_url": "..."}}(部分代理)
- 兜底:任意位置首个 .url 字符串
"""
for key in ("content", "data", "output"):
sub = resp.get(key)
if isinstance(sub, dict):
v = sub.get("video_url")
if isinstance(v, str) and v.startswith("http"):
return v
# output[0].content[0].text → json parse(代理路径,极少用到)
# 递归兜底
def _find_url(o: Any) -> Optional[str]:
if isinstance(o, dict):
for k, v in o.items():
if k in ("video_url", "url") and isinstance(v, str) and v.startswith("http"):
# 过滤明显非视频的(image_url 等),video_url 优先级最高
if k == "video_url":
return v
if v.lower().endswith((".mp4", ".webm", ".mov")):
return v
r = _find_url(v)
if r:
return r
elif isinstance(o, list):
for x in o:
r = _find_url(x)
if r:
return r
return None
return _find_url(resp) or ""
@staticmethod
def _extract_tokens(resp: dict) -> Optional[int]:
"""响应里若带官方 usage 字段就取,优先 total_tokens > completion_tokens。"""
usage = resp.get("usage")
if isinstance(usage, dict):
for k in ("total_tokens", "completion_tokens", "tokens"):
v = usage.get(k)
if isinstance(v, (int, float)) and v > 0:
return int(v)
return None