Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions xla/stream_executor/gpu/gpu_test_kernels_lib.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,34 @@ namespace stream_executor::gpu {
// name, therefore we switch off name mangling
extern "C" {

__global__ void AddI32(int32_t* a, int32_t* b, int32_t* c) {
inline __global__ void AddI32(int32_t* a, int32_t* b, int32_t* c) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
c[index] = a[index] + b[index];
}

__global__ void IncI32(int32_t a, int32_t* b, int32_t* c) {
inline __global__ void IncI32(int32_t a, int32_t* b, int32_t* c) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
c[index] = a + b[index];
}

__global__ void MulI32(int32_t* a, int32_t* b, int32_t* c) {
inline __global__ void MulI32(int32_t* a, int32_t* b, int32_t* c) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
c[index] = a[index] * b[index];
}

__global__ void IncAndCmp(int32_t* counter, bool* pred, int32_t* value) {
inline __global__ void IncAndCmp(int32_t* counter, bool* pred, int32_t* value) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
pred[index] = counter[index] < *value;
counter[index] += 1;
}

__global__ void AddI32Ptrs3(Ptrs3<int32_t> ptrs) {
inline __global__ void AddI32Ptrs3(Ptrs3<int32_t> ptrs) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
ptrs.c[index] = ptrs.a[index] + ptrs.b[index];
}

__global__ void CopyKernel(std::byte* dst, std::array<std::byte, 16> byval) {
inline __global__ void CopyKernel(std::byte* dst,
std::array<std::byte, 16> byval) {
if (threadIdx.x == 0) {
for (int i = 0; i < byval.size(); i++) {
dst[i] = byval[i];
Expand Down
7 changes: 3 additions & 4 deletions xla/stream_executor/gpu/redzone_allocator_kernel_lib.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ limitations under the License.

namespace stream_executor::gpu {

__global__ void RedzoneAllocatorKernelImpl(uint8_t* input_buffer,
uint8_t redzone_pattern,
uint64_t buffer_length,
uint32_t* out_mismatched_ptr) {
inline __global__ void RedzoneAllocatorKernelImpl(
uint8_t* input_buffer, uint8_t redzone_pattern, uint64_t buffer_length,
uint32_t* out_mismatched_ptr) {
const uint64_t block_dim_x = static_cast<uint64_t>(blockDim.x),
stride = block_dim_x * gridDim.x;
for (uint64_t idx = threadIdx.x + blockIdx.x * block_dim_x;
Expand Down
4 changes: 2 additions & 2 deletions xla/stream_executor/gpu/repeat_buffer_kernel.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ namespace stream_executor::gpu {
// Populate the last `buffer_size - repeat_size` bytes of `buffer` by repeating
// the first `repeat_size` bytes. This should be launched with at least
// `repeat_size` threads in total.
__global__ void RepeatBufferKernelImpl(char* buffer, int64_t repeat_size,
int64_t buffer_size) {
inline __global__ void RepeatBufferKernelImpl(char* buffer, int64_t repeat_size,
int64_t buffer_size) {
int64_t global_index = blockDim.x * blockIdx.x + threadIdx.x;
if (global_index >= repeat_size) {
return;
Expand Down
Loading