Skip to content

Commit 87de57e

Browse files
committed
Add external stream caching for improved CUDA performance and refactor interpolation calculations for clarity
1 parent e618e1d commit 87de57e

1 file changed

Lines changed: 77 additions & 50 deletions

File tree

diffct/differentiable.py

Lines changed: 77 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import numpy as np
33
import torch
44
from numba import cuda
5-
from functools import lru_cache
65

76
# ---------------------------------------------------------------------------
87
# Global settings & helpers
@@ -108,9 +107,31 @@ def tensor_to_cuda_array(tensor):
108107
raise ValueError("Tensor must be on CUDA device")
109108
return cuda.as_cuda_array(tensor.detach())
110109

110+
# ---------------------------------------------------------------------------
111+
# Stream helper (cached external Numba stream)
112+
# ---------------------------------------------------------------------------
113+
_cached_stream_ptr = None
114+
_cached_numba_stream = None
115+
116+
def _get_numba_external_stream_for(pt_stream=None):
117+
"""
118+
Return a cached numba.cuda.external_stream for the current PyTorch CUDA stream.
119+
Caches by the underlying CUDA stream pointer to avoid repeated construction.
120+
"""
121+
global _cached_stream_ptr, _cached_numba_stream
122+
if pt_stream is None:
123+
pt_stream = torch.cuda.current_stream()
124+
# Torch exposes an underlying CUDA stream handle via .cuda_stream
125+
ptr = int(pt_stream.cuda_stream)
126+
if _cached_stream_ptr == ptr and _cached_numba_stream is not None:
127+
return _cached_numba_stream
128+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
129+
_cached_stream_ptr = ptr
130+
_cached_numba_stream = numba_stream
131+
return numba_stream
111132

