Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ format: ## Format code with ruff
format-check: ## Check code formatting without modifying files
uv run ruff format --check kmos/ tests/

type-check: ## Run type checking with mypy
uv run mypy kmos/
type-check: ## Run type checking with ty
uv run ty check

docs: ## Build documentation
cd doc && uv run make html
Expand Down
43 changes: 28 additions & 15 deletions kmos/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,24 +521,37 @@ def sh(banner):
"""

import IPython

if hasattr(IPython, "release"):
try:
from IPython.terminal.embed import InteractiveShellEmbed

InteractiveShellEmbed(banner1=banner)()

except ImportError:
import sys

# Get the calling frame's namespace to preserve variables like 'model'
frame = sys._getframe(1)
user_ns = {}
user_ns.update(frame.f_globals)
user_ns.update(frame.f_locals)

# Use IPython.embed() for modern IPython (>= 0.11)
# This properly supports magic commands like %time, %timeit, etc.
try:
IPython.embed(banner1=banner, user_ns=user_ns)
except AttributeError:
# Fallback for older IPython versions
if hasattr(IPython, "release"):
try:
from IPython.frontend.terminal.embed import InteractiveShellEmbed
from IPython.terminal.embed import InteractiveShellEmbed

InteractiveShellEmbed(banner1=banner)()
InteractiveShellEmbed(banner1=banner, user_ns=user_ns)()

except ImportError:
from IPython.Shell import IPShellEmbed
try:
from IPython.frontend.terminal.embed import InteractiveShellEmbed

IPShellEmbed(banner=banner)()
else:
from IPython.Shell import IPShellEmbed
InteractiveShellEmbed(banner1=banner, user_ns=user_ns)()

except ImportError:
from IPython.Shell import IPShellEmbed

IPShellEmbed(banner=banner, user_ns=user_ns)()
else:
from IPython.Shell import IPShellEmbed

IPShellEmbed(banner=banner)()
IPShellEmbed(banner=banner, user_ns=user_ns)()
109 changes: 14 additions & 95 deletions kmos/io.py → kmos/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,26 +95,6 @@ def _chop_line(outstr, line_length=100):
return "".join(outstr_list)


def compact_deladd_init(modified_process, out):
n = len(modified_processes) # noqa: F821 - TODO: should be modified_process
out.write("integer :: n\n")
out.write("integer, dimension(%s, 4) :: sites, cells\n\n" % n)


def compact_deladd_statements(modified_processes, out, action):
n = len(modified_processes)
sites = np.zeros((n, 4), int)
cells = np.zeros((n, 4), int)

for i, (process, offset) in enumerate(modified_procs): # noqa: F821 - TODO: should be modified_processes
cells[i, :] = np.array(offset + [0])
sites[i, :] = np.array(offset + [1])

out.write("do n = 1, %s\n" % (n + 1))
out.write(" call %s_proc(nli_%s(cell + %s), cell + %s)\n" % ()) # noqa: F507 - TODO: fix format arguments
out.write("enddo\n")


def _most_common(L):
# thanks go to Alex Martelli for this function
# get an iterable of (item, iterable) pairs
Expand Down Expand Up @@ -151,7 +131,7 @@ def write_template(self, filename, target=None, options=None):

with open(
os.path.join(
os.path.dirname(__file__),
os.path.dirname(os.path.dirname(__file__)),
"fortran_src",
"{filename}.mpy".format(**locals()),
)
Expand Down Expand Up @@ -202,7 +182,7 @@ def write_proclist(self, smart=True, code_generator="local_smart"):
# write the proclist_constant module from the template
with open(
os.path.join(
os.path.dirname(__file__),
os.path.dirname(os.path.dirname(__file__)),
"fortran_src",
"proclist_constants_otf.mpy",
)
Expand Down Expand Up @@ -232,77 +212,14 @@ def write_proclist(self, smart=True, code_generator="local_smart"):
def write_proclist_acf(self, smart=True, code_generator="local_smart"):
"""Write the proclist_acf.f90 module, i.e. the routines to run the
calculation of the autocorrelation function or to record the displacment..
"""
# make long lines a little shorter
data = self.data

# write header section and module imports
out = open("%s/proclist_acf.f90" % self.dir, "w")
out.write(
(
"module proclist_acf\n"
"use kind_values\n"
"use base, only: &\n"
" update_accum_rate, &\n"
" update_integ_rate, &\n"
" determine_procsite, &\n"
" update_clocks, &\n"
" avail_sites, &\n"
" null_species, &\n"
" increment_procstat\n\n"
"use base_acf, only: &\n"
" assign_particle_id, &\n"
" update_id_arr, &\n"
" update_displacement, &\n"
" update_config_bin, &\n"
" update_buffer_acf, &\n"
" update_property_and_buffer_acf, &\n"
" drain_process, &\n"
" source_process, &\n"
" update_kmc_step_acf, &\n"
" get_kmc_step_acf, &\n"
" update_trajectory, &\n"
" update_displacement, &\n"
" nr_of_annhilations, &\n"
" wrap_count, &\n"
" update_after_wrap_acf\n\n"
"use lattice\n\n"
"use proclist\n"
)
)

out.write("\nimplicit none\n")

out.write("\n\ncontains\n\n")

if code_generator == "local_smart":
self.write_proclist_generic_subroutines_acf(
data, out, code_generator=code_generator
)
self.write_proclist_get_diff_sites_acf_smart(data, out)
self.write_proclist_get_diff_sites_displacement_smart(data, out)
self.write_proclist_acf_end(out)

elif code_generator == "lat_int":
self.write_proclist_generic_subroutines_acf(
data, out, code_generator=code_generator
)
self.write_proclist_get_diff_sites_acf_otf(data, out)
self.write_proclist_get_diff_sites_displacement_otf(data, out)
self.write_proclist_acf_end(out)

elif code_generator == "otf":
self.write_proclist_generic_subroutines_acf(
data, out, code_generator=code_generator
)
self.write_proclist_get_diff_sites_acf_otf(data, out)
self.write_proclist_get_diff_sites_displacement_otf(data, out)
self.write_proclist_acf_end(out)

else:
raise Exception("Don't know this code generator '%s'" % code_generator)
Note: This method now delegates to kmos.io.acf for the actual implementation.
It is maintained for backward compatibility.
"""
from kmos.io.acf import get_acf_writer

