Skip to content

Commit b399145

Browse files
committed
chore: simplify scalar op loop
1 parent 7763da9 commit b399145

File tree

2 files changed

+61
-61
lines changed

2 files changed

+61
-61
lines changed

vortex-cuda/benches/dynamic_dispatch_cuda.rs

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ use vortex_error::vortex_err;
3838
use vortex_fastlanes::BitPackedArray;
3939
use vortex_session::VortexSession;
4040

41-
const BENCH_ARGS: &[(usize, &str)] = &[(1_000_000, "1M"), (10_000_000, "10M")];
41+
const BENCH_ARGS: &[(usize, &str)] = &[
42+
(1_000_000, "1M"),
43+
(10_000_000, "10M"),
44+
(100_000_000, "100M"),
45+
];
4246

4347
const REFERENCE_VALUE: u32 = 100_000;
4448

@@ -152,10 +156,6 @@ fn run_dynamic_dispatch_timed(
152156
Ok(Duration::from_secs_f32(elapsed_ms / 1000.0))
153157
}
154158

155-
// ============================================================================
156-
// Benchmark: BitUnpack + FoR — two separate kernel launches
157-
// ============================================================================
158-
159159
/// Run bitunpack then FoR as two separate kernel launches, returning GPU time.
160160
fn run_bitunpack_for_separate_timed(
161161
cuda_ctx: &mut CudaExecutionCtx,
@@ -199,19 +199,17 @@ fn run_bitunpack_for_separate_timed(
199199
.map_err(|e| vortex_err!("failed to record start event: {:?}", e))?;
200200

201201
// --- Kernel 1: BitUnpack ---
202-
{
203-
let output_width = u32::BITS as usize;
204-
let cuda_function = bitpacked_cuda_kernel(bit_width, output_width, cuda_ctx)?;
205-
let mut launch_builder = cuda_ctx.launch_builder(&cuda_function);
206-
launch_builder.arg(&input_view);
207-
launch_builder.arg(&output_view);
208-
209-
let config = bitpacked_cuda_launch_config(output_width, len)?;
210-
unsafe {
211-
launch_builder
212-
.launch(config)
213-
.map_err(|e| vortex_err!("bit_unpack kernel launch failed: {}", e))?;
214-
}
202+
let output_width = u32::BITS as usize;
203+
let cuda_function = bitpacked_cuda_kernel(bit_width, output_width, cuda_ctx)?;
204+
let mut launch_builder = cuda_ctx.launch_builder(&cuda_function);
205+
launch_builder.arg(&input_view);
206+
launch_builder.arg(&output_view);
207+
208+
let config = bitpacked_cuda_launch_config(output_width, len)?;
209+
unsafe {
210+
launch_builder
211+
.launch(config)
212+
.map_err(|e| vortex_err!("bit_unpack kernel launch failed: {}", e))?;
215213
}
216214

217215
// --- Kernel 2: FoR (in-place on output_buf) ---

vortex-cuda/kernels/src/dynamic_dispatch.cu

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include "dynamic_dispatch.h"
1818
#include "types.cuh"
1919

20-
constexpr uint32_t FL_CHUNK_SIZE = 1024;
20+
constexpr uint32_t ELEMENTS_PER_BLOCK = 2048;
2121

2222
template <typename T>
2323
__device__ __forceinline__ void bitunpack_lane_to_smem(const T *__restrict packed_chunk, T *__restrict smem,
@@ -45,16 +45,23 @@ __device__ __forceinline__ void dynamic_source_op(const T *__restrict input, T *
4545
uint64_t chunk_start, uint32_t chunk_len,
4646
const struct SourceOp &source_op) {
4747
constexpr uint32_t T_BITS = sizeof(T) * 8;
48-
constexpr uint32_t FL_LANES = FL_CHUNK_SIZE / T_BITS;
48+
constexpr uint32_t FL_LANES = ELEMENTS_PER_BLOCK / T_BITS;
4949

5050
switch (source_op.op_code) {
5151
case SourceOp::BITUNPACK: {
52+
constexpr uint32_t ELEMENTS_PER_FL_BLOCK = 1024;
53+
constexpr uint32_t LANES_PER_FL_BLOCK = ELEMENTS_PER_FL_BLOCK / T_BITS;
5254
const uint32_t bit_width = source_op.params.bitunpack.bit_width;
53-
const uint32_t packed_words_per_chunk = FL_LANES * bit_width;
54-
const uint64_t chunk_idx = chunk_start / FL_CHUNK_SIZE;
55-
const T *packed_chunk = input + chunk_idx * packed_words_per_chunk;
56-
for (uint32_t lane = threadIdx.x; lane < FL_LANES; lane += blockDim.x) {
57-
bitunpack_lane_to_smem<T>(packed_chunk, smem, lane, bit_width);
55+
const uint32_t packed_words_per_fl_block = LANES_PER_FL_BLOCK * bit_width;
56+
const uint64_t first_fl_block = chunk_start / ELEMENTS_PER_FL_BLOCK;
57+
58+
#pragma unroll
59+
for (uint32_t blk = 0; blk < ELEMENTS_PER_BLOCK / ELEMENTS_PER_FL_BLOCK; ++blk) {
60+
const T *packed_fl = input + (first_fl_block + blk) * packed_words_per_fl_block;
61+
T *smem_fl = smem + blk * ELEMENTS_PER_FL_BLOCK;
62+
for (uint32_t lane = threadIdx.x; lane < LANES_PER_FL_BLOCK; lane += blockDim.x) {
63+
bitunpack_lane_to_smem<T>(packed_fl, smem_fl, lane, bit_width);
64+
}
5865
}
5966
break;
6067
}
@@ -82,66 +89,61 @@ __device__ __forceinline__ T dynamic_scalar_op(T value, const struct ScalarOp &o
8289
template <typename T>
8390
__device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict output, uint64_t array_len,
8491
const struct DynamicDispatchPlan *__restrict plan) {
85-
constexpr uint32_t ELEMENTS_PER_BLOCK = 2048;
8692
constexpr uint32_t VALUES_PER_LOOP = 32 / sizeof(T);
8793

8894
__shared__ struct DynamicDispatchPlan smem_plan;
89-
__shared__ T smem_values[FL_CHUNK_SIZE];
95+
__shared__ T smem_values[ELEMENTS_PER_BLOCK];
9096

9197
// Cache the plan in shared memory.
9298
if (threadIdx.x == 0) smem_plan = *plan;
9399
__syncthreads();
94100

95101
const uint64_t block_start = static_cast<uint64_t>(blockIdx.x) * ELEMENTS_PER_BLOCK;
96102
const uint64_t block_end = min(block_start + ELEMENTS_PER_BLOCK, array_len);
103+
const uint32_t block_len = static_cast<uint32_t>(block_end - block_start);
97104

98-
for (uint64_t chunk_start = block_start; chunk_start < block_end; chunk_start += FL_CHUNK_SIZE) {
99-
const uint32_t chunk_len = min(FL_CHUNK_SIZE, static_cast<uint32_t>(block_end - chunk_start));
100-
dynamic_source_op<T>(input, smem_values, chunk_start, chunk_len, smem_plan.source);
101-
__syncthreads();
102-
103-
const uint32_t tile_size = blockDim.x * VALUES_PER_LOOP;
104-
const uint32_t num_full_tiles = chunk_len / tile_size;
105+
dynamic_source_op<T>(input, smem_values, block_start, block_len, smem_plan.source);
106+
__syncthreads();
105107

106-
for (uint32_t tile = 0; tile < num_full_tiles; ++tile) {
107-
const uint32_t tile_base = tile * tile_size;
108+
const uint32_t tile_size = blockDim.x * VALUES_PER_LOOP;
109+
const uint32_t num_full_tiles = block_len / tile_size;
108110

109-
// Operate on values in registers. This is faster than a coalesced
110-
// one-element-per-thread loop as it enables better instruction-level
111-
// parallelism.
112-
T values[VALUES_PER_LOOP];
111+
for (uint32_t tile = 0; tile < num_full_tiles; ++tile) {
112+
const uint32_t tile_base = tile * tile_size;
113113

114-
#pragma unroll
115-
for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) {
116-
values[idx] = smem_values[tile_base + idx * blockDim.x + threadIdx.x];
117-
}
114+
// Operate on values in registers. This is faster than a coalesced
115+
// one-element-per-thread loop as it enables better instruction-level
116+
// parallelism.
117+
T values[VALUES_PER_LOOP];
118118

119-
for (uint8_t op_idx = 0; op_idx < smem_plan.num_scalar_ops; ++op_idx) {
120-
const struct ScalarOp &scalar_op = smem_plan.scalar_ops[op_idx];
119+
#pragma unroll
120+
for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) {
121+
values[idx] = smem_values[tile_base + idx * blockDim.x + threadIdx.x];
122+
}
121123

122-
#pragma unroll
123-
for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) {
124-
values[idx] = dynamic_scalar_op(values[idx], scalar_op);
125-
}
126-
}
124+
for (uint8_t op_idx = 0; op_idx < smem_plan.num_scalar_ops; ++op_idx) {
125+
const struct ScalarOp &scalar_op = smem_plan.scalar_ops[op_idx];
127126

128127
#pragma unroll
129128
for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) {
130-
output[chunk_start + tile_base + idx * blockDim.x + threadIdx.x] = values[idx];
129+
values[idx] = dynamic_scalar_op(values[idx], scalar_op);
131130
}
132131
}
133132

134-
// Handle remaining elements that were not part of a full tile.
135-
const uint32_t rem_start = num_full_tiles * tile_size;
136-
for (uint32_t elem_idx = rem_start + threadIdx.x; elem_idx < chunk_len; elem_idx += blockDim.x) {
137-
T val = smem_values[elem_idx];
138-
for (uint8_t op_idx = 0; op_idx < smem_plan.num_scalar_ops; ++op_idx) {
139-
val = dynamic_scalar_op(val, smem_plan.scalar_ops[op_idx]);
140-
}
141-
output[chunk_start + elem_idx] = val;
133+
#pragma unroll
134+
for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) {
135+
output[block_start + tile_base + idx * blockDim.x + threadIdx.x] = values[idx];
142136
}
137+
}
143138

144-
__syncthreads();
139+
// Handle remaining elements that were not part of a full tile.
140+
const uint32_t rem_start = num_full_tiles * tile_size;
141+
for (uint32_t elem_idx = rem_start + threadIdx.x; elem_idx < block_len; elem_idx += blockDim.x) {
142+
T val = smem_values[elem_idx];
143+
for (uint8_t op_idx = 0; op_idx < smem_plan.num_scalar_ops; ++op_idx) {
144+
val = dynamic_scalar_op(val, smem_plan.scalar_ops[op_idx]);
145+
}
146+
output[block_start + elem_idx] = val;
145147
}
146148
}
147149

0 commit comments

Comments
 (0)