From 1bf15189e4cc3ab8ca6748a37fd4f43f1b24a6ce Mon Sep 17 00:00:00 2001 From: Jeffrey in Tiamat Date: Tue, 17 Mar 2026 14:40:09 +0000 Subject: [PATCH] chore: migrate snn --- configs/snn/cnv_toy.toml | 31 - src/chop/actions/transform.py | 9 - src/chop/models/vision/snn/__init__.py | 7 - src/chop/models/vision/snn/snn_toy.py | 35 - .../models/vision/snn/spikingResformer.py | 227 --- src/chop/nn/snn/__init__.py | 1 - src/chop/nn/snn/auto_cuda/__init__.py | 0 src/chop/nn/snn/auto_cuda/base.py | 1556 ----------------- src/chop/nn/snn/auto_cuda/cfunction.py | 426 ----- src/chop/nn/snn/auto_cuda/example.py | 44 - src/chop/nn/snn/auto_cuda/generator.py | 667 ------- src/chop/nn/snn/auto_cuda/neuron_kernel.py | 1045 ----------- src/chop/nn/snn/auto_cuda/readme.md | 152 -- src/chop/nn/snn/auto_cuda/ss_neuron_kernel.py | 708 -------- src/chop/nn/snn/base.py | 329 ---- src/chop/nn/snn/configure.py | 77 - src/chop/nn/snn/cuda_utils.py | 326 ---- src/chop/nn/snn/functional/__init__.py | 3 - src/chop/nn/snn/functional/functional.py | 114 -- src/chop/nn/snn/functional/surrogate.py | 70 - src/chop/nn/snn/functional/utils.py | 0 src/chop/nn/snn/modules/__init__.py | 115 -- src/chop/nn/snn/modules/batch_norm1d.py | 54 - src/chop/nn/snn/modules/batch_norm2d.py | 54 - src/chop/nn/snn/modules/batch_norm3d.py | 54 - src/chop/nn/snn/modules/conv1d.py | 64 - src/chop/nn/snn/modules/conv2d.py | 63 - src/chop/nn/snn/modules/conv3d.py | 64 - src/chop/nn/snn/modules/embedding.py | 48 - src/chop/nn/snn/modules/flatten.py | 52 - src/chop/nn/snn/modules/gelu.py | 32 - src/chop/nn/snn/modules/group_norm.py | 49 - src/chop/nn/snn/modules/layernorm.py | 31 - src/chop/nn/snn/modules/linear.py | 107 -- src/chop/nn/snn/modules/modules.py | 151 -- src/chop/nn/snn/modules/neuron/__init__.py | 5 - src/chop/nn/snn/modules/neuron/ifnode.py | 339 ---- src/chop/nn/snn/modules/neuron/lifnode.py | 569 ------ src/chop/nn/snn/modules/neuron/neuron.py | 265 --- .../snn/modules/neuron/parametriclifnode.py | 192 -- src/chop/nn/snn/modules/neuron/st_bifnode.py | 82 - src/chop/nn/snn/modules/pool1d.py | 158 -- src/chop/nn/snn/modules/pool2d.py | 161 -- src/chop/nn/snn/modules/pool3d.py | 161 -- src/chop/nn/snn/modules/roberta/__init__.py | 1 - src/chop/nn/snn/modules/roberta/attention.py | 265 --- src/chop/nn/snn/modules/silu.py | 31 - src/chop/nn/snn/modules/softmax.py | 47 - .../nn/snn/modules/spiking_self_attention.py | 261 --- src/chop/nn/snn/modules/surrogate.py | 233 --- src/chop/nn/snn/modules/upsample.py | 54 - src/chop/nn/snn/modules/utils.py | 0 src/chop/nn/snn/readme.md | 56 - src/chop/passes/__init__.py | 2 - src/chop/passes/graph/__init__.py | 2 - src/chop/passes/graph/transforms/__init__.py | 1 - .../passes/graph/transforms/snn/__init__.py | 1 - .../passes/graph/transforms/snn/ann2snn.py | 240 --- .../passes/module/module_modify_helper.py | 10 +- src/chop/passes/module/state_dict_map.py | 99 -- src/chop/passes/module/transforms/__init__.py | 1 - .../passes/module/transforms/snn/__init__.py | 1 - .../passes/module/transforms/snn/ann2snn.py | 181 -- test/nn/snn/test_ann2snn.py | 218 --- .../ann2snn/test_ann2snn_module_roberta.py | 128 -- 65 files changed, 1 insertion(+), 10528 deletions(-) delete mode 100644 configs/snn/cnv_toy.toml delete mode 100644 src/chop/models/vision/snn/__init__.py delete mode 100644 src/chop/models/vision/snn/snn_toy.py delete mode 100644 src/chop/models/vision/snn/spikingResformer.py delete mode 100644 src/chop/nn/snn/__init__.py delete mode 100644 src/chop/nn/snn/auto_cuda/__init__.py delete mode 100644 src/chop/nn/snn/auto_cuda/base.py delete mode 100644 src/chop/nn/snn/auto_cuda/cfunction.py delete mode 100644 src/chop/nn/snn/auto_cuda/example.py delete mode 100644 src/chop/nn/snn/auto_cuda/generator.py delete mode 100644 src/chop/nn/snn/auto_cuda/neuron_kernel.py delete mode 100644 src/chop/nn/snn/auto_cuda/readme.md delete mode 100644 src/chop/nn/snn/auto_cuda/ss_neuron_kernel.py delete mode 100644 src/chop/nn/snn/base.py delete mode 100644 src/chop/nn/snn/configure.py delete mode 100644 src/chop/nn/snn/cuda_utils.py delete mode 100644 src/chop/nn/snn/functional/__init__.py delete mode 100644 src/chop/nn/snn/functional/functional.py delete mode 100644 src/chop/nn/snn/functional/surrogate.py delete mode 100644 src/chop/nn/snn/functional/utils.py delete mode 100644 src/chop/nn/snn/modules/__init__.py delete mode 100644 src/chop/nn/snn/modules/batch_norm1d.py delete mode 100644 src/chop/nn/snn/modules/batch_norm2d.py delete mode 100644 src/chop/nn/snn/modules/batch_norm3d.py delete mode 100644 src/chop/nn/snn/modules/conv1d.py delete mode 100644 src/chop/nn/snn/modules/conv2d.py delete mode 100644 src/chop/nn/snn/modules/conv3d.py delete mode 100644 src/chop/nn/snn/modules/embedding.py delete mode 100644 src/chop/nn/snn/modules/flatten.py delete mode 100644 src/chop/nn/snn/modules/gelu.py delete mode 100644 src/chop/nn/snn/modules/group_norm.py delete mode 100644 src/chop/nn/snn/modules/layernorm.py delete mode 100644 src/chop/nn/snn/modules/linear.py delete mode 100644 src/chop/nn/snn/modules/modules.py delete mode 100644 src/chop/nn/snn/modules/neuron/__init__.py delete mode 100644 src/chop/nn/snn/modules/neuron/ifnode.py delete mode 100644 src/chop/nn/snn/modules/neuron/lifnode.py delete mode 100644 src/chop/nn/snn/modules/neuron/neuron.py delete mode 100644 src/chop/nn/snn/modules/neuron/parametriclifnode.py delete mode 100644 src/chop/nn/snn/modules/neuron/st_bifnode.py delete mode 100644 src/chop/nn/snn/modules/pool1d.py delete mode 100644 src/chop/nn/snn/modules/pool2d.py delete mode 100644 src/chop/nn/snn/modules/pool3d.py delete mode 100644 src/chop/nn/snn/modules/roberta/__init__.py delete mode 100644 src/chop/nn/snn/modules/roberta/attention.py delete mode 100644 src/chop/nn/snn/modules/silu.py delete mode 100644 src/chop/nn/snn/modules/softmax.py delete mode 100644 src/chop/nn/snn/modules/spiking_self_attention.py delete mode 100644 src/chop/nn/snn/modules/surrogate.py delete mode 100644 src/chop/nn/snn/modules/upsample.py delete mode 100644 src/chop/nn/snn/modules/utils.py delete mode 100644 src/chop/nn/snn/readme.md delete mode 100644 src/chop/passes/graph/transforms/snn/__init__.py delete mode 100644 src/chop/passes/graph/transforms/snn/ann2snn.py delete mode 100644 src/chop/passes/module/transforms/snn/__init__.py delete mode 100644 src/chop/passes/module/transforms/snn/ann2snn.py delete mode 100644 test/nn/snn/test_ann2snn.py delete mode 100644 test/passes/module/transforms/ann2snn/test_ann2snn_module_roberta.py diff --git a/configs/snn/cnv_toy.toml b/configs/snn/cnv_toy.toml deleted file mode 100644 index 1403b78a9..000000000 --- a/configs/snn/cnv_toy.toml +++ /dev/null @@ -1,31 +0,0 @@ -# basics -model = "cnv_toy" -dataset = "cifar10" -# training -training_optimizer = "adam" -learning_rate = 0.01 -max_epochs = 3 -batch_size = 32 -# torch lightning -task = "classification" -num_workers = 0 -num_devices = 1 -accelerator = "gpu" -project_dir = "../mase_output" - -[transform] -style = "graph" - -[passes.ann2snn] -by = "type" -report = true -fuse = true -device = "cuda" - -[passes.ann2snn.default.config] -name = "NA" - -[passes.ann2snn.relu.config] -name = "IFNode" -mode = "99.9%" -mementum = 0.1 diff --git a/src/chop/actions/transform.py b/src/chop/actions/transform.py index 6f217c0c8..68f09a39c 100644 --- a/src/chop/actions/transform.py +++ b/src/chop/actions/transform.py @@ -384,15 +384,6 @@ def transform_graph( PASSES["summarize_quantization"]( graph, {"save_dir": pass_save_dir, "original_graph": ori_graph} ) - case "ann2snn": - input_generator = InputGenerator( - model_info=model_info, - data_module=data_module, - task=task, - which_dataloader="train", - ) - pass_config["train_data_loader"] = input_generator - graph, _ = PASSES[pass_name](graph, pass_args=pass_config) case "profile_statistics": input_generator = InputGenerator( model_info=model_info, diff --git a/src/chop/models/vision/snn/__init__.py b/src/chop/models/vision/snn/__init__.py deleted file mode 100644 index 82cf5131f..000000000 --- a/src/chop/models/vision/snn/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .snn_toy import get_snn_toy -from .spikingResformer import ( - spikingresformer_ti, - spikingresformer_s, - spikingresformer_m, - spikingresformer_l, -) diff --git a/src/chop/models/vision/snn/snn_toy.py b/src/chop/models/vision/snn/snn_toy.py deleted file mode 100644 index 5915f1ec6..000000000 --- a/src/chop/models/vision/snn/snn_toy.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch.nn as nn -from chop.nn.snn import functional -from chop.nn.snn import modules as snn_modules -from chop.nn.snn.modules import neuron as snn_neuron -import torch -from timm.models.registry import register_model -from typing import Any - - -@register_model -class SNN_toy(nn.Module): - def __init__(self, tau, num_classes): - super().__init__() - - self.layer = nn.Sequential( - snn_modules.Flatten(), - snn_modules.Linear(28 * 28, num_classes, bias=False), - snn_neuron.LIFNode( - tau=tau, surrogate_function=snn_modules.surrogate.ATan() - ), - ) - - def forward(self, x: torch.Tensor): - return self.layer(x) - - -# Getters ------------------------------------------------------------------------------ -def get_snn_toy( - info, - pretrained=False, - **kwargs: Any, -): - tau = info["tau"] - num_classes = info.num_classes - return SNN_toy(num_classes) diff --git a/src/chop/models/vision/snn/spikingResformer.py b/src/chop/models/vision/snn/spikingResformer.py deleted file mode 100644 index ff74d76df..000000000 --- a/src/chop/models/vision/snn/spikingResformer.py +++ /dev/null @@ -1,227 +0,0 @@ -from chop.nn.snn.modules.linear import Linear -import torch -import torch.nn as nn -from typing import Any, List, Mapping - -from chop.nn.snn.modules.spiking_self_attention import ( - DSSA, - GWFFN, - BN, - DownsampleLayer, - LIF, - PLIF, -) - -from chop.nn.snn.modules.conv2d import Conv2d - -from chop.nn.snn.modules.pool2d import MaxPool2d, AdaptiveAvgPool2d - -from chop.nn.snn.modules import surrogate - -from timm.models.registry import register_model - - -class SpikingResformer(nn.Module): - def __init__( - self, - layers: List[List[str]], - planes: List[int], - num_heads: List[int], - patch_sizes: List[int], - img_size=224, - T=4, - in_channels=3, - num_classes=1000, - prologue=None, - group_size=64, - activation=LIF, - **kwargs, - ): - super().__init__() - self.T = T - self.skip = ["prologue.0", "classifier"] - assert len(planes) == len(layers) == len(num_heads) == len(patch_sizes) - - if prologue is None: - self.prologue = nn.Sequential( - Conv2d(in_channels, planes[0], 7, 2, 3, bias=False, step_mode="m"), - BN(planes[0]), - MaxPool2d(kernel_size=3, stride=2, padding=1, step_mode="m"), - ) - img_size = img_size // 4 - else: - self.prologue = prologue - - self.layers = nn.Sequential() - for idx in range(len(planes)): - sub_layers = nn.Sequential() - if idx != 0: - sub_layers.append( - DownsampleLayer( - planes[idx - 1], planes[idx], stride=2, activation=activation - ) - ) - img_size = img_size // 2 - for name in layers[idx]: - if name == "DSSA": - sub_layers.append( - DSSA( - planes[idx], - num_heads[idx], - (img_size // patch_sizes[idx]) ** 2, - patch_sizes[idx], - activation=activation, - ) - ) - elif name == "GWFFN": - sub_layers.append( - GWFFN(planes[idx], group_size=group_size, activation=activation) - ) - else: - raise ValueError(name) - self.layers.append(sub_layers) - - self.avgpool = AdaptiveAvgPool2d((1, 1), step_mode="m") - self.classifier = Linear(planes[-1], num_classes, bias=False, step_mode="m") - self.init_weight() - - def init_weight(self): - for m in self.modules(): - if isinstance(m, (nn.Linear, nn.Conv2d)): - nn.init.trunc_normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def transfer(self, state_dict: Mapping[str, Any]): - _state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k} - return self.load_state_dict(_state_dict, strict=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if x.dim() != 5: - x = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1) - assert x.dim() == 5 - else: - #### [B, T, C, H, W] -> [T, B, C, H, W] - x = x.transpose(0, 1) - x = self.prologue(x) - x = self.layers(x) - x = self.avgpool(x) - x = torch.flatten(x, 2) - x = self.classifier(x) - return x - - def no_weight_decay(self): - ret = set() - for name, module in self.named_modules(): - if isinstance(module, PLIF): - ret.add(name + ".w") - return ret - - -@register_model -def spikingresformer_ti(**kwargs): - return SpikingResformer( - [ - ["DSSA", "GWFFN"] * 1, - ["DSSA", "GWFFN"] * 2, - ["DSSA", "GWFFN"] * 3, - ], - [64, 192, 384], - [1, 3, 6], - [4, 2, 1], - in_channels=3, - **kwargs, - ) - - -@register_model -def spikingresformer_s(**kwargs): - return SpikingResformer( - [ - ["DSSA", "GWFFN"] * 1, - ["DSSA", "GWFFN"] * 2, - ["DSSA", "GWFFN"] * 3, - ], - [64, 256, 512], - [1, 4, 8], - [4, 2, 1], - in_channels=3, - **kwargs, - ) - - -@register_model -def spikingresformer_m(**kwargs): - return SpikingResformer( - [ - ["DSSA", "GWFFN"] * 1, - ["DSSA", "GWFFN"] * 2, - ["DSSA", "GWFFN"] * 3, - ], - [64, 384, 768], - [1, 6, 12], - [4, 2, 1], - in_channels=3, - **kwargs, - ) - - -@register_model -def spikingresformer_l(**kwargs): - return SpikingResformer( - [ - ["DSSA", "GWFFN"] * 1, - ["DSSA", "GWFFN"] * 2, - ["DSSA", "GWFFN"] * 3, - ], - [128, 512, 1024], - [2, 8, 16], - [4, 2, 1], - in_channels=3, - **kwargs, - ) - - -@register_model -def spikingresformer_dvsg(**kwargs): - return SpikingResformer( - [ - ["DSSA", "GWFFN"] * 1, - ["DSSA", "GWFFN"] * 2, - ["DSSA", "GWFFN"] * 3, - ], - [32, 96, 192], - [1, 3, 6], - [4, 2, 1], - in_channels=3, - prologue=nn.Sequential( - Conv2d(3, 32, 3, 1, 1, bias=False, step_mode="m"), - BN(32), - ), - group_size=32, - activation=PLIF, - **kwargs, - ) - - -@register_model -def spikingresformer_cifar(**kwargs): - return SpikingResformer( - [ - ["DSSA", "GWFFN"] * 1, - ["DSSA", "GWFFN"] * 2, - ["DSSA", "GWFFN"] * 3, - ], - [64, 192, 384], - [1, 3, 6], - [4, 2, 1], - in_channels=3, - prologue=nn.Sequential( - Conv2d(3, 64, 3, 1, 1, bias=False, step_mode="m"), - BN(64), - ), - **kwargs, - ) diff --git a/src/chop/nn/snn/__init__.py b/src/chop/nn/snn/__init__.py deleted file mode 100644 index 940091db9..000000000 --- a/src/chop/nn/snn/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base import StepModule, SingleModule, MultiStepModule, MemoryModule diff --git a/src/chop/nn/snn/auto_cuda/__init__.py b/src/chop/nn/snn/auto_cuda/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/chop/nn/snn/auto_cuda/base.py b/src/chop/nn/snn/auto_cuda/base.py deleted file mode 100644 index 2e36e0c52..000000000 --- a/src/chop/nn/snn/auto_cuda/base.py +++ /dev/null @@ -1,1556 +0,0 @@ -# *************************************************************************************** -# * Title: auto_cuda -# * Reference: These directory is directly sourced from spikingJelly -# * Availability: https://github.com/fangwei123456/spikingjelly/tree/master/spikingjelly/activation_based/auto_cuda -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# *************************************************************************************** - -import numpy as np -import logging - -try: - import cupy -except BaseException as e: - logging.info(f"spikingjelly.activation_based.auto_cuda.base: {e}") - cupy = None - -import torch -import torch.nn.functional as F -import sys -import logging -from .. import cuda_utils -from .. import configure - - -def wrap_with_comment(code: str, comment: str): - if logging.DEBUG >= logging.root.level: - return ( - "\n//------" - + comment - + " start------\n" - + code - + "\n//------" - + comment - + " end--------\n\n" - ) - else: - return code - - -def startswiths(x: str, prefixes: tuple): - ret = False - for prefix in prefixes: - if x.startswith(prefix): - ret = True - - return ret - - -class CKernel: - def __init__(self, kernel_name: str): - """ - :param kernel_name: the name of kernel - :type kernel_name: str - - The base python class for simplifying the using of custom CUDA kernel. - - Some critical attributes: - - cparams: - a dict for saving parameters name and type. - - reserved_cnames: - a list for saving reserved variables names, which can not be used to name variable again. - - - Here is an example: - - .. code-block:: python - - from spikingjelly.activation_based.auto_cuda import base - - example_ck = base.CKernel(kernel_name='example_ck') - print(example_ck.full_codes) - - The outputs are: - - .. code-block:: c++ - - #include - extern "C" __global__ - void example_ck( - ) - {} - - - - A ``CKernel`` is composed of three parts: declaration, head, core, and tail. - When setting ``logging level <= DEBUG``, some debug information will be added to cuda codes or printed. - And we can check where is each part. - Here is an example: - - .. code-block:: python - - import logging - logging.basicConfig(level=logging.DEBUG) - from spikingjelly.activation_based.auto_cuda import base - - example_ck = base.CKernel(kernel_name='example_ck') - print(example_ck.full_codes) - - The outputs are: - - .. code-block:: c++ - - //------declaration start------ - - #include - extern "C" __global__ - void example_ck( - ) - - //------declaration end-------- - - - //------head start------ - { - //------head end-------- - - - //------core start------ - - //------core end-------- - - - //------tail start------ - } - //------tail end-------- - - In most cases, ``CKernel`` is used as a base class. Refer to :class:`CKernel1D ` and :class:`CKernel2D ` for more details. - """ - self.cparams = {} - self.reserved_cnames = [] - self.kernel_name = kernel_name - self._core = "" - - def check_attributes(self, **kwargs): - """ - :param kwargs: a dict of attributes - :type kwargs: dict - :return: if all ``value`` in ``kwargs[key]`` is identical to ``self.__getattribute__(key)`` - :rtype: bool - - This function can be used to check if a ``CKernel`` is changed by if any of its attributes changes. - """ - for key, value in kwargs.items(): - if value != self.__getattribute__(key): - return False - - else: - return True - - @property - def core(self): - return self._core - - @core.setter - def core(self, value): - self._core = value - - def set_contiguous(self, py_dict: dict): - """ - :param py_dict: a dict whose value is ``torch.Tensor`` or ``cupy.ndarray`` - :type py_dict: dict - - Check if all values in py_dict are ``torch.Tensor`` or ``cupy.ndarray`` and contiguous. - If not, this function will raise an error. - """ - # get contiguous - for key, value in py_dict.items(): - if isinstance(value, torch.Tensor): - value = value.contiguous() - - elif isinstance(value, cupy.ndarray): - value = cupy.ascontiguousarray(value) - else: - raise TypeError(type(value)) - - py_dict[key] = value - - def get_device(self, py_dict: dict) -> int: - """ - :param py_dict: a dict - :type py_dict: dict - - Traverse the dict and return the device id of the first met ``torch.Tensor``. - If no ``torch.Tensor`` in ``py_dict``, this function will raise an error. - """ - - for item in py_dict.values(): - if isinstance(item, torch.Tensor): - return item.get_device() - - elif isinstance(item, cupy.ndarray): - return item.device.id - - raise ValueError - - def check_device(self, device: int, py_dict: dict): - """ - :param device: the cuda device id - :type device: int - :param py_dict: a dict - :type py_dict: dict - - Check if the device id of each ``torch.Tensor`` or ``cupy.ndarray`` in py_dict is identical to ``device``. - If not, this function will raise an error. - """ - for item in py_dict.values(): - if isinstance(item, torch.Tensor): - assert item.get_device() == device - - elif isinstance(item, cupy.ndarray): - assert item.device.id == device - - def check_keys(self, py_dict: dict): - """ - :param py_dict: a dict - :type py_dict: dict - - Check if keys of ``py_dict`` are identical to keys of ``self.cparams``. - If not, this function will raise an error. - """ - if py_dict.keys() != self.cparams.keys(): - missed_keys = (py_dict.keys() | self.cparams.keys()) - ( - py_dict.keys() & self.cparams.keys() - ) - - if missed_keys.__len__() > 0: - if (missed_keys & py_dict.keys()).__len__() > 0: - msg = f"{missed_keys} is in py_dict but not in cparams!" - else: - msg = f"{missed_keys} is in cparams but not in py_dict!" - raise ValueError(msg) - - def check_ctypes(self, py_dict: dict): - """ - :param py_dict: a dict - :type py_dict: dict - - Check if the value in ``py_dict`` has the corresponding ``ctype`` in ``self.cparams``, which includes: - - ``torch.float`` or ``np.float32``------ ``'const float'`` or ``'float'`` - - ``torch.half`` or ``np.float16`` ------ ``'const half2'`` or ``'half2'`` - - ``np.int_`` ------ ``'const int'`` or ``'int'`` - - If not, this function will raise an error. - """ - for key, value in py_dict.items(): - ctype: str = self.cparams[key] - if isinstance(value, torch.Tensor): - if value.dtype == torch.float: - assert startswiths(ctype, ("const float", "float")) - - elif value.dtype == torch.half: - assert startswiths(ctype, ("const half2", "half2")) - - if isinstance(value, cupy.ndarray): - if value.dtype == np.float32: - assert startswiths(ctype, ("const float", "float")) - - elif value.dtype == np.float16: - assert startswiths(ctype, ("const half2", "half2")) - - elif value.dtype == np.int_: - assert startswiths(ctype, ("const int", "int")) - - def check_half2(self, py_dict: dict): - """ - This function is implemented for sub-class when needed. - """ - raise NotImplementedError - - def get_ptrs(self, py_dict: dict): - """ - :param py_dict: a dict - :type py_dict: dict - :return: a tuple of data ptr - :rtype: tuple - - Get the address of the first element of each ``torch.Tensor`` or ``cupy.ndarray`` in ``py_dict``. - """ - ret_list = [] - for item in py_dict.values(): - if isinstance(item, torch.Tensor): - ret_list.append(item.data_ptr()) - - elif isinstance(item, cupy.ndarray): - ret_list.append(item) - - else: - raise TypeError - return tuple(ret_list) - - def __call__(self, grid: tuple, block: tuple, py_dict: dict, *args_1, **kwargs): - """ - :param grid: the grid number of CUDA kernel - :type grid: tuple - :param block: the block number of CUDA kernel - :type block: tuple - :param py_dict: the dict that contains parameters for CUDA kernel - :type py_dict: dict - - Execute the CUDA kernel. ``*args_1, **kwargs`` are used as ``*args_1, **kwargs`` in :class:`cupy.RawKernel`. - - ``py_dict`` should contain ``key: value`` where ``key`` is the cuda kernel function param name, and ``value`` is - the variable. This dict should be one-to-one correspondence to ``self.cparams``. - - For example, if ``self.cparams`` is - - .. code-block:: python - - { - 'numel': 'const int &', - 'x': 'const float *', - 'y': 'const float *' - } - - - Then ``py_dict`` sould be - - .. code-block:: python - - { - 'numel': numel, - 'x': x, - 'y': y - } - - where ``numel, x, y`` should be ``torch.Tensor`` or ``cupy.ndarray`` with the corresponding data type, e.g., - ``x`` in ``py_dict`` should have data type ``torch.float`` because ``x`` in ``self.cparams`` have value ``'const float *'`` . - - The keys order is arbitrary because this function will sort keys to align formal and actual parameters. - - """ - - device = self.get_device(py_dict) - self.check_device(device, py_dict) - - self.set_contiguous(py_dict) - - self.check_ctypes(py_dict) - - self.check_half2(py_dict) - - py_dict = dict(sorted(py_dict.items())) - self.check_keys(py_dict) - assert sys.version_info.major >= 3 and sys.version_info.minor >= 6 - # 需要使用有序词典 - # python >= 3.6时,字典默认是有序的 - - cp_kernel = cupy.RawKernel( - self.full_codes, - self.kernel_name, - options=configure.cuda_compiler_options, - backend=configure.cuda_compiler_backend, - ) - - with cuda_utils.DeviceEnvironment(device): - cp_kernel(grid, block, self.get_ptrs(py_dict), *args_1, **kwargs) - - def add_param(self, ctype: str, cname: str): - """ - :param ctype: the type of the CUDA param - :type ctype: str - :param cname: the name of the CUDA param - :type cname: str - - Add a param to ``self.cparams``. - - .. admonition:: Note - :class: note - - When calling ``self.__call__``, the params order in the CUDA kernel are sorted by the dictionary order. Thus, - the user do not need to call ``add_param`` by some specific order. - - Here is an example: - - .. code-block:: python - - from spikingjelly.activation_based.auto_cuda import base - - example_ck = base.CKernel(kernel_name='example_ck') - print('origin:') - print(example_ck.full_codes) - - - example_ck.add_param(ctype='const float*', cname='x') - example_ck.add_param(ctype='const float*', cname='y') - example_ck.add_param(ctype='float', cname='z') - - print('after:') - print(example_ck.full_codes) - - .. code-block:: c++ - - origin: - - #include - extern "C" __global__ - void example_ck( - const int & numel - ) - - after: - - #include - extern "C" __global__ - void example_ck( - const int & numel, const float* x, const float* y, float z - ) - - - """ - # example: ctype = 'const float *', cname = 'x' - if cname in self.cparams: - raise ValueError(f"{cname} has been added to cparams!") - - if cname in self.reserved_cnames: - raise ValueError( - f"{cname} is the reserved cname. You should change the name of your variable to avoid conflict." - ) - - self.cparams[cname] = ctype - - @property - def declaration(self): - codes = f""" - #include - extern "C" __global__ - void {self.kernel_name}( - """ - self.cparams = dict(sorted(self.cparams.items())) - params_list = [] - for cname, ctype in self.cparams.items(): - params_list.append(f"{ctype} {cname}") - - codes += ", ".join(params_list) - - codes += """ - ) - """ - return codes - - @property - def head(self): - return "{" - - @property - def tail(self): - return "}" - - @property - def full_codes(self): - """ - :return: the full cuda codes - :rtype: str - - """ - return ( - wrap_with_comment(self.declaration, "declaration") - + wrap_with_comment(self.head, "head") - + wrap_with_comment(self.core, "core") - + wrap_with_comment(self.tail, "tail") - ) - - -class CKernel1D(CKernel): - def __init__(self, *args, **kwargs): - """ - :param kernel_name: the name of kernel - :type kernel_name: str - - The 1D (element-wise) CUDA kernel, which is extended from :class:`CKernel `. - All input/output tensors will be regarded as 1D tensors. - - Some critical attributes: - - cparams: - A dict for saving parameters name and type. - The default value is ``{'numel': 'const int &'}``. - ``numel`` represents the numel of elements for element-wise operations, which is also the numer of cuda - threads. - - reserved_cnames: - A list for saving reserved variables names, which can not be used to name variable again. - The defaule value is ``['index']``. - ``index`` represents the index of element, which is also the cuda thread index. - - Now let us check what the empty 1d kernel looks like: - - .. code-block:: python - - from spikingjelly.activation_based.auto_cuda import base - temp_kernel = base.CKernel1D(kernel_name='temp_kernel') - print(temp_kernel.full_codes) - - The outputs are: - - .. code-block:: c++ - - #include - extern "C" __global__ - void temp_kernel( - const int & numel - ) - - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < numel) - { - - } - } - - With setting logging level, we can check each part of the kernel: - - .. code-block:: python - - import logging - logging.basicConfig(level=logging.DEBUG) - from spikingjelly.activation_based.auto_cuda import base - temp_kernel = base.CKernel1D(kernel_name='temp_kernel') - print(temp_kernel.full_codes) - - The outputs are: - - .. code-block:: c++ - - //------declaration start------ - - #include - extern "C" __global__ - void temp_kernel( - const int & numel - ) - - //------declaration end-------- - - - //------head start------ - - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < numel) - { - - //------head end-------- - - - //------core start------ - - //------core end-------- - - - //------tail start------ - - } - } - - //------tail end-------- - - ``self.code`` can be specified by user. - For example, if we want to write a heaviside kernel, we can implement it easily with the cuda code - ``y[index] = x[index] >= 0.0f ? 1.0f: 0.0f;``, and add two params ``x, y``, which are inputs and outputs. - - Here is the example: - - .. code-block:: python - - from spikingjelly.activation_based.auto_cuda import base - - c_heaviside = base.CKernel1D(kernel_name='heaviside') - c_heaviside.add_param(ctype='const float *', cname='x') - c_heaviside.add_param(ctype='float *', cname='y') - c_heaviside.core = ''' - y[index] = x[index] >= 0.0f ? 1.0f: 0.0f; - ''' - print(c_heaviside.full_codes) - - The outputs are: - - .. code-block:: c++ - - #include - extern "C" __global__ - void heaviside( - const int & numel, const float * x, float * y - ) - - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < numel) - { - - y[index] = x[index] >= 0.0f ? 1.0f: 0.0f; - - } - } - - Here is an example of how to execute the kernel: - - .. code-block:: bash - - import torch - from spikingjelly.activation_based import cuda_utils - - device = 'cuda:0' - x = torch.rand([4, 4], device=device) - 0.5 - y = torch.zeros_like(x) - - numel = x.numel() - threads = 1024 - blocks = cuda_utils.cal_blocks(numel, threads) - print('x=') - print(x) - - with cuda_utils.DeviceEnvironment(device=x.get_device()): - numel = cupy.asarray(numel) - py_dict = { - 'numel': numel, - 'x': x, - 'y': y - } - c_heaviside((blocks, ), (threads, ), py_dict) - - - print('y=') - print(y) - - The outputs are: - - .. code-block:: bash - - x= - tensor([[-0.0423, -0.1383, -0.0238, 0.1018], - [ 0.3422, 0.1449, -0.2938, -0.1858], - [-0.3503, 0.0004, -0.4274, -0.2012], - [-0.0227, 0.2229, -0.0776, 0.2687]], device='cuda:0') - y= - tensor([[0., 0., 0., 1.], - [1., 1., 0., 0.], - [0., 1., 0., 0.], - [0., 1., 0., 1.]], device='cuda:0') - - - - """ - super().__init__(*args, **kwargs) - self.cparams["numel"] = "const int &" - self.reserved_cnames.append("index") - - @property - def head(self): - codes = """ - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < numel) - { - """ - return codes - - @property - def tail(self): - codes = """ - } - } - """ - return codes - - def check_half2(self, py_dict: dict): - """ - :param py_dict: a dict - :type py_dict: dict - - Check value in ``py_dict``. If the value is ``torch.Tensor`` with ``value.dtype == torch.half`` or - ``cupy.ndarray`` with ``value.dtype == np.float16``, this function will check whether the number of elements of - value is even. - - We assert when using half dtype, the numel should be even because we will use ``half2`` in CUDA kernel. - - .. admonition:: Note - :class: note - - :class:`CKernel1D.__call__ ` will pad half - tensor to even numel before executing the kernel. Thus, the user does not need to worry about padding. - - - """ - for key, value in py_dict.items(): - if isinstance(value, torch.Tensor): - if value.dtype == torch.half: - assert ( - value.numel() % 2 == 0 - ), f"please pad the numel of {key} to assert mod 2 == 0! (for half2)" - - if isinstance(value, cupy.ndarray): - if value.dtype == np.float16: - assert ( - value.size % 2 == 0 - ), f"please pad the numel of {key} to assert mod 2 == 0! (for half2)" - - def __call__(self, grid: tuple, block: tuple, py_dict: dict, *args_1, **kwargs): - """ - :param grid: the grid number of CUDA kernel - :type grid: tuple - :param block: the block number of CUDA kernel - :type block: tuple - :param py_dict: the dict that contains parameters for CUDA kernel - :type py_dict: dict - - Execute the CUDA kernel. ``*args_1, **kwargs`` are used as ``*args_1, **kwargs`` in :class:`cupy.RawKernel`. - - ``py_dict`` should contain ``key: value`` where ``key`` is the cuda kernel function param name, and ``value`` is - the variable. This dict should be one-to-one correspondence to ``self.cparams``. - - For example, if ``self.cparams`` is - - .. code-block:: python - - { - 'numel': 'const int &', - 'x': 'const float *', - 'y': 'const float *' - } - - - Then ``py_dict`` sould be - - .. code-block:: python - - { - 'numel': numel, - 'x': x, - 'y': y - } - - where ``numel, x, y`` should be ``torch.Tensor`` or ``cupy.ndarray`` with the corresponding data type, e.g., - ``x`` in ``py_dict`` should have data type ``torch.float`` because ``x`` in ``self.cparams`` have value ``'const float *'`` . - - The keys order is arbitrary because this function will sort keys to align formal and actual parameters. - - .. admonition:: Note - :class: note - - All tensors in ``py_dict`` will be regarded as 1D. - - - .. admonition:: Note - :class: note - - If any tensor ``x`` in ``py_dict`` with data type ``torch.half`` or ``np.float16`` but odd numel will be - flattened and padded by ``x = [x, x[-1]]`` before executing the CUDA kernel. After execution, padded values - in ``x`` will be removed, and ``x`` will be reshaped to the origin shape. - - - """ - # pad half2 - pad_keys = [] - pad_shapes = [] - for key, value in py_dict.items(): - if isinstance(value, torch.Tensor) and value.dtype == torch.half: - if value.numel() % 2 != 0: - pad_shapes.append(value.shape) - pad_keys.append(key) - value = value.flatten() - - value = torch.cat((value, value[-1].view(1))) - - py_dict[key] = value - - elif isinstance(value, cupy.ndarray) and value.dtype == np.float16: - if value.size % 2 != 0: - pad_shapes.append(value.shape) - pad_keys.append(key) - value = cupy.reshape(value, -1) - - value = cupy.concatenate((value, cupy.reshape(value[-1], 1))) - - py_dict[key] = value - - super().__call__(grid, block, py_dict, *args_1, **kwargs) - - # move pad values - for key, shape in zip(pad_keys, pad_shapes): - value = py_dict[key] - value = value[:-1] - - if isinstance(value, torch.Tensor): - value = value.view(shape) - - elif isinstance(value, cupy.ndarray): - value = cupy.reshape(value, shape) - - py_dict[key] = value - - def simple_call(self, **kwargs): - """ - :param kwargs: the dict that contains parameters for CUDA kernel - :type kwargs: dict - - - The simplified calling function, which is simplified from the standard calling function is :class:`CKernel1D.simple_call `. - - Compared with :class:`CKernel1D.simple_call `, - the device, numel, numbers of CUDA threads and blocks are calculated automatically from tensors in ``kwargs``. - - Here is the example: - - .. code-block:: python - - import torch - from spikingjelly.activation_based import cuda_utils - from spikingjelly.activation_based.auto_cuda import base - - c_heaviside = base.CKernel1D(kernel_name='heaviside') - c_heaviside.add_param(ctype='const float *', cname='x') - c_heaviside.add_param(ctype='float *', cname='y') - c_heaviside.core = ''' - y[index] = x[index] >= 0.0f ? 1.0f: 0.0f; - ''' - device = 'cuda:0' - - x = torch.rand([4, 4], device=device) - 0.5 - y = torch.zeros_like(x) - - print('x=') - print(x) - c_heaviside.simple_call(x=x, y=y) - print('y=') - print(y) - - The outputs are: - - .. code-block:: bash - - x= - tensor([[-0.1706, 0.2063, -0.2077, 0.3335], - [-0.0180, -0.2429, 0.3488, 0.1146], - [ 0.0362, 0.1584, 0.4828, -0.1389], - [-0.2684, 0.1898, 0.0560, 0.2058]], device='cuda:0') - y= - tensor([[0., 1., 0., 1.], - [0., 0., 1., 1.], - [1., 1., 1., 0.], - [0., 1., 1., 1.]], device='cuda:0') - - """ - py_dict = kwargs - device = self.get_device(py_dict) - numel = None - for value in kwargs.values(): - if isinstance(value, torch.Tensor): - numel = value.numel() - elif isinstance(value, cupy.ndarray): - numel = value.size - - if numel is None: - raise ValueError("No torch.Tensor or cupy.ndarray in kwargs!") - - with cuda_utils.DeviceEnvironment(device): - threads = configure.cuda_threads - blocks = cuda_utils.cal_blocks(numel) - numel = cupy.asarray(numel) - py_dict["numel"] = numel - self.__call__((blocks,), (threads,), py_dict) - - -class CKernel2D(CKernel): - def __init__(self, kernel_name: str, reverse: bool = False): - """ - :param kernel_name: the name of kernel - :type kernel_name: str - :param reverse: If ``True``, then the for-loop in kernel is ``for(int t = index; t < numel; t += dt)``. - If ``False``, then the for-loop in kernel is ``for(int t = numel - N + index; t >= 0; t -= dt)``. - :type reverse: bool - - - The 2D CUDA kernel, which is extended from :class:`CKernel `. - - All input/output tensors should have dimensions no more than 2. All 2D tensors will be regarded as ``shape = [T, N]``, - where ``T`` is the sequence length and ``N`` is the elements number of data at one time-step - - Some critical attributes: - - cparams: - A dict for saving parameters name and type. - The default value is ``{'numel': 'const int &', 'N': 'const int &'}``. - - ``N``: the number of elements number of sequence data at one time-step (the numel of 1-th dimension) - - ``numel``: the numel of elements in input/output tensors, which is ``T * N`` - - - reserved_cnames: - A list for saving reserved variables names, which can not be used to name variable again. - The defaule value is ``['index', 'dt', 't']``. - - ``index``: the index in 1-th dimension, which is also the CUDA thread index - - ``t``: the index in 0-th dimension - - ``dt``: used in CUDA kernel as the time-step stride. When ``x[t_py][j]`` in python code is identical to - ``x[t]`` in CUDA code, then ``x[t_py + 1][j]`` in python code is identical to ``x[t + dt]`` in CUDA code. - - Now let us check what the empty 2d kernel looks like: - - .. code-block:: python - - from spikingjelly.activation_based.auto_cuda import base - - temp_kernel = base.CKernel2D(kernel_name='temp_kernel') - print(temp_kernel.full_codes) - - The outputs are: - - .. code-block:: c++ - - #include - extern "C" __global__ - void temp_kernel( - const int & numel, const int & N - ) - - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < N) - { - const int dt = N; - - for(int t = index; t < numel; t += dt) - { - - } - - } - } - - With setting logging level, we can check each part of the kernel: - - .. code-block:: python - - import logging - logging.basicConfig(level=logging.DEBUG) - from spikingjelly.activation_based.auto_cuda import base - - temp_kernel = base.CKernel2D(kernel_name='temp_kernel') - print(temp_kernel.full_codes) - - The outputs are: - - .. code-block:: c++ - - //------declaration start------ - - #include - extern "C" __global__ - void temp_kernel( - const int & numel, const int & N - ) - - //------declaration end-------- - - - //------head start------ - - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < N) - { - const int dt = N; - - //------pre_core start------ - - //------pre_core end-------- - - - for(int t = index; t < numel; t += dt) - { - - //------head end-------- - - - //------core start------ - - //------core end-------- - - - //------tail start------ - - } - - //------post_core start------ - - //------post_core end-------- - - - } - } - - //------tail end-------- - - ``self.pre_core, self.post_core, self.core`` can be specified by user. - - Here is the example of how to implement the :class:`cumsum ` operation: - - .. code-block:: python - - import torch - import cupy - from spikingjelly.activation_based.auto_cuda import base - from spikingjelly.activation_based import cuda_utils - - cumsum = base.CKernel2D(kernel_name='cumsum') - cumsum.add_param(ctype='const float *', cname='x') - cumsum.add_param(ctype='float *', cname='y') - - cumsum.core = ''' - if (t - dt < 0) - { - y[t] = x[t]; - } - else - { - y[t] = x[t] + y[t - dt]; - } - ''' - - print(cumsum.full_codes) - - T = 4 - N = 3 - device = 'cuda:0' - - x = torch.randint(low=0, high=4, size=[T, N], device=device).float() - y = torch.zeros_like(x) - - threads = 1024 - blocks = cuda_utils.cal_blocks(N, threads) - - with cuda_utils.DeviceEnvironment(device=x.get_device()): - numel = cupy.asarray(T * N) - N = cupy.asarray(N) - py_dict = { - 'N': N, - 'numel': numel, - 'x': x, - 'y': y - } - cumsum((blocks, ), (threads, ), py_dict) - - print('x=') - print(x) - print('y=') - print(y) - - - The outputs are: - - .. code-block:: c++ - - #include - extern "C" __global__ - void cumsum( - const int & numel, const int & N, const float * x, float * y - ) - - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < N) - { - const int dt = N; - - for(int t = index; t < numel; t += dt) - { - - if (t - dt < 0) - { - y[t] = x[t]; - } - else - { - y[t] = x[t] + y[t - dt]; - } - - } - - } - } - - .. code-block:: bash - - x= - tensor([[3., 0., 2.], - [2., 0., 0.], - [2., 3., 2.], - [2., 1., 0.]], device='cuda:0') - y= - tensor([[3., 0., 2.], - [5., 0., 2.], - [7., 3., 4.], - [9., 4., 4.]], device='cuda:0') - - """ - super().__init__(kernel_name) - self.cparams["numel"] = "const int &" - self.reverse = reverse - self.cparams["N"] = "const int &" - self.reserved_cnames.append("index") - self.reserved_cnames.append("dt") - self.reserved_cnames.append("t") - - self._pre_core = "" - self._post_core = "" - - @property - def pre_core(self): - return self._pre_core - - @pre_core.setter - def pre_core(self, value: str): - self._pre_core = value - - @property - def post_core(self): - return self._post_core - - @post_core.setter - def post_core(self, value: str): - self._post_core = value - - def check_shape(self, py_dict: dict): - # all tensors should be ndim <= 2 - for value in py_dict.values(): - if isinstance(value, torch.Tensor): - assert value.ndim <= 2 - - elif isinstance(value, cupy.ndarray): - assert value.ndim <= 2 - - def check_half2(self, py_dict: dict): - """ - :param py_dict: a dict - :type py_dict: dict - - Check value in ``py_dict``. If the value is ``torch.Tensor`` with ``value.dtype == torch.half`` or - ``cupy.ndarray`` with ``value.dtype == np.float16``, this function will check whether the number of elements of - value is even. - - If the tensor ``x`` is 1D, it will be padded when ``x.numel() % 2 != 0``. - If the tensor ``x`` is 2D, it will be padded when ``x.shape[1] % 2 != 0``. - - We assert when using half dtype, the numel should be even because we will use ``half2`` in CUDA kernel. - - .. admonition:: Note - :class: note - - :class:`CKernel2D.__call__ ` will pad half - tensor to even numel before executing the kernel. Thus, the user does not need to worry about padding. - - - """ - for key, value in py_dict.items(): - if isinstance(value, torch.Tensor): - - if value.dtype == torch.half: - if value.ndim <= 1: - assert value.numel() % 2 == 0 - elif value.ndim == 2: - assert value.shape[1] % 2 == 0 - - elif isinstance(value, cupy.ndarray): - if value.dtype == np.float16: - if value.ndim <= 1: - assert value.size % 2 == 0 - elif value.ndim == 2: - assert value.shape[1] % 2 == 0 - - def __call__(self, grid: tuple, block: tuple, py_dict: dict, *args_1, **kwargs): - """ - :param grid: the grid number of CUDA kernel - :type grid: tuple - :param block: the block number of CUDA kernel - :type block: tuple - :param py_dict: the dict that contains parameters for CUDA kernel - :type py_dict: dict - - Execute the CUDA kernel. ``*args_1, **kwargs`` are used as ``*args_1, **kwargs`` in :class:`cupy.RawKernel`. - - ``py_dict`` should contain ``key: value`` where ``key`` is the cuda kernel function param name, and ``value`` is - the variable. This dict should be one-to-one correspondence to ``self.cparams``. - - For example, if ``self.cparams`` is - - .. code-block:: python - - { - 'numel': 'const int &', - 'x': 'const float *', - 'y': 'const float *' - } - - - Then ``py_dict`` sould be - - .. code-block:: python - - { - 'numel': numel, - 'x': x, - 'y': y - } - - where ``numel, x, y`` should be ``torch.Tensor`` or ``cupy.ndarray`` with the corresponding data type, e.g., - ``x`` in ``py_dict`` should have data type ``torch.float`` because ``x`` in ``self.cparams`` have value ``'const float *'`` . - - The keys order is arbitrary because this function will sort keys to align formal and actual parameters. - - .. admonition:: Note - :class: note - - All tensors in ``py_dict`` should be 1D or 2D. - - - .. admonition:: Note - :class: note - - If any 1D tensor ``x`` in ``py_dict`` with data type ``torch.half`` or ``np.float16`` but odd numel will be - flattened and padded by ``x = [x, x[-1]]`` before executing the CUDA kernel. - - If any 2D tensor ``x`` with shape ``[T, N]`` in ``py_dict`` with data type ``torch.half`` or ``np.float16`` - but ``N`` is odd, then ``x`` will be padded as ``x = [x, x[:, -1]]``, whose shape is ``[T, N + 1]``. - - After execution, padded values in ``x`` will be removed, and ``x`` will be reshaped to the origin shape. - """ - self.check_shape(py_dict) - - # pad half2 - pad_keys = [] - pad_shapes = [] - for key, value in py_dict.items(): - if isinstance(value, torch.Tensor) and value.dtype == torch.half: - - if value.ndim <= 1: - # 1D tensor - if value.numel() % 2 != 0: - pad_shapes.append(value.shape) - pad_keys.append(key) - value = value.flatten() - - value = torch.cat((value, value[-1].view(1))) - - py_dict[key] = value - - elif value.shape[1] % 2 != 0: - # 2D tensor with shape = [T, N] and N % 2 != 0 - pad_shapes.append(value.shape) - pad_keys.append(key) - - value = torch.cat((value, value[:, -1].view(-1, 1)), dim=1) - # [T, N] -> [T, N + 1] - py_dict[key] = value - - elif isinstance(value, cupy.ndarray) and value.dtype == np.float16: - - if value.ndim <= 1: - # 1D tensor - if value.size % 2 != 0: - pad_shapes.append(value.shape) - pad_keys.append(key) - value = cupy.reshape(value, -1) - - value = cupy.concatenate((value, cupy.reshape(value[-1], 1))) - - py_dict[key] = value - - elif value.shape[1] % 2 != 0: - pad_shapes.append(value.shape) - pad_keys.append(key) - # [T, N] -> [T, N + 1] - - value = cupy.concatenate( - (value, cupy.reshape(value[:, -1], (-1, 1))), axis=1 - ) - py_dict[key] = value - - super().__call__(grid, block, py_dict, *args_1, **kwargs) - - # move pad values - for i, key in enumerate(pad_keys): - value = py_dict[key] - shape = pad_shapes[i] - if isinstance(value, torch.Tensor): - if value.ndim <= 1: - value = value[:-1] - value = value.view(shape) - else: - value = value[:, :-1] - - elif isinstance(value, cupy.ndarray): - if value.ndim <= 1: - value = value[:, -1] - value = cupy.reshape(value, shape) - - else: - value = value[:, :-1] - - py_dict[key] = value - - @property - def head(self): - codes = """ - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < N) - { - const int dt = N; - """ - - codes += wrap_with_comment(self.pre_core, "pre_core") - - if self.reverse: - codes += """ - for(int t = numel - N + index; t >= 0; t -= dt) - { - """ - else: - codes += """ - for(int t = index; t < numel; t += dt) - { - """ - return codes - - @property - def tail(self): - codes = """ - } - """ - - codes += wrap_with_comment(self.post_core, "post_core") - - codes += """ - } - } - """ - return codes - - def simple_call(self, **kwargs): - """ - :param kwargs: the dict that contains parameters for CUDA kernel - :type kwargs: dict - - - The simplified calling function, which is simplified from the standard calling function is :class:`CKernel2D.simple_call `. - - Compared with :class:`CKernel2D.simple_call `, - the device, N, numel, numbers of CUDA threads and blocks are calculated automatically from tensors in ``kwargs``. - - Here is the example: - - .. code-block:: python - - import torch - import cupy - from spikingjelly.activation_based.auto_cuda import base - from spikingjelly.activation_based import cuda_utils - - cumsum = base.CKernel2D(kernel_name='cumsum') - cumsum.add_param(ctype='const float *', cname='x') - cumsum.add_param(ctype='float *', cname='y') - - cumsum.core = ''' - if (t - dt < 0) - { - y[t] = x[t]; - } - else - { - y[t] = x[t] + y[t - dt]; - } - ''' - - T = 4 - N = 3 - device = 'cuda:0' - - x = torch.randint(low=0, high=4, size=[T, N], device=device).float() - y = torch.zeros_like(x) - - cumsum.simple_call(x=x, y=y) - print('x=') - print(x) - print('y=') - print(y) - - The outputs are: - - .. code-block:: bash - - x= - tensor([[0., 2., 1.], - [1., 3., 1.], - [2., 2., 0.], - [2., 0., 1.]], device='cuda:0') - y= - tensor([[0., 2., 1.], - [1., 5., 2.], - [3., 7., 2.], - [5., 7., 3.]], device='cuda:0') - - """ - py_dict = kwargs - device = self.get_device(py_dict) - - numel = None - N = None - for value in kwargs.values(): - if isinstance(value, torch.Tensor) and value.ndim == 2: - numel = value.numel() - N = value.shape[1] - elif isinstance(value, cupy.ndarray) and value.ndim == 2: - numel = value.size - N = value.shape[1] - - if numel is None or N is None: - raise ValueError("No 2D torch.Tensor or cupy.ndarray in kwargs!") - - with cuda_utils.DeviceEnvironment(device): - threads = configure.cuda_threads - blocks = cuda_utils.cal_blocks(numel) - numel = cupy.asarray(numel) - N = cupy.asarray(N) - py_dict["numel"] = numel - py_dict["N"] = N - self.__call__((blocks,), (threads,), py_dict) - - -class CodeTyper: - def __init__(self, indent_num: int): - """ - :param indent_num: the number of indents - :type indent_num: int - - A CUDA code formatter with adding indents. The full code can be accessed by ``self.codes``. - - Here is an example: - - .. code-block:: python - - from spikingjelly.activation_based.auto_cuda import base, cfunction - - code0 = cfunction.if_else(z='z', x='x', y='y', mask='mask', dtype='float') - code1 = cfunction.sigmoid_backward(y='y', x='x', alpha=2., dtype='float') - - codes = '' - codes += code0 - codes += code1 - - print('// Without CodeTyper:') - print('// ------------------') - print(codes) - print('// ------------------') - - ctyper = base.CodeTyper(4) - ctyper.append(code0) - ctyper.append(code1) - print('// With CodeTyper:') - print('// ------------------') - print(ctyper.codes) - print('// ------------------') - - .. code-block:: c++ - - // Without CodeTyper: - // ------------------ - z = x * mask + y * (1.0f - mask);const float sigmoid_backward__sigmoid_ax = 1.0f / (1.0f + expf(- (2.0f) * x)); - y = (1.0f - sigmoid_backward__sigmoid_ax) * sigmoid_backward__sigmoid_ax * (2.0f); - // ------------------ - // With CodeTyper: - // ------------------ - - z = x * mask + y * (1.0f - mask); - const float sigmoid_backward__sigmoid_ax = 1.0f / (1.0f + expf(- (2.0f) * x)); - y = (1.0f - sigmoid_backward__sigmoid_ax) * sigmoid_backward__sigmoid_ax * (2.0f); - - // ------------------ - - - """ - self.indent = " " * indent_num - self.codes = "\n" - - def append(self, codes: str): - """ - :param codes: cuda codes to be added - :type codes: str - - Append codes in ``self.codes``. - """ - codes = codes.replace("\n", "") - codes = codes.split(";") - for i in range(codes.__len__()): - if codes[i].__len__() > 0: - if codes[i] in ("{", "}"): - self.codes += self.indent + codes[i] + "\n" - else: - self.codes += self.indent + codes[i] + ";\n" - - -class CodeBlock: - def __init__(self, env: CodeTyper): - """ - :param env: a CodeTyper - :type env: CodeTyper - - A tool for adding a CUDA code block in ``CodeTyper.code``. It is helpful when we want to calculate by intermediate variables. - - Here is an example: - - .. code-block:: python - - from spikingjelly.activation_based.auto_cuda import base - - ctyper = base.CodeTyper(4) - with base.CodeBlock(ctyper): - ctyper.append('// swap x and y') - ctyper.append('float temp_var = x;') - ctyper.append('x = y;') - ctyper.append('y = temp_var;') - - print(ctyper.codes) - - The outputs are: - - .. code-block:: c++ - - { - // swap x and y; - float temp_var = x; - x = y; - y = temp_var; - } - - - """ - self.env = env - - def __enter__(self): - self.env.append("{") - self.env.indent += " " - - def __exit__(self, exc_type, exc_val, exc_tb): - self.env.indent = self.env.indent[:-1] - self.env.append("}") diff --git a/src/chop/nn/snn/auto_cuda/cfunction.py b/src/chop/nn/snn/auto_cuda/cfunction.py deleted file mode 100644 index 5b2aa752a..000000000 --- a/src/chop/nn/snn/auto_cuda/cfunction.py +++ /dev/null @@ -1,426 +0,0 @@ -# *************************************************************************************** -# * Title: auto_cuda -# * Reference: These directory is directly sourced from spikingJelly -# * Availability: https://github.com/fangwei123456/spikingjelly/tree/master/spikingjelly/activation_based/auto_cuda -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# *************************************************************************************** - -from typing import Optional - - -def wrap_return_codes(y: Optional[str], codes: str): - if y is None: - return f"({codes})" - else: - return f"{y} = {codes};" - - -def float2half2(y: Optional[str], x: str): - codes = f"__float2half2_rn({x})" - return wrap_return_codes(y, codes) - - -def constant(y: Optional[str], x: float, dtype: str): - if dtype == "float": - codes = f"{x}f" - elif dtype == "half2": - codes = f"__float2half2_rn({x}f)" - else: - raise NotImplementedError(dtype) - - return wrap_return_codes(y, codes) - - -def abs(y: Optional[str], x: str, dtype: str): - if dtype == "float": - codes = f"fabsf({x})" - elif dtype == "half2": - codes = f"__habs2({x})" - else: - raise NotImplementedError(dtype) - return wrap_return_codes(y, codes) - - -def power(z: Optional[str], x: str, y: str, dtype: str): - if dtype == "float": - codes = f"__powf({x, y})" - elif dtype == "half2": - # CUDA FP16 does not provide powf function. We use z = 2 ** (log2(x) * y) - codes = f"h2exp(__hmul2(h2log2({x}), {y}))" - else: - raise NotImplementedError(dtype) - return wrap_return_codes(z, codes) - - -def if_else(z: Optional[str], x: str, y: str, mask: str, dtype: str): - # z = x * mask + y * (1. - mask) - if dtype == "float": - codes = f"{x} * {mask} + {y} * (1.0f - {mask})" - elif dtype == "half2": - codes = f"__hfma2({x}, {mask}, __hmul2({y}, __hsub2(__float2half2_rn(1.0f), {mask})))" - else: - raise NotImplementedError(dtype) - - return wrap_return_codes(z, codes) - - -def if_else_else( - w: Optional[str], x: str, y: str, z: str, mask_x: str, mask_y: str, dtype: str -): - # w = mask_x * x + mask_y * y + (1. - mask_x * mask_y) * z - if dtype == "float": - codes = f"{mask_x} * {x} + {mask_y} * {y} + (1. - {mask_x} * {mask_y}) * {z}" - else: - codes = f"__hadd2(__hadd2(__hmul2({mask_x}, {x}), __hmul2({mask_y}, {y})), __hmul2({z}, __hsub2(__float2half_rn(1.0f), __hmul2({mask_x}, {mask_y}))))" - - return wrap_return_codes(w, codes) - - -def greater_equal(z: Optional[str], x: str, y: str, dtype: str): - if dtype == "float": - codes = f"(float) ({x} >= {y})" - elif dtype == "half2": - codes = f"__hgeu2({x}, {y})" - else: - raise NotImplementedError(dtype) - return wrap_return_codes(z, codes) - - -def greater_than(z: Optional[str], x: str, y: str, dtype: str): - if dtype == "float": - codes = f"(float) ({x} > {y})" - elif dtype == "half2": - codes = f"__hgtu2({x}, {y})" - else: - raise NotImplementedError(dtype) - return wrap_return_codes(z, codes) - - -def minimal(z: Optional[str], x: str, y: str, dtype: str): - if dtype == "float": - codes = f"min({x}, {y})" - elif dtype == "half2": - codes = f"__hmin2({x}, {y})" - else: - raise NotImplementedError(dtype) - return wrap_return_codes(z, codes) - - -def maximum(z: Optional[str], x: str, y: str, dtype: str): - if dtype == "float": - codes = f"max({x}, {y})" - elif dtype == "half2": - codes = f"__hmax2({x}, {y})" - else: - raise NotImplementedError(dtype) - return wrap_return_codes(z, codes) - - -def add(z: Optional[str], x: str, y: str, dtype: str): - if dtype == "float": - - if x == "0.0f": - codes = f"{y}" - - elif y == "0.0f": - codes = f"{x}" - - else: - codes = f"{x} + {y}" - - elif dtype == "half2": - if x == "__float2half2_rn(0.0f)": - codes = f"{y}" - - elif y == "__float2half2_rn(0.0f)": - codes = f"{x}" - else: - codes = f"__hadd2({x}, {y})" - else: - raise NotImplementedError(dtype) - - return wrap_return_codes(z, codes) - - -def sub(z: Optional[str], x: str, y: str, dtype: str): - if dtype == "float": - - if y == "0.0f": - codes = f"{x}" - else: - codes = f"{x} - {y}" - - elif dtype == "half2": - - if y == "__float2half2_rn(0.0f)": - codes = f"{x}" - else: - codes = f"__hsub2({x}, {y})" - else: - raise NotImplementedError(dtype) - - return wrap_return_codes(z, codes) - - -def mul(z: Optional[str], x: str, y: str, dtype: str): - if dtype == "float": - - if x == "1.0f": - codes = f"{y}" - - elif y == "1.0f": - codes = f"{x}" - - else: - codes = f"{x} * {y}" - - elif dtype == "half2": - - if x == "__float2half2_rn(1.0f)": - codes = f"{y}" - - elif y == "__float2half2_rn(1.0f)": - codes = f"{x}" - - else: - codes = f"__hmul2({x}, {y})" - - else: - raise NotImplementedError(dtype) - - return wrap_return_codes(z, codes) - - -def div(z: Optional[str], x: str, y: str, dtype: str): - if dtype == "float": - - if y == "1.0f": - codes = f"{x}" - else: - codes = f"{x} / {y}" - elif dtype == "half2": - if y == "__float2half2_rn(1.0f)": - codes = f"{x}" - else: - codes = f"__h2div({x}, {y})" - else: - raise NotImplementedError(dtype) - - return wrap_return_codes(z, codes) - - -def neg(y: Optional[str], x: str, dtype: str): - if dtype == "float": - codes = f"- {x}" - elif dtype == "half2": - codes = f"__hneg2({x})" - else: - raise NotImplementedError(dtype) - return wrap_return_codes(y, codes) - - -def heaviside(y: Optional[str], x: str, dtype: str): - if dtype == "float": - codes = f"{x} >= 0.0f ? 1.0f: 0.0f" - elif dtype == "half2": - codes = f"__hgeu2({x}, __float2half2_rn(0.0f))" - else: - raise NotImplementedError(dtype) - return wrap_return_codes(y, codes) - - -def exp(y: Optional[str], x: str, dtype: str): - if dtype == "float": - codes = f"expf({x})" - elif dtype == "half2": - codes = f"h2exp({x})" - else: - raise NotImplementedError(dtype) - return wrap_return_codes(y, codes) - - -def sigmoid(y: Optional[str], x: str, alpha: float, dtype: str): - alpha = constant(None, alpha, dtype) - if dtype == "float": - codes = f"1.0f / (1.0f + expf(- {alpha} * {x}))" - elif dtype == "half2": - codes = f"__h2div(__float2half2_rn(1.0f), __hadd2(__float2half2_rn(1.0f), h2exp(__hneg2(__hmul2({alpha}, {x})))))" - - else: - raise NotImplementedError(dtype) - - return wrap_return_codes(y, codes) - - -def sigmoid_backward(y: str, x: str, alpha: float, dtype: str): - assert y is not None - codes = ( - sigmoid( - y=f"const {dtype} sigmoid_backward__sigmoid_ax", - x=x, - alpha=alpha, - dtype=dtype, - ) - + "\n" - ) - alpha = constant(None, alpha, dtype) - if dtype == "float": - codes += f"{y} = (1.0f - sigmoid_backward__sigmoid_ax) * sigmoid_backward__sigmoid_ax * {alpha};" - return codes - elif dtype == "half2": - codes += f"{y} = __hmul2(__hmul2(__hsub2(__float2half2_rn(1.0f), sigmoid_backward__sigmoid_ax), sigmoid_backward__sigmoid_ax), {alpha});" - return codes - else: - raise NotImplementedError(dtype) - - -def atan_backward(y: str, x: str, alpha: float, dtype: str): - assert y is not None - alpha = constant(None, alpha, dtype) - if dtype == "float": - codes = f"const float atan_backward__alpha_x = ((float) 1.57079632679489661923) * {alpha} * {x};" - codes += f"{y} = {alpha} / 2.0f / (1.0f + atan_backward__alpha_x * atan_backward__alpha_x);" - return codes - - elif dtype == "half2": - codes = f"const half2 atan_backward__alpha_x = __hmul2(__hmul2(__float2half2_rn((float) 1.57079632679489661923), {alpha}), {x});" - codes += f"{y} = __h2div({alpha}, __hmul2(__float2half2_rn(2.0f), __hfma2(atan_backward__alpha_x, atan_backward__alpha_x, __float2half2_rn(1.0f))));" - return codes - - else: - raise NotImplementedError(dtype) - - -def piecewise_leaky_relu_backward(y: str, x: str, w: float, c: float, dtype: str): - assert y is not None - w_inv = constant(None, 1.0 / w, dtype) - w = constant(None, w, dtype) - c = constant(None, c, dtype) - - codes = greater_equal( - z=f"const {dtype} piecewise_leaky_relu_backward__mask", - x=w, - y=abs(y=None, x=x, dtype=dtype), - dtype=dtype, - ) - - codes += if_else( - z=y, x=w_inv, y=c, mask=f"piecewise_leaky_relu_backward__mask", dtype=dtype - ) - - return codes - - -def s2nn_backward(y: str, x: str, alpha: float, beta: float, dtype: str): - assert y is not None - codes = sigmoid_backward( - y=f"const {dtype} s2nn_backward__sgax", x=x, alpha=alpha, dtype=dtype - ) - codes += greater_than( - z=f"const {dtype} s2nn_backward__mask", - x=constant(None, 0.0, dtype), - y=x, - dtype=dtype, - ) - - codes += if_else( - z=y, - x=f"s2nn_backward__sgax", - y=div( - z=None, - x=constant(None, beta, dtype), - y=add(z=None, x=x, y=constant(None, 1.0, dtype), dtype=dtype), - dtype=dtype, - ), - mask=f"s2nn_backward__mask", - dtype=dtype, - ) - return codes - - -def q_pseudo_spike_backward(y: str, x: str, alpha: float, dtype: str): - assert y is not None - alpha = constant(None, alpha, dtype) - if dtype == "float": - return f"{y} = __powf(2.0f * fabsf({x}) / ({alpha} - 1.0f) + 1.0f, - {alpha});" - elif dtype == "half2": - return power( - z=y, - x=f"__hadd2(__h2div(__hmul2(__float2half2_rn(2.0f), __habs2({x})), __hsub2({alpha}, __float2half2_rn(1.0f))), __float2half2_rn(1.0f))", - y=f"__hneg2({alpha})", - dtype=dtype, - ) - - -def leaky_k_relu_backward(y: str, x: str, leak: float, k: float, dtype: str): - assert y is not None - leak = constant(None, leak, dtype) - k = constant(None, k, dtype) - codes = greater_equal( - z=f"const {dtype} leaky_k_relu_backward__mask", - x=x, - y=constant(None, 0.0, dtype), - dtype=dtype, - ) - codes += if_else(z=y, x=k, y=leak, mask=f"leaky_k_relu_backward__mask", dtype=dtype) - return codes - - -def fake_numerical_gradient_backward(y: str, x: str, alpha: float, dtype: str): - assert y is not None - alpha = constant(None, alpha, dtype) - codes = greater_equal( - z=f"{dtype} fake_numerical_gradient_backward__mask", - x=x, - y=constant(None, 0.0, dtype), - dtype=dtype, - ) - codes += mul( - z="fake_numerical_gradient_backward__mask", - x="fake_numerical_gradient_backward__mask", - y=constant(None, 2.0, dtype), - dtype=dtype, - ) - codes += sub( - z="fake_numerical_gradient_backward__mask", - x="fake_numerical_gradient_backward__mask", - y=constant(None, 1.0, dtype), - dtype=dtype, - ) - codes += div( - z="fake_numerical_gradient_backward__mask", - x="fake_numerical_gradient_backward__mask", - y=x, - dtype=dtype, - ) - codes += minimal( - z=y, x="fake_numerical_gradient_backward__mask", y=alpha, dtype=dtype - ) - return codes - - -def log_tailed_relu_backward(y: str, x: str, alpha: float, dtype: str): - alpha = constant(None, alpha, dtype) - codes = greater_equal( - z=f"const {dtype} log_tailed_relu_backward__mask_le0", - x=constant(None, 0.0, dtype), - y=x, - dtype=dtype, - ) - codes += greater_than( - z=f"const {dtype} log_tailed_relu_backward__mask_gt1", - x=x, - y=constant(None, 1, dtype), - dtype=dtype, - ) - codes += if_else_else( - w=y, - x=alpha, - y=div(z=None, x=constant(None, 1.0, dtype), y=x, dtype=dtype), - z=constant(None, 1.0, dtype), - mask_x=f"const {dtype} log_tailed_relu_backward__mask_le0", - mask_y=f"const {dtype} log_tailed_relu_backward__mask_gt1", - dtype=dtype, - ) - return codes diff --git a/src/chop/nn/snn/auto_cuda/example.py b/src/chop/nn/snn/auto_cuda/example.py deleted file mode 100644 index 48edce05a..000000000 --- a/src/chop/nn/snn/auto_cuda/example.py +++ /dev/null @@ -1,44 +0,0 @@ -# *************************************************************************************** -# * Title: auto_cuda -# * Reference: These directory is directly sourced from spikingJelly -# * Availability: https://github.com/fangwei123456/spikingjelly/tree/master/spikingjelly/activation_based/auto_cuda -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# *************************************************************************************** - - -from spikingjelly.activation_based.auto_cuda.generator import ( - analyse_graph, - gen_forward_codes, - gen_backward_codes, -) -from spikingjelly.activation_based import surrogate - -import torch - -if __name__ == "__main__": - - def lif_charge(x: torch.Tensor, v_last: torch.Tensor, tau: float, v_reset: float): - h = v_last + (x - (v_last - v_reset)) / tau - return h - - input_nodes, inter_nodes, output_nodes, cmds = analyse_graph( - lif_charge, requires_grad=(True, True, False, False) - ) - - forward_codes, forward_kernel_name, cuda_cmds = gen_forward_codes( - input_nodes, inter_nodes, output_nodes, cmds, hard_reset=True - ) - - backward_codes, backward_kernel_name, input_bp_vars = gen_backward_codes( - cuda_cmds, - input_nodes, - output_nodes, - cmds, - hard_reset=True, - detach_reset=True, - surrogate_fuction=surrogate.ATan(), - ) - - print(f"forward_codes = \n{forward_codes}") - print(f"backward_codes = \n{backward_codes}") diff --git a/src/chop/nn/snn/auto_cuda/generator.py b/src/chop/nn/snn/auto_cuda/generator.py deleted file mode 100644 index 639cf0379..000000000 --- a/src/chop/nn/snn/auto_cuda/generator.py +++ /dev/null @@ -1,667 +0,0 @@ -# *************************************************************************************** -# * Title: auto_cuda -# * Reference: These directory is directly sourced from spikingJelly -# * Availability: https://github.com/fangwei123456/spikingjelly/tree/master/spikingjelly/activation_based/auto_cuda -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# *************************************************************************************** - -import logging -import torch -import torch.nn as nn -import torch.nn.functional as F -import re -import sys -import copy -from typing import Callable -import numpy as np - - -def hash_str(x: object): - hash_code = hash(x) - if hash_code < 0: - return f"_{-hash_code}" - else: - return hash_code - - -class VarNode: - def __init__(self, prefix: str, name: str, instance: object, value=None): - self.debug_name = name # 原始的name形如 %8, v_last.1 - # 将原始的name进行转换 - self.name = prefix + "_" + name.replace(".", "_") - - self.instance = str(instance) - # 中间节点的self.instance,在生成前向传播cuda代码时,若debug_instance为Tensor,self.instance会被修改为float - self.value = value - self.requires_grad = False - self.cu_var_suffix = "" - - @property - def name_bp(self): - return "grad_" + self.name - - @property - def cu_var(self): - # 前向传播时,在cuda代码中的变量名 - - # 如果value非空,表明其是一个常数值,直接返回数值即可,例如 value = 0.1 返回 '0.1f' - if self.value is not None: - if self.instance == "int": - return self.name - elif self.instance == "float": - return self.name + "f" - else: - raise ValueError(self.instance) - - # value空,表示其是一个变量 - - return self.name + self.cu_var_suffix - - @property - def cu_var_bp(self): - - # 反向传播时在cuda代码中的变量名 - if self.value is not None: - raise ValueError - else: - return "grad_" + self.cu_var - - def __repr__(self): - return f"({self.debug_name}, {self.name}, {self.instance}, value={self.value}, rg={self.requires_grad})" - - -def analyse_graph(custom_fun, requires_grad: tuple): - graph: torch.Graph = torch.jit.script(custom_fun).graph - - logging.debug(f"\ngraph = {graph}") - # 生成 输入 中间 输出 节点 - assert sys.version_info.major >= 3 and sys.version_info.minor >= 6 - # python >= 3.6时,字典默认是有序的 - # key是VarNode.debug_name,value是VarNode - input_nodes = {} - output_nodes = {} - inter_nodes = {} - - assert custom_fun.__annotations__.__len__() >= 2 - for i, (item, name) in enumerate( - zip(graph.inputs(), custom_fun.__annotations__.keys()) - ): - # 要求custom_fun一定是custom_fun(x: torch.Tensor, v_last: torch.Tensor, ...)的形式 - if i == 0: - assert str(item.type()) == "Tensor" and name == "x" - elif i == 1: - assert str(item.type()) == "Tensor" and name == "v_last" - - # 用python函数中的name覆盖掉jit自动生成的name - # 仅包括输入。中间变量的命名仍然是jit设置的,不会被更改 - item.setDebugName(name) - - node = VarNode(prefix="input", name=item.debugName(), instance=item.type()) - if node.instance == "Tensor" and requires_grad[i]: - node.requires_grad = True - - logging.debug(f"\ninput node [{i}] = {node}") - assert node not in input_nodes - input_nodes[node.debug_name] = node - - for i, item in enumerate(graph.outputs()): - - if i == 0: - assert str(item.type()) == "Tensor" - item.setDebugName("h") - - elif i > 0: - raise NotImplementedError( - "For the moment, we only support for single output!" - ) - - node = VarNode(prefix="output", name=item.debugName(), instance=item.type()) - - logging.debug(f"\noutput node [{i}] = {node}") - assert node not in output_nodes - output_nodes[node.debug_name] = node - - cmds = [] - # cmds的元素是一个元组,为 (output, fun, inputs) - # 这里的output是VarNode,fun是str,inputs是(VarNode) - for node in graph.nodes(): - # item: torch.Note - fun = node.kind() - if fun == "prim::Constant": - - item = node.output() - assert ( - item.debugName() not in input_nodes - and item.debugName() not in output_nodes - ) - - i_node = VarNode( - prefix="inter", name=item.debugName(), instance=item.type() - ) - value = None - - # 从命令中提取出常数值 - if i_node.instance == "int": - pattern = re.compile(r".*prim::Constant\[value=([0-9]+)\]") - m = pattern.match(str(node)) - value = int(m.groups()[0]) - - elif i_node.instance == "float": - pattern = re.compile(r".*prim::Constant\[value=([0-9\.]+)\]") - m = pattern.match(str(node)) - value = float(m.groups()[0]) - - else: - raise NotImplementedError - - i_node.value = value - assert i_node.debug_name not in input_nodes - assert i_node.debug_name not in output_nodes - if i_node.debug_name not in inter_nodes: - inter_nodes[i_node.debug_name] = i_node - - cmds.append((i_node, fun, ())) - - else: - inputs = [] - - for item in node.inputs(): - - if item.debugName() in input_nodes: - i_node = input_nodes[item.debugName()] - - elif item.debugName() in output_nodes: - i_node = output_nodes[item.debugName()] - - else: - # 只有既不为输入node也不为输出node的node,才会被视作中间node - if item.debugName() in inter_nodes: - i_node = inter_nodes[item.debugName()] - else: - i_node = VarNode( - prefix="inter", name=item.debugName(), instance=item.type() - ) - inter_nodes[i_node.debug_name] = i_node - - inputs.append(i_node) - - item = node.output() - if item.debugName() in input_nodes: - i_node = input_nodes[item.debugName()] - - elif item.debugName() in output_nodes: - i_node = output_nodes[item.debugName()] - - else: - # 只有既不为输入node也不为输出node的node,才会被视作中间node - if item.debugName() in inter_nodes: - i_node = inter_nodes[item.debugName()] - else: - i_node = VarNode( - prefix="inter", name=item.debugName(), instance=item.type() - ) - inter_nodes[i_node.debug_name] = i_node - - cmds.append((i_node, fun, tuple(inputs))) - - for i, node in enumerate(inter_nodes.values()): - logging.debug(f"\ninter node [{i}] = {node}") - - return input_nodes, inter_nodes, output_nodes, cmds - - -def gen_forward_codes( - input_nodes: dict, - inter_nodes: dict, - output_nodes: dict, - cmds: list, - hard_reset: bool, -): - # 暂时只支持单个输出 - assert output_nodes.__len__() == 1 - - # 代码生成 - codes = "\n" - codes += " " - codes += "{\n" - - for node in input_nodes.values(): - # 赋值到代码段的变量 - if node.debug_name == "x": - codes += " " - codes += f"const float {node.cu_var} = x_seq[t];\n" - elif node.debug_name == "v_last": - codes += " " - codes += f"const float {node.cu_var} = v_v_seq[t];\n" - else: - if node.instance == "Tensor": - node.cu_var_suffix = "_t" - codes += " " - codes += f"const float {node.cu_var} = {node.name}[t];\n" - - # instance为float的不需要提前赋值,因为不需要索引(直接从cuda函数的参数中取出即可) - - # 记录在自动生成的cuda代码段中,哪些cu_var是已经声明的 - code_block_nodes = {} - - cuda_cmds = [] - for item in cmds: - output, fun, inputs = item - codes += " " - if fun == "prim::Constant": - gen_cmd = "\n" - elif fun in ["aten::add", "aten::sub"]: - # z = x + y * alpha - x, y, alpha = inputs - z = output - z.requires_grad = x.requires_grad or y.requires_grad - if z.cu_var not in code_block_nodes: - code_block_nodes[z.cu_var] = z - codes += "float " - - if fun == "aten::add": - op = "+" - else: - op = "-" - - if alpha.value == 1: - gen_cmd = f"{z.cu_var} = {x.cu_var} {op} {y.cu_var};\n" - else: - gen_cmd = f"{z.cu_var} = {x.cu_var} {op} {y.cu_var} * {alpha.cu_var};\n" - - elif fun in ["aten::mul", "aten::div"]: - x, y = inputs - z = output - z.requires_grad = x.requires_grad or y.requires_grad - if z.cu_var not in code_block_nodes: - code_block_nodes[z.cu_var] = z - codes += "float " - if fun == "aten::mul": - op = "*" - else: - op = "/" - - gen_cmd = f"{z.cu_var} = {x.cu_var} {op} {y.cu_var};\n" - else: - raise NotImplementedError(fun) - - codes += gen_cmd - cuda_cmds.append(gen_cmd) - - for i, node in enumerate(output_nodes.values()): - # 代码段的变量赋值到输出 - if i == 0: - codes += " " - codes += f"h_seq[t] = {node.name};\n" - - codes += " " - codes += "}\n" - - # CUDA函数的参数 - params = [ - ("x_seq", "const float *"), - ("v_v_seq", "float *"), - ("h_seq", "float *"), - ("spike_seq", "float *"), - ("v_threshold", "const float &"), - ] - if hard_reset: - params.append(("v_reset", "const float &")) - - params.extend( - [ - ("neuron_num", "const int &"), - ("numel", "const int &"), - ] - ) - params_name = [] - for item in params: - params_name.append(item[0]) - - # 在CUDA函数参数中增加参数,同时检测命名冲突 - - for node in inter_nodes.values(): - assert node.name not in params_name - - for node in input_nodes.values(): - if node.debug_name in ["x", "v_last"]: - pass - else: - assert node.name not in params_name - - if node.instance == "Tensor": - param = (node.name, "const float *") - elif node.instance == "float": - param = (node.name, "const float &") - else: - raise NotImplementedError - params.append(param) - - for node in output_nodes.values(): - assert node.name not in params_name - - for i in range(params.__len__()): - param = params[i] - params[i] = param[1] + param[0] - - head = ", ".join(params) - head = "(" + head + ")" - - head += """ - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < neuron_num) - { - const int dt = neuron_num; - for(int mem_offset = 0; mem_offset < numel; mem_offset += neuron_num) - { - const int t = index + mem_offset; - """ - tail = """ - if (h_seq[t] >= v_threshold) - - { - spike_seq[t] = 1.0f; - v_v_seq[t + dt] = v_reset; - } - - else - { - spike_seq[t] = 0.0f; - v_v_seq[t + dt] = h_seq[t]; - } - } - } - } - """ - - codes = head + codes + tail - - kernel_name = f"forward_kernel_{hash_str(codes)}" - codes = ( - f""" - extern "C" __global__ - void {kernel_name} - """ - + codes - ) - - return codes, kernel_name, cuda_cmds - - -def gen_backward_codes( - cuda_cmds: list, - input_nodes: dict, - output_nodes: dict, - cmds: list, - hard_reset: bool, - detach_reset: bool, - surrogate_fuction, -): - """ - 用户定义的前向传播函数为 - h_seq[t] = fun(x_seq[t], v_v_seq[t], ...) - - 需要计算出 h_seq[t] -> x_seq[t] 的梯度和 h_seq[t] -> v_v_seq[t]的梯度 - 还需要考虑 ... 中如果有tensor,可以增加flag,决定是否计算h_seq[t]对其的梯度 - """ - - input_bp_nodes = {} - """ - 在反向传播时,输入梯度是output_nodes的梯度 - 有些变量的梯度在计算时,需要用到其他变量,例如z = x * y,计算grad_x需要用到y - input_bp_nodes用来记录哪些node要用到 - """ - - # 记录在自动生成的cuda代码段中,哪些cu_var是已经声明的 - code_block_nodes = {} - - codes = "\n" - - for i in range(cmds.__len__()): - output, fun, inputs = cmds[cmds.__len__() - 1 - i] - codes += "\n" - codes += " " - codes += f"// {cuda_cmds[cmds.__len__() - 1 - i]}" - if fun == "prim::Constant": - codes += "\n" - elif fun == "aten::add": - # z = x + y * alpha - x, y, alpha = inputs - z = output - if alpha.value == 1: - if x.requires_grad: - if x.cu_var_bp not in code_block_nodes: - code_block_nodes[x.cu_var_bp] = x - codes += " " - codes += f"float {x.cu_var_bp} = {z.cu_var_bp};\n" - else: - codes += " " - codes += f"{x.cu_var_bp} += {z.cu_var_bp};\n" - if y.requires_grad: - if y.cu_var_bp not in code_block_nodes: - code_block_nodes[y.cu_var_bp] = y - codes += " " - codes += f"float {y.cu_var_bp} = {z.cu_var_bp};\n" - else: - codes += " " - codes += f"{y.cu_var_bp} += {z.cu_var_bp};\n" - else: - if x.requires_grad: - if x.cu_var_bp not in code_block_nodes: - code_block_nodes[x.cu_var_bp] = x - codes += " " - codes += f"float {x.cu_var_bp} = {z.cu_var_bp};\n" - else: - codes += " " - codes += f"{x.cu_var_bp} += {z.cu_var_bp};\n" - if y.requires_grad: - if y.cu_var_bp not in code_block_nodes: - code_block_nodes[y.cu_var_bp] = y - codes += " " - codes += f"float {y.cu_var_bp} = {z.cu_var_bp} * {alpha.cu_var_bp};\n" - else: - codes += " " - codes += ( - f"{y.cu_var_bp} += {z.cu_var_bp} * {alpha.cu_var_bp};\n" - ) - - elif fun == "aten::sub": - # z = x - y * alpha - x, y, alpha = inputs - z = output - if alpha.value == 1: - if x.requires_grad: - if x.cu_var_bp not in code_block_nodes: - code_block_nodes[x.cu_var_bp] = x - codes += " " - codes += f"float {x.cu_var_bp} = {z.cu_var_bp};\n" - else: - codes += " " - codes += f"{x.cu_var_bp} += {z.cu_var_bp};\n" - if y.requires_grad: - if y.cu_var_bp not in code_block_nodes: - code_block_nodes[y.cu_var_bp] = y - codes += " " - codes += f"float {y.cu_var_bp} = - {z.cu_var_bp};\n" - else: - codes += " " - codes += f"{y.cu_var_bp} += - {z.cu_var_bp};\n" - else: - if x.requires_grad: - if x.cu_var_bp not in code_block_nodes: - code_block_nodes[x.cu_var_bp] = x - codes += " " - codes += f"float {x.cu_var_bp} = {z.cu_var_bp};\n" - else: - codes += " " - codes += f"{x.cu_var_bp} += {z.cu_var_bp};\n" - if y.requires_grad: - if y.cu_var_bp not in code_block_nodes: - code_block_nodes[y.cu_var_bp] = y - codes += " " - codes += f"float {y.cu_var_bp} = - {z.cu_var_bp} * {alpha.cu_var_bp};\n" - else: - codes += " " - codes += ( - f"{y.cu_var_bp} += - {z.cu_var_bp} * {alpha.cu_var_bp};\n" - ) - - elif fun == "aten::mul": - # z = x * y - x, y = inputs - z = output - if x.requires_grad: - if x.cu_var_bp not in code_block_nodes: - code_block_nodes[x.cu_var_bp] = x - codes += " " - codes += f"float {x.cu_var_bp} = {z.cu_var_bp} * {y.cu_var};\n" - else: - codes += " " - codes += f"{x.cu_var_bp} += {z.cu_var_bp} * {y.cu_var};\n" - input_bp_nodes[y.name] = y - if y.requires_grad: - if y.cu_var_bp not in code_block_nodes: - code_block_nodes[y.cu_var_bp] = y - codes += " " - codes += f"float {y.cu_var_bp} = {z.cu_var_bp} * {x.cu_var};\n" - else: - codes += " " - codes += f"{y.cu_var_bp} += {z.cu_var_bp} * {x.cu_var};\n" - input_bp_nodes[x.name] = x - - elif fun == "aten::div": - # z = x / y - x, y = inputs - z = output - if x.requires_grad: - if x.cu_var_bp not in code_block_nodes: - code_block_nodes[x.cu_var_bp] = x - codes += " " - codes += f"float {x.cu_var_bp} = {z.cu_var_bp} / {y.cu_var};\n" - else: - codes += " " - codes += f"{x.cu_var_bp} += {z.cu_var_bp} / {y.cu_var};\n" - input_bp_nodes[y.name] = y - if y.requires_grad: - if y.cu_var_bp not in code_block_nodes: - code_block_nodes[y.cu_var_bp] = y - codes += " " - codes += f"float {y.cu_var_bp} = - {z.cu_var_bp} * {x.cu_var} / ({y.cu_var} * {y.cu_var});\n" - else: - codes += " " - codes += f"{y.cu_var_bp} += - {z.cu_var_bp} * {x.cu_var} / ({y.cu_var} * {y.cu_var});\n" - input_bp_nodes[x.name] = x - input_bp_nodes[y.name] = y - - for i, node in enumerate(input_bp_nodes): - logging.debug(f"\ninput bp node [{i}] = {node}") - - # CUDA函数的参数 - cuda_params = { - "grad_spike_seq": "const float *", - "grad_v_seq": "const float *", - "h_seq": "const float *", - "spike_seq": "const float *", - "grad_x_seq": "float *", - "grad_v_init": "float *", - "v_threshold": "const float &", - } - - if hard_reset: - cuda_params["v_reset"] = "const float &" - - cuda_params["neuron_num"] = "const int &" - cuda_params["numel"] = "const int &" - - # 在CUDA函数参数中增加参数,同时检测命名冲突 - - # 这里增加的是用户自定义的除了x和v_last外,其他需要梯度的python函数的参数 - for i, node in enumerate(input_nodes.values()): - if i >= 2: - if node.name_bp not in cuda_params: - if node.requires_grad: - cuda_params[node.name_bp] = "const float *" - - # 这里增加的是反向传播所需要的参数 - for node in input_bp_nodes.values(): - if node.name not in cuda_params: - assert node.debug_name in input_nodes or node.debug_name in output_nodes - - if node.instance == "Tensor": - cuda_params[node.name] = "const float *" - elif node.instance == "float": - cuda_params[node.name] = "const float &" - else: - raise NotImplementedError(node) - - params = [] - for cuda_param, cuda_param_instance in cuda_params.items(): - params.append(cuda_param_instance + cuda_param) - - head = ", ".join(params) - head = "(" + head + ")" - - head += """ - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < neuron_num) - { - float grad_output_h = 0.0f; // grad_output_h will be used recursively - for(int mem_offset = numel - neuron_num; mem_offset >= 0; mem_offset -= neuron_num) - { - const int t = index + mem_offset; - const float over_th = h_seq[t] - v_threshold; - """ - head += surrogate_fuction.cuda_code(x="over_th", y="grad_s_to_h", dtype="fp32") - - head += " " - if detach_reset: - if hard_reset: - head += "const float grad_v_to_h = 1.0f - spike_seq[t];\n" - else: - head += "const float grad_v_to_h = 1.0f;\n" - else: - if hard_reset: - head += "const float grad_v_to_h = 1.0f - spike_seq[t] + (-h_seq[t] + v_reset) * grad_s_to_h;\n" - else: - head += "const float grad_v_to_h = 1.0f - v_threshold * grad_s_to_h;\n" - - tail = "" - # grad_input_x, grad_input_v_last是自动生成的代码计算出来的 - tail += " " - tail += "grad_output_h = grad_spike_seq[t] * grad_s_to_h + (grad_v_seq[t] + grad_input_v_last) * grad_v_to_h;\n" - - for i, node in enumerate(input_nodes.values()): - if i >= 2: - if node.requires_grad: - tail += " " - tail += f"{node.name_bp}[t] = {node.cu_var_bp};\n" - - tail += """ - } - """ - - tail += codes - # += codes 是为了计算grad_v_init[index] - tail += """ - grad_v_init[index] = grad_input_v_last; - } - } - """ - codes = head + codes + tail - kernel_name = f"backward_kernel_{hash_str(codes)}" - codes = ( - f""" - extern "C" __global__ - void {kernel_name} - """ - + codes - ) - - input_bp_vars = [] - # input_bp_vars记录了python函数中的哪些输入变量,是计算反向传播所需的 - for node in input_bp_nodes.values(): - input_bp_vars.append(node.debug_name) - return codes, kernel_name, input_bp_vars diff --git a/src/chop/nn/snn/auto_cuda/neuron_kernel.py b/src/chop/nn/snn/auto_cuda/neuron_kernel.py deleted file mode 100644 index 1d5c9935d..000000000 --- a/src/chop/nn/snn/auto_cuda/neuron_kernel.py +++ /dev/null @@ -1,1045 +0,0 @@ -# *************************************************************************************** -# * Title: auto_cuda -# * Reference: These directory is directly sourced from spikingJelly -# * Availability: https://github.com/fangwei123456/spikingjelly/tree/master/spikingjelly/activation_based/auto_cuda -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# *************************************************************************************** - -from typing import Optional -import torch -import torch.nn.functional as F -import numpy as np -import logging - -try: - import cupy -except BaseException as e: - logging.info(f"spikingjelly.activation_based.auto_cuda.neuronal_kernel: {e}") - cupy = None - - -from .. import cuda_utils -from .. import configure -from typing import Callable, Iterable -from . import base, cfunction -import math - - -def neuronal_hard_reset( - v_next: str, h: str, spike: str, v_reset: str, dtype: str = "float" -): - if dtype == "float": - return f"{v_next} = {h} * (1.0f - {spike}) + {v_reset} * {spike};" - elif dtype == "half2": - return f"{v_next} = __hfma2({h}, __hsub2(__float2half2_rn(1.0f), {spike}), __hmul2(v_reset, {spike}));" - else: - raise NotImplementedError(dtype) - - -def neuronal_soft_reset( - v_next: str, h: str, spike: str, v_th: str, dtype: str = "float" -): - if dtype == "float": - return f"{v_next} = {h} - {v_th} * {spike};" - elif dtype == "half2": - return f"{v_next} = __hsub2({h}, __hmul2({v_th}, {spike}));" - else: - raise NotImplementedError(dtype) - - -def neuronal_fire(spike: str, v: str, v_th: str, dtype: str = "float"): - if dtype == "float": - return cfunction.heaviside(y=spike, x=f"({v} - {v_th})", dtype=dtype) - elif dtype == "half2": - return cfunction.heaviside(y=spike, x=f"__hsub2({v}, {v_th})", dtype=dtype) - else: - raise NotImplementedError(dtype) - - -class NeuronFPTTKernel(base.CKernel2D): - def __init__(self, hard_reset: bool, dtype: str): - super().__init__( - kernel_name=f'{self.__class__.__name__}_{dtype}_{"hard_reset" if hard_reset else "soft_reset"}', - reverse=False, - ) - self.hard_reset = hard_reset - self.dtype = dtype - self.add_param(ctype=f"const {dtype} *", cname="x_seq") - self.add_param(ctype=f"{dtype} *", cname="v_v_seq") - self.add_param(ctype=f"{dtype} *", cname="h_seq") - self.add_param(ctype=f"{dtype} *", cname="spike_seq") - self.add_param(ctype=f"{dtype} &", cname="v_th") - if hard_reset: - self.add_param(ctype=f"{dtype} &", cname="v_reset") - - def neuronal_charge(self) -> str: - """ - :return: CUDA code - :rtype: str - - Returns CUDA code for calculating :math:`H[t] = f(X[t], V[t-1], ...)`. - - This function should define how ``h_seq[t]`` is calculated by ``x_seq[t], v_v_seq[t]`` and other params if - the neuron needs. - - For example, the IF neuron define this function as: - - .. code-block:: python - - def neuronal_charge(self) -> str: - # note that v_v_seq[t] is v_seq[t - dt] - return cfunction.add(z='h_seq[t]', x='x_seq[t]', y='v_v_seq[t]', dtype=self.dtype) - """ - return "// neuronal_charge should be defined here!" - - @property - def core(self): - core_codes = base.CodeTyper(18) - - core_codes.append(self.neuronal_charge()) - - core_codes.append( - neuronal_fire( - spike="spike_seq[t]", v="h_seq[t]", v_th="v_th", dtype=self.dtype - ) - ) - - if self.hard_reset: - core_codes.append( - neuronal_hard_reset( - v_next="v_v_seq[t + dt]", - h="h_seq[t]", - spike="spike_seq[t]", - v_reset="v_reset", - dtype=self.dtype, - ) - ) - else: - core_codes.append( - neuronal_soft_reset( - v_next="v_v_seq[t + dt]", - h="h_seq[t]", - spike="spike_seq[t]", - v_th="v_th", - dtype=self.dtype, - ) - ) - - self._core = core_codes.codes - return self._core - - -class NeuronBPTTKernel(base.CKernel2D): - def __init__( - self, - surrogate_function: Callable, - hard_reset: bool, - detach_reset: bool, - dtype: str, - ): - super().__init__( - kernel_name=f'{self.__class__.__name__}_{dtype}_{"hard_reset" if hard_reset else "soft_reset"}_{"detach_reset" if detach_reset else "nodetach_reset"}', - reverse=True, - ) - self.surrogate_function = surrogate_function - self.hard_reset = hard_reset - self.detach_reset = detach_reset - self.dtype = dtype - self.add_param(ctype=f"const {dtype} *", cname="grad_spike_seq") - self.add_param(ctype=f"const {dtype} *", cname="grad_v_seq") - self.add_param(ctype=f"const {dtype} *", cname="h_seq") - self.add_param(ctype=f"{dtype} *", cname="grad_x_seq") - self.add_param(ctype=f"{dtype} *", cname="grad_v_init") - self.add_param(ctype=f"{dtype} &", cname="v_th") - if hard_reset: - self.add_param(ctype=f"{dtype} &", cname="v_reset") - - @property - def pre_core(self): - codes = base.CodeTyper(16) - if self.dtype == "float": - codes.append("float grad_h = 0.0f;") - elif self.dtype == "half2": - codes.append(cfunction.float2half2(y="half2 grad_h", x="0.0f")) - else: - raise NotImplementedError(self.dtype) - - self._pre_core = codes.codes - return self._pre_core - - @property - def post_core(self): - - codes = base.CodeTyper(16) - codes.append(self.grad_h_next_to_v()) - codes.append( - cfunction.mul( - z="grad_v_init[index]", - x="grad_h", - y="grad_h_next_to_v", - dtype=self.dtype, - ) - ) - self._post_core = codes.codes - return self._post_core - - def grad_h_next_to_v(self) -> str: - """ - :return: CUDA code - :rtype: str - - Returns CUDA code for calculating :math:`\\frac{\\mathrm{d} H[t+1]}{\\mathrm{d} V[t]}`. - - This function should define how ``grad_h_next_to_v`` is calculated. Note that ``grad_h_next_to_v`` has not been - declared. Thus, this function should also declare ``grad_h_next_to_v``. - - For example, the IF neuron define this function as: - - .. code-block:: python - - def grad_h_next_to_v(self) -> str: - return cfunction.constant(y=f'const {self.dtype} grad_h_next_to_v', x=1., dtype=self.dtype) - """ - return "// grad_h_next_to_v should be defined here!" - - def grad_h_to_x(self) -> str: - """ - :return: CUDA code - :rtype: str - - Returns CUDA code for calculating :math:`\\frac{\\mathrm{d} H[t]}{\\mathrm{d} X[t]}`. - - This function should define how ``grad_h_to_x`` is calculated. Note that ``grad_h_to_x`` has not been - declared. Thus, this function should also declare ``grad_h_to_x``. - - For example, the IF neuron define this function as: - - .. code-block:: python - - def grad_h_to_x(self) -> str: - return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype) - """ - return "// grad_h_to_x should be defined here!" - - @property - def core(self): - core_codes = base.CodeTyper(18) - - core_codes.append( - cfunction.sub( - z=f"const {self.dtype} over_th", - x="h_seq[t]", - y="v_th", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.heaviside( - y=f"const {self.dtype} spike_seq_t", x="over_th", dtype=self.dtype - ) - ) - core_codes.append( - self.surrogate_function( - y=f"const {self.dtype} grad_s_to_h", x="over_th", dtype=self.dtype - ) - ) - - if self.hard_reset: - core_codes.append( - cfunction.sub( - z=f"{self.dtype} grad_v_to_h", - x=cfunction.constant(y=None, x=1.0, dtype=self.dtype), - y="spike_seq_t", - dtype=self.dtype, - ) - ) - - if not self.detach_reset: - with base.CodeBlock(core_codes): - core_codes.append( - cfunction.sub( - z=f"{self.dtype} temp_var", - x="v_reset", - y="h_seq[t]", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.mul( - z=f"temp_var", - x="temp_var", - y="grad_s_to_h", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.add( - z=f"grad_v_to_h", - x="temp_var", - y="grad_v_to_h", - dtype=self.dtype, - ) - ) - - else: - core_codes.append( - f"{self.dtype} grad_v_to_h = {cfunction.constant(None, 1., dtype=self.dtype)}" - ) - - if not self.detach_reset: - with base.CodeBlock(core_codes): - core_codes.append( - cfunction.mul( - z=f"{self.dtype} temp_var", - x="v_th", - y="grad_s_to_h", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.sub( - z=f"grad_v_to_h", - x="grad_v_to_h", - y="temp_var", - dtype=self.dtype, - ) - ) - - core_codes.append(self.grad_h_next_to_v()) - core_codes.append( - cfunction.mul( - z=f"grad_h", x="grad_h", y="grad_h_next_to_v", dtype=self.dtype - ) - ) - core_codes.append( - cfunction.add(z="grad_h", x="grad_v_seq[t]", y="grad_h", dtype=self.dtype) - ) - core_codes.append( - cfunction.mul(z="grad_h", x="grad_h", y="grad_v_to_h", dtype=self.dtype) - ) - with base.CodeBlock(core_codes): - core_codes.append( - cfunction.mul( - z=f"{self.dtype} temp_var", - x="grad_spike_seq[t]", - y="grad_s_to_h", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.add(z="grad_h", x="grad_h", y="temp_var", dtype=self.dtype) - ) - - core_codes.append(self.grad_h_to_x()) - core_codes.append( - cfunction.mul( - z="grad_x_seq[t]", x="grad_h", y="grad_h_to_x", dtype=self.dtype - ) - ) - - self._core = core_codes.codes - return self._core - - -class IFNodeFPTTKernel(NeuronFPTTKernel): - def neuronal_charge(self) -> str: - return cfunction.add( - z="h_seq[t]", x="x_seq[t]", y="v_v_seq[t]", dtype=self.dtype - ) - - -class IFNodeBPTTKernel(NeuronBPTTKernel): - def grad_h_next_to_v(self) -> str: - return cfunction.constant( - y=f"const {self.dtype} grad_h_next_to_v", x=1.0, dtype=self.dtype - ) - - def grad_h_to_x(self) -> str: - return cfunction.constant( - y=f"const {self.dtype} grad_h_to_x", x=1.0, dtype=self.dtype - ) - - -def if_requires_grad(items: Iterable): - requires_grad = False - for item in items: - if isinstance(item, torch.Tensor): - if item.requires_grad: - requires_grad = True - break - - return requires_grad - - -def scalar_to_cupy(py_dict: dict, ref: str = "x_seq"): - device = py_dict[ref].get_device() - dtype = py_dict[ref].dtype - - with cuda_utils.DeviceEnvironment(device): - for key, value in py_dict.items(): - if isinstance(value, float): - if dtype == torch.float32: - value = cupy.asarray(value, dtype=np.float32) - elif dtype == torch.float16: - value = cupy.asarray([value, value], dtype=np.float16) - else: - raise NotImplementedError(dtype) - py_dict[key] = value - - elif isinstance(value, int): - py_dict[key] = cupy.asarray(value) - - -def new_tensors(news: tuple, py_dict: dict, ref: str = "x_seq"): - ref = py_dict[ref] - zero_shape = list(ref.shape) - zero_shape[0] *= news.__len__() - for i, item in enumerate( - torch.split( - torch.zeros(zero_shape, device=ref.device, dtype=ref.dtype), ref.shape[0] - ) - ): - py_dict[news[i]] = item - - -class NeuronATGFBase: - @staticmethod - def pre_forward(py_dict: dict): - """ - :param py_dict: a dict built from the neuron's forward autograd function. It should at least contain ``x_seq, v_init, v_reset`` - :type py_dict: dict - :return: requires_grad, blocks, threads, py_dict - - requires_grad: bool - if any tensor in ``py_dict`` requires grad, then ``requires_grad = True``;else ``requires_grad = False`` - - blocks: int - CUDA param used in calling CUDA kernel - - threads: int - CUDA param used in calling CUDA kernel. The default value is ``spikingjelly.configure.cuda_threads`` - - py_dict: dict - Compared with the input ``py_dict``, the returned ``py_dict`` will: - - * convert all ``float/int`` scalars in ``py_dict`` to ``cupy.ndarray`` - - * add ``h_seq, spike_seq, v_v_seq`` to ``py_dict``. ``h_seq, spike_seq`` are zero tensors - with the same shape with ``x_seq``. ``v_v_seq`` is concatenated from ``v_init`` and - ``v_seq``, which is zero tensors with the same shape with ``x_seq`` - - * add ``N, numel`` to ``py_dict``. Note that ``x_seq.shape = [T, N]`` and ``numel = T * N``. - A specific case is that ``x_seq.dtype == torch.half``, then ``N = math.ceil(N / 2)``, and - ``numel = N * x_seq.shape[0]``. - Note that ``N, numel`` in the returned ``py_dict`` are ``cupy.ndarray`` - - - :rtype: tuple - """ - device = py_dict["x_seq"].get_device() - requires_grad = if_requires_grad(py_dict.values()) - scalar_to_cupy(py_dict) - - new_tensors(("h_seq", "spike_seq", "v_seq"), py_dict) - py_dict["v_v_seq"] = torch.cat( - (py_dict.pop("v_init").unsqueeze(0), py_dict.pop("v_seq")) - ) - numel = py_dict["x_seq"].numel() - N = py_dict["x_seq"].shape[1] - threads = configure.cuda_threads - if py_dict["x_seq"].dtype == torch.float16: - # we will take two neurons to calculate as one neuron in cuda half2 - # pad will be implemented by the kernel.__call__ - N = math.ceil(N / 2) - numel = N * py_dict["x_seq"].shape[0] - - blocks = cuda_utils.cal_blocks(N) - - with cuda_utils.DeviceEnvironment(device): - numel = cupy.asarray(numel) - N = cupy.asarray(N) - - py_dict["numel"] = numel - py_dict["N"] = N - - return requires_grad, blocks, threads, py_dict - - @staticmethod - def ctx_save(ctx, requires_grad: bool, *args, **kwargs): - """ - :param ctx: ``ctx`` in :class:`torch.autograd.Function` - :param requires_grad: if any tensor in forward params requires grad - :type requires_grad: bool - :param args: tensors that need to be saved by ``ctx.save_for_backward`` - :param kwargs: items that need to be saved by ``ctx.xx = xx`` - - Saves ``*args, **kwargs`` in ``ctx`` by ``ctx.save_for_backward(*args)`` and ``ctx.xx = xx`` for all ``xx`` in ``kwargs.items()``. - """ - if requires_grad: - ctx.save_for_backward(*args) - for key, value in kwargs.items(): - ctx.__setattr__(key, value) - - @staticmethod - def pre_backward(ctx, grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor): - """ - :param ctx: ``ctx`` in :class:`torch.autograd.Function` - :param grad_spike_seq: gradients of ``spike_seq`` - :type grad_spike_seq: torch.Tensor - :param grad_v_seq: gradients of ``v_seq`` - :type grad_v_seq: torch.Tensor - :return: backward_kernel, blocks, threads, py_dict - - backward_kernel: NeuronBPTTKernel - The CUDA kernel used for backward. It should be provided in ``ctx.backward_kernel`` - - blocks: int - CUDA param used in calling CUDA kernel. It should be provided in ``ctx.blocks`` - - threads: int - CUDA param used in calling CUDA kernel. It should be provided in ``ctx.threads`` - :rtype: tuple - """ - backward_kernel = ctx.backward_kernel - blocks = ctx.blocks - threads = ctx.threads - - h_seq = ctx.saved_tensors[0] - numel = ctx.numel - N = ctx.N - v_th = ctx.v_th - v_reset = ctx.v_reset - - zero_shape = list(grad_spike_seq.shape) - zero_shape[0] += 1 - zero_data = torch.zeros( - zero_shape, device=grad_spike_seq.device, dtype=grad_spike_seq.dtype - ) - grad_x_seq = zero_data[0:-1] - grad_v_init = zero_data[-1] - - py_dict = { - "numel": numel, - "N": N, - "grad_spike_seq": grad_spike_seq, - "grad_v_seq": grad_v_seq, - "h_seq": h_seq, - "grad_x_seq": grad_x_seq, - "grad_v_init": grad_v_init, - "v_th": v_th, - "v_reset": v_reset, - } - - return backward_kernel, blocks, threads, py_dict - - -class IFNodeATGF(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x_seq: torch.Tensor, - v_init: torch.Tensor, - v_th: float, - v_reset: Optional[float], - forward_kernel: IFNodeFPTTKernel, - backward_kernel: IFNodeBPTTKernel, - ): - py_dict = {"x_seq": x_seq, "v_init": v_init, "v_th": v_th, "v_reset": v_reset} - requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict) - - if py_dict["v_reset"] is None: - py_dict.pop("v_reset") - - forward_kernel((blocks,), (threads,), py_dict) - - if "v_reset" not in py_dict: - py_dict["v_reset"] = None - - NeuronATGFBase.ctx_save( - ctx, - requires_grad, - py_dict["h_seq"], - blocks=blocks, - threads=threads, - numel=py_dict["numel"], - N=py_dict["N"], - v_th=py_dict["v_th"], - v_reset=py_dict["v_reset"], - backward_kernel=backward_kernel, - ) - - return py_dict["spike_seq"], py_dict["v_v_seq"][1:,] - - @staticmethod - def backward(ctx, grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor): - - backward_kernel, blocks, threads, py_dict = NeuronATGFBase.pre_backward( - ctx, grad_spike_seq, grad_v_seq - ) - - if py_dict["v_reset"] is None: - py_dict.pop("v_reset") - - backward_kernel((blocks,), (threads,), py_dict) - - if "v_reset" not in py_dict: - py_dict["v_reset"] = None - - return py_dict["grad_x_seq"], py_dict["grad_v_init"], None, None, None, None - - -class LIFNodeFPTTKernel(NeuronFPTTKernel): - def __init__(self, decay_input: bool, hard_reset: bool, dtype: str): - super().__init__(hard_reset, dtype) - self.decay_input = decay_input - self.add_param(ctype=f"const {dtype} &", cname="decay") - - def neuronal_charge(self) -> str: - if self.hard_reset: - codes = cfunction.sub( - z=f"{self.dtype} LIFNodeFPTTKernel_temp_var", - x="v_v_seq[t]", - y="v_reset", - dtype=self.dtype, - ) - else: - codes = f"{self.dtype} LIFNodeFPTTKernel_temp_var = v_v_seq[t];" - - if self.decay_input: - codes += cfunction.sub( - z="LIFNodeFPTTKernel_temp_var", - x="x_seq[t]", - y="LIFNodeFPTTKernel_temp_var", - dtype=self.dtype, - ) - codes += cfunction.mul( - z="LIFNodeFPTTKernel_temp_var", - x="decay", - y="LIFNodeFPTTKernel_temp_var", - dtype=self.dtype, - ) - else: - codes += cfunction.mul( - z="LIFNodeFPTTKernel_temp_var", - x="decay", - y="LIFNodeFPTTKernel_temp_var", - dtype=self.dtype, - ) - codes += cfunction.sub( - z="LIFNodeFPTTKernel_temp_var", - x="x_seq[t]", - y="LIFNodeFPTTKernel_temp_var", - dtype=self.dtype, - ) - - codes += cfunction.add( - z="h_seq[t]", - x="LIFNodeFPTTKernel_temp_var", - y="v_v_seq[t]", - dtype=self.dtype, - ) - - return codes - - -class LIFNodeBPTTKernel(NeuronBPTTKernel): - def __init__( - self, - decay_input: bool, - surrogate_function: Callable, - hard_reset: bool, - detach_reset: bool, - dtype: str, - ): - super().__init__(surrogate_function, hard_reset, detach_reset, dtype) - self.decay_input = decay_input - self.add_param(ctype=f"const {dtype} &", cname="decay") - - def grad_h_next_to_v(self) -> str: - return cfunction.sub( - z=f"const {self.dtype} grad_h_next_to_v", - x=cfunction.constant(None, x=1.0, dtype=self.dtype), - y="decay", - dtype=self.dtype, - ) - - def grad_h_to_x(self) -> str: - if not self.decay_input: - return cfunction.constant( - y=f"const {self.dtype} grad_h_to_x", x=1.0, dtype=self.dtype - ) - else: - return f"const {self.dtype} grad_h_to_x = decay;" - - -class LIFNodeATGF(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x_seq: torch.Tensor, - v_init: torch.Tensor, - v_th: float, - v_reset: Optional[float], - decay: float, - forward_kernel: LIFNodeFPTTKernel, - backward_kernel: LIFNodeBPTTKernel, - ): - py_dict = { - "x_seq": x_seq, - "v_init": v_init, - "v_th": v_th, - "v_reset": v_reset, - "decay": decay, - } - requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict) - - if py_dict["v_reset"] is None: - py_dict.pop("v_reset") - - forward_kernel((blocks,), (threads,), py_dict) - - if "v_reset" not in py_dict: - py_dict["v_reset"] = None - - NeuronATGFBase.ctx_save( - ctx, - requires_grad, - py_dict["h_seq"], - blocks=blocks, - threads=threads, - numel=py_dict["numel"], - N=py_dict["N"], - v_th=py_dict["v_th"], - v_reset=py_dict["v_reset"], - backward_kernel=backward_kernel, - decay=py_dict["decay"], - ) - - return py_dict["spike_seq"], py_dict["v_v_seq"][1:,] - - @staticmethod - def backward(ctx, grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor): - - backward_kernel, blocks, threads, py_dict = NeuronATGFBase.pre_backward( - ctx, grad_spike_seq, grad_v_seq - ) - py_dict["decay"] = ctx.decay - - if py_dict["v_reset"] is None: - py_dict.pop("v_reset") - - backward_kernel((blocks,), (threads,), py_dict) - - if "v_reset" not in py_dict: - py_dict["v_reset"] = None - - return ( - py_dict["grad_x_seq"], - py_dict["grad_v_init"], - None, - None, - None, - None, - None, - ) - - -class ParametricLIFNodeFPTTKernel(NeuronFPTTKernel): - def __init__(self, decay_input: bool, hard_reset: bool, dtype: str): - super().__init__(hard_reset, dtype) - self.decay_input = decay_input - self.add_param(ctype=f"const {dtype} *", cname="decay") - - def neuronal_charge(self) -> str: - if self.hard_reset: - codes = cfunction.sub( - z=f"{self.dtype} LIFNodeFPTTKernel_temp_var", - x="v_v_seq[t]", - y="v_reset", - dtype=self.dtype, - ) - else: - codes = f"{self.dtype} LIFNodeFPTTKernel_temp_var = v_v_seq[t];" - if self.decay_input: - codes += cfunction.sub( - z="LIFNodeFPTTKernel_temp_var", - x="x_seq[t]", - y="LIFNodeFPTTKernel_temp_var", - dtype=self.dtype, - ) - codes += cfunction.mul( - z="LIFNodeFPTTKernel_temp_var", - x="decay[0]", - y="LIFNodeFPTTKernel_temp_var", - dtype=self.dtype, - ) - else: - codes += cfunction.mul( - z="LIFNodeFPTTKernel_temp_var", - x="decay[0]", - y="LIFNodeFPTTKernel_temp_var", - dtype=self.dtype, - ) - codes += cfunction.sub( - z="LIFNodeFPTTKernel_temp_var", - x="x_seq[t]", - y="LIFNodeFPTTKernel_temp_var", - dtype=self.dtype, - ) - - codes += cfunction.add( - z="h_seq[t]", - x="LIFNodeFPTTKernel_temp_var", - y="v_v_seq[t]", - dtype=self.dtype, - ) - - return codes - - -class ParametricLIFNodeBPTTKernel(NeuronBPTTKernel): - def __init__( - self, - decay_input: bool, - surrogate_function: Callable, - hard_reset: bool, - detach_reset: bool, - dtype: str, - ): - super().__init__(surrogate_function, hard_reset, detach_reset, dtype) - self.decay_input = decay_input - self.add_param(ctype=f"const {dtype} *", cname="decay") - self.add_param(ctype=f"float *", cname="grad_decay") - # float to avoid overflow - self.add_param(ctype=f"const {dtype} *", cname="v_v_seq") - - def grad_h_next_to_v(self) -> str: - return cfunction.sub( - z=f"const {self.dtype} grad_h_next_to_v", - x=cfunction.constant(None, x=1.0, dtype=self.dtype), - y="decay[0]", - dtype=self.dtype, - ) - - def grad_h_to_x(self) -> str: - if not self.decay_input: - return cfunction.constant( - y=f"const {self.dtype} grad_h_to_x", x=1.0, dtype=self.dtype - ) - else: - return f"const {self.dtype} grad_h_to_x = decay[0];" - - @property - def head(self): - # override - codes = """ - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - """ - codes += rf""" - __shared__ float sdata[{configure.cuda_threads}]; - """ - codes += """ - if (index < N) - { - const int dt = N; - """ - - codes += self.pre_core - - if self.reverse: - codes += """ - for(int t = numel - N + index; t >= 0; t -= dt) - { - """ - else: - codes += """ - for(int t = index; t < numel; t += dt) - { - """ - return codes - - @property - def pre_core(self): - codes = base.CodeTyper(16) - # use float to avoid overflow - codes.append("sdata[threadIdx.x] = 0.0f;") - return super().pre_core + "\n" + codes.codes - - @property - def core(self): - core_codes = base.CodeTyper(18) - with base.CodeBlock(core_codes): - if self.decay_input: - - core_codes.append( - cfunction.sub( - z=f"{self.dtype} temp_var", - x="h_seq[t]", - y="v_v_seq[t]", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.mul( - z="temp_var", x="temp_var", y="grad_h", dtype=self.dtype - ) - ) - core_codes.append( - cfunction.div( - z="temp_var", x="temp_var", y="decay[0]", dtype=self.dtype - ) - ) - - else: - if self.hard_reset: - core_codes.append( - cfunction.sub( - z=f"{self.dtype} temp_var", - x="v_reset", - y="v_v_seq[t]", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.mul( - z="temp_var", x="temp_var", y="grad_h", dtype=self.dtype - ) - ) - else: - core_codes.append( - cfunction.mul( - z=f"{self.dtype} temp_var", - x="grad_h", - y="v_v_seq[t]", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.neg(y="temp_var", x="temp_var", dtype=self.dtype) - ) - - if self.dtype == "float": - core_codes.append("sdata[threadIdx.x] += temp_var;") - elif self.dtype == "half2": - core_codes.append( - "sdata[threadIdx.x] += __half2float(__hadd(__low2half(temp_var), __high2half(temp_var)));" - ) - else: - raise NotImplementedError(self.dtype) - - return super().core + "\n" + core_codes.codes - - @property - def tail(self): - codes = """ - } - """ - - codes += self.post_core - - codes += """ - } - else - { - sdata[threadIdx.x] = 0.0f; - } - int threadx = blockDim.x; - #pragma unroll - for (int stride = threadx >> 1; stride > 0; stride = stride >> 1) - { - // Synchronize all thread before next loop - __syncthreads(); - if (threadIdx.x < stride) - { - sdata[threadIdx.x] += sdata[threadIdx.x + stride]; - } - } - __syncthreads(); - if (threadIdx.x == 0) - { - atomicAdd(grad_decay, sdata[0]); - } - } - """ - return codes - - -class ParametricLIFNodeATGF(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x_seq: torch.Tensor, - v_init: torch.Tensor, - v_th: float, - v_reset: Optional[float], - decay: torch.Tensor, - forward_kernel: ParametricLIFNodeFPTTKernel, - backward_kernel: ParametricLIFNodeBPTTKernel, - ): - if x_seq.dtype == torch.float16 and v_init.numel() % 2 != 0: - raise ValueError( - "When using the the PLIF neuron with half2 cupy backend, the numer of neurons should be even to avoid the wrong gradient of tau caused by padding!" - ) - py_dict = { - "x_seq": x_seq, - "v_init": v_init, - "v_th": v_th, - "v_reset": v_reset, - "decay": decay, - } - requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict) - - if py_dict["v_reset"] is None: - py_dict.pop("v_reset") - - forward_kernel((blocks,), (threads,), py_dict) - - if "v_reset" not in py_dict: - py_dict["v_reset"] = None - - NeuronATGFBase.ctx_save( - ctx, - requires_grad, - py_dict["h_seq"], - py_dict["v_v_seq"], - blocks=blocks, - threads=threads, - numel=py_dict["numel"], - N=py_dict["N"], - v_th=py_dict["v_th"], - v_reset=py_dict["v_reset"], - backward_kernel=backward_kernel, - decay=py_dict["decay"], - ) - - return py_dict["spike_seq"], py_dict["v_v_seq"][1:,] - - @staticmethod - def backward(ctx, grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor): - - backward_kernel, blocks, threads, py_dict = NeuronATGFBase.pre_backward( - ctx, grad_spike_seq, grad_v_seq - ) - py_dict["decay"] = ctx.decay - py_dict["grad_decay"] = torch.zeros_like(ctx.decay, dtype=torch.float) - py_dict["v_v_seq"] = ctx.saved_tensors[1] - - if py_dict["v_reset"] is None: - py_dict.pop("v_reset") - - backward_kernel((blocks,), (threads,), py_dict) - - if "v_reset" not in py_dict: - py_dict["v_reset"] = None - - return ( - py_dict["grad_x_seq"], - py_dict["grad_v_init"], - None, - None, - py_dict["grad_decay"], - None, - None, - ) diff --git a/src/chop/nn/snn/auto_cuda/readme.md b/src/chop/nn/snn/auto_cuda/readme.md deleted file mode 100644 index c66ec01fe..000000000 --- a/src/chop/nn/snn/auto_cuda/readme.md +++ /dev/null @@ -1,152 +0,0 @@ -> *************************************************************************************** -> * Title: auto_cuda -> * Reference: These directory is directly sourced from spikingJelly -> * Availability: https://github.com/fangwei123456/spikingjelly/tree/master/spikingjelly/activation_based/auto_cuda -> * Date: 07/11/2024 -> * Code version: 0.0.0.014 -> *************************************************************************************** - -# Description - -`auto_cuda` is an experimental package for creating CUDA codes from the python function automatically. - -The final goal is after the user defines a new kind of spiking neuron by python, then the neuron can use the `cupy` backend, whose codes are generated by `auto_cuda`. - -For the moment, we have implemented creating CUDA codes from the python `neuronal_charge` function. - -# Example - -Run the following python codes: - -```python -from spikingjelly.activation_based.auto_cuda.generator import analyse_graph, gen_forward_codes, gen_backward_codes -from spikingjelly.activation_based import surrogate - -import torch -if __name__ == '__main__': - - def lif_charge(x: torch.Tensor, v_last: torch.Tensor, tau: float, v_reset: float): - h = v_last + (x - (v_last - v_reset)) / tau - return h - - - input_nodes, inter_nodes, output_nodes, cmds = analyse_graph(lif_charge, requires_grad=(True, True, False, False)) - - forward_codes, forward_kernel_name, cuda_cmds = gen_forward_codes(input_nodes, inter_nodes, output_nodes, cmds, hard_reset=True) - - backward_codes, backward_kernel_name, input_bp_vars = gen_backward_codes(cuda_cmds, input_nodes, output_nodes, cmds, hard_reset=True, detach_reset=True, surrogate_fuction=surrogate.ATan()) - - print(f'forward_codes = \n{forward_codes}') - print(f'backward_codes = \n{backward_codes}') -``` - -Then we will get the output CUDA codes: - -```c++ -forward_codes = - - extern "C" __global__ - void forward_kernel_697806161140619033 - (const float *x_seq, float *v_v_seq, float *h_seq, float *spike_seq, const float &v_threshold, const float &v_reset, const int &neuron_num, const int &numel, const float &input_tau, const float &input_v_reset) - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < neuron_num) - { - const int dt = neuron_num; - for(int mem_offset = 0; mem_offset < numel; mem_offset += neuron_num) - { - const int t = index + mem_offset; - - { - const float input_x = x_seq[t]; - const float input_v_last = v_v_seq[t]; - - float inter_9 = input_v_last - input_v_reset; - float inter_11 = input_x - inter_9; - float inter_13 = inter_11 / input_tau; - float output_h = input_v_last + inter_13; - h_seq[t] = output_h; - } - - if (h_seq[t] >= v_threshold) - - { - spike_seq[t] = 1.0f; - v_v_seq[t + dt] = v_reset; - } - - else - { - spike_seq[t] = 0.0f; - v_v_seq[t + dt] = h_seq[t]; - } - } - } - } - -backward_codes = - - extern "C" __global__ - void backward_kernel__3595517059288953692 - (const float *grad_spike_seq, const float *grad_v_seq, const float *h_seq, const float *spike_seq, float *grad_x_seq, float *grad_v_init, const float &v_threshold, const float &v_reset, const int &neuron_num, const int &numel, const float &input_tau) - { - const int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < neuron_num) - { - float grad_output_h = 0.0f; // grad_output_h will be used recursively - for(int mem_offset = numel - neuron_num; mem_offset >= 0; mem_offset -= neuron_num) - { - const int t = index + mem_offset; - const float over_th = h_seq[t] - v_threshold; - - // start: spikingjelly.activation_based.surrogate.ATan.cuda_code - - const float sg_ATan_M_PI_2__alpha__x = ((float) 1.57079632679489661923) * 2.0f * over_th; - const float grad_s_to_h = 2.0f / 2.0f / (1.0f + sg_ATan_M_PI_2__alpha__x * sg_ATan_M_PI_2__alpha__x); - - // end: spikingjelly.activation_based.surrogate.ATan.cuda_code - const float grad_v_to_h = 1.0f - spike_seq[t]; - - - // output_h = input_v_last + inter_13; - float grad_input_v_last = grad_output_h; - float grad_inter_13 = grad_output_h; - - // inter_13 = inter_11 / input_tau; - float grad_inter_11 = grad_inter_13 / input_tau; - - // inter_11 = input_x - inter_9; - float grad_input_x = grad_inter_11; - float grad_inter_9 = - grad_inter_11; - - // inter_9 = input_v_last - input_v_reset; - grad_input_v_last += grad_inter_9; - - // - - grad_output_h = grad_spike_seq[t] * grad_s_to_h + (grad_v_seq[t] + grad_input_v_last) * grad_v_to_h; - - } - - - // output_h = input_v_last + inter_13; - float grad_input_v_last = grad_output_h; - float grad_inter_13 = grad_output_h; - - // inter_13 = inter_11 / input_tau; - float grad_inter_11 = grad_inter_13 / input_tau; - - // inter_11 = input_x - inter_9; - float grad_input_x = grad_inter_11; - float grad_inter_9 = - grad_inter_11; - - // inter_9 = input_v_last - input_v_reset; - grad_input_v_last += grad_inter_9; - - // - - - grad_v_init[index] = grad_input_v_last; - } - } -``` \ No newline at end of file diff --git a/src/chop/nn/snn/auto_cuda/ss_neuron_kernel.py b/src/chop/nn/snn/auto_cuda/ss_neuron_kernel.py deleted file mode 100644 index 8ceb6abac..000000000 --- a/src/chop/nn/snn/auto_cuda/ss_neuron_kernel.py +++ /dev/null @@ -1,708 +0,0 @@ -# *************************************************************************************** -# * Title: auto_cuda -# * Reference: These directory is directly sourced from spikingJelly -# * Availability: https://github.com/fangwei123456/spikingjelly/tree/master/spikingjelly/activation_based/auto_cuda -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# *************************************************************************************** - -from typing import Optional -import torch -import torch.nn.functional as F -import numpy as np -import logging - -try: - import cupy -except BaseException as e: - logging.info(f"spikingjelly.activation_based.auto_cuda.ss_neuronal_kernel: {e}") - cupy = None - - -from .. import cuda_utils -from .. import configure -from typing import Callable, Iterable -from . import base, cfunction -import math - - -def if_requires_grad(items: Iterable): - requires_grad = False - for item in items: - if isinstance(item, torch.Tensor): - if item.requires_grad: - requires_grad = True - break - - return requires_grad - - -def scalar_to_cupy(py_dict: dict, ref: str = "x"): - device = py_dict[ref].get_device() - dtype = py_dict[ref].dtype - - with cuda_utils.DeviceEnvironment(device): - for key, value in py_dict.items(): - if isinstance(value, float): - if dtype == torch.float32: - value = cupy.asarray(value, dtype=np.float32) - elif dtype == torch.float16: - value = cupy.asarray([value, value], dtype=np.float16) - else: - raise NotImplementedError(dtype) - py_dict[key] = value - - elif isinstance(value, int): - py_dict[key] = cupy.asarray(value) - - -def new_tensors(news: tuple, py_dict: dict, ref: str = "x"): - ref = py_dict[ref] - zero_shape = list(ref.shape) - zero_shape[0] *= news.__len__() - for i, item in enumerate( - torch.split( - torch.zeros(zero_shape, device=ref.device, dtype=ref.dtype), ref.shape[0] - ) - ): - py_dict[news[i]] = item - - -def neuronal_hard_reset( - v_next: str, h: str, spike: str, v_reset: str, dtype: str = "float" -): - if dtype == "float": - return f"{v_next} = {h} * (1.0f - {spike}) + {v_reset} * {spike};" - elif dtype == "half2": - return f"{v_next} = __hfma2({h}, __hsub2(__float2half2_rn(1.0f), {spike}), __hmul2(v_reset, {spike}));" - else: - raise NotImplementedError(dtype) - - -def neuronal_soft_reset( - v_next: str, h: str, spike: str, v_th: str, dtype: str = "float" -): - if dtype == "float": - return f"{v_next} = {h} - {v_th} * {spike};" - elif dtype == "half2": - return f"{v_next} = __hsub2({h}, __hmul2({v_th}, {spike}));" - else: - raise NotImplementedError(dtype) - - -def neuronal_fire(spike: str, v: str, v_th: str, dtype: str = "float"): - if dtype == "float": - return cfunction.heaviside(y=spike, x=f"({v} - {v_th})", dtype=dtype) - elif dtype == "half2": - return cfunction.heaviside(y=spike, x=f"__hsub2({v}, {v_th})", dtype=dtype) - else: - raise NotImplementedError(dtype) - - -class NeuronFPKernel(base.CKernel1D): - def __init__(self, hard_reset: bool, dtype: str): - super().__init__( - kernel_name=f'{self.__class__.__name__}_{dtype}_{"hard_reset" if hard_reset else "soft_reset"}' - ) - self.hard_reset = hard_reset - self.dtype = dtype - self.hard_reset = hard_reset - self.dtype = dtype - self.add_param(ctype=f"const {dtype} *", cname="x") - self.add_param(ctype=f"const {dtype} *", cname="v") - self.add_param(ctype=f"{dtype} *", cname="h") - self.add_param(ctype=f"{dtype} *", cname="v_next") - self.add_param(ctype=f"{dtype} *", cname="spike") - self.add_param(ctype=f"{dtype} &", cname="v_th") - if hard_reset: - self.add_param(ctype=f"{dtype} &", cname="v_reset") - - def neuronal_charge(self) -> str: - """ - :return: CUDA code - :rtype: str - - Returns CUDA code for calculating :math:`H = f(X, V, ...)`. - - This function should define how ``h`` is calculated by ``x[index], v[index]`` and other params if - the neuron needs. - - For example, the IF neuron define this function as: - - .. code-block:: python - - def neuronal_charge(self) -> str: - return cfunction.add(z='h[index]', x='x[index]', y='v[index]', dtype=self.dtype) - """ - return "// neuronal_charge should be defined here!" - - @property - def core(self): - core_codes = base.CodeTyper(18) - - core_codes.append(self.neuronal_charge()) - - core_codes.append( - neuronal_fire( - spike="spike[index]", v="h[index]", v_th="v_th", dtype=self.dtype - ) - ) - - if self.hard_reset: - core_codes.append( - neuronal_hard_reset( - v_next="v_next[index]", - h="h[index]", - spike="spike[index]", - v_reset="v_reset", - dtype=self.dtype, - ) - ) - else: - core_codes.append( - neuronal_soft_reset( - v_next="v_next[index]", - h="h[index]", - spike="spike[index]", - v_th="v_th", - dtype=self.dtype, - ) - ) - - self._core = core_codes.codes - return self._core - - -class NeuronBPKernel(base.CKernel1D): - def __init__( - self, - surrogate_function: Callable, - hard_reset: bool, - detach_reset: bool, - dtype: str, - ): - super().__init__( - kernel_name=f'{self.__class__.__name__}_{dtype}_{"hard_reset" if hard_reset else "soft_reset"}_{"detach_reset" if detach_reset else "nodetach_reset"}' - ) - self.surrogate_function = surrogate_function - self.hard_reset = hard_reset - self.detach_reset = detach_reset - self.dtype = dtype - self.add_param(ctype=f"const {dtype} *", cname="grad_spike") - self.add_param(ctype=f"const {dtype} *", cname="grad_v_next") - self.add_param(ctype=f"const {dtype} *", cname="h") - self.add_param(ctype=f"{dtype} *", cname="grad_x") - self.add_param(ctype=f"{dtype} *", cname="grad_v") - self.add_param(ctype=f"{dtype} &", cname="v_th") - if hard_reset: - self.add_param(ctype=f"{dtype} &", cname="v_reset") - - @property - def post_core(self): - - codes = base.CodeTyper(16) - codes.append(self.grad_h_next_to_v()) - codes.append( - cfunction.mul( - z="grad_v[index]", x="grad_h", y="grad_h_next_to_v", dtype=self.dtype - ) - ) - self._post_core = codes.codes - return self._post_core - - def grad_h_to_v(self) -> str: - """ - :return: CUDA code - :rtype: str - - Returns CUDA code for calculating :math:`\\frac{\\mathrm{d} H}{\\mathrm{d} V}`. - - This function should define how ``grad_h_to_v`` is calculated. Note that ``grad_h_to_v`` has not been - declared. Thus, this function should also declare ``grad_h_to_v``. - - For example, the IF neuron define this function as: - - .. code-block:: python - - def grad_h_to_v(self) -> str: - return cfunction.constant(y=f'const {self.dtype} grad_h_to_v', x=1., dtype=self.dtype) - """ - return "// grad_h_to_v should be defined here!" - - def grad_h_to_x(self) -> str: - """ - :return: CUDA code - :rtype: str - - Returns CUDA code for calculating :math:`\\frac{\\mathrm{d} H[t]}{\\mathrm{d} X[t]}`. - - This function should define how ``grad_h_to_x`` is calculated. Note that ``grad_h_to_x`` has not been - declared. Thus, this function should also declare ``grad_h_to_x``. - - For example, the IF neuron define this function as: - - .. code-block:: python - - def grad_h_to_x(self) -> str: - return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype) - """ - return "// grad_h_to_x should be defined here!" - - @property - def core(self): - core_codes = base.CodeTyper(18) - - core_codes.append( - cfunction.sub( - z=f"const {self.dtype} over_th", - x="h[index]", - y="v_th", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.heaviside( - y=f"const {self.dtype} spike", x="over_th", dtype=self.dtype - ) - ) - core_codes.append( - self.surrogate_function( - y=f"const {self.dtype} grad_s_to_h", x="over_th", dtype=self.dtype - ) - ) - - if self.hard_reset: - core_codes.append( - cfunction.sub( - z=f"{self.dtype} grad_v_next_to_h", - x=cfunction.constant(y=None, x=1.0, dtype=self.dtype), - y="spike", - dtype=self.dtype, - ) - ) - - if not self.detach_reset: - with base.CodeBlock(core_codes): - core_codes.append( - cfunction.sub( - z=f"{self.dtype} temp_var", - x="v_reset", - y="h[index]", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.mul( - z=f"temp_var", - x="temp_var", - y="grad_s_to_h", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.add( - z=f"grad_v_next_to_h", - x="temp_var", - y="grad_v_next_to_h", - dtype=self.dtype, - ) - ) - - else: - core_codes.append( - f"{self.dtype} grad_v_next_to_h = {cfunction.constant(None, 1., dtype=self.dtype)}" - ) - - if not self.detach_reset: - with base.CodeBlock(core_codes): - core_codes.append( - cfunction.mul( - z=f"{self.dtype} temp_var", - x="v_th", - y="grad_s_to_h", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.sub( - z=f"grad_v_next_to_h", - x="grad_v_next_to_h", - y="temp_var", - dtype=self.dtype, - ) - ) - - core_codes.append( - cfunction.mul( - z=f"{self.dtype} grad_h", - x="grad_s_to_h", - y="grad_spike[index]", - dtype=self.dtype, - ) - ) - core_codes.append( - cfunction.add( - z="grad_h", - x=cfunction.mul( - z=None, - x="grad_v_next[index]", - y="grad_v_next_to_h", - dtype=self.dtype, - ), - y="grad_h", - dtype=self.dtype, - ) - ) - - core_codes.append(self.grad_h_to_v()) - core_codes.append( - cfunction.mul( - z="grad_v[index]", x="grad_h", y="grad_h_to_v", dtype=self.dtype - ) - ) - - core_codes.append(self.grad_h_to_x()) - core_codes.append( - cfunction.mul( - z="grad_x[index]", x="grad_h", y="grad_h_to_x", dtype=self.dtype - ) - ) - - self._core = core_codes.codes - return self._core - - -class NeuronATGFBase: - @staticmethod - def pre_forward(py_dict: dict): - """ - :param py_dict: a dict built from the neuron's forward autograd function. It should at least contain ``x, v, v_reset`` - :type py_dict: dict - :return: requires_grad, blocks, threads, py_dict - - requires_grad: bool - if any tensor in ``py_dict`` requires grad, then ``requires_grad = True``;else ``requires_grad = False`` - - blocks: int - CUDA param used in calling CUDA kernel - - threads: int - CUDA param used in calling CUDA kernel. The default value is ``spikingjelly.configure.cuda_threads`` - - py_dict: dict - Compared with the input ``py_dict``, the returned ``py_dict`` will: - - * convert all ``float/int`` scalars in ``py_dict`` to ``cupy.ndarray`` - - * add ``h, spike, v_next`` to ``py_dict``. They are zero tensors - with the same shape with ``x`` or ``v``. - - * add ``numel`` to ``py_dict``. Note that ``x.shape = [numel]``. - A specific case is that ``x.dtype == torch.half``, then ``numel = math.ceil(numel / 2)``. - Note that ``numel`` in the returned ``py_dict`` is ``cupy.ndarray`` - - - :rtype: tuple - """ - device = py_dict["x"].get_device() - requires_grad = if_requires_grad(py_dict.values()) - scalar_to_cupy(py_dict) - - new_tensors(("h", "spike", "v_next"), py_dict) - numel = py_dict["x"].numel() - threads = configure.cuda_threads - if py_dict["x"].dtype == torch.float16: - # we will take two neurons to calculate as one neuron in cuda half2 - # pad will be implemented by the kernel.__call__ - numel = math.ceil(numel / 2) - - blocks = cuda_utils.cal_blocks(numel) - - with cuda_utils.DeviceEnvironment(device): - numel = cupy.asarray(numel) - - py_dict["numel"] = numel - - return requires_grad, blocks, threads, py_dict - - @staticmethod - def ctx_save(ctx, requires_grad: bool, *args, **kwargs): - """ - :param ctx: ``ctx`` in :class:`torch.autograd.Function` - :param requires_grad: if any tensor in forward params requires grad - :type requires_grad: bool - :param args: tensors that need to be saved by ``ctx.save_for_backward`` - :param kwargs: items that need to be saved by ``ctx.xx = xx`` - - Saves ``*args, **kwargs`` in ``ctx`` by ``ctx.save_for_backward(*args)`` and ``ctx.xx = xx`` for all ``xx`` in ``kwargs.items()``. - """ - if requires_grad: - ctx.save_for_backward(*args) - for key, value in kwargs.items(): - ctx.__setattr__(key, value) - - @staticmethod - def pre_backward(ctx, grad_spike: torch.Tensor, grad_v_next: torch.Tensor): - """ - :param ctx: ``ctx`` in :class:`torch.autograd.Function` - :param grad_spike: gradients of ``spike`` - :type grad_spike: torch.Tensor - :param grad_v_next: gradients of ``v_next`` - :type grad_v_next: torch.Tensor - :return: backward_kernel, blocks, threads, py_dict - - backward_kernel: NeuronBPTTKernel - The CUDA kernel used for backward. It should be provided in ``ctx.backward_kernel`` - - blocks: int - CUDA param used in calling CUDA kernel. It should be provided in ``ctx.blocks`` - - threads: int - CUDA param used in calling CUDA kernel. It should be provided in ``ctx.threads`` - :rtype: tuple - """ - backward_kernel = ctx.backward_kernel - blocks = ctx.blocks - threads = ctx.threads - - h = ctx.saved_tensors[0] - numel = ctx.numel - v_th = ctx.v_th - v_reset = ctx.v_reset - - zero_shape = list(grad_spike.shape) - zero_shape[0] *= 2 - zero_data = torch.zeros( - zero_shape, device=grad_spike.device, dtype=grad_spike.dtype - ) - - # For fp16, ctx.numel will be divided by 2 later. Here is a reliable way to divide tensor equally - real_numel = grad_spike.size(0) - grad_x = zero_data[:real_numel] - grad_v = zero_data[real_numel:] - - py_dict = { - "numel": numel, - "grad_spike": grad_spike, - "grad_v_next": grad_v_next, - "h": h, - "grad_x": grad_x, - "grad_v": grad_v, - "v_th": v_th, - "v_reset": v_reset, - } - - return backward_kernel, blocks, threads, py_dict - - -class IFNodeFPKernel(NeuronFPKernel): - def neuronal_charge(self) -> str: - return cfunction.add(z="h[index]", x="x[index]", y="v[index]", dtype=self.dtype) - - -class IFNodeBPKernel(NeuronBPKernel): - def grad_h_to_v(self) -> str: - return cfunction.constant( - y=f"const {self.dtype} grad_h_to_v", x=1.0, dtype=self.dtype - ) - - def grad_h_to_x(self) -> str: - return cfunction.constant( - y=f"const {self.dtype} grad_h_to_x", x=1.0, dtype=self.dtype - ) - - -class IFNodeATGF(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: torch.Tensor, - v: torch.Tensor, - v_th: float, - v_reset: Optional[float], - forward_kernel: IFNodeFPKernel, - backward_kernel: IFNodeBPKernel, - ): - py_dict = {"x": x, "v": v, "v_th": v_th, "v_reset": v_reset} - requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict) - - if py_dict["v_reset"] is None: - py_dict.pop("v_reset") - - forward_kernel((blocks,), (threads,), py_dict) - - if "v_reset" not in py_dict: - py_dict["v_reset"] = None - - NeuronATGFBase.ctx_save( - ctx, - requires_grad, - py_dict["h"], - blocks=blocks, - threads=threads, - numel=py_dict["numel"], - v_th=py_dict["v_th"], - v_reset=py_dict["v_reset"], - backward_kernel=backward_kernel, - ) - - return py_dict["spike"], py_dict["v_next"] - - @staticmethod - def backward(ctx, grad_spike: torch.Tensor, grad_v_next: torch.Tensor): - backward_kernel, blocks, threads, py_dict = NeuronATGFBase.pre_backward( - ctx, grad_spike, grad_v_next - ) - if py_dict["v_reset"] is None: - py_dict.pop("v_reset") - - backward_kernel((blocks,), (threads,), py_dict) - - if "v_reset" not in py_dict: - py_dict["v_reset"] = None - - return py_dict["grad_x"], py_dict["grad_v"], None, None, None, None - - -class LIFNodeFPKernel(NeuronFPKernel): - def __init__(self, decay_input: bool, hard_reset: bool, dtype: str): - super().__init__(hard_reset, dtype) - self.decay_input = decay_input - self.add_param(ctype=f"const {dtype} &", cname="decay") - - def neuronal_charge(self) -> str: - if self.hard_reset: - codes = cfunction.sub( - z=f"{self.dtype} LIFNodeFPKernel_temp_var", - x="v[index]", - y="v_reset", - dtype=self.dtype, - ) - else: - codes = f"{self.dtype} LIFNodeFPKernel_temp_var = v[index];" - - if self.decay_input: - codes += cfunction.sub( - z="LIFNodeFPKernel_temp_var", - x="x[index]", - y="LIFNodeFPKernel_temp_var", - dtype=self.dtype, - ) - codes += cfunction.mul( - z="LIFNodeFPKernel_temp_var", - x="decay", - y="LIFNodeFPKernel_temp_var", - dtype=self.dtype, - ) - else: - codes += cfunction.mul( - z="LIFNodeFPKernel_temp_var", - x="decay", - y="LIFNodeFPKernel_temp_var", - dtype=self.dtype, - ) - codes += cfunction.sub( - z="LIFNodeFPKernel_temp_var", - x="x[index]", - y="LIFNodeFPKernel_temp_var", - dtype=self.dtype, - ) - - codes += cfunction.add( - z="h[index]", x="LIFNodeFPKernel_temp_var", y="v[index]", dtype=self.dtype - ) - - return codes - - -class LIFNodeBPKernel(NeuronBPKernel): - def __init__( - self, - decay_input: bool, - surrogate_function: Callable, - hard_reset: bool, - detach_reset: bool, - dtype: str, - ): - super().__init__(surrogate_function, hard_reset, detach_reset, dtype) - self.decay_input = decay_input - self.add_param(ctype=f"const {dtype} &", cname="decay") - - def grad_h_to_v(self) -> str: - return cfunction.sub( - z=f"const {self.dtype} grad_h_to_v", - x=cfunction.constant(None, x=1.0, dtype=self.dtype), - y="decay", - dtype=self.dtype, - ) - - def grad_h_to_x(self) -> str: - if not self.decay_input: - return cfunction.constant( - y=f"const {self.dtype} grad_h_to_x", x=1.0, dtype=self.dtype - ) - else: - return f"const {self.dtype} grad_h_to_x = decay;" - - -class LIFNodeATGF(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: torch.Tensor, - v: torch.Tensor, - v_th: float, - v_reset: Optional[float], - decay: float, - forward_kernel: LIFNodeFPKernel, - backward_kernel: LIFNodeBPKernel, - ): - py_dict = { - "x": x, - "v": v, - "v_th": v_th, - "v_reset": v_reset, - "decay": decay, - } - requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict) - - if py_dict["v_reset"] is None: - py_dict.pop("v_reset") - - forward_kernel((blocks,), (threads,), py_dict) - - if "v_reset" not in py_dict: - py_dict["v_reset"] = None - - NeuronATGFBase.ctx_save( - ctx, - requires_grad, - py_dict["h"], - blocks=blocks, - threads=threads, - numel=py_dict["numel"], - v_th=py_dict["v_th"], - v_reset=py_dict["v_reset"], - backward_kernel=backward_kernel, - decay=py_dict["decay"], - ) - - return py_dict["spike"], py_dict["v_next"] - - @staticmethod - def backward(ctx, grad_spike: torch.Tensor, grad_v_next: torch.Tensor): - - backward_kernel, blocks, threads, py_dict = NeuronATGFBase.pre_backward( - ctx, grad_spike, grad_v_next - ) - py_dict["decay"] = ctx.decay - - if py_dict["v_reset"] is None: - py_dict.pop("v_reset") - - backward_kernel((blocks,), (threads,), py_dict) - - if "v_reset" not in py_dict: - py_dict["v_reset"] = None - - return py_dict["grad_x"], py_dict["grad_v"], None, None, None, None, None diff --git a/src/chop/nn/snn/base.py b/src/chop/nn/snn/base.py deleted file mode 100644 index a6b892a14..000000000 --- a/src/chop/nn/snn/base.py +++ /dev/null @@ -1,329 +0,0 @@ -# *************************************************************************************** -# * Title: base.py -# * Reference: This file is directly sourced from spikingJelly -# * Availability: https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/base.py -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# *************************************************************************************** - -import torch -import torch.nn as nn -import copy -import logging -from abc import abstractmethod - -try: - import cupy -except BaseException as e: - logging.info(f"spikingjelly.activation_based.base: {e}") - cupy = None - -try: - import lava.lib.dl.slayer as slayer -except BaseException as e: - slayer = None - - -def check_backend_library(backend: str): - """ - * :ref:`API in English ` - .. _check_backend_library-en: - - :param backend: ``'torch'``, ``'cupy'`` or ``'lava'`` - :type backend: str - - Check whether the python lib for backend is installed. If not, this function will raise an error. - """ - if backend == "torch": - return - elif backend == "cupy": - if cupy is None: - raise ImportError( - 'CuPy is not installed! You can install it from "https://github.com/cupy/cupy".' - ) - elif backend == "lava": - if slayer is None: - raise ImportError( - "Lava-DL is not installed! You can install it from " - '"https://github.com/lava-nc/lava-dl". ' - ) - else: - pass - - -class StepModule: - def supported_step_mode(self): - """ - * :ref:`API in English ` - .. _StepModule.supported_step_mode-en: - - :return: a tuple that contains the supported backends - :rtype: tuple[str] - - """ - return ("s", "m") - - @property - def step_mode(self): - """ - * :ref:`API in English ` - .. _StepModule.step_mode-en: - - :return: the current step mode of this module - :rtype: str - """ - return self._step_mode - - @step_mode.setter - def step_mode(self, value: str): - """ - * :ref:`API in English ` - .. _StepModule.step_mode-setter-en: - - :param value: the step mode - :type value: str - - Set the step mode of this module to be ``value`` - - """ - if value not in self.supported_step_mode(): - raise ValueError( - f'step_mode can only be {self.supported_step_mode()}, but got "{value}"!' - ) - self._step_mode = value - - -class SingleModule(StepModule): - """ - * :ref:`API in English ` - .. _SingleModule-en: - - The module that only supports for single-step (``step_mode == 's'``) - """ - - def supported_step_mode(self): - return ("s",) - - -class MultiStepModule(StepModule): - """ - * :ref:`API in English ` - .. _MultiStepModule-en: - - The module that only supports for multi-step (``step_mode == 'm'``) - """ - - def supported_step_mode(self): - return ("m",) - - -class MemoryModule(nn.Module, StepModule): - def __init__(self): - """ - * :ref:`API in English ` - .. _MemoryModule.__init__-en: - - ``MemoryModule`` is the base class of all stateful modules in SpikingJelly. - - """ - super().__init__() - self._memories = {} - self._memories_rv = {} - self._backend = "torch" - self.step_mode = "s" - - @property - def supported_backends(self): - """ - * :ref:`API in English ` - .. _MemoryModule.supported_backends-en: - - Return the supported backends. The default return value is `('torch', )` - - :return: supported backends - :rtype: tuple[str] - - """ - return ("torch",) - - @property - def backend(self): - return self._backend - - @backend.setter - def backend(self, value: str): - if value not in self.supported_backends: - raise NotImplementedError( - f"{value} is not a supported backend of {self._get_name()}!" - ) - check_backend_library(value) - self._backend = value - - @abstractmethod - def single_step_forward(self, x: torch.Tensor, *args, **kwargs): - """ - * :ref:`API in English ` - .. _MemoryModule.single_step_forward-en: - - :param x: input tensor with ``shape = [N, *] `` - :type x: torch.Tensor - - The single-step forward function for this module - - """ - pass - - def multi_step_forward(self, x_seq: torch.Tensor, *args, **kwargs): - """ - * :ref:`API in English ` - .. _MemoryModule.multi_step_forward-en: - - :param x_seq: input tensor with ``shape = [T, N, *] `` - :type x_seq: torch.Tensor - - The multi-step forward function for this module, which is implemented by calling ``single_step_forward(x[t], *args, **kwargs)`` over ``T`` times - - """ - - T = x_seq.shape[0] - y_seq = [] - for t in range(T): - y = self.single_step_forward(x_seq[t], *args, **kwargs) - y_seq.append(y.unsqueeze(0)) - - return torch.cat(y_seq, 0) - - def forward(self, *args, **kwargs): - if self.step_mode == "s": - return self.single_step_forward(*args, **kwargs) - elif self.step_mode == "m": - return self.multi_step_forward(*args, **kwargs) - else: - raise ValueError(self.step_mode) - - def extra_repr(self): - return f"step_mode={self.step_mode}, backend={self.backend}" - - def register_memory(self, name: str, value): - """ - * :ref:`API in English ` - .. _MemoryModule.register_memory-en: - - :param name: variable's name - :type name: str - :param value: variable's value - :type value: any - - Register the variable to memory dict, which saves stateful variables (e.g., the membrane potential of a - spiking neuron). The reset value of this variable will be ``value``. ``self.name`` will be set to ``value`` after - each calling of ``self.reset()``. - - """ - assert not hasattr(self, name), f"{name} has been set as a member variable!" - self._memories[name] = value - self.set_reset_value(name, value) - - def reset(self): - """ - * :ref:`API in English ` - .. _MemoryModule.reset-en: - - Reset all stateful variables to their default values. - """ - for key in self._memories.keys(): - self._memories[key] = copy.deepcopy(self._memories_rv[key]) - - def set_reset_value(self, name: str, value): - self._memories_rv[name] = copy.deepcopy(value) - - def __getattr__(self, name: str): - if "_memories" in self.__dict__: - memories = self.__dict__["_memories"] - if name in memories: - return memories[name] - - return super().__getattr__(name) - - def __setattr__(self, name: str, value) -> None: - _memories = self.__dict__.get("_memories") - if _memories is not None and name in _memories: - _memories[name] = value - else: - super().__setattr__(name, value) - - def __delattr__(self, name): - if name in self._memories: - del self._memories[name] - del self._memories_rv[name] - else: - return super().__delattr__(name) - - def __dir__(self): - module_attrs = dir(self.__class__) - attrs = list(self.__dict__.keys()) - parameters = list(self._parameters.keys()) - modules = list(self._modules.keys()) - buffers = list(self._buffers.keys()) - memories = list(self._memories.keys()) - keys = module_attrs + attrs + parameters + modules + buffers + memories - - # Eliminate attrs that are not legal Python variable names - keys = [key for key in keys if not key[0].isdigit()] - - return sorted(keys) - - def memories(self): - """ - * :ref:`API in English ` - .. _MemoryModule.memories-en: - - :return: an iterator over all stateful variables - :rtype: Iterator - """ - for name, value in self._memories.items(): - yield value - - def named_memories(self): - """ - * :ref:`API in English ` - .. _MemoryModule.named_memories-en: - - :return: an iterator over all stateful variables and their names - :rtype: Iterator - """ - - for name, value in self._memories.items(): - yield name, value - - def detach(self): - """ - * :ref:`API in English ` - .. _MemoryModule.detach-en: - - Detach all stateful variables. - - .. admonition:: Tip - :class: tip - - We can use this function to implement TBPTT(Truncated Back Propagation Through Time). - - """ - - for key in self._memories.keys(): - if isinstance(self._memories[key], torch.Tensor): - self._memories[key].detach_() - - def _apply(self, fn): - for key, value in self._memories.items(): - if isinstance(value, torch.Tensor): - self._memories[key] = fn(value) - # do not apply on default values - # for key, value in self._memories_rv.items(): - # if isinstance(value, torch.Tensor): - # self._memories_rv[key] = fn(value) - return super()._apply(fn) - - def _replicate_for_data_parallel(self): - replica = super()._replicate_for_data_parallel() - replica._memories = self._memories.copy() - return replica diff --git a/src/chop/nn/snn/configure.py b/src/chop/nn/snn/configure.py deleted file mode 100644 index 2fe757841..000000000 --- a/src/chop/nn/snn/configure.py +++ /dev/null @@ -1,77 +0,0 @@ -# *************************************************************************************** -# * Title: configuration.py -# * Reference: This file is directly sourced from spikingJelly -# * Availability: https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/configure.py -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# *************************************************************************************** - -""" -This py file defines some variables used in SpikingJelly. -Here is an example of how you can change them to make effect in your codes: - - import spikingjelly - spikingjelly.configure.cuda_threads = 512 - -Do not change them in this way, which will not make effect: - - from spikingjelly.configure import cuda_threads - cuda_threads = 512 - -""" - -max_threads_number_for_datasets_preprocess = 16 -""" -`max_threads_number_for_datasets_preprocess` defines the maximum threads for datasets preprocessing, which is -1. reading binary events and saving them to numpy format -2. integrating events to frames. - -Note that a too larger `max_threads_number_for_datasets_preprocess` will overload the disc and slow down the speed. -""" - -cuda_threads = 512 -""" -`cuda_threads` defines the default threads number for CUDA kernel. - -It is recommended that `cuda_threads` is the power of 2. -""" - -cuda_compiler_options = ("-use_fast_math",) -""" -`cuda_compiler_options` defines the compiler options passed to the backend (NVRTC or NVCC). - -For more details, refer to -1. https://docs.nvidia.com/cuda/nvrtc/index.html#group__options -2. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#command-option-description -3. https://github.com/fangwei123456/spikingjelly/discussions/116 -""" - -cuda_compiler_backend = "nvrtc" -""" -`cuda_compiler_backend` defines the compiler for CUDA(cupy). - -It can be set to 'nvcc' or 'nvrtc'. -""" - -save_datasets_compressed = True -""" -If `save_datasets_compressed == True`, events and frames in spikingjelly.datasets will be saved in compressed npz format. - -The compressed npz file consumes less memory in disk but more time in reading. -""" - -save_spike_as_bool_in_neuron_kernel = False -""" -If `save_spike_as_bool_in_neuron_kernel == True`, the neuron kernel used in the neuron's cupy backend will save the spike as a bool, rather than float/half tensor for backward, which can reduce the memory consumption. -""" - -save_bool_spike_level = 0 -""" -`save_bool_spike_level` take effects on SpikeConv/SpikeLinear, and on neuron's cupy kernel when `save_spike_as_bool_in_cuda_utils == True`. - -If `save_bool_spike_level == 0`, spikes will be saved in bool. Note that bool uses 8-bit, rather than 1-bit. - -If `save_bool_spike_level == 1`, spikes will be saved in uint8 with each 8-bit storing 8 spikes. - -A larger `save_bool_spike_level` means less memory consumption but slower speed. -""" diff --git a/src/chop/nn/snn/cuda_utils.py b/src/chop/nn/snn/cuda_utils.py deleted file mode 100644 index 417b6c318..000000000 --- a/src/chop/nn/snn/cuda_utils.py +++ /dev/null @@ -1,326 +0,0 @@ -# *************************************************************************************** -# * Title: cuda_utils -# * Reference: This file is directly sourced from spikingJelly -# * Availability: https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/cuda_utils.py -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# *************************************************************************************** - -import logging -import torch -import time -import numpy as np -from . import configure -from typing import Callable, Union - -try: - import cupy -except BaseException as e: - logging.info(f"spikingjelly.activation_based.cuda_utils: {e}") - cupy = None - - -def cpu_timer(f: Callable, *args, **kwargs): - """ - * :ref:`API in English ` - - .. _cpu_timer-cn: - - 计算在CPU上执行 ``f(*args, **kwargs)`` 所需的时间 - - :param f: 函数 - :type f: Callable - :return: 用时,单位是毫秒 - :rtype: float - - * :ref:`中文 API ` - - .. _cpu_timer-en: - - Returns the used time for calling ``f(*args, **kwargs)`` in CPU - - :param f: a function - :type f: Callable - :return: used time in milliseconds - :rtype: float - """ - start = time.perf_counter() - f(*args, **kwargs) - return time.perf_counter() - start - - -def cuda_timer(device: Union[torch.device, int], f: Callable, *args, **kwargs): - """ - * :ref:`API in English ` - - .. _cuda_timer-cn: - - 计算在CUDA上执行 ``f(*args, **kwargs)`` 所需的时间 - - :param device: ``f`` 运行的CUDA设备 - :type device: Union[torch.device, int] - :param f: 函数 - :type f: Callable - :return: 用时,单位是毫秒 - :rtype: float - - * :ref:`中文 API ` - - .. _cuda_timer-en: - - Returns the used time for calling ``f(*args, **kwargs)`` in CUDA - - :param device: on which cuda device that ``f`` is running - :type device: Union[torch.device, int] - :param f: a function - :type f: Callable - :return: used time in milliseconds - :rtype: float - """ - torch.cuda.set_device(device) - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - f(*args, **kwargs) - end.record() - torch.cuda.synchronize(device) - return start.elapsed_time(end) - - -def cal_fun_t( - n: int, device: Union[str, torch.device, int], f: Callable, *args, **kwargs -): - """ - * :ref:`API in English ` - - .. _cal_fun_t-cn: - - 测量在 ``device`` 上执行 ``n`` 次 ``f(*args, **kwargs)`` 的平均用时 - - .. note:: - - 当 ``n > 1`` 时,实际上会执行 ``2n`` 次,然后返回后 ``n`` 次的平均用时,以减小误差。 - - :param n: 重复的次数 - :type n: int - :param device: ``f`` 执行的设备,可以为 'cpu' 或CUDA设备 - :type device: Union[str, torch.device, int] - :param f: 函数 - :type f: Callable - :return: 用时,单位是毫秒 - :rtype: float - - * :ref:`中文 API ` - - .. _cal_fun_t-en: - - Returns the used time averaged by calling ``f(*args, **kwargs)`` over ``n`` times - - .. admonition:: Note - :class: note - - If ``n > 1``, this function will call ``f`` for ``2n`` times and return the average used time by the last ``n`` - times to reduce the measure error. - - :param n: repeat times - :type n: int - :param device: on which cuda device that ``f`` is running. It can be 'cpu' or a cuda deivce - :type device: Union[str, torch.device, int] - :param f: function - :type f: Callable - :return: used time in milliseconds - :rtype: float - - """ - if n == 1: - if device == "cpu": - return cpu_timer(f, *args, **kwargs) - else: - return cuda_timer(device, f, *args, **kwargs) - - # warm up - if device == "cpu": - cpu_timer(f, *args, **kwargs) - else: - cuda_timer(device, f, *args, **kwargs) - - t_list = [] - for _ in range(n * 2): - if device == "cpu": - ti = cpu_timer(f, *args, **kwargs) - else: - ti = cuda_timer(device, f, *args, **kwargs) - t_list.append(ti) - - t_list = np.asarray(t_list) - return t_list[n:].mean() - - -def cal_blocks(numel: int, threads: int = -1): - """ - * :ref:`API in English ` - - .. _cal_blocks-cn: - - :param numel: 并行执行的CUDA内核的数量 - :type numel: int - :param threads: 每个cuda block中threads的数量,默认为-1,表示使用 ``configure.cuda_threads`` - :type threads: int - :return: blocks的数量 - :rtype: int - - 此函数返回 blocks的数量,用来按照 ``kernel((blocks,), (configure.cuda_threads,), ...)`` 调用 :class:`cupy.RawKernel` - - * :ref:`中文 API ` - - .. _cal_blocks-en: - - :param numel: the number of parallel CUDA kernels - :type numel: int - :param threads: the number of threads in each cuda block. - The defaule value is -1, indicating to use ``configure.cuda_threads`` - :type threads: int - :return: the number of blocks - :rtype: int - - Returns the number of blocks to call :class:`cupy.RawKernel` by ``kernel((blocks,), (threads,), ...)`` - - """ - if threads == -1: - threads = configure.cuda_threads - return (numel + threads - 1) // threads - - -def get_contiguous(*args): - """ - * :ref:`API in English ` - - .. _get_contiguous-cn: - - 将 ``*args`` 中所有的 ``torch.Tensor`` 或 ``cupy.ndarray`` 进行连续化。 - - .. note:: - - 连续化的操作无法in-place,因此本函数返回一个新的list。 - - :return: 一个元素全部为连续的 ``torch.Tensor`` 或 ``cupy.ndarray`` 的 ``list`` - :rtype: list - - * :ref:`中文 API ` - - .. _get_contiguous-en: - - :return: a list that contains the contiguous ``torch.Tensor`` or ``cupy.ndarray`` - :rtype: list - - Makes ``torch.Tensor`` or ``cupy.ndarray`` in ``*args`` to be contiguous - - .. admonition:: Note - :class: note - - The making contiguous operation can not be done in-place. Hence, this function will return a new list. - - """ - ret_list = [] - - for item in args: - if isinstance(item, torch.Tensor): - ret_list.append(item.contiguous()) - - elif isinstance(item, cupy.ndarray): - ret_list.append(cupy.ascontiguousarray(item)) - else: - raise TypeError(type(item)) - return ret_list - - -def wrap_args_to_raw_kernel(device: int, *args): - """ - * :ref:`API in English ` - - .. _wrap_args_to_raw_kernel-cn: - - :param device: raw kernel运行的CUDA设备 - :type device: int - :return: 一个包含用来调用 :class:`cupy.RawKernel` 的 ``tuple`` - :rtype: tuple - - 此函数可以包装 ``torch.Tensor`` 和 ``cupy.ndarray`` 并将其作为 :class:`cupy.RawKernel.__call__` 的 ``args`` - - * :ref:`中文 API ` - - .. _wrap_args_to_raw_kernel-en: - - :param device: on which CUDA device the raw kernel will run - :type device: int - :return: a ``tuple`` that contains args to call :class:`cupy.RawKernel` - :rtype: tuple - - This function can wrap ``torch.Tensor`` or ``cupy.ndarray`` to ``args`` in :class:`cupy.RawKernel.__call__` - - """ - # note that the input must be contiguous - # check device and get data_ptr from tensor - ret_list = [] - for item in args: - if isinstance(item, torch.Tensor): - assert item.get_device() == device - assert item.is_contiguous() - ret_list.append(item.data_ptr()) - - elif isinstance(item, cupy.ndarray): - assert item.device.id == device - assert item.flags["C_CONTIGUOUS"] - ret_list.append(item) - - else: - raise TypeError - return tuple(ret_list) - - -class DeviceEnvironment: - def __init__(self, device: int): - """ - * :ref:`API in English ` - - .. _DeviceEnvironment.__init__-cn: - - 这个模块可以被用作在指定的 ``device`` 上执行CuPy函数的上下文,用来避免 `torch.cuda.current_device()` 被CuPy意外改变( https://github.com/cupy/cupy/issues/6569 )。 - - 代码示例: - - .. code-block:: python - - with DeviceEnvironment(device): - kernel((blocks,), (configure.cuda_threads,), ...) - - - * :ref:`中文 API ` - - .. _DeviceEnvironment.__init__-en: - - :param device: the CUDA device - :type device: int - - This module is used as a context to make CuPy use the specific device, and avoids `torch.cuda.current_device()` is changed by CuPy ( https://github.com/cupy/cupy/issues/6569 ). - - Codes example: - - .. code-block:: python - - with DeviceEnvironment(device): - kernel((blocks,), (configure.cuda_threads,), ...) - - """ - self.device = device - self.previous_device = None - - def __enter__(self): - current_device = torch.cuda.current_device() - if current_device != self.device: - torch.cuda.set_device(self.device) - self.previous_device = current_device - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.previous_device is not None: - torch.cuda.set_device(self.previous_device) diff --git a/src/chop/nn/snn/functional/__init__.py b/src/chop/nn/snn/functional/__init__.py deleted file mode 100644 index a2f69860f..000000000 --- a/src/chop/nn/snn/functional/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .surrogate import sigmoid - -from .functional import multi_step_forward, seq_to_ann_forward, reset_net diff --git a/src/chop/nn/snn/functional/functional.py b/src/chop/nn/snn/functional/functional.py deleted file mode 100644 index 829ee358a..000000000 --- a/src/chop/nn/snn/functional/functional.py +++ /dev/null @@ -1,114 +0,0 @@ -# *************************************************************************************** -# * Title: funtional.py -# * Reference: This file is adapted from spikingJelly -# * Availability: https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/functional.py -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# *************************************************************************************** - -import logging -import copy -from chop.nn.snn import base -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from typing import Callable, Union -from torch import Tensor - - -def reset_net(net: nn.Module): - """ - * :ref:`API in English ` - - .. _reset_net-cn: - - :param net: 任何属于 ``nn.Module`` 子类的网络 - - :return: None - - 将网络的状态重置。做法是遍历网络中的所有 ``Module``,若 ``m `` 为 ``base.MemoryModule`` 函数或者是拥有 ``reset()`` 方法,则调用 ``m.reset()``。 - - * :ref:`中文API ` - - .. _reset_net-en: - - :param net: Any network inherits from ``nn.Module`` - - :return: None - - Reset the whole network. Walk through every ``Module`` as ``m``, and call ``m.reset()`` if this ``m`` is ``base.MemoryModule`` or ``m`` has ``reset()``. - """ - for m in net.modules(): - if hasattr(m, "reset"): - if not isinstance(m, base.MemoryModule): - logging.warning( - f"Trying to call `reset()` of {m}, which is not spikingjelly.activation_based.base" - f".MemoryModule" - ) - m.reset() - - -def multi_step_forward( - x_seq: Tensor, - single_step_module: Union[ - nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable - ], -): - """ - * :ref:`API in English ` - - .. _multi_step_forward-en: - - :param x_seq: the input tensor with ``shape=[T, batch_size, ...]`` - :type x_seq: torch.Tensor - :param single_step_module: one or many single-step modules - :type single_step_module: Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable] - :return: the output tensor with ``shape=[T, batch_size, ...]`` - :rtype: torch.torch.Tensor - - Applies multi-step forward on ``single_step_module``. - - """ - y_seq = [] - if isinstance(single_step_module, (list, tuple, nn.Sequential)): - for t in range(x_seq.shape[0]): - x_seq_t = x_seq[t] - for m in single_step_module: - x_seq_t = m(x_seq_t) - y_seq.append(x_seq_t) - else: - for t in range(x_seq.shape[0]): - y_seq.append(single_step_module(x_seq[t])) - - return torch.stack(y_seq) - - -def seq_to_ann_forward( - x_seq: Tensor, - stateless_module: Union[nn.Module, list, tuple, nn.Sequential, Callable], -): - """ - * :ref:`API in English ` - - .. _seq_to_ann_forward-en: - - :param x_seq: the input tensor with ``shape=[T, batch_size, ...]`` - :type x_seq: Tensor - :param stateless_module: one or many stateless modules - :type stateless_module: Union[nn.Module, list, tuple, nn.Sequential, Callable] - :return: the output tensor with ``shape=[T, batch_size, ...]`` - :rtype: Tensor - - Applied forward on stateless modules. - - """ - y_shape = [x_seq.shape[0], x_seq.shape[1]] - y = x_seq.flatten(0, 1) - if isinstance(stateless_module, (list, tuple, nn.Sequential)): - for m in stateless_module: - y = m(y) - else: - y = stateless_module(y) - y_shape.extend(y.shape[1:]) - return y.view(y_shape) diff --git a/src/chop/nn/snn/functional/surrogate.py b/src/chop/nn/snn/functional/surrogate.py deleted file mode 100644 index 2b7556479..000000000 --- a/src/chop/nn/snn/functional/surrogate.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from chop.nn.snn.auto_cuda import cfunction - -tab4_str = "\t\t\t\t" # used for aligning code -curly_bracket_l = "{" -curly_bracket_r = "}" - - -@torch.jit.script -def heaviside(x: torch.Tensor): - """ - * :ref:`API in English ` - .. _heaviside.__init__-en: - - :param x: the input tensor - :return: the output tensor - - The heaviside function, which is defined by - - .. math:: - g(x) = - \\begin{cases} - 1, & x \\geq 0 \\\\ - 0, & x < 0 \\\\ - \\end{cases} - - For more information, see `HeavisideStepFunction `_. - - """ - return (x >= 0).to(x) - - -@torch.jit.script -def sigmoid_backward(grad_output: torch.Tensor, x: torch.Tensor, alpha: float): - sgax = (x * alpha).sigmoid_() - return grad_output * (1.0 - sgax) * sgax * alpha, None - - -class sigmoid(torch.autograd.Function): - @staticmethod - def forward(ctx, x, alpha): - if x.requires_grad: - ctx.save_for_backward(x) - ctx.alpha = alpha - return heaviside(x) - - @staticmethod - def backward(ctx, grad_output): - return sigmoid_backward(grad_output, ctx.saved_tensors[0], ctx.alpha) - - -@torch.jit.script -def atan_backward(grad_output: torch.Tensor, x: torch.Tensor, alpha: float): - return alpha / 2 / (1 + (math.pi / 2 * alpha * x).pow_(2)) * grad_output, None - - -class atan(torch.autograd.Function): - @staticmethod - def forward(ctx, x, alpha): - if x.requires_grad: - ctx.save_for_backward(x) - ctx.alpha = alpha - return heaviside(x) - - @staticmethod - def backward(ctx, grad_output): - return atan_backward(grad_output, ctx.saved_tensors[0], ctx.alpha) diff --git a/src/chop/nn/snn/functional/utils.py b/src/chop/nn/snn/functional/utils.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/chop/nn/snn/modules/__init__.py b/src/chop/nn/snn/modules/__init__.py deleted file mode 100644 index 6b6efd229..000000000 --- a/src/chop/nn/snn/modules/__init__.py +++ /dev/null @@ -1,115 +0,0 @@ -from torch import nn - -from .modules import VoltageHook, VoltageScaler, SeqToANNContainer, MultiStepContainer - -# from .converter import Converter - -from .surrogate import Sigmoid, ATan - -from .conv1d import Conv1d - -from .conv2d import Conv2d - -from .conv3d import Conv3d - -from .linear import Linear, LinearUnfoldBias - -from .pool1d import MaxPool1d, AvgPool1d, AdaptiveAvgPool1d - -from .pool2d import MaxPool2d, AvgPool2d, AdaptiveAvgPool2d - -from .pool3d import MaxPool3d, AvgPool3d, AdaptiveAvgPool3d - -from .batch_norm1d import BatchNorm1d - -from .batch_norm2d import BatchNorm2d - -from .batch_norm3d import BatchNorm3d - -from .flatten import Flatten - -from .group_norm import GroupNorm - -from .upsample import Upsample - -from .neuron import ( - IFNode, - LIFNode, - ParametricLIFNode, - ST_BIFNode, -) - -from .layernorm import LayerNormZIPTF - -from .softmax import SoftmaxZIPTF - -from .gelu import GELUZIPTF - -from .silu import SiLUZIPTF - -from .spiking_self_attention import ( - DSSA, - GWFFN, - BN, - DownsampleLayer, - Conv1x1, - LIF, - PLIF, - Conv3x3, - SpikingMatmul, -) - -from .embedding import EmbeddingZIPTF -from .roberta import ( - RobertaSelfAttentionZIPTF, -) - -spiking_basic_module_map = { - "conv1d": Conv1d, - "conv2d": Conv2d, - "conv3d": Conv3d, - "linear": Linear, - "linear_unfold_bias": LinearUnfoldBias, - "max_pool1d": MaxPool1d, - "avg_pool1d": AvgPool1d, - "adaptive_avg_pool1d": AdaptiveAvgPool1d, - "max_pool2d": MaxPool2d, - "avg_pool2d": AvgPool2d, - "adaptive_avg_pool2d": AdaptiveAvgPool2d, - "max_pool3d": MaxPool3d, - "avg_pool3d": AvgPool3d, - "adaptive_avg_pool3d": AdaptiveAvgPool3d, - "batch_norm1d": BatchNorm1d, - "batch_norm2d": BatchNorm2d, - "batch_norm3d": BatchNorm3d, - "flatten": Flatten, - "group_norm": GroupNorm, - "upsample": Upsample, - "identity": nn.Identity, -} - -spiking_varied_module_map = { - "softmax_zip_tf": SoftmaxZIPTF, - "layernorm_zip_tf": LayerNormZIPTF, - "embedding_zip_tf": EmbeddingZIPTF, - "gelu_zip_tf": GELUZIPTF, - "silu_zip_tf": SiLUZIPTF, -} - -spiking_neuron_module_map = { - "if": IFNode, - "lif": LIFNode, - "plif": ParametricLIFNode, - "st_bif": ST_BIFNode, -} - -spiking_roberta_module_map = { - "roberta_self_attention_zip_tf": RobertaSelfAttentionZIPTF, -} - -spiking_module_map = { - **spiking_basic_module_map, - **spiking_neuron_module_map, - **spiking_varied_module_map, - **spiking_roberta_module_map, -} diff --git a/src/chop/nn/snn/modules/batch_norm1d.py b/src/chop/nn/snn/modules/batch_norm1d.py deleted file mode 100644 index 5f7f7145a..000000000 --- a/src/chop/nn/snn/modules/batch_norm1d.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import ( - _size_any_t, - _size_1_t, - _size_2_t, - _size_3_t, - _ratio_any_t, -) -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - - -class BatchNorm1d(nn.BatchNorm1d, base.StepModule): - def __init__( - self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - track_running_stats=True, - step_mode="s", - ): - """ - * :ref:`API in English ` - - .. _BatchNorm1d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.BatchNorm1d` for other parameters' API - """ - super().__init__(num_features, eps, momentum, affine, track_running_stats) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - return super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 4 and x.dim() != 3: - raise ValueError( - f"expected x with shape [T, N, C, L] or [T, N, C], but got x with shape {x.shape}!" - ) - return functional.seq_to_ann_forward(x, super().forward) diff --git a/src/chop/nn/snn/modules/batch_norm2d.py b/src/chop/nn/snn/modules/batch_norm2d.py deleted file mode 100644 index 1edca2919..000000000 --- a/src/chop/nn/snn/modules/batch_norm2d.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import ( - _size_any_t, - _size_1_t, - _size_2_t, - _size_3_t, - _ratio_any_t, -) -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - - -class BatchNorm2d(nn.BatchNorm2d, base.StepModule): - def __init__( - self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - track_running_stats=True, - step_mode="s", - ): - """ - * :ref:`API in English ` - - .. _BatchNorm2d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.BatchNorm2d` for other parameters' API - """ - super().__init__(num_features, eps, momentum, affine, track_running_stats) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - return super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 5: - raise ValueError( - f"expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!" - ) - return functional.seq_to_ann_forward(x, super().forward) diff --git a/src/chop/nn/snn/modules/batch_norm3d.py b/src/chop/nn/snn/modules/batch_norm3d.py deleted file mode 100644 index ed61a3a74..000000000 --- a/src/chop/nn/snn/modules/batch_norm3d.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import ( - _size_any_t, - _size_1_t, - _size_2_t, - _size_3_t, - _ratio_any_t, -) -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - - -class BatchNorm3d(nn.BatchNorm3d, base.StepModule): - def __init__( - self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - track_running_stats=True, - step_mode="s", - ): - """ - * :ref:`API in English ` - - .. _BatchNorm3d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.BatchNorm3d` for other parameters' API - """ - super().__init__(num_features, eps, momentum, affine, track_running_stats) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - return super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 6: - raise ValueError( - f"expected x with shape [T, N, C, D, H, W], but got x with shape {x.shape}!" - ) - return functional.seq_to_ann_forward(x, super().forward) diff --git a/src/chop/nn/snn/modules/conv1d.py b/src/chop/nn/snn/modules/conv1d.py deleted file mode 100644 index 4dc6e2308..000000000 --- a/src/chop/nn/snn/modules/conv1d.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import _size_any_t, _size_1_t, _size_2_t, _size_3_t -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - - -class Conv1d(nn.Conv1d, base.StepModule): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_1_t, - stride: _size_1_t = 1, - padding: Union[str, _size_1_t] = 0, - dilation: _size_1_t = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", - step_mode: str = "s", - ) -> None: - """ - * :ref:`API in English ` - - .. _Conv1d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.Conv1d` for other parameters' API - """ - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - ) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 4: - raise ValueError( - f"expected x with shape [T, N, C, L], but got x with shape {x.shape}!" - ) - x = functional.seq_to_ann_forward(x, super().forward) - - return x diff --git a/src/chop/nn/snn/modules/conv2d.py b/src/chop/nn/snn/modules/conv2d.py deleted file mode 100644 index 9afd75653..000000000 --- a/src/chop/nn/snn/modules/conv2d.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import _size_any_t, _size_1_t, _size_2_t, _size_3_t -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - - -class Conv2d(nn.Conv2d, base.StepModule): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_2_t, - stride: _size_2_t = 1, - padding: Union[str, _size_2_t] = 0, - dilation: _size_2_t = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", - step_mode: str = "s", - ) -> None: - """ - * :ref:`API in English ` - .. _Conv2d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.Conv2d` for other parameters' API - """ - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - ) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 5: - raise ValueError( - f"expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!" - ) - x = functional.seq_to_ann_forward(x, super().forward) - - return x diff --git a/src/chop/nn/snn/modules/conv3d.py b/src/chop/nn/snn/modules/conv3d.py deleted file mode 100644 index 31ed37ef1..000000000 --- a/src/chop/nn/snn/modules/conv3d.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import _size_any_t, _size_1_t, _size_2_t, _size_3_t -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - - -class Conv3d(nn.Conv3d, base.StepModule): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: _size_3_t, - stride: _size_3_t = 1, - padding: Union[str, _size_3_t] = 0, - dilation: _size_3_t = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = "zeros", - step_mode: str = "s", - ) -> None: - """ - * :ref:`API in English ` - - .. _Conv3d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.Conv3d` for other parameters' API - """ - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - ) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 6: - raise ValueError( - f"expected x with shape [T, N, C, D, H, W], but got x with shape {x.shape}!" - ) - x = functional.seq_to_ann_forward(x, super().forward) - - return x diff --git a/src/chop/nn/snn/modules/embedding.py b/src/chop/nn/snn/modules/embedding.py deleted file mode 100644 index 6bab58947..000000000 --- a/src/chop/nn/snn/modules/embedding.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -from torch import nn -from typing import Optional, Tuple -from torch import Tensor - - -class EmbeddingZIPTF(nn.Embedding): - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, - norm_type: float = 2.0, - scale_grad_by_freq: bool = False, - sparse: bool = False, - _weight: Optional[Tensor] = None, - _freeze: bool = False, - device=None, - dtype=None, - ) -> None: - super().__init__( - num_embeddings, - embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=_weight, - _freeze=_freeze, - device=device, - dtype=dtype, - ) - self.T = 0 - self.shape = None - - def reset(self): - self.T = 0 - - def forward(self, x): - if self.T == 0: - output = super().forward(x) - self.shape = output.shape - self.T = self.T + 1 - return output - else: - return torch.zeros(self.shape, device=x.device) diff --git a/src/chop/nn/snn/modules/flatten.py b/src/chop/nn/snn/modules/flatten.py deleted file mode 100644 index a9d5450aa..000000000 --- a/src/chop/nn/snn/modules/flatten.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import ( - _size_any_t, - _size_1_t, - _size_2_t, - _size_3_t, - _ratio_any_t, -) -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - - -class Flatten(nn.Flatten, base.StepModule): - def __init__(self, start_dim: int = 1, end_dim: int = -1, step_mode="s") -> None: - """ - * :ref:`API in English ` - - .. _Flatten-cn: - - :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) - :type step_mode: str - - 其他的参数API参见 :class:`torch.nn.Flatten` - - * :ref:`中文 API ` - - .. _Flatten-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.Flatten` for other parameters' API - """ - super().__init__(start_dim, end_dim) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - x = functional.seq_to_ann_forward(x, super().forward) - return x diff --git a/src/chop/nn/snn/modules/gelu.py b/src/chop/nn/snn/modules/gelu.py deleted file mode 100644 index 8d9ccdc0b..000000000 --- a/src/chop/nn/snn/modules/gelu.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -from torch import nn - - -class GELUZIPTF(nn.GELU): - def __init__( - self, - normalized_shape, - eps: float = 1e-5, - elementwise_affine: bool = True, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) - self.X = 0.0 - self.Y_pre = None - - def reset(self): - self.X = 0.0 - self.Y_pre = None - - def forward(self, input): - # print("input", input) - self.X = self.X + input - Y = super().forward(self.X) - if self.Y_pre is not None: - Y_pre = self.Y_pre.detach().clone() - else: - Y_pre = 0.0 - self.Y_pre = Y - return Y - Y_pre diff --git a/src/chop/nn/snn/modules/group_norm.py b/src/chop/nn/snn/modules/group_norm.py deleted file mode 100644 index b7a7aa3b3..000000000 --- a/src/chop/nn/snn/modules/group_norm.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import ( - _size_any_t, - _size_1_t, - _size_2_t, - _size_3_t, - _ratio_any_t, -) -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - - -class GroupNorm(nn.GroupNorm, base.StepModule): - def __init__( - self, - num_groups: int, - num_channels: int, - eps: float = 1e-5, - affine: bool = True, - step_mode="s", - ): - """ - * :ref:`API in English ` - - .. _GroupNorm-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.GroupNorm` for other parameters' API - """ - super().__init__(num_groups, num_channels, eps, affine) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - return super().forward(x) - - elif self.step_mode == "m": - return functional.seq_to_ann_forward(x, super().forward) diff --git a/src/chop/nn/snn/modules/layernorm.py b/src/chop/nn/snn/modules/layernorm.py deleted file mode 100644 index ba17d468c..000000000 --- a/src/chop/nn/snn/modules/layernorm.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from torch import nn - - -class LayerNormZIPTF(nn.LayerNorm): - def __init__( - self, - normalized_shape, - eps: float = 1e-5, - elementwise_affine: bool = True, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) - self.X = 0.0 - self.Y_pre = None - - def reset(self): - self.X = 0.0 - self.Y_pre = None - - def forward(self, input): - self.X = self.X + input - Y = super().forward(self.X) - if self.Y_pre is not None: - Y_pre = self.Y_pre.detach().clone() - else: - Y_pre = 0.0 - self.Y_pre = Y - return Y - Y_pre diff --git a/src/chop/nn/snn/modules/linear.py b/src/chop/nn/snn/modules/linear.py deleted file mode 100644 index 44cfb4f92..000000000 --- a/src/chop/nn/snn/modules/linear.py +++ /dev/null @@ -1,107 +0,0 @@ -from torch import nn -import chop.nn.snn.base as base -import torch - - -class Linear(nn.Linear, base.StepModule): - def __init__( - self, in_features: int, out_features: int, bias: bool = True, step_mode="s" - ) -> None: - """ - * :ref:`API in English ` - .. _Linear-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.Linear` for other parameters' API - """ - super().__init__(in_features, out_features, bias) - self.step_mode = step_mode - - -# TODO: Merge this with StepModule? -class LinearUnfoldBias(nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - device=None, - dtype=None, - level: int = None, - neuron_type: str = None, - ) -> None: - super().__init__( - in_features, - out_features, - bias, - device, - dtype, - ) - self.is_work = False - self.first = True - self.zero_output = None - self.neuron_type = neuron_type - self.level = level - self.steps = self.level - self.realize_time = self.steps - - def reset(self): - # print("LLLinear reset") - self.is_work = False - self.first = True - self.zero_output = None - self.realize_time = self.steps - - def forward(self, input): - # print("LLLinear.steps",self.steps) - x = input - # if x.ndim == 2: - # B,N = x.shape - # elif x.ndim == 3: - # B,C,N = x.shape - # N = self.out_features - if x.dim() == 3: - B, N, _ = x.shape - D = self.out_features - shape_new = (B, N, D) - elif x.dim() == 2: - B, _ = x.shape - D = self.out_features - shape_new = (B, D) - if self.zero_output is None: - self.zero_output = torch.zeros( - size=shape_new, device=x.device, dtype=x.dtype - ) - - if (not torch.is_tensor(x) and (x == 0.0)) or ((x == 0.0).all()): - self.is_work = False - if self.realize_time > 0: - output = self.zero_output + ( - self.bias.data.unsqueeze(0) / self.steps - if self.bias is not None - else 0.0 - ) - self.realize_time = self.realize_time - 1 - self.is_work = True - return output - return self.zero_output - - output = super().forward(x) - - if self.neuron_type == "IF": - pass - else: - if self.bias is None: - pass - else: - output = output - self.bias.data.unsqueeze(0) - if self.realize_time > 0: - output = output + self.bias.data.unsqueeze(0) / self.steps - self.realize_time = self.realize_time - 1 - - self.is_work = True - self.first = False - - return output diff --git a/src/chop/nn/snn/modules/modules.py b/src/chop/nn/snn/modules/modules.py deleted file mode 100644 index 34865b4e6..000000000 --- a/src/chop/nn/snn/modules/modules.py +++ /dev/null @@ -1,151 +0,0 @@ -# *************************************************************************************** -# * Title: modules.py -# * Reference: This file is adapted from spikingJelly -# * Availability: https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/layer.py -# * Availability: https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/ann2snn/modules.py -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# *************************************************************************************** - -import torch.nn as nn -from torch import Tensor -import torch -import numpy as np -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional -import logging - - -class MultiStepContainer(nn.Sequential, base.MultiStepModule): - def __init__(self, *args): - super().__init__(*args) - for m in self: - assert not hasattr(m, "step_mode") or m.step_mode == "s" - if isinstance(m, base.StepModule): - if "m" in m.supported_step_mode(): - logging.warning( - f"{m} supports for step_mode == 's', which should not be contained by MultiStepContainer!" - ) - - def forward(self, x_seq: Tensor): - """ - :param x_seq: ``shape=[T, batch_size, ...]`` - :type x_seq: Tensor - :return: y_seq with ``shape=[T, batch_size, ...]`` - :rtype: Tensor - """ - return functional.multi_step_forward(x_seq, super().forward) - - -class SeqToANNContainer(nn.Sequential, base.MultiStepModule): - def __init__(self, *args): - super().__init__(*args) - for m in self: - assert not hasattr(m, "step_mode") or m.step_mode == "s" - if isinstance(m, base.StepModule): - if "m" in m.supported_step_mode(): - logging.warning( - f"{m} supports for step_mode == 's', which should not be contained by SeqToANNContainer!" - ) - - def forward(self, x_seq: Tensor): - """ - :param x_seq: shape=[T, batch_size, ...] - :type x_seq: Tensor - :return: y_seq, shape=[T, batch_size, ...] - :rtype: Tensor - """ - return functional.seq_to_ann_forward(x_seq, super().forward) - - -class VoltageHook(nn.Module): - def __init__(self, scale=1.0, momentum=0.1, mode="Max"): - """ - * :ref:`API in English ` - .. _voltageHook.__init__-en: - - :param scale: initial scaling value - :type scale: float - :param momentum: momentum value - :type momentum: float - :param mode: The mode. Value "Max" means recording the maximum value of ANN activation, "99.9%" means recording the 99.9% precentile of ANN activation, and a float of 0-1 means recording the corresponding multiple of the maximum activation value. - :type mode: str, float - - ``VoltageHook`` is placed behind ReLU and used to determine the range of activations in ANN inference. - - """ - super().__init__() - self.register_buffer("scale", torch.tensor(scale)) - self.mode = mode - self.num_batches_tracked = 0 - self.momentum = momentum - - def forward(self, x): - """ - * :ref:`API in English ` - .. _VoltageHook.forward-en: - - :param x: input tensor - :type x: torch.Tensor - :return: original input tensor - :rtype: torch.Tensor - - It doesn't process input tensors, but hooks the activation values of ReLU. - - """ - err_msg = "You have used a non-defined VoltageScale Method." - if isinstance(self.mode, str): - if self.mode[-1] == "%": - try: - s_t = torch.tensor( - np.percentile(x.detach().cpu(), float(self.mode[:-1])) - ) - except ValueError: - raise NotImplementedError(err_msg) - elif self.mode.lower() in ["max"]: - s_t = x.max().detach() - else: - raise NotImplementedError(err_msg) - elif isinstance(self.mode, float) and self.mode <= 1 and self.mode > 0: - s_t = x.max().detach() * self.mode - else: - raise NotImplementedError(err_msg) - - if self.num_batches_tracked == 0: - self.scale = s_t - else: - self.scale = (1 - self.momentum) * self.scale + self.momentum * s_t - self.num_batches_tracked += x.shape[0] - return x - - -class VoltageScaler(nn.Module): - def __init__(self, scale=1.0): - """ - * :ref:`API in English ` - .. _VoltageScaler.__init__-en: - - :param scale: scaling value - :type scale: float - - ``VoltageScaler`` is used for scaling current in SNN inference. - - """ - super().__init__() - self.register_buffer("scale", torch.tensor(scale)) - - def forward(self, x): - """ - * :ref:`API in English ` - .. _VoltageScaler.forward-en: - - :param x: input tensor, or input current - :type x: torch.Tensor - :return: current after scaling - :rtype: torch.Tensor - - """ - return x * self.scale - - def extra_repr(self): - return "%f" % self.scale.item() diff --git a/src/chop/nn/snn/modules/neuron/__init__.py b/src/chop/nn/snn/modules/neuron/__init__.py deleted file mode 100644 index 5c93710a2..000000000 --- a/src/chop/nn/snn/modules/neuron/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .neuron import BaseNode -from .ifnode import IFNode -from .lifnode import LIFNode -from .parametriclifnode import ParametricLIFNode -from .st_bifnode import ST_BIFNode diff --git a/src/chop/nn/snn/modules/neuron/ifnode.py b/src/chop/nn/snn/modules/neuron/ifnode.py deleted file mode 100644 index e5e3c373f..000000000 --- a/src/chop/nn/snn/modules/neuron/ifnode.py +++ /dev/null @@ -1,339 +0,0 @@ -from abc import abstractmethod -from typing import Callable, Optional -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -import numpy as np -import logging - -from ... import base -from .. import surrogate - -from ...auto_cuda import neuron_kernel as ac_neuron_kernel -from ...auto_cuda import ss_neuron_kernel as ss_ac_neuron_kernel - -try: - from ... import neuron_kernel, cuda_utils - -except BaseException as e: - logging.info(f"spikingjelly.activation_based.neuron: {e}") - neuron_kernel = None - cuda_utils = None - -from .neuron import BaseNode, SimpleBaseNode - - -class SimpleIFNode(SimpleBaseNode): - def neuronal_charge(self, x: torch.Tensor): - self.v = self.v + x - - -class IFNode(BaseNode): - def __init__( - self, - v_threshold: float = 1.0, - v_reset: Optional[float] = 0.0, - surrogate_function: Callable = surrogate.Sigmoid(), - detach_reset: bool = False, - step_mode="s", - backend="torch", - store_v_seq: bool = False, - ): - """ - * :ref:`API in English ` - - .. _IFNode.__init__-en: - - :param v_threshold: threshold of this neurons layer - :type v_threshold: float - - :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` - after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike - :type v_reset: Optional[float] - - :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward - :type surrogate_function: Callable - - :param detach_reset: whether detach the computation graph of reset in backward - :type detach_reset: bool - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can - print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, - using ``'cupy'`` backend will have the fastest training speed - :type backend: str - - :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls - whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, - only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the - memory consumption - :type store_v_seq: bool - - The Integrate-and-Fire neuron, which can be seen as a ideal integrator. The voltage of the IF neuron will not decay - as that of the LIF neuron. The sub-threshold neural dynamics of it is as followed: - - .. math:: - H[t] = V[t-1] + X[t] - - """ - super().__init__( - v_threshold, - v_reset, - surrogate_function, - detach_reset, - step_mode, - backend, - store_v_seq, - ) - - @property - def supported_backends(self): - if self.step_mode == "s": - return ("torch", "cupy") - elif self.step_mode == "m": - return ("torch", "cupy") - else: - raise ValueError(self.step_mode) - - def neuronal_charge(self, x: torch.Tensor): - self.v = self.v + x - - @staticmethod - @torch.jit.script - def jit_eval_single_step_forward_hard_reset( - x: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float - ): - v = v + x - spike = (v >= v_threshold).to(x) - v = v_reset * spike + (1.0 - spike) * v - return spike, v - - @staticmethod - @torch.jit.script - def jit_eval_single_step_forward_soft_reset( - x: torch.Tensor, v: torch.Tensor, v_threshold: float - ): - v = v + x - spike = (v >= v_threshold).to(x) - v = v - spike * v_threshold - return spike, v - - @staticmethod - @torch.jit.script - def jit_eval_multi_step_forward_hard_reset( - x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float - ): - spike_seq = torch.zeros_like(x_seq) - for t in range(x_seq.shape[0]): - v = v + x_seq[t] - spike = (v >= v_threshold).to(x_seq) - v = v_reset * spike + (1.0 - spike) * v - spike_seq[t] = spike - return spike_seq, v - - @staticmethod - @torch.jit.script - def jit_eval_multi_step_forward_hard_reset_with_v_seq( - x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float - ): - spike_seq = torch.zeros_like(x_seq) - v_seq = torch.zeros_like(x_seq) - for t in range(x_seq.shape[0]): - v = v + x_seq[t] - spike = (v >= v_threshold).to(x_seq) - v = v_reset * spike + (1.0 - spike) * v - spike_seq[t] = spike - v_seq[t] = v - return spike_seq, v, v_seq - - @staticmethod - @torch.jit.script - def jit_eval_multi_step_forward_soft_reset( - x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float - ): - spike_seq = torch.zeros_like(x_seq) - for t in range(x_seq.shape[0]): - v = v + x_seq[t] - spike = (v >= v_threshold).to(x_seq) - v = v - spike * v_threshold - spike_seq[t] = spike - return spike_seq, v - - @staticmethod - @torch.jit.script - def jit_eval_multi_step_forward_soft_reset_with_v_seq( - x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float - ): - spike_seq = torch.zeros_like(x_seq) - v_seq = torch.zeros_like(x_seq) - for t in range(x_seq.shape[0]): - v = v + x_seq[t] - spike = (v >= v_threshold).to(x_seq) - v = v - spike * v_threshold - spike_seq[t] = spike - v_seq[t] = v - return spike_seq, v, v_seq - - def multi_step_forward(self, x_seq: torch.Tensor): - if self.training: - if self.backend == "torch": - return super().multi_step_forward(x_seq) - elif self.backend == "cupy": - hard_reset = self.v_reset is not None - - if x_seq.dtype == torch.float: - dtype = "float" - elif x_seq.dtype == torch.half: - dtype = "half2" - else: - raise NotImplementedError(x_seq.dtype) - - if ( - self.forward_kernel is None - or not self.forward_kernel.check_attributes( - hard_reset=hard_reset, dtype=dtype - ) - ): - self.forward_kernel = ac_neuron_kernel.IFNodeFPTTKernel( - hard_reset=hard_reset, dtype=dtype - ) - - if ( - self.backward_kernel is None - or not self.backward_kernel.check_attributes( - surrogate_function=self.surrogate_function.cuda_codes, - hard_reset=hard_reset, - detach_reset=self.detach_reset, - dtype=dtype, - ) - ): - self.backward_kernel = ac_neuron_kernel.IFNodeBPTTKernel( - surrogate_function=self.surrogate_function.cuda_codes, - hard_reset=hard_reset, - detach_reset=self.detach_reset, - dtype=dtype, - ) - - self.v_float_to_tensor(x_seq[0]) - - spike_seq, v_seq = ac_neuron_kernel.IFNodeATGF.apply( - x_seq.flatten(1), - self.v.flatten(0), - self.v_threshold, - self.v_reset, - self.forward_kernel, - self.backward_kernel, - ) - - spike_seq = spike_seq.reshape(x_seq.shape) - v_seq = v_seq.reshape(x_seq.shape) - - if self.store_v_seq: - self.v_seq = v_seq - - self.v = v_seq[-1].clone() - - return spike_seq - else: - raise ValueError(self.backend) - - else: - self.v_float_to_tensor(x_seq[0]) - if self.v_reset is None: - if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_soft_reset_with_v_seq( - x_seq, self.v, self.v_threshold - ) - ) - else: - spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset( - x_seq, self.v, self.v_threshold - ) - else: - if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_hard_reset_with_v_seq( - x_seq, self.v, self.v_threshold, self.v_reset - ) - ) - else: - spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset( - x_seq, self.v, self.v_threshold, self.v_reset - ) - return spike_seq - - def single_step_forward(self, x: torch.Tensor): - if self.training: - if self.backend == "torch": - return super().single_step_forward(x) - elif self.backend == "cupy": - hard_reset = self.v_reset is not None - - if x.dtype == torch.float: - dtype = "float" - elif x.dtype == torch.half: - dtype = "half2" - else: - raise NotImplementedError(x.dtype) - - if ( - self.forward_kernel is None - or not self.forward_kernel.check_attributes( - hard_reset=hard_reset, dtype=dtype - ) - ): - self.forward_kernel = ss_ac_neuron_kernel.IFNodeFPKernel( - hard_reset=hard_reset, dtype=dtype - ) - - if ( - self.backward_kernel is None - or not self.backward_kernel.check_attributes( - surrogate_function=self.surrogate_function.cuda_codes, - hard_reset=hard_reset, - detach_reset=self.detach_reset, - dtype=dtype, - ) - ): - self.backward_kernel = ss_ac_neuron_kernel.IFNodeBPKernel( - surrogate_function=self.surrogate_function.cuda_codes, - hard_reset=hard_reset, - detach_reset=self.detach_reset, - dtype=dtype, - ) - - self.v_float_to_tensor(x) - - spike, v = ss_ac_neuron_kernel.IFNodeATGF.apply( - x.flatten(0), - self.v.flatten(0), - self.v_threshold, - self.v_reset, - self.forward_kernel, - self.backward_kernel, - ) - - spike = spike.reshape(x.shape) - v = v.reshape(x.shape) - - self.v = v - - return spike - else: - raise ValueError(self.backend) - - else: - self.v_float_to_tensor(x) - if self.v_reset is None: - spike, self.v = self.jit_eval_single_step_forward_soft_reset( - x, self.v, self.v_threshold - ) - else: - spike, self.v = self.jit_eval_single_step_forward_hard_reset( - x, self.v, self.v_threshold, self.v_reset - ) - return spike diff --git a/src/chop/nn/snn/modules/neuron/lifnode.py b/src/chop/nn/snn/modules/neuron/lifnode.py deleted file mode 100644 index 90b02289c..000000000 --- a/src/chop/nn/snn/modules/neuron/lifnode.py +++ /dev/null @@ -1,569 +0,0 @@ -from typing import Callable, Optional -import torch -import logging -from .. import surrogate - -from ...auto_cuda import neuron_kernel as ac_neuron_kernel -from ...auto_cuda import ss_neuron_kernel as ss_ac_neuron_kernel - -try: - from ... import neuron_kernel, cuda_utils - -except BaseException as e: - logging.info(f"spikingjelly.activation_based.neuron: {e}") - neuron_kernel = None - cuda_utils = None - -from .neuron import BaseNode, SimpleBaseNode - - -class SimpleLIFNode(SimpleBaseNode): - def __init__( - self, - tau: float, - decay_input: bool, - v_threshold: float = 1.0, - v_reset: float = 0.0, - surrogate_function: Callable = surrogate.Sigmoid(), - detach_reset: bool = False, - step_mode="s", - ): - super().__init__( - v_threshold, v_reset, surrogate_function, detach_reset, step_mode - ) - self.tau = tau - self.decay_input = decay_input - - def neuronal_charge(self, x: torch.Tensor): - if self.decay_input: - self.v = self.v + (self.v_reset - self.v + x) / self.tau - else: - self.v = self.v + (self.v_reset - self.v) / self.tau + x - - -class LIFNode(BaseNode): - def __init__( - self, - tau: float = 2.0, - decay_input: bool = True, - v_threshold: float = 1.0, - v_reset: Optional[float] = 0.0, - surrogate_function: Callable = surrogate.Sigmoid(), - detach_reset: bool = False, - step_mode="s", - backend="torch", - store_v_seq: bool = False, - ): - """ - * :ref:`API in English ` - - .. _LIFNode.__init__-en: - - :param tau: membrane time constant - :type tau: float - - :param decay_input: whether the input will decay - :type decay_input: bool - - :param v_threshold: threshold of this neurons layer - :type v_threshold: float - - :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` - after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike - :type v_reset: Optional[float] - - :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward - :type surrogate_function: Callable - - :param detach_reset: whether detach the computation graph of reset in backward - :type detach_reset: bool - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can - print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, - using ``'cupy'`` backend will have the fastest training speed - :type backend: str - - :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls - whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, - only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the - memory consumption - :type store_v_seq: bool - - The Leaky Integrate-and-Fire neuron, which can be seen as a leaky integrator. - The subthreshold neural dynamics of it is as followed: - - IF ``decay_input == True``: - - .. math:: - H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) - - IF ``decay_input == False``: - - .. math:: - H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] - - """ - assert isinstance(tau, float) and tau > 1.0 - - super().__init__( - v_threshold, - v_reset, - surrogate_function, - detach_reset, - step_mode, - backend, - store_v_seq, - ) - - self.tau = tau - self.decay_input = decay_input - - @property - def supported_backends(self): - if self.step_mode == "s": - return ("torch", "cupy") - elif self.step_mode == "m": - return ("torch", "cupy") - else: - raise ValueError(self.step_mode) - - def extra_repr(self): - return super().extra_repr() + f", tau={self.tau}" - - def neuronal_charge(self, x: torch.Tensor): - if self.decay_input: - if self.v_reset is None or self.v_reset == 0.0: - self.v = self.neuronal_charge_decay_input_reset0(x, self.v, self.tau) - else: - self.v = self.neuronal_charge_decay_input( - x, self.v, self.v_reset, self.tau - ) - - else: - if self.v_reset is None or self.v_reset == 0.0: - self.v = self.neuronal_charge_no_decay_input_reset0(x, self.v, self.tau) - else: - self.v = self.neuronal_charge_no_decay_input( - x, self.v, self.v_reset, self.tau - ) - - @staticmethod - @torch.jit.script - def neuronal_charge_decay_input_reset0( - x: torch.Tensor, v: torch.Tensor, tau: float - ): - v = v + (x - v) / tau - return v - - @staticmethod - @torch.jit.script - def neuronal_charge_decay_input( - x: torch.Tensor, v: torch.Tensor, v_reset: float, tau: float - ): - v = v + (x - (v - v_reset)) / tau - return v - - @staticmethod - @torch.jit.script - def neuronal_charge_no_decay_input_reset0( - x: torch.Tensor, v: torch.Tensor, tau: float - ): - v = v * (1.0 - 1.0 / tau) + x - return v - - @staticmethod - @torch.jit.script - def neuronal_charge_no_decay_input( - x: torch.Tensor, v: torch.Tensor, v_reset: float, tau: float - ): - v = v - (v - v_reset) / tau + x - return v - - @staticmethod - @torch.jit.script - def jit_eval_single_step_forward_hard_reset_decay_input( - x: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float, tau: float - ): - v = v + (x - (v - v_reset)) / tau - spike = (v >= v_threshold).to(x) - v = v_reset * spike + (1.0 - spike) * v - return spike, v - - @staticmethod - @torch.jit.script - def jit_eval_single_step_forward_hard_reset_no_decay_input( - x: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float, tau: float - ): - v = v - (v - v_reset) / tau + x - spike = (v >= v_threshold).to(x) - v = v_reset * spike + (1.0 - spike) * v - return spike, v - - @staticmethod - @torch.jit.script - def jit_eval_single_step_forward_soft_reset_decay_input( - x: torch.Tensor, v: torch.Tensor, v_threshold: float, tau: float - ): - v = v + (x - v) / tau - spike = (v >= v_threshold).to(x) - v = v - spike * v_threshold - return spike, v - - @staticmethod - @torch.jit.script - def jit_eval_single_step_forward_soft_reset_no_decay_input( - x: torch.Tensor, v: torch.Tensor, v_threshold: float, tau: float - ): - v = v * (1.0 - 1.0 / tau) + x - spike = (v >= v_threshold).to(x) - v = v - spike * v_threshold - return spike, v - - @staticmethod - @torch.jit.script - def jit_eval_multi_step_forward_hard_reset_decay_input( - x_seq: torch.Tensor, - v: torch.Tensor, - v_threshold: float, - v_reset: float, - tau: float, - ): - spike_seq = torch.zeros_like(x_seq) - for t in range(x_seq.shape[0]): - v = v + (x_seq[t] - (v - v_reset)) / tau - spike = (v >= v_threshold).to(x_seq) - v = v_reset * spike + (1.0 - spike) * v - spike_seq[t] = spike - return spike_seq, v - - @staticmethod - @torch.jit.script - def jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq( - x_seq: torch.Tensor, - v: torch.Tensor, - v_threshold: float, - v_reset: float, - tau: float, - ): - spike_seq = torch.zeros_like(x_seq) - v_seq = torch.zeros_like(x_seq) - for t in range(x_seq.shape[0]): - v = v + (x_seq[t] - (v - v_reset)) / tau - spike = (v >= v_threshold).to(x_seq) - v = v_reset * spike + (1.0 - spike) * v - spike_seq[t] = spike - v_seq[t] = v - return spike_seq, v, v_seq - - @staticmethod - @torch.jit.script - def jit_eval_multi_step_forward_hard_reset_no_decay_input( - x_seq: torch.Tensor, - v: torch.Tensor, - v_threshold: float, - v_reset: float, - tau: float, - ): - spike_seq = torch.zeros_like(x_seq) - for t in range(x_seq.shape[0]): - v = v - (v - v_reset) / tau + x_seq[t] - spike = (v >= v_threshold).to(x_seq) - v = v_reset * spike + (1.0 - spike) * v - spike_seq[t] = spike - return spike_seq, v - - @staticmethod - @torch.jit.script - def jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq( - x_seq: torch.Tensor, - v: torch.Tensor, - v_threshold: float, - v_reset: float, - tau: float, - ): - spike_seq = torch.zeros_like(x_seq) - v_seq = torch.zeros_like(x_seq) - for t in range(x_seq.shape[0]): - v = v - (v - v_reset) / tau + x_seq[t] - spike = (v >= v_threshold).to(x_seq) - v = v_reset * spike + (1.0 - spike) * v - spike_seq[t] = spike - v_seq[t] = v - return spike_seq, v, v_seq - - @staticmethod - @torch.jit.script - def jit_eval_multi_step_forward_soft_reset_decay_input( - x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, tau: float - ): - spike_seq = torch.zeros_like(x_seq) - for t in range(x_seq.shape[0]): - v = v + (x_seq[t] - v) / tau - spike = (v >= v_threshold).to(x_seq) - v = v - spike * v_threshold - spike_seq[t] = spike - return spike_seq, v - - @staticmethod - @torch.jit.script - def jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq( - x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, tau: float - ): - spike_seq = torch.zeros_like(x_seq) - v_seq = torch.zeros_like(x_seq) - for t in range(x_seq.shape[0]): - v = v + (x_seq[t] - v) / tau - spike = (v >= v_threshold).to(x_seq) - v = v - spike * v_threshold - spike_seq[t] = spike - v_seq[t] = v - return spike_seq, v, v_seq - - @staticmethod - @torch.jit.script - def jit_eval_multi_step_forward_soft_reset_no_decay_input( - x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, tau: float - ): - spike_seq = torch.zeros_like(x_seq) - for t in range(x_seq.shape[0]): - v = v * (1.0 - 1.0 / tau) + x_seq[t] - spike = (v >= v_threshold).to(x_seq) - v = v - spike * v_threshold - spike_seq[t] = spike - return spike_seq, v - - @staticmethod - @torch.jit.script - def jit_eval_multi_step_forward_soft_reset_no_decay_input_with_v_seq( - x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, tau: float - ): - spike_seq = torch.zeros_like(x_seq) - v_seq = torch.zeros_like(x_seq) - for t in range(x_seq.shape[0]): - v = v * (1.0 - 1.0 / tau) + x_seq[t] - spike = (v >= v_threshold).to(x_seq) - v = v - spike * v_threshold - spike_seq[t] = spike - v_seq[t] = v - return spike_seq, v, v_seq - - def single_step_forward(self, x: torch.Tensor): - if self.training: - if self.backend == "torch": - return super().single_step_forward(x) - elif self.backend == "cupy": - hard_reset = self.v_reset is not None - - if x.dtype == torch.float: - dtype = "float" - elif x.dtype == torch.half: - dtype = "half2" - else: - raise NotImplementedError(x.dtype) - - if ( - self.forward_kernel is None - or not self.forward_kernel.check_attributes( - hard_reset=hard_reset, dtype=dtype, decay_input=self.decay_input - ) - ): - self.forward_kernel = ss_ac_neuron_kernel.LIFNodeFPKernel( - decay_input=self.decay_input, hard_reset=hard_reset, dtype=dtype - ) - - if ( - self.backward_kernel is None - or not self.backward_kernel.check_attributes( - surrogate_function=self.surrogate_function.cuda_codes, - hard_reset=hard_reset, - detach_reset=self.detach_reset, - dtype=dtype, - decay_input=self.decay_input, - ) - ): - self.backward_kernel = ss_ac_neuron_kernel.LIFNodeBPKernel( - decay_input=self.decay_input, - surrogate_function=self.surrogate_function.cuda_codes, - hard_reset=hard_reset, - detach_reset=self.detach_reset, - dtype=dtype, - ) - - self.v_float_to_tensor(x) - - spike, v = ss_ac_neuron_kernel.LIFNodeATGF.apply( - x.flatten(0), - self.v.flatten(0), - self.v_threshold, - self.v_reset, - 1.0 / self.tau, - self.forward_kernel, - self.backward_kernel, - ) - - spike = spike.reshape(x.shape) - v = v.reshape(x.shape) - - self.v = v - - return spike - else: - raise ValueError(self.backend) - - else: - self.v_float_to_tensor(x) - if self.v_reset is None: - if self.decay_input: - spike, self.v = ( - self.jit_eval_single_step_forward_soft_reset_decay_input( - x, self.v, self.v_threshold, self.tau - ) - ) - else: - spike, self.v = ( - self.jit_eval_single_step_forward_soft_reset_no_decay_input( - x, self.v, self.v_threshold, self.tau - ) - ) - else: - if self.decay_input: - spike, self.v = ( - self.jit_eval_single_step_forward_hard_reset_decay_input( - x, self.v, self.v_threshold, self.v_reset, self.tau - ) - ) - else: - spike, self.v = ( - self.jit_eval_single_step_forward_hard_reset_no_decay_input( - x, self.v, self.v_threshold, self.v_reset, self.tau - ) - ) - return spike - - def multi_step_forward(self, x_seq: torch.Tensor): - if self.training: - if self.backend == "torch": - return super().multi_step_forward(x_seq) - elif self.backend == "cupy": - - hard_reset = self.v_reset is not None - if x_seq.dtype == torch.float: - dtype = "float" - elif x_seq.dtype == torch.half: - dtype = "half2" - else: - raise NotImplementedError(x_seq.dtype) - - if ( - self.forward_kernel is None - or not self.forward_kernel.check_attributes( - hard_reset=hard_reset, dtype=dtype, decay_input=self.decay_input - ) - ): - self.forward_kernel = ac_neuron_kernel.LIFNodeFPTTKernel( - decay_input=self.decay_input, hard_reset=hard_reset, dtype=dtype - ) - - if ( - self.backward_kernel is None - or not self.backward_kernel.check_attributes( - surrogate_function=self.surrogate_function.cuda_codes, - hard_reset=hard_reset, - detach_reset=self.detach_reset, - dtype=dtype, - decay_input=self.decay_input, - ) - ): - self.backward_kernel = ac_neuron_kernel.LIFNodeBPTTKernel( - decay_input=self.decay_input, - surrogate_function=self.surrogate_function.cuda_codes, - hard_reset=hard_reset, - detach_reset=self.detach_reset, - dtype=dtype, - ) - - self.v_float_to_tensor(x_seq[0]) - - spike_seq, v_seq = ac_neuron_kernel.LIFNodeATGF.apply( - x_seq.flatten(1), - self.v.flatten(0), - self.v_threshold, - self.v_reset, - 1.0 / self.tau, - self.forward_kernel, - self.backward_kernel, - ) - - spike_seq = spike_seq.reshape(x_seq.shape) - v_seq = v_seq.reshape(x_seq.shape) - - if self.store_v_seq: - self.v_seq = v_seq - - self.v = v_seq[-1].clone() - - return spike_seq - else: - raise ValueError(self.backend) - - else: - self.v_float_to_tensor(x_seq[0]) - if self.v_reset is None: - if self.decay_input: - if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq( - x_seq, self.v, self.v_threshold, self.tau - ) - ) - else: - spike_seq, self.v = ( - self.jit_eval_multi_step_forward_soft_reset_decay_input( - x_seq, self.v, self.v_threshold, self.tau - ) - ) - else: - if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_soft_reset_no_decay_input_with_v_seq( - x_seq, self.v, self.v_threshold, self.tau - ) - ) - else: - spike_seq, self.v = ( - self.jit_eval_multi_step_forward_soft_reset_no_decay_input( - x_seq, self.v, self.v_threshold, self.tau - ) - ) - else: - if self.decay_input: - if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq( - x_seq, self.v, self.v_threshold, self.v_reset, self.tau - ) - ) - else: - spike_seq, self.v = ( - self.jit_eval_multi_step_forward_hard_reset_decay_input( - x_seq, self.v, self.v_threshold, self.v_reset, self.tau - ) - ) - else: - if self.store_v_seq: - spike_seq, self.v, self.v_seq = ( - self.jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq( - x_seq, self.v, self.v_threshold, self.v_reset, self.tau - ) - ) - else: - spike_seq, self.v = ( - self.jit_eval_multi_step_forward_hard_reset_no_decay_input( - x_seq, self.v, self.v_threshold, self.v_reset, self.tau - ) - ) - - return spike_seq diff --git a/src/chop/nn/snn/modules/neuron/neuron.py b/src/chop/nn/snn/modules/neuron/neuron.py deleted file mode 100644 index 95d7c7a86..000000000 --- a/src/chop/nn/snn/modules/neuron/neuron.py +++ /dev/null @@ -1,265 +0,0 @@ -# *************************************************************************************** -# * Title: neuron.py -# * Reference: This file is adapted from spikingJelly -# * Availability: https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/timing_based/neuron.py -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# *************************************************************************************** - -from abc import abstractmethod -from typing import Callable, Optional -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -import numpy as np -import logging - -from ... import base -from .. import surrogate - -from ...auto_cuda import neuron_kernel as ac_neuron_kernel -from ...auto_cuda import ss_neuron_kernel as ss_ac_neuron_kernel - -try: - from ... import neuron_kernel, cuda_utils - -except BaseException as e: - logging.info(f"spikingjelly.activation_based.neuron: {e}") - neuron_kernel = None - cuda_utils = None - - -class SimpleBaseNode(base.MemoryModule): - def __init__( - self, - v_threshold: float = 1.0, - v_reset: Optional[float] = 0.0, - surrogate_function: Callable = surrogate.Sigmoid(), - detach_reset: bool = False, - step_mode="s", - ): - """ - A simple version of ``BaseNode``. The user can modify this neuron easily. - """ - super().__init__() - self.v_threshold = v_threshold - self.v_reset = v_reset - self.surrogate_function = surrogate_function - self.detach_reset = detach_reset - self.step_mode = step_mode - self.register_memory(name="v", value=0.0) - - def single_step_forward(self, x: torch.Tensor): - - self.neuronal_charge(x) - spike = self.neuronal_fire() - self.neuronal_reset(spike) - return spike - - def neuronal_charge(self, x: torch.Tensor): - raise NotImplementedError - - def neuronal_fire(self): - return self.surrogate_function(self.v - self.v_threshold) - - def neuronal_reset(self, spike): - if self.detach_reset: - spike_d = spike.detach() - else: - spike_d = spike - - if self.v_reset is None: - # soft reset - self.v = self.v - self.v_threshold * spike_d - - else: - # hard reset - self.v = spike_d * self.v_reset + (1.0 - spike_d) * self.v - - -class BaseNode(base.MemoryModule): - def __init__( - self, - v_threshold: float = 1.0, - v_reset: Optional[float] = 0.0, - surrogate_function: Callable = surrogate.Sigmoid(), - detach_reset: bool = False, - step_mode="s", - backend="torch", - store_v_seq: bool = False, - ): - """ - * :ref:`API in English ` - .. _BaseNode.__init__-en: - - :param v_threshold: threshold of this neurons layer - :type v_threshold: float - - :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` - after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike - :type v_reset: Optional[float] - - :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward - :type surrogate_function: Callable - - :param detach_reset: whether detach the computation graph of reset in backward - :type detach_reset: bool - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can - print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, - using ``'cupy'`` backend will have the fastest training speed - :type backend: str - - :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls - whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, - only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the - memory consumption - :type store_v_seq: bool - - This class is the base class of differentiable spiking neurons. - """ - assert isinstance(v_reset, float) or v_reset is None - assert isinstance(v_threshold, float) - assert isinstance(detach_reset, bool) - super().__init__() - - if v_reset is None: - self.register_memory("v", 0.0) - else: - self.register_memory("v", v_reset) - - self.v_threshold = v_threshold - self.v_reset = v_reset - - self.detach_reset = detach_reset - self.surrogate_function = surrogate_function - - self.step_mode = step_mode - self.backend = backend - - self.store_v_seq = store_v_seq - - # used in lava_exchange - self.lava_s_cale = 1 << 6 - - # used for cupy backend - self.forward_kernel = None - self.backward_kernel = None - - @property - def store_v_seq(self): - return self._store_v_seq - - @store_v_seq.setter - def store_v_seq(self, value: bool): - self._store_v_seq = value - if value: - if not hasattr(self, "v_seq"): - self.register_memory("v_seq", None) - - @staticmethod - @torch.jit.script - def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float): - v = (1.0 - spike) * v + spike * v_reset - return v - - @staticmethod - @torch.jit.script - def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float): - v = v - spike * v_threshold - return v - - @abstractmethod - def neuronal_charge(self, x: torch.Tensor): - """ - * :ref:`API in English ` - .. _BaseNode.neuronal_charge-en: - - - Define the charge difference equation. The sub-class must implement this function. - """ - raise NotImplementedError - - def neuronal_fire(self): - """ - * :ref:`API in English ` - - .. _BaseNode.neuronal_fire-en: - - - Calculate out spikes of neurons by their current membrane potential and threshold voltage. - """ - - return self.surrogate_function(self.v - self.v_threshold) - - def neuronal_reset(self, spike): - """ - * :ref:`API in English ` - - .. _BaseNode.neuronal_reset-en: - - - Reset the membrane potential according to neurons' output spikes. - """ - if self.detach_reset: - spike_d = spike.detach() - else: - spike_d = spike - - if self.v_reset is None: - # soft reset - self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold) - - else: - # hard reset - self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset) - - def extra_repr(self): - return f"v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, step_mode={self.step_mode}, backend={self.backend}" - - def single_step_forward(self, x: torch.Tensor): - """ - - * :ref:`API in English ` - - .. _BaseNode.single_step_forward-en: - - :param x: increment of voltage inputted to neurons - :type x: torch.Tensor - - :return: out spikes of neurons - :rtype: torch.Tensor - - Forward by the order of `neuronal_charge`, `neuronal_fire`, and `neuronal_reset`. - - """ - self.v_float_to_tensor(x) - self.neuronal_charge(x) - spike = self.neuronal_fire() - self.neuronal_reset(spike) - return spike - - def multi_step_forward(self, x_seq: torch.Tensor): - T = x_seq.shape[0] - y_seq = [] - if self.store_v_seq: - v_seq = [] - for t in range(T): - y = self.single_step_forward(x_seq[t]) - y_seq.append(y) - if self.store_v_seq: - v_seq.append(self.v) - - if self.store_v_seq: - self.v_seq = torch.stack(v_seq) - - return torch.stack(y_seq) - - def v_float_to_tensor(self, x: torch.Tensor): - if isinstance(self.v, float): - v_init = self.v - self.v = torch.full_like(x.data, v_init) diff --git a/src/chop/nn/snn/modules/neuron/parametriclifnode.py b/src/chop/nn/snn/modules/neuron/parametriclifnode.py deleted file mode 100644 index bf2ac2c30..000000000 --- a/src/chop/nn/snn/modules/neuron/parametriclifnode.py +++ /dev/null @@ -1,192 +0,0 @@ -from typing import Callable, Optional -import torch -import torch.nn as nn -import logging -import math -from .. import surrogate - -from ...auto_cuda import neuron_kernel as ac_neuron_kernel -from ...auto_cuda import ss_neuron_kernel as ss_ac_neuron_kernel - -try: - from ... import neuron_kernel, cuda_utils - -except BaseException as e: - logging.info(f"spikingjelly.activation_based.neuron: {e}") - neuron_kernel = None - cuda_utils = None - -from .neuron import BaseNode, SimpleBaseNode - - -class ParametricLIFNode(BaseNode): - def __init__( - self, - init_tau: float = 2.0, - decay_input: bool = True, - v_threshold: float = 1.0, - v_reset: float = 0.0, - surrogate_function: Callable = surrogate.Sigmoid(), - detach_reset: bool = False, - step_mode="s", - backend="torch", - store_v_seq: bool = False, - ): - """ - * :ref:`API in English ` - - .. _ParametricLIFNode.__init__-en: - - :param init_tau: the initial value of membrane time constant - :type init_tau: float - - :param decay_input: whether the input will decay - :type decay_input: bool - - :param v_threshold: threshold of this neurons layer - :type v_threshold: float - - :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` - after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike - :type v_reset: float - - :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward - :type surrogate_function: Callable - - :param detach_reset: whether detach the computation graph of reset in backward - :type detach_reset: bool - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can - print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, - using ``'cupy'`` backend will have the fastest training speed - :type backend: str - - :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls - whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, - only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the - memory consumption - :type store_v_seq: bool - - :param cupy_fp32_inference: If `True`, if this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this - module will use `cupy` to accelerate. This option has priority over ``backend`` - :type cupy_fp32_inference: bool - - The Parametric Leaky Integrate-and-Fire (PLIF) neuron, which is proposed by `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks `_ and can be seen as a leaky integrator. - The subthreshold neural dynamics of it is as followed: - - IF ``decay_input == True``: - - .. math:: - H = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) - - IF ``decay_input == False``: - - .. math:: - H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] - - where :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`, :math:`w` is a learnable parameter. - """ - - assert isinstance(init_tau, float) and init_tau > 1.0 - super().__init__( - v_threshold, - v_reset, - surrogate_function, - detach_reset, - step_mode, - backend, - store_v_seq, - ) - self.decay_input = decay_input - init_w = -math.log(init_tau - 1.0) - self.w = nn.Parameter(torch.as_tensor(init_w)) - - @property - def supported_backends(self): - if self.step_mode == "s": - return ("torch",) - elif self.step_mode == "m": - return ("torch", "cupy") - else: - raise ValueError(self.step_mode) - - def extra_repr(self): - with torch.no_grad(): - tau = 1.0 / self.w.sigmoid() - return super().extra_repr() + f", tau={tau}" - - def neuronal_charge(self, x: torch.Tensor): - if self.decay_input: - if self.v_reset is None or self.v_reset == 0.0: - self.v = self.v + (x - self.v) * self.w.sigmoid() - else: - self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid() - else: - if self.v_reset is None or self.v_reset == 0.0: - self.v = self.v * (1.0 - self.w.sigmoid()) + x - else: - self.v = self.v - (self.v - self.v_reset) * self.w.sigmoid() + x - - def multi_step_forward(self, x_seq: torch.Tensor): - if self.backend == "torch": - return super().multi_step_forward(x_seq) - elif self.backend == "cupy": - hard_reset = self.v_reset is not None - if x_seq.dtype == torch.float: - dtype = "float" - elif x_seq.dtype == torch.half: - dtype = "half2" - else: - raise NotImplementedError(x_seq.dtype) - - if self.forward_kernel is None or not self.forward_kernel.check_attributes( - hard_reset=hard_reset, dtype=dtype, decay_input=self.decay_input - ): - self.forward_kernel = ac_neuron_kernel.ParametricLIFNodeFPTTKernel( - decay_input=self.decay_input, hard_reset=hard_reset, dtype=dtype - ) - - if ( - self.backward_kernel is None - or not self.backward_kernel.check_attributes( - surrogate_function=self.surrogate_function.cuda_codes, - hard_reset=hard_reset, - detach_reset=self.detach_reset, - dtype=dtype, - decay_input=self.decay_input, - ) - ): - self.backward_kernel = ac_neuron_kernel.ParametricLIFNodeBPTTKernel( - decay_input=self.decay_input, - surrogate_function=self.surrogate_function.cuda_codes, - hard_reset=hard_reset, - detach_reset=self.detach_reset, - dtype=dtype, - ) - - self.v_float_to_tensor(x_seq[0]) - - spike_seq, v_seq = ac_neuron_kernel.ParametricLIFNodeATGF.apply( - x_seq.flatten(1), - self.v.flatten(0), - self.v_threshold, - self.v_reset, - self.w.sigmoid().to(x_seq), - self.forward_kernel, - self.backward_kernel, - ) - - spike_seq = spike_seq.reshape(x_seq.shape) - v_seq = v_seq.reshape(x_seq.shape) - - if self.store_v_seq: - self.v_seq = v_seq - - self.v = v_seq[-1].clone() - - return spike_seq - else: - raise ValueError(self.backend) diff --git a/src/chop/nn/snn/modules/neuron/st_bifnode.py b/src/chop/nn/snn/modules/neuron/st_bifnode.py deleted file mode 100644 index cbef9162b..000000000 --- a/src/chop/nn/snn/modules/neuron/st_bifnode.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.jit import Final -import math -from copy import deepcopy -from typing import List, Optional, Tuple, Union -import math - - -# TODO: this need to be change to neurons model later if we want to support training -class ST_BIFNode(nn.Module): - def __init__(self, q_threshold=torch.tensor(1), level=32, sym=False): - super(ST_BIFNode, self).__init__() - self.q = 0.0 - self.acc_q = 0.0 - self.q_threshold = ( - torch.tensor(q_threshold) - if not torch.is_tensor(q_threshold) - else q_threshold - ) - self.is_work = False - self.cur_output = 0.0 - self.level = torch.tensor(level) - self.sym = sym - if sym: - self.pos_max = torch.tensor(level // 2 - 1) - self.neg_min = torch.tensor(-level // 2) - else: - self.pos_max = torch.tensor(level - 1) - self.neg_min = torch.tensor(0) - - self.eps = 0 - - def __repr__(self): - return f"ST_BIFNode(level={self.level}, sym={self.sym}, pos_max={self.pos_max}, neg_min={self.neg_min}, q_threshold={self.q_threshold})" - - def reset(self): - self.q = 0.0 - self.cur_output = 0.0 - self.acc_q = 0.0 - # I believe this is some shot of early stopping machanism - self.is_work = False - self.spike_position = None - self.neg_spike_position = None - - def forward(self, input): - x = input / self.q_threshold - if ( - (not torch.is_tensor(x)) - and x == 0.0 - and (not torch.is_tensor(self.cur_output)) - and self.cur_output == 0.0 - ): - self.is_work = False - return x - - if not torch.is_tensor(self.cur_output): - self.cur_output = torch.zeros(x.shape, dtype=x.dtype).to(x.device) - self.acc_q = torch.zeros(x.shape).to(x.device) - self.q = torch.zeros(x.shape).to(x.device) + 0.5 - - self.is_work = True - - self.q = self.q + (x.detach() if torch.is_tensor(x) else x) - self.acc_q = torch.round(self.acc_q) - - spike_position = (self.q - 1 >= 0) & (self.acc_q < self.pos_max) - neg_spike_position = (self.q < -self.eps) & (self.acc_q > self.neg_min) - - self.cur_output[:] = 0 - self.cur_output[spike_position] = 1 - self.cur_output[neg_spike_position] = -1 - - self.acc_q = self.acc_q + self.cur_output - self.q[spike_position] = self.q[spike_position] - 1 - self.q[neg_spike_position] = self.q[neg_spike_position] + 1 - - if (x == 0).all() and (self.cur_output == 0).all(): - self.is_work = False - - return self.cur_output * self.q_threshold diff --git a/src/chop/nn/snn/modules/pool1d.py b/src/chop/nn/snn/modules/pool1d.py deleted file mode 100644 index 18652692c..000000000 --- a/src/chop/nn/snn/modules/pool1d.py +++ /dev/null @@ -1,158 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import ( - _size_any_t, - _size_1_t, - _size_2_t, - _size_3_t, - _ratio_any_t, -) -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - - -class MaxPool1d(nn.MaxPool1d, base.StepModule): - def __init__( - self, - kernel_size: _size_1_t, - stride: Optional[_size_1_t] = None, - padding: _size_1_t = 0, - dilation: _size_1_t = 1, - return_indices: bool = False, - ceil_mode: bool = False, - step_mode="s", - ) -> None: - """ - * :ref:`API in English ` - - .. _MaxPool1d-cn: - - :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) - :type step_mode: str - - 其他的参数API参见 :class:`torch.nn.MaxPool1d` - - * :ref:`中文 API ` - - .. _MaxPool1d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.MaxPool1d` for other parameters' API - """ - super().__init__( - kernel_size, stride, padding, dilation, return_indices, ceil_mode - ) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 4: - raise ValueError( - f"expected x with shape [T, N, C, L], but got x with shape {x.shape}!" - ) - x = functional.seq_to_ann_forward(x, super().forward) - - return x - - -class AvgPool1d(nn.AvgPool1d, base.StepModule): - def __init__( - self, - kernel_size: _size_1_t, - stride: _size_1_t = None, - padding: _size_1_t = 0, - ceil_mode: bool = False, - count_include_pad: bool = True, - step_mode="s", - ) -> None: - """ - * :ref:`API in English ` - - .. _AvgPool1d-cn: - - :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) - :type step_mode: str - - 其他的参数API参见 :class:`torch.nn.AvgPool1d` - - * :ref:`中文 API ` - - .. _AvgPool1d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.AvgPool1d` for other parameters' API - """ - super().__init__(kernel_size, stride, padding, ceil_mode, count_include_pad) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 4: - raise ValueError( - f"expected x with shape [T, N, C, L], but got x with shape {x.shape}!" - ) - x = functional.seq_to_ann_forward(x, super().forward) - - return x - - -class AdaptiveAvgPool1d(nn.AdaptiveAvgPool1d, base.StepModule): - def __init__(self, output_size, step_mode="s") -> None: - """ - * :ref:`API in English ` - - .. _AdaptiveAvgPool1d-cn: - - :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) - :type step_mode: str - - 其他的参数API参见 :class:`torch.nn.AdaptiveAvgPool1d` - - * :ref:`中文 API ` - - .. _AdaptiveAvgPool1d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.AdaptiveAvgPool1d` for other parameters' API - """ - super().__init__(output_size) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 4: - raise ValueError( - f"expected x with shape [T, N, C, L], but got x with shape {x.shape}!" - ) - x = functional.seq_to_ann_forward(x, super().forward) - - return x diff --git a/src/chop/nn/snn/modules/pool2d.py b/src/chop/nn/snn/modules/pool2d.py deleted file mode 100644 index 11f9c6ac6..000000000 --- a/src/chop/nn/snn/modules/pool2d.py +++ /dev/null @@ -1,161 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import ( - _size_any_t, - _size_1_t, - _size_2_t, - _size_3_t, - _ratio_any_t, -) -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - - -class MaxPool2d(nn.MaxPool2d, base.StepModule): - def __init__( - self, - kernel_size: _size_2_t, - stride: Optional[_size_2_t] = None, - padding: _size_2_t = 0, - dilation: _size_2_t = 1, - return_indices: bool = False, - ceil_mode: bool = False, - step_mode="s", - ) -> None: - """ - * :ref:`API in English ` - - .. _MaxPool2d-cn: - - :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) - :type step_mode: str - - 其他的参数API参见 :class:`torch.nn.MaxPool2d` - - * :ref:`中文 API ` - - .. _MaxPool2d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.MaxPool2d` for other parameters' API - """ - super().__init__( - kernel_size, stride, padding, dilation, return_indices, ceil_mode - ) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 5: - raise ValueError( - f"expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!" - ) - x = functional.seq_to_ann_forward(x, super().forward) - - return x - - -class AvgPool2d(nn.AvgPool2d, base.StepModule): - def __init__( - self, - kernel_size: _size_2_t, - stride: Optional[_size_2_t] = None, - padding: _size_2_t = 0, - ceil_mode: bool = False, - count_include_pad: bool = True, - divisor_override: Optional[int] = None, - step_mode="s", - ) -> None: - """ - * :ref:`API in English ` - - .. _AvgPool2d-cn: - - :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) - :type step_mode: str - - 其他的参数API参见 :class:`torch.nn.AvgPool2d` - - * :ref:`中文 API ` - - .. _AvgPool2d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.AvgPool2d` for other parameters' API - """ - super().__init__( - kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override - ) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 5: - raise ValueError( - f"expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!" - ) - x = functional.seq_to_ann_forward(x, super().forward) - - return x - - -class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, base.StepModule): - def __init__(self, output_size, step_mode="s") -> None: - """ - * :ref:`API in English ` - - .. _AdaptiveAvgPool2d-cn: - - :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) - :type step_mode: str - - 其他的参数API参见 :class:`torch.nn.AdaptiveAvgPool2d` - - * :ref:`中文 API ` - - .. _AdaptiveAvgPool2d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.AdaptiveAvgPool2d` for other parameters' API - """ - super().__init__(output_size) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 5: - raise ValueError( - f"expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!" - ) - x = functional.seq_to_ann_forward(x, super().forward) - - return x diff --git a/src/chop/nn/snn/modules/pool3d.py b/src/chop/nn/snn/modules/pool3d.py deleted file mode 100644 index 302feb2ff..000000000 --- a/src/chop/nn/snn/modules/pool3d.py +++ /dev/null @@ -1,161 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import ( - _size_any_t, - _size_1_t, - _size_2_t, - _size_3_t, - _ratio_any_t, -) -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - - -class MaxPool3d(nn.MaxPool3d, base.StepModule): - def __init__( - self, - kernel_size: _size_3_t, - stride: Optional[_size_3_t] = None, - padding: _size_3_t = 0, - dilation: _size_3_t = 1, - return_indices: bool = False, - ceil_mode: bool = False, - step_mode="s", - ) -> None: - """ - * :ref:`API in English ` - - .. _MaxPool3d-cn: - - :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) - :type step_mode: str - - 其他的参数API参见 :class:`torch.nn.MaxPool3d` - - * :ref:`中文 API ` - - .. _MaxPool3d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.MaxPool3d` for other parameters' API - """ - super().__init__( - kernel_size, stride, padding, dilation, return_indices, ceil_mode - ) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 6: - raise ValueError( - f"expected x with shape [T, N, C, D, H, W], but got x with shape {x.shape}!" - ) - x = functional.seq_to_ann_forward(x, super().forward) - - return x - - -class AvgPool3d(nn.AvgPool3d, base.StepModule): - def __init__( - self, - kernel_size: _size_3_t, - stride: Optional[_size_3_t] = None, - padding: _size_3_t = 0, - ceil_mode: bool = False, - count_include_pad: bool = True, - divisor_override: Optional[int] = None, - step_mode="s", - ) -> None: - """ - * :ref:`API in English ` - - .. _AvgPool3d-cn: - - :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) - :type step_mode: str - - 其他的参数API参见 :class:`torch.nn.AvgPool3d` - - * :ref:`中文 API ` - - .. _AvgPool3d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.AvgPool3d` for other parameters' API - """ - super().__init__( - kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override - ) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 6: - raise ValueError( - f"expected x with shape [T, N, C, D, H, W], but got x with shape {x.shape}!" - ) - x = functional.seq_to_ann_forward(x, super().forward) - - return x - - -class AdaptiveAvgPool3d(nn.AdaptiveAvgPool3d, base.StepModule): - def __init__(self, output_size, step_mode="s") -> None: - """ - * :ref:`API in English ` - - .. _AdaptiveAvgPool3d-cn: - - :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) - :type step_mode: str - - 其他的参数API参见 :class:`torch.nn.AdaptiveAvgPool3d` - - * :ref:`中文 API ` - - .. _AdaptiveAvgPool3d-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.AdaptiveAvgPool3d` for other parameters' API - """ - super().__init__(output_size) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor): - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - if x.dim() != 6: - raise ValueError( - f"expected x with shape [T, N, C, D, H, W], but got x with shape {x.shape}!" - ) - x = functional.seq_to_ann_forward(x, super().forward) - - return x diff --git a/src/chop/nn/snn/modules/roberta/__init__.py b/src/chop/nn/snn/modules/roberta/__init__.py deleted file mode 100644 index cda86a74d..000000000 --- a/src/chop/nn/snn/modules/roberta/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .attention import RobertaSelfAttentionZIPTF diff --git a/src/chop/nn/snn/modules/roberta/attention.py b/src/chop/nn/snn/modules/roberta/attention.py deleted file mode 100644 index b8253d4a1..000000000 --- a/src/chop/nn/snn/modules/roberta/attention.py +++ /dev/null @@ -1,265 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.jit import Final -import math -from copy import deepcopy -from typing import List, Optional, Tuple, Union -import math - - -from chop.nn.snn.modules.linear import LinearUnfoldBias -from chop.nn.snn.modules.neuron import ST_BIFNode -from chop.nn.snn.modules.softmax import SoftmaxZIPTF - - -def multi(x1_t, x2_t, x1_sum_t, x2_sum_t): - """ - SpikeZip-TF multi - """ - return ( - x1_sum_t @ x2_t.transpose(-2, -1) - + x1_t @ x2_sum_t.transpose(-2, -1) - - x1_t @ x2_t.transpose(-2, -1) - ) - - -def multi1(x1_t, x2_t, x1_sum_t, x2_sum_t): - """ - SpikeZip-TF multi - """ - return x1_sum_t @ x2_t + x1_t @ x2_sum_t - x1_t @ x2_t - - -class RobertaSelfAttentionZIPTF(nn.Module): - """ - ST-Spike Transformer Self Attention Module - """ - - def __init__(self, config, q_config, position_embedding_type=None): - super().__init__() - if config.hidden_size % config.num_attention_heads != 0 and not hasattr( - config, "embedding_size" - ): - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})" - ) - - self.level = q_config["level"] - self.neuron_type = q_config["neuron_type"] - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = LinearUnfoldBias( - config.hidden_size, - self.all_head_size, - level=q_config["level"], - neuron_type=q_config["neuron_type"], - ) - self.query_IF = ST_BIFNode(q_threshold=1.0, level=q_config["level"], sym=True) - self.key = LinearUnfoldBias( - config.hidden_size, - self.all_head_size, - level=q_config["level"], - neuron_type=q_config["neuron_type"], - ) - self.key_IF = ST_BIFNode(q_threshold=1.0, level=q_config["level"], sym=True) - self.value = LinearUnfoldBias( - config.hidden_size, - self.all_head_size, - level=q_config["level"], - neuron_type=q_config["neuron_type"], - ) - self.value_IF = ST_BIFNode(q_threshold=1.0, level=q_config["level"], sym=True) - self.attn_IF = ST_BIFNode(q_threshold=1.0, level=q_config["level"], sym=False) - self.after_attn_IF = ST_BIFNode( - q_threshold=1.0, level=q_config["level"], sym=False - ) - self.Ssoftmax = SoftmaxZIPTF() - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) - if ( - self.position_embedding_type == "relative_key" - or self.position_embedding_type == "relative_key_query" - ): - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding( - 2 * config.max_position_embeddings - 1, self.attention_head_size - ) - - self.is_decoder = config.is_decoder - - def reset(self): - # print("SAttention reset") - self.query_IF.reset() - self.key_IF.reset() - self.value_IF.reset() - self.attn_IF.reset() - self.after_attn_IF.reset() - self.Ssoftmax.reset() - self.query.reset() - self.key.reset() - self.value.reset() - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor]: - mixed_query_layer = self.query_IF(self.query(hidden_states)) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores( - self.key_IF(self.key(encoder_hidden_states)) - ) - value_layer = self.transpose_for_scores( - self.value_IF(self.value(encoder_hidden_states)) - ) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key_IF(self.key(hidden_states))) - value_layer = self.transpose_for_scores( - self.value_IF(self.value(hidden_states)) - ) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key_IF(self.key(hidden_states))) - value_layer = self.transpose_for_scores( - self.value_IF(self.value(hidden_states)) - ) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - - attention_scores = multi( - query_layer, - key_layer, - self.transpose_for_scores(self.query_IF.acc_q * self.query_IF.q_threshold), - self.transpose_for_scores(self.key_IF.acc_q * self.key_IF.q_threshold), - ) - - # attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - if ( - self.position_embedding_type == "relative_key" - or self.position_embedding_type == "relative_key_query" - ): - query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: - position_ids_l = torch.tensor( - key_length - 1, dtype=torch.long, device=hidden_states.device - ).view(-1, 1) - else: - position_ids_l = torch.arange( - query_length, dtype=torch.long, device=hidden_states.device - ).view(-1, 1) - position_ids_r = torch.arange( - key_length, dtype=torch.long, device=hidden_states.device - ).view(1, -1) - distance = position_ids_l - position_ids_r - - positional_embedding = self.distance_embedding( - distance + self.max_position_embeddings - 1 - ) - positional_embedding = positional_embedding.to( - dtype=query_layer.dtype - ) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum( - "bhld,lrd->bhlr", query_layer, positional_embedding - ) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum( - "bhld,lrd->bhlr", query_layer, positional_embedding - ) - relative_position_scores_key = torch.einsum( - "bhrd,lrd->bhlr", key_layer, positional_embedding - ) - attention_scores = ( - attention_scores - + relative_position_scores_query - + relative_position_scores_key - ) - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - # attention_probs = nn.functional.softmax(attention_scores, dim=-1) - attention_probs = self.Ssoftmax(attention_scores) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - attention_probs = self.attn_IF(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - # context_layer = torch.matmul(attention_probs, value_layer) - context_layer = multi1( - attention_probs, - value_layer, - (self.attn_IF.acc_q * self.attn_IF.q_threshold), - self.transpose_for_scores(self.value_IF.acc_q * self.value_IF.q_threshold), - ) - - context_layer = self.after_attn_IF(context_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = ( - (context_layer, attention_probs) if output_attentions else (context_layer,) - ) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs diff --git a/src/chop/nn/snn/modules/silu.py b/src/chop/nn/snn/modules/silu.py deleted file mode 100644 index d8b453f03..000000000 --- a/src/chop/nn/snn/modules/silu.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -from torch import nn - - -class SiLUZIPTF(nn.SiLU): - def __init__( - self, - normalized_shape, - eps: float = 1e-5, - elementwise_affine: bool = True, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) - self.X = 0.0 - self.Y_pre = None - - def reset(self): - self.X = 0.0 - self.Y_pre = None - - def forward(self, input): - self.X = self.X + input - Y = super().forward(self.X) - if self.Y_pre is not None: - Y_pre = self.Y_pre.detach().clone() - else: - Y_pre = 0.0 - self.Y_pre = Y - return Y - Y_pre diff --git a/src/chop/nn/snn/modules/softmax.py b/src/chop/nn/snn/modules/softmax.py deleted file mode 100644 index 70ab7bf69..000000000 --- a/src/chop/nn/snn/modules/softmax.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.jit import Final -import math -from copy import deepcopy -from typing import List, Optional, Tuple, Union -import math - -import chop.nn.snn.base as base - - -class SoftmaxZIPTF(nn.Softmax, base.StepModule): - """ - Stateful Softmax function - Copied from SpikeZIP-TF - https://arxiv.org/pdf/2406.03470 - """ - - def __init__(self, dim=-1, step_mode="s") -> None: - super().__init__(dim=dim) - self.X = 0.0 - self.Y_pre = 0.0 - self.step_mode = step_mode - - def reset(self): - self.X = 0.0 - self.Y_pre = 0.0 - - def forward(self, input): - if self.step_mode == "s": - self.X = input + self.X - Y = super().forward(self.X) - Y_pre = deepcopy(self.Y_pre) - self.Y_pre = Y - return Y - Y_pre - - elif self.step_mode == "m": - T = input.shape[0] - y_seq = [] - for t in range(T): - self.X = input[t] + self.X - Y = super().forward(self.X) - Y_pre = deepcopy(self.Y_pre) - self.Y_pre = Y - y_seq.append(Y - Y_pre) - return torch.stack(y_seq, dim=0) diff --git a/src/chop/nn/snn/modules/spiking_self_attention.py b/src/chop/nn/snn/modules/spiking_self_attention.py deleted file mode 100644 index d9d2ae32e..000000000 --- a/src/chop/nn/snn/modules/spiking_self_attention.py +++ /dev/null @@ -1,261 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import _size_any_t, _size_1_t, _size_2_t, _size_3_t -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - -from chop.nn.snn.modules import Conv2d, BatchNorm2d -import chop.nn.snn.modules.surrogate as surrogate -from chop.nn.snn.modules.neuron import LIFNode, ParametricLIFNode - -""" -This file contains the implementation of the Spiking Self-Attention module. Spikeformer -https://arxiv.org/abs/2403.14302 -""" - - -class Conv1x1(Conv2d): - def __init__( - self, - in_channels: int, - out_channels: int, - stride: _size_2_t = 1, - bias: bool = False, - ) -> None: - super().__init__( - in_channels, - out_channels, - kernel_size=1, - stride=stride, - padding=0, - dilation=1, - groups=1, - bias=bias, - padding_mode="zeros", - step_mode="m", - ) - - -class LIF(LIFNode): - def __init__(self): - super().__init__( - tau=2.0, - decay_input=True, - v_threshold=1.0, - v_reset=0.0, - surrogate_function=surrogate.ATan(), - detach_reset=True, - step_mode="m", - backend="cupy", - store_v_seq=False, - ) - - -class PLIF(ParametricLIFNode): - def __init__(self): - super().__init__( - init_tau=2.0, - decay_input=True, - v_threshold=1.0, - v_reset=0.0, - surrogate_function=surrogate.ATan(), - detach_reset=True, - step_mode="m", - backend="cupy", - store_v_seq=False, - ) - - -class BN(BatchNorm2d): - """ - BatchNorm2d with added extra warning message for input shape check. - """ - - def __init__( - self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - track_running_stats=True, - step_mode="m", - ): - super().__init__( - num_features, eps, momentum, affine, track_running_stats, step_mode - ) - - def forward(self, x: Tensor): - if x.dim() != 5: - raise ValueError( - f"expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!" - ) - return super().forward(x) - - -class DownsampleLayer(nn.Module): - def __init__(self, in_channels, out_channels, stride=2, activation=LIF) -> None: - super().__init__() - self.conv = Conv3x3(in_channels, out_channels, stride=stride) - self.norm = BN(out_channels) - self.activation = activation() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.activation(x) - x = self.conv(x) - x = self.norm(x) - return x - - -class SpikingMatmul(nn.Module): - def __init__(self, spike: str) -> None: - super().__init__() - assert spike == "l" or spike == "r" or spike == "both" - self.spike = spike - - def forward(self, left: torch.Tensor, right: torch.Tensor): - return torch.matmul(left, right) - - -class Conv3x3(Conv2d): - def __init__( - self, - in_channels: int, - out_channels: int, - stride: _size_2_t = 1, - dilation: _size_2_t = 1, - groups: int = 1, - bias: bool = False, - ) -> None: - super().__init__( - in_channels, - out_channels, - kernel_size=3, - stride=stride, - padding=dilation, - dilation=dilation, - groups=groups, - bias=bias, - padding_mode="zeros", - step_mode="m", - ) - - -class GWFFN(nn.Module): - def __init__(self, in_channels, num_conv=1, ratio=4, group_size=64, activation=LIF): - super().__init__() - inner_channels = in_channels * ratio - self.up = nn.Sequential( - activation(), - Conv1x1(in_channels, inner_channels), - BN(inner_channels), - ) - self.conv = nn.ModuleList() - for _ in range(num_conv): - self.conv.append( - nn.Sequential( - activation(), - Conv3x3( - inner_channels, - inner_channels, - groups=inner_channels // group_size, - ), - BN(inner_channels), - ) - ) - self.down = nn.Sequential( - activation(), - Conv1x1(inner_channels, in_channels), - BN(in_channels), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_feat_out = x.clone() - x = self.up(x) - x_feat_in = x.clone() - for m in self.conv: - x = m(x) - x = x + x_feat_in - x = self.down(x) - x = x + x_feat_out - return x - - -class DSSA(nn.Module): - def __init__(self, dim, num_heads, lenth, patch_size, activation=LIF): - super().__init__() - assert ( - dim % num_heads == 0 - ), f"dim {dim} should be divided by num_heads {num_heads}." - self.dim = dim - self.num_heads = num_heads - self.lenth = lenth - self.register_buffer("firing_rate_x", torch.zeros(1, 1, num_heads, 1, 1)) - self.register_buffer("firing_rate_attn", torch.zeros(1, 1, num_heads, 1, 1)) - self.init_firing_rate_x = False - self.init_firing_rate_attn = False - self.momentum = 0.999 - - self.activation_in = activation() - - self.W = Conv2d(dim, dim * 2, patch_size, patch_size, bias=False, step_mode="m") - self.norm = BN(dim * 2) - self.matmul1 = SpikingMatmul("r") - self.matmul2 = SpikingMatmul("r") - self.activation_attn = activation() - self.activation_out = activation() - - self.Wproj = Conv1x1(dim, dim) - self.norm_proj = BN(dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # X: [T, B, C, H, W] - T, B, C, H, W = x.shape - x_feat = x.clone() - x = self.activation_in(x) - - y = self.W(x) - y = self.norm(y) - y = y.reshape(T, B, self.num_heads, 2 * C // self.num_heads, -1) - y1, y2 = ( - y[:, :, :, : C // self.num_heads, :], - y[:, :, :, C // self.num_heads :, :], - ) - x = x.reshape(T, B, self.num_heads, C // self.num_heads, -1) - - if self.training: - firing_rate_x = x.detach().mean((0, 1, 3, 4), keepdim=True) - if not self.init_firing_rate_x and torch.all(self.firing_rate_x == 0): - self.firing_rate_x = firing_rate_x - self.init_firing_rate_x = True - self.firing_rate_x = self.firing_rate_x * self.momentum + firing_rate_x * ( - 1 - self.momentum - ) - scale1 = 1.0 / torch.sqrt(self.firing_rate_x * (self.dim // self.num_heads)) - attn = self.matmul1(y1.transpose(-1, -2), x) - attn = attn * scale1 - attn = self.activation_attn(attn) - - if self.training: - firing_rate_attn = attn.detach().mean((0, 1, 3, 4), keepdim=True) - if not self.init_firing_rate_attn and torch.all(self.firing_rate_attn == 0): - self.firing_rate_attn = firing_rate_attn - self.init_firing_rate_attn = True - self.firing_rate_attn = ( - self.firing_rate_attn * self.momentum - + firing_rate_attn * (1 - self.momentum) - ) - scale2 = 1.0 / torch.sqrt(self.firing_rate_attn * self.lenth) - out = self.matmul2(y2, attn) - out = out * scale2 - out = out.reshape(T, B, C, H, W) - out = self.activation_out(out) - - out = self.Wproj(out) - out = self.norm_proj(out) - out = out + x_feat - return out diff --git a/src/chop/nn/snn/modules/surrogate.py b/src/chop/nn/snn/modules/surrogate.py deleted file mode 100644 index 78f08495b..000000000 --- a/src/chop/nn/snn/modules/surrogate.py +++ /dev/null @@ -1,233 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math - -from chop.nn.snn.auto_cuda import cfunction - -tab4_str = "\t\t\t\t" # used for aligning code -curly_bracket_l = "{" -curly_bracket_r = "}" - - -from chop.nn.snn.functional.surrogate import ( - sigmoid_backward, - sigmoid, - atan_backward, - atan, -) - - -class SurrogateFunctionBase(nn.Module): - def __init__(self, alpha, spiking=True): - super().__init__() - self.spiking = spiking - self.alpha = alpha - - def set_spiking_mode(self, spiking: bool): - self.spiking = spiking - - def extra_repr(self): - return f"alpha={self.alpha}, spiking={self.spiking}" - - @staticmethod - def spiking_function(x, alpha): - raise NotImplementedError - - @staticmethod - def primitive_function(x, alpha): - raise NotImplementedError - - def cuda_code(self, x: str, y: str, dtype="fp32"): - raise NotImplementedError - - def cuda_code_start_comments(self): - return f"// start: spikingjelly.activation_based.surrogate.{self._get_name()}.cuda_code" - - def cuda_code_end_comments(self): - return f"// end: spikingjelly.activation_based.surrogate.{self._get_name()}.cuda_code" - - def forward(self, x: torch.Tensor): - if self.spiking: - return self.spiking_function(x, self.alpha) - else: - return self.primitive_function(x, self.alpha) - - def cuda_codes(self, y: str, x: str, dtype: str): - # new version - raise NotImplementedError - - -class Sigmoid(SurrogateFunctionBase): - def __init__(self, alpha=4.0, spiking=True): - """ - * :ref:`API in English ` - .. _Sigmoid.__init__-en: - - :param alpha: parameter to control smoothness of gradient - :param spiking: whether output spikes. The default is ``True`` which means that using ``heaviside`` in forward - propagation and using surrogate gradient in backward propagation. If ``False``, in forward propagation, - using the primitive function of the surrogate gradient function used in backward propagation - - The sigmoid surrogate spiking function. The gradient is defined by - - .. math:: - g'(x) = \\alpha * (1 - \\mathrm{sigmoid} (\\alpha x)) \\mathrm{sigmoid} (\\alpha x) - - The primitive function is defined by - - .. math:: - g(x) = \\mathrm{sigmoid}(\\alpha x) = \\frac{1}{1+e^{-\\alpha x}} - - .. image:: ../_static/API/activation_based/surrogate/Sigmoid.* - :width: 100% - - The function is used in [#STBP]_ [#roy2019scaling]_ [#SNNLSTM]_ [#SNU]_ . - """ - super().__init__(alpha, spiking) - - @staticmethod - def spiking_function(x, alpha): - return sigmoid.apply(x, alpha) - - @staticmethod - @torch.jit.script - def primitive_function(x: torch.Tensor, alpha: float): - return (x * alpha).sigmoid() - - @staticmethod - def backward(grad_output, x, alpha): - return sigmoid_backward(grad_output, x, alpha)[0] - - def cuda_code(self, x: str, y: str, dtype="fp32"): - sg_name = "sg_" + self._get_name() - alpha = str(self.alpha) + "f" - code = f""" - {tab4_str}{self.cuda_code_start_comments()} - """ - - if dtype == "fp32": - code += f""" - {tab4_str}const float {sg_name}_sigmoid_ax = 1.0f / (1.0f + expf(- {alpha} * {x})); - {tab4_str}const float {y} = (1.0f - {sg_name}_sigmoid_ax) * {sg_name}_sigmoid_ax * {alpha}; - """ - elif dtype == "fp16": - code += f""" - {tab4_str}const half2 {sg_name}_alpha = __float2half2_rn({alpha}); - {tab4_str}const half2 {sg_name}_sigmoid_ax = __h2div(__float2half2_rn(1.0f), __hadd2(h2exp(__hneg2(__hmul2({sg_name}_alpha, {x}))), __float2half2_rn(1.0f))); - {tab4_str}const half2 {y} = __hmul2(__hmul2(__hsub2(__float2half2_rn(1.0f), {sg_name}_sigmoid_ax), {sg_name}_sigmoid_ax), {sg_name}_alpha); - """ - else: - raise NotImplementedError - code += f""" - {tab4_str}{self.cuda_code_end_comments()} - """ - return code - - def cuda_codes(self, y: str, x: str, dtype: str): - return cfunction.sigmoid_backward(y=y, x=x, alpha=self.alpha, dtype=dtype) - - # plt.style.use(['science', 'muted', 'grid']) - # fig = plt.figure(dpi=200) - # x = torch.arange(-2.5, 2.5, 0.001) - # plt.plot(x.data, surrogate.heaviside(x), label='Heaviside', linestyle='-.') - # surrogate_function = surrogate.Sigmoid(alpha=5, spiking=False) - # y = surrogate_function(x) - # plt.plot(x.data, y.data, label='Primitive, $\\alpha=5$') - - # surrogate_function = surrogate.Sigmoid(alpha=5, spiking=True) - # x.requires_grad_(True) - # y = surrogate_function(x) - # z = y.sum() - # z.backward() - # plt.plot(x.data, x.grad, label='Gradient, $\\alpha=5$') - # plt.xlim(-2, 2) - # plt.legend() - # plt.title('Sigmoid surrogate function') - # plt.xlabel('Input') - # plt.ylabel('Output') - # plt.grid(linestyle='--') - # plt.show() - - -class ATan(SurrogateFunctionBase): - def __init__(self, alpha=2.0, spiking=True): - """ - * :ref:`API in English ` - - .. _ATan.__init__-en: - - The arc tangent surrogate spiking function. The gradient is defined by - - .. math:: - g'(x) = \\frac{\\alpha}{2(1 + (\\frac{\\pi}{2}\\alpha x)^2)} - - The primitive function is defined by - - .. math:: - g(x) = \\frac{1}{\\pi} \\arctan(\\frac{\\pi}{2}\\alpha x) + \\frac{1}{2} - - """ - super().__init__(alpha, spiking) - - @staticmethod - def spiking_function(x, alpha): - return atan.apply(x, alpha) - - @staticmethod - @torch.jit.script - def primitive_function(x: torch.Tensor, alpha: float): - return (math.pi / 2 * alpha * x).atan_() / math.pi + 0.5 - - @staticmethod - def backward(grad_output, x, alpha): - return atan_backward(grad_output, x, alpha)[0] - - def cuda_code(self, x: str, y: str, dtype="fp32"): - sg_name = "sg_" + self._get_name() - alpha = str(self.alpha) + "f" - code = f""" - {tab4_str}{self.cuda_code_start_comments()} - """ - if dtype == "fp32": - code += f""" - {tab4_str}const float {sg_name}_M_PI_2__alpha__x = ((float) 1.57079632679489661923) * {alpha} * {x}; - {tab4_str}const float {y} = {alpha} / 2.0f / (1.0f + {sg_name}_M_PI_2__alpha__x * {sg_name}_M_PI_2__alpha__x); - """ - elif dtype == "fp16": - code += f""" - {tab4_str}const half2 {sg_name}_alpha = __float2half2_rn({alpha}); - {tab4_str}const half2 {sg_name}_M_PI_2__alpha__x = __hmul2(__hmul2(__float2half2_rn((float) 1.57079632679489661923), {sg_name}_alpha), {x}); - {tab4_str}const half2 {y} = __h2div(__h2div({sg_name}_alpha, __float2half2_rn(2.0f)), __hfma2({sg_name}_M_PI_2__alpha__x, {sg_name}_M_PI_2__alpha__x, __float2half2_rn(1.0f))); - """ - else: - raise NotImplementedError - code += f""" - {tab4_str}{self.cuda_code_end_comments()} - """ - return code - - def cuda_codes(self, y: str, x: str, dtype: str): - return cfunction.atan_backward(y=y, x=x, alpha=self.alpha, dtype=dtype) - - # plt.style.use(['science', 'muted', 'grid']) - # fig = plt.figure(dpi=200) - # x = torch.arange(-2.5, 2.5, 0.001) - # plt.plot(x.data, surrogate.heaviside(x), label='Heaviside', linestyle='-.') - # surrogate_function = surrogate.ATan(alpha=3, spiking=False) - # y = surrogate_function(x) - # plt.plot(x.data, y.data, label='Primitive, $\\alpha=3$') - - # surrogate_function = surrogate.ATan(alpha=3, spiking=True) - # x.requires_grad_(True) - # y = surrogate_function(x) - # z = y.sum() - # z.backward() - # plt.plot(x.data, x.grad, label='Gradient, $\\alpha=3$') - # plt.xlim(-2, 2) - # plt.legend() - # plt.title('ATan surrogate function') - # plt.xlabel('Input') - # plt.ylabel('Output') - # plt.grid(linestyle='--') - # plt.show() diff --git a/src/chop/nn/snn/modules/upsample.py b/src/chop/nn/snn/modules/upsample.py deleted file mode 100644 index a51ecb221..000000000 --- a/src/chop/nn/snn/modules/upsample.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from torch import Tensor -from torch.nn.common_types import ( - _size_any_t, - _size_1_t, - _size_2_t, - _size_3_t, - _ratio_any_t, -) -from typing import Optional, List, Tuple, Union -from typing import Callable -import chop.nn.snn.base as base -import chop.nn.snn.functional as functional - - -class Upsample(nn.Upsample, base.StepModule): - def __init__( - self, - size: Optional[_size_any_t] = None, - scale_factor: Optional[_ratio_any_t] = None, - mode: str = "nearest", - align_corners: Optional[bool] = None, - recompute_scale_factor: Optional[bool] = None, - step_mode: str = "s", - ) -> None: - """ - * :ref:`API in English ` - - .. _Upsample-en: - - :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) - :type step_mode: str - - Refer to :class:`torch.nn.Upsample` for other parameters' API - """ - super().__init__( - size, scale_factor, mode, align_corners, recompute_scale_factor - ) - self.step_mode = step_mode - - def extra_repr(self): - return super().extra_repr() + f", step_mode={self.step_mode}" - - def forward(self, x: Tensor) -> Tensor: - if self.step_mode == "s": - x = super().forward(x) - - elif self.step_mode == "m": - x = functional.seq_to_ann_forward(x, super().forward) - - return x diff --git a/src/chop/nn/snn/modules/utils.py b/src/chop/nn/snn/modules/utils.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/chop/nn/snn/readme.md b/src/chop/nn/snn/readme.md deleted file mode 100644 index b8b1be885..000000000 --- a/src/chop/nn/snn/readme.md +++ /dev/null @@ -1,56 +0,0 @@ -## Overview - -This directory contains code designed for transformations in Spiking Neural Networks (SNNs) within the MASE framework. The code adapts and integrates modules from the SpikingJelly library, providing essential functionality for efficient SNN simulations and operations. - -## Directory Structure - -### `\functional` - -- **Description**: This folder contains adapted code from SpikingJelly, offering a functional style interface. -- **Key Features**: - - Contains containers that allow users to set simulation behavior, either as a **single-step** or **multi-step** process. - - Provides functional style implementations for SNN layers, enabling modular and flexible usage in various network architectures. - -### `\module` - -- **Description**: This folder includes module-based code adapted from SpikingJelly. -- **Key Features**: - - Contains modular containers similar to those in the `functional` directory, allowing the specification of single-step or multi-step simulations. - - Provides neural network modules specifically designed for SNNs, allowing users to build structured SNN models with ease. - - Provides neural scaling modules for ann to snn conversion. - -### `\auto_cuda` - -- **Description**: This folder is directly sourced from SpikingJelly without modifications. -- **Key Features**: - - Defines acceleration kernels to optimize SNN operations, particularly useful for using CUDA capabilities, cupy or just pytorch. - - -### `neuron.py` - -- **Description**: This file is directly sourced from SpikingJelly without modifications. -- **Key Features**: - - Defines neuron modules. - -### `base.py` - -- **Description**: This file is directly sourced from SpikingJelly without modifications. -- **Key Features**: - - Defines the base class for neuron modules (modules that has internal memory across inference) and the base class for modular-style containers. - -### `configuration.py` - -- **Description**: This file is directly sourced from SpikingJelly without modifications. -- **Key Features**: - - Defines various configuration variables used in SpikingJelly, such as settings for the number of CUDA threads. - -### `cuda_utils.py` - -- **Description**: This file is directly sourced from SpikingJelly without modifications. -- **Key Features**: - - Contains helper functions for running SNN modules with CUDA. - - -## Additional Information - -This codebase is built on top of SpikingJelly \ No newline at end of file diff --git a/src/chop/passes/__init__.py b/src/chop/passes/__init__.py index 36589a2d9..5c5a58c12 100644 --- a/src/chop/passes/__init__.py +++ b/src/chop/passes/__init__.py @@ -27,13 +27,11 @@ patch_metadata_transform_pass, insert_lora_adapter_transform_pass, fuse_lora_weights_transform_pass, - ann2snn_transform_pass, ) from .module.analysis import calculate_avg_bits_module_analysis_pass from .module.transforms import ( quantize_module_transform_pass, resharding_transform_pass, - ann2snn_module_transform_pass, ) from .onnx.analysis import ( diff --git a/src/chop/passes/graph/__init__.py b/src/chop/passes/graph/__init__.py index 9bb6a6c49..e71a7a1ff 100644 --- a/src/chop/passes/graph/__init__.py +++ b/src/chop/passes/graph/__init__.py @@ -28,7 +28,6 @@ logicnets_fusion_transform_pass, onnx_annotate_transform_pass, raise_granularity_transform_pass, - ann2snn_transform_pass, ) from .interface import ( @@ -115,7 +114,6 @@ "summarize_quantization": summarize_quantization_analysis_pass, "prune": prune_transform_pass, "prune_detach_hook": prune_detach_hook_transform_pass, - "ann2snn": ann2snn_transform_pass, # "remove_prune_wrappers": prune_unwrap_transform_pass, "conv_bn_fusion": conv_bn_fusion_transform_pass, "logicnets_fusion": logicnets_fusion_transform_pass, diff --git a/src/chop/passes/graph/transforms/__init__.py b/src/chop/passes/graph/transforms/__init__.py index be36ca803..6f6f93cff 100644 --- a/src/chop/passes/graph/transforms/__init__.py +++ b/src/chop/passes/graph/transforms/__init__.py @@ -1,6 +1,5 @@ from .pruning import prune_transform_pass, prune_detach_hook_transform_pass from .quantize import quantize_transform_pass, summarize_quantization_analysis_pass -from .snn import ann2snn_transform_pass from .utils import ( conv_bn_fusion_transform_pass, logicnets_fusion_transform_pass, diff --git a/src/chop/passes/graph/transforms/snn/__init__.py b/src/chop/passes/graph/transforms/snn/__init__.py deleted file mode 100644 index 1e98293bd..000000000 --- a/src/chop/passes/graph/transforms/snn/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .ann2snn import ann2snn_transform_pass diff --git a/src/chop/passes/graph/transforms/snn/ann2snn.py b/src/chop/passes/graph/transforms/snn/ann2snn.py deleted file mode 100644 index c7713c7ab..000000000 --- a/src/chop/passes/graph/transforms/snn/ann2snn.py +++ /dev/null @@ -1,240 +0,0 @@ -# ***************************************************************************************/ -# * Title: ann2snn -# * Reference: This code is adapted from spikingJelly cnn_mnist.py -# * Availability: https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/ann2snn/examples/cnn_mnist.py -# * Date: 07/11/2024 -# * Code version: 0.0.0.014 -# * -# ***************************************************************************************/ -from copy import copy, deepcopy -import logging -from chop.ir.graph.mase_metadata import MaseMetadata -from chop.nn.snn.modules.neuron import IFNode -from chop.passes.graph.transforms.quantize.quant_parsers.update_node_meta import ( - update_quant_meta_param, -) -import torch -from chop.passes.graph.interface.save_and_load import load_mase_graph_interface_pass -from chop.passes.graph.transforms.utils.conv_bn_fusion import ( - conv_bn_fusion_transform_pass, -) -from chop.nn.snn.modules.modules import VoltageHook, VoltageScaler -from tqdm import tqdm -from typing import Tuple - -from ...utils import ( - deepcopy_mase_graph, - get_mase_op, - get_mase_type, - get_node_actual_target, - get_parent_name, - get_similar_node_actual_target, - match_a_pattern, - get_node_target_by_name, -) - -CONVERTABLE_OP = { - "add", - "bmm", - "conv1d", - "conv2d", - "matmul", - "mul", - "linear", - "relu", - "sub", - "batch_norm2d", - "layer_norm", -} - - -def get_config(config: dict, name: str): - if name in config: - return config[name]["config"] - else: - return config["default"]["config"] - - -def attach_empty_mase_metadata(node): - node.meta["mase"] = MaseMetadata(node=node) - return node - - -def add_module_and_node( - fx_model: torch.fx.GraphModule, - target: str, - after: torch.fx.Node, - m: torch.nn.Module, - args: Tuple, -) -> torch.fx.Node: - """Add a node m with target name after the after_node""" - fx_model.add_submodule(target=target, m=m) - with fx_model.graph.inserting_after(n=after): - new_node = fx_model.graph.call_module(module_name=target, args=args) - return new_node - - -def replace_by_ifnode(graph, config: dict) -> torch.fx.GraphModule: - """ - * :ref:`API in English ` - - .. replace_by_ifnode-en: - - :param fx_model: Original fx_model - :type fx_model: torch.fx.GraphModule - :return: fx_model whose ReLU has been replaced by IF neuron. - :rtype: torch.fx.GraphModule - - ``replace_by_ifnode`` is used to replace ReLU with IF neuron. - - """ - - # TODO: Many of the code here need to be refactored when the spiking mase graph is available - - hook_cnt = -1 - fx_model = graph.model - - # for node in fx_model.graph.nodes: - for node in graph.fx_graph.nodes: - if node.op != "call_module": - continue - - if type(fx_model.get_submodule(node.target)) is VoltageHook: - if type(fx_model.get_submodule(node.args[0].target)) is torch.nn.ReLU: - node_config = get_config(config, get_mase_op(node.args[0])) - - hook_cnt += 1 - hook_node = node - relu_node = node.args[0] - if len(relu_node.args) != 1: - raise NotImplementedError( - "The number of relu_node.args should be 1." - ) - s = fx_model.get_submodule(node.target).scale.item() - target0 = "snn tailor." + str(hook_cnt) + ".0" # voltage_scaler - target1 = "snn tailor." + str(hook_cnt) + ".1" # IF_node - target2 = "snn tailor." + str(hook_cnt) + ".2" # voltage_scaler - m0 = VoltageScaler(1.0 / s) - if node_config.get("name") == "IFNode": - m1 = IFNode(v_threshold=1.0, v_reset=None) - else: - raise NotImplementedError("Not implemented yet.") - m2 = VoltageScaler(s) - - node0 = add_module_and_node( - fx_model, target0, hook_node, m0, relu_node.args - ) - node0 = attach_empty_mase_metadata(node0) - - # parent_name, name = get_parent_name(node.target) - # setattr(graph.modules[parent_name], name, m0) - - node1 = add_module_and_node(fx_model, target1, node0, m1, (node0,)) - node1 = attach_empty_mase_metadata(node1) - - node2 = add_module_and_node(fx_model, target2, node1, m2, args=(node1,)) - node2 = attach_empty_mase_metadata(node2) - - relu_node.replace_all_uses_with(node2) - node2.args = (node1,) - fx_model.graph.erase_node(hook_node) - fx_model.graph.erase_node(relu_node) - fx_model.delete_all_unused_submodules() - fx_model.graph.lint() - fx_model.recompile() - - return graph.model - - -def graph_iterator_ann2snn_by_name(graph, config: dict): - pass - - -def graph_iterator_ann2snn_by_type(graph, config: dict): - fuse_flag = config.get("fuse", False) - dataloader = config.get("train_data_loader") - device = config.get("device", "cpu") - - if fuse_flag: - graph, _ = conv_bn_fusion_transform_pass(graph) - - hook_cnt = -1 - - # Adding hooks to the graph - for node in graph.fx_graph.nodes: - if node.meta["mase"].parameters["common"] == {}: - # spiking node! Ignore for now - continue - node_config = get_config(config, get_mase_op(node)) - - if node.op == "call_module": - # NOTE: if the following list continues to grow, consider moving it to a separate file - if get_mase_op(node) == "relu": - hook_cnt += 1 - target = "snn tailor." + str(hook_cnt) + ".0" # voltage_hook] - - mode = node_config.get("mode", "99.9%") - momentum = node_config.get("momentum", 0.1) - m = VoltageHook(momentum=momentum, mode=mode) - # TODO: check this - new_node = add_module_and_node(graph.model, target, node, m, (node,)) - new_node = attach_empty_mase_metadata(new_node) - - graph.fx_graph.lint() - graph.model.recompile() # TODO: is this necessary? - - # calibrate the scale - for _, imgs in enumerate(tqdm(dataloader)): - graph.model(imgs["x"].to(device)) - - # snn = replace_by_ifnode(ann_with_hook).to(self.device) - graph.model = replace_by_ifnode(graph, config).to(device) - - return graph # return type: GraphModule - - -def ann2snn_transform_pass(graph, pass_args=None): - """ - Transform the graph from ANN to SNN. - - :param graph: The input graph to be transformed. - :type graph: MaseGraph - - :param pass_args: Additional arguments for the transformation. - :type pass_args: dict, optional - - .. code-block: python - - quan_args = { - "by": "type", # quantize by type, name, or regex_name - "default": {"config": {"name": None}}, # default config, this would be used for any node that does not have a specific config - "relu": { - "config": { - "name": "IFNode", # conversion scheme name supported are ["IFNode", "LIFNode"...] - - # Voltage normalization (ensure the output of the activation function is within the range of the neuron model [0,1]) - "mode": "99.9%", # conversion mode supported are ["max", "99.9%", 1.0/2, 1.0/3. 1.0/4, 1.0/5] - "momentum": 0.1, # momentum for the voltage normalization - "fuse": True, # Bool if true: fusing the conv and bn layer, vice versa - "device": "cpu", # device to perform the calibration - } - }, - } - - :return: The transformed graph. - :rtype: tuple - :raises ValueError: If the quantize "by" argument is unsupported. - - """ - by = pass_args.pop("by") - match by: - case "type": - graph = graph_iterator_ann2snn_by_type(graph, pass_args) - case "name": - graph = graph_iterator_ann2snn_by_name(graph, pass_args) - case _: - raise ValueError(f'Unsupported quantize "by": {by}') - - # link the model with graph - graph.model = torch.fx.GraphModule(graph.model, graph.fx_graph) - return graph, {} diff --git a/src/chop/passes/module/module_modify_helper.py b/src/chop/passes/module/module_modify_helper.py index b8bfe6864..05bb91ca0 100644 --- a/src/chop/passes/module/module_modify_helper.py +++ b/src/chop/passes/module/module_modify_helper.py @@ -1,4 +1,3 @@ -from chop.passes.module.state_dict_map import SPECIAL_CONVERT_PATTERNS import torch from functools import reduce, partial @@ -99,14 +98,7 @@ def set_module_by_name( def replace_by_name(network, name, module): original = get_module_by_name(network, name) - - # state_dict replacement - special_replacement = (type(original), type(module)) in SPECIAL_CONVERT_PATTERNS - if special_replacement: - new = SPECIAL_CONVERT_PATTERNS[(type(original), type(module))](original, module) - else: - new = weight_replacement(original, module) - + new = weight_replacement(original, module) network = set_module_by_name(network, name, new) return network diff --git a/src/chop/passes/module/state_dict_map.py b/src/chop/passes/module/state_dict_map.py index 0d35d70ca..1b5571d94 100644 --- a/src/chop/passes/module/state_dict_map.py +++ b/src/chop/passes/module/state_dict_map.py @@ -4,12 +4,6 @@ from copy import deepcopy from typing import Tuple -from chop.nn.quantizers.SNN.LSQ import LSQInteger -from chop.nn.quantized.modules.roberta.attention import RobertaSelfAttentionLSQInteger -from chop.nn.snn.modules.linear import LinearUnfoldBias -from chop.nn.snn.modules.roberta.attention import RobertaSelfAttentionZIPTF - -from chop.nn.snn.modules.neuron.st_bifnode import ST_BIFNode import torch from pathlib import Path from functools import reduce @@ -26,96 +20,3 @@ def match_a_pattern(name: str, patterns: list[str]) -> str | None: def check_is_huggingface_model(model): return isinstance(model, (PreTrainedModel, TFPreTrainedModel)) - - -def attn_convert( - QAttn: RobertaSelfAttentionLSQInteger, SAttn: RobertaSelfAttentionZIPTF -) -> RobertaSelfAttentionZIPTF: - # NOTE: level and neuron_type are configure during the initialization of the module through the config args - level = SAttn.level - neuron_type = SAttn.neuron_type - - SAttn.query = LinearUnfoldBias( - in_features=QAttn.query.in_features, - out_features=QAttn.query.out_features, - bias=QAttn.query.bias is not None, - neuron_type="ST-BIF", - level=level, - ) - SAttn.query.weight.data = QAttn.query.weight.data - SAttn.query.bias.data = QAttn.query.bias.data - - SAttn.key = LinearUnfoldBias( - in_features=QAttn.key.in_features, - out_features=QAttn.key.out_features, - bias=QAttn.key.bias is not None, - neuron_type="ST-BIF", - level=level, - ) - SAttn.key.weight.data = QAttn.key.weight.data - SAttn.key.bias.data = QAttn.key.bias.data - - SAttn.value = LinearUnfoldBias( - in_features=QAttn.value.in_features, - out_features=QAttn.value.out_features, - bias=QAttn.value.bias is not None, - neuron_type="ST-BIF", - level=level, - ) - SAttn.value.weight.data = QAttn.value.weight.data - SAttn.value.bias.data = QAttn.value.bias.data - - SAttn.query_IF.neuron_type = neuron_type - SAttn.query_IF.level = level - SAttn.query_IF.q_threshold = QAttn.query_quan.s.data - SAttn.query_IF.pos_max = QAttn.query_quan.pos_max - SAttn.query_IF.neg_min = QAttn.query_quan.neg_min - SAttn.query_IF.is_init = False - - SAttn.key_IF.neuron_type = neuron_type - SAttn.key_IF.level = level - SAttn.key_IF.q_threshold = QAttn.key_quan.s.data - SAttn.key_IF.pos_max = QAttn.key_quan.pos_max - SAttn.key_IF.neg_min = QAttn.key_quan.neg_min - SAttn.key_IF.is_init = False - - SAttn.value_IF.neuron_type = neuron_type - SAttn.value_IF.level = level - SAttn.value_IF.q_threshold = QAttn.value_quan.s.data - SAttn.value_IF.pos_max = QAttn.value_quan.pos_max - SAttn.value_IF.neg_min = QAttn.value_quan.neg_min - SAttn.value_IF.is_init = False - - SAttn.attn_IF.neuron_type = neuron_type - SAttn.attn_IF.level = level - SAttn.attn_IF.q_threshold = QAttn.attn_quan.s.data - SAttn.attn_IF.pos_max = QAttn.attn_quan.pos_max - SAttn.attn_IF.neg_min = QAttn.attn_quan.neg_min - SAttn.attn_IF.is_init = False - - SAttn.after_attn_IF.neuron_type = neuron_type - SAttn.after_attn_IF.level = level - SAttn.after_attn_IF.q_threshold = QAttn.after_attn_quan.s.data - SAttn.after_attn_IF.pos_max = QAttn.after_attn_quan.pos_max - SAttn.after_attn_IF.neg_min = QAttn.after_attn_quan.neg_min - SAttn.after_attn_IF.is_init = False - - return SAttn - - -def lsqinteger_to_st_bif(LSQ: LSQInteger, ST_BIF: ST_BIFNode) -> ST_BIFNode: - - ST_BIF.q_threshold = LSQ.s.data - ST_BIF.sym = LSQ.sym - ST_BIF.level = LSQ.level - ST_BIF.pos_max = LSQ.pos_max - ST_BIF.neg_min = LSQ.neg_min - ST_BIF.is_init = False - - return ST_BIF - - -SPECIAL_CONVERT_PATTERNS = { - (RobertaSelfAttentionLSQInteger, RobertaSelfAttentionZIPTF): attn_convert, - (LSQInteger, ST_BIFNode): lsqinteger_to_st_bif, -} diff --git a/src/chop/passes/module/transforms/__init__.py b/src/chop/passes/module/transforms/__init__.py index 7e3164b8b..a52eae460 100644 --- a/src/chop/passes/module/transforms/__init__.py +++ b/src/chop/passes/module/transforms/__init__.py @@ -1,4 +1,3 @@ from .autosharding import resharding_transform_pass from .quantize import quantize_module_transform_pass -from .snn import ann2snn_module_transform_pass from .attention import attention_swap_transform_pass diff --git a/src/chop/passes/module/transforms/snn/__init__.py b/src/chop/passes/module/transforms/snn/__init__.py deleted file mode 100644 index c89bd6f54..000000000 --- a/src/chop/passes/module/transforms/snn/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .ann2snn import ann2snn_module_transform_pass diff --git a/src/chop/passes/module/transforms/snn/ann2snn.py b/src/chop/passes/module/transforms/snn/ann2snn.py deleted file mode 100644 index 60932daed..000000000 --- a/src/chop/passes/module/transforms/snn/ann2snn.py +++ /dev/null @@ -1,181 +0,0 @@ -from chop.nn.quantizers.SNN.LSQ import LSQInteger -import torch - -from chop.nn.snn.modules import spiking_module_map -from ...module_modify_helper import ( - manual_instantiate_module, - replace_by_name, - instantiate_module, -) -from ...state_dict_map import match_a_pattern, check_is_huggingface_model - - -def get_config(config: dict, name: str): - if name in config: - return config[name]["config"] - else: - return config["default"]["config"] - - -def convert_by_type(network, pass_args): - is_huggingface_model = check_is_huggingface_model(network) - - for type_name, conversion_config in pass_args.items(): - n_m = {} - for n, m in network.named_modules(): - n_m[n] = m - - if type_name == "linear": - module = torch.nn.Linear - elif type_name == "conv2d": - module = torch.nn.Conv2d - elif type_name == "embedding": - module = torch.nn.Embedding - elif type_name == "layernorm": - module = torch.nn.LayerNorm - elif type_name == "relu": - module = torch.nn.ReLU - elif type_name == "lsqinteger": - module = LSQInteger - else: - raise ValueError(f"{type_name} is not supported!") - - is_manual_instantiate = conversion_config.get("manual_instantiate", False) - conversion_config = conversion_config["config"] - postfix = conversion_config.pop("name") - - for n, m in n_m.items(): - if isinstance(m, module): - # same across all convert methods - additional_module_args = ( - {"config": conversion_config, "network_config": network.config} - if is_huggingface_model - else {"config": conversion_config} - ) - - if is_manual_instantiate: - new_m = manual_instantiate_module( - m, postfix, spiking_module_map, additional_module_args - ) - else: - new_m = instantiate_module( - m, postfix, spiking_module_map, additional_module_args - ) - network = replace_by_name(network, n, new_m) - - return network - - -def convert_by_name(network, pass_args): - is_huggingface_model = check_is_huggingface_model(network) - is_manual_instantiate = pass_args.get("manual_instantiate", False) - - conversion_names = pass_args.keys() - n_m = {} - for n, m in network.named_modules(): - n_m[n] = m - - for n, m in n_m.items(): - if n in conversion_names: - conversion_config = pass_args[n] - - conversion_config = conversion_config["config"] - postfix = conversion_config.pop("name") - - # same across all convert methods - additional_module_args = ( - {"config": conversion_config, "network_config": network.config} - if is_huggingface_model - else {"config": conversion_config} - ) - - if is_manual_instantiate: - new_m = manual_instantiate_module( - m, postfix, spiking_module_map, additional_module_args - ) - else: - new_m = instantiate_module( - m, postfix, spiking_module_map, additional_module_args - ) - network = replace_by_name(network, n, new_m) - - return network - - -def convert_by_regex_name(network, pass_args): - is_huggingface_model = check_is_huggingface_model(network) - is_manual_instantiate = pass_args.get("manual_instantiate", False) - - patterns = list(pass_args.keys()) - n_m = {} - for n, m in network.named_modules(): - n_m[n] = m - - for n, m in n_m.items(): - matched_pattern = match_a_pattern(n, patterns) - if not matched_pattern: - continue - - conversion_config = pass_args[matched_pattern]["config"] - postfix = conversion_config["name"] - - # same across all convert methods - additional_module_args = ( - {"config": conversion_config, "network_config": network.config} - if is_huggingface_model - else {"config": conversion_config} - ) - - if is_manual_instantiate: - new_m = manual_instantiate_module( - m, postfix, spiking_module_map, additional_module_args - ) - else: - new_m = instantiate_module( - m, postfix, spiking_module_map, additional_module_args - ) - network = replace_by_name(network, n, new_m) - - return network - - -def ann2snn_module_transform_pass(network, pass_args): - """ - Apply spike neural network (SNN) transformation to the input network. - - :param network: The input network to be transformed. - :type network: torch.nn.Module - - :param pass_args: Additional arguments for the transformation. - :type pass_args: dict, optional - - Examples pass_args: - - .. code-block:: python - - pass_args = { - "by": "type", # transform by type, name, or regex_name - "default": {"config": {"name": None}}, - "linear": { - "config": { - "name": "unfold_bias", - } - }, - } - - :return: The transformed torch.nn.Module. - :rtype: tuple - :raises ValueError: If the convert "by" argument is unsupported. - - """ - by = pass_args.pop("by") - match by: - case "type": - network = convert_by_type(network, pass_args) - case "name": - network = convert_by_name(network, pass_args) - case "regex_name": - network = convert_by_regex_name(network, pass_args) - case _: - raise ValueError(f'Unsupported conversion "by": {by}') - return network, {} diff --git a/test/nn/snn/test_ann2snn.py b/test/nn/snn/test_ann2snn.py deleted file mode 100644 index 1b8bbbcfa..000000000 --- a/test/nn/snn/test_ann2snn.py +++ /dev/null @@ -1,218 +0,0 @@ -import logging -import sys, os - -sys.path.append(os.path.join(os.path.dirname(__file__), "src")) -# import chop - -from chop.tools.checkpoint_load import load_model -import numpy as np -import torch -import tqdm -from chop.dataset import MaseDataModule, get_dataset_info -from chop.ir.graph.mase_graph import MaseGraph -from chop import models -from chop.tools.get_input import InputGenerator, get_dummy_input -from chop.actions.train import train -from chop.actions.test import test -from lightning.pytorch.loggers.tensorboard import TensorBoardLogger - -from chop.passes.graph import ( - init_metadata_analysis_pass, - add_common_metadata_analysis_pass, - quantize_transform_pass, - summarize_quantization_analysis_pass, - verify_common_metadata_analysis_pass, -) - -from chop.passes.graph.utils import deepcopy_mase_graph -from chop.passes.graph.transforms.snn.ann2snn import ann2snn_transform_pass - -model_name = "cnv-toy" -dataset_name = "cifar10" -BATCH_SIZE = 32 - -import torch.nn as nn - - -def val(net, device, data_loader, T=None): - net.eval().to(device) - correct = 0.0 - total = 0.0 - if T is not None: - corrects = np.zeros(T) - with torch.no_grad(): - for batch, (img, label) in enumerate(data_loader): - img = img.to(device) - if T is None: - out = net(img) - correct += (out.argmax(dim=1) == label.to(device)).float().sum().item() - else: - for m in net.modules(): - if hasattr(m, "reset"): - m.reset() - for t in range(T): - if t == 0: - out = net(img) - else: - out += net(img) - corrects[t] += ( - (out.argmax(dim=1) == label.to(device)).float().sum().item() - ) - total += out.shape[0] - return correct / total if T is None else corrects / total - - -# get dataset information -dataset_info = get_dataset_info(dataset_name) - -# get model information -model_info = models.get_model_info(model_name) - -# get data module -data_module = MaseDataModule( - model_name=model_name, - name=dataset_name, - batch_size=BATCH_SIZE, - num_workers=8, - tokenizer=None, - max_token_len=None, -) -data_module.prepare_data() -data_module.setup() -# NOTE: We only support vision classification models for now. -dummy_input = get_dummy_input(model_info, data_module, "cls", "cpu") - -# get an input generator to calibrate the spiking normalization factor during conversion -input_generator = InputGenerator( - model_info=model_info, - data_module=data_module, - task="cls", - which_dataloader="train", -) - -model = models.get_model(model_name, pretrained=False, dataset_info=dataset_info) - - -# This line transforms a nn.Module to a MaseGraph -mg = MaseGraph(model=model) - -# Apply initialization passes to populate information in the graph -mg, _ = init_metadata_analysis_pass(mg, {}) -mg, _ = add_common_metadata_analysis_pass( - mg, {"dummy_in": dummy_input, "add_value": False} -) - -# ------------------------------------------------------------ -# Training the base ANN -# ------------------------------------------------------------ - -plt_trainer_args = { - "max_epochs": 10, - "devices": 1, - "accelerator": "cuda", -} - -# save_path = "/home/thw20/projects/mase/mase_output/snn/training_ckpts" -# visualizer_save_path = ( -# "/home/thw20/projects/mase/mase_output/snn/software/training_ckpts" -# ) -# visualizer = TensorBoardLogger( -# save_dir=visualizer_save_path, -# ) - -# train( -# model=mg.model, -# model_info=model_info, -# dataset_info=dataset_info, -# weight_decay=1e-4, -# task="cls", -# data_module=data_module, -# optimizer="adam", -# learning_rate=1e-5, -# plt_trainer_args=plt_trainer_args, -# scheduler_args=None, -# save_path=save_path, -# load_name=None, -# load_type="pl", -# visualizer=visualizer, -# auto_requeue=False, -# ) - - -# train( -# model=mg.model, -# model_info=model_info, -# dataset_info=dataset_info, -# weight_decay=1e-4, -# task="cls", -# data_module=data_module, -# optimizer="adam", -# learning_rate=1e-5, -# plt_trainer_args=plt_trainer_args, -# scheduler_args=None, -# save_path=None, -# load_name=None, -# load_type="pl", -# visualizer=None, -# auto_requeue=False, -# ) - -# test( -# model=mg.model, -# model_info=model_info, -# data_module=data_module, -# dataset_info=dataset_info, -# task="cls", -# optimizer="adam", -# learning_rate=1e-5, -# weight_decay=1e-4, -# plt_trainer_args=plt_trainer_args, -# auto_requeue=False, -# save_path=save_path, -# visualizer=visualizer, -# load_name="/home/thw20/projects/mase/mase_output/snn/training_ckpts/best.ckpt", -# load_type='pl', -# ) -# print(val(mg.model, "cuda", data_module.test_dataloader())) - - -# ann_model = load_model( -# load_name="/home/thw20/projects/mase/mase_output/snn/training_ckpts/best.ckpt", -# load_type="pl", -# model=model, -# ) -# print(val(ann_model, "cuda", data_module.test_dataloader())) - -# ------------------------------------------------------------ -# Convert the base ANN to SNN and test -# ------------------------------------------------------------ - -quan_args = { - "by": "type", - "default": {"config": {"name": None}}, - "fuse": True, - "relu": { - "config": { - "name": "IFNode", - "mode": "99.9%", - "momentum": 0.1, - } - }, - "train_data_loader": input_generator, - "device": "cpu", # "device": "cuda", -} - -model.to("cpu") # model.to("gpu") -mg, _ = ann2snn_transform_pass(mg, quan_args) -# print(val(mg.model, "cuda", data_module.test_dataloader(), T=10)) - - -# ------------------------------------------------------------ -# load the SNN mz graph and test -# ------------------------------------------------------------ -# snn_model = load_model( -# load_name="/home/thw20/projects/mase/mase_output/cnv_toy_cls_cifar10_2024-10-23/software/transform/transformed_ckpt/graph_module.mz", -# load_type="mz", -# model=model, -# ) -# print(val(snn_model, "cuda", data_module.test_dataloader(), T=20)) diff --git a/test/passes/module/transforms/ann2snn/test_ann2snn_module_roberta.py b/test/passes/module/transforms/ann2snn/test_ann2snn_module_roberta.py deleted file mode 100644 index b7236b43c..000000000 --- a/test/passes/module/transforms/ann2snn/test_ann2snn_module_roberta.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python3 -# This example converts a simple MLP model to Verilog -import logging -import os -import sys - -from chop.passes.module.transforms.snn.ann2snn import ann2snn_module_transform_pass -import torch -import torch.nn as nn - -from pathlib import Path - -sys.path.append(Path(__file__).resolve().parents[5].as_posix()) - - -from chop.passes.module.transforms import quantize_module_transform_pass - - -import torch -from torch import nn -from transformers import RobertaForSequenceClassification, AutoTokenizer - -pretrained = "XianYiyk/roberta-relu-pretrained-sst2" -bert = RobertaForSequenceClassification.from_pretrained(pretrained, num_labels=2) -tokenizer = AutoTokenizer.from_pretrained(pretrained, do_lower_case=True) -for param in bert.parameters(): - param.requires_grad = True # QAT training - - -# def test_ann2snn_module_transform_pass(): -quan_pass_args = { - "by": "regex_name", - "roberta\.encoder\.layer\.\d+\.attention\.self": { - "config": { - "name": "lsqinteger", - "level": 32, - } - }, - "roberta\.encoder\.layer\.\d+\.attention\.output": { - "config": { - "name": "lsqinteger", - "level": 32, - } - }, - "roberta\.encoder\.layer\.\d+\.output": { - "config": { - "name": "lsqinteger", - "level": 32, - } - }, - "roberta\.encoder\.layer\.\d+\.intermediate": { - "config": { - "name": "lsqinteger", - "level": 32, - } - }, - "classifier": { - "config": { - "name": "lsqinteger", - "level": 32, - } - }, -} -mg, _ = quantize_module_transform_pass(bert, quan_pass_args) -# f = open(f"qann_model_arch.txt", "w") -# f.write(str(mg)) -# f.close() - -convert_pass_args = { - "by": "regex_name", - "roberta\.encoder\.layer\.\d+\.attention\.self": { - "config": { - "name": "zip_tf", - "level": 32, - "neuron_type": "ST-BIF", - }, - }, -} -mg, _ = ann2snn_module_transform_pass(mg, convert_pass_args) - -convert_pass_args = { - "by": "type", - "embedding": { - "config": { - "name": "zip_tf", - }, - }, - "linear": { - "config": { - "name": "unfold_bias", - "level": 32, - "neuron_type": "ST-BIF", - }, - }, - "conv2d": { - "config": { - "name": "zip_tf", - "level": 32, - "neuron_type": "ST-BIF", - }, - }, - "layernorm": { - "config": { - "name": "zip_tf", - }, - }, - "relu": { - "manual_instantiate": True, - "config": { - "name": "identity", - }, - }, - "lsqinteger": { - "manual_instantiate": True, - "config": { - "name": "st_bif", - # Default values. These would be replaced by the values from the LSQInteger module, so it has no effect. - # "q_threshold": 1, - # "level": 32, - # "sym": True, - }, - }, -} -mg, _ = ann2snn_module_transform_pass(mg, convert_pass_args) - -# f = open(f"spiking_model_arch.txt", "w") -# f.write(str(mg)) -# f.close()