diff --git a/.env.example b/.env.example index 51052fc..86e4527 100644 --- a/.env.example +++ b/.env.example @@ -82,4 +82,29 @@ LEMATERIALFETCHER_DEST_TABLE_NAME=optimade_materials # Transformer processing settings # LEMATERIALFETCHER_BATCH_SIZE=500 # LEMATERIALFETCHER_OFFSET=0 -# LEMATERIALFETCHER_LOG_EVERY=1000 +# LEMATERIALFETCHER_LOG_EVERY=1000 + +# ------------------------------------------------------------------------------ +# LeMatRho Configuration (charge density pipeline) +# ------------------------------------------------------------------------------ + +# AWS credentials for authenticated S3 access (LeMatRho bucket) +# AWS_ACCESS_KEY_ID=your_access_key +# AWS_SECRET_ACCESS_KEY=your_secret_key +# AWS_DEFAULT_REGION=us-east-1 + +# LeMatRho S3 bucket name +# LEMATERIALFETCHER_LEMATRHO_BUCKET_NAME=lemat-rho + +# VASP pseudopotential directory (required for Bader/DDEC6 analysis) +# PMG_VASP_PSP_DIR=/path/to/vasp/pseudopotentials + +# External tool paths (optional, auto-detected on PATH if not set) +# LEMATERIALFETCHER_BADER_PATH=/path/to/bader +# LEMATERIALFETCHER_CHARGEMOL_PATH=/path/to/chargemol +# LEMATERIALFETCHER_CHGSUM_SCRIPT_PATH=/path/to/chgsum.pl +# LEMATERIALFETCHER_ATOMIC_DENSITIES_PATH=/path/to/atomic_densities + +# HuggingFace (for pushing dataset after pipeline completes) +# HF_REPO_ID=your-org/lematrho-dataset +# HF_TOKEN=hf_your_token diff --git a/.gitignore b/.gitignore index a682896..21d486a 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,8 @@ celerybeat.pid # Environments .env +.env.* +!.env.example .venv # env/ venv/ diff --git a/pyproject.toml b/pyproject.toml index cb18c18..a9b7855 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "moyopy>=0.4.2", "ase>=3.24.0", "material-hasher", + "mp-pyrho>=0.3.1", ] [project.scripts] @@ -58,5 +59,10 @@ dev-dependencies = [ material-hasher = { git = "https://github.com/LeMaterial/lematerial-hasher.git" } +[tool.pytest.ini_options] +markers = [ + "integration: tests that require real AWS credentials and S3 access (deselect with '-m \"not integration\"')", +] + [tool.ruff.lint] extend-select = ["I"] diff --git a/src/lematerial_fetcher/cli.py b/src/lematerial_fetcher/cli.py index 95ec89c..ce61aa8 100644 --- a/src/lematerial_fetcher/cli.py +++ b/src/lematerial_fetcher/cli.py @@ -28,6 +28,9 @@ MPTrajectoryTransformer, MPTransformer, ) +from lematerial_fetcher.fetcher.lematrho.fetch import LeMatRhoFetcher +from lematerial_fetcher.fetcher.lematrho.pipeline import LeMatRhoDirectPipeline +from lematerial_fetcher.fetcher.lematrho.transform import LeMatRhoTransformer from lematerial_fetcher.fetcher.oqmd.fetch import OQMDFetcher from lematerial_fetcher.fetcher.oqmd.transform import ( OQMDTrajectoryTransformer, @@ -37,6 +40,9 @@ from lematerial_fetcher.utils.cli import ( add_common_options, add_fetch_options, + add_lematrho_direct_options, + add_lematrho_fetch_options, + add_lematrho_transform_options, add_mp_fetch_options, add_mysql_options, add_push_options, @@ -44,6 +50,7 @@ get_default_mp_bucket_name, ) from lematerial_fetcher.utils.config import ( + load_direct_pipeline_config, load_fetcher_config, load_push_config, load_transformer_config, @@ -114,9 +121,17 @@ def oqmd_cli(ctx): pass +@click.group(name="lematrho") +@click.pass_context +def lematrho_cli(ctx): + """Commands for fetching charge density data from LeMatRho.""" + pass + + cli.add_command(mp_cli) cli.add_command(alexandria_cli) cli.add_command(oqmd_cli) +cli.add_command(lematrho_cli) # ------------------------------------------------------------------------------ # MP commands @@ -341,6 +356,85 @@ def oqmd_transform(ctx, traj, **config_kwargs): logger.fatal("\nAborted.", exit=1) +# ------------------------------------------------------------------------------ +# LeMatRho commands +# ------------------------------------------------------------------------------ + + +@lematrho_cli.command(name="fetch") +@click.pass_context +@add_common_options +@add_fetch_options +@add_lematrho_fetch_options +def lematrho_fetch(ctx, **config_kwargs): + """Fetch charge density data from the LeMatRho S3 bucket. + + Downloads CHGCAR/AECCAR files, compresses charge densities via pyrho, + and stores compressed grids in the raw_structures database table. + + Requires AWS credentials (AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY). + """ + config_kwargs["base_url"] = "DUMMY_BASE_URL" # Not needed for LeMatRho + config_kwargs["mp_bucket_name"] = "" # Not needed for LeMatRho + config_kwargs["mp_bucket_prefix"] = "" # Not needed for LeMatRho + + # Convert grid_shape tuple from Click to proper format + if "grid_shape" in config_kwargs: + config_kwargs["lematrho_grid_shape"] = config_kwargs.pop("grid_shape") + + config = load_fetcher_config(**config_kwargs) + try: + fetcher = LeMatRhoFetcher(config=config, debug=ctx.obj["debug"]) + fetcher.fetch() + except KeyboardInterrupt: + logger.fatal("\nAborted.", exit=1) + + +@lematrho_cli.command(name="transform") +@click.pass_context +@add_common_options +@add_transformer_options +@add_lematrho_transform_options +def lematrho_transform(ctx, **config_kwargs): + """Transform raw LeMatRho structures into OPTIMADE format. + + Optionally runs Bader and DDEC6 charge analysis using external tools. + + External tool requirements: + - bader executable (--bader-path or on PATH) + - perl + chgsum.pl script (--chgsum-script-path) + - chargemol executable (--chargemol-path or on PATH) + - PMG_VASP_PSP_DIR environment variable for POTCAR generation + - Atomic densities directory (--atomic-densities-path) for DDEC6 + """ + config = load_transformer_config(**config_kwargs) + try: + transformer = LeMatRhoTransformer(config=config, debug=ctx.obj["debug"]) + transformer.transform() + except KeyboardInterrupt: + logger.fatal("\nAborted.", exit=1) + + +@lematrho_cli.command(name="run") +@click.pass_context +@add_lematrho_direct_options +def lematrho_run(ctx, **config_kwargs): + """Run complete LeMatRho pipeline: S3 -> Parquet -> HuggingFace. + + Downloads charge density data, compresses via pyrho, optionally runs + Bader and DDEC6 analysis, and writes Parquet files directly. + No PostgreSQL required. + + Requires AWS credentials (AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY). + """ + config = load_direct_pipeline_config(**config_kwargs) + try: + pipeline = LeMatRhoDirectPipeline(config=config, debug=ctx.obj["debug"]) + pipeline.run() + except KeyboardInterrupt: + logger.fatal("\nAborted.", exit=1) + + # ------------------------------------------------------------------------------ # Push commands # ------------------------------------------------------------------------------ diff --git a/src/lematerial_fetcher/database/postgres.py b/src/lematerial_fetcher/database/postgres.py index df2416f..1094f81 100644 --- a/src/lematerial_fetcher/database/postgres.py +++ b/src/lematerial_fetcher/database/postgres.py @@ -494,6 +494,14 @@ def columns(cls) -> dict[str, str]: "space_group_it_number": "INTEGER", "cross_compatibility": "BOOLEAN", "bawl_fingerprint": "TEXT", + "compressed_charge_density": "JSONB", + "compressed_aeccar0": "JSONB", + "compressed_aeccar1": "JSONB", + "compressed_aeccar2": "JSONB", + "charge_density_grid_shape": "INTEGER[]", + "bader_charges": "FLOAT[]", + "bader_atomic_volume": "FLOAT[]", + "ddec6_charges": "FLOAT[]", } def _prepare_species_data(self, species: list[dict[str, Any]]) -> list[Json]: @@ -572,6 +580,14 @@ def insert_data(self, structure: OptimadeStructure) -> None: structure.space_group_it_number, structure.cross_compatibility, structure.bawl_fingerprint, + Json(structure.compressed_charge_density), + Json(structure.compressed_aeccar0), + Json(structure.compressed_aeccar1), + Json(structure.compressed_aeccar2), + structure.charge_density_grid_shape, + structure.bader_charges, + structure.bader_atomic_volume, + structure.ddec6_charges, ) cur.execute(query, input_data) self.conn.commit() @@ -639,6 +655,14 @@ def batch_insert_data( structure.space_group_it_number, structure.cross_compatibility, structure.bawl_fingerprint, + Json(structure.compressed_charge_density), + Json(structure.compressed_aeccar0), + Json(structure.compressed_aeccar1), + Json(structure.compressed_aeccar2), + structure.charge_density_grid_shape, + structure.bader_charges, + structure.bader_atomic_volume, + structure.ddec6_charges, ) ) @@ -740,6 +764,14 @@ def insert_data(self, structure: Trajectory) -> None: structure.space_group_it_number, structure.cross_compatibility, structure.bawl_fingerprint, + Json(structure.compressed_charge_density), + Json(structure.compressed_aeccar0), + Json(structure.compressed_aeccar1), + Json(structure.compressed_aeccar2), + structure.charge_density_grid_shape, + structure.bader_charges, + structure.bader_atomic_volume, + structure.ddec6_charges, # trajectory-specific fields structure.relaxation_step, structure.relaxation_number, @@ -810,6 +842,14 @@ def batch_insert_data( structure.space_group_it_number, structure.cross_compatibility, structure.bawl_fingerprint, + Json(structure.compressed_charge_density), + Json(structure.compressed_aeccar0), + Json(structure.compressed_aeccar1), + Json(structure.compressed_aeccar2), + structure.charge_density_grid_shape, + structure.bader_charges, + structure.bader_atomic_volume, + structure.ddec6_charges, # trajectory-specific fields structure.relaxation_step, structure.relaxation_number, diff --git a/src/lematerial_fetcher/fetch.py b/src/lematerial_fetcher/fetch.py index 9f7184c..fc1b1a7 100644 --- a/src/lematerial_fetcher/fetch.py +++ b/src/lematerial_fetcher/fetch.py @@ -67,7 +67,7 @@ def db(self) -> StructuresDatabase: Returns ------- - StructuresDatabase + StructuresDatabase Database connection """ if self._db is None: diff --git a/src/lematerial_fetcher/fetcher/lematrho/__init__.py b/src/lematerial_fetcher/fetcher/lematrho/__init__.py new file mode 100644 index 0000000..7dd6f34 --- /dev/null +++ b/src/lematerial_fetcher/fetcher/lematrho/__init__.py @@ -0,0 +1 @@ +# Copyright 2025 Entalpic diff --git a/src/lematerial_fetcher/fetcher/lematrho/fetch.py b/src/lematerial_fetcher/fetcher/lematrho/fetch.py new file mode 100644 index 0000000..b6a4376 --- /dev/null +++ b/src/lematerial_fetcher/fetcher/lematrho/fetch.py @@ -0,0 +1,196 @@ +# Copyright 2025 Entalpic +from datetime import datetime +from multiprocessing import Manager +from typing import Optional + +from lematerial_fetcher.database.postgres import StructuresDatabase +from lematerial_fetcher.fetch import BaseFetcher, ItemsInfo +from lematerial_fetcher.fetcher.lematrho.utils import ( + DEFAULT_MAX_WORKERS, + GRID_KEY_MAP, + RELAX_CALC_TYPE, + STATIC_CALC_TYPE, + STATIC_FILES, + VALID_PREFIXES, + build_raw_structure, + compress_chgcar, + download_gz_file_from_s3, + parse_vasprun_structure, +) +from lematerial_fetcher.utils.aws import get_authenticated_aws_client +from lematerial_fetcher.utils.config import FetcherConfig, load_fetcher_config +from lematerial_fetcher.utils.logging import logger + + +class LeMatRhoFetcher(BaseFetcher): + """Fetcher for LeMatRho charge density data from an authenticated S3 bucket. + + Downloads CHGCAR/AECCAR files, compresses charge densities via pyrho, + and stores compressed grids in the raw_structures PostgreSQL table. + + Args: + config: Fetcher configuration. If ``None``, loads from defaults. + debug: If ``True``, process sequentially for debugging. + """ + + def __init__(self, config: FetcherConfig = None, debug: bool = False): + super().__init__(config or load_fetcher_config(), debug) + self.aws_client = None + self.manager = Manager() + self.manager_dict = self.manager.dict() + self.manager_dict["occurred"] = False + + def setup_resources(self) -> None: + """Set up authenticated AWS client and database connection.""" + self.aws_client = get_authenticated_aws_client() + self.setup_database() + + def get_items_to_process(self) -> ItemsInfo: + """List material folder prefixes from S3, filtered by valid ID prefixes. + + Only folders starting with ``VALID_PREFIXES`` (``oqmd-``, ``mp-``, + ``agm``) are included. + + Returns: + ``ItemsInfo`` containing material folder names to process. + + Raises: + ValueError: If ``lematrho_bucket_name`` is not set in config. + """ + bucket = self.config.lematrho_bucket_name + if not bucket: + raise ValueError("lematrho_bucket_name must be set in config") + + client = self.aws_client + paginator = client.get_paginator("list_objects_v2") + + material_folders = [] + for page in paginator.paginate(Bucket=bucket, Delimiter="/"): + for prefix_info in page.get("CommonPrefixes", []): + prefix = prefix_info["Prefix"] + # Extract folder name (remove trailing slash) + folder_name = prefix.rstrip("/") + if folder_name.startswith(VALID_PREFIXES): + material_folders.append(folder_name) + else: + logger.debug(f"Skipping folder with unknown prefix: {folder_name}") + + logger.info( + f"Found {len(material_folders)} material folders to process in bucket {bucket}" + ) + + return ItemsInfo( + start_offset=0, + total_count=len(material_folders), + items=material_folders, + ) + + @staticmethod + def _process_batch( + batch: str, config: FetcherConfig, manager_dict: dict, worker_id: int = 0 + ) -> bool: + """Process a single material folder from S3. + + Downloads vasprun.xml.gz (for structure), then CHGCAR/AECCAR files, + compresses charge densities, and inserts into PostgreSQL. + + Args: + batch: Material folder name, e.g. ``"agm000001"``. + config: Fetcher configuration. + manager_dict: Shared dict for inter-process communication. + worker_id: Worker process identifier. + + Returns: + ``True`` if successful, ``False`` if failed. + """ + material_id = batch + bucket = config.lematrho_bucket_name + grid_shape = config.lematrho_grid_shape or (15, 15, 15) + + try: + # Fresh clients per worker (multiprocessing safety) + aws_client = get_authenticated_aws_client() + db = StructuresDatabase(config.db_conn_str, config.table_name) + + # Step 1: Download and parse vasprun.xml.gz for structure + vasprun_key = f"{material_id}/{RELAX_CALC_TYPE}/vasprun.xml.gz" + try: + vasprun_bytes = download_gz_file_from_s3( + aws_client, bucket, vasprun_key + ) + structure = parse_vasprun_structure(vasprun_bytes) + del vasprun_bytes + except Exception as e: + logger.warning( + f"Failed to parse vasprun.xml.gz for {material_id}: {e}" + ) + return False + + # Step 2: Download and compress charge density files sequentially + compressed_grids: dict[str, Optional[list]] = { + "charge_density": None, + "aeccar0": None, + "aeccar1": None, + "aeccar2": None, + } + + for filename in STATIC_FILES: + s3_key = f"{material_id}/{STATIC_CALC_TYPE}/{filename}" + grid_name = GRID_KEY_MAP[filename] + try: + raw_bytes = download_gz_file_from_s3(aws_client, bucket, s3_key) + compressed = compress_chgcar(raw_bytes, grid_shape) + compressed_grids[grid_name] = compressed + del raw_bytes, compressed + except Exception as e: + logger.warning( + f"Failed to process {filename} for {material_id}: {e}" + ) + + # Step 3: Build and insert raw structure + raw_structure = build_raw_structure( + material_id=material_id, + structure=structure, + compressed_grids=compressed_grids, + grid_shape=grid_shape, + s3_prefix=material_id, + ) + db.insert_data(raw_structure) + + logger.debug(f"Successfully processed {material_id}") + return True + + except Exception as e: + logger.error(f"Failed to process material {material_id}: {e}") + if BaseFetcher.is_critical_error(e): + if manager_dict is not None: + manager_dict["occurred"] = True + return False + + def cleanup_resources(self) -> None: + """Clean up AWS client, database connection, and process manager.""" + if self.aws_client: + self.aws_client = None + + if hasattr(self, "manager"): + self.manager.shutdown() + + super().cleanup_resources() + + def get_new_version(self) -> str: + """Get version identifier for this fetch run. + + Returns: + Current date as ``YYYY-MM-DD`` string. + """ + return datetime.now().strftime("%Y-%m-%d") + + +def fetch(): + """Fetch charge density data from the LeMatRho S3 bucket.""" + fetcher = LeMatRhoFetcher() + fetcher.fetch() + + +if __name__ == "__main__": + fetch() diff --git a/src/lematerial_fetcher/fetcher/lematrho/pipeline.py b/src/lematerial_fetcher/fetcher/lematrho/pipeline.py new file mode 100644 index 0000000..3881f48 --- /dev/null +++ b/src/lematerial_fetcher/fetcher/lematrho/pipeline.py @@ -0,0 +1,863 @@ +# Copyright 2025 Entalpic +"""Direct S3-to-Parquet pipeline for LeMatRho charge density data. + +Bypasses PostgreSQL entirely: downloads from S3, compresses charge densities, +runs Bader/DDEC6 analysis, and writes Parquet files directly. +""" + +import concurrent.futures +import gc +import json +import os +import shutil +import subprocess +import tempfile +from datetime import datetime +from glob import glob +from typing import Optional + +import pyarrow as pa +import pyarrow.parquet as pq +from pymatgen.core import Structure + +from lematerial_fetcher.fetcher.lematrho.transform import ( + get_cross_compatibility, + parse_acf_dat, + parse_ddec6_charges, + read_potcar_zval, +) +from lematerial_fetcher.fetcher.lematrho.utils import ( + BADER_TIMEOUT, + CHARGEMOL_TIMEOUT, + CHGSUM_TIMEOUT, + GRID_KEY_MAP, + RELAX_CALC_TYPE, + STATIC_CALC_TYPE, + STATIC_FILES, + VALID_PREFIXES, + compress_chgcar, + download_gz_file_from_s3, + parse_vasprun_structure, + write_potcar, +) +from lematerial_fetcher.models.optimade import Functional, OptimadeStructure +from lematerial_fetcher.utils.aws import get_authenticated_aws_client +from lematerial_fetcher.utils.config import DirectPipelineConfig +from lematerial_fetcher.utils.logging import logger +from lematerial_fetcher.utils.structure import get_optimade_from_pymatgen + +# Columns in the output Parquet files (matches HuggingFace Features schema) +PARQUET_COLUMNS = [ + "elements", + "nsites", + "chemical_formula_anonymous", + "chemical_formula_reduced", + "chemical_formula_descriptive", + "nelements", + "dimension_types", + "nperiodic_dimensions", + "lattice_vectors", + "immutable_id", + "cartesian_site_positions", + "species", + "species_at_sites", + "last_modified", + "elements_ratios", + "stress_tensor", + "energy", + "energy_corrected", + "magnetic_moments", + "forces", + "total_magnetization", + "charges", + "dos_ef", + "functional", + "cross_compatibility", + "bawl_fingerprint", + "space_group_it_number", + "compressed_charge_density", + "compressed_aeccar0", + "compressed_aeccar1", + "compressed_aeccar2", + "charge_density_grid_shape", + "bader_charges", + "bader_atomic_volume", + "ddec6_charges", +] + +# PyArrow schema matching the HuggingFace Features +PARQUET_SCHEMA = pa.schema( + [ + ("elements", pa.list_(pa.string())), + ("nsites", pa.int32()), + ("chemical_formula_anonymous", pa.string()), + ("chemical_formula_reduced", pa.string()), + ("chemical_formula_descriptive", pa.string()), + ("nelements", pa.int8()), + ("dimension_types", pa.list_(pa.int8())), + ("nperiodic_dimensions", pa.int8()), + ("lattice_vectors", pa.list_(pa.list_(pa.float64()))), + ("immutable_id", pa.string()), + ("cartesian_site_positions", pa.list_(pa.list_(pa.float64()))), + ("species", pa.string()), # JSON-serialized + ("species_at_sites", pa.list_(pa.string())), + ("last_modified", pa.string()), + ("elements_ratios", pa.list_(pa.float64())), + ("stress_tensor", pa.list_(pa.list_(pa.float64()))), + ("energy", pa.float64()), + ("energy_corrected", pa.float64()), + ("magnetic_moments", pa.list_(pa.float64())), + ("forces", pa.list_(pa.list_(pa.float64()))), + ("total_magnetization", pa.float64()), + ("charges", pa.list_(pa.float64())), + ("dos_ef", pa.float64()), + ("functional", pa.string()), + ("cross_compatibility", pa.bool_()), + ("bawl_fingerprint", pa.string()), + ("space_group_it_number", pa.int32()), + ("compressed_charge_density", pa.string()), # JSON-serialized + ("compressed_aeccar0", pa.string()), # JSON-serialized + ("compressed_aeccar1", pa.string()), # JSON-serialized + ("compressed_aeccar2", pa.string()), # JSON-serialized + ("charge_density_grid_shape", pa.list_(pa.int32())), + ("bader_charges", pa.list_(pa.float64())), + ("bader_atomic_volume", pa.list_(pa.float64())), + ("ddec6_charges", pa.list_(pa.float64())), + ] +) + +# Files needed for Bader analysis (must keep raw bytes) +_BADER_FILES = {"CHGCAR.gz", "AECCAR0.gz", "AECCAR2.gz"} +# Files needed for DDEC6 analysis +_DDEC6_FILES = {"CHGCAR.gz"} + + + +def _run_bader_from_bytes( + structure: Structure, + raw_files: dict[str, bytes], + tool_paths: dict, + material_id: str, +) -> tuple[Optional[list[float]], Optional[list[float]]]: + """Run Bader charge analysis from raw decompressed file bytes. + + Writes raw VASP files to a temp directory, generates POTCAR, runs + ``chgsum.pl`` and ``bader``, and parses the resulting ACF.dat. + + Args: + structure: Pymatgen Structure for POTCAR generation. + raw_files: Mapping of VASP filenames to their raw bytes, + e.g. ``{"CHGCAR": b"...", "AECCAR0": b"...", "AECCAR2": b"..."}``. + tool_paths: Tool path configuration dict from ``_validate_tools()``. + material_id: Material identifier, used for logging. + + Returns: + Tuple of ``(net_charges, atomic_volumes)`` or ``(None, None)`` on failure. + """ + try: + with tempfile.TemporaryDirectory() as tmpdir: + # Write raw decompressed files (bader expects plain-text VASP format) + for name in ["CHGCAR", "AECCAR0", "AECCAR2"]: + with open(os.path.join(tmpdir, name), "wb") as f: + f.write(raw_files[name]) + + # Generate POTCAR + write_potcar(structure, tmpdir) + + # Sum AECCAR0 + AECCAR2 -> CHGCAR_sum + subprocess.run( + [ + tool_paths["perl_path"], + tool_paths["chgsum_script_path"], + "AECCAR0", + "AECCAR2", + ], + cwd=tmpdir, + timeout=CHGSUM_TIMEOUT, + check=True, + capture_output=True, + ) + + # Run Bader + subprocess.run( + [tool_paths["bader_path"], "CHGCAR", "-ref", "CHGCAR_sum"], + cwd=tmpdir, + timeout=BADER_TIMEOUT, + check=True, + capture_output=True, + ) + + # Parse results + electron_counts, atomic_volumes = parse_acf_dat( + os.path.join(tmpdir, "ACF.dat") + ) + zval = read_potcar_zval(os.path.join(tmpdir, "POTCAR")) + + net_charges = [] + for site, ec in zip(structure.sites, electron_counts): + element = str(site.specie) + valence = zval.get(element, 0) + net_charges.append(valence - ec) + + return net_charges, atomic_volumes + + except subprocess.TimeoutExpired: + logger.warning(f"Bader analysis timed out for {material_id}") + return None, None + except subprocess.CalledProcessError as e: + logger.warning( + f"Bader subprocess failed for {material_id}: " + f"exit code {e.returncode}, stderr: {e.stderr}" + ) + return None, None + except Exception as e: + logger.warning(f"Bader analysis failed for {material_id}: {e}") + return None, None + + +def _run_ddec6_from_bytes( + structure: Structure, + raw_files: dict[str, bytes], + tool_paths: dict, + material_id: str, +) -> Optional[list[float]]: + """Run DDEC6 charge analysis from raw decompressed file bytes. + + Writes CHGCAR and POTCAR to a temp directory, runs chargemol, and + parses the DDEC6 net atomic charges. + + Args: + structure: Pymatgen Structure for POTCAR generation. + raw_files: Mapping with at least ``{"CHGCAR": b"..."}``. + tool_paths: Tool path configuration dict from ``_validate_tools()``. + material_id: Material identifier, used for logging. + + Returns: + DDEC6 net charges per site, or ``None`` on failure. + """ + try: + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, "CHGCAR"), "wb") as f: + f.write(raw_files["CHGCAR"]) + + write_potcar(structure, tmpdir) + + # Write chargemol job control file + config_content = ( + "\n" + "0.0\n" + "\n" + "\n" + ".true.\n" + ".true.\n" + ".true.\n" + "\n" + "\n" + f"{tool_paths['atomic_densities_path']}\n" + "\n" + "\n" + "DDEC6\n" + "\n" + "\n" + "CHGCAR\n" + "\n" + ) + with open(os.path.join(tmpdir, "job_control.txt"), "w") as f: + f.write(config_content) + + # Run chargemol + env = os.environ.copy() + env["DDEC6_ATOMIC_DENSITIES_DIR"] = tool_paths["atomic_densities_path"] + subprocess.run( + [tool_paths["chargemol_path"]], + cwd=tmpdir, + timeout=CHARGEMOL_TIMEOUT, + check=True, + capture_output=True, + env=env, + ) + + return parse_ddec6_charges(tmpdir) + + except subprocess.TimeoutExpired: + logger.warning(f"DDEC6 analysis timed out for {material_id}") + return None + except subprocess.CalledProcessError as e: + logger.warning( + f"DDEC6 subprocess failed for {material_id}: " + f"exit code {e.returncode}, stderr: {e.stderr}" + ) + return None + except Exception as e: + logger.warning(f"DDEC6 analysis failed for {material_id}: {e}") + return None + + +def _structure_to_row( + optimade_structure: OptimadeStructure, +) -> dict: + """Convert an OptimadeStructure to a flat dict matching the Parquet schema. + + JSON-serializes species and compressed charge density fields. + Converts ``Functional`` enums to their string value and ``datetime`` + to ISO format. + + Args: + optimade_structure: Validated ``OptimadeStructure`` instance. + + Returns: + Flat dict with one key per ``PARQUET_COLUMNS`` entry, ready for + ``pyarrow.Table.from_pydict()``. + """ + row = {} + for col in PARQUET_COLUMNS: + row[col] = getattr(optimade_structure, col, None) + + # Convert datetime to ISO string + if row["last_modified"] is not None: + row["last_modified"] = row["last_modified"].isoformat() + + # Convert Functional enum to string + if row["functional"] is not None: + row["functional"] = row["functional"].value + + # JSON-serialize complex fields + if row["species"] is not None: + row["species"] = json.dumps(row["species"]) + + for col in [ + "compressed_charge_density", + "compressed_aeccar0", + "compressed_aeccar1", + "compressed_aeccar2", + ]: + if row[col] is not None: + row[col] = json.dumps(row[col]) + + return row + + +class LeMatRhoDirectPipeline: + """Direct S3-to-Parquet pipeline for LeMatRho charge density data. + + Downloads from S3, compresses charge densities via pyrho, optionally runs + Bader and DDEC6 charge analysis, and writes Parquet files directly. + No PostgreSQL required. + + Args: + config: Pipeline configuration. + debug: If ``True``, process sequentially in the main process. + """ + + def __init__(self, config: DirectPipelineConfig, debug: bool = False): + self.config = config + self.debug = debug + self._checkpoint_path = os.path.join(config.output_dir, ".checkpoint.txt") + self._failures_path = os.path.join(config.output_dir, ".failures.txt") + self._processed_ids: set[str] = set() + self._failed_ids: set[str] = set() + + # Validate external tools + self._tool_paths = self._validate_tools() + + # Create output directory + os.makedirs(config.output_dir, exist_ok=True) + + def _validate_tools(self) -> dict: + """Check availability of external tools for Bader/DDEC6 analysis. + + Returns: + Dict with keys ``bader_path``, ``chargemol_path``, + ``chgsum_script_path``, ``perl_path``, ``atomic_densities_path``, + ``can_generate_potcar``, ``can_run_bader``, ``can_run_ddec6``. + """ + tools = {} + + tools["bader_path"] = self.config.bader_path or shutil.which("bader") + if not tools["bader_path"]: + logger.warning( + "bader executable not found. Bader charges will not be computed." + ) + + tools["chargemol_path"] = ( + self.config.chargemol_path + or shutil.which("Chargemol_09_26_2017_linux_serial") + or shutil.which("chargemol") + ) + if not tools["chargemol_path"]: + logger.warning( + "chargemol executable not found. DDEC6 charges will not be computed." + ) + + tools["chgsum_script_path"] = self.config.chgsum_script_path + if tools["chgsum_script_path"] and not os.path.isfile( + tools["chgsum_script_path"] + ): + logger.warning( + f"chgsum.pl not found at {tools['chgsum_script_path']}. " + "Bader analysis requires this script." + ) + tools["chgsum_script_path"] = None + + tools["perl_path"] = shutil.which("perl") + if not tools["perl_path"]: + logger.warning( + "perl not found. Bader analysis requires perl for chgsum.pl." + ) + + tools["atomic_densities_path"] = self.config.atomic_densities_path + if tools["atomic_densities_path"] and not os.path.isdir( + tools["atomic_densities_path"] + ): + logger.warning( + f"Atomic densities directory not found: {tools['atomic_densities_path']}. " + "DDEC6 analysis requires this directory." + ) + tools["atomic_densities_path"] = None + + tools["can_generate_potcar"] = bool(os.environ.get("PMG_VASP_PSP_DIR")) + if not tools["can_generate_potcar"]: + logger.warning( + "PMG_VASP_PSP_DIR not set. POTCAR generation will fail. " + "Bader and DDEC6 analysis require POTCAR." + ) + + tools["can_run_bader"] = bool( + tools["bader_path"] + and tools["chgsum_script_path"] + and tools["perl_path"] + and tools["can_generate_potcar"] + ) + tools["can_run_ddec6"] = bool( + tools["chargemol_path"] + and tools["atomic_densities_path"] + and tools["can_generate_potcar"] + ) + + return tools + + def run(self) -> None: + """Run the full pipeline: list materials, process, write Parquet, optionally push. + + In debug mode, materials are processed sequentially in the main process. + Otherwise, uses a ``ProcessPoolExecutor`` with a work-stealing pattern. + Writes Parquet chunks of ``config.parquet_chunk_size`` rows using atomic + rename. Appends each processed ID to a checkpoint file for crash recovery. + """ + # 1. List material folders from S3 + logger.info("Listing material folders from S3...") + material_ids = self._list_materials() + logger.info(f"Found {len(material_ids)} materials in S3") + + # 2. Load checkpoint and failures, filter already-handled + self._processed_ids = self._load_checkpoint() + self._failed_ids = self._load_failures() + remaining = [ + m + for m in material_ids + if m not in self._processed_ids and m not in self._failed_ids + ] + + # Apply limit if set + if self.config.limit is not None and len(remaining) > self.config.limit: + remaining = remaining[: self.config.limit] + + logger.info( + f"Already processed: {len(self._processed_ids)}, " + f"previously failed: {len(self._failed_ids)}, " + f"remaining: {len(remaining)}" + ) + + if not remaining: + logger.info("All materials already processed.") + if self.config.hf_repo_id: + self._push_to_huggingface() + return + + # 3. Process materials + buffer = [] + buffer_ids = [] + chunk_index = self._get_next_chunk_index() + processed_count = 0 + failed_count = 0 + + if self.debug: + for material_id in remaining: + result = self._process_material( + material_id, self.config, self._tool_paths + ) + if result is not None: + buffer.append(result) + buffer_ids.append(material_id) + processed_count += 1 + else: + self._append_failure(material_id) + failed_count += 1 + + if len(buffer) >= self.config.parquet_chunk_size: + self._write_parquet_chunk(buffer, chunk_index) + self._batch_checkpoint(buffer_ids) + buffer.clear() + buffer_ids.clear() + chunk_index += 1 + + total = processed_count + failed_count + if total % self.config.log_every == 0 and total > 0: + logger.info( + f"Progress: {processed_count} processed, {failed_count} failed" + ) + else: + with concurrent.futures.ProcessPoolExecutor( + max_workers=self.config.num_workers + ) as executor: + remaining_iter = iter(remaining) + futures = {} + + # Submit initial batch (2x workers for pipeline saturation) + initial_count = min(self.config.num_workers * 2, len(remaining)) + for _ in range(initial_count): + mid = next(remaining_iter) + future = executor.submit( + self._process_material, mid, self.config, self._tool_paths + ) + futures[future] = mid + + while futures: + done, _ = concurrent.futures.wait( + futures, + return_when=concurrent.futures.FIRST_COMPLETED, + ) + for future in done: + material_id = futures.pop(future) + + try: + result = future.result() + if result is not None: + buffer.append(result) + buffer_ids.append(material_id) + processed_count += 1 + else: + self._append_failure(material_id) + failed_count += 1 + except Exception as e: + logger.warning( + f"Worker exception for {material_id}: {e}" + ) + self._append_failure(material_id) + failed_count += 1 + + # Write chunk if buffer is full + if len(buffer) >= self.config.parquet_chunk_size: + self._write_parquet_chunk(buffer, chunk_index) + self._batch_checkpoint(buffer_ids) + buffer.clear() + buffer_ids.clear() + chunk_index += 1 + + # Submit replacement (work-stealing) + try: + next_id = next(remaining_iter) + f = executor.submit( + self._process_material, + next_id, + self.config, + self._tool_paths, + ) + futures[f] = next_id + except StopIteration: + pass + + total = processed_count + failed_count + if total % self.config.log_every == 0 and total > 0: + logger.info( + f"Progress: {processed_count} processed, " + f"{failed_count} failed" + ) + + # Write remaining buffer + if buffer: + self._write_parquet_chunk(buffer, chunk_index) + self._batch_checkpoint(buffer_ids) + + logger.info( + f"Done. {processed_count} processed, {failed_count} failed." + ) + + # 4. Push if configured + if self.config.hf_repo_id: + self._push_to_huggingface() + + def _list_materials(self) -> list[str]: + """List material folder prefixes from S3, filtered by ``VALID_PREFIXES``. + + Returns: + List of material IDs (e.g. ``["agm000001", "mp-123", ...]``). + Order depends on S3 listing order (typically lexicographic). + """ + client = get_authenticated_aws_client() + bucket = self.config.lematrho_bucket_name + paginator = client.get_paginator("list_objects_v2") + + material_folders = [] + for page in paginator.paginate(Bucket=bucket, Delimiter="/"): + for prefix_info in page.get("CommonPrefixes", []): + folder_name = prefix_info["Prefix"].rstrip("/") + if folder_name.startswith(VALID_PREFIXES): + material_folders.append(folder_name) + + return material_folders + + @staticmethod + def _process_material( + material_id: str, + config: DirectPipelineConfig, + tool_paths: dict, + ) -> Optional[dict]: + """Process a single material: download, compress, analyze, return row dict. + + Designed to run in a worker process. Creates a fresh AWS client per + invocation (boto3 clients are not multiprocess-safe). Calls + ``gc.collect()`` after each material to free memory from large CHGCAR + arrays. + + Args: + material_id: Material folder name, e.g. ``"agm000001"``. + config: Pipeline configuration. + tool_paths: Tool path dict from ``_validate_tools()``. + + Returns: + Flat dict matching ``PARQUET_COLUMNS``, or ``None`` on failure. + """ + bucket = config.lematrho_bucket_name + grid_shape = config.lematrho_grid_shape + + try: + # Fresh client per worker (boto3 clients are NOT multiprocess-safe) + aws_client = get_authenticated_aws_client() + + # Step 1: Download and parse vasprun.xml.gz for structure + vasprun_key = f"{material_id}/{RELAX_CALC_TYPE}/vasprun.xml.gz" + try: + vasprun_bytes = download_gz_file_from_s3( + aws_client, bucket, vasprun_key + ) + structure = parse_vasprun_structure(vasprun_bytes) + del vasprun_bytes + except Exception as e: + logger.warning( + f"Failed to parse vasprun.xml.gz for {material_id}: {e}" + ) + return None + + # Step 2: Download and compress charge density files + compressed_grids = {} + # Memory trade-off: raw decompressed bytes are kept in memory for + # Bader/DDEC6 analysis to avoid a second S3 download. Each CHGCAR + # can be 100-500 MB, so peak RSS per worker ≈ sum of needed files. + raw_files = {} + + # Determine which raw files to keep + need_raw = set() + if tool_paths["can_run_bader"]: + need_raw |= _BADER_FILES + if tool_paths["can_run_ddec6"]: + need_raw |= _DDEC6_FILES + + for filename in STATIC_FILES: + s3_key = f"{material_id}/{STATIC_CALC_TYPE}/{filename}" + grid_name = GRID_KEY_MAP[filename] + try: + raw_bytes = download_gz_file_from_s3(aws_client, bucket, s3_key) + + # Compress via pyrho + compressed = compress_chgcar(raw_bytes, grid_shape) + compressed_grids[grid_name] = compressed + del compressed + + # Keep raw bytes if needed for analysis + if filename in need_raw: + vasp_name = filename.replace(".gz", "") + raw_files[vasp_name] = raw_bytes + del raw_bytes + except Exception as e: + logger.warning( + f"Failed to process {filename} for {material_id}: {e}" + ) + + # Step 3: Bader analysis (if tools available and files downloaded) + bader_charges = None + bader_atomic_volume = None + if tool_paths["can_run_bader"] and all( + k in raw_files for k in ["CHGCAR", "AECCAR0", "AECCAR2"] + ): + bader_charges, bader_atomic_volume = _run_bader_from_bytes( + structure, raw_files, tool_paths, material_id + ) + + # Step 4: DDEC6 analysis (if tools available and CHGCAR downloaded) + ddec6_charges = None + if tool_paths["can_run_ddec6"] and "CHGCAR" in raw_files: + ddec6_charges = _run_ddec6_from_bytes( + structure, raw_files, tool_paths, material_id + ) + + # Free raw file bytes + del raw_files + + # Step 5: Build OptimadeStructure (Pydantic validation) + optimade_dict = get_optimade_from_pymatgen(structure) + cross_compatibility = get_cross_compatibility(optimade_dict["elements"]) + + optimade_structure = OptimadeStructure( + id=material_id, + source="lematrho", + immutable_id=material_id, + last_modified=datetime.now(), + **optimade_dict, + functional=Functional.PBE, + cross_compatibility=cross_compatibility, + compressed_charge_density=compressed_grids.get("charge_density"), + compressed_aeccar0=compressed_grids.get("aeccar0"), + compressed_aeccar1=compressed_grids.get("aeccar1"), + compressed_aeccar2=compressed_grids.get("aeccar2"), + charge_density_grid_shape=list(grid_shape), + bader_charges=bader_charges, + bader_atomic_volume=bader_atomic_volume, + ddec6_charges=ddec6_charges, + compute_space_group=True, + compute_bawl_hash=True, + ) + + # Step 6: Convert to flat dict for Parquet + row = _structure_to_row(optimade_structure) + del optimade_structure, compressed_grids + + # Step 7: Force garbage collection in worker + gc.collect() + + return row + + except Exception as e: + logger.error(f"Failed to process material {material_id}: {e}") + gc.collect() + return None + + def _load_checkpoint(self) -> set[str]: + """Load processed material IDs from checkpoint file. + + Returns: + Set of already-processed material IDs. Empty set if no checkpoint exists. + """ + if not os.path.exists(self._checkpoint_path): + return set() + with open(self._checkpoint_path, "r") as f: + return {line.strip() for line in f if line.strip()} + + def _append_checkpoint(self, material_id: str) -> None: + """Append a material ID to the checkpoint file and flush to disk. + + Args: + material_id: ID to record as processed. + """ + with open(self._checkpoint_path, "a") as f: + f.write(material_id + "\n") + f.flush() + os.fsync(f.fileno()) + + def _batch_checkpoint(self, material_ids: list[str]) -> None: + """Append multiple material IDs to the checkpoint file atomically. + + Called after a Parquet chunk is successfully flushed so that + checkpoint and Parquet stay in sync. + + Args: + material_ids: IDs to record as processed. + """ + with open(self._checkpoint_path, "a") as f: + for mid in material_ids: + f.write(mid + "\n") + f.flush() + os.fsync(f.fileno()) + + def _load_failures(self) -> set[str]: + """Load failed material IDs from failures file. + + Returns: + Set of material IDs that previously failed. Empty set if no file. + """ + if not os.path.exists(self._failures_path): + return set() + with open(self._failures_path, "r") as f: + return {line.strip() for line in f if line.strip()} + + def _append_failure(self, material_id: str) -> None: + """Record a failed material ID so it is skipped on resume. + + Args: + material_id: ID that failed processing. + """ + with open(self._failures_path, "a") as f: + f.write(material_id + "\n") + f.flush() + os.fsync(f.fileno()) + + def _get_next_chunk_index(self) -> int: + """Determine next chunk index from existing ``chunk_*.parquet`` files. + + Ignores ``.tmp`` files left by interrupted writes. + + Returns: + Next available chunk index (0 if no existing chunks). + """ + existing = glob(os.path.join(self.config.output_dir, "chunk_*.parquet")) + if not existing: + return 0 + indices = [] + for path in existing: + basename = os.path.basename(path) + try: + idx = int(basename.replace("chunk_", "").replace(".parquet", "")) + indices.append(idx) + except ValueError: + pass + return max(indices) + 1 if indices else 0 + + def _write_parquet_chunk(self, rows: list[dict], chunk_index: int) -> None: + """Write rows to a Parquet file atomically (write ``.tmp``, then rename). + + Args: + rows: List of flat dicts matching ``PARQUET_COLUMNS``. + chunk_index: Chunk sequence number (zero-padded in filename). + """ + final_path = os.path.join( + self.config.output_dir, f"chunk_{chunk_index:06d}.parquet" + ) + tmp_path = final_path + ".tmp" + + # Build column-oriented data from row-oriented dicts + columns = {col: [row.get(col) for row in rows] for col in PARQUET_COLUMNS} + table = pa.table(columns, schema=PARQUET_SCHEMA) + pq.write_table(table, tmp_path) + + # Atomic rename + os.rename(tmp_path, final_path) + logger.info( + f"Wrote chunk {chunk_index} ({len(rows)} rows) to {final_path}" + ) + + def _push_to_huggingface(self) -> None: + """Load all Parquet files and push to HuggingFace as a private dataset.""" + from datasets import load_dataset + + parquet_files = os.path.join(self.config.output_dir, "chunk_*.parquet") + logger.info(f"Loading Parquet files from {parquet_files}") + + dataset = load_dataset("parquet", data_files=parquet_files) + + logger.info(f"Pushing to HuggingFace repo: {self.config.hf_repo_id}") + dataset["train"].push_to_hub( + self.config.hf_repo_id, + token=self.config.hf_token, + private=True, + ) + logger.info("Push complete.") diff --git a/src/lematerial_fetcher/fetcher/lematrho/transform.py b/src/lematerial_fetcher/fetcher/lematrho/transform.py new file mode 100644 index 0000000..fe6081e --- /dev/null +++ b/src/lematerial_fetcher/fetcher/lematrho/transform.py @@ -0,0 +1,493 @@ +# Copyright 2025 Entalpic +import os +import shutil +import subprocess +import tempfile +from datetime import datetime +from typing import Optional + +from pymatgen.core import Structure + +from lematerial_fetcher.database.postgres import OptimadeDatabase, StructuresDatabase +from lematerial_fetcher.fetcher.lematrho.utils import ( + BADER_TIMEOUT, + CHARGEMOL_TIMEOUT, + CHGSUM_TIMEOUT, + STATIC_CALC_TYPE, + download_gz_file_from_s3, + write_potcar, +) +from lematerial_fetcher.models.models import RawStructure +from lematerial_fetcher.models.optimade import Functional, OptimadeStructure +from lematerial_fetcher.transform import BaseTransformer +from lematerial_fetcher.utils.aws import get_authenticated_aws_client +from lematerial_fetcher.utils.logging import logger +from lematerial_fetcher.utils.structure import get_optimade_from_pymatgen + + +def get_cross_compatibility(elements: list[str]) -> bool: + """Determine cross-compatibility for LeMatRho structures. + + Yb-containing structures are excluded (same policy as Alexandria). + + Args: + elements: List of element symbols in the structure. + + Returns: + ``True`` if the structure is cross-compatible, ``False`` otherwise. + """ + return "Yb" not in elements + + +def parse_acf_dat(filepath: str) -> tuple[list[float], list[float]]: + """Parse Bader ACF.dat file for electron counts and atomic volumes. + + Args: + filepath: Path to the ACF.dat file produced by bader. + + Returns: + Tuple of ``(electron_counts, atomic_volumes)`` lists, one entry per atom. + """ + electron_counts = [] + atomic_volumes = [] + with open(filepath) as f: + lines = f.readlines() + + # Skip header (first 2 lines), parse data rows until separator line + for line in lines[2:]: + stripped = line.strip() + if stripped.startswith("-") or not stripped: + break + parts = stripped.split() + if len(parts) >= 7: + electron_counts.append(float(parts[4])) # CHARGE column + atomic_volumes.append(float(parts[6])) # ATOMIC VOL column + + return electron_counts, atomic_volumes + + +def read_potcar_zval(filepath: str) -> dict[str, float]: + """Read valence electron counts from a POTCAR file. + + Parses TITEL and ZVAL lines to build an element-to-valence-electrons mapping. + + Args: + filepath: Path to the POTCAR file. + + Returns: + Dict mapping element symbols to their number of valence electrons. + """ + zval = {} + current_element = None + with open(filepath) as f: + for line in f: + if "TITEL" in line: + parts = line.split() + if len(parts) >= 4: + # Handle element names like 'Si_d' -> 'Si' + current_element = parts[3].split("_")[0] + elif "ZVAL" in line and current_element: + parts = line.split("=") + if len(parts) >= 2: + try: + zval[current_element] = float(parts[1].split()[0]) + current_element = None + except (ValueError, IndexError): + pass + return zval + + +def parse_ddec6_charges(tmpdir: str) -> list[float]: + """Parse DDEC6 net atomic charges from chargemol output. + + Reads ``DDEC6_even_tempered_net_atomic_charges.xyz`` from *tmpdir*. + + Args: + tmpdir: Directory containing chargemol output files. + + Returns: + List of net DDEC6 charges, one per atom. + """ + filepath = os.path.join(tmpdir, "DDEC6_even_tempered_net_atomic_charges.xyz") + charges = [] + with open(filepath) as f: + lines = f.readlines() + + n_atoms = int(lines[0].strip()) + for line in lines[2 : 2 + n_atoms]: + parts = line.split() + if len(parts) >= 5: + charges.append(float(parts[4])) + + return charges + + +class LeMatRhoTransformer(BaseTransformer): + """Transformer for LeMatRho charge density data. + + Transforms raw structures (with compressed charge densities from the fetch step) + into ``OptimadeStructure`` objects. Optionally runs Bader and DDEC6 charge + analysis using external tools. + + External tool requirements: + - ``bader``: Bader charge analysis executable + - ``perl`` + ``chgsum.pl``: For summing AECCAR0 + AECCAR2 + - ``chargemol``: DDEC6 charge partitioning executable + - ``PMG_VASP_PSP_DIR``: Env var for POTCAR generation + - atomic densities directory: For DDEC6/chargemol analysis + + Args: + config: Transformer configuration. + database_class: Database class for storing results. + structure_class: Pydantic model class for validated structures. + debug: If ``True``, process sequentially for debugging. + """ + + def __init__( + self, + config=None, + database_class=OptimadeDatabase, + structure_class=OptimadeStructure, + debug=False, + ): + super().__init__(config, database_class, structure_class, debug) + self._aws_client = None + self._bader_path = None + self._chargemol_path = None + self._chgsum_script_path = None + self._perl_path = None + self._atomic_densities_path = None + self._can_generate_potcar = False + self._validate_tools() + + def _validate_tools(self) -> None: + """Check availability of external tools and log warnings for missing ones. + + Sets instance attributes ``_bader_path``, ``_chargemol_path``, + ``_chgsum_script_path``, ``_perl_path``, ``_atomic_densities_path``, + and ``_can_generate_potcar``. + """ + self._bader_path = getattr(self.config, "bader_path", None) or shutil.which( + "bader" + ) + if not self._bader_path: + logger.warning( + "bader executable not found. Bader charges will not be computed." + ) + + self._chargemol_path = getattr( + self.config, "chargemol_path", None + ) or shutil.which("Chargemol_09_26_2017_linux_serial") + if not self._chargemol_path: + self._chargemol_path = shutil.which("chargemol") + if not self._chargemol_path: + logger.warning( + "chargemol executable not found. DDEC6 charges will not be computed." + ) + + self._chgsum_script_path = getattr( + self.config, "chgsum_script_path", None + ) + if self._chgsum_script_path and not os.path.isfile(self._chgsum_script_path): + logger.warning( + f"chgsum.pl not found at {self._chgsum_script_path}. " + "Bader analysis requires this script." + ) + self._chgsum_script_path = None + + self._perl_path = shutil.which("perl") + if not self._perl_path: + logger.warning( + "perl not found. Bader analysis requires perl for chgsum.pl." + ) + + self._atomic_densities_path = getattr( + self.config, "atomic_densities_path", None + ) + if self._atomic_densities_path and not os.path.isdir( + self._atomic_densities_path + ): + logger.warning( + f"Atomic densities directory not found: {self._atomic_densities_path}. " + "DDEC6 analysis requires this directory." + ) + self._atomic_densities_path = None + + if not os.environ.get("PMG_VASP_PSP_DIR"): + logger.warning( + "PMG_VASP_PSP_DIR not set. POTCAR generation will fail. " + "Bader and DDEC6 analysis require POTCAR." + ) + self._can_generate_potcar = False + else: + self._can_generate_potcar = True + + @property + def aws_client(self): + """Lazy-initialized authenticated S3 client.""" + if self._aws_client is None: + self._aws_client = get_authenticated_aws_client() + return self._aws_client + + @property + def can_run_bader(self) -> bool: + """Check if all Bader analysis prerequisites are met.""" + return bool( + self._bader_path + and self._chgsum_script_path + and self._perl_path + and self._can_generate_potcar + ) + + @property + def can_run_ddec6(self) -> bool: + """Check if all DDEC6 analysis prerequisites are met.""" + return bool( + self._chargemol_path + and self._atomic_densities_path + and self._can_generate_potcar + ) + + def transform_row( + self, + raw_structure: RawStructure, + source_db: Optional[StructuresDatabase] = None, + task_table_name: Optional[str] = None, + ) -> list[OptimadeStructure]: + """Transform a raw LeMatRho structure into an OptimadeStructure. + + Args: + raw_structure: Raw structure from the fetch step, with charge + density data in its attributes dict. + source_db: Not used for LeMatRho. + task_table_name: Not used for LeMatRho. + + Returns: + Single-element list containing the transformed ``OptimadeStructure``. + """ + attrs = raw_structure.attributes + material_id = raw_structure.id + + # Extract pymatgen Structure from raw data + structure = Structure.from_dict(attrs["structure"]) + + # Get base OPTIMADE fields from structure + optimade_dict = get_optimade_from_pymatgen(structure) + + # Extract pre-computed compressed grids from fetch step + compressed_charge_density = attrs.get("compressed_charge_density") + compressed_aeccar0 = attrs.get("compressed_aeccar0") + compressed_aeccar1 = attrs.get("compressed_aeccar1") + compressed_aeccar2 = attrs.get("compressed_aeccar2") + grid_shape = attrs.get("grid_shape") + s3_prefix = attrs.get("s3_prefix") + + # Cross-compatibility (exclude Yb, same policy as Alexandria) + cross_compatibility = get_cross_compatibility(optimade_dict["elements"]) + + # Bader analysis (independent from DDEC6) + bader_charges = None + bader_atomic_volume = None + if self.can_run_bader and s3_prefix: + bader_charges, bader_atomic_volume = self._run_bader_analysis( + structure, s3_prefix, material_id + ) + + # DDEC6 analysis (independent from Bader) + ddec6_charges = None + if self.can_run_ddec6 and s3_prefix: + ddec6_charges = self._run_ddec6_analysis( + structure, s3_prefix, material_id + ) + + optimade_structure = OptimadeStructure( + id=material_id, + source="lematrho", + immutable_id=material_id, + last_modified=raw_structure.last_modified or datetime.now(), + **optimade_dict, + functional=Functional.PBE, + cross_compatibility=cross_compatibility, + compressed_charge_density=compressed_charge_density, + compressed_aeccar0=compressed_aeccar0, + compressed_aeccar1=compressed_aeccar1, + compressed_aeccar2=compressed_aeccar2, + charge_density_grid_shape=grid_shape, + bader_charges=bader_charges, + bader_atomic_volume=bader_atomic_volume, + ddec6_charges=ddec6_charges, + compute_space_group=True, + compute_bawl_hash=True, + ) + + return [optimade_structure] + + def _run_bader_analysis( + self, structure: Structure, s3_prefix: str, material_id: str + ) -> tuple[Optional[list[float]], Optional[list[float]]]: + """Run Bader charge analysis. + + Downloads CHGCAR, AECCAR0, AECCAR2 from S3, runs ``chgsum.pl`` to + create the reference charge density, then runs bader and parses results. + + Args: + structure: Pymatgen Structure for POTCAR generation. + s3_prefix: S3 folder prefix for this material. + material_id: Material identifier, used for logging. + + Returns: + Tuple of ``(net_charges, atomic_volumes)`` or ``(None, None)`` + on failure. + """ + try: + bucket = self.config.lematrho_bucket_name + with tempfile.TemporaryDirectory() as tmpdir: + # Download raw VASP charge density files + for filename in ["CHGCAR.gz", "AECCAR0.gz", "AECCAR2.gz"]: + key = f"{s3_prefix}/{STATIC_CALC_TYPE}/{filename}" + data = download_gz_file_from_s3(self.aws_client, bucket, key) + outname = filename.replace(".gz", "") + with open(os.path.join(tmpdir, outname), "wb") as f: + f.write(data) + del data + + # Generate POTCAR + write_potcar(structure, tmpdir) + + # Sum AECCAR0 + AECCAR2 -> CHGCAR_sum + subprocess.run( + [ + self._perl_path, + self._chgsum_script_path, + "AECCAR0", + "AECCAR2", + ], + cwd=tmpdir, + timeout=CHGSUM_TIMEOUT, + check=True, + capture_output=True, + ) + + # Run Bader analysis with reference charge density + subprocess.run( + [self._bader_path, "CHGCAR", "-ref", "CHGCAR_sum"], + cwd=tmpdir, + timeout=BADER_TIMEOUT, + check=True, + capture_output=True, + ) + + # Parse ACF.dat for electron counts and atomic volumes + electron_counts, atomic_volumes = parse_acf_dat( + os.path.join(tmpdir, "ACF.dat") + ) + + # Compute net charges: valence_electrons - bader_electron_count + zval = read_potcar_zval(os.path.join(tmpdir, "POTCAR")) + net_charges = [] + for site, electron_count in zip(structure.sites, electron_counts): + element = str(site.specie) + valence = zval.get(element, 0) + net_charges.append(valence - electron_count) + + return net_charges, atomic_volumes + + except subprocess.TimeoutExpired: + logger.warning(f"Bader analysis timed out for {material_id}") + return None, None + except subprocess.CalledProcessError as e: + logger.warning( + f"Bader subprocess failed for {material_id}: " + f"exit code {e.returncode}, stderr: {e.stderr}" + ) + return None, None + except Exception as e: + logger.warning(f"Bader analysis failed for {material_id}: {e}") + return None, None + + def _run_ddec6_analysis( + self, structure: Structure, s3_prefix: str, material_id: str + ) -> Optional[list[float]]: + """Run DDEC6 charge analysis via chargemol. + + Downloads CHGCAR from S3, generates POTCAR, writes chargemol config, + runs chargemol, and parses DDEC6 net charges. + + Args: + structure: Pymatgen Structure for POTCAR generation. + s3_prefix: S3 folder prefix for this material. + material_id: Material identifier, used for logging. + + Returns: + DDEC6 net charges per site, or ``None`` on failure. + """ + try: + bucket = self.config.lematrho_bucket_name + with tempfile.TemporaryDirectory() as tmpdir: + # Download CHGCAR + key = f"{s3_prefix}/{STATIC_CALC_TYPE}/CHGCAR.gz" + data = download_gz_file_from_s3(self.aws_client, bucket, key) + with open(os.path.join(tmpdir, "CHGCAR"), "wb") as f: + f.write(data) + del data + + # Generate POTCAR + write_potcar(structure, tmpdir) + + # Write chargemol job control file + self._write_chargemol_config(tmpdir) + + # Run chargemol + env = os.environ.copy() + env["DDEC6_ATOMIC_DENSITIES_DIR"] = self._atomic_densities_path + subprocess.run( + [self._chargemol_path], + cwd=tmpdir, + timeout=CHARGEMOL_TIMEOUT, + check=True, + capture_output=True, + env=env, + ) + + return parse_ddec6_charges(tmpdir) + + except subprocess.TimeoutExpired: + logger.warning(f"DDEC6 analysis timed out for {material_id}") + return None + except subprocess.CalledProcessError as e: + logger.warning( + f"DDEC6 subprocess failed for {material_id}: " + f"exit code {e.returncode}, stderr: {e.stderr}" + ) + return None + except Exception as e: + logger.warning(f"DDEC6 analysis failed for {material_id}: {e}") + return None + + def _write_chargemol_config(self, tmpdir: str) -> None: + """Write ``job_control.txt`` for chargemol DDEC6 analysis. + + Args: + tmpdir: Directory where the config file will be written. + """ + config_content = ( + "\n" + "0.0\n" + "\n" + "\n" + ".true.\n" + ".true.\n" + ".true.\n" + "\n" + "\n" + f"{self._atomic_densities_path}\n" + "\n" + "\n" + "DDEC6\n" + "\n" + "\n" + "CHGCAR\n" + "\n" + ) + with open(os.path.join(tmpdir, "job_control.txt"), "w") as f: + f.write(config_content) diff --git a/src/lematerial_fetcher/fetcher/lematrho/utils.py b/src/lematerial_fetcher/fetcher/lematrho/utils.py new file mode 100644 index 0000000..df49b67 --- /dev/null +++ b/src/lematerial_fetcher/fetcher/lematrho/utils.py @@ -0,0 +1,170 @@ +# Copyright 2025 Entalpic +import gzip +import os +import tempfile +from datetime import datetime +from typing import Any, Optional + +from pymatgen.core import Structure +from pymatgen.io.vasp import Chgcar, Vasprun + +from lematerial_fetcher.models.models import RawStructure +from lematerial_fetcher.utils.logging import logger + +# ── S3 folder structure constants ────────────────────────────────────────────── +STATIC_CALC_TYPE = "LeMatRhoStaticMaker" +RELAX_CALC_TYPE = "LeMatRhoRelaxMaker_1" +STATIC_FILES = ["CHGCAR.gz", "AECCAR0.gz", "AECCAR1.gz", "AECCAR2.gz"] +RELAX_FILES = ["vasprun.xml.gz"] + +# Only process materials with these ID prefixes +VALID_PREFIXES = ("oqmd-", "mp-", "agm") + +# Conservative default due to high memory usage per CHGCAR (~hundreds of MB) +DEFAULT_MAX_WORKERS = 4 + +# Map from S3 filename to compressed grid key name +GRID_KEY_MAP = { + "CHGCAR.gz": "charge_density", + "AECCAR0.gz": "aeccar0", + "AECCAR1.gz": "aeccar1", + "AECCAR2.gz": "aeccar2", +} + +# Subprocess timeout constants (seconds) +BADER_TIMEOUT = 600 +CHGSUM_TIMEOUT = 300 +CHARGEMOL_TIMEOUT = 600 + + +def download_gz_file_from_s3(client: Any, bucket: str, key: str) -> bytes: + """Download and decompress a gzipped file from S3. + + Args: + client: Boto3 S3 client. + bucket: S3 bucket name. + key: S3 object key. + + Returns: + Decompressed file contents as raw bytes. + """ + response = client.get_object(Bucket=bucket, Key=key) + body = response["Body"] + try: + compressed = body.read() + decompressed = gzip.decompress(compressed) + del compressed + return decompressed + finally: + body.close() + + +def parse_vasprun_structure(vasprun_bytes: bytes) -> Structure: + """Parse a vasprun.xml to extract the final relaxed structure. + + Writes bytes to a temporary file because pymatgen's ``Vasprun`` requires + a filesystem path, not a file-like object. + + Args: + vasprun_bytes: Raw vasprun.xml content. + + Returns: + The final relaxed pymatgen Structure. + """ + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "vasprun.xml") + with open(path, "wb") as f: + f.write(vasprun_bytes) + vasprun = Vasprun( + path, + parse_dos=False, + parse_eigen=False, + parse_potcar_file=False, + ) + return vasprun.final_structure + + +def compress_chgcar(chgcar_bytes: bytes, grid_shape: tuple[int, int, int]) -> list: + """Parse a CHGCAR file and compress its charge density using pyrho. + + Writes bytes to a temporary file because pymatgen's ``Chgcar.from_file`` + requires a filesystem path, not a file-like object. + + Args: + chgcar_bytes: Raw CHGCAR file content (uncompressed VASP format). + grid_shape: Target grid shape for lossy compression, e.g. ``(15, 15, 15)``. + + Returns: + Compressed charge density grid as a nested Python list. + """ + from pyrho.charge_density import ChargeDensity + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "CHGCAR") + with open(path, "wb") as f: + f.write(chgcar_bytes) + chgcar = Chgcar.from_file(path) + charge_density = ChargeDensity.from_pmg(chgcar) + compressed = charge_density.pgrids["total"].lossy_smooth_compression(grid_shape) + result = compressed.tolist() + del chgcar, charge_density, compressed + return result + + +def build_raw_structure( + material_id: str, + structure: Structure, + compressed_grids: dict[str, Optional[list]], + grid_shape: tuple[int, int, int], + s3_prefix: str, +) -> RawStructure: + """Build a RawStructure from parsed charge density data. + + Args: + material_id: Material identifier, e.g. ``"agm000001"``. + structure: Pymatgen Structure parsed from vasprun.xml. + compressed_grids: Dict mapping grid names (``"charge_density"``, + ``"aeccar0"``, ``"aeccar1"``, ``"aeccar2"``) to compressed + grid lists or ``None``. + grid_shape: Grid shape used for compression. + s3_prefix: S3 prefix path for the material folder. + + Returns: + A ``RawStructure`` ready for database insertion. + """ + attributes = { + "structure": structure.as_dict(), + "compressed_charge_density": compressed_grids.get("charge_density"), + "compressed_aeccar0": compressed_grids.get("aeccar0"), + "compressed_aeccar1": compressed_grids.get("aeccar1"), + "compressed_aeccar2": compressed_grids.get("aeccar2"), + "grid_shape": list(grid_shape), + "s3_prefix": s3_prefix, + } + + return RawStructure( + id=material_id, + type="lematrho", + attributes=attributes, + last_modified=datetime.now(), + ) + + +def write_potcar(structure: Structure, tmpdir: str) -> None: + """Generate a POTCAR file for the given structure. + + Uses ``MatPESStaticSet`` to select pseudopotentials consistent with + Materials Project settings and writes the resulting POTCAR to *tmpdir*. + + Args: + structure: Pymatgen Structure for which to generate the POTCAR. + tmpdir: Directory where ``POTCAR`` will be written. + + Raises: + OSError: If ``PMG_VASP_PSP_DIR`` is not set or the pseudopotential + files cannot be found. + """ + from pymatgen.io.vasp.sets import MatPESStaticSet + + input_set = MatPESStaticSet(structure) + input_set.potcar.write_file(os.path.join(tmpdir, "POTCAR")) diff --git a/src/lematerial_fetcher/models/optimade.py b/src/lematerial_fetcher/models/optimade.py index ac07685..065cce9 100644 --- a/src/lematerial_fetcher/models/optimade.py +++ b/src/lematerial_fetcher/models/optimade.py @@ -184,6 +184,45 @@ class OptimadeStructure(BaseModel): description="BAWL fingerprint hash", ) + # Charge density fields (LeMatRho) + compressed_charge_density: Optional[list] = Field( + None, + description="Compressed charge density grid from pyrho lossy compression", + ) + compressed_aeccar0: Optional[list] = Field( + None, + description="Compressed AECCAR0 (all-electron core charge density) grid", + ) + compressed_aeccar1: Optional[list] = Field( + None, + description="Compressed AECCAR1 (pseudo valence charge density) grid", + ) + compressed_aeccar2: Optional[list] = Field( + None, + description="Compressed AECCAR2 (pseudo core charge density) grid", + ) + charge_density_grid_shape: Optional[list[int]] = Field( + None, + min_length=3, + max_length=3, + description="Shape of the compressed charge density grid [nx, ny, nz]", + ) + bader_charges: Optional[list[float]] = Field( + None, + min_length=1, + description="Bader charges per site", + ) + bader_atomic_volume: Optional[list[float]] = Field( + None, + min_length=1, + description="Bader atomic volumes per site", + ) + ddec6_charges: Optional[list[float]] = Field( + None, + min_length=1, + description="DDEC6 charges per site", + ) + def __init__( self, compute_space_group: bool = True, @@ -513,6 +552,15 @@ def check_consistency(self): self.charges = self._validate_with_number_of_sites( self.charges, nsites, "charges" ) + self.bader_charges = self._validate_with_number_of_sites( + self.bader_charges, nsites, "bader_charges" + ) + self.bader_atomic_volume = self._validate_with_number_of_sites( + self.bader_atomic_volume, nsites, "bader_atomic_volume" + ) + self.ddec6_charges = self._validate_with_number_of_sites( + self.ddec6_charges, nsites, "ddec6_charges" + ) # Validation using the Pymatgen structure structure = Structure( diff --git a/src/lematerial_fetcher/models/utils/enums.py b/src/lematerial_fetcher/models/utils/enums.py index 7525b5a..828aa89 100644 --- a/src/lematerial_fetcher/models/utils/enums.py +++ b/src/lematerial_fetcher/models/utils/enums.py @@ -13,3 +13,4 @@ class Source(str, Enum): ALEXANDRIA = "alexandria" MP = "mp" OQMD = "oqmd" + LEMATRHO = "lematrho" diff --git a/src/lematerial_fetcher/push.py b/src/lematerial_fetcher/push.py index c8f01cc..df43a56 100644 --- a/src/lematerial_fetcher/push.py +++ b/src/lematerial_fetcher/push.py @@ -135,6 +135,14 @@ def _get_optimade_features(self) -> Features: "cross_compatibility": Value("bool"), "bawl_fingerprint": Value("string"), "space_group_it_number": Value("int32"), + "compressed_charge_density": Value("string"), + "compressed_aeccar0": Value("string"), + "compressed_aeccar1": Value("string"), + "compressed_aeccar2": Value("string"), + "charge_density_grid_shape": Sequence(Value("int32")), + "bader_charges": Sequence(Value("float64")), + "bader_atomic_volume": Sequence(Value("float64")), + "ddec6_charges": Sequence(Value("float64")), } ) @@ -171,6 +179,14 @@ def _get_trajectories_features(self) -> Features: del features["charges"] del features["total_magnetization"] del features["bawl_fingerprint"] + del features["compressed_charge_density"] + del features["compressed_aeccar0"] + del features["compressed_aeccar1"] + del features["compressed_aeccar2"] + del features["charge_density_grid_shape"] + del features["bader_charges"] + del features["bader_atomic_volume"] + del features["ddec6_charges"] convert_features_dict.update( { @@ -467,6 +483,34 @@ def convert_species(batch): desc="Converting species column to string", ) + # Convert compressed charge density fields from nested lists to JSON strings + json_serialized_columns = [ + "compressed_charge_density", + "compressed_aeccar0", + "compressed_aeccar1", + "compressed_aeccar2", + ] + columns_to_convert = [ + col + for col in json_serialized_columns + if col in dataset["train"].column_names + ] + if columns_to_convert: + + def convert_charge_density(batch): + for col in columns_to_convert: + batch[col] = [ + json.dumps(v) if v is not None else None for v in batch[col] + ] + return batch + + dataset = dataset.map( + convert_charge_density, + batched=True, + num_proc=self.config.num_workers, + desc="Converting charge density columns to string", + ) + for split in dataset.keys(): dataset[split] = dataset[split].cast( features=self.features, num_proc=self.config.num_workers diff --git a/src/lematerial_fetcher/utils/aws.py b/src/lematerial_fetcher/utils/aws.py index 154e464..402606a 100644 --- a/src/lematerial_fetcher/utils/aws.py +++ b/src/lematerial_fetcher/utils/aws.py @@ -27,6 +27,32 @@ def get_aws_client(region_name: str = "us-east-1"): return s3_client +def get_authenticated_aws_client(region_name: str = "us-east-1"): + """Returns a configured S3 client using the default AWS credential chain. + + Uses AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN, + or IAM roles for authentication. Includes adaptive retry configuration. + + Parameters + ---------- + region_name: str, default='us-east-1' + The AWS region for the S3 client. + + Returns + ------- + s3_client: boto3.client + A configured S3 client with authenticated credentials + """ + s3_client = boto3.client( + "s3", + config=Config( + retries={"max_attempts": 3, "mode": "adaptive"}, + region_name=region_name, + ), + ) + return s3_client + + def get_latest_collection_version_prefix( client, bucket_name: str, bucket_prefix: str, collections_prefix: str ) -> str: diff --git a/src/lematerial_fetcher/utils/cli.py b/src/lematerial_fetcher/utils/cli.py index 20c2f15..8608cba 100644 --- a/src/lematerial_fetcher/utils/cli.py +++ b/src/lematerial_fetcher/utils/cli.py @@ -250,6 +250,165 @@ def add_mp_fetch_options(f): return f +def add_lematrho_fetch_options(f): + """Add LeMatRho fetch options to a command.""" + decorators = [ + click.option( + "--lematrho-bucket-name", + type=str, + default="lemat-rho", + envvar="LEMATERIALFETCHER_LEMATRHO_BUCKET_NAME", + help="LeMatRho S3 bucket name.", + ), + click.option( + "--grid-shape", + type=(int, int, int), + default=(15, 15, 15), + envvar="LEMATERIALFETCHER_LEMATRHO_GRID_SHAPE", + help="Grid shape for pyrho charge density compression (nx ny nz).", + ), + ] + for decorator in reversed(decorators): + f = decorator(f) + return f + + +def add_lematrho_transform_options(f): + """Add LeMatRho transform options to a command.""" + decorators = [ + click.option( + "--lematrho-bucket-name", + type=str, + default="lemat-rho", + envvar="LEMATERIALFETCHER_LEMATRHO_BUCKET_NAME", + help="LeMatRho S3 bucket name for re-downloading raw files during transform.", + ), + click.option( + "--bader-path", + type=str, + envvar="LEMATERIALFETCHER_BADER_PATH", + help="Path to the bader executable. If not provided, will search PATH.", + ), + click.option( + "--chargemol-path", + type=str, + envvar="LEMATERIALFETCHER_CHARGEMOL_PATH", + help="Path to the chargemol executable. If not provided, will search PATH.", + ), + click.option( + "--chgsum-script-path", + type=str, + envvar="LEMATERIALFETCHER_CHGSUM_SCRIPT_PATH", + help="Path to the chgsum.pl perl script for Bader charge summation.", + ), + click.option( + "--atomic-densities-path", + type=str, + envvar="LEMATERIALFETCHER_ATOMIC_DENSITIES_PATH", + help="Path to atomic densities directory for DDEC6/chargemol analysis.", + ), + ] + for decorator in reversed(decorators): + f = decorator(f) + return f + + +def add_lematrho_direct_options(f): + """Add options for the direct LeMatRho S3-to-Parquet pipeline.""" + decorators = [ + click.option( + "--output-dir", + type=str, + default="./lematrho_output", + envvar="LEMATERIALFETCHER_LEMATRHO_OUTPUT_DIR", + help="Directory to write Parquet output files.", + ), + click.option( + "--parquet-chunk-size", + type=int, + default=1000, + envvar="LEMATERIALFETCHER_LEMATRHO_PARQUET_CHUNK_SIZE", + help="Number of rows per Parquet chunk file.", + ), + click.option( + "--num-workers", + type=int, + default=4, + envvar="LEMATERIALFETCHER_NUM_WORKERS", + help="Number of parallel worker processes.", + ), + click.option( + "--log-every", + type=int, + default=100, + envvar="LEMATERIALFETCHER_LOG_EVERY", + help="Log progress every N materials.", + ), + click.option( + "--limit", + type=int, + default=None, + envvar="LEMATERIALFETCHER_LEMATRHO_LIMIT", + help="Max number of materials to process (for testing). Default: no limit.", + ), + click.option( + "--lematrho-bucket-name", + type=str, + default="lemat-rho", + envvar="LEMATERIALFETCHER_LEMATRHO_BUCKET_NAME", + help="LeMatRho S3 bucket name.", + ), + click.option( + "--grid-shape", + type=(int, int, int), + default=(15, 15, 15), + envvar="LEMATERIALFETCHER_LEMATRHO_GRID_SHAPE", + help="Grid shape for pyrho charge density compression (nx ny nz).", + ), + click.option( + "--hf-repo-id", + type=str, + default=None, + envvar="LEMATERIALFETCHER_HF_REPO_ID", + help="HuggingFace repository ID for pushing. If not set, skip push.", + ), + click.option( + "--hf-token", + type=str, + default=None, + envvar="LEMATERIALFETCHER_HF_TOKEN", + help="HuggingFace token for pushing.", + ), + click.option( + "--bader-path", + type=str, + envvar="LEMATERIALFETCHER_BADER_PATH", + help="Path to the bader executable. If not provided, will search PATH.", + ), + click.option( + "--chargemol-path", + type=str, + envvar="LEMATERIALFETCHER_CHARGEMOL_PATH", + help="Path to the chargemol executable. If not provided, will search PATH.", + ), + click.option( + "--chgsum-script-path", + type=str, + envvar="LEMATERIALFETCHER_CHGSUM_SCRIPT_PATH", + help="Path to the chgsum.pl perl script for Bader charge summation.", + ), + click.option( + "--atomic-densities-path", + type=str, + envvar="LEMATERIALFETCHER_ATOMIC_DENSITIES_PATH", + help="Path to atomic densities directory for DDEC6/chargemol analysis.", + ), + ] + for decorator in reversed(decorators): + f = decorator(f) + return f + + def add_push_options(f): """Add push options to a command.""" decorators = [ diff --git a/src/lematerial_fetcher/utils/config.py b/src/lematerial_fetcher/utils/config.py index f7c5746..2478e83 100644 --- a/src/lematerial_fetcher/utils/config.py +++ b/src/lematerial_fetcher/utils/config.py @@ -30,6 +30,8 @@ class FetcherConfig(BaseConfig): mp_bucket_prefix: str mysql_config: Optional[dict] = None oqmd_download_dir: Optional[str] = None + lematrho_bucket_name: Optional[str] = None + lematrho_grid_shape: Optional[tuple[int, int, int]] = None @dataclass @@ -43,6 +45,35 @@ class TransformerConfig(BaseConfig): db_fetch_batch_size: Optional[int] = None mp_task_table_name: Optional[str] = None mysql_config: Optional[dict] = None + lematrho_bucket_name: Optional[str] = None + bader_path: Optional[str] = None + chargemol_path: Optional[str] = None + chgsum_script_path: Optional[str] = None + atomic_densities_path: Optional[str] = None + + +@dataclass +class DirectPipelineConfig: + """Config for the direct S3-to-Parquet pipeline (no PostgreSQL).""" + + # S3 source + lematrho_bucket_name: str = "lemat-rho" + lematrho_grid_shape: tuple[int, int, int] = (15, 15, 15) + # Output + output_dir: str = "./lematrho_output" + parquet_chunk_size: int = 1000 + # Processing + num_workers: int = 4 + log_every: int = 100 + limit: Optional[int] = None + # HuggingFace (optional) + hf_repo_id: Optional[str] = None + hf_token: Optional[str] = None + # External tools (all optional — missing tools result in None fields) + bader_path: Optional[str] = None + chargemol_path: Optional[str] = None + chgsum_script_path: Optional[str] = None + atomic_densities_path: Optional[str] = None @dataclass @@ -156,6 +187,8 @@ def load_fetcher_config( mp_bucket_name: Optional[str] = None, mp_bucket_prefix: Optional[str] = None, oqmd_download_dir: Optional[str] = None, + lematrho_bucket_name: Optional[str] = None, + lematrho_grid_shape: Optional[tuple[int, int, int]] = None, mysql_host: str = "localhost", mysql_user: Optional[str] = None, # No MySQL password parameter @@ -189,6 +222,8 @@ def load_fetcher_config( "mp_bucket_name": mp_bucket_name, "mp_bucket_prefix": mp_bucket_prefix, "oqmd_download_dir": oqmd_download_dir, + "lematrho_bucket_name": lematrho_bucket_name, + "lematrho_grid_shape": lematrho_grid_shape, } # Validate required fields @@ -245,6 +280,12 @@ def load_transformer_config( mysql_user: Optional[str] = None, mysql_database: str = "lematerial", mysql_cert_path: Optional[str] = None, + # LeMatRho-specific params + lematrho_bucket_name: Optional[str] = None, + bader_path: Optional[str] = None, + chargemol_path: Optional[str] = None, + chgsum_script_path: Optional[str] = None, + atomic_densities_path: Optional[str] = None, **base_config_kwargs: Any, ) -> TransformerConfig: """Loads transformer config from passed arguments. @@ -337,6 +378,11 @@ def load_transformer_config( **base_config, **config, mysql_config=mysql_config, + lematrho_bucket_name=lematrho_bucket_name, + bader_path=bader_path, + chargemol_path=chargemol_path, + chgsum_script_path=chgsum_script_path, + atomic_densities_path=atomic_densities_path, ) @@ -399,3 +445,41 @@ def load_push_config( **base_config, **config, ) + + +def load_direct_pipeline_config( + lematrho_bucket_name: str = "lemat-rho", + grid_shape: tuple[int, int, int] = (15, 15, 15), + output_dir: str = "./lematrho_output", + parquet_chunk_size: int = 1000, + num_workers: int = 4, + log_every: int = 100, + limit: Optional[int] = None, + hf_repo_id: Optional[str] = None, + hf_token: Optional[str] = None, + bader_path: Optional[str] = None, + chargemol_path: Optional[str] = None, + chgsum_script_path: Optional[str] = None, + atomic_densities_path: Optional[str] = None, + **_kwargs: Any, +) -> DirectPipelineConfig: + """Load config for the direct S3-to-Parquet pipeline. + + The common workflow is that arguments are passed by Click. + No database credentials needed — this pipeline writes Parquet directly. + """ + return DirectPipelineConfig( + lematrho_bucket_name=lematrho_bucket_name, + lematrho_grid_shape=grid_shape, + output_dir=output_dir, + parquet_chunk_size=parquet_chunk_size, + num_workers=num_workers, + log_every=log_every, + limit=limit, + hf_repo_id=hf_repo_id, + hf_token=hf_token, + bader_path=bader_path, + chargemol_path=chargemol_path, + chgsum_script_path=chgsum_script_path, + atomic_densities_path=atomic_densities_path, + ) diff --git a/tests/fetcher/lematrho/test_lematrho_fetch.py b/tests/fetcher/lematrho/test_lematrho_fetch.py new file mode 100644 index 0000000..fc7f5bd --- /dev/null +++ b/tests/fetcher/lematrho/test_lematrho_fetch.py @@ -0,0 +1,616 @@ +# Copyright 2025 Entalpic +import gzip +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from lematerial_fetcher.database.postgres import DatasetVersions, StructuresDatabase +from lematerial_fetcher.fetch import ItemsInfo +from lematerial_fetcher.fetcher.lematrho.fetch import ( + DEFAULT_MAX_WORKERS, + RELAX_CALC_TYPE, + STATIC_CALC_TYPE, + STATIC_FILES, + VALID_PREFIXES, + LeMatRhoFetcher, +) +from lematerial_fetcher.fetcher.lematrho.utils import ( + build_raw_structure, + download_gz_file_from_s3, +) +from lematerial_fetcher.utils.config import FetcherConfig + + +@pytest.fixture +def mock_aws_client(): + return MagicMock() + + +@pytest.fixture +def mock_db(): + return MagicMock(spec=StructuresDatabase) + + +@pytest.fixture +def mock_version_db(): + return MagicMock(spec=DatasetVersions) + + +@pytest.fixture +def mock_config(): + return FetcherConfig( + base_url="https://api.test.com", + db_conn_str="postgresql://test:test@localhost:5432/test", + table_name="test_lematrho_raw", + page_limit=10, + page_offset=0, + mp_bucket_name="", + mp_bucket_prefix="", + log_dir="./logs", + max_retries=3, + num_workers=2, + retry_delay=2, + log_every=100, + lematrho_bucket_name="lemat-rho", + lematrho_grid_shape=(15, 15, 15), + ) + + +# --------------------------------------------------------------------------- +# get_items_to_process tests +# --------------------------------------------------------------------------- + + +class TestGetItemsToProcess: + def test_filters_by_valid_prefix(self, mock_aws_client, mock_config, mock_version_db): + """Only folders with oqmd-, mp-, or agm prefixes should be returned.""" + mock_paginator = MagicMock() + mock_aws_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "CommonPrefixes": [ + {"Prefix": "agm000001/"}, + {"Prefix": "mp-12345/"}, + {"Prefix": "oqmd-678/"}, + {"Prefix": "test-invalid/"}, + {"Prefix": "random-folder/"}, + ] + } + ] + + with ( + patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls, + patch("lematerial_fetcher.fetch.StructuresDatabase"), + ): + mock_ver_cls.return_value = mock_version_db + mock_version_db.get_last_synced_version.return_value = None + + fetcher = LeMatRhoFetcher(config=mock_config, debug=True) + fetcher.aws_client = mock_aws_client + + items = fetcher.get_items_to_process() + + assert items.total_count == 3 + assert set(items.items) == {"agm000001", "mp-12345", "oqmd-678"} + + def test_ignores_unknown_prefixes(self, mock_aws_client, mock_config, mock_version_db): + """Folders like 'test-123/' or 'data/' should be excluded.""" + mock_paginator = MagicMock() + mock_aws_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "CommonPrefixes": [ + {"Prefix": "test-123/"}, + {"Prefix": "data/"}, + {"Prefix": "backup-20240101/"}, + ] + } + ] + + with ( + patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls, + patch("lematerial_fetcher.fetch.StructuresDatabase"), + ): + mock_ver_cls.return_value = mock_version_db + mock_version_db.get_last_synced_version.return_value = None + + fetcher = LeMatRhoFetcher(config=mock_config, debug=True) + fetcher.aws_client = mock_aws_client + + items = fetcher.get_items_to_process() + + assert items.total_count == 0 + assert items.items == [] + + def test_handles_empty_bucket(self, mock_aws_client, mock_config, mock_version_db): + """Empty bucket should return zero items.""" + mock_paginator = MagicMock() + mock_aws_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [{}] + + with ( + patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls, + patch("lematerial_fetcher.fetch.StructuresDatabase"), + ): + mock_ver_cls.return_value = mock_version_db + mock_version_db.get_last_synced_version.return_value = None + + fetcher = LeMatRhoFetcher(config=mock_config, debug=True) + fetcher.aws_client = mock_aws_client + + items = fetcher.get_items_to_process() + + assert items.total_count == 0 + assert items.items == [] + + def test_handles_pagination(self, mock_aws_client, mock_config, mock_version_db): + """Should handle multiple pages of results.""" + mock_paginator = MagicMock() + mock_aws_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + {"CommonPrefixes": [{"Prefix": "agm000001/"}]}, + {"CommonPrefixes": [{"Prefix": "agm000002/"}]}, + ] + + with ( + patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls, + patch("lematerial_fetcher.fetch.StructuresDatabase"), + ): + mock_ver_cls.return_value = mock_version_db + mock_version_db.get_last_synced_version.return_value = None + + fetcher = LeMatRhoFetcher(config=mock_config, debug=True) + fetcher.aws_client = mock_aws_client + + items = fetcher.get_items_to_process() + + assert items.total_count == 2 + assert items.items == ["agm000001", "agm000002"] + + def test_raises_without_bucket_name(self, mock_aws_client, mock_version_db): + """Should raise ValueError if bucket name is not configured.""" + config = FetcherConfig( + base_url="https://api.test.com", + db_conn_str="postgresql://test:test@localhost:5432/test", + table_name="test_table", + page_limit=10, + page_offset=0, + mp_bucket_name="", + mp_bucket_prefix="", + log_dir="./logs", + max_retries=3, + num_workers=2, + retry_delay=2, + log_every=100, + lematrho_bucket_name=None, + ) + + with ( + patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls, + patch("lematerial_fetcher.fetch.StructuresDatabase"), + ): + mock_ver_cls.return_value = mock_version_db + mock_version_db.get_last_synced_version.return_value = None + + fetcher = LeMatRhoFetcher(config=config, debug=True) + fetcher.aws_client = mock_aws_client + + with pytest.raises(ValueError, match="lematrho_bucket_name"): + fetcher.get_items_to_process() + + +# --------------------------------------------------------------------------- +# _process_batch tests +# --------------------------------------------------------------------------- + + +class TestProcessBatch: + def test_happy_path(self, mock_config): + """Successful processing: downloads vasprun + all 4 charge files, inserts to DB.""" + mock_client = MagicMock() + mock_db_instance = MagicMock(spec=StructuresDatabase) + mock_structure = MagicMock() + mock_structure.as_dict.return_value = {"lattice": {}, "sites": []} + + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client" + ) as mock_auth, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase" + ) as mock_db_cls, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.download_gz_file_from_s3" + ) as mock_download, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.parse_vasprun_structure" + ) as mock_parse, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.compress_chgcar" + ) as mock_compress, + ): + mock_auth.return_value = mock_client + mock_db_cls.return_value = mock_db_instance + mock_download.return_value = b"fake data" + mock_parse.return_value = mock_structure + mock_compress.return_value = [[[1.0, 2.0]]] + + result = LeMatRhoFetcher._process_batch( + "agm000001", mock_config, {"occurred": False} + ) + + assert result is True + mock_db_instance.insert_data.assert_called_once() + # vasprun + 4 charge files = 5 downloads + assert mock_download.call_count == 5 + # 4 CHGCAR/AECCAR compressions + assert mock_compress.call_count == 4 + + def test_missing_vasprun_returns_false(self, mock_config): + """If vasprun.xml.gz is missing, should return False.""" + mock_client = MagicMock() + + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client" + ) as mock_auth, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase" + ), + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.download_gz_file_from_s3" + ) as mock_download, + ): + mock_auth.return_value = mock_client + mock_download.side_effect = Exception("NoSuchKey: vasprun.xml.gz") + + result = LeMatRhoFetcher._process_batch( + "agm000001", mock_config, {"occurred": False} + ) + + assert result is False + + def test_missing_aeccar1_still_processes_others(self, mock_config): + """If AECCAR1.gz is missing, other files should still be processed.""" + mock_client = MagicMock() + mock_db_instance = MagicMock(spec=StructuresDatabase) + mock_structure = MagicMock() + mock_structure.as_dict.return_value = {"lattice": {}, "sites": []} + + def download_side_effect(client, bucket, key): + if "AECCAR1.gz" in key: + raise Exception("NoSuchKey") + return b"fake data" + + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client" + ) as mock_auth, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase" + ) as mock_db_cls, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.download_gz_file_from_s3" + ) as mock_download, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.parse_vasprun_structure" + ) as mock_parse, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.compress_chgcar" + ) as mock_compress, + ): + mock_auth.return_value = mock_client + mock_db_cls.return_value = mock_db_instance + mock_download.side_effect = download_side_effect + mock_parse.return_value = mock_structure + mock_compress.return_value = [[[1.0]]] + + result = LeMatRhoFetcher._process_batch( + "agm000001", mock_config, {"occurred": False} + ) + + assert result is True + mock_db_instance.insert_data.assert_called_once() + # vasprun + 3 successful charge files (AECCAR1 failed) = 4 downloads succeed + # But download is called 5 times (1 vasprun + 4 charge files, 1 raises) + assert mock_download.call_count == 5 + # Only 3 compressions (AECCAR1 failed before compression) + assert mock_compress.call_count == 3 + + def test_all_charge_files_missing_still_inserts(self, mock_config): + """If all charge files fail but vasprun succeeds, still insert with None grids.""" + mock_client = MagicMock() + mock_db_instance = MagicMock(spec=StructuresDatabase) + mock_structure = MagicMock() + mock_structure.as_dict.return_value = {"lattice": {}, "sites": []} + + call_count = {"n": 0} + + def download_side_effect(client, bucket, key): + call_count["n"] += 1 + if "vasprun" in key: + return b"fake vasprun" + raise Exception("NoSuchKey") + + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client" + ) as mock_auth, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase" + ) as mock_db_cls, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.download_gz_file_from_s3" + ) as mock_download, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.parse_vasprun_structure" + ) as mock_parse, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.compress_chgcar" + ) as mock_compress, + ): + mock_auth.return_value = mock_client + mock_db_cls.return_value = mock_db_instance + mock_download.side_effect = download_side_effect + mock_parse.return_value = mock_structure + + result = LeMatRhoFetcher._process_batch( + "agm000001", mock_config, {"occurred": False} + ) + + assert result is True + mock_db_instance.insert_data.assert_called_once() + # No compressions since all charge file downloads failed + mock_compress.assert_not_called() + + # Verify the inserted structure has None grids + inserted = mock_db_instance.insert_data.call_args[0][0] + assert inserted.attributes["compressed_charge_density"] is None + assert inserted.attributes["compressed_aeccar0"] is None + assert inserted.attributes["compressed_aeccar1"] is None + assert inserted.attributes["compressed_aeccar2"] is None + + def test_critical_error_sets_manager_flag(self, mock_config): + """A connection error should flag the manager_dict for shutdown.""" + manager_dict = {"occurred": False} + + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client" + ) as mock_auth, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase" + ), + ): + mock_auth.side_effect = Exception("Connection refused") + + result = LeMatRhoFetcher._process_batch( + "agm000001", mock_config, manager_dict + ) + + assert result is False + assert manager_dict["occurred"] is True + + def test_correct_s3_keys(self, mock_config): + """Verify the exact S3 keys constructed for downloads.""" + mock_client = MagicMock() + mock_db_instance = MagicMock(spec=StructuresDatabase) + mock_structure = MagicMock() + mock_structure.as_dict.return_value = {} + + download_calls = [] + + def capture_downloads(client, bucket, key): + download_calls.append(key) + return b"data" + + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client" + ) as mock_auth, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase" + ) as mock_db_cls, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.download_gz_file_from_s3" + ) as mock_download, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.parse_vasprun_structure" + ) as mock_parse, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.compress_chgcar" + ) as mock_compress, + ): + mock_auth.return_value = mock_client + mock_db_cls.return_value = mock_db_instance + mock_download.side_effect = capture_downloads + mock_parse.return_value = mock_structure + mock_compress.return_value = [] + + LeMatRhoFetcher._process_batch( + "agm000001", mock_config, {"occurred": False} + ) + + expected_keys = [ + f"agm000001/{RELAX_CALC_TYPE}/vasprun.xml.gz", + f"agm000001/{STATIC_CALC_TYPE}/CHGCAR.gz", + f"agm000001/{STATIC_CALC_TYPE}/AECCAR0.gz", + f"agm000001/{STATIC_CALC_TYPE}/AECCAR1.gz", + f"agm000001/{STATIC_CALC_TYPE}/AECCAR2.gz", + ] + assert download_calls == expected_keys + + def test_uses_config_grid_shape(self, mock_config): + """Verify that the configured grid shape is passed to compress_chgcar.""" + mock_client = MagicMock() + mock_db_instance = MagicMock(spec=StructuresDatabase) + mock_structure = MagicMock() + mock_structure.as_dict.return_value = {} + + mock_config.lematrho_grid_shape = (20, 20, 20) + + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client" + ) as mock_auth, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase" + ) as mock_db_cls, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.download_gz_file_from_s3" + ) as mock_download, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.parse_vasprun_structure" + ) as mock_parse, + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.compress_chgcar" + ) as mock_compress, + ): + mock_auth.return_value = mock_client + mock_db_cls.return_value = mock_db_instance + mock_download.return_value = b"data" + mock_parse.return_value = mock_structure + mock_compress.return_value = [] + + LeMatRhoFetcher._process_batch( + "agm000001", mock_config, {"occurred": False} + ) + + # All 4 compress calls should use (20, 20, 20) + for call in mock_compress.call_args_list: + assert call[0][1] == (20, 20, 20) + + +# --------------------------------------------------------------------------- +# Utility function tests +# --------------------------------------------------------------------------- + + +class TestDownloadGzFile: + def test_decompresses_gzipped_content(self): + """Verify gzip decompression works correctly.""" + import io + + original = b"hello world test content" + compressed = gzip.compress(original) + + mock_client = MagicMock() + mock_client.get_object.return_value = { + "Body": MagicMock(read=MagicMock(return_value=compressed)) + } + + result = download_gz_file_from_s3(mock_client, "bucket", "key.gz") + assert result == original + + def test_propagates_s3_errors(self): + """S3 download errors should propagate.""" + mock_client = MagicMock() + mock_client.get_object.side_effect = Exception("NoSuchKey") + + with pytest.raises(Exception, match="NoSuchKey"): + download_gz_file_from_s3(mock_client, "bucket", "key.gz") + + +class TestBuildRawStructure: + def test_builds_correct_structure(self): + """Verify the RawStructure has correct fields and type.""" + from pymatgen.core import Lattice, Structure + + structure = Structure( + Lattice.cubic(3.0), + ["Si", "Si"], + [[0, 0, 0], [0.5, 0.5, 0.5]], + ) + compressed_grids = { + "charge_density": [[[1.0]]], + "aeccar0": [[[2.0]]], + "aeccar1": None, + "aeccar2": [[[4.0]]], + } + + raw = build_raw_structure( + material_id="agm000001", + structure=structure, + compressed_grids=compressed_grids, + grid_shape=(15, 15, 15), + s3_prefix="agm000001", + ) + + assert raw.id == "agm000001" + assert raw.type == "lematrho" + assert raw.attributes["compressed_charge_density"] == [[[1.0]]] + assert raw.attributes["compressed_aeccar0"] == [[[2.0]]] + assert raw.attributes["compressed_aeccar1"] is None + assert raw.attributes["compressed_aeccar2"] == [[[4.0]]] + assert raw.attributes["grid_shape"] == [15, 15, 15] + assert raw.attributes["s3_prefix"] == "agm000001" + assert raw.attributes["structure"] is not None + assert raw.last_modified is not None + + +# --------------------------------------------------------------------------- +# Constants tests +# --------------------------------------------------------------------------- + + +def test_valid_prefixes(): + """Verify VALID_PREFIXES covers expected material ID patterns.""" + assert "oqmd-123".startswith(VALID_PREFIXES) + assert "mp-456".startswith(VALID_PREFIXES) + assert "agm000001".startswith(VALID_PREFIXES) + assert not "test-789".startswith(VALID_PREFIXES) + assert not "random".startswith(VALID_PREFIXES) + + +def test_static_files_list(): + """Verify STATIC_FILES contains all expected charge density files.""" + assert "CHGCAR.gz" in STATIC_FILES + assert "AECCAR0.gz" in STATIC_FILES + assert "AECCAR1.gz" in STATIC_FILES + assert "AECCAR2.gz" in STATIC_FILES + assert len(STATIC_FILES) == 4 + + +def test_calc_type_constants(): + """Verify S3 subfolder constants.""" + assert STATIC_CALC_TYPE == "LeMatRhoStaticMaker" + assert RELAX_CALC_TYPE == "LeMatRhoRelaxMaker_1" + + +# --------------------------------------------------------------------------- +# Fetcher lifecycle tests +# --------------------------------------------------------------------------- + + +class TestLeMatRhoFetcher: + def test_setup_resources(self, mock_config, mock_aws_client, mock_version_db): + """Test that setup_resources initializes the authenticated AWS client.""" + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client" + ) as mock_auth, + patch("lematerial_fetcher.fetch.StructuresDatabase") as mock_db_cls, + patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls, + ): + mock_auth.return_value = mock_aws_client + mock_ver_cls.return_value = mock_version_db + + fetcher = LeMatRhoFetcher(config=mock_config) + fetcher.setup_resources() + + mock_auth.assert_called_once() + assert fetcher.aws_client is mock_aws_client + + def test_get_new_version_returns_today(self, mock_config, mock_version_db): + """Version should be today's date.""" + with ( + patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls, + patch("lematerial_fetcher.fetch.StructuresDatabase"), + ): + mock_ver_cls.return_value = mock_version_db + + fetcher = LeMatRhoFetcher(config=mock_config, debug=True) + version = fetcher.get_new_version() + + assert version == datetime.now().strftime("%Y-%m-%d") diff --git a/tests/fetcher/lematrho/test_lematrho_pipeline.py b/tests/fetcher/lematrho/test_lematrho_pipeline.py new file mode 100644 index 0000000..b377119 --- /dev/null +++ b/tests/fetcher/lematrho/test_lematrho_pipeline.py @@ -0,0 +1,1536 @@ +# Copyright 2025 Entalpic +import json +import os +import subprocess +import tempfile +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pyarrow.parquet as pq +import pytest + +from lematerial_fetcher.fetcher.lematrho.pipeline import ( + PARQUET_COLUMNS, + PARQUET_SCHEMA, + LeMatRhoDirectPipeline, + _run_bader_from_bytes, + _run_ddec6_from_bytes, + _structure_to_row, +) +from lematerial_fetcher.models.optimade import Functional, OptimadeStructure +from lematerial_fetcher.utils.config import DirectPipelineConfig + +# Minimal pymatgen Structure dict for testing +_MOCK_STRUCTURE_DICT = { + "lattice": { + "matrix": [[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]], + "a": 3.0, + "b": 3.0, + "c": 3.0, + "alpha": 90.0, + "beta": 90.0, + "gamma": 90.0, + }, + "sites": [ + { + "species": [{"element": "Si", "occu": 1}], + "abc": [0.0, 0.0, 0.0], + "xyz": [0.0, 0.0, 0.0], + "label": "Si", + }, + { + "species": [{"element": "O", "occu": 1}], + "abc": [0.5, 0.5, 0.5], + "xyz": [1.5, 1.5, 1.5], + "label": "O", + }, + ], +} + + +def _make_mock_optimade_dict(): + """Create a mock OPTIMADE dict like get_optimade_from_pymatgen would return.""" + return { + "elements": ["O", "Si"], + "nelements": 2, + "elements_ratios": [0.5, 0.5], + "nsites": 2, + "cartesian_site_positions": [[0.0, 0.0, 0.0], [1.5, 1.5, 1.5]], + "species_at_sites": ["Si", "O"], + "species": [ + { + "mass": None, + "name": "O", + "attached": None, + "nattached": None, + "concentration": [1], + "original_name": None, + "chemical_symbols": ["O"], + }, + { + "mass": None, + "name": "Si", + "attached": None, + "nattached": None, + "concentration": [1], + "original_name": None, + "chemical_symbols": ["Si"], + }, + ], + "chemical_formula_anonymous": "AB", + "chemical_formula_descriptive": "Si1 O1", + "chemical_formula_reduced": "OSi", + "dimension_types": [1, 1, 1], + "nperiodic_dimensions": 3, + "lattice_vectors": [[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]], + } + + +@pytest.fixture +def tmp_output_dir(): + """Create a temp directory for pipeline output.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + +@pytest.fixture +def mock_config(tmp_output_dir): + return DirectPipelineConfig( + lematrho_bucket_name="test-bucket", + lematrho_grid_shape=(10, 10, 10), + output_dir=tmp_output_dir, + parquet_chunk_size=3, + num_workers=2, + log_every=10, + ) + + +@pytest.fixture +def no_tools(): + """Tool paths dict where no tools are available.""" + return { + "bader_path": None, + "chargemol_path": None, + "chgsum_script_path": None, + "perl_path": None, + "atomic_densities_path": None, + "can_generate_potcar": False, + "can_run_bader": False, + "can_run_ddec6": False, + } + + +# --------------------------------------------------------------------------- +# TestListMaterials +# --------------------------------------------------------------------------- + + +class TestListMaterials: + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_filters_by_valid_prefix(self, mock_get_client, mock_config): + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + mock_paginator = MagicMock() + mock_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "CommonPrefixes": [ + {"Prefix": "mp-123/"}, + {"Prefix": "agm000001/"}, + {"Prefix": "oqmd-456/"}, + {"Prefix": "unknown-789/"}, + {"Prefix": "test-data/"}, + ] + } + ] + + with patch.object(LeMatRhoDirectPipeline, "_validate_tools", return_value={}): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + result = pipeline._list_materials() + + assert result == ["mp-123", "agm000001", "oqmd-456"] + + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_excludes_processed_ids(self, mock_get_client, mock_config): + """Checkpoint filtering removes already-processed materials.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + mock_paginator = MagicMock() + mock_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "CommonPrefixes": [ + {"Prefix": "mp-1/"}, + {"Prefix": "mp-2/"}, + {"Prefix": "mp-3/"}, + ] + } + ] + + # Write checkpoint with mp-1 already processed + checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt") + with open(checkpoint_path, "w") as f: + f.write("mp-1\n") + + with patch.object(LeMatRhoDirectPipeline, "_validate_tools", return_value={}): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + all_materials = pipeline._list_materials() + pipeline._processed_ids = pipeline._load_checkpoint() + remaining = [m for m in all_materials if m not in pipeline._processed_ids] + + assert remaining == ["mp-2", "mp-3"] + + +# --------------------------------------------------------------------------- +# TestProcessMaterial +# --------------------------------------------------------------------------- + + +class TestProcessMaterial: + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_optimade_from_pymatgen") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.compress_chgcar") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_vasprun_structure") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.download_gz_file_from_s3") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_happy_path_no_tools( + self, + mock_get_client, + mock_download, + mock_parse_vasprun, + mock_compress, + mock_get_optimade, + mock_config, + no_tools, + ): + """Full processing without external tools — charge analysis fields are None.""" + from pymatgen.core import Lattice, Structure + + mock_get_client.return_value = MagicMock() + mock_download.return_value = b"mock_bytes" + + structure = Structure( + Lattice.cubic(3.0), + ["Si", "O"], + [[0, 0, 0], [0.5, 0.5, 0.5]], + ) + mock_parse_vasprun.return_value = structure + mock_compress.return_value = [[[1.0] * 10] * 10] * 10 + mock_get_optimade.return_value = _make_mock_optimade_dict() + + result = LeMatRhoDirectPipeline._process_material( + "mp-123", mock_config, no_tools + ) + + assert result is not None + assert isinstance(result, dict) + + # Check key fields + assert result["immutable_id"] == "mp-123" + assert result["functional"] == "pbe" + assert result["cross_compatibility"] is True + assert result["nsites"] == 2 + assert result["charge_density_grid_shape"] == [10, 10, 10] + + # Compressed grids should be JSON strings + assert isinstance(result["compressed_charge_density"], str) + parsed = json.loads(result["compressed_charge_density"]) + assert isinstance(parsed, list) + + # No tools — charge analysis fields are None + assert result["bader_charges"] is None + assert result["bader_atomic_volume"] is None + assert result["ddec6_charges"] is None + + # Species should be JSON string + assert isinstance(result["species"], str) + + # All Parquet columns present + for col in PARQUET_COLUMNS: + assert col in result + + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_optimade_from_pymatgen") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.compress_chgcar") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_vasprun_structure") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.download_gz_file_from_s3") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_missing_vasprun_returns_none( + self, + mock_get_client, + mock_download, + mock_parse_vasprun, + mock_compress, + mock_get_optimade, + mock_config, + no_tools, + ): + mock_get_client.return_value = MagicMock() + mock_download.side_effect = Exception("NoSuchKey: vasprun.xml.gz") + + result = LeMatRhoDirectPipeline._process_material( + "mp-999", mock_config, no_tools + ) + assert result is None + + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_optimade_from_pymatgen") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.compress_chgcar") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_vasprun_structure") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.download_gz_file_from_s3") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_partial_charge_files( + self, + mock_get_client, + mock_download, + mock_parse_vasprun, + mock_compress, + mock_get_optimade, + mock_config, + no_tools, + ): + """Missing AECCAR1 — other files still processed.""" + from pymatgen.core import Lattice, Structure + + mock_get_client.return_value = MagicMock() + + def download_side_effect(client, bucket, key): + if "AECCAR1.gz" in key: + raise Exception("NoSuchKey") + return b"mock_bytes" + + mock_download.side_effect = download_side_effect + + structure = Structure( + Lattice.cubic(3.0), ["Si", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]] + ) + mock_parse_vasprun.return_value = structure + mock_compress.return_value = [[[1.0]]] + mock_get_optimade.return_value = _make_mock_optimade_dict() + + result = LeMatRhoDirectPipeline._process_material( + "mp-123", mock_config, no_tools + ) + + assert result is not None + # CHGCAR, AECCAR0, AECCAR2 should be present + assert result["compressed_charge_density"] is not None + assert result["compressed_aeccar0"] is not None + assert result["compressed_aeccar2"] is not None + # AECCAR1 should be None (failed to download) + assert result["compressed_aeccar1"] is None + + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_optimade_from_pymatgen") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.compress_chgcar") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_vasprun_structure") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.download_gz_file_from_s3") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_cross_compatibility_excludes_yb( + self, + mock_get_client, + mock_download, + mock_parse_vasprun, + mock_compress, + mock_get_optimade, + mock_config, + no_tools, + ): + from pymatgen.core import Lattice, Structure + + mock_get_client.return_value = MagicMock() + mock_download.return_value = b"mock_bytes" + + structure = Structure( + Lattice.cubic(3.0), ["Yb", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]] + ) + mock_parse_vasprun.return_value = structure + mock_compress.return_value = [[[1.0]]] + + optimade_dict = _make_mock_optimade_dict() + optimade_dict["elements"] = ["O", "Yb"] + optimade_dict["species_at_sites"] = ["Yb", "O"] + optimade_dict["species"] = [ + { + "mass": None, + "name": "O", + "attached": None, + "nattached": None, + "concentration": [1], + "original_name": None, + "chemical_symbols": ["O"], + }, + { + "mass": None, + "name": "Yb", + "attached": None, + "nattached": None, + "concentration": [1], + "original_name": None, + "chemical_symbols": ["Yb"], + }, + ] + mock_get_optimade.return_value = optimade_dict + + result = LeMatRhoDirectPipeline._process_material( + "mp-yb", mock_config, no_tools + ) + + assert result is not None + assert result["cross_compatibility"] is False + + @patch("lematerial_fetcher.fetcher.lematrho.pipeline._run_bader_from_bytes") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_optimade_from_pymatgen") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.compress_chgcar") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_vasprun_structure") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.download_gz_file_from_s3") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_bader_failure_still_returns_result( + self, + mock_get_client, + mock_download, + mock_parse_vasprun, + mock_compress, + mock_get_optimade, + mock_bader, + mock_config, + ): + """When Bader fails, bader fields are None but row is still returned.""" + from pymatgen.core import Lattice, Structure + + mock_get_client.return_value = MagicMock() + mock_download.return_value = b"mock_bytes" + + structure = Structure( + Lattice.cubic(3.0), ["Si", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]] + ) + mock_parse_vasprun.return_value = structure + mock_compress.return_value = [[[1.0]]] + mock_get_optimade.return_value = _make_mock_optimade_dict() + mock_bader.return_value = (None, None) + + tools = { + "bader_path": "/usr/bin/bader", + "chargemol_path": None, + "chgsum_script_path": "/opt/chgsum.pl", + "perl_path": "/usr/bin/perl", + "atomic_densities_path": None, + "can_generate_potcar": True, + "can_run_bader": True, + "can_run_ddec6": False, + } + + result = LeMatRhoDirectPipeline._process_material( + "mp-123", mock_config, tools + ) + + assert result is not None + assert result["bader_charges"] is None + assert result["bader_atomic_volume"] is None + mock_bader.assert_called_once() + + +# --------------------------------------------------------------------------- +# TestCheckpointing +# --------------------------------------------------------------------------- + + +class TestCheckpointing: + def test_load_empty_checkpoint(self, mock_config): + """No checkpoint file -> empty set.""" + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + ids = pipeline._load_checkpoint() + + assert ids == set() + + def test_load_existing_checkpoint(self, mock_config): + """Checkpoint file with IDs -> returns set of those IDs.""" + checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt") + with open(checkpoint_path, "w") as f: + f.write("mp-1\nmp-2\nagm000001\n") + + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + ids = pipeline._load_checkpoint() + + assert ids == {"mp-1", "mp-2", "agm000001"} + + def test_append_checkpoint(self, mock_config): + """Appending to checkpoint writes ID and persists.""" + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + pipeline._append_checkpoint("mp-100") + pipeline._append_checkpoint("mp-200") + + checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt") + with open(checkpoint_path, "r") as f: + lines = [line.strip() for line in f if line.strip()] + + assert lines == ["mp-100", "mp-200"] + + def test_checkpoint_skips_blank_lines(self, mock_config): + """Blank lines in checkpoint file are ignored.""" + checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt") + with open(checkpoint_path, "w") as f: + f.write("mp-1\n\n\nmp-2\n\n") + + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + ids = pipeline._load_checkpoint() + + assert ids == {"mp-1", "mp-2"} + + def test_batch_checkpoint(self, mock_config): + """Batch checkpoint writes multiple IDs atomically.""" + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + pipeline._batch_checkpoint(["mp-1", "mp-2", "mp-3"]) + + checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt") + with open(checkpoint_path, "r") as f: + lines = [line.strip() for line in f if line.strip()] + + assert lines == ["mp-1", "mp-2", "mp-3"] + + +# --------------------------------------------------------------------------- +# TestFailureTracking +# --------------------------------------------------------------------------- + + +class TestFailureTracking: + def test_load_empty_failures(self, mock_config): + """No failures file -> empty set.""" + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + ids = pipeline._load_failures() + + assert ids == set() + + def test_load_existing_failures(self, mock_config): + """Failures file with IDs -> returns set.""" + failures_path = os.path.join(mock_config.output_dir, ".failures.txt") + with open(failures_path, "w") as f: + f.write("mp-bad1\nmp-bad2\n") + + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + ids = pipeline._load_failures() + + assert ids == {"mp-bad1", "mp-bad2"} + + def test_append_failure(self, mock_config): + """Appending failure records ID on disk.""" + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + pipeline._append_failure("mp-fail1") + pipeline._append_failure("mp-fail2") + + failures_path = os.path.join(mock_config.output_dir, ".failures.txt") + with open(failures_path, "r") as f: + lines = [line.strip() for line in f if line.strip()] + + assert lines == ["mp-fail1", "mp-fail2"] + + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_resume_skips_failures(self, mock_get_client, mock_config, no_tools): + """Pipeline skips previously failed materials on resume.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + mock_paginator = MagicMock() + mock_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "CommonPrefixes": [ + {"Prefix": "mp-0/"}, + {"Prefix": "mp-1/"}, + {"Prefix": "mp-2/"}, + ] + } + ] + + # mp-0 already processed, mp-1 previously failed + checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt") + with open(checkpoint_path, "w") as f: + f.write("mp-0\n") + failures_path = os.path.join(mock_config.output_dir, ".failures.txt") + with open(failures_path, "w") as f: + f.write("mp-1\n") + + row = {col: None for col in PARQUET_COLUMNS} + row.update( + { + "elements": ["Si"], + "nsites": 1, + "chemical_formula_anonymous": "A", + "chemical_formula_reduced": "Si", + "chemical_formula_descriptive": "Si1", + "nelements": 1, + "dimension_types": [1, 1, 1], + "nperiodic_dimensions": 3, + "lattice_vectors": [[3, 0, 0], [0, 3, 0], [0, 0, 3]], + "immutable_id": "mp-2", + "cartesian_site_positions": [[0, 0, 0]], + "species": json.dumps([{"name": "Si"}]), + "species_at_sites": ["Si"], + "last_modified": datetime.now().isoformat(), + "elements_ratios": [1.0], + "functional": "pbe", + "cross_compatibility": True, + } + ) + + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value=no_tools + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config, debug=True) + + processed_ids = [] + + def mock_process(material_id, config, tool_paths): + processed_ids.append(material_id) + r = dict(row) + r["immutable_id"] = material_id + return r + + with patch.object( + LeMatRhoDirectPipeline, "_process_material", side_effect=mock_process + ): + pipeline.run() + + # Only mp-2 should be processed (mp-0 checkpointed, mp-1 failed) + assert processed_ids == ["mp-2"] + + +# --------------------------------------------------------------------------- +# TestParquetWriting +# --------------------------------------------------------------------------- + + +class TestParquetWriting: + def _make_row(self, material_id="mp-1"): + """Create a minimal valid row dict matching PARQUET_COLUMNS.""" + row = {col: None for col in PARQUET_COLUMNS} + row.update( + { + "elements": ["O", "Si"], + "nsites": 2, + "chemical_formula_anonymous": "AB", + "chemical_formula_reduced": "OSi", + "chemical_formula_descriptive": "Si1 O1", + "nelements": 2, + "dimension_types": [1, 1, 1], + "nperiodic_dimensions": 3, + "lattice_vectors": [[3.0, 0, 0], [0, 3.0, 0], [0, 0, 3.0]], + "immutable_id": material_id, + "cartesian_site_positions": [[0, 0, 0], [1.5, 1.5, 1.5]], + "species": json.dumps([{"name": "O"}, {"name": "Si"}]), + "species_at_sites": ["Si", "O"], + "last_modified": datetime.now().isoformat(), + "elements_ratios": [0.5, 0.5], + "functional": "pbe", + "cross_compatibility": True, + "charge_density_grid_shape": [10, 10, 10], + "compressed_charge_density": json.dumps([[[1.0]]]), + } + ) + return row + + def test_write_chunk(self, mock_config): + """Verify Parquet file is created with correct schema.""" + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + rows = [self._make_row(f"mp-{i}") for i in range(3)] + pipeline._write_parquet_chunk(rows, 0) + + path = os.path.join(mock_config.output_dir, "chunk_000000.parquet") + assert os.path.exists(path) + + table = pq.read_table(path) + assert table.num_rows == 3 + assert set(table.column_names) == set(PARQUET_COLUMNS) + + def test_atomic_write_no_tmp_file_remains(self, mock_config): + """After writing, no .tmp file should remain.""" + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + rows = [self._make_row()] + pipeline._write_parquet_chunk(rows, 0) + + tmp_files = [ + f + for f in os.listdir(mock_config.output_dir) + if f.endswith(".tmp") + ] + assert len(tmp_files) == 0 + + def test_chunk_index_resume(self, mock_config): + """Next chunk index should be max existing + 1.""" + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + + # Write chunks 0, 1, 2 + for i in range(3): + rows = [self._make_row(f"mp-{i}")] + pipeline._write_parquet_chunk(rows, i) + + assert pipeline._get_next_chunk_index() == 3 + + def test_tmp_files_ignored_on_resume(self, mock_config): + """Stale .tmp files don't affect chunk indexing.""" + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + + # Write one real chunk + rows = [self._make_row()] + pipeline._write_parquet_chunk(rows, 0) + + # Create a stale .tmp file + tmp_path = os.path.join( + mock_config.output_dir, "chunk_000001.parquet.tmp" + ) + with open(tmp_path, "w") as f: + f.write("stale") + + assert pipeline._get_next_chunk_index() == 1 + + def test_chunk_index_empty_dir(self, mock_config): + """Empty output dir -> chunk index 0.""" + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value={} + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config) + assert pipeline._get_next_chunk_index() == 0 + + +# --------------------------------------------------------------------------- +# TestStructureToRow +# --------------------------------------------------------------------------- + + +class TestStructureToRow: + def test_all_columns_present(self): + """Row dict should have exactly the PARQUET_COLUMNS keys.""" + optimade_dict = _make_mock_optimade_dict() + structure = OptimadeStructure( + id="mp-1", + source="lematrho", + immutable_id="mp-1", + last_modified=datetime.now(), + **optimade_dict, + functional=Functional.PBE, + cross_compatibility=True, + compute_space_group=True, + compute_bawl_hash=True, + ) + + row = _structure_to_row(structure) + assert set(row.keys()) == set(PARQUET_COLUMNS) + + def test_species_json_serialized(self): + """Species field should be a JSON string.""" + optimade_dict = _make_mock_optimade_dict() + structure = OptimadeStructure( + id="mp-1", + source="lematrho", + immutable_id="mp-1", + last_modified=datetime.now(), + **optimade_dict, + functional=Functional.PBE, + cross_compatibility=True, + compute_space_group=True, + compute_bawl_hash=True, + ) + + row = _structure_to_row(structure) + assert isinstance(row["species"], str) + parsed = json.loads(row["species"]) + assert isinstance(parsed, list) + + def test_functional_is_string(self): + """Functional enum should be converted to string value.""" + optimade_dict = _make_mock_optimade_dict() + structure = OptimadeStructure( + id="mp-1", + source="lematrho", + immutable_id="mp-1", + last_modified=datetime.now(), + **optimade_dict, + functional=Functional.PBE, + cross_compatibility=True, + compute_space_group=True, + compute_bawl_hash=True, + ) + + row = _structure_to_row(structure) + assert row["functional"] == "pbe" + + def test_last_modified_is_iso_string(self): + """last_modified should be an ISO format string.""" + optimade_dict = _make_mock_optimade_dict() + now = datetime.now() + structure = OptimadeStructure( + id="mp-1", + source="lematrho", + immutable_id="mp-1", + last_modified=now, + **optimade_dict, + functional=Functional.PBE, + cross_compatibility=True, + compute_space_group=True, + compute_bawl_hash=True, + ) + + row = _structure_to_row(structure) + # Model validator strips time to date-only (YYYY-MM-DD -> YYYY-MM-DDT00:00:00) + assert row["last_modified"] == structure.last_modified.isoformat() + + def test_compressed_fields_json_serialized(self): + """Compressed charge density fields should be JSON strings when present.""" + optimade_dict = _make_mock_optimade_dict() + grid = [[[1.0, 2.0], [3.0, 4.0]]] + structure = OptimadeStructure( + id="mp-1", + source="lematrho", + immutable_id="mp-1", + last_modified=datetime.now(), + **optimade_dict, + functional=Functional.PBE, + cross_compatibility=True, + compressed_charge_density=grid, + charge_density_grid_shape=[1, 2, 2], + compute_space_group=True, + compute_bawl_hash=True, + ) + + row = _structure_to_row(structure) + assert isinstance(row["compressed_charge_density"], str) + assert json.loads(row["compressed_charge_density"]) == grid + + +# --------------------------------------------------------------------------- +# TestRunIntegration +# --------------------------------------------------------------------------- + + +class TestRunIntegration: + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_full_pipeline_debug_mode( + self, mock_get_client, mock_config, no_tools + ): + """Integration test: process 5 materials in debug mode, verify chunks + checkpoint.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + # Mock S3 listing + mock_paginator = MagicMock() + mock_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "CommonPrefixes": [ + {"Prefix": f"mp-{i}/"} + for i in range(5) + ] + } + ] + + # Create mock results + mock_results = [] + for i in range(5): + row = {col: None for col in PARQUET_COLUMNS} + row.update( + { + "elements": ["Si"], + "nsites": 1, + "chemical_formula_anonymous": "A", + "chemical_formula_reduced": "Si", + "chemical_formula_descriptive": "Si1", + "nelements": 1, + "dimension_types": [1, 1, 1], + "nperiodic_dimensions": 3, + "lattice_vectors": [[3, 0, 0], [0, 3, 0], [0, 0, 3]], + "immutable_id": f"mp-{i}", + "cartesian_site_positions": [[0, 0, 0]], + "species": json.dumps([{"name": "Si"}]), + "species_at_sites": ["Si"], + "last_modified": datetime.now().isoformat(), + "elements_ratios": [1.0], + "functional": "pbe", + "cross_compatibility": True, + "charge_density_grid_shape": [10, 10, 10], + } + ) + mock_results.append(row) + + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value=no_tools + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config, debug=True) + + # Mock _process_material to return our pre-built rows + call_count = [0] + + def mock_process(material_id, config, tool_paths): + idx = call_count[0] + call_count[0] += 1 + return mock_results[idx] + + with patch.object( + LeMatRhoDirectPipeline, "_process_material", side_effect=mock_process + ): + pipeline.run() + + # With chunk_size=3 and 5 materials: should write 2 chunks (3 + 2) + parquet_files = sorted( + f + for f in os.listdir(mock_config.output_dir) + if f.endswith(".parquet") + ) + assert len(parquet_files) == 2 + + # Check row counts + total_rows = 0 + for f in parquet_files: + table = pq.read_table(os.path.join(mock_config.output_dir, f)) + total_rows += table.num_rows + assert total_rows == 5 + + # Check checkpoint + checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt") + with open(checkpoint_path) as f: + checkpoint_ids = {line.strip() for line in f if line.strip()} + assert checkpoint_ids == {f"mp-{i}" for i in range(5)} + + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_resume_skips_processed(self, mock_get_client, mock_config, no_tools): + """Pipeline resumes from checkpoint, skipping already-processed materials.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + mock_paginator = MagicMock() + mock_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "CommonPrefixes": [ + {"Prefix": "mp-0/"}, + {"Prefix": "mp-1/"}, + {"Prefix": "mp-2/"}, + ] + } + ] + + # Pre-existing checkpoint (mp-0 already done) + checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt") + with open(checkpoint_path, "w") as f: + f.write("mp-0\n") + + # Pre-existing Parquet chunk + row = {col: None for col in PARQUET_COLUMNS} + row.update( + { + "elements": ["Si"], + "nsites": 1, + "chemical_formula_anonymous": "A", + "chemical_formula_reduced": "Si", + "chemical_formula_descriptive": "Si1", + "nelements": 1, + "dimension_types": [1, 1, 1], + "nperiodic_dimensions": 3, + "lattice_vectors": [[3, 0, 0], [0, 3, 0], [0, 0, 3]], + "immutable_id": "mp-0", + "cartesian_site_positions": [[0, 0, 0]], + "species": json.dumps([{"name": "Si"}]), + "species_at_sites": ["Si"], + "last_modified": datetime.now().isoformat(), + "elements_ratios": [1.0], + "functional": "pbe", + "cross_compatibility": True, + } + ) + + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value=no_tools + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config, debug=True) + + processed_ids = [] + + def mock_process(material_id, config, tool_paths): + processed_ids.append(material_id) + r = dict(row) + r["immutable_id"] = material_id + return r + + with patch.object( + LeMatRhoDirectPipeline, "_process_material", side_effect=mock_process + ): + pipeline.run() + + # Only mp-1 and mp-2 should be processed (mp-0 skipped) + assert "mp-0" not in processed_ids + assert set(processed_ids) == {"mp-1", "mp-2"} + + +# --------------------------------------------------------------------------- +# TestBaderFromBytes +# --------------------------------------------------------------------------- + + +class TestBaderFromBytes: + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.write_potcar") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.subprocess.run") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.read_potcar_zval") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_acf_dat") + def test_happy_path( + self, mock_parse_acf, mock_read_zval, mock_subprocess, mock_potcar + ): + from pymatgen.core import Lattice, Structure + + structure = Structure( + Lattice.cubic(3.0), ["Si", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]] + ) + raw_files = { + "CHGCAR": b"chgcar_data", + "AECCAR0": b"aeccar0_data", + "AECCAR2": b"aeccar2_data", + } + tools = { + "bader_path": "/usr/bin/bader", + "perl_path": "/usr/bin/perl", + "chgsum_script_path": "/opt/chgsum.pl", + } + + mock_parse_acf.return_value = ([4.0, 6.0], [10.0, 12.0]) + mock_read_zval.return_value = {"Si": 4.0, "O": 6.0} + + charges, volumes = _run_bader_from_bytes( + structure, raw_files, tools, "mp-test" + ) + + assert charges == [0.0, 0.0] # valence - electron_count + assert volumes == [10.0, 12.0] + assert mock_subprocess.call_count == 2 # chgsum + bader + + def test_subprocess_timeout(self): + from pymatgen.core import Lattice, Structure + + structure = Structure( + Lattice.cubic(3.0), ["Si"], [[0, 0, 0]] + ) + raw_files = {"CHGCAR": b"x", "AECCAR0": b"x", "AECCAR2": b"x"} + tools = { + "bader_path": "/usr/bin/bader", + "perl_path": "/usr/bin/perl", + "chgsum_script_path": "/opt/chgsum.pl", + } + + with patch( + "lematerial_fetcher.fetcher.lematrho.pipeline.write_potcar" + ): + with patch( + "lematerial_fetcher.fetcher.lematrho.pipeline.subprocess.run", + side_effect=subprocess.TimeoutExpired("bader", 600), + ): + charges, volumes = _run_bader_from_bytes( + structure, raw_files, tools, "mp-test" + ) + + assert charges is None + assert volumes is None + + +# --------------------------------------------------------------------------- +# TestDdec6FromBytes +# --------------------------------------------------------------------------- + + +class TestDdec6FromBytes: + def test_subprocess_failure(self): + from pymatgen.core import Lattice, Structure + + structure = Structure( + Lattice.cubic(3.0), ["Si"], [[0, 0, 0]] + ) + raw_files = {"CHGCAR": b"x"} + tools = { + "chargemol_path": "/usr/bin/chargemol", + "atomic_densities_path": "/opt/densities", + } + + with patch( + "lematerial_fetcher.fetcher.lematrho.pipeline.write_potcar" + ): + with patch( + "lematerial_fetcher.fetcher.lematrho.pipeline.subprocess.run", + side_effect=subprocess.CalledProcessError(1, "chargemol"), + ): + result = _run_ddec6_from_bytes( + structure, raw_files, tools, "mp-test" + ) + + assert result is None + + def test_happy_path(self): + """DDEC6 returns charges when subprocess succeeds.""" + from pymatgen.core import Lattice, Structure + + structure = Structure( + Lattice.cubic(3.0), ["Si", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]] + ) + raw_files = {"CHGCAR": b"chgcar_data"} + tools = { + "chargemol_path": "/usr/bin/chargemol", + "atomic_densities_path": "/opt/densities", + } + + with patch( + "lematerial_fetcher.fetcher.lematrho.pipeline.write_potcar" + ): + with patch( + "lematerial_fetcher.fetcher.lematrho.pipeline.subprocess.run" + ): + with patch( + "lematerial_fetcher.fetcher.lematrho.pipeline.parse_ddec6_charges", + return_value=[0.5, -0.5], + ): + result = _run_ddec6_from_bytes( + structure, raw_files, tools, "mp-test" + ) + + assert result == [0.5, -0.5] + + def test_timeout(self): + """DDEC6 returns None on timeout.""" + from pymatgen.core import Lattice, Structure + + structure = Structure( + Lattice.cubic(3.0), ["Si"], [[0, 0, 0]] + ) + raw_files = {"CHGCAR": b"x"} + tools = { + "chargemol_path": "/usr/bin/chargemol", + "atomic_densities_path": "/opt/densities", + } + + with patch( + "lematerial_fetcher.fetcher.lematrho.pipeline.write_potcar" + ): + with patch( + "lematerial_fetcher.fetcher.lematrho.pipeline.subprocess.run", + side_effect=subprocess.TimeoutExpired("chargemol", 600), + ): + result = _run_ddec6_from_bytes( + structure, raw_files, tools, "mp-test" + ) + + assert result is None + + +# --------------------------------------------------------------------------- +# TestValidateTools +# --------------------------------------------------------------------------- + + +class TestValidateTools: + def test_all_tools_available(self, mock_config): + """All tools present -> can_run_bader and can_run_ddec6 are True.""" + with patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"): + with patch.dict(os.environ, {"PMG_VASP_PSP_DIR": "/opt/psp"}): + config = DirectPipelineConfig( + lematrho_bucket_name="test-bucket", + output_dir=mock_config.output_dir, + bader_path="/usr/bin/bader", + chargemol_path="/usr/bin/chargemol", + chgsum_script_path=__file__, # use this test file as a file that exists + atomic_densities_path=os.path.dirname(__file__), # dir that exists + ) + pipeline = LeMatRhoDirectPipeline(config=config) + + assert pipeline._tool_paths["can_run_bader"] is True + assert pipeline._tool_paths["can_run_ddec6"] is True + assert pipeline._tool_paths["can_generate_potcar"] is True + + def test_no_tools_available(self, mock_config): + """No tools on PATH -> can_run_bader and can_run_ddec6 are False.""" + with patch("shutil.which", return_value=None): + with patch.dict(os.environ, {}, clear=True): + config = DirectPipelineConfig( + lematrho_bucket_name="test-bucket", + output_dir=mock_config.output_dir, + ) + pipeline = LeMatRhoDirectPipeline(config=config) + + assert pipeline._tool_paths["can_run_bader"] is False + assert pipeline._tool_paths["can_run_ddec6"] is False + assert pipeline._tool_paths["can_generate_potcar"] is False + + def test_bader_but_no_chgsum(self, mock_config): + """Bader on PATH but chgsum not set -> can_run_bader False.""" + with patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"): + with patch.dict(os.environ, {"PMG_VASP_PSP_DIR": "/opt/psp"}): + config = DirectPipelineConfig( + lematrho_bucket_name="test-bucket", + output_dir=mock_config.output_dir, + bader_path="/usr/bin/bader", + # no chgsum_script_path + ) + pipeline = LeMatRhoDirectPipeline(config=config) + + assert pipeline._tool_paths["can_run_bader"] is False + assert pipeline._tool_paths["bader_path"] == "/usr/bin/bader" + + def test_missing_pmg_vasp_psp_dir(self, mock_config): + """No PMG_VASP_PSP_DIR -> can_generate_potcar False, both analyses disabled.""" + with patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"): + env = os.environ.copy() + env.pop("PMG_VASP_PSP_DIR", None) + with patch.dict(os.environ, env, clear=True): + config = DirectPipelineConfig( + lematrho_bucket_name="test-bucket", + output_dir=mock_config.output_dir, + bader_path="/usr/bin/bader", + chargemol_path="/usr/bin/chargemol", + chgsum_script_path=__file__, + atomic_densities_path=os.path.dirname(__file__), + ) + pipeline = LeMatRhoDirectPipeline(config=config) + + assert pipeline._tool_paths["can_generate_potcar"] is False + assert pipeline._tool_paths["can_run_bader"] is False + assert pipeline._tool_paths["can_run_ddec6"] is False + + +# --------------------------------------------------------------------------- +# TestStructureToRowNoneFields +# --------------------------------------------------------------------------- + + +class TestStructureToRowNoneFields: + def test_all_charge_fields_none(self): + """Structure with no charge density fields -> all charge columns None.""" + optimade_dict = _make_mock_optimade_dict() + structure = OptimadeStructure( + id="mp-1", + source="lematrho", + immutable_id="mp-1", + last_modified=datetime.now(), + **optimade_dict, + functional=Functional.PBE, + cross_compatibility=True, + compute_space_group=True, + compute_bawl_hash=True, + ) + + row = _structure_to_row(structure) + assert row["compressed_charge_density"] is None + assert row["compressed_aeccar0"] is None + assert row["compressed_aeccar1"] is None + assert row["compressed_aeccar2"] is None + assert row["charge_density_grid_shape"] is None + assert row["bader_charges"] is None + assert row["bader_atomic_volume"] is None + assert row["ddec6_charges"] is None + + def test_partial_charge_fields(self): + """Structure with only some charge fields -> only those are populated.""" + optimade_dict = _make_mock_optimade_dict() + structure = OptimadeStructure( + id="mp-1", + source="lematrho", + immutable_id="mp-1", + last_modified=datetime.now(), + **optimade_dict, + functional=Functional.PBE, + cross_compatibility=True, + compressed_charge_density=[[[1.0]]], + charge_density_grid_shape=[1, 1, 1], + bader_charges=[0.1, -0.1], + compute_space_group=True, + compute_bawl_hash=True, + ) + + row = _structure_to_row(structure) + assert row["compressed_charge_density"] is not None + assert row["compressed_aeccar0"] is None + assert row["bader_charges"] == [0.1, -0.1] + assert row["ddec6_charges"] is None + + +# --------------------------------------------------------------------------- +# TestPushToHuggingface +# --------------------------------------------------------------------------- + + +class TestPushToHuggingface: + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_push_called_when_configured(self, mock_get_client, mock_config, no_tools): + """Pipeline calls push_to_hub when hf_repo_id is configured.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_paginator = MagicMock() + mock_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + {"CommonPrefixes": [{"Prefix": "mp-1/"}]} + ] + + config = DirectPipelineConfig( + lematrho_bucket_name="test-bucket", + output_dir=mock_config.output_dir, + hf_repo_id="test-org/test-dataset", + hf_token="hf_test_token", + ) + + mock_row = {col: None for col in PARQUET_COLUMNS} + mock_row.update({ + "elements": ["Si"], + "nsites": 1, + "chemical_formula_anonymous": "A", + "chemical_formula_reduced": "Si", + "chemical_formula_descriptive": "Si1", + "nelements": 1, + "dimension_types": [1, 1, 1], + "nperiodic_dimensions": 3, + "lattice_vectors": [[3, 0, 0], [0, 3, 0], [0, 0, 3]], + "immutable_id": "mp-1", + "cartesian_site_positions": [[0, 0, 0]], + "species": json.dumps([{"name": "Si"}]), + "species_at_sites": ["Si"], + "last_modified": datetime.now().isoformat(), + "elements_ratios": [1.0], + "functional": "pbe", + "cross_compatibility": True, + }) + + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value=no_tools + ): + pipeline = LeMatRhoDirectPipeline(config=config, debug=True) + + mock_dataset = MagicMock() + with patch.object( + LeMatRhoDirectPipeline, "_process_material", return_value=mock_row + ): + with patch( + "datasets.load_dataset", + return_value={"train": mock_dataset}, + ): + pipeline.run() + + mock_dataset.push_to_hub.assert_called_once_with( + "test-org/test-dataset", + token="hf_test_token", + private=True, + ) + + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_push_not_called_without_repo_id( + self, mock_get_client, mock_config, no_tools + ): + """Pipeline skips push when hf_repo_id is None.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + mock_paginator = MagicMock() + mock_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + {"CommonPrefixes": [{"Prefix": "mp-1/"}]} + ] + + mock_row = {col: None for col in PARQUET_COLUMNS} + mock_row.update({ + "elements": ["Si"], + "nsites": 1, + "chemical_formula_anonymous": "A", + "chemical_formula_reduced": "Si", + "chemical_formula_descriptive": "Si1", + "nelements": 1, + "dimension_types": [1, 1, 1], + "nperiodic_dimensions": 3, + "lattice_vectors": [[3, 0, 0], [0, 3, 0], [0, 0, 3]], + "immutable_id": "mp-1", + "cartesian_site_positions": [[0, 0, 0]], + "species": json.dumps([{"name": "Si"}]), + "species_at_sites": ["Si"], + "last_modified": datetime.now().isoformat(), + "elements_ratios": [1.0], + "functional": "pbe", + "cross_compatibility": True, + }) + + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value=no_tools + ): + pipeline = LeMatRhoDirectPipeline(config=mock_config, debug=True) + + with patch.object( + LeMatRhoDirectPipeline, "_process_material", return_value=mock_row + ): + with patch.object( + LeMatRhoDirectPipeline, "_push_to_huggingface" + ) as mock_push: + pipeline.run() + + mock_push.assert_not_called() + + +# --------------------------------------------------------------------------- +# TestProcessMaterialWithDdec6 +# --------------------------------------------------------------------------- + + +class TestProcessMaterialWithDdec6: + @patch("lematerial_fetcher.fetcher.lematrho.pipeline._run_ddec6_from_bytes") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_optimade_from_pymatgen") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.compress_chgcar") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_vasprun_structure") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.download_gz_file_from_s3") + @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client") + def test_ddec6_populates_charges( + self, + mock_get_client, + mock_download, + mock_parse_vasprun, + mock_compress, + mock_get_optimade, + mock_ddec6, + mock_config, + ): + """When DDEC6 tools available and succeed, ddec6_charges is populated.""" + from pymatgen.core import Lattice, Structure + + mock_get_client.return_value = MagicMock() + mock_download.return_value = b"mock_bytes" + + structure = Structure( + Lattice.cubic(3.0), ["Si", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]] + ) + mock_parse_vasprun.return_value = structure + mock_compress.return_value = [[[1.0] * 10] * 10] * 10 + mock_get_optimade.return_value = _make_mock_optimade_dict() + mock_ddec6.return_value = [0.3, -0.3] + + tools = { + "bader_path": None, + "chargemol_path": "/usr/bin/chargemol", + "chgsum_script_path": None, + "perl_path": None, + "atomic_densities_path": "/opt/densities", + "can_generate_potcar": True, + "can_run_bader": False, + "can_run_ddec6": True, + } + + result = LeMatRhoDirectPipeline._process_material( + "mp-123", mock_config, tools + ) + + assert result is not None + assert result["ddec6_charges"] == [0.3, -0.3] + assert result["bader_charges"] is None + mock_ddec6.assert_called_once() + + +# --------------------------------------------------------------------------- +# TestIntegrationS3 (requires credentials, skipped in normal runs) +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +class TestIntegrationS3: + """Integration tests that pull real data from the LeMatRho S3 bucket. + + To run these tests: + 1. Create a .env.integration file with AWS credentials: + AWS_ACCESS_KEY_ID=... + AWS_SECRET_ACCESS_KEY=... + AWS_DEFAULT_REGION=us-east-1 + 2. Run: pytest -m integration tests/fetcher/lematrho/test_lematrho_pipeline.py + """ + + @pytest.fixture(autouse=True) + def _load_integration_env(self): + """Load .env.integration if available, skip otherwise.""" + env_path = os.path.join( + os.path.dirname(__file__), "..", "..", "..", ".env.integration" + ) + env_path = os.path.normpath(env_path) + if not os.path.exists(env_path): + pytest.skip( + ".env.integration not found — set AWS credentials to run integration tests" + ) + from dotenv import load_dotenv + + load_dotenv(env_path, override=True) + + def test_list_materials_from_real_bucket(self): + """Verify we can list at least 1 material from the real S3 bucket.""" + config = DirectPipelineConfig( + lematrho_bucket_name="lemat-rho", + output_dir=tempfile.mkdtemp(), + ) + with patch.object(LeMatRhoDirectPipeline, "_validate_tools", return_value={}): + pipeline = LeMatRhoDirectPipeline(config=config) + materials = pipeline._list_materials() + assert len(materials) > 0 + # All should start with valid prefixes + for m in materials[:10]: + assert m.startswith(("oqmd-", "mp-", "agm")) + + def test_process_single_material(self): + """Fetch and process a single real material end-to-end (no Bader/DDEC6).""" + output_dir = tempfile.mkdtemp() + config = DirectPipelineConfig( + lematrho_bucket_name="lemat-rho", + lematrho_grid_shape=(10, 10, 10), + output_dir=output_dir, + ) + no_tools = { + "bader_path": None, + "chargemol_path": None, + "chgsum_script_path": None, + "perl_path": None, + "atomic_densities_path": None, + "can_generate_potcar": False, + "can_run_bader": False, + "can_run_ddec6": False, + } + with patch.object( + LeMatRhoDirectPipeline, "_validate_tools", return_value=no_tools + ): + pipeline = LeMatRhoDirectPipeline(config=config) + + materials = pipeline._list_materials() + assert len(materials) > 0 + material_id = materials[0] + + result = LeMatRhoDirectPipeline._process_material( + material_id, config, no_tools + ) + assert result is not None + assert result["immutable_id"] == material_id + assert result["functional"] == "pbe" + for col in PARQUET_COLUMNS: + assert col in result diff --git a/tests/fetcher/lematrho/test_lematrho_transform.py b/tests/fetcher/lematrho/test_lematrho_transform.py new file mode 100644 index 0000000..eb3b460 --- /dev/null +++ b/tests/fetcher/lematrho/test_lematrho_transform.py @@ -0,0 +1,728 @@ +# Copyright 2025 Entalpic +import datetime +import os +import subprocess +import tempfile +from unittest.mock import MagicMock, patch + +import pytest +from pymatgen.core import Lattice, Structure + +from lematerial_fetcher.fetcher.lematrho.transform import ( + LeMatRhoTransformer, + get_cross_compatibility, + parse_acf_dat, + parse_ddec6_charges, + read_potcar_zval, +) +from lematerial_fetcher.models.models import RawStructure +from lematerial_fetcher.models.optimade import Functional +from lematerial_fetcher.utils.config import TransformerConfig + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def si_structure(): + """A simple Si diamond structure.""" + return Structure( + Lattice.cubic(5.43), + ["Si", "Si"], + [[0, 0, 0], [0.25, 0.25, 0.25]], + ) + + +@pytest.fixture +def yb_structure(): + """A structure containing Yb (not cross-compatible).""" + return Structure( + Lattice.cubic(5.0), + ["Yb", "O"], + [[0, 0, 0], [0.5, 0.5, 0.5]], + ) + + +@pytest.fixture +def raw_structure(si_structure): + """Raw structure from the fetch step with charge density data.""" + return RawStructure( + id="agm000001", + type="lematrho", + attributes={ + "structure": si_structure.as_dict(), + "compressed_charge_density": [[[1.0, 2.0]]], + "compressed_aeccar0": [[[0.1]]], + "compressed_aeccar1": [[[0.01]]], + "compressed_aeccar2": [[[0.3]]], + "grid_shape": [15, 15, 15], + "s3_prefix": "agm000001", + }, + last_modified=datetime.datetime(2025, 1, 1), + ) + + +@pytest.fixture +def raw_structure_yb(yb_structure): + """Raw structure containing Yb.""" + return RawStructure( + id="agm000002", + type="lematrho", + attributes={ + "structure": yb_structure.as_dict(), + "compressed_charge_density": [[[1.0]]], + "compressed_aeccar0": None, + "compressed_aeccar1": None, + "compressed_aeccar2": None, + "grid_shape": [15, 15, 15], + "s3_prefix": "agm000002", + }, + last_modified=datetime.datetime(2025, 1, 1), + ) + + +@pytest.fixture +def mock_config(): + """TransformerConfig for testing.""" + return TransformerConfig( + source_db_conn_str="mock://source", + dest_db_conn_str="mock://dest", + source_table_name="test_source", + dest_table_name="test_dest", + batch_size=100, + page_offset=0, + log_every=100, + log_dir="./logs", + max_retries=3, + page_limit=10, + num_workers=1, + retry_delay=2, + lematrho_bucket_name="lemat-rho", + bader_path="/usr/bin/bader", + chargemol_path="/usr/bin/chargemol", + chgsum_script_path="/usr/bin/chgsum.pl", + atomic_densities_path="/path/to/atomic_densities", + ) + + +@pytest.fixture +def transformer_with_tools(mock_config): + """Transformer with all external tools available (mocked).""" + with patch.object(LeMatRhoTransformer, "_validate_tools"): + transformer = LeMatRhoTransformer(config=mock_config, debug=True) + transformer._bader_path = "/usr/bin/bader" + transformer._chargemol_path = "/usr/bin/chargemol" + transformer._chgsum_script_path = "/usr/bin/chgsum.pl" + transformer._perl_path = "/usr/bin/perl" + transformer._atomic_densities_path = "/path/to/atomic_densities" + transformer._can_generate_potcar = True + return transformer + + +@pytest.fixture +def transformer_no_tools(mock_config): + """Transformer with no external tools available.""" + with patch.object(LeMatRhoTransformer, "_validate_tools"): + transformer = LeMatRhoTransformer(config=mock_config, debug=True) + # All tool paths remain None (set in __init__ before _validate_tools) + transformer._bader_path = None + transformer._chargemol_path = None + transformer._chgsum_script_path = None + transformer._perl_path = None + transformer._atomic_densities_path = None + transformer._can_generate_potcar = False + return transformer + + +# --------------------------------------------------------------------------- +# transform_row tests +# --------------------------------------------------------------------------- + + +class TestTransformRow: + def test_happy_path(self, transformer_with_tools, raw_structure): + """Full transform with Bader and DDEC6 results populated.""" + with ( + patch.object( + transformer_with_tools, "_run_bader_analysis" + ) as mock_bader, + patch.object( + transformer_with_tools, "_run_ddec6_analysis" + ) as mock_ddec6, + ): + mock_bader.return_value = ([0.5, -0.5], [10.0, 12.0]) + mock_ddec6.return_value = [0.3, -0.3] + + result = transformer_with_tools.transform_row(raw_structure) + + assert len(result) == 1 + s = result[0] + assert s.id == "agm000001" + assert s.source == "lematrho" + assert s.immutable_id == "agm000001" + assert s.functional == Functional.PBE + assert s.cross_compatibility is True + assert s.bader_charges == [0.5, -0.5] + assert s.bader_atomic_volume == [10.0, 12.0] + assert s.ddec6_charges == [0.3, -0.3] + assert s.compressed_charge_density == [[[1.0, 2.0]]] + assert s.compressed_aeccar0 == [[[0.1]]] + assert s.compressed_aeccar1 == [[[0.01]]] + assert s.compressed_aeccar2 == [[[0.3]]] + assert s.charge_density_grid_shape == [15, 15, 15] + assert s.space_group_it_number is not None + + def test_bader_fails_ddec6_succeeds(self, transformer_with_tools, raw_structure): + """When Bader fails, DDEC6 should still succeed independently.""" + with ( + patch.object( + transformer_with_tools, "_run_bader_analysis" + ) as mock_bader, + patch.object( + transformer_with_tools, "_run_ddec6_analysis" + ) as mock_ddec6, + ): + mock_bader.return_value = (None, None) + mock_ddec6.return_value = [0.3, -0.3] + + result = transformer_with_tools.transform_row(raw_structure) + + s = result[0] + assert s.bader_charges is None + assert s.bader_atomic_volume is None + assert s.ddec6_charges == [0.3, -0.3] + + def test_both_analyses_fail(self, transformer_with_tools, raw_structure): + """When both Bader and DDEC6 fail, structure should still be created.""" + with ( + patch.object( + transformer_with_tools, "_run_bader_analysis" + ) as mock_bader, + patch.object( + transformer_with_tools, "_run_ddec6_analysis" + ) as mock_ddec6, + ): + mock_bader.return_value = (None, None) + mock_ddec6.return_value = None + + result = transformer_with_tools.transform_row(raw_structure) + + s = result[0] + assert s.bader_charges is None + assert s.bader_atomic_volume is None + assert s.ddec6_charges is None + # Compressed grids should still be present + assert s.compressed_charge_density == [[[1.0, 2.0]]] + + def test_no_bader_binary_skips_bader(self, transformer_no_tools, raw_structure): + """When bader is not available, _run_bader_analysis should not be called.""" + with ( + patch.object( + transformer_no_tools, "_run_bader_analysis" + ) as mock_bader, + patch.object( + transformer_no_tools, "_run_ddec6_analysis" + ) as mock_ddec6, + ): + mock_ddec6.return_value = None + + result = transformer_no_tools.transform_row(raw_structure) + + mock_bader.assert_not_called() + mock_ddec6.assert_not_called() + s = result[0] + assert s.bader_charges is None + assert s.ddec6_charges is None + + def test_functional_is_pbe(self, transformer_no_tools, raw_structure): + """Functional should always be PBE for LeMatRho (MP settings).""" + result = transformer_no_tools.transform_row(raw_structure) + assert result[0].functional == Functional.PBE + + def test_cross_compatibility_excludes_yb( + self, transformer_no_tools, raw_structure_yb + ): + """Yb-containing structures should not be cross-compatible.""" + result = transformer_no_tools.transform_row(raw_structure_yb) + assert result[0].cross_compatibility is False + + def test_cross_compatibility_normal(self, transformer_no_tools, raw_structure): + """Non-Yb structures should be cross-compatible.""" + result = transformer_no_tools.transform_row(raw_structure) + assert result[0].cross_compatibility is True + + def test_missing_s3_prefix_skips_analyses( + self, transformer_with_tools, si_structure + ): + """If s3_prefix is missing from attributes, skip Bader and DDEC6.""" + raw = RawStructure( + id="agm000003", + type="lematrho", + attributes={ + "structure": si_structure.as_dict(), + "compressed_charge_density": None, + "grid_shape": [15, 15, 15], + # No s3_prefix + }, + last_modified=datetime.datetime(2025, 1, 1), + ) + with ( + patch.object( + transformer_with_tools, "_run_bader_analysis" + ) as mock_bader, + patch.object( + transformer_with_tools, "_run_ddec6_analysis" + ) as mock_ddec6, + ): + result = transformer_with_tools.transform_row(raw) + + mock_bader.assert_not_called() + mock_ddec6.assert_not_called() + + +# --------------------------------------------------------------------------- +# _run_bader_analysis tests +# --------------------------------------------------------------------------- + + +class TestRunBaderAnalysis: + def test_subprocess_timeout(self, transformer_with_tools, si_structure): + """Bader analysis should return (None, None) on subprocess timeout.""" + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3" + ) as mock_dl, + patch( + "lematerial_fetcher.fetcher.lematrho.transform.subprocess.run" + ) as mock_run, + patch("lematerial_fetcher.fetcher.lematrho.transform.write_potcar"), + patch.object( + type(transformer_with_tools), + "aws_client", + new_callable=lambda: property(lambda self: MagicMock()), + ), + ): + mock_dl.return_value = b"fake chgcar data" + mock_run.side_effect = subprocess.TimeoutExpired("bader", 600) + + charges, volumes = transformer_with_tools._run_bader_analysis( + si_structure, "agm000001", "agm000001" + ) + + assert charges is None + assert volumes is None + + def test_chgsum_nonzero_exit(self, transformer_with_tools, si_structure): + """chgsum.pl returning non-zero exit should be handled gracefully.""" + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3" + ) as mock_dl, + patch( + "lematerial_fetcher.fetcher.lematrho.transform.subprocess.run" + ) as mock_run, + patch("lematerial_fetcher.fetcher.lematrho.transform.write_potcar"), + patch.object( + type(transformer_with_tools), + "aws_client", + new_callable=lambda: property(lambda self: MagicMock()), + ), + ): + mock_dl.return_value = b"fake data" + mock_run.side_effect = subprocess.CalledProcessError( + 1, "perl chgsum.pl", stderr=b"chgsum error" + ) + + charges, volumes = transformer_with_tools._run_bader_analysis( + si_structure, "agm000001", "agm000001" + ) + + assert charges is None + assert volumes is None + + def test_s3_download_failure(self, transformer_with_tools, si_structure): + """S3 download failure should be handled gracefully.""" + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3" + ) as mock_dl, + patch.object( + type(transformer_with_tools), + "aws_client", + new_callable=lambda: property(lambda self: MagicMock()), + ), + ): + mock_dl.side_effect = Exception("NoSuchKey") + + charges, volumes = transformer_with_tools._run_bader_analysis( + si_structure, "agm000001", "agm000001" + ) + + assert charges is None + assert volumes is None + + def test_potcar_generation_failure(self, transformer_with_tools, si_structure): + """POTCAR generation failure should be handled gracefully.""" + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3" + ) as mock_dl, + patch( + "lematerial_fetcher.fetcher.lematrho.transform.write_potcar", + side_effect=Exception("No PSP"), + ), + patch.object( + type(transformer_with_tools), + "aws_client", + new_callable=lambda: property(lambda self: MagicMock()), + ), + ): + mock_dl.return_value = b"fake data" + + charges, volumes = transformer_with_tools._run_bader_analysis( + si_structure, "agm000001", "agm000001" + ) + + assert charges is None + assert volumes is None + + +# --------------------------------------------------------------------------- +# _run_ddec6_analysis tests +# --------------------------------------------------------------------------- + + +class TestRunDdec6Analysis: + def test_subprocess_timeout(self, transformer_with_tools, si_structure): + """DDEC6 analysis should return None on subprocess timeout.""" + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3" + ) as mock_dl, + patch( + "lematerial_fetcher.fetcher.lematrho.transform.subprocess.run" + ) as mock_run, + patch("lematerial_fetcher.fetcher.lematrho.transform.write_potcar"), + patch.object(transformer_with_tools, "_write_chargemol_config"), + patch.object( + type(transformer_with_tools), + "aws_client", + new_callable=lambda: property(lambda self: MagicMock()), + ), + ): + mock_dl.return_value = b"fake chgcar data" + mock_run.side_effect = subprocess.TimeoutExpired("chargemol", 600) + + result = transformer_with_tools._run_ddec6_analysis( + si_structure, "agm000001", "agm000001" + ) + + assert result is None + + def test_chargemol_nonzero_exit(self, transformer_with_tools, si_structure): + """chargemol returning non-zero exit should be handled gracefully.""" + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3" + ) as mock_dl, + patch( + "lematerial_fetcher.fetcher.lematrho.transform.subprocess.run" + ) as mock_run, + patch("lematerial_fetcher.fetcher.lematrho.transform.write_potcar"), + patch.object(transformer_with_tools, "_write_chargemol_config"), + patch.object( + type(transformer_with_tools), + "aws_client", + new_callable=lambda: property(lambda self: MagicMock()), + ), + ): + mock_dl.return_value = b"fake data" + mock_run.side_effect = subprocess.CalledProcessError( + 1, "chargemol", stderr=b"chargemol error" + ) + + result = transformer_with_tools._run_ddec6_analysis( + si_structure, "agm000001", "agm000001" + ) + + assert result is None + + +# --------------------------------------------------------------------------- +# Temp directory cleanup tests +# --------------------------------------------------------------------------- + + +class TestTempDirectoryCleanup: + def test_cleanup_on_success(self, transformer_with_tools, si_structure): + """Temp directory should be cleaned up after successful analysis.""" + created_tmpdir = [None] + + original_tempdir = tempfile.TemporaryDirectory + + class TrackingTempDir: + def __init__(self, *args, **kwargs): + self._real = original_tempdir(*args, **kwargs) + created_tmpdir[0] = self._real.name + + def __enter__(self): + return self._real.__enter__() + + def __exit__(self, *args): + return self._real.__exit__(*args) + + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3" + ) as mock_dl, + patch( + "lematerial_fetcher.fetcher.lematrho.transform.subprocess.run" + ), + patch("lematerial_fetcher.fetcher.lematrho.transform.write_potcar"), + patch( + "lematerial_fetcher.fetcher.lematrho.transform.parse_acf_dat" + ) as mock_parse, + patch( + "lematerial_fetcher.fetcher.lematrho.transform.read_potcar_zval" + ) as mock_zval, + patch( + "lematerial_fetcher.fetcher.lematrho.transform.tempfile.TemporaryDirectory", + TrackingTempDir, + ), + patch.object( + type(transformer_with_tools), + "aws_client", + new_callable=lambda: property(lambda self: MagicMock()), + ), + ): + mock_dl.return_value = b"fake data" + mock_parse.return_value = ([4.0, 4.0], [10.0, 12.0]) + mock_zval.return_value = {"Si": 4.0} + + transformer_with_tools._run_bader_analysis( + si_structure, "agm000001", "agm000001" + ) + + assert created_tmpdir[0] is not None + assert not os.path.exists(created_tmpdir[0]) + + def test_cleanup_on_failure(self, transformer_with_tools, si_structure): + """Temp directory should be cleaned up even on failure.""" + created_tmpdir = [None] + + original_tempdir = tempfile.TemporaryDirectory + + class TrackingTempDir: + def __init__(self, *args, **kwargs): + self._real = original_tempdir(*args, **kwargs) + created_tmpdir[0] = self._real.name + + def __enter__(self): + return self._real.__enter__() + + def __exit__(self, *args): + return self._real.__exit__(*args) + + with ( + patch( + "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3" + ) as mock_dl, + patch( + "lematerial_fetcher.fetcher.lematrho.transform.tempfile.TemporaryDirectory", + TrackingTempDir, + ), + patch.object( + type(transformer_with_tools), + "aws_client", + new_callable=lambda: property(lambda self: MagicMock()), + ), + ): + mock_dl.side_effect = Exception("S3 error") + + transformer_with_tools._run_bader_analysis( + si_structure, "agm000001", "agm000001" + ) + + assert created_tmpdir[0] is not None + assert not os.path.exists(created_tmpdir[0]) + + +# --------------------------------------------------------------------------- +# Parsing function unit tests +# --------------------------------------------------------------------------- + + +class TestParseAcfDat: + def test_parse_standard_format(self, tmp_path): + """Parse a standard ACF.dat file.""" + acf_content = ( + " # X Y Z CHARGE MIN DIST ATOMIC VOL\n" + " -----------------------------------------------------------------------\n" + " 1 0.000000 0.000000 0.000000 3.112903 1.235698 19.382956\n" + " 2 1.357500 1.357500 1.357500 4.887097 1.235698 13.617044\n" + " -----------------------------------------------------------------------\n" + " VACUUM CHARGE: 0.0000\n" + " VACUUM VOLUME: 0.0000\n" + " NUMBER OF ELECTRONS: 8.0000\n" + ) + acf_file = tmp_path / "ACF.dat" + acf_file.write_text(acf_content) + + counts, volumes = parse_acf_dat(str(acf_file)) + + assert len(counts) == 2 + assert len(volumes) == 2 + assert abs(counts[0] - 3.112903) < 1e-6 + assert abs(counts[1] - 4.887097) < 1e-6 + assert abs(volumes[0] - 19.382956) < 1e-6 + assert abs(volumes[1] - 13.617044) < 1e-6 + + def test_parse_single_atom(self, tmp_path): + """Parse ACF.dat with a single atom.""" + acf_content = ( + " # X Y Z CHARGE MIN DIST ATOMIC VOL\n" + " -----------------------------------------------------------------------\n" + " 1 0.000000 0.000000 0.000000 8.000000 2.715000 40.000000\n" + " -----------------------------------------------------------------------\n" + ) + acf_file = tmp_path / "ACF.dat" + acf_file.write_text(acf_content) + + counts, volumes = parse_acf_dat(str(acf_file)) + + assert counts == [8.0] + assert volumes == [40.0] + + +class TestReadPotcarZval: + def test_parse_potcar(self, tmp_path): + """Parse POTCAR for valence electron counts.""" + potcar_content = ( + " TITEL = PAW_PBE Si 05Jan2001\n" + " ZVAL = 4.00000\n" + " END of PSCTR\n" + " TITEL = PAW_PBE O 08Apr2002\n" + " ZVAL = 6.00000\n" + " END of PSCTR\n" + ) + potcar_file = tmp_path / "POTCAR" + potcar_file.write_text(potcar_content) + + zval = read_potcar_zval(str(potcar_file)) + + assert zval["Si"] == 4.0 + assert zval["O"] == 6.0 + + def test_parse_potcar_with_underscore_element(self, tmp_path): + """Handle element names like Si_d in POTCAR.""" + potcar_content = ( + " TITEL = PAW_PBE Si_d 05Jan2001\n" + " ZVAL = 4.00000\n" + " END of PSCTR\n" + ) + potcar_file = tmp_path / "POTCAR" + potcar_file.write_text(potcar_content) + + zval = read_potcar_zval(str(potcar_file)) + + assert zval["Si"] == 4.0 + + +class TestParseDdec6Charges: + def test_parse_standard_output(self, tmp_path): + """Parse DDEC6 output file.""" + ddec6_content = ( + " 2\n" + " Charge analysis\n" + " Si 0.000000 0.000000 0.000000 0.123456\n" + " Si 1.357500 1.357500 1.357500 -0.123456\n" + ) + ddec6_file = tmp_path / "DDEC6_even_tempered_net_atomic_charges.xyz" + ddec6_file.write_text(ddec6_content) + + charges = parse_ddec6_charges(str(tmp_path)) + + assert len(charges) == 2 + assert abs(charges[0] - 0.123456) < 1e-6 + assert abs(charges[1] - (-0.123456)) < 1e-6 + + +# --------------------------------------------------------------------------- +# Cross-compatibility function tests +# --------------------------------------------------------------------------- + + +class TestGetCrossCompatibility: + def test_normal_elements(self): + assert get_cross_compatibility(["Si", "O"]) is True + + def test_yb_excluded(self): + assert get_cross_compatibility(["Yb", "O"]) is False + + def test_yb_in_larger_set(self): + assert get_cross_compatibility(["Fe", "Yb", "O"]) is False + + def test_empty_elements(self): + assert get_cross_compatibility([]) is True + + +# --------------------------------------------------------------------------- +# Tool validation tests +# --------------------------------------------------------------------------- + + +class TestValidateTools: + def test_all_tools_available(self, mock_config): + """When all tools are found, can_run_bader and can_run_ddec6 should be True.""" + with ( + patch("shutil.which", return_value="/usr/bin/tool"), + patch("os.path.isfile", return_value=True), + patch("os.path.isdir", return_value=True), + patch.dict(os.environ, {"PMG_VASP_PSP_DIR": "/path/to/psp"}), + ): + transformer = LeMatRhoTransformer(config=mock_config, debug=True) + + assert transformer.can_run_bader is True + assert transformer.can_run_ddec6 is True + + def test_no_tools_available(self): + """When no tools are found, can_run_bader and can_run_ddec6 should be False.""" + config = TransformerConfig( + source_db_conn_str="mock://source", + dest_db_conn_str="mock://dest", + source_table_name="test_source", + dest_table_name="test_dest", + batch_size=100, + page_offset=0, + log_every=100, + log_dir="./logs", + max_retries=3, + page_limit=10, + num_workers=1, + retry_delay=2, + # No tool paths set + ) + with ( + patch("shutil.which", return_value=None), + patch.dict(os.environ, {}, clear=True), + ): + transformer = LeMatRhoTransformer(config=config, debug=True) + + assert transformer.can_run_bader is False + assert transformer.can_run_ddec6 is False + + def test_no_pmg_vasp_psp_dir(self, mock_config): + """Without PMG_VASP_PSP_DIR, both analyses should be disabled.""" + with ( + patch("shutil.which", return_value="/usr/bin/tool"), + patch("os.path.isfile", return_value=True), + patch("os.path.isdir", return_value=True), + patch.dict(os.environ, {}, clear=True), + ): + transformer = LeMatRhoTransformer(config=mock_config, debug=True) + + assert transformer._can_generate_potcar is False + assert transformer.can_run_bader is False + assert transformer.can_run_ddec6 is False diff --git a/tests/models/test_optimade_model.py b/tests/models/test_optimade_model.py index 39ecf53..bcfcca9 100644 --- a/tests/models/test_optimade_model.py +++ b/tests/models/test_optimade_model.py @@ -3,7 +3,9 @@ import pytest +from lematerial_fetcher.database.postgres import OptimadeDatabase, TrajectoriesDatabase from lematerial_fetcher.models.optimade import Functional, OptimadeStructure +from lematerial_fetcher.models.utils.enums import Source # Test data for a valid structure VALID_STRUCTURE_DATA = { @@ -249,3 +251,129 @@ def test_functional_enum(): data["functional"] = func structure = OptimadeStructure(**data) assert structure.functional == func + + +# ------------------------------------------------------------------- +# LeMatRho charge density field tests +# ------------------------------------------------------------------- + + +def test_source_lematrho_is_valid(): + """Test that LEMATRHO is a valid Source enum value.""" + assert Source.LEMATRHO == "lematrho" + data = VALID_STRUCTURE_DATA.copy() + data["source"] = "lematrho" + structure = OptimadeStructure(**data) + assert structure.source == Source.LEMATRHO + + +def test_charge_density_none_fields(): + """Regression: existing VALID_STRUCTURE_DATA still passes with new optional fields as None.""" + structure = OptimadeStructure(**VALID_STRUCTURE_DATA) + assert structure.compressed_charge_density is None + assert structure.compressed_aeccar0 is None + assert structure.compressed_aeccar1 is None + assert structure.compressed_aeccar2 is None + assert structure.charge_density_grid_shape is None + assert structure.bader_charges is None + assert structure.bader_atomic_volume is None + assert structure.ddec6_charges is None + + +def test_structure_with_charge_density_fields(): + """Test structure creation with all charge density fields populated.""" + data = VALID_STRUCTURE_DATA.copy() + # nsites=2 so per-site lists must have length 2 + data.update( + { + "compressed_charge_density": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], + "compressed_aeccar0": [[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]], + "compressed_aeccar1": [[[0.01, 0.02]]], + "compressed_aeccar2": [[[0.01, 0.02]]], + "charge_density_grid_shape": [2, 2, 2], + "bader_charges": [1.5, -1.5], + "bader_atomic_volume": [10.0, 12.0], + "ddec6_charges": [0.8, -0.8], + } + ) + structure = OptimadeStructure(**data) + assert structure.charge_density_grid_shape == [2, 2, 2] + assert structure.bader_charges == [1.5, -1.5] + assert structure.bader_atomic_volume == [10.0, 12.0] + assert structure.ddec6_charges == [0.8, -0.8] + assert structure.compressed_charge_density is not None + + +def test_bader_charges_wrong_length(): + """Test that bader_charges with wrong length raises ValueError.""" + data = VALID_STRUCTURE_DATA.copy() + data["bader_charges"] = [1.0, 2.0, 3.0] # nsites=2, but 3 values + with pytest.raises(ValueError, match="bader_charges"): + OptimadeStructure(**data) + + +def test_ddec6_charges_wrong_length(): + """Test that ddec6_charges with wrong length raises ValueError.""" + data = VALID_STRUCTURE_DATA.copy() + data["ddec6_charges"] = [1.0] # nsites=2, but 1 value + with pytest.raises(ValueError, match="ddec6_charges"): + OptimadeStructure(**data) + + +def test_bader_atomic_volume_wrong_length(): + """Test that bader_atomic_volume with wrong length raises ValueError.""" + data = VALID_STRUCTURE_DATA.copy() + data["bader_atomic_volume"] = [10.0, 12.0, 14.0] # nsites=2, but 3 values + with pytest.raises(ValueError, match="bader_atomic_volume"): + OptimadeStructure(**data) + + +def test_charge_density_grid_shape_validation(): + """Test that charge_density_grid_shape must be exactly 3 elements.""" + data = VALID_STRUCTURE_DATA.copy() + + # Too short + data["charge_density_grid_shape"] = [15, 15] + with pytest.raises(ValueError): + OptimadeStructure(**data) + + # Too long + data["charge_density_grid_shape"] = [15, 15, 15, 15] + with pytest.raises(ValueError): + OptimadeStructure(**data) + + +def test_optimade_db_columns_include_charge_density_fields(): + """Test that OptimadeDatabase.columns() includes the new charge density columns.""" + cols = OptimadeDatabase.columns() + assert "compressed_charge_density" in cols + assert "compressed_aeccar0" in cols + assert "compressed_aeccar1" in cols + assert "compressed_aeccar2" in cols + assert "charge_density_grid_shape" in cols + assert "bader_charges" in cols + assert "bader_atomic_volume" in cols + assert "ddec6_charges" in cols + + +def test_trajectories_db_columns_inherit_charge_density_fields(): + """Test that TrajectoriesDatabase.columns() inherits the charge density columns.""" + cols = TrajectoriesDatabase.columns() + assert "compressed_charge_density" in cols + assert "bader_charges" in cols + assert "ddec6_charges" in cols + # Also still has trajectory-specific columns + assert "relaxation_step" in cols + assert "relaxation_number" in cols + + +def test_optimade_db_column_count_matches_insert_tuple(): + """Guard test: verify that the number of columns matches what insert_data expects. + + This prevents silent data corruption from tuple/column ordering mismatches + across the 4 manually maintained tuple definitions in postgres.py. + """ + optimade_col_count = len(OptimadeDatabase.columns()) + traj_col_count = len(TrajectoriesDatabase.columns()) + # TrajectoriesDatabase should have exactly 2 more columns (relaxation_step, relaxation_number) + assert traj_col_count == optimade_col_count + 2 diff --git a/tests/test_cli.py b/tests/test_cli.py index e6feb0c..faf9723 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,6 +7,7 @@ from lematerial_fetcher.cli import cli from lematerial_fetcher.utils.config import ( + DirectPipelineConfig, FetcherConfig, TransformerConfig, ) @@ -231,6 +232,147 @@ def test_cli_args_override_env_vars(): assert call_kwargs["db_host"] == "cli.host" +def test_lematrho_subcommands(): + """Test that lematrho subcommands are correctly registered.""" + runner = CliRunner() + + result = runner.invoke(cli, ["lematrho", "--help"]) + assert result.exit_code == 0 + assert "Commands for fetching charge density data from LeMatRho" in result.output + assert "fetch" in result.output + assert "transform" in result.output + + +def test_lematrho_fetch_help(): + """Test that lematrho fetch --help returns 0 and shows expected options.""" + runner = CliRunner() + result = runner.invoke(cli, ["lematrho", "fetch", "--help"]) + assert result.exit_code == 0 + assert "--lematrho-bucket-name" in result.output + assert "--grid-shape" in result.output + + +def test_lematrho_transform_help(): + """Test that lematrho transform --help returns 0 and shows expected options.""" + runner = CliRunner() + result = runner.invoke(cli, ["lematrho", "transform", "--help"]) + assert result.exit_code == 0 + assert "--bader-path" in result.output + assert "--chargemol-path" in result.output + assert "--chgsum-script-path" in result.output + assert "--atomic-densities-path" in result.output + assert "--lematrho-bucket-name" in result.output + + +@patch("lematerial_fetcher.cli.LeMatRhoFetcher") +@patch("lematerial_fetcher.cli.load_fetcher_config") +def test_lematrho_fetch_passes_cli_args(mock_load_config, mock_fetcher): + """Test that lematrho fetch command passes CLI args to the config loader.""" + mock_config = FetcherConfig( + log_dir="./logs", + max_retries=3, + num_workers=2, + retry_delay=2, + log_every=1000, + page_offset=0, + page_limit=10, + base_url="DUMMY_BASE_URL", + table_name="test_table", + db_conn_str="db_conn_string", + mp_bucket_name="", + mp_bucket_prefix="", + lematrho_bucket_name="lemat-rho", + lematrho_grid_shape=(15, 15, 15), + ) + mock_load_config.return_value = mock_config + mock_fetcher_instance = mock_fetcher.return_value + + runner = CliRunner() + result = runner.invoke( + cli, + [ + "lematrho", + "fetch", + "--db-user", + "test_user", + "--table-name", + "test_raw", + "--lematrho-bucket-name", + "my-bucket", + "--grid-shape", + "20", + "20", + "20", + ], + ) + + assert result.exit_code == 0 + mock_load_config.assert_called_once() + call_kwargs = mock_load_config.call_args[1] + assert call_kwargs["db_user"] == "test_user" + assert call_kwargs["table_name"] == "test_raw" + assert call_kwargs["lematrho_bucket_name"] == "my-bucket" + assert call_kwargs["lematrho_grid_shape"] == (20, 20, 20) + + mock_fetcher.assert_called_once_with(config=mock_config, debug=False) + mock_fetcher_instance.fetch.assert_called_once() + + +@patch("lematerial_fetcher.cli.LeMatRhoTransformer") +@patch("lematerial_fetcher.cli.load_transformer_config") +def test_lematrho_transform_passes_cli_args(mock_load_config, mock_transformer): + """Test that lematrho transform command passes CLI args to the config loader.""" + mock_config = TransformerConfig( + log_dir="./logs", + max_retries=3, + num_workers=2, + retry_delay=2, + log_every=1000, + page_offset=0, + page_limit=10, + source_db_conn_str="source_conn_string", + dest_db_conn_str="dest_conn_string", + source_table_name="source_table", + dest_table_name="dest_table", + batch_size=500, + lematrho_bucket_name="lemat-rho", + bader_path="/usr/bin/bader", + ) + mock_load_config.return_value = mock_config + mock_transformer_instance = mock_transformer.return_value + + runner = CliRunner() + result = runner.invoke( + cli, + [ + "lematrho", + "transform", + "--db-user", + "src_user", + "--table-name", + "src_table", + "--dest-table-name", + "dest_table", + "--bader-path", + "/opt/bader", + "--lematrho-bucket-name", + "my-bucket", + ], + ) + + assert result.exit_code == 0 + mock_load_config.assert_called_once() + call_kwargs = mock_load_config.call_args[1] + assert call_kwargs["db_user"] == "src_user" + assert call_kwargs["table_name"] == "src_table" + assert call_kwargs["dest_table_name"] == "dest_table" + assert call_kwargs["bader_path"] == "/opt/bader" + assert call_kwargs["lematrho_bucket_name"] == "my-bucket" + + mock_transformer.assert_called_once_with(config=mock_config, debug=False) + mock_transformer_instance.transform.assert_called_once() + + def test_env_vars_pass_to_config(): """Test that environment variables are passed to the config loader""" @@ -260,3 +402,99 @@ def test_env_vars_pass_to_config(): assert call_kwargs["db_user"] == "src_user" assert call_kwargs["table_name"] == "src_table" assert call_kwargs["dest_table_name"] == "dest_table" + + +def test_lematrho_run_help(): + """Test that lematrho run --help returns 0 and shows expected options.""" + runner = CliRunner() + result = runner.invoke(cli, ["lematrho", "run", "--help"]) + assert result.exit_code == 0 + assert "--output-dir" in result.output + assert "--parquet-chunk-size" in result.output + assert "--lematrho-bucket-name" in result.output + assert "--grid-shape" in result.output + assert "--hf-repo-id" in result.output + assert "--hf-token" in result.output + assert "--bader-path" in result.output + assert "--chargemol-path" in result.output + assert "--chgsum-script-path" in result.output + assert "--atomic-densities-path" in result.output + assert "--num-workers" in result.output + + +def test_lematrho_run_in_subcommands(): + """Test that 'run' appears in lematrho subcommands alongside fetch and transform.""" + runner = CliRunner() + result = runner.invoke(cli, ["lematrho", "--help"]) + assert result.exit_code == 0 + assert "run" in result.output + assert "fetch" in result.output + assert "transform" in result.output + + +@patch("lematerial_fetcher.cli.LeMatRhoDirectPipeline") +@patch("lematerial_fetcher.cli.load_direct_pipeline_config") +def test_lematrho_run_passes_cli_args(mock_load_config, mock_pipeline): + """Test that lematrho run command passes CLI args to the config loader.""" + mock_config = DirectPipelineConfig( + lematrho_bucket_name="my-bucket", + lematrho_grid_shape=(20, 20, 20), + output_dir="/tmp/test_output", + parquet_chunk_size=500, + num_workers=2, + ) + mock_load_config.return_value = mock_config + mock_pipeline_instance = mock_pipeline.return_value + + runner = CliRunner() + result = runner.invoke( + cli, + [ + "lematrho", + "run", + "--output-dir", + "/tmp/test_output", + "--parquet-chunk-size", + "500", + "--num-workers", + "2", + "--lematrho-bucket-name", + "my-bucket", + "--grid-shape", + "20", + "20", + "20", + "--bader-path", + "/opt/bader", + ], + ) + + assert result.exit_code == 0 + mock_load_config.assert_called_once() + call_kwargs = mock_load_config.call_args[1] + assert call_kwargs["output_dir"] == "/tmp/test_output" + assert call_kwargs["parquet_chunk_size"] == 500 + assert call_kwargs["num_workers"] == 2 + assert call_kwargs["lematrho_bucket_name"] == "my-bucket" + assert call_kwargs["grid_shape"] == (20, 20, 20) + assert call_kwargs["bader_path"] == "/opt/bader" + + mock_pipeline.assert_called_once_with(config=mock_config, debug=False) + mock_pipeline_instance.run.assert_called_once() + + +@patch("lematerial_fetcher.cli.LeMatRhoDirectPipeline") +@patch("lematerial_fetcher.cli.load_direct_pipeline_config") +def test_lematrho_run_debug_flag(mock_load_config, mock_pipeline): + """Test that --debug flag is passed through to the pipeline.""" + mock_config = DirectPipelineConfig() + mock_load_config.return_value = mock_config + + runner = CliRunner() + result = runner.invoke( + cli, + ["--debug", "lematrho", "run"], + ) + + assert result.exit_code == 0 + mock_pipeline.assert_called_once_with(config=mock_config, debug=True) diff --git a/tests/utils/test_aws.py b/tests/utils/test_aws.py index 6309aae..375a4fc 100644 --- a/tests/utils/test_aws.py +++ b/tests/utils/test_aws.py @@ -8,6 +8,7 @@ from lematerial_fetcher.utils.aws import ( download_s3_object, + get_authenticated_aws_client, get_aws_client, get_latest_collection_version_prefix, list_s3_objects, @@ -22,6 +23,39 @@ def test_get_aws_client(): assert client._client_config.region_name == "us-east-1" +def test_get_aws_client_unchanged(): + """Regression: anonymous client still uses UNSIGNED after adding authenticated client""" + client = get_aws_client() + assert client._client_config.signature_version == UNSIGNED + + +def test_get_authenticated_aws_client_no_unsigned(): + """Test that authenticated client does not use UNSIGNED signature""" + client = get_authenticated_aws_client() + assert client._client_config.signature_version != UNSIGNED + + +def test_get_authenticated_aws_client_has_retry_config(): + """Test that authenticated client has adaptive retry configuration""" + client = get_authenticated_aws_client() + retry_config = client._client_config.retries + # boto3 converts max_attempts=3 to total_max_attempts=4 (initial + retries) + assert retry_config["total_max_attempts"] == 4 + assert retry_config["mode"] == "adaptive" + + +def test_get_authenticated_aws_client_default_region(): + """Test that authenticated client defaults to us-east-1""" + client = get_authenticated_aws_client() + assert client._client_config.region_name == "us-east-1" + + +def test_get_authenticated_aws_client_custom_region(): + """Test that authenticated client accepts a custom region""" + client = get_authenticated_aws_client(region_name="eu-west-1") + assert client._client_config.region_name == "eu-west-1" + + @pytest.fixture def mock_s3_client(): """Fixture to create a stubbed S3 client""" diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py index ff94319..df5201c 100644 --- a/tests/utils/test_config.py +++ b/tests/utils/test_config.py @@ -7,6 +7,8 @@ from dotenv import load_dotenv from lematerial_fetcher.utils.config import ( + DirectPipelineConfig, + load_direct_pipeline_config, load_fetcher_config, load_push_config, load_transformer_config, @@ -724,3 +726,110 @@ def test_load_push_config_missing_required(): assert "db credentials" in str(excinfo.value) assert "table_name" in str(excinfo.value) assert "hf_repo_id" in str(excinfo.value) + + +# --------------------------------------------------------------------------- +# DirectPipelineConfig tests +# --------------------------------------------------------------------------- + + +class TestDirectPipelineConfig: + def test_defaults(self): + """All defaults should produce a valid config.""" + config = DirectPipelineConfig() + assert config.lematrho_bucket_name == "lemat-rho" + assert config.lematrho_grid_shape == (15, 15, 15) + assert config.output_dir == "./lematrho_output" + assert config.parquet_chunk_size == 1000 + assert config.num_workers == 4 + assert config.log_every == 100 + assert config.hf_repo_id is None + assert config.hf_token is None + assert config.bader_path is None + assert config.chargemol_path is None + assert config.chgsum_script_path is None + assert config.atomic_densities_path is None + + def test_custom_values(self): + """Config should accept custom values for all fields.""" + config = DirectPipelineConfig( + lematrho_bucket_name="my-bucket", + lematrho_grid_shape=(20, 20, 20), + output_dir="/tmp/output", + parquet_chunk_size=500, + num_workers=2, + log_every=50, + hf_repo_id="org/repo", + hf_token="hf_abc123", + bader_path="/usr/bin/bader", + chargemol_path="/usr/bin/chargemol", + chgsum_script_path="/opt/chgsum.pl", + atomic_densities_path="/opt/atomic_densities", + ) + assert config.lematrho_bucket_name == "my-bucket" + assert config.lematrho_grid_shape == (20, 20, 20) + assert config.output_dir == "/tmp/output" + assert config.parquet_chunk_size == 500 + assert config.num_workers == 2 + assert config.hf_repo_id == "org/repo" + assert config.bader_path == "/usr/bin/bader" + assert config.chargemol_path == "/usr/bin/chargemol" + assert config.chgsum_script_path == "/opt/chgsum.pl" + assert config.atomic_densities_path == "/opt/atomic_densities" + + def test_not_a_base_config(self): + """DirectPipelineConfig should NOT inherit from BaseConfig.""" + from lematerial_fetcher.utils.config import BaseConfig + + assert not issubclass(DirectPipelineConfig, BaseConfig) + + +class TestLoadDirectPipelineConfig: + def test_defaults(self): + """Loader with no args should return config with all defaults.""" + config = load_direct_pipeline_config() + assert config.lematrho_bucket_name == "lemat-rho" + assert config.lematrho_grid_shape == (15, 15, 15) + assert config.output_dir == "./lematrho_output" + assert config.parquet_chunk_size == 1000 + assert config.num_workers == 4 + + def test_passes_arguments_through(self): + """Loader should pass all arguments to the config.""" + config = load_direct_pipeline_config( + lematrho_bucket_name="custom-bucket", + grid_shape=(10, 10, 10), + output_dir="/data/output", + parquet_chunk_size=2000, + num_workers=8, + log_every=200, + hf_repo_id="org/dataset", + hf_token="token123", + bader_path="/bin/bader", + chargemol_path="/bin/chargemol", + chgsum_script_path="/scripts/chgsum.pl", + atomic_densities_path="/data/densities", + ) + assert config.lematrho_bucket_name == "custom-bucket" + assert config.lematrho_grid_shape == (10, 10, 10) + assert config.output_dir == "/data/output" + assert config.parquet_chunk_size == 2000 + assert config.num_workers == 8 + assert config.log_every == 200 + assert config.hf_repo_id == "org/dataset" + assert config.hf_token == "token123" + assert config.bader_path == "/bin/bader" + + def test_ignores_unknown_kwargs(self): + """Loader should silently ignore unknown kwargs (from Click spillover).""" + config = load_direct_pipeline_config( + debug=True, + cache_dir="/tmp/cache", + some_random_kwarg="value", + ) + assert isinstance(config, DirectPipelineConfig) + + def test_grid_shape_kwarg_maps_to_config(self): + """Loader uses 'grid_shape' (Click name) mapped to 'lematrho_grid_shape'.""" + config = load_direct_pipeline_config(grid_shape=(25, 25, 25)) + assert config.lematrho_grid_shape == (25, 25, 25)