Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,8 @@ flydsl/src/*.egg-info/

# Wheels / packaging outputs
dist/

# Profiling artifacts (rocprofv3)
.rocprofv3/
R9700-Workstation-SH/
*_results.db
3 changes: 2 additions & 1 deletion flir/include/flir/FlirRocmDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def GfxArchEnum : I32EnumAttr<"GfxArch",
I32EnumAttrCase<"GFX940", 2, "gfx940">,
I32EnumAttrCase<"GFX941", 3, "gfx941">,
I32EnumAttrCase<"GFX942", 4, "gfx942">,
I32EnumAttrCase<"GFX950", 5, "gfx950">
I32EnumAttrCase<"GFX950", 5, "gfx950">,
I32EnumAttrCase<"GFX1201", 6, "gfx1201">
]> {
let cppNamespace = "::mlir::flir::rocm";
}
Expand Down
101 changes: 93 additions & 8 deletions flir/python_bindings/runtime/FlirRocmRuntimeWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#include <cassert>
#include <cstdio>
#include <mutex>
#include <unordered_map>

#include "hip/hip_runtime.h"

Expand All @@ -36,9 +38,65 @@
fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \
}(expr)

extern "C" FLIR_EXPORT hipModule_t mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) {
// ---------------------------------------------------------------------------
// Caching layer: avoids calling hipModuleLoadData / hipModuleGetFunction /
// hipStreamCreate on every kernel launch. The MLIR gpu-to-llvm lowering
// emits calls to mgpuModuleLoad + mgpuModuleGetFunction + mgpuStreamCreate
// on *every* invocation of the host wrapper, which adds ~0.8 ms of overhead
// per call.
//
// Cache key: we use a content hash of the blob data. The JIT may reuse
// the same address for different blobs (after freeing a previous module),
// so pointer-based keys are unsafe. We compute a fast hash of the blob
// content and use (hash, size) as the cache key. For repeated calls from
// the same JIT'd function, the blob content is identical → cache hit.
// ---------------------------------------------------------------------------

static std::mutex g_cache_mutex;

// Simple FNV-1a hash of blob content
static uint64_t hash_blob(const void *data, size_t size) {
const uint8_t *p = static_cast<const uint8_t *>(data);
uint64_t hash = 14695981039346656037ULL; // FNV offset basis
for (size_t i = 0; i < size; i++) {
hash ^= p[i];
hash *= 1099511628211ULL; // FNV prime
}
return hash;
}

struct BlobKey {
uint64_t content_hash;
size_t size;
bool operator==(const BlobKey &o) const {
return content_hash == o.content_hash && size == o.size;
}
};

struct BlobKeyHash {
size_t operator()(const BlobKey &k) const {
size_t h = std::hash<uint64_t>()(k.content_hash);
h ^= std::hash<size_t>()(k.size) + 0x9e3779b9 + (h << 6) + (h >> 2);
return h;
}
};

static std::unordered_map<BlobKey, hipModule_t, BlobKeyHash> g_module_cache;
static std::unordered_map<hipModule_t, std::unordered_map<std::string, hipFunction_t>> g_func_cache;
static hipStream_t g_cached_stream = nullptr;
static bool g_stream_initialized = false;

extern "C" FLIR_EXPORT hipModule_t mgpuModuleLoad(void *data, size_t gpuBlobSize) {
std::lock_guard<std::mutex> lock(g_cache_mutex);
uint64_t h = hash_blob(data, gpuBlobSize);
BlobKey key{h, gpuBlobSize};
auto it = g_module_cache.find(key);
if (it != g_module_cache.end()) {
return it->second;
}
hipModule_t module = nullptr;
HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
g_module_cache[key] = module;
return module;
}

Expand All @@ -50,13 +108,26 @@ extern "C" FLIR_EXPORT hipModule_t mgpuModuleLoadJIT(void *data, int optLevel) {
}

extern "C" FLIR_EXPORT void mgpuModuleUnload(hipModule_t module) {
HIP_REPORT_IF_ERROR(hipModuleUnload(module));
// Don't unload cached modules — they're reused across calls.
(void)module;
}

extern "C" FLIR_EXPORT hipFunction_t mgpuModuleGetFunction(hipModule_t module,
const char *name) {
const char *name) {
std::lock_guard<std::mutex> lock(g_cache_mutex);
auto &func_map = g_func_cache[module];
std::string key(name);
auto it = func_map.find(key);
if (it != func_map.end()) {
return it->second;
}
hipFunction_t function = nullptr;
HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name));
hipError_t err = hipModuleGetFunction(&function, module, name);
if (err != hipSuccess) {
fprintf(stderr, "mgpuModuleGetFunction: failed for name='%s' module=%p err=%d\n",
name, (void*)module, (int)err);
}
func_map[key] = function;
return function;
}

Expand All @@ -74,13 +145,27 @@ extern "C" FLIR_EXPORT void mgpuLaunchKernel(hipFunction_t function, intptr_t gr
}

extern "C" FLIR_EXPORT hipStream_t mgpuStreamCreate() {
hipStream_t stream = nullptr;
HIP_REPORT_IF_ERROR(hipStreamCreate(&stream));
return stream;
std::lock_guard<std::mutex> lock(g_cache_mutex);
if (g_stream_initialized) {
return g_cached_stream;
}
g_stream_initialized = true;
HIP_REPORT_IF_ERROR(hipStreamCreate(&g_cached_stream));
return g_cached_stream;
}

// Allow Python to set the stream we use for kernel launches.
// Call this with torch.cuda.current_stream().cuda_stream to share PyTorch's stream.
extern "C" FLIR_EXPORT void mgpuSetStream(hipStream_t stream) {
std::lock_guard<std::mutex> lock(g_cache_mutex);
g_cached_stream = stream;
g_stream_initialized = true;
}

extern "C" FLIR_EXPORT void mgpuStreamDestroy(hipStream_t stream) {
HIP_REPORT_IF_ERROR(hipStreamDestroy(stream));
// Don't destroy the cached stream — it's reused across calls.
// The MLIR lowering emits a destroy after every launch_func.
(void)stream;
}

extern "C" FLIR_EXPORT void mgpuStreamSynchronize(hipStream_t stream) {
Expand Down
60 changes: 47 additions & 13 deletions flydsl/src/flydsl/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
from flydsl.runtime.device import get_rocm_arch

from .executor import default_shared_libs
from .cache import FileCache, cache_enabled, cache_rebuild_requested, default_key_payload, make_cache_key
from .cache import (
FileCache,
cache_enabled,
cache_rebuild_requested,
default_key_payload,
make_cache_key,
)

if TYPE_CHECKING:
from .executor import ExecutionEngineExecutor as Executor
Expand All @@ -31,6 +37,15 @@ class CompileOptions:
backend: Literal["execution_engine"] = "execution_engine"


def _is_wave32_arch(chip: str) -> bool:
"""Return True if the GPU architecture uses wavefront size 32 (RDNA)."""
# RDNA architectures (gfx10xx, gfx11xx, gfx12xx) use wave32 natively.
# CDNA architectures (gfx9xx) use wave64.
return (
chip.startswith("gfx10") or chip.startswith("gfx11") or chip.startswith("gfx12")
)


def _pipeline_fragments(
*,
chip: str,
Expand All @@ -50,6 +65,7 @@ def _pipeline_fragments(
rocdl_bare_ptr_opt = b2s(use_bare_ptr_memref_call_conv)
llvm_bare_host_opt = b2s(use_bare_pointers_for_host)
llvm_bare_kern_opt = b2s(use_bare_pointers_for_kernels)
wave64_opt = b2s(not _is_wave32_arch(chip))
return [
"flir-to-standard",
"trivial-dce",
Expand All @@ -63,13 +79,13 @@ def _pipeline_fragments(
"gpu.module(reconcile-unrealized-casts)",
# Keep this as a formatted string so the chip is visible in dumps and matches
# the non-dump compilation pipeline.
f"rocdl-attach-target{{O=2 abi=600 chip={chip} correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}}",
f"rocdl-attach-target{{O=2 abi=600 chip={chip} correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64={wave64_opt}}}",
"gpu-to-llvm{intersperse-sizes-for-kernels=false "
+ f"use-bare-pointers-for-host={llvm_bare_host_opt} "
+ f"use-bare-pointers-for-kernels={llvm_bare_kern_opt}"
+ "}",
"reconcile-unrealized-casts",
"gpu-module-to-binary{format=fatbin opts= section= toolkit=}",
f"gpu-module-to-binary{{format=fatbin opts={os.environ.get('FLYDSL_LLC_OPTS', '')} section= toolkit=}}",
]


Expand Down Expand Up @@ -137,7 +153,9 @@ def _dump_ir(stage: str, *, dump_dir: Path, asm: str) -> Path:
return out


def _dump_isa_from_rocdl_module_asm(*, dump_dir: Path, ctx: ir.Context, asm: str, verify: bool) -> Optional[Path]:
def _dump_isa_from_rocdl_module_asm(
*, dump_dir: Path, ctx: ir.Context, asm: str, verify: bool
) -> Optional[Path]:
"""Best-effort dump final ISA/assembly (.s) for the current GPU module.

This is only used for debug dumps. It intentionally does not affect the main
Expand All @@ -151,7 +169,10 @@ def _dump_isa_from_rocdl_module_asm(*, dump_dir: Path, ctx: ir.Context, asm: str
try:
# Parse a fresh clone so we don't mutate the main compilation module.
mod = ir.Module.parse(asm, context=ctx)
pm = PassManager.parse("builtin.module(gpu-module-to-binary{format=isa opts= section= toolkit=})", context=ctx)
pm = PassManager.parse(
"builtin.module(gpu-module-to-binary{format=isa opts= section= toolkit=})", # ISA dump always without custom opts
context=ctx,
)
pm.enable_verifier(bool(verify))
pm.run(mod.operation)
isa_bytes = get_compile_object_bytes(mod)
Expand Down Expand Up @@ -239,8 +260,10 @@ def _apply_waves_per_eu_hint(mlir_module, waves_per_eu: int):
# Best-effort: if attribute injection fails, log and continue
# This prevents breaking existing functionality
import warnings

warnings.warn(f"Failed to apply waves_per_eu hint: {e}", RuntimeWarning)


def compile(
flir_module_or_ir: Union[object, ir.Module],
*,
Expand Down Expand Up @@ -272,12 +295,16 @@ def compile(
if mlir_module is None:
mlir_module = flir_module_or_ir
if not isinstance(mlir_module, ir.Module):
raise TypeError(f"Expected an MLIR module or flir.lang.MlirModule; got {type(flir_module_or_ir)}")
raise TypeError(
f"Expected an MLIR module or flir.lang.MlirModule; got {type(flir_module_or_ir)}"
)

ctx = mlir_module.context
ensure_flir_python_extensions(ctx)

compile_only = _env_truthy("FLYDSL_COMPILE_ONLY", "0") or _env_truthy("COMPILE_ONLY", "0")
compile_only = _env_truthy("FLYDSL_COMPILE_ONLY", "0") or _env_truthy(
"COMPILE_ONLY", "0"
)
dump_enabled = _env_truthy("FLIR_DUMP_IR", "0")
dump_root_dir = Path(os.environ.get("FLIR_DUMP_DIR", "my_ir_dumps")).resolve()
dump_prefix_base = (
Expand Down Expand Up @@ -311,7 +338,9 @@ def compile(
module = mlir_module

# Allow overriding target arch via env var (useful for cross-compilation or FLYDSL_COMPILE_ONLY mode)
chip = (os.environ.get("FLYDSL_TARGET_ARCH") or os.environ.get("ARCH") or "").strip() or get_rocm_arch()
chip = (
os.environ.get("FLYDSL_TARGET_ARCH") or os.environ.get("ARCH") or ""
).strip() or get_rocm_arch()

pipeline = _build_pipeline_str(
chip=chip,
Expand Down Expand Up @@ -355,12 +384,17 @@ def compile(
print(f"[flir.compile] cache hit key={cache_key}")
if compile_only:
if dump_enabled or print_final_module:
print(f"[flir.compile] FLYDSL_COMPILE_ONLY=1, skipping executor creation (arch={chip})")
print(
f"[flir.compile] FLYDSL_COMPILE_ONLY=1, skipping executor creation (arch={chip})"
)
return None
from .executor import ExecutionEngineExecutor as Executor

if shared_libs is None:
shared_libs = default_shared_libs().as_list()
return Executor(cached_mod, opt_level=opt_level, shared_libs=shared_libs)
return Executor(
cached_mod, opt_level=opt_level, shared_libs=shared_libs
)
except Exception:
# Treat cache parse failures as misses.
pass
Expand Down Expand Up @@ -433,13 +467,13 @@ def compile(
# In compile-only mode, skip executor creation and return None
if compile_only:
if dump_enabled or print_final_module:
print(f"[flir.compile] FLYDSL_COMPILE_ONLY=1, skipping executor creation (arch={chip})")
print(
f"[flir.compile] FLYDSL_COMPILE_ONLY=1, skipping executor creation (arch={chip})"
)
return None

from .executor import ExecutionEngineExecutor as Executor

if shared_libs is None:
shared_libs = default_shared_libs().as_list()
return Executor(module, opt_level=opt_level, shared_libs=shared_libs)


Loading
Loading