"""Host-side document_search tools. These tools intentionally keep DOCUMENT_SEARCH_API_KEY on the host side. The sandbox receives only business arguments and trimmed results / saved paths. """ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Optional from skills.documents import client as doc_client from .base import Tool _MAX_QUERIES = 8 # document_search 单次批量 query 上限 _MAX_DOWNLOADS = 10 # document_download 单次批量 item 上限 _CONCURRENCY = 6 def _clip(text: str, max_chars: int) -> tuple[str, bool]: max_chars = max(0, int(max_chars)) if len(text) <= max_chars: return text, False return text[:max_chars], True def _dedup_keep_order(items: list[str]) -> list[str]: seen: set[str] = set() out: list[str] = [] for it in items: key = it.strip() if key and key.lower() not in seen: seen.add(key.lower()) out.append(key) return out class DocumentListKbTool(Tool): name = "document_list_kb" description = ( "List internal materials knowledge bases available in document_search. " "Use before document_search when the user did not specify a materials domain." ) parameters = {"type": "object", "properties": {}} def execute(self) -> str: try: kbs = doc_client.list_kb() except Exception as e: return f"[Error] document_list_kb failed: {type(e).__name__}: {e}" if not kbs: return "(no knowledge bases returned)" lines = ["Knowledge bases:"] for kb in kbs: lines.append( "- id={id} kb_name={kb_name} ch_name={ch_name} file_count={file_count}".format( id=kb.get("id", ""), kb_name=kb.get("kb_name", ""), ch_name=kb.get("ch_name", ""), file_count=kb.get("file_count", ""), ) ) return "\n".join(lines) class DocumentSearchTool(Tool): name = "document_search" description = ( "Search the internal materials document knowledge base with one OR MORE queries at once. " "Pass every distinct query you want in a single `queries` list instead of calling this tool " "repeatedly — searches run concurrently and one failing query does not abort the others. " "When many queries are batched, per-query documents and per-document content shrink automatically " "to keep the result compact; use a single query when you need maximum depth on one topic. " "Avoid firing many near-identical reworded queries (diminishing returns) — plan a deduplicated set first." ) parameters = { "type": "object", "properties": { "queries": { "type": "array", "items": {"type": "string"}, "description": f"Search queries (1-{_MAX_QUERIES}), Chinese or English; technical terms are usually better in English. Batch distinct queries together.", }, "kb_names": { "type": "array", "items": {"type": "string"}, "description": "Optional knowledge-base names from document_list_kb (applies to all queries).", }, "classification_ids": { "type": "array", "items": {"type": "integer"}, "description": "Optional materials domain ids, 1-7 (applies to all queries).", }, "max_documents": { "type": "integer", "default": 6, "description": "Documents per query, 1-20 (auto-reduced when many queries are batched).", }, "content_chars_per_doc": { "type": "integer", "default": 1200, "description": "Maximum markdown characters per document, 0-5000 (auto-reduced when many queries are batched).", }, }, "required": ["queries"], } def _search_one( self, query: str, kb_names: Optional[list[str]], classification_ids: Optional[list[int]], max_documents: int, content_chars_per_doc: int, ) -> str: """搜单个 query,返回格式化文本块或 [Error ...];绝不抛异常(供并发安全调用)。""" try: docs = doc_client.search( query=query, kb_names=kb_names or None, classification_ids=classification_ids or None, max_documents=max_documents, ) except Exception as e: return f"[Error] document_search failed: {type(e).__name__}: {e}" if not docs: return f"(no documents found for query: {query!r})" lines = [f"Document search results for: {query!r}"] for i, d in enumerate(docs, 1): content = d.get("md_content") or "" snippet, truncated = _clip(str(content), content_chars_per_doc) lines.append("") lines.append(f"{i}. file_name={d.get('file_name') or ''}") lines.append(f" kb_name={d.get('kb_name') or ''}") lines.append(f" character_count={d.get('character_count') or 0}") if d.get("md_filename"): lines.append(f" md_filename={d.get('md_filename')}") if snippet: suffix = " ...(truncated)" if truncated else "" lines.append(f" md_content[:{content_chars_per_doc}]={snippet}{suffix}") return "\n".join(lines) def execute( self, queries: list[str] | str, kb_names: Optional[list[str]] = None, classification_ids: Optional[list[int]] = None, max_documents: int = 6, content_chars_per_doc: int = 1200, ) -> str: if isinstance(queries, str): queries = [queries] queries = _dedup_keep_order([q for q in (queries or []) if isinstance(q, str)]) if not queries: return "[Error] queries 不能为空" dropped = 0 if len(queries) > _MAX_QUERIES: dropped = len(queries) - _MAX_QUERIES queries = queries[:_MAX_QUERIES] n = len(queries) max_documents = min(max(int(max_documents), 1), 20) content_chars_per_doc = min(max(int(content_chars_per_doc), 0), 5000) # 批量时自动缩量,bound 总输出(单 query 时保持用户给定值不动) if n > 1: max_documents = min(max_documents, max(2, 12 // n)) content_chars_per_doc = min(content_chars_per_doc, max(400, 6000 // n)) with ThreadPoolExecutor(max_workers=min(_CONCURRENCY, n)) as pool: results = list(pool.map( lambda q: self._search_one(q, kb_names, classification_ids, max_documents, content_chars_per_doc), queries, )) if n == 1: out = results[0] return out if not dropped else out + f"\n\n[note] 多余 {dropped} 个 query 被丢弃(单次上限 {_MAX_QUERIES})" blocks = [] for i, (q, text) in enumerate(zip(queries, results), 1): blocks.append(f"=== [{i}/{n}] {q!r} ===\n{text}") out = "\n\n".join(blocks) if dropped: out += f"\n\n[note] 多余 {dropped} 个 query 被丢弃(单次上限 {_MAX_QUERIES})" return out class DocumentDownloadTool(Tool): name = "document_download" description = ( "Download one OR MORE original documents from document_search into task_dir/documents/. " "Pass every document you want in a single `items` list instead of calling this tool repeatedly — " "downloads run concurrently and one failing item does not abort the others. " "Use the file_name and kb_name returned by document_search." ) parameters = { "type": "object", "properties": { "items": { "type": "array", "items": { "type": "object", "properties": { "file_name": {"type": "string", "description": "Original file_name or md_filename returned by document_search."}, "kb_name": {"type": "string", "description": "Knowledge-base name returned by document_search."}, "preview": {"type": "boolean", "description": "Request inline preview disposition. Usually false."}, }, "required": ["file_name", "kb_name"], }, "description": f"Documents to download (1-{_MAX_DOWNLOADS}). Batch every document into one call.", }, }, "required": ["items"], } def __init__( self, *, working_dir: Path, base_dir: Optional[Path] = None, user_root: Optional[Path] = None, ) -> None: super().__init__(base_dir=base_dir, user_root=user_root) self.working_dir = Path(working_dir) def _download_one(self, item: dict) -> str: """下载单个 item,返回 'saved: ...' 或 [Error ...];绝不抛异常(供并发安全调用)。""" if not isinstance(item, dict): return f"[Error] 非法 item(应为对象): {item!r}" file_name = str(item.get("file_name") or "").strip() kb_name = str(item.get("kb_name") or "").strip() if not file_name or not kb_name: return f"[Error] file_name / kb_name 不可为空: {item!r}" try: rel = doc_client.download( file_name=file_name, kb_name=kb_name, working_dir=str(self.working_dir), preview=bool(item.get("preview", False)), ) except Exception as e: return f"[Error] download {file_name!r} failed: {type(e).__name__}: {e}" return f"saved: {self._display(self.working_dir / rel)}" def execute(self, items: list[dict] | dict) -> str: if isinstance(items, dict): items = [items] items = [it for it in (items or []) if isinstance(it, dict)] if not items: return "[Error] items 不能为空" dropped = 0 if len(items) > _MAX_DOWNLOADS: dropped = len(items) - _MAX_DOWNLOADS items = items[:_MAX_DOWNLOADS] with ThreadPoolExecutor(max_workers=min(_CONCURRENCY, len(items))) as pool: results = list(pool.map(self._download_one, items)) if len(items) == 1: out = results[0] return out if not dropped else out + f"\n[note] 多余 {dropped} 个被丢弃(单次上限 {_MAX_DOWNLOADS})" lines = [f"{i}. {r}" for i, r in enumerate(results, 1)] out = "\n".join(lines) if dropped: out += f"\n[note] 多余 {dropped} 个被丢弃(单次上限 {_MAX_DOWNLOADS})" return out