diff --git a/docs/source/modules/chop/passes_module.rst b/docs/source/modules/chop/passes_module.rst index 6a1ac7320..2dae348bf 100644 --- a/docs/source/modules/chop/passes_module.rst +++ b/docs/source/modules/chop/passes_module.rst @@ -5,7 +5,7 @@ chop.passes.module Summary of Mase Module Analysis Passes -------------------------------------- -.. list-table:: MASE module-level analysis passes +.. list-table:: MASE module-level analysis passes :widths: 20 40 40 :header-rows: 1 @@ -25,7 +25,7 @@ Summary of Mase Module Analysis Passes Summary of Mase Module Transform Passes --------------------------------------- -.. list-table:: MASE module-level transform passes +.. list-table:: MASE module-level transform passes :widths: 20 40 40 :header-rows: 1 @@ -35,9 +35,13 @@ Summary of Mase Module Transform Passes * - :py:meth:`~chop.passes.module.transforms.quantize.quantize_module_transform_pass` - `test_module_quantize `_ - Apply quantization transformation to the given nn.Module + * - :py:meth:`~chop.passes.module.transforms.onn.optical_transformer_module_transform_pass` + - See :doc:`transform/onn` + - Transform modules to Optical Neural Network (ONN) equivalents .. toctree:: :maxdepth: 2 :caption: Full list of module-level transform passes - module_transform/quantization \ No newline at end of file + module_transform/quantization + transform/onn \ No newline at end of file diff --git a/docs/source/modules/chop/transform/onn.rst b/docs/source/modules/chop/transform/onn.rst new file mode 100644 index 000000000..1750a97ff --- /dev/null +++ b/docs/source/modules/chop/transform/onn.rst @@ -0,0 +1,203 @@ +chop.passes.module.transforms.onn +================================= + +This module provides transformation passes for converting standard neural network +modules into Optical Neural Network (ONN) equivalents. The optical transformer +implementation is based on the `Optical Transformers paper `_. + +Optical neural networks leverage photonic hardware to perform matrix multiplications +with reduced power consumption. This transform simulates the quantization effects +and constraints of optical compute hardware, enabling model development and evaluation +before deployment on physical optical accelerators. + +.. note:: + + This module requires the ``mase-triton`` package to be installed. + Install via: ``pip install mase-triton`` + + +Transform Pass +-------------- + +optical\_transformer\_module\_transform\_pass +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: chop.passes.module.transforms.onn.optical_transformer_module_transform_pass + + +Configuration +------------- + +The transform pass accepts configuration through the ``pass_args`` dictionary. +Layer matching can be done by exact name or regex patterns. + +Example configuration: + +.. code-block:: python + + pass_args = { + "by": "regex_name", # or "name" for exact matching + "default": { + "q_levels": 256, + "q_lut_min": 0.020040, + "q_smooth_factor": 0.9, + "q_init_seed": 0, + "q_bypass": False, + }, + # Override for specific layers using regex + ".*mlp.*": { + "q_levels": 128, + "q_bypass": False, + }, + } + + +Configuration Parameters +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. list-table:: + :header-rows: 1 + :widths: 20 15 15 50 + + * - Parameter + - Type + - Default + - Description + * - ``q_levels`` + - int + - 256 + - Number of quantization levels for optical simulation + * - ``q_lut_min`` + - float + - 0.020040 + - Minimum value for the lookup table used in quantization + * - ``q_smooth_factor`` + - float + - 0.9 + - Exponential moving average factor for updating running statistics + * - ``q_init_seed`` + - int + - 0 + - Random seed for quantization noise initialization + * - ``q_bypass`` + - bool + - False + - If True, bypass optical quantization (useful for debugging) + + +Layers +------ + +OtLinear +^^^^^^^^ + +.. py:data:: chop.passes.module.transforms.onn.layers.linear.OtLinear + + Optical Transformer Linear layer. + + This is an alias to ``mase_triton.optical_compute.layers.OpticalTransformerLinear``. + It replaces standard ``torch.nn.Linear`` layers with quantized optical transformer + equivalents that simulate optical neural network hardware constraints. + + The layer applies quantization to both the input activations and weights during + matrix multiplication, and tracks running min/max statistics for calibration. + + **Class method:** + + .. py:method:: from_linear(linear, **kwargs) + :classmethod: + + Create an OtLinear from an existing ``torch.nn.Linear`` layer. + + :param linear: Source linear layer + :type linear: torch.nn.Linear + :param kwargs: Quantization parameters (q_levels, q_lut_min, etc.) + :return: Optical transformer linear layer with copied weights + + +OtLlamaAttention +^^^^^^^^^^^^^^^^ + +.. autoclass:: chop.passes.module.transforms.onn.layers.attn.OtLlamaAttention + :members: + :undoc-members: + :show-inheritance: + + +Functional API +-------------- + +optical\_transformer\_SDPA +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: chop.passes.module.transforms.onn.layers.attn.optical_transformer_SDPA + + +Usage Example +------------- + +Basic usage with a LLaMA model: + +.. code-block:: python + + from transformers import AutoModelForCausalLM + from chop.passes.module.transforms.onn import optical_transformer_module_transform_pass + + # Load a pretrained model + model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + + # Define transformation configuration + pass_args = { + "by": "regex_name", + "default": { + "q_levels": 256, + "q_lut_min": 0.020040, + "q_smooth_factor": 0.9, + "q_init_seed": 0, + "q_bypass": False, + }, + } + + # Apply the optical transformer transform + model = optical_transformer_module_transform_pass(model, pass_args) + + # The model now uses OtLinear and OtLlamaAttention layers + # Continue with training or inference as usual + + +Selective Layer Transformation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Transform only specific layers using regex patterns: + +.. code-block:: python + + pass_args = { + "by": "regex_name", + # Only transform attention layers + ".*self_attn.*": { + "q_levels": 256, + "q_bypass": False, + }, + # Transform MLP with different settings + ".*mlp.*": { + "q_levels": 128, + "q_bypass": False, + }, + } + + +Bypass Mode for Debugging +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Use ``q_bypass=True`` to disable quantization while keeping the module structure: + +.. code-block:: python + + pass_args = { + "by": "regex_name", + "default": { + "q_levels": 256, + "q_bypass": True, # Disable quantization + }, + } diff --git a/docs/tutorials/newcompute/onn/1-transform.ipynb b/docs/tutorials/newcompute/onn/1-transform.ipynb new file mode 100644 index 000000000..5d04e5ca9 --- /dev/null +++ b/docs/tutorials/newcompute/onn/1-transform.ipynb @@ -0,0 +1,411 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1a79ad55", + "metadata": {}, + "source": [ + "# Optical Transformer Transform Pass\n", + "\n", + "This tutorial provides minimal documentation for the Optical Neural Network (ONN) transform pass and layer classes in MASE.\n", + "\n", + "The optical transformer implementation is based on the [Optical Transformers paper](https://arxiv.org/abs/2302.10360).\n", + "\n", + "## Overview\n", + "\n", + "The ONN transform pass replaces standard PyTorch modules with their optical transformer equivalents:\n", + "\n", + "| Original Module | Optical Equivalent |\n", + "|-----------------|--------------------|\n", + "| `torch.nn.Linear` | `OtLinear` |\n", + "| `LlamaAttention` | `OtLlamaAttention` |\n", + "\n", + "## Requirements\n", + "\n", + "The `mase-triton` package is required for ONN transforms:\n", + "\n", + "```bash\n", + "pip install mase-triton\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "984ca459", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/zz7522/miniconda3/envs/mase/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import torch\n", + "from transformers.models.llama.modeling_llama import LlamaAttention, LlamaConfig\n", + "\n", + "from chop.passes.module.transforms.onn.transform import (\n", + " OtLinear,\n", + " OtLlamaAttention,\n", + " OtTransformConfig,\n", + " optical_transformer_module_transform_pass,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0fc4e439", + "metadata": {}, + "source": [ + "## Configuration\n", + "\n", + "Use `OtTransformConfig` to configure the optical transform parameters:\n", + "\n", + "| Parameter | Type | Default | Description |\n", + "|-----------|------|---------|-------------|\n", + "| `q_levels` | int | 256 | Number of quantization levels, $2^n$ for n-bit quantization. |\n", + "| `q_lut_min` | float | 0.020040 | Minimum LUT value for quantization |\n", + "| `q_smooth_factor` | float | 0.9 | Smoothing factor for statistics updates in the training mode |\n", + "| `q_init_seed` | int | 0 | Random seed for initialization (only used in triton kernels) |\n", + "| `q_bypass` | bool | False | If True, bypass optical quantization |" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9a7c1bbd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Default ONN config: {'q_levels': 256, 'q_lut_min': 0.02004, 'q_smooth_factor': 0.9, 'q_init_seed': 0, 'q_bypass': False}\n", + "Modified ONN config: {'q_levels': 256, 'q_lut_min': 0.02004, 'q_smooth_factor': 0.1, 'q_init_seed': 0, 'q_bypass': False}\n" + ] + } + ], + "source": [ + "# Create default configuration\n", + "onn_config = OtTransformConfig.create_default()\n", + "print(\"Default ONN config:\", onn_config)\n", + "\n", + "# Customize configuration\n", + "onn_config[\"q_levels\"] = 256 # 8-bit quantization\n", + "onn_config[\"q_smooth_factor\"] = 0.1\n", + "print(\"Modified ONN config:\", onn_config)" + ] + }, + { + "cell_type": "markdown", + "id": "de082853", + "metadata": {}, + "source": [ + "## OtLinear: Optical Linear Layer\n", + "\n", + "`OtLinear` is the optical equivalent of `torch.nn.Linear`. It applies quantized matrix multiplication that simulates optical computing behavior." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7e8b9c1d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original output shape: torch.Size([2, 64])\n", + "Optical output shape: torch.Size([2, 64])\n", + "Max absolute difference: 0.035205\n" + ] + } + ], + "source": [ + "# Create a standard linear layer\n", + "linear = torch.nn.Linear(in_features=32, out_features=64)\n", + "\n", + "# Convert to optical linear layer\n", + "onn_config = OtTransformConfig.create_default()\n", + "linear_onn = OtLinear.from_linear(linear, **onn_config)\n", + "\n", + "# Compare outputs\n", + "x = torch.randn(2, 32)\n", + "y = linear(x)\n", + "y_onn = linear_onn(x)\n", + "\n", + "print(f\"Original output shape: {y.shape}\")\n", + "print(f\"Optical output shape: {y_onn.shape}\")\n", + "print(f\"Max absolute difference: {(y - y_onn).abs().max().item():.6f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "5d4e368f", + "metadata": {}, + "source": [ + "## OtLlamaAttention: Optical Llama Attention\n", + "\n", + "`OtLlamaAttention` replaces the HuggingFace `LlamaAttention` with an optical-aware implementation that uses quantized scaled dot-product attention." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6e7a6261", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original output shape: torch.Size([1, 16, 384])\n", + "Optical output shape: torch.Size([1, 16, 384])\n" + ] + } + ], + "source": [ + "# Setup Llama configuration\n", + "model_name = \"AICrossSim/clm-60m\"\n", + "hf_config = LlamaConfig.from_pretrained(model_name)\n", + "\n", + "batch_size = 1\n", + "seq_len = 16\n", + "head_dim = hf_config.hidden_size // hf_config.num_attention_heads\n", + "\n", + "# Create standard attention layer\n", + "attn = LlamaAttention(config=hf_config, layer_idx=0)\n", + "\n", + "# Convert to optical attention\n", + "onn_config = OtTransformConfig.create_default()\n", + "onn_config[\"q_levels\"] = 512\n", + "attn_onn = OtLlamaAttention.from_pretrained(attn, layer_idx=0, **onn_config)\n", + "\n", + "# Test forward pass\n", + "pos_emb = torch.ones(batch_size, seq_len, head_dim)\n", + "x = 3 * torch.randn(batch_size, seq_len, hf_config.hidden_size)\n", + "\n", + "y, _ = attn(x, (pos_emb, pos_emb), None)\n", + "attn_onn.train() # Enable statistics updates\n", + "y_onn, _ = attn_onn(x, (pos_emb, pos_emb), None)\n", + "\n", + "print(f\"Original output shape: {y.shape}\")\n", + "print(f\"Optical output shape: {y_onn.shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "1ac49d15", + "metadata": {}, + "source": [ + "## Transform Pass: Network-Level Transformation\n", + "\n", + "Use `optical_transformer_module_transform_pass` to transform an entire network. The pass replaces modules based on name matching.\n", + "\n", + "### Pass Arguments\n", + "\n", + "| Key | Description |\n", + "|-----|-------------|\n", + "| `by` | Matching mode: `\"name\"` (exact) or `\"regex_name\"` (regex pattern) |\n", + "| `` | Configuration dict for layers matching the name/pattern |\n", + "| `default` | Fallback configuration if no pattern matches |" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "bdee282d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original network:\n", + "SimpleNetwork(\n", + " (attn): LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " )\n", + " (linear): Linear(in_features=384, out_features=384, bias=True)\n", + ")\n" + ] + } + ], + "source": [ + "# Define a simple network with attention and linear layers\n", + "class SimpleNetwork(torch.nn.Module):\n", + " def __init__(self, hf_config):\n", + " super().__init__()\n", + " self.attn = LlamaAttention(config=hf_config, layer_idx=0)\n", + " self.linear = torch.nn.Linear(\n", + " in_features=hf_config.hidden_size,\n", + " out_features=hf_config.hidden_size,\n", + " )\n", + "\n", + " def forward(self, x, pos_emb):\n", + " attn_output, _ = self.attn(x, (pos_emb, pos_emb), None)\n", + " output = self.linear(attn_output)\n", + " return output\n", + "\n", + "network = SimpleNetwork(hf_config)\n", + "print(\"Original network:\")\n", + "print(network)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ffcb8d3d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=True) to OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Transformed network:\n", + "SimpleNetwork(\n", + " (attn): OtLlamaAttention(\n", + " (q_proj): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + " (k_proj): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + " (v_proj): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + " (o_proj): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + " )\n", + " (linear): OpticalTransformerLinear(q_bypass=False, q_levels=512, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + ")\n" + ] + } + ], + "source": [ + "# Configure the transform pass with regex patterns\n", + "onn_config = OtTransformConfig.create_default()\n", + "onn_config[\"q_levels\"] = 512\n", + "\n", + "pass_args = {\n", + " \"by\": \"regex_name\", # Use regex matching\n", + " \"attn\": onn_config, # Transform the attention layer\n", + " \"linear\": onn_config, # Transform the linear layer\n", + " r\"attn\\.(q|k|v|o)_proj\": onn_config, # Transform Q/K/V/O projections inside attention\n", + "}\n", + "\n", + "# Apply the transform\n", + "network_onn = optical_transformer_module_transform_pass(network, pass_args)\n", + "\n", + "print(\"\\nTransformed network:\")\n", + "print(network_onn)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "3c3557ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Verification:\n", + " attn is OtLlamaAttention: True\n", + " linear is OtLinear: True\n", + " attn.q_proj is OtLinear: True\n", + " attn.k_proj is OtLinear: True\n", + " attn.v_proj is OtLinear: True\n", + " attn.o_proj is OtLinear: True\n" + ] + } + ], + "source": [ + "# Verify the transformation\n", + "print(\"Verification:\")\n", + "print(f\" attn is OtLlamaAttention: {isinstance(network_onn.attn, OtLlamaAttention)}\")\n", + "print(f\" linear is OtLinear: {isinstance(network_onn.linear, OtLinear)}\")\n", + "print(f\" attn.q_proj is OtLinear: {isinstance(network_onn.attn.q_proj, OtLinear)}\")\n", + "print(f\" attn.k_proj is OtLinear: {isinstance(network_onn.attn.k_proj, OtLinear)}\")\n", + "print(f\" attn.v_proj is OtLinear: {isinstance(network_onn.attn.v_proj, OtLinear)}\")\n", + "print(f\" attn.o_proj is OtLinear: {isinstance(network_onn.attn.o_proj, OtLinear)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "355b1f3c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output shape: torch.Size([1, 16, 384])\n", + "Max output error: 0.029137\n", + "Output is finite: True\n" + ] + } + ], + "source": [ + "# Test the transformed network\n", + "network_onn.train() # Enable statistics updates\n", + "\n", + "pos_emb = torch.ones(batch_size, seq_len, head_dim)\n", + "x = 3 * torch.randn(batch_size, seq_len, hf_config.hidden_size)\n", + "\n", + "y = network(x, pos_emb)\n", + "y_onn = network_onn(x, pos_emb)\n", + "print(f\"Output shape: {y_onn.shape}\")\n", + "print(f\"Max output error: {(y - y_onn).abs().max().item():.6f}\")\n", + "print(f\"Output is finite: {y_onn.isfinite().all().item()}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mase", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/tutorials/newcompute/onn/2-finetuning.ipynb b/docs/tutorials/newcompute/onn/2-finetuning.ipynb new file mode 100644 index 000000000..10aaf438f --- /dev/null +++ b/docs/tutorials/newcompute/onn/2-finetuning.ipynb @@ -0,0 +1,1224 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6ba64c6d", + "metadata": {}, + "source": [ + "# Fine-tuning HuggingFace Llama with Optical Transformer Transform\n", + "\n", + "This tutorial demonstrates how to:\n", + "1. Load a pretrained HuggingFace Llama model\n", + "2. Transform it using the optical transformer pass from MASE\n", + "3. Run continual fine-tuning on the transformed model\n", + "\n", + "## Overview\n", + "\n", + "The optical transformer transform replaces standard PyTorch modules with their optical equivalents that simulate optical computing behavior. This enables:\n", + "- Quantized matrix multiplication that models optical hardware\n", + "- Noise-aware training for robust optical neural network deployment\n", + "\n", + "## Requirements\n", + "\n", + "```bash\n", + "pip install mase-triton transformers datasets accelerate\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "943f547b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/zz7522/miniconda3/envs/mase/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import torch\n", + "from transformers import (\n", + " AutoConfig,\n", + " AutoModelForCausalLM,\n", + " AutoTokenizer,\n", + " get_scheduler,\n", + " default_data_collator,\n", + ")\n", + "from datasets import load_dataset\n", + "from torch.utils.data import DataLoader\n", + "from itertools import chain\n", + "from tqdm.auto import tqdm\n", + "\n", + "from chop.passes.module.transforms.onn.transform import (\n", + " OtLinear,\n", + " OtLlamaAttention,\n", + " OtTransformConfig,\n", + " optical_transformer_module_transform_pass,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "dea4c060", + "metadata": {}, + "source": [ + "## 1. Load Pretrained HuggingFace Llama Model\n", + "\n", + "We'll use a small Llama model for demonstration. You can replace this with any Llama-based model." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "38b797b7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cuda\n" + ] + } + ], + "source": [ + "# Model configuration\n", + "MODEL_NAME = \"AICrossSim/clm-60m\" # Small Llama model for demo\n", + "BLOCK_SIZE = 128 # Sequence length (use smaller value for demo)\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "print(f\"Using device: {DEVICE}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8b6ed596", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model type: LlamaForCausalLM\n", + "Number of parameters: 82,101,120\n" + ] + } + ], + "source": [ + "# Load model and tokenizer\n", + "config = AutoConfig.from_pretrained(MODEL_NAME)\n", + "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " MODEL_NAME,\n", + " config=config,\n", + " attn_implementation=\"eager\", # Use eager attention for compatibility\n", + ")\n", + "\n", + "print(f\"Model type: {type(model).__name__}\")\n", + "print(f\"Number of parameters: {sum(p.numel() for p in model.parameters()):,}\")" + ] + }, + { + "cell_type": "markdown", + "id": "e99e16d4", + "metadata": {}, + "source": [ + "## 2. Configure and Apply Optical Transform\n", + "\n", + "We configure the optical transform with quantization parameters and apply it to the model." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "968be7f6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ONN Configuration:\n", + " q_levels: 256\n", + " q_lut_min: 0.02004\n", + " q_smooth_factor: 0.9\n", + " q_init_seed: 0\n", + " q_bypass: False\n" + ] + } + ], + "source": [ + "# Create ONN configuration\n", + "onn_config = OtTransformConfig.create_default()\n", + "\n", + "# Customize configuration (optional)\n", + "onn_config[\"q_levels\"] = 256 # Number of quantization levels\n", + "onn_config[\"q_lut_min\"] = 0.020040 # Minimum LUT value\n", + "onn_config[\"q_smooth_factor\"] = 0.9 # Statistics smoothing factor\n", + "onn_config[\"q_bypass\"] = False # Set to True to bypass optical quantization\n", + "\n", + "print(\"ONN Configuration:\")\n", + "for k, v in onn_config.items():\n", + " print(f\" {k}: {v}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c4c54884", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transform pass arguments configured\n" + ] + } + ], + "source": [ + "# Configure the transform pass\n", + "# Use regex patterns to match layer names for transformation\n", + "pass_args = {\n", + " \"by\": \"regex_name\",\n", + " # Transform all attention layers\n", + " r\"model\\.layers\\.\\d+\\.self_attn\": onn_config,\n", + " # Transform attention projections (Q, K, V, O)\n", + " r\"model\\.layers\\.\\d+\\.self_attn\\.(q|k|v|o)_proj\": onn_config,\n", + " # Transform MLP layers\n", + " r\"model\\.layers\\.\\d+\\.mlp\\.(gate|up|down)_proj\": onn_config,\n", + "}\n", + "\n", + "print(\"Transform pass arguments configured\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "183195e2", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['query_min_max', 'key_min_max', 'qk_min_max', 'attn_min_max', 'value_min_max', 'av_min_max', 'seed'] from LlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ") to OtLlamaAttention(\n", + " (q_proj): Linear(in_features=384, out_features=384, bias=False)\n", + " (k_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (v_proj): Linear(in_features=384, out_features=128, bias=False)\n", + " (o_proj): Linear(in_features=384, out_features=384, bias=False)\n", + ")\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transforming model...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=128, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=384, out_features=1408, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n", + "WARNING:root:Missing keys when loading state_dict: ['q_min_max_quantile', 'x_min_max', 'w_min_max', 'out_min_max', 'seed'] from Linear(in_features=1408, out_features=384, bias=False) to OpticalTransformerLinear(q_bypass=False, q_levels=256, q_lut_min=0.02004, q_quantiles=[0.0010000000474974513, 0.9990000128746033], x_min_max=tensor([inf, -inf]), w_min_max=tensor([inf, -inf]), out_min_max=tensor([inf, -inf]), seed=0)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model transformed successfully!\n" + ] + } + ], + "source": [ + "# Apply the optical transformer transform\n", + "print(\"Transforming model...\")\n", + "model = optical_transformer_module_transform_pass(model, pass_args)\n", + "print(\"Model transformed successfully!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "57dd533d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transformed layers:\n", + " OtLinear: 154\n", + " OtLlamaAttention: 22\n" + ] + } + ], + "source": [ + "# Verify the transformation\n", + "def count_transformed_layers(model):\n", + " ot_linear_count = 0\n", + " ot_attn_count = 0\n", + " for name, module in model.named_modules():\n", + " if isinstance(module, OtLinear):\n", + " ot_linear_count += 1\n", + " elif isinstance(module, OtLlamaAttention):\n", + " ot_attn_count += 1\n", + " return ot_linear_count, ot_attn_count\n", + "\n", + "ot_linear, ot_attn = count_transformed_layers(model)\n", + "print(f\"Transformed layers:\")\n", + "print(f\" OtLinear: {ot_linear}\")\n", + "print(f\" OtLlamaAttention: {ot_attn}\")" + ] + }, + { + "cell_type": "markdown", + "id": "066772ea", + "metadata": {}, + "source": [ + "## 3. Prepare Dataset for Fine-tuning\n", + "\n", + "We'll use a small subset of a text dataset for demonstration." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a52002e3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train samples: 36718\n", + "Validation samples: 3760\n" + ] + } + ], + "source": [ + "# Load a small dataset for demonstration\n", + "raw_datasets = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\")\n", + "\n", + "print(f\"Train samples: {len(raw_datasets['train'])}\")\n", + "print(f\"Validation samples: {len(raw_datasets['validation'])}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a12cb5b5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenization complete\n" + ] + } + ], + "source": [ + "# Tokenize the dataset\n", + "def tokenize_function(examples):\n", + " return tokenizer(examples[\"text\"])\n", + "\n", + "tokenized_datasets = raw_datasets.map(\n", + " tokenize_function,\n", + " batched=True,\n", + " remove_columns=[\"text\"],\n", + " desc=\"Tokenizing\",\n", + ")\n", + "\n", + "print(\"Tokenization complete\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "05df197a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train chunks: 19848\n", + "Validation chunks: 2074\n" + ] + } + ], + "source": [ + "# Group texts into chunks of block_size\n", + "def group_texts(examples):\n", + " # Concatenate all texts\n", + " concatenated_examples = {k: list(chain(*examples[k])) for k in examples}\n", + " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", + " # Drop the remainder\n", + " total_length = (total_length // BLOCK_SIZE) * BLOCK_SIZE\n", + " # Split into chunks\n", + " result = {\n", + " k: [t[i : i + BLOCK_SIZE] for i in range(0, total_length, BLOCK_SIZE)]\n", + " for k, t in concatenated_examples.items()\n", + " }\n", + " result[\"labels\"] = result[\"input_ids\"].copy()\n", + " return result\n", + "\n", + "lm_datasets = tokenized_datasets.map(\n", + " group_texts,\n", + " batched=True,\n", + " desc=f\"Grouping texts in chunks of {BLOCK_SIZE}\",\n", + ")\n", + "\n", + "print(f\"Train chunks: {len(lm_datasets['train'])}\")\n", + "print(f\"Validation chunks: {len(lm_datasets['validation'])}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "cee790dd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train batches: 4962\n", + "Eval batches: 519\n" + ] + } + ], + "source": [ + "# Create DataLoaders\n", + "BATCH_SIZE = 4\n", + "\n", + "train_dataloader = DataLoader(\n", + " lm_datasets[\"train\"],\n", + " shuffle=True,\n", + " collate_fn=default_data_collator,\n", + " batch_size=BATCH_SIZE,\n", + ")\n", + "\n", + "eval_dataloader = DataLoader(\n", + " lm_datasets[\"validation\"],\n", + " collate_fn=default_data_collator,\n", + " batch_size=BATCH_SIZE,\n", + ")\n", + "\n", + "print(f\"Train batches: {len(train_dataloader)}\")\n", + "print(f\"Eval batches: {len(eval_dataloader)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9796a8b2", + "metadata": {}, + "source": [ + "## 4. Setup Training\n", + "\n", + "Configure optimizer, scheduler, and training parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "d69f8a60", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training configuration:\n", + " Learning rate: 0.0002\n", + " Weight decay: 0.01\n", + " Max train steps: 100\n" + ] + } + ], + "source": [ + "# Training hyperparameters\n", + "LEARNING_RATE = 2e-4\n", + "WEIGHT_DECAY = 0.01\n", + "NUM_EPOCHS = 1 # Use 1 epoch for demo\n", + "MAX_TRAIN_STEPS = 100 # Limit steps for demo\n", + "WARMUP_STEPS = 10\n", + "\n", + "print(\"Training configuration:\")\n", + "print(f\" Learning rate: {LEARNING_RATE}\")\n", + "print(f\" Weight decay: {WEIGHT_DECAY}\")\n", + "print(f\" Max train steps: {MAX_TRAIN_STEPS}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0e2230cd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trainable parameters: 82,101,120\n" + ] + } + ], + "source": [ + "# Move model to device\n", + "model = model.to(DEVICE)\n", + "\n", + "# Set all parameters trainable\n", + "for param in model.parameters():\n", + " param.requires_grad = True\n", + "\n", + "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "print(f\"Trainable parameters: {trainable_params:,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e2dd98fe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Optimizer configured\n" + ] + } + ], + "source": [ + "# Setup optimizer with weight decay\n", + "no_decay = [\"bias\", \"layer_norm.weight\"]\n", + "optimizer_grouped_parameters = [\n", + " {\n", + " \"params\": [\n", + " p for n, p in model.named_parameters()\n", + " if not any(nd in n for nd in no_decay) and p.requires_grad\n", + " ],\n", + " \"weight_decay\": WEIGHT_DECAY,\n", + " },\n", + " {\n", + " \"params\": [\n", + " p for n, p in model.named_parameters()\n", + " if any(nd in n for nd in no_decay) and p.requires_grad\n", + " ],\n", + " \"weight_decay\": 0.0,\n", + " },\n", + "]\n", + "\n", + "optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE)\n", + "print(\"Optimizer configured\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "0a8803df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LR scheduler configured\n" + ] + } + ], + "source": [ + "# Setup learning rate scheduler\n", + "lr_scheduler = get_scheduler(\n", + " name=\"linear\",\n", + " optimizer=optimizer,\n", + " num_warmup_steps=WARMUP_STEPS,\n", + " num_training_steps=MAX_TRAIN_STEPS,\n", + ")\n", + "\n", + "print(\"LR scheduler configured\")" + ] + }, + { + "cell_type": "markdown", + "id": "dfe3d417", + "metadata": {}, + "source": [ + "## 5. Training Loop\n", + "\n", + "Run the fine-tuning loop with the transformed optical model." + ] + }, + { + "cell_type": "markdown", + "id": "f58d6dd2", + "metadata": {}, + "source": [ + "### Quantization Statistics Warmup\n", + "\n", + "**Important:** The optical transformer layers require calibration of their quantization statistics (min/max values) before they can work correctly. Without this warmup:\n", + "- The statistics are initialized to `[inf, -inf]`\n", + "- The quantized matmul operations produce NaN values\n", + "- Loss and perplexity become NaN\n", + "\n", + "We run a few forward passes in **training mode** to let the layers collect statistics from the data." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "3653ac53", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running warmup to initialize quantization statistics...\n", + " Warmup batch 1/5\n", + " Warmup batch 2/5\n", + " Warmup batch 3/5\n", + " Warmup batch 4/5\n", + " Warmup batch 5/5\n", + "Warmup complete! Quantization statistics initialized.\n" + ] + } + ], + "source": [ + "# Warmup: Run a few forward passes in training mode to initialize quantization statistics\n", + "# This is necessary because the optical transformer layers need to calibrate their\n", + "# min/max statistics before they can perform quantized operations correctly.\n", + "\n", + "print(\"Running warmup to initialize quantization statistics...\")\n", + "model.train() # Must be in training mode to update stats\n", + "num_warmup_batches = 5\n", + "\n", + "with torch.no_grad(): # No need for gradients during warmup\n", + " for i, batch in enumerate(train_dataloader):\n", + " if i >= num_warmup_batches:\n", + " break\n", + " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", + " _ = model(**batch)\n", + " print(f\" Warmup batch {i+1}/{num_warmup_batches}\")\n", + "\n", + "print(\"Warmup complete! Quantization statistics initialized.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "2300b1a1", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "\n", + "def evaluate(model, eval_dataloader, device):\n", + " \"\"\"Evaluate model and return perplexity.\"\"\"\n", + " model.eval()\n", + " losses = []\n", + " for batch in eval_dataloader:\n", + " batch = {k: v.to(device) for k, v in batch.items()}\n", + " with torch.no_grad():\n", + " outputs = model(**batch)\n", + " losses.append(outputs.loss.item())\n", + "\n", + " avg_loss = sum(losses) / len(losses)\n", + " perplexity = math.exp(avg_loss)\n", + " return avg_loss, perplexity" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "dd599352", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluating before training...\n", + "Initial - Loss: 7.0303, Perplexity: 1130.35\n" + ] + } + ], + "source": [ + "# Evaluate before training\n", + "print(\"Evaluating before training...\")\n", + "eval_loss, eval_ppl = evaluate(model, eval_dataloader, DEVICE)\n", + "print(f\"Initial - Loss: {eval_loss:.4f}, Perplexity: {eval_ppl:.2f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "c9132b0a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Starting training...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Training: 100%|██████████| 100/100 [00:24<00:00, 4.15it/s, loss=5.9711]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Training completed! Steps: 100\n" + ] + } + ], + "source": [ + "# Training loop\n", + "print(\"\\nStarting training...\")\n", + "model.train()\n", + "completed_steps = 0\n", + "train_losses = []\n", + "\n", + "progress_bar = tqdm(range(MAX_TRAIN_STEPS), desc=\"Training\")\n", + "\n", + "for epoch in range(NUM_EPOCHS):\n", + " for step, batch in enumerate(train_dataloader):\n", + " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n", + "\n", + " # Forward pass\n", + " outputs = model(**batch)\n", + " loss = outputs.loss\n", + " train_losses.append(loss.item())\n", + "\n", + " # Backward pass\n", + " loss.backward()\n", + " optimizer.step()\n", + " lr_scheduler.step()\n", + " optimizer.zero_grad()\n", + "\n", + " progress_bar.update(1)\n", + " completed_steps += 1\n", + "\n", + " # Log progress\n", + " if completed_steps % 20 == 0:\n", + " avg_loss = sum(train_losses[-20:]) / min(20, len(train_losses))\n", + " progress_bar.set_postfix({\"loss\": f\"{avg_loss:.4f}\"})\n", + "\n", + " if completed_steps >= MAX_TRAIN_STEPS:\n", + " break\n", + "\n", + " if completed_steps >= MAX_TRAIN_STEPS:\n", + " break\n", + "\n", + "print(f\"\\nTraining completed! Steps: {completed_steps}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "18251766", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Evaluating after training...\n", + "Final - Loss: 5.9571, Perplexity: 386.50\n" + ] + } + ], + "source": [ + "# Evaluate after training\n", + "print(\"\\nEvaluating after training...\")\n", + "eval_loss, eval_ppl = evaluate(model, eval_dataloader, DEVICE)\n", + "print(f\"Final - Loss: {eval_loss:.4f}, Perplexity: {eval_ppl:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "36f3e09b", + "metadata": {}, + "source": [ + "## 6. Save the Fine-tuned Model (Optional)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "472d0991", + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment to save the model\n", + "# OUTPUT_DIR = \"./ot-llama-finetuned\"\n", + "# model.save_pretrained(OUTPUT_DIR)\n", + "# tokenizer.save_pretrained(OUTPUT_DIR)\n", + "# print(f\"Model saved to {OUTPUT_DIR}\")" + ] + }, + { + "cell_type": "markdown", + "id": "24de471e", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook demonstrated:\n", + "\n", + "1. **Loading a HuggingFace Llama model** using `AutoModelForCausalLM`\n", + "2. **Configuring the optical transform** with `OtTransformConfig`\n", + "3. **Applying the transform pass** using `optical_transformer_module_transform_pass`\n", + "4. **Preparing a dataset** for causal language modeling\n", + "5. **Running fine-tuning** with the transformed optical model\n", + "\n", + "### Key Points\n", + "\n", + "- Use `attn_implementation=\"eager\"` when loading the model for compatibility\n", + "- The transform pass uses regex patterns to match layer names\n", + "- Training mode (`model.train()`) enables statistics updates in optical layers\n", + "- The optical quantization adds noise that the model learns to be robust against\n", + "\n", + "### References\n", + "\n", + "- [Optical Transformers Paper](https://arxiv.org/abs/2302.10360)\n", + "- MASE ONN Transform: `src/chop/passes/module/transforms/onn/`" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mase", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/tutorials/newcompute/onn/README.md b/docs/tutorials/newcompute/onn/README.md new file mode 100644 index 000000000..a2d813af3 --- /dev/null +++ b/docs/tutorials/newcompute/onn/README.md @@ -0,0 +1,138 @@ +# Optical Neural Network (ONN) Transform API + +This module provides tools for transforming PyTorch neural networks to simulate optical computing behavior, based on the [Optical Transformers paper](https://arxiv.org/abs/2302.10360). + +## Installation + +```bash +pip install mase-triton +``` + +## Quick Start + +```python +from chop.passes.module.transforms.onn.transform import ( + OtTransformConfig, + optical_transformer_module_transform_pass, +) + +# Create configuration +config = OtTransformConfig.create_default() + +# Transform a model +pass_args = { + "by": "regex_name", + r"model\.layers\.\d+\.self_attn": config, + r"model\.layers\.\d+\.mlp\..*_proj": config, +} +model = optical_transformer_module_transform_pass(model, pass_args) +``` + +## API Reference + +### `OtTransformConfig` + +Configuration dictionary for optical transform parameters. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `q_levels` | int | 256 | Number of quantization levels ($2^n$ for n-bit) | +| `q_lut_min` | float | 0.020040 | Minimum LUT value for quantization | +| `q_smooth_factor` | float | 0.9 | Smoothing factor for statistics updates | +| `q_init_seed` | int | 0 | Random seed for Triton kernels | +| `q_bypass` | bool | False | Bypass optical quantization if True | + +```python +# Create default config +config = OtTransformConfig.create_default() + +# Customize +config["q_levels"] = 512 # 9-bit quantization +config["q_smooth_factor"] = 0.1 +``` + +### `optical_transformer_module_transform_pass` + +Transform supported modules in a network to their optical equivalents. + +```python +optical_transformer_module_transform_pass(network, pass_args) -> torch.nn.Module +``` + +**Parameters:** +- `network`: The PyTorch model to transform +- `pass_args`: Configuration dictionary with: + - `by`: Matching mode - `"name"` (exact) or `"regex_name"` (regex pattern) + - Layer patterns mapped to `OtTransformConfig` dicts + - `default`: Optional fallback config + +**Supported Transformations:** + +| Original Module | Optical Equivalent | +|-----------------|-------------------| +| `torch.nn.Linear` | `OtLinear` | +| `LlamaAttention` | `OtLlamaAttention` | + +### `OtLinear` + +Optical equivalent of `torch.nn.Linear` with quantized matrix multiplication. + +```python +from chop.passes.module.transforms.onn.transform import OtLinear + +# Convert from existing linear layer +linear_onn = OtLinear.from_linear(linear, **config) +``` + +### `OtLlamaAttention` + +Optical equivalent of HuggingFace's `LlamaAttention` with quantized scaled dot-product attention. + +```python +from chop.passes.module.transforms.onn.transform import OtLlamaAttention + +# Convert from existing attention layer +attn_onn = OtLlamaAttention.from_pretrained(attn, **config) +``` + +## Important Notes + +### Quantization Statistics Warmup + +Optical layers require calibration before use. Run a few forward passes in **training mode** first: + +```python +model.train() +with torch.no_grad(): + for batch in warmup_batches: + _ = model(**batch) +``` + +Without warmup, statistics are `[inf, -inf]` and outputs will be NaN. + +### Training vs Evaluation Mode + +- **Training mode** (`model.train()`): Statistics are updated with each forward pass +- **Evaluation mode** (`model.eval()`): Statistics are frozen + +### Attention Implementation + +When loading HuggingFace models, use eager attention for compatibility: + +```python +model = AutoModelForCausalLM.from_pretrained( + model_name, + attn_implementation="eager", +) +``` + + +## Source Code + +- Transform pass: `src/chop/passes/module/transforms/onn/transform.py` +- Linear layer: `src/chop/passes/module/transforms/onn/layers/linear.py` +- Attention layer: `src/chop/passes/module/transforms/onn/layers/attn.py` + +## References + +- [Optical Transformers: End-to-end Optical Training of Transformer Models](https://arxiv.org/abs/2302.10360) diff --git a/setup.py b/setup.py index 2de937aaf..baeff73f5 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ -from setuptools import setup, find_packages import sys +from setuptools import find_packages, setup + def is_cuda_available(): try: @@ -98,7 +99,7 @@ def get_system(): ] if is_cuda_available(): - requirements += ["mase-triton", "pycuda", "tensorrt"] + requirements += ["mase-triton>=0.0.6.post4", "pycuda", "tensorrt"] setup( name="mase-tools", diff --git a/src/chop/passes/module/transforms/bitflip/__init__.py b/src/chop/passes/module/transforms/bitflip/__init__.py index 5c705f480..dd03ca3b0 100644 --- a/src/chop/passes/module/transforms/bitflip/__init__.py +++ b/src/chop/passes/module/transforms/bitflip/__init__.py @@ -1 +1,3 @@ from .bitflip_transform import bitflip_module_transform_pass + +__all__ = ["bitflip_module_transform_pass"] diff --git a/src/chop/passes/module/transforms/bitflip/bitflip_transform.py b/src/chop/passes/module/transforms/bitflip/bitflip_transform.py index a6dc37c59..61d3cb66d 100644 --- a/src/chop/passes/module/transforms/bitflip/bitflip_transform.py +++ b/src/chop/passes/module/transforms/bitflip/bitflip_transform.py @@ -1,14 +1,14 @@ try: - import mase_triton - import mase_triton.random_bitflip + from mase_triton.random_bitflip.layers import RandomBitFlipLinear MASE_TRITON_AVAILABLE = True except ImportError: MASE_TRITON_AVAILABLE = False import torch -from ...state_dict_map import match_a_pattern + from ...module_modify_helper import replace_by_name +from ...state_dict_map import match_a_pattern def get_config_by_name(config: dict, name: str): @@ -44,7 +44,7 @@ def get_layer_config( if MASE_TRITON_AVAILABLE: BITFLIP_CLS_MAP = { - torch.nn.Linear: mase_triton.random_bitflip.layers.RandomBitFlipLinear, + torch.nn.Linear: RandomBitFlipLinear, } def bitflip_module_transform_pass( @@ -87,4 +87,4 @@ def bitflip_module_transform_pass( def bitflip_module_transform_pass( network: torch.nn.Module, pass_args: dict ) -> torch.nn.Module: - raise RuntimeError("mase_triton is not available, please install it first.") + raise RuntimeError("mase-triton is not available, please install it first.") diff --git a/src/chop/passes/module/transforms/onn/__init__.py b/src/chop/passes/module/transforms/onn/__init__.py new file mode 100644 index 000000000..19f18a190 --- /dev/null +++ b/src/chop/passes/module/transforms/onn/__init__.py @@ -0,0 +1,3 @@ +from .transform import optical_transformer_module_transform_pass + +__all__ = ["optical_transformer_module_transform_pass"] diff --git a/src/chop/passes/module/transforms/onn/layers/__init__.py b/src/chop/passes/module/transforms/onn/layers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/chop/passes/module/transforms/onn/layers/attn.py b/src/chop/passes/module/transforms/onn/layers/attn.py new file mode 100644 index 000000000..1f4ccc16e --- /dev/null +++ b/src/chop/passes/module/transforms/onn/layers/attn.py @@ -0,0 +1,358 @@ +from typing import Optional + +import torch +from mase_triton.optical_compute import OpticalTransformerFunctions as OTFunctions +from mase_triton.optical_compute.layers import OpticalTransformerLinear as OTLinear +from mase_triton.optical_compute.layers import optical_transformer_update_qstats +from mase_triton.utils.torch_module import get_layer_name, set_layer_by_name +from torch import Tensor, nn +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaConfig, + LlamaDecoderLayer, + LlamaForCausalLM, + apply_rotary_pos_emb, + eager_attention_forward, + repeat_kv, +) + + +def ot_eager_attention_forward( + module: "OtLlamaAttention", + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + """ + Optical Transformer Scaled Dot-Product Attention. + + Computes scaled dot-product attention with quantized matrix multiplications + to simulate optical neural network hardware constraints. This function applies + quantization to both the query-key and attention-value matrix products. + + The quantization statistics (min/max values) are updated in-place during training + using an exponential moving average controlled by ``q_smooth_factor``. + + Args: + query (Tensor): Query tensor of shape ``(batch, heads, seq_len, head_dim)``. + key (Tensor): Key tensor of shape ``(batch, kv_heads, seq_len, head_dim)``. + value (Tensor): Value tensor of shape ``(batch, kv_heads, seq_len, head_dim)``. + attention_mask (Tensor, optional): Attention mask. Default: None. + dropout (float): Dropout probability. Default: 0.0. + scaling (float, optional): Scaling factor. If None, uses ``1/sqrt(head_dim)``. + + Returns: + Tensor: Attention output of shape ``(batch, heads, seq_len, head_dim)``. + """ + with torch.no_grad(): + query_min_max_ = optical_transformer_update_qstats( + query, + module.query_min_max, + module.q_min_max_quantiles, + module.stat_smooth_factor, + ) + module.query_min_max.copy_(query_min_max_) + key_min_max_ = optical_transformer_update_qstats( + key, + module.key_min_max, + module.q_min_max_quantiles, + module.stat_smooth_factor, + ) + module.key_min_max.copy_(key_min_max_) + key_states = repeat_kv(key, module.num_key_value_groups) + if not module.qk_min_max.isfinite().all(): + attn_weights = torch.matmul(query, key_states.transpose(-1, -2)) * scaling + qk_min_max_ = optical_transformer_update_qstats( + attn_weights, + module.qk_min_max, + module.q_min_max_quantiles, + module.stat_smooth_factor, + ) + module.qk_min_max.copy_(qk_min_max_) + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + # attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + attn_weights, _ = OTFunctions.quantized_matmul_fn( + a=query.contiguous(), + b=key_states.transpose(2, 3).contiguous(), + a_min=module.query_min_max[0], + a_max=module.query_min_max[1], + b_min=module.key_min_max[0], + b_max=module.key_min_max[1], + b_lut_min=module.q_lut_min, + o_min=module.qk_min_max[0], + o_max=module.qk_min_max[1], + q_levels=module.q_levels, + q_seed=module.seed.item(), + skip_quantize=False, + ) + attn_weights = attn_weights * scaling + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training + ) + # attn_output = torch.matmul(attn_weights, value_states) + + with torch.no_grad(): + attn_min_max_ = optical_transformer_update_qstats( + attn_weights, + module.attn_min_max, + module.q_min_max_quantiles, + module.stat_smooth_factor, + ) + module.attn_min_max.copy_(attn_min_max_) + value_min_max_ = optical_transformer_update_qstats( + value_states, + module.value_min_max, + module.q_min_max_quantiles, + module.stat_smooth_factor, + ) + module.value_min_max.copy_(value_min_max_) + attn_ = torch.matmul(attn_weights, value_states) + av_min_max_ = optical_transformer_update_qstats( + attn_, + module.av_min_max, + module.q_min_max_quantiles, + module.stat_smooth_factor, + ) + module.av_min_max.copy_(av_min_max_) + + attn_output, _ = OTFunctions.quantized_matmul_fn( + a=attn_weights.contiguous(), + b=value_states.contiguous(), + a_min=module.attn_min_max[0], + a_max=module.attn_min_max[1], + b_min=module.value_min_max[0], + b_max=module.value_min_max[1], + b_lut_min=module.q_lut_min, + o_min=module.av_min_max[0], + o_max=module.av_min_max[1], + q_levels=module.q_levels, + q_seed=module.seed.item(), + skip_quantize=module.bypass, + ) + with torch.no_grad(): + module.seed += 1 + + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class OtLlamaAttention(nn.Module): + """ + Optical Transformer attention module for LLaMA models. + + This module replaces the standard HuggingFace LlamaAttention with an optical + transformer equivalent that simulates quantized matrix multiplications as would + occur in optical neural network hardware. The implementation is based on the + `Optical Transformers paper `_. + + The attention computation uses optical transformer scaled dot-product attention + (SDPA) which applies quantization to the query-key and attention-value matrix + multiplications to simulate optical compute constraints. + + Args: + config: HuggingFace LLaMA configuration object. + layer_idx (int): Index of this attention layer in the model. + q_levels (int): Number of quantization levels for optical simulation. Default: 256. + q_lut_min (float): Minimum value for the lookup table used in quantization. Default: 0.020040. + q_quantiles (tuple[float, float], optional): Quantile range for min/max statistics. + If None, uses absolute min/max. Default: None. + q_smooth_factor (float): Exponential moving average factor for updating + running min/max statistics during training. Default: 0.9. + q_init_seed (int): Random seed for quantization noise initialization. Default: 0. + q_bypass (bool): If True, bypasses optical quantization and uses standard + PyTorch attention. Useful for debugging or comparison. Default: False. + + Attributes: + query_min_max (Tensor): Running min/max statistics for query tensors. + key_min_max (Tensor): Running min/max statistics for key tensors. + value_min_max (Tensor): Running min/max statistics for value tensors. + qk_min_max (Tensor): Running min/max statistics for query-key products. + attn_min_max (Tensor): Min/max range for attention weights (fixed at [0, 1]). + av_min_max (Tensor): Running min/max statistics for attention-value products. + seed (Tensor): Current random seed state for quantization. + + Example: + .. code-block:: python + + from chop.passes.module.transforms.onn.layers.attn import OtLlamaAttention + + # Create from existing HuggingFace attention layer + ot_attn = OtLlamaAttention.from_pretrained( + hf_attention_layer, + layer_idx=0, + q_levels=256, + q_bypass=False, + ) + """ + + def __init__( + self, + config: LlamaConfig, + layer_idx: int, + q_levels: int = 256, + q_lut_min: float = 0.020040, + q_quantiles: tuple[float, float] | None = None, + q_smooth_factor: float = 0.9, + q_init_seed: int = 0, + q_bypass: bool = False, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.q_levels = q_levels + self.q_lut_min = q_lut_min + if q_quantiles is None: + self.q_min_max_quantiles = None + else: + self.register_buffer("q_min_max_quantiles", torch.tensor(q_quantiles)) + self.register_buffer( + "query_min_max", torch.tensor([float("inf"), float("-inf")]) + ) + self.register_buffer("key_min_max", torch.tensor([float("inf"), float("-inf")])) + self.register_buffer("qk_min_max", torch.tensor([float("inf"), float("-inf")])) + self.register_buffer("attn_min_max", torch.tensor([float(0), float(1)])) + self.register_buffer( + "value_min_max", torch.tensor([float("inf"), float("-inf")]) + ) + self.register_buffer("av_min_max", torch.tensor([float("inf"), float("-inf")])) + self.register_buffer("seed", torch.tensor(q_init_seed, dtype=torch.int64)) + self.stat_smooth_factor = q_smooth_factor + self.bypass = q_bypass + + self.query_min_max: Tensor + self.key_min_max: Tensor + self.qk_min_max: Tensor + self.attn_min_max: Tensor + self.value_min_max: Tensor + self.av_min_max: Tensor + self.seed: Tensor + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value=None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + if self.bypass: + attn_output, attn_weights = eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + else: + attn_output, attn_weights = ot_eager_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + self.seed += 1 + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + @classmethod + def from_pretrained( + cls, + attn: LlamaAttention, + layer_idx: int, + q_levels: int = 256, + q_lut_min: float = 0.020040, + q_quantiles: tuple[float, float] | None = None, + q_smooth_factor: float = 0.9, + q_init_seed: int = 0, + q_bypass: bool = False, + ) -> "OtLlamaAttention": + assert isinstance(attn, LlamaAttention) + ot_attn = cls( + attn.config, + layer_idx, + q_levels, + q_lut_min, + q_quantiles, + q_smooth_factor, + q_init_seed, + q_bypass, + ) + ot_attn.to(attn.o_proj.weight.dtype) + ot_attn.load_state_dict(attn.state_dict(), strict=False) + return ot_attn diff --git a/src/chop/passes/module/transforms/onn/layers/linear.py b/src/chop/passes/module/transforms/onn/layers/linear.py new file mode 100644 index 000000000..91dcf57d0 --- /dev/null +++ b/src/chop/passes/module/transforms/onn/layers/linear.py @@ -0,0 +1,30 @@ +""" +Optical Transformer Linear Layer. + +This module provides the optical transformer linear layer implementation +by importing from the mase-triton package. +""" + +from mase_triton.optical_compute import layers as OTLayers + +#: Optical Transformer Linear layer. +#: +#: This is an alias to ``mase_triton.optical_compute.layers.OpticalTransformerLinear``. +#: It replaces standard ``torch.nn.Linear`` layers with quantized optical transformer +#: equivalents that simulate optical neural network hardware constraints. +#: +#: The layer applies quantization to both the input activations and weights during +#: matrix multiplication, and tracks running min/max statistics for calibration. +#: +#: Use the ``from_linear`` class method to convert an existing ``torch.nn.Linear``: +#: +#: .. code-block:: python +#: +#: from chop.passes.module.transforms.onn.layers.linear import OtLinear +#: +#: ot_linear = OtLinear.from_linear( +#: linear_layer, +#: q_levels=256, +#: q_lut_min=0.020040, +#: ) +OtLinear = OTLayers.OpticalTransformerLinear diff --git a/src/chop/passes/module/transforms/onn/transform.py b/src/chop/passes/module/transforms/onn/transform.py new file mode 100644 index 000000000..edbb9eff5 --- /dev/null +++ b/src/chop/passes/module/transforms/onn/transform.py @@ -0,0 +1,180 @@ +try: + import mase_triton + + MASE_TRITON_IS_AVAILABLE = True +except ImportError: + MASE_TRITON_IS_AVAILABLE = False + +from typing import TypedDict + +import torch +from transformers.models.llama.modeling_llama import LlamaAttention as HfLlamaAttention + +from ...module_modify_helper import replace_by_name +from ...state_dict_map import match_a_pattern +from .layers.attn import OtLlamaAttention +from .layers.linear import OtLinear + + +def get_config_by_name(config: dict, name: str): + if name in config: + return config[name] + else: + if "default" in config: + return config["default"] + else: + return None + + +def get_config_by_regex_name(config: dict, name: str): + matched_pattern = match_a_pattern(name, config.keys()) + if matched_pattern is None: + if "default" in config: + return config["default"] + else: + return None + else: + return config[matched_pattern] + + +def get_layer_config( + layer_name_to_config: dict[str, dict], use_regex: bool, layer_name: str +) -> dict | None: + if use_regex: + config = get_config_by_regex_name(layer_name_to_config, layer_name) + else: + config = get_config_by_name(layer_name_to_config, layer_name) + return config + + +class OtTransformConfig(TypedDict): + q_levels: int + q_lut_min: float + q_smooth_factor: float + q_init_seed: int + q_bypass: bool + + @classmethod + def create_default(cls) -> "OtTransformConfig": + return cls( + q_levels=256, + q_lut_min=0.020040, + q_smooth_factor=0.9, + q_init_seed=0, + q_bypass=False, + ) + + +if MASE_TRITON_IS_AVAILABLE: + _SUPPORTED_MODULE_CLS = (torch.nn.Linear, HfLlamaAttention) + + def optical_transformer_module_transform_pass( + network: torch.nn.Module, pass_args: dict + ) -> torch.nn.Module: + """ + Transform a neural network by replacing supported modules with their optical transformer equivalents. + + This pass simulates optical neural network (ONN) computation by replacing standard PyTorch + modules with quantized optical transformer layers. The optical transformer model is based on + the `Optical Transformers paper `_. + + Supported module replacements: + + - ``torch.nn.Linear`` → ``OtLinear`` + - ``transformers.models.llama.modeling_llama.LlamaAttention`` → ``OtLlamaAttention`` + + Args: + network (torch.nn.Module): The input network to be transformed. + pass_args (dict): A dictionary containing transformation configurations. + + - ``by`` (str): Layer matching strategy. Either ``'name'`` for exact name matching + or ``'regex_name'`` for regex-based pattern matching. Defaults to ``'regex_name'``. + - ``default`` (dict, optional): Default configuration applied to all matching layers. + - ```` (dict): Per-layer configuration. Each layer config + can contain the following keys: + + - ``q_levels`` (int): Number of quantization levels. Default: 256. + - ``q_lut_min`` (float): Minimum value for lookup table. Default: 0.020040. + - ``q_smooth_factor`` (float): Smoothing factor for running statistics. Default: 0.9. + - ``q_init_seed`` (int): Random seed for quantization initialization. Default: 0. + - ``q_bypass`` (bool): If True, bypass optical quantization. Default: False. + + Returns: + torch.nn.Module: The transformed network with optical transformer modules. + + Raises: + RuntimeError: If ``mase-triton`` is not installed. + + Example: + .. code-block:: python + + from chop.passes.module.transforms.onn import optical_transformer_module_transform_pass + + # Transform all linear layers with default config + pass_args = { + "by": "regex_name", + "default": { + "q_levels": 256, + "q_lut_min": 0.020040, + "q_bypass": False, + } + } + transformed_model = optical_transformer_module_transform_pass(model, pass_args) + + Note: + This pass requires the ``mase-triton`` package to be installed. + Install via ``pip install mase-triton``. + """ + by = pass_args.pop("by", "regex_name") + assert by in [ + "name", + "regex_name", + ], f"`by` can be either 'name' or 'regex_name', but got {by}" + # replace attn layers if any + for m_name, m in network.named_modules(): + if not isinstance(m, HfLlamaAttention): + continue + m_config = get_layer_config( + pass_args, use_regex=by == "regex_name", layer_name=m_name + ) + if m_config is None: + continue + if isinstance(m, HfLlamaAttention): + new_m = OtLlamaAttention.from_pretrained( + m, layer_idx=m.layer_idx, **m_config + ) + elif isinstance(m, _SUPPORTED_MODULE_CLS): + continue + else: + raise NotImplementedError( + f"ONN transform for type {type(m)} is supported" + ) + replace_by_name(network, name=m_name, module=new_m) + # replace linear layers if any + for m_name, m in network.named_modules(): + if not isinstance(m, torch.nn.Linear): + continue + m_config = get_layer_config( + pass_args, use_regex=by == "regex_name", layer_name=m_name + ) + if m_config is None: + continue + if isinstance(m, torch.nn.Linear): + new_m = OtLinear.from_linear(m, **m_config) + elif isinstance(m, _SUPPORTED_MODULE_CLS): + continue + else: + raise NotImplementedError( + f"ONN transform for type {type(m)} is supported" + ) + replace_by_name(network, name=m_name, module=new_m) + return network + +else: + + def optical_transformer_module_transform_pass( + network: torch.nn.Module, pass_args: dict + ) -> torch.nn.Module: + raise RuntimeError( + "`mase-triton` is needed for ONN transform. Install via `pip install mase-triton`." + ) diff --git a/test/passes/module/transforms/onn/test_optical_transformer.py b/test/passes/module/transforms/onn/test_optical_transformer.py new file mode 100644 index 000000000..6e99719d8 --- /dev/null +++ b/test/passes/module/transforms/onn/test_optical_transformer.py @@ -0,0 +1,121 @@ +import unittest + +import torch +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaConfig + +from chop.passes.module.transforms.onn.transform import ( + OtLinear, + OtLlamaAttention, + OtTransformConfig, + optical_transformer_module_transform_pass, +) + + +def _calculate_snr(x, noisy_x): + noise = noisy_x - x + + signal_power = torch.sum(x**2) + noise_power = torch.sum(noise**2) + + snr = signal_power / noise_power + snr_db = 10 * torch.log10(snr) + return snr_db.item() + + +class TestOnnTransform(unittest.TestCase): + def test_ot_linear_layer(self): + linear = torch.nn.Linear(in_features=32, out_features=64) + onn_cfg = OtTransformConfig.create_default() + linear_onn = OtLinear.from_linear(linear, **onn_cfg) + + x = torch.randn(2, 32) + y = linear(x) + y_onn = linear_onn(x) + + snr = _calculate_snr(y, y_onn) + assert snr > 23 + + def test_ot_llama_attn_layer(self): + onn_config = OtTransformConfig.create_default() + onn_config["q_levels"] = 512 + onn_config["q_smooth_factor"] = 0.1 + model_name = "AICrossSim/clm-60m" + hf_config = LlamaConfig.from_pretrained(model_name) + batch_size = 1 + seq_len = 16 + head_dim = hf_config.hidden_size // hf_config.num_attention_heads + + attn = LlamaAttention(config=hf_config, layer_idx=0) + + pos_emb = torch.ones(batch_size, seq_len, head_dim) + x = 3 * torch.randn(batch_size, seq_len, hf_config.hidden_size) + + y, _ = attn(x, (pos_emb, pos_emb), None) + y: torch.Tensor + assert y.isfinite().all() + + attn_onn = OtLlamaAttention.from_pretrained(attn, layer_idx=0, **onn_config) + attn_onn.train() + for _ in range(3): + y_onn, _ = attn_onn(x, (pos_emb, pos_emb), None) + + snr = _calculate_snr(y, y_onn) + print(f"Attn SNR: {snr:.2f} dB") + assert snr > 1 + + def test_optical_transformer_module_transform_pass(self): + onn_config = OtTransformConfig.create_default() + onn_config["q_levels"] = 512 + onn_config["q_smooth_factor"] = 0.1 + model_name = "AICrossSim/clm-60m" + hf_config = LlamaConfig.from_pretrained(model_name) + batch_size = 1 + seq_len = 16 + head_dim = hf_config.hidden_size // hf_config.num_attention_heads + + class Network(torch.nn.Module): + def __init__(self): + super().__init__() + self.attn = LlamaAttention(config=hf_config, layer_idx=0) + self.linear = torch.nn.Linear( + in_features=hf_config.hidden_size, + out_features=hf_config.hidden_size, + ) + + def forward(self, x, pos_emb): + attn_output, _ = self.attn(x, (pos_emb, pos_emb), None) + output = self.linear(attn_output) + return output, None + + network = Network() + pos_emb = torch.ones(batch_size, seq_len, head_dim) + x = 3 * torch.randn(batch_size, seq_len, hf_config.hidden_size) + + y, _ = network(x, pos_emb) + y: torch.Tensor + assert y.isfinite().all() + + pass_args = { + "by": "regex_name", + "attn": onn_config, + "linear": onn_config, + r"attn\.(q|k|v|o)_proj": onn_config, + } + + network_onn = optical_transformer_module_transform_pass(network, pass_args) + assert isinstance(network_onn.attn, OtLlamaAttention) + assert isinstance(network_onn.linear, OtLinear) + assert isinstance(network_onn.attn.q_proj, OtLinear) + assert isinstance(network_onn.attn.k_proj, OtLinear) + assert isinstance(network_onn.attn.v_proj, OtLinear) + assert isinstance(network_onn.attn.o_proj, OtLinear) + + print(network_onn) + network_onn.train() + for _ in range(3): + y_onn, _ = network_onn(x, pos_emb) + assert y_onn.isfinite().all() + + snr = _calculate_snr(y, y_onn) + assert snr > 1 + print(f"Network SNR: {snr:.2f} dB")