Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,21 +308,25 @@ def select_expression(expressions, index):
return ComponentTensor(selected, alpha)


def delta_elimination(sum_indices, factors):
def delta_elimination(sum_indices, factors, index_replacer=None):
"""IndexSum-Delta cancellation.

:arg sum_indices: free indices for contractions
:arg factors: product factors
:kwarg index_replacer: MemoizerArg(filtered_replace_indices)

:returns: optimised (sum_indices, factors)
"""
if index_replacer is None:
index_replacer = MemoizerArg(filtered_replace_indices)

sum_indices = list(sum_indices) # copy for modification

def substitute(expression, from_, to_):
if from_ not in expression.free_indices:
return expression
elif isinstance(expression, Delta):
mapper = MemoizerArg(filtered_replace_indices)
return mapper(expression, ((from_, to_),))
return index_replacer(expression, ((from_, to_),))
else:
return Indexed(ComponentTensor(expression, (from_,)), (to_,))

Expand Down Expand Up @@ -490,9 +494,9 @@ def applier(expr):
return partial(_renamer, rename_map, set())


def traverse_product(expression, stop_at=None, rename_map=None):
def traverse_product(expression, stop_at=None, rename_map=None, index_replacer=None):
"""Traverses a product tree and collects factors, also descending into
tensor contractions (IndexSum). The nominators of divisions are
tensor contractions (IndexSum). The numerators of divisions are
also broken up, but not the denominators.

:arg expression: a GEM expression
Expand All @@ -501,13 +505,17 @@ def traverse_product(expression, stop_at=None, rename_map=None):
subexpression is not broken into further factors
even if it is a product-like expression.
:arg rename_map: an rename map for consistent index renaming
:kwarg index_replacer: MemoizerArg(filtered_replace_indices)

:returns: (sum_indices, terms)
- sum_indices: list of indices to sum over
- terms: list of product terms
"""
if rename_map is None:
rename_map = make_rename_map()
renamer = make_renamer(rename_map)
if index_replacer is None:
index_replacer = MemoizerArg(filtered_replace_indices)

sum_indices = []
terms = []
Expand All @@ -520,7 +528,7 @@ def traverse_product(expression, stop_at=None, rename_map=None):
elif isinstance(expr, IndexSum):
indices, applier = renamer(expr.multiindex)
sum_indices.extend(indices)
stack.extend(remove_componenttensors(map(applier, expr.children)))
stack.extend(index_replacer(applier(c), ()) for c in expr.children)
elif isinstance(expr, Product):
stack.extend(reversed(expr.children))
elif isinstance(expr, Division):
Expand Down Expand Up @@ -575,13 +583,18 @@ def contraction(expression, ignore=None):
This routine was designed with finite element coefficient
evaluation in mind.
"""

# Common memoizer to remove ComponentTensors
index_replacer = MemoizerArg(filtered_replace_indices)

# Eliminate annoying ComponentTensors
expression, = remove_componenttensors([expression])
expression = index_replacer(expression, ())

# Flatten product tree, eliminate deltas, sum factorise
def rebuild(expression):
sum_indices, factors = delta_elimination(*traverse_product(expression))
factors = remove_componenttensors(factors)
sum_indices, factors = traverse_product(expression, index_replacer=index_replacer)
sum_indices, factors = delta_elimination(sum_indices, factors, index_replacer=index_replacer)
factors = [index_replacer(f, ()) for f in factors]
if ignore is not None:
# TODO: This is a really blunt instrument and one might
# plausibly want the ignored indices to be contracted on
Expand Down Expand Up @@ -610,7 +623,7 @@ def rebuild(expression):
# Rebuild each split component
tensor = ComponentTensor(expression, lt_fis)
entries = [Indexed(tensor, zeta) for zeta in numpy.ndindex(tensor.shape)]
entries = remove_componenttensors(entries)
entries = [index_replacer(e, ()) for e in entries]
return Indexed(ListTensor(
numpy.array(list(map(rebuild, entries))).reshape(tensor.shape)
), lt_fis)
Expand Down