Source code for qualia_core.utils.config

from __future__ import annotations

import logging

from pydantic import TypeAdapter, ValidationError

from qualia_core.typing import TYPE_CHECKING, ConfigDict
from qualia_core.utils.merge_dict import merge_dict

if TYPE_CHECKING:
    from pathlib import Path  # noqa: TCH003

    from qualia_core.typing import RecursiveConfigDict

logger = logging.getLogger(__name__)

[docs] def validate_config_dict(config: RecursiveConfigDict) -> ConfigDict | None: ta = TypeAdapter(ConfigDict) try: validated_config: ConfigDict = ta.validate_python(config, strict=False) extra_dict_a_keys = [k for k in config if k not in validated_config] extra_dict_b_keys = [k for k in validated_config if k not in config] non_matching_values = {k: v for k, v in validated_config.items() if k in config and v != config[k]} if extra_dict_a_keys: logger.error('Missing keys after validation: %s', extra_dict_a_keys) if extra_dict_b_keys: logger.error('Extra keys after validation: %s', extra_dict_b_keys) if non_matching_values: for k, v in non_matching_values.items(): logger.error('Different value after validation for key: %s', k) logger.error('Before validation:\n%s', config[k]) logger.error('After validation:\n%s', v) if extra_dict_a_keys or extra_dict_b_keys or non_matching_values: return None except ValidationError: logger.exception('Error when validating configuration.') return None return validated_config
[docs] def parse_config(path: Path) -> tuple[RecursiveConfigDict, str]: import tomlkit with path.open() as f: toml_config = tomlkit.parse(f.read()) # Convert to built-in Python types config: RecursiveConfigDict = toml_config.unwrap() # Merge settings from template into individual models if 'model_template' in config: models = config['model'] model_template = config['model_template'] if not isinstance(models, list): logger.error('`model` must be a list, got: %s', type(models)) raise TypeError if not isinstance(model_template, dict): logger.error('`model_template` must be a dict, got: %s', type(model_template)) raise TypeError for i, model in enumerate(models): if not isinstance(model, dict): logger.error('`model[%d]` must be a dict, got: %s', i, type(model)) raise TypeError models[i] = merge_dict(model, model_template) return config, path.stem