out.close()
writer = get_acf_writer(self.data, self.dir, code_generator)
writer.write_proclist_acf()

def write_proclist_constants(
self,
Expand All @@ -314,7 +231,9 @@ def write_proclist_constants(
):
with open(
os.path.join(
os.path.dirname(__file__), "fortran_src", "proclist_constants.mpy"
os.path.dirname(os.path.dirname(__file__)),
"fortran_src",
"proclist_constants.mpy",
)
) as infile:
template = infile.read()
Expand Down Expand Up @@ -344,7 +263,7 @@ def write_proclist_generic_subroutines(

with open(
os.path.join(
os.path.dirname(__file__),
os.path.dirname(os.path.dirname(__file__)),
"fortran_src",
"proclist_generic_subroutines.mpy",
)
Expand All @@ -367,7 +286,7 @@ def write_proclist_generic_subroutines_acf(

with open(
os.path.join(
os.path.dirname(__file__),
os.path.dirname(os.path.dirname(__file__)),
"fortran_src",
"proclist_generic_subroutines_acf.mpy",
)
Expand Down Expand Up @@ -2089,7 +2008,7 @@ def write_proclist_lat_int_nli_casetree(self, data, lat_int_groups, progress_bar
out.write("contains\n")
fname = "nli_%s" % lat_int_group
if data.meta.debug > 0:
out.write("function %(cell)\n" % (fname)) # noqa: F509 - TODO: fix format string
out.write("function %s(cell)\n" % (fname))
else:
# DEBUGGING
# out.write('function nli_%s(cell)\n'
Expand Down
Loading
Loading