Source code for qualia_core.dataset.CIFAR100

from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Literal

from qualia_core.datamodel.RawDataModel import RawDataShape

from .CIFAR10 import CIFAR, CIFARFile

logger = logging.getLogger(__name__)


[docs] @dataclass class CIFAR100File(CIFARFile): coarse_labels: list[int] fine_labels: list[int]
[docs] class CIFAR100(CIFAR): def __init__(self, path: str, dtype: str = 'float32', labels_field: Literal['coarse_labels', 'fine_labels'] = 'fine_labels') -> None: super().__init__(path=path, dtype=dtype, labels_field=labels_field, train_files=['train'], test_files=['test'], train_shapes=RawDataShape(x=(None, 32, 32, 3), y=(None,)), test_shapes=RawDataShape(x=(None, 32, 32, 3), y=(None,)), file_cls=CIFAR100File)