zcbot/skills/documents/client.py

159 lines
4.8 KiB
Python

"""document_search API 客户端 helper。
base_url / api_key 走 env:
DOCUMENT_SEARCH_URL 默认 https://ai.ctc-zc.com:8100/api
DOCUMENT_SEARCH_API_KEY 必填,缺失时调用立即抛 RuntimeError(而不是裸 401)
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Optional
import httpx
_BASE_URL = os.environ.get(
"DOCUMENT_SEARCH_URL", "https://ai.ctc-zc.com:8100/api"
).rstrip("/")
_API = f"{_BASE_URL}/document_search"
_TIMEOUT = 30.0
_DOWNLOAD_TIMEOUT = 120.0
# search 返回字段(剥掉项目里不常用的 file_version / document_loader / text_splitter / custom_docs
# 等,但保留 md_content —— 这是接口最大价值)
_LIST_FIELDS = (
"id",
"kb_name",
"file_name",
"file_ext",
"create_time",
"file_mtime",
"file_size",
"docs_count",
"character_count",
"md_filename",
"md_content",
"classification_ids",
"url",
)
def _api_key() -> str:
key = os.environ.get("DOCUMENT_SEARCH_API_KEY", "").strip()
if not key:
raise RuntimeError(
"DOCUMENT_SEARCH_API_KEY env 未设置 —— 配置后再调 documents skill"
)
return key
def _auth_headers(extra: Optional[dict] = None) -> dict:
h = {"Authorization": f"Bearer {_api_key()}"}
if extra:
h.update(extra)
return h
def _safe_name(name: str) -> str:
# 防目录穿越;保留扩展名
return name.replace("/", "_").replace("\\", "_").replace("..", "_")
def list_kb() -> list[dict]:
"""列出所有有效知识库(对应 GET /list_knowledge_bases)。
返回每条含 id / kb_name / ch_name / kb_info / customer_id / classification_type
/ create_time / file_count。只返回 ID 映射里有效的(分类 1-7)。
"""
r = httpx.get(f"{_API}/list_knowledge_bases", headers=_auth_headers(), timeout=_TIMEOUT)
r.raise_for_status()
payload = r.json()
data = payload.get("data") or {}
return list(data.get("knowledge_bases") or [])
def search(
query: str,
kb_names: Optional[list[str]] = None,
classification_ids: Optional[list[int]] = None,
max_documents: int = 6,
) -> list[dict]:
"""搜文档(对应 POST /search)。返回精简列表,每条带 md_content(整篇 Markdown)。
query: 搜索词
kb_names: 知识库白名单(从 list_kb() 选);None 走 server 默认值
classification_ids: 分类 ID 白名单(1-7);None 不过滤
max_documents: 1-20,默认 6
SKILL.md 反模式:不要 print 整个 md_content(动辄几十 K),只打前 200-400 字判断切题。
"""
if not query:
raise ValueError("query 不可为空")
if max_documents < 1:
max_documents = 1
if max_documents > 20:
max_documents = 20
body: dict[str, Any] = {"query": query, "max_documents": max_documents}
if kb_names:
body["kb_names"] = kb_names
if classification_ids:
body["classification_ids"] = classification_ids
r = httpx.post(
f"{_API}/search",
headers=_auth_headers({"Content-Type": "application/json"}),
json=body,
timeout=_TIMEOUT,
)
r.raise_for_status()
payload = r.json()
data = payload.get("data") or {}
docs = data.get("documents") or []
return [{k: d.get(k) for k in _LIST_FIELDS} for d in docs]
def download(
file_name: str,
kb_name: str,
working_dir: str,
preview: bool = False,
) -> str:
"""下载原始文档到 <working_dir>/documents/<safe_file_name>,返回相对路径。
file_name 支持原始文件名(example.pdf)或 Markdown 名(example.md)—— server 会
回退查 DB 拿原始名。已存在跳过下载直接复用。
preview=True 给浏览器内预览 header(Content-Disposition: inline),通常 agent
不需要(我们要落盘读取),保留参数透传。
"""
if not file_name or not kb_name:
raise ValueError("file_name / kb_name 不可为空")
rel = f"documents/{_safe_name(file_name)}"
dest = Path(working_dir) / rel
if dest.exists() and dest.stat().st_size > 0:
return rel
dest.parent.mkdir(parents=True, exist_ok=True)
params = {
"knowledge_base_name": kb_name,
"file_name": file_name,
"preview": "true" if preview else "false",
}
with httpx.stream(
"GET",
f"{_API}/download_doc",
headers=_auth_headers(),
params=params,
timeout=_DOWNLOAD_TIMEOUT,
) as resp:
resp.raise_for_status()
with open(dest, "wb") as f:
for chunk in resp.iter_bytes(chunk_size=64 * 1024):
f.write(chunk)
return rel
def health() -> dict:
"""健康检查(公开,不需要认证)。"""
r = httpx.get(f"{_API}/health", timeout=_TIMEOUT)
r.raise_for_status()
return r.json().get("data") or {}