Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
0d90c41
feat: make mapper.coefficient positional-only
alexfikl Oct 31, 2025
1052c42
feat: make mapper.evaluator positional-only
alexfikl Nov 3, 2025
f254c2e
feat: make mapper.flattener positional-only
alexfikl Oct 31, 2025
7630578
feat(typing): use ExpressionNode in handle_unsupported_expression
alexfikl Nov 1, 2025
1f4ea3b
feat(typing): add annotations to primitives.quotient
alexfikl Oct 31, 2025
7f8ba20
feat(typing): add annotations to functions
alexfikl Oct 31, 2025
c32bf66
feat(typing): add types to mapper.analysis
alexfikl Oct 31, 2025
8d7ebfe
feat(typing): improve types in mapper.collector
alexfikl Oct 31, 2025
f7b5146
feat(typing): improve types in mapper.constant_folder
alexfikl Oct 31, 2025
f503b9d
feat(typing): improve types in mapper.cse_tagger
alexfikl Oct 31, 2025
c4ff0a7
feat(typing): improve types in mapper.dependency
alexfikl Oct 31, 2025
d6a7a9d
feat(typing): improve types in mapper.graphviz
alexfikl Oct 31, 2025
a8b592f
feat(typing): improve types in mapper.distributor
alexfikl Oct 31, 2025
eb1b4a3
feat(typing): improve types in mapper.differentiator
alexfikl Oct 31, 2025
d87b3fc
feat(typing): improve types in mapper.flop_counter
alexfikl Nov 1, 2025
e5bb8af
feat(typing): improve types in mapper.optimize
alexfikl Nov 1, 2025
98ad668
feat(typing): improve types in mapper.stringifier
alexfikl Nov 1, 2025
e928623
fet(typing): add type annotations to mapper.c_code
alexfikl Nov 1, 2025
73f5b3e
feat: use more f-strings in mapper
alexfikl Nov 1, 2025
23a47b4
feat: improve some exception messages
alexfikl Nov 1, 2025
a344d10
feat: use C complex numbers in CCodeMapper
alexfikl Nov 2, 2025
84c8f22
chore: update baseline
alexfikl Nov 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9,010 changes: 916 additions & 8,094 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

27 changes: 16 additions & 11 deletions pymbolic/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,49 +23,54 @@
THE SOFTWARE.
"""

from typing import TYPE_CHECKING

import pymbolic.primitives as p


def sin(x):
if TYPE_CHECKING:
from pymbolic.typing import ArithmeticExpression


def sin(x: ArithmeticExpression) -> ArithmeticExpression:
return p.Call(p.Lookup(p.Variable("math"), "sin"), (x,))


def cos(x):
def cos(x: ArithmeticExpression) -> ArithmeticExpression:
return p.Call(p.Lookup(p.Variable("math"), "cos"), (x,))


def tan(x):
def tan(x: ArithmeticExpression) -> ArithmeticExpression:
return p.Call(p.Lookup(p.Variable("math"), "tan"), (x,))


def log(x):
def log(x: ArithmeticExpression) -> ArithmeticExpression:
return p.Call(p.Lookup(p.Variable("math"), "log"), (x,))


def exp(x):
def exp(x: ArithmeticExpression) -> ArithmeticExpression:
return p.Call(p.Lookup(p.Variable("math"), "exp"), (x,))


def sinh(x):
def sinh(x: ArithmeticExpression) -> ArithmeticExpression:
return p.Call(p.Lookup(p.Variable("math"), "sinh"), (x,))


def cosh(x):
def cosh(x: ArithmeticExpression) -> ArithmeticExpression:
return p.Call(p.Lookup(p.Variable("math"), "cosh"), (x,))


def tanh(x):
def tanh(x: ArithmeticExpression) -> ArithmeticExpression:
return p.Call(p.Lookup(p.Variable("math"), "tanh"), (x,))


def expm1(x):
def expm1(x: ArithmeticExpression) -> ArithmeticExpression:
return p.Call(p.Lookup(p.Variable("math"), "expm1"), (x,))


def fabs(x):
def fabs(x: ArithmeticExpression) -> ArithmeticExpression:
return p.Call(p.Lookup(p.Variable("math"), "fabs"), (x,))


def sign(x):
def sign(x: ArithmeticExpression) -> ArithmeticExpression:
return p.Call(p.Lookup(p.Variable("math"), "copysign"), (1, x,))
2 changes: 1 addition & 1 deletion pymbolic/geometric_algebra/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@


class MultiVectorVariable(Variable):
mapper_method = "map_multivector_variable"
mapper_method: ClassVar[str] = "map_multivector_variable"


# {{{ geometric calculus
Expand Down
86 changes: 42 additions & 44 deletions pymbolic/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,14 @@
"""

def handle_unsupported_expression(self,
expr: object, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
expr: p.ExpressionNode, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
"""Mapper method that is invoked for
:class:`pymbolic.ExpressionNode` subclasses for which a mapper
method does not exist in this mapper.
"""

raise UnsupportedExpressionError(
"{} cannot handle expressions of type {}".format(
type(self), type(expr)))
f"{type(self)} cannot handle expressions of type {type(expr)}")

