272 lines
11 KiB
Python
272 lines
11 KiB
Python
"""Host-side document_search tools.
|
|
|
|
These tools intentionally keep DOCUMENT_SEARCH_API_KEY on the host side. The
|
|
sandbox receives only business arguments and trimmed results / saved paths.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from skills.documents import client as doc_client
|
|
|
|
from .base import Tool
|
|
|
|
_MAX_QUERIES = 8 # document_search 单次批量 query 上限
|
|
_MAX_DOWNLOADS = 10 # document_download 单次批量 item 上限
|
|
_CONCURRENCY = 6
|
|
|
|
|
|
def _clip(text: str, max_chars: int) -> tuple[str, bool]:
|
|
max_chars = max(0, int(max_chars))
|
|
if len(text) <= max_chars:
|
|
return text, False
|
|
return text[:max_chars], True
|
|
|
|
|
|
def _dedup_keep_order(items: list[str]) -> list[str]:
|
|
seen: set[str] = set()
|
|
out: list[str] = []
|
|
for it in items:
|
|
key = it.strip()
|
|
if key and key.lower() not in seen:
|
|
seen.add(key.lower())
|
|
out.append(key)
|
|
return out
|
|
|
|
|
|
class DocumentListKbTool(Tool):
|
|
name = "document_list_kb"
|
|
description = (
|
|
"List internal materials knowledge bases available in document_search. "
|
|
"Use before document_search when the user did not specify a materials domain."
|
|
)
|
|
parameters = {"type": "object", "properties": {}}
|
|
|
|
def execute(self) -> str:
|
|
try:
|
|
kbs = doc_client.list_kb()
|
|
except Exception as e:
|
|
return f"[Error] document_list_kb failed: {type(e).__name__}: {e}"
|
|
if not kbs:
|
|
return "(no knowledge bases returned)"
|
|
lines = ["Knowledge bases:"]
|
|
for kb in kbs:
|
|
lines.append(
|
|
"- id={id} kb_name={kb_name} ch_name={ch_name} file_count={file_count}".format(
|
|
id=kb.get("id", ""),
|
|
kb_name=kb.get("kb_name", ""),
|
|
ch_name=kb.get("ch_name", ""),
|
|
file_count=kb.get("file_count", ""),
|
|
)
|
|
)
|
|
return "\n".join(lines)
|
|
|
|
|
|
class DocumentSearchTool(Tool):
|
|
name = "document_search"
|
|
description = (
|
|
"Search the internal materials document knowledge base with one OR MORE queries at once. "
|
|
"Pass every distinct query you want in a single `queries` list instead of calling this tool "
|
|
"repeatedly — searches run concurrently and one failing query does not abort the others. "
|
|
"When many queries are batched, per-query documents and per-document content shrink automatically "
|
|
"to keep the result compact; use a single query when you need maximum depth on one topic. "
|
|
"Avoid firing many near-identical reworded queries (diminishing returns) — plan a deduplicated set first."
|
|
)
|
|
parameters = {
|
|
"type": "object",
|
|
"properties": {
|
|
"queries": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": f"Search queries (1-{_MAX_QUERIES}), Chinese or English; technical terms are usually better in English. Batch distinct queries together.",
|
|
},
|
|
"kb_names": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": "Optional knowledge-base names from document_list_kb (applies to all queries).",
|
|
},
|
|
"classification_ids": {
|
|
"type": "array",
|
|
"items": {"type": "integer"},
|
|
"description": "Optional materials domain ids, 1-7 (applies to all queries).",
|
|
},
|
|
"max_documents": {
|
|
"type": "integer",
|
|
"default": 6,
|
|
"description": "Documents per query, 1-20 (auto-reduced when many queries are batched).",
|
|
},
|
|
"content_chars_per_doc": {
|
|
"type": "integer",
|
|
"default": 1200,
|
|
"description": "Maximum markdown characters per document, 0-5000 (auto-reduced when many queries are batched).",
|
|
},
|
|
},
|
|
"required": ["queries"],
|
|
}
|
|
|
|
def _search_one(
|
|
self,
|
|
query: str,
|
|
kb_names: Optional[list[str]],
|
|
classification_ids: Optional[list[int]],
|
|
max_documents: int,
|
|
content_chars_per_doc: int,
|
|
) -> str:
|
|
"""搜单个 query,返回格式化文本块或 [Error ...];绝不抛异常(供并发安全调用)。"""
|
|
try:
|
|
docs = doc_client.search(
|
|
query=query,
|
|
kb_names=kb_names or None,
|
|
classification_ids=classification_ids or None,
|
|
max_documents=max_documents,
|
|
)
|
|
except Exception as e:
|
|
return f"[Error] document_search failed: {type(e).__name__}: {e}"
|
|
if not docs:
|
|
return f"(no documents found for query: {query!r})"
|
|
|
|
lines = [f"Document search results for: {query!r}"]
|
|
for i, d in enumerate(docs, 1):
|
|
content = d.get("md_content") or ""
|
|
snippet, truncated = _clip(str(content), content_chars_per_doc)
|
|
lines.append("")
|
|
lines.append(f"{i}. file_name={d.get('file_name') or ''}")
|
|
lines.append(f" kb_name={d.get('kb_name') or ''}")
|
|
lines.append(f" character_count={d.get('character_count') or 0}")
|
|
if d.get("md_filename"):
|
|
lines.append(f" md_filename={d.get('md_filename')}")
|
|
if snippet:
|
|
suffix = " ...(truncated)" if truncated else ""
|
|
lines.append(f" md_content[:{content_chars_per_doc}]={snippet}{suffix}")
|
|
return "\n".join(lines)
|
|
|
|
def execute(
|
|
self,
|
|
queries: list[str] | str,
|
|
kb_names: Optional[list[str]] = None,
|
|
classification_ids: Optional[list[int]] = None,
|
|
max_documents: int = 6,
|
|
content_chars_per_doc: int = 1200,
|
|
) -> str:
|
|
if isinstance(queries, str):
|
|
queries = [queries]
|
|
queries = _dedup_keep_order([q for q in (queries or []) if isinstance(q, str)])
|
|
if not queries:
|
|
return "[Error] queries 不能为空"
|
|
|
|
dropped = 0
|
|
if len(queries) > _MAX_QUERIES:
|
|
dropped = len(queries) - _MAX_QUERIES
|
|
queries = queries[:_MAX_QUERIES]
|
|
|
|
n = len(queries)
|
|
max_documents = min(max(int(max_documents), 1), 20)
|
|
content_chars_per_doc = min(max(int(content_chars_per_doc), 0), 5000)
|
|
# 批量时自动缩量,bound 总输出(单 query 时保持用户给定值不动)
|
|
if n > 1:
|
|
max_documents = min(max_documents, max(2, 12 // n))
|
|
content_chars_per_doc = min(content_chars_per_doc, max(400, 6000 // n))
|
|
|
|
with ThreadPoolExecutor(max_workers=min(_CONCURRENCY, n)) as pool:
|
|
results = list(pool.map(
|
|
lambda q: self._search_one(q, kb_names, classification_ids, max_documents, content_chars_per_doc),
|
|
queries,
|
|
))
|
|
|
|
if n == 1:
|
|
out = results[0]
|
|
return out if not dropped else out + f"\n\n[note] 多余 {dropped} 个 query 被丢弃(单次上限 {_MAX_QUERIES})"
|
|
|
|
blocks = []
|
|
for i, (q, text) in enumerate(zip(queries, results), 1):
|
|
blocks.append(f"=== [{i}/{n}] {q!r} ===\n{text}")
|
|
out = "\n\n".join(blocks)
|
|
if dropped:
|
|
out += f"\n\n[note] 多余 {dropped} 个 query 被丢弃(单次上限 {_MAX_QUERIES})"
|
|
return out
|
|
|
|
|
|
class DocumentDownloadTool(Tool):
|
|
name = "document_download"
|
|
description = (
|
|
"Download one OR MORE original documents from document_search into task_dir/documents/. "
|
|
"Pass every document you want in a single `items` list instead of calling this tool repeatedly — "
|
|
"downloads run concurrently and one failing item does not abort the others. "
|
|
"Use the file_name and kb_name returned by document_search."
|
|
)
|
|
parameters = {
|
|
"type": "object",
|
|
"properties": {
|
|
"items": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"properties": {
|
|
"file_name": {"type": "string", "description": "Original file_name or md_filename returned by document_search."},
|
|
"kb_name": {"type": "string", "description": "Knowledge-base name returned by document_search."},
|
|
"preview": {"type": "boolean", "description": "Request inline preview disposition. Usually false."},
|
|
},
|
|
"required": ["file_name", "kb_name"],
|
|
},
|
|
"description": f"Documents to download (1-{_MAX_DOWNLOADS}). Batch every document into one call.",
|
|
},
|
|
},
|
|
"required": ["items"],
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
working_dir: Path,
|
|
base_dir: Optional[Path] = None,
|
|
user_root: Optional[Path] = None,
|
|
) -> None:
|
|
super().__init__(base_dir=base_dir, user_root=user_root)
|
|
self.working_dir = Path(working_dir)
|
|
|
|
def _download_one(self, item: dict) -> str:
|
|
"""下载单个 item,返回 'saved: ...' 或 [Error ...];绝不抛异常(供并发安全调用)。"""
|
|
if not isinstance(item, dict):
|
|
return f"[Error] 非法 item(应为对象): {item!r}"
|
|
file_name = str(item.get("file_name") or "").strip()
|
|
kb_name = str(item.get("kb_name") or "").strip()
|
|
if not file_name or not kb_name:
|
|
return f"[Error] file_name / kb_name 不可为空: {item!r}"
|
|
try:
|
|
rel = doc_client.download(
|
|
file_name=file_name,
|
|
kb_name=kb_name,
|
|
working_dir=str(self.working_dir),
|
|
preview=bool(item.get("preview", False)),
|
|
)
|
|
except Exception as e:
|
|
return f"[Error] download {file_name!r} failed: {type(e).__name__}: {e}"
|
|
return f"saved: {self._display(self.working_dir / rel)}"
|
|
|
|
def execute(self, items: list[dict] | dict) -> str:
|
|
if isinstance(items, dict):
|
|
items = [items]
|
|
items = [it for it in (items or []) if isinstance(it, dict)]
|
|
if not items:
|
|
return "[Error] items 不能为空"
|
|
|
|
dropped = 0
|
|
if len(items) > _MAX_DOWNLOADS:
|
|
dropped = len(items) - _MAX_DOWNLOADS
|
|
items = items[:_MAX_DOWNLOADS]
|
|
|
|
with ThreadPoolExecutor(max_workers=min(_CONCURRENCY, len(items))) as pool:
|
|
results = list(pool.map(self._download_one, items))
|
|
|
|
if len(items) == 1:
|
|
out = results[0]
|
|
return out if not dropped else out + f"\n[note] 多余 {dropped} 个被丢弃(单次上限 {_MAX_DOWNLOADS})"
|
|
|
|
lines = [f"{i}. {r}" for i, r in enumerate(results, 1)]
|
|
out = "\n".join(lines)
|
|
if dropped:
|
|
out += f"\n[note] 多余 {dropped} 个被丢弃(单次上限 {_MAX_DOWNLOADS})"
|
|
return out
|