Source code for qualia_core.postprocessing.PostProcessing

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Callable, Generic

from qualia_core.learningframework.LearningFramework import LearningFramework, T
from qualia_core.typing import TYPE_CHECKING, ModelConfigDict

if TYPE_CHECKING:
    from qualia_core.qualia import TrainResult  # noqa: TC001

[docs] class PostProcessing(ABC, Generic[T]): @abstractmethod def __call__(self, trainresult: TrainResult, model_conf: ModelConfigDict) -> tuple[TrainResult, ModelConfigDict]: ...
[docs] def process_name(self, name: str) -> str: return name
[docs] def process_framework(self, framework: LearningFramework[T]) -> LearningFramework[Any]: return framework
[docs] def process_mem_params(self, mem_params: int) -> Callable[[LearningFramework[T], T], int]: def f(_: LearningFramework[T], __: T) -> int: return mem_params return f
[docs] def process_model(self, model: T, model_conf: ModelConfigDict, framework: LearningFramework[T]) -> tuple[T, ModelConfigDict]: return model, model_conf