diff --git a/ramanujantools/cmf/meijer_g.py b/ramanujantools/cmf/meijer_g.py index 485ca93..e009c3b 100644 --- a/ramanujantools/cmf/meijer_g.py +++ b/ramanujantools/cmf/meijer_g.py @@ -29,6 +29,30 @@ def __init__(self, m: int, n: int, p: int, q: int, z: sp.Expr = z): def __repr__(self) -> str: return f"MeijerG({self.m, self.n, self.p, self.q, self.z})" + def __getstate__(self): + state = super().__getstate__().copy() + state.update({ + 'm': self.m, + 'n': self.n, + 'p': self.p, + 'q': self.q, + 'z': self.z + }) + return state + + def __setstate__(self, state): + self.m = state['m'] + self.n = state['n'] + self.p = state['p'] + self.q = state['q'] + self.z = state['z'] + del state['m'] + del state['n'] + del state['p'] + del state['q'] + del state['z'] + super().__setstate__(state) + @staticmethod def a_axes(p) -> list[sp.Symbol]: return sp.symbols(f"a:{p}") diff --git a/ramanujantools/cmf/meijer_g_test.py b/ramanujantools/cmf/meijer_g_test.py index 1395248..072a52d 100644 --- a/ramanujantools/cmf/meijer_g_test.py +++ b/ramanujantools/cmf/meijer_g_test.py @@ -1,4 +1,5 @@ from pytest import approx +import pickle import sympy as sp from sympy.abc import n, z @@ -30,6 +31,24 @@ def test_gamma(): assert limit.as_float() == approx(limit.mp.euler) +def test_serialization(): + original_meijer_g = MeijerG(p=3, n=2, q=1, m=1, z=-1) + + serialized_data = pickle.dumps(original_meijer_g) + unpickled_meijer_g = pickle.loads(serialized_data) + + assert isinstance(unpickled_meijer_g, MeijerG), "Object type mismatch after unpickling." + assert hasattr(unpickled_meijer_g, 'p'), 'expected to have attribute p' + assert hasattr(unpickled_meijer_g, 'q'), 'expected to have attribute q' + assert hasattr(unpickled_meijer_g, 'm'), 'expected to have attribute m' + assert hasattr(unpickled_meijer_g, 'n'), 'expected to have attribute n' + assert hasattr(unpickled_meijer_g, 'z'), 'expected to have attribute z' + assert unpickled_meijer_g.p == original_meijer_g.p, f"p mismatch: {unpickled_meijer_g.p} != {original_meijer_g.p}" + assert unpickled_meijer_g.q == original_meijer_g.q, f"q mismatch: {unpickled_meijer_g.q} != {original_meijer_g.q}" + assert unpickled_meijer_g.m == original_meijer_g.m, f"m mismatch: {unpickled_meijer_g.m} != {original_meijer_g.m}" + assert unpickled_meijer_g.n == original_meijer_g.n, f"n mismatch: {unpickled_meijer_g.n} != {original_meijer_g.n}" + assert unpickled_meijer_g.z == original_meijer_g.z, f"z mismatch: {unpickled_meijer_g.z} != {original_meijer_g.z}" + def test_asymptotics_fail1(): cmf = MeijerG(3, 2, 2, 3, 1) a0, a1 = sp.symbols("a:2") diff --git a/ramanujantools/cmf/pfq.py b/ramanujantools/cmf/pfq.py index d423b69..fc608ee 100644 --- a/ramanujantools/cmf/pfq.py +++ b/ramanujantools/cmf/pfq.py @@ -35,6 +35,24 @@ def __init__( def __repr__(self) -> str: return f"pFq({self.p, self.q, self.z})" + def __getstate__(self): + state = super().__getstate__().copy() + state.update({ + 'p': self.p, + 'q': self.q, + 'z': self.z + }) + return state + + def __setstate__(self, state): + self.p = state['p'] + self.q = state['q'] + self.z = state['z'] + del state['p'] + del state['q'] + del state['z'] + super().__setstate__(state) + @staticmethod def x_axes(p: int) -> list[sp.Symbol]: return sp.symbols(f"x:{p}") diff --git a/ramanujantools/cmf/pfq_test.py b/ramanujantools/cmf/pfq_test.py index f4c3a26..512ad2a 100644 --- a/ramanujantools/cmf/pfq_test.py +++ b/ramanujantools/cmf/pfq_test.py @@ -1,4 +1,5 @@ import pytest +import pickle import sympy as sp from sympy.abc import n, z @@ -247,3 +248,18 @@ def test_hardcoded_determinant_formula(p, q, z, axis): det = pFq_determinant_from_char_poly(p, q, z, axis) calc = pFq.determinant(p, q, z, axis) assert calc == det + + +def test_pfq_serialization(): + original_pfq = pFq(p=2, q=1, z=-1) + + serialized_data = pickle.dumps(original_pfq) + unpickled_pfq = pickle.loads(serialized_data) + + assert isinstance(unpickled_pfq, pFq), "Object type mismatch after unpickling." + assert hasattr(unpickled_pfq, 'p'), 'expected to have attribute p' + assert hasattr(unpickled_pfq, 'q'), 'expected to have attribute q' + assert hasattr(unpickled_pfq, 'z'), 'expected to have attribute z' + assert unpickled_pfq.p == original_pfq.p, f"p mismatch: {unpickled_pfq.p} != {original_pfq.p}" + assert unpickled_pfq.q == original_pfq.q, f"q mismatch: {unpickled_pfq.q} != {original_pfq.q}" + assert unpickled_pfq.z == original_pfq.z, f"z mismatch: {unpickled_pfq.z} != {original_pfq.z}"