Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mathics/builtin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def get_lookup_name(self):
return self.get_name()

def get_sort_key(self):
return self.to_expression().get_sort_key()
return self.to_expression().cached_get_sort_key()

def get_string_value(self):
return "-@" + self.get_head_name() + "@-"
Expand Down Expand Up @@ -938,7 +938,7 @@ def get_match_count(self, vars={}):
return (1, 1)

def get_sort_key(self, pattern_sort=False):
return self.expr.get_sort_key(pattern_sort=pattern_sort)
return self.expr.cached_get_sort_key(pattern_sort=pattern_sort)


class NegativeIntegerException(Exception):
Expand Down
2 changes: 1 addition & 1 deletion mathics/builtin/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def apply(self, expr, positions, evaluation):
if positions.elements[0].has_form("List", None)
else [positions]
)
positions.sort(key=lambda e: e.get_sort_key(pattern_sort=True))
positions.sort(key=lambda e: e.cached_get_sort_key(pattern_sort=True))
newexpr = expr
for position in positions:
pos = [p.get_int_value() for p in position.get_elements()]
Expand Down
6 changes: 3 additions & 3 deletions mathics/builtin/numbers/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def apply(self, m, evaluation, options={}) -> Expression:
eigenvalues = [(from_sympy(v), c) for (v, c) in eigenvalues]

# Sort the eigenvalues by their sort key
eigenvalues.sort(key=lambda v: v[0].get_sort_key())
eigenvalues.sort(key=lambda v: v[0].cached_get_sort_key())

eigenvalues = [v for (v, c) in eigenvalues for _ in range(c)]

Expand Down Expand Up @@ -307,9 +307,9 @@ def apply(self, m, evaluation):
key=lambda v: (abs(v[0]), -re(v[0]), -im(v[0])), reverse=True
)
except TypeError:
eigenvects.sort(key=lambda v: from_sympy(v[0]).get_sort_key())
eigenvects.sort(key=lambda v: from_sympy(v[0]).cached_get_sort_key())
else:
eigenvects.sort(key=lambda v: from_sympy(v[0]).get_sort_key())
eigenvects.sort(key=lambda v: from_sympy(v[0]).cached_get_sort_key())

result = []
for val, count, basis in eigenvects:
Expand Down
2 changes: 1 addition & 1 deletion mathics/builtin/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,7 +1686,7 @@ def __init__(self, rulelist, evaluation):
self._head = SymbolDispatch

def get_sort_key(self):
return self.src.get_sort_key()
return self.src.cached_get_sort_key()

def get_atom_name(self):
return "System`Dispatch"
Expand Down
2 changes: 1 addition & 1 deletion mathics/builtin/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class PatternsOrderedQ(Builtin):
def apply(self, p1, p2, evaluation):
"PatternsOrderedQ[p1_, p2_]"

if p1.get_sort_key(True) <= p2.get_sort_key(True):
if p1.cached_get_sort_key(True) <= p2.cached_get_sort_key(True):
return SymbolTrue
else:
return SymbolFalse
Expand Down
12 changes: 9 additions & 3 deletions mathics/core/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def __eq__(self, other) -> bool:
self.to_mpmath(), other.to_mpmath(), abs_eps=0, rel_eps=rel_eps
)
else:
return self.get_sort_key() == other.get_sort_key()
return self.cached_get_sort_key() == other.cached_get_sort_key()

def __ne__(self, other) -> bool:
# Real is a total order
Expand Down Expand Up @@ -683,7 +683,13 @@ def get_sort_key(self, pattern_sort=False):
if pattern_sort:
return super().get_sort_key(True)
else:
return [0, 0, self.real.get_sort_key()[2], self.imag.get_sort_key()[2], 1]
return [
0,
0,
self.real.cached_get_sort_key()[2],
self.imag.cached_get_sort_key()[2],
1,
]

def sameQ(self, other) -> bool:
"""Mathics SameQ"""
Expand Down Expand Up @@ -740,7 +746,7 @@ def __eq__(self, other) -> bool:
if isinstance(other, Complex):
return self.real == other.real and self.imag == other.imag
else:
return self.get_sort_key() == other.get_sort_key()
return self.cached_get_sort_key() == other.cached_get_sort_key()

def __getnewargs__(self):
return (self.real, self.imag)
Expand Down
27 changes: 20 additions & 7 deletions mathics/core/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,41 @@ class KeyComparable:
def get_sort_key(self):
raise NotImplementedError

def cached_get_sort_key(self, pattern_sort=False):
if pattern_sort:
if hasattr(self, "pattern_sort_key"):
return self.pattern_sort_key
sort_key = self.get_sort_key(True)
self.pattern_sort_key = sort_key
else:
if hasattr(self, "sort_key"):
return self.sort_key
sort_key = self.get_sort_key()
self.sort_key = sort_key
return sort_key

def __eq__(self, other) -> bool:
return (
hasattr(other, "get_sort_key")
and self.get_sort_key() == other.get_sort_key()
hasattr(other, "cached_get_sort_key")
and self.cached_get_sort_key() == other.cached_get_sort_key()
)

