Skip to content

Commit 34ed5bf

Browse files
committed
Refactor FBP and FDK implementations to utilize new weighting and filtering functions
- Updated `fbp_fan.py` to replace manual ramp filtering with `ramp_filter_1d` and added cosine weighting using `fan_cosine_weights`. - Integrated Parker weighting option in fan-beam reconstruction. - Refactored `fbp_parallel.py` to use `ramp_filter_1d` and added angular integration weights. - Enhanced `fdk_cone.py` to implement cosine weighting with `cone_cosine_weights` and replaced 3D ramp filtering with `ramp_filter_1d`. - Introduced new tests for CUDA functionality and weight calculations in `test_cuda_smoke.py` and `test_weights.py`. - Added `pytest.ini` for test configuration.
1 parent dce5030 commit 34ed5bf

9 files changed

Lines changed: 919 additions & 180 deletions

File tree

diffct/__init__.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,17 @@
11
# diffct/__init__.py
2-
from .differentiable import ParallelProjectorFunction, ParallelBackprojectorFunction, FanProjectorFunction, FanBackprojectorFunction, ConeProjectorFunction, ConeBackprojectorFunction
2+
from .differentiable import (
3+
ParallelProjectorFunction,
4+
ParallelBackprojectorFunction,
5+
FanProjectorFunction,
6+
FanBackprojectorFunction,
7+
ConeProjectorFunction,
8+
ConeBackprojectorFunction,
9+
detector_coordinates_1d,
10+
angular_integration_weights,
11+
fan_cosine_weights,
12+
cone_cosine_weights,
13+
parker_weights,
14+
ramp_filter_1d,
15+
fan_weighted_backproject,
16+
cone_weighted_backproject,
17+
)

diffct/differentiable.py

Lines changed: 650 additions & 73 deletions
Large diffs are not rendered by default.

examples/fbp_fan.py

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
import torch
44
import matplotlib.pyplot as plt
55
import torch.nn.functional as F
6-
from diffct.differentiable import FanProjectorFunction, FanBackprojectorFunction
6+
from diffct.differentiable import (
7+
FanProjectorFunction,
8+
angular_integration_weights,
9+
fan_cosine_weights,
10+
fan_weighted_backproject,
11+
parker_weights,
12+
ramp_filter_1d,
13+
)
714

815

916
def shepp_logan_2d(Nx, Ny):
@@ -34,19 +41,6 @@ def shepp_logan_2d(Nx, Ny):
3441
phantom = np.clip(phantom, 0.0, 1.0)
3542
return phantom
3643

37-
def ramp_filter(sinogram_tensor):
38-
device = sinogram_tensor.device
39-
num_views, num_det = sinogram_tensor.shape
40-
freqs = torch.fft.fftfreq(num_det, device=device)
41-
omega = 2.0 * torch.pi * freqs
42-
ramp = torch.abs(omega)
43-
ramp_2d = ramp.reshape(1, num_det)
44-
sino_fft = torch.fft.fft(sinogram_tensor, dim=1)
45-
filtered_fft = sino_fft * ramp_2d
46-
filtered = torch.real(torch.fft.ifft(filtered_fft, dim=1))
47-
48-
return filtered
49-
5044
def main():
5145
Nx, Ny = 256, 256
5246
phantom = shepp_logan_2d(Nx, Ny)
@@ -55,45 +49,61 @@ def main():
5549

5650
num_detectors = 600
5751
detector_spacing = 1.0
52+
detector_offset = 0.0
5853
voxel_spacing = 1.0
5954
sdd = 800.0
6055
sid = 500.0
56+
apply_parker = False
6157

6258
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63-
image_torch = torch.tensor(phantom, device=device, dtype=torch.float32, requires_grad=True)
59+
image_torch = torch.tensor(phantom, device=device, dtype=torch.float32)
6460
angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
6561

6662
sinogram = FanProjectorFunction.apply(image_torch, angles_torch, num_detectors,
6763
detector_spacing, sdd, sid, voxel_spacing)
6864

