diff --git a/pymbolic/mapper/subst_applier.py b/pymbolic/mapper/subst_applier.py new file mode 100644 index 00000000..2b030399 --- /dev/null +++ b/pymbolic/mapper/subst_applier.py @@ -0,0 +1,45 @@ +from __future__ import annotations + + +__copyright__ = "Copyright (C) 2021 Thomas Gibson" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from pymbolic.mapper import IdentityMapper + + +class SubstitutionApplier(IdentityMapper): + """todo. + """ + + def map_substitution(self, expr, current_substs): + new_substs = current_substs.copy() + new_substs.update( + {variable: self.rec(value, current_substs) + for variable, value in zip(expr.variables, expr.values)}) + return self.rec(expr.child, new_substs) + + def map_variable(self, expr, current_substs): + return current_substs.get(expr.name, expr) + + def __call__(self, expr): + current_substs = {} + return super().__call__(expr, current_substs) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 838bad1a..496a0427 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -1551,7 +1551,24 @@ def get_extra_properties(self): @expr_dataclass() class Substitution(Expression): - """Work-alike of sympy's Subs.""" + """A (deferred) substitution applicable to a subexpression. + + See also sympy's ``Subs``. + + .. attribute:: child + + The sub-:class:`Expression` to which the substitution is to be applied. + + .. attribute:: variables + + A sequence of string identifiers of the variables to be replaced with + their corresponding entry in :attr:`values`. + + .. attribute:: values + + A sequence of sub-:class:`Expression` objects corresponding to each + string identifier in :attr:`variables`. + """ child: ExpressionT variables: tuple[str, ...] diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index eb8ac768..b49333e2 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -1016,6 +1016,40 @@ def test_nodecount(): assert get_num_nodes(expr) == 12 +def test_subst_applier(): + x = prim.Variable("x") + y = prim.Variable("y") + z = prim.Variable("z") + + from pymbolic.mapper.substitutor import substitute as subst_actual + + def subst_deferred(expr, **kwargs): + variables = [] + values = [] + for name, value in kwargs.items(): + variables.append(name) + values.append(value) + return prim.Substitution(expr, variables, values) + + from pymbolic.mapper.subst_applier import SubstitutionApplier + sapp = SubstitutionApplier() + + results = [] + for subst in [subst_actual, subst_deferred]: + expr = subst(x + y, x=5*y) + print(expr) + expr = subst(subst(expr**2, y=z) - subst(expr, y=x), x=y) + print(expr) + expr = sapp(expr) + print(expr) + + results.append(sapp(expr)) + print("--------") + + result_actual, result_deferred = results + assert result_actual == result_deferred + + def test_python_ast_interop_roundtrip(): from pymbolic.interop.ast import ASTToPymbolic, PymbolicToASTMapper