Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gusto/core/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __call__(self, target, value=None):
# ---------------------------------------------------------------------------- #
implicit = Label("implicit")
explicit = Label("explicit")
horizontal_transport = Label("horizontal_transport")
vertical_transport = Label("vertical_transport")
source_label = Label("source_label")
transporting_velocity = Label("transporting_velocity", validator=lambda value: type(value) in [Function, ufl.tensors.ListTensor, ufl.indexed.Indexed])
prognostic = Label("prognostic", validator=lambda value: type(value) == str)
Expand Down
128 changes: 126 additions & 2 deletions gusto/equations/common_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
from firedrake.fml import subject, drop
from gusto.core.configuration import TransportEquationType
from gusto.core.labels import (transport, transporting_velocity, diffusion,
prognostic, linearisation)
prognostic, linearisation, horizontal_transport,
vertical_transport)

__all__ = ["advection_form", "advection_form_1d", "continuity_form",
"continuity_form_1d", "vector_invariant_form",
"kinetic_energy_form", "advection_equation_circulation_form",
"diffusion_form", "diffusion_form_1d",
"linear_advection_form", "linear_continuity_form",
"split_continuity_form", "tracer_conservative_form"]
"split_continuity_form", "tracer_conservative_form", "split_hv_advective_form"]


def advection_form(test, q, ubar):
Expand Down Expand Up @@ -346,3 +347,126 @@ def tracer_conservative_form(test, q, rho, ubar):
form = transporting_velocity(L, ubar)

return transport(form, TransportEquationType.tracer_conservative)


def split_advection_form(test, q, ubar, ubar_full):
u"""
The form corresponding to the advective transport operator in either horzontal
or vertical directions (dependent on ubar).

This describes either u_h.(∇)q or w dq/dz, for transporting velocity u and transported q.

Args:
test (:class:`TestFunction`): the test function.
q (:class:`ufl.Expr`): the variable to be transported.
ubar (:class:`ufl.Expr`): the transporting velocity in a subset of dimensions.
ubar_full (:class:`ufl.Expr`): the transporting velocity in all dimensions.

Returns:
class:`LabelledForm`: a labelled transport form.
"""

L = inner(test, dot(ubar, grad(q)))*dx
form = transporting_velocity(L, ubar_full)

return transport(form, TransportEquationType.advective)


def split_linear_advection_form(test, qbar, ubar, ubar_full):
"""
The form corresponding to the linearised advective transport operator in
either horzontal or vertical directions (dependent on ubar).

Args:
test (:class:`TestFunction`): the test function.
qbar (:class:`ufl.Expr`): the variable to be transported.
ubar (:class:`ufl.Expr`): the transporting velocity in a subset of dimensions.
ubar_full (:class:`ufl.Expr`): the transporting velocity in all dimensions.

Returns:
:class:`LabelledForm`: a labelled transport form.
"""

L = test*dot(ubar, grad(qbar))*dx
form = transporting_velocity(L, ubar_full)

return transport(form, TransportEquationType.advective)


def split_hv_advective_form(equation, field_name):
u"""
Splits advective term into horizontal and vertical terms.
This describes splitting u.∇(q) terms into u_h.(∇)q and w dq/dz,
for transporting velocity u and transported q.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this all looks correct. It's very cumbersome having to go through the whole of this process -- do you think there are any shortcuts we can take, e.g. in adding the linearisations?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried to tidy it a bit by splitting out the advection forms, hopefully this helps!

Args:
equation (:class:`PrognosticEquation`): the model's equation.
Returns:
:class:`PrognosticEquation`: the model's equation.
"""
k = equation.domain.k # vertical unit vector
for t in equation.residual:
if (t.get(transport) == TransportEquationType.advective and t.get(prognostic) == field_name):
# Get fields and test functions
subj = t.get(subject)

# u is either a prognostic or prescribed field
if (hasattr(equation, "field_names")
and 'u' in equation.field_names):
idx = equation.field_names.index(field_name)
W = equation.function_space
test = TestFunctions(W)[idx]
q = split(subj)[idx]
u_idx = equation.field_names.index('u')
uadv = split(equation.X)[u_idx]
elif 'u' in equation.prescribed_fields._field_names:
uadv = equation.prescribed_fields('u')
q = subj
W = equation.function_space
test = TestFunction(W)
else:
raise ValueError('Cannot get velocity field')

