zcbot/scripts/backfill_tool_message_name.py

103 lines
3.5 KiB
Python

"""Backfill tool 消息 payload 缺失的 `name` 字段。
背景:loop.py 早期 append tool 消息时只写了 role/tool_call_id/content,
没存 name。前端依赖 `payload.name` 判断是不是产物工具(seedream/seedance)
→ 历史 task 重新打开时 banner/chip 不显示。本脚本一次性回填。
策略:按 task 走,先把每条 assistant 消息里 tool_calls[].id → name 收集
成 map;再扫该 task 内 role=tool 的消息,按 tool_call_id 查 map 补 name。
查不到的(罕见:LLM 给的 id 跨 task 不对齐 / 历史脏数据)打 warn 跳过。
跑法: .venv/Scripts/python.exe scripts/backfill_tool_message_name.py
默认 dry-run,加 --apply 真写。
幂等:已经有 name 的消息跳过;再跑一遍 0 改动。
"""
from __future__ import annotations
import argparse
import os
import sys
from collections import defaultdict
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
env_file = ROOT / ".env"
if env_file.exists():
for line in env_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
k, _, v = line.partition("=")
os.environ.setdefault(k.strip(), v.strip())
from sqlalchemy import select
from sqlalchemy.orm.attributes import flag_modified
from core.storage import session_scope
from core.storage.models import Message
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--apply", action="store_true", help="真写;默认 dry-run 只打印")
args = ap.parse_args()
n_tool = 0 # 扫到的 tool 消息总数
n_already = 0 # 已有 name 跳过
n_filled = 0 # 本次补上
n_unresolved = 0 # 查不到 name 的(打 warn)
with session_scope() as s:
rows = s.execute(select(Message).order_by(Message.task_id, Message.idx)).scalars().all()
# tool_call_id → name(按 task 分组,避免跨 task 撞 id)
by_task: dict[str, dict[str, str]] = defaultdict(dict)
for m in rows:
p = m.payload or {}
if p.get("role") != "assistant":
continue
tcs = p.get("tool_calls") or []
for tc in tcs:
tcid = tc.get("id")
fn = (tc.get("function") or {}).get("name")
if tcid and fn:
by_task[str(m.task_id)][tcid] = fn
for m in rows:
p = m.payload or {}
if p.get("role") != "tool":
continue
n_tool += 1
if p.get("name"):
n_already += 1
continue
tcid = p.get("tool_call_id")
name = by_task.get(str(m.task_id), {}).get(tcid or "")
if not name:
n_unresolved += 1
print(f" [WARN] task={m.task_id} idx={m.idx} tool_call_id={tcid!r} 找不到对应 assistant tool_call,跳过")
continue
print(f" [FILL] task={m.task_id} idx={m.idx} <- name={name!r}")
p["name"] = name
m.payload = p
flag_modified(m, "payload")
n_filled += 1
if args.apply:
s.commit()
else:
s.rollback()
print()
print(f"[summary] tool messages={n_tool} | already={n_already} | "
f"filled={n_filled} | unresolved={n_unresolved}")
print(f"[mode] {'APPLIED (committed)' if args.apply else 'DRY-RUN (no commit, rerun with --apply)'}")
return 0
if __name__ == "__main__":
sys.exit(main())