Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -6359,6 +6359,14 @@
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
"startColumn": 4,
"endColumn": 14,
"lineCount": 1
}
},
{
"code": "reportUnannotatedClassAttribute",
"range": {
Expand Down Expand Up @@ -6903,6 +6911,14 @@
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 15,
"endColumn": 66,
"lineCount": 1
}
},
{
"code": "reportOperatorIssue",
"range": {
Expand Down Expand Up @@ -6967,6 +6983,14 @@
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 15,
"endColumn": 69,
"lineCount": 1
}
},
{
"code": "reportOperatorIssue",
"range": {
Expand Down
32 changes: 32 additions & 0 deletions pymbolic/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
else:
return self.handle_unsupported_expression(expr, *args, **kwargs)
else:
return self.map_foreign(expr, *args, **kwargs)

Check warning on line 208 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 208 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 208 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 208 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

rec = __call__
"""Identical to :meth:`__call__`, but intended for use in recursive dispatch
Expand Down Expand Up @@ -284,6 +284,10 @@
expr: p.Power, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_matmul(self,
expr: p.Matmul, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_constant(self,
expr: object, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
"""Mapper method for constants.
Expand Down Expand Up @@ -596,6 +600,11 @@
self.rec(expr.base, *args, **kwargs),
self.rec(expr.exponent, *args, **kwargs)))

