diff --git a/merlin/systems/triton/conversions.py b/merlin/systems/triton/conversions.py index 5faf58099..a10b63ce0 100644 --- a/merlin/systems/triton/conversions.py +++ b/merlin/systems/triton/conversions.py @@ -25,6 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import itertools +from functools import singledispatch from typing import Any, Dict, List import numpy as np @@ -33,6 +34,7 @@ import merlin.dtypes as md from merlin.core.compat import cudf from merlin.core.compat import cupy as cp +from merlin.core.compat.torch import torch from merlin.core.dispatch import build_cudf_list_column, is_list_dtype from merlin.dag import Supports from merlin.schema import Schema @@ -135,19 +137,24 @@ def _from_values_offsets(values, offsets, shape): return values.reshape(new_shape) -def _to_values_offsets(array): +@singledispatch +def _to_values_offsets(values): """Convert array to values/offsets representation Parameters ---------- - array : numpy.ndarray or cupy.ndarray - Array to convert + values : array or tensor + Array or tensor to convert Returns ------- values, offsets Tuple of values and offsets """ + raise NotImplementedError(f"_to_values_offsets not implemented for {type(values)}") + + +def _to_values_offsets_array(array): num_rows = array.shape[0] row_lengths = [array.shape[1]] * num_rows offsets = [0] + list(itertools.accumulate(row_lengths)) @@ -157,6 +164,30 @@ def _to_values_offsets(array): return values, offsets +@_to_values_offsets.register(np.ndarray) +def _(array): + return _to_values_offsets_array(array) + + +if cp: + + @_to_values_offsets.register(cp.ndarray) + def _(array): + return _to_values_offsets_array(array) + + +if torch: + + @_to_values_offsets.register(torch.Tensor) + def _(tensor): + num_rows = tensor.shape[0] + row_lengths = [tensor.shape[1]] * num_rows + offsets = [0] + list(itertools.accumulate(row_lengths)) + offsets = torch.tensor(offsets, dtype=torch.int32, device=tensor.device) + values = tensor.reshape(-1, *tensor.shape[2:]) + return values, offsets + + def triton_request_to_tensor_table(request, schema): """ Turns a Triton request into a TensorTable by extracting individual tensors diff --git a/requirements/test-cpu.txt b/requirements/test-cpu.txt index e682e3549..871c5c051 100644 --- a/requirements/test-cpu.txt +++ b/requirements/test-cpu.txt @@ -1,6 +1,5 @@ -r test.txt -merlin-models>=0.6.0 faiss-cpu==1.7.2 tensorflow<=2.9.0 treelite==2.4.0 diff --git a/requirements/test-gpu.txt b/requirements/test-gpu.txt index ef1b403ac..28cd15508 100644 --- a/requirements/test-gpu.txt +++ b/requirements/test-gpu.txt @@ -1,4 +1,3 @@ -r test.txt -tensorflow faiss-gpu==1.7.2 diff --git a/requirements/test.txt b/requirements/test.txt index ea42bcb12..fc43407f2 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -18,5 +18,7 @@ feast==0.31 xgboost==1.6.2 implicit==0.6.0 +merlin-models[tensorflow,pytorch,transformers]@git+https://github.com/NVIDIA-Merlin/models.git + # TODO: do we need more of these? # https://github.com/NVIDIA-Merlin/Merlin/blob/a1cc48fe23c4dfc627423168436f26ef7e028204/ci/dockerfile.ci#L13-L18 diff --git a/tests/integration/t4r/test_pytorch_backend.py b/tests/integration/t4r/test_pytorch_backend.py index fe16f05a7..bb23bdf51 100644 --- a/tests/integration/t4r/test_pytorch_backend.py +++ b/tests/integration/t4r/test_pytorch_backend.py @@ -14,6 +14,8 @@ # limitations under the License. # +import shutil + import pytest np = pytest.importorskip("numpy") @@ -30,9 +32,13 @@ from merlin.core.dispatch import make_df # noqa from merlin.systems.dag import Ensemble # noqa from merlin.systems.dag.ops.pytorch import PredictPyTorch # noqa +from merlin.systems.triton.conversions import match_representations # noqa from merlin.systems.triton.utils import run_ensemble_on_tritonserver # noqa +TRITON_SERVER_PATH = shutil.which("tritonserver") + +@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found") def test_serve_t4r_with_torchscript(tmpdir): # =========================================== # Generate training data @@ -69,11 +75,12 @@ def test_serve_t4r_with_torchscript(tmpdir): model.eval() - traced_model = torch.jit.trace(model, torch_yoochoose_like, strict=True) + example_inputs = match_representations(model.input_schema, torch_yoochoose_like) + traced_model = torch.jit.trace(model, example_inputs, strict=True) assert isinstance(traced_model, torch.jit.TopLevelTracedModule) assert torch.allclose( - model(torch_yoochoose_like), - traced_model(torch_yoochoose_like), + model(example_inputs), + traced_model(example_inputs), ) # =========================================== diff --git a/tests/integration/tf/test_transformer_model.py b/tests/integration/tf/test_transformer_model.py index 5e4eb662d..e56b41ea7 100644 --- a/tests/integration/tf/test_transformer_model.py +++ b/tests/integration/tf/test_transformer_model.py @@ -14,6 +14,8 @@ # limitations under the License. # +import shutil + import pytest tf = pytest.importorskip("tensorflow") @@ -31,7 +33,10 @@ from merlin.systems.dag.ops.tensorflow import PredictTensorflow # noqa from merlin.systems.triton.utils import run_ensemble_on_tritonserver # noqa +TRITON_SERVER_PATH = shutil.which("tritonserver") + +@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found") def test_serve_tf_session_based_with_libtensorflow(tmpdir): # =========================================== diff --git a/tox.ini b/tox.ini index 35e20e9e2..657eac65d 100644 --- a/tox.ini +++ b/tox.ini @@ -47,6 +47,8 @@ sitepackages=true ; need to add some back. setenv = TF_GPU_ALLOCATOR=cuda_malloc_async +passenv = + OPAL_PREFIX deps = -rrequirements/test-gpu.txt pytest @@ -71,6 +73,9 @@ sitepackages=true ; need to add some back. setenv = TF_GPU_ALLOCATOR=cuda_malloc_async + LD_LIBRARY_PATH=/opt/tritonserver/backends/pytorch +passenv = + OPAL_PREFIX deps = pytest pytest-cov