from __future__ import annotations
import ctypes
import errno
import logging
import os
import sys
from enum import IntEnum
from pathlib import Path
from typing import TYPE_CHECKING as NATIVE_TYPE_CHECKING
from typing import Any, ClassVar, cast
import numpy as np
from qualia_core.datamodel.RawDataModel import RawData, RawDataModel, RawDataSets
from qualia_core.dataset.RawDataset import RawDataset
from qualia_core.typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Sequence
if NATIVE_TYPE_CHECKING:
from ctypes import _CDataType
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
[docs]
class IDXType(IntEnum):
"""List of possible data types of an IDX file."""
UINT8 = 0x08
INT8 = 0x09
INT16 = 0x0B
INT32 = 0x0C
FLOAT32 = 0x0D
FLOAT64 = 0x0E
[docs]
def as_numpy_dtype(self) -> np.dtype[Any]:
"""Convert the selected enum type to a numpy.dtype object, with big-endian byte order.
Returns:
numpy.dtype for the corresponding data type
"""
mapping = {
IDXType.UINT8: np.uint8,
IDXType.INT8: np.int8,
IDXType.INT16: np.int16,
IDXType.INT32: np.int32,
IDXType.FLOAT32: np.float32,
IDXType.FLOAT64: np.float64,
}
return np.dtype(mapping[self]).newbyteorder('>')
[docs]
class IDXMagicNumber(ctypes.BigEndianStructure):
"""Magic number of IDX file format.
Header of 4 bytes.
- First 2 bytes are always 0
- 3rd byte is the data type, one of :class:`IDXType`
- 4th byte is the number of dimensions that follow the magic number
"""
_fields_: ClassVar[Sequence[tuple[str, type[_CDataType]] | tuple[str, type[_CDataType], int]]] = [
('null', ctypes.c_uint16),
('dtype', ctypes.c_uint8),
('n_dims', ctypes.c_uint8),
]
[docs]
class MNISTBase(RawDataset):
"""Base class for MNIST-style datasets (MNIST and Fashion-MNIST).
This class provides common functionality for loading and processing datasets that
use the IDX file format. Both MNIST and Fashion-MNIST share the same:
- File format (IDX)
- Image dimensions (28x28 pixels)
- Number of classes (10)
- Dataset sizes (60,000 training, 10,000 test)
The IDX file format is a simple format for vectors and multidimensional matrices
of various numerical types. The files are organized as:
- magic number (4 bytes) identifying data type and dimensions
- dimension sizes (4 bytes each)
- data in row-major order
"""
def __init__(self, path: str = '', dtype: str = 'float32') -> None:
"""Initialize an MNIST-style dataset.
Args:
path: Directory containing the IDX files
dtype: Data type to convert images to
"""
super().__init__()
self.__path = Path(path)
self.__dtype = dtype
# MNIST datasets don't use a validation set, so we remove it
if 'valid' in self.sets:
self.sets.remove('valid')
def _read_idx_file(self, filepath: Path) -> np.ndarray[Any, Any]:
"""Read data from an IDX file format.
The IDX file format begins with a magic number containing:
- first 2 bytes: zero
- third byte: data type
- fourth byte: number of dimensions
Following this are the dimension sizes (4 bytes each).
Finally comes the data in row-major order.
All integers in most significant byte first order.
Args:
filepath: Path to the IDX file to read
Returns:
numpy.ndarray containing the file's data properly shaped
Raises:
FileNotFoundError: If the file doesn't exist
ValueError: If the file format is invalid
"""
if not filepath.exists():
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), filepath)
with filepath.open('rb') as f:
# Decode the first 4 bytes "magic number"
magic = IDXMagicNumber.from_buffer_copy(f.read(4))
if magic.null != 0:
logger.error('First 2 bytes of IDX files expected to be 0x%04x, got 0x%04x', b'\0\0', magic.null)
raise ValueError
dtype = IDXType(magic.dtype)
# Dims is an array of n_dims 32-bit unsigned integers in most significant byte first order
dims_ctype = cast(ctypes.c_uint32, ctypes.c_uint32.__ctype_be__) * magic.n_dims # type: ignore[attr-defined]
# Read the dimension sizes
dims_bytes = f.read(ctypes.sizeof(dims_ctype))
dims = dims_ctype.from_buffer_copy(dims_bytes)
# Read all the remaining data with the declared dtype
data = np.fromfile(f, dtype=dtype.as_numpy_dtype())
# And reshape to the specified dimensions
return data.reshape(dims)
def _load_data(self,
images_file: str,
labels_file: str) -> tuple[np.ndarray[Any, np.dtype[np.float32]], np.ndarray[Any, np.dtype[np.uint8]]]:
"""Load and preprocess a set of images and labels.
This method:
1. Reads both the image and label IDX files
2. Reshapes images to [N, H, W, C] format as required by Qualia
3. Normalizes pixel values to [0, 1] range
4. Ensures data types are correct (float32 for images)
Args:
images_file: Name of the IDX file containing images
labels_file: Name of the IDX file containing labels
Returns:
Tuple of (images, labels) where:
- images is float32 array of shape [N, 28, 28, 1], values in [0, 1]
- labels is uint8 array of shape [N] with values 0-9
"""
# Load raw data from IDX files
images = self._read_idx_file(self.__path / images_file)
labels = self._read_idx_file(self.__path / labels_file)
# Images need to be:
# - Reshaped to [N, H, W, C] format (adding channel dimension)
# - Converted to the chosen dtype
images = np.expand_dims(images, -1).astype(self.__dtype)
return images, labels
@override
def __call__(self) -> RawDataModel:
"""Load and prepare the complete dataset.
This method:
1. Loads both training and test sets
2. Formats them according to Qualia's requirements
3. Packages them in Qualia's data structures
The MNIST datasets use specific file names:
- train-images-idx3-ubyte: Training images (60,000 x 28 x 28)
- train-labels-idx1-ubyte: Training labels (60,000)
- t10k-images-idx3-ubyte: Test images (10,000 x 28 x 28)
- t10k-labels-idx1-ubyte: Test labels (10,000)
Returns:
RawDataModel containing:
- Training set (60,000 samples)
- Test set (10,000 samples)
Each set has:
- Images: float32 [N, 28, 28, 1] arrays, values in [0, 1]
- Labels: uint8 [N] arrays with values 0-9
"""
logger.info('Loading MNIST-style dataset from %s', self.__path)
# Load training and test sets
train_x, train_y = self._load_data('train-images-idx3-ubyte',
'train-labels-idx1-ubyte')
test_x, test_y = self._load_data('t10k-images-idx3-ubyte',
't10k-labels-idx1-ubyte')
# Log shapes to verify loading was correct
logger.info('Shapes: train_x=%s, train_y=%s, test_x=%s, test_y=%s',
train_x.shape, train_y.shape, test_x.shape, test_y.shape)
# Package everything in Qualia's containers
return RawDataModel(
sets=RawDataSets(
train=RawData(train_x, train_y),
test=RawData(test_x, test_y),
),
name=self.name,
)
[docs]
class MNIST(MNISTBase):
"""Original MNIST handwritten digits dataset.
The MNIST database contains 70,000 grayscale images of handwritten digits (0-9).
Each image is 28x28 pixels, centered to reduce preprocessing and get better results.
Dataset split:
- 60,000 training images
- 10,000 test images
Labels:
- 0-9: Corresponding digits
"""
[docs]
class FashionMNIST(MNISTBase):
"""Fashion MNIST clothing dataset.
A drop-in replacement for MNIST, containing 70,000 grayscale images of clothing items.
Each image is 28x28 pixels, following the same format as original MNIST.
Dataset split:
- 60,000 training images
- 10,000 test images
Labels:
0: T-shirt/top 5: Sandal
1: Trouser 6: Shirt
2: Pullover 7: Sneaker
3: Dress 8: Bag
4: Coat 9: Ankle boot
"""