Skip to content

Commit 49ac06f

Browse files
committed
Refactor fastmath decorators for backward kernels to ensure gradient correctness
1 parent 1f3e8bf commit 49ac06f

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

diffct/differentiable.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
# Trades numerical precision for performance in ray-tracing calculations
2121
# Safe for CT reconstruction where slight precision loss is acceptable for speed gains
2222
_FASTMATH_DECORATOR = cuda.jit(cache=True, fastmath=True)
23-
# Disable fastmath for backward kernels to ensure gradient correctness
24-
_NON_FASTMATH_DECORATOR = cuda.jit(cache=True, fastmath=False)
23+
2524
_INF = _DTYPE(np.inf)
2625
_EPSILON = _DTYPE(1e-6)
2726
# === Device Management Utilities ===
@@ -485,7 +484,7 @@ def _parallel_2d_forward_kernel(
485484

486485
d_sino[iang, idet] = accum
487486

488-
@_NON_FASTMATH_DECORATOR
487+
@_FASTMATH_DECORATOR
489488
def _parallel_2d_backward_kernel(
490489
d_sino, n_ang, n_det,
491490
d_image, Nx, Ny,
@@ -765,7 +764,7 @@ def _fan_2d_forward_kernel(
765764

766765
d_sino[iang, idet] = accum
767766

768-
@_NON_FASTMATH_DECORATOR
767+
@_FASTMATH_DECORATOR
769768
def _fan_2d_backward_kernel(
770769
d_sino, n_ang, n_det,
771770
d_image, Nx, Ny,
@@ -1115,7 +1114,7 @@ def _cone_3d_forward_kernel(
11151114

11161115
d_sino[iview, iu, iv] = accum
11171116

1118-
@_NON_FASTMATH_DECORATOR
1117+
@_FASTMATH_DECORATOR
11191118
def _cone_3d_backward_kernel(
11201119
d_sino, n_views, n_u, n_v,
11211120
d_vol, Nx, Ny, Nz,

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,4 @@ where = ["."]
3636
[tool.hatch.envs.default]
3737
python = "python"
3838

39-
[tool.hatch.envs.default.env-vars]
40-
PYTHONDONTWRITEBYTECODE = "1"
39+
[tool.hatch.envs.default.env-vars]

0 commit comments

Comments
 (0)