@override
def map_matmul(self,
expr: p.Matmul, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.combine(self.rec(child, *args, **kwargs) for child in expr.children)

@override
def map_left_shift(self,
expr: p.LeftShift, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
Expand Down Expand Up @@ -954,6 +963,17 @@
return expr
return type(expr)(base, exponent)

@override
def map_matmul(self,
expr: p.Matmul, /, *args: P.args, **kwargs: P.kwargs
) -> Expression:
children = [self.rec_arith(child, *args, **kwargs) for child in expr.children]
if all(child is orig_child for child, orig_child in
zip(children, expr.children, strict=True)):
return expr

return type(expr)(tuple(children))

@override
def map_left_shift(self,
expr: p.LeftShift, /, *args: P.args, **kwargs: P.kwargs
Expand Down Expand Up @@ -1380,6 +1400,17 @@

self.post_visit(expr, *args, **kwargs)

@override
def map_matmul(self, expr: p.Matmul, /,
*args: P.args, **kwargs: P.kwargs) -> None:
if not self.visit(expr, *args, **kwargs):
return

for child in expr.children:
self.rec(child, *args, **kwargs)

self.post_visit(expr, *args, **kwargs)

@override
def map_tuple(self,
expr: tuple[Expression, ...], /, *args: P.args, **kwargs: P.kwargs
Expand Down Expand Up @@ -1640,6 +1671,7 @@
map_floor_div = map_constant
map_remainder = map_constant
map_power = map_constant
map_matmul = map_constant

map_left_shift = map_constant
map_right_shift = map_constant
Expand Down
20 changes: 10 additions & 10 deletions pymbolic/mapper/c_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
PREC_LOGICAL_AND,
PREC_LOGICAL_OR,
PREC_NONE,
PREC_POWER,
PREC_PRODUCT,
PREC_UNARY,
SimplifyingSortingStringifyMapper,
)
Expand Down Expand Up @@ -126,7 +128,6 @@ def copy_with_mapped_cses(

@override
def map_product(self, expr: p.Product, /, enclosing_prec: int) -> str:
from pymbolic.mapper.stringifier import PREC_PRODUCT
return self.parenthesize_if_needed(
# Spaces prevent '**z' (times dereference z), which
# is hard to read.
Expand Down Expand Up @@ -161,26 +162,27 @@ def map_call(self, expr: p.Call, /, enclosing_prec: int) -> str:

@override
def map_power(self, expr: p.Power, /, enclosing_prec: int) -> str:
from pymbolic.mapper.stringifier import PREC_NONE
from pymbolic.primitives import is_constant, is_zero
if is_constant(expr.exponent):
if is_zero(expr.exponent):
if p.is_constant(expr.exponent):
if p.is_zero(expr.exponent):
return "1"
elif is_zero(expr.exponent - 1):
elif p.is_zero(expr.exponent - 1):
return self.rec(expr.base, enclosing_prec)
elif is_zero(expr.exponent - 2):
elif p.is_zero(expr.exponent - 2):
return self.rec(expr.base*expr.base, enclosing_prec)

return self.format("pow(%s, %s)",
self.rec(expr.base, PREC_NONE),
self.rec(expr.exponent, PREC_NONE))

@override
def map_matmul(self, expr: p.Matmul, /, enclosing_prec: int) -> str:
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

@override
def map_floor_div(self, expr: p.FloorDiv, /, enclosing_prec: int) -> str:
# Let's see how bad of an idea this is--sane people would only
# apply this to integers, right?

from pymbolic.mapper.stringifier import PREC_POWER, PREC_PRODUCT
return self.format("(%s/%s)",
self.rec(expr.numerator, PREC_PRODUCT),
self.rec(expr.denominator, PREC_POWER)) # analogous to ^{-1}
Expand Down Expand Up @@ -208,7 +210,6 @@ def map_common_subexpression(
try:
cse_name = self.cse_to_name[expr.child]
except KeyError:
from pymbolic.mapper.stringifier import PREC_NONE
cse_str = self.rec(expr.child, PREC_NONE)

if expr.prefix is not None:
Expand Down Expand Up @@ -241,7 +242,6 @@ def generate_cse_names() -> Iterator[str]:

@override
def map_if(self, expr: p.If, /, enclosing_prec: int) -> str:
from pymbolic.mapper.stringifier import PREC_NONE
return self.format("(%s ? %s : %s)",
self.rec(expr.condition, PREC_NONE),
self.rec(expr.then, PREC_NONE),
Expand Down
6 changes: 2 additions & 4 deletions pymbolic/mapper/coefficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,14 @@ def map_product(self, expr: p.Product, /) -> CoeffsT:

@override
def map_quotient(self, expr: p.Quotient, /) -> CoeffsT:
from pymbolic.primitives import Quotient
d_num = dict(self.rec(expr.numerator))
d_den = self.rec(expr.denominator)
# d_den should look like {1: k}
if len(d_den) > 1 or 1 not in d_den:
raise RuntimeError("nonlinear expression")
val = d_den[1]
for k in d_num:
d_num[k] = p.flattened_product((d_num[k], Quotient(1, val)))
d_num[k] = p.flattened_product((d_num[k], p.Quotient(1, val)))
return d_num

@override
Expand All @@ -111,8 +110,7 @@ def map_power(self, expr: p.Power, /) -> CoeffsT:
@override
def map_constant(self, expr: object, /) -> CoeffsT:
assert p.is_arithmetic_expression(expr)
from pymbolic.primitives import is_zero
return {} if is_zero(expr) else {1: expr}
return {} if p.is_zero(expr) else {1: expr}

@override
def map_variable(self, expr: p.Variable, /) -> CoeffsT:
Expand Down
10 changes: 4 additions & 6 deletions pymbolic/mapper/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,22 @@ def split_term(self, mul_term: Expression) -> tuple[

The argument `product' has to be fully expanded already.
"""
from pymbolic.primitives import AlgebraicLeaf, Power, Product

def base(term: Expression) -> ArithmeticExpression:
if isinstance(term, Power):
if isinstance(term, p.Power):
return term.base
else:
assert p.is_arithmetic_expression(term)
return term

def exponent(term: Expression) -> ArithmeticExpression:
if isinstance(term, Power):
if isinstance(term, p.Power):
return term.exponent
else:
return 1

if isinstance(mul_term, Product):
if isinstance(mul_term, p.Product):
terms: Sequence[Expression] = mul_term.children
elif (isinstance(mul_term, Power | AlgebraicLeaf)
elif (isinstance(mul_term, (p.Power, p.AlgebraicLeaf))
or not bool(self.get_dependencies(mul_term))):
terms = [mul_term]
else:
Expand Down
8 changes: 2 additions & 6 deletions pymbolic/mapper/constant_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,15 @@ def fold(self,
def map_sum(self, expr: prim.Sum, /) -> Expression:
import operator

from pymbolic.primitives import flattened_sum

return self.fold(expr, operator.add, flattened_sum)
return self.fold(expr, operator.add, prim.flattened_sum)


class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase):
@override
def map_product(self, expr: prim.Product, /) -> Expression:
import operator

from pymbolic.primitives import flattened_product

return self.fold(expr, operator.mul, flattened_product)
return self.fold(expr, operator.mul, prim.flattened_product)


class ConstantFoldingMapper(
Expand Down
1 change: 1 addition & 0 deletions pymbolic/mapper/cse_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _map_subexpr(self, expr: prim.ExpressionNode, /) -> Expression:
map_floor_div: Callable[[Self, prim.FloorDiv], Expression] = _map_subexpr
map_remainder: Callable[[Self, prim.Remainder], Expression] = _map_subexpr
map_power: Callable[[Self, prim.Power], Expression] = _map_subexpr
map_matmul: Callable[[Self, prim.Matmul], Expression] = _map_subexpr

map_left_shift: Callable[[Self, prim.LeftShift], Expression] = _map_subexpr
map_right_shift: Callable[[Self, prim.RightShift], Expression] = _map_subexpr
Expand Down
4 changes: 1 addition & 3 deletions pymbolic/mapper/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,14 @@ def map_quotient(self, expr: p.Quotient, /) -> Expression:

@override
def map_power(self, expr: p.Power, /) -> Expression:
from pymbolic.primitives import Sum

newbase = self.rec(expr.base)
if isinstance(newbase, p.Product):
return self.rec(p.flattened_product([
child**expr.exponent for child in newbase.children
]))

if isinstance(expr.exponent, int):
if isinstance(newbase, Sum):
if isinstance(newbase, p.Sum):
return self.rec(p.flattened_product(expr.exponent*(newbase,)))
else:
return IdentityMapper.map_power(self, expr)
Expand Down
12 changes: 10 additions & 2 deletions pymbolic/mapper/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,11 @@ def map_sum(self, expr: p.Sum, /) -> ResultT:

@override
def map_product(self, expr: p.Product, /) -> ResultT:
from pytools import product
return cast("ResultT", product(self.rec(child) for child in expr.children))
if not expr.children:
return cast("ResultT", 1)

from operator import mul
return reduce(mul, (self.rec(ch) for ch in expr.children))

@override
def map_quotient(self, expr: p.Quotient, /) -> ResultT:
Expand All @@ -151,6 +154,11 @@ def map_remainder(self, expr: p.Remainder, /) -> ResultT:
def map_power(self, expr: p.Power, /) -> ResultT:
return self.rec(expr.base) ** self.rec(expr.exponent)

@override
def map_matmul(self, expr: p.Matmul, /) -> ResultT:
from operator import matmul
return reduce(matmul, (self.rec(ch) for ch in expr.children))

@override
def map_left_shift(self, expr: p.LeftShift, /) -> ResultT:
return self.rec(expr.shiftee) << self.rec(expr.shift)
Expand Down
10 changes: 2 additions & 8 deletions pymbolic/mapper/flattener.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,11 @@ def is_expr_integer_valued(self, expr: Expression, /) -> bool:

@override
def map_sum(self, expr: p.Sum, /) -> Expression:
from pymbolic.primitives import flattened_sum
return flattened_sum([
self.rec_arith(ch)
for ch in expr.children])
return p.flattened_sum([self.rec_arith(ch) for ch in expr.children])

@override
def map_product(self, expr: p.Product, /) -> Expression:
from pymbolic.primitives import flattened_product
return flattened_product([
self.rec_arith(ch)
for ch in expr.children])
return p.flattened_product([self.rec_arith(ch) for ch in expr.children])

@override
def map_quotient(self, expr: p.Quotient, /) -> Expression:
Expand Down
12 changes: 9 additions & 3 deletions pymbolic/mapper/flop_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,20 @@ def map_constant(self, expr: object) -> ArithmeticExpression:
def map_variable(self, expr: p.Variable) -> ArithmeticExpression:
return 0

@override
def map_sum(self, expr: p.Sum | p.Product) -> ArithmeticExpression:
def _count_children(
self, expr: p.Sum | p.Product | p.Matmul
) -> ArithmeticExpression:
if expr.children:
return len(expr.children) - 1 + sum(self.rec(ch) for ch in expr.children)
else:
return 0

map_product: Callable[[Self, p.Product], ArithmeticExpression] = map_sum
map_sum: Callable[[Self, p.Sum], ArithmeticExpression] = _count_children
map_product: Callable[[Self, p.Product], ArithmeticExpression] = _count_children

@override
def map_matmul(self, expr: p.Matmul) -> ArithmeticExpression:
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

@override
def map_quotient(self, expr: p.Quotient | p.FloorDiv) -> ArithmeticExpression:
Expand Down
18 changes: 14 additions & 4 deletions pymbolic/mapper/graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@

from typing_extensions import Self, override

import pymbolic.primitives as prim
from pymbolic.mapper import WalkMapper


if TYPE_CHECKING:
from collections.abc import Callable, Hashable

import pymbolic.primitives as prim
from pymbolic.geometric_algebra.primitives import Nabla, NablaComponent


Expand Down Expand Up @@ -143,6 +143,18 @@ def map_product(self, expr: prim.Product, /) -> None:

self.post_visit(expr)

@override
def map_matmul(self, expr: prim.Matmul, /) -> None:
sid = self.get_id(expr)
self.lines.append(f'{sid} [label="@",shape=circle];')
if not self.visit(expr, node_printed=True):
return

for child in expr.children:
self.rec(child)

self.post_visit(expr)

@override
def map_variable(self, expr: prim.Variable, /) -> None:
# Shared nodes for variables do not make for pretty graphs.
Expand Down Expand Up @@ -182,9 +194,7 @@ def map_constant(self, expr: object) -> None:

@override
def map_call(self, expr: prim.Call) -> None:
from pymbolic.primitives import Variable

if not isinstance(expr.function, Variable):
if not isinstance(expr.function, prim.Variable):
return super().map_call(expr)

sid = self.get_id(expr)
Expand Down
Loading
Loading