Source code for qualia_core.dataset.MNIST

from __future__ import annotations

import ctypes
import errno
import logging
import os
import sys
from enum import IntEnum
from pathlib import Path
from typing import TYPE_CHECKING as NATIVE_TYPE_CHECKING
from typing import Any, ClassVar, cast

import numpy as np

from qualia_core.datamodel.RawDataModel import RawData, RawDataModel, RawDataSets
from qualia_core.dataset.RawDataset import RawDataset
from qualia_core.typing import TYPE_CHECKING

if TYPE_CHECKING:
    from collections.abc import Sequence

if NATIVE_TYPE_CHECKING:
    from ctypes import _CDataType

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

logger = logging.getLogger(__name__)

[docs] class IDXType(IntEnum): """List of possible data types of an IDX file.""" UINT8 = 0x08 INT8 = 0x09 INT16 = 0x0B INT32 = 0x0C FLOAT32 = 0x0D FLOAT64 = 0x0E
[docs] def as_numpy_dtype(self) -> np.dtype[Any]: """Convert the selected enum type to a numpy.dtype object, with big-endian byte order. Returns: numpy.dtype for the corresponding data type """ mapping = { IDXType.UINT8: np.uint8, IDXType.INT8: np.int8, IDXType.INT16: np.int16, IDXType.INT32: np.int32, IDXType.FLOAT32: np.float32, IDXType.FLOAT64: np.float64, } return np.dtype(mapping[self]).newbyteorder('>')
[docs] class IDXMagicNumber(ctypes.BigEndianStructure): """Magic number of IDX file format. Header of 4 bytes. - First 2 bytes are always 0 - 3rd byte is the data type, one of :class:`IDXType` - 4th byte is the number of dimensions that follow the magic number """ _fields_: ClassVar[Sequence[tuple[str, type[_CDataType]] | tuple[str, type[_CDataType], int]]] = [ ('null', ctypes.c_uint16), ('dtype', ctypes.c_uint8), ('n_dims', ctypes.c_uint8), ]
[docs] class MNISTBase(RawDataset): """Base class for MNIST-style datasets (MNIST and Fashion-MNIST). This class provides common functionality for loading and processing datasets that use the IDX file format. Both MNIST and Fashion-MNIST share the same: - File format (IDX) - Image dimensions (28x28 pixels) - Number of classes (10) - Dataset sizes (60,000 training, 10,000 test) The IDX file format is a simple format for vectors and multidimensional matrices of various numerical types. The files are organized as: - magic number (4 bytes) identifying data type and dimensions - dimension sizes (4 bytes each) - data in row-major order """ def __init__(self, path: str = '', dtype: str = 'float32') -> None: """Initialize an MNIST-style dataset. Args: path: Directory containing the IDX files dtype: Data type to convert images to """ super().__init__() self.__path = Path(path) self.__dtype = dtype # MNIST datasets don't use a validation set, so we remove it if 'valid' in self.sets: self.sets.remove('valid') def _read_idx_file(self, filepath: Path) -> np.ndarray[Any, Any]: """Read data from an IDX file format. The IDX file format begins with a magic number containing: - first 2 bytes: zero - third byte: data type - fourth byte: number of dimensions Following this are the dimension sizes (4 bytes each). Finally comes the data in row-major order. All integers in most significant byte first order. Args: filepath: Path to the IDX file to read Returns: numpy.ndarray containing the file's data properly shaped Raises: FileNotFoundError: If the file doesn't exist ValueError: If the file format is invalid """ if not filepath.exists(): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), filepath) with filepath.open('rb') as f: # Decode the first 4 bytes "magic number" magic = IDXMagicNumber.from_buffer_copy(f.read(4)) if magic.null != 0: logger.error('First 2 bytes of IDX files expected to be 0x%04x, got 0x%04x', b'\0\0', magic.null) raise ValueError dtype = IDXType(magic.dtype) # Dims is an array of n_dims 32-bit unsigned integers in most significant byte first order dims_ctype = cast(ctypes.c_uint32, ctypes.c_uint32.__ctype_be__) * magic.n_dims # type: ignore[attr-defined] # Read the dimension sizes dims_bytes = f.read(ctypes.sizeof(dims_ctype)) dims = dims_ctype.from_buffer_copy(dims_bytes) # Read all the remaining data with the declared dtype data = np.fromfile(f, dtype=dtype.as_numpy_dtype()) # And reshape to the specified dimensions return data.reshape(dims) def _load_data(self, images_file: str, labels_file: str) -> tuple[np.ndarray[Any, np.dtype[np.float32]], np.ndarray[Any, np.dtype[np.uint8]]]: """Load and preprocess a set of images and labels. This method: 1. Reads both the image and label IDX files 2. Reshapes images to [N, H, W, C] format as required by Qualia 3. Normalizes pixel values to [0, 1] range 4. Ensures data types are correct (float32 for images) Args: images_file: Name of the IDX file containing images labels_file: Name of the IDX file containing labels Returns: Tuple of (images, labels) where: - images is float32 array of shape [N, 28, 28, 1], values in [0, 1] - labels is uint8 array of shape [N] with values 0-9 """ # Load raw data from IDX files images = self._read_idx_file(self.__path / images_file) labels = self._read_idx_file(self.__path / labels_file) # Images need to be: # - Reshaped to [N, H, W, C] format (adding channel dimension) # - Converted to the chosen dtype images = np.expand_dims(images, -1).astype(self.__dtype) return images, labels @override def __call__(self) -> RawDataModel: """Load and prepare the complete dataset. This method: 1. Loads both training and test sets 2. Formats them according to Qualia's requirements 3. Packages them in Qualia's data structures The MNIST datasets use specific file names: - train-images-idx3-ubyte: Training images (60,000 x 28 x 28) - train-labels-idx1-ubyte: Training labels (60,000) - t10k-images-idx3-ubyte: Test images (10,000 x 28 x 28) - t10k-labels-idx1-ubyte: Test labels (10,000) Returns: RawDataModel containing: - Training set (60,000 samples) - Test set (10,000 samples) Each set has: - Images: float32 [N, 28, 28, 1] arrays, values in [0, 1] - Labels: uint8 [N] arrays with values 0-9 """ logger.info('Loading MNIST-style dataset from %s', self.__path) # Load training and test sets train_x, train_y = self._load_data('train-images-idx3-ubyte', 'train-labels-idx1-ubyte') test_x, test_y = self._load_data('t10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte') # Log shapes to verify loading was correct logger.info('Shapes: train_x=%s, train_y=%s, test_x=%s, test_y=%s', train_x.shape, train_y.shape, test_x.shape, test_y.shape) # Package everything in Qualia's containers return RawDataModel( sets=RawDataSets( train=RawData(train_x, train_y), test=RawData(test_x, test_y), ), name=self.name, )
[docs] class MNIST(MNISTBase): """Original MNIST handwritten digits dataset. The MNIST database contains 70,000 grayscale images of handwritten digits (0-9). Each image is 28x28 pixels, centered to reduce preprocessing and get better results. Dataset split: - 60,000 training images - 10,000 test images Labels: - 0-9: Corresponding digits """
[docs] class FashionMNIST(MNISTBase): """Fashion MNIST clothing dataset. A drop-in replacement for MNIST, containing 70,000 grayscale images of clothing items. Each image is 28x28 pixels, following the same format as original MNIST. Dataset split: - 60,000 training images - 10,000 test images Labels: 0: T-shirt/top 5: Sandal 1: Trouser 6: Shirt 2: Pullover 7: Sneaker 3: Dress 8: Bag 4: Coat 9: Ankle boot """