Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 10 additions & 26 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -36153,8 +36153,8 @@
"code": "reportAny",
"range": {
"startColumn": 15,
"endColumn": 83,
"lineCount": 1
"endColumn": 46,
"lineCount": 5
}
},
{
Expand Down Expand Up @@ -36229,43 +36229,27 @@
"lineCount": 1
}
},
{
"code": "reportMissingTypeStubs",
"range": {
"startColumn": 13,
"endColumn": 27,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 13,
"endColumn": 20,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 26,
"endColumn": 40,
"startColumn": 8,
"endColumn": 12,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 26,
"endColumn": 48,
"startColumn": 15,
"endColumn": 23,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 26,
"endColumn": 48,
"startColumn": 40,
"endColumn": 47,
"lineCount": 1
}
},
Expand Down Expand Up @@ -37331,7 +37315,7 @@
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 48,
"endColumn": 60,
"endColumn": 63,
"lineCount": 1
}
},
Expand All @@ -37355,7 +37339,7 @@
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 49,
"endColumn": 61,
"endColumn": 64,
"lineCount": 1
}
},
Expand Down
11 changes: 11 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ jobs:
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project-within-miniconda.sh
. ./build-and-test-py-project-within-miniconda.sh

pytest_symengine_loopy_fft:
name: Conda Pytest Symengine with Loopy FFT
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- name: "Main Script"
run: |
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project-within-miniconda.sh
export SUMPY_FFT_BACKEND=loopy
. ./build-and-test-py-project-within-miniconda.sh

pytest_symengine:
name: Conda Pytest Symengine
runs-on: ubuntu-latest
Expand Down
6 changes: 5 additions & 1 deletion sumpy/test/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ def test_fft(actx_factory: ArrayContextFactory, size: int):
inp_dev = actx.from_numpy(inp)
out = fft(inp)

fft_func = loopy_fft(inp.shape, inverse=False, complex_dtype=inp.dtype.type)
fft_func = loopy_fft(
inp.shape[-1],
n_batch_dims=len(inp.shape) - 1,
inverse=False,
complex_dtype=inp.dtype.type)
_evt, (out_dev,) = fft_func(actx.queue, y=inp_dev)

assert np.allclose(actx.to_numpy(out_dev), out)
Expand Down
80 changes: 43 additions & 37 deletions sumpy/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import loopy as lp
import pytools.obj_array as obj_array
from pymbolic.mapper.dependency import DependencyMapper
from pyopencl import MemoryObjectHolder
from pyopencl.characterize import get_pocl_version
from pytools import memoize_method
from pytools.tag import Tag, tag_dataclass

Expand Down Expand Up @@ -752,7 +752,8 @@ def wait(self):


