Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _ttir_to_ttsharedir(mod):
subprocess_args.insert(2, "--add-llvm-debug-info")

subprocess.check_call(subprocess_args)
_dump_ir_if_needed([dst_path])
return Path(dst_path).read_text()


Expand Down Expand Up @@ -113,14 +114,15 @@ def _ttsharedir_to_llir(ttsharedir: str):
"--mlir-print-debuginfo",
"-o",
llmlir_path])
_dump_ir_if_needed([llmlir_path])

# LLVM-MLIR to LLVM-IR
mlir_translate_path = _get_llvm_bin_path("mlir-translate")
subprocess.check_call([mlir_translate_path, llmlir_path,
"--mlir-to-llvmir",
"-o",
llir_path])
_dump_ir_if_needed([ttshared_path, llmlir_path, llir_path])
_dump_ir_if_needed([llir_path])
return Path(llir_path).read_text()


Expand Down Expand Up @@ -151,7 +153,7 @@ def _llir_to_bin(llir: str, metadata):
sanitizer_attributes_pass_path = str(next(Path(top_level_triton_path).rglob("libSanitizerAttributes.so"), None))

if not sanitizer_attributes_pass_path:
raise Exception(f"libSanitizerAttributes.so does not exist.")
raise Exception("libSanitizerAttributes.so does not exist.")

subprocess.check_call([opt_path, "-load-pass-plugin", sanitizer_attributes_pass_path,
"-passes=sanitizer-attributes", f"-sanitizer-type={sanitizer_type}", "-S", src_path,
Expand Down Expand Up @@ -194,6 +196,7 @@ class CPUOptions:
allow_fp8e4nv: bool = False
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
sanitize_overflow: bool = True
instrumentation_mode: str = ""

def __post_init__(self):
pass
Expand Down Expand Up @@ -256,7 +259,7 @@ def make_ttir(mod, metadata, options):
passes.common.add_symbol_dce(pm)
passes.ttir.add_loop_unroll(pm)
passes.common.add_cse(pm)
pm.run(mod)
pm.run(mod, 'make_ttir')
return mod

def add_stages(self, stages, options, language):
Expand Down
43 changes: 35 additions & 8 deletions backend/driver.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import hashlib
import tempfile
import sysconfig

import os, subprocess, tempfile, platform
import os
import subprocess
import platform
import importlib.util
import sys

Expand Down Expand Up @@ -63,14 +64,34 @@ def _ty_to_cpp(ty):
"fp64": "double",
}[ty]

def _flatten_signature(sig, output):
# Flatten tuples
if isinstance(sig, tuple):
for x in sig:
_flatten_signature(x, output)
else:
output.append(sig)

def _extracted_type(ty):
if isinstance(ty, tuple):
val = ','.join(map(_extracted_type, ty))
return f"[{val}]"
if ty[0] == '*':
return "PyObject*"
if ty == "constexpr":
return "PyObject*"
return _ty_to_cpp(ty)

def _format_of(ty):
if isinstance(ty, tuple):
val = ''.join(map(_format_of, ty))
return f"({val})"
if ty[0] == '*':
return "O"
if ty == "constexpr":
return "O"
if ty.startswith("tensordesc"):
return "O"
return {
"PyObject*": "O",
"constexpr": "O",
Expand All @@ -85,15 +106,20 @@ def _format_of(ty):
"uint16_t": "H",
"uint32_t": "I",
"uint64_t": "K",
}[ty]
}[_ty_to_cpp(ty)]

def _generate_launcher(constants, signature, kernel_name):
arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
args_format = ''.join([_format_of(_extracted_type(ty)) for ty in signature.values()])
args_format = ''.join([_format_of(ty) for ty in signature.values()])
format = "iiiOOOO" + args_format

flat_signature = []
for sig in signature.values():
_flatten_signature(sig, flat_signature)
signature = {i: s for i, s in enumerate(flat_signature)}
arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''

kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if ty != "constexpr")
kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else "int64_t, void*" for i, ty in signature.items() if ty != "constexpr")
kernel_arg_decls += ', ' if kernel_arg_decls else ''

kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if ty != "constexpr")
Expand Down Expand Up @@ -327,7 +353,7 @@ def launch(
libomp_path = next(Path(Path(_get_llvm_bin_path("")).parent).rglob("libomp.so"), None)

if not libomp_path:
raise Exception(f"libomp.so does not exist.")
raise Exception("libomp.so does not exist.")

libomp_path = str(libomp_path.parent)

Expand Down Expand Up @@ -364,7 +390,8 @@ def __init__(self, src, metadata):
kernel_placeholder_name = "KERNEL_NAME_PLACEHOLDER"

constants = src.constants if hasattr(src, "constants") else dict()
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
def cst_key(i):
return src.fn.arg_names.index(i) if isinstance(i, str) else i
constants = {cst_key(key): value for key, value in constants.items()}
signature = {cst_key(key): value for key, value in src.signature.items()}
launcher_src = _generate_launcher(constants, signature, kernel_placeholder_name)
Expand Down
Loading