Skip to content

Commit 9f2541c

Browse files
committed
Remove unnecessary CPU-GPU data transfers and update the examples accordingly.
1 parent 2051622 commit 9f2541c

8 files changed

Lines changed: 287 additions & 390 deletions

File tree

diffct/differentiable.py

Lines changed: 262 additions & 366 deletions
Large diffs are not rendered by default.

examples/fbp_fan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def main():
5858
isocenter_distance = 500.0
5959

6060
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
61-
image_torch = torch.tensor(phantom, device=device, requires_grad=True)
62-
angles_torch = torch.tensor(angles_np, device=device)
61+
image_torch = torch.tensor(phantom, device=device, dtype=torch.float32, requires_grad=True)
62+
angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
6363

6464
sinogram = FanProjectorFunction.apply(image_torch, angles_torch, num_detectors,
6565
detector_spacing, source_distance, isocenter_distance)
@@ -73,7 +73,7 @@ def main():
7373

7474
# Apply weights before filtering
7575
sino_weighted = sinogram * weights
76-
sinogram_filt = ramp_filter(sino_weighted).detach().requires_grad_(True).contiguous()
76+
sinogram_filt = ramp_filter(sino_weighted)
7777

7878
reconstruction = FanBackprojectorFunction.apply(sinogram_filt, angles_torch,
7979
detector_spacing, Nx, Ny,

examples/fbp_parallel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ def main():
5555
detector_spacing = 1.0
5656

5757
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58-
image_torch = torch.tensor(phantom, device=device, requires_grad=True)
59-
angles_torch = torch.tensor(angles_np, device=device, requires_grad=False)
58+
image_torch = torch.tensor(phantom, device=device, dtype=torch.float32, requires_grad=True)
59+
angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
6060

6161
sinogram = ParallelProjectorFunction.apply(image_torch, angles_torch,
6262
num_detectors, detector_spacing)
6363

64-
sinogram_filt = ramp_filter(sinogram).detach().requires_grad_(True).contiguous()
64+
sinogram_filt = ramp_filter(sinogram)
6565

6666
reconstruction = ParallelBackprojectorFunction.apply(sinogram_filt, angles_torch,
6767
detector_spacing, Nx, Ny)

examples/fdk_cone.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def main():
8484
isocenter_distance = 600.0
8585

8686
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
87-
phantom_torch = torch.tensor(phantom_cpu, device=device, requires_grad=True)
88-
angles_torch = torch.tensor(angles_np, device=device)
87+
phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32, requires_grad=True)
88+
angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
8989

9090
sinogram = ConeProjectorFunction.apply(phantom_torch, angles_torch,
9191
det_u, det_v, du, dv,
@@ -106,7 +106,7 @@ def main():
106106

107107
# Apply weights and then filter
108108
sino_weighted = sinogram * weights
109-
sinogram_filt = ramp_filter_3d(sino_weighted).detach().requires_grad_(True).contiguous()
109+
sinogram_filt = ramp_filter_3d(sino_weighted)
110110

111111
reconstruction = ConeBackprojectorFunction.apply(sinogram_filt, angles_torch, Nx, Ny, Nz,
112112
du, dv, source_distance, isocenter_distance)

examples/iterative_reco_cone.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,20 +123,21 @@ def main():
123123
isocenter_distance = 400.0
124124

125125
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
126-
phantom_torch = torch.tensor(phantom_cpu, device=device)
126+
phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32)
127127

128128
# Generate the "real" sinogram
129-
real_sinogram = ConeProjectorFunction.apply(phantom_torch, angles_np,
129+
angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
130+
real_sinogram = ConeProjectorFunction.apply(phantom_torch, angles_torch,
130131
det_u, det_v, du, dv,
131132
source_distance, isocenter_distance)
132133

133134
pipeline_instance = Pipeline(lr=1e-1,
134-
volume_shape=(Nz,Ny,Nx),
135-
angles=angles_np,
136-
det_u=det_u, det_v=det_v,
137-
du=du, dv=dv,
138-
source_distance=source_distance,
139-
isocenter_distance=isocenter_distance,
135+
volume_shape=(Nz,Ny,Nx),
136+
angles=angles_torch,
137+
det_u=det_u, det_v=det_v,
138+
du=du, dv=dv,
139+
source_distance=source_distance,
140+
isocenter_distance=isocenter_distance,
140141
device=device, epoches=1000)
141142

142143
ini_guess = torch.zeros_like(phantom_torch)

examples/iterative_reco_fan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,13 @@ def main():
9393
angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
9494

9595
num_detectors = 256
96-
detector_spacing = 0.75
96+
detector_spacing = 1.0
9797
source_distance = 600.0
9898
isocenter_distance = 400.0
9999

100100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
101-
phantom_torch = torch.tensor(phantom_cpu, device=device)
102-
angles_torch = torch.tensor(angles_np, device=device)
101+
phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32)
102+
angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
103103

104104
# Generate the "real" sinogram
105105
real_sinogram = FanProjectorFunction.apply(phantom_torch, angles_torch,

examples/iterative_reco_parallel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ def main():
8484
angles_np = np.linspace(0, 2 * math.pi, num_views, endpoint=False).astype(np.float32)
8585

8686
num_detectors = 256
87-
detector_spacing = 0.75
87+
detector_spacing = 0.5
8888

8989
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
90-
phantom_torch = torch.tensor(phantom_cpu, device=device)
91-
angles_torch = torch.tensor(angles_np, device=device)
90+
phantom_torch = torch.tensor(phantom_cpu, device=device, dtype=torch.float32)
91+
angles_torch = torch.tensor(angles_np, device=device, dtype=torch.float32)
9292

9393
# Generate the "real" sinogram
9494
real_sinogram = ParallelProjectorFunction.apply(phantom_torch, angles_torch,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "diffct"
7-
version = "1.1.7"
7+
version = "1.2.0"
88
description = "A CUDA-based library for computed tomography (CT) projection and reconstruction with differentiable operators"
99
readme = "README.md"
1010
authors = [

0 commit comments

Comments
 (0)