zcbot/scripts/diag_tool_repeat.py

121 lines
4.1 KiB
Python

"""诊断:找高轮数 task,统计 tool_call 序列里的重复模式。
用法:.venv/Scripts/python.exe scripts/diag_tool_repeat.py [task_id]
不带 task_id → 列出 message 数最高的 15 个 task + 各自重复度概览。
带 task_id → dump 该 task 的完整 tool_call 序列(name + 参数指纹)。
"""
import hashlib
import json
import os
import sys
from collections import Counter
from pathlib import Path
# 读 .env 里的 ZCBOT_DB_URL
env = Path(__file__).resolve().parent.parent / ".env"
if env.is_file():
for line in env.read_text(encoding="utf-8").splitlines():
line = line.strip()
if line.startswith("ZCBOT_DB_URL=") and not line.startswith("#"):
os.environ["ZCBOT_DB_URL"] = line.split("=", 1)[1].strip()
from sqlalchemy import create_engine, text # noqa: E402
engine = create_engine(os.environ["ZCBOT_DB_URL"])
def fingerprint(payload):
"""返回 (kind, label) —— assistant 的每个 tool_call 拆成一条。
label = tool 名 + 关键参数指纹(script_path / 文件路径 / code 的 hash 前 8 位)。
"""
out = []
if payload.get("role") != "assistant":
return out
for tc in payload.get("tool_calls") or []:
fn = (tc.get("function") or {})
name = fn.get("name") or "?"
raw = fn.get("arguments") or "{}"
try:
args = json.loads(raw)
except Exception:
args = {}
key_bits = []
for k in ("script_path", "path", "file_path", "old_str", "command"):
if k in args and isinstance(args[k], str):
key_bits.append(f"{k}={args[k][:60]}")
if "code" in args and isinstance(args["code"], str):
h = hashlib.sha1(args["code"].encode("utf-8")).hexdigest()[:8]
key_bits.append(f"code#{h}(len={len(args['code'])})")
label = name + (" | " + " ".join(key_bits) if key_bits else "")
out.append(label)
return out
def overview():
sql = text(
"""
SELECT t.task_id, t.name, t.skill, count(m.message_id) AS n_msg
FROM tasks t JOIN messages m ON m.task_id = t.task_id
GROUP BY t.task_id, t.name, t.skill
ORDER BY n_msg DESC LIMIT 15
"""
)
with engine.connect() as conn:
rows = conn.execute(sql).fetchall()
print(f"{'n_msg':>6} {'repeat%':>7} {'maxdup':>6} skill / name / task_id")
print("-" * 90)
for r in rows:
labels = []
msgs = conn.execute(
text("SELECT payload FROM messages WHERE task_id=:t ORDER BY idx"),
{"t": r.task_id},
).fetchall()
for (payload,) in msgs:
labels.extend(fingerprint(payload))
if not labels:
continue
c = Counter(labels)
maxdup = max(c.values())
repeated = sum(v for v in c.values() if v > 1)
pct = 100 * repeated / len(labels) if labels else 0
print(
f"{r.n_msg:>6} {pct:>6.0f}% {maxdup:>6} "
f"{(r.skill or '-'):<10} {r.name[:28]:<28} {str(r.task_id)[:8]}"
)
def dump(task_id):
with engine.connect() as conn:
msgs = conn.execute(
text("SELECT idx, payload FROM messages WHERE task_id=:t ORDER BY idx"),
{"t": task_id},
).fetchall()
seq = []
for idx, payload in msgs:
for label in fingerprint(payload):
seq.append((idx, label))
print(f"total tool_calls: {len(seq)}\n")
c = Counter(label for _, label in seq)
print("=== 调用次数 TOP 20(同名同参指纹) ===")
for label, n in c.most_common(20):
flag = " <<< 重复" if n > 1 else ""
print(f" {n:>3}x {label}{flag}")
print("\n=== 连续重复段(同一指纹连着出现) ===")
prev, run = None, 0
for idx, label in seq:
if label == prev:
run += 1
else:
if run >= 2:
print(f" 连调 {run}x: {prev}")
prev, run = label, 1
if run >= 2:
print(f" 连调 {run}x: {prev}")
if __name__ == "__main__":
if len(sys.argv) > 1:
dump(sys.argv[1])
else:
overview()