Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8,896 changes: 1,328 additions & 7,568 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@
"Variable": "class:pymbolic.primitives.Variable",
"prim.Subscript": "class:pymbolic.primitives.Subscript",
"prim.Variable": "class:pymbolic.primitives.Variable",
"ExpressionNode": "class:pytential.symbolic.primitives.ExpressionNode",
"ArithmeticExpressionContainerTc":
"obj:pymbolic.typing.ArithmeticExpressionContainerTc",
# arraycontext
"ArrayContainer": "obj:arraycontext.ArrayContainer",
"ArrayOrContainerOrScalar": "obj:arraycontext.ArrayOrContainerOrScalar",
Expand All @@ -100,6 +101,7 @@
"P2PBase": "class:sumpy.p2p.P2PBase",
"FMMLevelToOrder": "class:sumpy.fmm.FMMLevelToOrder",
# pytential
"ExpressionNode": "class:pytential.symbolic.primitives.ExpressionNode",
"DOFDescriptorLike": "data:pytential.symbolic.dof_desc.DOFDescriptorLike",
"DOFGranularity": "data:pytential.symbolic.dof_desc.DOFGranularity",
"DiscretizationStage": "data:pytential.symbolic.dof_desc.DiscretizationStage",
Expand Down
12 changes: 7 additions & 5 deletions pytential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from arraycontext import Array
from meshmode.discretization import Discretization
from meshmode.dof_array import DOFArray
from pymbolic.geometric_algebra import MultiVector
from pytools.obj_array import ObjectArray1D


Expand Down Expand Up @@ -72,14 +71,14 @@ def _set_up_errors():


@memoize_on_first_arg
def _integral_op(discr):
def _integral_op(discr: Discretization):
from pytential import bind, sym
return bind(discr,
sym.integral(
discr.ambient_dim, discr.dim, sym.var("integrand")))


def integral(discr, x):
def integral(discr: Discretization, x: DOFArray):
return _integral_op(discr)(integrand=x)


Expand Down Expand Up @@ -112,7 +111,7 @@ def _norm_inf_op(discr: Discretization, num_components: int | None):

