256 lines
9.2 KiB
Python
256 lines
9.2 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
|
|
|
|
|
class TestDocumentHostTools(unittest.TestCase):
|
|
def test_document_search_truncates_content_without_requiring_key_arg(self):
|
|
from tools.documents import DocumentSearchTool
|
|
|
|
docs = [
|
|
{
|
|
"file_name": "paper.md",
|
|
"kb_name": "mu_1",
|
|
"character_count": 50000,
|
|
"md_content": "A" * 200,
|
|
}
|
|
]
|
|
with patch("tools.documents.doc_client.search", return_value=docs) as search:
|
|
# 单 query 批量:queries 列表只一条时,缩量逻辑不动用户给的参数
|
|
out = DocumentSearchTool().execute(
|
|
queries=["cement hydration"],
|
|
max_documents=3,
|
|
content_chars_per_doc=20,
|
|
)
|
|
|
|
search.assert_called_once_with(
|
|
query="cement hydration",
|
|
kb_names=None,
|
|
classification_ids=None,
|
|
max_documents=3,
|
|
)
|
|
self.assertIn("paper.md", out)
|
|
self.assertIn("mu_1", out)
|
|
self.assertIn("A" * 20, out)
|
|
self.assertIn("truncated", out)
|
|
|
|
def test_document_search_batches_queries_concurrently_and_dedups(self):
|
|
from tools.documents import DocumentSearchTool
|
|
|
|
calls: list[str] = []
|
|
|
|
def fake_search(query, **kwargs):
|
|
calls.append(query)
|
|
return [{"file_name": f"{query}.md", "kb_name": "mu_1", "md_content": "x"}]
|
|
|
|
with patch("tools.documents.doc_client.search", side_effect=fake_search):
|
|
out = DocumentSearchTool().execute(
|
|
queries=["q1", "q2", "q1"], # 含重复 → 去重成 q1/q2
|
|
)
|
|
|
|
self.assertEqual(sorted(calls), ["q1", "q2"]) # 去重后只两次
|
|
self.assertIn("[1/2]", out)
|
|
self.assertIn("[2/2]", out)
|
|
self.assertIn("'q1'", out)
|
|
self.assertIn("'q2'", out)
|
|
|
|
def test_document_download_uses_constructor_working_dir(self):
|
|
from tools.documents import DocumentDownloadTool
|
|
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
working_dir = Path(tmp) / "task"
|
|
working_dir.mkdir()
|
|
with patch(
|
|
"tools.documents.doc_client.download",
|
|
return_value="documents/paper.pdf",
|
|
) as download:
|
|
tool = DocumentDownloadTool(
|
|
working_dir=working_dir,
|
|
base_dir=working_dir,
|
|
user_root=Path(tmp),
|
|
)
|
|
out = tool.execute(items=[{"file_name": "paper.pdf", "kb_name": "mu_1"}])
|
|
|
|
download.assert_called_once_with(
|
|
file_name="paper.pdf",
|
|
kb_name="mu_1",
|
|
working_dir=str(working_dir),
|
|
preview=False,
|
|
)
|
|
self.assertIn("saved: task/documents/paper.pdf", out)
|
|
|
|
def test_document_download_batches_items_isolating_failure(self):
|
|
from tools.documents import DocumentDownloadTool
|
|
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
working_dir = Path(tmp) / "task"
|
|
working_dir.mkdir()
|
|
|
|
def fake_download(file_name, kb_name, working_dir, preview):
|
|
if file_name == "bad.pdf":
|
|
raise RuntimeError("404")
|
|
return f"documents/{file_name}"
|
|
|
|
with patch("tools.documents.doc_client.download", side_effect=fake_download):
|
|
tool = DocumentDownloadTool(
|
|
working_dir=working_dir, base_dir=working_dir, user_root=Path(tmp)
|
|
)
|
|
out = tool.execute(items=[
|
|
{"file_name": "ok.pdf", "kb_name": "mu_1"},
|
|
{"file_name": "bad.pdf", "kb_name": "mu_1"},
|
|
])
|
|
|
|
# 一条失败不连坐另一条
|
|
self.assertIn("saved: task/documents/ok.pdf", out)
|
|
self.assertIn("[Error]", out)
|
|
self.assertIn("bad.pdf", out)
|
|
|
|
|
|
class TestMaterialsProjectHostTools(unittest.TestCase):
|
|
def test_mp_search_summary_uses_host_key_and_returns_json(self):
|
|
from tools.materials_project import MaterialsProjectSearchSummaryTool
|
|
|
|
class FakeDoc:
|
|
material_id = "mp-1"
|
|
formula_pretty = "Ca3SiO5"
|
|
energy_above_hull = 0.0123
|
|
|
|
captured = {}
|
|
|
|
class FakeSummary:
|
|
def search(self, **kwargs):
|
|
captured.update(kwargs)
|
|
return [FakeDoc()]
|
|
|
|
class FakeMaterials:
|
|
def __init__(self):
|
|
self.summary = FakeSummary()
|
|
|
|
class FakeMPRester:
|
|
def __init__(self, api_key):
|
|
self.api_key = api_key
|
|
self.materials = FakeMaterials()
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
with patch.dict("os.environ", {"MP_API_KEY": "host-secret"}, clear=False), patch(
|
|
"tools.materials_project.MPRester",
|
|
FakeMPRester,
|
|
):
|
|
out = MaterialsProjectSearchSummaryTool().execute(
|
|
formula="Ca3SiO5",
|
|
fields=["material_id", "formula_pretty", "energy_above_hull"],
|
|
limit=2,
|
|
)
|
|
|
|
data = json.loads(out)
|
|
self.assertEqual(data[0]["material_id"], "mp-1")
|
|
self.assertEqual(data[0]["formula_pretty"], "Ca3SiO5")
|
|
self.assertEqual(data[0]["energy_above_hull"], 0.0123)
|
|
self.assertNotIn("host-secret", out)
|
|
# Server-side limiting: a single page of `limit` records, never the
|
|
# default chunk_size=1000 full-database pagination that gets the IP banned.
|
|
self.assertEqual(captured["num_chunks"], 1)
|
|
self.assertEqual(captured["chunk_size"], 2)
|
|
|
|
|
|
class TestHostFileToolPathResolution(unittest.TestCase):
|
|
"""send_email / wechat_push 在宿主进程读附件:agent 给的相对/容器 `/workspace` 路径
|
|
须翻回宿主 task 目录,越界仍挡。回归 docker 模式下「附件路径越界」发不出文件的 bug。"""
|
|
|
|
def _mk(self):
|
|
tmp = tempfile.TemporaryDirectory()
|
|
root = Path(tmp.name).resolve()
|
|
user_root = root / "users" / "uid"
|
|
wd = user_root / "wechat-abc" # 宿主 task 目录(= 容器 /workspace/wechat-abc)
|
|
wd.mkdir(parents=True)
|
|
(wd / "report.txt").write_text("x", encoding="utf-8")
|
|
return tmp, user_root, wd
|
|
|
|
def test_resolve_user_file_translates_and_bounds(self):
|
|
from tools.base import FileOutOfBounds, Tool
|
|
|
|
class _T(Tool):
|
|
def execute(self, **k):
|
|
return ""
|
|
|
|
tmp, user_root, wd = self._mk()
|
|
try:
|
|
t = _T(base_dir=wd, user_root=user_root)
|
|
# 相对路径 → 宿主 task 目录(原 bug:拼到 cwd 判越界)
|
|
self.assertEqual(t._resolve_user_file("report.txt"), (wd / "report.txt").resolve())
|
|
# 容器绝对路径 /workspace/... → 翻回 user_root 下(原 bug:宿主上解析判越界)
|
|
self.assertEqual(
|
|
t._resolve_user_file("/workspace/wechat-abc/report.txt"),
|
|
(wd / "report.txt").resolve(),
|
|
)
|
|
# 越界仍挡
|
|
for bad in ("../../../etc/passwd", "/workspace/../../etc/passwd"):
|
|
with self.assertRaises(FileOutOfBounds):
|
|
t._resolve_user_file(bad)
|
|
finally:
|
|
tmp.cleanup()
|
|
|
|
def test_send_email_attaches_container_path(self):
|
|
from tools.send_email import SendEmailTool
|
|
|
|
tmp, user_root, wd = self._mk()
|
|
try:
|
|
tool = SendEmailTool(base_dir=wd, user_root=user_root)
|
|
captured = {}
|
|
|
|
def fake_send(to, subject, body, attachments=None, **k):
|
|
captured["attachments"] = list(attachments or [])
|
|
|
|
with patch("tools.send_email.smtp_configured", return_value=True), patch(
|
|
"tools.send_email.send_email_smtp", side_effect=fake_send
|
|
):
|
|
out = tool.execute(
|
|
to=["a@b.com"], subject="s", body="b",
|
|
attachments=["/workspace/wechat-abc/report.txt"],
|
|
)
|
|
self.assertIn("[ok]", out)
|
|
self.assertIn("含 1 个附件", out)
|
|
self.assertEqual(captured["attachments"], [(wd / "report.txt").resolve()])
|
|
finally:
|
|
tmp.cleanup()
|
|
|
|
def test_wechat_push_resolves_relative_file(self):
|
|
from tools.wechat_bot import WechatPushTool
|
|
from uuid import uuid4
|
|
|
|
tmp, user_root, wd = self._mk()
|
|
try:
|
|
tool = WechatPushTool(uuid4(), base_dir=wd, user_root=user_root)
|
|
captured = {}
|
|
|
|
class _Report:
|
|
delivered = True
|
|
results = []
|
|
|
|
def fake_send(uid, text, fpath):
|
|
captured["fpath"] = fpath
|
|
return _Report()
|
|
|
|
with patch("core.wechat.service.send_to_user", side_effect=fake_send):
|
|
out = tool.execute(text="给你文件", file="report.txt")
|
|
self.assertIn("[ok]", out)
|
|
self.assertEqual(captured["fpath"], str((wd / "report.txt").resolve()))
|
|
finally:
|
|
tmp.cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|