6965
# --- FBP weighting and filtering ---
70-
# For fan-beam FBP, projections must be weighted before filtering.
71-
# Weight = cos(gamma), where gamma is the fan angle for each detector.
72-
u = (torch.arange(num_detectors, dtype=image_torch.dtype, device=device) - (num_detectors - 1) / 2) * detector_spacing
73-
gamma = torch.atan(u / sdd)
74-
weights = torch.cos(gamma).unsqueeze(0) # Shape (1, num_detectors) for broadcasting
75-
76-
# Apply weights before filtering
66+
# 1) Optional Parker redundancy weighting for short-scan trajectories
67+
if apply_parker:
68+
parker = parker_weights(angles_torch, num_detectors, detector_spacing, sdd, detector_offset)
69+
sinogram = sinogram * parker
70+
71+
# 2) Fan-beam cosine pre-weighting
72+
weights = fan_cosine_weights(
73+
num_detectors,
74+
detector_spacing,
75+
sdd,
76+
detector_offset=detector_offset,
77+
device=device,
78+
dtype=image_torch.dtype,
79+
).unsqueeze(0)
7780
sino_weighted = sinogram * weights
78-
sinogram_filt = ramp_filter(sino_weighted)
79-
80-
reconstruction = F.relu(FanBackprojectorFunction.apply(sinogram_filt, angles_torch,
81-
detector_spacing, Ny, Nx,
82-
sdd, sid, voxel_spacing)) # ReLU to ensure non-negativity
83-
84-
# --- FBP normalization ---
85-
# The backprojection is a sum over all angles. To approximate the integral,
86-
# we need to multiply by the angular step d_beta.
87-
# The fan-beam FBP formula also includes a factor of 1/2 when integrating over [0, 2*pi].
88-
# d_beta = 2 * pi / num_angles
89-
# Normalization factor = (1/2) * d_beta = pi / num_angles
90-
reconstruction = reconstruction * (math.pi / num_angles)
81+
82+
# 3) Ramp filter along detector axis
83+
sinogram_filt = ramp_filter_1d(sino_weighted, dim=1)
84+
85+
# 4) Angle-integration weights
86+
d_beta = angular_integration_weights(angles_torch, redundant_full_scan=(not apply_parker)).view(-1, 1)
87+
sinogram_filt = sinogram_filt * d_beta
88+
89+
# 5) Weighted fan-beam backprojection
90+
reconstruction = F.relu(
91+
fan_weighted_backproject(
92+
sinogram_filt,
93+
angles_torch,
94+
detector_spacing,
95+
Ny,
96+
Nx,
97+
sdd,
98+
sid,
99+
voxel_spacing=voxel_spacing,
100+
detector_offset=detector_offset,
101+
)
102+
)
91103

92104
loss = torch.mean((reconstruction - image_torch)**2)
93-
loss.backward()
94105

