Source code for qualia_plugin_snn.preprocessing.IntegrateEventsByFixedFramesNumber

"""Provide a preprocessing module to construct fixed number of frames from a sample of event data."""

from __future__ import annotations

import logging
import sys
import time
from typing import Any

import numpy as np
from qualia_core.datamodel.RawDataModel import RawData, RawDataModel, RawDataSets
from qualia_core.learningframework.PyTorch import PyTorch
from qualia_core.typing import TYPE_CHECKING
from spikingjelly.datasets import cal_fixed_frames_number_segment_index  # type: ignore[import-untyped]

from .IntegrateEventsByFixedDuration import IntegrateEventsByFixedDuration

if TYPE_CHECKING:
    from qualia_core.dataset.Dataset import Dataset

    from qualia_plugin_snn.datamodel.EventDataModel import EventDataModel

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

logger = logging.getLogger(__name__)


[docs] class IntegrateEventsByFixedFramesNumber(IntegrateEventsByFixedDuration): """Preprocessing module to construct fixed number of frames from a sample of event data."""
[docs] def __init__(self, split_by: str, frames_num: int) -> None: """Construct :class:`qualia_plugin_snn.preprocessing.IntegrateEventsByFixedFramesNumber.IntegrateEventsByFixedFramesNumber`. :param split_by: ``'time'`` or ``'number'``, see :meth:`spikingjelly.datasets.cal_fixed_frames_number_segment_index` :param frames_num: Number of frames to produce per sample """ # noqa: E501 super().__init__(0) self.__split_by = split_by self.__frames_num = frames_num
# Adapted from SpikingJelly def __integrate_events_by_fixed_frames_number(self, events: np.recarray[Any, Any], labels: np.ndarray[Any, Any], h: int, w: int) -> tuple[np.ndarray[Any, np.dtype[np.float32]], np.ndarray[Any, Any]]: """Integrate events into a fixed number of frames. Adapted from SpikingJelly :meth:`spikingjelly.datasets.integrate_events_by_fixed_frames_number` Additionally reduces labels of each event over the integration window. :param events: Collection of events with (t, x, y, p) columns :param labels: Labels for each event :param h: Height of the frame :param w: Width of the frame :return: Frame after integration and labels after reduction """ j_l, j_r = cal_fixed_frames_number_segment_index(events.t, self.__split_by, self.__frames_num) if not hasattr(events, 'y'): # 1D Data frames = np.zeros([self.__frames_num, 2, w], dtype=np.float32) else: frames = np.zeros([self.__frames_num, 2, h, w], dtype=np.float32) reduced_labels: np.ndarray[Any, Any] = np.ndarray((self.__frames_num,), dtype=labels.dtype) for i in range(self.__frames_num): frames[i] = self._integrate_events_segment_to_frame(events, h, w, j_l[i], j_r[i]) # Find most common label in frame values, counts = np.unique(labels[j_l[i]:j_r[i]], return_counts=True) reduced_labels[i] = values[counts.argmax()] return frames, reduced_labels
[docs] @override def __call__(self, datamodel: EventDataModel) -> RawDataModel: """Construct frames from events of the same sample of a :class:`qualia_plugin_snn.datamodel.EventDataModel.EventDataModel`. Relies on sample indices (begin, end) from the info array to only collect events from the same sample. Input data should be 2D event data with (t, x, y, p) columns or 1D event data with (t, x, p) columns. Output data has [N, H, W, C] or [N, W, C] dimensions Uses a derivative of SpikingJelly's implementation of :meth:`spikingjelly.datasets.integrate_events_by_fixed_frames_number` :param datamodel: The input event-based dataset :return: The new frame dataset """ start = time.time() sets: dict[str, RawData] = {} for name, s in datamodel: if s.info is None: logger.error('Info for %s must be defined', name) raise ValueError data_list: list[np.ndarray[Any, np.dtype[np.float32]]] = [] labels_list: list[np.ndarray[Any, Any]] = [] info_list: list[tuple[int, int]] = [] first = 0 last = 0 for sample in s.info: # For each input sample data, labels = self.__integrate_events_by_fixed_frames_number( events=s.x[sample.begin:sample.end], labels=s.y[sample.begin:sample.end], h=datamodel.h, w=datamodel.w) # channels_first to channels_last: N, C, H, W → N, H, W, C or N, C, S → N, S, C data = PyTorch.channels_first_to_channels_last(data) data_list.append(data) labels_list.append(labels) # Record new begin and end indices of sample in data array to info array last += len(data) info_list.append((first, last)) first = last data_array = np.concatenate(data_list) labels_array = np.concatenate(labels_list) info_array = np.rec.array(info_list, dtype=np.dtype([('begin', np.int64), ('end', np.int64)])) logger.info('Shapes: %s_x=%s, %s_y=%s, %s_info=%s', name, data_array.shape, name, labels_array.shape, name, info_array.shape) sets[name] = RawData(x=data_array, y=labels_array, info=info_array) logger.info('Event integration finished in %s s.', time.time() - start) return RawDataModel(sets=RawDataSets(**sets), name=datamodel.name)
[docs] @override def import_data(self, dataset: Dataset[Any]) -> Dataset[Any]: def func() -> RawDataModel: rdm = RawDataModel(name=dataset.name) rdm.import_sets(set_names=dataset.sets) return rdm dataset.import_data = func # type: ignore[method-assign] return dataset