"""Host-side Materials Project tools. MP_API_KEY stays on the host. The sandbox can use offline pymatgen on files written to task_dir, but it must not receive the Materials Project key. """ from __future__ import annotations import json import os from pathlib import Path from typing import Any, Optional from .base import Tool try: # patched in tests; missing dependency should produce a clean tool error. from mp_api.client import MPRester # type: ignore except Exception: # pragma: no cover - exercised when mp-api is not installed MPRester = None # type: ignore _DEFAULT_SUMMARY_FIELDS = [ "material_id", "formula_pretty", "symmetry", "energy_above_hull", ] def _mp_key() -> str: key = os.environ.get("MP_API_KEY", "").strip() if not key: raise RuntimeError("MP_API_KEY env 未设置,无法查询 Materials Project") return key def _to_plain(obj: Any) -> Any: if obj is None or isinstance(obj, (str, int, float, bool)): return obj if isinstance(obj, (list, tuple)): return [_to_plain(x) for x in obj] if isinstance(obj, dict): return {str(k): _to_plain(v) for k, v in obj.items()} if hasattr(obj, "model_dump"): return _to_plain(obj.model_dump()) if hasattr(obj, "as_dict"): return _to_plain(obj.as_dict()) # mp-api documents expose fields as attributes. out: dict[str, Any] = {} for name in _DEFAULT_SUMMARY_FIELDS: if hasattr(obj, name): out[name] = _to_plain(getattr(obj, name)) return out or str(obj) def _mpr(): if MPRester is None: raise RuntimeError("mp-api 未安装,请在宿主环境安装 mp-api 后再启用 MP tool") return MPRester(_mp_key()) class MaterialsProjectSearchSummaryTool(Tool): name = "mp_search_summary" description = ( "Search Materials Project summary data using the host MP_API_KEY. " "Returns trimmed JSON; use mp_get_structure to save a CIF for offline pymatgen analysis." ) parameters = { "type": "object", "properties": { "formula": {"type": "string", "description": "Optional formula such as Ca3SiO5."}, "material_ids": {"type": "array", "items": {"type": "string"}, "description": "Optional Materials Project ids such as mp-123."}, "elements": {"type": "array", "items": {"type": "string"}, "description": "Optional element symbols for a chemical system search."}, "fields": {"type": "array", "items": {"type": "string"}, "description": "Fields to return; defaults to material_id/formula_pretty/symmetry/energy_above_hull. Do NOT put the search formula here — use formula_pretty (not 'formula'). Common fields: formula_pretty, symmetry, energy_above_hull, band_gap, density, volume, is_stable, structure."}, "limit": {"type": "integer", "default": 10, "description": "Maximum records returned, 1-50. Server-side bounded — keep small to stay within MP fair-use."}, }, } def execute( self, formula: str = "", material_ids: Optional[list[str]] = None, elements: Optional[list[str]] = None, fields: Optional[list[str]] = None, limit: int = 10, ) -> str: limit = min(max(int(limit), 1), 50) chosen_fields = fields or _DEFAULT_SUMMARY_FIELDS kwargs: dict[str, Any] = {"fields": chosen_fields} if formula: kwargs["formula"] = formula if material_ids: kwargs["material_ids"] = material_ids if elements: kwargs["elements"] = elements if len(kwargs) == 1: return "[Error] formula / material_ids / elements 至少传一个" try: with _mpr() as mpr: # num_chunks=1 + chunk_size=limit 让服务端单次只回 limit 条、不翻页。 # 否则 mp-api 默认 chunk_size=1000 且自动拉完所有页 —— 整库级下载, # 会被 MP 判定为 abusive traffic 并封 IP/ASN。 docs = mpr.materials.summary.search( num_chunks=1, chunk_size=limit, **kwargs ) except Exception as e: return f"[Error] mp_search_summary failed: {type(e).__name__}: {e}" plain = [_to_plain(d) for d in list(docs)[:limit]] return json.dumps(plain, ensure_ascii=False, indent=2) class MaterialsProjectGetStructureTool(Tool): name = "mp_get_structure" description = ( "Download a Materials Project structure by material_id and save it as CIF in task_dir/materials/." ) parameters = { "type": "object", "properties": { "material_id": {"type": "string", "description": "Materials Project id, e.g. mp-123."}, "filename": {"type": "string", "description": "Optional CIF filename. Defaults to .cif."}, }, "required": ["material_id"], } def __init__( self, *, working_dir: Path, base_dir: Optional[Path] = None, user_root: Optional[Path] = None, ) -> None: super().__init__(base_dir=base_dir, user_root=user_root) self.working_dir = Path(working_dir) def execute(self, material_id: str, filename: str = "") -> str: material_id = (material_id or "").strip() if not material_id: return "[Error] material_id 不可为空" safe_name = (filename or f"{material_id}.cif").replace("/", "_").replace("\\", "_").replace("..", "_") if not safe_name.lower().endswith(".cif"): safe_name += ".cif" dest = self.working_dir / "materials" / safe_name try: with _mpr() as mpr: struct = mpr.get_structure_by_material_id(material_id) dest.parent.mkdir(parents=True, exist_ok=True) struct.to(filename=str(dest)) except Exception as e: return f"[Error] mp_get_structure failed: {type(e).__name__}: {e}" return f"saved: {self._display(dest)}" class MaterialsProjectGetEntriesTool(Tool): name = "mp_get_entries" description = ( "Fetch Materials Project computed entries for a chemical system and save trimmed JSON to task_dir/materials/. " "Downloads the FULL chemical system (all sub-systems) — volume grows fast with element count; call sparingly and reuse the saved file rather than re-querying." ) parameters = { "type": "object", "properties": { "elements": {"type": "array", "items": {"type": "string"}, "description": "Chemical system elements, e.g. ['Ca','Si','O','H']. More elements = much heavier download."}, "filename": {"type": "string", "description": "Optional JSON filename. Defaults to mp_entries_.json."}, "limit": {"type": "integer", "default": 200, "description": "Max entries SAVED to disk, 1-1000. NOTE: this only trims the saved JSON — the full chemsys is still fetched from MP. It does not reduce network traffic."}, }, "required": ["elements"], } def __init__( self, *, working_dir: Path, base_dir: Optional[Path] = None, user_root: Optional[Path] = None, ) -> None: super().__init__(base_dir=base_dir, user_root=user_root) self.working_dir = Path(working_dir) def execute( self, elements: list[str], filename: str = "", limit: int = 200, ) -> str: elems = [e.strip() for e in (elements or []) if str(e).strip()] if not elems: return "[Error] elements 不可为空" limit = min(max(int(limit), 1), 1000) chemsys = "-".join(elems) safe_name = filename or f"mp_entries_{chemsys}.json" safe_name = safe_name.replace("/", "_").replace("\\", "_").replace("..", "_") if not safe_name.lower().endswith(".json"): safe_name += ".json" dest = self.working_dir / "materials" / safe_name try: with _mpr() as mpr: entries = mpr.get_entries_in_chemsys(elems) payload = [_to_plain(e) for e in list(entries)[:limit]] dest.parent.mkdir(parents=True, exist_ok=True) dest.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") except Exception as e: return f"[Error] mp_get_entries failed: {type(e).__name__}: {e}" return f"saved: {self._display(dest)}"