Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
413 changes: 99 additions & 314 deletions openequivariance/openequivariance/_torch/TensorProduct.py

Large diffs are not rendered by default.

758 changes: 217 additions & 541 deletions openequivariance/openequivariance/_torch/TensorProductConv.py

Large diffs are not rendered by default.

14 changes: 1 addition & 13 deletions openequivariance/openequivariance/_torch/extlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

extra_cflags = ["-O3"]
generic_sources = ["generic_module.cpp"]
torch_sources = ["libtorch_tp_jit.cpp"]
torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"]

include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"])

Expand Down Expand Up @@ -149,22 +149,13 @@ def torch_ext_so_path():

if BUILT_EXTENSION:
from generic_module import (
JITTPImpl,
JITConvImpl,
GroupMM_F32,
GroupMM_F64,
DeviceProp,
DeviceBuffer,
GPUTimer,
)
else:

def JITTPImpl(*args, **kwargs):
_raise_import_error_helper("JITTPImpl")

def JITConvImpl(*args, **kwargs):
_raise_import_error_helper("JITConvImpl")

def GroupMM_F32(*args, **kwargs):
_raise_import_error_helper("GroupMM_F32")

Expand All @@ -174,8 +165,5 @@ def GroupMM_F64(*args, **kwargs):
def DeviceProp(*args, **kwargs):
_raise_import_error_helper("DeviceProp")

def DeviceBuffer(*args, **kwargs):
_raise_import_error_helper("DeviceBuffer")

def GPUTimer(*args, **kwargs):
_raise_import_error_helper("GPUTimer")
9 changes: 9 additions & 0 deletions openequivariance/openequivariance/_torch/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import numpy as np
from types import MappingProxyType
from openequivariance.core.utils import DTypeEnum

Expand Down Expand Up @@ -66,3 +67,11 @@ def reorder_torch(schedule, weights_in, direction, has_batch_dim):
DTypeEnum.UINT8: torch.uint8,
}
)


def string_to_tensor(text: str) -> torch.Tensor:
bytes_data = text.encode("utf-8")
np_bytes = np.frombuffer(bytes_data, dtype=np.uint8)
result = torch.tensor(np_bytes, device="cpu")
result.requires_grad = False
return result
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import openequivariance as oeq
from openequivariance.benchmark.logging_utils import getLogger
from openequivariance.core.ConvolutionBase import CoordGraph
from openequivariance.benchmark.benchmark_utils import NpEncoder

logger = getLogger()

Expand Down Expand Up @@ -145,7 +146,7 @@ def run(
f"{output_folder}/{self.exp_count}_{impl.name()}_{graph.name}.json"
)
with open(fname, "w") as f:
json.dump(result, f, indent=2)
json.dump(result, f, indent=2, cls=NpEncoder)
self.exp_count += 1

logger.info(f"Finished {tc_name}, graph {graph.name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
benchmark_forward,
benchmark_backward,
benchmark_double_backward,
NpEncoder,
)