def loopy_fft(
shape: tuple[int, ...],
n: int,
*, n_batch_dims: int,
inverse: bool,
complex_dtype: DTypeLike,
index_dtype: DTypeLike | None = None,
Expand All @@ -766,7 +767,6 @@ def loopy_fft(
complex_dtype = np.dtype(complex_dtype)

sign = 1 if not inverse else -1
n = shape[-1]

m = n
factors = []
Expand All @@ -776,13 +776,13 @@ def loopy_fft(

nfft = n

broadcast_dims = tuple(var(f"j{d}") for d in range(len(shape) - 1))
batch_dims = tuple(var(f"j{d}") for d in range(n_batch_dims))

domains = [
"{[i]: 0<=i<n}",
"{[i2]: 0<=i2<n}",
]
domains += [f"{{[j{d}]: 0<=j{d}<{shape[d]} }}" for d in range(len(shape) - 1)]
domains += [f"{{[j{d}]: 0<=j{d}<Nbatch{d} }}" for d in range(n_batch_dims)]

x = var("x")
y = var("y")
Expand All @@ -792,7 +792,7 @@ def loopy_fft(

fixed_parameters = {"const": complex_dtype.type(sign*(-2j)*pi/n), "n": int(n)}

index = (*broadcast_dims, i2)
index = (*batch_dims, i2)
insns = [
"exp_table[i] = exp(const * i) {id=exp_table}",
lp.Assignment(
Expand All @@ -819,19 +819,19 @@ def loopy_fft(
table_idx = var(f"table_idx_{ilev}")
exp = var(f"exp_{ilev}")

i_bcast = (*broadcast_dims, i)
i2_bcast = (*broadcast_dims, i2)
iN_bcast = (*broadcast_dims, ifft + nfft * (iN1 * N2 + iN2)) # noqa: N806
i_batch = (*batch_dims, i)
i2_batch = (*batch_dims, i2)
iN_batch = (*batch_dims, ifft + nfft * (iN1 * N2 + iN2)) # noqa: N806

insns += [
lp.Assignment(
assignee=temp[i],
expression=x[i_bcast],
expression=x[i_batch],
id=f"copy_{ilev}",
happens_after=frozenset([init_happens_after]),
),
lp.Assignment(
assignee=x[i2_bcast],
assignee=x[i2_batch],
expression=0,
id=f"reset_{ilev}",
happens_after=frozenset([f"copy_{ilev}"])),
Expand All @@ -847,11 +847,11 @@ def loopy_fft(
id=f"exp_{ilev}",
happens_after=frozenset([f"idx_{ilev}"]),
within_inames=frozenset({x.name for x in
[*broadcast_dims, iN1_sum, iN1, iN2]}),
[*batch_dims, iN1_sum, iN1, iN2]}),
temp_var_type=lp.Optional(complex_dtype)),
lp.Assignment(
assignee=x[iN_bcast],
expression=(x[iN_bcast]
assignee=x[iN_batch],
expression=(x[iN_batch]
+ exp * temp[ifft + nfft * (iN2*N1 + iN1_sum)]),
id=f"update_{ilev}",
happens_after=frozenset([f"exp_{ilev}"])),
Expand All @@ -870,6 +870,7 @@ def loopy_fft(
if not dom.startswith("{"):
domains[idom] = "{" + dom + "}"

shape = (*[var(f"Nbatch{i}") for i in range(n_batch_dims)], n)
kernel_data = [
lp.GlobalArg("x", shape=shape, is_input=False, is_output=True,
dtype=complex_dtype),
Expand All @@ -884,7 +885,7 @@ def loopy_fft(

if n == 1:
domains = domains[2:]
index = (*broadcast_dims, 0)
index = (*batch_dims, 0)
insns = [
lp.Assignment(
assignee=x[index],
Expand All @@ -894,7 +895,7 @@ def loopy_fft(
kernel_data = kernel_data[:2]
elif inverse:
domains += ["{[i3]: 0<=i3<n}"]
index = (*broadcast_dims, i3)
index = (*batch_dims, i3)
insns += [
lp.Assignment(
assignee=x[index],
Expand All @@ -915,10 +916,12 @@ def loopy_fft(
index_dtype=index_dtype,
)

if broadcast_dims:
if batch_dims:
knl = lp.split_iname(knl, "j0", 32, inner_tag="l.0", outer_tag="g.0")
knl = lp.add_inames_for_unused_hw_axes(knl)

knl = lp.preprocess_kernel(knl)
knl = lp.linearize(knl)
return knl


Expand Down Expand Up @@ -956,15 +959,22 @@ def _get_fft_backend(queue: pyopencl.CommandQueue) -> FFTBackend:
import platform
import sys

if (sys.platform == "darwin"
and platform.machine() == "x86_64"
and queue.context.devices[0].platform.name
== "Portable Computing Language"):
warnings.warn(
"PoCL miscompiles some VkFFT kernels. "
"See https://github.com/inducer/sumpy/issues/129. "
"Falling back to slower implementation.", stacklevel=3)
return FFTBackend.loopy
pocl_ver = get_pocl_version(queue.device.platform)
if pocl_ver is not None:
if pocl_ver >= (7,):
warnings.warn(
"PoCL>=7 miscompiles VkFFT. "
"See https://github.com/pocl/pocl/issues/2069 for details. "
"Falling back to slower implementation.", stacklevel=3)
return FFTBackend.loopy

if (sys.platform == "darwin"
and platform.machine() == "x86_64"):
warnings.warn(
"PoCL crashes on some VkFFT kernels on MacOS. "
"See https://github.com/inducer/sumpy/issues/129. "
"Falling back to slower implementation.", stacklevel=3)
return FFTBackend.loopy

return FFTBackend.pyvkfft

Expand All @@ -983,7 +993,11 @@ def get_opencl_fft_app(
backend = _get_fft_backend(queue)

if backend == FFTBackend.loopy:
return loopy_fft(shape, inverse=inverse, complex_dtype=dtype.type), backend
return loopy_fft(
shape[-1],
n_batch_dims=len(shape) - 1,
inverse=inverse,
complex_dtype=dtype.type), backend
elif backend == FFTBackend.pyvkfft:
from pyvkfft.opencl import VkFFTApp
app = VkFFTApp(shape=shape, dtype=dtype, queue=queue, ndim=1, inplace=False)
Expand Down Expand Up @@ -1033,17 +1047,9 @@ def run_opencl_fft(
else:
output_vec = cla.empty_like(input_vec, queue)

# FIXME: use the public API once
# https://github.com/vincefn/pyvkfft/pull/17 is in
from pyvkfft.opencl import _vkfft_opencl
if inverse: # noqa: SIM108
meth = _vkfft_opencl.ifft
else:
meth = _vkfft_opencl.fft
meth = app.ifft if inverse else app.fft

assert isinstance(output_vec.data, MemoryObjectHolder)
meth(app.app, int(input_vec.data.int_ptr),
int(output_vec.data.int_ptr), int(queue.int_ptr))
meth(input_vec, output_vec, queue=queue)

if queue.device.platform.name == "NVIDIA CUDA":
end_evt = cl.enqueue_marker(queue)
Expand Down
4 changes: 2 additions & 2 deletions sumpy/toys.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,9 @@ def _m2l(psource, to_center, to_rscale, to_order: int, e2e, expn_class, expn_kwa
from sumpy.tools import get_native_event, get_opencl_fft_app, run_opencl_fft

if toy_ctx.use_fft:
fft_app = get_opencl_fft_app(queue, (expn_size,),
fft_app = get_opencl_fft_app(queue, (1, expn_size,),
dtype=preprocessed_src_expansions.dtype, inverse=False)
ifft_app = get_opencl_fft_app(queue, (expn_size,),
ifft_app = get_opencl_fft_app(queue, (1, expn_size,),
dtype=preprocessed_src_expansions.dtype, inverse=True)

evt, preprocessed_src_expansions = run_opencl_fft(fft_app, queue,
Expand Down
Loading