@@ -95,8 +95,10 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams
9595 bool const dt_softplus = params.dt_softplus ;
9696
9797 // State scale pointer (only used when scaleState == true)
98- [[maybe_unused]] auto * __restrict__ state_scale =
99- reinterpret_cast <state_scale_t *>(params.state_scale );
98+ [[maybe_unused]] state_scale_t * __restrict__ state_scale = nullptr ;
99+ if constexpr (scaleState) {
100+ state_scale = reinterpret_cast <state_scale_t *>(params.state_scale );
101+ }
100102
101103 // Load device-side Philox seed once into a register
102104 [[maybe_unused]] int64_t const rand_seed = params.rand_seed ? *params.rand_seed : 0 ;
@@ -130,10 +132,11 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams
130132 if constexpr (scaleState) {
131133 state_scale += state_batch * params.state_scale_stride_batch + head * DIM;
132134 }
133- [[maybe_unused]] auto * __restrict__ dst_state_scale =
134- scaleState ? reinterpret_cast <state_scale_t *>(params.state_scale ) +
135- dst_state_batch * params.state_scale_stride_batch + head * DIM
136- : nullptr ;
135+ [[maybe_unused]] state_scale_t * __restrict__ dst_state_scale = nullptr ;
136+ if constexpr (scaleState) {
137+ dst_state_scale = reinterpret_cast <state_scale_t *>(params.state_scale ) +
138+ dst_state_batch * params.state_scale_stride_batch + head * DIM;
139+ }
137140
138141 __shared__ SharedStorageSimple<input_t , state_scale_t , ROWS_PER_BLOCK, DSTATE> sram;
139142
@@ -650,8 +653,10 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical(
650653 auto const * __restrict__ z = reinterpret_cast <input_t const *>(params.z );
651654 auto const * __restrict__ state_batch_indices =
652655 reinterpret_cast <stateIndex_t const *>(params.state_batch_indices );
653- [[maybe_unused]] auto * __restrict__ state_scale =
654- reinterpret_cast <state_scale_t *>(params.state_scale );
656+ [[maybe_unused]] state_scale_t * __restrict__ state_scale = nullptr ;
657+ if constexpr (scaleState) {
658+ state_scale = reinterpret_cast <state_scale_t *>(params.state_scale );
659+ }
655660
656661 // Load device-side Philox seed once into a register
657662 [[maybe_unused]] int64_t const rand_seed = params.rand_seed ? *params.rand_seed : 0 ;
0 commit comments