diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 779d45e4..7d4ac86f 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1584,8 +1584,8 @@ { "code": "reportAny", "range": { - "startColumn": 30, - "endColumn": 31, + "startColumn": 27, + "endColumn": 28, "lineCount": 1 } }, @@ -2955,6 +2955,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 20, + "endColumn": 24, + "lineCount": 1 + } + }, { "code": "reportUnknownVariableType", "range": { @@ -3039,7 +3047,7 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 24, - "endColumn": 45, + "endColumn": 32, "lineCount": 1 } }, @@ -3051,6 +3059,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 16, + "endColumn": 20, + "lineCount": 1 + } + }, { "code": "reportUnknownVariableType", "range": { @@ -3135,7 +3151,7 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 20, - "endColumn": 44, + "endColumn": 31, "lineCount": 1 } }, @@ -3148,34 +3164,34 @@ } }, { - "code": "reportUnknownVariableType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 12, - "endColumn": 18, + "startColumn": 8, + "endColumn": 20, "lineCount": 1 } }, { - "code": "reportUnknownMemberType", + "code": "reportUnknownVariableType", "range": { - "startColumn": 22, - "endColumn": 35, + "startColumn": 49, + "endColumn": 55, "lineCount": 1 } }, { - "code": "reportUnknownArgumentType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 36, - "endColumn": 42, + "startColumn": 59, + "endColumn": 72, "lineCount": 1 } }, { - "code": "reportUnknownMemberType", + "code": "reportUnknownArgumentType", "range": { - "startColumn": 12, - "endColumn": 24, + "startColumn": 73, + "endColumn": 79, "lineCount": 1 } }, @@ -5909,19 +5925,11 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 16, - "endColumn": 22, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { "startColumn": 23, - "endColumn": 29, + "endColumn": 52, "lineCount": 1 } }, diff --git a/pymbolic/algorithm.py b/pymbolic/algorithm.py index 43757f3b..fa8e96af 100644 --- a/pymbolic/algorithm.py +++ b/pymbolic/algorithm.py @@ -46,6 +46,8 @@ from typing import TYPE_CHECKING, Protocol, TypeVar, cast, overload from warnings import warn +from typing_extensions import Self + from pytools import MovedFunctionDeprecationWrapper, memoize @@ -62,7 +64,7 @@ # {{{ integer powers class _CanMultiply(Protocol): - def __mul__(self: _CanMultiplyT, other: _CanMultiplyT, /) -> _CanMultiplyT: ... + def __mul__(self, other: Self, /) -> Self: ... _CanMultiplyT = TypeVar("_CanMultiplyT", bound=_CanMultiply) @@ -112,7 +114,7 @@ def extended_euclidean(q, r): `Wikipedia article on the Euclidean algorithm `__. """ - import pymbolic.traits as traits + from pymbolic import traits # see [Davenport], Appendix, p. 214 @@ -379,7 +381,7 @@ def reduced_row_echelon_form( rhs[i], rhs[nonz_row] = \ (rhs[nonz_row].copy(), rhs[i].copy()) - for u in range(0, m): + for u in range(m): if u == i: continue if not mat[u, j]: diff --git a/pymbolic/cse.py b/pymbolic/cse.py index 5f3efe65..cb7cc83b 100644 --- a/pymbolic/cse.py +++ b/pymbolic/cse.py @@ -32,7 +32,13 @@ if TYPE_CHECKING: - from collections.abc import Callable, Hashable, Iterable, Sequence, Set + from collections.abc import ( + Callable, + Hashable, + Iterable, + Sequence, + Set as AbstractSet, + ) from pymbolic.typing import Expression @@ -93,12 +99,12 @@ def map_common_subexpression(self, expr: prim.CommonSubexpression, /, class CSEMapper(IdentityMapper[[]]): - to_eliminate: Set[Hashable] + to_eliminate: AbstractSet[Hashable] get_key: Callable[[object], Hashable] canonical_subexprs: dict[Hashable, Expression] def __init__(self, - to_eliminate: Set[Hashable], + to_eliminate: AbstractSet[Hashable], get_key: Callable[[object], Hashable]) -> None: self.to_eliminate = to_eliminate self.get_key = get_key @@ -126,8 +132,7 @@ def map_sum(self, ) -> Expression: key = self.get_key(expr) if key in self.to_eliminate: - result = self.get_cse(expr, key) - return result + return self.get_cse(expr, key) else: return getattr(IdentityMapper, expr.mapper_method)(self, expr) @@ -177,5 +182,4 @@ def tag_common_subexpressions(exprs: Iterable[Expression]) -> Sequence[Expressio if count > 1} cse_mapper = CSEMapper(to_eliminate, get_key) - result = [cse_mapper(expr) for expr in exprs] - return result + return [cse_mapper(expr) for expr in exprs] diff --git a/pymbolic/geometric_algebra/__init__.py b/pymbolic/geometric_algebra/__init__.py index 87cfa75c..51a214c1 100644 --- a/pymbolic/geometric_algebra/__init__.py +++ b/pymbolic/geometric_algebra/__init__.py @@ -42,8 +42,7 @@ import numpy as np from typing_extensions import Self, override -import pytools.obj_array as obj_array -from pytools import memoize, memoize_method +from pytools import memoize, memoize_method, obj_array from pytools.obj_array import ObjectArray, ObjectArray1D, ShapeT from pymbolic.primitives import expr_dataclass, is_zero @@ -145,18 +144,18 @@ class _HasArithmetic(Protocol): - def __neg__(self: CoeffT) -> CoeffT: ... - def __abs__(self: CoeffT) -> CoeffT: ... + def __neg__(self) -> Self: ... + def __abs__(self) -> Self: ... - def __add__(self: CoeffT, other: CoeffT, /) -> CoeffT: ... - def __radd__(self: CoeffT, other: int, /) -> CoeffT: ... + def __add__(self, other: Self, /) -> Self: ... + def __radd__(self, other: int, /) -> Self: ... - def __sub__(self: CoeffT, other: CoeffT, /) -> CoeffT: ... + def __sub__(self, other: Self, /) -> Self: ... - def __mul__(self: CoeffT, other: CoeffT, /) -> CoeffT: ... - def __rmul__(self: CoeffT, other: int, /) -> CoeffT: ... + def __mul__(self, other: Self, /) -> Self: ... + def __rmul__(self, other: int, /) -> Self: ... - def __pow__(self: CoeffT, other: CoeffT, /) -> CoeffT: ... + def __pow__(self, other: Self, /) -> Self: ... CoeffT = TypeVar("CoeffT", bound=_HasArithmetic) @@ -1103,10 +1102,10 @@ def project(self, r: int) -> MultiVector[CoeffT]: Often written :math:`\langle A\rangle_r`. """ - new_data: dict[int, CoeffT] = {} - for bits, coeff in self.data.items(): - if bits.bit_count() == r: - new_data[bits] = coeff + new_data: dict[int, CoeffT] = { + bits: coeff + for bits, coeff in self.data.items() + if bits.bit_count() == r} return MultiVector(new_data, self.space) @@ -1161,19 +1160,19 @@ def get_pure_grade(self) -> int | None: def odd(self) -> MultiVector[CoeffT]: """Extract the odd-grade blades.""" - new_data: dict[int, CoeffT] = {} - for bits, coeff in self.data.items(): - if bits.bit_count() % 2: - new_data[bits] = coeff + new_data: dict[int, CoeffT] = { + bits: coeff + for bits, coeff in self.data.items() + if bits.bit_count() % 2} return MultiVector(new_data, self.space) def even(self) -> MultiVector[CoeffT]: """Extract the even-grade blades.""" - new_data: dict[int, CoeffT] = {} - for bits, coeff in self.data.items(): - if bits.bit_count() % 2 == 0: - new_data[bits] = coeff + new_data: dict[int, CoeffT] = { + bits: coeff + for bits, coeff in self.data.items() + if bits.bit_count() % 2 == 0} return MultiVector(new_data, self.space) diff --git a/pymbolic/geometric_algebra/mapper.py b/pymbolic/geometric_algebra/mapper.py index 138df5ec..eccd0e4d 100644 --- a/pymbolic/geometric_algebra/mapper.py +++ b/pymbolic/geometric_algebra/mapper.py @@ -27,12 +27,12 @@ # Consider yourself warned. from abc import ABC, abstractmethod -from collections.abc import Callable, Set +from collections.abc import Callable, Set as AbstractSet from typing import TYPE_CHECKING, ClassVar from typing_extensions import Self, override -import pytools.obj_array as obj_array +from pytools import obj_array import pymbolic.geometric_algebra.primitives as gp from pymbolic.geometric_algebra import MultiVector @@ -93,12 +93,12 @@ def map_derivative_source( class Collector(CollectorBase[CollectedT, P]): def map_nabla(self, expr: gp.Nabla, /, *args: P.args, **kwargs: P.kwargs - ) -> Set[CollectedT]: + ) -> AbstractSet[CollectedT]: return set() def map_nabla_component(self, expr: gp.NablaComponent, /, *args: P.args, **kwargs: P.kwargs - ) -> Set[CollectedT]: + ) -> AbstractSet[CollectedT]: return set() @@ -227,25 +227,26 @@ def map_derivative_source( # {{{ derivative binder class DerivativeSourceAndNablaComponentCollector( - CachedMapper[Set[ArithmeticExpression], []], + CachedMapper[AbstractSet[ArithmeticExpression], []], Collector[ArithmeticExpression, []]): def __init__(self) -> None: Collector.__init__(self) CachedMapper.__init__(self) @override - def map_nabla(self, expr: gp.Nabla, /) -> Set[ArithmeticExpression]: + def map_nabla(self, expr: gp.Nabla, /) -> AbstractSet[ArithmeticExpression]: raise RuntimeError( f"{type(self).__name__} must be invoked after " "Dimensionalizer -- Nabla expression found and not allowed") @override def map_nabla_component( - self, expr: gp.NablaComponent, /) -> Set[ArithmeticExpression]: + self, expr: gp.NablaComponent, /) -> AbstractSet[ArithmeticExpression]: return {expr} def map_derivative_source( - self, expr: gp.DerivativeSource, /) -> Set[ArithmeticExpression]: + self, expr: gp.DerivativeSource, + /) -> AbstractSet[ArithmeticExpression]: return {expr} | self.rec(expr.operand) diff --git a/pymbolic/geometric_algebra/primitives.py b/pymbolic/geometric_algebra/primitives.py index 8e5a17bf..9e4f1bd2 100644 --- a/pymbolic/geometric_algebra/primitives.py +++ b/pymbolic/geometric_algebra/primitives.py @@ -31,7 +31,7 @@ from typing_extensions import override -import pytools.obj_array as obj_array +from pytools import obj_array from pymbolic.geometric_algebra import MultiVector from pymbolic.primitives import ExpressionNode, Variable, expr_dataclass diff --git a/pymbolic/imperative/statement.py b/pymbolic/imperative/statement.py index 505ac11b..6b288fa4 100644 --- a/pymbolic/imperative/statement.py +++ b/pymbolic/imperative/statement.py @@ -34,7 +34,7 @@ if TYPE_CHECKING: - from collections.abc import Callable, Set + from collections.abc import Callable, Set as AbstractSet from pymbolic.primitives import Variable @@ -45,7 +45,7 @@ class BasicStatementLike(Protocol): @property def id(self) -> str: ... @property - def depends_on(self) -> Set[str]: ... + def depends_on(self) -> AbstractSet[str]: ... def copy(self, **kwargs: object) -> Self: ... @@ -54,9 +54,9 @@ def copy(self, **kwargs: object) -> Self: ... class StatementLike(BasicStatementLike, Protocol): - def get_written_variables(self) -> Set[str]: ... + def get_written_variables(self) -> AbstractSet[str]: ... - def get_read_variables(self) -> Set[str]: ... + def get_read_variables(self) -> AbstractSet[str]: ... def map_expressions(self, mapper: Callable[[Expression], Expression], @@ -81,7 +81,7 @@ class Statement: A string, a unique identifier for this instruction. """ - depends_on: Set[str] + depends_on: AbstractSet[str] """A :class:`frozenset` of instruction ids that are required to be executed within this execution context before this instruction can be executed.""" @@ -90,13 +90,13 @@ def __post_init__(self): if isinstance(self.id, str): object.__setattr__(self, "id", intern(self.id)) - def get_written_variables(self) -> Set[str]: + def get_written_variables(self) -> AbstractSet[str]: """Returns a :class:`frozenset` of variables being written by this instruction. """ return frozenset() - def get_read_variables(self) -> Set[str]: + def get_read_variables(self) -> AbstractSet[str]: """Returns a :class:`frozenset` of variables being read by this instruction. """ @@ -151,7 +151,7 @@ def __str__(self): + self._condition_printing_suffix()) @override - def get_read_variables(self) -> Set[str]: + def get_read_variables(self) -> AbstractSet[str]: dep_mapper = self.get_dependency_mapper() return ( super().get_read_variables() @@ -185,16 +185,13 @@ def get_written_variables(self): raise TypeError("unexpected type of LHS") @override - def get_read_variables(self) -> Set[str]: - result = super().get_read_variables() + def get_read_variables(self) -> AbstractSet[str]: get_deps = self.get_dependency_mapper() def get_vars(expr: Expression): return frozenset(cast("Variable", dep).name for dep in get_deps(expr)) - result = get_vars(self.rhs) | get_vars(self.lhs) - - return result + return get_vars(self.rhs) | get_vars(self.lhs) @override def map_expressions(self, @@ -209,11 +206,10 @@ def map_expressions(self, @override def __str__(self): - result = "{assignee} <- {expr}".format( + return "{assignee} <- {expr}".format( assignee=str(self.lhs), expr=str(self.rhs),) - return result # }}} diff --git a/pymbolic/imperative/transform.py b/pymbolic/imperative/transform.py index e80719e2..3c4b16e3 100644 --- a/pymbolic/imperative/transform.py +++ b/pymbolic/imperative/transform.py @@ -60,12 +60,11 @@ def fuse_statement_streams_with_unique_ids( b_unique_statements.append( stmtb.copy(id=new_id)) - for stmtb in b_unique_statements: - new_statements.append( - stmtb.copy( + new_statements.extend(stmtb.copy( depends_on=frozenset( old_b_id_to_new_b_id[dep_id] - for dep_id in stmtb.depends_on))) + for dep_id in stmtb.depends_on)) + for stmtb in b_unique_statements) return new_statements, old_b_id_to_new_b_id diff --git a/pymbolic/imperative/utils.py b/pymbolic/imperative/utils.py index d9afee70..9825ea57 100644 --- a/pymbolic/imperative/utils.py +++ b/pymbolic/imperative/utils.py @@ -109,27 +109,26 @@ def get_node_attrs(stmt): while True: changed_something = False - for stmt_1 in dep_graph: + for stmt_1, deps in dep_graph.items(): for stmt_2 in dep_graph.get(stmt_1, set()).copy(): for stmt_3 in dep_graph.get(stmt_2, set()).copy(): if stmt_3 not in dep_graph.get(stmt_1, set()): changed_something = True - dep_graph[stmt_1].add(stmt_3) + deps.add(stmt_3) if not changed_something: break - for stmt_1 in dep_graph: + for stmt_1, deps in dep_graph.items(): for stmt_2 in dep_graph.get(stmt_1, set()).copy(): for stmt_3 in dep_graph.get(stmt_2, set()).copy(): if stmt_3 in dep_graph.get(stmt_1, set()): - dep_graph[stmt_1].remove(stmt_3) + deps.remove(stmt_3) # }}} for stmt_1 in dep_graph: - for stmt_2 in dep_graph.get(stmt_1, set()): - lines.append(f"{stmt_1} -> {stmt_2}") + lines.extend(f"{stmt_1} -> {stmt_2}" for stmt_2 in dep_graph.get(stmt_1, set())) for (stmt_1, stmt_2), annot in annotation_dep_graph.items(): lines.append(f'{stmt_2} -> {stmt_1} [label="{annot}", style="dashed"]') diff --git a/pymbolic/interop/ast.py b/pymbolic/interop/ast.py index 784ab917..77025810 100644 --- a/pymbolic/interop/ast.py +++ b/pymbolic/interop/ast.py @@ -46,16 +46,16 @@ from pymbolic.geometric_algebra import MultiVector # NOTE: these are removed in Python 3.14 - if sys.version_info < (3, 14): + if sys.version_info >= (3, 14): + AstNum: TypeAlias = Any + AstStr: TypeAlias = Any + AstBytes: TypeAlias = Any + else: from ast import ( Bytes as AstBytes, # pyright: ignore[reportDeprecated] Num as AstNum, # pyright: ignore[reportDeprecated] Str as AstStr, # pyright: ignore[reportDeprecated] ) - else: - AstNum: TypeAlias = Any - AstStr: TypeAlias = Any - AstBytes: TypeAlias = Any __doc__ = r''' An example: diff --git a/pymbolic/interop/common.py b/pymbolic/interop/common.py index 7879e4ae..1ae4335e 100644 --- a/pymbolic/interop/common.py +++ b/pymbolic/interop/common.py @@ -23,9 +23,10 @@ THE SOFTWARE. """ +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Generic, TypeAlias, cast -from typing_extensions import override +from typing_extensions import Never, override import pymbolic.primitives as prim from pymbolic.mapper import P, ResultT @@ -182,7 +183,7 @@ def map_StrictLessThan(self, expr: sp.StrictLessThan) -> Expression: SympyLikeExpression: TypeAlias = "sp.Basic" -class PymbolicToSympyLikeMapper(EvaluationMapper[SympyLikeExpression]): +class PymbolicToSympyLikeMapper(EvaluationMapper[SympyLikeExpression], ABC): # FIXME(pyright): Returning `Any` here is not great, but we would need a big # protocol or something if we want it to work for both sympy/symengine. @property @@ -194,8 +195,9 @@ def to_expr(self, expr: Expression) -> sp.Expr: assert isinstance(result, self.sym.Expr) # pyright: ignore[reportAny] return cast("sp.Expr", result) - def raise_conversion_error(self, expr: object) -> None: - raise NotImplementedError + @abstractmethod + def raise_conversion_error(self, expr: object) -> Never: + ... @override def map_variable(self, expr: prim.Variable) -> SympyLikeExpression: @@ -221,7 +223,6 @@ def map_call(self, expr: prim.Call) -> SympyLikeExpression: return func(*[self.rec(par) for par in expr.parameters]) else: self.raise_conversion_error(expr) - raise @override def map_subscript(self, expr: prim.Subscript) -> SympyLikeExpression: @@ -232,7 +233,6 @@ def map_subscript(self, expr: prim.Subscript) -> SympyLikeExpression: *(self.rec(idx) for idx in expr.index_tuple)) else: self.raise_conversion_error(expr) - raise @override def map_substitution(self, expr: prim.Substitution) -> SympyLikeExpression: diff --git a/pymbolic/interop/symengine.py b/pymbolic/interop/symengine.py index b7d6d73c..3d6c5c08 100644 --- a/pymbolic/interop/symengine.py +++ b/pymbolic/interop/symengine.py @@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Any import symengine as sp -from typing_extensions import override +from typing_extensions import Never, override import pymbolic.primitives as prim from pymbolic.interop.common import ( @@ -127,7 +127,7 @@ def sym(self) -> Any: return sp @override - def raise_conversion_error(self, expr: object) -> None: + def raise_conversion_error(self, expr: object) -> Never: raise RuntimeError(f"do not know how to translate '{expr}' to symengine") diff --git a/pymbolic/interop/sympy.py b/pymbolic/interop/sympy.py index 2e9060db..eac2a565 100644 --- a/pymbolic/interop/sympy.py +++ b/pymbolic/interop/sympy.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any import sympy as sp -from typing_extensions import override +from typing_extensions import Never, override import pymbolic.primitives as prim from pymbolic.interop.common import ( @@ -118,7 +118,7 @@ def sym(self) -> Any: return sp @override - def raise_conversion_error(self, expr: object) -> None: + def raise_conversion_error(self, expr: object) -> Never: raise RuntimeError(f"do not know how to translate '{expr}' to sympy") @override diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 44d402d6..0dcf6696 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -24,7 +24,7 @@ """ from abc import ABC, abstractmethod -from collections.abc import Callable, Hashable, Iterable, Mapping, Set +from collections.abc import Callable, Hashable, Iterable, Mapping, Set as AbstractSet from typing import ( TYPE_CHECKING, Concatenate, @@ -192,8 +192,7 @@ def __call__(self, if method_name is not None: method = getattr(self, method_name, None) if method is not None: - result = method(expr, *args, **kwargs) - return result + return method(expr, *args, **kwargs) if isinstance(expr, p.ExpressionNode): for cls in type(expr).__mro__[1:]: @@ -202,8 +201,7 @@ def __call__(self, method = getattr(self, method_name, None) if method: return method(expr, *args, **kwargs) - else: - return self.handle_unsupported_expression(expr, *args, **kwargs) + return self.handle_unsupported_expression(expr, *args, **kwargs) else: return self.map_foreign(expr, *args, **kwargs) @@ -221,8 +219,7 @@ def rec_fallback(self, method = getattr(self, method_name, None) if method: return method(expr, *args, **kwargs) - else: - return self.handle_unsupported_expression(expr, *args, **kwargs) + return self.handle_unsupported_expression(expr, *args, **kwargs) else: return self.map_foreign(expr, *args, **kwargs) @@ -731,7 +728,7 @@ class CachedCombineMapper(CachedMapper[ResultT, P], CombineMapper[ResultT, P]): CollectedT = TypeVar("CollectedT") -class Collector(CombineMapper[Set[CollectedT], P]): +class Collector(CombineMapper[AbstractSet[CollectedT], P]): """A subclass of :class:`CombineMapper` for the common purpose of collecting data derived from an expression in a set that gets 'unioned' across children at each non-leaf node in the expression tree. @@ -742,7 +739,9 @@ class Collector(CombineMapper[Set[CollectedT], P]): """ @override - def combine(self, values: Iterable[Set[CollectedT]], /) -> frozenset[CollectedT]: + def combine(self, + values: Iterable[AbstractSet[CollectedT]], + /) -> frozenset[CollectedT]: it = iter(values) try: first = next(it) @@ -753,36 +752,38 @@ def combine(self, values: Iterable[Set[CollectedT]], /) -> frozenset[CollectedT] @override def map_constant(self, expr: object, /, - *args: P.args, **kwargs: P.kwargs) -> Set[CollectedT]: + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: return set() @override def map_variable(self, expr: p.Variable, /, - *args: P.args, **kwargs: P.kwargs) -> Set[CollectedT]: + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: return set() @override def map_wildcard(self, expr: p.Wildcard, /, - *args: P.args, **kwargs: P.kwargs) -> Set[CollectedT]: + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: return set() @override def map_dot_wildcard(self, expr: p.DotWildcard, /, - *args: P.args, **kwargs: P.kwargs) -> Set[CollectedT]: + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: return set() @override def map_star_wildcard(self, expr: p.StarWildcard, /, - *args: P.args, **kwargs: P.kwargs) -> Set[CollectedT]: + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: return set() @override def map_function_symbol(self, expr: p.FunctionSymbol, /, - *args: P.args, **kwargs: P.kwargs) -> Set[CollectedT]: + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: return set() -class CachedCollector(CachedMapper[Set[CollectedT], P], Collector[CollectedT, P]): +class CachedCollector( + CachedMapper[AbstractSet[CollectedT], P], + Collector[CollectedT, P]): pass # }}} @@ -1441,7 +1442,7 @@ def map_multivector(self, if not self.visit(expr, *args, **kwargs): return - for _bits, coeff in expr.data.items(): + for coeff in expr.data.values(): self.rec(coeff, *args, **kwargs) self.post_visit(expr, *args, **kwargs) diff --git a/pymbolic/mapper/coefficient.py b/pymbolic/mapper/coefficient.py index bb499768..50f3e707 100644 --- a/pymbolic/mapper/coefficient.py +++ b/pymbolic/mapper/coefficient.py @@ -91,8 +91,8 @@ def map_quotient(self, expr: p.Quotient, /) -> CoeffsT: if len(d_den) > 1 or 1 not in d_den: raise RuntimeError("nonlinear expression") val = d_den[1] - for k in d_num: - d_num[k] = p.flattened_product((d_num[k], p.Quotient(1, val))) + for k, coeff in d_num.items(): + d_num[k] = p.flattened_product((coeff, p.Quotient(1, val))) return d_num @override diff --git a/pymbolic/mapper/collector.py b/pymbolic/mapper/collector.py index c977fe7e..f3b46fa9 100644 --- a/pymbolic/mapper/collector.py +++ b/pymbolic/mapper/collector.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: - from collections.abc import Sequence, Set + from collections.abc import Sequence, Set as AbstractSet from pymbolic.mapper.dependency import Dependencies from pymbolic.typing import ArithmeticExpression, Expression @@ -50,9 +50,9 @@ class TermCollector(IdentityMapper[[]]): coefficients and are not used for term collection. """ - parameters: Set[p.AlgebraicLeaf] + parameters: AbstractSet[p.AlgebraicLeaf] - def __init__(self, parameters: Set[p.AlgebraicLeaf] | None = None) -> None: + def __init__(self, parameters: AbstractSet[p.AlgebraicLeaf] | None = None) -> None: if parameters is None: parameters = set() @@ -63,7 +63,7 @@ def get_dependencies(self, expr: Expression) -> Dependencies: return DependencyMapper()(expr) def split_term(self, mul_term: Expression) -> tuple[ - Set[tuple[ArithmeticExpression, ArithmeticExpression]], + AbstractSet[tuple[ArithmeticExpression, ArithmeticExpression]], ArithmeticExpression ]: """Returns a pair consisting of: @@ -124,18 +124,17 @@ def exponent(term: Expression) -> ArithmeticExpression: @override def map_sum(self, expr: p.Sum, /) -> Expression: term2coeff: dict[ - Set[tuple[ArithmeticExpression, ArithmeticExpression]], + AbstractSet[tuple[ArithmeticExpression, ArithmeticExpression]], ArithmeticExpression] = {} for child in expr.children: term, coeff = self.split_term(child) term2coeff[term] = term2coeff.get(term, 0) + coeff def rep2term( - rep: Set[tuple[ArithmeticExpression, ArithmeticExpression]] + rep: AbstractSet[tuple[ArithmeticExpression, ArithmeticExpression]] ) -> ArithmeticExpression: return p.flattened_product([base**exp for base, exp in rep]) - result = p.flattened_sum([ + return p.flattened_sum([ coeff*rep2term(termrep) for termrep, coeff in term2coeff.items() ]) - return result diff --git a/pymbolic/mapper/dependency.py b/pymbolic/mapper/dependency.py index 9cdb38b8..e650f6fc 100644 --- a/pymbolic/mapper/dependency.py +++ b/pymbolic/mapper/dependency.py @@ -30,7 +30,7 @@ THE SOFTWARE. """ -from collections.abc import Set +from collections.abc import Set as AbstractSet from typing import TYPE_CHECKING, Literal, TypeAlias from typing_extensions import override @@ -40,7 +40,7 @@ Dependency: TypeAlias = p.AlgebraicLeaf | p.CommonSubexpression -Dependencies: TypeAlias = Set[Dependency] +Dependencies: TypeAlias = AbstractSet[Dependency] if not TYPE_CHECKING: DependenciesT: TypeAlias = Dependencies diff --git a/pymbolic/mapper/distributor.py b/pymbolic/mapper/distributor.py index d1bb9e0a..ec2e4306 100644 --- a/pymbolic/mapper/distributor.py +++ b/pymbolic/mapper/distributor.py @@ -106,8 +106,7 @@ def dist(prod: ArithmeticExpression) -> ArithmeticExpression: if len(leading) == len(prod.children): # no more sums found - result = p.flattened_product(prod.children) - return result + return p.flattened_product(prod.children) else: sum = prod.children[len(leading)] assert isinstance(sum, p.Sum) diff --git a/pymbolic/mapper/optimize.py b/pymbolic/mapper/optimize.py index 76b49012..9795bf58 100644 --- a/pymbolic/mapper/optimize.py +++ b/pymbolic/mapper/optimize.py @@ -391,7 +391,7 @@ def wrapper(cls: type) -> type: # }}} - exec(compile( + exec(compile( # noqa: S102 code_str, f"<'{_get_file_name_for_module_name(cls.__module__)}' " "modified by optimize_mapper>", diff --git a/pymbolic/mapper/substitutor.py b/pymbolic/mapper/substitutor.py index 32d5bdd7..c84beed3 100644 --- a/pymbolic/mapper/substitutor.py +++ b/pymbolic/mapper/substitutor.py @@ -41,7 +41,7 @@ if TYPE_CHECKING: - from collections.abc import Callable, Set + from collections.abc import Callable, Set as AbstractSet import optype @@ -52,7 +52,7 @@ _VT_co = TypeVar("_VT_co", covariant=True) class CanItems(Protocol[_KT_co, _VT_co]): - def items(self) -> Set[tuple[_KT_co, _VT_co]]: ... + def items(self) -> AbstractSet[tuple[_KT_co, _VT_co]]: ... class SubstitutionMapper(IdentityMapper[[]]): diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index da6fdb8c..3bea04de 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -46,8 +46,7 @@ from constantdict import constantdict from typing_extensions import Self, TypeIs, dataclass_transform, override -import pytools.obj_array as obj_array -from pytools import T, module_getattr_for_deprecations, ndindex +from pytools import T, module_getattr_for_deprecations, ndindex, obj_array from pytools.obj_array import ( ObjectArray, ObjectArray1D, @@ -909,7 +908,7 @@ def __iter__(self) -> NoReturn: (?<=[A-Z]) # preceded by lowercase (?=[A-Z][a-z]) # followed by uppercase, then lowercase """, - re.X, + re.VERBOSE, ) @@ -1048,7 +1047,7 @@ def {cls.__name__}_setstate(self, state): """) exec_dict = {"cls": cls, "_MODULE_SOURCE_CODE": augment_code} - exec(compile(augment_code, + exec(compile(augment_code, # noqa: S102 f"", "exec"), exec_dict) @@ -1123,14 +1122,12 @@ class AlgebraicLeaf(ExpressionNode): """An expression that serves as a leaf for arithmetic evaluation. This may end up having child nodes still, but they're not reached by ways of arithmetic.""" - pass @expr_dataclass() class Leaf(AlgebraicLeaf): """An expression that is irreducible, i.e. has no Expression-type parts whatsoever.""" - pass @expr_dataclass() @@ -1204,7 +1201,7 @@ class CallWithKwargs(AlgebraicLeaf): def __post_init__(self): try: hash(self.kw_parameters) - except Exception: + except Exception: # noqa: BLE001 warn("CallWithKwargs created with non-hashable kw_parameters. " "This is deprecated and will stop working in 2025. " "If you need an immutable mapping, " @@ -1807,8 +1804,6 @@ def quotient(numerator: ArithmeticExpression, # {{{ tool functions -global VALID_CONSTANT_CLASSES -global VALID_OPERANDS VALID_CONSTANT_CLASSES: tuple[type, ...] = (int, float, complex) _BOOL_CLASSES: tuple[type, ...] = (bool,) VALID_OPERANDS = (ExpressionNode,) @@ -1946,7 +1941,7 @@ def make_common_subexpression( scope = cse_scope.EVALUATION if (isinstance(expr, CommonSubexpression) - and (scope == cse_scope.EVALUATION or expr.scope == scope)): + and (scope in (cse_scope.EVALUATION, expr.scope))): return expr # handle MultiVector diff --git a/pymbolic/rational.py b/pymbolic/rational.py index b966af6a..91bdc49b 100644 --- a/pymbolic/rational.py +++ b/pymbolic/rational.py @@ -25,8 +25,7 @@ from sys import intern -import pymbolic.primitives as primitives -import pymbolic.traits as traits +from pymbolic import primitives, traits class Rational(primitives.ExpressionNode): diff --git a/pyproject.toml b/pyproject.toml index 191638ad..9d2efb09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,12 +91,14 @@ extend-ignore = [ "UP031", # use f-strings instead of % "UP032", # use f-strings instead of .format "RUF067", # no code in __init__ *shrug* + "TRY004", ] [tool.ruff.lint.per-file-ignores] "experiments/traversal-benchmark.py" = ["E501"] -"doc/conf.py" = ["I002"] +"doc/conf.py" = ["I002", "S102"] "experiments/*.py" = ["I002"] +"test/*.py" = ["S102"] [tool.ruff.lint.pep8-naming] extend-ignore-names = ["map_*"] diff --git a/test/test_maxima.py b/test/test_maxima.py index eacea8d8..25b5278e 100644 --- a/test/test_maxima.py +++ b/test/test_maxima.py @@ -103,11 +103,11 @@ def test_strict_round_trip(knl: MaximaKernel) -> None: round_trips_correctly = result == expr if not round_trips_correctly: print("ORIGINAL:") - print("") + print() print(expr) - print("") + print() print("POST-MAXIMA:") - print("") + print() print(result) assert round_trips_correctly diff --git a/test/test_pymbolic.py b/test/test_pymbolic.py index bb7bccc5..aa766c77 100644 --- a/test/test_pymbolic.py +++ b/test/test_pymbolic.py @@ -42,7 +42,7 @@ if TYPE_CHECKING: - from collections.abc import Collection, Sequence, Set + from collections.abc import Collection, Sequence, Set as AbstractSet from pymbolic.mapper.unifier import UnificationRecord from pymbolic.typing import Expression @@ -630,7 +630,7 @@ def test_unifier(): def match_found( records: Sequence[UnificationRecord], - eqns: Set[tuple[Expression, Expression]]): + eqns: AbstractSet[tuple[Expression, Expression]]): return any(eqns <= set(record.equations) for record in records) def unify(candidates: Collection[str], expr: Expression, other: Expression):