103 lines
3.5 KiB
Python
103 lines
3.5 KiB
Python
"""Backfill tool 消息 payload 缺失的 `name` 字段。
|
|
|
|
背景:loop.py 早期 append tool 消息时只写了 role/tool_call_id/content,
|
|
没存 name。前端依赖 `payload.name` 判断是不是产物工具(seedream/seedance)
|
|
→ 历史 task 重新打开时 banner/chip 不显示。本脚本一次性回填。
|
|
|
|
策略:按 task 走,先把每条 assistant 消息里 tool_calls[].id → name 收集
|
|
成 map;再扫该 task 内 role=tool 的消息,按 tool_call_id 查 map 补 name。
|
|
查不到的(罕见:LLM 给的 id 跨 task 不对齐 / 历史脏数据)打 warn 跳过。
|
|
|
|
跑法: .venv/Scripts/python.exe scripts/backfill_tool_message_name.py
|
|
默认 dry-run,加 --apply 真写。
|
|
|
|
幂等:已经有 name 的消息跳过;再跑一遍 0 改动。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parent.parent
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
env_file = ROOT / ".env"
|
|
if env_file.exists():
|
|
for line in env_file.read_text(encoding="utf-8").splitlines():
|
|
line = line.strip()
|
|
if not line or line.startswith("#") or "=" not in line:
|
|
continue
|
|
k, _, v = line.partition("=")
|
|
os.environ.setdefault(k.strip(), v.strip())
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm.attributes import flag_modified
|
|
|
|
from core.storage import session_scope
|
|
from core.storage.models import Message
|
|
|
|
|
|
def main() -> int:
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--apply", action="store_true", help="真写;默认 dry-run 只打印")
|
|
args = ap.parse_args()
|
|
|
|
n_tool = 0 # 扫到的 tool 消息总数
|
|
n_already = 0 # 已有 name 跳过
|
|
n_filled = 0 # 本次补上
|
|
n_unresolved = 0 # 查不到 name 的(打 warn)
|
|
|
|
with session_scope() as s:
|
|
rows = s.execute(select(Message).order_by(Message.task_id, Message.idx)).scalars().all()
|
|
|
|
# tool_call_id → name(按 task 分组,避免跨 task 撞 id)
|
|
by_task: dict[str, dict[str, str]] = defaultdict(dict)
|
|
for m in rows:
|
|
p = m.payload or {}
|
|
if p.get("role") != "assistant":
|
|
continue
|
|
tcs = p.get("tool_calls") or []
|
|
for tc in tcs:
|
|
tcid = tc.get("id")
|
|
fn = (tc.get("function") or {}).get("name")
|
|
if tcid and fn:
|
|
by_task[str(m.task_id)][tcid] = fn
|
|
|
|
for m in rows:
|
|
p = m.payload or {}
|
|
if p.get("role") != "tool":
|
|
continue
|
|
n_tool += 1
|
|
if p.get("name"):
|
|
n_already += 1
|
|
continue
|
|
tcid = p.get("tool_call_id")
|
|
name = by_task.get(str(m.task_id), {}).get(tcid or "")
|
|
if not name:
|
|
n_unresolved += 1
|
|
print(f" [WARN] task={m.task_id} idx={m.idx} tool_call_id={tcid!r} 找不到对应 assistant tool_call,跳过")
|
|
continue
|
|
print(f" [FILL] task={m.task_id} idx={m.idx} <- name={name!r}")
|
|
p["name"] = name
|
|
m.payload = p
|
|
flag_modified(m, "payload")
|
|
n_filled += 1
|
|
|
|
if args.apply:
|
|
s.commit()
|
|
else:
|
|
s.rollback()
|
|
|
|
print()
|
|
print(f"[summary] tool messages={n_tool} | already={n_already} | "
|
|
f"filled={n_filled} | unresolved={n_unresolved}")
|
|
print(f"[mode] {'APPLIED (committed)' if args.apply else 'DRY-RUN (no commit, rerun with --apply)'}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|