Source code for qualia_core.dataset.HD

from __future__ import annotations

import logging
import sys
import time
from pathlib import Path

import numpy as np
import numpy.typing

from qualia_core.datamodel import RawDataModel
from qualia_core.datamodel.Info import Info, Info_dtype
from qualia_core.datamodel.RawDataModel import RawData, RawDataSets
from qualia_core.utils.file import DirectoryReader

from .RawDataset import RawDataset

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

logger = logging.getLogger(__name__)

[docs] class HD(RawDataset): """Heidelberg Digits, raw audio ('hd_audio.tar.gz').""" def __init__(self, path: str, variant: str | None = None, test_subjects: list[int] | None = None) -> None: super().__init__() self.__path = Path(path) self.__variant = variant self.__test_subjects = test_subjects self.__dr = DirectoryReader() self.sets.remove('valid') def __load_wave(self, recording: Path) -> numpy.typing.NDArray[np.float32]: import soundfile as sf data, _ = sf.read(str(recording)) return data.astype(np.float32) def __is_test_subject(self, recording: Path, test_list: list[int] | list[str]) -> bool: if self.__variant == 'by-subject': return int(recording.stem.split('_')[1].replace('speaker-', '')) in test_list return recording.name in test_list def __load(self, path: Path, test_list: list[int] | list[str]) -> RawDataModel: start = time.time() directory = self.__dr.read(path / 'audio', ext='.flac', recursive=True) train_x_list: list[numpy.typing.NDArray[np.float32]] = [] train_y_list: list[int] = [] train_info_list: list[Info] = [] test_x_list: list[numpy.typing.NDArray[np.float32]] = [] test_y_list: list[int] = [] test_info_list: list[Info] = [] for recording in directory: data = self.__load_wave(recording).astype(np.float32) label = int(recording.stem.split('-')[-1]) if self.__is_test_subject(recording, test_list): test_x_list.append(data) test_y_list.append(label) test_info_list.append(Info(subject=int(recording.stem.split('_')[1].replace('speaker-', '')))) else: train_x_list.append(data) train_y_list.append(label) train_info_list.append(Info(subject=int(recording.stem.split('_')[1].replace('speaker-', '')))) # Zero pad to longest sample max_length = max(x.shape[0] for x in train_x_list + test_x_list) train_x_list = [np.pad(x, ((0, max_length - x.shape[0])), 'constant', constant_values=0) for x in train_x_list] test_x_list = [np.pad(x, ((0, max_length - x.shape[0])), 'constant', constant_values=0) for x in test_x_list] train_x = np.expand_dims(np.array(train_x_list), -1) # Add channels dimension train_y = np.array(train_y_list) train_info = np.array(train_info_list, dtype=Info_dtype) test_x = np.expand_dims(np.array(test_x_list), -1) # Add channels dimension test_y = np.array(test_y_list) test_info = np.array(test_info_list, dtype=Info_dtype) logger.info('Shapes: train_x=%s, train_y=%s, test_x=%s, test_y=%s', train_x.shape, train_y.shape, test_x.shape, test_y.shape) train = RawData(train_x, train_y, train_info) test = RawData(test_x, test_y, test_info) logger.info('Elapsed: %s s', time.time() - start) return RawDataModel(sets=RawDataSets(train=train, test=test), name=self.name) @override def __call__(self) -> RawDataModel: test_list: list[int] | list[str] = [] if self.__variant == 'by-subject': if self.__test_subjects is None: logger.error('test_subects required for by-subject variant') raise ValueError test_list = self.__test_subjects else: with (self.__path/'test_filenames.txt').open() as f: test_list = f.read().splitlines() return self.__load(self.__path, test_list=test_list) @property @override def name(self) -> str: if self.__variant: return f'{super().name}_{self.__variant}' return super().name