"""诊断:找高轮数 task,统计 tool_call 序列里的重复模式。 用法:.venv/Scripts/python.exe scripts/diag_tool_repeat.py [task_id] 不带 task_id → 列出 message 数最高的 15 个 task + 各自重复度概览。 带 task_id → dump 该 task 的完整 tool_call 序列(name + 参数指纹)。 """ import hashlib import json import os import sys from collections import Counter from pathlib import Path # 读 .env 里的 ZCBOT_DB_URL env = Path(__file__).resolve().parent.parent / ".env" if env.is_file(): for line in env.read_text(encoding="utf-8").splitlines(): line = line.strip() if line.startswith("ZCBOT_DB_URL=") and not line.startswith("#"): os.environ["ZCBOT_DB_URL"] = line.split("=", 1)[1].strip() from sqlalchemy import create_engine, text # noqa: E402 engine = create_engine(os.environ["ZCBOT_DB_URL"]) def fingerprint(payload): """返回 (kind, label) —— assistant 的每个 tool_call 拆成一条。 label = tool 名 + 关键参数指纹(script_path / 文件路径 / code 的 hash 前 8 位)。 """ out = [] if payload.get("role") != "assistant": return out for tc in payload.get("tool_calls") or []: fn = (tc.get("function") or {}) name = fn.get("name") or "?" raw = fn.get("arguments") or "{}" try: args = json.loads(raw) except Exception: args = {} key_bits = [] for k in ("script_path", "path", "file_path", "old_str", "command"): if k in args and isinstance(args[k], str): key_bits.append(f"{k}={args[k][:60]}") if "code" in args and isinstance(args["code"], str): h = hashlib.sha1(args["code"].encode("utf-8")).hexdigest()[:8] key_bits.append(f"code#{h}(len={len(args['code'])})") label = name + (" | " + " ".join(key_bits) if key_bits else "") out.append(label) return out def overview(): sql = text( """ SELECT t.task_id, t.name, t.skill, count(m.message_id) AS n_msg FROM tasks t JOIN messages m ON m.task_id = t.task_id GROUP BY t.task_id, t.name, t.skill ORDER BY n_msg DESC LIMIT 15 """ ) with engine.connect() as conn: rows = conn.execute(sql).fetchall() print(f"{'n_msg':>6} {'repeat%':>7} {'maxdup':>6} skill / name / task_id") print("-" * 90) for r in rows: labels = [] msgs = conn.execute( text("SELECT payload FROM messages WHERE task_id=:t ORDER BY idx"), {"t": r.task_id}, ).fetchall() for (payload,) in msgs: labels.extend(fingerprint(payload)) if not labels: continue c = Counter(labels) maxdup = max(c.values()) repeated = sum(v for v in c.values() if v > 1) pct = 100 * repeated / len(labels) if labels else 0 print( f"{r.n_msg:>6} {pct:>6.0f}% {maxdup:>6} " f"{(r.skill or '-'):<10} {r.name[:28]:<28} {str(r.task_id)[:8]}" ) def dump(task_id): with engine.connect() as conn: msgs = conn.execute( text("SELECT idx, payload FROM messages WHERE task_id=:t ORDER BY idx"), {"t": task_id}, ).fetchall() seq = [] for idx, payload in msgs: for label in fingerprint(payload): seq.append((idx, label)) print(f"total tool_calls: {len(seq)}\n") c = Counter(label for _, label in seq) print("=== 调用次数 TOP 20(同名同参指纹) ===") for label, n in c.most_common(20): flag = " <<< 重复" if n > 1 else "" print(f" {n:>3}x {label}{flag}") print("\n=== 连续重复段(同一指纹连着出现) ===") prev, run = None, 0 for idx, label in seq: if label == prev: run += 1 else: if run >= 2: print(f" 连调 {run}x: {prev}") prev, run = label, 1 if run >= 2: print(f" 连调 {run}x: {prev}") if __name__ == "__main__": if len(sys.argv) > 1: dump(sys.argv[1]) else: overview()