logger = getLogger()
Expand Down Expand Up @@ -235,10 +236,12 @@ def run(

fname = pathlib.Path(f"{output_folder}/{test_ID}_{impl.name()}.json")

pretty_result = json.dumps(obj=result, indent=2).replace("\\n", "\n")
pretty_result = json.dumps(obj=result, indent=2, cls=NpEncoder).replace(
"\\n", "\n"
)
logger.debug(pretty_result)
with open(fname, "w") as f:
json.dump(result, f, indent=2)
json.dump(result, f, indent=2, cls=NpEncoder)

self.results.append(result)
logger.info(f"Finished Test ID: {test_ID}")
Expand Down
12 changes: 12 additions & 0 deletions openequivariance/openequivariance/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import numpy as np

from openequivariance.benchmark.random_buffer_utils import (
Expand Down Expand Up @@ -290,3 +291,14 @@ def benchmark_double_backward(
)

return result


class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return super(NpEncoder, self).default(obj)
22 changes: 18 additions & 4 deletions openequivariance/openequivariance/core/LoopUnrollConv.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import numpy as np
import json

from openequivariance.core.ConvolutionBase import ConvolutionBase
from openequivariance.core.ComputationSchedule import (
ComputationSchedule,
SMEMCapacityException,
)

from openequivariance.core.utils import dtype_to_enum
from openequivariance.templates.jinja_utils import get_jinja_environment
from openequivariance.core.utils import filter_and_analyze_problem
from openequivariance.core.utils import (
filter_and_analyze_problem,
dtype_to_enum,
hash_str_64,
)


class LoopUnrollConv(ConvolutionBase):
Expand Down Expand Up @@ -203,5 +207,15 @@ def generate_double_backward_schedule(warps_per_block):
)
self.jit_kernel = postprocess_kernel(self.jit_kernel)

# with open("scratch.txt", "w") as f:
# f.write(self.jit_kernel)
self.kernel_string = json.dumps(
{
"kernel": self.jit_kernel,
"forward_config": vars(self.forward_schedule.launch_config),
"backward_config": vars(self.backward_schedule.launch_config),
"double_backward_config": vars(
self.double_backward_schedule.launch_config
),
"kernel_prop": self.kernel_prop,
}
)
self.hash = hash_str_64(self.kernel_string)
22 changes: 20 additions & 2 deletions openequivariance/openequivariance/core/LoopUnrollTP.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import numpy as np
import json

from openequivariance.templates.jinja_utils import get_jinja_environment
from openequivariance.core.ComputationSchedule import ComputationSchedule
from openequivariance.core.TensorProductBase import TensorProductBase
from openequivariance.core.utils import dtype_to_enum
from openequivariance.benchmark.logging_utils import getLogger
from openequivariance.core.utils import dtype_to_enum, hash_str_64

from openequivariance.core.utils import (
filter_and_analyze_problem,
count_cg_non_zero,
)

logger = getLogger()


class LoopUnrollTP(TensorProductBase):
def __init__(self, config, dp, postprocess_kernel, torch_op):
Expand Down Expand Up @@ -91,7 +95,7 @@ def generate_double_backward_schedule(warps_per_block):
)
)

self.kernelProp = {
self.kernel_prop = {
"L1_dim": self.L1.dim,
"L2_dim": self.L2.dim,
"L3_dim": self.L3.dim,
Expand All @@ -106,6 +110,20 @@ def generate_double_backward_schedule(warps_per_block):
"idx_dtype": 0,
}

self.kernel_string = json.dumps(
{
"kernel": self.jit_kernel,
"forward_config": vars(self.forward_schedule.launch_config),
"backward_config": vars(self.backward_schedule.launch_config),
"double_backward_config": vars(
self.double_backward_schedule.launch_config
),
"kernel_prop": self.kernel_prop,
}
)
self.hash = hash_str_64(self.kernel_string)
logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB")

def calculate_flops_forward(self, batch_size: int) -> dict:
if self.is_uvw:
return super().calculate_flops_forward(batch_size)
Expand Down
2 changes: 1 addition & 1 deletion openequivariance/openequivariance/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import json
import tempfile
import hashlib

from enum import IntEnum
import hashlib


class DTypeEnum(IntEnum):
Expand Down
84 changes: 0 additions & 84 deletions openequivariance/openequivariance/extension/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,88 +176,4 @@ class __attribute__ ((visibility ("default"))) JITConvImpl {
}

~JITConvImpl() = default;

// Integer pointer versions of the functions above

void exec_conv_rawptrs(
uint64_t L1_in,
uint64_t L2_in,
uint64_t weights,
uint64_t L3_out,
uint64_t rows,
uint64_t cols,
uint64_t nnz,
uint64_t node_count,
uint64_t workspace) {

exec_conv(
reinterpret_cast<void*>(L1_in),
reinterpret_cast<void*>(L2_in),
reinterpret_cast<void*>(weights),
reinterpret_cast<void*>(L3_out),
reinterpret_cast<void*>(rows),
reinterpret_cast<void*>(cols),
nnz,
node_count,
reinterpret_cast<void*>(workspace),
0 // Default Stream
);
}

void backward_rawptrs(
uint64_t L1_in, uint64_t L1_grad,
uint64_t L2_in, uint64_t L2_grad,
uint64_t weight, uint64_t weight_grad,
uint64_t L3_grad,
uint64_t rows, uint64_t cols,
uint64_t nnz, uint64_t node_count,
uint64_t workspace, uint64_t inverse_perm) {

backward(
reinterpret_cast<void*>(L1_in),
reinterpret_cast<void*>(L1_grad),
reinterpret_cast<void*>(L2_in),
reinterpret_cast<void*>(L2_grad),
reinterpret_cast<void*>(weight),
reinterpret_cast<void*>(weight_grad),
reinterpret_cast<void*>(L3_grad),
reinterpret_cast<void*>(rows),
reinterpret_cast<void*>(cols),
nnz,
node_count,
reinterpret_cast<void*>(workspace),
reinterpret_cast<void*>(inverse_perm),
0 // Default Stream
);
}

void double_backward_rawptrs(
uint64_t L1_in, uint64_t L2_in, uint64_t W, uint64_t L3_grad,
uint64_t L1_dgrad, uint64_t L2_dgrad, uint64_t w_dgrad,
uint64_t L1_grad, uint64_t L2_grad, uint64_t W_grad, uint64_t L3_dgrad,
uint64_t rows, uint64_t cols,
uint64_t nnz, uint64_t node_count,
uint64_t wspace, uint64_t transpose_perm) {

double_backward(
reinterpret_cast<void*>(L1_in),
reinterpret_cast<void*>(L2_in),
reinterpret_cast<void*>(W),
reinterpret_cast<void*>(L3_grad),
reinterpret_cast<void*>(L1_dgrad),
reinterpret_cast<void*>(L2_dgrad),
reinterpret_cast<void*>(w_dgrad),
reinterpret_cast<void*>(L1_grad),
reinterpret_cast<void*>(L2_grad),
reinterpret_cast<void*>(W_grad),
reinterpret_cast<void*>(L3_dgrad),
reinterpret_cast<void*>(rows),
reinterpret_cast<void*>(cols),
nnz,
node_count,
reinterpret_cast<void*>(wspace),
reinterpret_cast<void*>(transpose_perm),
0
);
}
};
27 changes: 0 additions & 27 deletions openequivariance/openequivariance/extension/generic_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,34 +24,13 @@
using GroupMM = GroupMMHIP<T>;
#endif

