Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions mlx/backend/cuda/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <fmt/format.h>

#include <cassert>
#include <fstream>
#include <string>

namespace mlx::core {

Expand All @@ -26,6 +28,29 @@ constexpr int small_block_size = 8;
// size and small_block_size.
constexpr int small_pool_size = 4 * page_size;

// Check if running on Windows or Windows Subsystem for Linux
bool is_windows() {
#if defined(_WIN32)
return true;
#elif defined(__linux__)
// WSL kernels contain "microsoft" or "WSL" in /proc/version
static bool is_wsl = []() {
std::ifstream version("/proc/version");
if (version.is_open()) {
std::string line;
std::getline(version, line);
return line.find("microsoft") != std::string::npos ||
line.find("Microsoft") != std::string::npos ||
line.find("WSL") != std::string::npos;
}
return false;
}();
return is_wsl;
#else
return false;
#endif
}

bool supports_managed_memory() {
static bool managed_memory = []() {
int device_count = gpu::device_count();
Expand All @@ -34,13 +59,11 @@ bool supports_managed_memory() {
if (!d.managed_memory()) {
return false;
}
#if defined(_WIN32)
// Empirically on Windows if there is no concurrentManagedAccess the
// managed memory also does not work.
if (!d.concurrent_managed_access()) {
// Empirically on Windows (and WSL) if there is no concurrentManagedAccess
// the managed memory also does not work.
if (is_windows() && !d.concurrent_managed_access()) {
return false;
}
#endif
}
return true;
}();
Expand Down
Loading