95106
print("Loss:", loss.item())
96-
print("Center pixel gradient:", image_torch.grad[Ny//2, Nx//2].item())
97107

98108
sinogram_cpu = sinogram.detach().cpu().numpy()
99109
reco_cpu = reconstruction.detach().cpu().numpy()
@@ -119,4 +129,4 @@ def main():
119129
print("Reco range:", reco_cpu.min(), reco_cpu.max())
120130

121131
if __name__ == "__main__":
122-
main()
132+
main()

examples/fbp_parallel.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
import torch
33
import matplotlib.pyplot as plt
44
import torch.nn.functional as F
5-
from diffct.differentiable import ParallelProjectorFunction, ParallelBackprojectorFunction
5+
from diffct.differentiable import (
6+
ParallelProjectorFunction,
7+
ParallelBackprojectorFunction,
8+
angular_integration_weights,
9+
ramp_filter_1d,
10+
)
611

712

813
def shepp_logan_2d(Nx, Ny):
@@ -33,19 +38,6 @@ def shepp_logan_2d(Nx, Ny):
3338
phantom = np.clip(phantom, 0.0, 1.0)
3439
return phantom
3540

36-
def ramp_filter(sinogram_tensor):
37-
device = sinogram_tensor.device
38-
num_views, num_det = sinogram_tensor.shape
39-
freqs = torch.fft.fftfreq(num_det, device=device)
40-
omega = 2.0 * torch.pi * freqs
41-
ramp = torch.abs(omega)
42-
ramp_2d = ramp.reshape(1, num_det)
43-
sino_fft = torch.fft.fft(sinogram_tensor, dim=1)
44-
filtered_fft = sino_fft * ramp_2d
45-
filtered = torch.real(torch.fft.ifft(filtered_fft, dim=1))
46-
47-
return filtered
48-
4941
def main():
5042
Nx, Ny = 256, 256
5143
phantom = shepp_logan_2d(Nx, Ny)
@@ -63,18 +55,12 @@ def main():
6355
sinogram = ParallelProjectorFunction.apply(image_torch, angles_torch,
6456
num_detectors, detector_spacing, voxel_spacing)
6557

66-
sinogram_filt = ramp_filter(sinogram)
58+
sinogram_filt = ramp_filter_1d(sinogram, dim=1)
59+
d_theta = angular_integration_weights(angles_torch, redundant_full_scan=True).view(-1, 1)
60+
sinogram_filt = sinogram_filt * d_theta
6761

6862
reconstruction = F.relu(ParallelBackprojectorFunction.apply(sinogram_filt, angles_torch,
6963
detector_spacing, Ny, Nx, voxel_spacing)) # ReLU to ensure non-negativity
70-
71-
# --- FBP normalization ---
72-
# The backprojection is a sum over all angles. To approximate the integral,
73-
# we need to multiply by the angular step d_theta.
74-
# The FBP formula also includes a factor of 1/2 when integrating over [0, 2*pi].
75-
# d_theta = 2 * pi / num_angles
76-
# Normalization factor = (1/2) * d_theta = pi / num_angles
77-
reconstruction = reconstruction * (np.pi / num_angles)
7864

7965
loss = torch.mean((reconstruction - image_torch)**2)
8066
loss.backward()
@@ -107,4 +93,4 @@ def main():
10793
print("Reco range:", reco_cpu.min(), reco_cpu.max())
10894

10995
if __name__ == "__main__":
110-
main()
96+
main()

examples/fdk_cone.py

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import torch
44
import matplotlib.pyplot as plt
55
import torch.nn.functional as F
6-
from diffct.differentiable import ConeProjectorFunction, ConeBackprojectorFunction
6+
from diffct.differentiable import (
7+
ConeProjectorFunction,
8+
angular_integration_weights,
9+
cone_cosine_weights,
10+
cone_weighted_backproject,
11+
ramp_filter_1d,
12+
)
713

814

915
def shepp_logan_3d(shape):
@@ -59,19 +65,6 @@ def shepp_logan_3d(shape):
5965
shepp_logan = np.clip(shepp_logan, 0, 1)
6066
return shepp_logan
6167

62-
def ramp_filter_3d(sinogram_tensor):
63-
device = sinogram_tensor.device
64-
num_views, num_det_u, num_det_v = sinogram_tensor.shape
65-
freqs = torch.fft.fftfreq(num_det_u, device=device)
66-
omega = 2.0 * torch.pi * freqs
67-
ramp = torch.abs(omega)
68-
ramp_3d = ramp.reshape(1, num_det_u, 1)
69-
sino_fft = torch.fft.fft(sinogram_tensor, dim=1)
70-
filtered_fft = sino_fft * ramp_3d
71-
filtered = torch.real(torch.fft.ifft(filtered_fft, dim=1))
72-
73-
return filtered
74-
7568
def main():
7669
Nx, Ny, Nz = 128, 128, 128
7770
phantom_cpu = shepp_logan_3d((Nz, Ny, Nx))
@@ -81,53 +74,65 @@ def main():
8174

8275
det_u, det_v = 256, 256
8376
du, dv = 1.0, 1.0
77+
detector_offset_u = 0.0
78+
detector_offset_v = 0.0
8479
sdd = 900.0
8580
sid = 600.0
8681

8782
voxel_spacing = 1.0
8883

8984
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
90-
phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32, requires_grad=True).contiguous()
85+
phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32).contiguous()
9186
angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
9287

9388
sinogram = ConeProjectorFunction.apply(phantom_torch, angles_torch,
9489
det_u, det_v, du, dv,
9590
sdd, sid, voxel_spacing)
9691

