diff --git a/src/scope/core/patches/_startup.py b/src/scope/core/patches/_startup.py index dfe957d7..31ac1912 100644 --- a/src/scope/core/patches/_startup.py +++ b/src/scope/core/patches/_startup.py @@ -10,8 +10,10 @@ if sys.platform == "win32": try: from .cudnn import patch_torch_cudnn + from .static_cuda_launcher import patch_torch_static_cuda_launcher patch_torch_cudnn(silent=True) + patch_torch_static_cuda_launcher(silent=True) except Exception: # Never crash Python startup - fail silently pass diff --git a/src/scope/core/patches/_utils.py b/src/scope/core/patches/_utils.py new file mode 100644 index 00000000..51a2d7ed --- /dev/null +++ b/src/scope/core/patches/_utils.py @@ -0,0 +1,29 @@ +"""Shared utilities for patch modules.""" + +import importlib.util +import os + + +def find_package_path(package_name: str) -> str | None: + """Find a package's install path WITHOUT importing it. + + This is critical for torch - importing it loads DLLs which then + can't be overwritten. Using find_spec() locates the package without + executing its __init__.py. + + Handles both regular packages (with __init__.py) and namespace packages. + """ + try: + spec = importlib.util.find_spec(package_name) + if spec: + # Regular package: spec.origin points to __init__.py + if spec.origin: + return os.path.dirname(spec.origin) + # Namespace package: use submodule_search_locations + if spec.submodule_search_locations: + locations = list(spec.submodule_search_locations) + if locations: + return locations[0] + except (ImportError, ModuleNotFoundError): + pass + return None diff --git a/src/scope/core/patches/cudnn.py b/src/scope/core/patches/cudnn.py index 3d0ba627..fcbe55da 100644 --- a/src/scope/core/patches/cudnn.py +++ b/src/scope/core/patches/cudnn.py @@ -12,35 +12,11 @@ """ import glob -import importlib.util import os import shutil import sys - -def _find_package_path(package_name: str) -> str | None: - """Find a package's install path WITHOUT importing it. - - This is critical for torch - importing it loads cuDNN DLLs which then - can't be overwritten. Using find_spec() locates the package without - executing its __init__.py. - - Handles both regular packages (with __init__.py) and namespace packages. - """ - try: - spec = importlib.util.find_spec(package_name) - if spec: - # Regular package: spec.origin points to __init__.py - if spec.origin: - return os.path.dirname(spec.origin) - # Namespace package: use submodule_search_locations - if spec.submodule_search_locations: - locations = list(spec.submodule_search_locations) - if locations: - return locations[0] - except (ImportError, ModuleNotFoundError): - pass - return None +from ._utils import find_package_path def patch_torch_cudnn(silent: bool = False): @@ -61,8 +37,8 @@ def patch_torch_cudnn(silent: bool = False): return # Find package paths WITHOUT importing them (avoids loading/locking DLLs) - cudnn_path = _find_package_path("nvidia.cudnn") - torch_path = _find_package_path("torch") + cudnn_path = find_package_path("nvidia.cudnn") + torch_path = find_package_path("torch") if not cudnn_path: if not silent: diff --git a/src/scope/core/patches/static_cuda_launcher.py b/src/scope/core/patches/static_cuda_launcher.py new file mode 100644 index 00000000..4f067034 --- /dev/null +++ b/src/scope/core/patches/static_cuda_launcher.py @@ -0,0 +1,134 @@ +""" +Binary patch for torch_python.dll to fix StaticCudaLauncher overflow. + +On Windows, torch.compile with reduce-overhead mode can cause an OverflowError +when CUDA stream values exceed the range of a signed long integer. The fix +changes a format specifier from 'l' (signed long) to 'K' (unsigned long long). + +Fixes: https://github.com/pytorch/pytorch/issues/162430 +Commit: https://github.com/pytorch/pytorch/commit/7d1bcd9aea8f48733ea46d496e945b7f2592a585 + +This can be removed when upgrading to a PyTorch version with the fix (2.10.0+). +""" + +import os +import sys +import tempfile + +from ._utils import find_package_path + +# Byte sequences for detection and patching +UNPATCHED_BYTES = b"KiiiiisOl" +PATCHED_BYTES = b"KiiiiisOK" + + +def patch_torch_static_cuda_launcher(silent: bool = False): + """Binary patch torch_python.dll to fix StaticCudaLauncher overflow. + + Searches for the unpatched byte sequence and replaces it with the fixed version. + Idempotent: skips if already patched. + + IMPORTANT: This function does NOT import torch, so it can safely + modify the DLL before it is loaded. + + Args: + silent: If True, suppress all output (for use at Python startup). + """ + if sys.platform != "win32": + if not silent: + print("Not on Windows, skipping static_cuda_launcher patch") + return + + # Find torch package path WITHOUT importing it + torch_path = find_package_path("torch") + + if not torch_path: + if not silent: + print("torch package not found") + return + + dll_path = os.path.join(torch_path, "lib", "torch_python.dll") + + if not os.path.isfile(dll_path): + if not silent: + print(f"torch_python.dll not found: {dll_path}") + return + + # Read the DLL contents + try: + with open(dll_path, "rb") as f: + content = f.read() + except PermissionError: + if not silent: + print(f"Permission denied reading: {dll_path}") + print("Close any Python/torch processes and retry.") + return + + # Check if already patched + if PATCHED_BYTES in content: + if not silent: + print("torch_python.dll: already patched (StaticCudaLauncher fix)") + return + + # Check if patch is needed + if UNPATCHED_BYTES not in content: + if not silent: + print("torch_python.dll: byte sequence not found (different version?)") + return + + # Count occurrences to ensure we only patch once + count = content.count(UNPATCHED_BYTES) + if count != 1: + if not silent: + print( + f"torch_python.dll: found {count} occurrences of target sequence, expected 1" + ) + return + + if not silent: + print(f"Patching torch_python.dll: {dll_path}") + + # Apply the patch + patched_content = content.replace(UNPATCHED_BYTES, PATCHED_BYTES, 1) + + # Write to temp file first, then rename for atomic operation + try: + # Create temp file in same directory for same-filesystem rename + dll_dir = os.path.dirname(dll_path) + fd, temp_path = tempfile.mkstemp(dir=dll_dir, suffix=".dll.tmp") + try: + os.write(fd, patched_content) + finally: + os.close(fd) + + # Make original writable if needed + if os.path.exists(dll_path): + os.chmod(dll_path, 0o666) + + # Atomic rename (on Windows, need to remove destination first) + if os.path.exists(dll_path): + os.remove(dll_path) + os.rename(temp_path, dll_path) + + if not silent: + print("Done. Restart Python to use patched torch_python.dll.") + + except PermissionError: + if not silent: + print("Permission denied. Close any Python/torch processes and retry.") + print(f"Or manually patch: {dll_path}") + # Clean up temp file if it exists + if "temp_path" in locals() and os.path.exists(temp_path): + try: + os.remove(temp_path) + except OSError: + pass + + +def main(): + """Entry point for manual patching.""" + patch_torch_static_cuda_launcher(silent=False) + + +if __name__ == "__main__": + main()