Source code for qualia_plugin_snn.learningframework.SpikingJellyTimeStepsInData

"""Provide the SpikingJelly single-step with timesteps in input data learningframework module."""

from __future__ import annotations

import logging
import sys
from typing import Any

import torch
from qualia_core.typing import TYPE_CHECKING
from spikingjelly.activation_based import functional  # type: ignore[import-untyped]

from qualia_plugin_snn.dataaugmentation.pytorch.DataAugmentationPyTorchTimeStepsInData import (
    DataAugmentationPyTorchTimeStepsInData,
)

from .SpikingJelly import SpikingJelly

if TYPE_CHECKING:
    import numpy as np
    from qualia_core.dataaugmentation.pytorch.DataAugmentationPyTorch import DataAugmentationPyTorch
    from qualia_core.datamodel.RawDataModel import RawData

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

logger = logging.getLogger(__name__)

[docs] class SpikingJellyTimeStepsInData(SpikingJelly): """SpikingJelly single-step with timesteps in data LearningFramework implementation extending SpikingJelly single-step."""
[docs] class TrainerModule(SpikingJelly.TrainerModule): """SpikingJelly single-step with timesteps in data TrainerModule extending SpikingJelly single-step TrainerModule."""
[docs] @override def apply_dataaugmentation(self, batch: tuple[torch.Tensor, torch.Tensor], dataaugmentation: DataAugmentationPyTorch) -> tuple[torch.Tensor, torch.Tensor]: """Call a dataaugmentation module on the current batch. If the dataaugmentation module is not a :class:`qualia_plugin_snn.dataaugmentation.pytorch.DataAugmentationPyTorchTimeStepsInData`, or if :attr:`qualia_plugin_snn.dataaugmentation.pytorch.DataAugmentationPyTorchTimeStepsInData.collapse_timesteps` is ``True``, then the timestep dimension is automatically merged with the batch dimension in order to support existing non-timestep-aware Qualia-Core dataaugmentation modules transparently. :param batch: Current batch of data and targets :param dataaugmentation: dataaugmentation module to apply :return: Augmented batch """ if isinstance(dataaugmentation, DataAugmentationPyTorchTimeStepsInData) and not dataaugmentation.collapse_timesteps: return super().apply_dataaugmentation(batch, dataaugmentation) # Reshape to merge timestep and batch dimensions x, y = batch shape = x.shape x = x.reshape(*(shape[0] * shape[1], *shape[2:])) x, y = super().apply_dataaugmentation((x, y), dataaugmentation) x = x.reshape(*(shape[0], shape[1], *x.shape[1:])) return x, y
[docs] @override def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass for a Spiking Neural Network model with timesteps in input data in single-step mode. First calls SpikingJelly's reset on the model to reset neurons potentials. Call :meth:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.forward` for each timestep of the input data. Finally, average the output of the model over the timesteps. :param x: Input data with timestep dimension in [N, T, C, S] or [N, T, C, H, W] order :return: Output predictions :raise ValueError: when the input data does not have the correct number of dimenions or the timestep dimension does not match :attr:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.timesteps` """ functional.reset_net(self.model) # Switch timestep dim from 2nd to 1st place x = x.swapaxes(0, 1) input_timesteps = list(x) # Split timestep dims into multiple arrays if len(input_timesteps) != self.model.timesteps: logger.error('Model.timesteps differs from timesteps dimension in data: %s != %s', self.model.timesteps, len(input_timesteps)) raise ValueError x_list = [self.model(t) for t in input_timesteps] x = torch.stack(x_list).sum(0) return x / len(input_timesteps)
[docs] @override @staticmethod def channels_last_to_channels_first(x: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: """Channels last to channels first conversion with consideration of timestep dimension. For 2D data with timestep: [N, T, H, W, C] → [N, T, C, H, W] For 1D data with timestep: [N, T, S, C] → [N, T, C, S] :param x: NumPy array in channels last format :return: NumPy array reordered to channels first format """ if len(x.shape) == 5: # N, T, H, W, C → N, T, C, H, W # noqa: PLR2004 x = x.transpose(0, 1, 4, 2, 3) elif len(x.shape) == 4: # N, T, S, C → N, T, C, S # noqa: PLR2004 x = x.swapaxes(2, 3) else: logger.error('Unsupported number of axes in dataset: %s, must be 4 or 5', len(x.shape)) raise ValueError return x
[docs] @override @staticmethod def channels_first_to_channels_last(x: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: """Channels first to channels last conversion with consideration of timestep dimension. For 2D data with timestep: [N, T, C, H, W] → [N, T, H, W, C] For 1D data with timestep: [N, T, C, S] → [N, T, S, C] :param x: NumPy array in channels first format :return: NumPy array reordered to channels last format """ if len(x.shape) == 5: # N, T, C, H, W → N, T, H, W, C # noqa: PLR2004 x = x.transpose(0, 1, 3, 4, 2) elif len(x.shape) == 4: # N, T, C, S → N, T, S, C # noqa: PLR2004 x = x.swapaxes(3, 2) else: logger.error('Unsupported number of axes in dataset: %s, must be 4 or 5', len(x.shape)) raise ValueError return x
[docs] class DatasetFromArray(SpikingJelly.DatasetFromArray): """Override :class:`qualia_core.learningframework.PyTorch.PyTorch.DatasetFromArray` to reorder data with timestep dim."""
[docs] def __init__(self, dataset: RawData) -> None: """Load a Qualia dataset partition as a PyTorch dataset. Reorders the input data from channels_last to channels_first, with consideration of timestep dimension. :param dataset: Dataset partition to load """ self.x = SpikingJellyTimeStepsInData.channels_last_to_channels_first(dataset.x) self.y = dataset.y