diff --git a/python/dolfinx/fem/forms.py b/python/dolfinx/fem/forms.py index 3c994db79c..1de7765544 100644 --- a/python/dolfinx/fem/forms.py +++ b/python/dolfinx/fem/forms.py @@ -247,6 +247,7 @@ def mixed_topology_form( dtype: npt.DTypeLike = default_scalar_type, form_compiler_options: dict | None = None, jit_options: dict | None = None, + jit_comm: MPI.IntraComm | None = None, entity_maps: Sequence[_EntityMap] | None = None, ): """Create a mixed-topology from from an array of Forms. @@ -262,6 +263,8 @@ def mixed_topology_form( dtype: Scalar type to use for the compiled form. form_compiler_options: See :func:`ffcx_jit ` jit_options: See :func:`ffcx_jit `. + jit_comm: MPI communicator used when compiling the form. If + ``None``, then ``form.mesh.comm``. entity_maps: If any trial functions, test functions, or coefficients in the form are not defined over the same mesh as the integration domain (the domain associated with the @@ -287,13 +290,16 @@ def mixed_topology_form( assert all([d is data[0] for d in data if d is not None]) mesh = domain.ufl_cargo() + if mesh is None: + raise RuntimeError("Expecting to find a Mesh in the form.") + comm = mesh.comm if jit_comm is None else jit_comm ufcx_forms = [] modules = [] codes = [] for form in forms: ufcx_form, module, code = jit.ffcx_jit( - mesh.comm, + comm, form, form_compiler_options=form_compiler_options, jit_options=jit_options, @@ -324,6 +330,7 @@ def form( dtype: npt.DTypeLike = default_scalar_type, form_compiler_options: dict | None = None, jit_options: dict | None = None, + jit_comm: MPI.IntraComm | None = None, entity_maps: Sequence[_EntityMap] | None = None, ): """Create a Form or list of Forms. @@ -333,6 +340,8 @@ def form( dtype: Scalar type to use for the compiled form. form_compiler_options: See :func:`ffcx_jit ` jit_options: See :func:`ffcx_jit `. + jit_comm: MPI communicator used when compiling the form. If + `None`, then `form.mesh.comm`. entity_maps: If any trial functions, test functions, or coefficients in the form are not defined over the same mesh as the integration domain (the domain associated with the @@ -369,8 +378,10 @@ def _form(form): msh = domain.ufl_cargo() if msh is None: raise RuntimeError("Expecting to find a Mesh in the form.") + comm = msh.comm if jit_comm is None else jit_comm + ufcx_form, module, code = jit.ffcx_jit( - msh.comm, form, form_compiler_options=form_compiler_options, jit_options=jit_options + comm, form, form_compiler_options=form_compiler_options, jit_options=jit_options ) # For each argument in form extract its function space