Skip to content

Commit 8342ee2

Browse files
committed
Enhance CUDA kernel integration by utilizing external streams for improved performance in 2D and 3D projector functions
1 parent 49ac06f commit 8342ee2

1 file changed

Lines changed: 36 additions & 24 deletions

File tree

diffct/differentiable.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,11 +1402,12 @@ def forward(ctx, image, angles, num_detectors, detector_spacing=1.0, voxel_spaci
14021402
grid, tpb = _grid_2d(n_angles, num_detectors)
14031403
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
14041404

1405-
_parallel_2d_forward_kernel[grid, tpb](
1405+
pt_stream = torch.cuda.current_stream()
1406+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1407+
_parallel_2d_forward_kernel[grid, tpb, numba_stream](
14061408
d_image, Nx, Ny, d_sino, n_angles, num_detectors,
14071409
_DTYPE(detector_spacing), d_cos_arr, d_sin_arr, cx, cy, _DTYPE(voxel_spacing)
14081410
)
1409-
torch.cuda.synchronize()
14101411

14111412
ctx.save_for_backward(angles)
14121413
ctx.intermediate = (num_detectors, detector_spacing, Ny, Nx, voxel_spacing)
@@ -1435,12 +1436,13 @@ def backward(ctx, grad_sinogram):
14351436
grid, tpb = _grid_2d(n_angles, num_detectors)
14361437
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
14371438

1438-
_parallel_2d_backward_kernel[grid, tpb](
1439+
pt_stream = torch.cuda.current_stream()
1440+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1441+
_parallel_2d_backward_kernel[grid, tpb, numba_stream](
14391442
d_grad_sino, n_angles, num_detectors,
14401443
d_img_grad, Nx, Ny,
14411444
_DTYPE(detector_spacing), d_cos_arr, d_sin_arr, cx, cy, _DTYPE(voxel_spacing)
14421445
)
1443-
torch.cuda.synchronize()
14441446

14451447
return grad_image, None, None, None, None
14461448

@@ -1537,11 +1539,12 @@ def forward(ctx, sinogram, angles, detector_spacing=1.0, H=128, W=128, voxel_spa
15371539
grid, tpb = _grid_2d(n_ang, n_det)
15381540
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
15391541

1540-
_parallel_2d_backward_kernel[grid, tpb](
1542+
pt_stream = torch.cuda.current_stream()
1543+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1544+
_parallel_2d_backward_kernel[grid, tpb, numba_stream](
15411545
d_sino, n_ang, n_det, d_reco, Nx, Ny,
15421546
_DTYPE(detector_spacing), d_cos_arr, d_sin_arr, cx, cy, _DTYPE(voxel_spacing)
15431547
)
1544-
torch.cuda.synchronize()
15451548

15461549
ctx.save_for_backward(angles)
15471550
ctx.intermediate = (H, W, detector_spacing, sinogram.shape[0], sinogram.shape[1], voxel_spacing)
@@ -1575,11 +1578,12 @@ def backward(ctx, grad_output):
15751578
grid, tpb = _grid_2d(n_ang, n_det)
15761579
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
15771580

1578-
_parallel_2d_forward_kernel[grid, tpb](
1581+
pt_stream = torch.cuda.current_stream()
1582+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1583+
_parallel_2d_forward_kernel[grid, tpb, numba_stream](
15791584
d_grad_out, Nx, Ny, d_sino_grad, n_ang, n_det,
15801585
_DTYPE(detector_spacing), d_cos, d_sin, cx, cy, _DTYPE(voxel_spacing)
15811586
)
1582-
torch.cuda.synchronize()
15831587

15841588
return grad_sino, None, None, None, None, None
15851589

@@ -1676,12 +1680,13 @@ def forward(ctx, image, angles, num_detectors, detector_spacing, sdd, sid, voxel
16761680
grid, tpb = _grid_2d(n_ang, num_detectors)
16771681
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
16781682

1679-
_fan_2d_forward_kernel[grid, tpb](
1683+
pt_stream = torch.cuda.current_stream()
1684+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1685+
_fan_2d_forward_kernel[grid, tpb, numba_stream](
16801686
d_image, Nx, Ny, d_sino, n_ang, num_detectors,
16811687
_DTYPE(detector_spacing), d_cos_arr, d_sin_arr,
16821688
_DTYPE(sdd), _DTYPE(sid), cx, cy, _DTYPE(voxel_spacing)
16831689
)
1684-
torch.cuda.synchronize()
16851690

16861691
ctx.save_for_backward(angles)
16871692
ctx.intermediate = (num_detectors, detector_spacing, Ny, Nx,
@@ -1711,12 +1716,13 @@ def backward(ctx, grad_sinogram):
17111716
grid, tpb = _grid_2d(n_ang, n_det)
17121717
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
17131718

1714-
_fan_2d_backward_kernel[grid, tpb](
1719+
pt_stream = torch.cuda.current_stream()
1720+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1721+
_fan_2d_backward_kernel[grid, tpb, numba_stream](
17151722
d_grad_sino, n_ang, n_det, d_img_grad, Nx, Ny,
17161723
_DTYPE(det_spacing), d_cos_arr, d_sin_arr,
17171724
_DTYPE(sdd), _DTYPE(sid), cx, cy, _DTYPE(voxel_spacing)
17181725
)
1719-
torch.cuda.synchronize()
17201726

17211727
return grad_img, None, None, None, None, None, None
17221728

@@ -1816,12 +1822,13 @@ def forward(ctx, sinogram, angles, detector_spacing, H, W, sdd, sid, voxel_spaci
18161822
grid, tpb = _grid_2d(n_ang, n_det)
18171823
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
18181824

1819-
_fan_2d_backward_kernel[grid, tpb](
1825+
pt_stream = torch.cuda.current_stream()
1826+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1827+
_fan_2d_backward_kernel[grid, tpb, numba_stream](
18201828
d_sino, n_ang, n_det, d_reco, Nx, Ny,
18211829
_DTYPE(detector_spacing), d_cos_arr, d_sin_arr,
18221830
_DTYPE(sdd), _DTYPE(sid), cx, cy, _DTYPE(voxel_spacing)
18231831
)
1824-
torch.cuda.synchronize()
18251832

18261833
ctx.save_for_backward(angles)
18271834
ctx.intermediate = (H, W, detector_spacing, n_ang, n_det, sdd, sid, voxel_spacing)
@@ -1851,12 +1858,13 @@ def backward(ctx, grad_output):
18511858
grid, tpb = _grid_2d(n_ang, n_det)
18521859
cx, cy = _DTYPE(Nx * 0.5), _DTYPE(Ny * 0.5)
18531860

1854-
_fan_2d_forward_kernel[grid, tpb](
1861+
pt_stream = torch.cuda.current_stream()
1862+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1863+
_fan_2d_forward_kernel[grid, tpb, numba_stream](
18551864
d_grad_out, Nx, Ny, d_sino_grad, n_ang, n_det,
18561865
_DTYPE(det_spacing), d_cos_arr, d_sin_arr,
18571866
_DTYPE(sdd), _DTYPE(sid), cx, cy, _DTYPE(voxel_spacing)
18581867
)
1859-
torch.cuda.synchronize()
18601868

18611869
return grad_sino, None, None, None, None, None, None, None
18621870

@@ -1961,13 +1969,14 @@ def forward(ctx, volume, angles, det_u, det_v, du, dv, sdd, sid, voxel_spacing=1
19611969
grid, tpb = _grid_3d(n_views, det_u, det_v)
19621970
cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5)
19631971

1964-
_cone_3d_forward_kernel[grid, tpb](
1972+
pt_stream = torch.cuda.current_stream()
1973+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
1974+
_cone_3d_forward_kernel[grid, tpb, numba_stream](
19651975
d_vol, W, H, D, d_sino, n_views, det_u, det_v,
19661976
_DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr,
19671977
_DTYPE(sdd), _DTYPE(sid),
19681978
cx, cy, cz, _DTYPE(voxel_spacing)
19691979
)
1970-
torch.cuda.synchronize()
19711980

19721981
ctx.save_for_backward(angles)
19731982
ctx.intermediate = (D, H, W, det_u, det_v, du, dv,
@@ -1999,12 +2008,13 @@ def backward(ctx, grad_sinogram):
19992008
grid, tpb = _grid_3d(n_views, det_u, det_v)
20002009
cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5)
20012010

2002-
_cone_3d_backward_kernel[grid, tpb](
2011+
pt_stream = torch.cuda.current_stream()
2012+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
2013+
_cone_3d_backward_kernel[grid, tpb, numba_stream](
20032014
d_grad_sino, n_views, det_u, det_v, d_vol_grad, W, H, D,
20042015
_DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr,
20052016
_DTYPE(sdd), _DTYPE(sid), cx, cy, cz, _DTYPE(voxel_spacing)
20062017
)
2007-
torch.cuda.synchronize()
20082018

20092019
grad_vol = grad_vol_perm.permute(2, 1, 0).contiguous()
20102020
return grad_vol, None, None, None, None, None, None, None, None
@@ -2119,12 +2129,13 @@ def forward(ctx, sinogram, angles, D, H, W, du, dv, sdd, sid, voxel_spacing=1.0)
21192129
grid, tpb = _grid_3d(n_views, n_u, n_v)
21202130
cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5)
21212131

2122-
_cone_3d_backward_kernel[grid, tpb](
2132+
pt_stream = torch.cuda.current_stream()
2133+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
2134+
_cone_3d_backward_kernel[grid, tpb, numba_stream](
21232135
d_sino, n_views, n_u, n_v, d_reco, W, H, D,
21242136
_DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr,
21252137
_DTYPE(sdd), _DTYPE(sid), cx, cy, cz, _DTYPE(voxel_spacing)
21262138
)
2127-
torch.cuda.synchronize()
21282139

21292140
ctx.save_for_backward(angles)
21302141
ctx.intermediate = (D, H, W, n_u, n_v, du, dv,
@@ -2158,11 +2169,12 @@ def backward(ctx, grad_output):
21582169
grid, tpb = _grid_3d(n_views, n_u, n_v)
21592170
cx, cy, cz = _DTYPE(W * 0.5), _DTYPE(H * 0.5), _DTYPE(D * 0.5)
21602171

2161-
_cone_3d_forward_kernel[grid, tpb](
2172+
pt_stream = torch.cuda.current_stream()
2173+
numba_stream = cuda.external_stream(pt_stream.cuda_stream)
2174+
_cone_3d_forward_kernel[grid, tpb, numba_stream](
21622175
d_grad_out, W, H, D, d_sino_grad, n_views, n_u, n_v,
21632176
_DTYPE(du), _DTYPE(dv), d_cos_arr, d_sin_arr,
21642177
_DTYPE(sdd), _DTYPE(sid), cx, cy, cz, _DTYPE(voxel_spacing)
21652178
)
2166-
torch.cuda.synchronize()
21672179

21682180
return grad_sino, None, None, None, None, None, None, None, None, None

0 commit comments

Comments
 (0)