diff --git a/mathics/builtin/base.py b/mathics/builtin/base.py index 51cbd5af9..6afad2618 100644 --- a/mathics/builtin/base.py +++ b/mathics/builtin/base.py @@ -812,7 +812,7 @@ def get_lookup_name(self): return self.get_name() def get_sort_key(self): - return self.to_expression().get_sort_key() + return self.to_expression().cached_get_sort_key() def get_string_value(self): return "-@" + self.get_head_name() + "@-" @@ -938,7 +938,7 @@ def get_match_count(self, vars={}): return (1, 1) def get_sort_key(self, pattern_sort=False): - return self.expr.get_sort_key(pattern_sort=pattern_sort) + return self.expr.cached_get_sort_key(pattern_sort=pattern_sort) class NegativeIntegerException(Exception): diff --git a/mathics/builtin/lists.py b/mathics/builtin/lists.py index 3d9345911..902dbe227 100644 --- a/mathics/builtin/lists.py +++ b/mathics/builtin/lists.py @@ -382,7 +382,7 @@ def apply(self, expr, positions, evaluation): if positions.elements[0].has_form("List", None) else [positions] ) - positions.sort(key=lambda e: e.get_sort_key(pattern_sort=True)) + positions.sort(key=lambda e: e.cached_get_sort_key(pattern_sort=True)) newexpr = expr for position in positions: pos = [p.get_int_value() for p in position.get_elements()] diff --git a/mathics/builtin/numbers/linalg.py b/mathics/builtin/numbers/linalg.py index c30a714e4..c9753ff81 100644 --- a/mathics/builtin/numbers/linalg.py +++ b/mathics/builtin/numbers/linalg.py @@ -247,7 +247,7 @@ def apply(self, m, evaluation, options={}) -> Expression: eigenvalues = [(from_sympy(v), c) for (v, c) in eigenvalues] # Sort the eigenvalues by their sort key - eigenvalues.sort(key=lambda v: v[0].get_sort_key()) + eigenvalues.sort(key=lambda v: v[0].cached_get_sort_key()) eigenvalues = [v for (v, c) in eigenvalues for _ in range(c)] @@ -307,9 +307,9 @@ def apply(self, m, evaluation): key=lambda v: (abs(v[0]), -re(v[0]), -im(v[0])), reverse=True ) except TypeError: - eigenvects.sort(key=lambda v: from_sympy(v[0]).get_sort_key()) + eigenvects.sort(key=lambda v: from_sympy(v[0]).cached_get_sort_key()) else: - eigenvects.sort(key=lambda v: from_sympy(v[0]).get_sort_key()) + eigenvects.sort(key=lambda v: from_sympy(v[0]).cached_get_sort_key()) result = [] for val, count, basis in eigenvects: diff --git a/mathics/builtin/patterns.py b/mathics/builtin/patterns.py index e85905ea5..5eb631cdc 100644 --- a/mathics/builtin/patterns.py +++ b/mathics/builtin/patterns.py @@ -1686,7 +1686,7 @@ def __init__(self, rulelist, evaluation): self._head = SymbolDispatch def get_sort_key(self): - return self.src.get_sort_key() + return self.src.cached_get_sort_key() def get_atom_name(self): return "System`Dispatch" diff --git a/mathics/builtin/structure.py b/mathics/builtin/structure.py index 77f11cb11..9e53861f0 100644 --- a/mathics/builtin/structure.py +++ b/mathics/builtin/structure.py @@ -288,7 +288,7 @@ class PatternsOrderedQ(Builtin): def apply(self, p1, p2, evaluation): "PatternsOrderedQ[p1_, p2_]" - if p1.get_sort_key(True) <= p2.get_sort_key(True): + if p1.cached_get_sort_key(True) <= p2.cached_get_sort_key(True): return SymbolTrue else: return SymbolFalse diff --git a/mathics/core/atoms.py b/mathics/core/atoms.py index b17825f99..c75fd4667 100644 --- a/mathics/core/atoms.py +++ b/mathics/core/atoms.py @@ -406,7 +406,7 @@ def __eq__(self, other) -> bool: self.to_mpmath(), other.to_mpmath(), abs_eps=0, rel_eps=rel_eps ) else: - return self.get_sort_key() == other.get_sort_key() + return self.cached_get_sort_key() == other.cached_get_sort_key() def __ne__(self, other) -> bool: # Real is a total order @@ -683,7 +683,13 @@ def get_sort_key(self, pattern_sort=False): if pattern_sort: return super().get_sort_key(True) else: - return [0, 0, self.real.get_sort_key()[2], self.imag.get_sort_key()[2], 1] + return [ + 0, + 0, + self.real.cached_get_sort_key()[2], + self.imag.cached_get_sort_key()[2], + 1, + ] def sameQ(self, other) -> bool: """Mathics SameQ""" @@ -740,7 +746,7 @@ def __eq__(self, other) -> bool: if isinstance(other, Complex): return self.real == other.real and self.imag == other.imag else: - return self.get_sort_key() == other.get_sort_key() + return self.cached_get_sort_key() == other.cached_get_sort_key() def __getnewargs__(self): return (self.real, self.imag) diff --git a/mathics/core/element.py b/mathics/core/element.py index 44275e91c..efc294347 100644 --- a/mathics/core/element.py +++ b/mathics/core/element.py @@ -105,28 +105,41 @@ class KeyComparable: def get_sort_key(self): raise NotImplementedError + def cached_get_sort_key(self, pattern_sort=False): + if pattern_sort: + if hasattr(self, "pattern_sort_key"): + return self.pattern_sort_key + sort_key = self.get_sort_key(True) + self.pattern_sort_key = sort_key + else: + if hasattr(self, "sort_key"): + return self.sort_key + sort_key = self.get_sort_key() + self.sort_key = sort_key + return sort_key + def __eq__(self, other) -> bool: return ( - hasattr(other, "get_sort_key") - and self.get_sort_key() == other.get_sort_key() + hasattr(other, "cached_get_sort_key") + and self.cached_get_sort_key() == other.cached_get_sort_key() ) def __gt__(self, other) -> bool: - return self.get_sort_key() > other.get_sort_key() + return self.cached_get_sort_key() > other.cached_get_sort_key() def __ge__(self, other) -> bool: - return self.get_sort_key() >= other.get_sort_key() + return self.cached_get_sort_key() >= other.cached_get_sort_key() def __le__(self, other) -> bool: - return self.get_sort_key() <= other.get_sort_key() + return self.cached_get_sort_key() <= other.cached_get_sort_key() def __lt__(self, other) -> bool: - return self.get_sort_key() < other.get_sort_key() + return self.cached_get_sort_key() < other.cached_get_sort_key() def __ne__(self, other) -> bool: return ( not hasattr(other, "get_sort_key") - ) or self.get_sort_key() != other.get_sort_key() + ) or self.cached_get_sort_key() != other.cached_get_sort_key() class BaseElement(KeyComparable): diff --git a/mathics/core/expression.py b/mathics/core/expression.py index a7dc13fc9..1becd8fde 100644 --- a/mathics/core/expression.py +++ b/mathics/core/expression.py @@ -723,7 +723,6 @@ def get_sort_key(self, pattern_sort=False): 6: leaves / 0 for atoms 7: 0/1: 0 for Condition """ - head = self._head pattern = 0 if head is SymbolBlank: @@ -744,40 +743,42 @@ def get_sort_key(self, pattern_sort=False): 1, 1, 0, - head.get_sort_key(True), - tuple(element.get_sort_key(True) for element in self._elements), + head.cached_get_sort_key(True), + tuple( + element.cached_get_sort_key(True) for element in self._elements + ), 1, ] if head is SymbolPatternTest: if len(self._elements) != 2: return [3, 0, 0, 0, 0, head, self._elements, 1] - sub = self._elements[0].get_sort_key(True) + sub = self._elements[0].cached_get_sort_key(True) sub[2] = 0 return sub elif head is SymbolCondition: if len(self._elements) != 2: return [3, 0, 0, 0, 0, head, self._elements, 1] - sub = self._elements[0].get_sort_key(True) + sub = self._elements[0].cached_get_sort_key(True) sub[7] = 0 return sub elif head is SymbolPattern: if len(self._elements) != 2: return [3, 0, 0, 0, 0, head, self._elements, 1] - sub = self._elements[1].get_sort_key(True) + sub = self._elements[1].cached_get_sort_key(True) sub[3] = 0 return sub elif head is SymbolOptional: if len(self._elements) not in (1, 2): return [3, 0, 0, 0, 0, head, self._elements, 1] - sub = self._elements[0].get_sort_key(True) + sub = self._elements[0].cached_get_sort_key(True) sub[4] = 1 return sub elif head is SymbolAlternatives: min_key = [4] min = None for element in self._elements: - key = element.get_sort_key(True) + key = element.cached_get_sort_key(True) if key < min_key: min = element min_key = key @@ -788,7 +789,7 @@ def get_sort_key(self, pattern_sort=False): elif head is SymbolVerbatim: if len(self._elements) != 1: return [3, 0, 0, 0, 0, head, self._elements, 1] - return self._elements[0].get_sort_key(True) + return self._elements[0].cached_get_sort_key(True) elif head is SymbolOptionsPattern: return [2, 40, 0, 1, 1, 0, head, self._elements, 1] else: @@ -800,10 +801,13 @@ def get_sort_key(self, pattern_sort=False): 1, 1, 0, - head.get_sort_key(True), + head.cached_get_sort_key(True), tuple( chain( - (element.get_sort_key(True) for element in self._elements), + ( + element.cached_get_sort_key(True) + for element in self._elements + ), ([4],), ) ), @@ -1305,7 +1309,7 @@ def set_element(self, index: int, value): def shallow_copy(self) -> "Expression": # this is a minimal, shallow copy: head, elements are shared with # the original, only the Expression instance is new. - + return self expr = Expression( self._head, *self._elements, elements_properties=self.elements_properties ) @@ -1438,13 +1442,16 @@ def sort(self, pattern=False): # list sort method. Another approach would be to use sorted(). elements = self.get_mutable_elements() if pattern: - elements.sort(key=lambda e: e.get_sort_key(pattern_sort=True)) + elements.sort(key=lambda e: e.cached_get_sort_key(pattern_sort=True)) else: elements.sort() # update `self._elements` and self._cache with the possible permuted order. self.elements = elements - self._build_elements_properties() + if self.elements_properties is None: + self._build_elements_properties() + else: + self.elements_properties.is_ordered = True if self._cache: self._cache = self._cache.reordered() diff --git a/mathics/core/pattern.py b/mathics/core/pattern.py index 246f6b0f7..b54c4818e 100644 --- a/mathics/core/pattern.py +++ b/mathics/core/pattern.py @@ -134,7 +134,7 @@ def get_elements(self): get_leaves = get_elements def get_sort_key(self, pattern_sort=False): - return self.expr.get_sort_key(pattern_sort=pattern_sort) + return self.expr.cached_get_sort_key(pattern_sort=pattern_sort) def get_lookup_name(self): return self.expr.get_lookup_name()