Source code for qualia_plugin_snn.preprocessing.Split2TimeSteps

"""Provide a preprocessing module to split data into timesteps."""

import sys

from qualia_core.datamodel import RawDataModel
from qualia_core.preprocessing.Preprocessing import Preprocessing

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

[docs] class Split2TimeSteps(Preprocessing[RawDataModel, RawDataModel]): """Preprocessing module to split 1D input dataset into multiple timesteps."""
[docs] def __init__(self, chunks: int) -> None: """Construct :class:`qualia_plugin_snn.preprocessing.Split2TimeSteps.Split2TimeSteps`. :param chunks: Number of chunks to split the data into """ super().__init__() self.__chunks = chunks
[docs] @override def __call__(self, datamodel: RawDataModel) -> RawDataModel: """Split the given :class:`qualia_core.datamodel.RawDataModel.RawDataModel` into multiple timesteps. Input data should be 1D (+ channel) with [N, S, C] order (channels_last). Output data has [N, T, S // T, C] dimensions Extra data that do not fit in a chunk is truncated. :param datamodel: The input dataset :return: The dataset with additional timestep dimension """ for _, s in datamodel: truncated_dim = (s.x.shape[1] // self.__chunks) * self.__chunks s.x = s.x[:,:truncated_dim,:] s.x = s.x.reshape((s.x.shape[0], self.__chunks, s.x.shape[1] // self.__chunks) + s.x.shape[2:]) return datamodel