Skip to content

Commit 1e03d86

Browse files
committed
v0.6.7.post1 release
Signed-off-by: Javier <25750030+SystemPanic@users.noreply.github.com>
1 parent 6464607 commit 1e03d86

1 file changed

Lines changed: 13 additions & 8 deletions

File tree

include/flashinfer/mamba/kernel_selective_state_update_stp.cuh

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,10 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams
9595
bool const dt_softplus = params.dt_softplus;
9696

9797
// State scale pointer (only used when scaleState == true)
98-
[[maybe_unused]] auto* __restrict__ state_scale =
99-
reinterpret_cast<state_scale_t*>(params.state_scale);
98+
[[maybe_unused]] state_scale_t* __restrict__ state_scale = nullptr;
99+
if constexpr (scaleState) {
100+
state_scale = reinterpret_cast<state_scale_t*>(params.state_scale);
101+
}
100102

101103
// Load device-side Philox seed once into a register
102104
[[maybe_unused]] int64_t const rand_seed = params.rand_seed ? *params.rand_seed : 0;
@@ -130,10 +132,11 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams
130132
if constexpr (scaleState) {
131133
state_scale += state_batch * params.state_scale_stride_batch + head * DIM;
132134
}
133-
[[maybe_unused]] auto* __restrict__ dst_state_scale =
134-
scaleState ? reinterpret_cast<state_scale_t*>(params.state_scale) +
135-
dst_state_batch * params.state_scale_stride_batch + head * DIM
136-
: nullptr;
135+
[[maybe_unused]] state_scale_t* __restrict__ dst_state_scale = nullptr;
136+
if constexpr (scaleState) {
137+
dst_state_scale = reinterpret_cast<state_scale_t*>(params.state_scale) +
138+
dst_state_batch * params.state_scale_stride_batch + head * DIM;
139+
}
137140

138141
__shared__ SharedStorageSimple<input_t, state_scale_t, ROWS_PER_BLOCK, DSTATE> sram;
139142

@@ -650,8 +653,10 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical(
650653
auto const* __restrict__ z = reinterpret_cast<input_t const*>(params.z);
651654
auto const* __restrict__ state_batch_indices =
652655
reinterpret_cast<stateIndex_t const*>(params.state_batch_indices);
653-
[[maybe_unused]] auto* __restrict__ state_scale =
654-
reinterpret_cast<state_scale_t*>(params.state_scale);
656+
[[maybe_unused]] state_scale_t* __restrict__ state_scale = nullptr;
657+
if constexpr (scaleState) {
658+
state_scale = reinterpret_cast<state_scale_t*>(params.state_scale);
659+
}
655660

656661
// Load device-side Philox seed once into a register
657662
[[maybe_unused]] int64_t const rand_seed = params.rand_seed ? *params.rand_seed : 0;

0 commit comments

Comments
 (0)