159 lines
4.8 KiB
Python
159 lines
4.8 KiB
Python
"""document_search API 客户端 helper。
|
|
|
|
base_url / api_key 走 env:
|
|
DOCUMENT_SEARCH_URL 默认 https://ai.ctc-zc.com:8100/api
|
|
DOCUMENT_SEARCH_API_KEY 必填,缺失时调用立即抛 RuntimeError(而不是裸 401)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
import httpx
|
|
|
|
_BASE_URL = os.environ.get(
|
|
"DOCUMENT_SEARCH_URL", "https://ai.ctc-zc.com:8100/api"
|
|
).rstrip("/")
|
|
_API = f"{_BASE_URL}/document_search"
|
|
_TIMEOUT = 30.0
|
|
_DOWNLOAD_TIMEOUT = 120.0
|
|
|
|
# search 返回字段(剥掉项目里不常用的 file_version / document_loader / text_splitter / custom_docs
|
|
# 等,但保留 md_content —— 这是接口最大价值)
|
|
_LIST_FIELDS = (
|
|
"id",
|
|
"kb_name",
|
|
"file_name",
|
|
"file_ext",
|
|
"create_time",
|
|
"file_mtime",
|
|
"file_size",
|
|
"docs_count",
|
|
"character_count",
|
|
"md_filename",
|
|
"md_content",
|
|
"classification_ids",
|
|
"url",
|
|
)
|
|
|
|
|
|
def _api_key() -> str:
|
|
key = os.environ.get("DOCUMENT_SEARCH_API_KEY", "").strip()
|
|
if not key:
|
|
raise RuntimeError(
|
|
"DOCUMENT_SEARCH_API_KEY env 未设置 —— 配置后再调 documents skill"
|
|
)
|
|
return key
|
|
|
|
|
|
def _auth_headers(extra: Optional[dict] = None) -> dict:
|
|
h = {"Authorization": f"Bearer {_api_key()}"}
|
|
if extra:
|
|
h.update(extra)
|
|
return h
|
|
|
|
|
|
def _safe_name(name: str) -> str:
|
|
# 防目录穿越;保留扩展名
|
|
return name.replace("/", "_").replace("\\", "_").replace("..", "_")
|
|
|
|
|
|
def list_kb() -> list[dict]:
|
|
"""列出所有有效知识库(对应 GET /list_knowledge_bases)。
|
|
|
|
返回每条含 id / kb_name / ch_name / kb_info / customer_id / classification_type
|
|
/ create_time / file_count。只返回 ID 映射里有效的(分类 1-7)。
|
|
"""
|
|
r = httpx.get(f"{_API}/list_knowledge_bases", headers=_auth_headers(), timeout=_TIMEOUT)
|
|
r.raise_for_status()
|
|
payload = r.json()
|
|
data = payload.get("data") or {}
|
|
return list(data.get("knowledge_bases") or [])
|
|
|
|
|
|
def search(
|
|
query: str,
|
|
kb_names: Optional[list[str]] = None,
|
|
classification_ids: Optional[list[int]] = None,
|
|
max_documents: int = 6,
|
|
) -> list[dict]:
|
|
"""搜文档(对应 POST /search)。返回精简列表,每条带 md_content(整篇 Markdown)。
|
|
|
|
query: 搜索词
|
|
kb_names: 知识库白名单(从 list_kb() 选);None 走 server 默认值
|
|
classification_ids: 分类 ID 白名单(1-7);None 不过滤
|
|
max_documents: 1-20,默认 6
|
|
|
|
SKILL.md 反模式:不要 print 整个 md_content(动辄几十 K),只打前 200-400 字判断切题。
|
|
"""
|
|
if not query:
|
|
raise ValueError("query 不可为空")
|
|
if max_documents < 1:
|
|
max_documents = 1
|
|
if max_documents > 20:
|
|
max_documents = 20
|
|
body: dict[str, Any] = {"query": query, "max_documents": max_documents}
|
|
if kb_names:
|
|
body["kb_names"] = kb_names
|
|
if classification_ids:
|
|
body["classification_ids"] = classification_ids
|
|
r = httpx.post(
|
|
f"{_API}/search",
|
|
headers=_auth_headers({"Content-Type": "application/json"}),
|
|
json=body,
|
|
timeout=_TIMEOUT,
|
|
)
|
|
r.raise_for_status()
|
|
payload = r.json()
|
|
data = payload.get("data") or {}
|
|
docs = data.get("documents") or []
|
|
return [{k: d.get(k) for k in _LIST_FIELDS} for d in docs]
|
|
|
|
|
|
def download(
|
|
file_name: str,
|
|
kb_name: str,
|
|
working_dir: str,
|
|
preview: bool = False,
|
|
) -> str:
|
|
"""下载原始文档到 <working_dir>/documents/<safe_file_name>,返回相对路径。
|
|
|
|
file_name 支持原始文件名(example.pdf)或 Markdown 名(example.md)—— server 会
|
|
回退查 DB 拿原始名。已存在跳过下载直接复用。
|
|
|
|
preview=True 给浏览器内预览 header(Content-Disposition: inline),通常 agent
|
|
不需要(我们要落盘读取),保留参数透传。
|
|
"""
|
|
if not file_name or not kb_name:
|
|
raise ValueError("file_name / kb_name 不可为空")
|
|
rel = f"documents/{_safe_name(file_name)}"
|
|
dest = Path(working_dir) / rel
|
|
if dest.exists() and dest.stat().st_size > 0:
|
|
return rel
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
params = {
|
|
"knowledge_base_name": kb_name,
|
|
"file_name": file_name,
|
|
"preview": "true" if preview else "false",
|
|
}
|
|
with httpx.stream(
|
|
"GET",
|
|
f"{_API}/download_doc",
|
|
headers=_auth_headers(),
|
|
params=params,
|
|
timeout=_DOWNLOAD_TIMEOUT,
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
with open(dest, "wb") as f:
|
|
for chunk in resp.iter_bytes(chunk_size=64 * 1024):
|
|
f.write(chunk)
|
|
return rel
|
|
|
|
|
|
def health() -> dict:
|
|
"""健康检查(公开,不需要认证)。"""
|
|
r = httpx.get(f"{_API}/health", timeout=_TIMEOUT)
|
|
r.raise_for_status()
|
|
return r.json().get("data") or {}
|