Skip to content

Commit bc8d035

Browse files
committed
v0.6.3 release
Signed-off-by: Javier <25750030+SystemPanic@users.noreply.github.com>
1 parent d0096ee commit bc8d035

10 files changed

Lines changed: 202 additions & 89 deletions

File tree

LICENSE.cutlass.txt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
SPDX-License-Identifier: BSD-3-Clause
3+
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are met:
6+
7+
1. Redistributions of source code must retain the above copyright notice, this
8+
list of conditions and the following disclaimer.
9+
10+
2. Redistributions in binary form must reproduce the above copyright notice,
11+
this list of conditions and the following disclaimer in the documentation
12+
and/or other materials provided with the distribution.
13+
14+
3. Neither the name of the copyright holder nor the names of its
15+
contributors may be used to endorse or promote products derived from
16+
this software without specific prior written permission.
17+
18+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

LICENSE.flashattention3.txt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
BSD 3-Clause License
2+
3+
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
4+
All rights reserved.
5+
6+
Redistribution and use in source and binary forms, with or without
7+
modification, are permitted provided that the following conditions are met:
8+
9+
* Redistributions of source code must retain the above copyright notice, this
10+
list of conditions and the following disclaimer.
11+
12+
* Redistributions in binary form must reproduce the above copyright notice,
13+
this list of conditions and the following disclaimer in the documentation
14+
and/or other materials provided with the distribution.
15+
16+
* Neither the name of the copyright holder nor the names of its
17+
contributors may be used to endorse or promote products derived from
18+
this software without specific prior written permission.
19+
20+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

LICENSE.fmt.txt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
Copyright (c) 2012 - present, Victor Zverovich
2+
3+
Permission is hereby granted, free of charge, to any person obtaining
4+
a copy of this software and associated documentation files (the
5+
"Software"), to deal in the Software without restriction, including
6+
without limitation the rights to use, copy, modify, merge, publish,
7+
distribute, sublicense, and/or sell copies of the Software, and to
8+
permit persons to whom the Software is furnished to do so, subject to
9+
the following conditions:
10+
11+
The above copyright notice and this permission notice shall be
12+
included in all copies or substantial portions of the Software.
13+
14+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21+
22+
--- Optional exception to the license ---
23+
24+
As an exception, if, as a result of your compiling your source code, portions
25+
of this Software are embedded into a machine-executable object form of such
26+
source code, you may redistribute such embedded portions in such object form
27+
without including the above copyright and permission notices.

LICENSE.spdlog.txt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
The MIT License (MIT)
2+
3+
Copyright (c) 2016 Gabi Melman.
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in
13+
all copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21+
THE SOFTWARE.
22+
23+
-- NOTE: Third party dependency used by this software --
24+
This software depends on the fmt lib (MIT License),
25+
and users must comply to its license: https://raw.githubusercontent.com/fmtlib/fmt/master/LICENSE

