diff --git a/gusto/core/io.py b/gusto/core/io.py index 2b1f1ae83..1eea7cbee 100644 --- a/gusto/core/io.py +++ b/gusto/core/io.py @@ -473,10 +473,38 @@ def setup_dump(self, state_fields, t, pick_up=False): self.to_dump_latlon = [] for name in self.output.dumplist_latlon: f = state_fields(name) - field = Function( - functionspaceimpl.WithGeometry.create( - f.function_space(), mesh_ll), - val=f.topological, name=name+'_ll') + V = f.function_space() + try: # firedrake main + from firedrake import MeshSequenceGeometry + + if V.parent and isinstance(V.parent.topological, functionspaceimpl.MixedFunctionSpace): + if not isinstance(V.parent.mesh(), MeshSequenceGeometry): + raise ValueError("Expecting a MeshSequenceGeometry") + if len(set(V.parent.mesh().meshes)) > 1: + raise ValueError("Expecting a single mesh") + parent = functionspaceimpl.WithGeometry.create( + V.parent.topological, + MeshSequenceGeometry( + tuple(mesh_ll for _ in V.parent.mesh().meshes), + ), + ) + else: + parent = None + field = Function( + functionspaceimpl.WithGeometry.create( + V.topological, mesh_ll, parent=parent, + ), + val=f.topological, + name=name+'_ll', + ) + except ImportError: # firedrake release + field = Function( + functionspaceimpl.WithGeometry.create( + V.topological, mesh_ll, + ), + val=f.topological, + name=name+'_ll', + ) self.to_dump_latlon.append(field) # we create new netcdf files to write to, unless pick_up=True and they diff --git a/gusto/solvers/preconditioners.py b/gusto/solvers/preconditioners.py index 722104830..39397774e 100644 --- a/gusto/solvers/preconditioners.py +++ b/gusto/solvers/preconditioners.py @@ -66,6 +66,12 @@ def initialize(self, pc): V = test.function_space() mesh = V.mesh() + try: + from firedrake import MeshSequenceGeometry # noqa: F401 + + unique_mesh = mesh.unique() + except ImportError: + unique_mesh = mesh # Magically determine which spaces are vector and scalar valued for i, Vi in enumerate(V): @@ -96,7 +102,7 @@ def initialize(self, pc): DG = FiniteElement("DG", cell, deg) CG = FiniteElement("CG", interval, 1) Vv_tr_element = TensorProductElement(DG, CG) - Vv_tr = FunctionSpace(mesh, Vv_tr_element) + Vv_tr = FunctionSpace(unique_mesh, Vv_tr_element) # Break the spaces broken_elements = MixedElement([BrokenElement(Vi.ufl_element()) for Vi in V]) @@ -121,7 +127,7 @@ def initialize(self, pc): trial: TrialFunction(V_d)} Atilde = Tensor(replace(self.ctx.a, arg_map)) gammar = TestFunction(Vv_tr) - n = FacetNormal(mesh) + n = FacetNormal(unique_mesh) sigma = TrialFunctions(V_d)[self.vidx] # Again, assumes tensor product structure. Why use this if you @@ -157,7 +163,7 @@ def initialize(self, pc): trace_subdomains.extend(sorted({"top", "bottom"} - extruded_neumann_subdomains)) measures.extend((ds(sd) for sd in sorted(neumann_subdomains))) - markers = [int(x) for x in mesh.exterior_facets.unique_markers] + markers = [int(x) for x in unique_mesh.exterior_facets.unique_markers] dirichlet_subdomains = set(markers) - neumann_subdomains trace_subdomains.extend(sorted(dirichlet_subdomains))