"""Provide a preprocessing module to split data into timesteps."""
import sys
from qualia_core.datamodel import RawDataModel
from qualia_core.preprocessing.Preprocessing import Preprocessing
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
[docs]
class Split2TimeSteps(Preprocessing[RawDataModel, RawDataModel]):
"""Preprocessing module to split 1D input dataset into multiple timesteps."""
[docs]
def __init__(self, chunks: int) -> None:
"""Construct :class:`qualia_plugin_snn.preprocessing.Split2TimeSteps.Split2TimeSteps`.
:param chunks: Number of chunks to split the data into
"""
super().__init__()
self.__chunks = chunks
[docs]
@override
def __call__(self, datamodel: RawDataModel) -> RawDataModel:
"""Split the given :class:`qualia_core.datamodel.RawDataModel.RawDataModel` into multiple timesteps.
Input data should be 1D (+ channel) with [N, S, C] order (channels_last).
Output data has [N, T, S // T, C] dimensions
Extra data that do not fit in a chunk is truncated.
:param datamodel: The input dataset
:return: The dataset with additional timestep dimension
"""
for _, s in datamodel:
truncated_dim = (s.x.shape[1] // self.__chunks) * self.__chunks
s.x = s.x[:,:truncated_dim,:]
s.x = s.x.reshape((s.x.shape[0], self.__chunks, s.x.shape[1] // self.__chunks) + s.x.shape[2:])
return datamodel