diff --git a/PROGRESS.md b/PROGRESS.md index 7617405..695f55f 100644 --- a/PROGRESS.md +++ b/PROGRESS.md @@ -23,6 +23,8 @@ ### 2026-06-08 +- **loop 加病理性重复调用守卫(药1,治「不停调用同一脚本」的根因 ①②)**:接续批量化诊断——DB 实测高轮数 task 的浪费大头是「同名同参 + 无产出」的重复(`document_search` 122 次、空 `shell{}` 51 次、反复 `glob` 同一不存在路径),而 `core/loop.py` 主循环原本对此**零防护**照单全收。新增 `_RepeatGuard`(AgentLoop 实例持有、活在单次 run 内不跨 task):按 `(工具名, 精确参数 canonical-json)` 指纹跟踪「无产出重复」计数。**命门是只惩罚无产出、绝不误伤正常迭代**——同参但**结果每次不同**(改脚本后重跑 run_python、修 bug 后重跑构建)算有产出、计数清零永不拦;同参且**结果是 `[Error]` 或与之前一字不差**才累计。两档:累计 ≥`SOFT`(2)在 tool 结果尾部注入 `[重复调用警告]` 软提示(模型当轮即见);≥`HARD`(4)下一次同参调用 `should_block` 直接拦截不执行、回 `[已拦截重复调用]` 硬停消息逼其换路(一个卡死调用最多放过 ~4 次无产出重复)。**顺带堵 `_malformed_tool_calls` 的洞**:大参数畸形退化成合法空 `{}` 时 executor 每次返同一句「缺少必填参数」→ 走 dup 分支被同一机制拦下,无需单独特判空 `{}`。`_execute_tool_call` 接线:执行前 `should_block` 拦截、执行后用**截断后未加提示的原始结果**算指纹 `record`(保证同输出哈希一致)、`warn` 事件上抛拦截/首次软提示。改 `core/loop.py`;新增 `tests/test_loop_repeat_guard.py`(7 用例:同错拦截/空`{}`堵洞/同结果拦截/变化结果不拦/修好清零/SOFT 阈值/异参分别跟踪,全过)。**注**:阈值常数化(SOFT/HARD)便于后续按实跑调;药3(`/home/ubuntu/zcbot` 幽灵路径是否新任务仍复现)仍未查。 +- **检索/抓取类 host 工具批量化(治高轮数烧 token,先做的「web_fetch 类」一味药)**:DB 实测诊断(`scripts/diag_*.py`)高轮数 task 的 tool_call 序列——`ff1686b7` coding 任务 `document_search` **122 次**(104 个不同 query,关键词反复微调地毯式搜不收敛)、`document_download` 28 次;`ab063233` documents 任务 **64% 的 tool 结果带错误** + `shell{}`/`run_python{}` 空参数风暴。定位「不停调脚本」是三股根因叠加(空 `{}` 风暴 / 报错重试 / 检索不收敛),其中检索/抓取的往返成本可由**工具形态改造**直接压。本轮把三个 host 工具从「一次一个」改成**接受列表、一轮并发处理一批**(按「开发期不写兼容层」直接换签名、不留单数别名):① `web_fetch` `url`→`urls`(1-10,ThreadPoolExecutor 并发 6,全批正文总预算 16000 字符按条数分摊,单条 SSRF/超时/404 不连坐);② `document_search` `query`→`queries`(1-8,**批内去重** + 批量时自动缩 `max_documents`/`content_chars_per_doc` 防爆 context,每 query 独立 try);③ `document_download` `file_name+kb_name`→`items=[{...}]`(1-10 并发,单条失败标 `[Error]` 不毁整批)。三者输出都按条标注 `=== [i/n] ... ===`、超量截断明示不静默。`tools/{web_fetch,documents}.py` 改;`tests/test_secret_host_tools.py` 同步改新形态 + 加批量并发/去重/失败隔离 3 用例(5 过);`skills/documents/SKILL.md` 签名/工作流/反模式更新(加「一个 query 一轮反复搜」「同义近重复 query 堆叠」两条反模式,呼应诊断 ③)。`DOCUMENT_SEARCH_API.md`(上游 HTTP 接口,本就单 query/次)不动。**未动**:药1(loop `(name,args)` 重复探测器 + 堵空 `{}` 洞,最高杠杆但动核心 loop)、药3(`/home/ubuntu/zcbot` 幽灵路径是否新任务仍复现)——见记忆 `project_high_turn_token_burn_root_causes`,留作后续。 - **ppt skill 视觉系统升级为「卡片式」(治"生成效果不太行")**:学 GitHub `hugohe3/ppt-master`(24.9k★)后定位根因——其好看的核心是「SVG 作画→转原生 PPTX」给足设计自由度,而 zcbot 被 python-pptx **原语**(平矩形+左色条+圆点 bullet)摁死了视觉天花板,出来就是"2010 办公模板"。岔路三选(A 自建 SVG→pptx 转换器=最高天花板但大工程且与"一脚本整建/少来回"冲突;B 升级 python-pptx 设计系统;C 混合),**选 B**(保留单脚本批量架构、原生可编辑、风险低)。落地:① `pptx_helpers.py` 加质感件——`add_card`(圆角矩形 `adjustments[0]` 调圆角 + `a:outerShdw` XML 柔和投影)/`add_gradient_rect`(`fill.gradient()` + 角度)/`add_kpi`(数字卡)/`add_icon_tile`(图标底块)/`add_pill`/`add_eyebrow`/`add_chevron`/`add_notes`(演讲者备注),`set_palette` 从主/辅/强调**派生明暗色阶** `PRIMARY_WASH/SOFT/DARK`+`ACCENT_SOFT`;`apply_brand` 封面/章节改**渐变大色块**;**所有 helper 把 `name=` 写进形状 `.name`**(原来只喂 assert_inside,导致 quality_check 拿不到语义名)。② `layouts.md` 9 版式重写成卡片式 + 扩到 **13 种**(加 L10 KPI 卡 / L11 卡片网格 / L12 流程 / L13 大数字论据)。③ **quality_check 跟新设计语言对齐**(否则每个 deck 淹在假警告里):三色制改**按色相归桶**判(主色深浅/wash tint 不算新色)、小字号/bullet 按 `.name` 豁免标签类、大号展示字(≥40pt)跳过溢出估算、bullet ≤5 改**按列**判(双栏 3+3 不误报、单列 6 仍抓)。④ SKILL.md 工作流加 opt-in 真实配图(走 imagegen,¥0.22/张,大纲标 `[img]`)+ 每页 `add_notes`;`design_principles.md` 加派生色阶/KPI 卡/图表透明底卡片化。验证:13 版式全覆盖 demo deck 建成 + quality_check 全过;单列 6-bullet 回归仍触发。改 `skills/ppt/{SKILL.md,scripts/pptx_helpers.py,scripts/quality_check.py,references/layouts.md,references/design_principles.md}` + `SKILL_LIST.md`。**未动**:SVG 路线(A)、live preview、动画——属更大工程,本轮不上。 - **system prompt 加「少来回」全局原则(广谱减轮)**:ppt 之外的长尾 task(改代码/跑数据/画图)没专属 skill 兜,加一条通用 `工作原则`:互相独立、不依赖中间结果的操作(建多页产物/批量改文件/生成整份产物)合到一个脚本或一轮并发 tool call 里做,别一步一 call(每轮重发整段上下文,轮数=token 体量线性乘数);但下一步输入要看上一步结果时(探索检索/按报错改/需用户确认)就老实分步,别硬批——精准措辞避免"过度批处理"踩掉该有的 checkpoint。定位是便宜补充(prompt 走缓存近零成本),不指望它动 100+ 轮大头(那靠结构改造)。改 `prompts/system/general_v1.md`。 - **ppt skill 工作流批量化(减高轮数 task 的来回)**:实测高成本 task 几乎全是 100+ 轮的"逐步 tool 调用循环"(rust→PPT 34 轮、文献采集 245 轮),每轮重发整段上下文,轮数是 token 体量的线性乘数。ppt 是最易压、风险最低的试点:原 §阶段二**逐页**(每页 `读spec→glob图标→一个 run_python 加页→等用户确认→下一页`,N 页 ~2N 轮)。改法:① 阶段一 spec 增「逐页大纲」表(页|版式|标题|要点|图标),作为**替代逐页确认的前置 checkpoint**——改文字大纲比建完 slide 再推翻便宜;② 阶段二改成**写一个 `build_deck.py` 一次建整 deck**(同进程 `new_presentation`→按大纲循环 `add_slide`→一次 `save`,坐标天然一致;`pptx_helpers` 模块化已消解原"逐页防漂移"理由),图标**全 deck 批量预取**(不逐页拉);③ quality_check 一次 → 改脚本重跑(不 edit 成品);④ 可选"风格探针"(先建封面+1 页看观感)兜视觉返工险。N 页从 ~2N 轮降到 ~3-4 轮。改 `skills/ppt/SKILL.md`(阶段一/二/三 + 反模式 + 文件树)、`references/layouts.md`(§通用起手换成"整 deck 单脚本"模板)、`SKILL_LIST.md`(流程描述/典型产物同步)。冒烟过:单脚本 `new_presentation`+循环 `add_slide`+`save` 建 2 页成功,API 调用与模板一致。**注**:数据采集类(host 工具中转免不了)是另一条路(拆采集/处理相位),未动。 diff --git a/core/loop.py b/core/loop.py index f39c8d3..1bc062e 100644 --- a/core/loop.py +++ b/core/loop.py @@ -9,10 +9,11 @@ content delta 即时 emit `text` 事件让前端打字机渲染;chunks 攒齐后 """ from __future__ import annotations +import hashlib import json import time from pathlib import Path -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from uuid import UUID @@ -29,6 +30,75 @@ from .storage import record_chat_usage _CANCELLED_TOOL_PLACEHOLDER = "[cancelled by user]" +class _RepeatGuard: + """检测「同名同参 + 无产出」的病理性重复调用,断掉死循环。 + + 背景(2026-06-08 DB 实测):高轮数烧 token 的 task 里,单个工具被用**完全相同的 + 参数**重复调用几十上百次(`document_search` 122 次、空参数 `shell{}` 51 次、反复 + `glob` 同一个不存在的路径)。loop 原本对此零防护,照单全收直到撞 max_iterations。 + + 命门是只惩罚「无产出」重复,绝不误伤正常迭代: + - 同参但**每次结果不同**(改了脚本后重跑 run_python、修 bug 后重跑构建)→ 有产出, + 计数清零,永不拦。 + - 同参且**结果是 `[Error]` 或与之前某次一字不差**(空 `{}` 缺参、反复撞同一个错) + → 无产出,累计。 + 累计 >= SOFT 注入软提示(模型当轮就看到);>= HARD 直接拦截不执行,逼它换路。 + + 顺带堵掉 `_malformed_tool_calls` 的洞:大参数畸形退化成合法空 `{}` 时,executor 每次 + 返回同一句「缺少必填参数」→ 走 dup 分支被这同一机制拦下,无需单独特判空 `{}`。 + + 状态活在单次 task run 内(AgentLoop 实例持有),不跨 task。 + """ + + SOFT = 2 # 无产出重复累计 >= SOFT:在结果尾部注入软提示 + HARD = 4 # 无产出重复累计 >= HARD:下一次同参调用直接拦截不执行 + + def __init__(self) -> None: + # key -> {"hashes": set[str], "unproductive": int, "n": int, "blocked": int} + self._h: Dict[str, dict] = {} + + @staticmethod + def _key(name: str, args: Any) -> str: + try: + canon = json.dumps(args, sort_keys=True, ensure_ascii=False) + except (TypeError, ValueError): + canon = repr(args) + return name + "\x00" + canon + + def _state(self, name: str, args: Any) -> dict: + return self._h.setdefault( + self._key(name, args), + {"hashes": set(), "unproductive": 0, "n": 0, "blocked": 0}, + ) + + def should_block(self, name: str, args: Any) -> bool: + """执行前调用:该指纹已累计 >= HARD 次无产出重复 → 拦截(不执行)。""" + st = self._h.get(self._key(name, args)) + return bool(st and st["unproductive"] >= self.HARD) + + def register_block(self, name: str, args: Any) -> Tuple[int, int]: + """记一次拦截,返回 (已执行次数 n, 累计拦截次数 blocked)。""" + st = self._state(name, args) + st["blocked"] += 1 + return st["n"], st["blocked"] + + def record(self, name: str, args: Any, result: str) -> int: + """执行后调用:登记结果,返回该指纹当前的「无产出重复」计数。""" + st = self._state(name, args) + h = hashlib.sha1(result.encode("utf-8", "replace")).hexdigest() + is_err = result.lstrip().startswith("[Error") + dup = h in st["hashes"] + if st["n"] >= 1: + if is_err or dup: + st["unproductive"] += 1 + else: + # 新的非错误结果 = 有产出 → 清零,正常迭代不会被累积成拦截 + st["unproductive"] = 0 + st["hashes"].add(h) + st["n"] += 1 + return st["unproductive"] + + def _extract_delta_content(chunk: Any) -> Optional[str]: """从 stream chunk 提 delta.content(文本片段)。chunk 形态 litellm ModelResponseStream: choices[0].delta.content。usage-only 收尾 chunk(没 choices / delta)返 None。 @@ -146,6 +216,8 @@ class AgentLoop: # ③ tool_calls 之间。chunk 间 poll 让 cancel 延迟从「整轮 generation 时长」 # (几十秒)降到「单 chunk 间隔」(~100ms)。 self.cancel_check = cancel_check + # 病理性重复调用守卫(同名同参 + 无产出),活在本次 run 内,不跨 task。 + self._repeat_guard = _RepeatGuard() def _emit(self, event: dict) -> None: if self.sink is not None: @@ -358,6 +430,24 @@ class AgentLoop: "args_preview": args_preview, }) + # 病理性重复拦截:同参已累计 HARD 次无产出重复 → 不执行,回硬停消息逼模型换路。 + if self._repeat_guard.should_block(name, args): + n, blocked = self._repeat_guard.register_block(name, args) + result = ( + f"[已拦截重复调用] {name} 用完全相同的参数已调用 {n} 次且结果始终未变,本次未执行。" + "这通常意味着思路卡死:① 换不同的参数或方法;② 读一下相关文件/报错重新定位;" + "③ 若确实推进不了,停下来如实告诉用户卡在哪、缺什么。不要再用相同参数重试。" + ) + self._emit({"type": "warn", "msg": f"拦截重复调用 {name}(同参第 {n} 次、结果未变)"}) + self._emit({ + "type": "tool_result", + "name": name, + "result": result, + "preview": result, + "truncated": False, + }) + return result + ctx = ExecCtx( user_id=self.user_id, task_id=self.session.task_id, @@ -373,6 +463,16 @@ class AgentLoop: result = result[:MAX_LEN] + f"\n[... truncated, {len(result) - MAX_LEN} chars ...]" truncated = True + # 登记结果做重复检测(用截断后、未加提示的原始结果算指纹,保证同输出哈希一致)。 + unproductive = self._repeat_guard.record(name, args, result) + if unproductive >= _RepeatGuard.SOFT: + if unproductive == _RepeatGuard.SOFT: + self._emit({"type": "warn", "msg": f"{name} 同参重复且结果未变({unproductive} 次),已提示模型换路"}) + result += ( + f"\n\n[重复调用警告] 你已用完全相同的参数调用 {name} {unproductive + 1} 次、结果没有变化。" + "再原样重调不会有新结果——换参数/换工具/换思路,或停下来向用户说明卡在哪。" + ) + preview = result if len(result) < 400 else result[:400] + "..." self._emit({ "type": "tool_result", diff --git a/scripts/diag_error_retry.py b/scripts/diag_error_retry.py new file mode 100644 index 0000000..45338f2 --- /dev/null +++ b/scripts/diag_error_retry.py @@ -0,0 +1,67 @@ +"""判断 task 里:① tool 结果有多少带 [Error] / 失败;② 重复调用是否跟在错误后面。""" +import json +import os +import sys +from collections import Counter +from pathlib import Path + +env = Path(__file__).resolve().parent.parent / ".env" +for line in env.read_text(encoding="utf-8").splitlines(): + if line.strip().startswith("ZCBOT_DB_URL="): + os.environ["ZCBOT_DB_URL"] = line.split("=", 1)[1].strip() +from sqlalchemy import create_engine, text # noqa: E402 + +engine = create_engine(os.environ["ZCBOT_DB_URL"]) +prefix = sys.argv[1] if len(sys.argv) > 1 else "ab063233" + +with engine.connect() as conn: + tid = conn.execute( + text("select task_id from tasks where task_id::text like :p"), {"p": prefix + "%"} + ).fetchone()[0] + msgs = conn.execute( + text("select idx, payload from messages where task_id=:t order by idx"), + {"t": tid}, + ).fetchall() + +# 收集 tool 结果文本(role=tool)和它们的 name +results = {} # idx -> (name, text) +calls = [] # (idx, name, args_fingerprint) +for idx, payload in msgs: + role = payload.get("role") + if role == "tool": + results[idx] = (payload.get("name"), str(payload.get("content") or "")) + elif role == "assistant": + for tc in payload.get("tool_calls") or []: + fn = tc.get("function") or {} + try: + args = json.loads(fn.get("arguments") or "{}") + except Exception: + args = {} + fp = fn.get("name") + "|" + json.dumps(args, ensure_ascii=False, sort_keys=True) + calls.append((idx, fn.get("name"), fp)) + +n_tool = len(results) +n_err = sum(1 for _, (_, t) in results.items() if "[Error" in t or "Traceback" in t or "exit 1" in t or "[stderr]" in t) +print(f"task {tid}") +print(f"tool 结果总数: {n_tool} 含错误/stderr/exit1: {n_err} ({100*n_err/max(n_tool,1):.0f}%)\n") + +# 完全同名同参指纹的重复 +c = Counter(fp for _, _, fp in calls) +exactdup = [(fp, n) for fp, n in c.most_common() if n > 1] +print(f"完全同名同参(含全部参数)的调用指纹: 重复 {len(exactdup)} 种") +print("=== 同名同参重复 TOP 10(连参数都一字不差) ===") +for fp, n in c.most_common(10): + if n > 1: + name, _, rest = fp.partition("|") + print(f" {n:>3}x {name}: {rest[:70]}") + +# 错误样本 +print("\n=== 前 5 条错误结果样本 ===") +shown = 0 +for idx in sorted(results): + name, t = results[idx] + if any(k in t for k in ("[Error", "Traceback", "exit 1", "[stderr]")): + print(f" [{idx}] {name}: {t[:160].strip()}") + shown += 1 + if shown >= 5: + break diff --git a/scripts/diag_search_args.py b/scripts/diag_search_args.py new file mode 100644 index 0000000..1f9fde5 --- /dev/null +++ b/scripts/diag_search_args.py @@ -0,0 +1,62 @@ +"""看某 task 里 document_search / document_download 的真实参数序列, +判断是「同 query 反复」(病A) 还是「不同 query 地毯式」(病B)。""" +import json +import os +import sys +from pathlib import Path + +env = Path(__file__).resolve().parent.parent / ".env" +for line in env.read_text(encoding="utf-8").splitlines(): + if line.strip().startswith("ZCBOT_DB_URL="): + os.environ["ZCBOT_DB_URL"] = line.split("=", 1)[1].strip() + +from sqlalchemy import create_engine, text # noqa: E402 + +engine = create_engine(os.environ["ZCBOT_DB_URL"]) +prefix = sys.argv[1] if len(sys.argv) > 1 else "ff1686b7" +watch = sys.argv[2] if len(sys.argv) > 2 else "document_search" + +with engine.connect() as conn: + tid = conn.execute( + text("select task_id from tasks where task_id::text like :p"), + {"p": prefix + "%"}, + ).fetchone()[0] + msgs = conn.execute( + text("select idx, payload from messages where task_id=:t order by idx"), + {"t": tid}, + ).fetchall() + +seq = [] +for idx, payload in msgs: + if payload.get("role") != "assistant": + continue + for tc in payload.get("tool_calls") or []: + fn = tc.get("function") or {} + if fn.get("name") != watch: + continue + try: + args = json.loads(fn.get("arguments") or "{}") + except Exception: + args = {"": fn.get("arguments")} + seq.append((idx, args)) + +print(f"task {tid} — {watch}: {len(seq)} 次\n") +from collections import Counter # noqa: E402 + +# 用 query/关键字段做 key 看重复 +keys = [] +for _, args in seq: + k = args.get("query") or args.get("keyword") or args.get("q") or json.dumps(args, ensure_ascii=False) + keys.append(k) +c = Counter(keys) +dup = [(k, n) for k, n in c.most_common() if n > 1] +print(f"unique query: {len(c)} / total {len(keys)}") +print(f"被重复的 query 数: {len(dup)}\n") +print("=== 重复最多的 query TOP 15 ===") +for k, n in c.most_common(15): + mark = " <<<同一query重复" if n > 1 else "" + print(f" {n:>3}x {str(k)[:80]}{mark}") +print("\n=== 前 40 次调用的 query 顺序(看是不是连着搜同一个) ===") +for i, (idx, args) in enumerate(seq[:40]): + k = args.get("query") or args.get("keyword") or json.dumps(args, ensure_ascii=False) + print(f" [{idx:>4}] {str(k)[:80]}") diff --git a/scripts/diag_tool_repeat.py b/scripts/diag_tool_repeat.py new file mode 100644 index 0000000..46b3729 --- /dev/null +++ b/scripts/diag_tool_repeat.py @@ -0,0 +1,120 @@ +"""诊断:找高轮数 task,统计 tool_call 序列里的重复模式。 + +用法:.venv/Scripts/python.exe scripts/diag_tool_repeat.py [task_id] +不带 task_id → 列出 message 数最高的 15 个 task + 各自重复度概览。 +带 task_id → dump 该 task 的完整 tool_call 序列(name + 参数指纹)。 +""" +import hashlib +import json +import os +import sys +from collections import Counter +from pathlib import Path + +# 读 .env 里的 ZCBOT_DB_URL +env = Path(__file__).resolve().parent.parent / ".env" +if env.is_file(): + for line in env.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if line.startswith("ZCBOT_DB_URL=") and not line.startswith("#"): + os.environ["ZCBOT_DB_URL"] = line.split("=", 1)[1].strip() + +from sqlalchemy import create_engine, text # noqa: E402 + +engine = create_engine(os.environ["ZCBOT_DB_URL"]) + + +def fingerprint(payload): + """返回 (kind, label) —— assistant 的每个 tool_call 拆成一条。 + label = tool 名 + 关键参数指纹(script_path / 文件路径 / code 的 hash 前 8 位)。 + """ + out = [] + if payload.get("role") != "assistant": + return out + for tc in payload.get("tool_calls") or []: + fn = (tc.get("function") or {}) + name = fn.get("name") or "?" + raw = fn.get("arguments") or "{}" + try: + args = json.loads(raw) + except Exception: + args = {} + key_bits = [] + for k in ("script_path", "path", "file_path", "old_str", "command"): + if k in args and isinstance(args[k], str): + key_bits.append(f"{k}={args[k][:60]}") + if "code" in args and isinstance(args["code"], str): + h = hashlib.sha1(args["code"].encode("utf-8")).hexdigest()[:8] + key_bits.append(f"code#{h}(len={len(args['code'])})") + label = name + (" | " + " ".join(key_bits) if key_bits else "") + out.append(label) + return out + + +def overview(): + sql = text( + """ + SELECT t.task_id, t.name, t.skill, count(m.message_id) AS n_msg + FROM tasks t JOIN messages m ON m.task_id = t.task_id + GROUP BY t.task_id, t.name, t.skill + ORDER BY n_msg DESC LIMIT 15 + """ + ) + with engine.connect() as conn: + rows = conn.execute(sql).fetchall() + print(f"{'n_msg':>6} {'repeat%':>7} {'maxdup':>6} skill / name / task_id") + print("-" * 90) + for r in rows: + labels = [] + msgs = conn.execute( + text("SELECT payload FROM messages WHERE task_id=:t ORDER BY idx"), + {"t": r.task_id}, + ).fetchall() + for (payload,) in msgs: + labels.extend(fingerprint(payload)) + if not labels: + continue + c = Counter(labels) + maxdup = max(c.values()) + repeated = sum(v for v in c.values() if v > 1) + pct = 100 * repeated / len(labels) if labels else 0 + print( + f"{r.n_msg:>6} {pct:>6.0f}% {maxdup:>6} " + f"{(r.skill or '-'):<10} {r.name[:28]:<28} {str(r.task_id)[:8]}" + ) + + +def dump(task_id): + with engine.connect() as conn: + msgs = conn.execute( + text("SELECT idx, payload FROM messages WHERE task_id=:t ORDER BY idx"), + {"t": task_id}, + ).fetchall() + seq = [] + for idx, payload in msgs: + for label in fingerprint(payload): + seq.append((idx, label)) + print(f"total tool_calls: {len(seq)}\n") + c = Counter(label for _, label in seq) + print("=== 调用次数 TOP 20(同名同参指纹) ===") + for label, n in c.most_common(20): + flag = " <<< 重复" if n > 1 else "" + print(f" {n:>3}x {label}{flag}") + print("\n=== 连续重复段(同一指纹连着出现) ===") + prev, run = None, 0 + for idx, label in seq: + if label == prev: + run += 1 + else: + if run >= 2: + print(f" 连调 {run}x: {prev}") + prev, run = label, 1 + if run >= 2: + print(f" 连调 {run}x: {prev}") + + +if __name__ == "__main__": + if len(sys.argv) > 1: + dump(sys.argv[1]) + else: + overview() diff --git a/skills/documents/SKILL.md b/skills/documents/SKILL.md index 08525d1..ba17f1d 100644 --- a/skills/documents/SKILL.md +++ b/skills/documents/SKILL.md @@ -29,15 +29,15 @@ description: 查内部材料学科知识库(document_search API,7 个学科:胶 **用途**:用户没指定库 → 先 `document_list_kb` 看有哪些库(中文名 `ch_name` 看分类),再选 `kb_names` / `classification_ids` 缩窄 search 范围。 -### `document_search(query, kb_names=None, classification_ids=None, max_documents=6, content_chars_per_doc=1200)` +### `document_search(queries, kb_names=None, classification_ids=None, max_documents=6, content_chars_per_doc=1200)` -搜文档,返回精简列表,每条带 **截断后的 `md_content`**。默认每篇 1200 字符,需要看更多时调大 `content_chars_per_doc`,上限 5000。 +搜文档,**一次可传多个 query 并发搜**,返回精简列表,每条带 **截断后的 `md_content`**。 -- `query`:搜索词。**中英文均可** —— 文档主体是英文学术论文,但 API 后端有跨语言语义检索;复杂技术术语用**英文**更精准(`cement hydration` > `水泥水化`),日常概念中文 OK -- `kb_names`:知识库白名单(从 `document_list_kb` 选);`None` 走 server 默认(单库 `mu_34_1740625285897` 胶凝)。**多库联查就显式传**,如 `kb_names=["mu_34_1740625285897", "mu_34_1740625303475"]` -- `classification_ids`:分类 ID 白名单(1-7,对应 7 个学科库);`None` 不过滤 -- `max_documents`:1-20,默认 6 -- `content_chars_per_doc`:每篇返回多少 Markdown 字符,默认 1200,最大 5000;不要一上来拉满 +- `queries`:**搜索词列表**(1-8 条)。把你这一轮想搜的所有不同 query 一次性传进来(`queries=["alkali-activated slag strength", "fly ash cement hydration", ...]`),**别一个 query 一轮 tool call** —— 反复来回每轮重发整段上下文,轮数是 token 体量的线性乘数。**中英文均可** —— 文档主体是英文学术论文,但 API 后端有跨语言语义检索;复杂技术术语用**英文**更精准(`cement hydration` > `水泥水化`),日常概念中文 OK。**批内自动去重**;**别堆一堆只差几个词的近义 query**(边际递减),先想清楚一组互不重叠的 query 再批量发 +- `kb_names`:知识库白名单(从 `document_list_kb` 选,对所有 query 生效);`None` 走 server 默认(单库 `mu_34_1740625285897` 胶凝)。**多库联查就显式传**,如 `kb_names=["mu_34_1740625285897", "mu_34_1740625303475"]` +- `classification_ids`:分类 ID 白名单(1-7,对应 7 个学科库,对所有 query 生效);`None` 不过滤 +- `max_documents`:每个 query 返回几篇,1-20,默认 6(**批量多 query 时自动缩量**控制总输出) +- `content_chars_per_doc`:每篇返回多少 Markdown 字符,默认 1200,最大 5000(**批量多 query 时自动缩量**);不要一上来拉满 **学科库 → kb_name 速查**(`document_list_kb` 拿全量,这里只列常用): @@ -51,18 +51,21 @@ description: 查内部材料学科知识库(document_search API,7 个学科:胶 | 耐火材料 | `mu_34_1740625365079` | | 检验检测 | `mu_34_1740625376621` | -### `document_download(file_name, kb_name, preview=False)` +### `document_download(items)` -下载原始文档(PDF / Word / ...)到 `/documents/`,返回相对路径。已存在跳过下载直接复用。`file_name` 支持原始文件名(`example.pdf`)或 Markdown 名(`example.md`),server 自动回退。 +下载原始文档(PDF / Word / ...)到 `/documents/`,返回各自相对路径。**一次可传多个文档并发下载**,单条失败不连坐其余。已存在跳过下载直接复用。 + +- `items`:文档列表(1-10 条),每条 `{file_name, kb_name, preview?}`,如 `items=[{"file_name":"a.pdf","kb_name":"mu_34_1740625285897"}, {"file_name":"b.pdf","kb_name":"mu_34_1740625285897"}]`。**要下几篇就一次性列进来,别一篇一轮 tool call** +- `file_name` 支持原始文件名(`example.pdf`)或 Markdown 名(`example.md`),server 自动回退 ## 标准工作流 1. **(可选)`document_list_kb`** —— 用户没指定库 / 不确定分类时看一下有哪些 -2. **`document_search(query=..., max_documents=6)`** —— 中英文均可,专业技术术语优先英文 +2. **`document_search(queries=[...])`** —— 先规划好一组互不重叠的 query 一次性批量搜,中英文均可,专业技术术语优先英文 3. **看返回**: - 用 `file_name + character_count + md_content` 判断切题 - - 切题 → 直接用返回的 Markdown 摘要给 LLM 引用;需要更多上下文时提高 `content_chars_per_doc` 重搜 - - 需要看图表 / 表格原貌 / 给用户附件 → `document_download(file_name, kb_name)` 拿原文档,然后用主 agent 的 `read` 工具读(zcbot 已内置 PDF/Word 文本抽取) + - 切题 → 直接用返回的 Markdown 摘要给 LLM 引用;需要更多上下文时对**少数**命中文档提高 `content_chars_per_doc` 单独重搜 + - 需要看图表 / 表格原貌 / 给用户附件 → `document_download(items=[...])` 一次性批量拿原文档,然后用主 agent 的 `read` 工具读(zcbot 已内置 PDF/Word 文本抽取) 4. **写产出**:把 md_content 关键段落引到申报书 / 方案里,标注来源文件名 ## md_content 优先 vs 原件下载 @@ -86,8 +89,10 @@ description: 查内部材料学科知识库(document_search API,7 个学科:胶 - 用 `httpx` / `requests` 裸调 API(走 host tool,免得 base_url / auth / 字段名漂移时四处改,也避免 key 进入 sandbox) - `document_search(max_documents=20, content_chars_per_doc=5000)` 一次拉满(20 条直接爆 LLM 上下文)—— 先用默认值判断切题,只对少数命中文档加大 `content_chars_per_doc` +- **一个 query 一轮 tool call 地反复搜**(同一意图换着措辞搜十几遍)—— 这是最烧 token 的反模式:每轮重发整段上下文。改成**先列一组去重 query 一次 `queries=[...]` 批量发**;一批结果看完不够再发下一批,而不是一条一条挤 +- `document_download` 一篇一轮 tool call —— 把要下的都列进 `items=[...]` 一次下完 - 看到 md_content 切题还 `download` 一遍原件(md_content 已是 LLM 友好的 Markdown,大多数引用场景够用) - 凭 `ch_name`("胶凝材料学科知识库")就以为 query 要用中文 —— 文档主体是英文,复杂术语用英文更精准 - 编造 file_name / kb_name —— 不在 `document_list_kb` / `document_search` 返回里就**明确告诉用户"未命中"**,不要瞎传 ID - 把 `document_download` 返回的相对路径当绝对路径用(它是相对 task_dir 的) -- 尝试给 `document_download` 传 `working_dir`(tool 已绑定当前 task_dir,不要让模型指定路径) +- 尝试给 `document_download` 传 `working_dir`(tool 已绑定当前 task_dir,`items` 里只放 `file_name` / `kb_name`,不要让模型指定路径) diff --git a/tests/test_loop_repeat_guard.py b/tests/test_loop_repeat_guard.py new file mode 100644 index 0000000..2aaca5c --- /dev/null +++ b/tests/test_loop_repeat_guard.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import sys +import unittest +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from core.loop import _RepeatGuard # noqa: E402 + + +def _simulate(guard: _RepeatGuard, name: str, args, results: list[str]) -> list[str]: + """模拟 loop 逐次调用:先 should_block,未拦才 record。返回每次的判定标记。 + 'BLOCK' = 被拦截未执行;否则返回 'exec(unprod=N)'。 + """ + out = [] + for r in results: + if guard.should_block(name, args): + guard.register_block(name, args) + out.append("BLOCK") + continue + unprod = guard.record(name, args, r) + out.append(f"exec(unprod={unprod})") + return out + + +class TestRepeatGuard(unittest.TestCase): + def test_identical_error_repeats_get_blocked(self): + g = _RepeatGuard() + trace = _simulate(g, "glob", {"path": "/home/ubuntu/zcbot"}, ["[Error] base path not found"] * 8) + # 第一次执行无产出计 0,之后每次 +1;累计到 HARD 后拦截 + self.assertIn("BLOCK", trace) + self.assertTrue(g.should_block("glob", {"path": "/home/ubuntu/zcbot"})) + # 拦截前最多放过 HARD 次无产出重复(共 HARD+1 次执行) + n_exec = sum(1 for t in trace if t.startswith("exec")) + self.assertEqual(n_exec, _RepeatGuard.HARD + 1) + + def test_empty_args_error_storm_blocked(self): + """空 {} 缺参风暴:executor 每次返回同一句错误 → 被同一机制拦下(堵 malformed 洞)。""" + g = _RepeatGuard() + trace = _simulate(g, "shell", {}, ["[Error] 缺少必填参数 [command]"] * 7) + self.assertIn("BLOCK", trace) + + def test_identical_nonerror_result_blocked(self): + """同参且结果一字不差(非错误)也算无产出 → 拦截。""" + g = _RepeatGuard() + trace = _simulate(g, "read", {"path": "a.txt"}, ["same content"] * 8) + self.assertIn("BLOCK", trace) + + def test_changing_results_never_blocked(self): + """同参但每次结果不同(改脚本后重跑)= 有产出 → 永不拦,计数清零。""" + g = _RepeatGuard() + results = [f"[stdout]\nrun {i} output\n[exit 0]" for i in range(10)] + trace = _simulate(g, "run_python", {"script_path": "x.py"}, results) + self.assertNotIn("BLOCK", trace) + self.assertFalse(g.should_block("run_python", {"script_path": "x.py"})) + # 每次都是新结果,无产出计数恒为 0 + self.assertTrue(all(t == "exec(unprod=0)" for t in trace)) + + def test_productive_result_resets_counter(self): + """报错几次后拿到新结果(修好了)→ 计数清零,不会被先前的失败拖去拦截。""" + g = _RepeatGuard() + seq = ["[Error] x", "[Error] x", "[stdout]\nfixed!\n[exit 0]", "[stdout]\nfixed!\n[exit 0]"] + _simulate(g, "shell", {"command": "make"}, seq) + # 中途修好清零,不该进入 block 态 + self.assertFalse(g.should_block("shell", {"command": "make"})) + + def test_soft_threshold_reached_before_hard(self): + g = _RepeatGuard() + unprods = [] + for _ in range(_RepeatGuard.SOFT + 1): + unprods.append(g.record("document_search", {"queries": ["x"]}, "(no documents found)")) + # 累计达到 SOFT(此时应注入软提示),但还没到 HARD 拦截 + self.assertGreaterEqual(max(unprods), _RepeatGuard.SOFT) + self.assertFalse(g.should_block("document_search", {"queries": ["x"]})) + + def test_distinct_args_tracked_separately(self): + g = _RepeatGuard() + _simulate(g, "document_search", {"queries": ["a"]}, ["[Error] e"] * 8) + # 不同参数互不影响 + self.assertTrue(g.should_block("document_search", {"queries": ["a"]})) + self.assertFalse(g.should_block("document_search", {"queries": ["b"]})) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_secret_host_tools.py b/tests/test_secret_host_tools.py index f73e95e..f87a9bd 100644 --- a/tests/test_secret_host_tools.py +++ b/tests/test_secret_host_tools.py @@ -23,8 +23,9 @@ class TestDocumentHostTools(unittest.TestCase): } ] with patch("tools.documents.doc_client.search", return_value=docs) as search: + # 单 query 批量:queries 列表只一条时,缩量逻辑不动用户给的参数 out = DocumentSearchTool().execute( - query="cement hydration", + queries=["cement hydration"], max_documents=3, content_chars_per_doc=20, ) @@ -40,6 +41,26 @@ class TestDocumentHostTools(unittest.TestCase): self.assertIn("A" * 20, out) self.assertIn("truncated", out) + def test_document_search_batches_queries_concurrently_and_dedups(self): + from tools.documents import DocumentSearchTool + + calls: list[str] = [] + + def fake_search(query, **kwargs): + calls.append(query) + return [{"file_name": f"{query}.md", "kb_name": "mu_1", "md_content": "x"}] + + with patch("tools.documents.doc_client.search", side_effect=fake_search): + out = DocumentSearchTool().execute( + queries=["q1", "q2", "q1"], # 含重复 → 去重成 q1/q2 + ) + + self.assertEqual(sorted(calls), ["q1", "q2"]) # 去重后只两次 + self.assertIn("[1/2]", out) + self.assertIn("[2/2]", out) + self.assertIn("'q1'", out) + self.assertIn("'q2'", out) + def test_document_download_uses_constructor_working_dir(self): from tools.documents import DocumentDownloadTool @@ -55,7 +76,7 @@ class TestDocumentHostTools(unittest.TestCase): base_dir=working_dir, user_root=Path(tmp), ) - out = tool.execute(file_name="paper.pdf", kb_name="mu_1") + out = tool.execute(items=[{"file_name": "paper.pdf", "kb_name": "mu_1"}]) download.assert_called_once_with( file_name="paper.pdf", @@ -65,6 +86,32 @@ class TestDocumentHostTools(unittest.TestCase): ) self.assertIn("saved: task/documents/paper.pdf", out) + def test_document_download_batches_items_isolating_failure(self): + from tools.documents import DocumentDownloadTool + + with tempfile.TemporaryDirectory() as tmp: + working_dir = Path(tmp) / "task" + working_dir.mkdir() + + def fake_download(file_name, kb_name, working_dir, preview): + if file_name == "bad.pdf": + raise RuntimeError("404") + return f"documents/{file_name}" + + with patch("tools.documents.doc_client.download", side_effect=fake_download): + tool = DocumentDownloadTool( + working_dir=working_dir, base_dir=working_dir, user_root=Path(tmp) + ) + out = tool.execute(items=[ + {"file_name": "ok.pdf", "kb_name": "mu_1"}, + {"file_name": "bad.pdf", "kb_name": "mu_1"}, + ]) + + # 一条失败不连坐另一条 + self.assertIn("saved: task/documents/ok.pdf", out) + self.assertIn("[Error]", out) + self.assertIn("bad.pdf", out) + class TestMaterialsProjectHostTools(unittest.TestCase): def test_mp_search_summary_uses_host_key_and_returns_json(self): diff --git a/tools/documents.py b/tools/documents.py index b5fcd30..5d8b42f 100644 --- a/tools/documents.py +++ b/tools/documents.py @@ -5,6 +5,7 @@ sandbox receives only business arguments and trimmed results / saved paths. """ from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Optional @@ -12,6 +13,10 @@ from skills.documents import client as doc_client from .base import Tool +_MAX_QUERIES = 8 # document_search 单次批量 query 上限 +_MAX_DOWNLOADS = 10 # document_download 单次批量 item 上限 +_CONCURRENCY = 6 + def _clip(text: str, max_chars: int) -> tuple[str, bool]: max_chars = max(0, int(max_chars)) @@ -20,6 +25,17 @@ def _clip(text: str, max_chars: int) -> tuple[str, bool]: return text[:max_chars], True +def _dedup_keep_order(items: list[str]) -> list[str]: + seen: set[str] = set() + out: list[str] = [] + for it in items: + key = it.strip() + if key and key.lower() not in seen: + seen.add(key.lower()) + out.append(key) + return out + + class DocumentListKbTool(Tool): name = "document_list_kb" description = ( @@ -51,50 +67,54 @@ class DocumentListKbTool(Tool): class DocumentSearchTool(Tool): name = "document_search" description = ( - "Search the internal materials document knowledge base. " - "Returns file metadata and truncated markdown content; increase content_chars_per_doc only when needed." + "Search the internal materials document knowledge base with one OR MORE queries at once. " + "Pass every distinct query you want in a single `queries` list instead of calling this tool " + "repeatedly — searches run concurrently and one failing query does not abort the others. " + "When many queries are batched, per-query documents and per-document content shrink automatically " + "to keep the result compact; use a single query when you need maximum depth on one topic. " + "Avoid firing many near-identical reworded queries (diminishing returns) — plan a deduplicated set first." ) parameters = { "type": "object", "properties": { - "query": {"type": "string", "description": "Search query, Chinese or English; technical terms are usually better in English."}, + "queries": { + "type": "array", + "items": {"type": "string"}, + "description": f"Search queries (1-{_MAX_QUERIES}), Chinese or English; technical terms are usually better in English. Batch distinct queries together.", + }, "kb_names": { "type": "array", "items": {"type": "string"}, - "description": "Optional knowledge-base names from document_list_kb.", + "description": "Optional knowledge-base names from document_list_kb (applies to all queries).", }, "classification_ids": { "type": "array", "items": {"type": "integer"}, - "description": "Optional materials domain ids, 1-7.", + "description": "Optional materials domain ids, 1-7 (applies to all queries).", }, "max_documents": { "type": "integer", "default": 6, - "description": "Number of documents to return, 1-20.", + "description": "Documents per query, 1-20 (auto-reduced when many queries are batched).", }, "content_chars_per_doc": { "type": "integer", "default": 1200, - "description": "Maximum markdown characters returned per document, 0-5000.", + "description": "Maximum markdown characters per document, 0-5000 (auto-reduced when many queries are batched).", }, }, - "required": ["query"], + "required": ["queries"], } - def execute( + def _search_one( self, query: str, - kb_names: Optional[list[str]] = None, - classification_ids: Optional[list[int]] = None, - max_documents: int = 6, - content_chars_per_doc: int = 1200, + kb_names: Optional[list[str]], + classification_ids: Optional[list[int]], + max_documents: int, + content_chars_per_doc: int, ) -> str: - query = (query or "").strip() - if not query: - return "[Error] query 不能为空" - max_documents = min(max(int(max_documents), 1), 20) - content_chars_per_doc = min(max(int(content_chars_per_doc), 0), 5000) + """搜单个 query,返回格式化文本块或 [Error ...];绝不抛异常(供并发安全调用)。""" try: docs = doc_client.search( query=query, @@ -122,21 +142,78 @@ class DocumentSearchTool(Tool): lines.append(f" md_content[:{content_chars_per_doc}]={snippet}{suffix}") return "\n".join(lines) + def execute( + self, + queries: list[str] | str, + kb_names: Optional[list[str]] = None, + classification_ids: Optional[list[int]] = None, + max_documents: int = 6, + content_chars_per_doc: int = 1200, + ) -> str: + if isinstance(queries, str): + queries = [queries] + queries = _dedup_keep_order([q for q in (queries or []) if isinstance(q, str)]) + if not queries: + return "[Error] queries 不能为空" + + dropped = 0 + if len(queries) > _MAX_QUERIES: + dropped = len(queries) - _MAX_QUERIES + queries = queries[:_MAX_QUERIES] + + n = len(queries) + max_documents = min(max(int(max_documents), 1), 20) + content_chars_per_doc = min(max(int(content_chars_per_doc), 0), 5000) + # 批量时自动缩量,bound 总输出(单 query 时保持用户给定值不动) + if n > 1: + max_documents = min(max_documents, max(2, 12 // n)) + content_chars_per_doc = min(content_chars_per_doc, max(400, 6000 // n)) + + with ThreadPoolExecutor(max_workers=min(_CONCURRENCY, n)) as pool: + results = list(pool.map( + lambda q: self._search_one(q, kb_names, classification_ids, max_documents, content_chars_per_doc), + queries, + )) + + if n == 1: + out = results[0] + return out if not dropped else out + f"\n\n[note] 多余 {dropped} 个 query 被丢弃(单次上限 {_MAX_QUERIES})" + + blocks = [] + for i, (q, text) in enumerate(zip(queries, results), 1): + blocks.append(f"=== [{i}/{n}] {q!r} ===\n{text}") + out = "\n\n".join(blocks) + if dropped: + out += f"\n\n[note] 多余 {dropped} 个 query 被丢弃(单次上限 {_MAX_QUERIES})" + return out + class DocumentDownloadTool(Tool): name = "document_download" description = ( - "Download an original document from document_search into the current task_dir/documents/. " - "Use file_name and kb_name returned by document_search." + "Download one OR MORE original documents from document_search into task_dir/documents/. " + "Pass every document you want in a single `items` list instead of calling this tool repeatedly — " + "downloads run concurrently and one failing item does not abort the others. " + "Use the file_name and kb_name returned by document_search." ) parameters = { "type": "object", "properties": { - "file_name": {"type": "string", "description": "Original file_name or md_filename returned by document_search."}, - "kb_name": {"type": "string", "description": "Knowledge-base name returned by document_search."}, - "preview": {"type": "boolean", "default": False, "description": "Request inline preview disposition from the upstream API. Usually false."}, + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "file_name": {"type": "string", "description": "Original file_name or md_filename returned by document_search."}, + "kb_name": {"type": "string", "description": "Knowledge-base name returned by document_search."}, + "preview": {"type": "boolean", "description": "Request inline preview disposition. Usually false."}, + }, + "required": ["file_name", "kb_name"], + }, + "description": f"Documents to download (1-{_MAX_DOWNLOADS}). Batch every document into one call.", + }, }, - "required": ["file_name", "kb_name"], + "required": ["items"], } def __init__( @@ -149,16 +226,46 @@ class DocumentDownloadTool(Tool): super().__init__(base_dir=base_dir, user_root=user_root) self.working_dir = Path(working_dir) - def execute(self, file_name: str, kb_name: str, preview: bool = False) -> str: - if not (file_name or "").strip() or not (kb_name or "").strip(): - return "[Error] file_name / kb_name 不可为空" + def _download_one(self, item: dict) -> str: + """下载单个 item,返回 'saved: ...' 或 [Error ...];绝不抛异常(供并发安全调用)。""" + if not isinstance(item, dict): + return f"[Error] 非法 item(应为对象): {item!r}" + file_name = str(item.get("file_name") or "").strip() + kb_name = str(item.get("kb_name") or "").strip() + if not file_name or not kb_name: + return f"[Error] file_name / kb_name 不可为空: {item!r}" try: rel = doc_client.download( file_name=file_name, kb_name=kb_name, working_dir=str(self.working_dir), - preview=bool(preview), + preview=bool(item.get("preview", False)), ) except Exception as e: - return f"[Error] document_download failed: {type(e).__name__}: {e}" + return f"[Error] download {file_name!r} failed: {type(e).__name__}: {e}" return f"saved: {self._display(self.working_dir / rel)}" + + def execute(self, items: list[dict] | dict) -> str: + if isinstance(items, dict): + items = [items] + items = [it for it in (items or []) if isinstance(it, dict)] + if not items: + return "[Error] items 不能为空" + + dropped = 0 + if len(items) > _MAX_DOWNLOADS: + dropped = len(items) - _MAX_DOWNLOADS + items = items[:_MAX_DOWNLOADS] + + with ThreadPoolExecutor(max_workers=min(_CONCURRENCY, len(items))) as pool: + results = list(pool.map(self._download_one, items)) + + if len(items) == 1: + out = results[0] + return out if not dropped else out + f"\n[note] 多余 {dropped} 个被丢弃(单次上限 {_MAX_DOWNLOADS})" + + lines = [f"{i}. {r}" for i, r in enumerate(results, 1)] + out = "\n".join(lines) + if dropped: + out += f"\n[note] 多余 {dropped} 个被丢弃(单次上限 {_MAX_DOWNLOADS})" + return out diff --git a/tools/web_fetch.py b/tools/web_fetch.py index ebcaa91..17ae377 100644 --- a/tools/web_fetch.py +++ b/tools/web_fetch.py @@ -1,9 +1,14 @@ -"""Web Fetch: 抓取任意 URL 并返回 markdown 文本。""" +"""Web Fetch: 批量抓取多个 URL,各自返回 markdown 文本。 + +一次调用接受 URL 列表并发抓取 —— 避免「一个 URL 一轮 tool call」的高轮数循环 +(每轮重发整段上下文,轮数是 token 体量的线性乘数)。单条失败不连坐整批。 +""" from __future__ import annotations import ipaddress import re import socket +from concurrent.futures import ThreadPoolExecutor import html2text import httpx @@ -18,7 +23,11 @@ _SSRF_BLOCKED = { ) } -_MAX_CHARS = 8000 +_MAX_URLS = 10 # 单次批量上限,超出截断并明示 +_TOTAL_CHARS = 16000 # 全批正文总预算(对齐 loop 的 tool result 上限) +_PER_URL_CAP = 8000 # 单条上限(n=1 时与旧行为一致) +_MIN_PER_URL = 1500 # 单条下限(批量大时不至于压到几乎为零) +_CONCURRENCY = 6 _TIMEOUT = 15.0 _UA = ( @@ -53,54 +62,88 @@ def _check_ssrf(url: str) -> str | None: return None +def _fetch_one(url: str, per_url_cap: int) -> str: + """抓单个 URL,返回 markdown 正文或 [Error ...];绝不抛异常(供并发安全调用)。""" + err = _check_ssrf(url) + if err: + return f"[Error] {err}" + try: + resp = httpx.get( + url, + headers={"User-Agent": _UA}, + timeout=_TIMEOUT, + follow_redirects=True, + ) + except httpx.TimeoutException: + return f"[Error] request timed out after {_TIMEOUT:.0f}s" + except httpx.HTTPError as e: + return f"[Error] request failed: {e}" + + if resp.status_code >= 400: + return f"[Error] HTTP {resp.status_code}" + + content_type = resp.headers.get("content-type", "") + if "text/html" not in content_type and "text/plain" not in content_type: + return f"[Error] unsupported content type: {content_type} — only HTML/text pages are supported" + + try: + text = _h2t.handle(resp.text) + except Exception as e: + return f"[Error] failed to convert HTML to text: {e}" + + text = re.sub(r"\n{3,}", "\n\n", text).strip() + if len(text) > per_url_cap: + text = text[:per_url_cap] + f"\n\n...(truncated, {len(text) - per_url_cap} more chars — fetch this URL alone for the rest)" + return text + + class WebFetchTool(Tool): name = "web_fetch" description = ( - "Fetch a web page and return its content as markdown text. " - "Use this to read the full content of a URL found in search results or referenced by the user. " - "Results are truncated to 8000 characters." + "Fetch one OR MORE web pages concurrently and return their content as markdown. " + "Pass ALL the URLs you want to read in a single `urls` list — do NOT call this tool " + "repeatedly one URL at a time. Each page is fetched independently; one failing URL " + "does not abort the others. Per-page content is truncated (smaller when many URLs are " + "batched); fetch a single URL alone when you need its full text." ) parameters = { "type": "object", "properties": { - "url": {"type": "string", "description": "The URL to fetch"}, + "urls": { + "type": "array", + "items": {"type": "string"}, + "description": f"URLs to fetch (1-{_MAX_URLS}). Batch every URL you need into one call.", + }, }, - "required": ["url"], + "required": ["urls"], } - def execute(self, url: str) -> str: - err = _check_ssrf(url) - if err: - return f"[Error] {err}" + def execute(self, urls: list[str] | str) -> str: + # 容错:模型偶发传单个字符串而非列表 + if isinstance(urls, str): + urls = [urls] + urls = [u.strip() for u in (urls or []) if isinstance(u, str) and u.strip()] + if not urls: + return "[Error] urls 不能为空" - try: - resp = httpx.get( - url, - headers={"User-Agent": _UA}, - timeout=_TIMEOUT, - follow_redirects=True, - ) - except httpx.TimeoutException: - return f"[Error] request timed out after {_TIMEOUT:.0f}s" - except httpx.HTTPError as e: - return f"[Error] request failed: {e}" + dropped = 0 + if len(urls) > _MAX_URLS: + dropped = len(urls) - _MAX_URLS + urls = urls[:_MAX_URLS] - if resp.status_code >= 400: - return f"[Error] HTTP {resp.status_code}" + per_url_cap = min(_PER_URL_CAP, max(_MIN_PER_URL, _TOTAL_CHARS // len(urls))) - content_type = resp.headers.get("content-type", "") - if "text/html" not in content_type and "text/plain" not in content_type: - return f"[Error] unsupported content type: {content_type} — only HTML/text pages are supported" + with ThreadPoolExecutor(max_workers=min(_CONCURRENCY, len(urls))) as pool: + results = list(pool.map(lambda u: _fetch_one(u, per_url_cap), urls)) - try: - text = _h2t.handle(resp.text) - except Exception as e: - return f"[Error] failed to convert HTML to text: {e}" + if len(urls) == 1: + body = results[0] + return body if not dropped else body + f"\n\n[note] 多余 {dropped} 个 URL 被丢弃(单次上限 {_MAX_URLS})" - # 压缩多余空行 - text = re.sub(r"\n{3,}", "\n\n", text).strip() - - if len(text) > _MAX_CHARS: - text = text[:_MAX_CHARS] + f"\n\n...(truncated, {len(text) - _MAX_CHARS} more chars)" - - return text + blocks = [] + for i, (url, text) in enumerate(zip(urls, results), 1): + blocks.append(f"=== [{i}/{len(urls)}] {url} ===\n{text}") + out = "\n\n".join(blocks) + if dropped: + out += f"\n\n[note] 多余 {dropped} 个 URL 被丢弃(单次上限 {_MAX_URLS});需要就再发一批" + return out