Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,774 changes: 699 additions & 2,075 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
# https://github.com/jorenham/optype/issues/430
["py:class", r"optype.*"],
["py:class", r"onp.*"],
# sphinx >= 9.0 errors
["py:class", r"_not_provided"],
["py:class", r"Callable\[.*"],
]


Expand Down Expand Up @@ -81,6 +84,7 @@
"ArrayContainer": "obj:arraycontext.ArrayContainer",
"ArrayOrContainerOrScalar": "obj:arraycontext.ArrayOrContainerOrScalar",
"ArrayOrContainerT": "obj:arraycontext.ArrayOrContainerT",
"arraycontext.typing.ArrayOrContainerT": "obj:arraycontext.ArrayOrContainerT",
"PyOpenCLArrayContext": "class:arraycontext.PyOpenCLArrayContext",
"ScalarLike": "obj:arraycontext.ScalarLike",
# modepy
Expand All @@ -90,6 +94,7 @@
"DOFArray": "class:meshmode.dof_array.DOFArray",
"ElementGroupFactory": "class:meshmode.discretization.ElementGroupFactory",
# boxtree
"ExtentNorm": "obj:boxtree.tree_build.ExtentNorm",
"FromSepSmallerCrit": "obj:boxtree.traversal.FromSepSmallerCrit",
"TimingResult": "class:boxtree.timing.TimingResult",
"TreeKind": "obj:boxtree.tree_build.TreeKind",
Expand All @@ -108,6 +113,7 @@
"DOFGranularity": "data:pytential.symbolic.dof_desc.DOFGranularity",
"DiscretizationStage": "data:pytential.symbolic.dof_desc.DiscretizationStage",
"ExpressionNode": "class:pytential.symbolic.primitives.ExpressionNode",
"FMMBackend": "obj:pytential.qbx.FMMBackend",
"GeometryId": "data:pytential.symbolic.dof_desc.GeometryId",
"KernelArgumentLike": "obj:pytential.symbolic.primitives.KernelArgumentLike",
"KernelArgumentMapping": "obj:pytential.symbolic.primitives.KernelArgumentMapping",
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ exclude = [
]

[tool.ruff.lint.per-file-ignores]
"doc/*.py" = ["I002"]
"doc/*.py" = ["I002", "S102"]
"examples/*.py" = ["I002"]
"test/test_*.py" = ["S102"]

