zcbot/scripts/smoke_seedream.py

115 lines
4.0 KiB
Python

"""Smoke: 豆包 Seedream 图像生成 tool 端到端走通。
跑法: .venv/Scripts/python.exe scripts/smoke_seedream.py
依赖 .env 里 ARK_API_KEY / ZCBOT_DB_URL。**会真的调豆包 API,产生 ~¥0.22 费用**。
校验:
1. ArkConfig.load() 拿到 cfg
2. SeedreamTool.execute(prompt=...) 返回 [image saved: ...] 文案
3. figures/<ts>-<rand>.png 落盘且大于 0 字节
4. 同名 .meta.json 存在 + 含 prompt/model_id/cost_cny 字段
5. usage_events 多出一行 kind="image",单价 snapshot 在 units jsonb
"""
from __future__ import annotations
import json
import os
import sys
import uuid
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
# 读 .env
env_file = ROOT / ".env"
if env_file.exists():
for line in env_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
k, _, v = line.partition("=")
os.environ.setdefault(k.strip(), v.strip())
from sqlalchemy import text
from core.ark_client import ArkConfig
from core.storage import session_scope
from core.storage.models import Task, User
from tools.seedream import SeedreamTool
def main() -> int:
cfg = ArkConfig.load()
if cfg is None:
print("[SKIP] ARK_API_KEY 未设(或 config/media/doubao.yaml 缺失),无法测真接口")
return 0
image_cfg = (cfg.raw.get("image") or {})
variant_key, variant_cfg = next(iter(image_cfg.items()))
print(f"[setup] variant={variant_key} model={variant_cfg.get('model_id')} price={variant_cfg.get('price_cny_per_image')}")
# 准备一次性 user + task 行(usage_events FK 校验)
uid = uuid.uuid4()
tid = uuid.uuid4()
ws_user = ROOT / "workspace" / "users" / str(uid)
wd = ws_user / "smoke_seedream"
wd.mkdir(parents=True, exist_ok=True)
with session_scope() as s:
s.add(User(user_id=uid))
s.add(Task(task_id=tid, user_id=uid, name="smoke_seedream", working_dir=str(wd)))
tool = SeedreamTool(
ark_cfg=cfg,
image_variant_cfg=variant_cfg,
variant_key=variant_key,
working_dir=wd,
task_id=tid,
user_id=uid,
base_dir=wd,
user_root=ws_user,
)
print(f"[call] prompt='一只橙色的小猫坐在窗台上望向远方,水彩风格'")
result = tool.execute(prompt="一只橙色的小猫坐在窗台上望向远方,水彩风格")
print(f"[tool result]\n{result}\n")
if result.startswith("[Error]"):
print(f"[FAIL] tool 返回错误")
return 2
# 校验 figures 目录与 meta
figs = list((wd / "figures").glob("*.png"))
assert len(figs) == 1, f"figures/*.png 应当 1 个,实际 {len(figs)}"
png = figs[0]
assert png.stat().st_size > 0, f"{png} 大小为 0"
print(f"[OK] png 落盘 {png.name} ({png.stat().st_size} bytes)")
meta_path = png.with_suffix(".meta.json")
assert meta_path.exists(), f"meta 文件不存在 {meta_path}"
meta = json.loads(meta_path.read_text(encoding="utf-8"))
for k in ("prompt", "model_id", "size", "cost_cny", "ts"):
assert k in meta, f"meta 缺字段 {k}"
print(f"[OK] meta 字段齐全: {list(meta.keys())}")
# 校验 usage_events
with session_scope() as s:
rows = s.execute(text(
"SELECT kind, model_profile, units, cost_cny FROM usage_events "
"WHERE task_id = :tid"
), {"tid": str(tid)}).all()
assert len(rows) == 1, f"usage_events 行数应 1,实际 {len(rows)}"
row = rows[0]
assert row.kind == "image", f"kind 应 image,实际 {row.kind}"
assert row.model_profile == f"doubao.{variant_key}", f"model_profile = {row.model_profile}"
assert "price_cny_per_image" in row.units, "units 缺 price_cny_per_image snapshot"
print(f"[OK] usage_events: kind={row.kind} model={row.model_profile} cost_cny={row.cost_cny}")
print(f" units snapshot: {row.units}")
print("\n[PASS] smoke_seedream 全部通过")
return 0
if __name__ == "__main__":
sys.exit(main())