zcbot/tools/web_fetch.py

150 lines
5.3 KiB
Python

"""Web Fetch: 批量抓取多个 URL,各自返回 markdown 文本。
一次调用接受 URL 列表并发抓取 —— 避免「一个 URL 一轮 tool call」的高轮数循环
(每轮重发整段上下文,轮数是 token 体量的线性乘数)。单条失败不连坐整批。
"""
from __future__ import annotations
import ipaddress
import re
import socket
from concurrent.futures import ThreadPoolExecutor
import html2text
import httpx
from .base import Tool
_SSRF_BLOCKED = {
ipaddress.ip_network(n)
for n in (
"127.0.0.0/8", "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16",
"169.254.0.0/16", "0.0.0.0/8", "::1/128", "fc00::/7", "fe80::/10",
)
}
_MAX_URLS = 10 # 单次批量上限,超出截断并明示
_TOTAL_CHARS = 16000 # 全批正文总预算(对齐 loop 的 tool result 上限)
_PER_URL_CAP = 8000 # 单条上限(n=1 时与旧行为一致)
_MIN_PER_URL = 1500 # 单条下限(批量大时不至于压到几乎为零)
_CONCURRENCY = 6
_TIMEOUT = 15.0
_UA = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
)
_h2t = html2text.HTML2Text()
_h2t.ignore_links = False
_h2t.ignore_images = True
_h2t.body_width = 0
_h2t.skip_internal_links = True
def _check_ssrf(url: str) -> str | None:
"""返回 None 表示安全;否则返回错误信息字符串。"""
import urllib.parse
parsed = urllib.parse.urlparse(url)
host = parsed.hostname
if not host:
return f"invalid URL: no host in {url!r}"
try:
ip = ipaddress.ip_address(host)
except ValueError:
try:
ip = ipaddress.ip_address(socket.getaddrinfo(host, None, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)[0][4][0])
except (OSError, IndexError, ValueError):
return f"cannot resolve host: {host!r}"
for net in _SSRF_BLOCKED:
if ip in net:
return f"blocked internal/private host: {host} ({ip})"
return None
def _fetch_one(url: str, per_url_cap: int) -> str:
"""抓单个 URL,返回 markdown 正文或 [Error ...];绝不抛异常(供并发安全调用)。"""
err = _check_ssrf(url)
if err:
return f"[Error] {err}"
try:
resp = httpx.get(
url,
headers={"User-Agent": _UA},
timeout=_TIMEOUT,
follow_redirects=True,
)
except httpx.TimeoutException:
return f"[Error] request timed out after {_TIMEOUT:.0f}s"
except httpx.HTTPError as e:
return f"[Error] request failed: {e}"
if resp.status_code >= 400:
return f"[Error] HTTP {resp.status_code}"
content_type = resp.headers.get("content-type", "")
if "text/html" not in content_type and "text/plain" not in content_type:
return f"[Error] unsupported content type: {content_type} — only HTML/text pages are supported"
try:
text = _h2t.handle(resp.text)
except Exception as e:
return f"[Error] failed to convert HTML to text: {e}"
text = re.sub(r"\n{3,}", "\n\n", text).strip()
if len(text) > per_url_cap:
text = text[:per_url_cap] + f"\n\n...(truncated, {len(text) - per_url_cap} more chars — fetch this URL alone for the rest)"
return text
class WebFetchTool(Tool):
name = "web_fetch"
description = (
"Fetch one OR MORE web pages concurrently and return their content as markdown. "
"Pass ALL the URLs you want to read in a single `urls` list — do NOT call this tool "
"repeatedly one URL at a time. Each page is fetched independently; one failing URL "
"does not abort the others. Per-page content is truncated (smaller when many URLs are "
"batched); fetch a single URL alone when you need its full text."
)
parameters = {
"type": "object",
"properties": {
"urls": {
"type": "array",
"items": {"type": "string"},
"description": f"URLs to fetch (1-{_MAX_URLS}). Batch every URL you need into one call.",
},
},
"required": ["urls"],
}
def execute(self, urls: list[str] | str) -> str:
# 容错:模型偶发传单个字符串而非列表
if isinstance(urls, str):
urls = [urls]
urls = [u.strip() for u in (urls or []) if isinstance(u, str) and u.strip()]
if not urls:
return "[Error] urls 不能为空"
dropped = 0
if len(urls) > _MAX_URLS:
dropped = len(urls) - _MAX_URLS
urls = urls[:_MAX_URLS]
per_url_cap = min(_PER_URL_CAP, max(_MIN_PER_URL, _TOTAL_CHARS // len(urls)))
with ThreadPoolExecutor(max_workers=min(_CONCURRENCY, len(urls))) as pool:
results = list(pool.map(lambda u: _fetch_one(u, per_url_cap), urls))
if len(urls) == 1:
body = results[0]
return body if not dropped else body + f"\n\n[note] 多余 {dropped} 个 URL 被丢弃(单次上限 {_MAX_URLS})"
blocks = []
for i, (url, text) in enumerate(zip(urls, results), 1):
blocks.append(f"=== [{i}/{len(urls)}] {url} ===\n{text}")
out = "\n\n".join(blocks)
if dropped:
out += f"\n\n[note] 多余 {dropped} 个 URL 被丢弃(单次上限 {_MAX_URLS});需要就再发一批"
return out