diff --git a/ramanujantools/flint_core/numeric_matrix.py b/ramanujantools/flint_core/numeric_matrix.py index 42c5bff..cedaa90 100644 --- a/ramanujantools/flint_core/numeric_matrix.py +++ b/ramanujantools/flint_core/numeric_matrix.py @@ -2,8 +2,9 @@ from typing import Callable +import numpy as np import sympy as sp -from flint import fmpq_mat, fmpq +from flint import fmpq_mat, fmpq, fmpz_mat import ramanujantools as rt from ramanujantools import Position @@ -11,6 +12,283 @@ class NumericMatrix(fmpq_mat): + @staticmethod + def _fmpz_eye(size: int) -> fmpz_mat: + """ + Returns an identity matrix over FLINT integers. + """ + return fmpz_mat( + [ + [1 if row_index == col_index else 0 for col_index in range(size)] + for row_index in range(size) + ] + ) + + @staticmethod + def _to_fmpq(value: sp.Expr | int) -> fmpq: + """ + Converts an exact integer or rational SymPy value into a FLINT rational. + """ + value = sp.S(value) + if isinstance(value, sp.Integer): + return fmpq(int(value)) + if isinstance(value, sp.Rational): + return fmpq(int(value.p), int(value.q)) + raise TypeError(f"Expected an exact rational value, got {value}") + + @staticmethod + def _compile_batched_walk_matrix( + matrix: rt.Matrix, symbol: sp.Symbol + ) -> tuple[np.ndarray, np.ndarray, int]: + """ + Compiles a one-symbol rational matrix into integer coefficient tensors. + + The common denominator is used as one scalar per step, while the inflated + polynomial entries are evaluated together from their coefficient columns. + """ + cache = getattr(matrix, "_numeric_batched_walk_cache", None) + if cache is None: + cache = {} + matrix._numeric_batched_walk_cache = cache + if symbol in cache: + return cache[symbol] + + common_denominator = sp.factor( + sp.lcm_list( + [ + sp.denom(matrix[row_index, col_index]) + for row_index in range(matrix.rows) + for col_index in range(matrix.cols) + ] + ) + ) + + denominator_poly = sp.Poly(common_denominator, symbol) + flattened_denominator_coefficients = np.zeros(1, dtype=object) + flattened_denominator_coefficients[0] = sp.Integer(0) + + coefficients = [] + max_degree = 0 + for row_index in range(matrix.rows): + row_coefficients = [] + for col_index in range(matrix.cols): + polynomial = sp.Poly(matrix[row_index, col_index] * common_denominator, symbol) + if not all(coefficient.is_integer for coefficient in polynomial.all_coeffs()): + raise sp.PolynomialError( + f"Inflated entry does not have integer coefficients: {matrix[row_index, col_index] * common_denominator}" + ) + entry_coefficients = tuple( + int(coefficient) for coefficient in reversed(polynomial.all_coeffs()) + ) or (0,) + max_degree = max(max_degree, len(entry_coefficients) - 1) + row_coefficients.append(entry_coefficients) + coefficients.append(tuple(row_coefficients)) + + flattened_entry_coefficients = np.zeros( + (max_degree + 1, matrix.rows * matrix.cols), dtype=object + ) + for row_index in range(matrix.rows): + for col_index in range(matrix.cols): + entry_index = row_index * matrix.cols + col_index + for degree, coefficient in enumerate(coefficients[row_index][col_index]): + flattened_entry_coefficients[degree, entry_index] = coefficient + + flattened_denominator_coefficients = np.zeros(max_degree + 1, dtype=object) + if not all(coefficient.is_integer for coefficient in denominator_poly.all_coeffs()): + raise sp.PolynomialError( + f"Common denominator does not have integer coefficients: {common_denominator}" + ) + for degree, coefficient in enumerate(reversed(denominator_poly.all_coeffs())): + flattened_denominator_coefficients[degree] = int(coefficient) + compiled = flattened_denominator_coefficients, flattened_entry_coefficients, max_degree + cache[symbol] = compiled + return compiled + + @staticmethod + def _can_use_batched_evaluation( + matrix: rt.Matrix, trajectory: Position, start: Position + ) -> bool: + """ + Returns True when the walk can use one-symbol batched polynomial evaluation. + """ + if len(trajectory) != 1 or len(start) != 1: + return False + + symbol = next(iter(trajectory.keys())) + if not sp.S(trajectory[symbol]).is_integer or not sp.S(start[symbol]).is_integer: + return False + if set(matrix.free_symbols) - {symbol}: + return False + + try: + NumericMatrix._compile_batched_walk_matrix(matrix, symbol) + except sp.PolynomialError: + return False + return True + + @staticmethod + def _batched_step_matrices( + matrix: rt.Matrix, + trajectory: Position, + depth: int, + start: Position, + ) -> list["NumericMatrix"]: + """ + Generates all one-symbol step matrices with a Vandermonde powers table. + """ + if depth == 0: + return [] + + symbol = next(iter(trajectory.keys())) + step_size = int(sp.S(trajectory[symbol])) + start_value = int(sp.S(start[symbol])) + denominator_coefficients, flattened_entry_coefficients, max_degree = NumericMatrix._compile_batched_walk_matrix( + matrix, symbol + ) + + evaluation_points = np.array( + [start_value + offset * step_size for offset in range(depth)], dtype=object + ) + vandermonde = np.empty((depth, max_degree + 1), dtype=object) + vandermonde[:, 0] = 1 + for degree in range(1, max_degree + 1): + vandermonde[:, degree] = vandermonde[:, degree - 1] * evaluation_points + + denominators = vandermonde @ denominator_coefficients + evaluated_entries = vandermonde @ flattened_entry_coefficients + + step_matrices = [] + for depth_index in range(depth): + denominator = int(denominators[depth_index]) + if denominator == 0: + raise ZeroDivisionError( + f"Common denominator vanished at {symbol}={evaluation_points[depth_index]}" + ) + rows = [] + for row_index in range(matrix.rows): + row_values = [] + for col_index in range(matrix.cols): + entry_index = row_index * matrix.cols + col_index + numerator = int(evaluated_entries[depth_index, entry_index]) + row_values.append(fmpq(numerator, denominator)) + rows.append(row_values) + step_matrices.append(NumericMatrix(rows)) + return step_matrices + + @staticmethod + def _batched_integer_step_data( + matrix: rt.Matrix, + trajectory: Position, + depth: int, + start: Position, + ) -> tuple[list[fmpz_mat], list[int]]: + """ + Generates inflated integer step matrices together with their scalar denominators. + """ + if depth == 0: + return [], [] + + symbol = next(iter(trajectory.keys())) + step_size = int(sp.S(trajectory[symbol])) + start_value = int(sp.S(start[symbol])) + denominator_coefficients, flattened_entry_coefficients, max_degree = NumericMatrix._compile_batched_walk_matrix( + matrix, symbol + ) + + evaluation_points = np.array( + [start_value + offset * step_size for offset in range(depth)], dtype=object + ) + vandermonde = np.empty((depth, max_degree + 1), dtype=object) + vandermonde[:, 0] = 1 + for degree in range(1, max_degree + 1): + vandermonde[:, degree] = vandermonde[:, degree - 1] * evaluation_points + + denominators = [int(value) for value in (vandermonde @ denominator_coefficients)] + evaluated_entries = vandermonde @ flattened_entry_coefficients + + step_matrices = [] + for depth_index in range(depth): + if denominators[depth_index] == 0: + raise ZeroDivisionError( + f"Common denominator vanished at {symbol}={evaluation_points[depth_index]}" + ) + rows = [] + for row_index in range(matrix.rows): + row_values = [] + for col_index in range(matrix.cols): + entry_index = row_index * matrix.cols + col_index + row_values.append(int(evaluated_entries[depth_index, entry_index])) + rows.append(row_values) + step_matrices.append(fmpz_mat(rows)) + return step_matrices, denominators + + @staticmethod + def _numeric_from_integer_product(product: fmpz_mat, scalar: int) -> "NumericMatrix": + """ + Converts one exact inflated integer product back into a FLINT rational matrix. + """ + if scalar == 0: + raise ZeroDivisionError("Cannot recover a rational walk from a zero scalar") + return NumericMatrix( + [ + [fmpq(int(product[row_index, col_index]), scalar) for col_index in range(product.ncols())] + for row_index in range(product.nrows()) + ] + ) + + @staticmethod + def _batched_integer_walk( + matrix: rt.Matrix, + iterations: Batchable[int], + trajectory: Position, + start: Position, + ) -> Batchable["NumericMatrix"]: + """ + Runs the reduced walk through inflated integer FLINT matrices and divides only at checkpoints. + """ + if not iterations: + return [] + + step_matrices, step_scalars = NumericMatrix._batched_integer_step_data( + matrix, trajectory, iterations[-1], start + ) + dim = matrix.rows + + sequential_threshold = 8 + + def _product_tree(first: int, last: int) -> tuple[fmpz_mat, int]: + span = last - first + if span == 0: + return step_matrices[first], step_scalars[first] + if span <= sequential_threshold: + result_matrix = step_matrices[first] + result_scalar = step_scalars[first] + for index in range(first + 1, last + 1): + result_matrix = result_matrix * step_matrices[index] + result_scalar *= step_scalars[index] + return result_matrix, result_scalar + midpoint = (first + last) >> 1 + left_matrix, left_scalar = _product_tree(first, midpoint) + right_matrix, right_scalar = _product_tree(midpoint + 1, last) + return left_matrix * right_matrix, left_scalar * right_scalar + + accumulated_matrix = NumericMatrix._fmpz_eye(dim) + accumulated_scalar = 1 + results = [] + previous_depth = 0 + for target_depth in iterations: + if target_depth > previous_depth: + segment_matrix, segment_scalar = _product_tree(previous_depth, target_depth - 1) + accumulated_matrix = accumulated_matrix * segment_matrix + accumulated_scalar *= segment_scalar + results.append( + NumericMatrix._numeric_from_integer_product( + accumulated_matrix, accumulated_scalar + ) + ) + previous_depth = target_depth + return results + @staticmethod def eye(N: int): """ @@ -98,11 +376,14 @@ def walk( iterations: Batchable[int], start: Position, ) -> Batchable[NumericMatrix]: + if NumericMatrix._can_use_batched_evaluation(matrix, trajectory, start): + return NumericMatrix._batched_integer_walk( + matrix, iterations, trajectory, start + ) + N = iterations[-1] - fast_subs = NumericMatrix.lambda_from_rt(matrix) dim = matrix.rows - - # Pre-evaluate all per-step matrices into a flat list. + fast_subs = NumericMatrix.lambda_from_rt(matrix) position = start.copy() step_matrices = [] for _ in range(N): diff --git a/ramanujantools/flint_core/numeric_matrix_test.py b/ramanujantools/flint_core/numeric_matrix_test.py index 05e2192..b149631 100644 --- a/ramanujantools/flint_core/numeric_matrix_test.py +++ b/ramanujantools/flint_core/numeric_matrix_test.py @@ -1,48 +1,126 @@ -from sympy.abc import n +import sympy as sp +from sympy.abc import n, x, y from ramanujantools import Matrix, Position +from ramanujantools.cmf import pFq from ramanujantools.flint_core import NumericMatrix +def _manual_numeric_walk( + matrix: Matrix, trajectory: Position, iterations: list[int], start: Position +) -> list[Matrix]: + """ + Computes checkpoint walk products directly from per-step substitutions. + """ + evaluator = NumericMatrix.lambda_from_rt(matrix) + position = start.copy() + step_matrices = [] + for _ in range(iterations[-1]): + step_matrices.append(evaluator(position)) + position += trajectory + + results = [] + accumulated = NumericMatrix.eye(matrix.rows) + previous_depth = 0 + for target_depth in iterations: + for step_index in range(previous_depth, target_depth): + accumulated = accumulated * step_matrices[step_index] + results.append(accumulated.to_rt()) + previous_depth = target_depth + return results + + def test_conversion(): - m = Matrix( + """Checks that the FLINT evaluator matches direct substitution entrywise.""" + + matrix = Matrix( [ [(n - 1) * (n**2 - n + 1) / n**3, -1 / n**3], [1 / n**3, (n + 1) * (n**2 + n + 1) / n**3], ] ) - for i in range(1, 10): - assert m.subs({n: i}) == NumericMatrix.lambda_from_rt(m)({n: i}).to_rt() + for value in range(1, 10): + assert matrix.subs({n: value}) == NumericMatrix.lambda_from_rt(matrix)({n: value}).to_rt() + +def test_walk_matches_manual_products(): + """Checks the optimized single-symbol walk against manual exact products.""" -def test_walk(): - m = Matrix( + matrix = Matrix( [ [(n - 1) * (n**2 - n + 1) / n**3, -1 / n**3], [1 / n**3, (n + 1) * (n**2 + n + 1) / n**3], ] ) + iterations = [1, 2, 5, 10, 25] + expected = _manual_numeric_walk(matrix, Position({n: 2}), iterations, Position({n: 3})) + actual = NumericMatrix.walk(matrix, Position({n: 2}), iterations, Position({n: 3})) + assert [numeric_matrix.to_rt() for numeric_matrix in actual] == expected - assert ( - m.walk({n: 2}, 100, {n: 3}) - == NumericMatrix.walk(m, Position({n: 2}), 100, Position({n: 3})).to_rt() - ) +def test_walk_preserves_sparse_checkpoint_requests(): + """Checks that checkpoint accumulation stays correct for non-consecutive depths.""" -def test_walk_list(): - m = Matrix( + matrix = Matrix( [ - [(n - 1) * (n**2 - n + 1) / n**3, -1 / n**3], - [1 / n**3, (n + 1) * (n**2 + n + 1) / n**3], + [(n + 2) / (n + 1), (2 * n + 3) / (n + 1)], + [-(n + 5) / (n + 2), (3 * n + 7) / (n + 2)], ] ) + iterations = [3, 11, 17] + expected = _manual_numeric_walk(matrix, Position({n: 1}), iterations, Position({n: 1})) + actual = NumericMatrix.walk(matrix, Position({n: 1}), iterations, Position({n: 1})) + assert [numeric_matrix.to_rt() for numeric_matrix in actual] == expected + + +def test_walk_multisymbol_fallback_matches_manual_products(): + """Checks that unsupported multi-symbol walks still follow the old exact path.""" - rt_walk_matrices = m.walk({n: 2}, list(range(1, 100)), {n: 3}) - numeric_walk_matrices = NumericMatrix.walk( - m, Position({n: 2}), list(range(1, 100)), Position({n: 3}) + matrix = Matrix( + [ + [(x + y + 1) / (x + 1), (x - y) / (y + 2)], + [(2 * x + y) / (x + 2), (x + 3 * y + 1) / (y + 1)], + ] ) - for rt_walk_matrix, numeric_walk_matrix in zip( - rt_walk_matrices, numeric_walk_matrices - ): - assert rt_walk_matrix == numeric_walk_matrix.to_rt() + trajectory = Position({x: 1, y: 2}) + start = Position({x: 2, y: 3}) + iterations = [1, 4, 8] + expected = _manual_numeric_walk(matrix, trajectory, iterations, start) + actual = NumericMatrix.walk(matrix, trajectory, iterations, start) + assert [numeric_matrix.to_rt() for numeric_matrix in actual] == expected + + +def _reduced_pfq_matrix(p_value: int, q_value: int, z_value: sp.Expr) -> Matrix: + """ + Builds the reduced all-ones `pFq` trajectory matrix used in numeric walk benchmarks. + """ + cmf = pFq(p_value, q_value, z_value) + x_axes = sp.symbols(f"x:{p_value}") + y_axes = sp.symbols(f"y:{q_value}") + start = {axis: 1 for axis in x_axes + y_axes} + trajectory = { + **{axis: 1 for axis in x_axes}, + **{axis: 2 for axis in y_axes}, + } + return cmf.trajectory_matrix(trajectory, start) + + +def test_reduced_pfq_2f1_walk_matches_manual_products(): + """Checks the optimized reduced `2F1` walk against manual exact products.""" + + matrix = _reduced_pfq_matrix(2, 1, -1) + iterations = [1, 2, 5, 10, 25] + expected = _manual_numeric_walk(matrix, Position({n: 1}), iterations, Position({n: 1})) + actual = NumericMatrix.walk(matrix, Position({n: 1}), iterations, Position({n: 1})) + assert [numeric_matrix.to_rt() for numeric_matrix in actual] == expected + + +def test_reduced_pfq_3f2_walk_matches_manual_products(): + """Checks the optimized reduced `3F2` walk against manual exact products.""" + + matrix = _reduced_pfq_matrix(3, 2, sp.Rational(1, 4)) + iterations = [1, 2, 5, 10, 25] + expected = _manual_numeric_walk(matrix, Position({n: 1}), iterations, Position({n: 1})) + actual = NumericMatrix.walk(matrix, Position({n: 1}), iterations, Position({n: 1})) + assert [numeric_matrix.to_rt() for numeric_matrix in actual] == expected