Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -1584,8 +1584,8 @@
{
"code": "reportAny",
"range": {
"startColumn": 30,
"endColumn": 31,
"startColumn": 27,
"endColumn": 28,
"lineCount": 1
}
},
Expand Down Expand Up @@ -2955,6 +2955,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 20,
"endColumn": 24,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
Expand Down Expand Up @@ -3039,7 +3047,7 @@
"code": "reportUnknownMemberType",
"range": {
"startColumn": 24,
"endColumn": 45,
"endColumn": 32,
"lineCount": 1
}
},
Expand All @@ -3051,6 +3059,14 @@
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 16,
"endColumn": 20,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
Expand Down Expand Up @@ -3135,7 +3151,7 @@
"code": "reportUnknownMemberType",
"range": {
"startColumn": 20,
"endColumn": 44,
"endColumn": 31,
"lineCount": 1
}
},
Expand All @@ -3148,34 +3164,34 @@
}
},
{
"code": "reportUnknownVariableType",
"code": "reportUnknownMemberType",
"range": {
"startColumn": 12,
"endColumn": 18,
"startColumn": 8,
"endColumn": 20,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"code": "reportUnknownVariableType",
"range": {
"startColumn": 22,
"endColumn": 35,
"startColumn": 49,
"endColumn": 55,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"code": "reportUnknownMemberType",
"range": {
"startColumn": 36,
"endColumn": 42,
"startColumn": 59,
"endColumn": 72,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 12,
"endColumn": 24,
"startColumn": 73,
"endColumn": 79,
"lineCount": 1
}
},
Expand Down Expand Up @@ -5909,19 +5925,11 @@
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 16,
"endColumn": 22,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 23,
"endColumn": 29,
"endColumn": 52,
"lineCount": 1
}
},
Expand Down
8 changes: 5 additions & 3 deletions pymbolic/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from typing import TYPE_CHECKING, Protocol, TypeVar, cast, overload
from warnings import warn

from typing_extensions import Self

from pytools import MovedFunctionDeprecationWrapper, memoize


Expand All @@ -62,7 +64,7 @@
# {{{ integer powers

class _CanMultiply(Protocol):
def __mul__(self: _CanMultiplyT, other: _CanMultiplyT, /) -> _CanMultiplyT: ...
def __mul__(self, other: Self, /) -> Self: ...


_CanMultiplyT = TypeVar("_CanMultiplyT", bound=_CanMultiply)
Expand Down Expand Up @@ -112,7 +114,7 @@
`Wikipedia article on the Euclidean algorithm
<https://en.wikipedia.org/wiki/Euclidean_algorithm>`__.
"""
import pymbolic.traits as traits
from pymbolic import traits

# see [Davenport], Appendix, p. 214

Expand Down Expand Up @@ -298,7 +300,7 @@
return x

return NearZeroKiller()(
fft(wrap_intermediate_with_level(0, x), sign=sign,

Check warning on line 303 in pymbolic/algorithm.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

Not supplying complex_dtype is deprecated, falling back to complex128 for now. This will stop working in 2023.

Check warning on line 303 in pymbolic/algorithm.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

Not supplying complex_dtype is deprecated, falling back to complex128 for now. This will stop working in 2023.
wrap_intermediate_with_level=wrap_intermediate_with_level))

# }}}
Expand Down Expand Up @@ -379,7 +381,7 @@
rhs[i], rhs[nonz_row] = \
(rhs[nonz_row].copy(), rhs[i].copy())

for u in range(0, m):
for u in range(m):
if u == i:
continue
if not mat[u, j]:
Expand Down
18 changes: 11 additions & 7 deletions pymbolic/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@


if TYPE_CHECKING:
from collections.abc import Callable, Hashable, Iterable, Sequence, Set
from collections.abc import (
Callable,
Hashable,
Iterable,
Sequence,
Set as AbstractSet,
)

from pymbolic.typing import Expression

Expand Down Expand Up @@ -93,12 +99,12 @@ def map_common_subexpression(self, expr: prim.CommonSubexpression, /,


class CSEMapper(IdentityMapper[[]]):
to_eliminate: Set[Hashable]
to_eliminate: AbstractSet[Hashable]
get_key: Callable[[object], Hashable]
canonical_subexprs: dict[Hashable, Expression]

def __init__(self,
to_eliminate: Set[Hashable],
to_eliminate: AbstractSet[Hashable],
get_key: Callable[[object], Hashable]) -> None:
self.to_eliminate = to_eliminate
self.get_key = get_key
Expand Down Expand Up @@ -126,8 +132,7 @@ def map_sum(self,
) -> Expression:
key = self.get_key(expr)
if key in self.to_eliminate:
result = self.get_cse(expr, key)
return result
return self.get_cse(expr, key)
else:
return getattr(IdentityMapper, expr.mapper_method)(self, expr)

Expand Down Expand Up @@ -177,5 +182,4 @@ def tag_common_subexpressions(exprs: Iterable[Expression]) -> Sequence[Expressio
if count > 1}

cse_mapper = CSEMapper(to_eliminate, get_key)
result = [cse_mapper(expr) for expr in exprs]
return result
return [cse_mapper(expr) for expr in exprs]
43 changes: 21 additions & 22 deletions pymbolic/geometric_algebra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@
import numpy as np
from typing_extensions import Self, override

import pytools.obj_array as obj_array
from pytools import memoize, memoize_method
from pytools import memoize, memoize_method, obj_array
from pytools.obj_array import ObjectArray, ObjectArray1D, ShapeT

from pymbolic.primitives import expr_dataclass, is_zero
Expand Down Expand Up @@ -145,18 +144,18 @@


class _HasArithmetic(Protocol):
def __neg__(self: CoeffT) -> CoeffT: ...
def __abs__(self: CoeffT) -> CoeffT: ...
def __neg__(self) -> Self: ...
def __abs__(self) -> Self: ...

def __add__(self: CoeffT, other: CoeffT, /) -> CoeffT: ...
def __radd__(self: CoeffT, other: int, /) -> CoeffT: ...
def __add__(self, other: Self, /) -> Self: ...
def __radd__(self, other: int, /) -> Self: ...

def __sub__(self: CoeffT, other: CoeffT, /) -> CoeffT: ...
def __sub__(self, other: Self, /) -> Self: ...

def __mul__(self: CoeffT, other: CoeffT, /) -> CoeffT: ...
def __rmul__(self: CoeffT, other: int, /) -> CoeffT: ...
def __mul__(self, other: Self, /) -> Self: ...
def __rmul__(self, other: int, /) -> Self: ...

def __pow__(self: CoeffT, other: CoeffT, /) -> CoeffT: ...
def __pow__(self, other: Self, /) -> Self: ...


CoeffT = TypeVar("CoeffT", bound=_HasArithmetic)
Expand Down Expand Up @@ -1103,10 +1102,10 @@ def project(self, r: int) -> MultiVector[CoeffT]:

Often written :math:`\langle A\rangle_r`.
"""
new_data: dict[int, CoeffT] = {}
for bits, coeff in self.data.items():
if bits.bit_count() == r:
new_data[bits] = coeff
new_data: dict[int, CoeffT] = {
bits: coeff
for bits, coeff in self.data.items()
if bits.bit_count() == r}

