Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ add_subdirectory(test)
add_subdirectory(tools/triton-shared-opt)

if (TRITON_SHARED_BUILD_CPU_BACKEND)
add_triton_plugin(TritonShared ${CMAKE_CURRENT_SOURCE_DIR}/triton_shared.cc LINK_LIBS TritonSharedAnalysis TritonToLinalg TritonTilingExtIR)
add_triton_plugin(TritonShared ${CMAKE_CURRENT_SOURCE_DIR}/triton_shared.cc LINK_LIBS TritonSharedAnalysis TritonTilingExtIR)
target_link_libraries(TritonShared PRIVATE Python3::Module pybind11::headers)
endif()

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ As part of the conversion process, there are three important analyses:

### Conversion strategy

We introduce the `TritonToLinalg` pass that converts the `triton` dialect to the `linalg` dialect on *tensors*. This means the resulting IR is fully compatible with `linalg` tiling and fusion transformation passes. As mentioned in the `Pointer analysis`'s description, we do however have to deal with memref instructions at the load and store boundaries and have to convert them to tensors using `bufferization.to_tensor`. Here's a simple example of what the IR looks like:
We introduce the `TritonToLinalgExperimental` pass that converts the `triton` dialect to the `linalg` dialect on *tensors*. This means the resulting IR is fully compatible with `linalg` tiling and fusion transformation passes. As mentioned in the `Pointer analysis`'s description, we do however have to deal with memref instructions at the load and store boundaries and have to convert them to tensors using `bufferization.to_tensor`. Here's a simple example of what the IR looks like:

```mlir
tt.func @kernel(%afloat : !tt.ptr<bf16>, %res : !tt.ptr<bf16>) {
Expand Down
9 changes: 6 additions & 3 deletions backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _ttir_to_ttsharedir(mod):
subprocess_args.insert(2, "--add-llvm-debug-info")

subprocess.check_call(subprocess_args)
_dump_ir_if_needed([dst_path])
return Path(dst_path).read_text()


Expand Down Expand Up @@ -113,14 +114,15 @@ def _ttsharedir_to_llir(ttsharedir: str):
"--mlir-print-debuginfo",
"-o",
llmlir_path])
_dump_ir_if_needed([llmlir_path])

# LLVM-MLIR to LLVM-IR
mlir_translate_path = _get_llvm_bin_path("mlir-translate")
subprocess.check_call([mlir_translate_path, llmlir_path,
"--mlir-to-llvmir",
"-o",
llir_path])
_dump_ir_if_needed([ttshared_path, llmlir_path, llir_path])
_dump_ir_if_needed([llir_path])
return Path(llir_path).read_text()


Expand Down Expand Up @@ -151,7 +153,7 @@ def _llir_to_bin(llir: str, metadata):
sanitizer_attributes_pass_path = str(next(Path(top_level_triton_path).rglob("libSanitizerAttributes.so"), None))

if not sanitizer_attributes_pass_path:
raise Exception(f"libSanitizerAttributes.so does not exist.")
raise Exception("libSanitizerAttributes.so does not exist.")

subprocess.check_call([opt_path, "-load-pass-plugin", sanitizer_attributes_pass_path,
"-passes=sanitizer-attributes", f"-sanitizer-type={sanitizer_type}", "-S", src_path,
Expand Down Expand Up @@ -194,6 +196,7 @@ class CPUOptions:
allow_fp8e4nv: bool = False
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
sanitize_overflow: bool = True
instrumentation_mode: str = ""

def __post_init__(self):
pass
Expand Down Expand Up @@ -256,7 +259,7 @@ def make_ttir(mod, metadata, options):
passes.common.add_symbol_dce(pm)
passes.ttir.add_loop_unroll(pm)
passes.common.add_cse(pm)
pm.run(mod)
pm.run(mod, 'make_ttir')
return mod

def add_stages(self, stages, options, language):
Expand Down
43 changes: 35 additions & 8 deletions backend/driver.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import hashlib
import tempfile
import sysconfig

import os, subprocess, tempfile, platform
import os
import subprocess
import platform
import importlib.util
import sys

Expand Down Expand Up @@ -63,14 +64,34 @@ def _ty_to_cpp(ty):
"fp64": "double",
}[ty]

def _flatten_signature(sig, output):
# Flatten tuples
if isinstance(sig, tuple):
for x in sig:
_flatten_signature(x, output)
else:
output.append(sig)

def _extracted_type(ty):
if isinstance(ty, tuple):
val = ','.join(map(_extracted_type, ty))
return f"[{val}]"
if ty[0] == '*':
return "PyObject*"
if ty == "constexpr":
return "PyObject*"
return _ty_to_cpp(ty)

def _format_of(ty):
if isinstance(ty, tuple):
val = ''.join(map(_format_of, ty))
return f"({val})"
if ty[0] == '*':
return "O"
if ty == "constexpr":
return "O"
if ty.startswith("tensordesc"):
return "O"
return {
"PyObject*": "O",
"constexpr": "O",
Expand All @@ -85,15 +106,20 @@ def _format_of(ty):
"uint16_t": "H",
"uint32_t": "I",
"uint64_t": "K",
}[ty]
}[_ty_to_cpp(ty)]

def _generate_launcher(constants, signature, kernel_name):
arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
args_format = ''.join([_format_of(_extracted_type(ty)) for ty in signature.values()])
args_format = ''.join([_format_of(ty) for ty in signature.values()])
format = "iiiOOOO" + args_format

flat_signature = []
for sig in signature.values():
_flatten_signature(sig, flat_signature)
signature = {i: s for i, s in enumerate(flat_signature)}
arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''

kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else f"int64_t, void*" for i, ty in signature.items() if ty != "constexpr")
kernel_arg_decls = ', '.join(_ty_to_cpp(ty) if ty[0] != "*" else "int64_t, void*" for i, ty in signature.items() if ty != "constexpr")
kernel_arg_decls += ', ' if kernel_arg_decls else ''

kernel_parameters = ', '.join(f"static_cast<{_ty_to_cpp(ty)}>(arg{i})" if ty[0] != "*" else f"0, &ptr_arg{i}" for i, ty in signature.items() if ty != "constexpr")
Expand Down Expand Up @@ -327,7 +353,7 @@ def launch(
libomp_path = next(Path(Path(_get_llvm_bin_path("")).parent).rglob("libomp.so"), None)

if not libomp_path:
raise Exception(f"libomp.so does not exist.")
raise Exception("libomp.so does not exist.")

libomp_path = str(libomp_path.parent)

Expand Down Expand Up @@ -364,7 +390,8 @@ def __init__(self, src, metadata):
kernel_placeholder_name = "KERNEL_NAME_PLACEHOLDER"

constants = src.constants if hasattr(src, "constants") else dict()
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
def cst_key(i):
return src.fn.arg_names.index(i) if isinstance(i, str) else i
constants = {cst_key(key): value for key, value in constants.items()}
signature = {cst_key(key): value for key, value in src.signature.items()}
launcher_src = _generate_launcher(constants, signature, kernel_placeholder_name)
Expand Down
1 change: 0 additions & 1 deletion include/triton-shared/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
add_subdirectory(TritonToLinalg)
add_subdirectory(TritonToLinalgExperimental)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonArithToLinalg)
Expand Down
Loading