diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py index a81885699..359f450d5 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiled_function.py @@ -1,6 +1,8 @@ import dataclasses +import functools import logging from typing import Optional, Any, Callable, Dict, List, Sequence, Tuple, Union +from brt.utils import brt_dtype_to_torch_dtype import torch @@ -42,56 +44,99 @@ def __init__(self, module_path_or_session, none_indices): self._none_indices = none_indices self._req = self._session.new_request_context( torch.cuda.current_stream()._as_parameter_.value) + self.input_arg_offsets = self._session.get_input_arg_offsets() + self.output_arg_offsets = self._session.get_output_arg_offsets() + + self.output_shape_and_dtype = [( + self._session.get_static_shape(offset), + brt_dtype_to_torch_dtype(self._session.get_data_type(offset)), + ) for offset in self._session.get_output_arg_offsets()] + + self._outs_len = len(self.output_arg_offsets) + self.static_shape_and_dtype = [ + (self._session.get_static_shape(offset), + brt_dtype_to_torch_dtype(self._session.get_data_type(offset))) + for offset in self.output_arg_offsets + ] + + self.real_outs_index_map = self._get_outputs_index_map( + self._outs_len, self._none_indices) + self.strited_inputs_index = None + + def _get_outputs_index_map(self, out_lens: int, none_indices: List[int]): + res = [] + none_lens = len(none_indices) + none_cnt = 0 + for idx in range(out_lens + none_lens): + if none_cnt < none_lens and idx == none_indices[none_cnt]: + none_cnt += 1 + continue + res.append(idx) + + return res + + @functools.lru_cache + def get_out_tensors(self, device): + """ + The number of outputs is too large, which causes Torch to still take a significant amount of time even + with a memory pool. We just use a simple cache to reduce this overhead. + + NB. One should notice that: We made an assumption here, the subgraph will just be called once per + training iteration. As the cached output tensor is not reusable inside a iteration. + + TODO: We could implement a memory pool or just reduce the amount of torch tensor allocation, e.g. Just + alloc a large tensor and split this to output tensors. + """ + outputs_ptr = [None] * self._outs_len + results = [None] * (self._outs_len + len(self._none_indices)) + + for idx, shape_dty in enumerate(self.static_shape_and_dtype): + _out = torch.empty(shape_dty[0], dtype=shape_dty[1], device=device) + results[self.real_outs_index_map[idx]] = _out + outputs_ptr[idx] = _out.data_ptr() + + return results, outputs_ptr def __call__(self, *inputs): - from brt.utils import brt_dtype_to_torch_dtype log.debug(f"***** Run function compiled through byteir ******") # FIXME. byteir requires all inputs on device side, move host side tensor to device. # Preprocess the strided tensor as byteir does not support yet. - new_inputs = [] - - for i in range(0, len(inputs)): - _t = inputs[i] - if not _t.is_cuda: - log.warning(f"device error: type={type(_t)}, {_t.device}") - _t = _t.to("cuda") - new_inputs.append(_t.contiguous()) - - device = new_inputs[0].device - - results = [ - torch.empty( - self._session.get_static_shape(offset), - dtype=brt_dtype_to_torch_dtype( - self._session.get_data_type(offset)), - device=device, - ) for offset in self._session.get_output_arg_offsets() - ] - - for offset, input in zip(self._session.get_input_arg_offsets(), - new_inputs): - self._req.bind_arg(offset, input.data_ptr()) - for offset, output in zip(self._session.get_output_arg_offsets(), - results): - self._req.bind_arg(offset, output.data_ptr()) + new_inputs_ptr = [None] * len(inputs) + + if self.strited_inputs_index is None: + self.strited_inputs_index = [] + for i in range(0, len(inputs)): + _t = inputs[i] + if not _t.is_contiguous(): + _t = _t.contiguous() + self.strited_inputs_index.append(i) + new_inputs_ptr[i] = _t.data_ptr() + else: + for i in range(0, len(inputs)): + new_inputs_ptr[i] = inputs[i].data_ptr() + for i in self.strited_inputs_index: + new_inputs_ptr[i] = inputs[i].contiguous().data_ptr() + + device = inputs[0].device + + results, outputs_ptr = self.get_out_tensors(device) + + inputOffsetAndArg = [None] * len(new_inputs_ptr) + outputOffsetAndArg = [None] * len(outputs_ptr) + for idx, (offset, input_ptr) in enumerate(zip(self.input_arg_offsets, new_inputs_ptr)): + inputOffsetAndArg[idx] = (offset, input_ptr) + for idx, (offset, output_ptr) in enumerate(zip(self.output_arg_offsets, outputs_ptr)): + outputOffsetAndArg[idx] = (offset, output_ptr) + self._req.bind_args(inputOffsetAndArg) + self._req.bind_args(outputOffsetAndArg) self._req.finish_io_binding() self._req.run() self._req.sync() - # add None results to return values - rets = [] - none_cnt = 0 - result_cnt = 0 - for i in range(len(results) + len(self._none_indices)): - if none_cnt < len( - self._none_indices) and i == self._none_indices[none_cnt]: - rets.append(None) - none_cnt += 1 - else: - rets.append(results[result_cnt]) - result_cnt += 1 + rets = results + if len(rets) == 1: return rets[0] return rets diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiler.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiler.py index e8ba0283a..537b1c344 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiler.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/byteir_backend/compiler.py @@ -48,6 +48,7 @@ def byteir_compiler( partition_fn=byteir_partition_fn, #partition_fn=min_cut_rematerialization_partition, #partition_fn=default_partition, + keep_inference_input_mutations=False, ) fake_mode = detect_fake_mode( diff --git a/runtime/python/src/module.cc b/runtime/python/src/module.cc index 799a812f3..55e3ae1d6 100644 --- a/runtime/python/src/module.cc +++ b/runtime/python/src/module.cc @@ -256,6 +256,30 @@ PYBIND11_MODULE(MODULE_NAME, m) { THROW_ON_FAIL( req.Context().BindArg(offset, reinterpret_cast(ptr))); }) + .def("bind_args", + [](ReqeustContextWithSession &req, py::list offset_and_args) { + for (auto handle : offset_and_args) { + PyObject *obj = handle.ptr(); + if (!PyTuple_Check(obj) || PyTuple_Size(obj) != 2) { + PyErr_SetString(PyExc_TypeError, + "expect pair of offset and arg"); + return; + } + + PyObject *offset = PyTuple_GetItem(obj, 0); + PyObject *arg = PyTuple_GetItem(obj, 1); + if (!PyLong_Check(offset)) { + PyErr_SetString(PyExc_TypeError, "offset should be integer"); + return; + } + if (!PyLong_Check(arg)) { + PyErr_SetString(PyExc_TypeError, "arg should be integer"); + return; + } + THROW_ON_FAIL(req.Context().BindArg(PyLong_AsSize_t(offset), + PyLong_AsVoidPtr(arg))); + } + }) .def("get_arg", [](ReqeustContextWithSession &req, size_t offset) { void *ptr = req.Context().GetArg(offset);