@@ -1402,11 +1402,12 @@ def forward(ctx, image, angles, num_detectors, detector_spacing=1.0, voxel_spaci
14021402 grid , tpb = _grid_2d (n_angles , num_detectors )
14031403 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
14041404
1405- _parallel_2d_forward_kernel [grid , tpb ](
1405+ pt_stream = torch .cuda .current_stream ()
1406+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
1407+ _parallel_2d_forward_kernel [grid , tpb , numba_stream ](
14061408 d_image , Nx , Ny , d_sino , n_angles , num_detectors ,
14071409 _DTYPE (detector_spacing ), d_cos_arr , d_sin_arr , cx , cy , _DTYPE (voxel_spacing )
14081410 )
1409- torch .cuda .synchronize ()
14101411
14111412 ctx .save_for_backward (angles )
14121413 ctx .intermediate = (num_detectors , detector_spacing , Ny , Nx , voxel_spacing )
@@ -1435,12 +1436,13 @@ def backward(ctx, grad_sinogram):
14351436 grid , tpb = _grid_2d (n_angles , num_detectors )
14361437 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
14371438
1438- _parallel_2d_backward_kernel [grid , tpb ](
1439+ pt_stream = torch .cuda .current_stream ()
1440+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
1441+ _parallel_2d_backward_kernel [grid , tpb , numba_stream ](
14391442 d_grad_sino , n_angles , num_detectors ,
14401443 d_img_grad , Nx , Ny ,
14411444 _DTYPE (detector_spacing ), d_cos_arr , d_sin_arr , cx , cy , _DTYPE (voxel_spacing )
14421445 )
1443- torch .cuda .synchronize ()
14441446
14451447 return grad_image , None , None , None , None
14461448
@@ -1537,11 +1539,12 @@ def forward(ctx, sinogram, angles, detector_spacing=1.0, H=128, W=128, voxel_spa
15371539 grid , tpb = _grid_2d (n_ang , n_det )
15381540 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
15391541
1540- _parallel_2d_backward_kernel [grid , tpb ](
1542+ pt_stream = torch .cuda .current_stream ()
1543+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
1544+ _parallel_2d_backward_kernel [grid , tpb , numba_stream ](
15411545 d_sino , n_ang , n_det , d_reco , Nx , Ny ,
15421546 _DTYPE (detector_spacing ), d_cos_arr , d_sin_arr , cx , cy , _DTYPE (voxel_spacing )
15431547 )
1544- torch .cuda .synchronize ()
15451548
15461549 ctx .save_for_backward (angles )
15471550 ctx .intermediate = (H , W , detector_spacing , sinogram .shape [0 ], sinogram .shape [1 ], voxel_spacing )
@@ -1575,11 +1578,12 @@ def backward(ctx, grad_output):
15751578 grid , tpb = _grid_2d (n_ang , n_det )
15761579 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
15771580
1578- _parallel_2d_forward_kernel [grid , tpb ](
1581+ pt_stream = torch .cuda .current_stream ()
1582+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
1583+ _parallel_2d_forward_kernel [grid , tpb , numba_stream ](
15791584 d_grad_out , Nx , Ny , d_sino_grad , n_ang , n_det ,
15801585 _DTYPE (detector_spacing ), d_cos , d_sin , cx , cy , _DTYPE (voxel_spacing )
15811586 )
1582- torch .cuda .synchronize ()
15831587
15841588 return grad_sino , None , None , None , None , None
15851589
@@ -1676,12 +1680,13 @@ def forward(ctx, image, angles, num_detectors, detector_spacing, sdd, sid, voxel
16761680 grid , tpb = _grid_2d (n_ang , num_detectors )
16771681 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
16781682
1679- _fan_2d_forward_kernel [grid , tpb ](
1683+ pt_stream = torch .cuda .current_stream ()
1684+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
1685+ _fan_2d_forward_kernel [grid , tpb , numba_stream ](
16801686 d_image , Nx , Ny , d_sino , n_ang , num_detectors ,
16811687 _DTYPE (detector_spacing ), d_cos_arr , d_sin_arr ,
16821688 _DTYPE (sdd ), _DTYPE (sid ), cx , cy , _DTYPE (voxel_spacing )
16831689 )
1684- torch .cuda .synchronize ()
16851690
16861691 ctx .save_for_backward (angles )
16871692 ctx .intermediate = (num_detectors , detector_spacing , Ny , Nx ,
@@ -1711,12 +1716,13 @@ def backward(ctx, grad_sinogram):
17111716 grid , tpb = _grid_2d (n_ang , n_det )
17121717 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
17131718
1714- _fan_2d_backward_kernel [grid , tpb ](
1719+ pt_stream = torch .cuda .current_stream ()
1720+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
1721+ _fan_2d_backward_kernel [grid , tpb , numba_stream ](
17151722 d_grad_sino , n_ang , n_det , d_img_grad , Nx , Ny ,
17161723 _DTYPE (det_spacing ), d_cos_arr , d_sin_arr ,
17171724 _DTYPE (sdd ), _DTYPE (sid ), cx , cy , _DTYPE (voxel_spacing )
17181725 )
1719- torch .cuda .synchronize ()
17201726
17211727 return grad_img , None , None , None , None , None , None
17221728
@@ -1816,12 +1822,13 @@ def forward(ctx, sinogram, angles, detector_spacing, H, W, sdd, sid, voxel_spaci
18161822 grid , tpb = _grid_2d (n_ang , n_det )
18171823 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
18181824
1819- _fan_2d_backward_kernel [grid , tpb ](
1825+ pt_stream = torch .cuda .current_stream ()
1826+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
1827+ _fan_2d_backward_kernel [grid , tpb , numba_stream ](
18201828 d_sino , n_ang , n_det , d_reco , Nx , Ny ,
18211829 _DTYPE (detector_spacing ), d_cos_arr , d_sin_arr ,
18221830 _DTYPE (sdd ), _DTYPE (sid ), cx , cy , _DTYPE (voxel_spacing )
18231831 )
1824- torch .cuda .synchronize ()
18251832
18261833 ctx .save_for_backward (angles )
18271834 ctx .intermediate = (H , W , detector_spacing , n_ang , n_det , sdd , sid , voxel_spacing )
@@ -1851,12 +1858,13 @@ def backward(ctx, grad_output):
18511858 grid , tpb = _grid_2d (n_ang , n_det )
18521859 cx , cy = _DTYPE (Nx * 0.5 ), _DTYPE (Ny * 0.5 )
18531860
1854- _fan_2d_forward_kernel [grid , tpb ](
1861+ pt_stream = torch .cuda .current_stream ()
1862+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
1863+ _fan_2d_forward_kernel [grid , tpb , numba_stream ](
18551864 d_grad_out , Nx , Ny , d_sino_grad , n_ang , n_det ,
18561865 _DTYPE (det_spacing ), d_cos_arr , d_sin_arr ,
18571866 _DTYPE (sdd ), _DTYPE (sid ), cx , cy , _DTYPE (voxel_spacing )
18581867 )
1859- torch .cuda .synchronize ()
18601868
18611869 return grad_sino , None , None , None , None , None , None , None
18621870
@@ -1961,13 +1969,14 @@ def forward(ctx, volume, angles, det_u, det_v, du, dv, sdd, sid, voxel_spacing=1
19611969 grid , tpb = _grid_3d (n_views , det_u , det_v )
19621970 cx , cy , cz = _DTYPE (W * 0.5 ), _DTYPE (H * 0.5 ), _DTYPE (D * 0.5 )
19631971
1964- _cone_3d_forward_kernel [grid , tpb ](
1972+ pt_stream = torch .cuda .current_stream ()
1973+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
1974+ _cone_3d_forward_kernel [grid , tpb , numba_stream ](
19651975 d_vol , W , H , D , d_sino , n_views , det_u , det_v ,
19661976 _DTYPE (du ), _DTYPE (dv ), d_cos_arr , d_sin_arr ,
19671977 _DTYPE (sdd ), _DTYPE (sid ),
19681978 cx , cy , cz , _DTYPE (voxel_spacing )
19691979 )
1970- torch .cuda .synchronize ()
19711980
19721981 ctx .save_for_backward (angles )
19731982 ctx .intermediate = (D , H , W , det_u , det_v , du , dv ,
@@ -1999,12 +2008,13 @@ def backward(ctx, grad_sinogram):
19992008 grid , tpb = _grid_3d (n_views , det_u , det_v )
20002009 cx , cy , cz = _DTYPE (W * 0.5 ), _DTYPE (H * 0.5 ), _DTYPE (D * 0.5 )
20012010
2002- _cone_3d_backward_kernel [grid , tpb ](
2011+ pt_stream = torch .cuda .current_stream ()
2012+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
2013+ _cone_3d_backward_kernel [grid , tpb , numba_stream ](
20032014 d_grad_sino , n_views , det_u , det_v , d_vol_grad , W , H , D ,
20042015 _DTYPE (du ), _DTYPE (dv ), d_cos_arr , d_sin_arr ,
20052016 _DTYPE (sdd ), _DTYPE (sid ), cx , cy , cz , _DTYPE (voxel_spacing )
20062017 )
2007- torch .cuda .synchronize ()
20082018
20092019 grad_vol = grad_vol_perm .permute (2 , 1 , 0 ).contiguous ()
20102020 return grad_vol , None , None , None , None , None , None , None , None
@@ -2119,12 +2129,13 @@ def forward(ctx, sinogram, angles, D, H, W, du, dv, sdd, sid, voxel_spacing=1.0)
21192129 grid , tpb = _grid_3d (n_views , n_u , n_v )
21202130 cx , cy , cz = _DTYPE (W * 0.5 ), _DTYPE (H * 0.5 ), _DTYPE (D * 0.5 )
21212131
2122- _cone_3d_backward_kernel [grid , tpb ](
2132+ pt_stream = torch .cuda .current_stream ()
2133+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
2134+ _cone_3d_backward_kernel [grid , tpb , numba_stream ](
21232135 d_sino , n_views , n_u , n_v , d_reco , W , H , D ,
21242136 _DTYPE (du ), _DTYPE (dv ), d_cos_arr , d_sin_arr ,
21252137 _DTYPE (sdd ), _DTYPE (sid ), cx , cy , cz , _DTYPE (voxel_spacing )
21262138 )
2127- torch .cuda .synchronize ()
21282139
21292140 ctx .save_for_backward (angles )
21302141 ctx .intermediate = (D , H , W , n_u , n_v , du , dv ,
@@ -2158,11 +2169,12 @@ def backward(ctx, grad_output):
21582169 grid , tpb = _grid_3d (n_views , n_u , n_v )
21592170 cx , cy , cz = _DTYPE (W * 0.5 ), _DTYPE (H * 0.5 ), _DTYPE (D * 0.5 )
21602171
2161- _cone_3d_forward_kernel [grid , tpb ](
2172+ pt_stream = torch .cuda .current_stream ()
2173+ numba_stream = cuda .external_stream (pt_stream .cuda_stream )
2174+ _cone_3d_forward_kernel [grid , tpb , numba_stream ](
21622175 d_grad_out , W , H , D , d_sino_grad , n_views , n_u , n_v ,
21632176 _DTYPE (du ), _DTYPE (dv ), d_cos_arr , d_sin_arr ,
21642177 _DTYPE (sdd ), _DTYPE (sid ), cx , cy , cz , _DTYPE (voxel_spacing )
21652178 )
2166- torch .cuda .synchronize ()
21672179
21682180 return grad_sino , None , None , None , None , None , None , None , None , None
0 commit comments