Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions drudge/drs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,24 @@ class DrsSymbol(_Definable, Symbol):

def __new__(cls, drudge, name):
"""Create a symbol object."""
symb = super().__new__(cls, name)
# Handle the case where drudge is None during unpickling
if drudge is None:
# During unpickling, we just need the name for __new__
symb = super().__new__(cls, name)
else:
symb = super().__new__(cls, name)
return symb

def __init__(self, drudge, name):
"""Initialize the symbol object."""
self._drudge = drudge
self._orig = Symbol(name)
# During unpickling, drudge might be None, __setstate__ will fix it
if drudge is not None:
self._drudge = drudge
self._orig = Symbol(name)
else:
# This will be set properly in __setstate__
self._drudge = None
self._orig = Symbol(name)

def __eq__(self, other):
"""Make equality comparison."""
Expand Down Expand Up @@ -144,7 +155,10 @@ def __setstate__(self, state):
from .drudge import current_drudge
if current_drudge is None:
raise ValueError(_PICKLE_ENV_ERR)
self.__init__(current_drudge, self.name)
self._drudge = current_drudge
# _orig should already be set from __init__, but make sure
if not hasattr(self, '_orig'):
self._orig = Symbol(self.name)

# Better error reporting.
def __getattr__(self, item):
Expand Down
39 changes: 37 additions & 2 deletions drudge/drudge.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,41 @@
_DECR_SUFFIX = '_InternalProxy'


def _has_symbolic_quantum_objects(expr):
"""Check if expression contains symbolic quantum objects that cannot be simplified with doit().

Returns True if the expression contains quantum physics objects like CG, Wigner3j, Wigner6j
that have is_symbolic=True attribute.
"""
try:
# Import quantum physics classes to check against
from sympy.physics.quantum.cg import CG, Wigner3j, Wigner6j

# Check if this is a quantum physics object with is_symbolic=True
if isinstance(expr, (CG, Wigner3j, Wigner6j)) and hasattr(expr, 'is_symbolic') and expr.is_symbolic:
return True

if hasattr(expr, 'args'):
return any(_has_symbolic_quantum_objects(arg) for arg in expr.args)
except (ImportError, AttributeError):
# If there are any import issues, be conservative and return False
pass

return False


def _safe_simplify(expr):
"""Safely simplify expression avoiding doit() on symbolic quantum objects.

If the expression contains symbolic quantum objects, return it unchanged.
Otherwise, apply normal SymPy simplification.
"""
if _has_symbolic_quantum_objects(expr):
return expr

return expr.simplify()


class Tensor:
"""The main tensor class.

Expand Down Expand Up @@ -453,7 +488,7 @@ def _simplify_amps(terms):
"""Get the terms with amplitude simplified by SymPy."""

simplified_terms = terms.map(
lambda term: term.map(lambda x: x.simplify(), skip_vecs=True)
lambda term: term.map(lambda x: _safe_simplify(x), skip_vecs=True)
).filter(_is_nonzero)

return simplified_terms
Expand Down Expand Up @@ -3624,5 +3659,5 @@ def _simplify_symbolic_sum(expr, **_):
assert len(expr.args) == 2

return eval_sum_symbolic(
expr.args[0].simplify(), expr.args[1]
_safe_simplify(expr.args[0]), expr.args[1]
)
10 changes: 7 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ def spark_ctx():
"""A simple spark context."""

if IF_DUMMY_SPARK:
from dummy_spark import SparkConf, SparkContext
conf = SparkConf()
ctx = SparkContext(master='', conf=conf)
try:
from dummy_spark import SparkConf, SparkContext
conf = SparkConf()
ctx = SparkContext(master='', conf=conf)
except ImportError:
# Fallback to None if dummy_spark is not available
return None
else:
from pyspark import SparkConf, SparkContext
conf = SparkConf().setMaster('local[2]').setAppName('drudge-unittest')
Expand Down
2 changes: 1 addition & 1 deletion tests/nuclear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_wigner3j_sum_to_wigner6j(nuclear: NuclearBogoliubovDrudge):
((-1) ** (j3 - m3) / (2 * j3 + 1))
* KroneckerDelta(j3, jprm3) * KroneckerDelta(m3, mprm3)
* Wigner6j(j1, j2, j3, j4, j5, j6)
).expand().simplify()
).expand()

# For performance reason, just test a random arrangement of the summations.
random.shuffle(sums)
Expand Down
Loading