22import numpy as np
33import torch
44from numba import cuda
5- from functools import lru_cache
65
76# ---------------------------------------------------------------------------
87# Global settings & helpers
@@ -108,9 +107,31 @@ def tensor_to_cuda_array(tensor):
108107 raise ValueError ("Tensor must be on CUDA device" )
109108 return cuda .as_cuda_array (tensor .detach ())
110109
110+ # ---------------------------------------------------------------------------
111+ # Stream helper (cached external Numba stream)
112+ # ---------------------------------------------------------------------------
113+ _cached_stream_ptr = None
114+ _cached_numba_stream = None
115+
116+ def _get_numba_external_stream_for (pt_stream = None ):
117+ """
118+ Return a cached numba.cuda.external_stream for the current PyTorch CUDA stream.
119+ Caches by the underlying CUDA stream pointer to avoid repeated construction.
120+ """
121+ global _cached_stream_ptr , _cached_numba_stream
122+ if pt_stream is None :
123+ pt_stream = torch .cuda .current_stream ()
124+ # Torch exposes an underlying CUDA stream handle via .cuda_stream
125+ ptr = int (pt_stream .cuda_stream )
126+ if _cached_stream_ptr == ptr and _cached_numba_stream is not None :
127+ return _cached_numba_stream
128+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
129+ _cached_stream_ptr = ptr
130+ _cached_numba_stream = numba_stream
131+ return numba_stream
111132
112133# === GPU-aware Trigonometric Table Generation ===
113- @ lru_cache ( maxsize = 2048 )
134+ # Caching removed: torch.Tensor is unhashable for lru_cache
114135def _trig_tables (angles , dtype = _DTYPE , device = None ):
115136 """Compute cosine and sine tables for input angles.
116137
@@ -465,12 +486,15 @@ def _parallel_2d_forward_kernel(
465486 # Mathematical basis: Bilinear interpolation formula f(x,y) = Σ f(xi,yi) * wi(x,y)
466487 # where wi(x,y) are the bilinear basis functions for each corner voxel
467488 # Weights are products of 1D linear interpolation weights: (1-dx) or dx, (1-dy) or dy
468- val = (
469- d_image [iy0 , ix0 ] * (1 - dx ) * (1 - dy ) +
470- d_image [iy0 , ix0 + 1 ] * dx * (1 - dy ) +
471- d_image [iy0 + 1 , ix0 ] * (1 - dx ) * dy +
472- d_image [iy0 + 1 , ix0 + 1 ] * dx * dy
473- )
489+ one_minus_dx = 1.0 - dx
490+ one_minus_dy = 1.0 - dy
491+ v00 = d_image [iy0 , ix0 ]
492+ v10 = d_image [iy0 , ix0 + 1 ]
493+ v01 = d_image [iy0 + 1 , ix0 ]
494+ v11 = d_image [iy0 + 1 , ix0 + 1 ]
495+ row0 = (v00 * one_minus_dx + v10 * dx ) * one_minus_dy
496+ row1 = (v01 * one_minus_dx + v11 * dx ) * dy
497+ val = row0 + row1
474498 # Accumulate contribution weighted by ray segment length (discrete line integral approximation)
475499 # This implements the Radon transform: integral of f(x,y) along the ray path
476500 accum += val * seg_len
@@ -756,12 +780,15 @@ def _fan_2d_forward_kernel(
756780 iy0 = max (0 , min (iy0 , Ny - 2 ))
757781
758782 # Bilinear interpolation (identical to parallel beam)
759- val = (
760- d_image [iy0 , ix0 ] * (1 - dx ) * (1 - dy ) +
761- d_image [iy0 , ix0 + 1 ] * dx * (1 - dy ) +
762- d_image [iy0 + 1 , ix0 ] * (1 - dx ) * dy +
763- d_image [iy0 + 1 , ix0 + 1 ] * dx * dy
764- )
783+ one_minus_dx = 1.0 - dx
784+ one_minus_dy = 1.0 - dy
785+ v00 = d_image [iy0 , ix0 ]
786+ v10 = d_image [iy0 , ix0 + 1 ]
787+ v01 = d_image [iy0 + 1 , ix0 ]
788+ v11 = d_image [iy0 + 1 , ix0 + 1 ]
789+ row0 = (v00 * one_minus_dx + v10 * dx ) * one_minus_dy
790+ row1 = (v01 * one_minus_dx + v11 * dx ) * dy
791+ val = row0 + row1
765792 accum += val * seg_len
766793
767794 # Voxel boundary crossing logic (identical to parallel beam)
@@ -1405,8 +1432,8 @@ def forward(ctx, image, angles, num_detectors, detector_spacing=1.0, voxel_spaci
14051432 angles = DeviceManager .ensure_device (angles , device )
14061433
14071434 # Ensure input is float32 for kernel compatibility
1408- image = image .to (dtype = torch .float32 )
1409- angles = angles .to (dtype = torch .float32 )
1435+ image = image .to (dtype = torch .float32 ). contiguous ()
1436+ angles = angles .to (dtype = torch .float32 ). contiguous ()
14101437
14111438 Ny , Nx = image .shape
14121439 n_angles = angles .shape [0 ]
@@ -1427,7 +1454,7 @@ def forward(ctx, image, angles, num_detectors, detector_spacing=1.0, voxel_spaci
14271454 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
14281455
14291456 pt_stream = torch .cuda .current_stream ()
1430- numba_stream = cuda . external_stream (pt_stream . cuda_stream )
1457+ numba_stream = _get_numba_external_stream_for (pt_stream )
14311458 _parallel_2d_forward_kernel [grid , tpb , numba_stream ](
14321459 d_image , Nx , Ny , d_sino , n_angles , num_detectors ,
14331460 _DTYPE (detector_spacing ), d_cos_arr , d_sin_arr , cx , cy , _DTYPE (voxel_spacing )
@@ -1445,8 +1472,8 @@ def backward(ctx, grad_sinogram):
14451472 grad_sinogram = DeviceManager .ensure_device (grad_sinogram , device )
14461473 angles = DeviceManager .ensure_device (angles , device )
14471474
1448- grad_sinogram = grad_sinogram .to (dtype = torch .float32 )
1449- angles = angles .to (dtype = torch .float32 )
1475+ grad_sinogram = grad_sinogram .to (dtype = torch .float32 ). contiguous ()
1476+ angles = angles .to (dtype = torch .float32 ). contiguous ()
14501477
14511478 n_angles = angles .shape [0 ]
14521479 grad_image = torch .zeros ((Ny , Nx ), dtype = grad_sinogram .dtype , device = device )
@@ -1461,7 +1488,7 @@ def backward(ctx, grad_sinogram):
14611488 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
14621489
14631490 pt_stream = torch .cuda .current_stream ()
1464- numba_stream = cuda . external_stream (pt_stream . cuda_stream )
1491+ numba_stream = _get_numba_external_stream_for (pt_stream )
14651492 _parallel_2d_backward_kernel [grid , tpb , numba_stream ](
14661493 d_grad_sino , n_angles , num_detectors ,
14671494 d_img_grad , Nx , Ny ,
@@ -1542,8 +1569,8 @@ def forward(ctx, sinogram, angles, detector_spacing=1.0, H=128, W=128, voxel_spa
15421569 angles = DeviceManager .ensure_device (angles , device )
15431570
15441571 # Ensure input is float32 for kernel compatibility
1545- sinogram = sinogram .to (dtype = torch .float32 )
1546- angles = angles .to (dtype = torch .float32 )
1572+ sinogram = sinogram .to (dtype = torch .float32 ). contiguous ()
1573+ angles = angles .to (dtype = torch .float32 ). contiguous ()
15471574
15481575 n_ang , n_det = sinogram .shape
15491576 Ny , Nx = H , W
@@ -1564,7 +1591,7 @@ def forward(ctx, sinogram, angles, detector_spacing=1.0, H=128, W=128, voxel_spa
15641591 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
15651592
15661593 pt_stream = torch .cuda .current_stream ()
1567- numba_stream = cuda . external_stream (pt_stream . cuda_stream )
1594+ numba_stream = _get_numba_external_stream_for (pt_stream )
15681595 _parallel_2d_backward_kernel [grid , tpb , numba_stream ](
15691596 d_sino , n_ang , n_det , d_reco , Nx , Ny ,
15701597 _DTYPE (detector_spacing ), d_cos_arr , d_sin_arr , cx , cy , _DTYPE (voxel_spacing )
@@ -1582,8 +1609,8 @@ def backward(ctx, grad_output):
15821609 grad_output = DeviceManager .ensure_device (grad_output , device )
15831610 angles = DeviceManager .ensure_device (angles , device )
15841611
1585- grad_output = grad_output .to (dtype = torch .float32 )
1586- angles = angles .to (dtype = torch .float32 )
1612+ grad_output = grad_output .to (dtype = torch .float32 ). contiguous ()
1613+ angles = angles .to (dtype = torch .float32 ). contiguous ()
15871614
15881615 Ny , Nx = grad_output .shape
15891616
@@ -1603,7 +1630,7 @@ def backward(ctx, grad_output):
16031630 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
16041631
16051632 pt_stream = torch .cuda .current_stream ()
1606- numba_stream = cuda . external_stream (pt_stream . cuda_stream )
1633+ numba_stream = _get_numba_external_stream_for (pt_stream )
16071634 _parallel_2d_forward_kernel [grid , tpb , numba_stream ](
16081635 d_grad_out , Nx , Ny , d_sino_grad , n_ang , n_det ,
16091636 _DTYPE (detector_spacing ), d_cos , d_sin , cx , cy , _DTYPE (voxel_spacing )
@@ -1687,8 +1714,8 @@ def forward(ctx, image, angles, num_detectors, detector_spacing, sdd, sid, voxel
16871714 image = DeviceManager .ensure_device (image , device )
16881715 angles = DeviceManager .ensure_device (angles , device )
16891716
1690- image = image .to (dtype = torch .float32 )
1691- angles = angles .to (dtype = torch .float32 )
1717+ image = image .to (dtype = torch .float32 ). contiguous ()
1718+ angles = angles .to (dtype = torch .float32 ). contiguous ()
16921719
16931720 Ny , Nx = image .shape
16941721 n_ang = angles .shape [0 ]
@@ -1705,7 +1732,7 @@ def forward(ctx, image, angles, num_detectors, detector_spacing, sdd, sid, voxel
17051732 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
17061733
17071734 pt_stream = torch .cuda .current_stream ()
1708- numba_stream = cuda . external_stream (pt_stream . cuda_stream )
1735+ numba_stream = _get_numba_external_stream_for (pt_stream )
17091736 _fan_2d_forward_kernel [grid , tpb , numba_stream ](
17101737 d_image , Nx , Ny , d_sino , n_ang , num_detectors ,
17111738 _DTYPE (detector_spacing ), d_cos_arr , d_sin_arr ,
@@ -1725,8 +1752,8 @@ def backward(ctx, grad_sinogram):
17251752 grad_sinogram = DeviceManager .ensure_device (grad_sinogram , device )
17261753 angles = DeviceManager .ensure_device (angles , device )
17271754
1728- grad_sinogram = grad_sinogram .to (dtype = torch .float32 )
1729- angles = angles .to (dtype = torch .float32 )
1755+ grad_sinogram = grad_sinogram .to (dtype = torch .float32 ). contiguous ()
1756+ angles = angles .to (dtype = torch .float32 ). contiguous ()
17301757
17311758 n_ang = angles .shape [0 ]
17321759 grad_img = torch .zeros ((Ny , Nx ), dtype = grad_sinogram .dtype , device = device )
@@ -1741,7 +1768,7 @@ def backward(ctx, grad_sinogram):
17411768 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
17421769
17431770 pt_stream = torch .cuda .current_stream ()
1744- numba_stream = cuda . external_stream (pt_stream . cuda_stream )
1771+ numba_stream = _get_numba_external_stream_for (pt_stream )
17451772 _fan_2d_backward_kernel [grid , tpb , numba_stream ](
17461773 d_grad_sino , n_ang , n_det , d_img_grad , Nx , Ny ,
17471774 _DTYPE (det_spacing ), d_cos_arr , d_sin_arr ,
@@ -1829,8 +1856,8 @@ def forward(ctx, sinogram, angles, detector_spacing, H, W, sdd, sid, voxel_spaci
18291856 sinogram = DeviceManager .ensure_device (sinogram , device )
18301857 angles = DeviceManager .ensure_device (angles , device )
18311858
1832- sinogram = sinogram .to (dtype = torch .float32 )
1833- angles = angles .to (dtype = torch .float32 )
1859+ sinogram = sinogram .to (dtype = torch .float32 ). contiguous ()
1860+ angles = angles .to (dtype = torch .float32 ). contiguous ()
18341861
18351862 n_ang , n_det = sinogram .shape
18361863 Ny , Nx = H , W
@@ -1847,7 +1874,7 @@ def forward(ctx, sinogram, angles, detector_spacing, H, W, sdd, sid, voxel_spaci
18471874 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
18481875
18491876 pt_stream = torch .cuda .current_stream ()
1850- numba_stream = cuda . external_stream (pt_stream . cuda_stream )
1877+ numba_stream = _get_numba_external_stream_for (pt_stream )
18511878 _fan_2d_backward_kernel [grid , tpb , numba_stream ](
18521879 d_sino , n_ang , n_det , d_reco , Nx , Ny ,
18531880 _DTYPE (detector_spacing ), d_cos_arr , d_sin_arr ,
@@ -1866,8 +1893,8 @@ def backward(ctx, grad_output):
18661893 grad_output = DeviceManager .ensure_device (grad_output , device )
18671894 angles = DeviceManager .ensure_device (angles , device )
18681895
1869- grad_output = grad_output .to (dtype = torch .float32 )
1870- angles = angles .to (dtype = torch .float32 )
1896+ grad_output = grad_output .to (dtype = torch .float32 ). contiguous ()
1897+ angles = angles .to (dtype = torch .float32 ). contiguous ()
18711898
18721899 Ny , Nx = grad_output .shape
18731900
@@ -1883,7 +1910,7 @@ def backward(ctx, grad_output):
18831910 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
18841911
18851912 pt_stream = torch .cuda .current_stream ()
1886- numba_stream = cuda . external_stream (pt_stream . cuda_stream )
1913+ numba_stream = _get_numba_external_stream_for (pt_stream )
18871914 _fan_2d_forward_kernel [grid , tpb , numba_stream ](
18881915 d_grad_out , Nx , Ny , d_sino_grad , n_ang , n_det ,
18891916 _DTYPE (det_spacing ), d_cos_arr , d_sin_arr ,
@@ -1972,8 +1999,8 @@ def forward(ctx, volume, angles, det_u, det_v, du, dv, sdd, sid, voxel_spacing=1
19721999 volume = DeviceManager .ensure_device (volume , device )
19732000 angles = DeviceManager .ensure_device (angles , device )
19742001
1975- volume = volume .to (dtype = torch .float32 )
1976- angles = angles .to (dtype = torch .float32 )
2002+ volume = volume .to (dtype = torch .float32 ). contiguous ()
2003+ angles = angles .to (dtype = torch .float32 ). contiguous ()
19772004
19782005 D , H , W = volume .shape
19792006 n_views = angles .shape [0 ]
@@ -1994,7 +2021,7 @@ def forward(ctx, volume, angles, det_u, det_v, du, dv, sdd, sid, voxel_spacing=1
19942021 cx , cy , cz = _DTYPE (W * 0.5 ), _DTYPE (H * 0.5 ), _DTYPE (D * 0.5 )
19952022
19962023 pt_stream = torch .cuda .current_stream ()
1997- numba_stream = cuda . external_stream (pt_stream . cuda_stream )
2024+ numba_stream = _get_numba_external_stream_for (pt_stream )
19982025 _cone_3d_forward_kernel [grid , tpb , numba_stream ](
19992026 d_vol , W , H , D , d_sino , n_views , det_u , det_v ,
20002027 _DTYPE (du ), _DTYPE (dv ), d_cos_arr , d_sin_arr ,
@@ -2016,8 +2043,8 @@ def backward(ctx, grad_sinogram):
20162043 grad_sinogram = DeviceManager .ensure_device (grad_sinogram , device )
20172044 angles = DeviceManager .ensure_device (angles , device )
20182045
2019- grad_sinogram = grad_sinogram .to (dtype = torch .float32 )
2020- angles = angles .to (dtype = torch .float32 )
2046+ grad_sinogram = grad_sinogram .to (dtype = torch .float32 ). contiguous ()
2047+ angles = angles .to (dtype = torch .float32 ). contiguous ()
20212048
20222049 n_views = angles .shape [0 ]
20232050
@@ -2033,7 +2060,7 @@ def backward(ctx, grad_sinogram):
20332060 cx , cy , cz = _DTYPE (W * 0.5 ), _DTYPE (H * 0.5 ), _DTYPE (D * 0.5 )
20342061
20352062 pt_stream = torch .cuda .current_stream ()
2036- numba_stream = cuda . external_stream (pt_stream . cuda_stream )
2063+ numba_stream = _get_numba_external_stream_for (pt_stream )
20372064 _cone_3d_backward_kernel [grid , tpb , numba_stream ](
20382065 d_grad_sino , n_views , det_u , det_v , d_vol_grad , W , H , D ,
20392066 _DTYPE (du ), _DTYPE (dv ), d_cos_arr , d_sin_arr ,
@@ -2134,8 +2161,8 @@ def forward(ctx, sinogram, angles, D, H, W, du, dv, sdd, sid, voxel_spacing=1.0)
21342161 sinogram = DeviceManager .ensure_device (sinogram , device )
21352162 angles = DeviceManager .ensure_device (angles , device )
21362163
2137- sinogram = sinogram .to (dtype = torch .float32 )
2138- angles = angles .to (dtype = torch .float32 )
2164+ sinogram = sinogram .to (dtype = torch .float32 ). contiguous ()
2165+ angles = angles .to (dtype = torch .float32 ). contiguous ()
21392166
21402167 n_views , n_u , n_v = sinogram .shape
21412168
@@ -2154,7 +2181,7 @@ def forward(ctx, sinogram, angles, D, H, W, du, dv, sdd, sid, voxel_spacing=1.0)
21542181 cx , cy , cz = _DTYPE (W * 0.5 ), _DTYPE (H * 0.5 ), _DTYPE (D * 0.5 )
21552182
21562183 pt_stream = torch .cuda .current_stream ()
2157- numba_stream = cuda . external_stream (pt_stream . cuda_stream )
2184+ numba_stream = _get_numba_external_stream_for (pt_stream )
21582185 _cone_3d_backward_kernel [grid , tpb , numba_stream ](
21592186 d_sino , n_views , n_u , n_v , d_reco , W , H , D ,
21602187 _DTYPE (du ), _DTYPE (dv ), d_cos_arr , d_sin_arr ,
@@ -2176,8 +2203,8 @@ def backward(ctx, grad_output):
21762203 grad_output = DeviceManager .ensure_device (grad_output , device )
21772204 angles = DeviceManager .ensure_device (angles , device )
21782205
2179- grad_output = grad_output .to (dtype = torch .float32 )
2180- angles = angles .to (dtype = torch .float32 )
2206+ grad_output = grad_output .to (dtype = torch .float32 ). contiguous ()
2207+ angles = angles .to (dtype = torch .float32 ). contiguous ()
21812208
21822209 n_views = angles .shape [0 ]
21832210
@@ -2194,7 +2221,7 @@ def backward(ctx, grad_output):
21942221 cx , cy , cz = _DTYPE (W * 0.5 ), _DTYPE (H * 0.5 ), _DTYPE (D * 0.5 )
21952222
21962223 pt_stream = torch .cuda .current_stream ()
2197- numba_stream = cuda . external_stream (pt_stream . cuda_stream )
2224+ numba_stream = _get_numba_external_stream_for (pt_stream )
21982225 _cone_3d_forward_kernel [grid , tpb , numba_stream ](
21992226 d_grad_out , W , H , D , d_sino_grad , n_views , n_u , n_v ,
22002227 _DTYPE (du ), _DTYPE (dv ), d_cos_arr , d_sin_arr ,
0 commit comments