Source code for qualia_core.command.Train

from __future__ import annotations

import logging
from pathlib import Path
from typing import Any, NamedTuple

import colorful as cf  # type: ignore[import-untyped]

from qualia_core.qualia import train
from qualia_core.typing import TYPE_CHECKING
from qualia_core.utils.logger import CSVLogger, Logger

if TYPE_CHECKING:
    from qualia_core.dataaugmentation.DataAugmentation import DataAugmentation  # noqa: TCH001
    from qualia_core.datamodel.RawDataModel import RawDataModel  # noqa: TCH001
    from qualia_core.learningframework.LearningFramework import LearningFramework  # noqa: TCH001
    from qualia_core.typing import ConfigDict
    from qualia_core.utils.plugin import QualiaComponent  # noqa: TCH001

logger = logging.getLogger(__name__)

[docs] class LearningModelLoggerFields(NamedTuple): name: str i: int params: int mem_params: int accuracy: float
[docs] class Train: def __call__(self, # noqa: C901 qualia: QualiaComponent, learningframework: LearningFramework[Any], dataaugmentations: list[DataAugmentation], data: RawDataModel, config: ConfigDict) -> dict[str, Logger[LearningModelLoggerFields]]: loggers: dict[str, Logger[LearningModelLoggerFields]] = {} log: CSVLogger[LearningModelLoggerFields] = CSVLogger('learningmodel') loggers['learningmodel'] = log log.fields = LearningModelLoggerFields # Write column names experimenttracking = config.get('experimenttracking', None) (Path('out')/'learningmodel').mkdir(parents=True, exist_ok=True) for i in range(config['bench']['first_run'], config['bench']['last_run']+1): for m in config['model']: if m.get('disabled', False): continue if m.get('load', False): print(f'{cf.bold}Loading {cf.blue}{m["name"]}{cf.close_fg_color}, run {cf.red}{i}{cf.reset}') else: print(f'{cf.bold}Training {cf.blue}{m["name"]}{cf.close_fg_color}, run {cf.red}{i}{cf.reset}') print(f'{cf.bold}Params:{cf.reset} {m=}') et = None if experimenttracking is not None: et = getattr(learningframework.experimenttrackings, experimenttracking['kind'])(**experimenttracking.get('params', {})) et.start() et.hyperparameters = {'config': config, 'model': m, 'i': i} model = getattr(learningframework.learningmodels, m['kind'], None) if model is None: logger.error("Could not load model.kind '%s' from learningmodels '%s'.", m['kind'], learningframework.learningmodels.__name__) logger.error("Did you load the necessary plugins (loaded: %s) and use the correct learningframework.kind (in use: '%s')?", config['bench'].get('plugins', []), learningframework.__class__.__name__) raise ModuleNotFoundError trainresult = train(datamodel=data, train_epochs=m.get('epochs', 0), iteration=i, model_name=m['name'], model=model, model_params=m.get('params', {}), batch_size=m.get('batch_size', None), optimizer=m.get('optimizer', None), framework=learningframework, load=m.get('load', False), train=m.get('train', True), evaluate=m.get('evaluate', True), dataaugmentations=dataaugmentations, experimenttracking=et, use_test_as_valid=config['bench'].get('use_test_as_valid', False)) if et is not None: et.stop() log(LearningModelLoggerFields(name=trainresult.name, i=i, params=trainresult.params, mem_params=trainresult.mem_params, accuracy=trainresult.acc)) for postprocessing in config.get('postprocessing', []): ppp = {k: v for k,v in postprocessing.get('params', {}).items()} # Workaround tomlkit bug where some nested dict would lose their items trainresult, m = getattr(qualia.postprocessing, postprocessing['kind'])(**ppp)( trainresult=trainresult, model_conf=m) if trainresult.log: log(LearningModelLoggerFields(name=trainresult.name, i=i, params=trainresult.params, mem_params=trainresult.mem_params, accuracy=trainresult.acc)) if postprocessing.get('export', False): trainresult.framework.export(trainresult.model, f'{trainresult.name}_r{i}') return loggers