Skip to content

Commit e09d10e

Browse files
committed
Fixed non-windows nvshmem import and wrong spacy import
Signed-off-by: Javier <25750030+SystemPanic@users.noreply.github.com>
1 parent ae643ab commit e09d10e

2 files changed

Lines changed: 6 additions & 4 deletions

File tree

flashinfer/jit/cpp_ext.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from typing import List, Optional
1010

1111
import torch
12-
from spacy.compat import is_windows
1312
from torch.utils.cpp_extension import (
1413
_TORCH_PATH,
1514
CUDA_HOME,
@@ -65,6 +64,7 @@ def generate_ninja_build_for_op(
6564
common_cflags.append(f"-I{dir.resolve()}")
6665

6766
is_windows = platform.system() == "Windows"
67+
6868
if is_windows:
6969
for dir in system_includes:
7070
common_cflags.append(f"-I{dir}")
@@ -89,7 +89,6 @@ def generate_ninja_build_for_op(
8989

9090

9191
common_cuda_flags = common_cflags.copy()
92-
is_windows = platform.system() == "Windows"
9392

9493
if is_windows:
9594
common_cuda_flags = [

setup.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,11 @@ def generate_build_meta(aot_build_meta: dict) -> None:
8686
"requests",
8787
"cuda-python",
8888
"pynvml",
89-
"einops",
89+
"einops"
9090
]
91+
if not IS_WINDOWS:
92+
install_requires.append("nvidia-nvshmem-cu12")
93+
9194
generate_build_meta({})
9295

9396
if IS_WINDOWS:
@@ -163,4 +166,4 @@ def get_cuda_version() -> Version:
163166
cmdclass=cmdclass,
164167
install_requires=install_requires,
165168
options={"bdist_wheel": {"py_limited_api": "cp39"}}
166-
)
169+
)

0 commit comments

Comments
 (0)