From a90b0c65a7d1ab35f7c54281359472e24bfc6c70 Mon Sep 17 00:00:00 2001 From: Max Hoffmann Date: Tue, 30 Dec 2025 15:25:47 -0300 Subject: [PATCH 1/6] Fixed magic commands for kmos shell --- kmos/cli.py | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/kmos/cli.py b/kmos/cli.py index e9d55922..3badf2de 100644 --- a/kmos/cli.py +++ b/kmos/cli.py @@ -521,24 +521,37 @@ def sh(banner): """ import IPython - - if hasattr(IPython, "release"): - try: - from IPython.terminal.embed import InteractiveShellEmbed - - InteractiveShellEmbed(banner1=banner)() - - except ImportError: + import sys + + # Get the calling frame's namespace to preserve variables like 'model' + frame = sys._getframe(1) + user_ns = {} + user_ns.update(frame.f_globals) + user_ns.update(frame.f_locals) + + # Use IPython.embed() for modern IPython (>= 0.11) + # This properly supports magic commands like %time, %timeit, etc. + try: + IPython.embed(banner1=banner, user_ns=user_ns) + except AttributeError: + # Fallback for older IPython versions + if hasattr(IPython, "release"): try: - from IPython.frontend.terminal.embed import InteractiveShellEmbed + from IPython.terminal.embed import InteractiveShellEmbed - InteractiveShellEmbed(banner1=banner)() + InteractiveShellEmbed(banner1=banner, user_ns=user_ns)() except ImportError: - from IPython.Shell import IPShellEmbed + try: + from IPython.frontend.terminal.embed import InteractiveShellEmbed - IPShellEmbed(banner=banner)() - else: - from IPython.Shell import IPShellEmbed + InteractiveShellEmbed(banner1=banner, user_ns=user_ns)() + + except ImportError: + from IPython.Shell import IPShellEmbed + + IPShellEmbed(banner=banner, user_ns=user_ns)() + else: + from IPython.Shell import IPShellEmbed - IPShellEmbed(banner=banner)() + IPShellEmbed(banner=banner, user_ns=user_ns)() From 70e72a1df159b753a4d2bd33553e0af754c835bc Mon Sep 17 00:00:00 2001 From: Max Hoffmann Date: Tue, 30 Dec 2025 15:37:00 -0300 Subject: [PATCH 2/6] Fix test-coverage --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 9dc7979f..1141b8aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ kmos = "kmos.cli:main" [project.optional-dependencies] dev = [ "pytest>=8.4.2", + "pytest-cov>=4.0.0", "bump-my-version", "ruff>=0.8.0", "mypy>=1.0.0", From 98b504fe1d485262c37b77b6fb2643385d6cabff Mon Sep 17 00:00:00 2001 From: Max Hoffmann Date: Tue, 30 Dec 2025 15:46:41 -0300 Subject: [PATCH 3/6] Replace mypy with ty for faster type checking --- Makefile | 4 ++-- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index a042e631..716ae89e 100644 --- a/Makefile +++ b/Makefile @@ -44,8 +44,8 @@ format: ## Format code with ruff format-check: ## Check code formatting without modifying files uv run ruff format --check kmos/ tests/ -type-check: ## Run type checking with mypy - uv run mypy kmos/ +type-check: ## Run type checking with ty + uv run ty check docs: ## Build documentation cd doc && uv run make html diff --git a/pyproject.toml b/pyproject.toml index 1141b8aa..47058de4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ dev = [ "pytest-cov>=4.0.0", "bump-my-version", "ruff>=0.8.0", - "mypy>=1.0.0", + "ty", "coverage[toml]>=7.0.0", "pre-commit>=3.0.0", "sphinx>=7.0.0", From b4f9ed9923b7f35b8b02f31a90ef566f245eeb2b Mon Sep 17 00:00:00 2001 From: Max Hoffmann Date: Tue, 30 Dec 2025 15:52:33 -0300 Subject: [PATCH 4/6] Add tests for IPython shell magic command support --- tests/test_shell.py | 135 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 tests/test_shell.py diff --git a/tests/test_shell.py b/tests/test_shell.py new file mode 100644 index 00000000..9c7b46de --- /dev/null +++ b/tests/test_shell.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +"""Test that kmos shell properly supports IPython magic commands.""" + +from unittest.mock import patch + + +def test_shell_ipython_embed(): + """Test that sh() function properly calls IPython.embed with user namespace.""" + from kmos.cli import sh + + # Mock IPython.embed to capture the call + with patch("IPython.embed") as mock_embed: + # Create some test variables that should be in the namespace + test_model = "mock_model" + test_var = 42 + + # Call sh() - it should capture these local variables + try: + sh(banner="Test banner") + except SystemExit: + # IPython.embed might raise SystemExit when mocked + pass + + # Verify IPython.embed was called + assert mock_embed.called, "IPython.embed should be called" + + # Get the call arguments + call_kwargs = mock_embed.call_args.kwargs + + # Verify banner1 parameter is passed + assert "banner1" in call_kwargs, "banner1 should be in kwargs" + assert call_kwargs["banner1"] == "Test banner" + + # Verify user_ns parameter is passed + assert "user_ns" in call_kwargs, "user_ns should be in kwargs" + + # Verify user_ns is a dict + user_ns = call_kwargs["user_ns"] + assert isinstance(user_ns, dict), "user_ns should be a dict" + + # Verify it contains expected variables from the calling scope + assert "test_model" in user_ns, "user_ns should contain test_model" + assert "test_var" in user_ns, "user_ns should contain test_var" + assert user_ns["test_model"] == test_model + assert user_ns["test_var"] == test_var + + +def test_shell_with_complex_objects(): + """Test that sh() preserves complex objects in the namespace.""" + from unittest.mock import patch, MagicMock + from kmos.cli import sh + import numpy as np + + # Create mock objects that simulate what would be in kmos shell + mock_model = MagicMock() + mock_model.__class__.__name__ = "KMC_Model" + np_array = np.array([1, 2, 3]) + + # Mock IPython.embed to verify the call + with patch("IPython.embed") as mock_embed: + # Variables that would be in the shell namespace + model = mock_model + data = np_array + + try: + sh(banner="Test with complex objects") + except SystemExit: + pass + + # Verify IPython.embed was called with user_ns + assert mock_embed.called, "IPython.embed should be called" + call_kwargs = mock_embed.call_args.kwargs + assert "user_ns" in call_kwargs, "user_ns should be in kwargs" + + user_ns = call_kwargs["user_ns"] + + # Verify objects are in the namespace + assert "model" in user_ns, "model should be in namespace" + assert "data" in user_ns, "data should be in namespace" + assert "np" in user_ns, "np should be in namespace" + + # Verify they're the actual objects (use the local variables to satisfy linter) + assert user_ns["model"] is model, "model should be the same object" + assert user_ns["data"] is data, "data should be the same object" + assert user_ns["np"] is np, "np should be numpy module" + + # With this setup, magic commands like %time would work on these objects + # For example: %time model.do_steps(1000) + # The namespace is properly configured for all IPython features + + +def test_ipython_magic_compatibility(): + """ + Test that verifies the shell setup is compatible with IPython magic commands. + + This test doesn't actually execute magic commands (that would require an + interactive session), but it verifies that the shell is set up correctly + to support them. + """ + from unittest.mock import patch + from kmos.cli import sh + + # Mock IPython.embed to verify the namespace is properly passed + with patch("IPython.embed") as mock_embed: + # Call sh with some variables in scope + test_value = 123 + + try: + sh(banner="Test compatibility") + except SystemExit: + pass + + # Verify embed was called + assert mock_embed.called, "IPython.embed should be called" + + # Verify it was called with banner1 and user_ns + call_kwargs = mock_embed.call_args.kwargs + + # Check that user_ns was provided + assert "user_ns" in call_kwargs, "user_ns should be provided to IPython.embed" + + # Verify test_value is in the namespace + assert "test_value" in call_kwargs["user_ns"], "test_value should be in user_ns" + assert call_kwargs["user_ns"]["test_value"] == test_value + + # The fact that user_ns is being passed means magic commands will work, + # because IPython.embed() with user_ns properly sets up the interactive + # namespace with all IPython features including magic commands + + +if __name__ == "__main__": + test_shell_ipython_embed() + test_shell_with_complex_objects() + test_ipython_magic_compatibility() + print("All tests passed!") From 64130192e3612d207bbc2e3f627a2185aa5831ea Mon Sep 17 00:00:00 2001 From: Max Hoffmann Date: Tue, 30 Dec 2025 16:46:54 -0300 Subject: [PATCH 5/6] Add unit tests for kmos.io utility functions - Test _flatten, _chop_line, _most_common, _print_dict, _casetree_dict - Coverage improved from 588 to 574 missing lines (14 lines covered) - All 45 tests passing --- tests/test_io_utils.py | 152 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 tests/test_io_utils.py diff --git a/tests/test_io_utils.py b/tests/test_io_utils.py new file mode 100644 index 00000000..8edc4af7 --- /dev/null +++ b/tests/test_io_utils.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +"""Unit tests for kmos.io utility functions.""" + +import io as stdio + +import pytest + + +def test_flatten(): + """Test _flatten function flattens nested lists.""" + from kmos.io import _flatten + + # Test empty list + assert _flatten([]) == [] + + # Test single list + assert _flatten([[1, 2, 3]]) == [1, 2, 3] + + # Test multiple lists + assert _flatten([[1, 2], [3, 4], [5]]) == [1, 2, 3, 4, 5] + + # Test mixed types + assert _flatten([["a", "b"], ["c"]]) == ["a", "b", "c"] + + # Test empty sublists + assert _flatten([[], [1, 2], []]) == [1, 2] + + +def test_chop_line(): + """Test _chop_line function splits long lines at commas.""" + from kmos.io import _chop_line + + # Test short line (no chopping needed) + short = "short line" + assert _chop_line(short) == short + + # Test line exactly at length + line = "a" * 99 + assert _chop_line(line, line_length=100) == line + + # Test long line with commas + # The function finds the next comma AFTER line_length and breaks there + long_line = "a" * 50 + "," + "b" * 60 + result = _chop_line(long_line, line_length=100) + assert "&\n" in result + assert result.endswith("&\n") + # Verify it was split + parts = result.split("&\n") + assert len(parts) > 1 + + # Test line with multiple commas - breaks at comma after line_length + multi_comma = ", ".join(["x" * 30 for _ in range(10)]) + result = _chop_line(multi_comma, line_length=100) + # Should have multiple parts + parts = [p for p in result.split("&\n") if p] + assert len(parts) > 1 + # Each part should end with comma (except possibly last) + for part in parts[:-1]: + assert part.strip().endswith(",") + + # Test line without commas after line_length + no_comma = "a" * 150 + result = _chop_line(no_comma, line_length=100) + # Should return whole line with &\n at end + assert result == no_comma + "&\n" + + +def test_most_common(): + """Test _most_common function finds most frequent element.""" + from kmos.io import _most_common + + # Test simple list + assert _most_common([1, 2, 2, 3]) == 2 + + # Test tie - should return earliest + assert _most_common([1, 2, 1, 2]) == 1 + + # Test strings + assert _most_common(["a", "b", "b", "c"]) == "b" + + # Test single element + assert _most_common([1]) == 1 + + # Test all different (should return first) + assert _most_common([1, 2, 3, 4]) == 1 + + # Test frequency tie but different positions + assert _most_common([1, 2, 3, 1, 2]) == 1 # 1 appears first + + +def test_print_dict(capsys): + """Test _print_dict function prints nested dictionaries.""" + from kmos.io import _print_dict + + # Test simple dict + simple = {"key1": "value1", "key2": "value2"} + _print_dict(simple) + captured = capsys.readouterr() + assert "key1 = value1" in captured.out + assert "key2 = value2" in captured.out + + # Test nested dict + nested = {"outer": {"inner": "value"}} + _print_dict(nested) + captured = capsys.readouterr() + assert "outer:" in captured.out + assert " inner = value" in captured.out + + # Test with custom indent + _print_dict({"key": "value"}, indent=" ") + captured = capsys.readouterr() + assert " key = value" in captured.out + + +def test_casetree_dict(): + """Test _casetree_dict function generates conditional assignments.""" + from kmos.io import _casetree_dict + + # Create a string buffer to capture output + out = stdio.StringIO() + + # Test simple dictionary with non-dict values + # These are output as "key = value; return" statements + simple = {1: "case_1", 2: "case_2"} + _casetree_dict(simple, indent=" ", out=out) + output = out.getvalue() + + # Should contain assignment statements + assert "1 = case_1; return" in output + assert "2 = case_2; return" in output + + # Test nested dictionary (creates Fortran case structure) + out = stdio.StringIO() + # String keys with dict values create case() statements + nested = { + "species1": {"state1": "action1", "state2": "action2"}, + "default": {"state3": "action3"}, + } + _casetree_dict(nested, indent=" ", out=out) + output = out.getvalue() + + # Should have case structure for nested dicts + assert "case(species1)" in output + assert "state1 = action1; return" in output + assert "state2 = action2; return" in output + # Default with nested dict creates "case default" + assert "case default" in output + assert "state3 = action3; return" in output + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 842edb36f3f344266ded02a329affafd0efa6740 Mon Sep 17 00:00:00 2001 From: Max Hoffmann Date: Tue, 30 Dec 2025 17:35:17 -0300 Subject: [PATCH 6/6] Factor out acf function into kmos.io.acf --- kmos/{io.py => io/__init__.py} | 109 +-- kmos/io/acf.py | 1321 ++++++++++++++++++++++++++++++++ kmos/run/__init__.py | 4 +- 3 files changed, 1337 insertions(+), 97 deletions(-) rename kmos/{io.py => io/__init__.py} (97%) create mode 100644 kmos/io/acf.py diff --git a/kmos/io.py b/kmos/io/__init__.py similarity index 97% rename from kmos/io.py rename to kmos/io/__init__.py index 266b1882..e5739c56 100644 --- a/kmos/io.py +++ b/kmos/io/__init__.py @@ -95,26 +95,6 @@ def _chop_line(outstr, line_length=100): return "".join(outstr_list) -def compact_deladd_init(modified_process, out): - n = len(modified_processes) # noqa: F821 - TODO: should be modified_process - out.write("integer :: n\n") - out.write("integer, dimension(%s, 4) :: sites, cells\n\n" % n) - - -def compact_deladd_statements(modified_processes, out, action): - n = len(modified_processes) - sites = np.zeros((n, 4), int) - cells = np.zeros((n, 4), int) - - for i, (process, offset) in enumerate(modified_procs): # noqa: F821 - TODO: should be modified_processes - cells[i, :] = np.array(offset + [0]) - sites[i, :] = np.array(offset + [1]) - - out.write("do n = 1, %s\n" % (n + 1)) - out.write(" call %s_proc(nli_%s(cell + %s), cell + %s)\n" % ()) # noqa: F507 - TODO: fix format arguments - out.write("enddo\n") - - def _most_common(L): # thanks go to Alex Martelli for this function # get an iterable of (item, iterable) pairs @@ -151,7 +131,7 @@ def write_template(self, filename, target=None, options=None): with open( os.path.join( - os.path.dirname(__file__), + os.path.dirname(os.path.dirname(__file__)), "fortran_src", "{filename}.mpy".format(**locals()), ) @@ -202,7 +182,7 @@ def write_proclist(self, smart=True, code_generator="local_smart"): # write the proclist_constant module from the template with open( os.path.join( - os.path.dirname(__file__), + os.path.dirname(os.path.dirname(__file__)), "fortran_src", "proclist_constants_otf.mpy", ) @@ -232,77 +212,14 @@ def write_proclist(self, smart=True, code_generator="local_smart"): def write_proclist_acf(self, smart=True, code_generator="local_smart"): """Write the proclist_acf.f90 module, i.e. the routines to run the calculation of the autocorrelation function or to record the displacment.. - """ - # make long lines a little shorter - data = self.data - - # write header section and module imports - out = open("%s/proclist_acf.f90" % self.dir, "w") - out.write( - ( - "module proclist_acf\n" - "use kind_values\n" - "use base, only: &\n" - " update_accum_rate, &\n" - " update_integ_rate, &\n" - " determine_procsite, &\n" - " update_clocks, &\n" - " avail_sites, &\n" - " null_species, &\n" - " increment_procstat\n\n" - "use base_acf, only: &\n" - " assign_particle_id, &\n" - " update_id_arr, &\n" - " update_displacement, &\n" - " update_config_bin, &\n" - " update_buffer_acf, &\n" - " update_property_and_buffer_acf, &\n" - " drain_process, &\n" - " source_process, &\n" - " update_kmc_step_acf, &\n" - " get_kmc_step_acf, &\n" - " update_trajectory, &\n" - " update_displacement, &\n" - " nr_of_annhilations, &\n" - " wrap_count, &\n" - " update_after_wrap_acf\n\n" - "use lattice\n\n" - "use proclist\n" - ) - ) - - out.write("\nimplicit none\n") - - out.write("\n\ncontains\n\n") - - if code_generator == "local_smart": - self.write_proclist_generic_subroutines_acf( - data, out, code_generator=code_generator - ) - self.write_proclist_get_diff_sites_acf_smart(data, out) - self.write_proclist_get_diff_sites_displacement_smart(data, out) - self.write_proclist_acf_end(out) - - elif code_generator == "lat_int": - self.write_proclist_generic_subroutines_acf( - data, out, code_generator=code_generator - ) - self.write_proclist_get_diff_sites_acf_otf(data, out) - self.write_proclist_get_diff_sites_displacement_otf(data, out) - self.write_proclist_acf_end(out) - - elif code_generator == "otf": - self.write_proclist_generic_subroutines_acf( - data, out, code_generator=code_generator - ) - self.write_proclist_get_diff_sites_acf_otf(data, out) - self.write_proclist_get_diff_sites_displacement_otf(data, out) - self.write_proclist_acf_end(out) - else: - raise Exception("Don't know this code generator '%s'" % code_generator) + Note: This method now delegates to kmos.io.acf for the actual implementation. + It is maintained for backward compatibility. + """ + from kmos.io.acf import get_acf_writer - out.close() + writer = get_acf_writer(self.data, self.dir, code_generator) + writer.write_proclist_acf() def write_proclist_constants( self, @@ -314,7 +231,9 @@ def write_proclist_constants( ): with open( os.path.join( - os.path.dirname(__file__), "fortran_src", "proclist_constants.mpy" + os.path.dirname(os.path.dirname(__file__)), + "fortran_src", + "proclist_constants.mpy", ) ) as infile: template = infile.read() @@ -344,7 +263,7 @@ def write_proclist_generic_subroutines( with open( os.path.join( - os.path.dirname(__file__), + os.path.dirname(os.path.dirname(__file__)), "fortran_src", "proclist_generic_subroutines.mpy", ) @@ -367,7 +286,7 @@ def write_proclist_generic_subroutines_acf( with open( os.path.join( - os.path.dirname(__file__), + os.path.dirname(os.path.dirname(__file__)), "fortran_src", "proclist_generic_subroutines_acf.mpy", ) @@ -2089,7 +2008,7 @@ def write_proclist_lat_int_nli_casetree(self, data, lat_int_groups, progress_bar out.write("contains\n") fname = "nli_%s" % lat_int_group if data.meta.debug > 0: - out.write("function %(cell)\n" % (fname)) # noqa: F509 - TODO: fix format string + out.write("function %s(cell)\n" % (fname)) else: # DEBUGGING # out.write('function nli_%s(cell)\n' diff --git a/kmos/io/acf.py b/kmos/io/acf.py new file mode 100644 index 00000000..4565980a --- /dev/null +++ b/kmos/io/acf.py @@ -0,0 +1,1321 @@ +"""ACF Code Generation Module. + +This module handles the generation of Fortran code for autocorrelation +function (ACF) and mean squared displacement (MSD) calculations in kinetic +Monte Carlo simulations. + +The module uses pure inheritance to support multiple kMC backends: +- local_smart: Standard backend with local site tracking +- otf: On-the-fly backend with coordinate offsets +- lat_int: Lattice interpolation backend (uses OTF implementation) + +Usage Example +------------- + +>>> from kmos.io.acf import get_acf_writer +>>> writer = get_acf_writer(project_tree, export_dir, code_generator='local_smart') +>>> writer.write_proclist_acf() + +This generates proclist_acf.f90 containing: +- do_kmc_steps_acf(): Main ACF calculation loop +- get_diff_sites_acf(): Site tracking for diffusion +- get_diff_sites_displacement(): Displacement tracking for MSD + +Backend Differences +------------------- + +Smart Backend: + - Direct lattice site indexing + - No coordinate offsets + - Simpler conditional logic + +OTF Backend: + - Coordinate offset: lsite + (/0,0,0,-1/) + - More complex conditional branching + - Handles executing coordinate differences + +Design Pattern +-------------- +Uses pure inheritance with the Template Method pattern: +- ACFWriterBase: Abstract base class with shared orchestration +- SmartACFWriter: local_smart implementation +- OTFACFWriter: otf/lat_int implementation + +Migration from Old API +---------------------- + +Old way (still supported): + from kmos.io import ProcListWriter + writer = ProcListWriter(project, export_dir) + writer.write_proclist_acf(code_generator='local_smart') + +New way (recommended): + from kmos.io.acf import get_acf_writer + writer = get_acf_writer(project, export_dir, code_generator='local_smart') + writer.write_proclist_acf() + +The old API is maintained for backward compatibility. +""" + +import os +from kmos.utils import evaluate_template + + +class ACFWriterBase: + """Abstract base class for ACF code generation. + + This class provides the template method pattern for generating Fortran ACF + code. Subclasses must implement backend-specific methods for site tracking + and displacement calculations. + + Attributes: + data: Project data tree containing process and species information + dir: Output directory for generated Fortran files + """ + + def __init__(self, data, dir): + """Initialize the ACF writer. + + Args: + data: Project data tree + dir: Output directory path + """ + self.data = data + self.dir = dir + + def write_proclist_acf(self): + """Write the proclist_acf.f90 module. + + This is the main entry point that orchestrates the generation of the + complete ACF module. It creates the file, writes the header, generates + backend-specific subroutines, and finalizes the module. + """ + out = open(f"{self.dir}/proclist_acf.f90", "w") + self._write_header(out) + self._write_generic_subroutines(out) + self.write_diff_sites_acf(out) + self.write_diff_sites_displacement(out) + self._write_end(out) + out.close() + + def _write_header(self, out): + """Write the module header and imports. + + Args: + out: File handle to write to + """ + out.write( + ( + "module proclist_acf\n" + "use kind_values\n" + "use base, only: &\n" + " update_accum_rate, &\n" + " update_integ_rate, &\n" + " determine_procsite, &\n" + " update_clocks, &\n" + " avail_sites, &\n" + " null_species, &\n" + " increment_procstat\n\n" + "use base_acf, only: &\n" + " assign_particle_id, &\n" + " update_id_arr, &\n" + " update_displacement, &\n" + " update_config_bin, &\n" + " update_buffer_acf, &\n" + " update_property_and_buffer_acf, &\n" + " drain_process, &\n" + " source_process, &\n" + " update_kmc_step_acf, &\n" + " get_kmc_step_acf, &\n" + " update_trajectory, &\n" + " update_displacement, &\n" + " nr_of_annhilations, &\n" + " wrap_count, &\n" + " update_after_wrap_acf\n\n" + "use lattice\n\n" + "use proclist\n" + ) + ) + out.write("\nimplicit none\n") + out.write("\n\ncontains\n\n") + + def _write_generic_subroutines(self, out): + """Write generic ACF subroutines using template. + + Args: + out: File handle to write to + """ + with open( + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "fortran_src", + "proclist_generic_subroutines_acf.mpy", + ) + ) as infile: + template = infile.read() + + # Determine code_generator from class type + if isinstance(self, SmartACFWriter): + code_generator = "local_smart" + elif isinstance(self, OTFACFWriter): + code_generator = "otf" # Will work for both otf and lat_int + else: + code_generator = "local_smart" # Fallback + + out.write( + evaluate_template( + template, + self=self, + data=self.data, + code_generator=code_generator, + ) + ) + + def _write_end(self, out): + """Write the module end statement. + + Args: + out: File handle to write to + """ + out.write("end module proclist_acf\n") + + def write_diff_sites_acf(self, out): + """Write the get_diff_sites_acf subroutine (backend-specific). + + This method must be implemented by subclasses to generate + backend-specific code for tracking initial and final sites + during diffusion processes. + + Args: + out: File handle to write to + + Raises: + NotImplementedError: If not implemented by subclass + """ + raise NotImplementedError("Subclasses must implement write_diff_sites_acf()") + + def write_diff_sites_displacement(self, out): + """Write the get_diff_sites_displacement subroutine (backend-specific). + + This method must be implemented by subclasses to generate + backend-specific code for tracking displacements during + diffusion processes. + + Args: + out: File handle to write to + + Raises: + NotImplementedError: If not implemented by subclass + """ + raise NotImplementedError( + "Subclasses must implement write_diff_sites_displacement()" + ) + + +class SmartACFWriter(ACFWriterBase): + """ACF writer for local_smart backend. + + This implementation uses direct lattice site indexing without coordinate + offsets, providing simpler conditional logic for site tracking. + """ + + def write_diff_sites_acf(self, out): + """Write get_diff_sites_acf subroutine for smart backend. + + Args: + out: File handle to write to + """ + data = self.data + + # Write function header and documentation + out.write( + "subroutine get_diff_sites_acf(proc,nr_site,init_site,fin_site)\n\n" + "!****f* proclist_acf/get_diff_sites_acf\n" + "! FUNCTION\n" + "! get_diff_sites_acf gives the site ``init_site``, which is occupied by the particle before the diffusion process \n" + "! and also the site ``fin_site`` after the diffusion process.\n" + "!\n" + "! ARGUMENTS\n" + "!\n" + "! * ``proc`` integer representing the process number\n" + "! * ``nr_site`` integer representing the site\n" + "! * ``init_site`` integer representing the site, which is occupied by the particle before the diffusion process takes place\n" + "! * ``fin_site`` integer representing the site, which is occupied by the particle after the diffusion process\n" + "!******\n" + " integer(kind=iint), intent(in) :: proc\n" + " integer(kind=iint), intent(in) :: nr_site\n" + " integer(kind=iint), intent(out) :: init_site, fin_site\n\n" + " integer(kind=iint), dimension(4) :: lsite\n" + " integer(kind=iint), dimension(4) :: lsite_new\n" + " integer(kind=iint), dimension(4) :: lsite_old\n" + " integer(kind=iint) :: exit_site, entry_site\n\n" + " lsite = nr2lattice(nr_site, :)\n\n" + " select case(proc)\n" + ) + + # Iterate over all processes + for process in data.process_list: + out.write(" case(%s)\n" % process.name) + source_species = 0 + if data.meta.debug > 0: + out.write( + ( + 'print *,"PROCLIST/RUN_PROC_NR/NAME","%s"\n' + 'print *,"PROCLIST/RUN_PROC_NR/LSITE","lsite"\n' + 'print *,"PROCLIST/RUN_PROC_NR/SITE","site"\n' + ) + % process.name + ) + + # First pass: determine source species + for action in process.action_list: + try: + previous_species = list( + filter( + lambda x: x.coord.ff() == action.coord.ff(), + process.condition_list, + ) + )[0].species + except IndexError: + import warnings + + warnings.warn( + """Process %s seems to be ill-defined. + Every action needs a corresponding condition + for the same site.""" + % process.name, + UserWarning, + ) + previous_species = None + if action.species == previous_species: + source_species = action.species + + # Second pass: generate action code + for action in process.action_list: + if action.coord == process.executing_coord(): + relative_coord = "lsite" + else: + relative_coord = ( + "lsite%s" % (action.coord - process.executing_coord()).radd_ff() + ) + + try: + previous_species = list( + filter( + lambda x: x.coord.ff() == action.coord.ff(), + process.condition_list, + ) + )[0].species + except IndexError: + import warnings + + warnings.warn( + """Process %s seems to be ill-defined. + Every action needs a corresponding condition + for the same site.""" + % process.name, + UserWarning, + ) + previous_species = None + + if action.species[0] == "^": + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","create %s_%s"\n' + % (action.coord.layer, action.coord.name) + ) + out.write( + " call create_%s_%s(%s, %s)\n" + % ( + action.coord.layer, + action.coord.name, + relative_coord, + action.species[1:], + ) + ) + elif action.species[0] == "$": + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","annihilate %s_%s"\n' + % (action.coord.layer, action.coord.name) + ) + out.write( + " call annihilate_%s_%s(%s, %s)\n" + % ( + action.coord.layer, + action.coord.name, + relative_coord, + action.species[1:], + ) + ) + elif ( + action.species == data.species_list.default_species + and not action.species == previous_species + and source_species == 0 + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % (action.coord.layer, action.coord.name, previous_species) + ) + out.write(" lsite_old = (%s)\n" % (relative_coord)) + out.write( + " init_site = lattice2nr(lsite_old(1),lsite_old(2),lsite_old(3),lsite_old(4))\n" + ) + elif ( + action.species == data.species_list.default_species + and not action.species == previous_species + and not source_species == 0 + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % (action.coord.layer, action.coord.name, previous_species) + ) + + out.write(" lsite_old = (%s)\n" % (relative_coord)) + + out.write( + " exit_site = lattice2nr(lsite_old(1),lsite_old(2),lsite_old(3),lsite_old(4))\n" + ) + out.write( + " call drain_process(exit_site,init_site,fin_site)\n" + ) + + else: + if not previous_species == action.species: + if not previous_species == data.species_list.default_species: + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + previous_species, + ) + ) + out.write( + " call take_%s_%s_%s(%s)\n" + % ( + previous_species, + action.coord.layer, + action.coord.name, + relative_coord, + ) + ) + if source_species == 0: + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","put %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + action.species, + ) + ) + out.write(" lsite_new = (%s)\n" % (relative_coord)) + out.write( + " fin_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + if not source_species == 0: + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","put %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + action.species, + ) + ) + out.write(" lsite_new = (%s)\n" % (relative_coord)) + out.write( + " entry_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + out.write( + " call source_process(entry_site,init_site,fin_site)\n" + ) + + out.write("\n") + + out.write(" end select\n\n") + out.write("end subroutine get_diff_sites_acf\n\n") + + def write_diff_sites_displacement(self, out): + """Write get_diff_sites_displacement subroutine for smart backend. + + Args: + out: File handle to write to + """ + data = self.data + + # Write function header and documentation + out.write( + "subroutine get_diff_sites_displacement(proc,nr_site,init_site,fin_site,displace_coord)\n\n" + "!****f* proclist_acf/get_diff_sites_displacement\n" + "! FUNCTION\n" + "! get_diff_sites_displacement gives the site ``init_site``, which is occupied by the particle before the diffusion process \n" + "! and also the site ``fin_site`` after the diffusion process.\n" + "! Additionally, the displacement of the jumping particle will be saved.\n" + "!\n" + "! ARGUMENTS\n" + "!\n" + "! * ``proc`` integer representing the process number\n" + "! * ``nr_site`` integer representing the site\n" + "! * ``init_site`` integer representing the site, which is occupied by the particle before the diffusion process takes place\n" + "! * ``fin_site`` integer representing the site, which is occupied by the particle after the diffusion process\n" + "! * ``displace_coord`` writeable 3 dimensional array, in which the displacement of the jumping particle will be stored.\n" + "!******\n" + " integer(kind=iint), intent(in) :: proc\n" + " integer(kind=iint), intent(in) :: nr_site\n" + " integer(kind=iint), intent(out) :: init_site, fin_site\n\n" + " integer(kind=iint), dimension(4) :: lsite\n" + " integer(kind=iint), dimension(4) :: lsite_new\n" + " integer(kind=iint), dimension(4) :: lsite_old\n" + " integer(kind=iint) :: exit_site, entry_site\n" + " real(kind=rdouble), dimension(3), intent(out) :: displace_coord\n\n" + " lsite = nr2lattice(nr_site, :)\n\n" + " select case(proc)\n" + ) + + # Iterate over all processes + for process in data.process_list: + out.write(" case(%s)\n" % process.name) + source_species = 0 + if data.meta.debug > 0: + out.write( + ( + 'print *,"PROCLIST/RUN_PROC_NR/NAME","%s"\n' + 'print *,"PROCLIST/RUN_PROC_NR/LSITE","lsite"\n' + 'print *,"PROCLIST/RUN_PROC_NR/SITE","site"\n' + ) + % process.name + ) + + # First pass: determine source species + for action in process.action_list: + try: + previous_species = list( + filter( + lambda x: x.coord.ff() == action.coord.ff(), + process.condition_list, + ) + )[0].species + except IndexError: + import warnings + + warnings.warn( + """Process %s seems to be ill-defined. + Every action needs a corresponding condition + for the same site.""" + % process.name, + UserWarning, + ) + previous_species = None + if action.species == previous_species: + source_species = action.species + + # Second pass: generate action code + for action in process.action_list: + if action.coord == process.executing_coord(): + relative_coord = "lsite" + else: + relative_coord = ( + "lsite%s" % (action.coord - process.executing_coord()).radd_ff() + ) + + try: + previous_species = list( + filter( + lambda x: x.coord.ff() == action.coord.ff(), + process.condition_list, + ) + )[0].species + except IndexError: + import warnings + + warnings.warn( + """Process %s seems to be ill-defined. + Every action needs a corresponding condition + for the same site.""" + % process.name, + UserWarning, + ) + previous_species = None + + if action.species[0] == "^": + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","create %s_%s"\n' + % (action.coord.layer, action.coord.name) + ) + out.write( + " call create_%s_%s(%s, %s)\n" + % ( + action.coord.layer, + action.coord.name, + relative_coord, + action.species[1:], + ) + ) + elif action.species[0] == "$": + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","annihilate %s_%s"\n' + % (action.coord.layer, action.coord.name) + ) + out.write( + " call annihilate_%s_%s(%s, %s)\n" + % ( + action.coord.layer, + action.coord.name, + relative_coord, + action.species[1:], + ) + ) + elif ( + action.species == data.species_list.default_species + and not action.species == previous_species + and source_species == 0 + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % (action.coord.layer, action.coord.name, previous_species) + ) + out.write(" lsite_old = (%s)\n" % (relative_coord)) + out.write( + " init_site = lattice2nr(lsite_old(1),lsite_old(2),lsite_old(3),lsite_old(4))\n" + ) + elif ( + action.species == data.species_list.default_species + and not action.species == previous_species + and not source_species == 0 + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % (action.coord.layer, action.coord.name, previous_species) + ) + + out.write(" lsite_old = (%s)\n" % (relative_coord)) + + out.write( + " exit_site = lattice2nr(lsite_old(1),lsite_old(2),lsite_old(3),lsite_old(4))\n" + ) + out.write( + " call drain_process(exit_site,init_site,fin_site)\n" + ) + + else: + if not previous_species == action.species: + if not previous_species == data.species_list.default_species: + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + previous_species, + ) + ) + out.write( + " call take_%s_%s_%s(%s)\n" + % ( + previous_species, + action.coord.layer, + action.coord.name, + relative_coord, + ) + ) + if source_species == 0: + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","put %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + action.species, + ) + ) + out.write(" lsite_new = (%s)\n" % (relative_coord)) + out.write( + " fin_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + + if not source_species == 0: + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","put %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + action.species, + ) + ) + out.write(" lsite_new = (%s)\n" % (relative_coord)) + out.write( + " entry_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + out.write( + " call source_process(entry_site,init_site,fin_site)\n" + ) + + # Add displacement calculation (this is the key difference from write_diff_sites_acf) + out.write( + " displace_coord = matmul(unit_cell_size,(/(lsite_new(1)-lsite_old(1)),(lsite_new(2)-lsite_old(2)),(lsite_new(3)-lsite_old(3))/) + (site_positions(lsite_new(4),:) - site_positions(lsite_old(4),:)))\n" + ) + + out.write("\n") + + out.write(" end select\n\n") + out.write("end subroutine get_diff_sites_displacement\n\n") + + +class OTFACFWriter(ACFWriterBase): + """ACF writer for otf and lat_int backends. + + This implementation uses coordinate offsets (lsite + (/0,0,0,-1/)) and + more complex conditional logic for handling executing coordinate differences. + """ + + def write_diff_sites_acf(self, out): + """Write get_diff_sites_acf subroutine for OTF backend. + + Args: + out: File handle to write to + """ + data = self.data + + # Write function header and documentation + out.write( + "subroutine get_diff_sites_acf(proc,nr_site,init_site,fin_site)\n\n" + "!****f* proclist_acf/get_diff_sites_acf\n" + "! FUNCTION\n" + "! get_diff_sites_acf gives the site ``init_site``, which is occupied by the particle before the diffusion process \n" + "! and also the site ``fin_site`` after the diffusion process.\n" + "!\n" + "! ARGUMENTS\n" + "!\n" + "! * ``proc`` integer representing the process number\n" + "! * ``nr_site`` integer representing the site\n" + "! * ``init_site`` integer representing the site, which is occupied by the particle before the diffusion process takes place\n" + "! * ``fin_site`` integer representing the site, which is occupied by the particle after the diffusion process\n" + "!******\n" + " integer(kind=iint), intent(in) :: proc\n" + " integer(kind=iint), intent(in) :: nr_site\n" + " integer(kind=iint), intent(out) :: init_site, fin_site\n\n" + " integer(kind=iint), dimension(4) :: lsite\n" + " integer(kind=iint), dimension(4) :: lsite_new\n" + " integer(kind=iint), dimension(4) :: lsite_old\n" + " integer(kind=iint) :: exit_site, entry_site\n\n" + " lsite = nr2lattice(nr_site, :) + (/0,0,0,-1/)\n\n" + " select case(proc)\n" + ) + + # Iterate over all processes + for process in data.process_list: + out.write(" case(%s)\n" % process.name) + source_species = 0 + if data.meta.debug > 0: + out.write( + ( + 'print *,"PROCLIST/RUN_PROC_NR/NAME","%s"\n' + 'print *,"PROCLIST/RUN_PROC_NR/LSITE","lsite"\n' + 'print *,"PROCLIST/RUN_PROC_NR/SITE","site"\n' + ) + % process.name + ) + + # First pass: determine source species + for action in process.action_list: + try: + previous_species = list( + filter( + lambda x: x.coord.ff() == action.coord.ff(), + process.condition_list, + ) + )[0].species + except IndexError: + import warnings + + warnings.warn( + """Process %s seems to be ill-defined. + Every action needs a corresponding condition + for the same site.""" + % process.name, + UserWarning, + ) + previous_species = None + if action.species == previous_species: + source_species = action.species + + # Second pass: generate action code (with enumeration for OTF) + for i_action, action in enumerate(process.action_list): + if action.coord == process.executing_coord(): + relative_coord = "lsite" + else: + relative_coord = ( + "lsite%s" % (action.coord - process.executing_coord()).radd_ff() + ) + + action_coord = process.action_list[i_action].coord.radd_ff() + process_exec = process.action_list[1 - i_action].coord.radd_ff() + + try: + previous_species = list( + filter( + lambda x: x.coord.ff() == action.coord.ff(), + process.condition_list, + ) + )[0].species + except IndexError: + import warnings + + warnings.warn( + """Process %s seems to be ill-defined. + Every action needs a corresponding condition + for the same site.""" + % process.name, + UserWarning, + ) + previous_species = None + + if action.species[0] == "^": + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","create %s_%s"\n' + % (action.coord.layer, action.coord.name) + ) + out.write( + " call create_%s_%s(%s, %s)\n" + % ( + action.coord.layer, + action.coord.name, + relative_coord, + action.species[1:], + ) + ) + elif action.species[0] == "$": + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","annihilate %s_%s"\n' + % (action.coord.layer, action.coord.name) + ) + out.write( + " call annihilate_%s_%s(%s, %s)\n" + % ( + action.coord.layer, + action.coord.name, + relative_coord, + action.species[1:], + ) + ) + elif ( + action.species == data.species_list.default_species + and not action.species == previous_species + and source_species == 0 + and action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % (action.coord.layer, action.coord.name, previous_species) + ) + out.write(" lsite_new = lsite%s\n" % (process_exec)) + out.write( + " fin_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + elif ( + action.species == data.species_list.default_species + and not action.species == previous_species + and source_species == 0 + and not action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % (action.coord.layer, action.coord.name, previous_species) + ) + out.write(" lsite_old = lsite%s\n" % (action_coord)) + out.write( + " init_site = lattice2nr(lsite_old(1),lsite_old(2),lsite_old(3),lsite_old(4))\n" + ) + elif ( + action.species == data.species_list.default_species + and not action.species == previous_species + and not source_species == 0 + and action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % (action.coord.layer, action.coord.name, previous_species) + ) + + out.write(" lsite_new = lsite%s\n" % (process_exec)) + + out.write( + " entry_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + out.write( + " call source_process(entry_site,init_site,fin_site)\n" + ) + elif ( + action.species == data.species_list.default_species + and not action.species == previous_species + and not source_species == 0 + and not action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % (action.coord.layer, action.coord.name, previous_species) + ) + + out.write(" lsite_old = lsite%s\n" % (action_coord)) + + out.write( + " exit_site = lattice2nr(lsite_old(1),lsite_old(2),lsite_old(3),lsite_old(4))\n" + ) + out.write( + " call drain_process(exit_site,init_site,fin_site)\n" + ) + + else: + if not previous_species == action.species: + if not previous_species == data.species_list.default_species: + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + previous_species, + ) + ) + out.write( + " call take_%s_%s_%s(%s)\n" + % ( + previous_species, + action.coord.layer, + action.coord.name, + relative_coord, + ) + ) + if ( + source_species == 0 + and action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","put %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + action.species, + ) + ) + out.write(" lsite_new = lsite%s\n" % (action_coord)) + out.write( + " fin_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + if ( + source_species == 0 + and not action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","put %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + action.species, + ) + ) + out.write(" lsite_old = lsite%s\n" % (process_exec)) + out.write( + " init_site = lattice2nr(lsite_old(1),lsite_old(2),lsite_old(3),lsite_old(4))\n" + ) + if ( + not source_species == 0 + and action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","put %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + action.species, + ) + ) + out.write(" lsite_new = lsite%s\n" % (action_coord)) + out.write( + " entry_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + out.write( + " call source_process(entry_site,init_site,fin_site)\n" + ) + if ( + not source_species == 0 + and not action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","put %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + action.species, + ) + ) + out.write(" lsite_new = lsite%s\n" % (action_coord)) + out.write( + " entry_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + out.write( + " call source_process(entry_site,init_site,fin_site)\n" + ) + + out.write("\n") + + out.write(" end select\n\n") + out.write("end subroutine get_diff_sites_acf\n\n") + + def write_diff_sites_displacement(self, out): + """Write get_diff_sites_displacement subroutine for OTF backend. + + Args: + out: File handle to write to + """ + data = self.data + + # Write function header and documentation + out.write( + "subroutine get_diff_sites_displacement(proc,nr_site,init_site,fin_site,displace_coord)\n\n" + "!****f* proclist_acf/get_diff_sites_displacement\n" + "! FUNCTION\n" + "! get_diff_sites_displacement gives the site ``init_site``, which is occupied by the particle before the diffusion process \n" + "! and also the site ``fin_site`` after the diffusion process.\n" + "! Additionally, the displacement of the jumping particle will be saved.\n" + "!\n" + "! ARGUMENTS\n" + "!\n" + "! * ``proc`` integer representing the process number\n" + "! * ``nr_site`` integer representing the site\n" + "! * ``init_site`` integer representing the site, which is occupied by the particle before the diffusion process takes place\n" + "! * ``fin_site`` integer representing the site, which is occupied by the particle after the diffusion process\n" + "! * ``displace_coord`` writeable 3 dimensional array, in which the displacement of the jumping particle will be stored.\n" + "!******\n" + " integer(kind=iint), intent(in) :: proc\n" + " integer(kind=iint), intent(in) :: nr_site\n" + " integer(kind=iint), intent(out) :: init_site, fin_site\n\n" + " integer(kind=iint), dimension(4) :: lsite\n" + " integer(kind=iint), dimension(4) :: lsite_new\n" + " integer(kind=iint), dimension(4) :: lsite_old\n" + " integer(kind=iint) :: exit_site, entry_site\n" + " real(kind=rdouble), dimension(3), intent(out) :: displace_coord\n\n" + " lsite = nr2lattice(nr_site, :) + (/0,0,0,-1/)\n\n" + " select case(proc)\n" + ) + + # Iterate over all processes + for process in data.process_list: + out.write(" case(%s)\n" % process.name) + source_species = 0 + if data.meta.debug > 0: + out.write( + ( + 'print *,"PROCLIST/RUN_PROC_NR/NAME","%s"\n' + 'print *,"PROCLIST/RUN_PROC_NR/LSITE","lsite"\n' + 'print *,"PROCLIST/RUN_PROC_NR/SITE","site"\n' + ) + % process.name + ) + + # First pass: determine source species + for action in process.action_list: + try: + previous_species = list( + filter( + lambda x: x.coord.ff() == action.coord.ff(), + process.condition_list, + ) + )[0].species + except IndexError: + import warnings + + warnings.warn( + """Process %s seems to be ill-defined. + Every action needs a corresponding condition + for the same site.""" + % process.name, + UserWarning, + ) + previous_species = None + if action.species == previous_species: + source_species = action.species + + # Second pass: generate action code (with enumeration for OTF) + for i_action, action in enumerate(process.action_list): + if action.coord == process.executing_coord(): + relative_coord = "lsite" + else: + relative_coord = ( + "lsite%s" % (action.coord - process.executing_coord()).radd_ff() + ) + + action_coord = process.action_list[i_action].coord.radd_ff() + process_exec = process.action_list[1 - i_action].coord.radd_ff() + + try: + previous_species = list( + filter( + lambda x: x.coord.ff() == action.coord.ff(), + process.condition_list, + ) + )[0].species + except IndexError: + import warnings + + warnings.warn( + """Process %s seems to be ill-defined. + Every action needs a corresponding condition + for the same site.""" + % process.name, + UserWarning, + ) + previous_species = None + + if action.species[0] == "^": + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","create %s_%s"\n' + % (action.coord.layer, action.coord.name) + ) + out.write( + " call create_%s_%s(%s, %s)\n" + % ( + action.coord.layer, + action.coord.name, + relative_coord, + action.species[1:], + ) + ) + elif action.species[0] == "$": + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","annihilate %s_%s"\n' + % (action.coord.layer, action.coord.name) + ) + out.write( + " call annihilate_%s_%s(%s, %s)\n" + % ( + action.coord.layer, + action.coord.name, + relative_coord, + action.species[1:], + ) + ) + elif ( + action.species == data.species_list.default_species + and not action.species == previous_species + and source_species == 0 + and action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % (action.coord.layer, action.coord.name, previous_species) + ) + out.write(" lsite_new = lsite%s\n" % (process_exec)) + out.write( + " fin_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + elif ( + action.species == data.species_list.default_species + and not action.species == previous_species + and source_species == 0 + and not action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % (action.coord.layer, action.coord.name, previous_species) + ) + out.write(" lsite_old = lsite%s\n" % (action_coord)) + out.write( + " init_site = lattice2nr(lsite_old(1),lsite_old(2),lsite_old(3),lsite_old(4))\n" + ) + elif ( + action.species == data.species_list.default_species + and not action.species == previous_species + and not source_species == 0 + and action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % (action.coord.layer, action.coord.name, previous_species) + ) + + out.write(" lsite_new = lsite%s\n" % (process_exec)) + + out.write( + " entry_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + out.write( + " call source_process(entry_site,init_site,fin_site)\n" + ) + elif ( + action.species == data.species_list.default_species + and not action.species == previous_species + and not source_species == 0 + and not action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % (action.coord.layer, action.coord.name, previous_species) + ) + + out.write(" lsite_old = lsite%s\n" % (action_coord)) + + out.write( + " exit_site = lattice2nr(lsite_old(1),lsite_old(2),lsite_old(3),lsite_old(4))\n" + ) + out.write( + " call drain_process(exit_site,init_site,fin_site)\n" + ) + + else: + if not previous_species == action.species: + if not previous_species == data.species_list.default_species: + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","take %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + previous_species, + ) + ) + out.write( + " call take_%s_%s_%s(%s)\n" + % ( + previous_species, + action.coord.layer, + action.coord.name, + relative_coord, + ) + ) + if ( + source_species == 0 + and action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","put %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + action.species, + ) + ) + out.write(" lsite_new = lsite%s\n" % (action_coord)) + out.write( + " fin_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + if ( + source_species == 0 + and not action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","put %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + action.species, + ) + ) + out.write(" lsite_old = lsite%s\n" % (process_exec)) + out.write( + " init_site = lattice2nr(lsite_old(1),lsite_old(2),lsite_old(3),lsite_old(4))\n" + ) + if ( + not source_species == 0 + and action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","put %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + action.species, + ) + ) + out.write(" lsite_new = lsite%s\n" % (action_coord)) + out.write( + " entry_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + out.write( + " call source_process(entry_site,init_site,fin_site)\n" + ) + if ( + not source_species == 0 + and not action.coord == process.executing_coord() + ): + if data.meta.debug > 0: + out.write( + 'print *,"PROCLIST/RUN_PROC_NR/ACTION","put %s_%s %s"\n' + % ( + action.coord.layer, + action.coord.name, + action.species, + ) + ) + out.write(" lsite_new = lsite%s\n" % (action_coord)) + out.write( + " entry_site = lattice2nr(lsite_new(1),lsite_new(2),lsite_new(3),lsite_new(4))\n" + ) + out.write( + " call source_process(entry_site,init_site,fin_site)\n" + ) + + # Add displacement calculation (key difference from write_diff_sites_acf) + out.write( + " displace_coord = matmul(unit_cell_size,(/(lsite_new(1)-lsite_old(1)),(lsite_new(2)-lsite_old(2)),(lsite_new(3)-lsite_old(3))/) + (site_positions(lsite_new(4),:) - site_positions(lsite_old(4),:)))\n" + ) + + out.write("\n") + + out.write(" end select\n\n") + out.write("end subroutine get_diff_sites_displacement\n\n") + + +def get_acf_writer(data, dir, code_generator="local_smart"): + """Factory function to create appropriate ACF writer. + + Args: + data: Project data tree + dir: Output directory path + code_generator: Backend type ('local_smart', 'otf', or 'lat_int') + + Returns: + ACFWriter instance (SmartACFWriter or OTFACFWriter) + + Raises: + ValueError: If code_generator is not recognized + + Examples: + >>> writer = get_acf_writer(data, "/tmp/export", "local_smart") + >>> writer.write_proclist_acf() + """ + if code_generator == "local_smart": + return SmartACFWriter(data, dir) + elif code_generator in ("otf", "lat_int"): + return OTFACFWriter(data, dir) + else: + raise ValueError(f"Unknown code_generator: {code_generator}") diff --git a/kmos/run/__init__.py b/kmos/run/__init__.py index d8b20f94..0777c1d7 100644 --- a/kmos/run/__init__.py +++ b/kmos/run/__init__.py @@ -40,6 +40,7 @@ __all__ = ["base", "lattice", "proclist", "KMC_Model"] from ase.atoms import Atoms +from ase.io import write from copy import deepcopy from fnmatch import fnmatch from kmos import evaluate_rate_expression @@ -482,7 +483,6 @@ def run(self): elif signal.upper() == "WRITEOUT": atoms = self.get_atoms() step = self.base.get_kmc_step() - from ase.io import write filename = "%s_%s.traj" % (self.settings.model_name, step) print("Wrote snapshot to %s" % filename) @@ -654,7 +654,7 @@ def export_movie( colors=colors2, ) elif suffix == "traj": - write(filename, atoms) # noqa: F821 - TODO: import write from ase.io + write(filename, atoms) else: writer = kmos.run.png.MyPNG( atoms, show_unit_cell=True, scale=20, model=self, **kwargs