diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index 96b6d678..785feafb 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -20,18 +20,16 @@ THE SOFTWARE. """ -import pymbolic.primitives as prim import pytest +from functools import reduce + +import pymbolic.primitives as prim from pymbolic import parse from pytools.lex import ParseError - - from pymbolic.mapper import IdentityMapper -try: - reduce -except NameError: - from functools import reduce +import logging +logger = logging.getLogger(__name__) # {{{ utilities @@ -39,9 +37,11 @@ def assert_parsed_same_as_python(expr_str): # makes sure that has only one line expr_str, = expr_str.split("\n") - from pymbolic.interop.ast import ASTToPymbolic + import ast + from pymbolic.interop.ast import ASTToPymbolic ast2p = ASTToPymbolic() + try: expr_parsed_by_python = ast2p(ast.parse(expr_str).body[0].value) except SyntaxError: @@ -53,14 +53,17 @@ def assert_parsed_same_as_python(expr_str): def assert_parse_roundtrip(expr_str): - expr = parse(expr_str) from pymbolic.mapper.stringifier import StringifyMapper + expr = parse(expr_str) strified = StringifyMapper()(expr) + assert strified == expr_str, (strified, expr_str) # }}} +# {{{ test_integer_power + def test_integer_power(): from pymbolic.algorithm import integer_power @@ -72,6 +75,10 @@ def test_integer_power(): ]: assert base**expn == integer_power(base, expn) +# }}} + + +# {{{ test_expand def test_expand(): from pymbolic import var, expand @@ -80,6 +87,10 @@ def test_expand(): u = (x+1)**5 expand(u) +# }}} + + +# {{{ test_substitute def test_substitute(): from pymbolic import parse, substitute, evaluate @@ -87,6 +98,10 @@ def test_substitute(): xmin = parse("x.min") assert evaluate(substitute(u, {xmin: 25})) == 630 +# }}} + + +# {{{ test_no_comparison def test_no_comparison(): from pymbolic import parse @@ -107,13 +122,20 @@ def expect_typeerror(f): expect_typeerror(lambda: x > y) expect_typeerror(lambda: x >= y) +# }}} + + +# {{{ test_structure_preservation def test_structure_preservation(): x = prim.Sum((5, 7)) - from pymbolic.mapper import IdentityMapper x2 = IdentityMapper()(x) assert x == x2 +# }}} + + +# {{{ test_sympy_interaction def test_sympy_interaction(): pytest.importorskip("sympy") @@ -141,6 +163,8 @@ def test_sympy_interaction(): assert sp.ratsimp(s1_expr - s3_expr) == 0 +# }}} + # {{{ fft @@ -181,9 +205,9 @@ def test_fft(): from pymbolic.algorithm import fft, sym_fft vars = numpy.array([var(chr(97+i)) for i in range(16)], dtype=object) - print(vars) + logger.info("vars: %s", vars) - print(fft(vars)) + logger.info("fft: %s", fft(vars)) traced_fft = sym_fft(vars) from pymbolic.mapper.stringifier import PREC_NONE @@ -193,14 +217,16 @@ def test_fft(): code = [ccm(tfi, PREC_NONE) for tfi in traced_fft] for cse_name, cse_str in enumerate(ccm.cse_name_list): - print(f"{cse_name} = {cse_str}") + logger.info("%s = %s", cse_name, cse_str) for i, line in enumerate(code): - print("result[%d] = %s" % (i, line)) + logger.info("result[%d] = %s", i, line) # }}} +# {{{ test_sparse_multiply + def test_sparse_multiply(): numpy = pytest.importorskip("numpy") pytest.importorskip("scipy") @@ -219,6 +245,8 @@ def test_sparse_multiply(): assert la.norm(mat_vec-mat_vec_2) < 1e-14 +# }}} + # {{{ parser @@ -227,25 +255,25 @@ def test_parser(): parse("(2*a[1]*b[1]+2*a[0]*b[0])*(hankel_1(-1,sqrt(a[1]**2+a[0]**2)*k) " "-hankel_1(1,sqrt(a[1]**2+a[0]**2)*k))*k /(4*sqrt(a[1]**2+a[0]**2)) " "+hankel_1(0,sqrt(a[1]**2+a[0]**2)*k)") - print(repr(parse("d4knl0"))) - print(repr(parse("0."))) - print(repr(parse("0.e1"))) + logger.info("%r", parse("d4knl0")) + logger.info("%r", parse("0.")) + logger.info("%r", parse("0.e1")) assert parse("0.e1") == 0 assert parse("1e-12") == 1e-12 - print(repr(parse("a >= 1"))) - print(repr(parse("a <= 1"))) - - print(repr(parse(":"))) - print(repr(parse("1:"))) - print(repr(parse(":2"))) - print(repr(parse("1:2"))) - print(repr(parse("::"))) - print(repr(parse("1::"))) - print(repr(parse(":1:"))) - print(repr(parse("::1"))) - print(repr(parse("3::1"))) - print(repr(parse(":5:1"))) - print(repr(parse("3:5:1"))) + logger.info("%r", parse("a >= 1")) + logger.info("%r", parse("a <= 1")) + + logger.info("%r", parse(":")) + logger.info("%r", parse("1:")) + logger.info("%r", parse(":2")) + logger.info("%r", parse("1:2")) + logger.info("%r", parse("::")) + logger.info("%r", parse("1::")) + logger.info("%r", parse(":1:")) + logger.info("%r", parse("::1")) + logger.info("%r", parse("3::1")) + logger.info("%r", parse(":5:1")) + logger.info("%r", parse("3:5:1")) assert_parse_roundtrip("()") assert_parse_roundtrip("(3,)") @@ -257,17 +285,17 @@ def test_parser(): assert_parse_roundtrip("g[i, k] + 2.0*h[i, k]") parse("g[i,k]+(+2.0)*h[i, k]") - print(repr(parse("a - b - c"))) - print(repr(parse("-a - -b - -c"))) - print(repr(parse("- - - a - - - - b - - - - - c"))) + logger.info("%r", parse("a - b - c")) + logger.info("%r", parse("-a - -b - -c")) + logger.info("%r", parse("- - - a - - - - b - - - - - c")) - print(repr(parse("~(a ^ b)"))) - print(repr(parse("(a | b) | ~(~a & ~b)"))) + logger.info("%r", parse("~(a ^ b)")) + logger.info("%r", parse("(a | b) | ~(~a & ~b)")) - print(repr(parse("3 << 1"))) - print(repr(parse("1 >> 3"))) + logger.info("%r", parse("3 << 1")) + logger.info("%r", parse("1 >> 3")) - print(parse("3::1")) + logger.info(parse("3::1")) assert parse("e1") == prim.Variable("e1") assert parse("d1") == prim.Variable("d1") @@ -295,6 +323,8 @@ def test_parser(): # }}} +# {{{ test_mappers + def test_mappers(): from pymbolic import variables f, x, y, z = variables("f x y z") @@ -310,6 +340,11 @@ def test_mappers(): DependencyMapper()(expr) +# }}} + + +# {{{ test_func_dep_consistency + def test_func_dep_consistency(): from pymbolic import var from pymbolic.mapper.dependency import DependencyMapper @@ -319,6 +354,10 @@ def test_func_dep_consistency(): assert dep_map(f(x)) == {x} assert dep_map(f(x=x)) == {x} +# }}} + + +# {{{ test_conditions def test_conditions(): from pymbolic import var @@ -326,6 +365,10 @@ def test_conditions(): y = var("y") assert str(x.eq(y).and_(x.le(5))) == "x == y and x <= 5" +# }}} + + +# {{{ test_graphviz def test_graphviz(): from pymbolic import parse @@ -336,7 +379,9 @@ def test_graphviz(): from pymbolic.mapper.graphviz import GraphvizMapper gvm = GraphvizMapper() gvm(expr) - print(gvm.get_dot_code()) + logger.info("%s", gvm.get_dot_code()) + +# }}} # {{{ geometric algebra @@ -443,6 +488,8 @@ def test_geometric_algebra(dims): # }}} +# {{{ test_ast_interop + def test_ast_interop(): src = """ def f(): @@ -453,7 +500,7 @@ def f(): import ast mod = ast.parse(src.replace("\n ", "\n")) - print(ast.dump(mod)) + logger.info("%s", ast.dump(mod)) from pymbolic.interop.ast import ASTToPymbolic ast2p = ASTToPymbolic() @@ -470,8 +517,12 @@ def f(): lhs = ast2p(lhs) rhs = ast2p(stmt.value) - print(lhs, rhs) + logger.info("lhs %s rhs %s", lhs, rhs) + +# }}} + +# {{{ test_compile def test_compile(): from pymbolic import parse, compile @@ -483,6 +534,10 @@ def test_compile(): code = pickle.loads(pickle.dumps(code)) assert code(3, 3) == 27 +# }}} + + +# {{{ test_unifier def test_unifier(): from pymbolic import var @@ -521,6 +576,10 @@ def match_found(records, eqns): assert len(recs) == 1 assert match_found(recs, {(a, b), (b, c), (c, d)}) +# }}} + + +# {{{ test_long_sympy_mapping def test_long_sympy_mapping(): sp = pytest.importorskip("sympy") @@ -528,6 +587,10 @@ def test_long_sympy_mapping(): SympyToPymbolicMapper()(sp.sympify(int(10**20))) SympyToPymbolicMapper()(sp.sympify(int(10))) +# }}} + + +# {{{ test_stringifier_preserve_shift_order def test_stringifier_preserve_shift_order(): for expr in [ @@ -536,6 +599,10 @@ def test_stringifier_preserve_shift_order(): ]: assert parse(str(expr)) == expr +# }}} + + +# {{{ test_latex_mapper LATEX_TEMPLATE = r"""\documentclass{article} \usepackage{amsmath} @@ -604,6 +671,10 @@ def add(expr): finally: shutil.rmtree(latex_dir) +# }}} + + +# {{{ test_flop_counter def test_flop_counter(): x = prim.Variable("x") @@ -618,6 +689,10 @@ def test_flop_counter(): assert CSEAwareFlopCounter()(expr) == 4 + 2 +# }}} + + +# {{{ test_make_sym_vector def test_make_sym_vector(): numpy = pytest.importorskip("numpy") @@ -627,6 +702,10 @@ def test_make_sym_vector(): assert len(make_sym_vector("vec", numpy.int32(2))) == 2 assert len(make_sym_vector("vec", [1, 2, 3])) == 3 +# }}} + + +# {{{ test_multiplicative_stringify_preserves_association def test_multiplicative_stringify_preserves_association(): for inner in ["*", " / ", " // ", " % "]: @@ -639,6 +718,10 @@ def test_multiplicative_stringify_preserves_association(): assert_parse_roundtrip("(-1)*(((-1)*x) / 5)") +# }}} + + +# {{{ test_differentiator_flags_for_nonsmooth_and_discontinuous def test_differentiator_flags_for_nonsmooth_and_discontinuous(): import pymbolic.functions as pf @@ -658,6 +741,10 @@ def test_differentiator_flags_for_nonsmooth_and_discontinuous(): result = differentiate(pf.sign(x), x, allowed_nonsmoothness="discontinuous") assert result == 0 +# }}} + + +# {{{ test_diff_cse def test_diff_cse(): from pymbolic.mapper.differentiator import differentiate @@ -686,6 +773,10 @@ def test_diff_cse(): assert err2 < 1.1 * 0.5**2 * err1 +# }}} + + +# {{{ test_coefficient_collector def test_coefficient_collector(): from pymbolic.mapper.coefficient import CoefficientCollector @@ -698,6 +789,10 @@ def test_coefficient_collector(): assert cc(2*x + y - z) == {x: 2, y: 1, 1: -z} assert cc(x/2 + z**2) == {x: prim.Quotient(1, 2), 1: z**2} +# }}} + + +# {{{ test_np_bool_handling def test_np_bool_handling(): from pymbolic.mapper.evaluator import evaluate @@ -705,6 +800,10 @@ def test_np_bool_handling(): expr = prim.LogicalNot(numpy.bool_(False)) assert evaluate(expr) is True +# }}} + + +# {{{ test_mapper_method_of_parent_class def test_mapper_method_of_parent_class(): class SpatialConstant(prim.Variable): @@ -719,6 +818,50 @@ def map_spatial_constant(self, expr): assert MyMapper()(c) == 2*c assert IdentityMapper()(c) == c +# }}} + + +# {{{ test_equality_complexity + +@pytest.mark.xfail +def test_equality_complexity(): + # NOTE: https://github.com/inducer/pymbolic/issues/73 + from numpy.random import default_rng + + def construct_intestine_graph(depth=64, seed=0): + rng = default_rng(seed) + x = prim.Variable("x") + + for _ in range(depth): + coeff1, coeff2 = rng.integers(1, 10, 2) + x = coeff1 * x + coeff2 * x + + return x + + def check_equality(): + graph1 = construct_intestine_graph() + graph2 = construct_intestine_graph() + graph3 = construct_intestine_graph(seed=3) + + assert graph1 == graph2 + assert graph2 == graph1 + assert graph1 != graph3 + assert graph2 != graph3 + + # NOTE: this should finish in a second! + import multiprocessing + p = multiprocessing.Process(target=check_equality) + p.start() + p.join(timeout=1) + + is_alive = p.is_alive() + if p.is_alive(): + p.terminate() + + assert not is_alive + +# }}} + if __name__ == "__main__": import sys