zcbot/tests/test_secret_host_tools.py

170 lines
5.8 KiB
Python

from __future__ import annotations
import json
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
class TestDocumentHostTools(unittest.TestCase):
def test_document_search_truncates_content_without_requiring_key_arg(self):
from tools.documents import DocumentSearchTool
docs = [
{
"file_name": "paper.md",
"kb_name": "mu_1",
"character_count": 50000,
"md_content": "A" * 200,
}
]
with patch("tools.documents.doc_client.search", return_value=docs) as search:
# 单 query 批量:queries 列表只一条时,缩量逻辑不动用户给的参数
out = DocumentSearchTool().execute(
queries=["cement hydration"],
max_documents=3,
content_chars_per_doc=20,
)
search.assert_called_once_with(
query="cement hydration",
kb_names=None,
classification_ids=None,
max_documents=3,
)
self.assertIn("paper.md", out)
self.assertIn("mu_1", out)
self.assertIn("A" * 20, out)
self.assertIn("truncated", out)
def test_document_search_batches_queries_concurrently_and_dedups(self):
from tools.documents import DocumentSearchTool
calls: list[str] = []
def fake_search(query, **kwargs):
calls.append(query)
return [{"file_name": f"{query}.md", "kb_name": "mu_1", "md_content": "x"}]
with patch("tools.documents.doc_client.search", side_effect=fake_search):
out = DocumentSearchTool().execute(
queries=["q1", "q2", "q1"], # 含重复 → 去重成 q1/q2
)
self.assertEqual(sorted(calls), ["q1", "q2"]) # 去重后只两次
self.assertIn("[1/2]", out)
self.assertIn("[2/2]", out)
self.assertIn("'q1'", out)
self.assertIn("'q2'", out)
def test_document_download_uses_constructor_working_dir(self):
from tools.documents import DocumentDownloadTool
with tempfile.TemporaryDirectory() as tmp:
working_dir = Path(tmp) / "task"
working_dir.mkdir()
with patch(
"tools.documents.doc_client.download",
return_value="documents/paper.pdf",
) as download:
tool = DocumentDownloadTool(
working_dir=working_dir,
base_dir=working_dir,
user_root=Path(tmp),
)
out = tool.execute(items=[{"file_name": "paper.pdf", "kb_name": "mu_1"}])
download.assert_called_once_with(
file_name="paper.pdf",
kb_name="mu_1",
working_dir=str(working_dir),
preview=False,
)
self.assertIn("saved: task/documents/paper.pdf", out)
def test_document_download_batches_items_isolating_failure(self):
from tools.documents import DocumentDownloadTool
with tempfile.TemporaryDirectory() as tmp:
working_dir = Path(tmp) / "task"
working_dir.mkdir()
def fake_download(file_name, kb_name, working_dir, preview):
if file_name == "bad.pdf":
raise RuntimeError("404")
return f"documents/{file_name}"
with patch("tools.documents.doc_client.download", side_effect=fake_download):
tool = DocumentDownloadTool(
working_dir=working_dir, base_dir=working_dir, user_root=Path(tmp)
)
out = tool.execute(items=[
{"file_name": "ok.pdf", "kb_name": "mu_1"},
{"file_name": "bad.pdf", "kb_name": "mu_1"},
])
# 一条失败不连坐另一条
self.assertIn("saved: task/documents/ok.pdf", out)
self.assertIn("[Error]", out)
self.assertIn("bad.pdf", out)
class TestMaterialsProjectHostTools(unittest.TestCase):
def test_mp_search_summary_uses_host_key_and_returns_json(self):
from tools.materials_project import MaterialsProjectSearchSummaryTool
class FakeDoc:
material_id = "mp-1"
formula_pretty = "Ca3SiO5"
energy_above_hull = 0.0123
captured = {}
class FakeSummary:
def search(self, **kwargs):
captured.update(kwargs)
return [FakeDoc()]
class FakeMaterials:
def __init__(self):
self.summary = FakeSummary()
class FakeMPRester:
def __init__(self, api_key):
self.api_key = api_key
self.materials = FakeMaterials()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
with patch.dict("os.environ", {"MP_API_KEY": "host-secret"}, clear=False), patch(
"tools.materials_project.MPRester",
FakeMPRester,
):
out = MaterialsProjectSearchSummaryTool().execute(
formula="Ca3SiO5",
fields=["material_id", "formula_pretty", "energy_above_hull"],
limit=2,
)
data = json.loads(out)
self.assertEqual(data[0]["material_id"], "mp-1")
self.assertEqual(data[0]["formula_pretty"], "Ca3SiO5")
self.assertEqual(data[0]["energy_above_hull"], 0.0123)
self.assertNotIn("host-secret", out)
# Server-side limiting: a single page of `limit` records, never the
# default chunk_size=1000 full-database pagination that gets the IP banned.
self.assertEqual(captured["num_chunks"], 1)
self.assertEqual(captured["chunk_size"], 2)
if __name__ == "__main__":
unittest.main()