From a4eb3bcbf007f856a75af0f52d0533c261ae527e Mon Sep 17 00:00:00 2001 From: Shi Yan Date: Tue, 9 May 2023 21:44:47 +0000 Subject: [PATCH] Fix build issues with newer pytorch The build issues are due to the deprecation of the THC lib, This migration is based on the suggestions from https://github.com/pytorch/pytorch/issues/72807#issuecomment-1039505288 The torch extension infra seems to enforce no_half_operators, this results in an issue with the atomicAdd function which does accept c10::Half The workaround is based on this discussion https://forums.developer.nvidia.com/t/atomicadd-not-overloaded-for-c10-half/ --- .../dcn/src/cuda/deform_pool_cuda_kernel.cu | 22 ++++++++++++++++- vedadet/ops/nms/src/cuda/nms_kernel.cu | 24 +++++++++++-------- .../src/cuda/sigmoid_focal_loss_cuda.cu | 18 +++++++------- 3 files changed, 45 insertions(+), 19 deletions(-) diff --git a/vedadet/ops/dcn/src/cuda/deform_pool_cuda_kernel.cu b/vedadet/ops/dcn/src/cuda/deform_pool_cuda_kernel.cu index 18e3a04..c0f01a1 100644 --- a/vedadet/ops/dcn/src/cuda/deform_pool_cuda_kernel.cu +++ b/vedadet/ops/dcn/src/cuda/deform_pool_cuda_kernel.cu @@ -9,13 +9,33 @@ // modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu #include -#include +#include +#include +#include #include #include #include using namespace at; +__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { + unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + unsigned short hsum = reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); + hsum += val; + old = reinterpret_cast(address) & 2 + ? (old & 0xffff) | (hsum << 16) + : (old & 0xffff0000) | hsum; + old = atomicCAS(address_as_ui, assumed, old); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); +} + #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ diff --git a/vedadet/ops/nms/src/cuda/nms_kernel.cu b/vedadet/ops/nms/src/cuda/nms_kernel.cu index bb6d18a..58706c3 100644 --- a/vedadet/ops/nms/src/cuda/nms_kernel.cu +++ b/vedadet/ops/nms/src/cuda/nms_kernel.cu @@ -3,8 +3,12 @@ #include #include -#include -#include +//#include +//#include + +#include +#include + #include #include @@ -62,7 +66,7 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, t |= 1ULL << i; } } - const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); + const int col_blocks = at::ceil_div(n_boxes, threadsPerBlock); dev_mask[cur_box_idx * col_blocks + col_start] = t; } } @@ -81,20 +85,20 @@ at::Tensor nms_cuda_forward(const at::Tensor boxes, float nms_overlap_thresh) { int boxes_num = boxes.size(0); - const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); + const int col_blocks = at::ceil_div(boxes_num, threadsPerBlock); scalar_t* boxes_dev = boxes_sorted.data_ptr(); - THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState + //THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState unsigned long long* mask_dev = NULL; //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, // boxes_num * col_blocks * sizeof(unsigned long long))); - mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); + mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc(boxes_num * col_blocks * sizeof(unsigned long long)); - dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), - THCCeilDiv(boxes_num, threadsPerBlock)); + dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock), + at::ceil_div(boxes_num, threadsPerBlock)); dim3 threads(threadsPerBlock); nms_kernel<<>>(boxes_num, nms_overlap_thresh, @@ -102,7 +106,7 @@ at::Tensor nms_cuda_forward(const at::Tensor boxes, float nms_overlap_thresh) { mask_dev); std::vector mask_host(boxes_num * col_blocks); - THCudaCheck(cudaMemcpyAsync( + C10_CUDA_CHECK(cudaMemcpyAsync( &mask_host[0], mask_dev, sizeof(unsigned long long) * boxes_num * col_blocks, @@ -130,7 +134,7 @@ at::Tensor nms_cuda_forward(const at::Tensor boxes, float nms_overlap_thresh) { } } - THCudaFree(state, mask_dev); + c10::cuda::CUDACachingAllocator::raw_delete( mask_dev); // TODO improve this part return order_t.index({ keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to( diff --git a/vedadet/ops/sigmoid_focal_loss/src/cuda/sigmoid_focal_loss_cuda.cu b/vedadet/ops/sigmoid_focal_loss/src/cuda/sigmoid_focal_loss_cuda.cu index 012d01c..1e31b0f 100644 --- a/vedadet/ops/sigmoid_focal_loss/src/cuda/sigmoid_focal_loss_cuda.cu +++ b/vedadet/ops/sigmoid_focal_loss/src/cuda/sigmoid_focal_loss_cuda.cu @@ -9,9 +9,11 @@ #include #include -#include +/*#include #include -#include +#include */ +#include +#include #include @@ -112,11 +114,11 @@ at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits, auto losses_size = num_samples * logits.size(1); dim3 grid( - std::min(THCCeilDiv((int64_t)losses_size, (int64_t)512), (int64_t)4096)); + std::min(at::ceil_div((int64_t)losses_size, (int64_t)512), (int64_t)4096)); dim3 block(512); if (losses.numel() == 0) { - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return losses; } @@ -128,7 +130,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits, targets.contiguous().data_ptr(), num_classes, gamma, alpha, num_samples, losses.data_ptr()); }); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return losses; } @@ -151,12 +153,12 @@ at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits, auto d_logits = at::zeros({num_samples, num_classes}, logits.options()); auto d_logits_size = num_samples * logits.size(1); - dim3 grid(std::min(THCCeilDiv((int64_t)d_logits_size, (int64_t)512), + dim3 grid(std::min(at::ceil_div((int64_t)d_logits_size, (int64_t)512), (int64_t)4096)); dim3 block(512); if (d_logits.numel() == 0) { - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return d_logits; } @@ -170,6 +172,6 @@ at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits, alpha, num_samples, d_logits.data_ptr()); }); - THCudaCheck(cudaGetLastError()); + C10_CUDA_CHECK(cudaGetLastError()); return d_logits; }