def __gt__(self, other) -> bool:
return self.get_sort_key() > other.get_sort_key()
return self.cached_get_sort_key() > other.cached_get_sort_key()

def __ge__(self, other) -> bool:
return self.get_sort_key() >= other.get_sort_key()
return self.cached_get_sort_key() >= other.cached_get_sort_key()

def __le__(self, other) -> bool:
return self.get_sort_key() <= other.get_sort_key()
return self.cached_get_sort_key() <= other.cached_get_sort_key()

def __lt__(self, other) -> bool:
return self.get_sort_key() < other.get_sort_key()
return self.cached_get_sort_key() < other.cached_get_sort_key()

def __ne__(self, other) -> bool:
return (
not hasattr(other, "get_sort_key")
) or self.get_sort_key() != other.get_sort_key()
) or self.cached_get_sort_key() != other.cached_get_sort_key()


class BaseElement(KeyComparable):
Expand Down
35 changes: 21 additions & 14 deletions mathics/core/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,6 @@ def get_sort_key(self, pattern_sort=False):
6: leaves / 0 for atoms
7: 0/1: 0 for Condition
"""

head = self._head
pattern = 0
if head is SymbolBlank:
Expand All @@ -744,40 +743,42 @@ def get_sort_key(self, pattern_sort=False):
1,
1,
0,
head.get_sort_key(True),
tuple(element.get_sort_key(True) for element in self._elements),
head.cached_get_sort_key(True),
tuple(
element.cached_get_sort_key(True) for element in self._elements
),
1,
]

if head is SymbolPatternTest:
if len(self._elements) != 2:
return [3, 0, 0, 0, 0, head, self._elements, 1]
sub = self._elements[0].get_sort_key(True)
sub = self._elements[0].cached_get_sort_key(True)
sub[2] = 0
return sub
elif head is SymbolCondition:
if len(self._elements) != 2:
return [3, 0, 0, 0, 0, head, self._elements, 1]
sub = self._elements[0].get_sort_key(True)
sub = self._elements[0].cached_get_sort_key(True)
sub[7] = 0
return sub
elif head is SymbolPattern:
if len(self._elements) != 2:
return [3, 0, 0, 0, 0, head, self._elements, 1]
sub = self._elements[1].get_sort_key(True)
sub = self._elements[1].cached_get_sort_key(True)
sub[3] = 0
return sub
elif head is SymbolOptional:
if len(self._elements) not in (1, 2):
return [3, 0, 0, 0, 0, head, self._elements, 1]
sub = self._elements[0].get_sort_key(True)
sub = self._elements[0].cached_get_sort_key(True)
sub[4] = 1
return sub
elif head is SymbolAlternatives:
min_key = [4]
min = None
for element in self._elements:
key = element.get_sort_key(True)
key = element.cached_get_sort_key(True)
if key < min_key:
min = element
min_key = key
Expand All @@ -788,7 +789,7 @@ def get_sort_key(self, pattern_sort=False):
elif head is SymbolVerbatim:
if len(self._elements) != 1:
return [3, 0, 0, 0, 0, head, self._elements, 1]
return self._elements[0].get_sort_key(True)
return self._elements[0].cached_get_sort_key(True)
elif head is SymbolOptionsPattern:
return [2, 40, 0, 1, 1, 0, head, self._elements, 1]
else:
Expand All @@ -800,10 +801,13 @@ def get_sort_key(self, pattern_sort=False):
1,
1,
0,
head.get_sort_key(True),
head.cached_get_sort_key(True),
tuple(
chain(
(element.get_sort_key(True) for element in self._elements),
(
element.cached_get_sort_key(True)
for element in self._elements
),
([4],),
)
),
Expand Down Expand Up @@ -1305,7 +1309,7 @@ def set_element(self, index: int, value):
def shallow_copy(self) -> "Expression":
# this is a minimal, shallow copy: head, elements are shared with
# the original, only the Expression instance is new.

return self
expr = Expression(
self._head, *self._elements, elements_properties=self.elements_properties
)
Expand Down Expand Up @@ -1438,13 +1442,16 @@ def sort(self, pattern=False):
# list sort method. Another approach would be to use sorted().
elements = self.get_mutable_elements()
if pattern:
elements.sort(key=lambda e: e.get_sort_key(pattern_sort=True))
elements.sort(key=lambda e: e.cached_get_sort_key(pattern_sort=True))
else:
elements.sort()

# update `self._elements` and self._cache with the possible permuted order.
self.elements = elements
self._build_elements_properties()
if self.elements_properties is None:
self._build_elements_properties()
else:
self.elements_properties.is_ordered = True

if self._cache:
self._cache = self._cache.reordered()
Expand Down
2 changes: 1 addition & 1 deletion mathics/core/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_elements(self):
get_leaves = get_elements

def get_sort_key(self, pattern_sort=False):
return self.expr.get_sort_key(pattern_sort=pattern_sort)
return self.expr.cached_get_sort_key(pattern_sort=pattern_sort)

def get_lookup_name(self):
return self.expr.get_lookup_name()
Expand Down