Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,6 @@ def _make_parloops(expr, tensor, bcs, diagonal, fc_params, assembly_rank):
domains = expr.ufl_domains()

if isinstance(expr, slate.TensorBase):
if diagonal:
raise NotImplementedError("Diagonal + slate not supported")
kernels = slac.compile_expression(expr, compiler_parameters=form_compiler_parameters)
else:
kernels = tsfc_interface.compile_form(expr, "form", parameters=form_compiler_parameters, diagonal=diagonal)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/slate/slac/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def generate_loopy_kernel(slate_expr, compiler_parameters=None):

# Create a loopy builder for the Slate expression,
# e.g. contains the loopy kernels coming from TSFC
gem_expr, var2terminal = slate_to_gem(slate_expr)
gem_expr, var2terminal = slate_to_gem(slate_expr, compiler_parameters["slate_compiler"])

scalar_type = compiler_parameters["form_compiler"]["scalar_type"]
slate_loopy, output_arg = gem_to_loopy(gem_expr, var2terminal, scalar_type)
Expand Down
79 changes: 76 additions & 3 deletions firedrake/slate/slac/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ def optimise(expression, parameters):

Returns: An optimised Slate expression
"""
# 1) Block optimisation
# 0) Block optimisation
expression = push_block(expression)

# 1) DiagonalTensor optimisation
expression = push_diag(expression)

# 2) Multiplication optimisation
if expression.rank < 2:
expression = push_mul(expression, parameters)
Expand Down Expand Up @@ -70,6 +73,8 @@ def _push_block_transpose(expr, self, indices):

@_push_block.register(Add)
@_push_block.register(Negative)
@_push_block.register(DiagonalTensor)
@_push_block.register(Reciprocal)
def _push_block_distributive(expr, self, indices):
"""Distributes Blocks for these nodes"""
return type(expr)(*map(self, expr.children, repeat(indices))) if indices else expr
Expand Down Expand Up @@ -111,6 +116,66 @@ def _push_block_block(expr, self, indices):
return block


def push_diag(expression):
"""Executes a Slate compiler optimisation pass.
The optimisation is achieved by pushing DiagonalTensor from the outside to the inside of an expression.

:arg expression: A (potentially unoptimised) Slate expression.

Returns: An optimised Slate expression, where DiagonalTensors are sitting
on terminal tensors whereever possible.
"""
mapper = MemoizerArg(_push_diag)
return mapper(expression, False)


@singledispatch
def _push_diag(expr, self, diag):
raise AssertionError("Cannot handle terminal type: %s" % type(expr))


@_push_diag.register(Transpose)
@_push_diag.register(Add)
@_push_diag.register(Negative)
def _push_diag_distributive(expr, self, diag):
"""Distributes the DiagonalTensors into these nodes"""
return type(expr)(*map(self, expr.children, repeat(diag)))


@_push_diag.register(Factorization)
@_push_diag.register(Inverse)
@_push_diag.register(Solve)
@_push_diag.register(Mul)
@_push_diag.register(Tensor)
def _push_diag_stop(expr, self, diag):
"""Diagonal Tensors cannot be pushed further into this set of nodes."""
expr = type(expr)(*map(self, expr.children, repeat(False))) if not expr.terminal else expr
return DiagonalTensor(expr) if diag else expr


@_push_diag.register(Block)
def _push_diag_block(expr, self, diag):
"""Diagonal Tensors cannot be pushed further into this set of nodes."""
expr = type(expr)(*map(self, expr.children, repeat(False)), expr._indices) if not expr.terminal else expr
return DiagonalTensor(expr) if diag else expr


@_push_diag.register(AssembledVector)
@_push_diag.register(Reciprocal)
def _push_diag_vectors(expr, self, diag):
"""DiagonalTensors should not be pushed onto rank-1 tensors."""
if diag:
raise AssertionError("It is not legal to define DiagonalTensors on rank-1 tensors.")
else:
return expr


@_push_diag.register(DiagonalTensor)
def _push_diag_diag(expr, self, diag):
"""DiagonalTensors are either pushed down or ignored when wrapped into another DiagonalTensor."""
return self(*expr.children, not diag)


def push_mul(tensor, options):
"""Executes a Slate compiler optimisation pass.
The optimisation is achieved by pushing coefficients from
Expand Down Expand Up @@ -179,6 +244,8 @@ def _drop_double_transpose_transpose(expr, self):
@_drop_double_transpose.register(Mul)
@_drop_double_transpose.register(Solve)
@_drop_double_transpose.register(Inverse)
@_drop_double_transpose.register(DiagonalTensor)
@_drop_double_transpose.register(Reciprocal)
def _drop_double_transpose_distributive(expr, self):
"""Distribute into the children of the expression. """
return type(expr)(*map(self, expr.children))
Expand All @@ -202,6 +269,8 @@ def _push_mul_tensor(expr, self, state):