csrc/cudnn_sdpa_kernel_launcher.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,8 @@ static void create_packed_tma_desc_kv_prefill(int b, int32_t* actual_seq_lens_kv
341341
std::array<uint64_t, DIMS_QKV - 1> packed_tensor_stride_v = {h_kv * d_vo * BYTES_PER_ELEMENT,
342342
d_vo * BYTES_PER_ELEMENT, 0};
343343

344-
uint16_t* k_ptr = reinterpret_cast<uint16_t*>(k.data_ptr() + batch_offset_k);
345-
uint16_t* v_ptr = reinterpret_cast<uint16_t*>(v.data_ptr() + batch_offset_v);
344+
uint16_t* k_ptr = reinterpret_cast<uint16_t*>((int64_t)k.data_ptr() + batch_offset_k);
345+
uint16_t* v_ptr = reinterpret_cast<uint16_t*>((int64_t)v.data_ptr() + batch_offset_v);
346346

347347
tma::cudaSetTmaTileDescriptor(
348348
&packed_tma_desc_k[i], (void*)k_ptr, DIMS_QKV, packed_tensor_size_k.data(),
@@ -383,8 +383,8 @@ static void create_packed_tma_desc_qo_prefill(int b, int32_t* actual_seq_lens_q_
383383
std::array<uint64_t, DIMS_QKV - 1> packed_tensor_stride_o = {h_qo * d_vo * BYTES_PER_ELEMENT,
384384
d_vo * BYTES_PER_ELEMENT, 0};
385385

386-
uint16_t* q_ptr = reinterpret_cast<uint16_t*>(q.data_ptr() + batch_offset_q);
387-
uint16_t* out_ptr = reinterpret_cast<uint16_t*>(out.data_ptr() + batch_offset_o);
386+
uint16_t* q_ptr = reinterpret_cast<uint16_t*>((int64_t)q.data_ptr() + batch_offset_q);
387+
uint16_t* out_ptr = reinterpret_cast<uint16_t*>((int64_t)out.data_ptr() + batch_offset_o);
388388

389389
tma::cudaSetTmaTileDescriptor(
390390
&packed_tma_desc_q[i], (void*)q_ptr, DIMS_QKV, packed_tensor_size_q.data(),

flashinfer/aot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def copy_built_kernels(
598598
is_windows = platform.system() == "Windows"
599599
for jit_spec in jit_specs:
600600
if is_windows:
601-
src = jit_env.FLASHINFER_JIT_DIR / f"{jit_spec.name}.dll"
601+
src = jit_env.FLASHINFER_JIT_DIR / jit_spec.name / f"{jit_spec.name}.dll"
602602
dst = out_dir / jit_spec.name / f"{jit_spec.name}.dll"
603603
else:
604604
src = jit_env.FLASHINFER_JIT_DIR / jit_spec.name / f"{jit_spec.name}.so"

flashinfer/jit/attention/fmha_v2/generator_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3704,7 +3704,7 @@ def generate_files(specs_names):
37043704
"src",
37053705
"-Xcompiler",
37063706
"-Wno-enum-compare",
3707-
"--std=c++17",
3707+
"--std=c++20",
37083708
"-o",
37093709
"bin/print_traits.exe",
37103710
"generated/print_kernel_traits.cu",

flashinfer/jit/attention/modules.py

Lines changed: 76 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ def get_single_decode_uri(
5151
use_logits_soft_cap: bool,
5252
) -> str:
5353
return (
54-
f"single_decode_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
55-
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
56-
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
57-
f"head_dim_qk_{head_dim_qk}_"
58-
f"head_dim_vo_{head_dim_vo}_"
59-
f"posenc_{pos_encoding_mode}_"
60-
f"use_swa_{use_sliding_window}_"
61-
f"use_logits_cap_{use_logits_soft_cap}"
54+
f"sdkvcd_q_{filename_safe_dtype_map[dtype_q]}_"
55+
f"kv_{filename_safe_dtype_map[dtype_kv]}_"
56+
f"o_{filename_safe_dtype_map[dtype_o]}_"
57+
f"qk_{head_dim_qk}_"
58+
f"vo_{head_dim_vo}_"
59+
f"pe_{pos_encoding_mode}_"
60+
f"swa_{use_sliding_window}_"
61+
f"lc_{use_logits_soft_cap}"
6262
)
6363

6464

@@ -74,15 +74,15 @@ def get_batch_decode_uri(
7474
use_logits_soft_cap: bool,
7575
) -> str:
7676
return (
77-
f"batch_decode_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
78-
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
79-
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
80-
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
81-
f"head_dim_qk_{head_dim_qk}_"
82-
f"head_dim_vo_{head_dim_vo}_"
83-
f"posenc_{pos_encoding_mode}_"
84-
f"use_swa_{use_sliding_window}_"
85-
f"use_logits_cap_{use_logits_soft_cap}"
77+
f"bdkvcd_q_{filename_safe_dtype_map[dtype_q]}_"
78+
f"kv_{filename_safe_dtype_map[dtype_kv]}_"
79+
f"o_{filename_safe_dtype_map[dtype_o]}_"
80+
f"idx_{filename_safe_dtype_map[dtype_idx]}_"
81+
f"qk_{head_dim_qk}_"
82+
f"vo_{head_dim_vo}_"
83+
f"pe_{pos_encoding_mode}_"
84+
f"swa_{use_sliding_window}_"
85+
f"lc_{use_logits_soft_cap}"
8686
)
8787

8888

@@ -97,13 +97,13 @@ def get_batch_mla_uri(
9797
use_profiler: bool,
9898
) -> str:
9999
return (
100-
f"batch_mla_attention_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
101-
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
102-
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
103-
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
104-
f"head_dim_ckv_{head_dim_ckv}_"
105-
f"head_dim_kpe_{head_dim_kpe}_"
106-
f"profiler_{use_profiler}"
100+
f"bmad_q_{filename_safe_dtype_map[dtype_q]}_"
101+
f"kv_{filename_safe_dtype_map[dtype_kv]}_"
102+
f"o_{filename_safe_dtype_map[dtype_o]}_"
103+
f"idx_{filename_safe_dtype_map[dtype_idx]}_"
104+
f"ckv_{head_dim_ckv}_"
105+
f"kpe_{head_dim_kpe}_"
106+
f"pr_{use_profiler}"
107107
) + ("_sm90" if backend == "fa3" else "")
108108

109109

@@ -214,13 +214,13 @@ def get_batch_decode_mla_uri(
214214
arc: str,
215215
) -> str:
216216
return (
217-
f"batch_decode_mla_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
218-
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
219-
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
220-
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
221-
f"head_dim_ckv{head_dim_ckv}_"
222-
f"use_swa_{use_sliding_window}_"
223-
f"use_logits_cap_{use_logits_soft_cap}_"
217+
f"bdmkvcd_q_{filename_safe_dtype_map[dtype_q]}_"
218+
f"kv_{filename_safe_dtype_map[dtype_kv]}_"
219+
f"o_{filename_safe_dtype_map[dtype_o]}_"
220+
f"idx_{filename_safe_dtype_map[dtype_idx]}_"
221+
f"ckv{head_dim_ckv}_"
222+
f"swa_{use_sliding_window}_"
223+
f"lc_{use_logits_soft_cap}_"
224224
f"arc_{arc}"
225225
)
226226

@@ -326,14 +326,14 @@ def get_single_prefill_uri(
326326
use_fp16_qk_reduction: bool,
327327
) -> str:
328328
return (
329-
f"single_prefill_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
330-
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
331-
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
332-
f"head_dim_qk_{head_dim_qk}_"
333-
f"head_dim_vo_{head_dim_vo}_"
334-
f"posenc_{pos_encoding_mode}_"
335-
f"use_swa_{use_sliding_window}_"
336-
f"use_logits_cap_{use_logits_soft_cap}_"
329+
f"spkvcd_q_{filename_safe_dtype_map[dtype_q]}_"
330+
f"kv_{filename_safe_dtype_map[dtype_kv]}_"
331+
f"o_{filename_safe_dtype_map[dtype_o]}_"
332+
f"qk_{head_dim_qk}_"
333+
f"vo_{head_dim_vo}_"
334+
f"pe_{pos_encoding_mode}_"
335+
f"swa_{use_sliding_window}_"
336+
f"lc_{use_logits_soft_cap}_"
337337
f"f16qk_{use_fp16_qk_reduction}" + ("_sm90" if backend == "fa3" else "")
338338
)
339339

@@ -353,17 +353,17 @@ def get_pod_uri(
353353
use_logits_soft_cap_d: bool,
354354
) -> str:
355355
return (
356-
f"pod_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
357-
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
358-
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
359-
f"head_dim_{head_dim}_"
360-
f"posenc_p_{pos_encoding_mode_p}_"
361-
f"use_swa_p_{use_sliding_window_p}_"
362-
f"use_logits_cap_p_{use_logits_soft_cap_p}_"
363-
f"posenc_d_{pos_encoding_mode_d}_"
364-
f"use_swa_d_{use_sliding_window_d}_"
365-
f"use_logits_cap_d_{use_logits_soft_cap_d}_"
366-
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
356+
f"pkvcd_q_{filename_safe_dtype_map[dtype_q]}_"
357+
f"kv_{filename_safe_dtype_map[dtype_kv]}_"
358+
f"o_{filename_safe_dtype_map[dtype_o]}_"
359+
f"hd_{head_dim}_"
360+
f"pe_p_{pos_encoding_mode_p}_"
361+
f"swa_p_{use_sliding_window_p}_"
362+
f"lc_p_{use_logits_soft_cap_p}_"
363+
f"pe_d_{pos_encoding_mode_d}_"
364+
f"swa_d_{use_sliding_window_d}_"
365+
f"lc_d_{use_logits_soft_cap_d}_"
366+
f"idx_{filename_safe_dtype_map[dtype_idx]}_"
367367
f"f16qk_{use_fp16_qk_reduction}"
368368
)
369369

@@ -382,15 +382,15 @@ def get_batch_prefill_uri(
382382
use_fp16_qk_reduction: bool,
383383
) -> str:
384384
return (
385-
f"batch_prefill_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
386-
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
387-
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
388-
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
389-
f"head_dim_qk_{head_dim_qk}_"
390-
f"head_dim_vo_{head_dim_vo}_"
391-
f"posenc_{pos_encoding_mode}_"
392-
f"use_swa_{use_sliding_window}_"
393-
f"use_logits_cap_{use_logits_soft_cap}_"
385+
f"bpkvcd_q_{filename_safe_dtype_map[dtype_q]}_"
386+
f"kv_{filename_safe_dtype_map[dtype_kv]}_"
387+
f"o_{filename_safe_dtype_map[dtype_o]}_"
388+
f"idx_{filename_safe_dtype_map[dtype_idx]}_"
389+
f"qk_{head_dim_qk}_"
390+
f"vo_{head_dim_vo}_"
391+
f"pe_{pos_encoding_mode}_"
392+
f"swa_{use_sliding_window}_"
393+
f"lc_{use_logits_soft_cap}_"
394394
f"f16qk_{use_fp16_qk_reduction}" + ("_sm90" if backend == "fa3" else "")
395395
)
396396

@@ -407,13 +407,13 @@ def get_batch_prefill_attention_sink_uri(
407407
use_sliding_window: bool,
408408
) -> str:
409409
return (
410-
f"batch_prefill_with_attention_sink_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
411-
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
412-
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
413-
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
414-
f"head_dim_qk_{head_dim_qk}_"
415-
f"head_dim_vo_{head_dim_vo}_"
416-
f"use_swa_{use_sliding_window}_" + ("_sm90" if backend == "fa3" else "")
410+
f"bpaskvcd_q_{filename_safe_dtype_map[dtype_q]}_"
411+
f"kv_{filename_safe_dtype_map[dtype_kv]}_"
412+
f"o_{filename_safe_dtype_map[dtype_o]}_"
413+
f"idx_{filename_safe_dtype_map[dtype_idx]}_"
414+
f"qk_{head_dim_qk}_"
415+
f"vo_{head_dim_vo}_"
416+
f"swa_{use_sliding_window}_" + ("_sm90" if backend == "fa3" else "")
417417
)
418418

419419

@@ -429,15 +429,15 @@ def get_batch_attention_uri(
429429
use_profiler: bool,
430430
) -> str:
431431
return (
432-
f"batch_attention_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_"
433-
f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_"
434-
f"dtype_o_{filename_safe_dtype_map[dtype_o]}_"
435-
f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_"
436-
f"head_dim_qk_{head_dim_qk}_"
437-
f"head_dim_vo_{head_dim_vo}_"
438-
f"posenc_{pos_encoding_mode}_"
439-
f"use_logits_soft_cap_{str(use_logits_soft_cap).lower()}_"
440-
f"use_profiler_{str(use_profiler).lower()}"
432+
f"bakvcd_q_{filename_safe_dtype_map[dtype_q]}_"
433+
f"kv_{filename_safe_dtype_map[dtype_kv]}_"
434+
f"o_{filename_safe_dtype_map[dtype_o]}_"
435+
f"idx_{filename_safe_dtype_map[dtype_idx]}_"
436+
f"qk_{head_dim_qk}_"
437+
f"vo_{head_dim_vo}_"
438+
f"pe_{pos_encoding_mode}_"
439+
f"lc_{str(use_logits_soft_cap).lower()}_"
440+
f"pr_{str(use_profiler).lower()}"
441441
)
442442

443443

@@ -654,7 +654,7 @@ def gen_batch_pod_module(
654654
use_sliding_window_d: bool,
655655
use_logits_soft_cap_d: bool,
656656
) -> JitSpec:
657-
uri = "batch_" + get_pod_uri(
657+
uri = "b" + get_pod_uri(
658658
dtype_q,
659659
dtype_kv,
660660
dtype_o,

flashinfer/jit/core.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,10 @@ def gen_jit_spec(
429429

430430
cflags = ["-O2"] if is_windows else ["-Wno-switch-bool"]
431431
if not cflags_has_std:
432-
cflags.insert(0, "-std=c++17")
432+
if is_windows:
433+
cflags.insert(0, "/std:c++20")
434+
else:
435+
cflags.insert(0, "-std=c++20")
433436

434437
cuda_cflags = [
435438
f"--threads={os.environ.get('FLASHINFER_NVCC_THREADS', '1')}",
@@ -443,7 +446,7 @@ def gen_jit_spec(
443446
cuda_cflags.insert(0, "-O2")
444447

445448
if not cuda_cflags_has_std:
446-
cuda_cflags.insert(0, "-std=c++17")
449+
cuda_cflags.insert(0, "-std=c++20")
447450

448451
if debug:
449452
cflags += ["-O0", "-g"]

0 commit comments

Comments
 (0)