From ed0ac64cc54369cfe12c9e95f642c5304be6c0c3 Mon Sep 17 00:00:00 2001
From: John <1200986+maxwillzq@users.noreply.github.com>
Date: Wed, 23 Aug 2023 10:04:22 -0700
Subject: [PATCH 1/2] Created using Colaboratory
---
docs/experimental_call_torch_tutorial.ipynb | 521 ++++++++++++++++++++
1 file changed, 521 insertions(+)
create mode 100644 docs/experimental_call_torch_tutorial.ipynb
diff --git a/docs/experimental_call_torch_tutorial.ipynb b/docs/experimental_call_torch_tutorial.ipynb
new file mode 100644
index 0000000..4d93bd1
--- /dev/null
+++ b/docs/experimental_call_torch_tutorial.ipynb
@@ -0,0 +1,521 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "Ym4Xu6mWg2Na"
+ },
+ "outputs": [],
+ "source": [
+ "# For tips on running notebooks in Google Colab, see\n",
+ "# https://pytorch.org/tutorials/beginner/colab\n",
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PHmroXZmg2Nb"
+ },
+ "source": [
+ "\n",
+ "# jaxonnxruntime call_torch Tutorial\n",
+ "**Author:** John Zhang\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6ez7jdFpg2Nc"
+ },
+ "source": [
+ "Here we introduce the call_torch API which can seamlessly translate PyTorch models into JAX functions. This integration unites PyTorch with the extensive JAX software ecosystem and harnesses the power of XLA hardware (TPU/GPU/CPU and openXLA ), enhancing cross-framework collaboration and performance potential\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install git+https://github.com/google/jaxonnxruntime.git\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Kkfn_llxzdDa",
+ "outputId": "43397188-5fd7-40e9-93cf-212e43b1c1ee"
+ },
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting git+https://github.com/google/jaxonnxruntime.git\n",
+ " Cloning https://github.com/google/jaxonnxruntime.git to /tmp/pip-req-build-yj76dq13\n",
+ " Running command git clone --filter=blob:none --quiet https://github.com/google/jaxonnxruntime.git /tmp/pip-req-build-yj76dq13\n",
+ " Resolved https://github.com/google/jaxonnxruntime.git to commit 2b95bac67150f865222a4bda87b66c099b64ae59\n",
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from jaxonnxruntime==0.3.0) (1.23.5)\n",
+ "Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (from jaxonnxruntime==0.3.0) (0.3.25)\n",
+ "Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (from jaxonnxruntime==0.3.0) (0.3.25)\n",
+ "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax->jaxonnxruntime==0.3.0) (3.3.0)\n",
+ "Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.10/dist-packages (from jax->jaxonnxruntime==0.3.0) (1.10.1)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from jax->jaxonnxruntime==0.3.0) (4.7.1)\n",
+ "Building wheels for collected packages: jaxonnxruntime\n",
+ " Building wheel for jaxonnxruntime (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for jaxonnxruntime: filename=jaxonnxruntime-0.3.0-py3-none-any.whl size=177972 sha256=caba7590e9938cbca5365c19348155ab3d42c130fc1b19c45953963089c7066f\n",
+ " Stored in directory: /tmp/pip-ephem-wheel-cache-jf3xlcq6/wheels/43/d3/a6/40189cf2b24db631a69157a58da1d0fcc6d1df48abf847f8bd\n",
+ "Successfully built jaxonnxruntime\n",
+ "Installing collected packages: jaxonnxruntime\n",
+ "Successfully installed jaxonnxruntime-0.3.0\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install onnx torch"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1GX2-ZCwzmPA",
+ "outputId": "9439c8de-423a-4108-86f5-394d2da36913"
+ },
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting onnx\n",
+ " Downloading onnx-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.6/14.6 MB\u001b[0m \u001b[31m33.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.0.1+cu118)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from onnx) (1.23.5)\n",
+ "Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx) (3.20.3)\n",
+ "Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.10/dist-packages (from onnx) (4.7.1)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.12.2)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.1)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n",
+ "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.0.0)\n",
+ "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (3.27.2)\n",
+ "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (16.0.6)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n",
+ "Installing collected packages: onnx\n",
+ "Successfully installed onnx-1.14.0\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Yra_JCiTg2Nd"
+ },
+ "source": [
+ "## Basic Usage\n",
+ "\n",
+ "Generally, we describe all models with format. We use JAX PyTree data structure for any type model parameters and model inputs.Broadly, our approach involves characterizing all models using a standardized format. This entails employing the JAX PyTree data structure to encapsulate model parameters and inputs of varying types.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "id": "4RF7lO5ng2Nd",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "99878430-cc1b-4fd8-ebf6-cb53ce396c83"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "torch_output: tensor([[ 9.2530e-01, 4.5873e-01, 6.4589e-01, 5.6415e-01, 1.6490e+00,\n",
+ " 8.2248e-01, 8.9589e-01, -5.7184e-01, 1.0232e+00, 5.1742e-01],\n",
+ " [ 9.5014e-02, 1.3660e+00, 1.3621e+00, 9.0590e-01, 8.4838e-02,\n",
+ " 1.2079e-01, 1.7469e+00, 1.1094e+00, 1.1676e+00, 1.6560e+00],\n",
+ " [ 1.2260e+00, 1.7275e+00, 1.2192e+00, -9.5381e-01, 1.1586e+00,\n",
+ " 1.3536e-01, -8.1443e-01, 3.7343e-01, 1.5365e+00, 1.2673e+00],\n",
+ " [ 1.9366e+00, 4.1828e-01, 7.5243e-01, -2.6371e-01, -1.1587e-03,\n",
+ " 1.8683e+00, 7.5635e-01, -6.5726e-01, 1.7267e+00, 1.4934e+00],\n",
+ " [ 1.7448e-01, 1.2264e+00, 1.5650e+00, -1.1248e-01, 8.0965e-01,\n",
+ " 6.4813e-01, 2.7031e-01, -2.6631e-01, 4.7319e-02, 3.2769e-02],\n",
+ " [ 4.5132e-01, 1.9266e+00, 1.6291e+00, 4.6194e-01, -1.2171e-01,\n",
+ " 1.8986e+00, 1.3777e-01, 4.9093e-01, -3.5940e-01, 8.6310e-01],\n",
+ " [ 4.1073e-01, 7.4641e-01, 8.6454e-02, 8.2659e-01, 1.5467e+00,\n",
+ " 9.5625e-01, -1.6194e-01, 1.4552e+00, 6.8996e-01, 1.0307e-01],\n",
+ " [ 1.9547e+00, -5.6088e-01, -3.6236e-01, 6.7257e-01, 2.4588e-01,\n",
+ " 1.6908e-01, 1.6637e+00, 1.3871e+00, 3.9015e-01, 9.6702e-01],\n",
+ " [ 5.3747e-01, 4.6549e-02, 1.6271e+00, 1.7466e+00, 4.4292e-01,\n",
+ " -9.8217e-01, 1.4013e-01, 6.7617e-03, -5.3100e-02, -1.1265e+00],\n",
+ " [ 5.3585e-01, -1.8007e-02, 8.2986e-01, 1.8453e+00, 7.6978e-01,\n",
+ " -1.0868e-02, 1.4607e+00, 4.8968e-01, 8.1474e-01, 1.1434e+00]])\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import jax\n",
+ "from jaxonnxruntime.experimental import call_torch\n",
+ "\n",
+ "def foo(x, y):\n",
+ " a = torch.sin(x)\n",
+ " b = torch.cos(y)\n",
+ " return a + b\n",
+ "\n",
+ "torch_inputs =(torch.randn(10, 10), torch.randn(10, 10))\n",
+ "torch_module = torch.jit.trace(foo, torch_inputs)\n",
+ "\n",
+ "print(\"torch_output: \", torch_module(*torch_inputs))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)\n",
+ "jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)\n",
+ "print(\"jax_output:\", jax_fn(jax_params, jax_inputs))"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "kEEVNkLP1Owa",
+ "outputId": "96063361-a3e3-4b6b-97cb-62609fff171d"
+ },
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:825: UserWarning: no signature found for , skipping _decide_input_format\n",
+ " warnings.warn(f\"{e}, skipping _decide_input_format\")\n",
+ "WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n",
+ "verbose: False, log level: Level.ERROR\n",
+ "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n",
+ "\n",
+ "jax_output: [DeviceArray([[ 9.2529660e-01, 4.5873073e-01, 6.4588535e-01,\n",
+ " 5.6415093e-01, 1.6489711e+00, 8.2248122e-01,\n",
+ " 8.9588922e-01, -5.7184368e-01, 1.0232372e+00,\n",
+ " 5.1741624e-01],\n",
+ " [ 9.5014393e-02, 1.3659939e+00, 1.3621219e+00,\n",
+ " 9.0589929e-01, 8.4837854e-02, 1.2079075e-01,\n",
+ " 1.7469316e+00, 1.1094404e+00, 1.1675845e+00,\n",
+ " 1.6560377e+00],\n",
+ " [ 1.2259831e+00, 1.7274783e+00, 1.2191784e+00,\n",
+ " -9.5381111e-01, 1.1585734e+00, 1.3535953e-01,\n",
+ " -8.1442708e-01, 3.7343296e-01, 1.5364575e+00,\n",
+ " 1.2673373e+00],\n",
+ " [ 1.9366202e+00, 4.1827971e-01, 7.5243413e-01,\n",
+ " -2.6371196e-01, -1.1587143e-03, 1.8683007e+00,\n",
+ " 7.5634706e-01, -6.5725970e-01, 1.7267224e+00,\n",
+ " 1.4934300e+00],\n",
+ " [ 1.7448190e-01, 1.2263532e+00, 1.5649725e+00,\n",
+ " -1.1247611e-01, 8.0965269e-01, 6.4813310e-01,\n",
+ " 2.7030846e-01, -2.6631176e-01, 4.7319174e-02,\n",
+ " 3.2769233e-02],\n",
+ " [ 4.5132229e-01, 1.9265618e+00, 1.6291361e+00,\n",
+ " 4.6194053e-01, -1.2170833e-01, 1.8985560e+00,\n",
+ " 1.3776600e-01, 4.9092823e-01, -3.5940361e-01,\n",
+ " 8.6309528e-01],\n",
+ " [ 4.1072559e-01, 7.4641228e-01, 8.6453676e-02,\n",
+ " 8.2658666e-01, 1.5467417e+00, 9.5624632e-01,\n",
+ " -1.6194339e-01, 1.4552250e+00, 6.8996412e-01,\n",
+ " 1.0307056e-01],\n",
+ " [ 1.9546964e+00, -5.6088006e-01, -3.6236224e-01,\n",
+ " 6.7256629e-01, 2.4587882e-01, 1.6908441e-01,\n",
+ " 1.6637435e+00, 1.3870629e+00, 3.9014500e-01,\n",
+ " 9.6701825e-01],\n",
+ " [ 5.3747320e-01, 4.6549439e-02, 1.6270701e+00,\n",
+ " 1.7466159e+00, 4.4291818e-01, -9.8216683e-01,\n",
+ " 1.4013010e-01, 6.7616701e-03, -5.3099811e-02,\n",
+ " -1.1264541e+00],\n",
+ " [ 5.3585029e-01, -1.8007398e-02, 8.2986158e-01,\n",
+ " 1.8452730e+00, 7.6977569e-01, -1.0867842e-02,\n",
+ " 1.4607116e+00, 4.8967782e-01, 8.1474060e-01,\n",
+ " 1.1434039e+00]], dtype=float32)]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "N0FJif4_g2Ne"
+ },
+ "source": [
+ "*We* can also take ``torch.nn.Module``.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "zR31ReFKg2Ne"
+ },
+ "outputs": [],
+ "source": [
+ "class MyModule(torch.nn.Module):\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.lin = torch.nn.Linear(100, 10)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return torch.nn.functional.relu(self.lin(x))\n",
+ "\n",
+ "torch_module = MyModule()\n",
+ "torch_inputs = (torch.randn(10, 100), )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "torch_module.eval()\n",
+ "jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)\n",
+ "jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)\n",
+ "print(\"jax_output:\", jax_fn(jax_params, jax_inputs))"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "H--BoG-E1d6w",
+ "outputId": "6ff9f51c-0c7e-4ab1-b404-81e175587365"
+ },
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n",
+ "verbose: False, log level: Level.ERROR\n",
+ "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n",
+ "\n",
+ "jax_output: [DeviceArray([[0.12982486, 0.481035 , 0.3622095 , 0.7990376 , 0.607528 ,\n",
+ " 0. , 0. , 0.2921514 , 0.56446004, 0. ],\n",
+ " [0. , 0.57425183, 0. , 0.9224024 , 0. ,\n",
+ " 0.39311224, 0.11618385, 0. , 0.6319629 , 0.29966408],\n",
+ " [0.49294975, 0.36862767, 0.2809724 , 0. , 0. ,\n",
+ " 0.33000746, 0.4940635 , 0.02182353, 0. , 0. ],\n",
+ " [0. , 0.3759548 , 0. , 0.23886062, 0. ,\n",
+ " 0. , 0.35155773, 0. , 0.24100977, 0.15047333],\n",
+ " [0.40354684, 0. , 0.12827398, 0. , 0.07742476,\n",
+ " 0.6966675 , 0. , 0.00607699, 0. , 0.59710497],\n",
+ " [0. , 0.5936287 , 0. , 0. , 0. ,\n",
+ " 0.5815142 , 0.2761725 , 0.47168115, 0. , 0.26667207],\n",
+ " [0.06918865, 0. , 0.43948644, 0. , 0.5058311 ,\n",
+ " 0.09885295, 0.40746492, 0.30387542, 0.45276335, 0.10408825],\n",
+ " [0.24983016, 0.14137569, 0. , 0.72815377, 0.8114418 ,\n",
+ " 0.5519933 , 0. , 0. , 0.08975451, 0. ],\n",
+ " [0.4778166 , 0. , 0.8714261 , 0. , 0.3690765 ,\n",
+ " 0.2697782 , 0.13372795, 0. , 0. , 0. ],\n",
+ " [0. , 0.88346595, 0.2436127 , 0. , 0. ,\n",
+ " 0.41404647, 0.61342806, 1.4032906 , 0.00577256, 0.31158522]], dtype=float32)]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# A real testing model\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "g1LoU8pa3-uK"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "from torchvision.models import resnet50\n",
+ "# Generates random input and targets data for the model, where `b` is\n",
+ "# batch size.\n",
+ "\n",
+ "def generate_data(b):\n",
+ " return (\n",
+ " torch.randn(b, 3, 128, 128).to(torch.float32),\n",
+ " )\n",
+ "\n",
+ "torch_inputs = generate_data(1)\n",
+ "torch_module = resnet50()\n",
+ "torch_module.eval()\n",
+ "torch_module = torch.jit.trace(torch_module, torch_inputs)\n",
+ "torch_outputs = [torch_module(*torch_inputs)]\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "p03lb4Ix4CvN"
+ },
+ "execution_count": 8,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from jaxonnxruntime.experimental import call_torch\n",
+ "import jax\n",
+ "jax_fn, jax_params = call_torch.call_torch(torch_module, torch_inputs)\n",
+ "jax_fn = jax.jit(jax_fn)\n",
+ "jax_inputs = jax.tree_map(call_torch.torch_tensor_to_np_array, torch_inputs)\n",
+ "jax_outputs = jax_fn(jax_params, jax_inputs)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "d1UB2by-4Q8R",
+ "outputId": "43f18570-8fce-46d3-b456-331e9cccb51a"
+ },
+ "execution_count": 9,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:825: UserWarning: no signature found for , skipping _decide_input_format\n",
+ " warnings.warn(f\"{e}, skipping _decide_input_format\")\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n",
+ "verbose: False, log level: Level.ERROR\n",
+ "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n",
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from jaxonnxruntime.experimental.call_torch import CallTorchTestCase\n",
+ "test_case = CallTorchTestCase()\n",
+ "test_case.assert_allclose(jax.tree_map(call_torch.torch_tensor_to_np_array,torch_outputs), jax_outputs, rtol=1e-07, atol=1e-03)\n"
+ ],
+ "metadata": {
+ "id": "GrNFj9z27FTC"
+ },
+ "execution_count": 10,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%timeit _ = torch_module(*torch_inputs)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "P-aPak7H9FkV",
+ "outputId": "930fb187-6fc0-42b8-dbbc-bd277b806a29"
+ },
+ "execution_count": 11,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "93.7 ms ± 23.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%timeit _ = jax_fn(jax_params, jax_inputs)\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "LAEbgDR59LPV",
+ "outputId": "813556b9-ca57-4f99-858c-ad9e0a234eb2"
+ },
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "258 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "aoQ9-ggb97KP"
+ },
+ "execution_count": 12,
+ "outputs": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ },
+ "colab": {
+ "provenance": [],
+ "include_colab_link": true
+ },
+ "accelerator": "TPU"
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
From 8b556ee4401fc06aed35d3432361927ad129fb89 Mon Sep 17 00:00:00 2001
From: John <1200986+maxwillzq@users.noreply.github.com>
Date: Wed, 23 Aug 2023 10:48:57 -0700
Subject: [PATCH 2/2] Created using Colaboratory
---
docs/experimental_call_torch_tutorial.ipynb | 296 ++------------------
1 file changed, 29 insertions(+), 267 deletions(-)
diff --git a/docs/experimental_call_torch_tutorial.ipynb b/docs/experimental_call_torch_tutorial.ipynb
index 4d93bd1..8e92b44 100644
--- a/docs/experimental_call_torch_tutorial.ipynb
+++ b/docs/experimental_call_torch_tutorial.ipynb
@@ -12,7 +12,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {
"id": "Ym4Xu6mWg2Na"
},
@@ -51,42 +51,10 @@
"!pip install git+https://github.com/google/jaxonnxruntime.git\n"
],
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "Kkfn_llxzdDa",
- "outputId": "43397188-5fd7-40e9-93cf-212e43b1c1ee"
+ "id": "Kkfn_llxzdDa"
},
- "execution_count": 2,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Collecting git+https://github.com/google/jaxonnxruntime.git\n",
- " Cloning https://github.com/google/jaxonnxruntime.git to /tmp/pip-req-build-yj76dq13\n",
- " Running command git clone --filter=blob:none --quiet https://github.com/google/jaxonnxruntime.git /tmp/pip-req-build-yj76dq13\n",
- " Resolved https://github.com/google/jaxonnxruntime.git to commit 2b95bac67150f865222a4bda87b66c099b64ae59\n",
- " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
- " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
- " Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n",
- " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from jaxonnxruntime==0.3.0) (1.23.5)\n",
- "Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (from jaxonnxruntime==0.3.0) (0.3.25)\n",
- "Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (from jaxonnxruntime==0.3.0) (0.3.25)\n",
- "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax->jaxonnxruntime==0.3.0) (3.3.0)\n",
- "Requirement already satisfied: scipy>=1.5 in /usr/local/lib/python3.10/dist-packages (from jax->jaxonnxruntime==0.3.0) (1.10.1)\n",
- "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from jax->jaxonnxruntime==0.3.0) (4.7.1)\n",
- "Building wheels for collected packages: jaxonnxruntime\n",
- " Building wheel for jaxonnxruntime (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- " Created wheel for jaxonnxruntime: filename=jaxonnxruntime-0.3.0-py3-none-any.whl size=177972 sha256=caba7590e9938cbca5365c19348155ab3d42c130fc1b19c45953963089c7066f\n",
- " Stored in directory: /tmp/pip-ephem-wheel-cache-jf3xlcq6/wheels/43/d3/a6/40189cf2b24db631a69157a58da1d0fcc6d1df48abf847f8bd\n",
- "Successfully built jaxonnxruntime\n",
- "Installing collected packages: jaxonnxruntime\n",
- "Successfully installed jaxonnxruntime-0.3.0\n"
- ]
- }
- ]
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "code",
@@ -94,39 +62,10 @@
"!pip install onnx torch"
],
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "1GX2-ZCwzmPA",
- "outputId": "9439c8de-423a-4108-86f5-394d2da36913"
+ "id": "1GX2-ZCwzmPA"
},
- "execution_count": 3,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Collecting onnx\n",
- " Downloading onnx-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.6 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.6/14.6 MB\u001b[0m \u001b[31m33.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.0.1+cu118)\n",
- "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from onnx) (1.23.5)\n",
- "Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx) (3.20.3)\n",
- "Requirement already satisfied: typing-extensions>=3.6.2.1 in /usr/local/lib/python3.10/dist-packages (from onnx) (4.7.1)\n",
- "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.12.2)\n",
- "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n",
- "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.1)\n",
- "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n",
- "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.0.0)\n",
- "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (3.27.2)\n",
- "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (16.0.6)\n",
- "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n",
- "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n",
- "Installing collected packages: onnx\n",
- "Successfully installed onnx-1.14.0\n"
- ]
- }
- ]
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "markdown",
@@ -141,42 +80,11 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {
- "id": "4RF7lO5ng2Nd",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "99878430-cc1b-4fd8-ebf6-cb53ce396c83"
+ "id": "4RF7lO5ng2Nd"
},
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "torch_output: tensor([[ 9.2530e-01, 4.5873e-01, 6.4589e-01, 5.6415e-01, 1.6490e+00,\n",
- " 8.2248e-01, 8.9589e-01, -5.7184e-01, 1.0232e+00, 5.1742e-01],\n",
- " [ 9.5014e-02, 1.3660e+00, 1.3621e+00, 9.0590e-01, 8.4838e-02,\n",
- " 1.2079e-01, 1.7469e+00, 1.1094e+00, 1.1676e+00, 1.6560e+00],\n",
- " [ 1.2260e+00, 1.7275e+00, 1.2192e+00, -9.5381e-01, 1.1586e+00,\n",
- " 1.3536e-01, -8.1443e-01, 3.7343e-01, 1.5365e+00, 1.2673e+00],\n",
- " [ 1.9366e+00, 4.1828e-01, 7.5243e-01, -2.6371e-01, -1.1587e-03,\n",
- " 1.8683e+00, 7.5635e-01, -6.5726e-01, 1.7267e+00, 1.4934e+00],\n",
- " [ 1.7448e-01, 1.2264e+00, 1.5650e+00, -1.1248e-01, 8.0965e-01,\n",
- " 6.4813e-01, 2.7031e-01, -2.6631e-01, 4.7319e-02, 3.2769e-02],\n",
- " [ 4.5132e-01, 1.9266e+00, 1.6291e+00, 4.6194e-01, -1.2171e-01,\n",
- " 1.8986e+00, 1.3777e-01, 4.9093e-01, -3.5940e-01, 8.6310e-01],\n",
- " [ 4.1073e-01, 7.4641e-01, 8.6454e-02, 8.2659e-01, 1.5467e+00,\n",
- " 9.5625e-01, -1.6194e-01, 1.4552e+00, 6.8996e-01, 1.0307e-01],\n",
- " [ 1.9547e+00, -5.6088e-01, -3.6236e-01, 6.7257e-01, 2.4588e-01,\n",
- " 1.6908e-01, 1.6637e+00, 1.3871e+00, 3.9015e-01, 9.6702e-01],\n",
- " [ 5.3747e-01, 4.6549e-02, 1.6271e+00, 1.7466e+00, 4.4292e-01,\n",
- " -9.8217e-01, 1.4013e-01, 6.7617e-03, -5.3100e-02, -1.1265e+00],\n",
- " [ 5.3585e-01, -1.8007e-02, 8.2986e-01, 1.8453e+00, 7.6978e-01,\n",
- " -1.0868e-02, 1.4607e+00, 4.8968e-01, 8.1474e-01, 1.1434e+00]])\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import torch\n",
"import jax\n",
@@ -201,74 +109,10 @@
"print(\"jax_output:\", jax_fn(jax_params, jax_inputs))"
],
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "kEEVNkLP1Owa",
- "outputId": "96063361-a3e3-4b6b-97cb-62609fff171d"
+ "id": "kEEVNkLP1Owa"
},
- "execution_count": 5,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:825: UserWarning: no signature found for , skipping _decide_input_format\n",
- " warnings.warn(f\"{e}, skipping _decide_input_format\")\n",
- "WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
- ]
- },
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n",
- "verbose: False, log level: Level.ERROR\n",
- "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n",
- "\n",
- "jax_output: [DeviceArray([[ 9.2529660e-01, 4.5873073e-01, 6.4588535e-01,\n",
- " 5.6415093e-01, 1.6489711e+00, 8.2248122e-01,\n",
- " 8.9588922e-01, -5.7184368e-01, 1.0232372e+00,\n",
- " 5.1741624e-01],\n",
- " [ 9.5014393e-02, 1.3659939e+00, 1.3621219e+00,\n",
- " 9.0589929e-01, 8.4837854e-02, 1.2079075e-01,\n",
- " 1.7469316e+00, 1.1094404e+00, 1.1675845e+00,\n",
- " 1.6560377e+00],\n",
- " [ 1.2259831e+00, 1.7274783e+00, 1.2191784e+00,\n",
- " -9.5381111e-01, 1.1585734e+00, 1.3535953e-01,\n",
- " -8.1442708e-01, 3.7343296e-01, 1.5364575e+00,\n",
- " 1.2673373e+00],\n",
- " [ 1.9366202e+00, 4.1827971e-01, 7.5243413e-01,\n",
- " -2.6371196e-01, -1.1587143e-03, 1.8683007e+00,\n",
- " 7.5634706e-01, -6.5725970e-01, 1.7267224e+00,\n",
- " 1.4934300e+00],\n",
- " [ 1.7448190e-01, 1.2263532e+00, 1.5649725e+00,\n",
- " -1.1247611e-01, 8.0965269e-01, 6.4813310e-01,\n",
- " 2.7030846e-01, -2.6631176e-01, 4.7319174e-02,\n",
- " 3.2769233e-02],\n",
- " [ 4.5132229e-01, 1.9265618e+00, 1.6291361e+00,\n",
- " 4.6194053e-01, -1.2170833e-01, 1.8985560e+00,\n",
- " 1.3776600e-01, 4.9092823e-01, -3.5940361e-01,\n",
- " 8.6309528e-01],\n",
- " [ 4.1072559e-01, 7.4641228e-01, 8.6453676e-02,\n",
- " 8.2658666e-01, 1.5467417e+00, 9.5624632e-01,\n",
- " -1.6194339e-01, 1.4552250e+00, 6.8996412e-01,\n",
- " 1.0307056e-01],\n",
- " [ 1.9546964e+00, -5.6088006e-01, -3.6236224e-01,\n",
- " 6.7256629e-01, 2.4587882e-01, 1.6908441e-01,\n",
- " 1.6637435e+00, 1.3870629e+00, 3.9014500e-01,\n",
- " 9.6701825e-01],\n",
- " [ 5.3747320e-01, 4.6549439e-02, 1.6270701e+00,\n",
- " 1.7466159e+00, 4.4291818e-01, -9.8216683e-01,\n",
- " 1.4013010e-01, 6.7616701e-03, -5.3099811e-02,\n",
- " -1.1264541e+00],\n",
- " [ 5.3585029e-01, -1.8007398e-02, 8.2986158e-01,\n",
- " 1.8452730e+00, 7.6977569e-01, -1.0867842e-02,\n",
- " 1.4607116e+00, 4.8967782e-01, 8.1474060e-01,\n",
- " 1.1434039e+00]], dtype=float32)]\n"
- ]
- }
- ]
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "markdown",
@@ -282,7 +126,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {
"id": "zR31ReFKg2Ne"
},
@@ -309,45 +153,10 @@
"print(\"jax_output:\", jax_fn(jax_params, jax_inputs))"
],
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "H--BoG-E1d6w",
- "outputId": "6ff9f51c-0c7e-4ab1-b404-81e175587365"
+ "id": "H--BoG-E1d6w"
},
- "execution_count": 7,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n",
- "verbose: False, log level: Level.ERROR\n",
- "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n",
- "\n",
- "jax_output: [DeviceArray([[0.12982486, 0.481035 , 0.3622095 , 0.7990376 , 0.607528 ,\n",
- " 0. , 0. , 0.2921514 , 0.56446004, 0. ],\n",
- " [0. , 0.57425183, 0. , 0.9224024 , 0. ,\n",
- " 0.39311224, 0.11618385, 0. , 0.6319629 , 0.29966408],\n",
- " [0.49294975, 0.36862767, 0.2809724 , 0. , 0. ,\n",
- " 0.33000746, 0.4940635 , 0.02182353, 0. , 0. ],\n",
- " [0. , 0.3759548 , 0. , 0.23886062, 0. ,\n",
- " 0. , 0.35155773, 0. , 0.24100977, 0.15047333],\n",
- " [0.40354684, 0. , 0.12827398, 0. , 0.07742476,\n",
- " 0.6966675 , 0. , 0.00607699, 0. , 0.59710497],\n",
- " [0. , 0.5936287 , 0. , 0. , 0. ,\n",
- " 0.5815142 , 0.2761725 , 0.47168115, 0. , 0.26667207],\n",
- " [0.06918865, 0. , 0.43948644, 0. , 0.5058311 ,\n",
- " 0.09885295, 0.40746492, 0.30387542, 0.45276335, 0.10408825],\n",
- " [0.24983016, 0.14137569, 0. , 0.72815377, 0.8114418 ,\n",
- " 0.5519933 , 0. , 0. , 0.08975451, 0. ],\n",
- " [0.4778166 , 0. , 0.8714261 , 0. , 0.3690765 ,\n",
- " 0.2697782 , 0.13372795, 0. , 0. , 0. ],\n",
- " [0. , 0.88346595, 0.2436127 , 0. , 0. ,\n",
- " 0.41404647, 0.61342806, 1.4032906 , 0.00577256, 0.31158522]], dtype=float32)]\n"
- ]
- }
- ]
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "markdown",
@@ -382,7 +191,7 @@
"metadata": {
"id": "p03lb4Ix4CvN"
},
- "execution_count": 8,
+ "execution_count": null,
"outputs": []
},
{
@@ -396,33 +205,10 @@
"jax_outputs = jax_fn(jax_params, jax_inputs)"
],
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "d1UB2by-4Q8R",
- "outputId": "43f18570-8fce-46d3-b456-331e9cccb51a"
+ "id": "d1UB2by-4Q8R"
},
- "execution_count": 9,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py:825: UserWarning: no signature found for , skipping _decide_input_format\n",
- " warnings.warn(f\"{e}, skipping _decide_input_format\")\n"
- ]
- },
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============\n",
- "verbose: False, log level: Level.ERROR\n",
- "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n",
- "\n"
- ]
- }
- ]
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "code",
@@ -434,7 +220,7 @@
"metadata": {
"id": "GrNFj9z27FTC"
},
- "execution_count": 10,
+ "execution_count": null,
"outputs": []
},
{
@@ -443,22 +229,10 @@
"%timeit _ = torch_module(*torch_inputs)"
],
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "P-aPak7H9FkV",
- "outputId": "930fb187-6fc0-42b8-dbbc-bd277b806a29"
+ "id": "P-aPak7H9FkV"
},
- "execution_count": 11,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "93.7 ms ± 23.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
- ]
- }
- ]
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "code",
@@ -466,22 +240,10 @@
"%timeit _ = jax_fn(jax_params, jax_inputs)\n"
],
"metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "LAEbgDR59LPV",
- "outputId": "813556b9-ca57-4f99-858c-ad9e0a234eb2"
+ "id": "LAEbgDR59LPV"
},
- "execution_count": 12,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "258 ms ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
- ]
- }
- ]
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "code",
@@ -489,7 +251,7 @@
"metadata": {
"id": "aoQ9-ggb97KP"
},
- "execution_count": 12,
+ "execution_count": null,
"outputs": []
}
],