[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"
Expand Down
2 changes: 1 addition & 1 deletion pytential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def norm(
norm_op = _norm_2_op(discr, num_components)
return norm_op(integrand=x)**(1/2)

elif p == np.inf or p == "inf":
elif p in {np.inf, "inf"}:
norm_op = _norm_inf_op(discr, num_components)

# FIXME: norm_op (correctly) becomes BoundExpression[Operand], but
Expand Down
4 changes: 1 addition & 3 deletions pytential/linalg/direct_solver_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,10 @@ def _prepare_expr(expr: ArithmeticExpression) -> ArithmeticExpression:
# ensure all IntGs remove all the kernel derivatives
expr = KernelTransformationRemover()(expr)
# ensure all IntGs have their source and targets set
expr = DOFDescriptorReplacer(
return DOFDescriptorReplacer(
default_source=auto_where[0],
default_target=auto_where[1]).rec_arith(expr)

return expr

return obj_array.new_1d([_prepare_expr(expr) for expr in exprs])

# }}}
Expand Down
157 changes: 112 additions & 45 deletions pytential/linalg/gmres.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,64 @@

__doc__ = """
.. autofunction:: gmres

.. autoclass:: GMRESResult
.. autoexception:: GMRESError
.. autoclass:: ResidualPrinter

.. autoclass:: InnerProduct
:members:
:undoc-members:
:special-members: __call__
.. autoclass:: CallableOperator
:members:
:undoc-members:
:special-members: __call__
.. autoclass:: HasMatVec
:members:
:undoc-members:
"""

from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Generic, Protocol

import numpy as np

from arraycontext import ArrayContext, ArrayOrContainerT
from pytools import T


if TYPE_CHECKING:
from collections.abc import Callable, Sequence

from arraycontext import ArrayContainer, ArrayOrContainerT

class InnerProduct(Protocol, Generic[T]):
"""A :class:`~typing.Protocol` for the inner product used by :func:`gmres`."""

def __call__(self, a: T, b: T) -> T: ...


class CallableOperator(Protocol, Generic[T]):
"""A :class:`~typing.Protocol` for the operator used by :func:`gmres`."""

@property
def shape(self) -> tuple[int, int]: ...

def __call__(self, x: T) -> T: ...

def structured_vdot(x, y, array_context=None):

class HasMatVec(Protocol, Generic[T]):
"""A :class:`~typing.Protocol` for the operator used by :func:`gmres`."""

@property
def shape(self) -> tuple[int, int]: ...

def matvec(self, x: T) -> T: ...


def structured_vdot(x: ArrayOrContainerT, y: ArrayOrContainerT,
array_context: ArrayContext | None = None) -> float:
"""vdot() implementation that is aware of scalars and host or
PyOpenCL arrays. It also recurses down nested object arrays.
"""
Expand All @@ -57,6 +96,9 @@ def structured_vdot(x, y, array_context=None):
or (isinstance(x, np.ndarray) and x.dtype.char != "O")):
return np.vdot(x, y)
else:
if array_context is None:
raise ValueError("'array_context' is required for non-scalar inputs")

# actx.np.vdot works on PyOpenCL arrays and arbitrarily nested
# array containers, so this should handle all remaining cases
r = array_context.to_numpy(array_context.np.vdot(x, y))
Expand All @@ -81,41 +123,50 @@ class GMRESError(RuntimeError):
# {{{ main routine

@dataclass(frozen=True)
class GMRESResult:
class GMRESResult(Generic[T]):
"""
.. attribute:: solution
.. attribute:: residual_norms
.. attribute:: iteration_count
.. attribute:: success

A :class:`bool` indicating whether the iteration succeeded.

.. attribute:: state

A description of the outcome.
.. autoattribute:: solution
.. autoattribute:: residual_norms
.. autoattribute:: iteration_count
.. autoattribute:: success
.. autoattribute:: state
"""

solution: ArrayContainer
solution: T
residual_norms: Sequence[float]
iteration_count: int
success: bool
"""A :class:`bool` indicating whether the iteration succeeded."""
state: str
"""A description of the outcome."""


def _gmres(A, b, restart=None, tol=None, x0=None, dot=None,
maxiter=None, hard_failure=None, require_monotonicity=True,
no_progress_factor=None, stall_iterations=None,
callback=None):
def _gmres(
A: CallableOperator[ArrayOrContainerT] | HasMatVec[ArrayOrContainerT],
b: ArrayOrContainerT,
restart: int | None = None,
tol: float | None = None,
x0: ArrayOrContainerT | None = None,
dot: InnerProduct[ArrayOrContainerT] | None = None,
maxiter: int | None = None,
hard_failure: bool | None = None,
require_monotonicity: bool = True,
no_progress_factor: float | None = None,
stall_iterations: int | None = None,
callback: Callable[[ArrayOrContainerT], None] | None = None
) -> GMRESResult[ArrayOrContainerT]:

# {{{ input processing

n, _ = A.shape

if not callable(A):
a_call = A.matvec
else:
a_call = A

if dot is None:
raise ValueError("'dot' not provided")

if restart is None:
restart = min(n, 20)

Expand All @@ -130,32 +181,34 @@ def _gmres(A, b, restart=None, tol=None, x0=None, dot=None,

if stall_iterations is None:
stall_iterations = 10

if no_progress_factor is None:
no_progress_factor = 1.25

# }}}

def norm(x):
def norm(x: ArrayOrContainerT) -> float:
return np.sqrt(abs(dot(x, x)))

if x0 is None:
x = 0*b
x: ArrayOrContainerT = 0*b
r = b
recalc_r = False
else:
x = x0
del x0
recalc_r = True

Ae = [None]*restart
e = [None]*restart
Ae: list[ArrayOrContainerT] = [None]*restart
e: list[ArrayOrContainerT] = [None]*restart

k = 0

norm_b = norm(b)
last_resid_norm = None
residual_norms = []
residual_norms: list[float] = []

iteration = 0
for iteration in range(maxiter):
# restart if required
if k == restart:
Expand All @@ -175,9 +228,11 @@ def norm(x):
callback(r)

if norm_r < tol*norm_b or norm_r == 0:
return GMRESResult(solution=x,
return GMRESResult(
solution=x,
residual_norms=residual_norms,
iteration_count=iteration, success=True,
iteration_count=iteration,
success=True,
state="success")
if last_resid_norm is not None:
if norm_r > 1.25*last_resid_norm:
Expand All @@ -186,9 +241,11 @@ def norm(x):
if hard_failure:
raise GMRESError(state)
else:
return GMRESResult(solution=x,
return GMRESResult(
solution=x,
residual_norms=residual_norms,
iteration_count=iteration, success=False,
iteration_count=iteration,
success=False,
state=state)
else:
print("*** WARNING: non-monotonic residuals in GMRES")
Expand All @@ -203,9 +260,11 @@ def norm(x):
if hard_failure:
raise GMRESError(state)
else:
return GMRESResult(solution=x,
return GMRESResult(
solution=x,
residual_norms=residual_norms,
iteration_count=iteration, success=False,
iteration_count=iteration,
success=False,
state=state)

last_resid_norm = norm_r
Expand All @@ -218,7 +277,7 @@ def norm(x):
rp = r

for _orth_trips in range(2):
for j in range(0, orth_count):
for j in range(orth_count):
d = dot(Ae[j], w)
w = w - d * Ae[j]
rp = rp - d * e[j]
Expand Down Expand Up @@ -248,28 +307,40 @@ def norm(x):
if hard_failure:
raise GMRESError(state)
else:
return GMRESResult(solution=x,
return GMRESResult(
solution=x,
residual_norms=residual_norms,
iteration_count=iteration, success=False,
iteration_count=iteration,
success=False,
state=state)

# }}}


# {{{ progress reporting

class ResidualPrinter:
def __init__(self, inner_product=structured_vdot):
class ResidualPrinter(Generic[ArrayOrContainerT]):
count: int
inner_product: InnerProduct[ArrayOrContainerT]

def __init__(
self,
inner_product: InnerProduct[ArrayOrContainerT] | None = None
) -> None:
if inner_product is None:
inner_product = structured_vdot

self.count = 0
self.inner_product = inner_product

def __call__(self, resid):
def __call__(self, resid: ArrayOrContainerT | None) -> None:
import sys
if resid is not None:
norm = np.sqrt(self.inner_product(resid, resid))
sys.stdout.write(f"IT {self.count:8d} {abs(norm):.8e}\n")
else:
sys.stdout.write(f"IT {self.count:8d}\n")

self.count += 1
sys.stdout.flush()

Expand All @@ -279,20 +350,19 @@ def __call__(self, resid):
# {{{ entrypoint

def gmres(
op: Callable[[ArrayOrContainerT], ArrayOrContainerT],
op: CallableOperator[ArrayOrContainerT] | HasMatVec[ArrayOrContainerT],
rhs: ArrayOrContainerT,
restart: int | None = None,
tol: float | None = None,
x0: ArrayOrContainerT | None = None,
inner_product: (
Callable[[ArrayOrContainerT, ArrayOrContainerT], float] | None) = None,
inner_product: InnerProduct[ArrayOrContainerT] | None = None,
maxiter: int | None = None,
hard_failure: bool | None = None,
no_progress_factor: float | None = None,
stall_iterations: int | None = None,
callback: Callable[[ArrayOrContainerT], None] | None = None,
progress: bool = False,
require_monotonicity: bool = True) -> GMRESResult:
require_monotonicity: bool = True) -> GMRESResult[ArrayOrContainerT]:
"""Solve a linear system :math:`Ax = b` using GMRES with restarts.

:arg op: a callable to evaluate :math:`A(x)`.
Expand All @@ -308,8 +378,6 @@ def gmres(
:arg stall_iterations: number of iterations with residual decrease
below *no_progress_factor* indicates stall. Set to ``0`` to disable
stall detection.

:return: a :class:`GMRESResult`.
"""
if inner_product is None:
from pytential.symbolic.execution import (
Expand All @@ -330,14 +398,13 @@ def gmres(
else:
callback = None

result = _gmres(op, rhs, restart=restart, tol=tol, x0=x0,
return _gmres(op, rhs, restart=restart, tol=tol, x0=x0,
dot=inner_product,
maxiter=maxiter, hard_failure=hard_failure,
no_progress_factor=no_progress_factor,
stall_iterations=stall_iterations, callback=callback,
require_monotonicity=require_monotonicity)

return result

# }}}

Expand Down
Loading
Loading