From 9397d0b2e6bdcbaa6ab1302a7ca4fdfc3d963e06 Mon Sep 17 00:00:00 2001 From: xla authors Date: Fri, 3 Apr 2026 01:28:29 -0700 Subject: [PATCH] Automated Code Change PiperOrigin-RevId: 893948985 --- xla/stream_executor/gpu/gpu_test_kernels_lib.cu.h | 13 +++++++------ .../gpu/redzone_allocator_kernel_lib.cu.h | 7 +++---- xla/stream_executor/gpu/repeat_buffer_kernel.cu.h | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/xla/stream_executor/gpu/gpu_test_kernels_lib.cu.h b/xla/stream_executor/gpu/gpu_test_kernels_lib.cu.h index d92bc5c536b00..dd1b76e02b884 100644 --- a/xla/stream_executor/gpu/gpu_test_kernels_lib.cu.h +++ b/xla/stream_executor/gpu/gpu_test_kernels_lib.cu.h @@ -27,33 +27,34 @@ namespace stream_executor::gpu { // name, therefore we switch off name mangling extern "C" { -__global__ void AddI32(int32_t* a, int32_t* b, int32_t* c) { +inline __global__ void AddI32(int32_t* a, int32_t* b, int32_t* c) { int index = threadIdx.x + blockIdx.x * blockDim.x; c[index] = a[index] + b[index]; } -__global__ void IncI32(int32_t a, int32_t* b, int32_t* c) { +inline __global__ void IncI32(int32_t a, int32_t* b, int32_t* c) { int index = threadIdx.x + blockIdx.x * blockDim.x; c[index] = a + b[index]; } -__global__ void MulI32(int32_t* a, int32_t* b, int32_t* c) { +inline __global__ void MulI32(int32_t* a, int32_t* b, int32_t* c) { int index = threadIdx.x + blockIdx.x * blockDim.x; c[index] = a[index] * b[index]; } -__global__ void IncAndCmp(int32_t* counter, bool* pred, int32_t* value) { +inline __global__ void IncAndCmp(int32_t* counter, bool* pred, int32_t* value) { int index = threadIdx.x + blockIdx.x * blockDim.x; pred[index] = counter[index] < *value; counter[index] += 1; } -__global__ void AddI32Ptrs3(Ptrs3 ptrs) { +inline __global__ void AddI32Ptrs3(Ptrs3 ptrs) { int index = threadIdx.x + blockIdx.x * blockDim.x; ptrs.c[index] = ptrs.a[index] + ptrs.b[index]; } -__global__ void CopyKernel(std::byte* dst, std::array byval) { +inline __global__ void CopyKernel(std::byte* dst, + std::array byval) { if (threadIdx.x == 0) { for (int i = 0; i < byval.size(); i++) { dst[i] = byval[i]; diff --git a/xla/stream_executor/gpu/redzone_allocator_kernel_lib.cu.h b/xla/stream_executor/gpu/redzone_allocator_kernel_lib.cu.h index 1920618f0b34d..e9c3b0df0c612 100644 --- a/xla/stream_executor/gpu/redzone_allocator_kernel_lib.cu.h +++ b/xla/stream_executor/gpu/redzone_allocator_kernel_lib.cu.h @@ -20,10 +20,9 @@ limitations under the License. namespace stream_executor::gpu { -__global__ void RedzoneAllocatorKernelImpl(uint8_t* input_buffer, - uint8_t redzone_pattern, - uint64_t buffer_length, - uint32_t* out_mismatched_ptr) { +inline __global__ void RedzoneAllocatorKernelImpl( + uint8_t* input_buffer, uint8_t redzone_pattern, uint64_t buffer_length, + uint32_t* out_mismatched_ptr) { const uint64_t block_dim_x = static_cast(blockDim.x), stride = block_dim_x * gridDim.x; for (uint64_t idx = threadIdx.x + blockIdx.x * block_dim_x; diff --git a/xla/stream_executor/gpu/repeat_buffer_kernel.cu.h b/xla/stream_executor/gpu/repeat_buffer_kernel.cu.h index 2f5f042e20cb8..361af7c35cabb 100644 --- a/xla/stream_executor/gpu/repeat_buffer_kernel.cu.h +++ b/xla/stream_executor/gpu/repeat_buffer_kernel.cu.h @@ -23,8 +23,8 @@ namespace stream_executor::gpu { // Populate the last `buffer_size - repeat_size` bytes of `buffer` by repeating // the first `repeat_size` bytes. This should be launched with at least // `repeat_size` threads in total. -__global__ void RepeatBufferKernelImpl(char* buffer, int64_t repeat_size, - int64_t buffer_size) { +inline __global__ void RepeatBufferKernelImpl(char* buffer, int64_t repeat_size, + int64_t buffer_size) { int64_t global_index = blockDim.x * blockIdx.x + threadIdx.x; if (global_index >= repeat_size) { return;