Source code for qualia_core.datamodel.DataModel

from __future__ import annotations

import logging
from dataclasses import dataclass, fields
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Generic, TypeVar

if TYPE_CHECKING:
    import sys
    from collections.abc import Iterator

    if sys.version_info >= (3, 10):
        from typing import TypeGuard
    else:
        from typing_extensions import TypeGuard

    if sys.version_info >= (3, 11):
        from typing import Self
    else:
        from typing_extensions import Self

logger = logging.getLogger(__name__)

T = TypeVar('T')
U = TypeVar('U')

[docs] class DataModel(Generic[T]): sets: DataModel.Sets[T] name: str
[docs] @dataclass class Sets(Generic[U]): train: U | None = None valid: U | None = None test: U | None = None
[docs] @classmethod def fieldnames(cls) -> tuple[str, ...]: return tuple(f.name for f in fields(cls))
[docs] def asdict(self) -> dict[str, U]: # Explicitely use vars rather than dataclasses.asdict() since we don't want to recurse nor copy return vars(self)
def __iter__(self) -> Iterator[tuple[str, U]]: # Skip non-existent sets return {k: v for k,v in self.asdict().items() if v is not None}.items().__iter__()
[docs] def export(self, path: Path) -> None: for sname, sdata in self.asdict().items(): if not sdata: logger.info('No "%s" set to export', sname) else: (path/sname).mkdir(parents=True, exist_ok=True) sdata.export(path/sname)
def __init__(self, name: str, sets: Sets[T] | None = None) -> None: super().__init__() self.sets = sets if sets is not None else self.Sets() self.name = name def __iter__(self) -> Iterator[tuple[str, T]]: return self.sets.__iter__()
[docs] def export(self) -> None: self.sets.export(Path('out')/'data'/self.name)
@classmethod def _import_data_sets(cls, name: str, set_names: list[str], importer: Callable[[Path], T | None]) -> dict[str, T] | None: sets_dict: dict[str, T | None] = {sname: importer(Path('out')/'data'/name/sname) for sname in set_names} def no_none_in_sets(sets_dict: dict[str, T | None]) -> TypeGuard[dict[str, T]]: return all(s is not None for s in sets_dict.values()) if not no_none_in_sets(sets_dict): logger.error('Could not import data.') return None logger.info('Imported %s for %s', ', '.join(sets_dict.keys()), name) return sets_dict
[docs] def import_sets(self, set_names: list[str], sets_cls: type[DataModel.Sets[T]], importer: Callable[[Path], T | None]) -> None: sets_dict = self._import_data_sets(name=self.name, set_names=set_names, importer=importer) if sets_dict is not None: self.sets = sets_cls(**sets_dict)