From 28e41fbe8ffa4dedc879fde2715d7eec14fc4b59 Mon Sep 17 00:00:00 2001 From: operel Date: Thu, 27 Apr 2023 13:44:40 +0300 Subject: [PATCH] wire coord grads from kernel Signed-off-by: operel --- wisp/csrc/ops/hashgrid_interpolate.cpp | 4 ++-- wisp/csrc/ops/hashgrid_interpolate.h | 2 +- wisp/ops/grid.py | 10 ++++++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/wisp/csrc/ops/hashgrid_interpolate.cpp b/wisp/csrc/ops/hashgrid_interpolate.cpp index c1bb3b0..b69e9f1 100644 --- a/wisp/csrc/ops/hashgrid_interpolate.cpp +++ b/wisp/csrc/ops/hashgrid_interpolate.cpp @@ -65,7 +65,7 @@ at::Tensor hashgrid_interpolate_cuda( #endif // WITH_CUDA } -at::Tensor hashgrid_interpolate_backward_cuda( +std::vector hashgrid_interpolate_backward_cuda( at::Tensor coords, at::Tensor grad_output, at::Tensor codebook, @@ -93,7 +93,7 @@ at::Tensor hashgrid_interpolate_backward_cuda( resolution[i], i, num_lods, require_grad_coords, coords, codebook, codebook_first_idx, grad_output, grad_codebook, grad_coords); } - return grad_codebook; + return {grad_codebook, grad_coords}; #else AT_ERROR(__func__); #endif // WITH_CUDA diff --git a/wisp/csrc/ops/hashgrid_interpolate.h b/wisp/csrc/ops/hashgrid_interpolate.h index ad2b95a..5557477 100644 --- a/wisp/csrc/ops/hashgrid_interpolate.h +++ b/wisp/csrc/ops/hashgrid_interpolate.h @@ -22,7 +22,7 @@ at::Tensor hashgrid_interpolate_cuda( std::vector resolution, int32_t codebook_bitwidth); -at::Tensor hashgrid_interpolate_backward_cuda( +std::vector hashgrid_interpolate_backward_cuda( at::Tensor coords, at::Tensor grad_output, at::Tensor codebook, diff --git a/wisp/ops/grid.py b/wisp/ops/grid.py index 9f11db3..f625c74 100644 --- a/wisp/ops/grid.py +++ b/wisp/ops/grid.py @@ -111,7 +111,6 @@ def forward(ctx, coords, resolutions, codebook_bitwidth, lod_idx, codebook, code @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, grad_output): - coords = ctx.saved_tensors[0] codebook = ctx.saved_tensors[1] codebook_first_idx = ctx.saved_tensors[2] @@ -119,12 +118,15 @@ def backward(ctx, grad_output): feature_dim = ctx.feature_dim codebook_bitwidth = ctx.codebook_bitwidth - grad_codebook = wisp_C.ops.hashgrid_interpolate_backward_cuda( + is_needs_grad_by_coords = ctx.needs_input_grad[0] + grad_codebook, grad_coords = wisp_C.ops.hashgrid_interpolate_backward_cuda( coords.float().contiguous(), grad_output.contiguous(), codebook, codebook_first_idx, resolutions, - codebook_bitwidth, feature_dim, ctx.needs_input_grad[0]) - return (None, None, None, None, grad_codebook, None, None) + codebook_bitwidth, feature_dim, is_needs_grad_by_coords) + if not is_needs_grad_by_coords: + grad_coords = None + return grad_coords, None, None, None, grad_codebook, None, None def hashgrid(coords, resolutions, codebook_bitwidth, lod_idx, codebook, codebook_sizes, codebook_first_idx): """A hash-grid query + interpolation function, accelerated with CUDA.