from __future__ import annotations
import itertools
from typing import Any, NamedTuple
import colorful as cf # type: ignore[import-untyped]
from qualia_core.qualia import gen_tag, instantiate_model, prepare_deploy
from qualia_core.typing import TYPE_CHECKING
from qualia_core.utils.logger import CSVLogger, Logger
if TYPE_CHECKING:
from types import ModuleType # noqa: TCH003
from qualia_core.datamodel.RawDataModel import RawDataModel # noqa: TCH001
from qualia_core.learningframework.LearningFramework import LearningFramework # noqa: TCH001
from qualia_core.postprocessing.Converter import Converter # noqa: TCH001
from qualia_core.typing import ConfigDict
from qualia_core.utils.plugin import QualiaComponent # noqa: TCH001
[docs]
class PrepareDeployLoggerFields(NamedTuple):
name: str
quantize: str
optimize: str
compress: int
[docs]
class PrepareDeploy:
def __call__(self, # noqa: PLR0913
qualia: QualiaComponent,
learningframework: LearningFramework[Any],
converter: type[Converter[Any]],
deployers: ModuleType,
data: RawDataModel,
config: ConfigDict) -> dict[str, Logger[PrepareDeployLoggerFields]]:
loggers: dict[str, Logger[PrepareDeployLoggerFields]] = {}
log: CSVLogger[PrepareDeployLoggerFields] = CSVLogger('prepare_deploy')
loggers['prepare_deploy'] = log
# Write column names
log.fields = PrepareDeployLoggerFields
for i in range(config['bench']['first_run'], config['bench']['last_run']+1):
for m, q, o, c in itertools.product(config['model'],
config['deploy']['quantize'],
config['deploy'].get('optimize', ['']),
config['deploy'].get('compress', [1])):
if m.get('disabled', False):
continue
# Postprocessings can change model name
model_name = m['name']
for postprocessing in config.get('postprocessing', []):
# Workaround tomlkit bug where some nested dict would lose their items
ppp = {k: v for k,v in postprocessing.get('params', {}).items()}
pp = getattr(qualia.postprocessing, postprocessing['kind'])(**ppp)
model_name = pp.process_name(model_name)
learningframework = pp.process_framework(learningframework)
# Instantiate model
model = instantiate_model(dataset=data.sets.test,
framework=learningframework,
model=getattr(learningframework.learningmodels, m['kind']),
model_params=m.get('params', {}),
model_name=model_name,
iteration=i,
load=False, # Model params will be loaded after postprocessings
)
# Postprocessings can change model topology with PyTorch, needs to be done after instantiating model with new name
for postprocessing in config.get('postprocessing', []):
pp = getattr(qualia.postprocessing, postprocessing['kind'])(**postprocessing.get('params', {}))
model, m = pp.process_model(model, m, framework=learningframework)
# Show model architecture
learningframework.summary(model)
# Load weights after topology optionally changed
model = learningframework.load(f'{model_name}_r{i}', model)
print(f'{cf.bold}Preparing {cf.blue}{model_name}{cf.close_fg_color}, run {cf.red}{i}{cf.reset}')
r = prepare_deploy(
datamodel=data,
model_kind=m['kind'],
model_name=model_name,
model=model,
framework=learningframework,
iteration=i,
quantize=q,
optimize=o,
compress=c,
deploy_target=config['deploy']['target'],
tag=gen_tag(model_name, q, o, i, c),
converter=converter,
converter_params=config['deploy'].get('converter', {}).get('params', {}),
deployers=deployers,
deployer_params=config['deploy'].get('deployer', {}).get('params', {}),
representative_dataset=data.sets.train.x)
if not r:
continue
log(PrepareDeployLoggerFields(name=model_name, quantize=q, optimize=o, compress=c))
return loggers