|
17 | 17 | #include "dynamic_dispatch.h" |
18 | 18 | #include "types.cuh" |
19 | 19 |
|
20 | | -constexpr uint32_t FL_CHUNK_SIZE = 1024; |
| 20 | +constexpr uint32_t ELEMENTS_PER_BLOCK = 2048; |
21 | 21 |
|
22 | 22 | template <typename T> |
23 | 23 | __device__ __forceinline__ void bitunpack_lane_to_smem(const T *__restrict packed_chunk, T *__restrict smem, |
@@ -45,16 +45,23 @@ __device__ __forceinline__ void dynamic_source_op(const T *__restrict input, T * |
45 | 45 | uint64_t chunk_start, uint32_t chunk_len, |
46 | 46 | const struct SourceOp &source_op) { |
47 | 47 | constexpr uint32_t T_BITS = sizeof(T) * 8; |
48 | | - constexpr uint32_t FL_LANES = FL_CHUNK_SIZE / T_BITS; |
| 48 | + constexpr uint32_t FL_LANES = ELEMENTS_PER_BLOCK / T_BITS; |
49 | 49 |
|
50 | 50 | switch (source_op.op_code) { |
51 | 51 | case SourceOp::BITUNPACK: { |
| 52 | + constexpr uint32_t ELEMENTS_PER_FL_BLOCK = 1024; |
| 53 | + constexpr uint32_t LANES_PER_FL_BLOCK = ELEMENTS_PER_FL_BLOCK / T_BITS; |
52 | 54 | const uint32_t bit_width = source_op.params.bitunpack.bit_width; |
53 | | - const uint32_t packed_words_per_chunk = FL_LANES * bit_width; |
54 | | - const uint64_t chunk_idx = chunk_start / FL_CHUNK_SIZE; |
55 | | - const T *packed_chunk = input + chunk_idx * packed_words_per_chunk; |
56 | | - for (uint32_t lane = threadIdx.x; lane < FL_LANES; lane += blockDim.x) { |
57 | | - bitunpack_lane_to_smem<T>(packed_chunk, smem, lane, bit_width); |
| 55 | + const uint32_t packed_words_per_fl_block = LANES_PER_FL_BLOCK * bit_width; |
| 56 | + const uint64_t first_fl_block = chunk_start / ELEMENTS_PER_FL_BLOCK; |
| 57 | + |
| 58 | + #pragma unroll |
| 59 | + for (uint32_t blk = 0; blk < ELEMENTS_PER_BLOCK / ELEMENTS_PER_FL_BLOCK; ++blk) { |
| 60 | + const T *packed_fl = input + (first_fl_block + blk) * packed_words_per_fl_block; |
| 61 | + T *smem_fl = smem + blk * ELEMENTS_PER_FL_BLOCK; |
| 62 | + for (uint32_t lane = threadIdx.x; lane < LANES_PER_FL_BLOCK; lane += blockDim.x) { |
| 63 | + bitunpack_lane_to_smem<T>(packed_fl, smem_fl, lane, bit_width); |
| 64 | + } |
58 | 65 | } |
59 | 66 | break; |
60 | 67 | } |
@@ -82,66 +89,61 @@ __device__ __forceinline__ T dynamic_scalar_op(T value, const struct ScalarOp &o |
82 | 89 | template <typename T> |
83 | 90 | __device__ void dynamic_dispatch_impl(const T *__restrict input, T *__restrict output, uint64_t array_len, |
84 | 91 | const struct DynamicDispatchPlan *__restrict plan) { |
85 | | - constexpr uint32_t ELEMENTS_PER_BLOCK = 2048; |
86 | 92 | constexpr uint32_t VALUES_PER_LOOP = 32 / sizeof(T); |
87 | 93 |
|
88 | 94 | __shared__ struct DynamicDispatchPlan smem_plan; |
89 | | - __shared__ T smem_values[FL_CHUNK_SIZE]; |
| 95 | + __shared__ T smem_values[ELEMENTS_PER_BLOCK]; |
90 | 96 |
|
91 | 97 | // Cache the plan in shared memory. |
92 | 98 | if (threadIdx.x == 0) smem_plan = *plan; |
93 | 99 | __syncthreads(); |
94 | 100 |
|
95 | 101 | const uint64_t block_start = static_cast<uint64_t>(blockIdx.x) * ELEMENTS_PER_BLOCK; |
96 | 102 | const uint64_t block_end = min(block_start + ELEMENTS_PER_BLOCK, array_len); |
| 103 | + const uint32_t block_len = static_cast<uint32_t>(block_end - block_start); |
97 | 104 |
|
98 | | - for (uint64_t chunk_start = block_start; chunk_start < block_end; chunk_start += FL_CHUNK_SIZE) { |
99 | | - const uint32_t chunk_len = min(FL_CHUNK_SIZE, static_cast<uint32_t>(block_end - chunk_start)); |
100 | | - dynamic_source_op<T>(input, smem_values, chunk_start, chunk_len, smem_plan.source); |
101 | | - __syncthreads(); |
102 | | - |
103 | | - const uint32_t tile_size = blockDim.x * VALUES_PER_LOOP; |
104 | | - const uint32_t num_full_tiles = chunk_len / tile_size; |
| 105 | + dynamic_source_op<T>(input, smem_values, block_start, block_len, smem_plan.source); |
| 106 | + __syncthreads(); |
105 | 107 |
|
106 | | - for (uint32_t tile = 0; tile < num_full_tiles; ++tile) { |
107 | | - const uint32_t tile_base = tile * tile_size; |
| 108 | + const uint32_t tile_size = blockDim.x * VALUES_PER_LOOP; |
| 109 | + const uint32_t num_full_tiles = block_len / tile_size; |
108 | 110 |
|
109 | | - // Operate on values in registers. This is faster than a coalesced |
110 | | - // one-element-per-thread loop as it enables better instruction-level |
111 | | - // parallelism. |
112 | | - T values[VALUES_PER_LOOP]; |
| 111 | + for (uint32_t tile = 0; tile < num_full_tiles; ++tile) { |
| 112 | + const uint32_t tile_base = tile * tile_size; |
113 | 113 |
|
114 | | - #pragma unroll |
115 | | - for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) { |
116 | | - values[idx] = smem_values[tile_base + idx * blockDim.x + threadIdx.x]; |
117 | | - } |
| 114 | + // Operate on values in registers. This is faster than a coalesced |
| 115 | + // one-element-per-thread loop as it enables better instruction-level |
| 116 | + // parallelism. |
| 117 | + T values[VALUES_PER_LOOP]; |
118 | 118 |
|
119 | | - for (uint8_t op_idx = 0; op_idx < smem_plan.num_scalar_ops; ++op_idx) { |
120 | | - const struct ScalarOp &scalar_op = smem_plan.scalar_ops[op_idx]; |
| 119 | + #pragma unroll |
| 120 | + for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) { |
| 121 | + values[idx] = smem_values[tile_base + idx * blockDim.x + threadIdx.x]; |
| 122 | + } |
121 | 123 |
|
122 | | - #pragma unroll |
123 | | - for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) { |
124 | | - values[idx] = dynamic_scalar_op(values[idx], scalar_op); |
125 | | - } |
126 | | - } |
| 124 | + for (uint8_t op_idx = 0; op_idx < smem_plan.num_scalar_ops; ++op_idx) { |
| 125 | + const struct ScalarOp &scalar_op = smem_plan.scalar_ops[op_idx]; |
127 | 126 |
|
128 | 127 | #pragma unroll |
129 | 128 | for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) { |
130 | | - output[chunk_start + tile_base + idx * blockDim.x + threadIdx.x] = values[idx]; |
| 129 | + values[idx] = dynamic_scalar_op(values[idx], scalar_op); |
131 | 130 | } |
132 | 131 | } |
133 | 132 |
|
134 | | - // Handle remaining elements that were not part of a full tile. |
135 | | - const uint32_t rem_start = num_full_tiles * tile_size; |
136 | | - for (uint32_t elem_idx = rem_start + threadIdx.x; elem_idx < chunk_len; elem_idx += blockDim.x) { |
137 | | - T val = smem_values[elem_idx]; |
138 | | - for (uint8_t op_idx = 0; op_idx < smem_plan.num_scalar_ops; ++op_idx) { |
139 | | - val = dynamic_scalar_op(val, smem_plan.scalar_ops[op_idx]); |
140 | | - } |
141 | | - output[chunk_start + elem_idx] = val; |
| 133 | + #pragma unroll |
| 134 | + for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) { |
| 135 | + output[block_start + tile_base + idx * blockDim.x + threadIdx.x] = values[idx]; |
142 | 136 | } |
| 137 | + } |
143 | 138 |
|
144 | | - __syncthreads(); |
| 139 | + // Handle remaining elements that were not part of a full tile. |
| 140 | + const uint32_t rem_start = num_full_tiles * tile_size; |
| 141 | + for (uint32_t elem_idx = rem_start + threadIdx.x; elem_idx < block_len; elem_idx += blockDim.x) { |
| 142 | + T val = smem_values[elem_idx]; |
| 143 | + for (uint8_t op_idx = 0; op_idx < smem_plan.num_scalar_ops; ++op_idx) { |
| 144 | + val = dynamic_scalar_op(val, smem_plan.scalar_ops[op_idx]); |
| 145 | + } |
| 146 | + output[block_start + elem_idx] = val; |
145 | 147 | } |
146 | 148 | } |
147 | 149 |
|
|
0 commit comments