diff --git a/.env.example b/.env.example
index 51052fc..86e4527 100644
--- a/.env.example
+++ b/.env.example
@@ -82,4 +82,29 @@ LEMATERIALFETCHER_DEST_TABLE_NAME=optimade_materials
# Transformer processing settings
# LEMATERIALFETCHER_BATCH_SIZE=500
# LEMATERIALFETCHER_OFFSET=0
-# LEMATERIALFETCHER_LOG_EVERY=1000
+# LEMATERIALFETCHER_LOG_EVERY=1000
+
+# ------------------------------------------------------------------------------
+# LeMatRho Configuration (charge density pipeline)
+# ------------------------------------------------------------------------------
+
+# AWS credentials for authenticated S3 access (LeMatRho bucket)
+# AWS_ACCESS_KEY_ID=your_access_key
+# AWS_SECRET_ACCESS_KEY=your_secret_key
+# AWS_DEFAULT_REGION=us-east-1
+
+# LeMatRho S3 bucket name
+# LEMATERIALFETCHER_LEMATRHO_BUCKET_NAME=lemat-rho
+
+# VASP pseudopotential directory (required for Bader/DDEC6 analysis)
+# PMG_VASP_PSP_DIR=/path/to/vasp/pseudopotentials
+
+# External tool paths (optional, auto-detected on PATH if not set)
+# LEMATERIALFETCHER_BADER_PATH=/path/to/bader
+# LEMATERIALFETCHER_CHARGEMOL_PATH=/path/to/chargemol
+# LEMATERIALFETCHER_CHGSUM_SCRIPT_PATH=/path/to/chgsum.pl
+# LEMATERIALFETCHER_ATOMIC_DENSITIES_PATH=/path/to/atomic_densities
+
+# HuggingFace (for pushing dataset after pipeline completes)
+# HF_REPO_ID=your-org/lematrho-dataset
+# HF_TOKEN=hf_your_token
diff --git a/.gitignore b/.gitignore
index a682896..21d486a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -127,6 +127,8 @@ celerybeat.pid
# Environments
.env
+.env.*
+!.env.example
.venv
# env/
venv/
diff --git a/pyproject.toml b/pyproject.toml
index cb18c18..a9b7855 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,6 +24,7 @@ dependencies = [
"moyopy>=0.4.2",
"ase>=3.24.0",
"material-hasher",
+ "mp-pyrho>=0.3.1",
]
[project.scripts]
@@ -58,5 +59,10 @@ dev-dependencies = [
material-hasher = { git = "https://github.com/LeMaterial/lematerial-hasher.git" }
+[tool.pytest.ini_options]
+markers = [
+ "integration: tests that require real AWS credentials and S3 access (deselect with '-m \"not integration\"')",
+]
+
[tool.ruff.lint]
extend-select = ["I"]
diff --git a/src/lematerial_fetcher/cli.py b/src/lematerial_fetcher/cli.py
index 95ec89c..ce61aa8 100644
--- a/src/lematerial_fetcher/cli.py
+++ b/src/lematerial_fetcher/cli.py
@@ -28,6 +28,9 @@
MPTrajectoryTransformer,
MPTransformer,
)
+from lematerial_fetcher.fetcher.lematrho.fetch import LeMatRhoFetcher
+from lematerial_fetcher.fetcher.lematrho.pipeline import LeMatRhoDirectPipeline
+from lematerial_fetcher.fetcher.lematrho.transform import LeMatRhoTransformer
from lematerial_fetcher.fetcher.oqmd.fetch import OQMDFetcher
from lematerial_fetcher.fetcher.oqmd.transform import (
OQMDTrajectoryTransformer,
@@ -37,6 +40,9 @@
from lematerial_fetcher.utils.cli import (
add_common_options,
add_fetch_options,
+ add_lematrho_direct_options,
+ add_lematrho_fetch_options,
+ add_lematrho_transform_options,
add_mp_fetch_options,
add_mysql_options,
add_push_options,
@@ -44,6 +50,7 @@
get_default_mp_bucket_name,
)
from lematerial_fetcher.utils.config import (
+ load_direct_pipeline_config,
load_fetcher_config,
load_push_config,
load_transformer_config,
@@ -114,9 +121,17 @@ def oqmd_cli(ctx):
pass
+@click.group(name="lematrho")
+@click.pass_context
+def lematrho_cli(ctx):
+ """Commands for fetching charge density data from LeMatRho."""
+ pass
+
+
cli.add_command(mp_cli)
cli.add_command(alexandria_cli)
cli.add_command(oqmd_cli)
+cli.add_command(lematrho_cli)
# ------------------------------------------------------------------------------
# MP commands
@@ -341,6 +356,85 @@ def oqmd_transform(ctx, traj, **config_kwargs):
logger.fatal("\nAborted.", exit=1)
+# ------------------------------------------------------------------------------
+# LeMatRho commands
+# ------------------------------------------------------------------------------
+
+
+@lematrho_cli.command(name="fetch")
+@click.pass_context
+@add_common_options
+@add_fetch_options
+@add_lematrho_fetch_options
+def lematrho_fetch(ctx, **config_kwargs):
+ """Fetch charge density data from the LeMatRho S3 bucket.
+
+ Downloads CHGCAR/AECCAR files, compresses charge densities via pyrho,
+ and stores compressed grids in the raw_structures database table.
+
+ Requires AWS credentials (AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY).
+ """
+ config_kwargs["base_url"] = "DUMMY_BASE_URL" # Not needed for LeMatRho
+ config_kwargs["mp_bucket_name"] = "" # Not needed for LeMatRho
+ config_kwargs["mp_bucket_prefix"] = "" # Not needed for LeMatRho
+
+ # Convert grid_shape tuple from Click to proper format
+ if "grid_shape" in config_kwargs:
+ config_kwargs["lematrho_grid_shape"] = config_kwargs.pop("grid_shape")
+
+ config = load_fetcher_config(**config_kwargs)
+ try:
+ fetcher = LeMatRhoFetcher(config=config, debug=ctx.obj["debug"])
+ fetcher.fetch()
+ except KeyboardInterrupt:
+ logger.fatal("\nAborted.", exit=1)
+
+
+@lematrho_cli.command(name="transform")
+@click.pass_context
+@add_common_options
+@add_transformer_options
+@add_lematrho_transform_options
+def lematrho_transform(ctx, **config_kwargs):
+ """Transform raw LeMatRho structures into OPTIMADE format.
+
+ Optionally runs Bader and DDEC6 charge analysis using external tools.
+
+ External tool requirements:
+ - bader executable (--bader-path or on PATH)
+ - perl + chgsum.pl script (--chgsum-script-path)
+ - chargemol executable (--chargemol-path or on PATH)
+ - PMG_VASP_PSP_DIR environment variable for POTCAR generation
+ - Atomic densities directory (--atomic-densities-path) for DDEC6
+ """
+ config = load_transformer_config(**config_kwargs)
+ try:
+ transformer = LeMatRhoTransformer(config=config, debug=ctx.obj["debug"])
+ transformer.transform()
+ except KeyboardInterrupt:
+ logger.fatal("\nAborted.", exit=1)
+
+
+@lematrho_cli.command(name="run")
+@click.pass_context
+@add_lematrho_direct_options
+def lematrho_run(ctx, **config_kwargs):
+ """Run complete LeMatRho pipeline: S3 -> Parquet -> HuggingFace.
+
+ Downloads charge density data, compresses via pyrho, optionally runs
+ Bader and DDEC6 analysis, and writes Parquet files directly.
+ No PostgreSQL required.
+
+ Requires AWS credentials (AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY).
+ """
+ config = load_direct_pipeline_config(**config_kwargs)
+ try:
+ pipeline = LeMatRhoDirectPipeline(config=config, debug=ctx.obj["debug"])
+ pipeline.run()
+ except KeyboardInterrupt:
+ logger.fatal("\nAborted.", exit=1)
+
+
# ------------------------------------------------------------------------------
# Push commands
# ------------------------------------------------------------------------------
diff --git a/src/lematerial_fetcher/database/postgres.py b/src/lematerial_fetcher/database/postgres.py
index df2416f..1094f81 100644
--- a/src/lematerial_fetcher/database/postgres.py
+++ b/src/lematerial_fetcher/database/postgres.py
@@ -494,6 +494,14 @@ def columns(cls) -> dict[str, str]:
"space_group_it_number": "INTEGER",
"cross_compatibility": "BOOLEAN",
"bawl_fingerprint": "TEXT",
+ "compressed_charge_density": "JSONB",
+ "compressed_aeccar0": "JSONB",
+ "compressed_aeccar1": "JSONB",
+ "compressed_aeccar2": "JSONB",
+ "charge_density_grid_shape": "INTEGER[]",
+ "bader_charges": "FLOAT[]",
+ "bader_atomic_volume": "FLOAT[]",
+ "ddec6_charges": "FLOAT[]",
}
def _prepare_species_data(self, species: list[dict[str, Any]]) -> list[Json]:
@@ -572,6 +580,14 @@ def insert_data(self, structure: OptimadeStructure) -> None:
structure.space_group_it_number,
structure.cross_compatibility,
structure.bawl_fingerprint,
+ Json(structure.compressed_charge_density),
+ Json(structure.compressed_aeccar0),
+ Json(structure.compressed_aeccar1),
+ Json(structure.compressed_aeccar2),
+ structure.charge_density_grid_shape,
+ structure.bader_charges,
+ structure.bader_atomic_volume,
+ structure.ddec6_charges,
)
cur.execute(query, input_data)
self.conn.commit()
@@ -639,6 +655,14 @@ def batch_insert_data(
structure.space_group_it_number,
structure.cross_compatibility,
structure.bawl_fingerprint,
+ Json(structure.compressed_charge_density),
+ Json(structure.compressed_aeccar0),
+ Json(structure.compressed_aeccar1),
+ Json(structure.compressed_aeccar2),
+ structure.charge_density_grid_shape,
+ structure.bader_charges,
+ structure.bader_atomic_volume,
+ structure.ddec6_charges,
)
)
@@ -740,6 +764,14 @@ def insert_data(self, structure: Trajectory) -> None:
structure.space_group_it_number,
structure.cross_compatibility,
structure.bawl_fingerprint,
+ Json(structure.compressed_charge_density),
+ Json(structure.compressed_aeccar0),
+ Json(structure.compressed_aeccar1),
+ Json(structure.compressed_aeccar2),
+ structure.charge_density_grid_shape,
+ structure.bader_charges,
+ structure.bader_atomic_volume,
+ structure.ddec6_charges,
# trajectory-specific fields
structure.relaxation_step,
structure.relaxation_number,
@@ -810,6 +842,14 @@ def batch_insert_data(
structure.space_group_it_number,
structure.cross_compatibility,
structure.bawl_fingerprint,
+ Json(structure.compressed_charge_density),
+ Json(structure.compressed_aeccar0),
+ Json(structure.compressed_aeccar1),
+ Json(structure.compressed_aeccar2),
+ structure.charge_density_grid_shape,
+ structure.bader_charges,
+ structure.bader_atomic_volume,
+ structure.ddec6_charges,
# trajectory-specific fields
structure.relaxation_step,
structure.relaxation_number,
diff --git a/src/lematerial_fetcher/fetch.py b/src/lematerial_fetcher/fetch.py
index 9f7184c..fc1b1a7 100644
--- a/src/lematerial_fetcher/fetch.py
+++ b/src/lematerial_fetcher/fetch.py
@@ -67,7 +67,7 @@ def db(self) -> StructuresDatabase:
Returns
-------
- StructuresDatabase
+ StructuresDatabase
Database connection
"""
if self._db is None:
diff --git a/src/lematerial_fetcher/fetcher/lematrho/__init__.py b/src/lematerial_fetcher/fetcher/lematrho/__init__.py
new file mode 100644
index 0000000..7dd6f34
--- /dev/null
+++ b/src/lematerial_fetcher/fetcher/lematrho/__init__.py
@@ -0,0 +1 @@
+# Copyright 2025 Entalpic
diff --git a/src/lematerial_fetcher/fetcher/lematrho/fetch.py b/src/lematerial_fetcher/fetcher/lematrho/fetch.py
new file mode 100644
index 0000000..b6a4376
--- /dev/null
+++ b/src/lematerial_fetcher/fetcher/lematrho/fetch.py
@@ -0,0 +1,196 @@
+# Copyright 2025 Entalpic
+from datetime import datetime
+from multiprocessing import Manager
+from typing import Optional
+
+from lematerial_fetcher.database.postgres import StructuresDatabase
+from lematerial_fetcher.fetch import BaseFetcher, ItemsInfo
+from lematerial_fetcher.fetcher.lematrho.utils import (
+ DEFAULT_MAX_WORKERS,
+ GRID_KEY_MAP,
+ RELAX_CALC_TYPE,
+ STATIC_CALC_TYPE,
+ STATIC_FILES,
+ VALID_PREFIXES,
+ build_raw_structure,
+ compress_chgcar,
+ download_gz_file_from_s3,
+ parse_vasprun_structure,
+)
+from lematerial_fetcher.utils.aws import get_authenticated_aws_client
+from lematerial_fetcher.utils.config import FetcherConfig, load_fetcher_config
+from lematerial_fetcher.utils.logging import logger
+
+
+class LeMatRhoFetcher(BaseFetcher):
+ """Fetcher for LeMatRho charge density data from an authenticated S3 bucket.
+
+ Downloads CHGCAR/AECCAR files, compresses charge densities via pyrho,
+ and stores compressed grids in the raw_structures PostgreSQL table.
+
+ Args:
+ config: Fetcher configuration. If ``None``, loads from defaults.
+ debug: If ``True``, process sequentially for debugging.
+ """
+
+ def __init__(self, config: FetcherConfig = None, debug: bool = False):
+ super().__init__(config or load_fetcher_config(), debug)
+ self.aws_client = None
+ self.manager = Manager()
+ self.manager_dict = self.manager.dict()
+ self.manager_dict["occurred"] = False
+
+ def setup_resources(self) -> None:
+ """Set up authenticated AWS client and database connection."""
+ self.aws_client = get_authenticated_aws_client()
+ self.setup_database()
+
+ def get_items_to_process(self) -> ItemsInfo:
+ """List material folder prefixes from S3, filtered by valid ID prefixes.
+
+ Only folders starting with ``VALID_PREFIXES`` (``oqmd-``, ``mp-``,
+ ``agm``) are included.
+
+ Returns:
+ ``ItemsInfo`` containing material folder names to process.
+
+ Raises:
+ ValueError: If ``lematrho_bucket_name`` is not set in config.
+ """
+ bucket = self.config.lematrho_bucket_name
+ if not bucket:
+ raise ValueError("lematrho_bucket_name must be set in config")
+
+ client = self.aws_client
+ paginator = client.get_paginator("list_objects_v2")
+
+ material_folders = []
+ for page in paginator.paginate(Bucket=bucket, Delimiter="/"):
+ for prefix_info in page.get("CommonPrefixes", []):
+ prefix = prefix_info["Prefix"]
+ # Extract folder name (remove trailing slash)
+ folder_name = prefix.rstrip("/")
+ if folder_name.startswith(VALID_PREFIXES):
+ material_folders.append(folder_name)
+ else:
+ logger.debug(f"Skipping folder with unknown prefix: {folder_name}")
+
+ logger.info(
+ f"Found {len(material_folders)} material folders to process in bucket {bucket}"
+ )
+
+ return ItemsInfo(
+ start_offset=0,
+ total_count=len(material_folders),
+ items=material_folders,
+ )
+
+ @staticmethod
+ def _process_batch(
+ batch: str, config: FetcherConfig, manager_dict: dict, worker_id: int = 0
+ ) -> bool:
+ """Process a single material folder from S3.
+
+ Downloads vasprun.xml.gz (for structure), then CHGCAR/AECCAR files,
+ compresses charge densities, and inserts into PostgreSQL.
+
+ Args:
+ batch: Material folder name, e.g. ``"agm000001"``.
+ config: Fetcher configuration.
+ manager_dict: Shared dict for inter-process communication.
+ worker_id: Worker process identifier.
+
+ Returns:
+ ``True`` if successful, ``False`` if failed.
+ """
+ material_id = batch
+ bucket = config.lematrho_bucket_name
+ grid_shape = config.lematrho_grid_shape or (15, 15, 15)
+
+ try:
+ # Fresh clients per worker (multiprocessing safety)
+ aws_client = get_authenticated_aws_client()
+ db = StructuresDatabase(config.db_conn_str, config.table_name)
+
+ # Step 1: Download and parse vasprun.xml.gz for structure
+ vasprun_key = f"{material_id}/{RELAX_CALC_TYPE}/vasprun.xml.gz"
+ try:
+ vasprun_bytes = download_gz_file_from_s3(
+ aws_client, bucket, vasprun_key
+ )
+ structure = parse_vasprun_structure(vasprun_bytes)
+ del vasprun_bytes
+ except Exception as e:
+ logger.warning(
+ f"Failed to parse vasprun.xml.gz for {material_id}: {e}"
+ )
+ return False
+
+ # Step 2: Download and compress charge density files sequentially
+ compressed_grids: dict[str, Optional[list]] = {
+ "charge_density": None,
+ "aeccar0": None,
+ "aeccar1": None,
+ "aeccar2": None,
+ }
+
+ for filename in STATIC_FILES:
+ s3_key = f"{material_id}/{STATIC_CALC_TYPE}/{filename}"
+ grid_name = GRID_KEY_MAP[filename]
+ try:
+ raw_bytes = download_gz_file_from_s3(aws_client, bucket, s3_key)
+ compressed = compress_chgcar(raw_bytes, grid_shape)
+ compressed_grids[grid_name] = compressed
+ del raw_bytes, compressed
+ except Exception as e:
+ logger.warning(
+ f"Failed to process {filename} for {material_id}: {e}"
+ )
+
+ # Step 3: Build and insert raw structure
+ raw_structure = build_raw_structure(
+ material_id=material_id,
+ structure=structure,
+ compressed_grids=compressed_grids,
+ grid_shape=grid_shape,
+ s3_prefix=material_id,
+ )
+ db.insert_data(raw_structure)
+
+ logger.debug(f"Successfully processed {material_id}")
+ return True
+
+ except Exception as e:
+ logger.error(f"Failed to process material {material_id}: {e}")
+ if BaseFetcher.is_critical_error(e):
+ if manager_dict is not None:
+ manager_dict["occurred"] = True
+ return False
+
+ def cleanup_resources(self) -> None:
+ """Clean up AWS client, database connection, and process manager."""
+ if self.aws_client:
+ self.aws_client = None
+
+ if hasattr(self, "manager"):
+ self.manager.shutdown()
+
+ super().cleanup_resources()
+
+ def get_new_version(self) -> str:
+ """Get version identifier for this fetch run.
+
+ Returns:
+ Current date as ``YYYY-MM-DD`` string.
+ """
+ return datetime.now().strftime("%Y-%m-%d")
+
+
+def fetch():
+ """Fetch charge density data from the LeMatRho S3 bucket."""
+ fetcher = LeMatRhoFetcher()
+ fetcher.fetch()
+
+
+if __name__ == "__main__":
+ fetch()
diff --git a/src/lematerial_fetcher/fetcher/lematrho/pipeline.py b/src/lematerial_fetcher/fetcher/lematrho/pipeline.py
new file mode 100644
index 0000000..3881f48
--- /dev/null
+++ b/src/lematerial_fetcher/fetcher/lematrho/pipeline.py
@@ -0,0 +1,863 @@
+# Copyright 2025 Entalpic
+"""Direct S3-to-Parquet pipeline for LeMatRho charge density data.
+
+Bypasses PostgreSQL entirely: downloads from S3, compresses charge densities,
+runs Bader/DDEC6 analysis, and writes Parquet files directly.
+"""
+
+import concurrent.futures
+import gc
+import json
+import os
+import shutil
+import subprocess
+import tempfile
+from datetime import datetime
+from glob import glob
+from typing import Optional
+
+import pyarrow as pa
+import pyarrow.parquet as pq
+from pymatgen.core import Structure
+
+from lematerial_fetcher.fetcher.lematrho.transform import (
+ get_cross_compatibility,
+ parse_acf_dat,
+ parse_ddec6_charges,
+ read_potcar_zval,
+)
+from lematerial_fetcher.fetcher.lematrho.utils import (
+ BADER_TIMEOUT,
+ CHARGEMOL_TIMEOUT,
+ CHGSUM_TIMEOUT,
+ GRID_KEY_MAP,
+ RELAX_CALC_TYPE,
+ STATIC_CALC_TYPE,
+ STATIC_FILES,
+ VALID_PREFIXES,
+ compress_chgcar,
+ download_gz_file_from_s3,
+ parse_vasprun_structure,
+ write_potcar,
+)
+from lematerial_fetcher.models.optimade import Functional, OptimadeStructure
+from lematerial_fetcher.utils.aws import get_authenticated_aws_client
+from lematerial_fetcher.utils.config import DirectPipelineConfig
+from lematerial_fetcher.utils.logging import logger
+from lematerial_fetcher.utils.structure import get_optimade_from_pymatgen
+
+# Columns in the output Parquet files (matches HuggingFace Features schema)
+PARQUET_COLUMNS = [
+ "elements",
+ "nsites",
+ "chemical_formula_anonymous",
+ "chemical_formula_reduced",
+ "chemical_formula_descriptive",
+ "nelements",
+ "dimension_types",
+ "nperiodic_dimensions",
+ "lattice_vectors",
+ "immutable_id",
+ "cartesian_site_positions",
+ "species",
+ "species_at_sites",
+ "last_modified",
+ "elements_ratios",
+ "stress_tensor",
+ "energy",
+ "energy_corrected",
+ "magnetic_moments",
+ "forces",
+ "total_magnetization",
+ "charges",
+ "dos_ef",
+ "functional",
+ "cross_compatibility",
+ "bawl_fingerprint",
+ "space_group_it_number",
+ "compressed_charge_density",
+ "compressed_aeccar0",
+ "compressed_aeccar1",
+ "compressed_aeccar2",
+ "charge_density_grid_shape",
+ "bader_charges",
+ "bader_atomic_volume",
+ "ddec6_charges",
+]
+
+# PyArrow schema matching the HuggingFace Features
+PARQUET_SCHEMA = pa.schema(
+ [
+ ("elements", pa.list_(pa.string())),
+ ("nsites", pa.int32()),
+ ("chemical_formula_anonymous", pa.string()),
+ ("chemical_formula_reduced", pa.string()),
+ ("chemical_formula_descriptive", pa.string()),
+ ("nelements", pa.int8()),
+ ("dimension_types", pa.list_(pa.int8())),
+ ("nperiodic_dimensions", pa.int8()),
+ ("lattice_vectors", pa.list_(pa.list_(pa.float64()))),
+ ("immutable_id", pa.string()),
+ ("cartesian_site_positions", pa.list_(pa.list_(pa.float64()))),
+ ("species", pa.string()), # JSON-serialized
+ ("species_at_sites", pa.list_(pa.string())),
+ ("last_modified", pa.string()),
+ ("elements_ratios", pa.list_(pa.float64())),
+ ("stress_tensor", pa.list_(pa.list_(pa.float64()))),
+ ("energy", pa.float64()),
+ ("energy_corrected", pa.float64()),
+ ("magnetic_moments", pa.list_(pa.float64())),
+ ("forces", pa.list_(pa.list_(pa.float64()))),
+ ("total_magnetization", pa.float64()),
+ ("charges", pa.list_(pa.float64())),
+ ("dos_ef", pa.float64()),
+ ("functional", pa.string()),
+ ("cross_compatibility", pa.bool_()),
+ ("bawl_fingerprint", pa.string()),
+ ("space_group_it_number", pa.int32()),
+ ("compressed_charge_density", pa.string()), # JSON-serialized
+ ("compressed_aeccar0", pa.string()), # JSON-serialized
+ ("compressed_aeccar1", pa.string()), # JSON-serialized
+ ("compressed_aeccar2", pa.string()), # JSON-serialized
+ ("charge_density_grid_shape", pa.list_(pa.int32())),
+ ("bader_charges", pa.list_(pa.float64())),
+ ("bader_atomic_volume", pa.list_(pa.float64())),
+ ("ddec6_charges", pa.list_(pa.float64())),
+ ]
+)
+
+# Files needed for Bader analysis (must keep raw bytes)
+_BADER_FILES = {"CHGCAR.gz", "AECCAR0.gz", "AECCAR2.gz"}
+# Files needed for DDEC6 analysis
+_DDEC6_FILES = {"CHGCAR.gz"}
+
+
+
+def _run_bader_from_bytes(
+ structure: Structure,
+ raw_files: dict[str, bytes],
+ tool_paths: dict,
+ material_id: str,
+) -> tuple[Optional[list[float]], Optional[list[float]]]:
+ """Run Bader charge analysis from raw decompressed file bytes.
+
+ Writes raw VASP files to a temp directory, generates POTCAR, runs
+ ``chgsum.pl`` and ``bader``, and parses the resulting ACF.dat.
+
+ Args:
+ structure: Pymatgen Structure for POTCAR generation.
+ raw_files: Mapping of VASP filenames to their raw bytes,
+ e.g. ``{"CHGCAR": b"...", "AECCAR0": b"...", "AECCAR2": b"..."}``.
+ tool_paths: Tool path configuration dict from ``_validate_tools()``.
+ material_id: Material identifier, used for logging.
+
+ Returns:
+ Tuple of ``(net_charges, atomic_volumes)`` or ``(None, None)`` on failure.
+ """
+ try:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Write raw decompressed files (bader expects plain-text VASP format)
+ for name in ["CHGCAR", "AECCAR0", "AECCAR2"]:
+ with open(os.path.join(tmpdir, name), "wb") as f:
+ f.write(raw_files[name])
+
+ # Generate POTCAR
+ write_potcar(structure, tmpdir)
+
+ # Sum AECCAR0 + AECCAR2 -> CHGCAR_sum
+ subprocess.run(
+ [
+ tool_paths["perl_path"],
+ tool_paths["chgsum_script_path"],
+ "AECCAR0",
+ "AECCAR2",
+ ],
+ cwd=tmpdir,
+ timeout=CHGSUM_TIMEOUT,
+ check=True,
+ capture_output=True,
+ )
+
+ # Run Bader
+ subprocess.run(
+ [tool_paths["bader_path"], "CHGCAR", "-ref", "CHGCAR_sum"],
+ cwd=tmpdir,
+ timeout=BADER_TIMEOUT,
+ check=True,
+ capture_output=True,
+ )
+
+ # Parse results
+ electron_counts, atomic_volumes = parse_acf_dat(
+ os.path.join(tmpdir, "ACF.dat")
+ )
+ zval = read_potcar_zval(os.path.join(tmpdir, "POTCAR"))
+
+ net_charges = []
+ for site, ec in zip(structure.sites, electron_counts):
+ element = str(site.specie)
+ valence = zval.get(element, 0)
+ net_charges.append(valence - ec)
+
+ return net_charges, atomic_volumes
+
+ except subprocess.TimeoutExpired:
+ logger.warning(f"Bader analysis timed out for {material_id}")
+ return None, None
+ except subprocess.CalledProcessError as e:
+ logger.warning(
+ f"Bader subprocess failed for {material_id}: "
+ f"exit code {e.returncode}, stderr: {e.stderr}"
+ )
+ return None, None
+ except Exception as e:
+ logger.warning(f"Bader analysis failed for {material_id}: {e}")
+ return None, None
+
+
+def _run_ddec6_from_bytes(
+ structure: Structure,
+ raw_files: dict[str, bytes],
+ tool_paths: dict,
+ material_id: str,
+) -> Optional[list[float]]:
+ """Run DDEC6 charge analysis from raw decompressed file bytes.
+
+ Writes CHGCAR and POTCAR to a temp directory, runs chargemol, and
+ parses the DDEC6 net atomic charges.
+
+ Args:
+ structure: Pymatgen Structure for POTCAR generation.
+ raw_files: Mapping with at least ``{"CHGCAR": b"..."}``.
+ tool_paths: Tool path configuration dict from ``_validate_tools()``.
+ material_id: Material identifier, used for logging.
+
+ Returns:
+ DDEC6 net charges per site, or ``None`` on failure.
+ """
+ try:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ with open(os.path.join(tmpdir, "CHGCAR"), "wb") as f:
+ f.write(raw_files["CHGCAR"])
+
+ write_potcar(structure, tmpdir)
+
+ # Write chargemol job control file
+ config_content = (
+ "\n"
+ "0.0\n"
+ "\n"
+ "\n"
+ ".true.\n"
+ ".true.\n"
+ ".true.\n"
+ "\n"
+ "\n"
+ f"{tool_paths['atomic_densities_path']}\n"
+ "\n"
+ "\n"
+ "DDEC6\n"
+ "\n"
+ "\n"
+ "CHGCAR\n"
+ "\n"
+ )
+ with open(os.path.join(tmpdir, "job_control.txt"), "w") as f:
+ f.write(config_content)
+
+ # Run chargemol
+ env = os.environ.copy()
+ env["DDEC6_ATOMIC_DENSITIES_DIR"] = tool_paths["atomic_densities_path"]
+ subprocess.run(
+ [tool_paths["chargemol_path"]],
+ cwd=tmpdir,
+ timeout=CHARGEMOL_TIMEOUT,
+ check=True,
+ capture_output=True,
+ env=env,
+ )
+
+ return parse_ddec6_charges(tmpdir)
+
+ except subprocess.TimeoutExpired:
+ logger.warning(f"DDEC6 analysis timed out for {material_id}")
+ return None
+ except subprocess.CalledProcessError as e:
+ logger.warning(
+ f"DDEC6 subprocess failed for {material_id}: "
+ f"exit code {e.returncode}, stderr: {e.stderr}"
+ )
+ return None
+ except Exception as e:
+ logger.warning(f"DDEC6 analysis failed for {material_id}: {e}")
+ return None
+
+
+def _structure_to_row(
+ optimade_structure: OptimadeStructure,
+) -> dict:
+ """Convert an OptimadeStructure to a flat dict matching the Parquet schema.
+
+ JSON-serializes species and compressed charge density fields.
+ Converts ``Functional`` enums to their string value and ``datetime``
+ to ISO format.
+
+ Args:
+ optimade_structure: Validated ``OptimadeStructure`` instance.
+
+ Returns:
+ Flat dict with one key per ``PARQUET_COLUMNS`` entry, ready for
+ ``pyarrow.Table.from_pydict()``.
+ """
+ row = {}
+ for col in PARQUET_COLUMNS:
+ row[col] = getattr(optimade_structure, col, None)
+
+ # Convert datetime to ISO string
+ if row["last_modified"] is not None:
+ row["last_modified"] = row["last_modified"].isoformat()
+
+ # Convert Functional enum to string
+ if row["functional"] is not None:
+ row["functional"] = row["functional"].value
+
+ # JSON-serialize complex fields
+ if row["species"] is not None:
+ row["species"] = json.dumps(row["species"])
+
+ for col in [
+ "compressed_charge_density",
+ "compressed_aeccar0",
+ "compressed_aeccar1",
+ "compressed_aeccar2",
+ ]:
+ if row[col] is not None:
+ row[col] = json.dumps(row[col])
+
+ return row
+
+
+class LeMatRhoDirectPipeline:
+ """Direct S3-to-Parquet pipeline for LeMatRho charge density data.
+
+ Downloads from S3, compresses charge densities via pyrho, optionally runs
+ Bader and DDEC6 charge analysis, and writes Parquet files directly.
+ No PostgreSQL required.
+
+ Args:
+ config: Pipeline configuration.
+ debug: If ``True``, process sequentially in the main process.
+ """
+
+ def __init__(self, config: DirectPipelineConfig, debug: bool = False):
+ self.config = config
+ self.debug = debug
+ self._checkpoint_path = os.path.join(config.output_dir, ".checkpoint.txt")
+ self._failures_path = os.path.join(config.output_dir, ".failures.txt")
+ self._processed_ids: set[str] = set()
+ self._failed_ids: set[str] = set()
+
+ # Validate external tools
+ self._tool_paths = self._validate_tools()
+
+ # Create output directory
+ os.makedirs(config.output_dir, exist_ok=True)
+
+ def _validate_tools(self) -> dict:
+ """Check availability of external tools for Bader/DDEC6 analysis.
+
+ Returns:
+ Dict with keys ``bader_path``, ``chargemol_path``,
+ ``chgsum_script_path``, ``perl_path``, ``atomic_densities_path``,
+ ``can_generate_potcar``, ``can_run_bader``, ``can_run_ddec6``.
+ """
+ tools = {}
+
+ tools["bader_path"] = self.config.bader_path or shutil.which("bader")
+ if not tools["bader_path"]:
+ logger.warning(
+ "bader executable not found. Bader charges will not be computed."
+ )
+
+ tools["chargemol_path"] = (
+ self.config.chargemol_path
+ or shutil.which("Chargemol_09_26_2017_linux_serial")
+ or shutil.which("chargemol")
+ )
+ if not tools["chargemol_path"]:
+ logger.warning(
+ "chargemol executable not found. DDEC6 charges will not be computed."
+ )
+
+ tools["chgsum_script_path"] = self.config.chgsum_script_path
+ if tools["chgsum_script_path"] and not os.path.isfile(
+ tools["chgsum_script_path"]
+ ):
+ logger.warning(
+ f"chgsum.pl not found at {tools['chgsum_script_path']}. "
+ "Bader analysis requires this script."
+ )
+ tools["chgsum_script_path"] = None
+
+ tools["perl_path"] = shutil.which("perl")
+ if not tools["perl_path"]:
+ logger.warning(
+ "perl not found. Bader analysis requires perl for chgsum.pl."
+ )
+
+ tools["atomic_densities_path"] = self.config.atomic_densities_path
+ if tools["atomic_densities_path"] and not os.path.isdir(
+ tools["atomic_densities_path"]
+ ):
+ logger.warning(
+ f"Atomic densities directory not found: {tools['atomic_densities_path']}. "
+ "DDEC6 analysis requires this directory."
+ )
+ tools["atomic_densities_path"] = None
+
+ tools["can_generate_potcar"] = bool(os.environ.get("PMG_VASP_PSP_DIR"))
+ if not tools["can_generate_potcar"]:
+ logger.warning(
+ "PMG_VASP_PSP_DIR not set. POTCAR generation will fail. "
+ "Bader and DDEC6 analysis require POTCAR."
+ )
+
+ tools["can_run_bader"] = bool(
+ tools["bader_path"]
+ and tools["chgsum_script_path"]
+ and tools["perl_path"]
+ and tools["can_generate_potcar"]
+ )
+ tools["can_run_ddec6"] = bool(
+ tools["chargemol_path"]
+ and tools["atomic_densities_path"]
+ and tools["can_generate_potcar"]
+ )
+
+ return tools
+
+ def run(self) -> None:
+ """Run the full pipeline: list materials, process, write Parquet, optionally push.
+
+ In debug mode, materials are processed sequentially in the main process.
+ Otherwise, uses a ``ProcessPoolExecutor`` with a work-stealing pattern.
+ Writes Parquet chunks of ``config.parquet_chunk_size`` rows using atomic
+ rename. Appends each processed ID to a checkpoint file for crash recovery.
+ """
+ # 1. List material folders from S3
+ logger.info("Listing material folders from S3...")
+ material_ids = self._list_materials()
+ logger.info(f"Found {len(material_ids)} materials in S3")
+
+ # 2. Load checkpoint and failures, filter already-handled
+ self._processed_ids = self._load_checkpoint()
+ self._failed_ids = self._load_failures()
+ remaining = [
+ m
+ for m in material_ids
+ if m not in self._processed_ids and m not in self._failed_ids
+ ]
+
+ # Apply limit if set
+ if self.config.limit is not None and len(remaining) > self.config.limit:
+ remaining = remaining[: self.config.limit]
+
+ logger.info(
+ f"Already processed: {len(self._processed_ids)}, "
+ f"previously failed: {len(self._failed_ids)}, "
+ f"remaining: {len(remaining)}"
+ )
+
+ if not remaining:
+ logger.info("All materials already processed.")
+ if self.config.hf_repo_id:
+ self._push_to_huggingface()
+ return
+
+ # 3. Process materials
+ buffer = []
+ buffer_ids = []
+ chunk_index = self._get_next_chunk_index()
+ processed_count = 0
+ failed_count = 0
+
+ if self.debug:
+ for material_id in remaining:
+ result = self._process_material(
+ material_id, self.config, self._tool_paths
+ )
+ if result is not None:
+ buffer.append(result)
+ buffer_ids.append(material_id)
+ processed_count += 1
+ else:
+ self._append_failure(material_id)
+ failed_count += 1
+
+ if len(buffer) >= self.config.parquet_chunk_size:
+ self._write_parquet_chunk(buffer, chunk_index)
+ self._batch_checkpoint(buffer_ids)
+ buffer.clear()
+ buffer_ids.clear()
+ chunk_index += 1
+
+ total = processed_count + failed_count
+ if total % self.config.log_every == 0 and total > 0:
+ logger.info(
+ f"Progress: {processed_count} processed, {failed_count} failed"
+ )
+ else:
+ with concurrent.futures.ProcessPoolExecutor(
+ max_workers=self.config.num_workers
+ ) as executor:
+ remaining_iter = iter(remaining)
+ futures = {}
+
+ # Submit initial batch (2x workers for pipeline saturation)
+ initial_count = min(self.config.num_workers * 2, len(remaining))
+ for _ in range(initial_count):
+ mid = next(remaining_iter)
+ future = executor.submit(
+ self._process_material, mid, self.config, self._tool_paths
+ )
+ futures[future] = mid
+
+ while futures:
+ done, _ = concurrent.futures.wait(
+ futures,
+ return_when=concurrent.futures.FIRST_COMPLETED,
+ )
+ for future in done:
+ material_id = futures.pop(future)
+
+ try:
+ result = future.result()
+ if result is not None:
+ buffer.append(result)
+ buffer_ids.append(material_id)
+ processed_count += 1
+ else:
+ self._append_failure(material_id)
+ failed_count += 1
+ except Exception as e:
+ logger.warning(
+ f"Worker exception for {material_id}: {e}"
+ )
+ self._append_failure(material_id)
+ failed_count += 1
+
+ # Write chunk if buffer is full
+ if len(buffer) >= self.config.parquet_chunk_size:
+ self._write_parquet_chunk(buffer, chunk_index)
+ self._batch_checkpoint(buffer_ids)
+ buffer.clear()
+ buffer_ids.clear()
+ chunk_index += 1
+
+ # Submit replacement (work-stealing)
+ try:
+ next_id = next(remaining_iter)
+ f = executor.submit(
+ self._process_material,
+ next_id,
+ self.config,
+ self._tool_paths,
+ )
+ futures[f] = next_id
+ except StopIteration:
+ pass
+
+ total = processed_count + failed_count
+ if total % self.config.log_every == 0 and total > 0:
+ logger.info(
+ f"Progress: {processed_count} processed, "
+ f"{failed_count} failed"
+ )
+
+ # Write remaining buffer
+ if buffer:
+ self._write_parquet_chunk(buffer, chunk_index)
+ self._batch_checkpoint(buffer_ids)
+
+ logger.info(
+ f"Done. {processed_count} processed, {failed_count} failed."
+ )
+
+ # 4. Push if configured
+ if self.config.hf_repo_id:
+ self._push_to_huggingface()
+
+ def _list_materials(self) -> list[str]:
+ """List material folder prefixes from S3, filtered by ``VALID_PREFIXES``.
+
+ Returns:
+ List of material IDs (e.g. ``["agm000001", "mp-123", ...]``).
+ Order depends on S3 listing order (typically lexicographic).
+ """
+ client = get_authenticated_aws_client()
+ bucket = self.config.lematrho_bucket_name
+ paginator = client.get_paginator("list_objects_v2")
+
+ material_folders = []
+ for page in paginator.paginate(Bucket=bucket, Delimiter="/"):
+ for prefix_info in page.get("CommonPrefixes", []):
+ folder_name = prefix_info["Prefix"].rstrip("/")
+ if folder_name.startswith(VALID_PREFIXES):
+ material_folders.append(folder_name)
+
+ return material_folders
+
+ @staticmethod
+ def _process_material(
+ material_id: str,
+ config: DirectPipelineConfig,
+ tool_paths: dict,
+ ) -> Optional[dict]:
+ """Process a single material: download, compress, analyze, return row dict.
+
+ Designed to run in a worker process. Creates a fresh AWS client per
+ invocation (boto3 clients are not multiprocess-safe). Calls
+ ``gc.collect()`` after each material to free memory from large CHGCAR
+ arrays.
+
+ Args:
+ material_id: Material folder name, e.g. ``"agm000001"``.
+ config: Pipeline configuration.
+ tool_paths: Tool path dict from ``_validate_tools()``.
+
+ Returns:
+ Flat dict matching ``PARQUET_COLUMNS``, or ``None`` on failure.
+ """
+ bucket = config.lematrho_bucket_name
+ grid_shape = config.lematrho_grid_shape
+
+ try:
+ # Fresh client per worker (boto3 clients are NOT multiprocess-safe)
+ aws_client = get_authenticated_aws_client()
+
+ # Step 1: Download and parse vasprun.xml.gz for structure
+ vasprun_key = f"{material_id}/{RELAX_CALC_TYPE}/vasprun.xml.gz"
+ try:
+ vasprun_bytes = download_gz_file_from_s3(
+ aws_client, bucket, vasprun_key
+ )
+ structure = parse_vasprun_structure(vasprun_bytes)
+ del vasprun_bytes
+ except Exception as e:
+ logger.warning(
+ f"Failed to parse vasprun.xml.gz for {material_id}: {e}"
+ )
+ return None
+
+ # Step 2: Download and compress charge density files
+ compressed_grids = {}
+ # Memory trade-off: raw decompressed bytes are kept in memory for
+ # Bader/DDEC6 analysis to avoid a second S3 download. Each CHGCAR
+ # can be 100-500 MB, so peak RSS per worker ≈ sum of needed files.
+ raw_files = {}
+
+ # Determine which raw files to keep
+ need_raw = set()
+ if tool_paths["can_run_bader"]:
+ need_raw |= _BADER_FILES
+ if tool_paths["can_run_ddec6"]:
+ need_raw |= _DDEC6_FILES
+
+ for filename in STATIC_FILES:
+ s3_key = f"{material_id}/{STATIC_CALC_TYPE}/{filename}"
+ grid_name = GRID_KEY_MAP[filename]
+ try:
+ raw_bytes = download_gz_file_from_s3(aws_client, bucket, s3_key)
+
+ # Compress via pyrho
+ compressed = compress_chgcar(raw_bytes, grid_shape)
+ compressed_grids[grid_name] = compressed
+ del compressed
+
+ # Keep raw bytes if needed for analysis
+ if filename in need_raw:
+ vasp_name = filename.replace(".gz", "")
+ raw_files[vasp_name] = raw_bytes
+ del raw_bytes
+ except Exception as e:
+ logger.warning(
+ f"Failed to process {filename} for {material_id}: {e}"
+ )
+
+ # Step 3: Bader analysis (if tools available and files downloaded)
+ bader_charges = None
+ bader_atomic_volume = None
+ if tool_paths["can_run_bader"] and all(
+ k in raw_files for k in ["CHGCAR", "AECCAR0", "AECCAR2"]
+ ):
+ bader_charges, bader_atomic_volume = _run_bader_from_bytes(
+ structure, raw_files, tool_paths, material_id
+ )
+
+ # Step 4: DDEC6 analysis (if tools available and CHGCAR downloaded)
+ ddec6_charges = None
+ if tool_paths["can_run_ddec6"] and "CHGCAR" in raw_files:
+ ddec6_charges = _run_ddec6_from_bytes(
+ structure, raw_files, tool_paths, material_id
+ )
+
+ # Free raw file bytes
+ del raw_files
+
+ # Step 5: Build OptimadeStructure (Pydantic validation)
+ optimade_dict = get_optimade_from_pymatgen(structure)
+ cross_compatibility = get_cross_compatibility(optimade_dict["elements"])
+
+ optimade_structure = OptimadeStructure(
+ id=material_id,
+ source="lematrho",
+ immutable_id=material_id,
+ last_modified=datetime.now(),
+ **optimade_dict,
+ functional=Functional.PBE,
+ cross_compatibility=cross_compatibility,
+ compressed_charge_density=compressed_grids.get("charge_density"),
+ compressed_aeccar0=compressed_grids.get("aeccar0"),
+ compressed_aeccar1=compressed_grids.get("aeccar1"),
+ compressed_aeccar2=compressed_grids.get("aeccar2"),
+ charge_density_grid_shape=list(grid_shape),
+ bader_charges=bader_charges,
+ bader_atomic_volume=bader_atomic_volume,
+ ddec6_charges=ddec6_charges,
+ compute_space_group=True,
+ compute_bawl_hash=True,
+ )
+
+ # Step 6: Convert to flat dict for Parquet
+ row = _structure_to_row(optimade_structure)
+ del optimade_structure, compressed_grids
+
+ # Step 7: Force garbage collection in worker
+ gc.collect()
+
+ return row
+
+ except Exception as e:
+ logger.error(f"Failed to process material {material_id}: {e}")
+ gc.collect()
+ return None
+
+ def _load_checkpoint(self) -> set[str]:
+ """Load processed material IDs from checkpoint file.
+
+ Returns:
+ Set of already-processed material IDs. Empty set if no checkpoint exists.
+ """
+ if not os.path.exists(self._checkpoint_path):
+ return set()
+ with open(self._checkpoint_path, "r") as f:
+ return {line.strip() for line in f if line.strip()}
+
+ def _append_checkpoint(self, material_id: str) -> None:
+ """Append a material ID to the checkpoint file and flush to disk.
+
+ Args:
+ material_id: ID to record as processed.
+ """
+ with open(self._checkpoint_path, "a") as f:
+ f.write(material_id + "\n")
+ f.flush()
+ os.fsync(f.fileno())
+
+ def _batch_checkpoint(self, material_ids: list[str]) -> None:
+ """Append multiple material IDs to the checkpoint file atomically.
+
+ Called after a Parquet chunk is successfully flushed so that
+ checkpoint and Parquet stay in sync.
+
+ Args:
+ material_ids: IDs to record as processed.
+ """
+ with open(self._checkpoint_path, "a") as f:
+ for mid in material_ids:
+ f.write(mid + "\n")
+ f.flush()
+ os.fsync(f.fileno())
+
+ def _load_failures(self) -> set[str]:
+ """Load failed material IDs from failures file.
+
+ Returns:
+ Set of material IDs that previously failed. Empty set if no file.
+ """
+ if not os.path.exists(self._failures_path):
+ return set()
+ with open(self._failures_path, "r") as f:
+ return {line.strip() for line in f if line.strip()}
+
+ def _append_failure(self, material_id: str) -> None:
+ """Record a failed material ID so it is skipped on resume.
+
+ Args:
+ material_id: ID that failed processing.
+ """
+ with open(self._failures_path, "a") as f:
+ f.write(material_id + "\n")
+ f.flush()
+ os.fsync(f.fileno())
+
+ def _get_next_chunk_index(self) -> int:
+ """Determine next chunk index from existing ``chunk_*.parquet`` files.
+
+ Ignores ``.tmp`` files left by interrupted writes.
+
+ Returns:
+ Next available chunk index (0 if no existing chunks).
+ """
+ existing = glob(os.path.join(self.config.output_dir, "chunk_*.parquet"))
+ if not existing:
+ return 0
+ indices = []
+ for path in existing:
+ basename = os.path.basename(path)
+ try:
+ idx = int(basename.replace("chunk_", "").replace(".parquet", ""))
+ indices.append(idx)
+ except ValueError:
+ pass
+ return max(indices) + 1 if indices else 0
+
+ def _write_parquet_chunk(self, rows: list[dict], chunk_index: int) -> None:
+ """Write rows to a Parquet file atomically (write ``.tmp``, then rename).
+
+ Args:
+ rows: List of flat dicts matching ``PARQUET_COLUMNS``.
+ chunk_index: Chunk sequence number (zero-padded in filename).
+ """
+ final_path = os.path.join(
+ self.config.output_dir, f"chunk_{chunk_index:06d}.parquet"
+ )
+ tmp_path = final_path + ".tmp"
+
+ # Build column-oriented data from row-oriented dicts
+ columns = {col: [row.get(col) for row in rows] for col in PARQUET_COLUMNS}
+ table = pa.table(columns, schema=PARQUET_SCHEMA)
+ pq.write_table(table, tmp_path)
+
+ # Atomic rename
+ os.rename(tmp_path, final_path)
+ logger.info(
+ f"Wrote chunk {chunk_index} ({len(rows)} rows) to {final_path}"
+ )
+
+ def _push_to_huggingface(self) -> None:
+ """Load all Parquet files and push to HuggingFace as a private dataset."""
+ from datasets import load_dataset
+
+ parquet_files = os.path.join(self.config.output_dir, "chunk_*.parquet")
+ logger.info(f"Loading Parquet files from {parquet_files}")
+
+ dataset = load_dataset("parquet", data_files=parquet_files)
+
+ logger.info(f"Pushing to HuggingFace repo: {self.config.hf_repo_id}")
+ dataset["train"].push_to_hub(
+ self.config.hf_repo_id,
+ token=self.config.hf_token,
+ private=True,
+ )
+ logger.info("Push complete.")
diff --git a/src/lematerial_fetcher/fetcher/lematrho/transform.py b/src/lematerial_fetcher/fetcher/lematrho/transform.py
new file mode 100644
index 0000000..fe6081e
--- /dev/null
+++ b/src/lematerial_fetcher/fetcher/lematrho/transform.py
@@ -0,0 +1,493 @@
+# Copyright 2025 Entalpic
+import os
+import shutil
+import subprocess
+import tempfile
+from datetime import datetime
+from typing import Optional
+
+from pymatgen.core import Structure
+
+from lematerial_fetcher.database.postgres import OptimadeDatabase, StructuresDatabase
+from lematerial_fetcher.fetcher.lematrho.utils import (
+ BADER_TIMEOUT,
+ CHARGEMOL_TIMEOUT,
+ CHGSUM_TIMEOUT,
+ STATIC_CALC_TYPE,
+ download_gz_file_from_s3,
+ write_potcar,
+)
+from lematerial_fetcher.models.models import RawStructure
+from lematerial_fetcher.models.optimade import Functional, OptimadeStructure
+from lematerial_fetcher.transform import BaseTransformer
+from lematerial_fetcher.utils.aws import get_authenticated_aws_client
+from lematerial_fetcher.utils.logging import logger
+from lematerial_fetcher.utils.structure import get_optimade_from_pymatgen
+
+
+def get_cross_compatibility(elements: list[str]) -> bool:
+ """Determine cross-compatibility for LeMatRho structures.
+
+ Yb-containing structures are excluded (same policy as Alexandria).
+
+ Args:
+ elements: List of element symbols in the structure.
+
+ Returns:
+ ``True`` if the structure is cross-compatible, ``False`` otherwise.
+ """
+ return "Yb" not in elements
+
+
+def parse_acf_dat(filepath: str) -> tuple[list[float], list[float]]:
+ """Parse Bader ACF.dat file for electron counts and atomic volumes.
+
+ Args:
+ filepath: Path to the ACF.dat file produced by bader.
+
+ Returns:
+ Tuple of ``(electron_counts, atomic_volumes)`` lists, one entry per atom.
+ """
+ electron_counts = []
+ atomic_volumes = []
+ with open(filepath) as f:
+ lines = f.readlines()
+
+ # Skip header (first 2 lines), parse data rows until separator line
+ for line in lines[2:]:
+ stripped = line.strip()
+ if stripped.startswith("-") or not stripped:
+ break
+ parts = stripped.split()
+ if len(parts) >= 7:
+ electron_counts.append(float(parts[4])) # CHARGE column
+ atomic_volumes.append(float(parts[6])) # ATOMIC VOL column
+
+ return electron_counts, atomic_volumes
+
+
+def read_potcar_zval(filepath: str) -> dict[str, float]:
+ """Read valence electron counts from a POTCAR file.
+
+ Parses TITEL and ZVAL lines to build an element-to-valence-electrons mapping.
+
+ Args:
+ filepath: Path to the POTCAR file.
+
+ Returns:
+ Dict mapping element symbols to their number of valence electrons.
+ """
+ zval = {}
+ current_element = None
+ with open(filepath) as f:
+ for line in f:
+ if "TITEL" in line:
+ parts = line.split()
+ if len(parts) >= 4:
+ # Handle element names like 'Si_d' -> 'Si'
+ current_element = parts[3].split("_")[0]
+ elif "ZVAL" in line and current_element:
+ parts = line.split("=")
+ if len(parts) >= 2:
+ try:
+ zval[current_element] = float(parts[1].split()[0])
+ current_element = None
+ except (ValueError, IndexError):
+ pass
+ return zval
+
+
+def parse_ddec6_charges(tmpdir: str) -> list[float]:
+ """Parse DDEC6 net atomic charges from chargemol output.
+
+ Reads ``DDEC6_even_tempered_net_atomic_charges.xyz`` from *tmpdir*.
+
+ Args:
+ tmpdir: Directory containing chargemol output files.
+
+ Returns:
+ List of net DDEC6 charges, one per atom.
+ """
+ filepath = os.path.join(tmpdir, "DDEC6_even_tempered_net_atomic_charges.xyz")
+ charges = []
+ with open(filepath) as f:
+ lines = f.readlines()
+
+ n_atoms = int(lines[0].strip())
+ for line in lines[2 : 2 + n_atoms]:
+ parts = line.split()
+ if len(parts) >= 5:
+ charges.append(float(parts[4]))
+
+ return charges
+
+
+class LeMatRhoTransformer(BaseTransformer):
+ """Transformer for LeMatRho charge density data.
+
+ Transforms raw structures (with compressed charge densities from the fetch step)
+ into ``OptimadeStructure`` objects. Optionally runs Bader and DDEC6 charge
+ analysis using external tools.
+
+ External tool requirements:
+ - ``bader``: Bader charge analysis executable
+ - ``perl`` + ``chgsum.pl``: For summing AECCAR0 + AECCAR2
+ - ``chargemol``: DDEC6 charge partitioning executable
+ - ``PMG_VASP_PSP_DIR``: Env var for POTCAR generation
+ - atomic densities directory: For DDEC6/chargemol analysis
+
+ Args:
+ config: Transformer configuration.
+ database_class: Database class for storing results.
+ structure_class: Pydantic model class for validated structures.
+ debug: If ``True``, process sequentially for debugging.
+ """
+
+ def __init__(
+ self,
+ config=None,
+ database_class=OptimadeDatabase,
+ structure_class=OptimadeStructure,
+ debug=False,
+ ):
+ super().__init__(config, database_class, structure_class, debug)
+ self._aws_client = None
+ self._bader_path = None
+ self._chargemol_path = None
+ self._chgsum_script_path = None
+ self._perl_path = None
+ self._atomic_densities_path = None
+ self._can_generate_potcar = False
+ self._validate_tools()
+
+ def _validate_tools(self) -> None:
+ """Check availability of external tools and log warnings for missing ones.
+
+ Sets instance attributes ``_bader_path``, ``_chargemol_path``,
+ ``_chgsum_script_path``, ``_perl_path``, ``_atomic_densities_path``,
+ and ``_can_generate_potcar``.
+ """
+ self._bader_path = getattr(self.config, "bader_path", None) or shutil.which(
+ "bader"
+ )
+ if not self._bader_path:
+ logger.warning(
+ "bader executable not found. Bader charges will not be computed."
+ )
+
+ self._chargemol_path = getattr(
+ self.config, "chargemol_path", None
+ ) or shutil.which("Chargemol_09_26_2017_linux_serial")
+ if not self._chargemol_path:
+ self._chargemol_path = shutil.which("chargemol")
+ if not self._chargemol_path:
+ logger.warning(
+ "chargemol executable not found. DDEC6 charges will not be computed."
+ )
+
+ self._chgsum_script_path = getattr(
+ self.config, "chgsum_script_path", None
+ )
+ if self._chgsum_script_path and not os.path.isfile(self._chgsum_script_path):
+ logger.warning(
+ f"chgsum.pl not found at {self._chgsum_script_path}. "
+ "Bader analysis requires this script."
+ )
+ self._chgsum_script_path = None
+
+ self._perl_path = shutil.which("perl")
+ if not self._perl_path:
+ logger.warning(
+ "perl not found. Bader analysis requires perl for chgsum.pl."
+ )
+
+ self._atomic_densities_path = getattr(
+ self.config, "atomic_densities_path", None
+ )
+ if self._atomic_densities_path and not os.path.isdir(
+ self._atomic_densities_path
+ ):
+ logger.warning(
+ f"Atomic densities directory not found: {self._atomic_densities_path}. "
+ "DDEC6 analysis requires this directory."
+ )
+ self._atomic_densities_path = None
+
+ if not os.environ.get("PMG_VASP_PSP_DIR"):
+ logger.warning(
+ "PMG_VASP_PSP_DIR not set. POTCAR generation will fail. "
+ "Bader and DDEC6 analysis require POTCAR."
+ )
+ self._can_generate_potcar = False
+ else:
+ self._can_generate_potcar = True
+
+ @property
+ def aws_client(self):
+ """Lazy-initialized authenticated S3 client."""
+ if self._aws_client is None:
+ self._aws_client = get_authenticated_aws_client()
+ return self._aws_client
+
+ @property
+ def can_run_bader(self) -> bool:
+ """Check if all Bader analysis prerequisites are met."""
+ return bool(
+ self._bader_path
+ and self._chgsum_script_path
+ and self._perl_path
+ and self._can_generate_potcar
+ )
+
+ @property
+ def can_run_ddec6(self) -> bool:
+ """Check if all DDEC6 analysis prerequisites are met."""
+ return bool(
+ self._chargemol_path
+ and self._atomic_densities_path
+ and self._can_generate_potcar
+ )
+
+ def transform_row(
+ self,
+ raw_structure: RawStructure,
+ source_db: Optional[StructuresDatabase] = None,
+ task_table_name: Optional[str] = None,
+ ) -> list[OptimadeStructure]:
+ """Transform a raw LeMatRho structure into an OptimadeStructure.
+
+ Args:
+ raw_structure: Raw structure from the fetch step, with charge
+ density data in its attributes dict.
+ source_db: Not used for LeMatRho.
+ task_table_name: Not used for LeMatRho.
+
+ Returns:
+ Single-element list containing the transformed ``OptimadeStructure``.
+ """
+ attrs = raw_structure.attributes
+ material_id = raw_structure.id
+
+ # Extract pymatgen Structure from raw data
+ structure = Structure.from_dict(attrs["structure"])
+
+ # Get base OPTIMADE fields from structure
+ optimade_dict = get_optimade_from_pymatgen(structure)
+
+ # Extract pre-computed compressed grids from fetch step
+ compressed_charge_density = attrs.get("compressed_charge_density")
+ compressed_aeccar0 = attrs.get("compressed_aeccar0")
+ compressed_aeccar1 = attrs.get("compressed_aeccar1")
+ compressed_aeccar2 = attrs.get("compressed_aeccar2")
+ grid_shape = attrs.get("grid_shape")
+ s3_prefix = attrs.get("s3_prefix")
+
+ # Cross-compatibility (exclude Yb, same policy as Alexandria)
+ cross_compatibility = get_cross_compatibility(optimade_dict["elements"])
+
+ # Bader analysis (independent from DDEC6)
+ bader_charges = None
+ bader_atomic_volume = None
+ if self.can_run_bader and s3_prefix:
+ bader_charges, bader_atomic_volume = self._run_bader_analysis(
+ structure, s3_prefix, material_id
+ )
+
+ # DDEC6 analysis (independent from Bader)
+ ddec6_charges = None
+ if self.can_run_ddec6 and s3_prefix:
+ ddec6_charges = self._run_ddec6_analysis(
+ structure, s3_prefix, material_id
+ )
+
+ optimade_structure = OptimadeStructure(
+ id=material_id,
+ source="lematrho",
+ immutable_id=material_id,
+ last_modified=raw_structure.last_modified or datetime.now(),
+ **optimade_dict,
+ functional=Functional.PBE,
+ cross_compatibility=cross_compatibility,
+ compressed_charge_density=compressed_charge_density,
+ compressed_aeccar0=compressed_aeccar0,
+ compressed_aeccar1=compressed_aeccar1,
+ compressed_aeccar2=compressed_aeccar2,
+ charge_density_grid_shape=grid_shape,
+ bader_charges=bader_charges,
+ bader_atomic_volume=bader_atomic_volume,
+ ddec6_charges=ddec6_charges,
+ compute_space_group=True,
+ compute_bawl_hash=True,
+ )
+
+ return [optimade_structure]
+
+ def _run_bader_analysis(
+ self, structure: Structure, s3_prefix: str, material_id: str
+ ) -> tuple[Optional[list[float]], Optional[list[float]]]:
+ """Run Bader charge analysis.
+
+ Downloads CHGCAR, AECCAR0, AECCAR2 from S3, runs ``chgsum.pl`` to
+ create the reference charge density, then runs bader and parses results.
+
+ Args:
+ structure: Pymatgen Structure for POTCAR generation.
+ s3_prefix: S3 folder prefix for this material.
+ material_id: Material identifier, used for logging.
+
+ Returns:
+ Tuple of ``(net_charges, atomic_volumes)`` or ``(None, None)``
+ on failure.
+ """
+ try:
+ bucket = self.config.lematrho_bucket_name
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Download raw VASP charge density files
+ for filename in ["CHGCAR.gz", "AECCAR0.gz", "AECCAR2.gz"]:
+ key = f"{s3_prefix}/{STATIC_CALC_TYPE}/{filename}"
+ data = download_gz_file_from_s3(self.aws_client, bucket, key)
+ outname = filename.replace(".gz", "")
+ with open(os.path.join(tmpdir, outname), "wb") as f:
+ f.write(data)
+ del data
+
+ # Generate POTCAR
+ write_potcar(structure, tmpdir)
+
+ # Sum AECCAR0 + AECCAR2 -> CHGCAR_sum
+ subprocess.run(
+ [
+ self._perl_path,
+ self._chgsum_script_path,
+ "AECCAR0",
+ "AECCAR2",
+ ],
+ cwd=tmpdir,
+ timeout=CHGSUM_TIMEOUT,
+ check=True,
+ capture_output=True,
+ )
+
+ # Run Bader analysis with reference charge density
+ subprocess.run(
+ [self._bader_path, "CHGCAR", "-ref", "CHGCAR_sum"],
+ cwd=tmpdir,
+ timeout=BADER_TIMEOUT,
+ check=True,
+ capture_output=True,
+ )
+
+ # Parse ACF.dat for electron counts and atomic volumes
+ electron_counts, atomic_volumes = parse_acf_dat(
+ os.path.join(tmpdir, "ACF.dat")
+ )
+
+ # Compute net charges: valence_electrons - bader_electron_count
+ zval = read_potcar_zval(os.path.join(tmpdir, "POTCAR"))
+ net_charges = []
+ for site, electron_count in zip(structure.sites, electron_counts):
+ element = str(site.specie)
+ valence = zval.get(element, 0)
+ net_charges.append(valence - electron_count)
+
+ return net_charges, atomic_volumes
+
+ except subprocess.TimeoutExpired:
+ logger.warning(f"Bader analysis timed out for {material_id}")
+ return None, None
+ except subprocess.CalledProcessError as e:
+ logger.warning(
+ f"Bader subprocess failed for {material_id}: "
+ f"exit code {e.returncode}, stderr: {e.stderr}"
+ )
+ return None, None
+ except Exception as e:
+ logger.warning(f"Bader analysis failed for {material_id}: {e}")
+ return None, None
+
+ def _run_ddec6_analysis(
+ self, structure: Structure, s3_prefix: str, material_id: str
+ ) -> Optional[list[float]]:
+ """Run DDEC6 charge analysis via chargemol.
+
+ Downloads CHGCAR from S3, generates POTCAR, writes chargemol config,
+ runs chargemol, and parses DDEC6 net charges.
+
+ Args:
+ structure: Pymatgen Structure for POTCAR generation.
+ s3_prefix: S3 folder prefix for this material.
+ material_id: Material identifier, used for logging.
+
+ Returns:
+ DDEC6 net charges per site, or ``None`` on failure.
+ """
+ try:
+ bucket = self.config.lematrho_bucket_name
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Download CHGCAR
+ key = f"{s3_prefix}/{STATIC_CALC_TYPE}/CHGCAR.gz"
+ data = download_gz_file_from_s3(self.aws_client, bucket, key)
+ with open(os.path.join(tmpdir, "CHGCAR"), "wb") as f:
+ f.write(data)
+ del data
+
+ # Generate POTCAR
+ write_potcar(structure, tmpdir)
+
+ # Write chargemol job control file
+ self._write_chargemol_config(tmpdir)
+
+ # Run chargemol
+ env = os.environ.copy()
+ env["DDEC6_ATOMIC_DENSITIES_DIR"] = self._atomic_densities_path
+ subprocess.run(
+ [self._chargemol_path],
+ cwd=tmpdir,
+ timeout=CHARGEMOL_TIMEOUT,
+ check=True,
+ capture_output=True,
+ env=env,
+ )
+
+ return parse_ddec6_charges(tmpdir)
+
+ except subprocess.TimeoutExpired:
+ logger.warning(f"DDEC6 analysis timed out for {material_id}")
+ return None
+ except subprocess.CalledProcessError as e:
+ logger.warning(
+ f"DDEC6 subprocess failed for {material_id}: "
+ f"exit code {e.returncode}, stderr: {e.stderr}"
+ )
+ return None
+ except Exception as e:
+ logger.warning(f"DDEC6 analysis failed for {material_id}: {e}")
+ return None
+
+ def _write_chargemol_config(self, tmpdir: str) -> None:
+ """Write ``job_control.txt`` for chargemol DDEC6 analysis.
+
+ Args:
+ tmpdir: Directory where the config file will be written.
+ """
+ config_content = (
+ "\n"
+ "0.0\n"
+ "\n"
+ "\n"
+ ".true.\n"
+ ".true.\n"
+ ".true.\n"
+ "\n"
+ "\n"
+ f"{self._atomic_densities_path}\n"
+ "\n"
+ "\n"
+ "DDEC6\n"
+ "\n"
+ "\n"
+ "CHGCAR\n"
+ "\n"
+ )
+ with open(os.path.join(tmpdir, "job_control.txt"), "w") as f:
+ f.write(config_content)
diff --git a/src/lematerial_fetcher/fetcher/lematrho/utils.py b/src/lematerial_fetcher/fetcher/lematrho/utils.py
new file mode 100644
index 0000000..df49b67
--- /dev/null
+++ b/src/lematerial_fetcher/fetcher/lematrho/utils.py
@@ -0,0 +1,170 @@
+# Copyright 2025 Entalpic
+import gzip
+import os
+import tempfile
+from datetime import datetime
+from typing import Any, Optional
+
+from pymatgen.core import Structure
+from pymatgen.io.vasp import Chgcar, Vasprun
+
+from lematerial_fetcher.models.models import RawStructure
+from lematerial_fetcher.utils.logging import logger
+
+# ── S3 folder structure constants ──────────────────────────────────────────────
+STATIC_CALC_TYPE = "LeMatRhoStaticMaker"
+RELAX_CALC_TYPE = "LeMatRhoRelaxMaker_1"
+STATIC_FILES = ["CHGCAR.gz", "AECCAR0.gz", "AECCAR1.gz", "AECCAR2.gz"]
+RELAX_FILES = ["vasprun.xml.gz"]
+
+# Only process materials with these ID prefixes
+VALID_PREFIXES = ("oqmd-", "mp-", "agm")
+
+# Conservative default due to high memory usage per CHGCAR (~hundreds of MB)
+DEFAULT_MAX_WORKERS = 4
+
+# Map from S3 filename to compressed grid key name
+GRID_KEY_MAP = {
+ "CHGCAR.gz": "charge_density",
+ "AECCAR0.gz": "aeccar0",
+ "AECCAR1.gz": "aeccar1",
+ "AECCAR2.gz": "aeccar2",
+}
+
+# Subprocess timeout constants (seconds)
+BADER_TIMEOUT = 600
+CHGSUM_TIMEOUT = 300
+CHARGEMOL_TIMEOUT = 600
+
+
+def download_gz_file_from_s3(client: Any, bucket: str, key: str) -> bytes:
+ """Download and decompress a gzipped file from S3.
+
+ Args:
+ client: Boto3 S3 client.
+ bucket: S3 bucket name.
+ key: S3 object key.
+
+ Returns:
+ Decompressed file contents as raw bytes.
+ """
+ response = client.get_object(Bucket=bucket, Key=key)
+ body = response["Body"]
+ try:
+ compressed = body.read()
+ decompressed = gzip.decompress(compressed)
+ del compressed
+ return decompressed
+ finally:
+ body.close()
+
+
+def parse_vasprun_structure(vasprun_bytes: bytes) -> Structure:
+ """Parse a vasprun.xml to extract the final relaxed structure.
+
+ Writes bytes to a temporary file because pymatgen's ``Vasprun`` requires
+ a filesystem path, not a file-like object.
+
+ Args:
+ vasprun_bytes: Raw vasprun.xml content.
+
+ Returns:
+ The final relaxed pymatgen Structure.
+ """
+ with tempfile.TemporaryDirectory() as tmpdir:
+ path = os.path.join(tmpdir, "vasprun.xml")
+ with open(path, "wb") as f:
+ f.write(vasprun_bytes)
+ vasprun = Vasprun(
+ path,
+ parse_dos=False,
+ parse_eigen=False,
+ parse_potcar_file=False,
+ )
+ return vasprun.final_structure
+
+
+def compress_chgcar(chgcar_bytes: bytes, grid_shape: tuple[int, int, int]) -> list:
+ """Parse a CHGCAR file and compress its charge density using pyrho.
+
+ Writes bytes to a temporary file because pymatgen's ``Chgcar.from_file``
+ requires a filesystem path, not a file-like object.
+
+ Args:
+ chgcar_bytes: Raw CHGCAR file content (uncompressed VASP format).
+ grid_shape: Target grid shape for lossy compression, e.g. ``(15, 15, 15)``.
+
+ Returns:
+ Compressed charge density grid as a nested Python list.
+ """
+ from pyrho.charge_density import ChargeDensity
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ path = os.path.join(tmpdir, "CHGCAR")
+ with open(path, "wb") as f:
+ f.write(chgcar_bytes)
+ chgcar = Chgcar.from_file(path)
+ charge_density = ChargeDensity.from_pmg(chgcar)
+ compressed = charge_density.pgrids["total"].lossy_smooth_compression(grid_shape)
+ result = compressed.tolist()
+ del chgcar, charge_density, compressed
+ return result
+
+
+def build_raw_structure(
+ material_id: str,
+ structure: Structure,
+ compressed_grids: dict[str, Optional[list]],
+ grid_shape: tuple[int, int, int],
+ s3_prefix: str,
+) -> RawStructure:
+ """Build a RawStructure from parsed charge density data.
+
+ Args:
+ material_id: Material identifier, e.g. ``"agm000001"``.
+ structure: Pymatgen Structure parsed from vasprun.xml.
+ compressed_grids: Dict mapping grid names (``"charge_density"``,
+ ``"aeccar0"``, ``"aeccar1"``, ``"aeccar2"``) to compressed
+ grid lists or ``None``.
+ grid_shape: Grid shape used for compression.
+ s3_prefix: S3 prefix path for the material folder.
+
+ Returns:
+ A ``RawStructure`` ready for database insertion.
+ """
+ attributes = {
+ "structure": structure.as_dict(),
+ "compressed_charge_density": compressed_grids.get("charge_density"),
+ "compressed_aeccar0": compressed_grids.get("aeccar0"),
+ "compressed_aeccar1": compressed_grids.get("aeccar1"),
+ "compressed_aeccar2": compressed_grids.get("aeccar2"),
+ "grid_shape": list(grid_shape),
+ "s3_prefix": s3_prefix,
+ }
+
+ return RawStructure(
+ id=material_id,
+ type="lematrho",
+ attributes=attributes,
+ last_modified=datetime.now(),
+ )
+
+
+def write_potcar(structure: Structure, tmpdir: str) -> None:
+ """Generate a POTCAR file for the given structure.
+
+ Uses ``MatPESStaticSet`` to select pseudopotentials consistent with
+ Materials Project settings and writes the resulting POTCAR to *tmpdir*.
+
+ Args:
+ structure: Pymatgen Structure for which to generate the POTCAR.
+ tmpdir: Directory where ``POTCAR`` will be written.
+
+ Raises:
+ OSError: If ``PMG_VASP_PSP_DIR`` is not set or the pseudopotential
+ files cannot be found.
+ """
+ from pymatgen.io.vasp.sets import MatPESStaticSet
+
+ input_set = MatPESStaticSet(structure)
+ input_set.potcar.write_file(os.path.join(tmpdir, "POTCAR"))
diff --git a/src/lematerial_fetcher/models/optimade.py b/src/lematerial_fetcher/models/optimade.py
index ac07685..065cce9 100644
--- a/src/lematerial_fetcher/models/optimade.py
+++ b/src/lematerial_fetcher/models/optimade.py
@@ -184,6 +184,45 @@ class OptimadeStructure(BaseModel):
description="BAWL fingerprint hash",
)
+ # Charge density fields (LeMatRho)
+ compressed_charge_density: Optional[list] = Field(
+ None,
+ description="Compressed charge density grid from pyrho lossy compression",
+ )
+ compressed_aeccar0: Optional[list] = Field(
+ None,
+ description="Compressed AECCAR0 (all-electron core charge density) grid",
+ )
+ compressed_aeccar1: Optional[list] = Field(
+ None,
+ description="Compressed AECCAR1 (pseudo valence charge density) grid",
+ )
+ compressed_aeccar2: Optional[list] = Field(
+ None,
+ description="Compressed AECCAR2 (pseudo core charge density) grid",
+ )
+ charge_density_grid_shape: Optional[list[int]] = Field(
+ None,
+ min_length=3,
+ max_length=3,
+ description="Shape of the compressed charge density grid [nx, ny, nz]",
+ )
+ bader_charges: Optional[list[float]] = Field(
+ None,
+ min_length=1,
+ description="Bader charges per site",
+ )
+ bader_atomic_volume: Optional[list[float]] = Field(
+ None,
+ min_length=1,
+ description="Bader atomic volumes per site",
+ )
+ ddec6_charges: Optional[list[float]] = Field(
+ None,
+ min_length=1,
+ description="DDEC6 charges per site",
+ )
+
def __init__(
self,
compute_space_group: bool = True,
@@ -513,6 +552,15 @@ def check_consistency(self):
self.charges = self._validate_with_number_of_sites(
self.charges, nsites, "charges"
)
+ self.bader_charges = self._validate_with_number_of_sites(
+ self.bader_charges, nsites, "bader_charges"
+ )
+ self.bader_atomic_volume = self._validate_with_number_of_sites(
+ self.bader_atomic_volume, nsites, "bader_atomic_volume"
+ )
+ self.ddec6_charges = self._validate_with_number_of_sites(
+ self.ddec6_charges, nsites, "ddec6_charges"
+ )
# Validation using the Pymatgen structure
structure = Structure(
diff --git a/src/lematerial_fetcher/models/utils/enums.py b/src/lematerial_fetcher/models/utils/enums.py
index 7525b5a..828aa89 100644
--- a/src/lematerial_fetcher/models/utils/enums.py
+++ b/src/lematerial_fetcher/models/utils/enums.py
@@ -13,3 +13,4 @@ class Source(str, Enum):
ALEXANDRIA = "alexandria"
MP = "mp"
OQMD = "oqmd"
+ LEMATRHO = "lematrho"
diff --git a/src/lematerial_fetcher/push.py b/src/lematerial_fetcher/push.py
index c8f01cc..df43a56 100644
--- a/src/lematerial_fetcher/push.py
+++ b/src/lematerial_fetcher/push.py
@@ -135,6 +135,14 @@ def _get_optimade_features(self) -> Features:
"cross_compatibility": Value("bool"),
"bawl_fingerprint": Value("string"),
"space_group_it_number": Value("int32"),
+ "compressed_charge_density": Value("string"),
+ "compressed_aeccar0": Value("string"),
+ "compressed_aeccar1": Value("string"),
+ "compressed_aeccar2": Value("string"),
+ "charge_density_grid_shape": Sequence(Value("int32")),
+ "bader_charges": Sequence(Value("float64")),
+ "bader_atomic_volume": Sequence(Value("float64")),
+ "ddec6_charges": Sequence(Value("float64")),
}
)
@@ -171,6 +179,14 @@ def _get_trajectories_features(self) -> Features:
del features["charges"]
del features["total_magnetization"]
del features["bawl_fingerprint"]
+ del features["compressed_charge_density"]
+ del features["compressed_aeccar0"]
+ del features["compressed_aeccar1"]
+ del features["compressed_aeccar2"]
+ del features["charge_density_grid_shape"]
+ del features["bader_charges"]
+ del features["bader_atomic_volume"]
+ del features["ddec6_charges"]
convert_features_dict.update(
{
@@ -467,6 +483,34 @@ def convert_species(batch):
desc="Converting species column to string",
)
+ # Convert compressed charge density fields from nested lists to JSON strings
+ json_serialized_columns = [
+ "compressed_charge_density",
+ "compressed_aeccar0",
+ "compressed_aeccar1",
+ "compressed_aeccar2",
+ ]
+ columns_to_convert = [
+ col
+ for col in json_serialized_columns
+ if col in dataset["train"].column_names
+ ]
+ if columns_to_convert:
+
+ def convert_charge_density(batch):
+ for col in columns_to_convert:
+ batch[col] = [
+ json.dumps(v) if v is not None else None for v in batch[col]
+ ]
+ return batch
+
+ dataset = dataset.map(
+ convert_charge_density,
+ batched=True,
+ num_proc=self.config.num_workers,
+ desc="Converting charge density columns to string",
+ )
+
for split in dataset.keys():
dataset[split] = dataset[split].cast(
features=self.features, num_proc=self.config.num_workers
diff --git a/src/lematerial_fetcher/utils/aws.py b/src/lematerial_fetcher/utils/aws.py
index 154e464..402606a 100644
--- a/src/lematerial_fetcher/utils/aws.py
+++ b/src/lematerial_fetcher/utils/aws.py
@@ -27,6 +27,32 @@ def get_aws_client(region_name: str = "us-east-1"):
return s3_client
+def get_authenticated_aws_client(region_name: str = "us-east-1"):
+ """Returns a configured S3 client using the default AWS credential chain.
+
+ Uses AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN,
+ or IAM roles for authentication. Includes adaptive retry configuration.
+
+ Parameters
+ ----------
+ region_name: str, default='us-east-1'
+ The AWS region for the S3 client.
+
+ Returns
+ -------
+ s3_client: boto3.client
+ A configured S3 client with authenticated credentials
+ """
+ s3_client = boto3.client(
+ "s3",
+ config=Config(
+ retries={"max_attempts": 3, "mode": "adaptive"},
+ region_name=region_name,
+ ),
+ )
+ return s3_client
+
+
def get_latest_collection_version_prefix(
client, bucket_name: str, bucket_prefix: str, collections_prefix: str
) -> str:
diff --git a/src/lematerial_fetcher/utils/cli.py b/src/lematerial_fetcher/utils/cli.py
index 20c2f15..8608cba 100644
--- a/src/lematerial_fetcher/utils/cli.py
+++ b/src/lematerial_fetcher/utils/cli.py
@@ -250,6 +250,165 @@ def add_mp_fetch_options(f):
return f
+def add_lematrho_fetch_options(f):
+ """Add LeMatRho fetch options to a command."""
+ decorators = [
+ click.option(
+ "--lematrho-bucket-name",
+ type=str,
+ default="lemat-rho",
+ envvar="LEMATERIALFETCHER_LEMATRHO_BUCKET_NAME",
+ help="LeMatRho S3 bucket name.",
+ ),
+ click.option(
+ "--grid-shape",
+ type=(int, int, int),
+ default=(15, 15, 15),
+ envvar="LEMATERIALFETCHER_LEMATRHO_GRID_SHAPE",
+ help="Grid shape for pyrho charge density compression (nx ny nz).",
+ ),
+ ]
+ for decorator in reversed(decorators):
+ f = decorator(f)
+ return f
+
+
+def add_lematrho_transform_options(f):
+ """Add LeMatRho transform options to a command."""
+ decorators = [
+ click.option(
+ "--lematrho-bucket-name",
+ type=str,
+ default="lemat-rho",
+ envvar="LEMATERIALFETCHER_LEMATRHO_BUCKET_NAME",
+ help="LeMatRho S3 bucket name for re-downloading raw files during transform.",
+ ),
+ click.option(
+ "--bader-path",
+ type=str,
+ envvar="LEMATERIALFETCHER_BADER_PATH",
+ help="Path to the bader executable. If not provided, will search PATH.",
+ ),
+ click.option(
+ "--chargemol-path",
+ type=str,
+ envvar="LEMATERIALFETCHER_CHARGEMOL_PATH",
+ help="Path to the chargemol executable. If not provided, will search PATH.",
+ ),
+ click.option(
+ "--chgsum-script-path",
+ type=str,
+ envvar="LEMATERIALFETCHER_CHGSUM_SCRIPT_PATH",
+ help="Path to the chgsum.pl perl script for Bader charge summation.",
+ ),
+ click.option(
+ "--atomic-densities-path",
+ type=str,
+ envvar="LEMATERIALFETCHER_ATOMIC_DENSITIES_PATH",
+ help="Path to atomic densities directory for DDEC6/chargemol analysis.",
+ ),
+ ]
+ for decorator in reversed(decorators):
+ f = decorator(f)
+ return f
+
+
+def add_lematrho_direct_options(f):
+ """Add options for the direct LeMatRho S3-to-Parquet pipeline."""
+ decorators = [
+ click.option(
+ "--output-dir",
+ type=str,
+ default="./lematrho_output",
+ envvar="LEMATERIALFETCHER_LEMATRHO_OUTPUT_DIR",
+ help="Directory to write Parquet output files.",
+ ),
+ click.option(
+ "--parquet-chunk-size",
+ type=int,
+ default=1000,
+ envvar="LEMATERIALFETCHER_LEMATRHO_PARQUET_CHUNK_SIZE",
+ help="Number of rows per Parquet chunk file.",
+ ),
+ click.option(
+ "--num-workers",
+ type=int,
+ default=4,
+ envvar="LEMATERIALFETCHER_NUM_WORKERS",
+ help="Number of parallel worker processes.",
+ ),
+ click.option(
+ "--log-every",
+ type=int,
+ default=100,
+ envvar="LEMATERIALFETCHER_LOG_EVERY",
+ help="Log progress every N materials.",
+ ),
+ click.option(
+ "--limit",
+ type=int,
+ default=None,
+ envvar="LEMATERIALFETCHER_LEMATRHO_LIMIT",
+ help="Max number of materials to process (for testing). Default: no limit.",
+ ),
+ click.option(
+ "--lematrho-bucket-name",
+ type=str,
+ default="lemat-rho",
+ envvar="LEMATERIALFETCHER_LEMATRHO_BUCKET_NAME",
+ help="LeMatRho S3 bucket name.",
+ ),
+ click.option(
+ "--grid-shape",
+ type=(int, int, int),
+ default=(15, 15, 15),
+ envvar="LEMATERIALFETCHER_LEMATRHO_GRID_SHAPE",
+ help="Grid shape for pyrho charge density compression (nx ny nz).",
+ ),
+ click.option(
+ "--hf-repo-id",
+ type=str,
+ default=None,
+ envvar="LEMATERIALFETCHER_HF_REPO_ID",
+ help="HuggingFace repository ID for pushing. If not set, skip push.",
+ ),
+ click.option(
+ "--hf-token",
+ type=str,
+ default=None,
+ envvar="LEMATERIALFETCHER_HF_TOKEN",
+ help="HuggingFace token for pushing.",
+ ),
+ click.option(
+ "--bader-path",
+ type=str,
+ envvar="LEMATERIALFETCHER_BADER_PATH",
+ help="Path to the bader executable. If not provided, will search PATH.",
+ ),
+ click.option(
+ "--chargemol-path",
+ type=str,
+ envvar="LEMATERIALFETCHER_CHARGEMOL_PATH",
+ help="Path to the chargemol executable. If not provided, will search PATH.",
+ ),
+ click.option(
+ "--chgsum-script-path",
+ type=str,
+ envvar="LEMATERIALFETCHER_CHGSUM_SCRIPT_PATH",
+ help="Path to the chgsum.pl perl script for Bader charge summation.",
+ ),
+ click.option(
+ "--atomic-densities-path",
+ type=str,
+ envvar="LEMATERIALFETCHER_ATOMIC_DENSITIES_PATH",
+ help="Path to atomic densities directory for DDEC6/chargemol analysis.",
+ ),
+ ]
+ for decorator in reversed(decorators):
+ f = decorator(f)
+ return f
+
+
def add_push_options(f):
"""Add push options to a command."""
decorators = [
diff --git a/src/lematerial_fetcher/utils/config.py b/src/lematerial_fetcher/utils/config.py
index f7c5746..2478e83 100644
--- a/src/lematerial_fetcher/utils/config.py
+++ b/src/lematerial_fetcher/utils/config.py
@@ -30,6 +30,8 @@ class FetcherConfig(BaseConfig):
mp_bucket_prefix: str
mysql_config: Optional[dict] = None
oqmd_download_dir: Optional[str] = None
+ lematrho_bucket_name: Optional[str] = None
+ lematrho_grid_shape: Optional[tuple[int, int, int]] = None
@dataclass
@@ -43,6 +45,35 @@ class TransformerConfig(BaseConfig):
db_fetch_batch_size: Optional[int] = None
mp_task_table_name: Optional[str] = None
mysql_config: Optional[dict] = None
+ lematrho_bucket_name: Optional[str] = None
+ bader_path: Optional[str] = None
+ chargemol_path: Optional[str] = None
+ chgsum_script_path: Optional[str] = None
+ atomic_densities_path: Optional[str] = None
+
+
+@dataclass
+class DirectPipelineConfig:
+ """Config for the direct S3-to-Parquet pipeline (no PostgreSQL)."""
+
+ # S3 source
+ lematrho_bucket_name: str = "lemat-rho"
+ lematrho_grid_shape: tuple[int, int, int] = (15, 15, 15)
+ # Output
+ output_dir: str = "./lematrho_output"
+ parquet_chunk_size: int = 1000
+ # Processing
+ num_workers: int = 4
+ log_every: int = 100
+ limit: Optional[int] = None
+ # HuggingFace (optional)
+ hf_repo_id: Optional[str] = None
+ hf_token: Optional[str] = None
+ # External tools (all optional — missing tools result in None fields)
+ bader_path: Optional[str] = None
+ chargemol_path: Optional[str] = None
+ chgsum_script_path: Optional[str] = None
+ atomic_densities_path: Optional[str] = None
@dataclass
@@ -156,6 +187,8 @@ def load_fetcher_config(
mp_bucket_name: Optional[str] = None,
mp_bucket_prefix: Optional[str] = None,
oqmd_download_dir: Optional[str] = None,
+ lematrho_bucket_name: Optional[str] = None,
+ lematrho_grid_shape: Optional[tuple[int, int, int]] = None,
mysql_host: str = "localhost",
mysql_user: Optional[str] = None,
# No MySQL password parameter
@@ -189,6 +222,8 @@ def load_fetcher_config(
"mp_bucket_name": mp_bucket_name,
"mp_bucket_prefix": mp_bucket_prefix,
"oqmd_download_dir": oqmd_download_dir,
+ "lematrho_bucket_name": lematrho_bucket_name,
+ "lematrho_grid_shape": lematrho_grid_shape,
}
# Validate required fields
@@ -245,6 +280,12 @@ def load_transformer_config(
mysql_user: Optional[str] = None,
mysql_database: str = "lematerial",
mysql_cert_path: Optional[str] = None,
+ # LeMatRho-specific params
+ lematrho_bucket_name: Optional[str] = None,
+ bader_path: Optional[str] = None,
+ chargemol_path: Optional[str] = None,
+ chgsum_script_path: Optional[str] = None,
+ atomic_densities_path: Optional[str] = None,
**base_config_kwargs: Any,
) -> TransformerConfig:
"""Loads transformer config from passed arguments.
@@ -337,6 +378,11 @@ def load_transformer_config(
**base_config,
**config,
mysql_config=mysql_config,
+ lematrho_bucket_name=lematrho_bucket_name,
+ bader_path=bader_path,
+ chargemol_path=chargemol_path,
+ chgsum_script_path=chgsum_script_path,
+ atomic_densities_path=atomic_densities_path,
)
@@ -399,3 +445,41 @@ def load_push_config(
**base_config,
**config,
)
+
+
+def load_direct_pipeline_config(
+ lematrho_bucket_name: str = "lemat-rho",
+ grid_shape: tuple[int, int, int] = (15, 15, 15),
+ output_dir: str = "./lematrho_output",
+ parquet_chunk_size: int = 1000,
+ num_workers: int = 4,
+ log_every: int = 100,
+ limit: Optional[int] = None,
+ hf_repo_id: Optional[str] = None,
+ hf_token: Optional[str] = None,
+ bader_path: Optional[str] = None,
+ chargemol_path: Optional[str] = None,
+ chgsum_script_path: Optional[str] = None,
+ atomic_densities_path: Optional[str] = None,
+ **_kwargs: Any,
+) -> DirectPipelineConfig:
+ """Load config for the direct S3-to-Parquet pipeline.
+
+ The common workflow is that arguments are passed by Click.
+ No database credentials needed — this pipeline writes Parquet directly.
+ """
+ return DirectPipelineConfig(
+ lematrho_bucket_name=lematrho_bucket_name,
+ lematrho_grid_shape=grid_shape,
+ output_dir=output_dir,
+ parquet_chunk_size=parquet_chunk_size,
+ num_workers=num_workers,
+ log_every=log_every,
+ limit=limit,
+ hf_repo_id=hf_repo_id,
+ hf_token=hf_token,
+ bader_path=bader_path,
+ chargemol_path=chargemol_path,
+ chgsum_script_path=chgsum_script_path,
+ atomic_densities_path=atomic_densities_path,
+ )
diff --git a/tests/fetcher/lematrho/test_lematrho_fetch.py b/tests/fetcher/lematrho/test_lematrho_fetch.py
new file mode 100644
index 0000000..fc7f5bd
--- /dev/null
+++ b/tests/fetcher/lematrho/test_lematrho_fetch.py
@@ -0,0 +1,616 @@
+# Copyright 2025 Entalpic
+import gzip
+from datetime import datetime
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from lematerial_fetcher.database.postgres import DatasetVersions, StructuresDatabase
+from lematerial_fetcher.fetch import ItemsInfo
+from lematerial_fetcher.fetcher.lematrho.fetch import (
+ DEFAULT_MAX_WORKERS,
+ RELAX_CALC_TYPE,
+ STATIC_CALC_TYPE,
+ STATIC_FILES,
+ VALID_PREFIXES,
+ LeMatRhoFetcher,
+)
+from lematerial_fetcher.fetcher.lematrho.utils import (
+ build_raw_structure,
+ download_gz_file_from_s3,
+)
+from lematerial_fetcher.utils.config import FetcherConfig
+
+
+@pytest.fixture
+def mock_aws_client():
+ return MagicMock()
+
+
+@pytest.fixture
+def mock_db():
+ return MagicMock(spec=StructuresDatabase)
+
+
+@pytest.fixture
+def mock_version_db():
+ return MagicMock(spec=DatasetVersions)
+
+
+@pytest.fixture
+def mock_config():
+ return FetcherConfig(
+ base_url="https://api.test.com",
+ db_conn_str="postgresql://test:test@localhost:5432/test",
+ table_name="test_lematrho_raw",
+ page_limit=10,
+ page_offset=0,
+ mp_bucket_name="",
+ mp_bucket_prefix="",
+ log_dir="./logs",
+ max_retries=3,
+ num_workers=2,
+ retry_delay=2,
+ log_every=100,
+ lematrho_bucket_name="lemat-rho",
+ lematrho_grid_shape=(15, 15, 15),
+ )
+
+
+# ---------------------------------------------------------------------------
+# get_items_to_process tests
+# ---------------------------------------------------------------------------
+
+
+class TestGetItemsToProcess:
+ def test_filters_by_valid_prefix(self, mock_aws_client, mock_config, mock_version_db):
+ """Only folders with oqmd-, mp-, or agm prefixes should be returned."""
+ mock_paginator = MagicMock()
+ mock_aws_client.get_paginator.return_value = mock_paginator
+ mock_paginator.paginate.return_value = [
+ {
+ "CommonPrefixes": [
+ {"Prefix": "agm000001/"},
+ {"Prefix": "mp-12345/"},
+ {"Prefix": "oqmd-678/"},
+ {"Prefix": "test-invalid/"},
+ {"Prefix": "random-folder/"},
+ ]
+ }
+ ]
+
+ with (
+ patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls,
+ patch("lematerial_fetcher.fetch.StructuresDatabase"),
+ ):
+ mock_ver_cls.return_value = mock_version_db
+ mock_version_db.get_last_synced_version.return_value = None
+
+ fetcher = LeMatRhoFetcher(config=mock_config, debug=True)
+ fetcher.aws_client = mock_aws_client
+
+ items = fetcher.get_items_to_process()
+
+ assert items.total_count == 3
+ assert set(items.items) == {"agm000001", "mp-12345", "oqmd-678"}
+
+ def test_ignores_unknown_prefixes(self, mock_aws_client, mock_config, mock_version_db):
+ """Folders like 'test-123/' or 'data/' should be excluded."""
+ mock_paginator = MagicMock()
+ mock_aws_client.get_paginator.return_value = mock_paginator
+ mock_paginator.paginate.return_value = [
+ {
+ "CommonPrefixes": [
+ {"Prefix": "test-123/"},
+ {"Prefix": "data/"},
+ {"Prefix": "backup-20240101/"},
+ ]
+ }
+ ]
+
+ with (
+ patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls,
+ patch("lematerial_fetcher.fetch.StructuresDatabase"),
+ ):
+ mock_ver_cls.return_value = mock_version_db
+ mock_version_db.get_last_synced_version.return_value = None
+
+ fetcher = LeMatRhoFetcher(config=mock_config, debug=True)
+ fetcher.aws_client = mock_aws_client
+
+ items = fetcher.get_items_to_process()
+
+ assert items.total_count == 0
+ assert items.items == []
+
+ def test_handles_empty_bucket(self, mock_aws_client, mock_config, mock_version_db):
+ """Empty bucket should return zero items."""
+ mock_paginator = MagicMock()
+ mock_aws_client.get_paginator.return_value = mock_paginator
+ mock_paginator.paginate.return_value = [{}]
+
+ with (
+ patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls,
+ patch("lematerial_fetcher.fetch.StructuresDatabase"),
+ ):
+ mock_ver_cls.return_value = mock_version_db
+ mock_version_db.get_last_synced_version.return_value = None
+
+ fetcher = LeMatRhoFetcher(config=mock_config, debug=True)
+ fetcher.aws_client = mock_aws_client
+
+ items = fetcher.get_items_to_process()
+
+ assert items.total_count == 0
+ assert items.items == []
+
+ def test_handles_pagination(self, mock_aws_client, mock_config, mock_version_db):
+ """Should handle multiple pages of results."""
+ mock_paginator = MagicMock()
+ mock_aws_client.get_paginator.return_value = mock_paginator
+ mock_paginator.paginate.return_value = [
+ {"CommonPrefixes": [{"Prefix": "agm000001/"}]},
+ {"CommonPrefixes": [{"Prefix": "agm000002/"}]},
+ ]
+
+ with (
+ patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls,
+ patch("lematerial_fetcher.fetch.StructuresDatabase"),
+ ):
+ mock_ver_cls.return_value = mock_version_db
+ mock_version_db.get_last_synced_version.return_value = None
+
+ fetcher = LeMatRhoFetcher(config=mock_config, debug=True)
+ fetcher.aws_client = mock_aws_client
+
+ items = fetcher.get_items_to_process()
+
+ assert items.total_count == 2
+ assert items.items == ["agm000001", "agm000002"]
+
+ def test_raises_without_bucket_name(self, mock_aws_client, mock_version_db):
+ """Should raise ValueError if bucket name is not configured."""
+ config = FetcherConfig(
+ base_url="https://api.test.com",
+ db_conn_str="postgresql://test:test@localhost:5432/test",
+ table_name="test_table",
+ page_limit=10,
+ page_offset=0,
+ mp_bucket_name="",
+ mp_bucket_prefix="",
+ log_dir="./logs",
+ max_retries=3,
+ num_workers=2,
+ retry_delay=2,
+ log_every=100,
+ lematrho_bucket_name=None,
+ )
+
+ with (
+ patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls,
+ patch("lematerial_fetcher.fetch.StructuresDatabase"),
+ ):
+ mock_ver_cls.return_value = mock_version_db
+ mock_version_db.get_last_synced_version.return_value = None
+
+ fetcher = LeMatRhoFetcher(config=config, debug=True)
+ fetcher.aws_client = mock_aws_client
+
+ with pytest.raises(ValueError, match="lematrho_bucket_name"):
+ fetcher.get_items_to_process()
+
+
+# ---------------------------------------------------------------------------
+# _process_batch tests
+# ---------------------------------------------------------------------------
+
+
+class TestProcessBatch:
+ def test_happy_path(self, mock_config):
+ """Successful processing: downloads vasprun + all 4 charge files, inserts to DB."""
+ mock_client = MagicMock()
+ mock_db_instance = MagicMock(spec=StructuresDatabase)
+ mock_structure = MagicMock()
+ mock_structure.as_dict.return_value = {"lattice": {}, "sites": []}
+
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client"
+ ) as mock_auth,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase"
+ ) as mock_db_cls,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.download_gz_file_from_s3"
+ ) as mock_download,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.parse_vasprun_structure"
+ ) as mock_parse,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.compress_chgcar"
+ ) as mock_compress,
+ ):
+ mock_auth.return_value = mock_client
+ mock_db_cls.return_value = mock_db_instance
+ mock_download.return_value = b"fake data"
+ mock_parse.return_value = mock_structure
+ mock_compress.return_value = [[[1.0, 2.0]]]
+
+ result = LeMatRhoFetcher._process_batch(
+ "agm000001", mock_config, {"occurred": False}
+ )
+
+ assert result is True
+ mock_db_instance.insert_data.assert_called_once()
+ # vasprun + 4 charge files = 5 downloads
+ assert mock_download.call_count == 5
+ # 4 CHGCAR/AECCAR compressions
+ assert mock_compress.call_count == 4
+
+ def test_missing_vasprun_returns_false(self, mock_config):
+ """If vasprun.xml.gz is missing, should return False."""
+ mock_client = MagicMock()
+
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client"
+ ) as mock_auth,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase"
+ ),
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.download_gz_file_from_s3"
+ ) as mock_download,
+ ):
+ mock_auth.return_value = mock_client
+ mock_download.side_effect = Exception("NoSuchKey: vasprun.xml.gz")
+
+ result = LeMatRhoFetcher._process_batch(
+ "agm000001", mock_config, {"occurred": False}
+ )
+
+ assert result is False
+
+ def test_missing_aeccar1_still_processes_others(self, mock_config):
+ """If AECCAR1.gz is missing, other files should still be processed."""
+ mock_client = MagicMock()
+ mock_db_instance = MagicMock(spec=StructuresDatabase)
+ mock_structure = MagicMock()
+ mock_structure.as_dict.return_value = {"lattice": {}, "sites": []}
+
+ def download_side_effect(client, bucket, key):
+ if "AECCAR1.gz" in key:
+ raise Exception("NoSuchKey")
+ return b"fake data"
+
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client"
+ ) as mock_auth,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase"
+ ) as mock_db_cls,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.download_gz_file_from_s3"
+ ) as mock_download,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.parse_vasprun_structure"
+ ) as mock_parse,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.compress_chgcar"
+ ) as mock_compress,
+ ):
+ mock_auth.return_value = mock_client
+ mock_db_cls.return_value = mock_db_instance
+ mock_download.side_effect = download_side_effect
+ mock_parse.return_value = mock_structure
+ mock_compress.return_value = [[[1.0]]]
+
+ result = LeMatRhoFetcher._process_batch(
+ "agm000001", mock_config, {"occurred": False}
+ )
+
+ assert result is True
+ mock_db_instance.insert_data.assert_called_once()
+ # vasprun + 3 successful charge files (AECCAR1 failed) = 4 downloads succeed
+ # But download is called 5 times (1 vasprun + 4 charge files, 1 raises)
+ assert mock_download.call_count == 5
+ # Only 3 compressions (AECCAR1 failed before compression)
+ assert mock_compress.call_count == 3
+
+ def test_all_charge_files_missing_still_inserts(self, mock_config):
+ """If all charge files fail but vasprun succeeds, still insert with None grids."""
+ mock_client = MagicMock()
+ mock_db_instance = MagicMock(spec=StructuresDatabase)
+ mock_structure = MagicMock()
+ mock_structure.as_dict.return_value = {"lattice": {}, "sites": []}
+
+ call_count = {"n": 0}
+
+ def download_side_effect(client, bucket, key):
+ call_count["n"] += 1
+ if "vasprun" in key:
+ return b"fake vasprun"
+ raise Exception("NoSuchKey")
+
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client"
+ ) as mock_auth,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase"
+ ) as mock_db_cls,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.download_gz_file_from_s3"
+ ) as mock_download,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.parse_vasprun_structure"
+ ) as mock_parse,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.compress_chgcar"
+ ) as mock_compress,
+ ):
+ mock_auth.return_value = mock_client
+ mock_db_cls.return_value = mock_db_instance
+ mock_download.side_effect = download_side_effect
+ mock_parse.return_value = mock_structure
+
+ result = LeMatRhoFetcher._process_batch(
+ "agm000001", mock_config, {"occurred": False}
+ )
+
+ assert result is True
+ mock_db_instance.insert_data.assert_called_once()
+ # No compressions since all charge file downloads failed
+ mock_compress.assert_not_called()
+
+ # Verify the inserted structure has None grids
+ inserted = mock_db_instance.insert_data.call_args[0][0]
+ assert inserted.attributes["compressed_charge_density"] is None
+ assert inserted.attributes["compressed_aeccar0"] is None
+ assert inserted.attributes["compressed_aeccar1"] is None
+ assert inserted.attributes["compressed_aeccar2"] is None
+
+ def test_critical_error_sets_manager_flag(self, mock_config):
+ """A connection error should flag the manager_dict for shutdown."""
+ manager_dict = {"occurred": False}
+
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client"
+ ) as mock_auth,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase"
+ ),
+ ):
+ mock_auth.side_effect = Exception("Connection refused")
+
+ result = LeMatRhoFetcher._process_batch(
+ "agm000001", mock_config, manager_dict
+ )
+
+ assert result is False
+ assert manager_dict["occurred"] is True
+
+ def test_correct_s3_keys(self, mock_config):
+ """Verify the exact S3 keys constructed for downloads."""
+ mock_client = MagicMock()
+ mock_db_instance = MagicMock(spec=StructuresDatabase)
+ mock_structure = MagicMock()
+ mock_structure.as_dict.return_value = {}
+
+ download_calls = []
+
+ def capture_downloads(client, bucket, key):
+ download_calls.append(key)
+ return b"data"
+
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client"
+ ) as mock_auth,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase"
+ ) as mock_db_cls,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.download_gz_file_from_s3"
+ ) as mock_download,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.parse_vasprun_structure"
+ ) as mock_parse,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.compress_chgcar"
+ ) as mock_compress,
+ ):
+ mock_auth.return_value = mock_client
+ mock_db_cls.return_value = mock_db_instance
+ mock_download.side_effect = capture_downloads
+ mock_parse.return_value = mock_structure
+ mock_compress.return_value = []
+
+ LeMatRhoFetcher._process_batch(
+ "agm000001", mock_config, {"occurred": False}
+ )
+
+ expected_keys = [
+ f"agm000001/{RELAX_CALC_TYPE}/vasprun.xml.gz",
+ f"agm000001/{STATIC_CALC_TYPE}/CHGCAR.gz",
+ f"agm000001/{STATIC_CALC_TYPE}/AECCAR0.gz",
+ f"agm000001/{STATIC_CALC_TYPE}/AECCAR1.gz",
+ f"agm000001/{STATIC_CALC_TYPE}/AECCAR2.gz",
+ ]
+ assert download_calls == expected_keys
+
+ def test_uses_config_grid_shape(self, mock_config):
+ """Verify that the configured grid shape is passed to compress_chgcar."""
+ mock_client = MagicMock()
+ mock_db_instance = MagicMock(spec=StructuresDatabase)
+ mock_structure = MagicMock()
+ mock_structure.as_dict.return_value = {}
+
+ mock_config.lematrho_grid_shape = (20, 20, 20)
+
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client"
+ ) as mock_auth,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.StructuresDatabase"
+ ) as mock_db_cls,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.download_gz_file_from_s3"
+ ) as mock_download,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.parse_vasprun_structure"
+ ) as mock_parse,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.compress_chgcar"
+ ) as mock_compress,
+ ):
+ mock_auth.return_value = mock_client
+ mock_db_cls.return_value = mock_db_instance
+ mock_download.return_value = b"data"
+ mock_parse.return_value = mock_structure
+ mock_compress.return_value = []
+
+ LeMatRhoFetcher._process_batch(
+ "agm000001", mock_config, {"occurred": False}
+ )
+
+ # All 4 compress calls should use (20, 20, 20)
+ for call in mock_compress.call_args_list:
+ assert call[0][1] == (20, 20, 20)
+
+
+# ---------------------------------------------------------------------------
+# Utility function tests
+# ---------------------------------------------------------------------------
+
+
+class TestDownloadGzFile:
+ def test_decompresses_gzipped_content(self):
+ """Verify gzip decompression works correctly."""
+ import io
+
+ original = b"hello world test content"
+ compressed = gzip.compress(original)
+
+ mock_client = MagicMock()
+ mock_client.get_object.return_value = {
+ "Body": MagicMock(read=MagicMock(return_value=compressed))
+ }
+
+ result = download_gz_file_from_s3(mock_client, "bucket", "key.gz")
+ assert result == original
+
+ def test_propagates_s3_errors(self):
+ """S3 download errors should propagate."""
+ mock_client = MagicMock()
+ mock_client.get_object.side_effect = Exception("NoSuchKey")
+
+ with pytest.raises(Exception, match="NoSuchKey"):
+ download_gz_file_from_s3(mock_client, "bucket", "key.gz")
+
+
+class TestBuildRawStructure:
+ def test_builds_correct_structure(self):
+ """Verify the RawStructure has correct fields and type."""
+ from pymatgen.core import Lattice, Structure
+
+ structure = Structure(
+ Lattice.cubic(3.0),
+ ["Si", "Si"],
+ [[0, 0, 0], [0.5, 0.5, 0.5]],
+ )
+ compressed_grids = {
+ "charge_density": [[[1.0]]],
+ "aeccar0": [[[2.0]]],
+ "aeccar1": None,
+ "aeccar2": [[[4.0]]],
+ }
+
+ raw = build_raw_structure(
+ material_id="agm000001",
+ structure=structure,
+ compressed_grids=compressed_grids,
+ grid_shape=(15, 15, 15),
+ s3_prefix="agm000001",
+ )
+
+ assert raw.id == "agm000001"
+ assert raw.type == "lematrho"
+ assert raw.attributes["compressed_charge_density"] == [[[1.0]]]
+ assert raw.attributes["compressed_aeccar0"] == [[[2.0]]]
+ assert raw.attributes["compressed_aeccar1"] is None
+ assert raw.attributes["compressed_aeccar2"] == [[[4.0]]]
+ assert raw.attributes["grid_shape"] == [15, 15, 15]
+ assert raw.attributes["s3_prefix"] == "agm000001"
+ assert raw.attributes["structure"] is not None
+ assert raw.last_modified is not None
+
+
+# ---------------------------------------------------------------------------
+# Constants tests
+# ---------------------------------------------------------------------------
+
+
+def test_valid_prefixes():
+ """Verify VALID_PREFIXES covers expected material ID patterns."""
+ assert "oqmd-123".startswith(VALID_PREFIXES)
+ assert "mp-456".startswith(VALID_PREFIXES)
+ assert "agm000001".startswith(VALID_PREFIXES)
+ assert not "test-789".startswith(VALID_PREFIXES)
+ assert not "random".startswith(VALID_PREFIXES)
+
+
+def test_static_files_list():
+ """Verify STATIC_FILES contains all expected charge density files."""
+ assert "CHGCAR.gz" in STATIC_FILES
+ assert "AECCAR0.gz" in STATIC_FILES
+ assert "AECCAR1.gz" in STATIC_FILES
+ assert "AECCAR2.gz" in STATIC_FILES
+ assert len(STATIC_FILES) == 4
+
+
+def test_calc_type_constants():
+ """Verify S3 subfolder constants."""
+ assert STATIC_CALC_TYPE == "LeMatRhoStaticMaker"
+ assert RELAX_CALC_TYPE == "LeMatRhoRelaxMaker_1"
+
+
+# ---------------------------------------------------------------------------
+# Fetcher lifecycle tests
+# ---------------------------------------------------------------------------
+
+
+class TestLeMatRhoFetcher:
+ def test_setup_resources(self, mock_config, mock_aws_client, mock_version_db):
+ """Test that setup_resources initializes the authenticated AWS client."""
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.fetch.get_authenticated_aws_client"
+ ) as mock_auth,
+ patch("lematerial_fetcher.fetch.StructuresDatabase") as mock_db_cls,
+ patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls,
+ ):
+ mock_auth.return_value = mock_aws_client
+ mock_ver_cls.return_value = mock_version_db
+
+ fetcher = LeMatRhoFetcher(config=mock_config)
+ fetcher.setup_resources()
+
+ mock_auth.assert_called_once()
+ assert fetcher.aws_client is mock_aws_client
+
+ def test_get_new_version_returns_today(self, mock_config, mock_version_db):
+ """Version should be today's date."""
+ with (
+ patch("lematerial_fetcher.fetch.DatasetVersions") as mock_ver_cls,
+ patch("lematerial_fetcher.fetch.StructuresDatabase"),
+ ):
+ mock_ver_cls.return_value = mock_version_db
+
+ fetcher = LeMatRhoFetcher(config=mock_config, debug=True)
+ version = fetcher.get_new_version()
+
+ assert version == datetime.now().strftime("%Y-%m-%d")
diff --git a/tests/fetcher/lematrho/test_lematrho_pipeline.py b/tests/fetcher/lematrho/test_lematrho_pipeline.py
new file mode 100644
index 0000000..b377119
--- /dev/null
+++ b/tests/fetcher/lematrho/test_lematrho_pipeline.py
@@ -0,0 +1,1536 @@
+# Copyright 2025 Entalpic
+import json
+import os
+import subprocess
+import tempfile
+from datetime import datetime
+from unittest.mock import MagicMock, patch
+
+import pyarrow.parquet as pq
+import pytest
+
+from lematerial_fetcher.fetcher.lematrho.pipeline import (
+ PARQUET_COLUMNS,
+ PARQUET_SCHEMA,
+ LeMatRhoDirectPipeline,
+ _run_bader_from_bytes,
+ _run_ddec6_from_bytes,
+ _structure_to_row,
+)
+from lematerial_fetcher.models.optimade import Functional, OptimadeStructure
+from lematerial_fetcher.utils.config import DirectPipelineConfig
+
+# Minimal pymatgen Structure dict for testing
+_MOCK_STRUCTURE_DICT = {
+ "lattice": {
+ "matrix": [[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]],
+ "a": 3.0,
+ "b": 3.0,
+ "c": 3.0,
+ "alpha": 90.0,
+ "beta": 90.0,
+ "gamma": 90.0,
+ },
+ "sites": [
+ {
+ "species": [{"element": "Si", "occu": 1}],
+ "abc": [0.0, 0.0, 0.0],
+ "xyz": [0.0, 0.0, 0.0],
+ "label": "Si",
+ },
+ {
+ "species": [{"element": "O", "occu": 1}],
+ "abc": [0.5, 0.5, 0.5],
+ "xyz": [1.5, 1.5, 1.5],
+ "label": "O",
+ },
+ ],
+}
+
+
+def _make_mock_optimade_dict():
+ """Create a mock OPTIMADE dict like get_optimade_from_pymatgen would return."""
+ return {
+ "elements": ["O", "Si"],
+ "nelements": 2,
+ "elements_ratios": [0.5, 0.5],
+ "nsites": 2,
+ "cartesian_site_positions": [[0.0, 0.0, 0.0], [1.5, 1.5, 1.5]],
+ "species_at_sites": ["Si", "O"],
+ "species": [
+ {
+ "mass": None,
+ "name": "O",
+ "attached": None,
+ "nattached": None,
+ "concentration": [1],
+ "original_name": None,
+ "chemical_symbols": ["O"],
+ },
+ {
+ "mass": None,
+ "name": "Si",
+ "attached": None,
+ "nattached": None,
+ "concentration": [1],
+ "original_name": None,
+ "chemical_symbols": ["Si"],
+ },
+ ],
+ "chemical_formula_anonymous": "AB",
+ "chemical_formula_descriptive": "Si1 O1",
+ "chemical_formula_reduced": "OSi",
+ "dimension_types": [1, 1, 1],
+ "nperiodic_dimensions": 3,
+ "lattice_vectors": [[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]],
+ }
+
+
+@pytest.fixture
+def tmp_output_dir():
+ """Create a temp directory for pipeline output."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ yield tmpdir
+
+
+@pytest.fixture
+def mock_config(tmp_output_dir):
+ return DirectPipelineConfig(
+ lematrho_bucket_name="test-bucket",
+ lematrho_grid_shape=(10, 10, 10),
+ output_dir=tmp_output_dir,
+ parquet_chunk_size=3,
+ num_workers=2,
+ log_every=10,
+ )
+
+
+@pytest.fixture
+def no_tools():
+ """Tool paths dict where no tools are available."""
+ return {
+ "bader_path": None,
+ "chargemol_path": None,
+ "chgsum_script_path": None,
+ "perl_path": None,
+ "atomic_densities_path": None,
+ "can_generate_potcar": False,
+ "can_run_bader": False,
+ "can_run_ddec6": False,
+ }
+
+
+# ---------------------------------------------------------------------------
+# TestListMaterials
+# ---------------------------------------------------------------------------
+
+
+class TestListMaterials:
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_filters_by_valid_prefix(self, mock_get_client, mock_config):
+ mock_client = MagicMock()
+ mock_get_client.return_value = mock_client
+
+ mock_paginator = MagicMock()
+ mock_client.get_paginator.return_value = mock_paginator
+ mock_paginator.paginate.return_value = [
+ {
+ "CommonPrefixes": [
+ {"Prefix": "mp-123/"},
+ {"Prefix": "agm000001/"},
+ {"Prefix": "oqmd-456/"},
+ {"Prefix": "unknown-789/"},
+ {"Prefix": "test-data/"},
+ ]
+ }
+ ]
+
+ with patch.object(LeMatRhoDirectPipeline, "_validate_tools", return_value={}):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ result = pipeline._list_materials()
+
+ assert result == ["mp-123", "agm000001", "oqmd-456"]
+
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_excludes_processed_ids(self, mock_get_client, mock_config):
+ """Checkpoint filtering removes already-processed materials."""
+ mock_client = MagicMock()
+ mock_get_client.return_value = mock_client
+
+ mock_paginator = MagicMock()
+ mock_client.get_paginator.return_value = mock_paginator
+ mock_paginator.paginate.return_value = [
+ {
+ "CommonPrefixes": [
+ {"Prefix": "mp-1/"},
+ {"Prefix": "mp-2/"},
+ {"Prefix": "mp-3/"},
+ ]
+ }
+ ]
+
+ # Write checkpoint with mp-1 already processed
+ checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt")
+ with open(checkpoint_path, "w") as f:
+ f.write("mp-1\n")
+
+ with patch.object(LeMatRhoDirectPipeline, "_validate_tools", return_value={}):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ all_materials = pipeline._list_materials()
+ pipeline._processed_ids = pipeline._load_checkpoint()
+ remaining = [m for m in all_materials if m not in pipeline._processed_ids]
+
+ assert remaining == ["mp-2", "mp-3"]
+
+
+# ---------------------------------------------------------------------------
+# TestProcessMaterial
+# ---------------------------------------------------------------------------
+
+
+class TestProcessMaterial:
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_optimade_from_pymatgen")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.compress_chgcar")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_vasprun_structure")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.download_gz_file_from_s3")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_happy_path_no_tools(
+ self,
+ mock_get_client,
+ mock_download,
+ mock_parse_vasprun,
+ mock_compress,
+ mock_get_optimade,
+ mock_config,
+ no_tools,
+ ):
+ """Full processing without external tools — charge analysis fields are None."""
+ from pymatgen.core import Lattice, Structure
+
+ mock_get_client.return_value = MagicMock()
+ mock_download.return_value = b"mock_bytes"
+
+ structure = Structure(
+ Lattice.cubic(3.0),
+ ["Si", "O"],
+ [[0, 0, 0], [0.5, 0.5, 0.5]],
+ )
+ mock_parse_vasprun.return_value = structure
+ mock_compress.return_value = [[[1.0] * 10] * 10] * 10
+ mock_get_optimade.return_value = _make_mock_optimade_dict()
+
+ result = LeMatRhoDirectPipeline._process_material(
+ "mp-123", mock_config, no_tools
+ )
+
+ assert result is not None
+ assert isinstance(result, dict)
+
+ # Check key fields
+ assert result["immutable_id"] == "mp-123"
+ assert result["functional"] == "pbe"
+ assert result["cross_compatibility"] is True
+ assert result["nsites"] == 2
+ assert result["charge_density_grid_shape"] == [10, 10, 10]
+
+ # Compressed grids should be JSON strings
+ assert isinstance(result["compressed_charge_density"], str)
+ parsed = json.loads(result["compressed_charge_density"])
+ assert isinstance(parsed, list)
+
+ # No tools — charge analysis fields are None
+ assert result["bader_charges"] is None
+ assert result["bader_atomic_volume"] is None
+ assert result["ddec6_charges"] is None
+
+ # Species should be JSON string
+ assert isinstance(result["species"], str)
+
+ # All Parquet columns present
+ for col in PARQUET_COLUMNS:
+ assert col in result
+
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_optimade_from_pymatgen")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.compress_chgcar")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_vasprun_structure")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.download_gz_file_from_s3")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_missing_vasprun_returns_none(
+ self,
+ mock_get_client,
+ mock_download,
+ mock_parse_vasprun,
+ mock_compress,
+ mock_get_optimade,
+ mock_config,
+ no_tools,
+ ):
+ mock_get_client.return_value = MagicMock()
+ mock_download.side_effect = Exception("NoSuchKey: vasprun.xml.gz")
+
+ result = LeMatRhoDirectPipeline._process_material(
+ "mp-999", mock_config, no_tools
+ )
+ assert result is None
+
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_optimade_from_pymatgen")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.compress_chgcar")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_vasprun_structure")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.download_gz_file_from_s3")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_partial_charge_files(
+ self,
+ mock_get_client,
+ mock_download,
+ mock_parse_vasprun,
+ mock_compress,
+ mock_get_optimade,
+ mock_config,
+ no_tools,
+ ):
+ """Missing AECCAR1 — other files still processed."""
+ from pymatgen.core import Lattice, Structure
+
+ mock_get_client.return_value = MagicMock()
+
+ def download_side_effect(client, bucket, key):
+ if "AECCAR1.gz" in key:
+ raise Exception("NoSuchKey")
+ return b"mock_bytes"
+
+ mock_download.side_effect = download_side_effect
+
+ structure = Structure(
+ Lattice.cubic(3.0), ["Si", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]]
+ )
+ mock_parse_vasprun.return_value = structure
+ mock_compress.return_value = [[[1.0]]]
+ mock_get_optimade.return_value = _make_mock_optimade_dict()
+
+ result = LeMatRhoDirectPipeline._process_material(
+ "mp-123", mock_config, no_tools
+ )
+
+ assert result is not None
+ # CHGCAR, AECCAR0, AECCAR2 should be present
+ assert result["compressed_charge_density"] is not None
+ assert result["compressed_aeccar0"] is not None
+ assert result["compressed_aeccar2"] is not None
+ # AECCAR1 should be None (failed to download)
+ assert result["compressed_aeccar1"] is None
+
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_optimade_from_pymatgen")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.compress_chgcar")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_vasprun_structure")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.download_gz_file_from_s3")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_cross_compatibility_excludes_yb(
+ self,
+ mock_get_client,
+ mock_download,
+ mock_parse_vasprun,
+ mock_compress,
+ mock_get_optimade,
+ mock_config,
+ no_tools,
+ ):
+ from pymatgen.core import Lattice, Structure
+
+ mock_get_client.return_value = MagicMock()
+ mock_download.return_value = b"mock_bytes"
+
+ structure = Structure(
+ Lattice.cubic(3.0), ["Yb", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]]
+ )
+ mock_parse_vasprun.return_value = structure
+ mock_compress.return_value = [[[1.0]]]
+
+ optimade_dict = _make_mock_optimade_dict()
+ optimade_dict["elements"] = ["O", "Yb"]
+ optimade_dict["species_at_sites"] = ["Yb", "O"]
+ optimade_dict["species"] = [
+ {
+ "mass": None,
+ "name": "O",
+ "attached": None,
+ "nattached": None,
+ "concentration": [1],
+ "original_name": None,
+ "chemical_symbols": ["O"],
+ },
+ {
+ "mass": None,
+ "name": "Yb",
+ "attached": None,
+ "nattached": None,
+ "concentration": [1],
+ "original_name": None,
+ "chemical_symbols": ["Yb"],
+ },
+ ]
+ mock_get_optimade.return_value = optimade_dict
+
+ result = LeMatRhoDirectPipeline._process_material(
+ "mp-yb", mock_config, no_tools
+ )
+
+ assert result is not None
+ assert result["cross_compatibility"] is False
+
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline._run_bader_from_bytes")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_optimade_from_pymatgen")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.compress_chgcar")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_vasprun_structure")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.download_gz_file_from_s3")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_bader_failure_still_returns_result(
+ self,
+ mock_get_client,
+ mock_download,
+ mock_parse_vasprun,
+ mock_compress,
+ mock_get_optimade,
+ mock_bader,
+ mock_config,
+ ):
+ """When Bader fails, bader fields are None but row is still returned."""
+ from pymatgen.core import Lattice, Structure
+
+ mock_get_client.return_value = MagicMock()
+ mock_download.return_value = b"mock_bytes"
+
+ structure = Structure(
+ Lattice.cubic(3.0), ["Si", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]]
+ )
+ mock_parse_vasprun.return_value = structure
+ mock_compress.return_value = [[[1.0]]]
+ mock_get_optimade.return_value = _make_mock_optimade_dict()
+ mock_bader.return_value = (None, None)
+
+ tools = {
+ "bader_path": "/usr/bin/bader",
+ "chargemol_path": None,
+ "chgsum_script_path": "/opt/chgsum.pl",
+ "perl_path": "/usr/bin/perl",
+ "atomic_densities_path": None,
+ "can_generate_potcar": True,
+ "can_run_bader": True,
+ "can_run_ddec6": False,
+ }
+
+ result = LeMatRhoDirectPipeline._process_material(
+ "mp-123", mock_config, tools
+ )
+
+ assert result is not None
+ assert result["bader_charges"] is None
+ assert result["bader_atomic_volume"] is None
+ mock_bader.assert_called_once()
+
+
+# ---------------------------------------------------------------------------
+# TestCheckpointing
+# ---------------------------------------------------------------------------
+
+
+class TestCheckpointing:
+ def test_load_empty_checkpoint(self, mock_config):
+ """No checkpoint file -> empty set."""
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ ids = pipeline._load_checkpoint()
+
+ assert ids == set()
+
+ def test_load_existing_checkpoint(self, mock_config):
+ """Checkpoint file with IDs -> returns set of those IDs."""
+ checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt")
+ with open(checkpoint_path, "w") as f:
+ f.write("mp-1\nmp-2\nagm000001\n")
+
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ ids = pipeline._load_checkpoint()
+
+ assert ids == {"mp-1", "mp-2", "agm000001"}
+
+ def test_append_checkpoint(self, mock_config):
+ """Appending to checkpoint writes ID and persists."""
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ pipeline._append_checkpoint("mp-100")
+ pipeline._append_checkpoint("mp-200")
+
+ checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt")
+ with open(checkpoint_path, "r") as f:
+ lines = [line.strip() for line in f if line.strip()]
+
+ assert lines == ["mp-100", "mp-200"]
+
+ def test_checkpoint_skips_blank_lines(self, mock_config):
+ """Blank lines in checkpoint file are ignored."""
+ checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt")
+ with open(checkpoint_path, "w") as f:
+ f.write("mp-1\n\n\nmp-2\n\n")
+
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ ids = pipeline._load_checkpoint()
+
+ assert ids == {"mp-1", "mp-2"}
+
+ def test_batch_checkpoint(self, mock_config):
+ """Batch checkpoint writes multiple IDs atomically."""
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ pipeline._batch_checkpoint(["mp-1", "mp-2", "mp-3"])
+
+ checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt")
+ with open(checkpoint_path, "r") as f:
+ lines = [line.strip() for line in f if line.strip()]
+
+ assert lines == ["mp-1", "mp-2", "mp-3"]
+
+
+# ---------------------------------------------------------------------------
+# TestFailureTracking
+# ---------------------------------------------------------------------------
+
+
+class TestFailureTracking:
+ def test_load_empty_failures(self, mock_config):
+ """No failures file -> empty set."""
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ ids = pipeline._load_failures()
+
+ assert ids == set()
+
+ def test_load_existing_failures(self, mock_config):
+ """Failures file with IDs -> returns set."""
+ failures_path = os.path.join(mock_config.output_dir, ".failures.txt")
+ with open(failures_path, "w") as f:
+ f.write("mp-bad1\nmp-bad2\n")
+
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ ids = pipeline._load_failures()
+
+ assert ids == {"mp-bad1", "mp-bad2"}
+
+ def test_append_failure(self, mock_config):
+ """Appending failure records ID on disk."""
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ pipeline._append_failure("mp-fail1")
+ pipeline._append_failure("mp-fail2")
+
+ failures_path = os.path.join(mock_config.output_dir, ".failures.txt")
+ with open(failures_path, "r") as f:
+ lines = [line.strip() for line in f if line.strip()]
+
+ assert lines == ["mp-fail1", "mp-fail2"]
+
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_resume_skips_failures(self, mock_get_client, mock_config, no_tools):
+ """Pipeline skips previously failed materials on resume."""
+ mock_client = MagicMock()
+ mock_get_client.return_value = mock_client
+
+ mock_paginator = MagicMock()
+ mock_client.get_paginator.return_value = mock_paginator
+ mock_paginator.paginate.return_value = [
+ {
+ "CommonPrefixes": [
+ {"Prefix": "mp-0/"},
+ {"Prefix": "mp-1/"},
+ {"Prefix": "mp-2/"},
+ ]
+ }
+ ]
+
+ # mp-0 already processed, mp-1 previously failed
+ checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt")
+ with open(checkpoint_path, "w") as f:
+ f.write("mp-0\n")
+ failures_path = os.path.join(mock_config.output_dir, ".failures.txt")
+ with open(failures_path, "w") as f:
+ f.write("mp-1\n")
+
+ row = {col: None for col in PARQUET_COLUMNS}
+ row.update(
+ {
+ "elements": ["Si"],
+ "nsites": 1,
+ "chemical_formula_anonymous": "A",
+ "chemical_formula_reduced": "Si",
+ "chemical_formula_descriptive": "Si1",
+ "nelements": 1,
+ "dimension_types": [1, 1, 1],
+ "nperiodic_dimensions": 3,
+ "lattice_vectors": [[3, 0, 0], [0, 3, 0], [0, 0, 3]],
+ "immutable_id": "mp-2",
+ "cartesian_site_positions": [[0, 0, 0]],
+ "species": json.dumps([{"name": "Si"}]),
+ "species_at_sites": ["Si"],
+ "last_modified": datetime.now().isoformat(),
+ "elements_ratios": [1.0],
+ "functional": "pbe",
+ "cross_compatibility": True,
+ }
+ )
+
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value=no_tools
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config, debug=True)
+
+ processed_ids = []
+
+ def mock_process(material_id, config, tool_paths):
+ processed_ids.append(material_id)
+ r = dict(row)
+ r["immutable_id"] = material_id
+ return r
+
+ with patch.object(
+ LeMatRhoDirectPipeline, "_process_material", side_effect=mock_process
+ ):
+ pipeline.run()
+
+ # Only mp-2 should be processed (mp-0 checkpointed, mp-1 failed)
+ assert processed_ids == ["mp-2"]
+
+
+# ---------------------------------------------------------------------------
+# TestParquetWriting
+# ---------------------------------------------------------------------------
+
+
+class TestParquetWriting:
+ def _make_row(self, material_id="mp-1"):
+ """Create a minimal valid row dict matching PARQUET_COLUMNS."""
+ row = {col: None for col in PARQUET_COLUMNS}
+ row.update(
+ {
+ "elements": ["O", "Si"],
+ "nsites": 2,
+ "chemical_formula_anonymous": "AB",
+ "chemical_formula_reduced": "OSi",
+ "chemical_formula_descriptive": "Si1 O1",
+ "nelements": 2,
+ "dimension_types": [1, 1, 1],
+ "nperiodic_dimensions": 3,
+ "lattice_vectors": [[3.0, 0, 0], [0, 3.0, 0], [0, 0, 3.0]],
+ "immutable_id": material_id,
+ "cartesian_site_positions": [[0, 0, 0], [1.5, 1.5, 1.5]],
+ "species": json.dumps([{"name": "O"}, {"name": "Si"}]),
+ "species_at_sites": ["Si", "O"],
+ "last_modified": datetime.now().isoformat(),
+ "elements_ratios": [0.5, 0.5],
+ "functional": "pbe",
+ "cross_compatibility": True,
+ "charge_density_grid_shape": [10, 10, 10],
+ "compressed_charge_density": json.dumps([[[1.0]]]),
+ }
+ )
+ return row
+
+ def test_write_chunk(self, mock_config):
+ """Verify Parquet file is created with correct schema."""
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ rows = [self._make_row(f"mp-{i}") for i in range(3)]
+ pipeline._write_parquet_chunk(rows, 0)
+
+ path = os.path.join(mock_config.output_dir, "chunk_000000.parquet")
+ assert os.path.exists(path)
+
+ table = pq.read_table(path)
+ assert table.num_rows == 3
+ assert set(table.column_names) == set(PARQUET_COLUMNS)
+
+ def test_atomic_write_no_tmp_file_remains(self, mock_config):
+ """After writing, no .tmp file should remain."""
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ rows = [self._make_row()]
+ pipeline._write_parquet_chunk(rows, 0)
+
+ tmp_files = [
+ f
+ for f in os.listdir(mock_config.output_dir)
+ if f.endswith(".tmp")
+ ]
+ assert len(tmp_files) == 0
+
+ def test_chunk_index_resume(self, mock_config):
+ """Next chunk index should be max existing + 1."""
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+
+ # Write chunks 0, 1, 2
+ for i in range(3):
+ rows = [self._make_row(f"mp-{i}")]
+ pipeline._write_parquet_chunk(rows, i)
+
+ assert pipeline._get_next_chunk_index() == 3
+
+ def test_tmp_files_ignored_on_resume(self, mock_config):
+ """Stale .tmp files don't affect chunk indexing."""
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+
+ # Write one real chunk
+ rows = [self._make_row()]
+ pipeline._write_parquet_chunk(rows, 0)
+
+ # Create a stale .tmp file
+ tmp_path = os.path.join(
+ mock_config.output_dir, "chunk_000001.parquet.tmp"
+ )
+ with open(tmp_path, "w") as f:
+ f.write("stale")
+
+ assert pipeline._get_next_chunk_index() == 1
+
+ def test_chunk_index_empty_dir(self, mock_config):
+ """Empty output dir -> chunk index 0."""
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value={}
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config)
+ assert pipeline._get_next_chunk_index() == 0
+
+
+# ---------------------------------------------------------------------------
+# TestStructureToRow
+# ---------------------------------------------------------------------------
+
+
+class TestStructureToRow:
+ def test_all_columns_present(self):
+ """Row dict should have exactly the PARQUET_COLUMNS keys."""
+ optimade_dict = _make_mock_optimade_dict()
+ structure = OptimadeStructure(
+ id="mp-1",
+ source="lematrho",
+ immutable_id="mp-1",
+ last_modified=datetime.now(),
+ **optimade_dict,
+ functional=Functional.PBE,
+ cross_compatibility=True,
+ compute_space_group=True,
+ compute_bawl_hash=True,
+ )
+
+ row = _structure_to_row(structure)
+ assert set(row.keys()) == set(PARQUET_COLUMNS)
+
+ def test_species_json_serialized(self):
+ """Species field should be a JSON string."""
+ optimade_dict = _make_mock_optimade_dict()
+ structure = OptimadeStructure(
+ id="mp-1",
+ source="lematrho",
+ immutable_id="mp-1",
+ last_modified=datetime.now(),
+ **optimade_dict,
+ functional=Functional.PBE,
+ cross_compatibility=True,
+ compute_space_group=True,
+ compute_bawl_hash=True,
+ )
+
+ row = _structure_to_row(structure)
+ assert isinstance(row["species"], str)
+ parsed = json.loads(row["species"])
+ assert isinstance(parsed, list)
+
+ def test_functional_is_string(self):
+ """Functional enum should be converted to string value."""
+ optimade_dict = _make_mock_optimade_dict()
+ structure = OptimadeStructure(
+ id="mp-1",
+ source="lematrho",
+ immutable_id="mp-1",
+ last_modified=datetime.now(),
+ **optimade_dict,
+ functional=Functional.PBE,
+ cross_compatibility=True,
+ compute_space_group=True,
+ compute_bawl_hash=True,
+ )
+
+ row = _structure_to_row(structure)
+ assert row["functional"] == "pbe"
+
+ def test_last_modified_is_iso_string(self):
+ """last_modified should be an ISO format string."""
+ optimade_dict = _make_mock_optimade_dict()
+ now = datetime.now()
+ structure = OptimadeStructure(
+ id="mp-1",
+ source="lematrho",
+ immutable_id="mp-1",
+ last_modified=now,
+ **optimade_dict,
+ functional=Functional.PBE,
+ cross_compatibility=True,
+ compute_space_group=True,
+ compute_bawl_hash=True,
+ )
+
+ row = _structure_to_row(structure)
+ # Model validator strips time to date-only (YYYY-MM-DD -> YYYY-MM-DDT00:00:00)
+ assert row["last_modified"] == structure.last_modified.isoformat()
+
+ def test_compressed_fields_json_serialized(self):
+ """Compressed charge density fields should be JSON strings when present."""
+ optimade_dict = _make_mock_optimade_dict()
+ grid = [[[1.0, 2.0], [3.0, 4.0]]]
+ structure = OptimadeStructure(
+ id="mp-1",
+ source="lematrho",
+ immutable_id="mp-1",
+ last_modified=datetime.now(),
+ **optimade_dict,
+ functional=Functional.PBE,
+ cross_compatibility=True,
+ compressed_charge_density=grid,
+ charge_density_grid_shape=[1, 2, 2],
+ compute_space_group=True,
+ compute_bawl_hash=True,
+ )
+
+ row = _structure_to_row(structure)
+ assert isinstance(row["compressed_charge_density"], str)
+ assert json.loads(row["compressed_charge_density"]) == grid
+
+
+# ---------------------------------------------------------------------------
+# TestRunIntegration
+# ---------------------------------------------------------------------------
+
+
+class TestRunIntegration:
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_full_pipeline_debug_mode(
+ self, mock_get_client, mock_config, no_tools
+ ):
+ """Integration test: process 5 materials in debug mode, verify chunks + checkpoint."""
+ mock_client = MagicMock()
+ mock_get_client.return_value = mock_client
+
+ # Mock S3 listing
+ mock_paginator = MagicMock()
+ mock_client.get_paginator.return_value = mock_paginator
+ mock_paginator.paginate.return_value = [
+ {
+ "CommonPrefixes": [
+ {"Prefix": f"mp-{i}/"}
+ for i in range(5)
+ ]
+ }
+ ]
+
+ # Create mock results
+ mock_results = []
+ for i in range(5):
+ row = {col: None for col in PARQUET_COLUMNS}
+ row.update(
+ {
+ "elements": ["Si"],
+ "nsites": 1,
+ "chemical_formula_anonymous": "A",
+ "chemical_formula_reduced": "Si",
+ "chemical_formula_descriptive": "Si1",
+ "nelements": 1,
+ "dimension_types": [1, 1, 1],
+ "nperiodic_dimensions": 3,
+ "lattice_vectors": [[3, 0, 0], [0, 3, 0], [0, 0, 3]],
+ "immutable_id": f"mp-{i}",
+ "cartesian_site_positions": [[0, 0, 0]],
+ "species": json.dumps([{"name": "Si"}]),
+ "species_at_sites": ["Si"],
+ "last_modified": datetime.now().isoformat(),
+ "elements_ratios": [1.0],
+ "functional": "pbe",
+ "cross_compatibility": True,
+ "charge_density_grid_shape": [10, 10, 10],
+ }
+ )
+ mock_results.append(row)
+
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value=no_tools
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config, debug=True)
+
+ # Mock _process_material to return our pre-built rows
+ call_count = [0]
+
+ def mock_process(material_id, config, tool_paths):
+ idx = call_count[0]
+ call_count[0] += 1
+ return mock_results[idx]
+
+ with patch.object(
+ LeMatRhoDirectPipeline, "_process_material", side_effect=mock_process
+ ):
+ pipeline.run()
+
+ # With chunk_size=3 and 5 materials: should write 2 chunks (3 + 2)
+ parquet_files = sorted(
+ f
+ for f in os.listdir(mock_config.output_dir)
+ if f.endswith(".parquet")
+ )
+ assert len(parquet_files) == 2
+
+ # Check row counts
+ total_rows = 0
+ for f in parquet_files:
+ table = pq.read_table(os.path.join(mock_config.output_dir, f))
+ total_rows += table.num_rows
+ assert total_rows == 5
+
+ # Check checkpoint
+ checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt")
+ with open(checkpoint_path) as f:
+ checkpoint_ids = {line.strip() for line in f if line.strip()}
+ assert checkpoint_ids == {f"mp-{i}" for i in range(5)}
+
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_resume_skips_processed(self, mock_get_client, mock_config, no_tools):
+ """Pipeline resumes from checkpoint, skipping already-processed materials."""
+ mock_client = MagicMock()
+ mock_get_client.return_value = mock_client
+
+ mock_paginator = MagicMock()
+ mock_client.get_paginator.return_value = mock_paginator
+ mock_paginator.paginate.return_value = [
+ {
+ "CommonPrefixes": [
+ {"Prefix": "mp-0/"},
+ {"Prefix": "mp-1/"},
+ {"Prefix": "mp-2/"},
+ ]
+ }
+ ]
+
+ # Pre-existing checkpoint (mp-0 already done)
+ checkpoint_path = os.path.join(mock_config.output_dir, ".checkpoint.txt")
+ with open(checkpoint_path, "w") as f:
+ f.write("mp-0\n")
+
+ # Pre-existing Parquet chunk
+ row = {col: None for col in PARQUET_COLUMNS}
+ row.update(
+ {
+ "elements": ["Si"],
+ "nsites": 1,
+ "chemical_formula_anonymous": "A",
+ "chemical_formula_reduced": "Si",
+ "chemical_formula_descriptive": "Si1",
+ "nelements": 1,
+ "dimension_types": [1, 1, 1],
+ "nperiodic_dimensions": 3,
+ "lattice_vectors": [[3, 0, 0], [0, 3, 0], [0, 0, 3]],
+ "immutable_id": "mp-0",
+ "cartesian_site_positions": [[0, 0, 0]],
+ "species": json.dumps([{"name": "Si"}]),
+ "species_at_sites": ["Si"],
+ "last_modified": datetime.now().isoformat(),
+ "elements_ratios": [1.0],
+ "functional": "pbe",
+ "cross_compatibility": True,
+ }
+ )
+
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value=no_tools
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config, debug=True)
+
+ processed_ids = []
+
+ def mock_process(material_id, config, tool_paths):
+ processed_ids.append(material_id)
+ r = dict(row)
+ r["immutable_id"] = material_id
+ return r
+
+ with patch.object(
+ LeMatRhoDirectPipeline, "_process_material", side_effect=mock_process
+ ):
+ pipeline.run()
+
+ # Only mp-1 and mp-2 should be processed (mp-0 skipped)
+ assert "mp-0" not in processed_ids
+ assert set(processed_ids) == {"mp-1", "mp-2"}
+
+
+# ---------------------------------------------------------------------------
+# TestBaderFromBytes
+# ---------------------------------------------------------------------------
+
+
+class TestBaderFromBytes:
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.write_potcar")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.subprocess.run")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.read_potcar_zval")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_acf_dat")
+ def test_happy_path(
+ self, mock_parse_acf, mock_read_zval, mock_subprocess, mock_potcar
+ ):
+ from pymatgen.core import Lattice, Structure
+
+ structure = Structure(
+ Lattice.cubic(3.0), ["Si", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]]
+ )
+ raw_files = {
+ "CHGCAR": b"chgcar_data",
+ "AECCAR0": b"aeccar0_data",
+ "AECCAR2": b"aeccar2_data",
+ }
+ tools = {
+ "bader_path": "/usr/bin/bader",
+ "perl_path": "/usr/bin/perl",
+ "chgsum_script_path": "/opt/chgsum.pl",
+ }
+
+ mock_parse_acf.return_value = ([4.0, 6.0], [10.0, 12.0])
+ mock_read_zval.return_value = {"Si": 4.0, "O": 6.0}
+
+ charges, volumes = _run_bader_from_bytes(
+ structure, raw_files, tools, "mp-test"
+ )
+
+ assert charges == [0.0, 0.0] # valence - electron_count
+ assert volumes == [10.0, 12.0]
+ assert mock_subprocess.call_count == 2 # chgsum + bader
+
+ def test_subprocess_timeout(self):
+ from pymatgen.core import Lattice, Structure
+
+ structure = Structure(
+ Lattice.cubic(3.0), ["Si"], [[0, 0, 0]]
+ )
+ raw_files = {"CHGCAR": b"x", "AECCAR0": b"x", "AECCAR2": b"x"}
+ tools = {
+ "bader_path": "/usr/bin/bader",
+ "perl_path": "/usr/bin/perl",
+ "chgsum_script_path": "/opt/chgsum.pl",
+ }
+
+ with patch(
+ "lematerial_fetcher.fetcher.lematrho.pipeline.write_potcar"
+ ):
+ with patch(
+ "lematerial_fetcher.fetcher.lematrho.pipeline.subprocess.run",
+ side_effect=subprocess.TimeoutExpired("bader", 600),
+ ):
+ charges, volumes = _run_bader_from_bytes(
+ structure, raw_files, tools, "mp-test"
+ )
+
+ assert charges is None
+ assert volumes is None
+
+
+# ---------------------------------------------------------------------------
+# TestDdec6FromBytes
+# ---------------------------------------------------------------------------
+
+
+class TestDdec6FromBytes:
+ def test_subprocess_failure(self):
+ from pymatgen.core import Lattice, Structure
+
+ structure = Structure(
+ Lattice.cubic(3.0), ["Si"], [[0, 0, 0]]
+ )
+ raw_files = {"CHGCAR": b"x"}
+ tools = {
+ "chargemol_path": "/usr/bin/chargemol",
+ "atomic_densities_path": "/opt/densities",
+ }
+
+ with patch(
+ "lematerial_fetcher.fetcher.lematrho.pipeline.write_potcar"
+ ):
+ with patch(
+ "lematerial_fetcher.fetcher.lematrho.pipeline.subprocess.run",
+ side_effect=subprocess.CalledProcessError(1, "chargemol"),
+ ):
+ result = _run_ddec6_from_bytes(
+ structure, raw_files, tools, "mp-test"
+ )
+
+ assert result is None
+
+ def test_happy_path(self):
+ """DDEC6 returns charges when subprocess succeeds."""
+ from pymatgen.core import Lattice, Structure
+
+ structure = Structure(
+ Lattice.cubic(3.0), ["Si", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]]
+ )
+ raw_files = {"CHGCAR": b"chgcar_data"}
+ tools = {
+ "chargemol_path": "/usr/bin/chargemol",
+ "atomic_densities_path": "/opt/densities",
+ }
+
+ with patch(
+ "lematerial_fetcher.fetcher.lematrho.pipeline.write_potcar"
+ ):
+ with patch(
+ "lematerial_fetcher.fetcher.lematrho.pipeline.subprocess.run"
+ ):
+ with patch(
+ "lematerial_fetcher.fetcher.lematrho.pipeline.parse_ddec6_charges",
+ return_value=[0.5, -0.5],
+ ):
+ result = _run_ddec6_from_bytes(
+ structure, raw_files, tools, "mp-test"
+ )
+
+ assert result == [0.5, -0.5]
+
+ def test_timeout(self):
+ """DDEC6 returns None on timeout."""
+ from pymatgen.core import Lattice, Structure
+
+ structure = Structure(
+ Lattice.cubic(3.0), ["Si"], [[0, 0, 0]]
+ )
+ raw_files = {"CHGCAR": b"x"}
+ tools = {
+ "chargemol_path": "/usr/bin/chargemol",
+ "atomic_densities_path": "/opt/densities",
+ }
+
+ with patch(
+ "lematerial_fetcher.fetcher.lematrho.pipeline.write_potcar"
+ ):
+ with patch(
+ "lematerial_fetcher.fetcher.lematrho.pipeline.subprocess.run",
+ side_effect=subprocess.TimeoutExpired("chargemol", 600),
+ ):
+ result = _run_ddec6_from_bytes(
+ structure, raw_files, tools, "mp-test"
+ )
+
+ assert result is None
+
+
+# ---------------------------------------------------------------------------
+# TestValidateTools
+# ---------------------------------------------------------------------------
+
+
+class TestValidateTools:
+ def test_all_tools_available(self, mock_config):
+ """All tools present -> can_run_bader and can_run_ddec6 are True."""
+ with patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"):
+ with patch.dict(os.environ, {"PMG_VASP_PSP_DIR": "/opt/psp"}):
+ config = DirectPipelineConfig(
+ lematrho_bucket_name="test-bucket",
+ output_dir=mock_config.output_dir,
+ bader_path="/usr/bin/bader",
+ chargemol_path="/usr/bin/chargemol",
+ chgsum_script_path=__file__, # use this test file as a file that exists
+ atomic_densities_path=os.path.dirname(__file__), # dir that exists
+ )
+ pipeline = LeMatRhoDirectPipeline(config=config)
+
+ assert pipeline._tool_paths["can_run_bader"] is True
+ assert pipeline._tool_paths["can_run_ddec6"] is True
+ assert pipeline._tool_paths["can_generate_potcar"] is True
+
+ def test_no_tools_available(self, mock_config):
+ """No tools on PATH -> can_run_bader and can_run_ddec6 are False."""
+ with patch("shutil.which", return_value=None):
+ with patch.dict(os.environ, {}, clear=True):
+ config = DirectPipelineConfig(
+ lematrho_bucket_name="test-bucket",
+ output_dir=mock_config.output_dir,
+ )
+ pipeline = LeMatRhoDirectPipeline(config=config)
+
+ assert pipeline._tool_paths["can_run_bader"] is False
+ assert pipeline._tool_paths["can_run_ddec6"] is False
+ assert pipeline._tool_paths["can_generate_potcar"] is False
+
+ def test_bader_but_no_chgsum(self, mock_config):
+ """Bader on PATH but chgsum not set -> can_run_bader False."""
+ with patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"):
+ with patch.dict(os.environ, {"PMG_VASP_PSP_DIR": "/opt/psp"}):
+ config = DirectPipelineConfig(
+ lematrho_bucket_name="test-bucket",
+ output_dir=mock_config.output_dir,
+ bader_path="/usr/bin/bader",
+ # no chgsum_script_path
+ )
+ pipeline = LeMatRhoDirectPipeline(config=config)
+
+ assert pipeline._tool_paths["can_run_bader"] is False
+ assert pipeline._tool_paths["bader_path"] == "/usr/bin/bader"
+
+ def test_missing_pmg_vasp_psp_dir(self, mock_config):
+ """No PMG_VASP_PSP_DIR -> can_generate_potcar False, both analyses disabled."""
+ with patch("shutil.which", side_effect=lambda x: f"/usr/bin/{x}"):
+ env = os.environ.copy()
+ env.pop("PMG_VASP_PSP_DIR", None)
+ with patch.dict(os.environ, env, clear=True):
+ config = DirectPipelineConfig(
+ lematrho_bucket_name="test-bucket",
+ output_dir=mock_config.output_dir,
+ bader_path="/usr/bin/bader",
+ chargemol_path="/usr/bin/chargemol",
+ chgsum_script_path=__file__,
+ atomic_densities_path=os.path.dirname(__file__),
+ )
+ pipeline = LeMatRhoDirectPipeline(config=config)
+
+ assert pipeline._tool_paths["can_generate_potcar"] is False
+ assert pipeline._tool_paths["can_run_bader"] is False
+ assert pipeline._tool_paths["can_run_ddec6"] is False
+
+
+# ---------------------------------------------------------------------------
+# TestStructureToRowNoneFields
+# ---------------------------------------------------------------------------
+
+
+class TestStructureToRowNoneFields:
+ def test_all_charge_fields_none(self):
+ """Structure with no charge density fields -> all charge columns None."""
+ optimade_dict = _make_mock_optimade_dict()
+ structure = OptimadeStructure(
+ id="mp-1",
+ source="lematrho",
+ immutable_id="mp-1",
+ last_modified=datetime.now(),
+ **optimade_dict,
+ functional=Functional.PBE,
+ cross_compatibility=True,
+ compute_space_group=True,
+ compute_bawl_hash=True,
+ )
+
+ row = _structure_to_row(structure)
+ assert row["compressed_charge_density"] is None
+ assert row["compressed_aeccar0"] is None
+ assert row["compressed_aeccar1"] is None
+ assert row["compressed_aeccar2"] is None
+ assert row["charge_density_grid_shape"] is None
+ assert row["bader_charges"] is None
+ assert row["bader_atomic_volume"] is None
+ assert row["ddec6_charges"] is None
+
+ def test_partial_charge_fields(self):
+ """Structure with only some charge fields -> only those are populated."""
+ optimade_dict = _make_mock_optimade_dict()
+ structure = OptimadeStructure(
+ id="mp-1",
+ source="lematrho",
+ immutable_id="mp-1",
+ last_modified=datetime.now(),
+ **optimade_dict,
+ functional=Functional.PBE,
+ cross_compatibility=True,
+ compressed_charge_density=[[[1.0]]],
+ charge_density_grid_shape=[1, 1, 1],
+ bader_charges=[0.1, -0.1],
+ compute_space_group=True,
+ compute_bawl_hash=True,
+ )
+
+ row = _structure_to_row(structure)
+ assert row["compressed_charge_density"] is not None
+ assert row["compressed_aeccar0"] is None
+ assert row["bader_charges"] == [0.1, -0.1]
+ assert row["ddec6_charges"] is None
+
+
+# ---------------------------------------------------------------------------
+# TestPushToHuggingface
+# ---------------------------------------------------------------------------
+
+
+class TestPushToHuggingface:
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_push_called_when_configured(self, mock_get_client, mock_config, no_tools):
+ """Pipeline calls push_to_hub when hf_repo_id is configured."""
+ mock_client = MagicMock()
+ mock_get_client.return_value = mock_client
+ mock_paginator = MagicMock()
+ mock_client.get_paginator.return_value = mock_paginator
+ mock_paginator.paginate.return_value = [
+ {"CommonPrefixes": [{"Prefix": "mp-1/"}]}
+ ]
+
+ config = DirectPipelineConfig(
+ lematrho_bucket_name="test-bucket",
+ output_dir=mock_config.output_dir,
+ hf_repo_id="test-org/test-dataset",
+ hf_token="hf_test_token",
+ )
+
+ mock_row = {col: None for col in PARQUET_COLUMNS}
+ mock_row.update({
+ "elements": ["Si"],
+ "nsites": 1,
+ "chemical_formula_anonymous": "A",
+ "chemical_formula_reduced": "Si",
+ "chemical_formula_descriptive": "Si1",
+ "nelements": 1,
+ "dimension_types": [1, 1, 1],
+ "nperiodic_dimensions": 3,
+ "lattice_vectors": [[3, 0, 0], [0, 3, 0], [0, 0, 3]],
+ "immutable_id": "mp-1",
+ "cartesian_site_positions": [[0, 0, 0]],
+ "species": json.dumps([{"name": "Si"}]),
+ "species_at_sites": ["Si"],
+ "last_modified": datetime.now().isoformat(),
+ "elements_ratios": [1.0],
+ "functional": "pbe",
+ "cross_compatibility": True,
+ })
+
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value=no_tools
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=config, debug=True)
+
+ mock_dataset = MagicMock()
+ with patch.object(
+ LeMatRhoDirectPipeline, "_process_material", return_value=mock_row
+ ):
+ with patch(
+ "datasets.load_dataset",
+ return_value={"train": mock_dataset},
+ ):
+ pipeline.run()
+
+ mock_dataset.push_to_hub.assert_called_once_with(
+ "test-org/test-dataset",
+ token="hf_test_token",
+ private=True,
+ )
+
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_push_not_called_without_repo_id(
+ self, mock_get_client, mock_config, no_tools
+ ):
+ """Pipeline skips push when hf_repo_id is None."""
+ mock_client = MagicMock()
+ mock_get_client.return_value = mock_client
+ mock_paginator = MagicMock()
+ mock_client.get_paginator.return_value = mock_paginator
+ mock_paginator.paginate.return_value = [
+ {"CommonPrefixes": [{"Prefix": "mp-1/"}]}
+ ]
+
+ mock_row = {col: None for col in PARQUET_COLUMNS}
+ mock_row.update({
+ "elements": ["Si"],
+ "nsites": 1,
+ "chemical_formula_anonymous": "A",
+ "chemical_formula_reduced": "Si",
+ "chemical_formula_descriptive": "Si1",
+ "nelements": 1,
+ "dimension_types": [1, 1, 1],
+ "nperiodic_dimensions": 3,
+ "lattice_vectors": [[3, 0, 0], [0, 3, 0], [0, 0, 3]],
+ "immutable_id": "mp-1",
+ "cartesian_site_positions": [[0, 0, 0]],
+ "species": json.dumps([{"name": "Si"}]),
+ "species_at_sites": ["Si"],
+ "last_modified": datetime.now().isoformat(),
+ "elements_ratios": [1.0],
+ "functional": "pbe",
+ "cross_compatibility": True,
+ })
+
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value=no_tools
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=mock_config, debug=True)
+
+ with patch.object(
+ LeMatRhoDirectPipeline, "_process_material", return_value=mock_row
+ ):
+ with patch.object(
+ LeMatRhoDirectPipeline, "_push_to_huggingface"
+ ) as mock_push:
+ pipeline.run()
+
+ mock_push.assert_not_called()
+
+
+# ---------------------------------------------------------------------------
+# TestProcessMaterialWithDdec6
+# ---------------------------------------------------------------------------
+
+
+class TestProcessMaterialWithDdec6:
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline._run_ddec6_from_bytes")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_optimade_from_pymatgen")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.compress_chgcar")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.parse_vasprun_structure")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.download_gz_file_from_s3")
+ @patch("lematerial_fetcher.fetcher.lematrho.pipeline.get_authenticated_aws_client")
+ def test_ddec6_populates_charges(
+ self,
+ mock_get_client,
+ mock_download,
+ mock_parse_vasprun,
+ mock_compress,
+ mock_get_optimade,
+ mock_ddec6,
+ mock_config,
+ ):
+ """When DDEC6 tools available and succeed, ddec6_charges is populated."""
+ from pymatgen.core import Lattice, Structure
+
+ mock_get_client.return_value = MagicMock()
+ mock_download.return_value = b"mock_bytes"
+
+ structure = Structure(
+ Lattice.cubic(3.0), ["Si", "O"], [[0, 0, 0], [0.5, 0.5, 0.5]]
+ )
+ mock_parse_vasprun.return_value = structure
+ mock_compress.return_value = [[[1.0] * 10] * 10] * 10
+ mock_get_optimade.return_value = _make_mock_optimade_dict()
+ mock_ddec6.return_value = [0.3, -0.3]
+
+ tools = {
+ "bader_path": None,
+ "chargemol_path": "/usr/bin/chargemol",
+ "chgsum_script_path": None,
+ "perl_path": None,
+ "atomic_densities_path": "/opt/densities",
+ "can_generate_potcar": True,
+ "can_run_bader": False,
+ "can_run_ddec6": True,
+ }
+
+ result = LeMatRhoDirectPipeline._process_material(
+ "mp-123", mock_config, tools
+ )
+
+ assert result is not None
+ assert result["ddec6_charges"] == [0.3, -0.3]
+ assert result["bader_charges"] is None
+ mock_ddec6.assert_called_once()
+
+
+# ---------------------------------------------------------------------------
+# TestIntegrationS3 (requires credentials, skipped in normal runs)
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.integration
+class TestIntegrationS3:
+ """Integration tests that pull real data from the LeMatRho S3 bucket.
+
+ To run these tests:
+ 1. Create a .env.integration file with AWS credentials:
+ AWS_ACCESS_KEY_ID=...
+ AWS_SECRET_ACCESS_KEY=...
+ AWS_DEFAULT_REGION=us-east-1
+ 2. Run: pytest -m integration tests/fetcher/lematrho/test_lematrho_pipeline.py
+ """
+
+ @pytest.fixture(autouse=True)
+ def _load_integration_env(self):
+ """Load .env.integration if available, skip otherwise."""
+ env_path = os.path.join(
+ os.path.dirname(__file__), "..", "..", "..", ".env.integration"
+ )
+ env_path = os.path.normpath(env_path)
+ if not os.path.exists(env_path):
+ pytest.skip(
+ ".env.integration not found — set AWS credentials to run integration tests"
+ )
+ from dotenv import load_dotenv
+
+ load_dotenv(env_path, override=True)
+
+ def test_list_materials_from_real_bucket(self):
+ """Verify we can list at least 1 material from the real S3 bucket."""
+ config = DirectPipelineConfig(
+ lematrho_bucket_name="lemat-rho",
+ output_dir=tempfile.mkdtemp(),
+ )
+ with patch.object(LeMatRhoDirectPipeline, "_validate_tools", return_value={}):
+ pipeline = LeMatRhoDirectPipeline(config=config)
+ materials = pipeline._list_materials()
+ assert len(materials) > 0
+ # All should start with valid prefixes
+ for m in materials[:10]:
+ assert m.startswith(("oqmd-", "mp-", "agm"))
+
+ def test_process_single_material(self):
+ """Fetch and process a single real material end-to-end (no Bader/DDEC6)."""
+ output_dir = tempfile.mkdtemp()
+ config = DirectPipelineConfig(
+ lematrho_bucket_name="lemat-rho",
+ lematrho_grid_shape=(10, 10, 10),
+ output_dir=output_dir,
+ )
+ no_tools = {
+ "bader_path": None,
+ "chargemol_path": None,
+ "chgsum_script_path": None,
+ "perl_path": None,
+ "atomic_densities_path": None,
+ "can_generate_potcar": False,
+ "can_run_bader": False,
+ "can_run_ddec6": False,
+ }
+ with patch.object(
+ LeMatRhoDirectPipeline, "_validate_tools", return_value=no_tools
+ ):
+ pipeline = LeMatRhoDirectPipeline(config=config)
+
+ materials = pipeline._list_materials()
+ assert len(materials) > 0
+ material_id = materials[0]
+
+ result = LeMatRhoDirectPipeline._process_material(
+ material_id, config, no_tools
+ )
+ assert result is not None
+ assert result["immutable_id"] == material_id
+ assert result["functional"] == "pbe"
+ for col in PARQUET_COLUMNS:
+ assert col in result
diff --git a/tests/fetcher/lematrho/test_lematrho_transform.py b/tests/fetcher/lematrho/test_lematrho_transform.py
new file mode 100644
index 0000000..eb3b460
--- /dev/null
+++ b/tests/fetcher/lematrho/test_lematrho_transform.py
@@ -0,0 +1,728 @@
+# Copyright 2025 Entalpic
+import datetime
+import os
+import subprocess
+import tempfile
+from unittest.mock import MagicMock, patch
+
+import pytest
+from pymatgen.core import Lattice, Structure
+
+from lematerial_fetcher.fetcher.lematrho.transform import (
+ LeMatRhoTransformer,
+ get_cross_compatibility,
+ parse_acf_dat,
+ parse_ddec6_charges,
+ read_potcar_zval,
+)
+from lematerial_fetcher.models.models import RawStructure
+from lematerial_fetcher.models.optimade import Functional
+from lematerial_fetcher.utils.config import TransformerConfig
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture
+def si_structure():
+ """A simple Si diamond structure."""
+ return Structure(
+ Lattice.cubic(5.43),
+ ["Si", "Si"],
+ [[0, 0, 0], [0.25, 0.25, 0.25]],
+ )
+
+
+@pytest.fixture
+def yb_structure():
+ """A structure containing Yb (not cross-compatible)."""
+ return Structure(
+ Lattice.cubic(5.0),
+ ["Yb", "O"],
+ [[0, 0, 0], [0.5, 0.5, 0.5]],
+ )
+
+
+@pytest.fixture
+def raw_structure(si_structure):
+ """Raw structure from the fetch step with charge density data."""
+ return RawStructure(
+ id="agm000001",
+ type="lematrho",
+ attributes={
+ "structure": si_structure.as_dict(),
+ "compressed_charge_density": [[[1.0, 2.0]]],
+ "compressed_aeccar0": [[[0.1]]],
+ "compressed_aeccar1": [[[0.01]]],
+ "compressed_aeccar2": [[[0.3]]],
+ "grid_shape": [15, 15, 15],
+ "s3_prefix": "agm000001",
+ },
+ last_modified=datetime.datetime(2025, 1, 1),
+ )
+
+
+@pytest.fixture
+def raw_structure_yb(yb_structure):
+ """Raw structure containing Yb."""
+ return RawStructure(
+ id="agm000002",
+ type="lematrho",
+ attributes={
+ "structure": yb_structure.as_dict(),
+ "compressed_charge_density": [[[1.0]]],
+ "compressed_aeccar0": None,
+ "compressed_aeccar1": None,
+ "compressed_aeccar2": None,
+ "grid_shape": [15, 15, 15],
+ "s3_prefix": "agm000002",
+ },
+ last_modified=datetime.datetime(2025, 1, 1),
+ )
+
+
+@pytest.fixture
+def mock_config():
+ """TransformerConfig for testing."""
+ return TransformerConfig(
+ source_db_conn_str="mock://source",
+ dest_db_conn_str="mock://dest",
+ source_table_name="test_source",
+ dest_table_name="test_dest",
+ batch_size=100,
+ page_offset=0,
+ log_every=100,
+ log_dir="./logs",
+ max_retries=3,
+ page_limit=10,
+ num_workers=1,
+ retry_delay=2,
+ lematrho_bucket_name="lemat-rho",
+ bader_path="/usr/bin/bader",
+ chargemol_path="/usr/bin/chargemol",
+ chgsum_script_path="/usr/bin/chgsum.pl",
+ atomic_densities_path="/path/to/atomic_densities",
+ )
+
+
+@pytest.fixture
+def transformer_with_tools(mock_config):
+ """Transformer with all external tools available (mocked)."""
+ with patch.object(LeMatRhoTransformer, "_validate_tools"):
+ transformer = LeMatRhoTransformer(config=mock_config, debug=True)
+ transformer._bader_path = "/usr/bin/bader"
+ transformer._chargemol_path = "/usr/bin/chargemol"
+ transformer._chgsum_script_path = "/usr/bin/chgsum.pl"
+ transformer._perl_path = "/usr/bin/perl"
+ transformer._atomic_densities_path = "/path/to/atomic_densities"
+ transformer._can_generate_potcar = True
+ return transformer
+
+
+@pytest.fixture
+def transformer_no_tools(mock_config):
+ """Transformer with no external tools available."""
+ with patch.object(LeMatRhoTransformer, "_validate_tools"):
+ transformer = LeMatRhoTransformer(config=mock_config, debug=True)
+ # All tool paths remain None (set in __init__ before _validate_tools)
+ transformer._bader_path = None
+ transformer._chargemol_path = None
+ transformer._chgsum_script_path = None
+ transformer._perl_path = None
+ transformer._atomic_densities_path = None
+ transformer._can_generate_potcar = False
+ return transformer
+
+
+# ---------------------------------------------------------------------------
+# transform_row tests
+# ---------------------------------------------------------------------------
+
+
+class TestTransformRow:
+ def test_happy_path(self, transformer_with_tools, raw_structure):
+ """Full transform with Bader and DDEC6 results populated."""
+ with (
+ patch.object(
+ transformer_with_tools, "_run_bader_analysis"
+ ) as mock_bader,
+ patch.object(
+ transformer_with_tools, "_run_ddec6_analysis"
+ ) as mock_ddec6,
+ ):
+ mock_bader.return_value = ([0.5, -0.5], [10.0, 12.0])
+ mock_ddec6.return_value = [0.3, -0.3]
+
+ result = transformer_with_tools.transform_row(raw_structure)
+
+ assert len(result) == 1
+ s = result[0]
+ assert s.id == "agm000001"
+ assert s.source == "lematrho"
+ assert s.immutable_id == "agm000001"
+ assert s.functional == Functional.PBE
+ assert s.cross_compatibility is True
+ assert s.bader_charges == [0.5, -0.5]
+ assert s.bader_atomic_volume == [10.0, 12.0]
+ assert s.ddec6_charges == [0.3, -0.3]
+ assert s.compressed_charge_density == [[[1.0, 2.0]]]
+ assert s.compressed_aeccar0 == [[[0.1]]]
+ assert s.compressed_aeccar1 == [[[0.01]]]
+ assert s.compressed_aeccar2 == [[[0.3]]]
+ assert s.charge_density_grid_shape == [15, 15, 15]
+ assert s.space_group_it_number is not None
+
+ def test_bader_fails_ddec6_succeeds(self, transformer_with_tools, raw_structure):
+ """When Bader fails, DDEC6 should still succeed independently."""
+ with (
+ patch.object(
+ transformer_with_tools, "_run_bader_analysis"
+ ) as mock_bader,
+ patch.object(
+ transformer_with_tools, "_run_ddec6_analysis"
+ ) as mock_ddec6,
+ ):
+ mock_bader.return_value = (None, None)
+ mock_ddec6.return_value = [0.3, -0.3]
+
+ result = transformer_with_tools.transform_row(raw_structure)
+
+ s = result[0]
+ assert s.bader_charges is None
+ assert s.bader_atomic_volume is None
+ assert s.ddec6_charges == [0.3, -0.3]
+
+ def test_both_analyses_fail(self, transformer_with_tools, raw_structure):
+ """When both Bader and DDEC6 fail, structure should still be created."""
+ with (
+ patch.object(
+ transformer_with_tools, "_run_bader_analysis"
+ ) as mock_bader,
+ patch.object(
+ transformer_with_tools, "_run_ddec6_analysis"
+ ) as mock_ddec6,
+ ):
+ mock_bader.return_value = (None, None)
+ mock_ddec6.return_value = None
+
+ result = transformer_with_tools.transform_row(raw_structure)
+
+ s = result[0]
+ assert s.bader_charges is None
+ assert s.bader_atomic_volume is None
+ assert s.ddec6_charges is None
+ # Compressed grids should still be present
+ assert s.compressed_charge_density == [[[1.0, 2.0]]]
+
+ def test_no_bader_binary_skips_bader(self, transformer_no_tools, raw_structure):
+ """When bader is not available, _run_bader_analysis should not be called."""
+ with (
+ patch.object(
+ transformer_no_tools, "_run_bader_analysis"
+ ) as mock_bader,
+ patch.object(
+ transformer_no_tools, "_run_ddec6_analysis"
+ ) as mock_ddec6,
+ ):
+ mock_ddec6.return_value = None
+
+ result = transformer_no_tools.transform_row(raw_structure)
+
+ mock_bader.assert_not_called()
+ mock_ddec6.assert_not_called()
+ s = result[0]
+ assert s.bader_charges is None
+ assert s.ddec6_charges is None
+
+ def test_functional_is_pbe(self, transformer_no_tools, raw_structure):
+ """Functional should always be PBE for LeMatRho (MP settings)."""
+ result = transformer_no_tools.transform_row(raw_structure)
+ assert result[0].functional == Functional.PBE
+
+ def test_cross_compatibility_excludes_yb(
+ self, transformer_no_tools, raw_structure_yb
+ ):
+ """Yb-containing structures should not be cross-compatible."""
+ result = transformer_no_tools.transform_row(raw_structure_yb)
+ assert result[0].cross_compatibility is False
+
+ def test_cross_compatibility_normal(self, transformer_no_tools, raw_structure):
+ """Non-Yb structures should be cross-compatible."""
+ result = transformer_no_tools.transform_row(raw_structure)
+ assert result[0].cross_compatibility is True
+
+ def test_missing_s3_prefix_skips_analyses(
+ self, transformer_with_tools, si_structure
+ ):
+ """If s3_prefix is missing from attributes, skip Bader and DDEC6."""
+ raw = RawStructure(
+ id="agm000003",
+ type="lematrho",
+ attributes={
+ "structure": si_structure.as_dict(),
+ "compressed_charge_density": None,
+ "grid_shape": [15, 15, 15],
+ # No s3_prefix
+ },
+ last_modified=datetime.datetime(2025, 1, 1),
+ )
+ with (
+ patch.object(
+ transformer_with_tools, "_run_bader_analysis"
+ ) as mock_bader,
+ patch.object(
+ transformer_with_tools, "_run_ddec6_analysis"
+ ) as mock_ddec6,
+ ):
+ result = transformer_with_tools.transform_row(raw)
+
+ mock_bader.assert_not_called()
+ mock_ddec6.assert_not_called()
+
+
+# ---------------------------------------------------------------------------
+# _run_bader_analysis tests
+# ---------------------------------------------------------------------------
+
+
+class TestRunBaderAnalysis:
+ def test_subprocess_timeout(self, transformer_with_tools, si_structure):
+ """Bader analysis should return (None, None) on subprocess timeout."""
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3"
+ ) as mock_dl,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.subprocess.run"
+ ) as mock_run,
+ patch("lematerial_fetcher.fetcher.lematrho.transform.write_potcar"),
+ patch.object(
+ type(transformer_with_tools),
+ "aws_client",
+ new_callable=lambda: property(lambda self: MagicMock()),
+ ),
+ ):
+ mock_dl.return_value = b"fake chgcar data"
+ mock_run.side_effect = subprocess.TimeoutExpired("bader", 600)
+
+ charges, volumes = transformer_with_tools._run_bader_analysis(
+ si_structure, "agm000001", "agm000001"
+ )
+
+ assert charges is None
+ assert volumes is None
+
+ def test_chgsum_nonzero_exit(self, transformer_with_tools, si_structure):
+ """chgsum.pl returning non-zero exit should be handled gracefully."""
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3"
+ ) as mock_dl,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.subprocess.run"
+ ) as mock_run,
+ patch("lematerial_fetcher.fetcher.lematrho.transform.write_potcar"),
+ patch.object(
+ type(transformer_with_tools),
+ "aws_client",
+ new_callable=lambda: property(lambda self: MagicMock()),
+ ),
+ ):
+ mock_dl.return_value = b"fake data"
+ mock_run.side_effect = subprocess.CalledProcessError(
+ 1, "perl chgsum.pl", stderr=b"chgsum error"
+ )
+
+ charges, volumes = transformer_with_tools._run_bader_analysis(
+ si_structure, "agm000001", "agm000001"
+ )
+
+ assert charges is None
+ assert volumes is None
+
+ def test_s3_download_failure(self, transformer_with_tools, si_structure):
+ """S3 download failure should be handled gracefully."""
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3"
+ ) as mock_dl,
+ patch.object(
+ type(transformer_with_tools),
+ "aws_client",
+ new_callable=lambda: property(lambda self: MagicMock()),
+ ),
+ ):
+ mock_dl.side_effect = Exception("NoSuchKey")
+
+ charges, volumes = transformer_with_tools._run_bader_analysis(
+ si_structure, "agm000001", "agm000001"
+ )
+
+ assert charges is None
+ assert volumes is None
+
+ def test_potcar_generation_failure(self, transformer_with_tools, si_structure):
+ """POTCAR generation failure should be handled gracefully."""
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3"
+ ) as mock_dl,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.write_potcar",
+ side_effect=Exception("No PSP"),
+ ),
+ patch.object(
+ type(transformer_with_tools),
+ "aws_client",
+ new_callable=lambda: property(lambda self: MagicMock()),
+ ),
+ ):
+ mock_dl.return_value = b"fake data"
+
+ charges, volumes = transformer_with_tools._run_bader_analysis(
+ si_structure, "agm000001", "agm000001"
+ )
+
+ assert charges is None
+ assert volumes is None
+
+
+# ---------------------------------------------------------------------------
+# _run_ddec6_analysis tests
+# ---------------------------------------------------------------------------
+
+
+class TestRunDdec6Analysis:
+ def test_subprocess_timeout(self, transformer_with_tools, si_structure):
+ """DDEC6 analysis should return None on subprocess timeout."""
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3"
+ ) as mock_dl,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.subprocess.run"
+ ) as mock_run,
+ patch("lematerial_fetcher.fetcher.lematrho.transform.write_potcar"),
+ patch.object(transformer_with_tools, "_write_chargemol_config"),
+ patch.object(
+ type(transformer_with_tools),
+ "aws_client",
+ new_callable=lambda: property(lambda self: MagicMock()),
+ ),
+ ):
+ mock_dl.return_value = b"fake chgcar data"
+ mock_run.side_effect = subprocess.TimeoutExpired("chargemol", 600)
+
+ result = transformer_with_tools._run_ddec6_analysis(
+ si_structure, "agm000001", "agm000001"
+ )
+
+ assert result is None
+
+ def test_chargemol_nonzero_exit(self, transformer_with_tools, si_structure):
+ """chargemol returning non-zero exit should be handled gracefully."""
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3"
+ ) as mock_dl,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.subprocess.run"
+ ) as mock_run,
+ patch("lematerial_fetcher.fetcher.lematrho.transform.write_potcar"),
+ patch.object(transformer_with_tools, "_write_chargemol_config"),
+ patch.object(
+ type(transformer_with_tools),
+ "aws_client",
+ new_callable=lambda: property(lambda self: MagicMock()),
+ ),
+ ):
+ mock_dl.return_value = b"fake data"
+ mock_run.side_effect = subprocess.CalledProcessError(
+ 1, "chargemol", stderr=b"chargemol error"
+ )
+
+ result = transformer_with_tools._run_ddec6_analysis(
+ si_structure, "agm000001", "agm000001"
+ )
+
+ assert result is None
+
+
+# ---------------------------------------------------------------------------
+# Temp directory cleanup tests
+# ---------------------------------------------------------------------------
+
+
+class TestTempDirectoryCleanup:
+ def test_cleanup_on_success(self, transformer_with_tools, si_structure):
+ """Temp directory should be cleaned up after successful analysis."""
+ created_tmpdir = [None]
+
+ original_tempdir = tempfile.TemporaryDirectory
+
+ class TrackingTempDir:
+ def __init__(self, *args, **kwargs):
+ self._real = original_tempdir(*args, **kwargs)
+ created_tmpdir[0] = self._real.name
+
+ def __enter__(self):
+ return self._real.__enter__()
+
+ def __exit__(self, *args):
+ return self._real.__exit__(*args)
+
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3"
+ ) as mock_dl,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.subprocess.run"
+ ),
+ patch("lematerial_fetcher.fetcher.lematrho.transform.write_potcar"),
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.parse_acf_dat"
+ ) as mock_parse,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.read_potcar_zval"
+ ) as mock_zval,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.tempfile.TemporaryDirectory",
+ TrackingTempDir,
+ ),
+ patch.object(
+ type(transformer_with_tools),
+ "aws_client",
+ new_callable=lambda: property(lambda self: MagicMock()),
+ ),
+ ):
+ mock_dl.return_value = b"fake data"
+ mock_parse.return_value = ([4.0, 4.0], [10.0, 12.0])
+ mock_zval.return_value = {"Si": 4.0}
+
+ transformer_with_tools._run_bader_analysis(
+ si_structure, "agm000001", "agm000001"
+ )
+
+ assert created_tmpdir[0] is not None
+ assert not os.path.exists(created_tmpdir[0])
+
+ def test_cleanup_on_failure(self, transformer_with_tools, si_structure):
+ """Temp directory should be cleaned up even on failure."""
+ created_tmpdir = [None]
+
+ original_tempdir = tempfile.TemporaryDirectory
+
+ class TrackingTempDir:
+ def __init__(self, *args, **kwargs):
+ self._real = original_tempdir(*args, **kwargs)
+ created_tmpdir[0] = self._real.name
+
+ def __enter__(self):
+ return self._real.__enter__()
+
+ def __exit__(self, *args):
+ return self._real.__exit__(*args)
+
+ with (
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.download_gz_file_from_s3"
+ ) as mock_dl,
+ patch(
+ "lematerial_fetcher.fetcher.lematrho.transform.tempfile.TemporaryDirectory",
+ TrackingTempDir,
+ ),
+ patch.object(
+ type(transformer_with_tools),
+ "aws_client",
+ new_callable=lambda: property(lambda self: MagicMock()),
+ ),
+ ):
+ mock_dl.side_effect = Exception("S3 error")
+
+ transformer_with_tools._run_bader_analysis(
+ si_structure, "agm000001", "agm000001"
+ )
+
+ assert created_tmpdir[0] is not None
+ assert not os.path.exists(created_tmpdir[0])
+
+
+# ---------------------------------------------------------------------------
+# Parsing function unit tests
+# ---------------------------------------------------------------------------
+
+
+class TestParseAcfDat:
+ def test_parse_standard_format(self, tmp_path):
+ """Parse a standard ACF.dat file."""
+ acf_content = (
+ " # X Y Z CHARGE MIN DIST ATOMIC VOL\n"
+ " -----------------------------------------------------------------------\n"
+ " 1 0.000000 0.000000 0.000000 3.112903 1.235698 19.382956\n"
+ " 2 1.357500 1.357500 1.357500 4.887097 1.235698 13.617044\n"
+ " -----------------------------------------------------------------------\n"
+ " VACUUM CHARGE: 0.0000\n"
+ " VACUUM VOLUME: 0.0000\n"
+ " NUMBER OF ELECTRONS: 8.0000\n"
+ )
+ acf_file = tmp_path / "ACF.dat"
+ acf_file.write_text(acf_content)
+
+ counts, volumes = parse_acf_dat(str(acf_file))
+
+ assert len(counts) == 2
+ assert len(volumes) == 2
+ assert abs(counts[0] - 3.112903) < 1e-6
+ assert abs(counts[1] - 4.887097) < 1e-6
+ assert abs(volumes[0] - 19.382956) < 1e-6
+ assert abs(volumes[1] - 13.617044) < 1e-6
+
+ def test_parse_single_atom(self, tmp_path):
+ """Parse ACF.dat with a single atom."""
+ acf_content = (
+ " # X Y Z CHARGE MIN DIST ATOMIC VOL\n"
+ " -----------------------------------------------------------------------\n"
+ " 1 0.000000 0.000000 0.000000 8.000000 2.715000 40.000000\n"
+ " -----------------------------------------------------------------------\n"
+ )
+ acf_file = tmp_path / "ACF.dat"
+ acf_file.write_text(acf_content)
+
+ counts, volumes = parse_acf_dat(str(acf_file))
+
+ assert counts == [8.0]
+ assert volumes == [40.0]
+
+
+class TestReadPotcarZval:
+ def test_parse_potcar(self, tmp_path):
+ """Parse POTCAR for valence electron counts."""
+ potcar_content = (
+ " TITEL = PAW_PBE Si 05Jan2001\n"
+ " ZVAL = 4.00000\n"
+ " END of PSCTR\n"
+ " TITEL = PAW_PBE O 08Apr2002\n"
+ " ZVAL = 6.00000\n"
+ " END of PSCTR\n"
+ )
+ potcar_file = tmp_path / "POTCAR"
+ potcar_file.write_text(potcar_content)
+
+ zval = read_potcar_zval(str(potcar_file))
+
+ assert zval["Si"] == 4.0
+ assert zval["O"] == 6.0
+
+ def test_parse_potcar_with_underscore_element(self, tmp_path):
+ """Handle element names like Si_d in POTCAR."""
+ potcar_content = (
+ " TITEL = PAW_PBE Si_d 05Jan2001\n"
+ " ZVAL = 4.00000\n"
+ " END of PSCTR\n"
+ )
+ potcar_file = tmp_path / "POTCAR"
+ potcar_file.write_text(potcar_content)
+
+ zval = read_potcar_zval(str(potcar_file))
+
+ assert zval["Si"] == 4.0
+
+
+class TestParseDdec6Charges:
+ def test_parse_standard_output(self, tmp_path):
+ """Parse DDEC6 output file."""
+ ddec6_content = (
+ " 2\n"
+ " Charge analysis\n"
+ " Si 0.000000 0.000000 0.000000 0.123456\n"
+ " Si 1.357500 1.357500 1.357500 -0.123456\n"
+ )
+ ddec6_file = tmp_path / "DDEC6_even_tempered_net_atomic_charges.xyz"
+ ddec6_file.write_text(ddec6_content)
+
+ charges = parse_ddec6_charges(str(tmp_path))
+
+ assert len(charges) == 2
+ assert abs(charges[0] - 0.123456) < 1e-6
+ assert abs(charges[1] - (-0.123456)) < 1e-6
+
+
+# ---------------------------------------------------------------------------
+# Cross-compatibility function tests
+# ---------------------------------------------------------------------------
+
+
+class TestGetCrossCompatibility:
+ def test_normal_elements(self):
+ assert get_cross_compatibility(["Si", "O"]) is True
+
+ def test_yb_excluded(self):
+ assert get_cross_compatibility(["Yb", "O"]) is False
+
+ def test_yb_in_larger_set(self):
+ assert get_cross_compatibility(["Fe", "Yb", "O"]) is False
+
+ def test_empty_elements(self):
+ assert get_cross_compatibility([]) is True
+
+
+# ---------------------------------------------------------------------------
+# Tool validation tests
+# ---------------------------------------------------------------------------
+
+
+class TestValidateTools:
+ def test_all_tools_available(self, mock_config):
+ """When all tools are found, can_run_bader and can_run_ddec6 should be True."""
+ with (
+ patch("shutil.which", return_value="/usr/bin/tool"),
+ patch("os.path.isfile", return_value=True),
+ patch("os.path.isdir", return_value=True),
+ patch.dict(os.environ, {"PMG_VASP_PSP_DIR": "/path/to/psp"}),
+ ):
+ transformer = LeMatRhoTransformer(config=mock_config, debug=True)
+
+ assert transformer.can_run_bader is True
+ assert transformer.can_run_ddec6 is True
+
+ def test_no_tools_available(self):
+ """When no tools are found, can_run_bader and can_run_ddec6 should be False."""
+ config = TransformerConfig(
+ source_db_conn_str="mock://source",
+ dest_db_conn_str="mock://dest",
+ source_table_name="test_source",
+ dest_table_name="test_dest",
+ batch_size=100,
+ page_offset=0,
+ log_every=100,
+ log_dir="./logs",
+ max_retries=3,
+ page_limit=10,
+ num_workers=1,
+ retry_delay=2,
+ # No tool paths set
+ )
+ with (
+ patch("shutil.which", return_value=None),
+ patch.dict(os.environ, {}, clear=True),
+ ):
+ transformer = LeMatRhoTransformer(config=config, debug=True)
+
+ assert transformer.can_run_bader is False
+ assert transformer.can_run_ddec6 is False
+
+ def test_no_pmg_vasp_psp_dir(self, mock_config):
+ """Without PMG_VASP_PSP_DIR, both analyses should be disabled."""
+ with (
+ patch("shutil.which", return_value="/usr/bin/tool"),
+ patch("os.path.isfile", return_value=True),
+ patch("os.path.isdir", return_value=True),
+ patch.dict(os.environ, {}, clear=True),
+ ):
+ transformer = LeMatRhoTransformer(config=mock_config, debug=True)
+
+ assert transformer._can_generate_potcar is False
+ assert transformer.can_run_bader is False
+ assert transformer.can_run_ddec6 is False
diff --git a/tests/models/test_optimade_model.py b/tests/models/test_optimade_model.py
index 39ecf53..bcfcca9 100644
--- a/tests/models/test_optimade_model.py
+++ b/tests/models/test_optimade_model.py
@@ -3,7 +3,9 @@
import pytest
+from lematerial_fetcher.database.postgres import OptimadeDatabase, TrajectoriesDatabase
from lematerial_fetcher.models.optimade import Functional, OptimadeStructure
+from lematerial_fetcher.models.utils.enums import Source
# Test data for a valid structure
VALID_STRUCTURE_DATA = {
@@ -249,3 +251,129 @@ def test_functional_enum():
data["functional"] = func
structure = OptimadeStructure(**data)
assert structure.functional == func
+
+
+# -------------------------------------------------------------------
+# LeMatRho charge density field tests
+# -------------------------------------------------------------------
+
+
+def test_source_lematrho_is_valid():
+ """Test that LEMATRHO is a valid Source enum value."""
+ assert Source.LEMATRHO == "lematrho"
+ data = VALID_STRUCTURE_DATA.copy()
+ data["source"] = "lematrho"
+ structure = OptimadeStructure(**data)
+ assert structure.source == Source.LEMATRHO
+
+
+def test_charge_density_none_fields():
+ """Regression: existing VALID_STRUCTURE_DATA still passes with new optional fields as None."""
+ structure = OptimadeStructure(**VALID_STRUCTURE_DATA)
+ assert structure.compressed_charge_density is None
+ assert structure.compressed_aeccar0 is None
+ assert structure.compressed_aeccar1 is None
+ assert structure.compressed_aeccar2 is None
+ assert structure.charge_density_grid_shape is None
+ assert structure.bader_charges is None
+ assert structure.bader_atomic_volume is None
+ assert structure.ddec6_charges is None
+
+
+def test_structure_with_charge_density_fields():
+ """Test structure creation with all charge density fields populated."""
+ data = VALID_STRUCTURE_DATA.copy()
+ # nsites=2 so per-site lists must have length 2
+ data.update(
+ {
+ "compressed_charge_density": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],
+ "compressed_aeccar0": [[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]],
+ "compressed_aeccar1": [[[0.01, 0.02]]],
+ "compressed_aeccar2": [[[0.01, 0.02]]],
+ "charge_density_grid_shape": [2, 2, 2],
+ "bader_charges": [1.5, -1.5],
+ "bader_atomic_volume": [10.0, 12.0],
+ "ddec6_charges": [0.8, -0.8],
+ }
+ )
+ structure = OptimadeStructure(**data)
+ assert structure.charge_density_grid_shape == [2, 2, 2]
+ assert structure.bader_charges == [1.5, -1.5]
+ assert structure.bader_atomic_volume == [10.0, 12.0]
+ assert structure.ddec6_charges == [0.8, -0.8]
+ assert structure.compressed_charge_density is not None
+
+
+def test_bader_charges_wrong_length():
+ """Test that bader_charges with wrong length raises ValueError."""
+ data = VALID_STRUCTURE_DATA.copy()
+ data["bader_charges"] = [1.0, 2.0, 3.0] # nsites=2, but 3 values
+ with pytest.raises(ValueError, match="bader_charges"):
+ OptimadeStructure(**data)
+
+
+def test_ddec6_charges_wrong_length():
+ """Test that ddec6_charges with wrong length raises ValueError."""
+ data = VALID_STRUCTURE_DATA.copy()
+ data["ddec6_charges"] = [1.0] # nsites=2, but 1 value
+ with pytest.raises(ValueError, match="ddec6_charges"):
+ OptimadeStructure(**data)
+
+
+def test_bader_atomic_volume_wrong_length():
+ """Test that bader_atomic_volume with wrong length raises ValueError."""
+ data = VALID_STRUCTURE_DATA.copy()
+ data["bader_atomic_volume"] = [10.0, 12.0, 14.0] # nsites=2, but 3 values
+ with pytest.raises(ValueError, match="bader_atomic_volume"):
+ OptimadeStructure(**data)
+
+
+def test_charge_density_grid_shape_validation():
+ """Test that charge_density_grid_shape must be exactly 3 elements."""
+ data = VALID_STRUCTURE_DATA.copy()
+
+ # Too short
+ data["charge_density_grid_shape"] = [15, 15]
+ with pytest.raises(ValueError):
+ OptimadeStructure(**data)
+
+ # Too long
+ data["charge_density_grid_shape"] = [15, 15, 15, 15]
+ with pytest.raises(ValueError):
+ OptimadeStructure(**data)
+
+
+def test_optimade_db_columns_include_charge_density_fields():
+ """Test that OptimadeDatabase.columns() includes the new charge density columns."""
+ cols = OptimadeDatabase.columns()
+ assert "compressed_charge_density" in cols
+ assert "compressed_aeccar0" in cols
+ assert "compressed_aeccar1" in cols
+ assert "compressed_aeccar2" in cols
+ assert "charge_density_grid_shape" in cols
+ assert "bader_charges" in cols
+ assert "bader_atomic_volume" in cols
+ assert "ddec6_charges" in cols
+
+
+def test_trajectories_db_columns_inherit_charge_density_fields():
+ """Test that TrajectoriesDatabase.columns() inherits the charge density columns."""
+ cols = TrajectoriesDatabase.columns()
+ assert "compressed_charge_density" in cols
+ assert "bader_charges" in cols
+ assert "ddec6_charges" in cols
+ # Also still has trajectory-specific columns
+ assert "relaxation_step" in cols
+ assert "relaxation_number" in cols
+
+
+def test_optimade_db_column_count_matches_insert_tuple():
+ """Guard test: verify that the number of columns matches what insert_data expects.
+
+ This prevents silent data corruption from tuple/column ordering mismatches
+ across the 4 manually maintained tuple definitions in postgres.py.
+ """
+ optimade_col_count = len(OptimadeDatabase.columns())
+ traj_col_count = len(TrajectoriesDatabase.columns())
+ # TrajectoriesDatabase should have exactly 2 more columns (relaxation_step, relaxation_number)
+ assert traj_col_count == optimade_col_count + 2
diff --git a/tests/test_cli.py b/tests/test_cli.py
index e6feb0c..faf9723 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -7,6 +7,7 @@
from lematerial_fetcher.cli import cli
from lematerial_fetcher.utils.config import (
+ DirectPipelineConfig,
FetcherConfig,
TransformerConfig,
)
@@ -231,6 +232,147 @@ def test_cli_args_override_env_vars():
assert call_kwargs["db_host"] == "cli.host"
+def test_lematrho_subcommands():
+ """Test that lematrho subcommands are correctly registered."""
+ runner = CliRunner()
+
+ result = runner.invoke(cli, ["lematrho", "--help"])
+ assert result.exit_code == 0
+ assert "Commands for fetching charge density data from LeMatRho" in result.output
+ assert "fetch" in result.output
+ assert "transform" in result.output
+
+
+def test_lematrho_fetch_help():
+ """Test that lematrho fetch --help returns 0 and shows expected options."""
+ runner = CliRunner()
+ result = runner.invoke(cli, ["lematrho", "fetch", "--help"])
+ assert result.exit_code == 0
+ assert "--lematrho-bucket-name" in result.output
+ assert "--grid-shape" in result.output
+
+
+def test_lematrho_transform_help():
+ """Test that lematrho transform --help returns 0 and shows expected options."""
+ runner = CliRunner()
+ result = runner.invoke(cli, ["lematrho", "transform", "--help"])
+ assert result.exit_code == 0
+ assert "--bader-path" in result.output
+ assert "--chargemol-path" in result.output
+ assert "--chgsum-script-path" in result.output
+ assert "--atomic-densities-path" in result.output
+ assert "--lematrho-bucket-name" in result.output
+
+
+@patch("lematerial_fetcher.cli.LeMatRhoFetcher")
+@patch("lematerial_fetcher.cli.load_fetcher_config")
+def test_lematrho_fetch_passes_cli_args(mock_load_config, mock_fetcher):
+ """Test that lematrho fetch command passes CLI args to the config loader."""
+ mock_config = FetcherConfig(
+ log_dir="./logs",
+ max_retries=3,
+ num_workers=2,
+ retry_delay=2,
+ log_every=1000,
+ page_offset=0,
+ page_limit=10,
+ base_url="DUMMY_BASE_URL",
+ table_name="test_table",
+ db_conn_str="db_conn_string",
+ mp_bucket_name="",
+ mp_bucket_prefix="",
+ lematrho_bucket_name="lemat-rho",
+ lematrho_grid_shape=(15, 15, 15),
+ )
+ mock_load_config.return_value = mock_config
+ mock_fetcher_instance = mock_fetcher.return_value
+
+ runner = CliRunner()
+ result = runner.invoke(
+ cli,
+ [
+ "lematrho",
+ "fetch",
+ "--db-user",
+ "test_user",
+ "--table-name",
+ "test_raw",
+ "--lematrho-bucket-name",
+ "my-bucket",
+ "--grid-shape",
+ "20",
+ "20",
+ "20",
+ ],
+ )
+
+ assert result.exit_code == 0
+ mock_load_config.assert_called_once()
+ call_kwargs = mock_load_config.call_args[1]
+ assert call_kwargs["db_user"] == "test_user"
+ assert call_kwargs["table_name"] == "test_raw"
+ assert call_kwargs["lematrho_bucket_name"] == "my-bucket"
+ assert call_kwargs["lematrho_grid_shape"] == (20, 20, 20)
+
+ mock_fetcher.assert_called_once_with(config=mock_config, debug=False)
+ mock_fetcher_instance.fetch.assert_called_once()
+
+
+@patch("lematerial_fetcher.cli.LeMatRhoTransformer")
+@patch("lematerial_fetcher.cli.load_transformer_config")
+def test_lematrho_transform_passes_cli_args(mock_load_config, mock_transformer):
+ """Test that lematrho transform command passes CLI args to the config loader."""
+ mock_config = TransformerConfig(
+ log_dir="./logs",
+ max_retries=3,
+ num_workers=2,
+ retry_delay=2,
+ log_every=1000,
+ page_offset=0,
+ page_limit=10,
+ source_db_conn_str="source_conn_string",
+ dest_db_conn_str="dest_conn_string",
+ source_table_name="source_table",
+ dest_table_name="dest_table",
+ batch_size=500,
+ lematrho_bucket_name="lemat-rho",
+ bader_path="/usr/bin/bader",
+ )
+ mock_load_config.return_value = mock_config
+ mock_transformer_instance = mock_transformer.return_value
+
+ runner = CliRunner()
+ result = runner.invoke(
+ cli,
+ [
+ "lematrho",
+ "transform",
+ "--db-user",
+ "src_user",
+ "--table-name",
+ "src_table",
+ "--dest-table-name",
+ "dest_table",
+ "--bader-path",
+ "/opt/bader",
+ "--lematrho-bucket-name",
+ "my-bucket",
+ ],
+ )
+
+ assert result.exit_code == 0
+ mock_load_config.assert_called_once()
+ call_kwargs = mock_load_config.call_args[1]
+ assert call_kwargs["db_user"] == "src_user"
+ assert call_kwargs["table_name"] == "src_table"
+ assert call_kwargs["dest_table_name"] == "dest_table"
+ assert call_kwargs["bader_path"] == "/opt/bader"
+ assert call_kwargs["lematrho_bucket_name"] == "my-bucket"
+
+ mock_transformer.assert_called_once_with(config=mock_config, debug=False)
+ mock_transformer_instance.transform.assert_called_once()
+
+
def test_env_vars_pass_to_config():
"""Test that environment variables are passed to the config loader"""
@@ -260,3 +402,99 @@ def test_env_vars_pass_to_config():
assert call_kwargs["db_user"] == "src_user"
assert call_kwargs["table_name"] == "src_table"
assert call_kwargs["dest_table_name"] == "dest_table"
+
+
+def test_lematrho_run_help():
+ """Test that lematrho run --help returns 0 and shows expected options."""
+ runner = CliRunner()
+ result = runner.invoke(cli, ["lematrho", "run", "--help"])
+ assert result.exit_code == 0
+ assert "--output-dir" in result.output
+ assert "--parquet-chunk-size" in result.output
+ assert "--lematrho-bucket-name" in result.output
+ assert "--grid-shape" in result.output
+ assert "--hf-repo-id" in result.output
+ assert "--hf-token" in result.output
+ assert "--bader-path" in result.output
+ assert "--chargemol-path" in result.output
+ assert "--chgsum-script-path" in result.output
+ assert "--atomic-densities-path" in result.output
+ assert "--num-workers" in result.output
+
+
+def test_lematrho_run_in_subcommands():
+ """Test that 'run' appears in lematrho subcommands alongside fetch and transform."""
+ runner = CliRunner()
+ result = runner.invoke(cli, ["lematrho", "--help"])
+ assert result.exit_code == 0
+ assert "run" in result.output
+ assert "fetch" in result.output
+ assert "transform" in result.output
+
+
+@patch("lematerial_fetcher.cli.LeMatRhoDirectPipeline")
+@patch("lematerial_fetcher.cli.load_direct_pipeline_config")
+def test_lematrho_run_passes_cli_args(mock_load_config, mock_pipeline):
+ """Test that lematrho run command passes CLI args to the config loader."""
+ mock_config = DirectPipelineConfig(
+ lematrho_bucket_name="my-bucket",
+ lematrho_grid_shape=(20, 20, 20),
+ output_dir="/tmp/test_output",
+ parquet_chunk_size=500,
+ num_workers=2,
+ )
+ mock_load_config.return_value = mock_config
+ mock_pipeline_instance = mock_pipeline.return_value
+
+ runner = CliRunner()
+ result = runner.invoke(
+ cli,
+ [
+ "lematrho",
+ "run",
+ "--output-dir",
+ "/tmp/test_output",
+ "--parquet-chunk-size",
+ "500",
+ "--num-workers",
+ "2",
+ "--lematrho-bucket-name",
+ "my-bucket",
+ "--grid-shape",
+ "20",
+ "20",
+ "20",
+ "--bader-path",
+ "/opt/bader",
+ ],
+ )
+
+ assert result.exit_code == 0
+ mock_load_config.assert_called_once()
+ call_kwargs = mock_load_config.call_args[1]
+ assert call_kwargs["output_dir"] == "/tmp/test_output"
+ assert call_kwargs["parquet_chunk_size"] == 500
+ assert call_kwargs["num_workers"] == 2
+ assert call_kwargs["lematrho_bucket_name"] == "my-bucket"
+ assert call_kwargs["grid_shape"] == (20, 20, 20)
+ assert call_kwargs["bader_path"] == "/opt/bader"
+
+ mock_pipeline.assert_called_once_with(config=mock_config, debug=False)
+ mock_pipeline_instance.run.assert_called_once()
+
+
+@patch("lematerial_fetcher.cli.LeMatRhoDirectPipeline")
+@patch("lematerial_fetcher.cli.load_direct_pipeline_config")
+def test_lematrho_run_debug_flag(mock_load_config, mock_pipeline):
+ """Test that --debug flag is passed through to the pipeline."""
+ mock_config = DirectPipelineConfig()
+ mock_load_config.return_value = mock_config
+
+ runner = CliRunner()
+ result = runner.invoke(
+ cli,
+ ["--debug", "lematrho", "run"],
+ )
+
+ assert result.exit_code == 0
+ mock_pipeline.assert_called_once_with(config=mock_config, debug=True)
diff --git a/tests/utils/test_aws.py b/tests/utils/test_aws.py
index 6309aae..375a4fc 100644
--- a/tests/utils/test_aws.py
+++ b/tests/utils/test_aws.py
@@ -8,6 +8,7 @@
from lematerial_fetcher.utils.aws import (
download_s3_object,
+ get_authenticated_aws_client,
get_aws_client,
get_latest_collection_version_prefix,
list_s3_objects,
@@ -22,6 +23,39 @@ def test_get_aws_client():
assert client._client_config.region_name == "us-east-1"
+def test_get_aws_client_unchanged():
+ """Regression: anonymous client still uses UNSIGNED after adding authenticated client"""
+ client = get_aws_client()
+ assert client._client_config.signature_version == UNSIGNED
+
+
+def test_get_authenticated_aws_client_no_unsigned():
+ """Test that authenticated client does not use UNSIGNED signature"""
+ client = get_authenticated_aws_client()
+ assert client._client_config.signature_version != UNSIGNED
+
+
+def test_get_authenticated_aws_client_has_retry_config():
+ """Test that authenticated client has adaptive retry configuration"""
+ client = get_authenticated_aws_client()
+ retry_config = client._client_config.retries
+ # boto3 converts max_attempts=3 to total_max_attempts=4 (initial + retries)
+ assert retry_config["total_max_attempts"] == 4
+ assert retry_config["mode"] == "adaptive"
+
+
+def test_get_authenticated_aws_client_default_region():
+ """Test that authenticated client defaults to us-east-1"""
+ client = get_authenticated_aws_client()
+ assert client._client_config.region_name == "us-east-1"
+
+
+def test_get_authenticated_aws_client_custom_region():
+ """Test that authenticated client accepts a custom region"""
+ client = get_authenticated_aws_client(region_name="eu-west-1")
+ assert client._client_config.region_name == "eu-west-1"
+
+
@pytest.fixture
def mock_s3_client():
"""Fixture to create a stubbed S3 client"""
diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py
index ff94319..df5201c 100644
--- a/tests/utils/test_config.py
+++ b/tests/utils/test_config.py
@@ -7,6 +7,8 @@
from dotenv import load_dotenv
from lematerial_fetcher.utils.config import (
+ DirectPipelineConfig,
+ load_direct_pipeline_config,
load_fetcher_config,
load_push_config,
load_transformer_config,
@@ -724,3 +726,110 @@ def test_load_push_config_missing_required():
assert "db credentials" in str(excinfo.value)
assert "table_name" in str(excinfo.value)
assert "hf_repo_id" in str(excinfo.value)
+
+
+# ---------------------------------------------------------------------------
+# DirectPipelineConfig tests
+# ---------------------------------------------------------------------------
+
+
+class TestDirectPipelineConfig:
+ def test_defaults(self):
+ """All defaults should produce a valid config."""
+ config = DirectPipelineConfig()
+ assert config.lematrho_bucket_name == "lemat-rho"
+ assert config.lematrho_grid_shape == (15, 15, 15)
+ assert config.output_dir == "./lematrho_output"
+ assert config.parquet_chunk_size == 1000
+ assert config.num_workers == 4
+ assert config.log_every == 100
+ assert config.hf_repo_id is None
+ assert config.hf_token is None
+ assert config.bader_path is None
+ assert config.chargemol_path is None
+ assert config.chgsum_script_path is None
+ assert config.atomic_densities_path is None
+
+ def test_custom_values(self):
+ """Config should accept custom values for all fields."""
+ config = DirectPipelineConfig(
+ lematrho_bucket_name="my-bucket",
+ lematrho_grid_shape=(20, 20, 20),
+ output_dir="/tmp/output",
+ parquet_chunk_size=500,
+ num_workers=2,
+ log_every=50,
+ hf_repo_id="org/repo",
+ hf_token="hf_abc123",
+ bader_path="/usr/bin/bader",
+ chargemol_path="/usr/bin/chargemol",
+ chgsum_script_path="/opt/chgsum.pl",
+ atomic_densities_path="/opt/atomic_densities",
+ )
+ assert config.lematrho_bucket_name == "my-bucket"
+ assert config.lematrho_grid_shape == (20, 20, 20)
+ assert config.output_dir == "/tmp/output"
+ assert config.parquet_chunk_size == 500
+ assert config.num_workers == 2
+ assert config.hf_repo_id == "org/repo"
+ assert config.bader_path == "/usr/bin/bader"
+ assert config.chargemol_path == "/usr/bin/chargemol"
+ assert config.chgsum_script_path == "/opt/chgsum.pl"
+ assert config.atomic_densities_path == "/opt/atomic_densities"
+
+ def test_not_a_base_config(self):
+ """DirectPipelineConfig should NOT inherit from BaseConfig."""
+ from lematerial_fetcher.utils.config import BaseConfig
+
+ assert not issubclass(DirectPipelineConfig, BaseConfig)
+
+
+class TestLoadDirectPipelineConfig:
+ def test_defaults(self):
+ """Loader with no args should return config with all defaults."""
+ config = load_direct_pipeline_config()
+ assert config.lematrho_bucket_name == "lemat-rho"
+ assert config.lematrho_grid_shape == (15, 15, 15)
+ assert config.output_dir == "./lematrho_output"
+ assert config.parquet_chunk_size == 1000
+ assert config.num_workers == 4
+
+ def test_passes_arguments_through(self):
+ """Loader should pass all arguments to the config."""
+ config = load_direct_pipeline_config(
+ lematrho_bucket_name="custom-bucket",
+ grid_shape=(10, 10, 10),
+ output_dir="/data/output",
+ parquet_chunk_size=2000,
+ num_workers=8,
+ log_every=200,
+ hf_repo_id="org/dataset",
+ hf_token="token123",
+ bader_path="/bin/bader",
+ chargemol_path="/bin/chargemol",
+ chgsum_script_path="/scripts/chgsum.pl",
+ atomic_densities_path="/data/densities",
+ )
+ assert config.lematrho_bucket_name == "custom-bucket"
+ assert config.lematrho_grid_shape == (10, 10, 10)
+ assert config.output_dir == "/data/output"
+ assert config.parquet_chunk_size == 2000
+ assert config.num_workers == 8
+ assert config.log_every == 200
+ assert config.hf_repo_id == "org/dataset"
+ assert config.hf_token == "token123"
+ assert config.bader_path == "/bin/bader"
+
+ def test_ignores_unknown_kwargs(self):
+ """Loader should silently ignore unknown kwargs (from Click spillover)."""
+ config = load_direct_pipeline_config(
+ debug=True,
+ cache_dir="/tmp/cache",
+ some_random_kwarg="value",
+ )
+ assert isinstance(config, DirectPipelineConfig)
+
+ def test_grid_shape_kwarg_maps_to_config(self):
+ """Loader uses 'grid_shape' (Click name) mapped to 'lematrho_grid_shape'."""
+ config = load_direct_pipeline_config(grid_shape=(25, 25, 25))
+ assert config.lematrho_grid_shape == (25, 25, 25)