def norm(
discr: Discretization,
x: DOFArray | ObjectArray1D[DOFArray] | MultiVector[DOFArray],
x: DOFArray | ObjectArray1D[DOFArray],
p: float | Literal["inf"] = 2):
from pymbolic.geometric_algebra import MultiVector
if isinstance(x, MultiVector):
Expand All @@ -128,7 +127,10 @@ def norm(

elif p == np.inf or p == "inf":
norm_op = _norm_inf_op(discr, num_components)
norm_res = norm_op(arg=x)

# FIXME: norm_op (correctly) becomes BoundExpression[Operand], but
# then none of the overloads fit, hence the type-ignore.
norm_res = norm_op(arg=x) # pyright: ignore[reportCallIssue]
if isinstance(norm_res, obj_array.ObjectArray):
# FIXME: Pyright may have a point: It's not clear how/if this works
return max(cast("ObjectArray1D[Array]", norm_res)) # pyright: ignore[reportArgumentType]
Expand Down
40 changes: 24 additions & 16 deletions pytential/linalg/direct_solver_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@

from typing_extensions import override

from pymbolic.geometric_algebra import componentwise
from pytools import obj_array

from pytential.symbolic.mappers import IdentityMapper, LocationTagger, OperatorCollector


if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Callable, Iterable

from pymbolic.primitives import Product, Sum
from pymbolic.typing import ArithmeticExpression, Scalar
Expand Down Expand Up @@ -82,7 +83,9 @@ def _prepare_expr(expr: ArithmeticExpression) -> ArithmeticExpression:
# ensure all IntGs remove all the kernel derivatives
expr = KernelTransformationRemover()(expr)
# ensure all IntGs have their source and targets set
expr = DOFDescriptorReplacer(auto_where[0], auto_where[1])(expr)
expr = DOFDescriptorReplacer(
default_source=auto_where[0],
default_target=auto_where[1]).rec_arith(expr)

return expr

Expand Down Expand Up @@ -223,14 +226,18 @@ def _default_dofdesc(self, dofdesc: DOFDescriptorLike) -> DOFDescriptor:
return self.default_target

@override
def map_int_g(self, expr: IntG) -> IntG:
def map_int_g(self,
expr: IntG,
rec:
Callable[[ArithmeticExpression], ArithmeticExpression]
| None = None):
return type(expr)(
expr.target_kernel, expr.source_kernels,
densities=self.operand_rec(expr.densities),
densities=tuple(self.rec_arith(d) for d in expr.densities),
qbx_forced_limit=expr.qbx_forced_limit,
source=self.default_source, target=self.default_target,
kernel_arguments={
name: self.operand_rec(arg_expr)
name: componentwise(self.rec_arith, arg_expr)
for name, arg_expr in expr.kernel_arguments.items()
}
)
Expand All @@ -248,17 +255,18 @@ class DOFDescriptorReplacer(_LocationReplacer):
.. automethod:: __init__
"""

operand_rec: _LocationReplacer

def __init__(self, source: DOFDescriptorLike, target: DOFDescriptorLike) -> None:
"""
:param source: a descriptor for all expressions to be evaluated on
the source geometry.
:param target: a descriptor for all expressions to be evaluate on
the target geometry.
"""
super().__init__(target, default_source=source)
self.operand_rec = _LocationReplacer(source, default_source=source)
rec: _LocationReplacer

@override
def map_int_g(self,
expr: IntG,
rec:
Callable[[ArithmeticExpression], ArithmeticExpression]
| None = None):
ltag = _LocationReplacer(
default_source=self.default_source,
default_target=self.default_source)
return super().map_int_g(expr, rec=ltag.rec_arith)

# }}}

Expand Down
13 changes: 9 additions & 4 deletions pytential/qbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@
from pytential.symbolic.compiler import ComputePotential, PotentialOutput
from pytential.symbolic.dof_desc import DOFDescriptor, GeometryId
from pytential.symbolic.execution import BoundExpression
from pytential.symbolic.primitives import IntG, Operand, QBXForcedLimit
from pytential.symbolic.primitives import (
IntG,
LowLevelQBXForcedLimit,
Operand,
QBXForcedLimit,
)
from pytential.target import TargetOrDiscretization


Expand Down Expand Up @@ -647,7 +652,7 @@ def get_target_discrs_and_qbx_sides(self,
target_name_and_side_to_number: dict[GeometryId, int] = {}
# list of tuples (discr, qbx_side)
target_discrs_and_qbx_sides: list[
tuple[GeometryLike, Literal[-2, -1, +1, +2, "avg", 0]]
tuple[GeometryLike, LowLevelQBXForcedLimit]
] = []

for o in insn.outputs:
Expand Down Expand Up @@ -906,11 +911,11 @@ def _flat_centers(dofdesc, qbx_forced_limit):

from collections import defaultdict
self_outputs: defaultdict[
tuple[DOFDescriptor, Literal[-1, -2, 0, 1, 2]],
tuple[DOFDescriptor, LowLevelQBXForcedLimit],
list[tuple[int, PotentialOutput]]
] = defaultdict(list)
other_outputs: defaultdict[
tuple[DOFDescriptor, Literal[-1, -2, 0, 1, 2]],
tuple[DOFDescriptor, LowLevelQBXForcedLimit],
list[tuple[int, PotentialOutput]]
] = defaultdict(list)

Expand Down
10 changes: 6 additions & 4 deletions pytential/qbx/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,13 @@ def flat_expansion_radii(self):
from pytential import bind, sym

actx = self._setup_actx
dd = self.source_dd.to_stage1()
radii = bind(self.places,
sym.expansion_radii(
self.ambient_dim,
granularity=sym.GRANULARITY_CENTER,
dofdesc=self.source_dd.to_stage1()))(actx)
sym.interleave(
sym.expansion_radii(self.ambient_dim, dofdesc=dd),
sym.expansion_radii(self.ambient_dim, dofdesc=dd),
dd,
))(actx)

return actx.freeze(flatten(radii, actx))

Expand Down
11 changes: 7 additions & 4 deletions pytential/qbx/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"""

import logging
from typing import cast

import numpy as np

Expand Down Expand Up @@ -324,12 +325,14 @@ def check_expansion_disks_undisturbed_by_sources(self,

from pytential import bind, sym
center_danger_zone_radii = flatten(
bind(
cast("DOFArray", bind(
stage1_density_discr,
sym.interpolate(
sym.interleave(
# These are CSE'd anyway
sym.expansion_radii(stage1_density_discr.ambient_dim),
from_dd=None, to_dd=sym.GRANULARITY_CENTER)
)(self.array_context),
sym.expansion_radii(stage1_density_discr.ambient_dim),
)
)(self.array_context)),
self.array_context)

evt = knl(
Expand Down
46 changes: 29 additions & 17 deletions pytential/qbx/target_assoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@
THE SOFTWARE.
"""

import logging
from typing import TYPE_CHECKING

import numpy as np
from cgen import Enum

from arraycontext import Array, PyOpenCLArrayContext, flatten
from boxtree.area_query import AreaQueryElementwiseTemplate
from boxtree.tools import DeviceDataRecord, InlineBinarySearch
from pytools import memoize_in, memoize_method
from pytools import log_process, memoize_in, memoize_method

from pytential.qbx.utils import (
QBX_TREE_C_PREAMBLE,
Expand All @@ -42,18 +45,17 @@
)


unwrap_args = AreaQueryElementwiseTemplate.unwrap_args

import logging
from typing import TYPE_CHECKING

from pytools import log_process


if TYPE_CHECKING:
from numpy.typing import DTypeLike

from boxtree import Tree
from pyopencl import WaitList

from pytential.collection import GeometryCollection
from pytential.symbolic.dof_desc import DOFDescriptor


unwrap_args = AreaQueryElementwiseTemplate.unwrap_args


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -630,10 +632,17 @@ def mark_targets(self,
return actx.to_numpy(actx.np.all(found_target_close_to_element == 1))

@log_process(logger)
def find_centers(self, places, dofdesc,
tree, peer_lists, target_status, target_flags, target_assoc,
target_association_tolerance,
debug, wait_for=None):
def find_centers(self,
places: GeometryCollection,
dofdesc: DOFDescriptor,
tree: Tree,
peer_lists,
target_status,
target_flags,
target_assoc,
target_association_tolerance,
debug: bool,
wait_for: WaitList = None):
from pytential import bind, sym
ambient_dim = places.ambient_dim
actx = self.array_context
Expand All @@ -658,9 +667,12 @@ def find_centers(self, places, dofdesc,
center_slice = actx.thaw(tree.sorted_target_ids[tree.qbx_user_center_slice])
centers = [actx.thaw(axis)[center_slice] for axis in tree.sources]
expansion_radii_by_center = bind(places,
sym.expansion_radii(ambient_dim,
granularity=sym.GRANULARITY_CENTER,
dofdesc=dofdesc)
sym.interleave(
sym.expansion_radii(ambient_dim, dofdesc=dofdesc),
sym.expansion_radii(ambient_dim, dofdesc=dofdesc),
dofdesc
)

)(actx)
expansion_radii_by_center_with_tolerance = flatten(
expansion_radii_by_center * (1 + target_association_tolerance),
Expand All @@ -681,7 +693,7 @@ def find_centers(self, places, dofdesc,
wait_for=wait_for)
wait_for = [evt]

def make_target_field(fill_val, dtype=tree.coord_dtype):
def make_target_field(fill_val, dtype: DTypeLike = tree.coord_dtype):
arr = actx.np.zeros(tree.nqbxtargets, dtype)
arr.fill(fill_val)
wait_for.extend(arr.events)
Expand Down
4 changes: 2 additions & 2 deletions pytential/symbolic/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,12 @@ class PotentialOutput:
target_name: DOFDescriptor
"""A descriptor for the geometry used by the target kernel."""

# This removes "avg" compared to QBXForcedLimit
qbx_forced_limit: Literal[-2, -1, +1, +2] | None
"""The type of the limiting process used by the QBX expansion (``+1`` if the
output is required to originate from a QBX center on the "+" side of the
boundary. ``-1`` for the other side, etc.).
"""
# This removes "avg" and None compared to QBXForcedLimit


@dataclass(frozen=True, eq=False)
Expand Down Expand Up @@ -736,7 +736,7 @@ def map_common_subexpression(
@override
def map_int_g(
self, expr: IntG, name_hint: str | None = None,
) -> Expression:
) -> ArithmeticExpression:
try:
return self.expr_to_var[expr]
except KeyError:
Expand Down
Loading
Loading