Source code for qualia_core.preprocessing.DatasetSplitterBySubjects


from __future__ import annotations

import logging
import sys
from typing import TYPE_CHECKING, Any

from qualia_core.datamodel.RawDataModel import RawDataModel

from .Preprocessing import Preprocessing

if TYPE_CHECKING:
    from qualia_core.dataset.Dataset import Dataset

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

logger = logging.getLogger(__name__)

[docs] class DatasetSplitterBySubjects(Preprocessing[RawDataModel, RawDataModel]): """Warning: must be applied after Window to get correct split and randomization of windows.""" def __init__(self, source_subjects: list[Any], dest_subjects: list[Any], source: str = 'train', dest: str = 'test') -> None: super().__init__() self.__source_subjects = source_subjects self.__dest_subjects = dest_subjects self.__source = source self.__dest = dest @override def __call__(self, datamodel: RawDataModel) -> RawDataModel: source = getattr(datamodel.sets, self.__source) dest = getattr(datamodel.sets, self.__dest) dest = [s for s in source if s.name in self.__dest_subjects] source = [s for s in source if s.name in self.__source_subjects] setattr(datamodel.sets, self.__dest, dest) setattr(datamodel.sets, self.__source, source) return datamodel
[docs] @override def import_data(self, dataset: Dataset[Any]) -> Dataset[Any]: # Add dest to list of sets for dataset being loaded if self.__dest not in dataset.sets: dataset.sets.append(self.__dest) return dataset