Source code for qualia_plugin_snn.learningframework.SpikingJellyMultiStepTimeStepsInData

"""Provide the SpikingJelly multi-step with timesteps in input data learningframework module."""

from __future__ import annotations

import logging
import sys

from qualia_core.typing import TYPE_CHECKING
from spikingjelly.activation_based import functional  # type: ignore[import-untyped]

from .SpikingJellyTimeStepsInData import SpikingJellyTimeStepsInData

if TYPE_CHECKING:
    import torch

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

logger = logging.getLogger(__name__)

[docs] class SpikingJellyMultiStepTimeStepsInData(SpikingJellyTimeStepsInData): """SpikingJelly multi-step with timesteps in data LearningFramework implementation extending SpikingJelly single-step."""
[docs] class TrainerModule(SpikingJellyTimeStepsInData.TrainerModule): """SpikingJelly multi-step with timesteps in data TrainerModule extending SpikingJelly single-step TrainerModule."""
[docs] @override def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass for a Spiking Neural Network model with timesteps in input data in multi-step mode. First calls SpikingJelly's reset on the model to reset neurons potentials. Call :meth:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.forward` for each timestep of the input data. Finally, average the output of the model over the timesteps. :param x: Input data with timestep dimension in [N, T, C, S] or [N, T, C, H, W] order :return: Output predictions :raise ValueError: when the input data does not have the correct number of dimenions or the timestep dimension does not match :attr:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.timesteps` """ functional.reset_net(self.model) # Switch timestep dim from 2nd to 1st place, [N, T, C, H, W] → [T, N, C, H, W] x = x.swapaxes(0, 1) if x.shape[0] != self.model.timesteps: logger.error('Model.timesteps differs from timesteps dimension in data: %s != %s', self.model.timesteps, x.shape[0]) raise ValueError return self.model(x).sum(0) / self.model.timesteps