112133
# === GPU-aware Trigonometric Table Generation ===
113-
@lru_cache(maxsize=2048)
134+
# Caching removed: torch.Tensor is unhashable for lru_cache
114135
def _trig_tables(angles, dtype=_DTYPE, device=None):
115136
"""Compute cosine and sine tables for input angles.
116137
@@ -465,12 +486,15 @@ def _parallel_2d_forward_kernel(
465486
# Mathematical basis: Bilinear interpolation formula f(x,y) = Σ f(xi,yi) * wi(x,y)
466487
# where wi(x,y) are the bilinear basis functions for each corner voxel
467488
# Weights are products of 1D linear interpolation weights: (1-dx) or dx, (1-dy) or dy
468-
val = (
469-
d_image[iy0, ix0] * (1 - dx) * (1 - dy) +
470-
d_image[iy0, ix0 + 1] * dx * (1 - dy) +
471-
d_image[iy0 + 1, ix0] * (1 - dx) * dy +
472-
d_image[iy0 + 1, ix0 + 1] * dx * dy
473-
)
489+
one_minus_dx = 1.0 - dx
490+
one_minus_dy = 1.0 - dy
491+
v00 = d_image[iy0, ix0]
492+
v10 = d_image[iy0, ix0 + 1]
493+
v01 = d_image[iy0 + 1, ix0]
494+
v11 = d_image[iy0 + 1, ix0 + 1]
495+
row0 = (v00 * one_minus_dx + v10 * dx) * one_minus_dy
496+
row1 = (v01 * one_minus_dx + v11 * dx) * dy
497+
val = row0 + row1
474498
# Accumulate contribution weighted by ray segment length (discrete line integral approximation)
475499
# This implements the Radon transform: integral of f(x,y) along the ray path
476500
accum += val * seg_len
@@ -756,12 +780,15 @@ def _fan_2d_forward_kernel(
756780
iy0 = max(0, min(iy0, Ny - 2))
757781

758782
# Bilinear interpolation (identical to parallel beam)
759-
val = (
760-
d_image[iy0, ix0] * (1 - dx) * (1 - dy) +
761-
d_image[iy0, ix0 + 1] * dx * (1 - dy) +
762-
d_image[iy0 + 1, ix0] * (1 - dx) * dy +
763-
d_image[iy0 + 1, ix0 + 1] * dx * dy
764-
)
783+
one_minus_dx = 1.0 - dx
784+
one_minus_dy = 1.0 - dy
785+
v00 = d_image[iy0, ix0]
786+
v10 = d_image[iy0, ix0 + 1]
787+
v01 = d_image[iy0 + 1, ix0]
788+
v11 = d_image[iy0 + 1, ix0 + 1]
789+
row0 = (v00 * one_minus_dx + v10 * dx) * one_minus_dy
790+
row1 = (v01 * one_minus_dx + v11 * dx) * dy
791+
val = row0 + row1
765792
accum += val * seg_len
766793

767794
# Voxel boundary crossing logic (identical to parallel beam)
@@ -1405,8 +1432,8 @@ def forward(ctx, image, angles, num_detectors, detector_spacing=1.0, voxel_spaci
14051432
angles = DeviceManager.ensure_device(angles, device)
14061433

14071434
# Ensure input is float32 for kernel compatibility
1408-
image = image.to(dtype=torch.float32)
1409-
angles = angles.to(dtype=torch.float32)
1435+
image = image.to(dtype=torch.float32).contiguous()
1436+
angles = angles.to(dtype=torch.float32).contiguous()
14101437

14111438
Ny, Nx = image.shape
14121439
n_angles = angles.shape[0]
@@ -1427,7 +1454,7 @@ def forward(ctx, image, angles, num_detectors, detector_spacing=1.0, voxel_spaci
14271454
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
14281455

14291456
pt_stream = torch.cuda.current_stream()
1430-
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1457+
numba_stream = _get_numba_external_stream_for(pt_stream)
14311458
_parallel_2d_forward_kernel[grid, tpb, numba_stream](
14321459
d_image, Nx, Ny, d_sino, n_angles, num_detectors,
14331460
_DTYPE(detector_spacing), d_cos_arr, d_sin_arr, cx, cy, _DTYPE(voxel_spacing)
@@ -1445,8 +1472,8 @@ def backward(ctx, grad_sinogram):
14451472
grad_sinogram = DeviceManager.ensure_device(grad_sinogram, device)
14461473
angles = DeviceManager.ensure_device(angles, device)
14471474

1448-
grad_sinogram = grad_sinogram.to(dtype=torch.float32)
1449-
angles = angles.to(dtype=torch.float32)
1475+
grad_sinogram = grad_sinogram.to(dtype=torch.float32).contiguous()
1476+
angles = angles.to(dtype=torch.float32).contiguous()
14501477

14511478
n_angles = angles.shape[0]
14521479
grad_image = torch.zeros((Ny, Nx), dtype=grad_sinogram.dtype, device=device)
@@ -1461,7 +1488,7 @@ def backward(ctx, grad_sinogram):
14611488
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
14621489

14631490
pt_stream = torch.cuda.current_stream()
1464-
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1491+
numba_stream = _get_numba_external_stream_for(pt_stream)
14651492
_parallel_2d_backward_kernel[grid, tpb, numba_stream](
14661493
d_grad_sino, n_angles, num_detectors,
14671494
d_img_grad, Nx, Ny,
@@ -1542,8 +1569,8 @@ def forward(ctx, sinogram, angles, detector_spacing=1.0, H=128, W=128, voxel_spa
15421569
angles = DeviceManager.ensure_device(angles, device)
15431570

15441571
# Ensure input is float32 for kernel compatibility
1545-
sinogram = sinogram.to(dtype=torch.float32)
1546-
angles = angles.to(dtype=torch.float32)
1572+
sinogram = sinogram.to(dtype=torch.float32).contiguous()
1573+
angles = angles.to(dtype=torch.float32).contiguous()
15471574

15481575
n_ang, n_det = sinogram.shape
15491576
Ny, Nx = H, W
@@ -1564,7 +1591,7 @@ def forward(ctx, sinogram, angles, detector_spacing=1.0, H=128, W=128, voxel_spa
15641591
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
15651592

15661593
pt_stream = torch.cuda.current_stream()
1567-
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1594+
numba_stream = _get_numba_external_stream_for(pt_stream)
15681595
_parallel_2d_backward_kernel[grid, tpb, numba_stream](
15691596
d_sino, n_ang, n_det, d_reco, Nx, Ny,
15701597
_DTYPE(detector_spacing), d_cos_arr, d_sin_arr, cx, cy, _DTYPE(voxel_spacing)
@@ -1582,8 +1609,8 @@ def backward(ctx, grad_output):
15821609
grad_output = DeviceManager.ensure_device(grad_output, device)
15831610
angles = DeviceManager.ensure_device(angles, device)
15841611

1585-
grad_output = grad_output.to(dtype=torch.float32)
1586-
angles = angles.to(dtype=torch.float32)
1612+
grad_output = grad_output.to(dtype=torch.float32).contiguous()
1613+
angles = angles.to(dtype=torch.float32).contiguous()
15871614

15881615
Ny, Nx = grad_output.shape
15891616

@@ -1603,7 +1630,7 @@ def backward(ctx, grad_output):
16031630
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
16041631

16051632
pt_stream = torch.cuda.current_stream()
1606-
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1633+
numba_stream = _get_numba_external_stream_for(pt_stream)
16071634
_parallel_2d_forward_kernel[grid, tpb, numba_stream](
16081635
d_grad_out, Nx, Ny, d_sino_grad, n_ang, n_det,
16091636
_DTYPE(detector_spacing), d_cos, d_sin, cx, cy, _DTYPE(voxel_spacing)
@@ -1687,8 +1714,8 @@ def forward(ctx, image, angles, num_detectors, detector_spacing, sdd, sid, voxel
16871714
image = DeviceManager.ensure_device(image, device)
16881715
angles = DeviceManager.ensure_device(angles, device)
16891716

1690-
image = image.to(dtype=torch.float32)
1691-
angles = angles.to(dtype=torch.float32)
1717+
image = image.to(dtype=torch.float32).contiguous()
1718+
angles = angles.to(dtype=torch.float32).contiguous()
16921719

16931720
Ny, Nx = image.shape
16941721
n_ang = angles.shape[0]
@@ -1705,7 +1732,7 @@ def forward(ctx, image, angles, num_detectors, detector_spacing, sdd, sid, voxel
17051732
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
17061733

17071734
pt_stream = torch.cuda.current_stream()
1708-
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1735+
numba_stream = _get_numba_external_stream_for(pt_stream)
17091736
_fan_2d_forward_kernel[grid, tpb, numba_stream](
17101737
d_image, Nx, Ny, d_sino, n_ang, num_detectors,
17111738
_DTYPE(detector_spacing), d_cos_arr, d_sin_arr,
@@ -1725,8 +1752,8 @@ def backward(ctx, grad_sinogram):
17251752
grad_sinogram = DeviceManager.ensure_device(grad_sinogram, device)
17261753
angles = DeviceManager.ensure_device(angles, device)
17271754

1728-
grad_sinogram = grad_sinogram.to(dtype=torch.float32)
1729-
angles = angles.to(dtype=torch.float32)
1755+
grad_sinogram = grad_sinogram.to(dtype=torch.float32).contiguous()
1756+
angles = angles.to(dtype=torch.float32).contiguous()
17301757

17311758
n_ang = angles.shape[0]
17321759
grad_img = torch.zeros((Ny, Nx), dtype=grad_sinogram.dtype, device=device)
@@ -1741,7 +1768,7 @@ def backward(ctx, grad_sinogram):
17411768
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
17421769

17431770
pt_stream = torch.cuda.current_stream()
1744-
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1771+
numba_stream = _get_numba_external_stream_for(pt_stream)
17451772
_fan_2d_backward_kernel[grid, tpb, numba_stream](
17461773
d_grad_sino, n_ang, n_det, d_img_grad, Nx, Ny,
17471774
_DTYPE(det_spacing), d_cos_arr, d_sin_arr,
@@ -1829,8 +1856,8 @@ def forward(ctx, sinogram, angles, detector_spacing, H, W, sdd, sid, voxel_spaci
18291856
sinogram = DeviceManager.ensure_device(sinogram, device)
18301857
angles = DeviceManager.ensure_device(angles, device)
18311858

1832-
sinogram = sinogram.to(dtype=torch.float32)
1833-
angles = angles.to(dtype=torch.float32)
1859+
sinogram = sinogram.to(dtype=torch.float32).contiguous()
1860+
angles = angles.to(dtype=torch.float32).contiguous()
18341861

18351862
n_ang, n_det = sinogram.shape
18361863
Ny, Nx = H, W
@@ -1847,7 +1874,7 @@ def forward(ctx, sinogram, angles, detector_spacing, H, W, sdd, sid, voxel_spaci
18471874
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
18481875

18491876
pt_stream = torch.cuda.current_stream()
1850-
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1877+
numba_stream = _get_numba_external_stream_for(pt_stream)
18511878
_fan_2d_backward_kernel[grid, tpb, numba_stream](
18521879
d_sino, n_ang, n_det, d_reco, Nx, Ny,
18531880
_DTYPE(detector_spacing), d_cos_arr, d_sin_arr,
@@ -1866,8 +1893,8 @@ def backward(ctx, grad_output):
18661893
grad_output = DeviceManager.ensure_device(grad_output, device)
18671894
angles = DeviceManager.ensure_device(angles, device)
18681895

1869-
grad_output = grad_output.to(dtype=torch.float32)
1870-
angles = angles.to(dtype=torch.float32)
1896+
grad_output = grad_output.to(dtype=torch.float32).contiguous()
1897+
angles = angles.to(dtype=torch.float32).contiguous()
18711898

18721899
Ny, Nx = grad_output.shape
18731900

@@ -1883,7 +1910,7 @@ def backward(ctx, grad_output):
18831910
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
18841911

18851912
pt_stream = torch.cuda.current_stream()
1886-
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1913+
numba_stream = _get_numba_external_stream_for(pt_stream)
18871914
_fan_2d_forward_kernel[grid, tpb, numba_stream](
18881915
d_grad_out, Nx, Ny, d_sino_grad, n_ang, n_det,
18891916
_DTYPE(det_spacing), d_cos_arr, d_sin_arr,
@@ -1972,8 +1999,8 @@ def forward(ctx, volume, angles, det_u, det_v, du, dv, sdd, sid, voxel_spacing=1
19721999
volume = DeviceManager.ensure_device(volume, device)
19732000
angles = DeviceManager.ensure_device(angles, device)
19742001

1975-
volume = volume.to(dtype=torch.float32)
1976-
angles = angles.to(dtype=torch.float32)
2002+
volume = volume.to(dtype=torch.float32).contiguous()
2003+
angles = angles.to(dtype=torch.float32).contiguous()
19772004

19782005
D, H, W = volume.shape
19792006
n_views = angles.shape[0]
@@ -1994,7 +2021,7 @@ def forward(ctx, volume, angles, det_u, det_v, du, dv, sdd, sid, voxel_spacing=1
19942021
cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5)
19952022

19962023
pt_stream = torch.cuda.current_stream()
1997-
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
2024+
numba_stream = _get_numba_external_stream_for(pt_stream)
19982025
_cone_3d_forward_kernel[grid, tpb, numba_stream](
19992026
d_vol, W, H, D, d_sino, n_views, det_u, det_v,
20002027
_DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr,
@@ -2016,8 +2043,8 @@ def backward(ctx, grad_sinogram):
20162043
grad_sinogram = DeviceManager.ensure_device(grad_sinogram, device)
20172044
angles = DeviceManager.ensure_device(angles, device)
20182045

2019-
grad_sinogram = grad_sinogram.to(dtype=torch.float32)
2020-
angles = angles.to(dtype=torch.float32)
2046+
grad_sinogram = grad_sinogram.to(dtype=torch.float32).contiguous()
2047+
angles = angles.to(dtype=torch.float32).contiguous()
20212048

20222049
n_views = angles.shape[0]
20232050

@@ -2033,7 +2060,7 @@ def backward(ctx, grad_sinogram):
20332060
cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5)
20342061

20352062
pt_stream = torch.cuda.current_stream()
2036-
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
2063+
numba_stream = _get_numba_external_stream_for(pt_stream)
20372064
_cone_3d_backward_kernel[grid, tpb, numba_stream](
20382065
d_grad_sino, n_views, det_u, det_v, d_vol_grad, W, H, D,
20392066
_DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr,
@@ -2134,8 +2161,8 @@ def forward(ctx, sinogram, angles, D, H, W, du, dv, sdd, sid, voxel_spacing=1.0)
21342161
sinogram = DeviceManager.ensure_device(sinogram, device)
21352162
angles = DeviceManager.ensure_device(angles, device)
21362163

2137-
sinogram = sinogram.to(dtype=torch.float32)
2138-
angles = angles.to(dtype=torch.float32)
2164+
sinogram = sinogram.to(dtype=torch.float32).contiguous()
2165+
angles = angles.to(dtype=torch.float32).contiguous()
21392166

21402167
n_views, n_u, n_v = sinogram.shape
21412168

@@ -2154,7 +2181,7 @@ def forward(ctx, sinogram, angles, D, H, W, du, dv, sdd, sid, voxel_spacing=1.0)
21542181
cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5)
21552182

21562183
pt_stream = torch.cuda.current_stream()
2157-
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
2184+
numba_stream = _get_numba_external_stream_for(pt_stream)
21582185
_cone_3d_backward_kernel[grid, tpb, numba_stream](
21592186
d_sino, n_views, n_u, n_v, d_reco, W, H, D,
21602187
_DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr,
@@ -2176,8 +2203,8 @@ def backward(ctx, grad_output):
21762203
grad_output = DeviceManager.ensure_device(grad_output, device)
21772204
angles = DeviceManager.ensure_device(angles, device)
21782205

2179-
grad_output = grad_output.to(dtype=torch.float32)
2180-
angles = angles.to(dtype=torch.float32)
2206+
grad_output = grad_output.to(dtype=torch.float32).contiguous()
2207+
angles = angles.to(dtype=torch.float32).contiguous()
21812208

21822209
n_views = angles.shape[0]
21832210

@@ -2194,7 +2221,7 @@ def backward(ctx, grad_output):
21942221
cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5)
21952222

21962223
pt_stream = torch.cuda.current_stream()
2197-
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
2224+
numba_stream = _get_numba_external_stream_for(pt_stream)
21982225
_cone_3d_forward_kernel[grid, tpb, numba_stream](
21992226
d_grad_out, W, H, D, d_sino_grad, n_views, n_u, n_v,
22002227
_DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr,

0 commit comments

Comments
 (0)