from __future__ import annotations import json import sys import tempfile import unittest from pathlib import Path from unittest.mock import patch sys.path.insert(0, str(Path(__file__).resolve().parents[1])) class TestDocumentHostTools(unittest.TestCase): def test_document_search_truncates_content_without_requiring_key_arg(self): from tools.documents import DocumentSearchTool docs = [ { "file_name": "paper.md", "kb_name": "mu_1", "character_count": 50000, "md_content": "A" * 200, } ] with patch("tools.documents.doc_client.search", return_value=docs) as search: # 单 query 批量:queries 列表只一条时,缩量逻辑不动用户给的参数 out = DocumentSearchTool().execute( queries=["cement hydration"], max_documents=3, content_chars_per_doc=20, ) search.assert_called_once_with( query="cement hydration", kb_names=None, classification_ids=None, max_documents=3, ) self.assertIn("paper.md", out) self.assertIn("mu_1", out) self.assertIn("A" * 20, out) self.assertIn("truncated", out) def test_document_search_batches_queries_concurrently_and_dedups(self): from tools.documents import DocumentSearchTool calls: list[str] = [] def fake_search(query, **kwargs): calls.append(query) return [{"file_name": f"{query}.md", "kb_name": "mu_1", "md_content": "x"}] with patch("tools.documents.doc_client.search", side_effect=fake_search): out = DocumentSearchTool().execute( queries=["q1", "q2", "q1"], # 含重复 → 去重成 q1/q2 ) self.assertEqual(sorted(calls), ["q1", "q2"]) # 去重后只两次 self.assertIn("[1/2]", out) self.assertIn("[2/2]", out) self.assertIn("'q1'", out) self.assertIn("'q2'", out) def test_document_download_uses_constructor_working_dir(self): from tools.documents import DocumentDownloadTool with tempfile.TemporaryDirectory() as tmp: working_dir = Path(tmp) / "task" working_dir.mkdir() with patch( "tools.documents.doc_client.download", return_value="documents/paper.pdf", ) as download: tool = DocumentDownloadTool( working_dir=working_dir, base_dir=working_dir, user_root=Path(tmp), ) out = tool.execute(items=[{"file_name": "paper.pdf", "kb_name": "mu_1"}]) download.assert_called_once_with( file_name="paper.pdf", kb_name="mu_1", working_dir=str(working_dir), preview=False, ) self.assertIn("saved: task/documents/paper.pdf", out) def test_document_download_batches_items_isolating_failure(self): from tools.documents import DocumentDownloadTool with tempfile.TemporaryDirectory() as tmp: working_dir = Path(tmp) / "task" working_dir.mkdir() def fake_download(file_name, kb_name, working_dir, preview): if file_name == "bad.pdf": raise RuntimeError("404") return f"documents/{file_name}" with patch("tools.documents.doc_client.download", side_effect=fake_download): tool = DocumentDownloadTool( working_dir=working_dir, base_dir=working_dir, user_root=Path(tmp) ) out = tool.execute(items=[ {"file_name": "ok.pdf", "kb_name": "mu_1"}, {"file_name": "bad.pdf", "kb_name": "mu_1"}, ]) # 一条失败不连坐另一条 self.assertIn("saved: task/documents/ok.pdf", out) self.assertIn("[Error]", out) self.assertIn("bad.pdf", out) class TestMaterialsProjectHostTools(unittest.TestCase): def test_mp_search_summary_uses_host_key_and_returns_json(self): from tools.materials_project import MaterialsProjectSearchSummaryTool class FakeDoc: material_id = "mp-1" formula_pretty = "Ca3SiO5" energy_above_hull = 0.0123 captured = {} class FakeSummary: def search(self, **kwargs): captured.update(kwargs) return [FakeDoc()] class FakeMaterials: def __init__(self): self.summary = FakeSummary() class FakeMPRester: def __init__(self, api_key): self.api_key = api_key self.materials = FakeMaterials() def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False with patch.dict("os.environ", {"MP_API_KEY": "host-secret"}, clear=False), patch( "tools.materials_project.MPRester", FakeMPRester, ): out = MaterialsProjectSearchSummaryTool().execute( formula="Ca3SiO5", fields=["material_id", "formula_pretty", "energy_above_hull"], limit=2, ) data = json.loads(out) self.assertEqual(data[0]["material_id"], "mp-1") self.assertEqual(data[0]["formula_pretty"], "Ca3SiO5") self.assertEqual(data[0]["energy_above_hull"], 0.0123) self.assertNotIn("host-secret", out) # Server-side limiting: a single page of `limit` records, never the # default chunk_size=1000 full-database pagination that gets the IP banned. self.assertEqual(captured["num_chunks"], 1) self.assertEqual(captured["chunk_size"], 2) if __name__ == "__main__": unittest.main()