from __future__ import annotations
import ctypes
import errno
import logging
import os
import sys
from enum import IntEnum
from pathlib import Path
from typing import Any, ClassVar, cast
import numpy as np
from qualia_core.datamodel.RawDataModel import RawData, RawDataModel, RawDataSets
from qualia_core.dataset.RawDataset import RawDataset
from qualia_core.typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Sequence # noqa: TC003
from ctypes import _CDataType
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
[docs]
class IDXType(IntEnum):
"""List of possible data types of an IDX file."""
UINT8 = 0x08
INT8 = 0x09
INT16 = 0x0B
INT32 = 0x0C
FLOAT32 = 0x0D
FLOAT64 = 0x0E
[docs]
def as_numpy_dtype(self) -> np.dtype[Any]:
"""Convert the selected enum type to a numpy.dtype object, with big-endian byte order.
Returns:
numpy.dtype for the corresponding data type
"""
mapping = {
IDXType.UINT8: np.uint8,
IDXType.INT8: np.int8,
IDXType.INT16: np.int16,
IDXType.INT32: np.int32,
IDXType.FLOAT32: np.float32,
IDXType.FLOAT64: np.float64,
}
return np.dtype(mapping[self]).newbyteorder('>')
[docs]
class IDXMagicNumber(ctypes.BigEndianStructure):
"""Magic number of IDX file format.
Header of 4 bytes.
- First 2 bytes are always 0
- 3rd byte is the data type, one of :class:`IDXType`
- 4th byte is the number of dimensions that follow the magic number
"""
_fields_: ClassVar[Sequence[tuple[str, type[_CDataType]] | tuple[str, type[_CDataType], int]]] = [
('null', ctypes.c_uint16),
('dtype', ctypes.c_uint8),
('n_dims', ctypes.c_uint8),
]
[docs]
class MNISTBase(RawDataset):
"""Base class for MNIST-style datasets (MNIST and Fashion-MNIST).
This class provides common functionality for loading and processing datasets that
use the IDX file format. Both MNIST and Fashion-MNIST share the same:
- File format (IDX)
- Image dimensions (28x28 pixels)
- Number of classes (10)
- Dataset sizes (60,000 training, 10,000 test)
The IDX file format is a simple format for vectors and multidimensional matrices
of various numerical types. The files are organized as:
- magic number (4 bytes) identifying data type and dimensions
- dimension sizes (4 bytes each)
- data in row-major order
"""
def __init__(self, path: str = '', dtype: str = 'float32') -> None:
"""Initialize an MNIST-style dataset.
Args:
path: Directory containing the IDX files
dtype: Data type to convert images to
"""
super().__init__()
self.__path = Path(path)
self.__dtype = dtype
# MNIST datasets don't use a validation set, so we remove it
if 'valid' in self.sets:
self.sets.remove('valid')
def _read_idx_file(self, filepath: Path) -> np.ndarray[Any, Any]:
"""Read data from an IDX file format.
The IDX file format begins with a magic number containing:
- first 2 bytes: zero
- third byte: data type
- fourth byte: number of dimensions
Following this are the dimension sizes (4 bytes each).
Finally comes the data in row-major order.
All integers in most significant byte first order.
Args:
filepath: Path to the IDX file to read
Returns:
numpy.ndarray containing the file's data properly shaped
Raises:
FileNotFoundError: If the file doesn't exist
ValueError: If the file format is invalid
"""
if not filepath.exists():
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), filepath)
with filepath.open('rb') as f:
# Decode the first 4 bytes "magic number"
magic = IDXMagicNumber.from_buffer_copy(f.read(4))
if magic.null != 0:
logger.error('First 2 bytes of IDX files expected to be 0x%04x, got 0x%04x', b'\0\0', magic.null)
raise ValueError
dtype = IDXType(magic.dtype)
# Dims is an array of n_dims 32-bit unsigned integers in most significant byte first order
dims_ctype = cast(ctypes.c_uint32, ctypes.c_uint32.__ctype_be__) * magic.n_dims # type: ignore[attr-defined]
# Read the dimension sizes
dims_bytes = f.read(ctypes.sizeof(dims_ctype))
dims = dims_ctype.from_buffer_copy(dims_bytes)
# Read all the remaining data with the declared dtype
data = np.fromfile(f, dtype=dtype.as_numpy_dtype())
# And reshape to the specified dimensions
return data.reshape(dims)
def _load_data(self,
images_file: str,
labels_file: str) -> tuple[np.ndarray[Any, np.dtype[np.float32]], np.ndarray[Any, np.dtype[np.uint8]]]:
"""Load and preprocess a set of images and labels.
This method:
1. Reads both the image and label IDX files
2. Reshapes images to [N, H, W, C] format as required by Qualia
3. Normalizes pixel values to [0, 1] range
4. Ensures data types are correct (float32 for images)
Args:
images_file: Name of the IDX file containing images
labels_file: Name of the IDX file containing labels
Returns:
Tuple of (images, labels) where:
- images is float32 array of shape [N, 28, 28, 1], values in [0, 1]
- labels is uint8 array of shape [N] with values 0-9
"""
# Load raw data from IDX files
images = self._read_idx_file(self.__path / images_file)
labels = self._read_idx_file(self.__path / labels_file)
# Images need to be:
# - Reshaped to [N, H, W, C] format (adding channel dimension)
# - Converted to the chosen dtype
images = np.expand_dims(images, -1).astype(self.__dtype)
return images, labels
@override
def __call__(self) -> RawDataModel:
"""Load and prepare the complete dataset.
This method:
1. Loads both training and test sets
2. Formats them according to Qualia's requirements
3. Packages them in Qualia's data structures
The MNIST datasets use specific file names:
- train-images-idx3-ubyte: Training images (60,000 x 28 x 28)
- train-labels-idx1-ubyte: Training labels (60,000)
- t10k-images-idx3-ubyte: Test images (10,000 x 28 x 28)
- t10k-labels-idx1-ubyte: Test labels (10,000)
Returns:
RawDataModel containing:
- Training set (60,000 samples)
- Test set (10,000 samples)
Each set has:
- Images: float32 [N, 28, 28, 1] arrays, values in [0, 1]
- Labels: uint8 [N] arrays with values 0-9
"""
logger.info('Loading MNIST-style dataset from %s', self.__path)
# Load training and test sets
train_x, train_y = self._load_data('train-images-idx3-ubyte',
'train-labels-idx1-ubyte')
test_x, test_y = self._load_data('t10k-images-idx3-ubyte',
't10k-labels-idx1-ubyte')
# Log shapes to verify loading was correct
logger.info('Shapes: train_x=%s, train_y=%s, test_x=%s, test_y=%s',
train_x.shape, train_y.shape, test_x.shape, test_y.shape)
# Package everything in Qualia's containers
return RawDataModel(
sets=RawDataSets(
train=RawData(train_x, train_y),
test=RawData(test_x, test_y),
),
name=self.name,
)
[docs]
class MNIST(MNISTBase):
"""Original MNIST handwritten digits dataset.
The MNIST database contains 70,000 grayscale images of handwritten digits (0-9).
Each image is 28x28 pixels, centered to reduce preprocessing and get better results.
Dataset split:
- 60,000 training images
- 10,000 test images
Labels:
- 0-9: Corresponding digits
"""
[docs]
class FashionMNIST(MNISTBase):
"""Fashion MNIST clothing dataset.
A drop-in replacement for MNIST, containing 70,000 grayscale images of clothing items.
Each image is 28x28 pixels, following the same format as original MNIST.
Dataset split:
- 60,000 training images
- 10,000 test images
Labels:
0: T-shirt/top 5: Sandal
1: Trouser 6: Shirt
2: Pullover 7: Sneaker
3: Dress 8: Bag
4: Coat 9: Ankle boot
"""