121 lines
4.1 KiB
Python
121 lines
4.1 KiB
Python
"""诊断:找高轮数 task,统计 tool_call 序列里的重复模式。
|
|
|
|
用法:.venv/Scripts/python.exe scripts/diag_tool_repeat.py [task_id]
|
|
不带 task_id → 列出 message 数最高的 15 个 task + 各自重复度概览。
|
|
带 task_id → dump 该 task 的完整 tool_call 序列(name + 参数指纹)。
|
|
"""
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import sys
|
|
from collections import Counter
|
|
from pathlib import Path
|
|
|
|
# 读 .env 里的 ZCBOT_DB_URL
|
|
env = Path(__file__).resolve().parent.parent / ".env"
|
|
if env.is_file():
|
|
for line in env.read_text(encoding="utf-8").splitlines():
|
|
line = line.strip()
|
|
if line.startswith("ZCBOT_DB_URL=") and not line.startswith("#"):
|
|
os.environ["ZCBOT_DB_URL"] = line.split("=", 1)[1].strip()
|
|
|
|
from sqlalchemy import create_engine, text # noqa: E402
|
|
|
|
engine = create_engine(os.environ["ZCBOT_DB_URL"])
|
|
|
|
|
|
def fingerprint(payload):
|
|
"""返回 (kind, label) —— assistant 的每个 tool_call 拆成一条。
|
|
label = tool 名 + 关键参数指纹(script_path / 文件路径 / code 的 hash 前 8 位)。
|
|
"""
|
|
out = []
|
|
if payload.get("role") != "assistant":
|
|
return out
|
|
for tc in payload.get("tool_calls") or []:
|
|
fn = (tc.get("function") or {})
|
|
name = fn.get("name") or "?"
|
|
raw = fn.get("arguments") or "{}"
|
|
try:
|
|
args = json.loads(raw)
|
|
except Exception:
|
|
args = {}
|
|
key_bits = []
|
|
for k in ("script_path", "path", "file_path", "old_str", "command"):
|
|
if k in args and isinstance(args[k], str):
|
|
key_bits.append(f"{k}={args[k][:60]}")
|
|
if "code" in args and isinstance(args["code"], str):
|
|
h = hashlib.sha1(args["code"].encode("utf-8")).hexdigest()[:8]
|
|
key_bits.append(f"code#{h}(len={len(args['code'])})")
|
|
label = name + (" | " + " ".join(key_bits) if key_bits else "")
|
|
out.append(label)
|
|
return out
|
|
|
|
|
|
def overview():
|
|
sql = text(
|
|
"""
|
|
SELECT t.task_id, t.name, t.skill, count(m.message_id) AS n_msg
|
|
FROM tasks t JOIN messages m ON m.task_id = t.task_id
|
|
GROUP BY t.task_id, t.name, t.skill
|
|
ORDER BY n_msg DESC LIMIT 15
|
|
"""
|
|
)
|
|
with engine.connect() as conn:
|
|
rows = conn.execute(sql).fetchall()
|
|
print(f"{'n_msg':>6} {'repeat%':>7} {'maxdup':>6} skill / name / task_id")
|
|
print("-" * 90)
|
|
for r in rows:
|
|
labels = []
|
|
msgs = conn.execute(
|
|
text("SELECT payload FROM messages WHERE task_id=:t ORDER BY idx"),
|
|
{"t": r.task_id},
|
|
).fetchall()
|
|
for (payload,) in msgs:
|
|
labels.extend(fingerprint(payload))
|
|
if not labels:
|
|
continue
|
|
c = Counter(labels)
|
|
maxdup = max(c.values())
|
|
repeated = sum(v for v in c.values() if v > 1)
|
|
pct = 100 * repeated / len(labels) if labels else 0
|
|
print(
|
|
f"{r.n_msg:>6} {pct:>6.0f}% {maxdup:>6} "
|
|
f"{(r.skill or '-'):<10} {r.name[:28]:<28} {str(r.task_id)[:8]}"
|
|
)
|
|
|
|
|
|
def dump(task_id):
|
|
with engine.connect() as conn:
|
|
msgs = conn.execute(
|
|
text("SELECT idx, payload FROM messages WHERE task_id=:t ORDER BY idx"),
|
|
{"t": task_id},
|
|
).fetchall()
|
|
seq = []
|
|
for idx, payload in msgs:
|
|
for label in fingerprint(payload):
|
|
seq.append((idx, label))
|
|
print(f"total tool_calls: {len(seq)}\n")
|
|
c = Counter(label for _, label in seq)
|
|
print("=== 调用次数 TOP 20(同名同参指纹) ===")
|
|
for label, n in c.most_common(20):
|
|
flag = " <<< 重复" if n > 1 else ""
|
|
print(f" {n:>3}x {label}{flag}")
|
|
print("\n=== 连续重复段(同一指纹连着出现) ===")
|
|
prev, run = None, 0
|
|
for idx, label in seq:
|
|
if label == prev:
|
|
run += 1
|
|
else:
|
|
if run >= 2:
|
|
print(f" 连调 {run}x: {prev}")
|
|
prev, run = label, 1
|
|
if run >= 2:
|
|
print(f" 连调 {run}x: {prev}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) > 1:
|
|
dump(sys.argv[1])
|
|
else:
|
|
overview()
|