#include "buffer.hpp"
#include "tensorproducts.hpp"
#include "convolution.hpp"

using namespace std;
namespace py=pybind11;

PYBIND11_MODULE(generic_module, m) {
//=========== Batch tensor products =========
py::class_<JITTPImpl<JITKernel>>(m, "JITTPImpl")
.def(py::init< std::string,
std::unordered_map<string, int64_t>,
std::unordered_map<string, int64_t>,
std::unordered_map<string, int64_t>,
std::unordered_map<string, int64_t>>())
.def("exec_tensor_product_rawptr", &JITTPImpl<JITKernel>::exec_tensor_product_device_rawptrs)
.def("backward_rawptr", &JITTPImpl<JITKernel>::backward_device_rawptrs);

py::class_<JITConvImpl<JITKernel>>(m, "JITConvImpl")
.def(py::init< std::string,
std::unordered_map<string, int64_t>,
std::unordered_map<string, int64_t>,
std::unordered_map<string, int64_t>,
std::unordered_map<string, int64_t>>())
.def("exec_conv_rawptrs", &JITConvImpl<JITKernel>::exec_conv_rawptrs)
.def("backward_rawptrs", &JITConvImpl<JITKernel>::backward_rawptrs)
.def("double_backward_rawptrs", &JITConvImpl<JITKernel>::double_backward_rawptrs);

py::class_<GroupMM<float>>(m, "GroupMM_F32")
.def(py::init<int, int>())
.def("group_gemm", &GroupMM<float>::group_gemm_intptr);
Expand All @@ -68,12 +47,6 @@ PYBIND11_MODULE(generic_module, m) {
.def_readonly("multiprocessorCount", &DeviceProp::multiprocessorCount)
.def_readonly("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock);

py::class_<PyDeviceBuffer<GPU_Allocator>>(m, "DeviceBuffer")
.def(py::init<uint64_t>())
.def(py::init<py::buffer>())
.def("copy_to_host", &PyDeviceBuffer<GPU_Allocator>::copy_to_host)
.def("data_ptr", &PyDeviceBuffer<GPU_Allocator>::data_ptr);

py::class_<GPUTimer>(m, "GPUTimer")
.def(py::init<>())
.def("start", &GPUTimer::start)
Expand Down
Loading