diff --git a/gusto/core/io.py b/gusto/core/io.py index ebba47815..24c74c69a 100644 --- a/gusto/core/io.py +++ b/gusto/core/io.py @@ -57,7 +57,7 @@ def pick_up_mesh(output, mesh_name, comm=COMM_WORLD): else: dumpdir = path.join("results", output.dirname) chkfile = path.join(dumpdir, "chkpt.h5") - with CheckpointFile(chkfile, 'r', comm=comm) as chk: + with CheckpointFile(chkfile, 'r', comm) as chk: mesh = chk.load_mesh(mesh_name) if dumpdir: @@ -665,7 +665,7 @@ def pick_up_from_checkpoint(self, state_fields, comm=COMM_WORLD): step = chk.read_attribute("/", "step") else: - with CheckpointFile(chkfile, 'r', comm) as chk: + with CheckpointFile(chkfile, 'r', self.domain.mesh.comm) as chk: mesh = self.domain.mesh # Recover compulsory fields from the checkpoint for field_name in self.to_pick_up: @@ -773,7 +773,7 @@ def dump(self, state_fields, time_data): chkpt_mode = 'a' else: chkpt_mode = 'w' - with CheckpointFile(self.chkpt_path, chkpt_mode) as chk: + with CheckpointFile(self.chkpt_path, chkpt_mode, self.mesh.comm) as chk: chk.save_mesh(self.domain.mesh) for field_name in self.to_pick_up: if output.multichkpt: @@ -826,8 +826,11 @@ def create_nc_dump(self, filename, space_names): # we instead save string metadata as char arrays of fixed length. if isinstance(output_value, str): nc_field_file.createVariable(metadata_key, 'S1', ('dim_one', 'dim_string')) - output_char_array = np.array([output_value], dtype='S256') - nc_field_file[metadata_key][:] = stringtochar(output_char_array) + nc_field_file[metadata_key]._Encoding = 'UTF-8' # tell netCDF4 this char var is UTF-8 text + N = nc_field_file.dimensions['dim_string'].size + max_chars = max(1, N // 4) # max number of characters that can be stored + output_char_array = np.array([output_value], dtype=f"U{max_chars}") + nc_field_file[metadata_key][:] = stringtochar(output_char_array, encoding='utf-8') else: nc_field_file.createVariable(metadata_key, type(output_value), ('dim_one',)) nc_field_file[metadata_key][0] = output_value @@ -981,7 +984,7 @@ def make_nc_dataset(filename, access, comm): """ try: - nc_field_file = Dataset(filename, access, parallel=True) + nc_field_file = Dataset(filename, access, parallel=True, comm=comm) nc_supports_parallel = True except ValueError: # parallel netCDF not available, use the serial version instead diff --git a/gusto/equations/common_forms.py b/gusto/equations/common_forms.py index b8caf2e5c..360866d8f 100644 --- a/gusto/equations/common_forms.py +++ b/gusto/equations/common_forms.py @@ -367,7 +367,7 @@ def split_continuity_form(equation): u_trial = TrialFunctions(W)[u_idx] qbar = split(equation.X_ref)[idx] # Add linearisation to adv_term - linear_adv_term = linear_advection_form(test, qbar, u_trial) + linear_adv_term = linear_advection_form(test, qbar, u_trial, qbar, uadv) adv_term = linearisation(adv_term, linear_adv_term) # Add linearisation to div_term linear_div_term = transporting_velocity(qbar*test*div(u_trial)*dx, u_trial) @@ -446,7 +446,7 @@ def split_linear_advection_form(test, qbar, ubar, ubar_full): :class:`LabelledForm`: a labelled transport form. """ - L = test*dot(ubar, grad(qbar))*dx + L = inner(test, dot(ubar, grad(qbar)))*dx form = transporting_velocity(L, ubar_full) return transport(form, TransportEquationType.advective) diff --git a/gusto/time_discretisation/__init__.py b/gusto/time_discretisation/__init__.py index 20b428449..e26959ad2 100644 --- a/gusto/time_discretisation/__init__.py +++ b/gusto/time_discretisation/__init__.py @@ -4,4 +4,5 @@ from gusto.time_discretisation.imex_runge_kutta import * # noqa from gusto.time_discretisation.multi_level_schemes import * # noqa from gusto.time_discretisation.wrappers import * # noqa -from gusto.time_discretisation.sdc import * # noqa \ No newline at end of file +from gusto.time_discretisation.deferred_correction import * # noqa +from gusto.time_discretisation.parallel_dc import * # noqa \ No newline at end of file diff --git a/gusto/time_discretisation/sdc.py b/gusto/time_discretisation/deferred_correction.py similarity index 51% rename from gusto/time_discretisation/sdc.py rename to gusto/time_discretisation/deferred_correction.py index 7c419cb21..cd692e2bc 100644 --- a/gusto/time_discretisation/sdc.py +++ b/gusto/time_discretisation/deferred_correction.py @@ -1,46 +1,76 @@ -u""" -Objects for discretising time derivatives using Spectral Deferred Correction -Methods. +""" +Objects for discretising time derivatives using Deferred Correction (DC) +Methods. This includes Spectral Deferred Correction (SDC) and Serial Revisionist +Integral Deferred Correction (RIDC) methods. -SDC objects discretise ∂y/∂t = F(y), for variable y, time t and -operator F. +These methods discretise ∂y/∂t = F(y), for variable y, time t, and operator F. -Written in Picard integral form this equation is -y(t) = y_n + int[t_n,t] F(y(s)) ds +In Picard integral form, this equation is: +y(t) = y_n + ∫[t_n, t] F(y(s)) ds -Using some quadrature rule, we can evaluate y on a temporal quadrature node as -y_m = y_n + sum[j=1,M] q_mj*F(y_j) -where q_mj can be found by integrating Lagrange polynomials. This is similar to -how Runge-Kutta methods are formed. +================================================================================ +Spectral Deferred Correction (SDC) Formulation: +================================================================================ -In matrix form this equation is: -(I - dt*Q*F)(y)=y_n +SDC methods integrate the function F(y) over the interval [t_n, t_n+1] using +quadrature. Evaluating y on temporal quadrature nodes gives: +y_m = y_n + Σ[j=1,M] q_mj * F(y_j) +where q_mj are derived from integrating Lagrange polynomials, similar to how +Runge-Kutta methods are constructed. -Computing y by Picard iteration through k we get: -y^(k+1)=y^k + (y_n - (I - dt*Q*F)(y^k)) +In matrix form: +(I - dt * Q * F)(y) = y_n -Finally, to get our SDC method we precondition this system, using some approximation -of Q Q_delta: -(I - dt*Q_delta*F)(y^(k+1)) = y_n + dt*(Q - Q_delta)F(y^k) +Using Picard iteration: +y^(k+1) = y^k + (y_n - (I - dt * Q * F)(y^k)) -The zero-to-node (Z2N) formulation is then: -y_m^(k+1) = y_n + sum(j=1,M) q'_mj*(F(y_j^(k+1)) - F(y_j^k)) - + sum(j=1,M) q_mj*F(y_(m-1)^k) -for entires q_mj in Q and q'_mj in Q_delta. +Preconditioning this system with an approximation Q_delta gives: +(I - dt * Q_delta * F)(y^(k+1)) = y_n + dt * (Q - Q_delta) * F(y^k) -Node-wise from previous quadrature node (N2N formulation), the implicit SDC calculation is: -y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k)) - + sum(j=1,M) s_mj*F(y_(m-1)^k) -where s_mj = q_mj - q_(m-1)j for entires q_ik in Q. +Two formulations are commonly used: +1. Zero-to-node (Z2N): + y_m^(k+1) = y_n + Σ[j=1,M] q'_mj * (F(y_j^(k+1)) - F(y_j^k)) + + Σ[j=1,M] q_mj * F(y_(j)^k) + where q_mj are entries in Q and q'_mj are entries in Q_delta. +2. Node-to-node (N2N): + y_m^(k+1) = y_(m-1)^(k+1) + dtau_m * (F(y_(m)^(k+1)) - F(y_(m)^k)) + + Σ[j=1,M] s_mj * F(y_(j)^k) + where s_mj = q_mj - q_(m-1)j for entries q_ik in Q. -Key choices in our SDC method are: -- Choice of quadrature node type (e.g. gauss-lobatto) +Key choices in SDC: +- Quadrature node type (e.g., Gauss-Lobatto) - Number of quadrature nodes -- Number of iterations - each iteration increases the order of accuracy up to - the order of the underlying quadrature -- Choice of Q_delta (e.g. Forward Euler, Backward Euler, LU-trick) -- How to get initial solution on quadrature nodes +- Number of iterations (each iteration increases accuracy up to the quadrature order) +- Choice of Q_delta (e.g., Forward Euler, Backward Euler, LU-trick) +- Initial solution on quadrature nodes + +================================================================================ +Revisionist Integral Deferred Correction (RIDC) Formulation: +================================================================================ + +RIDC methods are similar to SDC but use equidistant nodes and a different +formulation for the error equation. The process involves: +1. Using a low-order method (predictor) to compute an initial solution: + y_m^(0) = y_(m-1)^(0) + dt * F(y_(m)^(0)) + +2. Performing K correction steps: + y_m^(k+1) = y_(m-1)^(k+1) + dt * (F(y_(m)^(k+1)) - F(y_(m)^k)) + + Σ[j=1,M] s_mj * F(y_(j)^k) +We solve on N equispaced nodes on the interval [0, T] divided into J intervals, +each further divided into M subintervals: + + 0 * * * * * | * * * * * | * * * * * | * * * * * | * * * * * T + | J intervals, each with M subintervals | + +Here, M >> K, and M must be at least K * (K+1) / 2 for the reduced stencil RIDC method. +dt = T / N, N = J * M. +Each correction sweep increases accuracy up to the quadrature order. + +Key choices in RIDC: +- Number of subintervals J +- Number of quadrature nodes M + 1 +- Number of correction iterations K """ from abc import ABCMeta @@ -54,10 +84,9 @@ from firedrake.utils import cached_property from gusto.time_discretisation.time_discretisation import wrapper_apply from gusto.core.labels import (time_derivative, implicit, explicit, source_label) - from qmat import genQCoeffs, genQDeltaCoeffs -__all__ = ["SDC"] +__all__ = ["SDC", "RIDC"] class SDC(object, metaclass=ABCMeta): @@ -66,7 +95,7 @@ class SDC(object, metaclass=ABCMeta): def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_imp, qdelta_exp, formulation="N2N", field_name=None, linear_solver_parameters=None, nonlinear_solver_parameters=None, final_update=True, - limiter=None, options=None, initial_guess="base"): + limiter=None, initial_guess="base"): """ Initialise SDC object Args: @@ -96,10 +125,6 @@ def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_im quadrature value. Defaults to True limiter (:class:`Limiter` object, optional): a limiter to apply to the evolving field to enforce monotonicity. Defaults to None. - options (:class:`AdvectionOptions`, optional): an object containing - options to either be passed to the spatial discretisation, or - to control the "wrapper" methods, such as Embedded DG or a - recovery method. Defaults to None. initial_guess (str, optional): Initial guess to be base timestepper, or copy """ # Check the configuration options @@ -108,6 +133,7 @@ def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_im # Initialise parameters self.base = base_scheme + self.base.dt = domain.dt self.field_name = field_name self.domain = domain self.dt_coarse = domain.dt @@ -125,11 +151,6 @@ def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_im quadType=quad_type, form=formulation) - # Rescale to be over [0,dt] rather than [0,1] - self.nodes = float(self.dt_coarse)*self.nodes - self.dtau = np.diff(np.append(0, self.nodes)) - self.Q = float(self.dt_coarse)*self.Q - self.Qfin = float(self.dt_coarse)*self.weights self.qdelta_imp_type = qdelta_imp self.formulation = formulation self.node_type = node_type @@ -137,10 +158,18 @@ def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_im # Get Q_delta matrices self.Qdelta_imp = genQDeltaCoeffs(qdelta_imp, form=formulation, - nodes=self.nodes, Q=self.Q, nNodes=M, nodeType=node_type, quadType=quad_type) + nodes=self.nodes, Q=self.Q, nNodes=M, nodeType=node_type, quadType=quad_type, k=1) self.Qdelta_exp = genQDeltaCoeffs(qdelta_exp, form=formulation, nodes=self.nodes, Q=self.Q, nNodes=M, nodeType=node_type, quadType=quad_type) + # Rescale to be over [0,dt] rather than [0,1] + self.nodes = float(self.dt_coarse)*self.nodes + self.dtau = np.diff(np.append(0, self.nodes)) + self.Q = float(self.dt_coarse)*self.Q + self.Qfin = float(self.dt_coarse)*self.weights + self.Qdelta_imp = float(self.dt_coarse)*self.Qdelta_imp + self.Qdelta_exp = float(self.dt_coarse)*self.Qdelta_exp + # Set default linear and nonlinear solver options if none passed in if linear_solver_parameters is None: self.linear_solver_parameters = {'snes_type': 'ksponly', @@ -206,7 +235,7 @@ def setup(self, equation, apply_bcs=True, *active_labels): self.quad = [Function(W) for _ in range(self.M)] self.source_Uk = [Function(W) for _ in range(self.M+1)] self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] - self.U_SDC = Function(W) + self.U_DC = Function(W) self.U_start = Function(W) self.Un = Function(W) self.Q_ = Function(W) @@ -283,7 +312,7 @@ def res(self, m): lambda t: t.has_label(time_derivative), map_if_false=drop) residual = mass_form.label_map(all_terms, - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx)) + map_if_true=replace_subject(self.U_DC, old_idx=self.idx)) residual -= mass_form.label_map(all_terms, map_if_true=replace_subject(self.U_start, old_idx=self.idx)) # Loop through nodes up to m-1 and calcualte @@ -349,7 +378,7 @@ def res(self, m): # Qdelta_imp[m,m]*(F(y_(m)^(k+1)) - F(y_(m)^k)) r_imp_kp1 = self.residual.label_map( lambda t: t.has_label(implicit), - map_if_true=replace_subject(self.U_SDC, old_idx=self.idx), + map_if_true=replace_subject(self.U_DC, old_idx=self.idx), map_if_false=drop) r_imp_kp1 = r_imp_kp1.label_map( all_terms, @@ -379,7 +408,7 @@ def solvers(self): solvers = [] for m in range(self.M): # setup solver using residual defined in derived class - problem = NonlinearVariationalProblem(self.res(m), self.U_SDC, bcs=self.bcs) + problem = NonlinearVariationalProblem(self.res(m), self.U_DC, bcs=self.bcs) solver_name = self.field_name+self.__class__.__name__ + "%s" % (m) solvers.append(NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name)) return solvers @@ -429,7 +458,7 @@ def apply(self, x_out, x_in): if self.qdelta_imp_type == "MIN-SR-FLEX": # Recompute Implicit Q_delta matrix for each iteration k - self.Qdelta_imp = genQDeltaCoeffs( + self.Qdelta_imp = float(self.dt_coarse)*genQDeltaCoeffs( self.qdelta_imp_type, form=self.formulation, nodes=self.nodes, @@ -463,7 +492,7 @@ def apply(self, x_out, x_in): if (self.formulation == "N2N"): self.U_start.assign(self.Unodes1[m-1]) self.solver = solver_list[m-1] - self.U_SDC.assign(self.Unodes[m]) + self.U_DC.assign(self.Unodes[m]) # Compute # for N2N: @@ -474,7 +503,7 @@ def apply(self, x_out, x_in): # y_m^(k+1) = y^n + sum(j=1,m) Qdelta_imp[m,j]*(F(y_(m)^(k+1)) - F(y_(m)^k)) # + sum(j=1,M) Q_delta_exp[m,j]*(S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) self.solver.solve() - self.Unodes1[m].assign(self.U_SDC) + self.Unodes1[m].assign(self.U_DC) # Evaluate source terms for evaluate in self.evaluate_source: @@ -509,3 +538,375 @@ def apply(self, x_out, x_in): x_out.assign(self.Unodes[-1]) else: x_out.assign(self.Unodes[-1]) + + +class RIDC(object, metaclass=ABCMeta): + """Class for Revisionist Integral Deferred Correction schemes.""" + + def __init__(self, base_scheme, domain, M, K, field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, + limiter=None, reduced=True): + """ + Initialise RIDC object + Args: + base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on + quadrature nodes. + domain (:class:`Domain`): the model's domain object, containing the + mesh and the compatible function spaces. + M (int): Number of subintervals + K (int): Max number of correction interations + field_name (str, optional): name of the field to be evolved. + Defaults to None. + linear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying linear solver. Defaults to None. + nonlinear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying nonlinear solver. Defaults to None. + limiter (:class:`Limiter` object, optional): a limiter to apply to + the evolving field to enforce monotonicity. Defaults to None. + reduced (bool, optional): whether to use reduced stencils for RIDC. Defaults to True. + """ + self.base = base_scheme + self.field_name = field_name + self.domain = domain + self.dt_coarse = domain.dt + self.limiter = limiter + self.augmentation = self.base.augmentation + self.wrapper = self.base.wrapper + self.K = K + self.M = M + self.reduced = reduced + self.dt = Constant(float(self.dt_coarse)/(self.M)) + + if reduced: + self.Q = [] + for l in range(1, self.K + 1): + _, _, Q = genQCoeffs( + "Collocation", + nNodes=l + 1, + nodeType="EQUID", + quadType="LOBATTO", + form="N2N" + ) + Q = l * float(self.dt) * Q + self.Q.append(Q) + else: + # Get integration weights + _, _, self.Q = genQCoeffs( + "Collocation", + nNodes=self.K + 1, + nodeType="EQUID", + quadType="LOBATTO", + form="N2N" + ) + self.Q = self.K * float(self.dt) * self.Q + + # Set default linear and nonlinear solver options if none passed in + if linear_solver_parameters is None: + self.linear_solver_parameters = {'snes_type': 'ksponly', + 'ksp_type': 'cg', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.linear_solver_parameters = linear_solver_parameters + + if nonlinear_solver_parameters is None: + self.nonlinear_solver_parameters = {'snes_type': 'newtonls', + 'ksp_type': 'gmres', + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'} + else: + self.nonlinear_solver_parameters = nonlinear_solver_parameters + + def setup(self, equation, apply_bcs=True, *active_labels): + """ + Set up the RIDC time discretisation based on the equation. + + Args: + equation (:class:`PrognosticEquation`): the model's equation. + apply_bcs (bool, optional): whether to apply the equation's boundary + conditions. Defaults to True. + *active_labels (:class:`Label`): labels indicating which terms of + the equation to include. + """ + # Inherit from base time discretisation + self.base.setup(equation, apply_bcs, *active_labels) + self.equation = self.base.equation + self.residual = self.base.residual + self.evaluate_source = self.base.evaluate_source + + for t in self.residual: + # Check all terms are labeled implicit or explicit + if ((not t.has_label(implicit)) and (not t.has_label(explicit)) + and (not t.has_label(time_derivative)) and (not t.has_label(source_label))): + raise NotImplementedError("Non time-derivative or source terms must be labeled as implicit or explicit") + + # Set up bcs + self.bcs = self.base.bcs + + # Set up RIDC variables + if self.field_name is not None and hasattr(equation, "field_names"): + self.idx = equation.field_names.index(self.field_name) + W = equation.spaces[self.idx] + else: + self.field_name = equation.field_name + W = equation.function_space + self.idx = None + self.W = W + self.Unodes = [Function(W) for _ in range(self.M+1)] + self.Unodes1 = [Function(W) for _ in range(self.M+1)] + self.fUnodes = [Function(W) for _ in range(self.M+1)] + self.quad = [Function(W) for _ in range(self.M+1)] + self.source_Uk = [Function(W) for _ in range(self.M+1)] + self.source_Ukp1 = [Function(W) for _ in range(self.M+1)] + self.U_DC = Function(W) + self.U_start = Function(W) + self.Un = Function(W) + self.Q_ = Function(W) + self.quad_final = Function(W) + self.U_fin = Function(W) + self.Urhs = Function(W) + self.Uin = Function(W) + self.source_in = Function(W) + self.source_Ukp1_m = Function(W) + self.source_Uk_m = Function(W) + self.Uk_mp1 = Function(W) + self.Uk_m = Function(W) + self.Ukp1_m = Function(W) + + @property + def nlevels(self): + return 1 + + def compute_quad(self, Q, fUnodes, m): + """ + Computes integration of F(y) on quadrature nodes + """ + quad = Function(self.W) + quad.assign(0.) + for k in range(0, np.shape(Q)[1]): + quad += float(Q[m, k])*fUnodes[k] + return quad + + def compute_quad_final(self, Q, fUnodes, m): + """ + Computes final integration of F(y) on quadrature nodes + """ + quad = Function(self.W) + quad.assign(0.) + if self.reduced: + l = np.shape(Q)[0] - 1 + else: + l = self.K + for k in range(0, l+1): + quad += float(Q[-1, k])*fUnodes[m - l + k] + return quad + + @property + def res_rhs(self): + """Set up the residual for the calculation of F(y).""" + a = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Urhs, old_idx=self.idx), + drop) + # F(y) + L = self.residual.label_map(lambda t: any(t.has_label(time_derivative, source_label)), + drop, + replace_subject(self.Uin, old_idx=self.idx)) + L_source = self.residual.label_map(lambda t: t.has_label(source_label), + replace_subject(self.source_in, old_idx=self.idx), + drop) + residual_rhs = a - (L + L_source) + return residual_rhs.form + + @property + def res(self): + """Set up the discretisation's residual.""" + # Add time derivative terms y^(k+1)_m - y_n + mass_form = self.residual.label_map( + lambda t: t.has_label(time_derivative), + map_if_false=drop) + residual = mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_DC, old_idx=self.idx)) + residual -= mass_form.label_map(all_terms, + map_if_true=replace_subject(self.U_start, old_idx=self.idx)) + + # Calculate source terms + r_source_kp1 = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_source_kp1 = r_source_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_source_kp1 + + r_source_k = self.residual.label_map( + lambda t: t.has_label(source_label), + map_if_true=replace_subject(self.source_Uk_m, old_idx=self.idx), + map_if_false=drop) + r_source_k = r_source_k.label_map( + all_terms, + map_if_true=lambda t: Constant(self.dt)*t) + residual -= r_source_k + + # Add on final implicit terms + # dt*(F(y_(m)^(k+1)) - F(y_(m)^k)) + r_imp_kp1 = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.U_DC, old_idx=self.idx), + map_if_false=drop) + r_imp_kp1 = r_imp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_imp_kp1 + r_imp_k = self.residual.label_map( + lambda t: t.has_label(implicit), + map_if_true=replace_subject(self.Uk_mp1, old_idx=self.idx), + map_if_false=drop) + r_imp_k = r_imp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_imp_k + + r_exp_kp1 = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Ukp1_m, old_idx=self.idx), + map_if_false=drop) + r_exp_kp1 = r_exp_kp1.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual += r_exp_kp1 + r_exp_k = self.residual.label_map( + lambda t: t.has_label(explicit), + map_if_true=replace_subject(self.Uk_m, old_idx=self.idx), + map_if_false=drop) + r_exp_k = r_exp_k.label_map( + all_terms, + lambda t: Constant(self.dt)*t) + residual -= r_exp_k + + # Add on sum(j=1,M) s_mj*F(y_m^k), where s_mj = q_mj-q_m-1j + # and s1j = q1j. + Q = self.residual.label_map(lambda t: t.has_label(time_derivative), + replace_subject(self.Q_, old_idx=self.idx), + drop) + residual += Q + return residual.form + + @cached_property + def solver(self): + """Set up the problem and the solver for the nonlinear solve.""" + # setup solver using residual defined in derived class + problem = NonlinearVariationalProblem(self.res, self.U_DC, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__ + solver = NonlinearVariationalSolver(problem, solver_parameters=self.nonlinear_solver_parameters, options_prefix=solver_name) + return solver + + @cached_property + def solver_rhs(self): + """Set up the problem and the solver for mass matrix inversion.""" + # setup linear solver using rhs residual defined in derived class + prob_rhs = NonlinearVariationalProblem(self.res_rhs, self.Urhs, bcs=self.bcs) + solver_name = self.field_name+self.__class__.__name__+"_rhs" + return NonlinearVariationalSolver(prob_rhs, solver_parameters=self.linear_solver_parameters, + options_prefix=solver_name) + + @wrapper_apply + def apply(self, x_out, x_in): + self.Un.assign(x_in) + + # Compute initial guess on quadrature nodes with low-order + # base timestepper + self.Unodes[0].assign(self.Un) + self.M1 = self.K + + for m in range(self.M): + self.base.dt = float(self.dt) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + + for m in range(self.M+1): + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + + # Iterate through correction sweeps + for k in range(1, self.K+1): + # Compute: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) + for m in range(self.M+1): + self.Uin.assign(self.Unodes[m]) + # Include source terms + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m].assign(self.Urhs) + + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + if self.reduced: + self.M1 = k + for m in range(0, self.M1): + # Set integration matrix + if self.reduced: + self.Q_.assign(self.compute_quad(self.Q[k-1], self.fUnodes, m+1)) + else: + self.Q_.assign(self.compute_quad(self.Q, self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_DC.assign(self.Unodes[m+1]) + + # Compute: + # y_m^(k+1) = y_(m-1)^(k+1) + dt*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_DC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + for m in range(self.M1, self.M): + # Set integration matrix + if self.reduced: + self.Q_.assign(self.compute_quad_final(self.Q[k-1], self.fUnodes, m+1)) + else: + self.Q_.assign(self.compute_quad_final(self.Q, self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_DC.assign(self.Unodes[m+1]) + + # Compute: + # y_m^(k+1) = y_(m-1)^(k+1) + dt*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_DC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + + for m in range(self.M+1): + self.Unodes[m].assign(self.Unodes1[m]) + self.source_Uk[m].assign(self.source_Ukp1[m]) + + x_out.assign(self.Unodes[-1]) diff --git a/gusto/time_discretisation/explicit_runge_kutta.py b/gusto/time_discretisation/explicit_runge_kutta.py index a5793399b..aa3214222 100644 --- a/gusto/time_discretisation/explicit_runge_kutta.py +++ b/gusto/time_discretisation/explicit_runge_kutta.py @@ -321,7 +321,6 @@ def solve_stage(self, x0, stage): if self.rk_formulation == RungeKuttaFormulation.increment: self.x1.assign(x0) - for i in range(stage): self.x1.assign(self.x1 + self.dt*self.butcher_matrix[stage-1, i]*self.k[i]) for evaluate in self.evaluate_source: @@ -340,8 +339,6 @@ def solve_stage(self, x0, stage): self.x1.assign(x0) for i in range(self.nStages): self.x1.assign(self.x1 + self.dt*self.butcher_matrix[stage, i]*self.k[i]) - self.x1.assign(self.x1) - if self.limiter is not None: self.limiter.apply(self.x1) diff --git a/gusto/time_discretisation/imex_runge_kutta.py b/gusto/time_discretisation/imex_runge_kutta.py index ba3911fd4..82e2a5ab9 100644 --- a/gusto/time_discretisation/imex_runge_kutta.py +++ b/gusto/time_discretisation/imex_runge_kutta.py @@ -93,6 +93,15 @@ def __init__(self, domain, butcher_imp, butcher_exp, field_name=None, self.butcher_exp = butcher_exp self.nStages = int(np.shape(self.butcher_imp)[1]) + # Some butcher tableaus have zero first stage, if so, we don't need to do an + # initial solve and can copy across x_in to x_s[0] + self.zero_first_stage = True + self.solver_start_stage = 1 + for value in self.butcher_imp[0]: + if value != 0.0: + self.zero_first_stage = False + self.solver_start_stage = 0 + # Set default linear and nonlinear solver options if none passed in if linear_solver_parameters is None: self.linear_solver_parameters = {'snes_type': 'ksponly', @@ -231,7 +240,7 @@ def final_res(self): def solvers(self): """Set up a list of solvers for each problem at a stage.""" solvers = [] - for stage in range(self.nStages): + for stage in range(self.solver_start_stage, self.nStages): # setup solver using residual defined in derived class problem = NonlinearVariationalProblem(self.res(stage), self.x_out, bcs=self.bcs) solver_name = self.field_name+self.__class__.__name__ + "%s" % (stage) @@ -251,15 +260,16 @@ def apply(self, x_out, x_in): self.x1.assign(x_in) self.x_out.assign(x_in) solver_list = self.solvers + self.xs[0].assign(x_in) + + for stage in range(self.solver_start_stage, self.nStages): - for stage in range(self.nStages): - self.solver = solver_list[stage] + self.solver = solver_list[stage-self.solver_start_stage] # Set initial solver guess - if (stage > 0): - self.x_out.assign(self.xs[stage-1]) - # Evaluate source terms - for evaluate in self.evaluate_source: - evaluate(self.xs[stage-1], self.dt, x_out=self.source[stage-1]) + self.x_out.assign(self.xs[stage-1]) + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.xs[stage-1], self.dt, x_out=self.source[stage-1]) self.solver.solve() # Apply limiter diff --git a/gusto/time_discretisation/parallel_dc.py b/gusto/time_discretisation/parallel_dc.py new file mode 100644 index 000000000..0dfd52ef6 --- /dev/null +++ b/gusto/time_discretisation/parallel_dc.py @@ -0,0 +1,404 @@ +""" +Objects for discretising time derivatives using time-parallel Deferred Correction +Methods. + +This module inherits from the serial SDC and RIDC classes, and implements the +parallelisation of the SDC and RIDC methods using MPI. + +SDC parallelises across the quadrature nodes by using diagonal QDelta matrices, +while RIDC parallelises across the correction iterations by using a reduced stencil +and pipelining. +""" + +from firedrake import ( + Function +) +from gusto.time_discretisation.time_discretisation import wrapper_apply +from qmat import genQDeltaCoeffs +from gusto.time_discretisation.deferred_correction import SDC, RIDC +from gusto.core.logging import logger + +__all__ = ["Parallel_RIDC", "Parallel_SDC"] + + +class Parallel_RIDC(RIDC): + """Class for Parallel Revisionist Integral Deferred Correction schemes.""" + + def __init__(self, base_scheme, domain, M, K, J, output_freq, flush_freq=None, field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, + limiter=None, communicator=None): + """ + Initialise RIDC object + Args: + base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on + quadrature nodes. + domain (:class:`Domain`): the model's domain object, containing the + mesh and the compatible function spaces. + M (int): Number of subintervals + K (int): Max number of correction interations + J (int): Number of intervals + output_freq (int): Frequency at which output is done + flush_freq (int): Frequency at which to flush the pipeline + field_name (str, optional): name of the field to be evolved. + Defaults to None. + linear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying linear solver. Defaults to None. + nonlinear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying nonlinear solver. Defaults to None. + limiter (:class:`Limiter` object, optional): a limiter to apply to + the evolving field to enforce monotonicity. Defaults to None. + communicator (MPI communicator, optional): communicator for parallel execution. Defaults to None. + """ + + super(Parallel_RIDC, self).__init__(base_scheme, domain, M, K, field_name, + linear_solver_parameters, nonlinear_solver_parameters, + limiter, reduced=True) + self.comm = communicator + self.TAG_EXCHANGE_FIELD = 11 # Tag for sending nodal fields (Firedrake Functions) + self.TAG_EXCHANGE_SOURCE = self.TAG_EXCHANGE_FIELD + J # Tag for sending nodal source fields (Firedrake Functions) + self.TAG_FLUSH_PIPE = self.TAG_EXCHANGE_SOURCE + J # Tag for flushing pipe and restarting + self.TAG_FINAL_OUT = self.TAG_FLUSH_PIPE + J # Tag for the final broadcast and output + self.TAG_END_INTERVAL = self.TAG_FINAL_OUT + J # Tag for telling the rank above you that you have ended interval j + + if flush_freq is None: + self.flush_freq = 1 + else: + self.flush_freq = flush_freq + + self.J = J + self.step = 1 + self.output_freq = output_freq + + if self.flush_freq == 0 or (self.flush_freq != 0 and self.output_freq % self.flush_freq != 0): + logger.warn("Output on all parallel in time ranks will not be the same until end of run!") + + # Checks for parallel RIDC + if self.comm is None: + raise ValueError("No communicator provided. Please provide a valid MPI communicator.") + if self.comm.ensemble_comm.size != self.K + 1: + raise ValueError("Number of ranks must be equal to K+1 for Parallel RIDC.") + if self.M < self.K*(self.K+1)//2: + raise ValueError("Number of subintervals M must be greater than K*(K+1)/2 for Parallel RIDC.") + + def setup(self, equation, apply_bcs=True, *active_labels): + """ + Set up the RIDC time discretisation based on the equation. + + Args: + equation (:class:`PrognosticEquation`): the model's equation. + apply_bcs (bool, optional): whether to apply the equation's boundary + conditions. Defaults to True. + *active_labels (:class:`Label`): labels indicating which terms of + the equation to include. + """ + super(Parallel_RIDC, self).setup(equation, apply_bcs, *active_labels) + + self.Uk_mp1 = Function(self.W) + self.Uk_m = Function(self.W) + self.Ukp1_m = Function(self.W) + + @wrapper_apply + def apply(self, x_out, x_in): + # Set up varibles on this rank + x_out.assign(x_in) + self.kval = self.comm.ensemble_comm.rank + self.Un.assign(x_in) + self.Unodes[0].assign(self.Un) + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + self.Uin.assign(self.Unodes[0]) + self.solver_rhs.solve() + self.fUnodes[0].assign(self.Urhs) + + # On first communicator, we do the predictor step + if (self.comm.ensemble_comm.rank == 0): + # Base timestepper + for m in range(self.M): + self.base.dt = float(self.dt) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m+1], self.base.dt, x_out=self.source_Uk[m+1]) + + # Send base guess to k+1 correction + self.comm.send(self.Unodes[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_FIELD + self.step) + self.comm.send(self.source_Uk[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_SOURCE + self.step) + else: + for m in range(1, self.kval + 1): + # Receive and evaluate the stencil of guesses we need to correct + self.comm.recv(self.Unodes[m], source=self.kval-1, tag=self.TAG_EXCHANGE_FIELD + self.step) + self.comm.recv(self.source_Uk[m], source=self.kval-1, tag=self.TAG_EXCHANGE_SOURCE + self.step) + self.Uin.assign(self.Unodes[m]) + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m].assign(self.Urhs) + for m in range(0, self.kval): + # Set S matrix + self.Q_.assign(self.compute_quad(self.Q[self.kval-1], self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_DC.assign(self.Unodes[m+1]) + + # Compute + # y_m^(k+1) = y_(m-1)^(k+1) + dt*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y_j^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_DC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + # Send our updated value to next communicator + if self.kval < self.K: + self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_FIELD + self.step) + self.comm.send(self.source_Ukp1[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_SOURCE + self.step) + + for m in range(self.kval, self.M): + # Receive the guess we need to correct and evaluate the rhs + self.comm.recv(self.Unodes[m+1], source=self.kval-1, tag=self.TAG_EXCHANGE_FIELD + self.step) + self.comm.recv(self.source_Uk[m+1], source=self.kval-1, tag=self.TAG_EXCHANGE_SOURCE + self.step) + self.Uin.assign(self.Unodes[m+1]) + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[m+1].assign(self.Urhs) + + # Set S matrix + self.Q_.assign(self.compute_quad_final(self.Q[self.kval-1], self.fUnodes, m+1)) + + # Set initial guess for solver, and pick correct solver + self.U_start.assign(self.Unodes1[m]) + self.Ukp1_m.assign(self.Unodes1[m]) + self.Uk_mp1.assign(self.Unodes[m+1]) + self.Uk_m.assign(self.Unodes[m]) + self.source_Ukp1_m.assign(self.source_Ukp1[m]) + self.source_Uk_m.assign(self.source_Uk[m]) + self.U_DC.assign(self.Unodes[m+1]) + + # y_m^(k+1) = y_(m-1)^(k+1) + dt*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + self.solver.solve() + self.Unodes1[m+1].assign(self.U_DC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[m+1], self.base.dt, x_out=self.source_Ukp1[m+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[m+1]) + + # Send our updated value to next communicator + if self.kval < self.K: + self.comm.send(self.Unodes1[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_FIELD + self.step) + self.comm.send(self.source_Ukp1[m+1], dest=self.kval+1, tag=self.TAG_EXCHANGE_SOURCE + self.step) + + for m in range(self.M+1): + self.Unodes[m].assign(self.Unodes1[m]) + self.source_Uk[m].assign(self.source_Ukp1[m]) + + if (self.flush_freq > 0 and self.step % self.flush_freq == 0) or self.step == self.J: + # Flush the pipe to ensure all ranks have the same data + if (self.kval == self.K): + x_out.assign(self.Unodes[-1]) + for i in range(self.K): + self.comm.send(x_out, dest=i, tag=self.TAG_FLUSH_PIPE + self.step) + else: + self.comm.recv(x_out, source=self.K, tag=self.TAG_FLUSH_PIPE + self.step) + else: + x_out.assign(self.Unodes[-1]) + + self.step += 1 + + +class Parallel_SDC(SDC): + """Class for Spectral Deferred Correction schemes.""" + + def __init__(self, base_scheme, domain, M, maxk, quad_type, node_type, qdelta_imp, qdelta_exp, + field_name=None, + linear_solver_parameters=None, nonlinear_solver_parameters=None, final_update=True, + limiter=None, options=None, initial_guess="base", communicator=None): + """ + Initialise SDC object + Args: + base_scheme (:class:`TimeDiscretisation`): Base time stepping scheme to get first guess of solution on + quadrature nodes. + domain (:class:`Domain`): the model's domain object, containing the + mesh and the compatible function spaces. + M (int): Number of quadrature nodes to compute spectral integration over + maxk (int): Max number of correction interations + quad_type (str): Type of quadrature to be used. Options are + GAUSS, RADAU-LEFT, RADAU-RIGHT and LOBATTO + node_type (str): Node type to be used. Options are + EQUID, LEGENDRE, CHEBY-1, CHEBY-2, CHEBY-3 and CHEBY-4 + qdelta_imp (str): Implicit Qdelta matrix to be used. Options are + BE, LU, TRAP, EXACT, PIC, OPT, WEIRD, MIN-SR-NS, MIN-SR-S + qdelta_exp (str): Explicit Qdelta matrix to be used. Options are + FE, EXACT, PIC + field_name (str, optional): name of the field to be evolved. + Defaults to None. + linear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying linear solver. Defaults to None. + nonlinear_solver_parameters (dict, optional): dictionary of parameters to + pass to the underlying nonlinear solver. Defaults to None. + final_update (bool, optional): Whether to compute final update, or just take last + quadrature value. Defaults to True + limiter (:class:`Limiter` object, optional): a limiter to apply to + the evolving field to enforce monotonicity. Defaults to None. + initial_guess (str, optional): Initial guess to be base timestepper, or copy + communicator (MPI communicator, optional): communicator for parallel execution. Defaults to None. + """ + super().__init__(base_scheme, domain, M, maxk, quad_type, node_type, qdelta_imp, qdelta_exp, + formulation="Z2N", field_name=field_name, + linear_solver_parameters=linear_solver_parameters, nonlinear_solver_parameters=nonlinear_solver_parameters, + final_update=final_update, + limiter=limiter, initial_guess=initial_guess) + self.comm = communicator + + # Checks for parallel SDC + if self.comm is None: + raise ValueError("No communicator provided. Please provide a valid MPI communicator.") + if self.comm.ensemble_comm.size != self.M: + raise ValueError("Number of ranks must be equal to the number of nodes M for Parallel SDC.") + + def compute_quad(self): + """ + Computes integration of F(y) on quadrature nodes + """ + x = Function(self.W) + for j in range(self.M): + x.assign(float(self.Q[j, self.comm.ensemble_comm.rank])*self.fUnodes[self.comm.ensemble_comm.rank]) + self.comm.reduce(x, self.quad[j], root=j) + + def compute_quad_final(self): + """ + Computes final integration of F(y) on quadrature nodes + """ + x = Function(self.W) + x.assign(float(self.Qfin[self.comm.ensemble_comm.rank])*self.fUnodes[self.comm.ensemble_comm.rank]) + self.comm.allreduce(x, self.quad_final) + + @wrapper_apply + def apply(self, x_out, x_in): + self.Un.assign(x_in) + self.U_start.assign(self.Un) + solver_list = self.solvers + + # Compute initial guess on quadrature nodes with low-order + # base timestepper + self.Unodes[0].assign(self.Un) + if (self.base_flag): + for m in range(self.M): + self.base.dt = float(self.dtau[m]) + self.base.apply(self.Unodes[m+1], self.Unodes[m]) + else: + for m in range(self.M): + self.Unodes[m+1].assign(self.Un) + for m in range(self.M+1): + for evaluate in self.evaluate_source: + evaluate(self.Unodes[m], self.base.dt, x_out=self.source_Uk[m]) + + # Iterate through correction sweeps + k = 0 + while k < self.maxk: + k += 1 + + if self.qdelta_imp_type == "MIN-SR-FLEX": + # Recompute Implicit Q_delta matrix for each iteration k + self.Qdelta_imp = float(self.dt_coarse)*genQDeltaCoeffs( + self.qdelta_imp_type, + form=self.formulation, + nodes=self.nodes, + Q=self.Q, + nNodes=self.M, + nodeType=self.node_type, + quadType=self.quad_type, + k=k + ) + + # Compute for N2N: sum(j=1,M) (s_mj*F(y_m^k) + s_mj*S(y_m^k)) + # for Z2N: sum(j=1,M) (q_mj*F(y_m^k) + q_mj*S(y_m^k)) + self.Uin.assign(self.Unodes[self.comm.ensemble_comm.rank+1]) + # Include source terms + for evaluate in self.evaluate_source: + evaluate(self.Uin, self.base.dt, x_out=self.source_in) + self.solver_rhs.solve() + self.fUnodes[self.comm.ensemble_comm.rank].assign(self.Urhs) + + self.compute_quad() + + # Loop through quadrature nodes and solve + self.Unodes1[0].assign(self.Unodes[0]) + for evaluate in self.evaluate_source: + evaluate(self.Unodes[0], self.base.dt, x_out=self.source_Uk[0]) + + # Set Q or S matrix + self.Q_.assign(self.quad[self.comm.ensemble_comm.rank]) + + # Set initial guess for solver, and pick correct solver + self.solver = solver_list[self.comm.ensemble_comm.rank] + self.U_DC.assign(self.Unodes[self.comm.ensemble_comm.rank+1]) + + # Compute + # for N2N: + # y_m^(k+1) = y_(m-1)^(k+1) + dtau_m*(F(y_(m)^(k+1)) - F(y_(m)^k) + # + S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + # + sum(j=1,M) s_mj*(F+S)(y^k) + # for Z2N: + # y_m^(k+1) = y^n + sum(j=1,m) Qdelta_imp[m,j]*(F(y_(m)^(k+1)) - F(y_(m)^k)) + # + sum(j=1,M) Q_delta_exp[m,j]*(S(y_(m-1)^(k+1)) - S(y_(m-1)^k)) + self.solver.solve() + self.Unodes1[self.comm.ensemble_comm.rank+1].assign(self.U_DC) + + # Evaluate source terms + for evaluate in self.evaluate_source: + evaluate(self.Unodes1[self.comm.ensemble_comm.rank+1], self.base.dt, x_out=self.source_Ukp1[self.comm.ensemble_comm.rank+1]) + + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.Unodes1[self.comm.ensemble_comm.rank+1]) + + self.Unodes[self.comm.ensemble_comm.rank+1].assign(self.Unodes1[self.comm.ensemble_comm.rank+1]) + self.source_Uk[self.comm.ensemble_comm.rank+1].assign(self.source_Ukp1[self.comm.ensemble_comm.rank+1]) + + if self.maxk > 0: + # Compute value at dt rather than final quadrature node tau_M + if self.final_update: + self.Uin.assign(self.Unodes1[self.comm.ensemble_comm.rank+1]) + self.source_in.assign(self.source_Ukp1[self.comm.ensemble_comm.rank+1]) + self.solver_rhs.solve() + self.fUnodes[self.comm.ensemble_comm.rank].assign(self.Urhs) + self.compute_quad_final() + # Compute y_(n+1) = y_n + sum(j=1,M) q_j*F(y_j) + if self.comm.ensemble_comm.rank == self.M-1: + self.U_fin.assign(self.Unodes[-1]) + self.comm.bcast(self.U_fin, self.M-1) + self.solver_fin.solve() + # Apply limiter if required + if self.limiter is not None: + self.limiter.apply(self.U_fin) + x_out.assign(self.U_fin) + else: + # Take value at final quadrature node dtau_M + if self.comm.ensemble_comm.rank == self.M-1: + x_out.assign(self.Unodes[-1]) + self.comm.bcast(x_out, self.M-1) + else: + # Take value at final quadrature node dtau_M + if self.comm.ensemble_comm.rank == self.M-1: + x_out.assign(self.Unodes[-1]) + self.comm.bcast(x_out, self.M-1) diff --git a/gusto/time_discretisation/time_discretisation.py b/gusto/time_discretisation/time_discretisation.py index 377479b83..b9578f720 100644 --- a/gusto/time_discretisation/time_discretisation.py +++ b/gusto/time_discretisation/time_discretisation.py @@ -23,6 +23,7 @@ from gusto.time_discretisation.wrappers import * from gusto.solvers import mass_parameters + __all__ = ["TimeDiscretisation", "ExplicitTimeDiscretisation", "BackwardEuler", "ThetaMethod", "TrapeziumRule", "TR_BDF2"] diff --git a/gusto/timestepping/timestepper.py b/gusto/timestepping/timestepper.py index 81e52e047..a16950749 100644 --- a/gusto/timestepping/timestepper.py +++ b/gusto/timestepping/timestepper.py @@ -358,6 +358,8 @@ def setup_scheme(self): self.setup_equation(self.equation) self.scheme.setup(self.equation) self.setup_transporting_velocity(self.scheme) + if hasattr(self.scheme, 'base'): + self.setup_transporting_velocity(self.scheme.base) if self.io.output.log_courant: self.scheme.courant_max = self.io.courant_max diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 02b1addfc..7fce74377 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -17,9 +17,13 @@ def tracer_sphere(tmpdir, degree, small_dt): radius = 1 - mesh = IcosahedralSphereMesh(radius=radius, - refinement_level=3, - degree=1) + + dirname = str(tmpdir) + mesh = IcosahedralSphereMesh( + radius=radius, + refinement_level=3, + degree=1 + ) x = SpatialCoordinate(mesh) # Parameters chosen so that dt != 1 @@ -30,7 +34,7 @@ def tracer_sphere(tmpdir, degree, small_dt): else: dt = pi/3. * 0.02 - output = OutputParameters(dirname=str(tmpdir), dumpfreq=15) + output = OutputParameters(dirname=dirname, dumpfreq=15) domain = Domain(mesh, dt, family="BDM", degree=degree) io = IO(domain, output) @@ -49,6 +53,8 @@ def tracer_sphere(tmpdir, degree, small_dt): def tracer_slice(tmpdir, degree, small_dt): n = 30 if degree == 0 else 15 + + dirname = str(tmpdir) m = PeriodicIntervalMesh(n, 1.) mesh = ExtrudedMesh(m, layers=n, layer_height=1./n) @@ -61,7 +67,7 @@ def tracer_slice(tmpdir, degree, small_dt): else: dt = 0.01 tmax = 0.75 - output = OutputParameters(dirname=str(tmpdir), dumpfreq=25) + output = OutputParameters(dirname=dirname, dumpfreq=25) domain = Domain(mesh, dt, family="CG", degree=degree) io = IO(domain, output) @@ -89,10 +95,11 @@ def tracer_blob_slice(tmpdir, degree, small_dt): else: dt = 0.01 L = 10. + dirname = str(tmpdir) m = PeriodicIntervalMesh(10, L) mesh = ExtrudedMesh(m, layers=10, layer_height=1.) - output = OutputParameters(dirname=str(tmpdir), dumpfreq=25) + output = OutputParameters(dirname=dirname, dumpfreq=25) domain = Domain(mesh, dt, family="CG", degree=degree) io = IO(domain, output) diff --git a/integration-tests/model/test_sdc.py b/integration-tests/model/test_deferred_correction.py similarity index 68% rename from integration-tests/model/test_sdc.py rename to integration-tests/model/test_deferred_correction.py index fd1dba068..7db9051f1 100644 --- a/integration-tests/model/test_sdc.py +++ b/integration-tests/model/test_deferred_correction.py @@ -1,11 +1,14 @@ """ -This runs a simple transport test on the sphere using the SDC time discretisations to +This runs a simple transport test on the sphere using the DC time discretisations to test whether the errors are within tolerance. The test is run for the following schemes: - IMEX_SDC_Le(1,1) - IMEX SDC with 1 quadrature node of Gauss type (2nd order scheme) - IMEX_SDC_R(2,2) - IMEX SDC with 2 qaudrature nodes of Radau type (3rd order scheme) using LU decomposition for the implicit update - BE_SDC_Lo(3,3) - Implicit SDC with 3 quadrature nodes of Lobatto type (4th order scheme). - FE_SDC_Le(3,5) - Explicit SDC with 3 quadrature nodes of Gauss type (6th order scheme). +- IMEX_RIDC_R(3) - IMEX RIDC with 4 quadrature nodes of equidistant type, reduced stencils (3rd order scheme). +- BE_RIDC(4) - Implicit RIDC with 3 quadrature nodes of equidistant type, full stencils (4th order scheme). +- FE_RIDC(4) - Explicit RIDC with 3 quadrature nodes of equidistant type, full stencils (4th order scheme). """ from firedrake import norm @@ -19,8 +22,9 @@ def run(timestepper, tmax, f_end): @pytest.mark.parametrize( - "scheme", ["IMEX_SDC_Le(1,1)", "IMEX_SDC_R(2,2)", "BE_SDC_Lo(3,3)", "FE_SDC_Le(3,5)"]) -def test_sdc(tmpdir, scheme, tracer_setup): + "scheme", ["IMEX_SDC_Le(1,1)", "IMEX_SDC_R(2,2)", "BE_SDC_Lo(3,3)", "FE_SDC_Le(3,5)", "IMEX_RIDC_R(3)", + "BE_RIDC(4)", "FE_RIDC(4)"]) +def test_dc(tmpdir, scheme, tracer_setup): geometry = "sphere" setup = tracer_setup(tmpdir, geometry) domain = setup.domain @@ -66,11 +70,33 @@ def test_sdc(tmpdir, scheme, tracer_setup): elif scheme == "FE_SDC_Le(3,5)": quad_type = "GAUSS" M = 3 - k = 4 + k = 5 eqn.label_terms(lambda t: not t.has_label(time_derivative), explicit) base_scheme = ForwardEuler(domain) scheme = SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, qdelta_exp, final_update=True, initial_guess="base") + elif scheme == "IMEX_RIDC_R(3)": + k = 2 + M = k*(k+1)//2 + 1 + eqn = ContinuityEquation(domain, V, "f") + # Split continuity term + eqn = split_continuity_form(eqn) + eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) + eqn.label_terms(lambda t: t.has_label(transport), explicit) + base_scheme = IMEX_Euler(domain) + scheme = RIDC(base_scheme, domain, M, k, reduced=True) + elif scheme == "BE_RIDC(4)": + k = 3 + M = 3 + eqn.label_terms(lambda t: not t.has_label(time_derivative), implicit) + base_scheme = BackwardEuler(domain) + scheme = RIDC(base_scheme, domain, M, k, reduced=False) + elif scheme == "FE_RIDC(4)": + M = 3 + k = 3 + eqn.label_terms(lambda t: not t.has_label(time_derivative), explicit) + base_scheme = ForwardEuler(domain) + scheme = RIDC(base_scheme, domain, M, k, reduced=False) transport_method = DGUpwind(eqn, 'f') diff --git a/integration-tests/model/test_nc_outputting.py b/integration-tests/model/test_nc_outputting.py index 55d63264d..8905de324 100644 --- a/integration-tests/model/test_nc_outputting.py +++ b/integration-tests/model/test_nc_outputting.py @@ -12,9 +12,10 @@ ZComponent, MeridionalComponent, ZonalComponent, RadialComponent, DGUpwind) from mpi4py import MPI -from netCDF4 import Dataset, chartostring +from netCDF4 import Dataset import pytest from pytest_mpi import parallel_assert +import numpy as np def make_dirname(test_name, suffix=""): @@ -169,6 +170,17 @@ def assertion(): return output_data[metadata_key][0] - output_value < 1e-14 else: def assertion(): - return str(chartostring(output_data[metadata_key][0])) == output_value + var = output_data[metadata_key] + row = var[0, :] if getattr(var, 'ndim', None) == 2 else var[:] + arr = np.array(row) + + if arr.dtype.kind == 'S': # bytes-like chars + decoded = b''.join(arr.tolist()).decode('utf-8').rstrip('\x00') + elif arr.dtype.kind in ('U', 'O'): # already text + decoded = ''.join(arr.tolist()).rstrip('\x00') + else: # fallback + decoded = b''.join(arr.view('S1').tolist()).decode('utf-8').rstrip('\x00') + + return decoded == output_value parallel_assert(assertion, participating=output_data is not None, msg=error_message) diff --git a/integration-tests/model/test_parallel_dc.py b/integration-tests/model/test_parallel_dc.py new file mode 100644 index 000000000..f1b3fe585 --- /dev/null +++ b/integration-tests/model/test_parallel_dc.py @@ -0,0 +1,111 @@ +""" +This runs a simple transport test on the sphere using the parallel DC time discretisations to +test whether the errors are within tolerance. The test is run for the following schemes: +- IMEX_SDC(2,2) - IMEX SDC with 2 qaudrature nodes of Radau type and 2 correction sweeps (2nd order scheme) +- IMEX_RIDC(2,1) - IMEX RIDC with 3 quadrature nodes of equidistant type, 1 correction sweep, reduced stencils (2nd order scheme). + Has a pipeline flush frequency of 1 (every timestep). +- IMEX_RIDC(2,5) - IMEX RIDC with 3 quadrature nodes of equidistant type, 1 correction sweep, reduced stencils (2nd order scheme). + Has a pipeline flush frequency of 5 (every 5 timesteps). +""" + +from firedrake import (norm, Ensemble, COMM_WORLD, SpatialCoordinate, + as_vector, pi, exp, IcosahedralSphereMesh) + +from gusto import * +import pytest +from pytest_mpi.parallel_assert import parallel_assert + + +def run(timestepper, tmax, f_end): + timestepper.run(0, tmax) + print(norm(timestepper.fields("f") - f_end) / norm(f_end)) + return norm(timestepper.fields("f") - f_end) / norm(f_end) + + +@pytest.mark.parallel(nprocs=[2]) +@pytest.mark.parametrize( + "scheme", ["IMEX_RIDC(2,1)", "IMEX_RIDC(2,5)", "IMEX_SDC(2,2)"]) +def test_parallel_dc(tmpdir, scheme): + + if scheme == "IMEX_SDC(2,2)": + M = 2 + k = 2 + ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(M)) + elif scheme == "IMEX_RIDC(2,1)" or scheme == "IMEX_RIDC(2,5)": + k = 1 + ensemble = Ensemble(COMM_WORLD, COMM_WORLD.size//(k+1)) + + # Get the tracer setup + radius = 1 + dirname = str(tmpdir) + mesh = IcosahedralSphereMesh( + radius=radius, + refinement_level=3, + degree=1, + comm=ensemble.comm + ) + x = SpatialCoordinate(mesh) + + # Parameters chosen so that dt != 1 + # Gaussian is translated from (lon=pi/2, lat=0) to (lon=0, lat=0) + # to demonstrate that transport is working correctly + + dt = pi/3. * 0.02 + dumpfreq = 15 + output = OutputParameters(dirname=dirname, dump_vtus=False, dump_nc=True, dumpfreq=dumpfreq) + domain = Domain(mesh, dt, family="BDM", degree=1) + io = IO(domain, output) + + umax = 1.0 + uexpr = as_vector([- umax * x[1] / radius, umax * x[0] / radius, 0.0]) + + tmax = pi/2 + f_init = exp(-x[2]**2 - x[0]**2) + f_end = exp(-x[2]**2 - x[1]**2) + + tol = 0.05 + + domain = domain + V = domain.spaces("DG") + eqn = ContinuityEquation(domain, V, "f") + + if scheme == "IMEX_SDC(2,2)": + eqn.label_terms(lambda t: not t.has_label(time_derivative), implicit) + + quad_type = "RADAU-RIGHT" + node_type = "LEGENDRE" + qdelta_imp = "MIN-SR-FLEX" + qdelta_exp = "MIN-SR-NS" + base_scheme = IMEX_Euler(domain) + time_scheme = Parallel_SDC(base_scheme, domain, M, k, quad_type, node_type, qdelta_imp, + qdelta_exp, final_update=True, initial_guess="base", communicator=ensemble) + elif scheme == "IMEX_RIDC(2,1)": + eqn = split_continuity_form(eqn) + eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) + eqn.label_terms(lambda t: t.has_label(transport), explicit) + + M = 5 + J = int(tmax/dt) + base_scheme = IMEX_Euler(domain) + time_scheme = Parallel_RIDC(base_scheme, domain, M, k, J, output_freq=dumpfreq, flush_freq=1, communicator=ensemble) + elif scheme == "IMEX_RIDC(2,5)": + eqn = split_continuity_form(eqn) + eqn.label_terms(lambda t: not any(t.has_label(time_derivative, transport)), implicit) + eqn.label_terms(lambda t: t.has_label(transport), explicit) + + M = 5 + J = int(tmax/dt) + base_scheme = IMEX_Euler(domain) + time_scheme = Parallel_RIDC(base_scheme, domain, M, k, J, output_freq=dumpfreq, flush_freq=5, communicator=ensemble) + + transport_method = DGUpwind(eqn, 'f') + + time_varying_velocity = False + timestepper = PrescribedTransport( + eqn, time_scheme, io, time_varying_velocity, transport_method + ) + + timestepper.fields("f").interpolate(f_init) + timestepper.fields("u").project(uexpr) + error = run(timestepper, tmax, f_end) + parallel_assert(error < tol, f"Error too large, Error: {error}, tol: {tol}")