@_push_mul.register(AssembledVector)
@_push_mul.register(DiagonalTensor)
@_push_mul.register(Reciprocal)
def _push_mul_vector(expr, self, state):
"""Do not push into AssembledVectors."""
return expr
Expand All @@ -220,8 +289,12 @@ def _push_mul_inverse(expr, self, state):
with a coefficient into a Solve via A.inv*b = A.solve(b)
or b*A^{-1}= (A.T.inv*b.T).T = A.T.solve(b.T).T ."""
child, = expr.children
return (Solve(child, state.coeff) if state.pick_op
else Transpose(Solve(Transpose(child), Transpose(state.coeff))))
if expr.diagonal:
# Don't optimise further so that the translation to gem at a later can just spill ]1/a_ii[
return expr * state.coeff if state.pick_op else state.coeff * expr
else:
return (Solve(child, state.coeff) if state.pick_op
else Transpose(Solve(Transpose(child), Transpose(state.coeff))))


@_push_mul.register(Transpose)
Expand Down
3 changes: 2 additions & 1 deletion firedrake/slate/slac/tsfc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def compile_terminal_form(tensor, prefix, *, tsfc_parameters=None, coffee=True):
kernels = tsfc_compile(form,
subkernel_prefix,
parameters=tsfc_parameters,
coffee=coffee, split=False)
coffee=coffee, split=False, diagonal=tensor.diagonal)

if kernels:
cxt_k = ContextKernel(tensor=tensor,
coefficients=form.coefficients(),
Expand Down
39 changes: 34 additions & 5 deletions firedrake/slate/slac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ufl.algorithms.multifunction import MultiFunction

from gem import (Literal, Sum, Product, Indexed, ComponentTensor, IndexSum,
Solve, Inverse, Variable, view)
Solve, Inverse, Variable, view, Delta, Index, Division)
from gem import indices as make_indices
from gem.node import Memoizer
from gem.node import pre_traversal as traverse_dags
Expand Down Expand Up @@ -148,15 +148,15 @@ def visit_Symbol(self, o, *args, **kwargs):
return SymbolWithFuncallIndexing(o.symbol, o.rank, o.offset)


def slate_to_gem(expression):
def slate_to_gem(expression, options):
"""Convert a slate expression to gem.

:arg expression: A slate expression.
:returns: A singleton list of gem expressions and a mapping from
gem variables to UFL "terminal" forms.
"""

mapper, var2terminal = slate2gem(expression)
mapper, var2terminal = slate2gem(expression, options)
return mapper, var2terminal


Expand Down Expand Up @@ -186,9 +186,37 @@ def _slate2gem_block(expr, self):
return view(child, *(slice(idx, idx+extent) for idx, extent in zip(offsets, expr.shape)))


@_slate2gem.register(sl.DiagonalTensor)
def _slate2gem_diagonal(expr, self):
if not self.matfree:
A, = map(self, expr.children)
assert A.shape[0] == A.shape[1]
i, j = (Index(extent=s) for s in A.shape)
return ComponentTensor(Product(Indexed(A, (i, i)), Delta(i, j)), (i, j))
else:
raise NotImplementedError("Diagonals on Slate expressions are \
not implemented in a matrix-free manner yet.")


@_slate2gem.register(sl.Inverse)
def _slate2gem_inverse(expr, self):
return Inverse(*map(self, expr.children))
tensor, = expr.children
if expr.diagonal:
# optimise inverse on diagonal tensor by translating to
# matrix which contains the reciprocal values of the diagonal tensor
A, = map(self, expr.children)
i, j = (Index(extent=s) for s in A.shape)
return ComponentTensor(Product(Division(Literal(1), Indexed(A, (i, i))),
Delta(i, j)), (i, j))
else:
return Inverse(self(tensor))


@_slate2gem.register(sl.Reciprocal)
def _slate2gem_reciprocal(expr, self):
child, = map(self, expr.children)
indices = tuple(make_indices(len(child.shape)))
return ComponentTensor(Division(Literal(1.), Indexed(child, indices)), indices)


@_slate2gem.register(sl.Solve)
Expand Down Expand Up @@ -237,9 +265,10 @@ def _slate2gem_factorization(expr, self):
return A


def slate2gem(expression):
def slate2gem(expression, options):
mapper = Memoizer(_slate2gem)
mapper.var2terminal = OrderedDict()
mapper.matfree = options["replace_mul"]
return mapper(expression), mapper.var2terminal


Expand Down
Loading