"""DVS128-Gesture event-based dataset import module based on SpikingJelly."""
from __future__ import annotations
import logging
import sys
import time
from pathlib import Path
from typing import Final, cast
import numpy as np
from qualia_core.datamodel.RawDataModel import RawDataDType, RawDataShape
from qualia_core.typing import TYPE_CHECKING
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture # type: ignore[import-untyped]
from qualia_plugin_snn.datamodel.EventDataModel import (
EventData,
EventDataChunks,
EventDataChunksModel,
EventDataChunksSets,
EventDataInfo,
EventDataInfoRecord,
)
from .EventDataset import EventDatasetChunks
if TYPE_CHECKING:
from collections.abc import Generator
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
[docs]
class DVSGesture(EventDatasetChunks):
"""DVS128 Gesture event-based data loading based on SpikingJelly."""
h: int = 128
w: int = 128
dtype: Final[np.dtype] = np.dtype([('t', np.int64), ('y', np.int8), ('x', np.int8), ('p', np.bool_)])
[docs]
def __init__(self,
path: str = '',
data_type: str = 'frame') -> None:
"""Instantiate the DVS128 Gesture dataset loader.
:param path: Dataset source path
:param data_type: Only ``'frame'`` is supported
"""
super().__init__()
self.__path = Path(path)
self.__data_type = data_type
self.sets.remove('valid')
def __load_dvs128gesture(self, *, train: bool) -> DVS128Gesture:
"""Call SpikingJelly loader implementation for DVS128 Gesture.
:param train: Load train data if ``True``, otherwise load test data
"""
self.__path.mkdir(parents=True, exist_ok=True)
return DVS128Gesture(str(self.__path),
train=train,
data_type='event')
def __dvs128gesture_to_event_data(self, dvs128gesture: DVS128Gesture) -> Generator[EventData]:
"""Load events using SpikingJelly loader and fill event-based data structure.
:param dvs128gesture: SpikingJelly DVS128Gesture loader
:yield: Event data with timestamps, x and y coordinates, polarity, label, and sample indices
"""
start = time.time()
# Couple of begin and end indices for each sample in concatenated array
sample_indices = EventDataInfo((len(dvs128gesture),))
for sample in dvs128gesture:
t = sample[0]['t'].astype(np.int64)
y = sample[0]['y'].astype(np.int8)
x = sample[0]['x'].astype(np.int8)
p = sample[0]['p'].astype(np.bool_)
labels = np.full(sample[0]['t'].shape[0], sample[1], dtype=np.int8)
data = np.rec.fromarrays([t, y, x, p],
dtype=self.dtype)
# Each chunk corresponds to a single sample
sample_indices = EventDataInfo((1,))
cast('EventDataInfoRecord', sample_indices[0]).begin = np.int64(0)
cast('EventDataInfoRecord', sample_indices[0]).end = t.shape[0]
yield EventData(data, labels, sample_indices)
logger.info('Loading finished in %s s.', time.time() - start)
[docs]
@override
def __call__(self) -> EventDataChunksModel:
"""Load DVS128 Gesture data as events.
:return: Data model structure with train and test sets containing events and labels
"""
if self.__data_type != 'frame':
logger.error('Unsupported data_type %s', self.__data_type)
raise ValueError
train_dvs128gesture = self.__load_dvs128gesture(train=True)
test_dvs128gesture = self.__load_dvs128gesture(train=False)
shapes = RawDataShape(x=(None,), y=(None,))
dtypes = RawDataDType(x=self.dtype, y=np.dtype(np.uint8))
trainset = EventDataChunks(chunks=self.__dvs128gesture_to_event_data(train_dvs128gesture),
shapes=shapes,
dtypes=dtypes)
testset = EventDataChunks(chunks=self.__dvs128gesture_to_event_data(test_dvs128gesture),
shapes=shapes,
dtypes=dtypes)
return EventDataChunksModel(sets=EventDataChunksSets(train=trainset, test=testset),
name=self.name,
h=self.h,
w=self.w)
@property
@override
def name(self) -> str:
return f'{self.__class__.__name__}_{self.__data_type}'