Source code for qualia_core.dataset.CIFAR

from __future__ import annotations

import logging
import pickle
import sys
import time
from abc import ABC
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import numpy as np

from qualia_core.datamodel.RawDataModel import (
    RawData,
    RawDataChunks,
    RawDataChunksModel,
    RawDataChunksSets,
    RawDataDType,
    RawDataShape,
)
from qualia_core.typing import TYPE_CHECKING

from .RawDataset import RawDatasetChunks

if TYPE_CHECKING:
    from collections.abc import Generator

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

logger = logging.getLogger(__name__)


[docs] @dataclass class CIFARFile: data: np.ndarray[Any, np.dtype[np.uint8]] batch_label: bytes filenames: list[bytes]
[docs] class CIFAR(RawDatasetChunks, ABC): def __init__(self, # noqa: PLR0913 path: str, dtype: str, labels_field: str, train_files: list[str], test_files: list[str], train_shapes: RawDataShape, test_shapes: RawDataShape, file_cls: type[CIFARFile]) -> None: super().__init__() self.__path = Path(path) self.__dtype = dtype self.__labels_field = labels_field self.__train_files = train_files self.__test_files = test_files self.__train_shapes = train_shapes self.__test_shapes = test_shapes self.__dtypes = RawDataDType(x=np.dtype(dtype), y=np.dtype(np.int64)) self.__file_cls = file_cls self.sets.remove('valid') def __load_file(self, file: Path) -> CIFARFile: with file.open('rb') as fo: raw = pickle.load(fo, encoding='bytes') content = {k.decode('cp437'): v for k, v in raw.items()} return self.__file_cls(**content) def __load_batch(self, path: Path) -> RawData: d = self.__load_file(path) x_uint8 = d.data.reshape((d.data.shape[0], 3, 32, 32)) # N, C, H, W x_uint8 = x_uint8.transpose((0, 2, 3, 1)) x = x_uint8.astype(self.__dtype) # N, H, W, C y = np.array(getattr(d, self.__labels_field)) return RawData(x, y) def __load_batches(self, files: list[str]) -> Generator[RawData]: for file in files: start = time.time() batch = self.__load_batch(self.__path / file) logger.info('"%s" loaded in %s s', self.__path / file, time.time() - start) yield batch def __load_set(self, files: list[str], shapes: RawDataShape) -> RawDataChunks: return RawDataChunks(chunks=self.__load_batches(files), shapes=shapes, dtypes=self.__dtypes) @override def __call__(self) -> RawDataChunksModel: train = self.__load_set(files=self.__train_files, shapes=self.__train_shapes) test = self.__load_set(files=self.__test_files, shapes=self.__test_shapes) return RawDataChunksModel(sets=RawDataChunksSets(train=train, test=test), name=self.name)