"""document_search API 客户端 helper。 base_url / api_key 走 env: DOCUMENT_SEARCH_URL 默认 https://ai.ctc-zc.com:8100/api DOCUMENT_SEARCH_API_KEY 必填,缺失时调用立即抛 RuntimeError(而不是裸 401) """ from __future__ import annotations import os from pathlib import Path from typing import Any, Optional import httpx _BASE_URL = os.environ.get( "DOCUMENT_SEARCH_URL", "https://ai.ctc-zc.com:8100/api" ).rstrip("/") _API = f"{_BASE_URL}/document_search" _TIMEOUT = 30.0 _DOWNLOAD_TIMEOUT = 120.0 # search 返回字段(剥掉项目里不常用的 file_version / document_loader / text_splitter / custom_docs # 等,但保留 md_content —— 这是接口最大价值) _LIST_FIELDS = ( "id", "kb_name", "file_name", "file_ext", "create_time", "file_mtime", "file_size", "docs_count", "character_count", "md_filename", "md_content", "classification_ids", "url", ) def _api_key() -> str: key = os.environ.get("DOCUMENT_SEARCH_API_KEY", "").strip() if not key: raise RuntimeError( "DOCUMENT_SEARCH_API_KEY env 未设置 —— 配置后再调 documents skill" ) return key def _auth_headers(extra: Optional[dict] = None) -> dict: h = {"Authorization": f"Bearer {_api_key()}"} if extra: h.update(extra) return h def _safe_name(name: str) -> str: # 防目录穿越;保留扩展名 return name.replace("/", "_").replace("\\", "_").replace("..", "_") def list_kb() -> list[dict]: """列出所有有效知识库(对应 GET /list_knowledge_bases)。 返回每条含 id / kb_name / ch_name / kb_info / customer_id / classification_type / create_time / file_count。只返回 ID 映射里有效的(分类 1-7)。 """ r = httpx.get(f"{_API}/list_knowledge_bases", headers=_auth_headers(), timeout=_TIMEOUT) r.raise_for_status() payload = r.json() data = payload.get("data") or {} return list(data.get("knowledge_bases") or []) def search( query: str, kb_names: Optional[list[str]] = None, classification_ids: Optional[list[int]] = None, max_documents: int = 6, ) -> list[dict]: """搜文档(对应 POST /search)。返回精简列表,每条带 md_content(整篇 Markdown)。 query: 搜索词 kb_names: 知识库白名单(从 list_kb() 选);None 走 server 默认值 classification_ids: 分类 ID 白名单(1-7);None 不过滤 max_documents: 1-20,默认 6 SKILL.md 反模式:不要 print 整个 md_content(动辄几十 K),只打前 200-400 字判断切题。 """ if not query: raise ValueError("query 不可为空") if max_documents < 1: max_documents = 1 if max_documents > 20: max_documents = 20 body: dict[str, Any] = {"query": query, "max_documents": max_documents} if kb_names: body["kb_names"] = kb_names if classification_ids: body["classification_ids"] = classification_ids r = httpx.post( f"{_API}/search", headers=_auth_headers({"Content-Type": "application/json"}), json=body, timeout=_TIMEOUT, ) r.raise_for_status() payload = r.json() data = payload.get("data") or {} docs = data.get("documents") or [] return [{k: d.get(k) for k in _LIST_FIELDS} for d in docs] def download( file_name: str, kb_name: str, working_dir: str, preview: bool = False, ) -> str: """下载原始文档到 /documents/,返回相对路径。 file_name 支持原始文件名(example.pdf)或 Markdown 名(example.md)—— server 会 回退查 DB 拿原始名。已存在跳过下载直接复用。 preview=True 给浏览器内预览 header(Content-Disposition: inline),通常 agent 不需要(我们要落盘读取),保留参数透传。 """ if not file_name or not kb_name: raise ValueError("file_name / kb_name 不可为空") rel = f"documents/{_safe_name(file_name)}" dest = Path(working_dir) / rel if dest.exists() and dest.stat().st_size > 0: return rel dest.parent.mkdir(parents=True, exist_ok=True) params = { "knowledge_base_name": kb_name, "file_name": file_name, "preview": "true" if preview else "false", } with httpx.stream( "GET", f"{_API}/download_doc", headers=_auth_headers(), params=params, timeout=_DOWNLOAD_TIMEOUT, ) as resp: resp.raise_for_status() with open(dest, "wb") as f: for chunk in resp.iter_bytes(chunk_size=64 * 1024): f.write(chunk) return rel def health() -> dict: """健康检查(公开,不需要认证)。""" r = httpx.get(f"{_API}/health", timeout=_TIMEOUT) r.raise_for_status() return r.json().get("data") or {}