375 lines
16 KiB
Python
375 lines
16 KiB
Python
"""seedance: 调豆包 Seedance 2.0 Fast 视频生成 API,产物落 working_dir/videos/。
|
||
|
||
异步任务:
|
||
1. POST /contents/generations/tasks → 返 `{"id": "cgt-..."}`
|
||
2. 轮询 GET /contents/generations/tasks/<cgt_id>(默认 5s 间隔,最长 10min)
|
||
直到 status ∈ {succeeded, failed, expired, cancelled}
|
||
3. succeeded → 取 content.video_url → download 到本地 + 写 meta + 计 usage_events
|
||
|
||
模型 ID + 单价 + 默认参数全在 `config/media/doubao.yaml`,本 tool 只装配。
|
||
计费按 token 公式 `(in_dur+out_dur) × W × H × fps / 1024`,文生视频 in_dur=0;
|
||
W×H 由 resolution + ratio 推算(横版 height=resolution_num,竖版 width=resolution_num)。
|
||
|
||
完成后:
|
||
- 视频落到 `<wd>/videos/<YYYYMMDD-HHMMSS>-<rand6>.mp4`
|
||
- 同名 `.meta.json` 写 prompt / model / 参数 / cost_cny / tokens / cgt_id / ts
|
||
- usage_events 写 kind="video" 一行(单价 + 分辨率 + 时长 snapshot 进 units)
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import secrets
|
||
import time
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from typing import Any, Callable, Optional
|
||
from uuid import UUID
|
||
|
||
from core.ark_client import ArkClient, ArkConfig, ArkError
|
||
from core.storage.usage import check_daily_quota, record_video_usage
|
||
|
||
from .base import Tool
|
||
|
||
|
||
# resolution → 短边像素;W/H 实际由 ratio 决定(横版短边=H,竖版短边=W)
|
||
_RESOLUTION_TO_SHORT_EDGE: dict[str, int] = {
|
||
"480p": 480,
|
||
"720p": 720,
|
||
"1080p": 1080, # fast 不支持,留给 pro;tool 不在这层挡,让豆包侧拒绝并返 [Error]
|
||
}
|
||
|
||
_VALID_RATIOS: set[str] = {"16:9", "9:16", "1:1", "4:3", "3:4", "21:9", "adaptive"}
|
||
|
||
|
||
def _resolve_dimensions(resolution: str, ratio: str) -> tuple[int, int]:
|
||
"""resolution(短边)+ ratio → (width, height) 像素值。
|
||
|
||
adaptive 比例下无法预知 W/H,退回正方形按短边算(仅供 token 估算,实际 W/H 由豆包决定)。
|
||
未知 resolution → 默认 720p。
|
||
"""
|
||
short = _RESOLUTION_TO_SHORT_EDGE.get(resolution, 720)
|
||
if ratio == "adaptive" or ratio not in _VALID_RATIOS:
|
||
return short, short
|
||
num_str, den_str = ratio.split(":", 1)
|
||
num, den = int(num_str), int(den_str)
|
||
if num >= den: # 横版或方形:height 是短边
|
||
return round(short * num / den), short
|
||
# 竖版:width 是短边
|
||
return short, round(short * den / num)
|
||
|
||
|
||
def _estimate_tokens(width: int, height: int, duration_s: int, fps: int, in_dur_s: int = 0) -> int:
|
||
"""火山方舟 token 估算公式:`(in_dur + out_dur) × W × H × fps / 1024`。
|
||
|
||
实测校验(fast, 720p 16:9, 5s, 24fps, 文生视频):
|
||
(0+5) × 1280 × 720 × 24 / 1024 = 108,000 tokens × ¥37/Mtok = ¥3.996 ≈ ¥4.00 ✓
|
||
"""
|
||
return int((in_dur_s + duration_s) * width * height * fps / 1024)
|
||
|
||
|
||
class SeedanceTool(Tool):
|
||
name = "seedance"
|
||
description = (
|
||
"Generate a short video with Doubao Seedance 2.0 Fast and save to working_dir/videos/. "
|
||
"Use only when the user explicitly asks for a video / 视频 / 动画 / 动起来. "
|
||
"Async: takes 30-90s to render. Costs ~¥1.86 (480p, 5s) ~ ¥4.00 (720p, 5s); "
|
||
"longer / higher-resolution scales up. Returns the saved relative path."
|
||
)
|
||
parameters = {
|
||
"type": "object",
|
||
"properties": {
|
||
"prompt": {
|
||
"type": "string",
|
||
"description": "中文或英文都行,详尽描述画面 + 运动 + 镜头(主体在做什么 / 镜头怎么动 / 场景 / 风格)。",
|
||
},
|
||
"resolution": {
|
||
"type": "string",
|
||
"description": "Video resolution. fast 版仅支持 '480p' / '720p'(默认 720p);1080p+ 仅 pro 可用。",
|
||
"enum": ["480p", "720p", "1080p"],
|
||
},
|
||
"ratio": {
|
||
"type": "string",
|
||
"description": "Aspect ratio. 默认 '16:9'(ppt/横屏配视频)。竖版海报/短视频用 '9:16'。",
|
||
"enum": ["16:9", "9:16", "1:1", "4:3", "3:4", "21:9", "adaptive"],
|
||
},
|
||
"duration": {
|
||
"type": "integer",
|
||
"description": "视频时长(秒),范围 4-15。短=便宜,默认 5。每加 1s 成本约线性上升。",
|
||
"minimum": 4,
|
||
"maximum": 15,
|
||
},
|
||
"watermark": {
|
||
"type": "boolean",
|
||
"description": "是否打豆包水印。默认 false(申报/PPT 场景反需求)。",
|
||
},
|
||
"generate_audio": {
|
||
"type": "boolean",
|
||
"description": (
|
||
"是否同步生成 AI 背景音 / 对白(Seedance 2.0 旗舰特性)。默认 false 控成本;"
|
||
"广告 / 短剧 / 角色对白等场景传 true,模型会一并算音轨,cost 高于纯视频。"
|
||
),
|
||
},
|
||
},
|
||
"required": ["prompt"],
|
||
}
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
ark_cfg: ArkConfig,
|
||
video_variant_cfg: dict,
|
||
variant_key: str,
|
||
working_dir: Path,
|
||
task_id: UUID,
|
||
user_id: UUID,
|
||
base_dir: Optional[Path] = None,
|
||
user_root: Optional[Path] = None,
|
||
cancel_check: Optional[Callable[[], bool]] = None,
|
||
daily_limit: int = 0,
|
||
) -> None:
|
||
super().__init__(base_dir, user_root=user_root)
|
||
self.ark_cfg = ark_cfg
|
||
self.cfg = video_variant_cfg
|
||
self.variant_key = variant_key # 'seedance_2_fast' → usage_events.model_profile = "doubao.seedance_2_fast"
|
||
self.working_dir = Path(working_dir)
|
||
self.task_id = task_id
|
||
self.user_id = user_id
|
||
self.cancel_check = cancel_check # 轮询期间检查是否被 cancel
|
||
self.daily_limit = int(daily_limit) # 0 / 负 = 不限;由 agent_builder 从 yaml quotas 透传
|
||
|
||
def execute(
|
||
self,
|
||
prompt: str,
|
||
resolution: Optional[str] = None,
|
||
ratio: Optional[str] = None,
|
||
duration: Optional[int] = None,
|
||
watermark: Optional[bool] = None,
|
||
generate_audio: Optional[bool] = None,
|
||
) -> str:
|
||
if not (prompt or "").strip():
|
||
return "[Error] prompt 不能为空"
|
||
|
||
# 每账号每日配额(yaml quotas.videos_per_day)。失败 / cancel 不计,因为
|
||
# record_video_usage 只在 succeeded+下载完才落库。tool 返串会进 LLM 上下文
|
||
# → 模型据此向用户解释,所以**只暴露用户该看的部分**(已用/上限 + 重置时间),
|
||
# 内部 yaml 路径不进对话(管理员要改的地方读代码/yaml 自己找)。
|
||
if self.daily_limit > 0:
|
||
used, over = check_daily_quota(user_id=self.user_id, kind="video", limit=self.daily_limit)
|
||
if over:
|
||
return (
|
||
f"[Error] 已达每日视频生成上限({used}/{self.daily_limit} 个),"
|
||
f"次日 00:00 重置。"
|
||
)
|
||
|
||
cfg = self.cfg
|
||
model_id = cfg["model_id"]
|
||
chosen_resolution = resolution or cfg.get("default_resolution", "720p")
|
||
chosen_ratio = ratio or cfg.get("default_ratio", "16:9")
|
||
chosen_duration = int(duration) if duration is not None else int(cfg.get("default_duration", 5))
|
||
chosen_watermark = bool(cfg.get("default_watermark", False)) if watermark is None else bool(watermark)
|
||
chosen_generate_audio = (
|
||
bool(cfg.get("default_generate_audio", False)) if generate_audio is None else bool(generate_audio)
|
||
)
|
||
fps = int(cfg.get("fps", 24))
|
||
|
||
submit_timeout = float(cfg.get("request_timeout_s", 60))
|
||
poll_interval = float(cfg.get("poll_interval_s", 5))
|
||
poll_timeout = float(cfg.get("poll_timeout_s", 600))
|
||
price_t2v = float(cfg.get("price_cny_per_mtoken_text2video", 37.0))
|
||
|
||
submit_endpoint = cfg.get("endpoint_submit", "/contents/generations/tasks")
|
||
poll_endpoint_base = cfg.get("endpoint_poll", "/contents/generations/tasks")
|
||
|
||
body: dict[str, Any] = {
|
||
"model": model_id,
|
||
"content": [{"type": "text", "text": prompt}],
|
||
"ratio": chosen_ratio,
|
||
"resolution": chosen_resolution,
|
||
"duration": chosen_duration,
|
||
"watermark": chosen_watermark,
|
||
"generate_audio": chosen_generate_audio,
|
||
}
|
||
|
||
t0 = time.monotonic()
|
||
try:
|
||
with ArkClient(self.ark_cfg, timeout_s=submit_timeout) as client:
|
||
# 1. submit
|
||
submit_resp = client.post_json(submit_endpoint, body, timeout_s=submit_timeout)
|
||
cgt_id = self._extract_task_id(submit_resp)
|
||
if not cgt_id:
|
||
return f"[Error] seedance submit 响应缺 task id: {json.dumps(submit_resp, ensure_ascii=False)[:300]}"
|
||
|
||
# 2. poll
|
||
poll_url = f"{poll_endpoint_base}/{cgt_id}"
|
||
deadline = time.monotonic() + poll_timeout
|
||
last_status = ""
|
||
final_resp: dict = {}
|
||
while True:
|
||
if self.cancel_check is not None and self.cancel_check():
|
||
return (
|
||
f"[Cancelled] seedance task {cgt_id} 用户取消(远端任务可能仍在跑;"
|
||
f"Volcengine 失败/成功才计费,若仍出片可能产生 ~¥{self._rough_cost(chosen_resolution, chosen_ratio, chosen_duration, fps, price_t2v):.2f})"
|
||
)
|
||
if time.monotonic() > deadline:
|
||
return (
|
||
f"[Error] seedance 轮询超时(>{poll_timeout:.0f}s),最后 status={last_status!r},"
|
||
f"cgt_id={cgt_id}(24h 内可手工 GET {poll_url} 拉结果)"
|
||
)
|
||
time.sleep(poll_interval)
|
||
poll_resp = client.get_json(poll_url, timeout_s=submit_timeout)
|
||
last_status = str(poll_resp.get("status") or "").lower()
|
||
final_resp = poll_resp
|
||
if last_status in ("succeeded", "failed", "expired", "cancelled"):
|
||
break
|
||
|
||
if last_status != "succeeded":
|
||
err = (final_resp.get("error") or {}) if isinstance(final_resp.get("error"), dict) else {}
|
||
msg = err.get("message") or final_resp.get("message") or "(无错误描述)"
|
||
return f"[Error] seedance task {cgt_id} 终态 status={last_status},msg={msg}"
|
||
|
||
video_url = self._extract_video_url(final_resp)
|
||
if not video_url:
|
||
return f"[Error] seedance succeeded 但响应缺 video url: {json.dumps(final_resp, ensure_ascii=False)[:400]}"
|
||
|
||
# 3. download
|
||
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||
short = secrets.token_hex(3)
|
||
videos_dir = self.working_dir / "videos"
|
||
dest_mp4 = videos_dir / f"{ts}-{short}.mp4"
|
||
client.download(video_url, dest_mp4, timeout_s=300.0)
|
||
except ArkError as e:
|
||
return f"[Error] seedance API: {e}"
|
||
|
||
elapsed = time.monotonic() - t0
|
||
|
||
# tokens / cost:优先用响应里 usage 字段(若豆包返了),否则按公式估算
|
||
width, height = _resolve_dimensions(chosen_resolution, chosen_ratio)
|
||
tokens_estimated = _estimate_tokens(width, height, chosen_duration, fps, in_dur_s=0)
|
||
tokens_actual = self._extract_tokens(final_resp) or tokens_estimated
|
||
cost_cny = tokens_actual * price_t2v / 1_000_000.0
|
||
|
||
meta = {
|
||
"prompt": prompt,
|
||
"model_id": model_id,
|
||
"resolution": chosen_resolution,
|
||
"ratio": chosen_ratio,
|
||
"width": width,
|
||
"height": height,
|
||
"duration_s": chosen_duration,
|
||
"fps": fps,
|
||
"watermark": chosen_watermark,
|
||
"generate_audio": chosen_generate_audio,
|
||
"tokens": tokens_actual,
|
||
"tokens_estimated": tokens_estimated,
|
||
"price_cny_per_mtoken": price_t2v,
|
||
"cost_cny": round(cost_cny, 4),
|
||
"elapsed_s": round(elapsed, 1),
|
||
"cgt_id": cgt_id,
|
||
"ts": datetime.now().isoformat(timespec="seconds"),
|
||
}
|
||
meta_path = dest_mp4.with_suffix(".meta.json")
|
||
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
|
||
|
||
try:
|
||
record_video_usage(
|
||
task_id=self.task_id,
|
||
user_id=self.user_id,
|
||
model_profile=f"doubao.{self.variant_key}",
|
||
resolution=chosen_resolution,
|
||
ratio=chosen_ratio,
|
||
duration_s=chosen_duration,
|
||
fps=fps,
|
||
width=width,
|
||
height=height,
|
||
tokens=tokens_actual,
|
||
price_cny_per_mtoken=price_t2v,
|
||
has_video_input=False, # phase 1 仅 t2v;i2v 接入后这里读 body 判断
|
||
watermark=chosen_watermark,
|
||
extra_units={
|
||
"cgt_id": cgt_id,
|
||
"elapsed_s": round(elapsed, 1),
|
||
"generate_audio": chosen_generate_audio,
|
||
},
|
||
)
|
||
except Exception as e:
|
||
print(f"[seedance] record_video_usage failed: {type(e).__name__}: {e}", flush=True)
|
||
|
||
disp = self._display(dest_mp4)
|
||
# banner 协议与 seedream 一致:首行 `[tool] key=value · key=value ...`
|
||
# 前端 extractMediaBanner 已 whitelist seedance,正则抓 key=value 挂徽章
|
||
return (
|
||
f"[seedance] model={model_id} · resolution={chosen_resolution} · ratio={chosen_ratio} · "
|
||
f"duration={chosen_duration}s · audio={chosen_generate_audio} · "
|
||
f"cost=¥{cost_cny:.2f} · elapsed={elapsed:.1f}s\n"
|
||
f"saved: {disp}\n"
|
||
f"prompt={prompt!r}\n"
|
||
f"watermark={chosen_watermark} cgt_id={cgt_id}"
|
||
)
|
||
|
||
@staticmethod
|
||
def _rough_cost(resolution: str, ratio: str, duration_s: int, fps: int, price_per_mtoken: float) -> float:
|
||
w, h = _resolve_dimensions(resolution, ratio)
|
||
tokens = _estimate_tokens(w, h, duration_s, fps)
|
||
return tokens * price_per_mtoken / 1_000_000.0
|
||
|
||
@staticmethod
|
||
def _extract_task_id(resp: dict) -> str:
|
||
"""submit 响应抽 cgt-xxx task id。容忍几种已知 shape。"""
|
||
for k in ("id", "task_id", "request_id"):
|
||
v = resp.get(k)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
data = resp.get("data")
|
||
if isinstance(data, dict):
|
||
for k in ("id", "task_id"):
|
||
v = data.get(k)
|
||
if isinstance(v, str) and v:
|
||
return v
|
||
return ""
|
||
|
||
@staticmethod
|
||
def _extract_video_url(resp: dict) -> str:
|
||
"""succeeded 响应抽 video url。已知路径:
|
||
- {"content": {"video_url": "..."}}(火山方舟标准)
|
||
- {"data": {"video_url": "..."}}(部分代理)
|
||
- 兜底:任意位置首个 .url 字符串
|
||
"""
|
||
for key in ("content", "data", "output"):
|
||
sub = resp.get(key)
|
||
if isinstance(sub, dict):
|
||
v = sub.get("video_url")
|
||
if isinstance(v, str) and v.startswith("http"):
|
||
return v
|
||
# output[0].content[0].text → json parse(代理路径,极少用到)
|
||
# 递归兜底
|
||
def _find_url(o: Any) -> Optional[str]:
|
||
if isinstance(o, dict):
|
||
for k, v in o.items():
|
||
if k in ("video_url", "url") and isinstance(v, str) and v.startswith("http"):
|
||
# 过滤明显非视频的(image_url 等),video_url 优先级最高
|
||
if k == "video_url":
|
||
return v
|
||
if v.lower().endswith((".mp4", ".webm", ".mov")):
|
||
return v
|
||
r = _find_url(v)
|
||
if r:
|
||
return r
|
||
elif isinstance(o, list):
|
||
for x in o:
|
||
r = _find_url(x)
|
||
if r:
|
||
return r
|
||
return None
|
||
return _find_url(resp) or ""
|
||
|
||
@staticmethod
|
||
def _extract_tokens(resp: dict) -> Optional[int]:
|
||
"""响应里若带官方 usage 字段就取,优先 total_tokens > completion_tokens。"""
|
||
usage = resp.get("usage")
|
||
if isinstance(usage, dict):
|
||
for k in ("total_tokens", "completion_tokens", "tokens"):
|
||
v = usage.get(k)
|
||
if isinstance(v, (int, float)) and v > 0:
|
||
return int(v)
|
||
return None
|