131 lines
4.7 KiB
Python
131 lines
4.7 KiB
Python
"""Smoke: look_at_image(豆包 Seed 2.0 Lite 视觉)端到端走通 + OCR 验证。
|
|
|
|
跑法: .venv/Scripts/python.exe scripts/smoke_look_at_image.py
|
|
依赖 .env 里 ARK_API_KEY / ZCBOT_DB_URL。**会真的调豆包 vision API,产生 < ¥0.01 费用**。
|
|
|
|
校验:
|
|
1. ArkConfig.load() + yaml vision 段存在
|
|
2. 合成一张含已知文字 "ZCBOT-VISION-8848" 的 PNG → LookAtImageTool.execute 能 OCR 出该串
|
|
3. 返回串首行 banner 含 model/tokens/cost
|
|
4. usage_events 多出一行 kind="vision",units 含 tokens_in/out + 单价 snapshot
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import sys
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parent.parent
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
# Windows 控制台默认 GBK,打印 ¥ / 中文结果会崩 → 强制 stdout UTF-8(只影响本脚本打印)
|
|
try:
|
|
sys.stdout.reconfigure(encoding="utf-8", errors="replace") # type: ignore[attr-defined]
|
|
except Exception:
|
|
pass
|
|
|
|
# 读 .env
|
|
env_file = ROOT / ".env"
|
|
if env_file.exists():
|
|
for line in env_file.read_text(encoding="utf-8").splitlines():
|
|
line = line.strip()
|
|
if not line or line.startswith("#") or "=" not in line:
|
|
continue
|
|
k, _, v = line.partition("=")
|
|
os.environ.setdefault(k.strip(), v.strip())
|
|
|
|
from PIL import Image, ImageDraw
|
|
from sqlalchemy import text
|
|
|
|
from core.ark_client import ArkConfig
|
|
from core.storage import session_scope
|
|
from core.storage.models import Task, User
|
|
from tools.look_at_image import LookAtImageTool
|
|
|
|
MAGIC = "ZCBOT-VISION-8848"
|
|
|
|
|
|
def make_text_png(dest: Path) -> None:
|
|
"""白底大号黑字 PNG(放大 4x 让默认位图字体也清晰可 OCR)。"""
|
|
small = Image.new("RGB", (260, 60), (255, 255, 255))
|
|
d = ImageDraw.Draw(small)
|
|
d.text((10, 22), MAGIC, fill=(0, 0, 0))
|
|
big = small.resize((260 * 4, 60 * 4), Image.NEAREST)
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
big.save(dest)
|
|
|
|
|
|
def main() -> int:
|
|
cfg = ArkConfig.load()
|
|
if cfg is None:
|
|
print("[SKIP] ARK_API_KEY 未设(或 doubao.yaml 缺失),无法测真接口")
|
|
return 0
|
|
vision_cfg = (cfg.raw.get("vision") or {})
|
|
if not vision_cfg:
|
|
print("[SKIP] doubao.yaml 无 vision 段")
|
|
return 0
|
|
variant_key, variant_cfg = next(iter(vision_cfg.items()))
|
|
print(f"[setup] variant={variant_key} model={variant_cfg.get('model_id')} "
|
|
f"price_in={variant_cfg.get('price_cny_per_mtoken_input')} "
|
|
f"price_out={variant_cfg.get('price_cny_per_mtoken_output')}")
|
|
|
|
uid = uuid.uuid4()
|
|
tid = uuid.uuid4()
|
|
ws_user = ROOT / "workspace" / "users" / str(uid)
|
|
wd = ws_user / "smoke_vision"
|
|
img = wd / "figures" / "magic.png"
|
|
make_text_png(img)
|
|
print(f"[setup] 合成测试图 {img.name}(含文字 {MAGIC!r})")
|
|
|
|
with session_scope() as s: # User 先单独落库,再建 Task(FK 顺序保险)
|
|
s.add(User(user_id=uid))
|
|
with session_scope() as s:
|
|
s.add(Task(task_id=tid, user_id=uid, name="smoke_vision", working_dir=str(wd)))
|
|
|
|
tool = LookAtImageTool(
|
|
ark_cfg=cfg,
|
|
vision_variant_cfg=variant_cfg,
|
|
variant_key=variant_key,
|
|
working_dir=wd,
|
|
task_id=tid,
|
|
user_id=uid,
|
|
base_dir=wd,
|
|
user_root=ws_user,
|
|
)
|
|
|
|
print("[call] question='把图中的文字逐字 OCR 出来'")
|
|
result = tool.execute(image="figures/magic.png", question="把图中的文字逐字 OCR 出来。")
|
|
print(f"[tool result]\n{result}\n")
|
|
if result.startswith("[Error]"):
|
|
print("[FAIL] tool 返回错误")
|
|
return 2
|
|
|
|
# OCR 命中(容忍模型加空格/大小写差异,去掉分隔比对)
|
|
norm = result.replace(" ", "").replace("\n", "").upper()
|
|
if MAGIC.replace("-", "") in norm.replace("-", ""):
|
|
print(f"[OK] OCR 命中魔术串 {MAGIC}")
|
|
else:
|
|
print(f"[WARN] 未在结果里精确匹配 {MAGIC} —— 人工核对上面 result(模型可能换了排版)")
|
|
|
|
with session_scope() as s:
|
|
rows = s.execute(text(
|
|
"SELECT kind, model_profile, units, cost_cny FROM usage_events "
|
|
"WHERE task_id = :tid"
|
|
), {"tid": str(tid)}).all()
|
|
assert len(rows) == 1, f"usage_events 行数应 1,实际 {len(rows)}"
|
|
row = rows[0]
|
|
assert row.kind == "vision", f"kind 应 vision,实际 {row.kind}"
|
|
assert row.model_profile == f"doubao.{variant_key}", f"model_profile={row.model_profile}"
|
|
for k in ("tokens_in", "tokens_out", "input_cny_per_mtoken", "output_cny_per_mtoken"):
|
|
assert k in row.units, f"units 缺 {k}"
|
|
print(f"[OK] usage_events: kind={row.kind} model={row.model_profile} "
|
|
f"cost_cny={row.cost_cny} units={row.units}")
|
|
|
|
print("\n[PASS] smoke_look_at_image 全部通过")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|