return MultiVector(new_data, self.space)

Expand Down Expand Up @@ -1161,19 +1160,19 @@ def get_pure_grade(self) -> int | None:

def odd(self) -> MultiVector[CoeffT]:
"""Extract the odd-grade blades."""
new_data: dict[int, CoeffT] = {}
for bits, coeff in self.data.items():
if bits.bit_count() % 2:
new_data[bits] = coeff
new_data: dict[int, CoeffT] = {
bits: coeff
for bits, coeff in self.data.items()
if bits.bit_count() % 2}

return MultiVector(new_data, self.space)

def even(self) -> MultiVector[CoeffT]:
"""Extract the even-grade blades."""
new_data: dict[int, CoeffT] = {}
for bits, coeff in self.data.items():
if bits.bit_count() % 2 == 0:
new_data[bits] = coeff
new_data: dict[int, CoeffT] = {
bits: coeff
for bits, coeff in self.data.items()
if bits.bit_count() % 2 == 0}

return MultiVector(new_data, self.space)

Expand Down
17 changes: 9 additions & 8 deletions pymbolic/geometric_algebra/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
# Consider yourself warned.

from abc import ABC, abstractmethod
from collections.abc import Callable, Set
from collections.abc import Callable, Set as AbstractSet
from typing import TYPE_CHECKING, ClassVar

from typing_extensions import Self, override

import pytools.obj_array as obj_array
from pytools import obj_array

import pymbolic.geometric_algebra.primitives as gp
from pymbolic.geometric_algebra import MultiVector
Expand Down Expand Up @@ -93,12 +93,12 @@ def map_derivative_source(
class Collector(CollectorBase[CollectedT, P]):
def map_nabla(self,
expr: gp.Nabla, /, *args: P.args, **kwargs: P.kwargs
) -> Set[CollectedT]:
) -> AbstractSet[CollectedT]:
return set()

def map_nabla_component(self,
expr: gp.NablaComponent, /, *args: P.args, **kwargs: P.kwargs
) -> Set[CollectedT]:
) -> AbstractSet[CollectedT]:
return set()


Expand Down Expand Up @@ -227,25 +227,26 @@ def map_derivative_source(
# {{{ derivative binder

class DerivativeSourceAndNablaComponentCollector(
CachedMapper[Set[ArithmeticExpression], []],
CachedMapper[AbstractSet[ArithmeticExpression], []],
Collector[ArithmeticExpression, []]):
def __init__(self) -> None:
Collector.__init__(self)
CachedMapper.__init__(self)

@override
def map_nabla(self, expr: gp.Nabla, /) -> Set[ArithmeticExpression]:
def map_nabla(self, expr: gp.Nabla, /) -> AbstractSet[ArithmeticExpression]:
raise RuntimeError(
f"{type(self).__name__} must be invoked after "
"Dimensionalizer -- Nabla expression found and not allowed")

@override
def map_nabla_component(
self, expr: gp.NablaComponent, /) -> Set[ArithmeticExpression]:
self, expr: gp.NablaComponent, /) -> AbstractSet[ArithmeticExpression]:
return {expr}

def map_derivative_source(
self, expr: gp.DerivativeSource, /) -> Set[ArithmeticExpression]:
self, expr: gp.DerivativeSource,
/) -> AbstractSet[ArithmeticExpression]:
return {expr} | self.rec(expr.operand)


Expand Down
2 changes: 1 addition & 1 deletion pymbolic/geometric_algebra/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from typing_extensions import override

import pytools.obj_array as obj_array
from pytools import obj_array

from pymbolic.geometric_algebra import MultiVector
from pymbolic.primitives import ExpressionNode, Variable, expr_dataclass
Expand Down
Loading
Loading