From f850f7a9dcfc69ca90c6cc73d03d282c9a340ad1 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 17 Jan 2025 13:44:18 +0000 Subject: [PATCH 1/8] Improve kernel caching Each parloop invocation was doing more than necessary leading to quite poor performance. --- firedrake/scripts/firedrake_clean.py | 7 +- pyop2/caching.py | 3 +- pyop2/global_kernel.py | 111 ++++++++++++++++----------- 3 files changed, 71 insertions(+), 50 deletions(-) diff --git a/firedrake/scripts/firedrake_clean.py b/firedrake/scripts/firedrake_clean.py index f411d498c3..678a3c35fa 100755 --- a/firedrake/scripts/firedrake_clean.py +++ b/firedrake/scripts/firedrake_clean.py @@ -4,10 +4,7 @@ from firedrake.configuration import setup_cache_dirs from pyop2.compilation import clear_compiler_disk_cache as pyop2_clear_cache from firedrake.tsfc_interface import clear_cache as tsfc_clear_cache -try: - import platformdirs as appdirs -except ImportError: - import appdirs +import platformdirs def main(): @@ -20,7 +17,7 @@ def main(): print(f"Removing cached PyOP2 code from {os.environ.get('PYOP2_CACHE_DIR', '???')}") pyop2_clear_cache() - pytools_cache = appdirs.user_cache_dir("pytools", "pytools") + pytools_cache = platformdirs.user_cache_dir("pytools", "pytools") print(f"Removing cached pytools files from {pytools_cache}") if os.path.exists(pytools_cache): shutil.rmtree(pytools_cache, ignore_errors=True) diff --git a/pyop2/caching.py b/pyop2/caching.py index 2948ddede7..7a827d90b0 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -547,7 +547,8 @@ def wrapper(*args, **kwargs): value = local_cache.get(key, CACHE_MISS) if value is CACHE_MISS: - value = func(*args, **kwargs) + with PETSc.Log.Event("pyop2: handle cache miss"): + value = func(*args, **kwargs) return local_cache.setdefault(key, value) return wrapper diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 7edfed0771..433a2992f4 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -8,12 +8,15 @@ import loopy as lp import numpy as np import pytools +from loopy.codegen.result import process_preambles from petsc4py import PETSc from pyop2 import mpi +from pyop2.caching import parallel_cache, serial_cache from pyop2.compilation import add_profiling_events, load from pyop2.configuration import configuration from pyop2.datatypes import IntType, as_ctypes +from pyop2.codegen.rep2loopy import generate from pyop2.types import IterationRegion, Constant, READ from pyop2.utils import cached_property, get_petsc_dir @@ -326,8 +329,7 @@ def __call__(self, comm, *args): :arg comm: Communicator the execution is collective over. :*args: Arguments to pass to the compiled kernel. """ - # It is unnecessary to cache this call as it is cached in pyop2/compilation.py - func = self.compile(comm) + func = _compile_global_kernel(self, comm) func(*args) @property @@ -364,48 +366,7 @@ def builder(self): @cached_property def code_to_compile(self): """Return the C/C++ source code as a string.""" - from pyop2.codegen.rep2loopy import generate - - with PETSc.Log.Event("GlobalKernel: generate loopy"): - wrapper = generate(self.builder) - - with PETSc.Log.Event("GlobalKernel: generate device code"): - code = lp.generate_code_v2(wrapper) - - if self.local_kernel.cpp: - from loopy.codegen.result import process_preambles - preamble = "".join(process_preambles(getattr(code, "device_preambles", []))) - device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs) - return preamble + "\nextern \"C\" {\n" + device_code + "\n}\n" - return code.device_code() - - @PETSc.Log.EventDecorator() - @mpi.collective - def compile(self, comm): - """Compile the kernel. - - :arg comm: The communicator the compilation is collective over. - :returns: A ctypes function pointer for the compiled function. - """ - extension = "cpp" if self.local_kernel.cpp else "c" - cppargs = ( - tuple("-I%s/include" % d for d in get_petsc_dir()) - + tuple("-I%s" % d for d in self.local_kernel.include_dirs) - + ("-I%s" % os.path.abspath(os.path.dirname(__file__)),) - ) - ldargs = ( - tuple("-L%s/lib" % d for d in get_petsc_dir()) - + tuple("-Wl,-rpath,%s/lib" % d for d in get_petsc_dir()) - + ("-lpetsc", "-lm") - + tuple(self.local_kernel.ldargs) - ) - - dll = load(self.code_to_compile, extension, cppargs=cppargs, ldargs=ldargs, comm=comm) - add_profiling_events(dll, self.local_kernel.events) - fn = getattr(dll, self.name) - fn.argtypes = self.argtypes - fn.restype = ctypes.c_int - return fn + return _generate_code_from_global_kernel(self) @cached_property def argtypes(self): @@ -427,3 +388,65 @@ def num_flops(self, iterset): elif region not in {IterationRegion.TOP, IterationRegion.BOTTOM}: size = layers - 1 return size * self.local_kernel.num_flops + + @cached_property + def _cppargs(self): + cppargs = [f"-I{d}/include" for d in get_petsc_dir()] + cppargs.extend(f"-I{d}" for d in self.local_kernel.include_dirs) + cppargs.append(f"-I{os.path.abspath(os.path.dirname(__file__))}") + return tuple(cppargs) + + @cached_property + def _ldargs(self): + ldargs = [f"-L{d}/lib" for d in get_petsc_dir()] + ldargs.extend(f"-Wl,-rpath,{d}/lib" for d in get_petsc_dir()) + ldargs.extend(["-lpetsc", "-lm"]) + ldargs.extend(self.local_kernel.ldargs) + return tuple(ldargs) + + +@serial_cache(hashkey=lambda knl: knl.cache_key) +def _generate_code_from_global_kernel(kernel): + with PETSc.Log.Event("GlobalKernel: generate loopy"): + wrapper = generate(kernel.builder) + + with PETSc.Log.Event("GlobalKernel: generate device code"): + code = lp.generate_code_v2(wrapper) + + if kernel.local_kernel.cpp: + preamble = "".join(process_preambles(getattr(code, "device_preambles", []))) + device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs) + return preamble + "\nextern \"C\" {\n" + device_code + "\n}\n" + + return code.device_code() + + +@parallel_cache(hashkey=lambda knl, _: knl.cache_key) +@mpi.collective +def _compile_global_kernel(kernel, comm): + """Compile the kernel. + + Parameters + ---------- + kernel : + The global kernel to generate code for. + comm : + The communicator the compilation is collective over. + + Returns + ------- + A ctypes function pointer for the compiled function. + + """ + dll = load( + kernel.code_to_compile, + "cpp" if kernel.local_kernel.cpp else "c", + cppargs=kernel._cppargs, + ldargs=kernel._ldargs, + comm=comm, + ) + add_profiling_events(dll, kernel.local_kernel.events) + fn = getattr(dll, kernel.name) + fn.argtypes = kernel.argtypes + fn.restype = ctypes.c_int + return fn From 7c153ad687cdc5af00685f00b0c08c4e8d643013 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 17 Jan 2025 13:47:06 +0000 Subject: [PATCH 2/8] fixup --- firedrake/preconditioners/patch.py | 5 +++-- pyop2/global_kernel.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/firedrake/preconditioners/patch.py b/firedrake/preconditioners/patch.py index 5e9d0d4fa0..d79a03b8a1 100644 --- a/firedrake/preconditioners/patch.py +++ b/firedrake/preconditioners/patch.py @@ -27,6 +27,7 @@ from pyop2.codegen.builder import Pack, MatPack, DatPack from pyop2.codegen.representation import Comparison, Literal from pyop2.codegen.rep2loopy import register_petsc_function +from pyop2.global_kernel import compile_global_kernel __all__ = ("PatchPC", "PlaneSmoother", "PatchSNES") @@ -222,7 +223,7 @@ def matrix_funptr(form, state): wrapper_knl_args = tuple(a.global_kernel_arg for a in args) mod = op2.GlobalKernel(kinfo.kernel, wrapper_knl_args, subset=True) - kernels.append(CompiledKernel(mod.compile(iterset.comm), kinfo)) + kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo)) return cell_kernels, int_facet_kernels @@ -316,7 +317,7 @@ def residual_funptr(form, state): wrapper_knl_args = tuple(a.global_kernel_arg for a in args) mod = op2.GlobalKernel(kinfo.kernel, wrapper_knl_args, subset=True) - kernels.append(CompiledKernel(mod.compile(iterset.comm), kinfo)) + kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo)) return cell_kernels, int_facet_kernels diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 433a2992f4..aac7d30c3d 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -329,7 +329,7 @@ def __call__(self, comm, *args): :arg comm: Communicator the execution is collective over. :*args: Arguments to pass to the compiled kernel. """ - func = _compile_global_kernel(self, comm) + func = compile_global_kernel(self, comm) func(*args) @property @@ -423,7 +423,7 @@ def _generate_code_from_global_kernel(kernel): @parallel_cache(hashkey=lambda knl, _: knl.cache_key) @mpi.collective -def _compile_global_kernel(kernel, comm): +def compile_global_kernel(kernel, comm): """Compile the kernel. Parameters From 45e660d92cab17662644018047427427a2a43eac Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 17 Jan 2025 15:02:23 +0000 Subject: [PATCH 3/8] test fixups --- tests/pyop2/test_caching.py | 39 ++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/tests/pyop2/test_caching.py b/tests/pyop2/test_caching.py index cfd9e6ce7f..5579377ca8 100644 --- a/tests/pyop2/test_caching.py +++ b/tests/pyop2/test_caching.py @@ -309,6 +309,13 @@ def cache(self): int_comm.Set_attr(comm_cache_keyval, _cache_collection) return _cache_collection[default_cache_name] + def code_cache_len_equals(self, expected): + # We need to do this check because different things also get + # put into self.cache + return sum( + 1 for key in self.cache if key[1] == "compile_global_kernel" + ) == expected + @pytest.fixture def a(cls, diterset): return op2.Dat(diterset, list(range(nelems)), numpy.uint32, "a") @@ -328,14 +335,14 @@ def test_same_args(self, iterset, iter2ind1, x, a): a(op2.WRITE), x(op2.READ, iter2ind1)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) op2.par_loop(op2.Kernel(kernel_cpy, "cpy"), iterset, a(op2.WRITE), x(op2.READ, iter2ind1)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) def test_diff_kernel(self, iterset, iter2ind1, x, a): self.cache.clear() @@ -348,7 +355,7 @@ def test_diff_kernel(self, iterset, iter2ind1, x, a): a(op2.WRITE), x(op2.READ, iter2ind1)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) kernel_cpy = "static void cpy(unsigned int* DST, unsigned int* SRC) { *DST = *SRC; }" @@ -357,7 +364,7 @@ def test_diff_kernel(self, iterset, iter2ind1, x, a): a(op2.WRITE), x(op2.READ, iter2ind1)) - assert len(self.cache) == 2 + assert self.code_cache_len_equals(2) def test_invert_arg_similar_shape(self, iterset, iter2ind1, x, y): self.cache.clear() @@ -377,14 +384,14 @@ def test_invert_arg_similar_shape(self, iterset, iter2ind1, x, y): x(op2.RW, iter2ind1), y(op2.RW, iter2ind1)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) op2.par_loop(op2.Kernel(kernel_swap, "swap"), iterset, y(op2.RW, iter2ind1), x(op2.RW, iter2ind1)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) def test_dloop_ignore_scalar(self, iterset, a, b): self.cache.clear() @@ -404,14 +411,14 @@ def test_dloop_ignore_scalar(self, iterset, a, b): a(op2.RW), b(op2.RW)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) op2.par_loop(op2.Kernel(kernel_swap, "swap"), iterset, b(op2.RW), a(op2.RW)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) def test_vector_map(self, iterset, x2, iter2ind2): self.cache.clear() @@ -431,13 +438,13 @@ def test_vector_map(self, iterset, x2, iter2ind2): iterset, x2(op2.RW, iter2ind2)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) op2.par_loop(op2.Kernel(kernel_swap, "swap"), iterset, x2(op2.RW, iter2ind2)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) def test_same_iteration_space_works(self, iterset, x2, iter2ind2): self.cache.clear() @@ -447,12 +454,12 @@ def test_same_iteration_space_works(self, iterset, x2, iter2ind2): op2.par_loop(k, iterset, x2(op2.INC, iter2ind2)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) op2.par_loop(k, iterset, x2(op2.INC, iter2ind2)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) def test_change_dat_dtype_matters(self, iterset, diterset): d = op2.Dat(diterset, list(range(nelems)), numpy.uint32) @@ -463,12 +470,12 @@ def test_change_dat_dtype_matters(self, iterset, diterset): op2.par_loop(k, iterset, d(op2.WRITE)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) d = op2.Dat(diterset, list(range(nelems)), numpy.int32) op2.par_loop(k, iterset, d(op2.WRITE)) - assert len(self.cache) == 2 + assert self.code_cache_len_equals(2) def test_change_global_dtype_matters(self, iterset, diterset): g = op2.Global(1, 0, dtype=numpy.uint32, comm=COMM_WORLD) @@ -479,12 +486,12 @@ def test_change_global_dtype_matters(self, iterset, diterset): op2.par_loop(k, iterset, g(op2.INC)) - assert len(self.cache) == 1 + assert self.code_cache_len_equals(1) g = op2.Global(1, 0, dtype=numpy.float64, comm=COMM_WORLD) op2.par_loop(k, iterset, g(op2.INC)) - assert len(self.cache) == 2 + assert self.code_cache_len_equals(2) class TestSparsityCache: From a778d6800a7cd9d83358675e945b53febf8f9bc6 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 22 Jan 2025 09:58:59 +0000 Subject: [PATCH 4/8] DO NOT MERGE --- .github/workflows/build.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0eb616c24d..7f108a77ec 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,6 +84,8 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch fiat pbrubeck/simplify-indexed \ + --package-branch ufl pbrubeck/remove-component-tensors \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | From 1749644549c302f14deb7fb48a91a6b25009a049 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sun, 26 Jan 2025 14:52:51 +0000 Subject: [PATCH 5/8] DO NOT MERGE --- .github/workflows/build.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0f17161362..77739f82f9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -87,6 +87,8 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch fiat pbrubeck/simplify-indexed \ + --package-branch ufl pbrubeck/remove-component-tensors \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies From 729339dfedde3a8c6a0b32880562583b3f4a04e7 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sun, 26 Jan 2025 15:02:52 +0000 Subject: [PATCH 6/8] Lazy basis transformation --- tsfc/fem.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tsfc/fem.py b/tsfc/fem.py index 10c10dc97d..e20244a0a1 100644 --- a/tsfc/fem.py +++ b/tsfc/fem.py @@ -641,15 +641,13 @@ def fiat_to_ufl(fiat_dict, order): @translate.register(Argument) def translate_argument(terminal, mt, ctx): - argument_multiindex = ctx.argument_multiindices[terminal.number()] - sigma = tuple(gem.Index(extent=d) for d in mt.expr.ufl_shape) element = ctx.create_element(terminal.ufl_element(), restriction=mt.restriction) def callback(entity_id): finat_dict = ctx.basis_evaluation(element, mt, entity_id) # Filter out irrelevant derivatives - filtered_dict = {alpha: table - for alpha, table in finat_dict.items() + filtered_dict = {alpha: finat_dict[alpha] + for alpha in finat_dict if sum(alpha) == mt.local_derivatives} # Change from FIAT to UFL arrangement @@ -658,13 +656,16 @@ def callback(entity_id): # A numerical hack that FFC used to apply on FIAT tables still # lives on after ditching FFC and switching to FInAT. return ffc_rounding(square, ctx.epsilon) + table = ctx.entity_selector(callback, mt.restriction) if ctx.use_canonical_quadrature_point_ordering: quad_multiindex = ctx.quadrature_rule.point_set.indices quad_multiindex_permuted = _make_quad_multiindex_permuted(mt, ctx) mapper = gem.node.MemoizerArg(gem.optimise.filtered_replace_indices) table = mapper(table, tuple(zip(quad_multiindex, quad_multiindex_permuted))) - return gem.ComponentTensor(gem.Indexed(table, argument_multiindex + sigma), sigma) + + argument_multiindex = ctx.argument_multiindices[terminal.number()] + return gem.partial_indexed(table, argument_multiindex) @translate.register(TSFCConstantMixin) From 03b83b37dca9f84fd61fef89a14e00d49fcbb4aa Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 18 Mar 2025 12:06:47 +0000 Subject: [PATCH 7/8] Force reinstall --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 92e385f27b..8dfb774588 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -85,8 +85,8 @@ jobs: --no-binary h5py \ --extra-index-url https://download.pytorch.org/whl/cpu \ './firedrake-repo[ci]' - pip install "fenics-ufl @ git+https://github.com/firedrakeproject/ufl.git@pbrubeck/remove-component-tensors" - pip install "fenics-fiat @ git+https://github.com/firedrakeproject/fiat.git@pbrubeck/simplify-indexed" + pip install -I "fenics-ufl @ git+https://github.com/firedrakeproject/ufl.git@pbrubeck/remove-component-tensors" + pip install -I "fenics-fiat @ git+https://github.com/firedrakeproject/fiat.git@pbrubeck/simplify-indexed" firedrake-clean pip list From 65a2e7049f13bf74001721c3a7875a745bc206e4 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 18 Mar 2025 16:02:34 +0000 Subject: [PATCH 8/8] Debug --- tests/firedrake/regression/test_vfs_component_bcs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/firedrake/regression/test_vfs_component_bcs.py b/tests/firedrake/regression/test_vfs_component_bcs.py index 1b636fefe4..9a5871041c 100644 --- a/tests/firedrake/regression/test_vfs_component_bcs.py +++ b/tests/firedrake/regression/test_vfs_component_bcs.py @@ -218,8 +218,8 @@ def test_component_full_bcs(V): A_cmp = assemble(a, bcs=bcs_cmp, mat_type="aij") A_mixed = assemble(a, bcs=bcs_mixed, mat_type="aij") - assert A_full.petscmat.equal(A_cmp.petscmat) - assert A_mixed.petscmat.equal(A_full.petscmat) + assert A_full.petscmat.equal(A_cmp.petscmat), str((A_full.petscmat[:, :]-A_cmp.petscmat[:, :]).tolist()) + assert A_mixed.petscmat.equal(A_full.petscmat), str((A_full.petscmat[:, :]-A_mixed.petscmat[:, :]).tolist()) def test_component_full_bcs_overlap(V):