# Create new advective and divergence terms
u_vertical = k*inner(uadv, k)
u_horizontal = uadv - u_vertical
vertical_adv_term = prognostic(
vertical_transport(
split_advection_form(test, q, u_vertical, uadv)
),
field_name
)
horizontal_adv_term = prognostic(
horizontal_transport(
split_advection_form(test, q, u_horizontal, uadv)
),
field_name
)

# Add linearisations of new terms if required
if (t.has_label(linearisation)):
u_trial = TrialFunctions(W)[u_idx]
u_trial_vert = k*inner(u_trial, k)
u_trial_horiz = u_trial - u_trial_vert
qbar = split(equation.X_ref)[idx]
# Add linearisations
linear_hori_term = horizontal_transport(
split_linear_advection_form(test, qbar, u_trial_horiz, u_trial)
)
adv_horiz_term = linearisation(horizontal_adv_term, linear_hori_term)

linear_vert_term = vertical_transport(
split_linear_advection_form(test, qbar, u_trial_vert, u_trial)
)
adv_vert_term = linearisation(vertical_adv_term, linear_vert_term)
else:
adv_vert_term = vertical_adv_term
adv_horiz_term = horizontal_adv_term
# Drop old term
equation.residual = equation.residual.label_map(
lambda t: t.get(transport) == TransportEquationType.advective and t.get(prognostic) == field_name,
map_if_true=drop)

# Add new terms onto residual
equation.residual += subject(adv_horiz_term, subj) + subject(adv_vert_term, subj)

return equation
49 changes: 36 additions & 13 deletions gusto/spatial_methods/spatial_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from firedrake import split
from firedrake.fml import Term, keep, drop
from firedrake.fml import Term, keep, drop, all_terms
from gusto.core.labels import prognostic

__all__ = ['SpatialMethod']
Expand All @@ -15,19 +15,19 @@ class SpatialMethod(object):
The base object for describing a spatial discretisation of some term.
"""

def __init__(self, equation, variable, term_label):
def __init__(self, equation, variable, *term_labels):
"""
Args:
equation (:class:`PrognosticEquation`): the equation, which includes
the original type of this term.
variable (str): name of the variable to set the method for
term_label (:class:`Label`): the label specifying which type of term
to be discretised.
term_labels (:class:`Label`): One or more labels specifying which type(s)
of terms should be discretized.
"""
self.equation = equation
self.variable = variable
self.domain = self.equation.domain
self.term_label = term_label
self.term_labels = list(term_labels)

if hasattr(equation, "field_names"):
# Equation with multiple prognostic variables
Expand All @@ -38,14 +38,37 @@ def __init__(self, equation, variable, term_label):
self.field = equation.X
self.test = equation.test

# Find the original term to be used
self.original_form = equation.residual.label_map(
lambda t: t.has_label(term_label) and t.get(prognostic) == variable,
map_if_true=keep, map_if_false=drop)

num_terms = len(self.original_form.terms)
assert num_terms == 1, f'Unable to find {term_label.label} term ' \
+ f'for {variable}. {num_terms} found'
if (len(self.term_labels) == 1):
# Most cases only have one term to be replaced
self.term_label = self.term_labels[0]
self.original_form = equation.residual.label_map(
lambda t: t.has_label(self.term_label) and t.get(prognostic) == variable,
map_if_true=keep,
map_if_false=drop
)
# Check that the original form has the correct number of terms
num_terms = len(self.original_form.terms)
assert num_terms == 1, f'Unable to find {self.term_label.label} term ' \
+ f'for {variable}. {num_terms} found'
else:
# Multiple terms to be replaced. Find the original terms to be used
self.term_label = self.term_labels[0]
self.original_form = equation.residual.label_map(
all_terms,
map_if_true=drop
)
for term in self.term_labels:
original_form = equation.residual.label_map(
lambda t: t.has_label(term) and t.get(prognostic) == variable,
map_if_true=keep,
map_if_false=drop
)
# Check that the original form has the correct number of terms
num_terms = len(original_form.terms)
assert num_terms == 1, f'Unable to find {term.label} term ' \
+ f'for {variable}. {num_terms} found'
# Add the terms form to original forms
self.original_form += original_form

def replace_form(self, equation):
"""
Expand Down
Loading