diff --git a/gusto/time_discretisation/explicit_runge_kutta.py b/gusto/time_discretisation/explicit_runge_kutta.py index aa3214222..09f8689e6 100644 --- a/gusto/time_discretisation/explicit_runge_kutta.py +++ b/gusto/time_discretisation/explicit_runge_kutta.py @@ -282,7 +282,7 @@ def res(self): ) # Set up all-but-last RHS - if self.idx is not None: + if self.idx is not None and self.wrapper is None: # If original function is in mixed function space, then ensure # correct test function in the all-but-last form r_all_but_last = self.residual.label_map( diff --git a/gusto/time_discretisation/time_discretisation.py b/gusto/time_discretisation/time_discretisation.py index b9578f720..6e0a7e05e 100644 --- a/gusto/time_discretisation/time_discretisation.py +++ b/gusto/time_discretisation/time_discretisation.py @@ -17,8 +17,10 @@ from firedrake.utils import cached_property from gusto.core.configuration import EmbeddedDGOptions, RecoveryOptions -from gusto.core.labels import (time_derivative, prognostic, physics_label, - mass_weighted, nonlinear_time_derivative) +from gusto.core.labels import ( + time_derivative, prognostic, physics_label, mass_weighted, + nonlinear_time_derivative, all_but_last +) from gusto.core.logging import logger, DEBUG, logging_ksp_monitor_true_residual from gusto.time_discretisation.wrappers import * from gusto.solvers import mass_parameters @@ -333,6 +335,21 @@ def setup(self, equation, apply_bcs=True, *active_labels): all_terms, map_if_true=replace_test_function(new_test)) + # TODO: roll this out to other wrappers + # Replace test function in any all_but_last label + def replace_test_all_but_last(t): + old_abl_form = t.get(all_but_last) + new_abl_form = old_abl_form.label_map( + all_terms, + replace_test_function(new_test, old_idx=self.idx) + ) + return all_but_last(t, new_abl_form) + + self.residual = self.residual.label_map( + lambda t: t.has_label(all_but_last), + map_if_true=replace_test_all_but_last + ) + self.residual = self.wrapper.label_terms(self.residual) if self.solver_parameters is None: self.solver_parameters = self.wrapper.solver_parameters