diff --git a/src/gpuma/cli.py b/src/gpuma/cli.py index 84e8470..cf78e62 100644 --- a/src/gpuma/cli.py +++ b/src/gpuma/cli.py @@ -391,12 +391,12 @@ def cmd_optimize(args, config: Config) -> None: eff_charge, eff_mult, ) + config.optimization.charge = eff_charge + config.optimization.multiplicity = eff_mult optimized = optimize_single_xyz_file( input_file=args.xyz, output_file=args.output, config=config, - charge=eff_charge, - multiplicity=eff_mult, ) logger.info( @@ -428,10 +428,10 @@ def cmd_ensemble(args, config: Config) -> None: num_conf = args.conformers or config.optimization.max_num_conformers logger.info("Generating %d conformers for SMILES: %s", num_conf, args.smiles) + config.optimization.max_num_conformers = num_conf optimized_conformers = optimize_ensemble_smiles( smiles=args.smiles, - num_conformers=num_conf, output_file=args.output, config=config, ) @@ -505,7 +505,7 @@ def cmd_batch(args, config: Config) -> None: sys.exit(1) -def cmd_convert(args) -> None: # pylint: disable=unused-argument +def cmd_convert(args, config: Config | None = None) -> None: # pylint: disable=unused-argument """Handle the SMILES to XYZ conversion command. This command generates a single 3D structure from SMILES without running diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d01cb74 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,185 @@ +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch + +from gpuma.structure import Structure + + +@pytest.fixture +def mock_hf_token(monkeypatch): + """Ensure HF_TOKEN is set for tests that check for it, or unset it.""" + # By default, we might want to unset it to ensure our mocking works even without it + # But some code paths check for it. + monkeypatch.setenv("HF_TOKEN", "fake_token") + +@pytest.fixture +def sample_structure(): + return Structure( + symbols=["C", "H", "H", "H", "H"], + coordinates=[ + (0.0, 0.0, 0.0), + (0.63, 0.63, 0.63), + (-0.63, -0.63, 0.63), + (-0.63, 0.63, -0.63), + (0.63, -0.63, -0.63), + ], + charge=0, + multiplicity=1, + comment="Methane", + ) + +@pytest.fixture +def sample_xyz_content(): + return """5 +Methane +C 0.000000 0.000000 0.000000 +H 0.630000 0.630000 0.630000 +H -0.630000 -0.630000 0.630000 +H -0.630000 0.630000 -0.630000 +H 0.630000 -0.630000 -0.630000 +""" + +@pytest.fixture +def sample_multi_xyz_content(): + return """3 +Water +O 0.000000 0.000000 0.000000 +H 0.757000 0.586000 0.000000 +H -0.757000 0.586000 0.000000 +5 +Methane +C 0.000000 0.000000 0.000000 +H 0.630000 0.630000 0.630000 +H -0.630000 -0.630000 0.630000 +H -0.630000 0.630000 -0.630000 +H 0.630000 -0.630000 -0.630000 +""" + +@pytest.fixture +def mock_fairchem_calculator(): + """Returns a mock Fairchem calculator.""" + mock_calc = MagicMock() + # Mocking what ASE calculator expects + mock_calc.get_potential_energy.return_value = -100.0 + mock_calc.get_forces.return_value = np.zeros((5, 3)) # 5 atoms, 3 coords + + # We also need to mock the internal implementation of FAIRChemCalculator if needed + # but since we mock the class or the object returned by load_model_fairchem, + # the ASE interface methods are what matters for optimize_single_structure (BFGS) + + # ASE optimizer calls get_potential_energy and get_forces on the Atoms object, + # which delegates to calc. + def calculate(atoms, properties, system_changes): + # update results + mock_calc.results = { + 'energy': -100.0, + 'forces': np.zeros((len(atoms), 3)) + } + + mock_calc.calculate = calculate + return mock_calc + +@pytest.fixture +def mock_torchsim_model(): + """Returns a mock TorchSim model.""" + mock_model = MagicMock() + mock_model.model_name = "mock-uma" + + # TorchSim model is called with a batched state + # and returns energy, forces, etc. + # However, torch_sim.optimize calls model(system) -> output + + def forward(system): + n_systems = system.n_systems + n_atoms = system.n_atoms + device = system.positions.device + + # Mock output + # energy: (n_systems,) + # forces: (n_atoms, 3) + return MagicMock( + energy=torch.zeros(n_systems, device=device), + forces=torch.zeros((n_atoms, 3), device=device) + ) + + mock_model.side_effect = forward + return mock_model + +@pytest.fixture(autouse=True) +def mock_load_models(request): + """Automatically mock model loading functions to prevent network access.""" + # Check if the test is marked to use real models (optional, for future) + if "real_model" in request.keywords: + return + + # Mock fairchem loading + # We mock _get_cached_calculator and load_model_fairchem + + with patch("gpuma.optimizer.load_model_fairchem") as mock_load_fc, \ + patch("gpuma.optimizer._get_cached_calculator") as mock_get_cached_fc, \ + patch("gpuma.optimizer.load_model_torchsim") as mock_load_ts, \ + patch("gpuma.optimizer._get_cached_torchsim_model") as mock_get_cached_ts: + + # Setup mocks + mock_calc = MagicMock() + # Setup ASE calculator mock behavior + mock_calc.get_potential_energy.return_value = -50.0 + mock_calc.get_forces.return_value = np.zeros((5, 3)) + # Ensure it works when assigned to atoms.calc + def side_effect_calc(atoms=None, **kwargs): + pass + mock_calc.calculate = MagicMock(side_effect=side_effect_calc) + mock_calc.results = {'energy': -50.0, 'forces': np.zeros((1, 3))} # Default + + # Better mock for ASE calculator + class MockCalculator: + def __init__(self): + self.results = {} + self.pars = {} + self.atoms = None + def calculate(self, atoms=None, properties=None, system_changes=None): + if properties is None: + properties = ['energy'] + self.results['energy'] = -50.0 + self.results['forces'] = np.zeros((len(atoms), 3)) + + def get_potential_energy(self, atoms=None, force_consistent=False): + if atoms: + self.calculate(atoms) + return self.results['energy'] + + def get_forces(self, atoms=None): + if atoms: + self.calculate(atoms) + return self.results['forces'] + def reset(self): + pass + + mock_instance = MockCalculator() + mock_load_fc.return_value = mock_instance + mock_get_cached_fc.return_value = mock_instance + + # Setup TorchSim model mock + mock_ts_model = MagicMock() + mock_ts_model.model_name = "mock-uma" + + def ts_forward(system): + n_systems = system.n_systems + n_atoms = system.n_atoms + # Create tensors on the same device as system + # Ensure we return objects that behave like tensors + energy = torch.zeros(n_systems).to(system.positions.device) + forces = torch.zeros((n_atoms, 3)).to(system.positions.device) + + output = MagicMock() + output.energy = energy + output.forces = forces + return output + + mock_ts_model.side_effect = ts_forward + mock_load_ts.return_value = mock_ts_model + mock_get_cached_ts.return_value = mock_ts_model + + yield diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..23f68d8 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,92 @@ +from gpuma.api import ( + optimize_batch_multi_xyz_file, + optimize_batch_xyz_directory, + optimize_ensemble_smiles, + optimize_single_smiles, + optimize_single_xyz_file, +) +from gpuma.config import Config +from gpuma.structure import Structure + + +def test_optimize_single_smiles(tmp_path): + output_file = tmp_path / "out.xyz" + # mocked calculator returns energy -50.0 + res = optimize_single_smiles("C", output_file=str(output_file)) + + assert isinstance(res, Structure) + assert res.energy == -50.0 + assert output_file.exists() + +def test_optimize_single_xyz_file(tmp_path, sample_xyz_content): + input_file = tmp_path / "in.xyz" + input_file.write_text(sample_xyz_content) + output_file = tmp_path / "out.xyz" + + res = optimize_single_xyz_file(str(input_file), output_file=str(output_file)) + + assert isinstance(res, Structure) + assert res.energy == -50.0 + assert output_file.exists() + +def test_optimize_ensemble_smiles(tmp_path): + output_file = tmp_path / "ensemble.xyz" + # Default batch mode might try to use GPU/batch optimizer if configured, + # but defaults are usually sequential or cpu fallback. + # We should ensure we test what we expect. + + # By default, config uses sequential/cpu if no GPU, which our conftest mocks. + # Actually conftest mocks load_model_* but _optimize_batch_sequential works with + # mocked calculator. + + results = optimize_ensemble_smiles("C", output_file=str(output_file)) + + assert isinstance(results, list) + assert len(results) > 0 + assert results[0].energy == -50.0 + assert output_file.exists() + +def test_optimize_batch_multi_xyz_file(tmp_path, sample_multi_xyz_content): + input_file = tmp_path / "multi.xyz" + input_file.write_text(sample_multi_xyz_content) + output_file = tmp_path / "out.xyz" + + results = optimize_batch_multi_xyz_file(str(input_file), output_file=str(output_file)) + + assert len(results) == 2 + assert results[0].energy == -50.0 + assert output_file.exists() + +def test_optimize_batch_xyz_directory(tmp_path, sample_xyz_content): + d = tmp_path / "batch_dir" + d.mkdir() + (d / "1.xyz").write_text(sample_xyz_content) + (d / "2.xyz").write_text(sample_xyz_content) + output_file = tmp_path / "out.xyz" + + results = optimize_batch_xyz_directory(str(d), str(output_file)) + + assert len(results) == 2 + assert results[0].energy == -50.0 + assert output_file.exists() + +def test_api_with_config_override(tmp_path): + output_file = tmp_path / "out.xyz" + cfg = Config({"optimization": {"charge": 1}}) + + # We need to make sure read_xyz or smiles_to_xyz respects this charge if passed via config + # optimize_single_smiles: + # multiplicity = getattr(config.optimization, "multiplicity", 1) + # structure = smiles_to_xyz(smiles, multiplicity=multiplicity) + + # Wait, charge is NOT passed to smiles_to_xyz in optimize_single_smiles. + # It seems charge is derived from SMILES in smiles_to_xyz. + # But for optimize_single_xyz_file: + # eff_charge = int(getattr(config.optimization, "charge", 0)) + # structure = read_xyz(..., charge=eff_charge, ...) + + input_file = tmp_path / "in.xyz" + input_file.write_text("1\nH\nH 0 0 0") + + res = optimize_single_xyz_file(str(input_file), str(output_file), config=cfg) + assert res.charge == 1 diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..b20e767 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,121 @@ +from unittest.mock import patch + +from gpuma.cli import main + + +def test_cli_optimize_smiles(tmp_path): + output = tmp_path / "out.xyz" + args = ["optimize", "--smiles", "C", "-o", str(output)] + + with patch("sys.argv", ["gpuma"] + args): + ret = main() + assert ret == 0 + assert output.exists() + assert "Energy:" in output.read_text() + +def test_cli_optimize_xyz(tmp_path, sample_xyz_content): + inp = tmp_path / "in.xyz" + inp.write_text(sample_xyz_content) + output = tmp_path / "out.xyz" + args = ["optimize", "--xyz", str(inp), "-o", str(output), "--charge", "1"] + + with patch("sys.argv", ["gpuma"] + args): + ret = main() + assert ret == 0 + assert output.exists() + # Verify charge was passed (our mock optimize returns energy -50 but respects charge + # in structure if IO reads it) + # But optimize_single_xyz_file reads it. + # Check structure in file + content = output.read_text() + assert "Charge: 1" in content + +def test_cli_smiles_alias(tmp_path): + output = tmp_path / "out.xyz" + args = ["smiles", "--smiles", "C", "-o", str(output)] + + with patch("sys.argv", ["gpuma"] + args): + ret = main() + assert ret == 0 + assert output.exists() + +def test_cli_ensemble(tmp_path): + output = tmp_path / "out.xyz" + args = ["ensemble", "--smiles", "C", "-o", str(output), "--conformers", "2"] + + with patch("sys.argv", ["gpuma"] + args): + ret = main() + assert ret == 0 + assert output.exists() + +def test_cli_batch_multi_xyz(tmp_path, sample_multi_xyz_content): + inp = tmp_path / "multi.xyz" + inp.write_text(sample_multi_xyz_content) + output = tmp_path / "out.xyz" + args = ["batch", "--multi-xyz", str(inp), "-o", str(output)] + + with patch("sys.argv", ["gpuma"] + args): + ret = main() + assert ret == 0 + assert output.exists() + +def test_cli_convert(tmp_path): + output = tmp_path / "out.xyz" + args = ["convert", "--smiles", "C", "-o", str(output)] + + with patch("sys.argv", ["gpuma"] + args): + ret = main() + assert ret == 0 + assert output.exists() + # Convert does NO optimization, so no energy in comment + assert "Energy:" not in output.read_text() + +def test_cli_generate(tmp_path): + output = tmp_path / "out.xyz" + args = ["generate", "--smiles", "C", "-o", str(output), "--conformers", "2"] + + with patch("sys.argv", ["gpuma"] + args): + ret = main() + assert ret == 0 + assert output.exists() + +def test_cli_config_create(tmp_path): + output = tmp_path / "config.json" + args = ["config", "--create", str(output)] + + with patch("sys.argv", ["gpuma"] + args): + ret = main() + assert ret == 0 + assert output.exists() + +def test_cli_config_validate(tmp_path): + # First create one + output = tmp_path / "config.json" + args_create = ["config", "--create", str(output)] + with patch("sys.argv", ["gpuma"] + args_create): + main() + + args_val = ["config", "--validate", str(output)] + with patch("sys.argv", ["gpuma"] + args_val): + ret = main() + assert ret == 0 + +def test_cli_no_args(capsys): + with patch("sys.argv", ["gpuma"]): + ret = main() + assert ret == 1 + captured = capsys.readouterr() + assert "usage:" in captured.out or "usage:" in captured.err + +def test_cli_verbose(tmp_path, caplog): + output = tmp_path / "out.xyz" + args = ["-v", "optimize", "--smiles", "C", "-o", str(output)] + + import logging + with caplog.at_level(logging.DEBUG): + with patch("sys.argv", ["gpuma"] + args): + main() + + # Check for debug logs if any (depends on what's logged) + # The config logic sets logging level to DEBUG + pass diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..ffa619a --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,114 @@ +import json + +import pytest +import yaml + +from gpuma.config import Config, get_huggingface_token, load_config_from_file, save_config_to_file + + +def test_config_initialization(): + cfg = Config() + # Check default values + assert cfg.optimization.charge == 0 + assert cfg.optimization.model_name == "uma-s-1p1" + +def test_config_override(): + data = {"optimization": {"charge": 1, "new_key": "value"}} + cfg = Config(data) + assert cfg.optimization.charge == 1 + assert cfg.optimization.new_key == "value" + # Defaults should persist for other keys + assert cfg.optimization.multiplicity == 1 + +def test_section_access(): + cfg = Config() + opt = cfg.optimization + assert opt.charge == 0 + + # Test setting attribute + opt.charge = 2 + assert cfg.optimization.charge == 2 + + # Test unknown attribute + with pytest.raises(AttributeError): + _ = opt.non_existent + +def test_load_save_json(tmp_path): + config_file = tmp_path / "config.json" + cfg = Config({"optimization": {"charge": -1}}) + save_config_to_file(cfg, str(config_file)) + + loaded_cfg = load_config_from_file(str(config_file)) + assert loaded_cfg.optimization.charge == -1 + + # Verify file content + with open(config_file) as f: + data = json.load(f) + assert data["optimization"]["charge"] == -1 + +def test_load_save_yaml(tmp_path): + config_file = tmp_path / "config.yaml" + cfg = Config({"optimization": {"charge": -2}}) + save_config_to_file(cfg, str(config_file)) + + loaded_cfg = load_config_from_file(str(config_file)) + assert loaded_cfg.optimization.charge == -2 + + # Verify file content + with open(config_file) as f: + data = yaml.safe_load(f) + assert data["optimization"]["charge"] == -2 + +def test_load_non_existent(): + # Should load defaults + cfg = load_config_from_file("non_existent_file.json") + assert cfg.optimization.charge == 0 + +def test_validate_config(): + # Invalid charge + with pytest.raises(ValueError, match="Invalid charge"): + Config({"optimization": {"charge": "invalid"}}) + + # Invalid multiplicity + with pytest.raises(ValueError, match="Multiplicity must be a positive integer"): + Config({"optimization": {"multiplicity": 0}}) + + # Invalid device + with pytest.raises(ValueError, match="Device must be"): + Config({"optimization": {"device": "invalid_device"}}) + +def test_huggingface_token(tmp_path, monkeypatch): + # Case 1: Token in config + cfg = Config({"optimization": {"huggingface_token": "token_in_config"}}) + assert cfg.optimization.get_huggingface_token() == "token_in_config" + assert get_huggingface_token(cfg) == "token_in_config" + + # Case 2: Token in file + token_file = tmp_path / "token.txt" + token_file.write_text("token_in_file") + cfg = Config({"optimization": {"huggingface_token_file": str(token_file)}}) + assert cfg.optimization.get_huggingface_token() == "token_in_file" + + # Case 3: Priority (config > file) + cfg = Config({ + "optimization": { + "huggingface_token": "priority_token", + "huggingface_token_file": str(token_file) + } + }) + assert cfg.optimization.get_huggingface_token() == "priority_token" + +def test_section_to_dict(): + cfg = Config() + d = cfg.optimization.to_dict() + assert isinstance(d, dict) + assert d["charge"] == 0 + + # Modifying dict shouldn't affect config + d["charge"] = 100 + assert cfg.optimization.charge == 0 + +def test_config_to_dict(): + cfg = Config() + d = cfg.to_dict() + assert d["optimization"]["charge"] == 0 diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 0000000..270f1a2 --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,99 @@ +import pytest + +from gpuma.io_handler import ( + file_exists, + read_multi_xyz, + read_xyz, + read_xyz_directory, + save_multi_xyz, + save_xyz_file, + smiles_to_ensemble, + smiles_to_xyz, +) +from gpuma.structure import Structure + + +def test_read_xyz_valid(tmp_path, sample_xyz_content): + f = tmp_path / "test.xyz" + f.write_text(sample_xyz_content) + + struct = read_xyz(str(f), charge=1, multiplicity=2) + assert len(struct.symbols) == 5 + assert struct.symbols[0] == "C" + assert struct.charge == 1 + assert struct.multiplicity == 2 + assert struct.comment == "Methane" + +def test_read_xyz_not_found(): + with pytest.raises(FileNotFoundError): + read_xyz("non_existent.xyz") + +def test_read_multi_xyz(tmp_path, sample_multi_xyz_content): + f = tmp_path / "multi.xyz" + f.write_text(sample_multi_xyz_content) + + structs = read_multi_xyz(str(f), charge=0, multiplicity=1) + assert len(structs) == 2 + assert structs[0].comment == "Water" + assert len(structs[0].symbols) == 3 + assert structs[1].comment == "Methane" + assert len(structs[1].symbols) == 5 + +def test_read_xyz_directory(tmp_path, sample_xyz_content): + d = tmp_path / "xyz_dir" + d.mkdir() + (d / "1.xyz").write_text(sample_xyz_content) + (d / "2.xyz").write_text(sample_xyz_content) + + structs = read_xyz_directory(str(d)) + assert len(structs) == 2 + +def test_save_xyz_file(tmp_path, sample_structure): + f = tmp_path / "output.xyz" + sample_structure.energy = -10.5 + save_xyz_file(sample_structure, str(f)) + + assert f.exists() + content = f.read_text() + assert "Methane | Energy: -10.500000 eV | Charge: 0 | Multiplicity: 1" in content + assert "C 0.000000" in content + +def test_save_multi_xyz(tmp_path, sample_structure): + f = tmp_path / "output_multi.xyz" + s1 = sample_structure + s2 = sample_structure.with_energy(-20.0) + save_multi_xyz([s1, s2], str(f), comments=["First", "Second"]) + + content = f.read_text() + # Energy might persist on s1 if mutated + assert "First | Energy: -20.000000" in content or "First | Energy: -10.500000" not in content + # wait, s2 = sample_structure.with_energy(-20.0) mutates sample_structure! + # Because s2 is s1. + assert "Second | Energy: -20.000000" in content + +def test_smiles_to_xyz(): + # Test with a simple molecule + s = smiles_to_xyz("C") + assert isinstance(s, Structure) + assert s.symbols == ["C", "H", "H", "H", "H"] + assert s.n_atoms == 5 + +def test_smiles_to_xyz_string(): + s = smiles_to_xyz("C", return_full_xyz_str=True) + assert isinstance(s, str) + assert "Generated from SMILES" in s + assert "C 0.0" in s or "C -0.0" in s or "C" in s + +def test_smiles_to_ensemble(): + # Generate ensemble for butane + structs = smiles_to_ensemble("CCCC", max_num_confs=3) + assert len(structs) > 0 + assert len(structs) <= 3 + assert all(isinstance(s, Structure) for s in structs) + assert structs[0].n_atoms == 14 # C4H10 + +def test_file_exists(tmp_path): + f = tmp_path / "exists.txt" + f.touch() + assert file_exists(str(f)) + assert not file_exists("non_existent") diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..3c1309a --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,77 @@ +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from gpuma.config import Config +from gpuma.models import ( + _backend_device_for_fairchem, + _device_for_torch, + _parse_device_string, + load_model_fairchem, + load_model_torchsim, +) + + +def test_parse_device_string(): + assert _parse_device_string("cpu") == "cpu" + assert _parse_device_string("CPU") == "cpu" + + # Mock cuda availability + with patch("torch.cuda.is_available", return_value=True): + assert _parse_device_string("cuda") == "cuda" + assert _parse_device_string("cuda:0") == "cuda:0" + + with patch("torch.cuda.is_available", return_value=False): + assert _parse_device_string("cuda") == "cpu" + +def test_backend_device_for_fairchem(): + with patch("torch.cuda.is_available", return_value=True): + assert _backend_device_for_fairchem("cuda:0") == "cuda" + assert _backend_device_for_fairchem("cpu") == "cpu" + +def test_device_for_torch(): + with patch("torch.cuda.is_available", return_value=True): + dev = _device_for_torch("cuda:0") + assert isinstance(dev, torch.device) + assert dev.type == "cuda" + assert dev.index == 0 + +def test_load_model_fairchem_logic(mock_hf_token): + config = Config({"optimization": {"model_name": "test_model"}}) + + # We need to patch fairchem.core inside the function or pre-import it + # Since it is a local import, patching the module where it lives (fairchem.core) works + # if we can import it first. + try: + import fairchem.core as _ # noqa: F401 + except ImportError: + pytest.skip("fairchem.core not installed") + + with patch("fairchem.core.pretrained_mlip.get_predict_unit") as mock_get_unit, \ + patch("fairchem.core.FAIRChemCalculator") as mock_calc_cls: + + mock_get_unit.return_value = MagicMock() + mock_calc_cls.return_value = MagicMock() + + calc = load_model_fairchem(config) + + mock_get_unit.assert_called() + mock_calc_cls.assert_called() + assert calc is mock_calc_cls.return_value + +def test_load_model_torchsim_logic(mock_hf_token): + config = Config({"optimization": {"model_name": "test_model"}}) + + try: + import torch_sim.models.fairchem as _ # noqa: F401 + except ImportError: + pytest.skip("torch_sim not installed") + + with patch("torch_sim.models.fairchem.FairChemModel") as mock_model_cls: + mock_model_cls.return_value = MagicMock() + + model = load_model_torchsim(config) + + mock_model_cls.assert_called() + assert model is mock_model_cls.return_value diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 0000000..1662668 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,97 @@ +from unittest.mock import MagicMock, patch + +from gpuma.config import Config +from gpuma.optimizer import optimize_single_structure, optimize_structure_batch + + +def test_optimize_single_structure(sample_structure): + config = Config() + # The autouse fixture mock_load_models will ensure we get a mock calculator + + optimized = optimize_single_structure(sample_structure, config) + + # Check if energy was updated (mock returns -50.0) + assert optimized.energy == -50.0 + # Check coordinates are updated (mock returns all zeros) + # The result is typically a list of lists or list of tuples depending on ASE + # optimize_single_structure converts to list + coords = optimized.coordinates[0] + assert list(coords) == [0.0, 0.0, 0.0] + +def test_optimize_batch_sequential(sample_structure): + config = Config({"optimization": {"batch_optimization_mode": "sequential"}}) + + structures = [sample_structure, sample_structure] + results = optimize_structure_batch(structures, config) + + assert len(results) == 2 + assert results[0].energy == -50.0 + assert results[1].energy == -50.0 + +def test_optimize_batch_structures_gpu_fallback(sample_structure): + # If we request batch mode but have no GPU, it might fall back or raise error + # depending on implementation. + # The code says: if mode == "batch" and not force_cpu: return _optimize_batch_structures + # if mode == "sequential" or force_cpu: ... + + # Let's mock device to be cpu + config = Config({"optimization": {"batch_optimization_mode": "batch", "device": "cpu"}}) + + # Should fallback to sequential + with patch("gpuma.optimizer._optimize_batch_sequential") as mock_seq: + optimize_structure_batch([sample_structure], config) + mock_seq.assert_called() + +def test_optimize_batch_structures_call(sample_structure): + # Test that batch mode calls _optimize_batch_structures when device is cuda + config = Config({"optimization": {"batch_optimization_mode": "batch", "device": "cuda"}}) + + with patch("gpuma.optimizer._parse_device_string", return_value="cuda"), \ + patch("gpuma.optimizer._optimize_batch_structures") as mock_batch: + + optimize_structure_batch([sample_structure], config) + mock_batch.assert_called() + +def test_optimize_batch_structures_implementation(sample_structure): + # Test the actual implementation of _optimize_batch_structures with mocked torchsim + config = Config({"optimization": {"batch_optimization_mode": "batch", "device": "cuda"}}) + + with patch("gpuma.optimizer._parse_device_string", return_value="cuda"), \ + patch("torch.cuda.is_available", return_value=True): + + # We need to ensure _optimize_batch_structures runs + # It uses torch_sim.io.atoms_to_state and torch_sim.optimize + # We need to mock those too because we don't want to run real + # torchsim optimization logic potentially + + with patch("torch_sim.io.atoms_to_state") as mock_ats, \ + patch("torch_sim.optimize") as mock_optimize, \ + patch("torch_sim.autobatching.InFlightAutoBatcher") as _: + + mock_state = MagicMock() + mock_state.n_atoms = 5 + mock_ats.return_value = mock_state + + mock_final_state = MagicMock() + # Mock final state energy/charge/spin/positions + mock_final_state.energy = [MagicMock(item=lambda: -60.0)] + mock_final_state.charge = [MagicMock(item=lambda: 0)] + mock_final_state.spin = [MagicMock(item=lambda: 1)] + + mock_atoms = MagicMock() + mock_atoms.get_chemical_symbols.return_value = ["C", "H", "H", "H", "H"] + + # get_positions returns a numpy array usually, which has tolist() + # If we return a list, it fails .tolist() call in optimizer.py + mock_pos = MagicMock() + mock_pos.tolist.return_value = [[0.0, 0.0, 0.0]] * 5 + mock_atoms.get_positions.return_value = mock_pos + + mock_final_state.to_atoms.return_value = [mock_atoms] + + mock_optimize.return_value = mock_final_state + + results = optimize_structure_batch([sample_structure], config) + + assert len(results) == 1 + assert results[0].energy == -60.0 diff --git a/tests/test_structure.py b/tests/test_structure.py new file mode 100644 index 0000000..c1571da --- /dev/null +++ b/tests/test_structure.py @@ -0,0 +1,36 @@ +from gpuma.structure import Structure + + +def test_structure_initialization(sample_structure): + s = sample_structure + assert s.symbols == ["C", "H", "H", "H", "H"] + assert len(s.coordinates) == 5 + assert s.charge == 0 + assert s.multiplicity == 1 + assert s.comment == "Methane" + assert s.energy is None + assert s.metadata == {} + +def test_structure_n_atoms(sample_structure): + assert sample_structure.n_atoms == 5 + +def test_structure_with_energy(sample_structure): + s = sample_structure + assert s.energy is None + s.with_energy(-123.456) + assert s.energy == -123.456 + + # Check chaining + s2 = s.with_energy(-100.0) + assert s2 is s + assert s.energy == -100.0 + +def test_structure_metadata(): + s = Structure( + symbols=["H"], + coordinates=[(0.0, 0.0, 0.0)], + charge=0, + multiplicity=2, + metadata={"source": "test"} + ) + assert s.metadata["source"] == "test"