from __future__ import annotations
import copy
from pathlib import Path
from typing import Any, NamedTuple
import colorful as cf # type: ignore[import-untyped]
from qualia_core.typing import TYPE_CHECKING
from qualia_core.utils.merge_dict import merge_dict
if TYPE_CHECKING:
from qualia_core.dataaugmentation.DataAugmentation import DataAugmentation # noqa: TCH001
from qualia_core.datamodel.RawDataModel import RawDataModel # noqa: TCH001
from qualia_core.learningframework.LearningFramework import LearningFramework # noqa: TCH001
from qualia_core.typing import ConfigDict
from qualia_core.utils.logger import Logger # noqa: TCH001
from qualia_core.utils.plugin import QualiaComponent # noqa: TCH001
[docs]
class ParameterResearchLoggerFields(NamedTuple):
name: str
params: int
mem_params: int
accuracy: float
tune_config: str
[docs]
class ParameterResearch:
def __call__(self,
qualia: QualiaComponent,
learningframework: LearningFramework[Any],
dataaugmentations: list[DataAugmentation],
data: RawDataModel,
config: ConfigDict) -> dict[str, Logger[ParameterResearchLoggerFields]]:
import optuna
from qualia_core.qualia import train
from qualia_core.utils.logger import CSVLogger
loggers: dict[str, CSVLogger[ParameterResearchLoggerFields]] = {}
log: CSVLogger[ParameterResearchLoggerFields] = CSVLogger('parameter_research')
loggers['parameter_research'] = log
# Write column names
log.fields = ParameterResearchLoggerFields
(Path('out')/'learningmodel').mkdir(parents=True, exist_ok=True)
experimenttracking = config.get('experimenttracking', None)
if config['bench']['first_run'] != config['bench']['last_run']:
print('Warning: parameter_research does not follow first_run/last_run')
# Filter out disabled models
config['model'] = [m for m in config['model'] if not m.get('disabled', False)]
if isinstance(config['model'], list):
if len(config['model']) > 1:
raise ValueError('parameter_research only runs on a single model definition')
else:
config['model'] = config['model'][0]
if config['model'].get('disabled', False):
raise ValueError('model must not be disabled for parameter_research')
def training_function(trial):
def gen_tune_config(params): # In-place
if isinstance(params, list):
for i, v in enumerate(params):
if 'kind' in v and 'params' in v: # "leaf" to call trial function on
params[i] = getattr(trial, v['kind'])(**v['params'])
else:
gen_tune_config(i)
else:
print(params)
for k, v in params.items():
if 'kind' in v and 'params' in v: # "leaf" to call trial function on
params[k] = getattr(trial, v['kind'])(**v['params'])
else:
gen_tune_config(v)
return params
print(f'{cf.bold}Hyperparameter research for {cf.blue}{config["model"]["name"]}{cf.close_fg_color}, trial {cf.red}{trial.number}{cf.reset}')
tune_config = gen_tune_config(copy.deepcopy(config['parameter_research']['trial']))
print(f'{cf.bold}Trial params:{cf.reset} {tune_config=}')
m = copy.deepcopy(tune_config)
m = merge_dict(m, config['model'])
print(f'{cf.bold}Model params:{cf.reset} {m=}')
et = None
if experimenttracking is not None:
et = getattr(learningframework.experimenttrackings, experimenttracking['kind'])(**experimenttracking.get('params', {}))
et.hyperparameters = {'config': config, 'model': m, 'i': trial.number}
trainresult = train(datamodel=data,
train_epochs=m.get('epochs', None),
iteration=trial.number,
model_name=m["name"],
model=getattr(learningframework.learningmodels, m['kind']),
model_params=m.get('params', {}),
batch_size=m.get('batch_size', None),
optimizer=m.get('optimizer', {}),
framework=learningframework,
load=m.get('load', False),
train=m.get('train', True),
evaluate=m.get('evaluate', True),
dataaugmentations=dataaugmentations,
experimenttracking=et)
if et is not None:
et.stop()
log(ParameterResearchLoggerFields(name=trainresult.name,
params=trainresult.params,
mem_params=trainresult.mem_params,
accuracy=trainresult.acc,
tune_config=str(tune_config)))
for postprocessing in config.get('postprocessing', []):
ppp = {k: v for k,v in postprocessing.get('params', {}).items()} # Workaround tomlkit bug where some nested dict would lose their items
trainresult = getattr(qualia.postprocessing, postprocessing['kind'])(**ppp)(
trainresult=trainresult,
model_conf=m)
if trainresult.log:
log(ParameterResearchLoggerFields(name=trainresult.name,
params=trainresult.params,
mem_params=trainresult.mem_params,
accuracy=trainresult.acc,
tune_config=str(tune_config)))
if postprocessing.get('export', False):
trainresult.framework.export(trainresult.model, f'{trainresult.name}_r{i}')
return trainresult.acc
if config['parameter_research'].get('study', {}).get('load', False):
study = optuna.load_study(**config['parameter_research'].get('study', {}).get('params', {}))
else:
study = optuna.create_study(**config['parameter_research'].get('study', {}).get('params', {}))
study.optimize(training_function, **config['parameter_research'].get('optimize', {}).get('params', {}))
return loggers