"""DVS128-Gesture event-based dataset import module based on SpikingJelly including preprocessing to frames and timesteps."""
from __future__ import annotations
import logging
import math
import os
import sys
import time
from concurrent.futures import Future, ProcessPoolExecutor
from multiprocessing.shared_memory import SharedMemory
from pathlib import Path
from typing import Any, Callable
import numpy as np
import numpy.typing as npt
from qualia_codegen_core.typing import TYPE_CHECKING
from qualia_core.datamodel import RawDataModel
from qualia_core.datamodel.RawDataModel import RawData
from qualia_core.dataset.RawDataset import RawDataset
from qualia_core.utils.process.init_process import init_process
from qualia_core.utils.process.SharedMemoryManager import SharedMemoryManager
from spikingjelly.datasets import integrate_events_by_fixed_duration # type: ignore[import-untyped]
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture # type: ignore[import-untyped]
if TYPE_CHECKING:
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
LoadFramesReturnT = tuple[tuple[str, tuple[int, ...], npt.DTypeLike], tuple[str, tuple[int, ...], npt.DTypeLike]]
SharedMemoryArrayReturnT = tuple[str, tuple[int, ...], npt.DTypeLike]
[docs]
class DVSGestureWithPreprocessing(RawDataset):
"""DVS128 Gesture event-based data loading based on SpikingJelly including preprocessing to frames and timesteps."""
[docs]
def __init__(self,
path: str='',
data_type: str = 'frame',
duration: int = 0,
timesteps: int = 0) -> None:
"""Instantiate the DVS128 Gesture dataset loader with preprocessing.
:param path: Dataset source path
:param data_type: Only ``'frame'`` is supported
:param duration: Frame integration duration
:param timesteps: Number of timesteps to groupe frames by
"""
super().__init__()
self.__path = Path(path)
self.__data_type = data_type
self.__duration = duration
self.__timesteps = timesteps
self.sets.remove('valid')
def __load_dvs128gesture(self, *, train: bool) -> DVS128Gesture:
"""Call SpikingJelly loader implementation for DVS128 Gesture.
:param train: Load train data if ``True``, otherwise load test data
"""
self.__path.mkdir(parents=True, exist_ok=True)
return DVS128Gesture(str(self.__path),
train=train,
data_type='event')
def _shared_memory_array(self,
smm: SharedMemoryManager,
data_array: npt.NDArray[np.float32] |
npt.NDArray[np.int32]) -> SharedMemoryArrayReturnT:
data_buffer = smm.SharedMemory(size=data_array.nbytes)
if data_buffer.buf is None:
logger.error('Shared memory buffer is invalid')
raise RuntimeError
data_shared = np.frombuffer(data_buffer.buf, count=data_array.size, dtype=data_array.dtype).reshape(data_array.shape)
np.copyto(data_shared, data_array)
del data_shared
ret = (data_buffer.name, data_array.shape, data_array.dtype)
data_buffer.close()
return ret
def _load_frames(self,
smm_address: str | tuple[str, int],
i: int,
dvs128gesture: DVS128Gesture,
chunks: npt.NDArray[np.int32]) -> LoadFramesReturnT:
"""Subprocess entry point to load and process data for a set of samples.
:param i: Process number
:param dvs128gesture: SpikingJelly DVS128Gesture loader
:param chunks: List of samples to load
:return: Frames over timesteps and labels for selected samples
"""
start = time.time()
smm = SharedMemoryManager(address=smm_address)
smm.connect()
logger.info('Process %s loading frames for chunks %s...', i, chunks)
h: int
w: int
h, w = dvs128gesture.get_H_W()
data_list: list[npt.NDArray[np.float32]] = []
labels_list: list[npt.NDArray[np.int32]] = []
for j in chunks:
data64: npt.NDArray[np.float64] = integrate_events_by_fixed_duration(events=dvs128gesture[j][0],
duration=self.__duration,
H=h,
W=w)
data = data64.astype(np.float32)
data = data.transpose((0, 2, 3, 1)) # N, C, H, W → N, H, W, C
frame_chunks: int = data.shape[0] // self.__timesteps
data = data[:frame_chunks * self.__timesteps] # Truncate excessive frames
data = data.reshape((frame_chunks, self.__timesteps, *data.shape[1:])) # N, T, H, W, C
label: int = dvs128gesture[j][1]
data_list.append(data)
labels_list.append(np.full(data.shape[0], label, dtype=np.int32))
data_array = np.concatenate(data_list)
labels_array = np.concatenate(labels_list)
del data_list
del labels_list
data_ret = self._shared_memory_array(smm, data_array)
labels_ret = self._shared_memory_array(smm, labels_array)
logger.info('Process %s finished in %s s.', i, time.time() - start)
return data_ret, labels_ret
@staticmethod
def __is_no_bufs_none(bufs: list[memoryview[Any] | None]) -> TypeGuard[memoryview[Any]]:
return not any(buf is None for buf in bufs)
def __dvs128gesture_to_data(self, dvs128gesture: DVS128Gesture) -> RawData:
"""Parallel loading and processing of event data to construct frames and timesteps.
:param dvs128gesture: SpikingJelly DVS128Gesture loader
:return: Frame and timesteps data and labels
:raise RuntimeError: If a SharedMemory buffer is invalid (use after close)
"""
samples = len(dvs128gesture)
cpus: int | None = os.cpu_count()
total_chunks: int = cpus // 2 if cpus is not None else 2
chunks_list = np.array_split(np.arange(samples, dtype=np.int32), total_chunks)
with SharedMemoryManager() as smm, ProcessPoolExecutor(initializer=init_process) as executor:
if smm.address is None: # After smm is started in context, address is necessary non-None
raise RuntimeError
train_futures = [executor.submit(self._load_frames, smm.address, i, dvs128gesture, chunks)
for i, chunks in enumerate(chunks_list)]
def load_results(futures: list[Future[LoadFramesReturnT]],
resloader: Callable[[LoadFramesReturnT],
SharedMemoryArrayReturnT]) -> np.ndarray[Any, Any]:
names = [resloader(f.result())[0] for f in futures]
shapes = [resloader(f.result())[1] for f in futures]
dtypes = [resloader(f.result())[2] for f in futures]
bufs = [SharedMemory(n) for n in names]
raw_bufs = [buf.buf for buf in bufs]
if not self.__is_no_bufs_none(raw_bufs):
logger.error('Shared memory buffer is invalid')
raise RuntimeError
data_list = [np.frombuffer(buf, count=math.prod(shape), dtype=dtype).reshape(shape)
for shape, dtype, buf in zip(shapes, dtypes, raw_bufs)]
data_array: np.ndarray[Any, Any] = np.concatenate(data_list)
del data_list
for buf in bufs:
buf.unlink()
return data_array
data = load_results(train_futures, lambda r: r[0])
labels = load_results(train_futures, lambda r: r[1])
return RawData(data, labels)
[docs]
@override
def __call__(self) -> RawDataModel:
"""Load DVS128 Gesture data as frames over timesteps.
:return: Data model structure with train and test sets containing frames with timesteps and labels
"""
if self.__data_type != 'frame':
logger.error('Unsupported data_type %s', self.__data_type)
raise ValueError
train_dvs128gesture = self.__load_dvs128gesture(train=True)
test_dvs128gesture = self.__load_dvs128gesture(train=False)
trainset = self.__dvs128gesture_to_data(train_dvs128gesture)
testset = self.__dvs128gesture_to_data(test_dvs128gesture)
logger.info('Shapes: train_x=%s, train_y=%s, test_x=%s, test_y=%s',
trainset.x.shape if trainset.x is not None else None,
trainset.y.shape if trainset.y is not None else None,
testset.x.shape if testset.x is not None else None,
testset.y.shape if testset.y is not None else None)
return RawDataModel(sets=RawDataModel.Sets(train=trainset, test=testset), name=self.name)
@property
@override
def name(self) -> str:
return f'{self.__class__.__name__}_{self.__data_type}_d{self.__duration}_t{self.__timesteps}'