Source code for qualia_plugin_snn.learningframework.SpikingJellyMultiStep

"""Provide the SpikingJellyMultiStep learningframework module."""

import sys

import torch
from spikingjelly.activation_based import functional  # type: ignore[import-untyped]

from .SpikingJelly import SpikingJelly

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

[docs] class SpikingJellyMultiStep(SpikingJelly): """SpikingJelly multi-step LearningFramework implementation extending SpikingJelly single-step."""
[docs] class TrainerModule(SpikingJelly.TrainerModule): """SpikingJelly multi-step TrainerModule implementation extending SpikingJelly single-step TrainerModule."""
[docs] @override def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass for a Spiking Neural Network model with duplicated timesteps in multi-step mode. First calls SpikingJelly's reset on the model to reset neurons potentials. Then duplicate the input to generate the number of timesteps given by :attr:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.timesteps`. Timesteps are generated as a new dimension of the input tensor: [N, C, …] → [T, N, C, …] Call :meth:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.forward` for the whole tensor. Finally, average the output of the model over the timestep dimension. :param x: Input data :return: Output predictions """ functional.reset_net(self.model) # Duplicate input over timesteps x = x.unsqueeze(0).repeat(self.model.timesteps, *([1] * len(x.shape))) return self.model(x).sum(0) / self.model.timesteps