diff --git a/ramanujantools/cmf/meijer_g.py b/ramanujantools/cmf/meijer_g.py index e009c3b..db0ac3b 100644 --- a/ramanujantools/cmf/meijer_g.py +++ b/ramanujantools/cmf/meijer_g.py @@ -31,26 +31,20 @@ def __repr__(self) -> str: def __getstate__(self): state = super().__getstate__().copy() - state.update({ - 'm': self.m, - 'n': self.n, - 'p': self.p, - 'q': self.q, - 'z': self.z - }) + state.update({"m": self.m, "n": self.n, "p": self.p, "q": self.q, "z": self.z}) return state def __setstate__(self, state): - self.m = state['m'] - self.n = state['n'] - self.p = state['p'] - self.q = state['q'] - self.z = state['z'] - del state['m'] - del state['n'] - del state['p'] - del state['q'] - del state['z'] + self.m = state["m"] + self.n = state["n"] + self.p = state["p"] + self.q = state["q"] + self.z = state["z"] + del state["m"] + del state["n"] + del state["p"] + del state["q"] + del state["z"] super().__setstate__(state) @staticmethod @@ -86,8 +80,8 @@ def construct_matrix( is_a = axis.name.startswith("a") index = int(axis.name[1:]) if is_a: - sign = -1 if index > n else 1 + sign = -1 if index >= n else 1 else: - sign = 1 if index > m else -1 + sign = 1 if index >= m else -1 multiplier = axis - 1 if is_a else axis return (theta_matrix - multiplier * eye) * sign diff --git a/ramanujantools/cmf/meijer_g_test.py b/ramanujantools/cmf/meijer_g_test.py index 072a52d..86fc602 100644 --- a/ramanujantools/cmf/meijer_g_test.py +++ b/ramanujantools/cmf/meijer_g_test.py @@ -7,6 +7,9 @@ from ramanujantools import Position, Matrix, LinearRecurrence from ramanujantools.cmf import MeijerG +a0, a1 = sp.symbols("a:2") +b0, b1 = sp.symbols("b:2") + def test_conserving(): for _p in range(1, 3): @@ -16,6 +19,12 @@ def test_conserving(): MeijerG(_m, _n, _p, _q, z).validate_conserving() +def test_m_n_sign(): + cmf = MeijerG(1, 1, 2, 2) + assert cmf.M(a0).subs({a0: a1, a1: a0}) == -1 * cmf.M(a1) + assert cmf.M(b0).subs({b0: b1, b1: b0}) == -1 * cmf.M(b1) + + def test_gamma(): cmf = MeijerG(3, 2, 2, 3, 1) a0, a1 = sp.symbols("a:2") @@ -37,18 +46,31 @@ def test_serialization(): serialized_data = pickle.dumps(original_meijer_g) unpickled_meijer_g = pickle.loads(serialized_data) - assert isinstance(unpickled_meijer_g, MeijerG), "Object type mismatch after unpickling." - assert hasattr(unpickled_meijer_g, 'p'), 'expected to have attribute p' - assert hasattr(unpickled_meijer_g, 'q'), 'expected to have attribute q' - assert hasattr(unpickled_meijer_g, 'm'), 'expected to have attribute m' - assert hasattr(unpickled_meijer_g, 'n'), 'expected to have attribute n' - assert hasattr(unpickled_meijer_g, 'z'), 'expected to have attribute z' - assert unpickled_meijer_g.p == original_meijer_g.p, f"p mismatch: {unpickled_meijer_g.p} != {original_meijer_g.p}" - assert unpickled_meijer_g.q == original_meijer_g.q, f"q mismatch: {unpickled_meijer_g.q} != {original_meijer_g.q}" - assert unpickled_meijer_g.m == original_meijer_g.m, f"m mismatch: {unpickled_meijer_g.m} != {original_meijer_g.m}" - assert unpickled_meijer_g.n == original_meijer_g.n, f"n mismatch: {unpickled_meijer_g.n} != {original_meijer_g.n}" - assert unpickled_meijer_g.z == original_meijer_g.z, f"z mismatch: {unpickled_meijer_g.z} != {original_meijer_g.z}" - + assert isinstance(unpickled_meijer_g, MeijerG), ( + "Object type mismatch after unpickling." + ) + assert hasattr(unpickled_meijer_g, "p"), "expected to have attribute p" + assert hasattr(unpickled_meijer_g, "q"), "expected to have attribute q" + assert hasattr(unpickled_meijer_g, "m"), "expected to have attribute m" + assert hasattr(unpickled_meijer_g, "n"), "expected to have attribute n" + assert hasattr(unpickled_meijer_g, "z"), "expected to have attribute z" + assert unpickled_meijer_g.p == original_meijer_g.p, ( + f"p mismatch: {unpickled_meijer_g.p} != {original_meijer_g.p}" + ) + assert unpickled_meijer_g.q == original_meijer_g.q, ( + f"q mismatch: {unpickled_meijer_g.q} != {original_meijer_g.q}" + ) + assert unpickled_meijer_g.m == original_meijer_g.m, ( + f"m mismatch: {unpickled_meijer_g.m} != {original_meijer_g.m}" + ) + assert unpickled_meijer_g.n == original_meijer_g.n, ( + f"n mismatch: {unpickled_meijer_g.n} != {original_meijer_g.n}" + ) + assert unpickled_meijer_g.z == original_meijer_g.z, ( + f"z mismatch: {unpickled_meijer_g.z} != {original_meijer_g.z}" + ) + + def test_asymptotics_fail1(): cmf = MeijerG(3, 2, 2, 3, 1) a0, a1 = sp.symbols("a:2")