Source code for qualia_plugin_snn.learningmodel.pytorch.QuantizedSMLP

"""Contains the template for a quantized spiking multi-layer perceptron."""

from __future__ import annotations

import math
import sys
from collections import OrderedDict

from qualia_core.typing import TYPE_CHECKING
from torch import nn

from .SNN import SNN

if TYPE_CHECKING:
    import torch
    from qualia_core.learningmodel.pytorch.Quantizer import QuantizationConfig
    from qualia_core.typing import RecursiveConfigDict

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

[docs] class QuantizedSMLP(SNN): """Quantized spiking multi-layer perceptron template. Should have topology identical to :class:`qualia_plugin_snn.learningmodel.pytorch.SMLP.SMLP` but with layers replaced with their quantized equivalent. """
[docs] def __init__(self, # noqa: PLR0913 input_shape: tuple[int, ...], output_shape: tuple[int, ...], units: list[int], quant_params: QuantizationConfig, timesteps: int, neuron: RecursiveConfigDict | None = None) -> None: """Construct :class:`QuantizedSMLP`. :param input_shape: Input shape :param output_shape: Output shape :param units: List of :class:`torch.nn.Linear` layer ``out_features`` to add in the network :param quant_params: Quantization configuration dict, see :class:`qualia_core.learningmodel.pytorch.Quantizer.Quantizer` :param neuron: Spiking neuron configuration, see :meth:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.__init__` :param timesteps: Number of timesteps """ # Prepend Quantized to the neuron kind to instantiate quantized spiking neurons if neuron is not None and 'kind' in neuron and isinstance(neuron['kind'], str): neuron['kind'] = 'Quantized' + neuron['kind'] super().__init__(input_shape=input_shape, output_shape=output_shape, timesteps=timesteps, neuron=neuron) from spikingjelly.activation_based.layer import Flatten # type: ignore[import-untyped] from qualia_plugin_snn.learningmodel.pytorch.layers.spikingjelly.quantized_layers import QuantizedLinear layers: OrderedDict[str, nn.Module] = OrderedDict() layers['flatten1'] = Flatten(step_mode=self.step_mode) i = 1 for in_units, out_units in zip([math.prod(input_shape), *units[:-1]], units): layers[f'fc{i}'] = QuantizedLinear(in_units, out_units, quant_params=quant_params, step_mode=self.step_mode) layers[f'neuron{i}'] = self.create_neuron(quant_params=quant_params) i += 1 layers[f'fc{i}'] = QuantizedLinear(units[-1] if len(units) > 1 else math.prod(input_shape), output_shape[0], quant_params=quant_params, step_mode=self.step_mode) self.layers = nn.ModuleDict(layers)
[docs] @override def forward(self, input: torch.Tensor) -> torch.Tensor: """Forward calls each of the MLP :attr:`layers` sequentially. :param input: Input tensor :return: Output tensor """ x = input for layer in self.layers: x = self.layers[layer](x) return x