From f8d2fe762c4cf0fd3c850aeb8216552fe4f995c4 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 18 Feb 2026 12:20:42 +0000 Subject: [PATCH] Optimise index replacement --- gem/optimise.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/gem/optimise.py b/gem/optimise.py index 9809cc2a..ccf0e2c1 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -308,21 +308,25 @@ def select_expression(expressions, index): return ComponentTensor(selected, alpha) -def delta_elimination(sum_indices, factors): +def delta_elimination(sum_indices, factors, index_replacer=None): """IndexSum-Delta cancellation. :arg sum_indices: free indices for contractions :arg factors: product factors + :kwarg index_replacer: MemoizerArg(filtered_replace_indices) + :returns: optimised (sum_indices, factors) """ + if index_replacer is None: + index_replacer = MemoizerArg(filtered_replace_indices) + sum_indices = list(sum_indices) # copy for modification def substitute(expression, from_, to_): if from_ not in expression.free_indices: return expression elif isinstance(expression, Delta): - mapper = MemoizerArg(filtered_replace_indices) - return mapper(expression, ((from_, to_),)) + return index_replacer(expression, ((from_, to_),)) else: return Indexed(ComponentTensor(expression, (from_,)), (to_,)) @@ -490,9 +494,9 @@ def applier(expr): return partial(_renamer, rename_map, set()) -def traverse_product(expression, stop_at=None, rename_map=None): +def traverse_product(expression, stop_at=None, rename_map=None, index_replacer=None): """Traverses a product tree and collects factors, also descending into - tensor contractions (IndexSum). The nominators of divisions are + tensor contractions (IndexSum). The numerators of divisions are also broken up, but not the denominators. :arg expression: a GEM expression @@ -501,6 +505,8 @@ def traverse_product(expression, stop_at=None, rename_map=None): subexpression is not broken into further factors even if it is a product-like expression. :arg rename_map: an rename map for consistent index renaming + :kwarg index_replacer: MemoizerArg(filtered_replace_indices) + :returns: (sum_indices, terms) - sum_indices: list of indices to sum over - terms: list of product terms @@ -508,6 +514,8 @@ def traverse_product(expression, stop_at=None, rename_map=None): if rename_map is None: rename_map = make_rename_map() renamer = make_renamer(rename_map) + if index_replacer is None: + index_replacer = MemoizerArg(filtered_replace_indices) sum_indices = [] terms = [] @@ -520,7 +528,7 @@ def traverse_product(expression, stop_at=None, rename_map=None): elif isinstance(expr, IndexSum): indices, applier = renamer(expr.multiindex) sum_indices.extend(indices) - stack.extend(remove_componenttensors(map(applier, expr.children))) + stack.extend(index_replacer(applier(c), ()) for c in expr.children) elif isinstance(expr, Product): stack.extend(reversed(expr.children)) elif isinstance(expr, Division): @@ -575,13 +583,18 @@ def contraction(expression, ignore=None): This routine was designed with finite element coefficient evaluation in mind. """ + + # Common memoizer to remove ComponentTensors + index_replacer = MemoizerArg(filtered_replace_indices) + # Eliminate annoying ComponentTensors - expression, = remove_componenttensors([expression]) + expression = index_replacer(expression, ()) # Flatten product tree, eliminate deltas, sum factorise def rebuild(expression): - sum_indices, factors = delta_elimination(*traverse_product(expression)) - factors = remove_componenttensors(factors) + sum_indices, factors = traverse_product(expression, index_replacer=index_replacer) + sum_indices, factors = delta_elimination(sum_indices, factors, index_replacer=index_replacer) + factors = [index_replacer(f, ()) for f in factors] if ignore is not None: # TODO: This is a really blunt instrument and one might # plausibly want the ignored indices to be contracted on @@ -610,7 +623,7 @@ def rebuild(expression): # Rebuild each split component tensor = ComponentTensor(expression, lt_fis) entries = [Indexed(tensor, zeta) for zeta in numpy.ndindex(tensor.shape)] - entries = remove_componenttensors(entries) + entries = [index_replacer(e, ()) for e in entries] return Indexed(ListTensor( numpy.array(list(map(rebuild, entries))).reshape(tensor.shape) ), lt_fis)