9792
# --- FDK weighting and filtering ---
98-
# For FDK, projections must be weighted before filtering.
99-
# Weight = D / sqrt(D^2 + u^2 + v^2), where D is source_distance
100-
# and (u,v) are detector coordinates.
101-
u_coords = (torch.arange(det_u, dtype=phantom_torch.dtype, device=device) - (det_u - 1) / 2) * du
102-
v_coords = (torch.arange(det_v, dtype=phantom_torch.dtype, device=device) - (det_v - 1) / 2) * dv
103-
104-
# Reshape for broadcasting over sinogram of shape (views, u, v)
105-
u_coords = u_coords.view(1, det_u, 1)
106-
v_coords = v_coords.view(1, 1, det_v)
107-
108-
weights = sdd / torch.sqrt(sdd**2 + u_coords**2 + v_coords**2)
109-
110-
# Apply weights and then filter
93+
# 1) FDK cosine pre-weighting
94+
weights = cone_cosine_weights(
95+
det_u,
96+
det_v,
97+
du,
98+
dv,
99+
sdd,
100+
detector_offset_u=detector_offset_u,
101+
detector_offset_v=detector_offset_v,
102+
device=device,
103+
dtype=phantom_torch.dtype,
104+
).unsqueeze(0)
111105
sino_weighted = sinogram * weights
112-
sinogram_filt = ramp_filter_3d(sino_weighted).contiguous()
113-
114-
reconstruction = F.relu(ConeBackprojectorFunction.apply(sinogram_filt, angles_torch, Nz, Ny, Nx,
115-
du, dv, sdd, sid, voxel_spacing)) # ReLU to ensure non-negativity
116-
117-
# --- FDK normalization ---
118-
# The backprojection is a sum over all angles. To approximate the integral,
119-
# we need to multiply by the angular step d_beta.
120-
# The FDK formula also includes a factor of 1/2 when integrating over [0, 2*pi].
121-
# d_beta = 2 * pi / num_views
122-
# Normalization factor = (1/2) * d_beta = pi / num_views
123-
reconstruction = reconstruction * (math.pi / num_views)
106+
107+
# 2) Ramp filter along detector-u rows
108+
sinogram_filt = ramp_filter_1d(sino_weighted, dim=1).contiguous()
109+
110+
# 3) Angle-integration weights
111+
d_beta = angular_integration_weights(angles_torch, redundant_full_scan=True).view(-1, 1, 1)
112+
sinogram_filt = sinogram_filt * d_beta
113+
114+
# 4) Weighted cone-beam backprojection
115+
reconstruction = F.relu(
116+
cone_weighted_backproject(
117+
sinogram_filt,
118+
angles_torch,
119+
Nz,
120+
Ny,
121+
Nx,
122+
du,
123+
dv,
124+
sdd,
125+
sid,
126+
voxel_spacing=voxel_spacing,
127+
detector_offset_u=detector_offset_u,
128+
detector_offset_v=detector_offset_v,
129+
)
130+
)
124131

125132
loss = torch.mean((reconstruction - phantom_torch)**2)
126-
loss.backward()
127133

128134
print("Cone Beam Example with user-defined geometry:")
129135
print("Loss:", loss.item())
130-
print("Volume center voxel gradient:", phantom_torch.grad[Nz//2, Ny//2, Nx//2].item())
131136
print("Reconstruction shape:", reconstruction.shape)
132137

133138
reconstruction_cpu = reconstruction.detach().cpu().numpy()
@@ -155,4 +160,4 @@ def main():
155160
print("Reco data range:", reconstruction_cpu.min(), reconstruction_cpu.max())
156161

157162
if __name__ == "__main__":
158-
main()
163+
main()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "diffct"
7-
version = "1.2.7"
7+
version = "1.2.8"
88
description = "A CUDA-based library for computed tomography (CT) projection and reconstruction with differentiable operators"
99
readme = "README.md"
1010
authors = [

pytest.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[pytest]
2+
markers =
3+
cuda: CUDA-dependent smoke tests

0 commit comments

Comments
 (0)