From ed0ac64cc54369cfe12c9e95f642c5304be6c0c3 Mon Sep 17 00:00:00 2001 From: John <1200986+maxwillzq@users.noreply.github.com> Date: Wed, 23 Aug 2023 10:04:22 -0700 Subject: [PATCH 1/2] Created using Colaboratory --- docs/experimental_call_torch_tutorial.ipynb | 521 ++++++++++++++++++++ 1 file changed, 521 insertions(+) create mode 100644 docs/experimental_call_torch_tutorial.ipynb diff --git a/docs/experimental_call_torch_tutorial.ipynb b/docs/experimental_call_torch_tutorial.ipynb new file mode 100644 index 0000000..4d93bd1 --- /dev/null +++ b/docs/experimental_call_torch_tutorial.ipynb @@ -0,0 +1,521 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "Ym4Xu6mWg2Na" + }, + "outputs": [], + "source": [ + "# For tips on running notebooks in Google Colab, see\n", + "# https://pytorch.org/tutorials/beginner/colab\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PHmroXZmg2Nb" + }, + "source": [ + "\n", + "# jaxonnxruntime call_torch Tutorial\n", + "**Author:** John Zhang\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6ez7jdFpg2Nc" + }, + "source": [ + "Here we introduce the call_torch API which can seamlessly translate PyTorch models into JAX functions. This integration unites PyTorch with the extensive JAX software ecosystem and harnesses the power of XLA hardware (TPU/GPU/CPU and openXLA ), enhancing cross-framework collaboration and performance potential\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install git+https://github.com/google/jaxonnxruntime.git\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Kkfn_llxzdDa", + "outputId": "43397188-5fd7-40e9-93cf-212e43b1c1ee" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting git+https://github.com/google/jaxonnxruntime.git\n", + " Cloning https://github.com/google/jaxonnxruntime.git to /tmp/pip-req-build-yj76dq13\n", + " Running command git clone --filter=blob:none --quiet https://github.com/google/jaxonnxruntime.git /tmp/pip-req-build-yj76dq13\n", + " Resolved https://github.com/google/jaxonnxruntime.git to commit 2b95bac67150f865222a4bda87b66c099b64ae59\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from jaxonnxruntime==0.3.0) (1.23.5)\n", + "Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (from jaxonnxruntime==0.3.0) (0.3.25)\n", + "Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (from jaxonnxruntime==0.3.0) (0.3.25)\n", + "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax->jaxonnxruntime==0.3.0) (3.3.0)\n", + "Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.10/dist-packages (from jax->jaxonnxruntime==0.3.0) (1.10.1)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from jax->jaxonnxruntime==0.3.0) (4.7.1)\n", + "Building wheels for collected packages: jaxonnxruntime\n", + " Building wheel for jaxonnxruntime (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for jaxonnxruntime: filename=jaxonnxruntime-0.3.0-py3-none-any.whl size=177972 sha256=caba7590e9938cbca5365c19348155ab3d42c130fc1b19c45953963089c7066f\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-jf3xlcq6/wheels/43/d3/a6/40189cf2b24db631a69157a58da1d0fcc6d1df48abf847f8bd\n", + "Successfully built jaxonnxruntime\n", + "Installing collected packages: jaxonnxruntime\n", + "Successfully installed jaxonnxruntime-0.3.0\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install onnx torch" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1GX2-ZCwzmPA", + "outputId": "9439c8de-423a-4108-86f5-394d2da36913" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting onnx\n", + " Downloading onnx-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.6/14.6 MB\u001b[0m \u001b[31m33.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.0.1+cu118)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from onnx) (1.23.5)\n", + "Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx) (3.20.3)\n", + "Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.10/dist-packages (from onnx) (4.7.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.12.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.0.0)\n", + "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (3.27.2)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (16.0.6)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", + "Installing collected packages: onnx\n", + "Successfully installed onnx-1.14.0\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Yra_JCiTg2Nd" + }, + "source": [ + "## Basic Usage\n", + "\n", + "Generally, we describe all models with format. We use JAX PyTree data structure for any type model parameters and model inputs.Broadly, our approach involves characterizing all models using a standardized format. This entails employing the JAX PyTree data structure to encapsulate model parameters and inputs of varying types.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "4RF7lO5ng2Nd", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "99878430-cc1b-4fd8-ebf6-cb53ce396c83" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch_output: tensor([[ 9.2530e-01, 4.5873e-01, 6.4589e-01, 5.6415e-01, 1.6490e+00,\n", + " 8.2248e-01, 8.9589e-01, -5.7184e-01, 1.0232e+00, 5.1742e-01],\n", + " [ 9.5014e-02, 1.3660e+00, 1.3621e+00, 9.0590e-01, 8.4838e-02,\n", + " 1.2079e-01, 1.7469e+00, 1.1094e+00, 1.1676e+00, 1.6560e+00],\n", + " [ 1.2260e+00, 1.7275e+00, 1.2192e+00, -9.5381e-01, 1.1586e+00,\n", + " 1.3536e-01, -8.1443e-01, 3.7343e-01, 1.5365e+00, 1.2673e+00],\n", + " [ 1.9366e+00, 4.1828e-01, 7.5243e-01, -2.6371e-01, -1.1587e-03,\n", + " 1.8683e+00, 7.5635e-01, -6.5726e-01, 1.7267e+00, 1.4934e+00],\n", + " [ 1.7448e-01, 1.2264e+00, 1.5650e+00, -1.1248e-01, 8.0965e-01,\n", + " 6.4813e-01, 2.7031e-01, -2.6631e-01, 4.7319e-02, 3.2769e-02],\n", + " [ 4.5132e-01, 1.9266e+00, 1.6291e+00, 4.6194e-01, -1.2171e-01,\n", + " 1.8986e+00, 1.3777e-01, 4.9093e-01, -3.5940e-01, 8.6310e-01],\n", + " [ 4.1073e-01, 7.4641e-01, 8.6454e-02, 8.2659e-01, 1.5467e+00,\n", + " 9.5625e-01, -1.6194e-01, 1.4552e+00, 6.8996e-01, 1.0307e-01],\n", + " [ 1.9547e+00, -5.6088e-01, -3.6236e-01, 6.7257e-01, 2.4588e-01,\n", + " 1.6908e-01, 1.6637e+00, 1.3871e+00, 3.9015e-01, 9.6702e-01],\n", + " [ 5.3747e-01, 4.6549e-02, 1.6271e+00, 1.7466e+00, 4.4292e-01,\n", + " -9.8217e-01, 1.4013e-01, 6.7617e-03, -5.3100e-02, -1.1265e+00],\n", + " [ 5.3585e-01, -1.8007e-02, 8.2986e-01, 1.8453e+00, 7.6978e-01,\n", + " -1.0868e-02, 1.4607e+00, 4.8968e-01, 8.1474e-01, 1.1434e+00]])\n" + ] + } + ], + "source": [ + "import torch\n", + "import jax\n", + "from jaxonnxruntime.experimental import call_torch\n", + "\n", + "def foo(x, y):\n", + " a = torch.sin(x)\n", + " b = torch.cos(y)\n", + " return a + b\n", + "\n", + "torch_inputs =(torch.randn(10, 10), torch.randn(10, 10))\n", + "torch_module = torch.jit.trace(foo, torch_inputs)\n", + "\n", + "print(\"torch_output: \", torch_module(*torch_inputs))\n" + ] + }, + { + "cell_type": "code", + "source": [ + "jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)\n", + "jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)\n", + "print(\"jax_output:\", jax_fn(jax_params, jax_inputs))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kEEVNkLP1Owa", + "outputId": "96063361-a3e3-4b6b-97cb-62609fff171d" + }, + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:825: UserWarning: no signature found for , skipping _decide_input_format\n", + " warnings.warn(f\"{e}, skipping _decide_input_format\")\n", + "WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n", + "verbose: False, log level: Level.ERROR\n", + "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n", + "\n", + "jax_output: [DeviceArray([[ 9.2529660e-01, 4.5873073e-01, 6.4588535e-01,\n", + " 5.6415093e-01, 1.6489711e+00, 8.2248122e-01,\n", + " 8.9588922e-01, -5.7184368e-01, 1.0232372e+00,\n", + " 5.1741624e-01],\n", + " [ 9.5014393e-02, 1.3659939e+00, 1.3621219e+00,\n", + " 9.0589929e-01, 8.4837854e-02, 1.2079075e-01,\n", + " 1.7469316e+00, 1.1094404e+00, 1.1675845e+00,\n", + " 1.6560377e+00],\n", + " [ 1.2259831e+00, 1.7274783e+00, 1.2191784e+00,\n", + " -9.5381111e-01, 1.1585734e+00, 1.3535953e-01,\n", + " -8.1442708e-01, 3.7343296e-01, 1.5364575e+00,\n", + " 1.2673373e+00],\n", + " [ 1.9366202e+00, 4.1827971e-01, 7.5243413e-01,\n", + " -2.6371196e-01, -1.1587143e-03, 1.8683007e+00,\n", + " 7.5634706e-01, -6.5725970e-01, 1.7267224e+00,\n", + " 1.4934300e+00],\n", + " [ 1.7448190e-01, 1.2263532e+00, 1.5649725e+00,\n", + " -1.1247611e-01, 8.0965269e-01, 6.4813310e-01,\n", + " 2.7030846e-01, -2.6631176e-01, 4.7319174e-02,\n", + " 3.2769233e-02],\n", + " [ 4.5132229e-01, 1.9265618e+00, 1.6291361e+00,\n", + " 4.6194053e-01, -1.2170833e-01, 1.8985560e+00,\n", + " 1.3776600e-01, 4.9092823e-01, -3.5940361e-01,\n", + " 8.6309528e-01],\n", + " [ 4.1072559e-01, 7.4641228e-01, 8.6453676e-02,\n", + " 8.2658666e-01, 1.5467417e+00, 9.5624632e-01,\n", + " -1.6194339e-01, 1.4552250e+00, 6.8996412e-01,\n", + " 1.0307056e-01],\n", + " [ 1.9546964e+00, -5.6088006e-01, -3.6236224e-01,\n", + " 6.7256629e-01, 2.4587882e-01, 1.6908441e-01,\n", + " 1.6637435e+00, 1.3870629e+00, 3.9014500e-01,\n", + " 9.6701825e-01],\n", + " [ 5.3747320e-01, 4.6549439e-02, 1.6270701e+00,\n", + " 1.7466159e+00, 4.4291818e-01, -9.8216683e-01,\n", + " 1.4013010e-01, 6.7616701e-03, -5.3099811e-02,\n", + " -1.1264541e+00],\n", + " [ 5.3585029e-01, -1.8007398e-02, 8.2986158e-01,\n", + " 1.8452730e+00, 7.6977569e-01, -1.0867842e-02,\n", + " 1.4607116e+00, 4.8967782e-01, 8.1474060e-01,\n", + " 1.1434039e+00]], dtype=float32)]\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N0FJif4_g2Ne" + }, + "source": [ + "*We* can also take ``torch.nn.Module``.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "zR31ReFKg2Ne" + }, + "outputs": [], + "source": [ + "class MyModule(torch.nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.lin = torch.nn.Linear(100, 10)\n", + "\n", + " def forward(self, x):\n", + " return torch.nn.functional.relu(self.lin(x))\n", + "\n", + "torch_module = MyModule()\n", + "torch_inputs = (torch.randn(10, 100), )" + ] + }, + { + "cell_type": "code", + "source": [ + "torch_module.eval()\n", + "jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)\n", + "jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)\n", + "print(\"jax_output:\", jax_fn(jax_params, jax_inputs))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "H--BoG-E1d6w", + "outputId": "6ff9f51c-0c7e-4ab1-b404-81e175587365" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n", + "verbose: False, log level: Level.ERROR\n", + "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n", + "\n", + "jax_output: [DeviceArray([[0.12982486, 0.481035 , 0.3622095 , 0.7990376 , 0.607528 ,\n", + " 0. , 0. , 0.2921514 , 0.56446004, 0. ],\n", + " [0. , 0.57425183, 0. , 0.9224024 , 0. ,\n", + " 0.39311224, 0.11618385, 0. , 0.6319629 , 0.29966408],\n", + " [0.49294975, 0.36862767, 0.2809724 , 0. , 0. ,\n", + " 0.33000746, 0.4940635 , 0.02182353, 0. , 0. ],\n", + " [0. , 0.3759548 , 0. , 0.23886062, 0. ,\n", + " 0. , 0.35155773, 0. , 0.24100977, 0.15047333],\n", + " [0.40354684, 0. , 0.12827398, 0. , 0.07742476,\n", + " 0.6966675 , 0. , 0.00607699, 0. , 0.59710497],\n", + " [0. , 0.5936287 , 0. , 0. , 0. ,\n", + " 0.5815142 , 0.2761725 , 0.47168115, 0. , 0.26667207],\n", + " [0.06918865, 0. , 0.43948644, 0. , 0.5058311 ,\n", + " 0.09885295, 0.40746492, 0.30387542, 0.45276335, 0.10408825],\n", + " [0.24983016, 0.14137569, 0. , 0.72815377, 0.8114418 ,\n", + " 0.5519933 , 0. , 0. , 0.08975451, 0. ],\n", + " [0.4778166 , 0. , 0.8714261 , 0. , 0.3690765 ,\n", + " 0.2697782 , 0.13372795, 0. , 0. , 0. ],\n", + " [0. , 0.88346595, 0.2436127 , 0. , 0. ,\n", + " 0.41404647, 0.61342806, 1.4032906 , 0.00577256, 0.31158522]], dtype=float32)]\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# A real testing model\n", + "\n" + ], + "metadata": { + "id": "g1LoU8pa3-uK" + } + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "from torchvision.models import resnet50\n", + "# Generates random input and targets data for the model, where `b` is\n", + "# batch size.\n", + "\n", + "def generate_data(b):\n", + " return (\n", + " torch.randn(b, 3, 128, 128).to(torch.float32),\n", + " )\n", + "\n", + "torch_inputs = generate_data(1)\n", + "torch_module = resnet50()\n", + "torch_module.eval()\n", + "torch_module = torch.jit.trace(torch_module, torch_inputs)\n", + "torch_outputs = [torch_module(*torch_inputs)]\n", + "\n" + ], + "metadata": { + "id": "p03lb4Ix4CvN" + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from jaxonnxruntime.experimental import call_torch\n", + "import jax\n", + "jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)\n", + "jax_fn = jax.jit(jax_fn)\n", + "jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)\n", + "jax_outputs = jax_fn(jax_params, jax_inputs)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "d1UB2by-4Q8R", + "outputId": "43f18570-8fce-46d3-b456-331e9cccb51a" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:825: UserWarning: no signature found for , skipping _decide_input_format\n", + " warnings.warn(f\"{e}, skipping _decide_input_format\")\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n", + "verbose: False, log level: Level.ERROR\n", + "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "from jaxonnxruntime.experimental.call_torch import CallTorchTestCase\n", + "test_case = CallTorchTestCase()\n", + "test_case.assert_allclose(jax.tree_map(call_torch.torch_tensor_to_np_array,torch_outputs), jax_outputs, rtol=1e-07, atol=1e-03)\n" + ], + "metadata": { + "id": "GrNFj9z27FTC" + }, + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%timeit _ = torch_module(*torch_inputs)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "P-aPak7H9FkV", + "outputId": "930fb187-6fc0-42b8-dbbc-bd277b806a29" + }, + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "93.7 ms ± 23.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "%timeit _ = jax_fn(jax_params, jax_inputs)\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LAEbgDR59LPV", + "outputId": "813556b9-ca57-4f99-858c-ad9e0a234eb2" + }, + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "258 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "aoQ9-ggb97KP" + }, + "execution_count": 12, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "colab": { + "provenance": [], + "include_colab_link": true + }, + "accelerator": "TPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file From 8b556ee4401fc06aed35d3432361927ad129fb89 Mon Sep 17 00:00:00 2001 From: John <1200986+maxwillzq@users.noreply.github.com> Date: Wed, 23 Aug 2023 10:48:57 -0700 Subject: [PATCH 2/2] Created using Colaboratory --- docs/experimental_call_torch_tutorial.ipynb | 296 ++------------------ 1 file changed, 29 insertions(+), 267 deletions(-) diff --git a/docs/experimental_call_torch_tutorial.ipynb b/docs/experimental_call_torch_tutorial.ipynb index 4d93bd1..8e92b44 100644 --- a/docs/experimental_call_torch_tutorial.ipynb +++ b/docs/experimental_call_torch_tutorial.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "id": "Ym4Xu6mWg2Na" }, @@ -51,42 +51,10 @@ "!pip install git+https://github.com/google/jaxonnxruntime.git\n" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Kkfn_llxzdDa", - "outputId": "43397188-5fd7-40e9-93cf-212e43b1c1ee" + "id": "Kkfn_llxzdDa" }, - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting git+https://github.com/google/jaxonnxruntime.git\n", - " Cloning https://github.com/google/jaxonnxruntime.git to /tmp/pip-req-build-yj76dq13\n", - " Running command git clone --filter=blob:none --quiet https://github.com/google/jaxonnxruntime.git /tmp/pip-req-build-yj76dq13\n", - " Resolved https://github.com/google/jaxonnxruntime.git to commit 2b95bac67150f865222a4bda87b66c099b64ae59\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from jaxonnxruntime==0.3.0) (1.23.5)\n", - "Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (from jaxonnxruntime==0.3.0) (0.3.25)\n", - "Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (from jaxonnxruntime==0.3.0) (0.3.25)\n", - "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax->jaxonnxruntime==0.3.0) (3.3.0)\n", - "Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.10/dist-packages (from jax->jaxonnxruntime==0.3.0) (1.10.1)\n", - "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from jax->jaxonnxruntime==0.3.0) (4.7.1)\n", - "Building wheels for collected packages: jaxonnxruntime\n", - " Building wheel for jaxonnxruntime (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Created wheel for jaxonnxruntime: filename=jaxonnxruntime-0.3.0-py3-none-any.whl size=177972 sha256=caba7590e9938cbca5365c19348155ab3d42c130fc1b19c45953963089c7066f\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-jf3xlcq6/wheels/43/d3/a6/40189cf2b24db631a69157a58da1d0fcc6d1df48abf847f8bd\n", - "Successfully built jaxonnxruntime\n", - "Installing collected packages: jaxonnxruntime\n", - "Successfully installed jaxonnxruntime-0.3.0\n" - ] - } - ] + "execution_count": null, + "outputs": [] }, { "cell_type": "code", @@ -94,39 +62,10 @@ "!pip install onnx torch" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "1GX2-ZCwzmPA", - "outputId": "9439c8de-423a-4108-86f5-394d2da36913" + "id": "1GX2-ZCwzmPA" }, - "execution_count": 3, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting onnx\n", - " Downloading onnx-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.6/14.6 MB\u001b[0m \u001b[31m33.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.0.1+cu118)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from onnx) (1.23.5)\n", - "Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx) (3.20.3)\n", - "Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.10/dist-packages (from onnx) (4.7.1)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.12.2)\n", - "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.1)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n", - "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.0.0)\n", - "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (3.27.2)\n", - "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (16.0.6)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n", - "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", - "Installing collected packages: onnx\n", - "Successfully installed onnx-1.14.0\n" - ] - } - ] + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", @@ -141,42 +80,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { - "id": "4RF7lO5ng2Nd", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "99878430-cc1b-4fd8-ebf6-cb53ce396c83" + "id": "4RF7lO5ng2Nd" }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch_output: tensor([[ 9.2530e-01, 4.5873e-01, 6.4589e-01, 5.6415e-01, 1.6490e+00,\n", - " 8.2248e-01, 8.9589e-01, -5.7184e-01, 1.0232e+00, 5.1742e-01],\n", - " [ 9.5014e-02, 1.3660e+00, 1.3621e+00, 9.0590e-01, 8.4838e-02,\n", - " 1.2079e-01, 1.7469e+00, 1.1094e+00, 1.1676e+00, 1.6560e+00],\n", - " [ 1.2260e+00, 1.7275e+00, 1.2192e+00, -9.5381e-01, 1.1586e+00,\n", - " 1.3536e-01, -8.1443e-01, 3.7343e-01, 1.5365e+00, 1.2673e+00],\n", - " [ 1.9366e+00, 4.1828e-01, 7.5243e-01, -2.6371e-01, -1.1587e-03,\n", - " 1.8683e+00, 7.5635e-01, -6.5726e-01, 1.7267e+00, 1.4934e+00],\n", - " [ 1.7448e-01, 1.2264e+00, 1.5650e+00, -1.1248e-01, 8.0965e-01,\n", - " 6.4813e-01, 2.7031e-01, -2.6631e-01, 4.7319e-02, 3.2769e-02],\n", - " [ 4.5132e-01, 1.9266e+00, 1.6291e+00, 4.6194e-01, -1.2171e-01,\n", - " 1.8986e+00, 1.3777e-01, 4.9093e-01, -3.5940e-01, 8.6310e-01],\n", - " [ 4.1073e-01, 7.4641e-01, 8.6454e-02, 8.2659e-01, 1.5467e+00,\n", - " 9.5625e-01, -1.6194e-01, 1.4552e+00, 6.8996e-01, 1.0307e-01],\n", - " [ 1.9547e+00, -5.6088e-01, -3.6236e-01, 6.7257e-01, 2.4588e-01,\n", - " 1.6908e-01, 1.6637e+00, 1.3871e+00, 3.9015e-01, 9.6702e-01],\n", - " [ 5.3747e-01, 4.6549e-02, 1.6271e+00, 1.7466e+00, 4.4292e-01,\n", - " -9.8217e-01, 1.4013e-01, 6.7617e-03, -5.3100e-02, -1.1265e+00],\n", - " [ 5.3585e-01, -1.8007e-02, 8.2986e-01, 1.8453e+00, 7.6978e-01,\n", - " -1.0868e-02, 1.4607e+00, 4.8968e-01, 8.1474e-01, 1.1434e+00]])\n" - ] - } - ], + "outputs": [], "source": [ "import torch\n", "import jax\n", @@ -201,74 +109,10 @@ "print(\"jax_output:\", jax_fn(jax_params, jax_inputs))" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "kEEVNkLP1Owa", - "outputId": "96063361-a3e3-4b6b-97cb-62609fff171d" + "id": "kEEVNkLP1Owa" }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:825: UserWarning: no signature found for , skipping _decide_input_format\n", - " warnings.warn(f\"{e}, skipping _decide_input_format\")\n", - "WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n", - "verbose: False, log level: Level.ERROR\n", - "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n", - "\n", - "jax_output: [DeviceArray([[ 9.2529660e-01, 4.5873073e-01, 6.4588535e-01,\n", - " 5.6415093e-01, 1.6489711e+00, 8.2248122e-01,\n", - " 8.9588922e-01, -5.7184368e-01, 1.0232372e+00,\n", - " 5.1741624e-01],\n", - " [ 9.5014393e-02, 1.3659939e+00, 1.3621219e+00,\n", - " 9.0589929e-01, 8.4837854e-02, 1.2079075e-01,\n", - " 1.7469316e+00, 1.1094404e+00, 1.1675845e+00,\n", - " 1.6560377e+00],\n", - " [ 1.2259831e+00, 1.7274783e+00, 1.2191784e+00,\n", - " -9.5381111e-01, 1.1585734e+00, 1.3535953e-01,\n", - " -8.1442708e-01, 3.7343296e-01, 1.5364575e+00,\n", - " 1.2673373e+00],\n", - " [ 1.9366202e+00, 4.1827971e-01, 7.5243413e-01,\n", - " -2.6371196e-01, -1.1587143e-03, 1.8683007e+00,\n", - " 7.5634706e-01, -6.5725970e-01, 1.7267224e+00,\n", - " 1.4934300e+00],\n", - " [ 1.7448190e-01, 1.2263532e+00, 1.5649725e+00,\n", - " -1.1247611e-01, 8.0965269e-01, 6.4813310e-01,\n", - " 2.7030846e-01, -2.6631176e-01, 4.7319174e-02,\n", - " 3.2769233e-02],\n", - " [ 4.5132229e-01, 1.9265618e+00, 1.6291361e+00,\n", - " 4.6194053e-01, -1.2170833e-01, 1.8985560e+00,\n", - " 1.3776600e-01, 4.9092823e-01, -3.5940361e-01,\n", - " 8.6309528e-01],\n", - " [ 4.1072559e-01, 7.4641228e-01, 8.6453676e-02,\n", - " 8.2658666e-01, 1.5467417e+00, 9.5624632e-01,\n", - " -1.6194339e-01, 1.4552250e+00, 6.8996412e-01,\n", - " 1.0307056e-01],\n", - " [ 1.9546964e+00, -5.6088006e-01, -3.6236224e-01,\n", - " 6.7256629e-01, 2.4587882e-01, 1.6908441e-01,\n", - " 1.6637435e+00, 1.3870629e+00, 3.9014500e-01,\n", - " 9.6701825e-01],\n", - " [ 5.3747320e-01, 4.6549439e-02, 1.6270701e+00,\n", - " 1.7466159e+00, 4.4291818e-01, -9.8216683e-01,\n", - " 1.4013010e-01, 6.7616701e-03, -5.3099811e-02,\n", - " -1.1264541e+00],\n", - " [ 5.3585029e-01, -1.8007398e-02, 8.2986158e-01,\n", - " 1.8452730e+00, 7.6977569e-01, -1.0867842e-02,\n", - " 1.4607116e+00, 4.8967782e-01, 8.1474060e-01,\n", - " 1.1434039e+00]], dtype=float32)]\n" - ] - } - ] + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", @@ -282,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "id": "zR31ReFKg2Ne" }, @@ -309,45 +153,10 @@ "print(\"jax_output:\", jax_fn(jax_params, jax_inputs))" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "H--BoG-E1d6w", - "outputId": "6ff9f51c-0c7e-4ab1-b404-81e175587365" + "id": "H--BoG-E1d6w" }, - "execution_count": 7, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n", - "verbose: False, log level: Level.ERROR\n", - "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n", - "\n", - "jax_output: [DeviceArray([[0.12982486, 0.481035 , 0.3622095 , 0.7990376 , 0.607528 ,\n", - " 0. , 0. , 0.2921514 , 0.56446004, 0. ],\n", - " [0. , 0.57425183, 0. , 0.9224024 , 0. ,\n", - " 0.39311224, 0.11618385, 0. , 0.6319629 , 0.29966408],\n", - " [0.49294975, 0.36862767, 0.2809724 , 0. , 0. ,\n", - " 0.33000746, 0.4940635 , 0.02182353, 0. , 0. ],\n", - " [0. , 0.3759548 , 0. , 0.23886062, 0. ,\n", - " 0. , 0.35155773, 0. , 0.24100977, 0.15047333],\n", - " [0.40354684, 0. , 0.12827398, 0. , 0.07742476,\n", - " 0.6966675 , 0. , 0.00607699, 0. , 0.59710497],\n", - " [0. , 0.5936287 , 0. , 0. , 0. ,\n", - " 0.5815142 , 0.2761725 , 0.47168115, 0. , 0.26667207],\n", - " [0.06918865, 0. , 0.43948644, 0. , 0.5058311 ,\n", - " 0.09885295, 0.40746492, 0.30387542, 0.45276335, 0.10408825],\n", - " [0.24983016, 0.14137569, 0. , 0.72815377, 0.8114418 ,\n", - " 0.5519933 , 0. , 0. , 0.08975451, 0. ],\n", - " [0.4778166 , 0. , 0.8714261 , 0. , 0.3690765 ,\n", - " 0.2697782 , 0.13372795, 0. , 0. , 0. ],\n", - " [0. , 0.88346595, 0.2436127 , 0. , 0. ,\n", - " 0.41404647, 0.61342806, 1.4032906 , 0.00577256, 0.31158522]], dtype=float32)]\n" - ] - } - ] + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", @@ -382,7 +191,7 @@ "metadata": { "id": "p03lb4Ix4CvN" }, - "execution_count": 8, + "execution_count": null, "outputs": [] }, { @@ -396,33 +205,10 @@ "jax_outputs = jax_fn(jax_params, jax_inputs)" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "d1UB2by-4Q8R", - "outputId": "43f18570-8fce-46d3-b456-331e9cccb51a" + "id": "d1UB2by-4Q8R" }, - "execution_count": 9, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:825: UserWarning: no signature found for , skipping _decide_input_format\n", - " warnings.warn(f\"{e}, skipping _decide_input_format\")\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n", - "verbose: False, log level: Level.ERROR\n", - "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n", - "\n" - ] - } - ] + "execution_count": null, + "outputs": [] }, { "cell_type": "code", @@ -434,7 +220,7 @@ "metadata": { "id": "GrNFj9z27FTC" }, - "execution_count": 10, + "execution_count": null, "outputs": [] }, { @@ -443,22 +229,10 @@ "%timeit _ = torch_module(*torch_inputs)" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "P-aPak7H9FkV", - "outputId": "930fb187-6fc0-42b8-dbbc-bd277b806a29" + "id": "P-aPak7H9FkV" }, - "execution_count": 11, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "93.7 ms ± 23.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ] + "execution_count": null, + "outputs": [] }, { "cell_type": "code", @@ -466,22 +240,10 @@ "%timeit _ = jax_fn(jax_params, jax_inputs)\n" ], "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "LAEbgDR59LPV", - "outputId": "813556b9-ca57-4f99-858c-ad9e0a234eb2" + "id": "LAEbgDR59LPV" }, - "execution_count": 12, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "258 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" - ] - } - ] + "execution_count": null, + "outputs": [] }, { "cell_type": "code", @@ -489,7 +251,7 @@ "metadata": { "id": "aoQ9-ggb97KP" }, - "execution_count": 12, + "execution_count": null, "outputs": [] } ],