From 5d8f4f0c3e8baf06494006fc26d52aa7cc3cad62 Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Fri, 7 Jun 2024 16:07:36 +0800 Subject: [PATCH 1/5] add bacthed bind args --- runtime/python/src/module.cc | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/runtime/python/src/module.cc b/runtime/python/src/module.cc index 799a812f3..55e3ae1d6 100644 --- a/runtime/python/src/module.cc +++ b/runtime/python/src/module.cc @@ -256,6 +256,30 @@ PYBIND11_MODULE(MODULE_NAME, m) { THROW_ON_FAIL( req.Context().BindArg(offset, reinterpret_cast(ptr))); }) + .def("bind_args", + [](ReqeustContextWithSession &req, py::list offset_and_args) { + for (auto handle : offset_and_args) { + PyObject *obj = handle.ptr(); + if (!PyTuple_Check(obj) || PyTuple_Size(obj) != 2) { + PyErr_SetString(PyExc_TypeError, + "expect pair of offset and arg"); + return; + } + + PyObject *offset = PyTuple_GetItem(obj, 0); + PyObject *arg = PyTuple_GetItem(obj, 1); + if (!PyLong_Check(offset)) { + PyErr_SetString(PyExc_TypeError, "offset should be integer"); + return; + } + if (!PyLong_Check(arg)) { + PyErr_SetString(PyExc_TypeError, "arg should be integer"); + return; + } + THROW_ON_FAIL(req.Context().BindArg(PyLong_AsSize_t(offset), + PyLong_AsVoidPtr(arg))); + } + }) .def("get_arg", [](ReqeustContextWithSession &req, size_t offset) { void *ptr = req.Context().GetArg(offset); From 1bd7f9201e5b42efebb879ef1a171a083ddcbaad Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Fri, 7 Jun 2024 23:03:22 +0800 Subject: [PATCH 2/5] [frontend] reduce brt overhead --- .../byteir_backend/compiled_function.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py index a81885699..693ec2f8a 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py @@ -1,6 +1,7 @@ import dataclasses import logging from typing import Optional, Any, Callable, Dict, List, Sequence, Tuple, Union +from brt.utils import brt_dtype_to_torch_dtype import torch @@ -43,8 +44,15 @@ def __init__(self, module_path_or_session, none_indices): self._req = self._session.new_request_context( torch.cuda.current_stream()._as_parameter_.value) + self.output_shape_and_dtype = [ + ( + self._session.get_static_shape(offset), + brt_dtype_to_torch_dtype(self._session.get_data_type(offset)), + ) + for offset in self._session.get_output_arg_offsets() + ] + def __call__(self, *inputs): - from brt.utils import brt_dtype_to_torch_dtype log.debug(f"***** Run function compiled through byteir ******") @@ -63,19 +71,22 @@ def __call__(self, *inputs): results = [ torch.empty( - self._session.get_static_shape(offset), - dtype=brt_dtype_to_torch_dtype( - self._session.get_data_type(offset)), + shape, + dtype=ty, device=device, - ) for offset in self._session.get_output_arg_offsets() + ) for shape, ty in self.output_shape_and_dtype ] + inputOffsetAndArg = [] + outputOffsetAndArg = [] for offset, input in zip(self._session.get_input_arg_offsets(), new_inputs): - self._req.bind_arg(offset, input.data_ptr()) + inputOffsetAndArg.append((offset, input.data_ptr())) for offset, output in zip(self._session.get_output_arg_offsets(), results): - self._req.bind_arg(offset, output.data_ptr()) + outputOffsetAndArg.append((offset, output.data_ptr())) + self._req.bind_args(inputOffsetAndArg) + self._req.bind_args(outputOffsetAndArg) self._req.finish_io_binding() self._req.run() self._req.sync() From 597d6221547a97500b9d4ff670f2e7bda8e59731 Mon Sep 17 00:00:00 2001 From: "huangchenhui.yellow" Date: Tue, 11 Jun 2024 15:29:30 +0800 Subject: [PATCH 3/5] refine contigous inputs && host side overhead Signed-off-by: huangchenhui.yellow --- .../byteir_backend/compiled_function.py | 110 +++++++++++------- .../torch_frontend/byteir_backend/compiler.py | 1 + 2 files changed, 68 insertions(+), 43 deletions(-) diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py index 693ec2f8a..a153cbf18 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py @@ -1,4 +1,5 @@ import dataclasses +import functools import logging from typing import Optional, Any, Callable, Dict, List, Sequence, Tuple, Union from brt.utils import brt_dtype_to_torch_dtype @@ -43,66 +44,89 @@ def __init__(self, module_path_or_session, none_indices): self._none_indices = none_indices self._req = self._session.new_request_context( torch.cuda.current_stream()._as_parameter_.value) - - self.output_shape_and_dtype = [ - ( - self._session.get_static_shape(offset), - brt_dtype_to_torch_dtype(self._session.get_data_type(offset)), - ) - for offset in self._session.get_output_arg_offsets() + self.input_arg_offsets = self._session.get_input_arg_offsets() + self.output_arg_offsets = self._session.get_output_arg_offsets() + + self.output_shape_and_dtype = [( + self._session.get_static_shape(offset), + brt_dtype_to_torch_dtype(self._session.get_data_type(offset)), + ) for offset in self._session.get_output_arg_offsets()] + + self._outs_len = len(self.output_arg_offsets) + self.static_shape_and_dtype = [ + (self._session.get_static_shape(offset), + brt_dtype_to_torch_dtype(self._session.get_data_type(offset))) + for offset in self.output_arg_offsets ] + self.real_outs_index_map = self._get_outputs_index_map( + self._outs_len, self._none_indices) + self.strited_inputs_index = None + + def _get_outputs_index_map(self, out_lens: int, none_indices: List[int]): + res = [] + none_lens = len(none_indices) + none_cnt = 0 + for idx in range(out_lens + none_lens): + if none_cnt < none_lens and idx == none_indices[none_cnt]: + none_cnt += 1 + continue + res.append(idx) + + return res + + @functools.lru_cache + def get_out_tensors(self, device): + outputs_ptr = [None] * self._outs_len + results = [None] * (self._outs_len + len(self._none_indices)) + + for idx, shape_dty in enumerate(self.static_shape_and_dtype): + _out = torch.empty(shape_dty[0], dtype=shape_dty[1], device=device) + results[self.real_outs_index_map[idx]] = _out + outputs_ptr[idx] = _out.data_ptr() + + return results, outputs_ptr + def __call__(self, *inputs): log.debug(f"***** Run function compiled through byteir ******") # FIXME. byteir requires all inputs on device side, move host side tensor to device. # Preprocess the strided tensor as byteir does not support yet. - new_inputs = [] - - for i in range(0, len(inputs)): - _t = inputs[i] - if not _t.is_cuda: - log.warning(f"device error: type={type(_t)}, {_t.device}") - _t = _t.to("cuda") - new_inputs.append(_t.contiguous()) - - device = new_inputs[0].device - - results = [ - torch.empty( - shape, - dtype=ty, - device=device, - ) for shape, ty in self.output_shape_and_dtype - ] + new_inputs_ptr = [None] * len(inputs) + + if self.strited_inputs_index is None: + self.strited_inputs_index = [] + for i in range(0, len(inputs)): + _t = inputs[i] + if not _t.is_contiguous(): + _t = _t.contiguous() + self.strited_inputs_index.append(i) + new_inputs_ptr[i] = _t.data_ptr() + else: + for i in range(0, len(inputs)): + new_inputs_ptr[i] = inputs[i].data_ptr() + for i in self.strited_inputs_index: + new_inputs_ptr[i] = inputs[i].contiguous().data_ptr() + + device = inputs[0].device + + results, outputs_ptr = self.get_out_tensors(device) inputOffsetAndArg = [] outputOffsetAndArg = [] - for offset, input in zip(self._session.get_input_arg_offsets(), - new_inputs): - inputOffsetAndArg.append((offset, input.data_ptr())) - for offset, output in zip(self._session.get_output_arg_offsets(), - results): - outputOffsetAndArg.append((offset, output.data_ptr())) + for offset, input_ptr in zip(self.input_arg_offsets, new_inputs_ptr): + inputOffsetAndArg.append((offset, input_ptr)) + for offset, output_ptr in zip(self.output_arg_offsets, outputs_ptr): + outputOffsetAndArg.append((offset, output_ptr)) self._req.bind_args(inputOffsetAndArg) self._req.bind_args(outputOffsetAndArg) self._req.finish_io_binding() self._req.run() self._req.sync() - # add None results to return values - rets = [] - none_cnt = 0 - result_cnt = 0 - for i in range(len(results) + len(self._none_indices)): - if none_cnt < len( - self._none_indices) and i == self._none_indices[none_cnt]: - rets.append(None) - none_cnt += 1 - else: - rets.append(results[result_cnt]) - result_cnt += 1 + rets = results + if len(rets) == 1: return rets[0] return rets diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiler.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiler.py index e8ba0283a..537b1c344 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiler.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiler.py @@ -48,6 +48,7 @@ def byteir_compiler( partition_fn=byteir_partition_fn, #partition_fn=min_cut_rematerialization_partition, #partition_fn=default_partition, + keep_inference_input_mutations=False, ) fake_mode = detect_fake_mode( From 21b3b6c4ca7a58787592b80ad28bb23dc13c9ac5 Mon Sep 17 00:00:00 2001 From: "huangchenhui.yellow" Date: Wed, 12 Jun 2024 17:50:31 +0800 Subject: [PATCH 4/5] [dynamo] reduce host overhead Signed-off-by: huangchenhui.yellow --- .../byteir_backend/compiled_function.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py index a153cbf18..ccc6ecb67 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py @@ -113,12 +113,12 @@ def __call__(self, *inputs): results, outputs_ptr = self.get_out_tensors(device) - inputOffsetAndArg = [] - outputOffsetAndArg = [] - for offset, input_ptr in zip(self.input_arg_offsets, new_inputs_ptr): - inputOffsetAndArg.append((offset, input_ptr)) - for offset, output_ptr in zip(self.output_arg_offsets, outputs_ptr): - outputOffsetAndArg.append((offset, output_ptr)) + inputOffsetAndArg = [None] * len(new_inputs_ptr) + outputOffsetAndArg = [None] * len(outputs_ptr) + for idx, (offset, input_ptr) in enumerate(zip(self.input_arg_offsets, new_inputs_ptr)): + inputOffsetAndArg[idx] = (offset, input_ptr) + for idx, (offset, output_ptr) in enumerate(zip(self.output_arg_offsets, outputs_ptr)): + outputOffsetAndArg[idx] = (offset, output_ptr) self._req.bind_args(inputOffsetAndArg) self._req.bind_args(outputOffsetAndArg) self._req.finish_io_binding() From 0b93109d0da13642cb79980b8334bb77fbee53a2 Mon Sep 17 00:00:00 2001 From: "huangchenhui.yellow" Date: Thu, 13 Jun 2024 19:16:09 +0800 Subject: [PATCH 5/5] [dynamo] add some comments Signed-off-by: huangchenhui.yellow --- .../torch_frontend/byteir_backend/compiled_function.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py index ccc6ecb67..359f450d5 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py @@ -77,6 +77,16 @@ def _get_outputs_index_map(self, out_lens: int, none_indices: List[int]): @functools.lru_cache def get_out_tensors(self, device): + """ + The number of outputs is too large, which causes Torch to still take a significant amount of time even + with a memory pool. We just use a simple cache to reduce this overhead. + + NB. One should notice that: We made an assumption here, the subgraph will just be called once per + training iteration. As the cached output tensor is not reusable inside a iteration. + + TODO: We could implement a memory pool or just reduce the amount of torch tensor allocation, e.g. Just + alloc a large tensor and split this to output tensors. + """ outputs_ptr = [None] * self._outs_len results = [None] * (self._outs_len + len(self._none_indices))