Source code for qualia_plugin_snn.learningmodel.pytorch.SMLP

"""Contains the template for a spiking multi-layer perceptron."""

from __future__ import annotations

import math
import sys
from collections import OrderedDict

from qualia_core.typing import TYPE_CHECKING
from torch import nn

from .SNN import SNN

if TYPE_CHECKING:
    import torch
    from qualia_core.typing import RecursiveConfigDict

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

[docs] class SMLP(SNN): """Spiking multi-layer perceptron template. Similar to :class:`qualia_core.learningmodel.pytorch.MLP.MLP` but with spiking neuron activation layers (e.g., IF) instead of :class:`torch.nn.ReLU`. Last :class:`torch.nn.Linear` layer matching number of output classes is implicitely added. Example TOML configuration for a 3-layer spiking MLP over 4 timesteps with soft-reset multi-step IF based on the SMLP template: .. code-block:: toml [[model]] kind = "SMLP" name = "smlp_128-128-10" params.units = [128, 128] params.timesteps = 4 params.neuron.kind = 'IFNode' params.neuron.params.v_reset = false # Soft reset params.neuron.params.v_threshold = 1.0 params.neuron.params.detach_reset = true params.neuron.params.step_mode = 'm' # Multi-step mode, make sure to use SpikingJellyMultiStep learningframework params.neuron.params.backend = 'torch' """
[docs] def __init__(self, input_shape: tuple[int, ...], output_shape: tuple[int, ...], units: list[int], timesteps: int, neuron: RecursiveConfigDict | None = None) -> None: """Construct :class:`SMLP`. :param input_shape: Input shape :param output_shape: Output shape :param units: List of :class:`torch.nn.Linear` layer ``out_features`` to add in the network :param neuron: Spiking neuron configuration, see :meth:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.__init__` :param timesteps: Number of timesteps """ super().__init__(input_shape=input_shape, output_shape=output_shape, timesteps=timesteps, neuron=neuron) from spikingjelly.activation_based.layer import Flatten, Linear # type: ignore[import-untyped] layers: OrderedDict[str, nn.Module] = OrderedDict() layers['flatten1'] = Flatten(step_mode=self.step_mode) i = 1 for in_units, out_units in zip([math.prod(input_shape), *units[:-1]], units): layers[f'fc{i}'] = Linear(in_units, out_units, step_mode=self.step_mode) layers[f'neuron{i}'] = self.create_neuron() i += 1 layers[f'fc{i}'] = Linear(units[-1] if len(units) > 1 else math.prod(input_shape), output_shape[0], step_mode=self.step_mode) self.layers = nn.ModuleDict(layers)
[docs] @override def forward(self, input: torch.Tensor) -> torch.Tensor: """Forward calls each of the MLP :attr:`layers` sequentially. :param input: Input tensor :return: Output tensor """ x = input for layer in self.layers: x = self.layers[layer](x) return x