198 lines
7.4 KiB
Python
198 lines
7.4 KiB
Python
"""Host-side Materials Project tools.
|
|
|
|
MP_API_KEY stays on the host. The sandbox can use offline pymatgen on files
|
|
written to task_dir, but it must not receive the Materials Project key.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
from .base import Tool
|
|
|
|
try: # patched in tests; missing dependency should produce a clean tool error.
|
|
from mp_api.client import MPRester # type: ignore
|
|
except Exception: # pragma: no cover - exercised when mp-api is not installed
|
|
MPRester = None # type: ignore
|
|
|
|
|
|
_DEFAULT_SUMMARY_FIELDS = [
|
|
"material_id",
|
|
"formula_pretty",
|
|
"symmetry",
|
|
"energy_above_hull",
|
|
]
|
|
|
|
|
|
def _mp_key() -> str:
|
|
key = os.environ.get("MP_API_KEY", "").strip()
|
|
if not key:
|
|
raise RuntimeError("MP_API_KEY env 未设置,无法查询 Materials Project")
|
|
return key
|
|
|
|
|
|
def _to_plain(obj: Any) -> Any:
|
|
if obj is None or isinstance(obj, (str, int, float, bool)):
|
|
return obj
|
|
if isinstance(obj, (list, tuple)):
|
|
return [_to_plain(x) for x in obj]
|
|
if isinstance(obj, dict):
|
|
return {str(k): _to_plain(v) for k, v in obj.items()}
|
|
if hasattr(obj, "model_dump"):
|
|
return _to_plain(obj.model_dump())
|
|
if hasattr(obj, "as_dict"):
|
|
return _to_plain(obj.as_dict())
|
|
# mp-api documents expose fields as attributes.
|
|
out: dict[str, Any] = {}
|
|
for name in _DEFAULT_SUMMARY_FIELDS:
|
|
if hasattr(obj, name):
|
|
out[name] = _to_plain(getattr(obj, name))
|
|
return out or str(obj)
|
|
|
|
|
|
def _mpr():
|
|
if MPRester is None:
|
|
raise RuntimeError("mp-api 未安装,请在宿主环境安装 mp-api 后再启用 MP tool")
|
|
return MPRester(_mp_key())
|
|
|
|
|
|
class MaterialsProjectSearchSummaryTool(Tool):
|
|
name = "mp_search_summary"
|
|
description = (
|
|
"Search Materials Project summary data using the host MP_API_KEY. "
|
|
"Returns trimmed JSON; use mp_get_structure to save a CIF for offline pymatgen analysis."
|
|
)
|
|
parameters = {
|
|
"type": "object",
|
|
"properties": {
|
|
"formula": {"type": "string", "description": "Optional formula such as Ca3SiO5."},
|
|
"material_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional Materials Project ids such as mp-123."},
|
|
"elements": {"type": "array", "items": {"type": "string"}, "description": "Optional element symbols for a chemical system search."},
|
|
"fields": {"type": "array", "items": {"type": "string"}, "description": "Fields to return; defaults to material_id/formula/symmetry/energy_above_hull."},
|
|
"limit": {"type": "integer", "default": 10, "description": "Maximum records returned, 1-50."},
|
|
},
|
|
}
|
|
|
|
def execute(
|
|
self,
|
|
formula: str = "",
|
|
material_ids: Optional[list[str]] = None,
|
|
elements: Optional[list[str]] = None,
|
|
fields: Optional[list[str]] = None,
|
|
limit: int = 10,
|
|
) -> str:
|
|
limit = min(max(int(limit), 1), 50)
|
|
chosen_fields = fields or _DEFAULT_SUMMARY_FIELDS
|
|
kwargs: dict[str, Any] = {"fields": chosen_fields}
|
|
if formula:
|
|
kwargs["formula"] = formula
|
|
if material_ids:
|
|
kwargs["material_ids"] = material_ids
|
|
if elements:
|
|
kwargs["elements"] = elements
|
|
if len(kwargs) == 1:
|
|
return "[Error] formula / material_ids / elements 至少传一个"
|
|
try:
|
|
with _mpr() as mpr:
|
|
docs = mpr.materials.summary.search(**kwargs)
|
|
except Exception as e:
|
|
return f"[Error] mp_search_summary failed: {type(e).__name__}: {e}"
|
|
plain = [_to_plain(d) for d in list(docs)[:limit]]
|
|
return json.dumps(plain, ensure_ascii=False, indent=2)
|
|
|
|
|
|
class MaterialsProjectGetStructureTool(Tool):
|
|
name = "mp_get_structure"
|
|
description = (
|
|
"Download a Materials Project structure by material_id and save it as CIF in task_dir/materials/."
|
|
)
|
|
parameters = {
|
|
"type": "object",
|
|
"properties": {
|
|
"material_id": {"type": "string", "description": "Materials Project id, e.g. mp-123."},
|
|
"filename": {"type": "string", "description": "Optional CIF filename. Defaults to <material_id>.cif."},
|
|
},
|
|
"required": ["material_id"],
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
working_dir: Path,
|
|
base_dir: Optional[Path] = None,
|
|
user_root: Optional[Path] = None,
|
|
) -> None:
|
|
super().__init__(base_dir=base_dir, user_root=user_root)
|
|
self.working_dir = Path(working_dir)
|
|
|
|
def execute(self, material_id: str, filename: str = "") -> str:
|
|
material_id = (material_id or "").strip()
|
|
if not material_id:
|
|
return "[Error] material_id 不可为空"
|
|
safe_name = (filename or f"{material_id}.cif").replace("/", "_").replace("\\", "_").replace("..", "_")
|
|
if not safe_name.lower().endswith(".cif"):
|
|
safe_name += ".cif"
|
|
dest = self.working_dir / "materials" / safe_name
|
|
try:
|
|
with _mpr() as mpr:
|
|
struct = mpr.get_structure_by_material_id(material_id)
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
struct.to(filename=str(dest))
|
|
except Exception as e:
|
|
return f"[Error] mp_get_structure failed: {type(e).__name__}: {e}"
|
|
return f"saved: {self._display(dest)}"
|
|
|
|
|
|
class MaterialsProjectGetEntriesTool(Tool):
|
|
name = "mp_get_entries"
|
|
description = (
|
|
"Fetch Materials Project computed entries for a chemical system and save trimmed JSON to task_dir/materials/."
|
|
)
|
|
parameters = {
|
|
"type": "object",
|
|
"properties": {
|
|
"elements": {"type": "array", "items": {"type": "string"}, "description": "Chemical system elements, e.g. ['Ca','Si','O','H']."},
|
|
"filename": {"type": "string", "description": "Optional JSON filename. Defaults to mp_entries_<chemsys>.json."},
|
|
"limit": {"type": "integer", "default": 200, "description": "Maximum entries saved, 1-1000."},
|
|
},
|
|
"required": ["elements"],
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
working_dir: Path,
|
|
base_dir: Optional[Path] = None,
|
|
user_root: Optional[Path] = None,
|
|
) -> None:
|
|
super().__init__(base_dir=base_dir, user_root=user_root)
|
|
self.working_dir = Path(working_dir)
|
|
|
|
def execute(
|
|
self,
|
|
elements: list[str],
|
|
filename: str = "",
|
|
limit: int = 200,
|
|
) -> str:
|
|
elems = [e.strip() for e in (elements or []) if str(e).strip()]
|
|
if not elems:
|
|
return "[Error] elements 不可为空"
|
|
limit = min(max(int(limit), 1), 1000)
|
|
chemsys = "-".join(elems)
|
|
safe_name = filename or f"mp_entries_{chemsys}.json"
|
|
safe_name = safe_name.replace("/", "_").replace("\\", "_").replace("..", "_")
|
|
if not safe_name.lower().endswith(".json"):
|
|
safe_name += ".json"
|
|
dest = self.working_dir / "materials" / safe_name
|
|
try:
|
|
with _mpr() as mpr:
|
|
entries = mpr.get_entries_in_chemsys(elems)
|
|
payload = [_to_plain(e) for e in list(entries)[:limit]]
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
dest.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
except Exception as e:
|
|
return f"[Error] mp_get_entries failed: {type(e).__name__}: {e}"
|
|
return f"saved: {self._display(dest)}"
|