Source code for qualia_plugin_snn.learningmodel.pytorch.SResNet

"""Contains the template for a residual spiking neural network."""

from __future__ import annotations

import logging
import sys
from typing import Callable, Protocol, cast

import numpy as np
import torch
from qualia_core.learningmodel.pytorch.layers import Add
from qualia_core.typing import TYPE_CHECKING
from torch import nn

from qualia_plugin_snn.learningmodel.pytorch.layers.spikingjelly import layers1d as sjlayers1d
from qualia_plugin_snn.learningmodel.pytorch.layers.spikingjelly import layers2d as sjlayers2d

from .SNN import SNN

if TYPE_CHECKING:
    from types import ModuleType  # noqa: TCH003

    from qualia_core.typing import RecursiveConfigDict

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

logger = logging.getLogger(__name__)

[docs] class BasicBlockBuilder(Protocol): """Signature for basicblockbuilder. Used to bind hyperparameters constant across all the ResNet blocks. """
[docs] def __call__(self, # noqa: PLR0913 in_planes: int, planes: int, kernel_size: int, stride: int, padding: int) -> BasicBlock: """Build a :class:`BasicBlock`. :param in_planes: Number of input channels :param planes: Number of filters (i.e., output channels) in the main branch Conv layers :param kernel_size: ``kernel_size`` for the main branch Conv layers :param stride: ``kernel_size`` for the MaxPool layers, no MaxPool layer added if `1` :param padding: Padding for the main branch Conv layers :return: A :class:`BasicBlock` """ ...
[docs] class BasicBlock(nn.Module): r"""A single ResNetv1 block. Structure is: .. code-block:: | / \ | | Conv | | | BatchNorm | | | MaxPool Conv | | IF BatchNorm | | Conv MaxPool | | BatchNorm IF | | IF | | | \ / | Add Main (left) branch Conv use ``kernel_size=kernel_size``, while residual (right) branch Conv use ``kernel_size=1``. BatchNorm layers will be absent if ``batch_norm == False`` MaxPool layer will be absent if ``stride == 1``. Residual (right) branch Conv layer will be asbent if ``in_planes==planes``, except if ``force_projection_with_stride==True`` and ``stride != 1``. """ expansion: int = 1 #: Unused
[docs] def __init__(self, # noqa: PLR0913 sjlayers_t: ModuleType, in_planes: int, planes: int, kernel_size: int, stride: int, padding: int, batch_norm: bool, # noqa: FBT001 bn_momentum: float, force_projection_with_stride: bool, # noqa: FBT001 create_neuron: Callable[[], nn.Module], step_mode: str) -> None: """Construct :class:`BasicBlock`. :param sjlayers_t: Module containing the aliased layers to use (1D or 2D) :param in_planes: Number of input channels :param planes: Number of filters (i.e., output channels) in the main branch Conv layers :param kernel_size: ``kernel_size`` for the main branch Conv layers :param stride: ``kernel_size`` for the MaxPool layers, no MaxPool layer added if `1` :param padding: Padding for the main branch Conv layers :param batch_norm: If ``True``, add BatchNorm layer after each Conv layer :param bn_momentum: BatchNorm layer ``momentum`` :param force_projection_with_stride: If ``True``, residual Conv layer is kept when ``stride != 1`` even if ``in_planes == planes`` :param create_neuron: :meth:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.create_neuron` method to instantiate a spiking neuron :param step_mode: SpikingJelly ``step_mode`` from :attr:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.step_mode` """ super().__init__() self.expansion = 1 self.batch_norm = batch_norm self.force_projection_with_stride = force_projection_with_stride self.in_planes = in_planes self.planes = planes self.stride = stride self.conv1 = sjlayers_t.Conv(in_planes, planes, kernel_size=kernel_size, stride=1, padding=padding, bias=not batch_norm, step_mode=step_mode) if batch_norm: self.bn1 = sjlayers_t.BatchNorm(planes, momentum=bn_momentum, step_mode=step_mode) if self.stride != 1: self.pool1 = sjlayers_t.MaxPool(stride, step_mode=step_mode) # QUANTIZED NEURON self.neuron1 = create_neuron() self.conv2 = sjlayers_t.Conv(planes, planes, kernel_size=kernel_size, stride=1, padding=padding, bias=not batch_norm, step_mode=step_mode) if batch_norm: self.bn2 = sjlayers_t.BatchNorm(planes, momentum=bn_momentum, step_mode=step_mode) self.neuron2 = create_neuron() # residual path if self.stride != 1: self.smax = sjlayers_t.MaxPool(stride, step_mode=step_mode) if self.in_planes != self.expansion*self.planes or force_projection_with_stride and self.stride != 1: self.sconv = sjlayers_t.Conv(in_planes, self.expansion*planes, kernel_size=1, stride=1, bias=not batch_norm, step_mode=step_mode) self.neuronr = create_neuron() if batch_norm: self.sbn = sjlayers_t.BatchNorm(self.expansion*planes, momentum=bn_momentum, step_mode=step_mode) self.add = Add()
[docs] @override def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002 """Forward of ResNet block. :param input: Input tensor :return: Output tensor """ x = input out = self.conv1(x) if self.batch_norm: out = self.bn1(out) if self.stride != 1: out = self.pool1(out) out = self.neuron1(out) out = self.conv2(out) if self.batch_norm: out = self.bn2(out) out = self.neuron2(out) # shortcut tmp = x if self.in_planes != self.expansion*self.planes or self.force_projection_with_stride and self.stride != 1: tmp = self.sconv(tmp) if self.batch_norm: tmp = self.sbn(tmp) if self.stride != 1: tmp = self.smax(tmp) if self.in_planes != self.expansion*self.planes or self.force_projection_with_stride and self.stride != 1: tmp = self.neuronr(tmp) return self.add(out, tmp)
[docs] class SResNet(SNN): """Residual spiking neural network template. Similar to :class:`qualia_core.learningmodel.pytorch.ResNet.ResNet` but with spiking neuron activation layers (e.g., IF) instead of :class:`torch.nn.ReLU`. Example TOML configuration for a 2D ResNetv1-18 over 4 timesteps with soft-reset multi-step IF based on the SResNet template: .. code-block:: toml [[model]] name = "SResNetv1-18" params.filters = [64, 64, 128, 256, 512] params.kernel_sizes = [ 7, 3, 3, 3, 3] params.paddings = [ 3, 1, 1, 1, 1] params.strides = [ 2, 1, 1, 1, 1] params.num_blocks = [ 2, 2, 2, 2] params.prepool = 1 params.postpool = 'max' params.batch_norm = true params.dims = 2 params.timesteps = 4 params.neuron.kind = 'IFNode' params.neuron.params.v_reset = false # Soft reset params.neuron.params.v_threshold = 1.0 params.neuron.params.detach_reset = true params.neuron.params.step_mode = 'm' # Multi-step mode, make sure to use SpikingJellyMultiStep learningframework params.neuron.params.backend = 'torch' """
[docs] def __init__(self, # noqa: PLR0913, C901 input_shape: tuple[int, ...], output_shape: tuple[int, ...], filters: list[int], kernel_sizes: list[int], num_blocks: list[int], strides: list[int], paddings: list[int], prepool: int = 1, postpool: str = 'max', batch_norm: bool = False, # noqa: FBT001, FBT002 bn_momentum: float = 0.1, force_projection_with_stride: bool = True, # noqa: FBT001, FBT002 neuron: RecursiveConfigDict | None = None, timesteps: int = 2, dims: int = 1, basicblockbuilder: BasicBlockBuilder | None = None) -> None: """Construct :class:`SResNet`. Structure is: .. code-block:: Input | AvgPool | Conv | BatchNorm | IF | BasicBlock | | BasicBlock | GlobalPool | Flatten | Linear :param input_shape: Input shape :param output_shape: Output shape :param filters: List of ``out_channels`` for Conv layers inside each :class:`BasicBlock` group, must be of the same size as ``num_blocks``, first element is for the first Conv layer at the beginning of the network :param kernel_sizes: List of ``kernel_size`` for Conv layers inside each :class:`BasicBlock` group, must of the same size as ``num_blocks``, first element is for the first Conv layer at the beginning of the network :param num_blocks: List of number of :class:`BasicBlock` in each group, also defines the number of :class:`BasicBlock` groups inside the network :param strides: List of ``kernel_size`` for MaxPool layers inside each :class:`BasicBlock` group, must of the same size as ``num_blocks``, ``stride`` is applied only to the first :class:`BasicBlock` of the group, next :class:`BasicBlock` in the group use a ``stride`` of ``1``, first element is the stride of the first Conv layer at the beginning of the network :param paddings: List of ``padding`` for Conv layer inside each :class:`BasicBlock` group, must of the same size as ``num_blocks``, first element is for the first Conv layer at the beginning of the network :param prepool: AvgPool layer ``kernel_size`` to add at the beginning of the network, no layer added if 0 :param postpool: Global pooling layer type after all :class:`BasicBlock`, either `max` for MaxPool or `avg` for AvgPool :param batch_norm: If ``True``, add a BatchNorm layer after each Conv layer, otherwise no layer added :param bn_momentum: BatchNorm ``momentum`` :param force_projection_with_stride: If ``True``, residual Conv layer is kept when ``stride != 1`` even if ``in_planes == planes`` inside a :class:`BasicBlock` :param neuron: Spiking neuron configuration, see :meth:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.__init__` :param timesteps: Number of timesteps :param dims: Either 1 or 2 for 1D or 2D convolutional network. :param basicblockbuilder: Optional function with :meth:`BasicBlockBuilder.__call__` signature to build a basic block after binding constants common across all basic blocks """ super().__init__(input_shape=input_shape, output_shape=output_shape, timesteps=timesteps, neuron=neuron) from spikingjelly.activation_based.layer import Flatten, Linear # type: ignore[import-untyped] sjlayers_t: ModuleType if dims == 1: sjlayers_t = sjlayers1d elif dims == 2: # noqa: PLR2004 sjlayers_t = sjlayers2d else: logger.error('Only dims=1 or dims=2 supported, got: %s', dims) raise ValueError if basicblockbuilder is None: def builder(in_planes: int, planes: int, kernel_size: int, stride: int, padding: int) -> BasicBlock: return BasicBlock( sjlayers_t=sjlayers_t, in_planes=in_planes, planes=planes, kernel_size=kernel_size, stride=stride, padding=padding, batch_norm=batch_norm, bn_momentum=bn_momentum, force_projection_with_stride=force_projection_with_stride, create_neuron=self.create_neuron, step_mode=self.step_mode) basicblockbuilder = builder self.in_planes: int = filters[0] self.batch_norm = batch_norm self.num_blocks = num_blocks if prepool > 1: self.prepool = sjlayers_t.AvgPool(prepool, step_mode=self.step_mode) self.conv1 = sjlayers_t.Conv(input_shape[-1], filters[0], kernel_size=kernel_sizes[0], stride=strides[0], padding=paddings[0], bias=not batch_norm, step_mode=self.step_mode) if self.batch_norm: self.bn1 = sjlayers_t.BatchNorm(self.in_planes, momentum=bn_momentum, step_mode=self.step_mode) self.neuron1 = self.create_neuron() blocks: list[nn.Sequential] = [] for planes, kernel_size, stride, padding, num_block in zip(filters[1:], kernel_sizes[1:], strides[1:], paddings[1:], num_blocks): blocks.append(self._make_layer(basicblockbuilder, num_block, planes, kernel_size, stride, padding)) self.layers = nn.ModuleList(blocks) # GlobalMaxPool kernel_size computation self._fm_dims = np.array(input_shape[:-1]) // np.array(prepool) for _, kernel, stride, padding in zip(filters, kernel_sizes, strides, paddings): self._fm_dims += np.array(padding) * 2 self._fm_dims -= (kernel - 1) self._fm_dims = np.floor(self._fm_dims / stride).astype(int) if postpool == 'avg': self.postpool = sjlayers_t.AvgPool(tuple(self._fm_dims), step_mode=self.step_mode) elif postpool == 'max': self.postpool = sjlayers_t.MaxPool(tuple(self._fm_dims), step_mode=self.step_mode) self.flatten = Flatten(step_mode=self.step_mode) self.linear = Linear(self.in_planes*BasicBlock.expansion, output_shape[0], step_mode=self.step_mode)
def _make_layer(self, # noqa: PLR0913 basicblockbuilder: BasicBlockBuilder, num_blocks: int, planes: int, kernel_size: int, stride: int, padding: int) -> nn.Sequential: strides = [stride] + [1]*(num_blocks-1) layers: list[BasicBlock] = [] for stride in strides: block: BasicBlock = basicblockbuilder(in_planes=self.in_planes, planes=planes, kernel_size=kernel_size, stride=stride, padding=padding) layers.append(block) self.in_planes = planes * block.expansion return nn.Sequential(*layers)
[docs] @override def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002 """Forward of residual spiking neural network. :param input: Input tensor :return: Output tensor """ x = input if hasattr(self, 'prepool'): x = self.prepool(x) out = self.conv1(x) if self.batch_norm: out = self.bn1(out) out = self.neuron1(out) for i in range(len(self.layers)): out = self.layers[i](out) if hasattr(self, 'postpool'): out = self.postpool(out) out = self.flatten(out) return cast(torch.Tensor, self.linear(out))