diff --git a/docs/user-tutorial/benchmarks/micro-benchmarks.md b/docs/user-tutorial/benchmarks/micro-benchmarks.md index aa3aa965b..dbec712dc 100644 --- a/docs/user-tutorial/benchmarks/micro-benchmarks.md +++ b/docs/user-tutorial/benchmarks/micro-benchmarks.md @@ -273,14 +273,14 @@ Measure the memory bandwidth of GPU using the STREAM benchmark. The benchmark te | Metric Name | Unit | Description | |------------------------------------------------------------|------------------|-----------------------------------------------------------------------------------------------------------------------------------------| -| STREAM\_COPY\_double\_gpu\_[0-9]\_buffer\_[0-9]+\_block\_[0-9]+\_bw | bandwidth (GB/s) | The fp64 memory bandwidth of the GPU for the copy operation with specified buffer size and block size. | -| STREAM\_SCALE\_double\_gpu\_[0-9]\_buffer\_[0-9]+\_block\_[0-9]+\_bw | bandwidth (GB/s) | The fp64 memory bandwidth of the GPU for the scale operation with specified buffer size and block size. | -| STREAM\_ADD\_double\_gpu\_[0-9]\_buffer\_[0-9]+\_block\_[0-9]+\_bw | bandwidth (GB/s) | The fp64 memory bandwidth of the GPU for the add operation with specified buffer size and block size. | -| STREAM\_TRIAD\_double\_gpu\_[0-9]\_buffer\_[0-9]+\_block\_[0-9]+\_bw | bandwidth (GB/s) | The fp64 memory bandwidth of the GPU for the triad operation with specified buffer size and block size. | -| STREAM\_COPY\_double\_gpu\_[0-9]\_buffer\_[0-9]+\_block\_[0-9]+\_ratio | Efficiency (%) | The fp64 memory bandwidth efficiency of the GPU for the copy operation with specified buffer size and block size. | -| STREAM\_SCALE\_double\_gpu\_[0-9]\_buffer\_[0-9]+\_block\_[0-9]+\_ratio | Efficiency (%) | The fp64 memory bandwidth efficiency of the GPU for the scale operation with specified buffer size and block size. | -| STREAM\_ADD\_double\_gpu\_[0-9]\_buffer\_[0-9]+\_block\_[0-9]+\_ratio | Efficiency (%) | The fp64 memory bandwidth efficiency of the GPU for the add operation with specified buffer size and block size. | -| STREAM\_TRIAD\_double\_gpu\_[0-9]\_buffer\_[0-9]+\_block\_[0-9]+\_ratio | Efficiency (%) | The fp64 memory bandwidth efficiency of the GPU for the triad operation with specified buffer size and block size. | +| STREAM\_COPY\_(double\|float)\_buffer\_[0-9]+\_block\_[0-9]+\_bw | bandwidth (GB/s) | The fp64 memory bandwidth of the GPU for the copy operation with specified buffer size and block size. | +| STREAM\_SCALE\_(double\|float)\_buffer\_[0-9]+\_block\_[0-9]+\_bw | bandwidth (GB/s) | The fp64 memory bandwidth of the GPU for the scale operation with specified buffer size and block size. | +| STREAM\_ADD\_(double\|float)\_buffer\_[0-9]+\_block\_[0-9]+\_bw | bandwidth (GB/s) | The fp64 memory bandwidth of the GPU for the add operation with specified buffer size and block size. | +| STREAM\_TRIAD\_(double\|float)\_buffer\_[0-9]+\_block\_[0-9]+\_bw | bandwidth (GB/s) | The fp64 memory bandwidth of the GPU for the triad operation with specified buffer size and block size. | +| STREAM\_COPY\_(double\|float)\_buffer\_[0-9]+\_block\_[0-9]+\_ratio | Efficiency (%) | The fp64 memory bandwidth efficiency of the GPU for the copy operation with specified buffer size and block size. | +| STREAM\_SCALE\_(double\|float)\_buffer\_[0-9]+\_block\_[0-9]+\_ratio | Efficiency (%) | The fp64 memory bandwidth efficiency of the GPU for the scale operation with specified buffer size and block size. | +| STREAM\_ADD\_(double\|float)\_buffer\_[0-9]+\_block\_[0-9]+\_ratio | Efficiency (%) | The fp64 memory bandwidth efficiency of the GPU for the add operation with specified buffer size and block size. | +| STREAM\_TRIAD\_(double\|float)\_buffer\_[0-9]+\_block\_[0-9]+\_ratio | Efficiency (%) | The fp64 memory bandwidth efficiency of the GPU for the triad operation with specified buffer size and block size. | ### `ib-loopback` diff --git a/examples/benchmarks/gpu_stream.py b/examples/benchmarks/gpu_stream.py index 88c789efb..1aa67b15d 100644 --- a/examples/benchmarks/gpu_stream.py +++ b/examples/benchmarks/gpu_stream.py @@ -12,7 +12,7 @@ if __name__ == '__main__': context = BenchmarkRegistry.create_benchmark_context( - 'gpu-stream', platform=Platform.CUDA, parameters='--num_warm_up 1 --num_loops 10' + 'gpu-stream', platform=Platform.CUDA, parameters='--num_warm_up 1 --num_loops 10 --data_type double' ) # For ROCm environment, please specify the benchmark name and the platform as the following. # context = BenchmarkRegistry.create_benchmark_context( diff --git a/superbench/benchmarks/micro_benchmarks/gpu_stream.py b/superbench/benchmarks/micro_benchmarks/gpu_stream.py index ecc90951f..44908e6b6 100644 --- a/superbench/benchmarks/micro_benchmarks/gpu_stream.py +++ b/superbench/benchmarks/micro_benchmarks/gpu_stream.py @@ -51,6 +51,15 @@ def add_parser_arguments(self): help='Number of data buffer copies performed.', ) + self._parser.add_argument( + '--data_type', + type=str, + default='double', + choices=['float', 'double'], + required=False, + help='Data type of the buffer elements.', + ) + self._parser.add_argument( '--check_data', action='store_true', @@ -68,8 +77,8 @@ def _preprocess(self): self.__bin_path = os.path.join(self._args.bin_dir, self._bin_name) - args = '--size %d --num_warm_up %d --num_loops %d ' % ( - self._args.size, self._args.num_warm_up, self._args.num_loops + args = '--size %d --num_warm_up %d --num_loops %d --data_type %s' % ( + self._args.size, self._args.num_warm_up, self._args.num_loops, self._args.data_type ) if self._args.check_data: diff --git a/superbench/benchmarks/micro_benchmarks/gpu_stream/CMakeLists.txt b/superbench/benchmarks/micro_benchmarks/gpu_stream/CMakeLists.txt index 2c856f32a..ce15d10c7 100644 --- a/superbench/benchmarks/micro_benchmarks/gpu_stream/CMakeLists.txt +++ b/superbench/benchmarks/micro_benchmarks/gpu_stream/CMakeLists.txt @@ -15,7 +15,7 @@ find_package(CUDAToolkit QUIET) # Source files set(SOURCES - gpu_stream_test.cpp + gpu_stream_main.cpp gpu_stream_utils.cpp gpu_stream.cu gpu_stream_kernels.cu diff --git a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream.cu b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream.cu index 617b8338a..5e5ac90e5 100644 --- a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream.cu +++ b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream.cu @@ -235,15 +235,15 @@ template int GpuStream::PrepareBufAndStream(std::unique_ptrcheck_data) { - // Generate data to copy - args->sub.data_buf = static_cast(numa_alloc_onnode(args->size * sizeof(T), args->numa_id)); + // Generate data to copy - use local NUMA node for best CPU access + args->sub.data_buf = static_cast(numa_alloc_local(args->size * sizeof(T))); for (int j = 0; j < args->size / sizeof(T); j++) { args->sub.data_buf[j] = static_cast(j % kUInt8Mod); } - // Allocate check buffer - args->sub.check_buf = static_cast(numa_alloc_onnode(args->size * sizeof(T), args->numa_id)); + // Allocate check buffer on local NUMA node + args->sub.check_buf = static_cast(numa_alloc_local(args->size * sizeof(T))); } // Allocate buffers @@ -420,8 +420,10 @@ int GpuStream::RunStreamKernel(std::unique_ptr> &args, Kernel kerne int size_factor = 2; // Validate data size - uint64_t num_elements_in_thread_block = kNumLoopUnroll * num_threads_per_block; - uint64_t num_bytes_in_thread_block = num_elements_in_thread_block * sizeof(T); + // Each thread processes 128 bits (16 bytes) for optimal memory bandwidth. + // For double: uses double2 (16 bytes). For float: would use float4 (16 bytes). + constexpr uint64_t kBytesPerThread = 16; // 128-bit aligned access + uint64_t num_bytes_in_thread_block = num_threads_per_block * kBytesPerThread; if (args->size % num_bytes_in_thread_block) { std::cerr << "RunCopy: Data size should be multiple of " << num_bytes_in_thread_block << std::endl; return -1; @@ -448,30 +450,30 @@ int GpuStream::RunStreamKernel(std::unique_ptr> &args, Kernel kerne switch (kernel) { case Kernel::kCopy: - CopyKernel<<sub.stream>>>( - reinterpret_cast(args->sub.gpu_buf_ptrs[2].get()), - reinterpret_cast(args->sub.gpu_buf_ptrs[0].get())); + CopyKernel<<sub.stream>>>( + reinterpret_cast *>(args->sub.gpu_buf_ptrs[2].get()), + reinterpret_cast *>(args->sub.gpu_buf_ptrs[0].get())); args->sub.kernel_name = "COPY"; break; case Kernel::kScale: - ScaleKernel<<sub.stream>>>( - reinterpret_cast(args->sub.gpu_buf_ptrs[2].get()), - reinterpret_cast(args->sub.gpu_buf_ptrs[0].get()), static_cast(scalar)); + ScaleKernel<<sub.stream>>>( + reinterpret_cast *>(args->sub.gpu_buf_ptrs[2].get()), + reinterpret_cast *>(args->sub.gpu_buf_ptrs[0].get()), static_cast(scalar)); args->sub.kernel_name = "SCALE"; break; case Kernel::kAdd: - AddKernel<<sub.stream>>>( - reinterpret_cast(args->sub.gpu_buf_ptrs[2].get()), - reinterpret_cast(args->sub.gpu_buf_ptrs[0].get()), - reinterpret_cast(args->sub.gpu_buf_ptrs[1].get())); + AddKernel<<sub.stream>>>( + reinterpret_cast *>(args->sub.gpu_buf_ptrs[2].get()), + reinterpret_cast *>(args->sub.gpu_buf_ptrs[0].get()), + reinterpret_cast *>(args->sub.gpu_buf_ptrs[1].get())); size_factor = 3; args->sub.kernel_name = "ADD"; break; case Kernel::kTriad: - TriadKernel<<sub.stream>>>( - reinterpret_cast(args->sub.gpu_buf_ptrs[2].get()), - reinterpret_cast(args->sub.gpu_buf_ptrs[0].get()), - reinterpret_cast(args->sub.gpu_buf_ptrs[1].get()), static_cast(scalar)); + TriadKernel<<sub.stream>>>( + reinterpret_cast *>(args->sub.gpu_buf_ptrs[2].get()), + reinterpret_cast *>(args->sub.gpu_buf_ptrs[0].get()), + reinterpret_cast *>(args->sub.gpu_buf_ptrs[1].get()), static_cast(scalar)); size_factor = 3; args->sub.kernel_name = "TRIAD"; break; @@ -583,10 +585,9 @@ int GpuStream::RunStream(std::unique_ptr> &args, const std::string // output formatted results to stdout // Tags are of format: - // STREAM__datatype_gpu__buffer__block_ + // STREAM__datatype_buffer__block_ for (int i = 0; i < args->sub.times_in_ms.size(); i++) { - std::string tag = "STREAM_" + KernelToString(i) + "_" + data_type + "_gpu_" + std::to_string(args->gpu_id) + - "_buffer_" + std::to_string(args->size); + std::string tag = "STREAM_" + KernelToString(i) + "_" + data_type + "_buffer_" + std::to_string(args->size); for (int j = 0; j < args->sub.times_in_ms[i].size(); j++) { // Calculate and display bandwidth double bw = args->size * args->num_loops / args->sub.times_in_ms[i][j] / 1e6; @@ -608,9 +609,9 @@ int GpuStream::RunStream(std::unique_ptr> &args, const std::string /** * @brief Runs the Stream benchmark. * - * @details This function processes the input args, validates and composes the BenchArgs structure for the - availavble - * GPUs, and runs the benchmark. + * @details This function processes the input args, validates and composes the BenchArgs structure for + * the first visible GPU (CUDA device 0). When running under Superbench's default_local_mode, + * CUDA_VISIBLE_DEVICES is set per process, so device 0 maps to the assigned physical GPU. * * @return int The status code indicating success or failure of the benchmark execution. * */ @@ -631,21 +632,29 @@ int GpuStream::Run() { return ret; } - // find all GPUs and compose the Benchmarking data structure - for (int j = 0; j < gpu_count; j++) { - auto args = std::make_unique>(); - args->numa_id = 0; - args->gpu_id = j; - cudaGetDeviceProperties(&args->gpu_device_prop, j); + if (gpu_count < 1) { + std::cerr << "Run::No GPU available" << std::endl; + return -1; + } + // Run on CUDA device 0 (the visible GPU assigned by CUDA_VISIBLE_DEVICES). + if (opts_.data_type == "float") { + auto args = std::make_unique>(); + args->gpu_id = 0; + cudaGetDeviceProperties(&args->gpu_device_prop, 0); + args->num_warm_up = opts_.num_warm_up; + args->num_loops = opts_.num_loops; + args->size = opts_.size; + args->check_data = opts_.check_data; + bench_args_.emplace_back(std::move(args)); + } else { + auto args = std::make_unique>(); + args->gpu_id = 0; + cudaGetDeviceProperties(&args->gpu_device_prop, 0); args->num_warm_up = opts_.num_warm_up; args->num_loops = opts_.num_loops; args->size = opts_.size; args->check_data = opts_.check_data; - args->numa_id = 0; - args->gpu_id = j; - - // add data to vector bench_args_.emplace_back(std::move(args)); } @@ -668,14 +677,6 @@ int GpuStream::Run() { // Print device info with both the memory clock and peak bandwidth PrintCudaDeviceInfo(curr_args->gpu_id, curr_args->gpu_device_prop, memory_clock_mhz, peak_bw); - // Set the NUMA node - ret = numa_run_on_node(curr_args->numa_id); - if (ret != 0) { - std::cerr << "Run::numa_run_on_node error: " << errno << std::endl; - has_error = true; - return; - } - // Run the stream benchmark for the configured data, passing the peak bandwidth if constexpr (std::is_same_v, BenchArgs>) { ret = RunStream(curr_args, "float", peak_bw); diff --git a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream.hpp b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream.hpp index 473a78839..754888339 100644 --- a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream.hpp +++ b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream.hpp @@ -34,7 +34,7 @@ class GpuStream { int Run(); private: - using BenchArgsVariant = std::variant>>; + using BenchArgsVariant = std::variant>, std::unique_ptr>>; std::vector bench_args_; Opts opts_; diff --git a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_kernels.cu b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_kernels.cu index 548fc8ba3..e40237b83 100644 --- a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_kernels.cu +++ b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_kernels.cu @@ -1,155 +1,33 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -#include "gpu_stream_kernels.hpp" - -/** - * @brief Fetches a value from source memory and writes it to a register. - * - * @details This inline device function fetches a value from the specified source memory - * location and writes it to the provided register. The implementation references the following: - * 1) NCCL: - * https://github.com/NVIDIA/nccl/blob/7e515921295adaab72adf56ea71a0fafb0ecb5f3/src/collectives/device/common_kernel.h#L483 - * 2) RCCL: - * https://github.com/ROCmSoftwarePlatform/rccl/blob/5c8380ff5b5925cae4bce00b1879a5f930226e8d/src/collectives/device/common_kernel.h#L268 - * - * @tparam T The type of the value to fetch. - * @param[out] v The register to write the fetched value to. - * @param[in] p The source memory location to fetch the value from. - */ -template inline __device__ void Fetch(T &v, const T *p) { -#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) - v = *p; -#else - if constexpr (std::is_same::value) { - asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(v) : "l"(p) : "memory"); - } else if constexpr (std::is_same::value) { - asm volatile("ld.volatile.global.f64 %0, [%1];" : "=d"(v) : "l"(p) : "memory"); - } -#endif -} - /** - * @brief Stores a value from register and writes it to target memory. + * @file gpu_stream_kernels.cu + * @brief CUDA kernel compilation unit for GPU stream benchmark. * - * @details This inline device function stores a value from the provided register - * and writes it to the specified target memory location. The implementation references the following: - * 1) NCCL: - * https://github.com/NVIDIA/nccl/blob/7e515921295adaab72adf56ea71a0fafb0ecb5f3/src/collectives/device/common_kernel.h#L486 - * 2) RCCL: - * https://github.com/ROCmSoftwarePlatform/rccl/blob/5c8380ff5b5925cae4bce00b1879a5f930226e8d/src/collectives/device/common_kernel.h#L276 + * All template kernel implementations (CopyKernel, ScaleKernel, AddKernel, TriadKernel) + * are defined in gpu_stream_kernels.hpp rather than here. This is required because: * - * @tparam T The type of the value to store. - * @param[out] p The target memory location to write the value to. - * @param[in] v The register containing the value to be stored. - */ -template inline __device__ void Store(T *p, const T &v) { -#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) - *p = v; -#else - if constexpr (std::is_same::value) { - asm volatile("st.volatile.global.f32 [%0], %1;" ::"l"(p), "f"(v) : "memory"); - } else if constexpr (std::is_same::value) { - asm volatile("st.volatile.global.f64 [%0], %1;" ::"l"(p), "d"(v) : "memory"); - } -#endif -} - -/** - * @brief Performs COPY, a simple copy operation from source to target. b = a + * 1. **C++ Template Instantiation Model**: Templates are not compiled until they are + * instantiated with concrete types. The compiler needs to see the full template + * definition (not just declaration) at the point of instantiation. * - * @details This CUDA kernel performs a simple copy operation, copying data from the source array - * to the target array. This is used to measure transfer rates without any arithmetic operations. + * 2. **Separate Compilation Units**: When gpu_stream.cu calls `CopyKernel<<<...>>>`, + * nvcc needs the full kernel implementation visible in that translation unit. + * If implementations were only in this .cu file, gpu_stream.cu would only see + * declarations, causing "undefined reference" linker errors. * - * @param[out] tgt The target array where data will be copied to. - * @param[in] src The source array from which data will be copied. - */ -__global__ void CopyKernel(double *tgt, const double *src) { - uint64_t index = blockIdx.x * blockDim.x * kNumLoopUnrollAlias + threadIdx.x; - double val[kNumLoopUnrollAlias]; -#pragma unroll - for (uint64_t i = 0; i < kNumLoopUnrollAlias; i++) - Fetch(val[i], src + index + i * blockDim.x); -#pragma unroll - for (uint64_t i = 0; i < kNumLoopUnrollAlias; i++) - Store(tgt + index + i * blockDim.x, val[i]); -} - -/** - * @brief Performs SCALE, a scaling operation on the source data. b = x * a - * - * @details This CUDA kernel performs a simple arithmetic operation by scaling the source data - * with a given scalar value and storing the result in the target array. - * - * @param[out] tgt The target array where the scaled data will be stored. - * @param[in] src The source array containing the data to be scaled. - * @param[in] scalar The scalar value used to scale the source data. - */ -__global__ void ScaleKernel(double *tgt, const double *src, const double scalar) { - uint64_t index = blockIdx.x * blockDim.x * kNumLoopUnrollAlias + threadIdx.x; - double val[kNumLoopUnrollAlias]; -#pragma unroll - for (uint64_t i = 0; i < kNumLoopUnrollAlias; i++) - Fetch(val[i], src + index + i * blockDim.x); -#pragma unroll - for (uint64_t i = 0; i < kNumLoopUnrollAlias; i++) { - val[i] *= scalar; - Store(tgt + index + i * blockDim.x, val[i]); - } -} - -/** - * @brief Performs ADD, an addition operation on two source arrays. c = a + b - * - * @details This CUDA kernel adds corresponding elements from two source arrays and stores the result - * in the target array. This operation is used to measure transfer rates with a simple arithmetic addition. - * - * @param[out] tgt The target array where the result of the addition will be stored. - * @param[in] src_a The first source array containing the first set of operands. - * @param[in] src_b The second source array containing the second set of operands. - */ -__global__ void AddKernel(double *tgt, const double *src_a, const double *src_b) { - uint64_t index = blockIdx.x * blockDim.x * kNumLoopUnrollAlias + threadIdx.x; - double val_a[kNumLoopUnrollAlias]; - double val_b[kNumLoopUnrollAlias]; - -#pragma unroll - for (uint64_t i = 0; i < kNumLoopUnrollAlias; i++) { - Fetch(val_a[i], src_a + index + i * blockDim.x); - Fetch(val_b[i], src_b + index + i * blockDim.x); - } -#pragma unroll - for (uint64_t i = 0; i < kNumLoopUnrollAlias; i++) { - val_a[i] += val_b[i]; - Store(tgt + index + i * blockDim.x, val_a[i]); - } -} - -/** - * @brief Performs TRIAD, fused multiply/add operations on source arrays. a = b + x * c + * 3. **CUDA-Specific Consideration**: Unlike regular C++ where explicit template + * instantiation in a .cpp file can work, CUDA kernel launches require the kernel + * code to be visible to nvcc when compiling the caller. This is because nvcc + * generates device code at compile time, not link time. * - * @details This CUDA kernel performs a fused multiply/add operation by multiplying elements from - * the second source array with a scalar value, adding the result to corresponding elements from - * the first source array, and storing the result in the target array. + * 4. **Header Guards for Mixed Compilation**: The header uses `#ifdef __CUDACC__` to + * protect CUDA-specific code (blockIdx, threadIdx, __global__, etc.) from g++ + * when the header is indirectly included by .cpp files (e.g., via gpu_stream.hpp). * - * @param[out] tgt The target array where the result of the fused multiply/add operation will be stored. - * @param[in] src_a The first source array containing the first set of operands. - * @param[in] src_b The second source array containing the second set of operands to be multiplied by the scalar. - * @param[in] scalar The scalar value used in the multiply/add operation. + * This file remains as the compilation unit that ensures the header is processed + * by nvcc, and can host any future non-template helper functions if needed. */ -__global__ void TriadKernel(double *tgt, const double *src_a, const double *src_b, const double scalar) { - uint64_t index = blockIdx.x * blockDim.x * kNumLoopUnrollAlias + threadIdx.x; - double val_a[kNumLoopUnrollAlias]; - double val_b[kNumLoopUnrollAlias]; -#pragma unroll - for (uint64_t i = 0; i < kNumLoopUnrollAlias; i++) { - Fetch(val_a[i], src_a + index + i * blockDim.x); - Fetch(val_b[i], src_b + index + i * blockDim.x); - } -#pragma unroll - for (uint64_t i = 0; i < kNumLoopUnrollAlias; i++) { - val_b[i] += (val_a[i] * scalar); - Store(tgt + index + i * blockDim.x, val_b[i]); - } -} \ No newline at end of file +#include "gpu_stream_kernels.hpp" \ No newline at end of file diff --git a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_kernels.hpp b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_kernels.hpp index cfe9f2052..b5ba6a43f 100644 --- a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_kernels.hpp +++ b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_kernels.hpp @@ -7,13 +7,195 @@ #include #include "gpu_stream_utils.hpp" -constexpr auto kNumLoopUnrollAlias = stream_config::kNumLoopUnroll; -// Function declarations -template inline __device__ void Fetch(T &v, const T *p); -template inline __device__ void Store(T *p, const T &v); +/** + * @brief Type trait mapping scalar types to their 128-bit aligned vector types. + * + * @details For optimal memory bandwidth, we use 128-bit (16 byte) vector loads/stores: + * - double -> double2 (2 x 64-bit = 128-bit) + * - float -> float4 (4 x 32-bit = 128-bit) + */ +template struct VectorType; +template <> struct VectorType { using type = double2; }; +template <> struct VectorType { using type = float4; }; -__global__ void CopyKernel(double *, const double *); -__global__ void ScaleKernel(double *, const double *, const double); -__global__ void AddKernel(double *, const double *, const double *); -__global__ void TriadKernel(double *, const double *, const double *, const double); \ No newline at end of file +template using VecT = typename VectorType::type; + +// Kernel declarations (visible to all compilers for function pointer usage) +template __global__ void CopyKernel(VecT *tgt, const VecT *src); +template __global__ void ScaleKernel(VecT *tgt, const VecT *src, const T scalar); +template __global__ void AddKernel(VecT *tgt, const VecT *src_a, const VecT *src_b); +template +__global__ void TriadKernel(VecT *tgt, const VecT *src_a, const VecT *src_b, const T scalar); + +// Implementation section - only compiled by nvcc +#ifdef __CUDACC__ + +/** + * @brief Fetches a value from source memory and writes it to a register. + * + * @details This inline device function fetches a value from the specified source memory + * location and writes it to the provided register. The implementation references the following: + * 1) NCCL: + * https://github.com/NVIDIA/nccl/blob/7e515921295adaab72adf56ea71a0fafb0ecb5f3/src/collectives/device/common_kernel.h#L483 + * 2) RCCL: + * https://github.com/ROCmSoftwarePlatform/rccl/blob/5c8380ff5b5925cae4bce00b1879a5f930226e8d/src/collectives/device/common_kernel.h#L268 + * + * @tparam T The type of the value to fetch. + * @param[out] v The register to write the fetched value to. + * @param[in] p The source memory location to fetch the value from. + */ +template inline __device__ void Fetch(T &v, const T *p) { +#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) + v = *p; +#else + if constexpr (std::is_same::value) { + asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(v) : "l"(p) : "memory"); + } else if constexpr (std::is_same::value) { + asm volatile("ld.volatile.global.f64 %0, [%1];" : "=d"(v) : "l"(p) : "memory"); + } else if constexpr (std::is_same::value) { + asm volatile("ld.volatile.global.v2.f64 {%0,%1}, [%2];" : "=d"(v.x), "=d"(v.y) : "l"(p) : "memory"); + } else if constexpr (std::is_same::value) { + asm volatile("ld.volatile.global.v4.f32 {%0,%1,%2,%3}, [%4];" + : "=f"(v.x), "=f"(v.y), "=f"(v.z), "=f"(v.w) + : "l"(p) + : "memory"); + } +#endif +} + +/** + * @brief Stores a value from register and writes it to target memory. + * + * @details This inline device function stores a value from the provided register + * and writes it to the specified target memory location. The implementation references the following: + * 1) NCCL: + * https://github.com/NVIDIA/nccl/blob/7e515921295adaab72adf56ea71a0fafb0ecb5f3/src/collectives/device/common_kernel.h#L486 + * 2) RCCL: + * https://github.com/ROCmSoftwarePlatform/rccl/blob/5c8380ff5b5925cae4bce00b1879a5f930226e8d/src/collectives/device/common_kernel.h#L276 + * + * @tparam T The type of the value to store. + * @param[out] p The target memory location to write the value to. + * @param[in] v The register containing the value to be stored. + */ +template inline __device__ void Store(T *p, const T &v) { +#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) + *p = v; +#else + if constexpr (std::is_same::value) { + asm volatile("st.volatile.global.f32 [%0], %1;" ::"l"(p), "f"(v) : "memory"); + } else if constexpr (std::is_same::value) { + asm volatile("st.volatile.global.f64 [%0], %1;" ::"l"(p), "d"(v) : "memory"); + } else if constexpr (std::is_same::value) { + asm volatile("st.volatile.global.v2.f64 [%0], {%1,%2};" ::"l"(p), "d"(v.x), "d"(v.y) : "memory"); + } else if constexpr (std::is_same::value) { + asm volatile("st.volatile.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(p), "f"(v.x), "f"(v.y), "f"(v.z), "f"(v.w) + : "memory"); + } +#endif +} + +/** + * @brief Performs COPY, a simple copy operation from source to target. b = a + * + * @details This CUDA kernel performs a simple copy operation, copying data from the source array + * to the target array. This is used to measure transfer rates without any arithmetic operations. + * + * @param[out] tgt The target array where data will be copied to (128-bit aligned). + * @param[in] src The source array from which data will be copied (128-bit aligned). + */ +template __global__ void CopyKernel(VecT *tgt, const VecT *src) { + uint64_t index = blockIdx.x * blockDim.x + threadIdx.x; + VecT val; + Fetch(val, src + index); + Store(tgt + index, val); +} + +/** + * @brief Performs SCALE, a scaling operation on the source data. b = x * a + * + * @details This CUDA kernel performs a simple arithmetic operation by scaling the source data + * with a given scalar value and storing the result in the target array. + * + * @param[out] tgt The target array where the scaled data will be stored (128-bit aligned). + * @param[in] src The source array containing the data to be scaled (128-bit aligned). + * @param[in] scalar The scalar value used to scale the source data. + */ +template __global__ void ScaleKernel(VecT *tgt, const VecT *src, const T scalar) { + uint64_t index = blockIdx.x * blockDim.x + threadIdx.x; + VecT val; + Fetch(val, src + index); + if constexpr (std::is_same::value) { + val.x *= scalar; + val.y *= scalar; + } else if constexpr (std::is_same::value) { + val.x *= scalar; + val.y *= scalar; + val.z *= scalar; + val.w *= scalar; + } + Store(tgt + index, val); +} + +/** + * @brief Performs ADD, an addition operation on two source arrays. c = a + b + * + * @details This CUDA kernel adds corresponding elements from two source arrays and stores the result + * in the target array. This operation is used to measure transfer rates with a simple arithmetic addition. + * + * @param[out] tgt The target array where the result of the addition will be stored (128-bit aligned). + * @param[in] src_a The first source array containing the first set of operands (128-bit aligned). + * @param[in] src_b The second source array containing the second set of operands (128-bit aligned). + */ +template __global__ void AddKernel(VecT *tgt, const VecT *src_a, const VecT *src_b) { + uint64_t index = blockIdx.x * blockDim.x + threadIdx.x; + VecT val_a; + VecT val_b; + Fetch(val_a, src_a + index); + Fetch(val_b, src_b + index); + if constexpr (std::is_same::value) { + val_a.x += val_b.x; + val_a.y += val_b.y; + } else if constexpr (std::is_same::value) { + val_a.x += val_b.x; + val_a.y += val_b.y; + val_a.z += val_b.z; + val_a.w += val_b.w; + } + Store(tgt + index, val_a); +} + +/** + * @brief Performs TRIAD, fused multiply/add operations on source arrays. a = b + x * c + * + * @details This CUDA kernel performs a fused multiply/add operation by multiplying elements from + * the second source array with a scalar value, adding the result to corresponding elements from + * the first source array, and storing the result in the target array. + * + * @param[out] tgt The target array where the result of the fused multiply/add operation will be stored (128-bit + * aligned). + * @param[in] src_a The first source array containing the first set of operands (128-bit aligned). + * @param[in] src_b The second source array containing the second set of operands to be multiplied by the scalar + * (128-bit aligned). + * @param[in] scalar The scalar value used in the multiply/add operation. + */ +template +__global__ void TriadKernel(VecT *tgt, const VecT *src_a, const VecT *src_b, const T scalar) { + uint64_t index = blockIdx.x * blockDim.x + threadIdx.x; + VecT val_a; + VecT val_b; + Fetch(val_a, src_a + index); + Fetch(val_b, src_b + index); + if constexpr (std::is_same::value) { + val_b.x += (val_a.x * scalar); + val_b.y += (val_a.y * scalar); + } else if constexpr (std::is_same::value) { + val_b.x += (val_a.x * scalar); + val_b.y += (val_a.y * scalar); + val_b.z += (val_a.z * scalar); + val_b.w += (val_a.w * scalar); + } + Store(tgt + index, val_b); +} + +#endif // __CUDACC__ \ No newline at end of file diff --git a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_test.cpp b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_main.cpp similarity index 100% rename from superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_test.cpp rename to superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_main.cpp diff --git a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_utils.cpp b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_utils.cpp index 6ced0fdd5..fd0dfb913 100644 --- a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_utils.cpp +++ b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_utils.cpp @@ -43,6 +43,7 @@ void PrintUsage() { << "--size " << "--num_warm_up " << "--num_loops " + << "[--data_type ] " << "[--check_data]" << std::endl; } @@ -60,6 +61,7 @@ void PrintInputInfo(Opts &opts) { std::cout << "Buffer size(bytes): " << opts.size << std::endl; std::cout << "Number of warm up runs: " << opts.num_warm_up << std::endl; std::cout << "Number of loops: " << opts.num_loops << std::endl; + std::cout << "Data type: " << opts.data_type << std::endl; std::cout << "Check data: " << (opts.check_data ? "Yes" : "No") << std::endl; } @@ -75,11 +77,12 @@ void PrintInputInfo(Opts &opts) { * @return int The status code. * */ int ParseOpts(int argc, char **argv, Opts *opts) { - enum class OptIdx { kSize, kNumWarmUp, kNumLoops, kEnableCheckData }; + enum class OptIdx { kSize, kNumWarmUp, kNumLoops, kEnableCheckData, kDataType }; const struct option options[] = {{"size", required_argument, nullptr, static_cast(OptIdx::kSize)}, {"num_warm_up", required_argument, nullptr, static_cast(OptIdx::kNumWarmUp)}, {"num_loops", required_argument, nullptr, static_cast(OptIdx::kNumLoops)}, - {"check_data", no_argument, nullptr, static_cast(OptIdx::kEnableCheckData)}}; + {"check_data", no_argument, nullptr, static_cast(OptIdx::kEnableCheckData)}, + {"data_type", required_argument, nullptr, static_cast(OptIdx::kDataType)}}; int getopt_ret = 0; int opt_idx = 0; bool size_specified = true; @@ -126,6 +129,13 @@ int ParseOpts(int argc, char **argv, Opts *opts) { case static_cast(OptIdx::kEnableCheckData): opts->check_data = true; break; + case static_cast(OptIdx::kDataType): + opts->data_type = optarg; + if (opts->data_type != "float" && opts->data_type != "double") { + std::cerr << "Invalid data_type: " << optarg << ". Must be 'float' or 'double'." << std::endl; + parse_err = true; + } + break; default: parse_err = true; } diff --git a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_utils.hpp b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_utils.hpp index 0c648514b..907d05ef2 100644 --- a/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_utils.hpp +++ b/superbench/benchmarks/micro_benchmarks/gpu_stream/gpu_stream_utils.hpp @@ -31,7 +31,6 @@ unsigned long long getCurrentTimestampInMicroseconds(); namespace stream_config { constexpr std::array kThreadsPerBlock = {128, 256, 512, 1024}; // Threads per block constexpr uint64_t kDefaultBufferSizeInBytes = 4294967296; // Default buffer size 4GB -constexpr int kNumLoopUnroll = 2; // Unroll depth in SM copy kernel constexpr int kNumBuffers = 3; // Number of buffers for triad, add kernel constexpr int kNumValidationBuffers = 4; // Number of validation buffers, one for each kernel constexpr int kUInt8Mod = 256; // Modulo for unsigned long data type @@ -83,10 +82,7 @@ template struct SubBenchArgs { // Arguments for each benchmark run. template struct BenchArgs { - // NUMA node under which the benchmark is done. - uint64_t numa_id = 0; - - // GPU ID for device. + // GPU ID for device (always 0 - actual GPU determined by CUDA_VISIBLE_DEVICES). int gpu_id = 0; // GPU device info @@ -121,6 +117,9 @@ struct Opts { // Whether check data after copy. bool check_data = false; + + // Data type for the benchmark ("float" or "double"). + std::string data_type = "double"; }; std::string KernelToString(int); // Function to convert enum to string diff --git a/tests/benchmarks/micro_benchmarks/test_gpu_stream.py b/tests/benchmarks/micro_benchmarks/test_gpu_stream.py index e6dea64e8..7f58fe461 100644 --- a/tests/benchmarks/micro_benchmarks/test_gpu_stream.py +++ b/tests/benchmarks/micro_benchmarks/test_gpu_stream.py @@ -31,7 +31,7 @@ def _test_gpu_stream_command_generation(self, platform): num_loops = 10 size = 25769803776 - parameters = '--num_warm_up %d --num_loops %d --size %d ' \ + parameters = '--num_warm_up %d --num_loops %d --size %d --data_type double ' \ '--check_data' % \ (num_warm_up, num_loops, size) benchmark = benchmark_class(benchmark_name, parameters=parameters) @@ -49,6 +49,7 @@ def _test_gpu_stream_command_generation(self, platform): assert (benchmark._args.num_warm_up == num_warm_up) assert (benchmark._args.num_loops == num_loops) assert (benchmark._args.check_data) + assert (benchmark._args.data_type == 'double') # Check command assert (1 == len(benchmark._commands)) @@ -56,6 +57,7 @@ def _test_gpu_stream_command_generation(self, platform): assert ('--size %d' % size in benchmark._commands[0]) assert ('--num_warm_up %d' % num_warm_up in benchmark._commands[0]) assert ('--num_loops %d' % num_loops in benchmark._commands[0]) + assert ('--data_type double' in benchmark._commands[0]) assert ('--check_data' in benchmark._commands[0]) @decorator.cuda_test diff --git a/tests/data/gpu_stream.log b/tests/data/gpu_stream.log index c3d6f2390..a3dcf2b01 100644 --- a/tests/data/gpu_stream.log +++ b/tests/data/gpu_stream.log @@ -2,40 +2,23 @@ STREAM Benchmark Buffer size(bytes): 4294967296 Number of warm up runs: 10 Number of loops: 40 +Data type: double Check data: No Device 0: "NVIDIA Graphics Device" 152 SMs(10.0) Memory: 4000MHz x 8192-bit = 8192 GB/s PEAK ECC is ON -STREAM_COPY_double_gpu_0_buffer_4294967296_block_128 6711.67 81.93 -STREAM_COPY_double_gpu_0_buffer_4294967296_block_256 6549.50 79.95 -STREAM_COPY_double_gpu_0_buffer_4294967296_block_512 6195.43 75.63 -STREAM_COPY_double_gpu_0_buffer_4294967296_block_1024 5721.52 69.84 -STREAM_SCALE_double_gpu_0_buffer_4294967296_block_128 6680.42 81.55 -STREAM_SCALE_double_gpu_0_buffer_4294967296_block_256 6515.51 79.54 -STREAM_SCALE_double_gpu_0_buffer_4294967296_block_512 6106.69 74.54 -STREAM_SCALE_double_gpu_0_buffer_4294967296_block_1024 5626.68 68.69 -STREAM_ADD_double_gpu_0_buffer_4294967296_block_128 7379.25 90.08 -STREAM_ADD_double_gpu_0_buffer_4294967296_block_256 7407.27 90.42 -STREAM_ADD_double_gpu_0_buffer_4294967296_block_512 7309.59 89.23 -STREAM_ADD_double_gpu_0_buffer_4294967296_block_1024 6788.64 82.87 -STREAM_TRIAD_double_gpu_0_buffer_4294967296_block_128 7378.19 90.07 -STREAM_TRIAD_double_gpu_0_buffer_4294967296_block_256 7414.01 90.50 -STREAM_TRIAD_double_gpu_0_buffer_4294967296_block_512 7295.50 89.06 -STREAM_TRIAD_double_gpu_0_buffer_4294967296_block_1024 6730.42 82.16 - -Device 1: "NVIDIA Graphics Device" 152 SMs(10.0) Memory: 4000.00MHz x 8192-bit = 8192.00 GB/s PEAK ECC is ON -STREAM_COPY_double_gpu_1_buffer_4294967296_block_128 6708.74 81.89 -STREAM_COPY_double_gpu_1_buffer_4294967296_block_256 6549.47 79.95 -STREAM_COPY_double_gpu_1_buffer_4294967296_block_512 6195.39 75.63 -STREAM_COPY_double_gpu_1_buffer_4294967296_block_1024 5725.07 69.89 -STREAM_SCALE_double_gpu_1_buffer_4294967296_block_128 6678.56 81.53 -STREAM_SCALE_double_gpu_1_buffer_4294967296_block_256 6514.05 79.52 -STREAM_SCALE_double_gpu_1_buffer_4294967296_block_512 6103.80 74.51 -STREAM_SCALE_double_gpu_1_buffer_4294967296_block_1024 5630.41 68.73 -STREAM_ADD_double_gpu_1_buffer_4294967296_block_128 7377.74 90.06 -STREAM_ADD_double_gpu_1_buffer_4294967296_block_256 7410.97 90.47 -STREAM_ADD_double_gpu_1_buffer_4294967296_block_512 7310.80 89.24 -STREAM_ADD_double_gpu_1_buffer_4294967296_block_1024 6789.91 82.88 -STREAM_TRIAD_double_gpu_1_buffer_4294967296_block_128 7379.03 90.08 -STREAM_TRIAD_double_gpu_1_buffer_4294967296_block_256 7414.04 90.50 -STREAM_TRIAD_double_gpu_1_buffer_4294967296_block_512 7298.26 89.09 -STREAM_TRIAD_double_gpu_1_buffer_4294967296_block_1024 6732.15 82.18 \ No newline at end of file +STREAM_COPY_double_buffer_4294967296_block_128 6711.67 81.93 +STREAM_COPY_double_buffer_4294967296_block_256 6549.50 79.95 +STREAM_COPY_double_buffer_4294967296_block_512 6195.43 75.63 +STREAM_COPY_double_buffer_4294967296_block_1024 5721.52 69.84 +STREAM_SCALE_double_buffer_4294967296_block_128 6680.42 81.55 +STREAM_SCALE_double_buffer_4294967296_block_256 6515.51 79.54 +STREAM_SCALE_double_buffer_4294967296_block_512 6106.69 74.54 +STREAM_SCALE_double_buffer_4294967296_block_1024 5626.68 68.69 +STREAM_ADD_double_buffer_4294967296_block_128 7379.25 90.08 +STREAM_ADD_double_buffer_4294967296_block_256 7407.27 90.42 +STREAM_ADD_double_buffer_4294967296_block_512 7309.59 89.23 +STREAM_ADD_double_buffer_4294967296_block_1024 6788.64 82.87 +STREAM_TRIAD_double_buffer_4294967296_block_128 7378.19 90.07 +STREAM_TRIAD_double_buffer_4294967296_block_256 7414.01 90.50 +STREAM_TRIAD_double_buffer_4294967296_block_512 7295.50 89.06 +STREAM_TRIAD_double_buffer_4294967296_block_1024 6730.42 82.16 \ No newline at end of file