Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion vedadet/ops/dcn/src/cuda/deform_pool_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,33 @@
// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu

#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
#include <c10/util/Half.h>
#include <stdio.h>
#include <math.h>
#include <algorithm>

using namespace at;

__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;

do {
assumed = old;
unsigned short hsum = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
hsum += val;
old = reinterpret_cast<size_t>(address) & 2
? (old & 0xffff) | (hsum << 16)
: (old & 0xffff0000) | hsum;
old = atomicCAS(address_as_ui, assumed, old);

// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
}

#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
Expand Down
24 changes: 14 additions & 10 deletions vedadet/ops/nms/src/cuda/nms_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/DeviceGuard.h>

#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
//#include <THC/THC.h>
//#include <THC/THCDeviceUtils.cuh>

#include <ATen/cuda/ThrustAllocator.h>
#include <ATen/ceil_div.h>


#include <vector>
#include <iostream>
Expand Down Expand Up @@ -62,7 +66,7 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
t |= 1ULL << i;
}
}
const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock);
const int col_blocks = at::ceil_div(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
Expand All @@ -81,28 +85,28 @@ at::Tensor nms_cuda_forward(const at::Tensor boxes, float nms_overlap_thresh) {

int boxes_num = boxes.size(0);

const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock);
const int col_blocks = at::ceil_div(boxes_num, threadsPerBlock);

scalar_t* boxes_dev = boxes_sorted.data_ptr<scalar_t>();

THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState
//THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState

unsigned long long* mask_dev = NULL;
//THCudaCheck(THCudaMalloc(state, (void**) &mask_dev,
// boxes_num * col_blocks * sizeof(unsigned long long)));

mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long));
mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc(boxes_num * col_blocks * sizeof(unsigned long long));

dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock),
THCCeilDiv(boxes_num, threadsPerBlock));
dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock),
at::ceil_div(boxes_num, threadsPerBlock));
dim3 threads(threadsPerBlock);
nms_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(boxes_num,
nms_overlap_thresh,
boxes_dev,
mask_dev);

std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
THCudaCheck(cudaMemcpyAsync(
C10_CUDA_CHECK(cudaMemcpyAsync(
&mask_host[0],
mask_dev,
sizeof(unsigned long long) * boxes_num * col_blocks,
Expand Down Expand Up @@ -130,7 +134,7 @@ at::Tensor nms_cuda_forward(const at::Tensor boxes, float nms_overlap_thresh) {
}
}

THCudaFree(state, mask_dev);
c10::cuda::CUDACachingAllocator::raw_delete( mask_dev);
// TODO improve this part
return order_t.index({
keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(
Expand Down
18 changes: 10 additions & 8 deletions vedadet/ops/sigmoid_focal_loss/src/cuda/sigmoid_focal_loss_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
/*#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <THC/THCDeviceUtils.cuh>*/
#include <ATen/cuda/ThrustAllocator.h>
#include <ATen/ceil_div.h>

#include <cfloat>

Expand Down Expand Up @@ -112,11 +114,11 @@ at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits,
auto losses_size = num_samples * logits.size(1);

dim3 grid(
std::min(THCCeilDiv((int64_t)losses_size, (int64_t)512), (int64_t)4096));
std::min(at::ceil_div((int64_t)losses_size, (int64_t)512), (int64_t)4096));
dim3 block(512);

if (losses.numel() == 0) {
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return losses;
}

Expand All @@ -128,7 +130,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda(const at::Tensor &logits,
targets.contiguous().data_ptr<int64_t>(), num_classes, gamma,
alpha, num_samples, losses.data_ptr<scalar_t>());
});
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return losses;
}

Expand All @@ -151,12 +153,12 @@ at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits,
auto d_logits = at::zeros({num_samples, num_classes}, logits.options());
auto d_logits_size = num_samples * logits.size(1);

dim3 grid(std::min(THCCeilDiv((int64_t)d_logits_size, (int64_t)512),
dim3 grid(std::min(at::ceil_div((int64_t)d_logits_size, (int64_t)512),
(int64_t)4096));
dim3 block(512);

if (d_logits.numel() == 0) {
THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return d_logits;
}

Expand All @@ -170,6 +172,6 @@ at::Tensor SigmoidFocalLoss_backward_cuda(const at::Tensor &logits,
alpha, num_samples, d_logits.data_ptr<scalar_t>());
});

THCudaCheck(cudaGetLastError());
C10_CUDA_CHECK(cudaGetLastError());
return d_logits;
}