33import torch
44import matplotlib .pyplot as plt
55import torch .nn .functional as F
6- from diffct .differentiable import ConeProjectorFunction , ConeBackprojectorFunction
6+ from diffct .differentiable import (
7+ ConeProjectorFunction ,
8+ angular_integration_weights ,
9+ cone_cosine_weights ,
10+ cone_weighted_backproject ,
11+ ramp_filter_1d ,
12+ )
713
814
915def shepp_logan_3d (shape ):
@@ -59,19 +65,6 @@ def shepp_logan_3d(shape):
5965 shepp_logan = np .clip (shepp_logan , 0 , 1 )
6066 return shepp_logan
6167
62- def ramp_filter_3d (sinogram_tensor ):
63- device = sinogram_tensor .device
64- num_views , num_det_u , num_det_v = sinogram_tensor .shape
65- freqs = torch .fft .fftfreq (num_det_u , device = device )
66- omega = 2.0 * torch .pi * freqs
67- ramp = torch .abs (omega )
68- ramp_3d = ramp .reshape (1 , num_det_u , 1 )
69- sino_fft = torch .fft .fft (sinogram_tensor , dim = 1 )
70- filtered_fft = sino_fft * ramp_3d
71- filtered = torch .real (torch .fft .ifft (filtered_fft , dim = 1 ))
72-
73- return filtered
74-
7568def main ():
7669 Nx , Ny , Nz = 128 , 128 , 128
7770 phantom_cpu = shepp_logan_3d ((Nz , Ny , Nx ))
@@ -81,53 +74,65 @@ def main():
8174
8275 det_u , det_v = 256 , 256
8376 du , dv = 1.0 , 1.0
77+ detector_offset_u = 0.0
78+ detector_offset_v = 0.0
8479 sdd = 900.0
8580 sid = 600.0
8681
8782 voxel_spacing = 1.0
8883
8984 device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
90- phantom_torch = torch .tensor (phantom_cpu , device = device , dtype = torch .float32 , requires_grad = True ).contiguous ()
85+ phantom_torch = torch .tensor (phantom_cpu , device = device , dtype = torch .float32 ).contiguous ()
9186 angles_torch = torch .tensor (angles_np , device = device , dtype = torch .float32 )
9287
9388 sinogram = ConeProjectorFunction .apply (phantom_torch , angles_torch ,
9489 det_u , det_v , du , dv ,
9590 sdd , sid , voxel_spacing )
9691
9792 # --- FDK weighting and filtering ---
98- # For FDK, projections must be weighted before filtering.
99- # Weight = D / sqrt(D^2 + u^2 + v^2), where D is source_distance
100- # and (u,v) are detector coordinates.
101- u_coords = (torch .arange (det_u , dtype = phantom_torch .dtype , device = device ) - (det_u - 1 ) / 2 ) * du
102- v_coords = (torch .arange (det_v , dtype = phantom_torch .dtype , device = device ) - (det_v - 1 ) / 2 ) * dv
103-
104- # Reshape for broadcasting over sinogram of shape (views, u, v)
105- u_coords = u_coords .view (1 , det_u , 1 )
106- v_coords = v_coords .view (1 , 1 , det_v )
107-
108- weights = sdd / torch .sqrt (sdd ** 2 + u_coords ** 2 + v_coords ** 2 )
109-
110- # Apply weights and then filter
93+ # 1) FDK cosine pre-weighting
94+ weights = cone_cosine_weights (
95+ det_u ,
96+ det_v ,
97+ du ,
98+ dv ,
99+ sdd ,
100+ detector_offset_u = detector_offset_u ,
101+ detector_offset_v = detector_offset_v ,
102+ device = device ,
103+ dtype = phantom_torch .dtype ,
104+ ).unsqueeze (0 )
111105 sino_weighted = sinogram * weights
112- sinogram_filt = ramp_filter_3d (sino_weighted ).contiguous ()
113-
114- reconstruction = F .relu (ConeBackprojectorFunction .apply (sinogram_filt , angles_torch , Nz , Ny , Nx ,
115- du , dv , sdd , sid , voxel_spacing )) # ReLU to ensure non-negativity
116-
117- # --- FDK normalization ---
118- # The backprojection is a sum over all angles. To approximate the integral,
119- # we need to multiply by the angular step d_beta.
120- # The FDK formula also includes a factor of 1/2 when integrating over [0, 2*pi].
121- # d_beta = 2 * pi / num_views
122- # Normalization factor = (1/2) * d_beta = pi / num_views
123- reconstruction = reconstruction * (math .pi / num_views )
106+
107+ # 2) Ramp filter along detector-u rows
108+ sinogram_filt = ramp_filter_1d (sino_weighted , dim = 1 ).contiguous ()
109+
110+ # 3) Angle-integration weights
111+ d_beta = angular_integration_weights (angles_torch , redundant_full_scan = True ).view (- 1 , 1 , 1 )
112+ sinogram_filt = sinogram_filt * d_beta
113+
114+ # 4) Weighted cone-beam backprojection
115+ reconstruction = F .relu (
116+ cone_weighted_backproject (
117+ sinogram_filt ,
118+ angles_torch ,
119+ Nz ,
120+ Ny ,
121+ Nx ,
122+ du ,
123+ dv ,
124+ sdd ,
125+ sid ,
126+ voxel_spacing = voxel_spacing ,
127+ detector_offset_u = detector_offset_u ,
128+ detector_offset_v = detector_offset_v ,
129+ )
130+ )
124131
125132 loss = torch .mean ((reconstruction - phantom_torch )** 2 )
126- loss .backward ()
127133
128134 print ("Cone Beam Example with user-defined geometry:" )
129135 print ("Loss:" , loss .item ())
130- print ("Volume center voxel gradient:" , phantom_torch .grad [Nz // 2 , Ny // 2 , Nx // 2 ].item ())
131136 print ("Reconstruction shape:" , reconstruction .shape )
132137
133138 reconstruction_cpu = reconstruction .detach ().cpu ().numpy ()
@@ -155,4 +160,4 @@ def main():
155160 print ("Reco data range:" , reconstruction_cpu .min (), reconstruction_cpu .max ())
156161
157162if __name__ == "__main__" :
158- main ()
163+ main ()
0 commit comments