From e2aadc314219e924233a63fc7ca8e926bae77867 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Thu, 12 Feb 2026 10:36:16 +0200 Subject: [PATCH 1/5] feat: add support for __matmul__ in ExpressionNode --- pymbolic/primitives.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index d7175ed..1664330 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -503,6 +503,18 @@ def __rpow__(self, other: ArithmeticExpression) -> Power: return Power(other, self) + def __matmul__(self, other: ArithmeticExpression) -> Matmul: + if not is_arithmetic_expression(other): + return NotImplemented + + return Matmul((self, other)) + + def __rmatmul__(self, other: ArithmeticExpression) -> Matmul: + if not is_arithmetic_expression(other): + return NotImplemented + + return Matmul((other, self)) + # }}} # {{{ shifts @@ -1330,6 +1342,16 @@ class Power(ExpressionNode): base: ArithmeticExpression exponent: ArithmeticExpression + +@expr_dataclass() +class Matmul(ExpressionNode): + """Matrix multiplication operator ``@``. + + .. autoattribute:: children + """ + + children: tuple[ArithmeticExpression, ...] + # }}} From 96a81daa39fab12e60e4d1d71671ae41aeb52db2 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Thu, 12 Feb 2026 10:37:25 +0200 Subject: [PATCH 2/5] feat: add support for new Matmul node in mappers --- pymbolic/mapper/__init__.py | 32 ++++++++++++++++++++++++++++++++ pymbolic/mapper/c_code.py | 4 ++++ pymbolic/mapper/cse_tagger.py | 1 + pymbolic/mapper/evaluator.py | 12 ++++++++++-- pymbolic/mapper/flop_counter.py | 12 +++++++++--- pymbolic/mapper/graphviz.py | 12 ++++++++++++ pymbolic/mapper/stringifier.py | 28 ++++++++++++++++++++++++++++ 7 files changed, 96 insertions(+), 5 deletions(-) diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index eb8074d..44d402d 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -284,6 +284,10 @@ def map_power(self, expr: p.Power, /, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}") + def map_matmul(self, + expr: p.Matmul, /, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}") + def map_constant(self, expr: object, /, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Mapper method for constants. @@ -596,6 +600,11 @@ def map_power(self, self.rec(expr.base, *args, **kwargs), self.rec(expr.exponent, *args, **kwargs))) + @override + def map_matmul(self, + expr: p.Matmul, /, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) for child in expr.children) + @override def map_left_shift(self, expr: p.LeftShift, /, *args: P.args, **kwargs: P.kwargs) -> ResultT: @@ -954,6 +963,17 @@ def map_power(self, return expr return type(expr)(base, exponent) + @override + def map_matmul(self, + expr: p.Matmul, /, *args: P.args, **kwargs: P.kwargs + ) -> Expression: + children = [self.rec_arith(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child for child, orig_child in + zip(children, expr.children, strict=True)): + return expr + + return type(expr)(tuple(children)) + @override def map_left_shift(self, expr: p.LeftShift, /, *args: P.args, **kwargs: P.kwargs @@ -1380,6 +1400,17 @@ def map_power(self, expr: p.Power, /, *args: P.args, **kwargs: P.kwargs) -> None self.post_visit(expr, *args, **kwargs) + @override + def map_matmul(self, expr: p.Matmul, /, + *args: P.args, **kwargs: P.kwargs) -> None: + if not self.visit(expr, *args, **kwargs): + return + + for child in expr.children: + self.rec(child, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + @override def map_tuple(self, expr: tuple[Expression, ...], /, *args: P.args, **kwargs: P.kwargs @@ -1640,6 +1671,7 @@ def map_constant(self, expr: object, /, map_floor_div = map_constant map_remainder = map_constant map_power = map_constant + map_matmul = map_constant map_left_shift = map_constant map_right_shift = map_constant diff --git a/pymbolic/mapper/c_code.py b/pymbolic/mapper/c_code.py index 05a65b2..4317aa1 100644 --- a/pymbolic/mapper/c_code.py +++ b/pymbolic/mapper/c_code.py @@ -175,6 +175,10 @@ def map_power(self, expr: p.Power, /, enclosing_prec: int) -> str: self.rec(expr.base, PREC_NONE), self.rec(expr.exponent, PREC_NONE)) + @override + def map_matmul(self, expr: p.Matmul, /, enclosing_prec: int) -> str: + raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}") + @override def map_floor_div(self, expr: p.FloorDiv, /, enclosing_prec: int) -> str: # Let's see how bad of an idea this is--sane people would only diff --git a/pymbolic/mapper/cse_tagger.py b/pymbolic/mapper/cse_tagger.py index 5c1574a..4d9a978 100644 --- a/pymbolic/mapper/cse_tagger.py +++ b/pymbolic/mapper/cse_tagger.py @@ -68,6 +68,7 @@ def _map_subexpr(self, expr: prim.ExpressionNode, /) -> Expression: map_floor_div: Callable[[Self, prim.FloorDiv], Expression] = _map_subexpr map_remainder: Callable[[Self, prim.Remainder], Expression] = _map_subexpr map_power: Callable[[Self, prim.Power], Expression] = _map_subexpr + map_matmul: Callable[[Self, prim.Matmul], Expression] = _map_subexpr map_left_shift: Callable[[Self, prim.LeftShift], Expression] = _map_subexpr map_right_shift: Callable[[Self, prim.RightShift], Expression] = _map_subexpr diff --git a/pymbolic/mapper/evaluator.py b/pymbolic/mapper/evaluator.py index 9ffd135..35cadcd 100644 --- a/pymbolic/mapper/evaluator.py +++ b/pymbolic/mapper/evaluator.py @@ -132,8 +132,11 @@ def map_sum(self, expr: p.Sum, /) -> ResultT: @override def map_product(self, expr: p.Product, /) -> ResultT: - from pytools import product - return cast("ResultT", product(self.rec(child) for child in expr.children)) + if not expr.children: + return cast("ResultT", 1) + + from operator import mul + return reduce(mul, (self.rec(ch) for ch in expr.children)) @override def map_quotient(self, expr: p.Quotient, /) -> ResultT: @@ -151,6 +154,11 @@ def map_remainder(self, expr: p.Remainder, /) -> ResultT: def map_power(self, expr: p.Power, /) -> ResultT: return self.rec(expr.base) ** self.rec(expr.exponent) + @override + def map_matmul(self, expr: p.Matmul, /) -> ResultT: + from operator import matmul + return reduce(matmul, (self.rec(ch) for ch in expr.children)) + @override def map_left_shift(self, expr: p.LeftShift, /) -> ResultT: return self.rec(expr.shiftee) << self.rec(expr.shift) diff --git a/pymbolic/mapper/flop_counter.py b/pymbolic/mapper/flop_counter.py index 0d7ecc0..83732c2 100644 --- a/pymbolic/mapper/flop_counter.py +++ b/pymbolic/mapper/flop_counter.py @@ -49,14 +49,20 @@ def map_constant(self, expr: object) -> ArithmeticExpression: def map_variable(self, expr: p.Variable) -> ArithmeticExpression: return 0 - @override - def map_sum(self, expr: p.Sum | p.Product) -> ArithmeticExpression: + def _count_children( + self, expr: p.Sum | p.Product | p.Matmul + ) -> ArithmeticExpression: if expr.children: return len(expr.children) - 1 + sum(self.rec(ch) for ch in expr.children) else: return 0 - map_product: Callable[[Self, p.Product], ArithmeticExpression] = map_sum + map_sum: Callable[[Self, p.Sum], ArithmeticExpression] = _count_children + map_product: Callable[[Self, p.Product], ArithmeticExpression] = _count_children + + @override + def map_matmul(self, expr: p.Matmul) -> ArithmeticExpression: + raise NotImplementedError(f"{type(self).__name__} cannot handle {type(expr)}") @override def map_quotient(self, expr: p.Quotient | p.FloorDiv) -> ArithmeticExpression: diff --git a/pymbolic/mapper/graphviz.py b/pymbolic/mapper/graphviz.py index b33c7f1..4904c00 100644 --- a/pymbolic/mapper/graphviz.py +++ b/pymbolic/mapper/graphviz.py @@ -143,6 +143,18 @@ def map_product(self, expr: prim.Product, /) -> None: self.post_visit(expr) + @override + def map_matmul(self, expr: prim.Matmul, /) -> None: + sid = self.get_id(expr) + self.lines.append(f'{sid} [label="@",shape=circle];') + if not self.visit(expr, node_printed=True): + return + + for child in expr.children: + self.rec(child) + + self.post_visit(expr) + @override def map_variable(self, expr: prim.Variable, /) -> None: # Shared nodes for variables do not make for pretty graphs. diff --git a/pymbolic/mapper/stringifier.py b/pymbolic/mapper/stringifier.py index c85fd48..a57d346 100644 --- a/pymbolic/mapper/stringifier.py +++ b/pymbolic/mapper/stringifier.py @@ -463,6 +463,23 @@ def map_power( PREC_POWER, ) + @override + def map_matmul( + self, expr: p.Matmul, /, enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + return self.parenthesize_if_needed( + self.join_rec_with_parens_around_types( + " @ ", + expr.children, + PREC_PRODUCT, + self.multiplicative_primitives, + *args, + **kwargs, + ), + enclosing_prec, + PREC_PRODUCT, + ) + @override def map_left_shift( self, expr: p.LeftShift, /, @@ -1133,6 +1150,17 @@ def map_max( self.join_rec(", ", expr.children, PREC_NONE, *args, **kwargs), ) + @override + def map_matmul( + self, expr: p.Matmul, /, + enclosing_prec: int, *args: P.args, **kwargs: P.kwargs + ) -> str: + return self.parenthesize_if_needed( + self.join_rec(r" \cdot ", expr.children, PREC_PRODUCT, *args, **kwargs), + enclosing_prec, + PREC_PRODUCT, + ) + @override def map_floor_div( self, expr: p.FloorDiv, /, From 8c22ba007f51012b60e9e9ed1934521b1988286c Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Thu, 12 Feb 2026 10:41:44 +0200 Subject: [PATCH 3/5] feat: remove some unneeded imports --- pymbolic/mapper/c_code.py | 16 ++++++---------- pymbolic/mapper/coefficient.py | 6 ++---- pymbolic/mapper/collector.py | 10 ++++------ pymbolic/mapper/constant_folder.py | 8 ++------ pymbolic/mapper/distributor.py | 4 +--- pymbolic/mapper/flattener.py | 10 ++-------- pymbolic/mapper/graphviz.py | 6 ++---- pymbolic/mapper/substitutor.py | 4 ++-- pymbolic/mapper/unifier.py | 28 +++++++++++----------------- 9 files changed, 32 insertions(+), 60 deletions(-) diff --git a/pymbolic/mapper/c_code.py b/pymbolic/mapper/c_code.py index 4317aa1..a71c0a7 100644 --- a/pymbolic/mapper/c_code.py +++ b/pymbolic/mapper/c_code.py @@ -37,6 +37,8 @@ PREC_LOGICAL_AND, PREC_LOGICAL_OR, PREC_NONE, + PREC_POWER, + PREC_PRODUCT, PREC_UNARY, SimplifyingSortingStringifyMapper, ) @@ -126,7 +128,6 @@ def copy_with_mapped_cses( @override def map_product(self, expr: p.Product, /, enclosing_prec: int) -> str: - from pymbolic.mapper.stringifier import PREC_PRODUCT return self.parenthesize_if_needed( # Spaces prevent '**z' (times dereference z), which # is hard to read. @@ -161,14 +162,12 @@ def map_call(self, expr: p.Call, /, enclosing_prec: int) -> str: @override def map_power(self, expr: p.Power, /, enclosing_prec: int) -> str: - from pymbolic.mapper.stringifier import PREC_NONE - from pymbolic.primitives import is_constant, is_zero - if is_constant(expr.exponent): - if is_zero(expr.exponent): + if p.is_constant(expr.exponent): + if p.is_zero(expr.exponent): return "1" - elif is_zero(expr.exponent - 1): + elif p.is_zero(expr.exponent - 1): return self.rec(expr.base, enclosing_prec) - elif is_zero(expr.exponent - 2): + elif p.is_zero(expr.exponent - 2): return self.rec(expr.base*expr.base, enclosing_prec) return self.format("pow(%s, %s)", @@ -184,7 +183,6 @@ def map_floor_div(self, expr: p.FloorDiv, /, enclosing_prec: int) -> str: # Let's see how bad of an idea this is--sane people would only # apply this to integers, right? - from pymbolic.mapper.stringifier import PREC_POWER, PREC_PRODUCT return self.format("(%s/%s)", self.rec(expr.numerator, PREC_PRODUCT), self.rec(expr.denominator, PREC_POWER)) # analogous to ^{-1} @@ -212,7 +210,6 @@ def map_common_subexpression( try: cse_name = self.cse_to_name[expr.child] except KeyError: - from pymbolic.mapper.stringifier import PREC_NONE cse_str = self.rec(expr.child, PREC_NONE) if expr.prefix is not None: @@ -245,7 +242,6 @@ def generate_cse_names() -> Iterator[str]: @override def map_if(self, expr: p.If, /, enclosing_prec: int) -> str: - from pymbolic.mapper.stringifier import PREC_NONE return self.format("(%s ? %s : %s)", self.rec(expr.condition, PREC_NONE), self.rec(expr.then, PREC_NONE), diff --git a/pymbolic/mapper/coefficient.py b/pymbolic/mapper/coefficient.py index bfe44cd..bb49976 100644 --- a/pymbolic/mapper/coefficient.py +++ b/pymbolic/mapper/coefficient.py @@ -85,7 +85,6 @@ def map_product(self, expr: p.Product, /) -> CoeffsT: @override def map_quotient(self, expr: p.Quotient, /) -> CoeffsT: - from pymbolic.primitives import Quotient d_num = dict(self.rec(expr.numerator)) d_den = self.rec(expr.denominator) # d_den should look like {1: k} @@ -93,7 +92,7 @@ def map_quotient(self, expr: p.Quotient, /) -> CoeffsT: raise RuntimeError("nonlinear expression") val = d_den[1] for k in d_num: - d_num[k] = p.flattened_product((d_num[k], Quotient(1, val))) + d_num[k] = p.flattened_product((d_num[k], p.Quotient(1, val))) return d_num @override @@ -111,8 +110,7 @@ def map_power(self, expr: p.Power, /) -> CoeffsT: @override def map_constant(self, expr: object, /) -> CoeffsT: assert p.is_arithmetic_expression(expr) - from pymbolic.primitives import is_zero - return {} if is_zero(expr) else {1: expr} + return {} if p.is_zero(expr) else {1: expr} @override def map_variable(self, expr: p.Variable, /) -> CoeffsT: diff --git a/pymbolic/mapper/collector.py b/pymbolic/mapper/collector.py index 3a61ea1..c977fe7 100644 --- a/pymbolic/mapper/collector.py +++ b/pymbolic/mapper/collector.py @@ -74,24 +74,22 @@ def split_term(self, mul_term: Expression) -> tuple[ The argument `product' has to be fully expanded already. """ - from pymbolic.primitives import AlgebraicLeaf, Power, Product - def base(term: Expression) -> ArithmeticExpression: - if isinstance(term, Power): + if isinstance(term, p.Power): return term.base else: assert p.is_arithmetic_expression(term) return term def exponent(term: Expression) -> ArithmeticExpression: - if isinstance(term, Power): + if isinstance(term, p.Power): return term.exponent else: return 1 - if isinstance(mul_term, Product): + if isinstance(mul_term, p.Product): terms: Sequence[Expression] = mul_term.children - elif (isinstance(mul_term, Power | AlgebraicLeaf) + elif (isinstance(mul_term, (p.Power, p.AlgebraicLeaf)) or not bool(self.get_dependencies(mul_term))): terms = [mul_term] else: diff --git a/pymbolic/mapper/constant_folder.py b/pymbolic/mapper/constant_folder.py index e2acae3..e584e23 100644 --- a/pymbolic/mapper/constant_folder.py +++ b/pymbolic/mapper/constant_folder.py @@ -108,9 +108,7 @@ def fold(self, def map_sum(self, expr: prim.Sum, /) -> Expression: import operator - from pymbolic.primitives import flattened_sum - - return self.fold(expr, operator.add, flattened_sum) + return self.fold(expr, operator.add, prim.flattened_sum) class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase): @@ -118,9 +116,7 @@ class CommutativeConstantFoldingMapperBase(ConstantFoldingMapperBase): def map_product(self, expr: prim.Product, /) -> Expression: import operator - from pymbolic.primitives import flattened_product - - return self.fold(expr, operator.mul, flattened_product) + return self.fold(expr, operator.mul, prim.flattened_product) class ConstantFoldingMapper( diff --git a/pymbolic/mapper/distributor.py b/pymbolic/mapper/distributor.py index 2307d8d..d1bb9e0 100644 --- a/pymbolic/mapper/distributor.py +++ b/pymbolic/mapper/distributor.py @@ -138,8 +138,6 @@ def map_quotient(self, expr: p.Quotient, /) -> Expression: @override def map_power(self, expr: p.Power, /) -> Expression: - from pymbolic.primitives import Sum - newbase = self.rec(expr.base) if isinstance(newbase, p.Product): return self.rec(p.flattened_product([ @@ -147,7 +145,7 @@ def map_power(self, expr: p.Power, /) -> Expression: ])) if isinstance(expr.exponent, int): - if isinstance(newbase, Sum): + if isinstance(newbase, p.Sum): return self.rec(p.flattened_product(expr.exponent*(newbase,))) else: return IdentityMapper.map_power(self, expr) diff --git a/pymbolic/mapper/flattener.py b/pymbolic/mapper/flattener.py index 43bfde9..621692e 100644 --- a/pymbolic/mapper/flattener.py +++ b/pymbolic/mapper/flattener.py @@ -68,17 +68,11 @@ def is_expr_integer_valued(self, expr: Expression, /) -> bool: @override def map_sum(self, expr: p.Sum, /) -> Expression: - from pymbolic.primitives import flattened_sum - return flattened_sum([ - self.rec_arith(ch) - for ch in expr.children]) + return p.flattened_sum([self.rec_arith(ch) for ch in expr.children]) @override def map_product(self, expr: p.Product, /) -> Expression: - from pymbolic.primitives import flattened_product - return flattened_product([ - self.rec_arith(ch) - for ch in expr.children]) + return p.flattened_product([self.rec_arith(ch) for ch in expr.children]) @override def map_quotient(self, expr: p.Quotient, /) -> Expression: diff --git a/pymbolic/mapper/graphviz.py b/pymbolic/mapper/graphviz.py index 4904c00..cb69afa 100644 --- a/pymbolic/mapper/graphviz.py +++ b/pymbolic/mapper/graphviz.py @@ -31,13 +31,13 @@ from typing_extensions import Self, override +import pymbolic.primitives as prim from pymbolic.mapper import WalkMapper if TYPE_CHECKING: from collections.abc import Callable, Hashable - import pymbolic.primitives as prim from pymbolic.geometric_algebra.primitives import Nabla, NablaComponent @@ -194,9 +194,7 @@ def map_constant(self, expr: object) -> None: @override def map_call(self, expr: prim.Call) -> None: - from pymbolic.primitives import Variable - - if not isinstance(expr.function, Variable): + if not isinstance(expr.function, prim.Variable): return super().map_call(expr) sid = self.get_id(expr) diff --git a/pymbolic/mapper/substitutor.py b/pymbolic/mapper/substitutor.py index c0cb4b2..32d5bdd 100644 --- a/pymbolic/mapper/substitutor.py +++ b/pymbolic/mapper/substitutor.py @@ -102,13 +102,13 @@ def make_subst_func( # e.g. https://github.com/python/typing/issues/445 variable_assignments: optype.CanGetitem[Any, Expression], ) -> Callable[[AlgebraicLeaf], Expression | None]: - import pymbolic.primitives as primitives + from pymbolic.primitives import Variable def subst_func(var: AlgebraicLeaf) -> Expression | None: try: return variable_assignments[var] except KeyError: - if isinstance(var, primitives.Variable): + if isinstance(var, Variable): try: return variable_assignments[var.name] except KeyError: diff --git a/pymbolic/mapper/unifier.py b/pymbolic/mapper/unifier.py index 0395cfc..2b0b490 100644 --- a/pymbolic/mapper/unifier.py +++ b/pymbolic/mapper/unifier.py @@ -25,19 +25,15 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Collection, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, final +from typing import final from typing_extensions import Self, override +import pymbolic.primitives as p from pymbolic.mapper import Mapper -from pymbolic.primitives import Variable from pymbolic.typing import ArithmeticExpression, Expression -if TYPE_CHECKING: - import pymbolic.primitives as p - - def unify_map( map1: Mapping[str, Expression], map2: Mapping[str, Expression] @@ -78,9 +74,9 @@ def __init__(self, rmap = {} for lhs, rhs in equations: - if isinstance(lhs, Variable): + if isinstance(lhs, p.Variable): lmap[lhs.name] = rhs - if isinstance(rhs, Variable): + if isinstance(rhs, p.Variable): rmap[rhs.name] = lhs self.lmap = lmap @@ -173,8 +169,8 @@ def unification_record_from_equation( return None - lhs_is_var = isinstance(lhs, Variable) - rhs_is_var = isinstance(rhs, Variable) + lhs_is_var = isinstance(lhs, p.Variable) + rhs_is_var = isinstance(rhs, p.Variable) if self.force_var_match and not (lhs_is_var or rhs_is_var): return None @@ -212,7 +208,7 @@ def map_variable( if new_uni_record is None: # Check if the variables match literally--that's ok, too. - if (isinstance(other, Variable) and other.name == expr.name + if (isinstance(other, p.Variable) and other.name == expr.name and ( self.lhs_mapping_candidates is None or expr.name not in self.lhs_mapping_candidates)): @@ -466,10 +462,10 @@ def map_commut_assoc( # Partition expr into terms that are plain (free) variables and those # that are not. - plain_var_candidates: list[Variable] = [] + plain_var_candidates: list[p.Variable] = [] non_var_children: list[Expression] = [] for child in expr.children: - if (isinstance(child, Variable) + if (isinstance(child, p.Variable) and ( self.lhs_mapping_candidates is None or child.name in self.lhs_mapping_candidates)): @@ -566,8 +562,7 @@ def map_sum(self, # pyright: ignore[reportIncompatibleMethodOverride] expr: p.Sum, other: Expression, urecs: Sequence[UnificationRecord]) -> Sequence[UnificationRecord]: - from pymbolic.primitives import flattened_sum - return list(self.map_commut_assoc(expr, other, urecs, flattened_sum)) + return list(self.map_commut_assoc(expr, other, urecs, p.flattened_sum)) @override def map_product( @@ -575,5 +570,4 @@ def map_product( expr: p.Product, other: Expression, urecs: Sequence[UnificationRecord]) -> Sequence[UnificationRecord]: - from pymbolic.primitives import flattened_product - return list(self.map_commut_assoc(expr, other, urecs, flattened_product)) + return list(self.map_commut_assoc(expr, other, urecs, p.flattened_product)) From 9dd86ca1e81a2a3d14c98aa6541ff196f6c0afcd Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Thu, 12 Feb 2026 10:43:37 +0200 Subject: [PATCH 4/5] test: add simple test for Matmul node --- test/test_pymbolic.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index df1cf13..9efe490 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -1156,6 +1156,34 @@ def test_subscript() -> None: # }}} +# {{{ test_matmul + +def test_matmul() -> None: + x = prim.Variable("x") + y = prim.Variable("y") + + r = x @ y + assert isinstance(r, prim.Matmul) + assert str(r) == "x @ y" + + from pymbolic.mapper.stringifier import LaTeXMapper + + assert LaTeXMapper()(r) == r"x \cdot y" + + import numpy as np + + rng = np.random.default_rng(seed=42) + xarr = rng.random((6, 10)) + yarr = rng.random((10, 8)) + + from pymbolic.mapper.evaluator import EvaluationMapper + + assert np.allclose(EvaluationMapper({"x": xarr, "y": yarr})(r), xarr @ yarr) + + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: From 7f9907d104dff38ce50756f92f994532e49a47ad Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Thu, 12 Feb 2026 10:45:43 +0200 Subject: [PATCH 5/5] chore: update baseline --- .basedpyright/baseline.json | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 923fbaa..30a31c7 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -6359,6 +6359,14 @@ "lineCount": 1 } }, + { + "code": "reportUnannotatedClassAttribute", + "range": { + "startColumn": 4, + "endColumn": 14, + "lineCount": 1 + } + }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -6903,6 +6911,14 @@ "lineCount": 1 } }, + { + "code": "reportAny", + "range": { + "startColumn": 15, + "endColumn": 66, + "lineCount": 1 + } + }, { "code": "reportOperatorIssue", "range": { @@ -6967,6 +6983,14 @@ "lineCount": 1 } }, + { + "code": "reportAny", + "range": { + "startColumn": 15, + "endColumn": 69, + "lineCount": 1 + } + }, { "code": "reportOperatorIssue", "range": {