def __call__(self,
expr: Expression, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
Expand Down Expand Up @@ -202,7 +201,7 @@
else:
return self.handle_unsupported_expression(expr, *args, **kwargs)
else:
return self.map_foreign(expr, *args, **kwargs)

Check warning on line 204 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 204 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 204 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 204 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

rec = __call__

Expand All @@ -223,7 +222,7 @@
def map_algebraic_leaf(self,
expr: p.AlgebraicLeaf, /,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_variable(self,
expr: p.Variable, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
Expand All @@ -248,110 +247,110 @@

def map_if(self,
expr: p.If, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_sum(self,
expr: p.Sum, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_product(self,
expr: p.Product, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_rational(self,
expr: Rational, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_quotient(self,
expr: p.Quotient, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_floor_div(self,
expr: p.FloorDiv, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_remainder(self,
expr: p.Remainder, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_power(self,
expr: p.Power, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_constant(self,
expr: object, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_comparison(self,
expr: p.Comparison, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_min(self,
expr: p.Min, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_max(self,
expr: p.Max, /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_list(self,
expr: list[Expression], /, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_tuple(self,
expr: tuple[Expression, ...], /,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_numpy_array(self,
expr: NDArray[np.generic], /,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_left_shift(self,
expr: p.LeftShift, /, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_right_shift(self,
expr: p.RightShift, /, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_bitwise_not(self,
expr: p.BitwiseNot, /, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_bitwise_or(self,
expr: p.BitwiseOr, /, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_bitwise_and(self,
expr: p.BitwiseAnd, /, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_bitwise_xor(self,
expr: p.BitwiseXor, /, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_logical_not(self,
expr: p.LogicalNot, /, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_logical_or(self,
expr: p.LogicalOr, /, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_logical_and(self,
expr: p.LogicalAnd, /, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_nan(self,
expr: p.NaN, /, *args: P.args, **kwargs: P.kwargs
Expand All @@ -361,44 +360,44 @@
def map_wildcard(self,
expr: p.Wildcard, /, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_dot_wildcard(self,
expr: p.DotWildcard, /, *args: P.args, **kwargs: P.kwargs
) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_star_wildcard(self, expr: p.StarWildcard, /,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_function_symbol(self, expr: p.FunctionSymbol, /,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_multivector(self,
expr: MultiVector[ArithmeticExpression], /,
*args: P.args, **kwargs: P.kwargs
) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

# def map_common_subexpression deliberately unimplemented to avoid breaking
# multiple inheritance with CSE-caching mappers

def map_substitution(self,
expr: p.Substitution, /,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_derivative(self,
expr: p.Derivative, /,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_slice(self,
expr: p.Slice, /,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError
raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}")

def map_foreign(self,
expr: object, /,
Expand All @@ -421,8 +420,7 @@
return self.map_list(cast("list[Expression]", expr), *args, **kwargs)
else:
raise ValueError(
"{} encountered invalid foreign object: {}".format(
self.__class__, repr(expr)))
f"{type(self)} encountered invalid foreign object: {expr!r}")


class _NotInCache:
Expand Down Expand Up @@ -503,10 +501,7 @@
the current expression, and then call :meth:`combine` on a tuple of
results.

.. method:: combine(values)

Combine the mapped results of multiple expressions (given in *values*)
into a single result, often by summing or taking set unions.
.. automethod:: combine

The :class:`pymbolic.mapper.flop_counter.FlopCounter` is a very simple
example. (Look at its source for an idea of how to derive from
Expand All @@ -515,7 +510,10 @@
"""

def combine(self, values: Iterable[ResultT], /) -> ResultT:
raise NotImplementedError
"""Combine the mapped results of multiple expressions (given in *values*)
into a single result, often by summing or taking set unions.
"""
raise NotImplementedError(type(self).__name__)

@override
def map_call(self,
Expand Down Expand Up @@ -1652,7 +1650,7 @@
:class:`pymbolic.primitives.CommonSubexpression`,
subclasses should implement the following method:

.. method:: map_common_subexpression_uncached(expr)
.. automethod:: map_common_subexpression_uncached

This method deliberately does not support extra arguments in mapper
dispatch, to avoid spurious dependencies of the cache on these arguments.
Expand Down
24 changes: 18 additions & 6 deletions pymbolic/mapper/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,17 @@
THE SOFTWARE.
"""

from typing import TYPE_CHECKING

from typing_extensions import override

from pymbolic.mapper import CachedWalkMapper


if TYPE_CHECKING:
from pymbolic.typing import Expression


__doc__ = """
.. autoclass:: NodeCountMapper
.. autofunction:: get_num_nodes
Expand All @@ -35,7 +42,7 @@

# {{{ NodeCountMapper

class NodeCountMapper(CachedWalkMapper):
class NodeCountMapper(CachedWalkMapper[[]]):
"""
Counts the number of nodes in an expression tree. Nodes that occur
repeatedly as well as :class:`~pymbolic.primitives.CommonSubexpression`
Expand All @@ -46,18 +53,23 @@ class NodeCountMapper(CachedWalkMapper):
The number of nodes.
"""

count: int

def __init__(self) -> None:
super().__init__()
self.count = 0

def post_visit(self, expr) -> None:
@override
def post_visit(self, expr: object) -> None:
self.count += 1


def get_num_nodes(expr) -> int:
"""Returns the number of nodes in *expr*. Nodes that occur
repeatedly as well as :class:`~pymbolic.primitives.CommonSubexpression`
nodes are only counted once."""
def get_num_nodes(expr: Expression) -> int:
"""
:returns: the number of nodes in *expr*. Nodes that occur
repeatedly as well as :class:`~pymbolic.primitives.CommonSubexpression`
nodes are only counted once.
"""

ncm = NodeCountMapper()
ncm(expr)
Expand Down
Loading
Loading