Source code for qualia_codegen_core.Allocator

# Copyright 2021 (c) Pierre-Emmanuel Novac <penovac@unice.fr> Université Côte d'Azur, CNRS, LEAT. All rights reserved.

from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING

from .graph.layers import TFlattenLayer

if TYPE_CHECKING:
    import sys
    if sys.version_info >= (3, 11):
        from typing import Self
    else:
        from typing_extensions import Self

    from .graph import ModelGraph
    from .graph.LayerNode import LayerNode

logger = logging.getLogger(__name__)

[docs] class Allocator:
[docs] @dataclass class AllocInfo: node: LayerNode input_ai: list[Self] keep_until: int overwrite_input: bool
def __call__(self, modelgraph: ModelGraph) -> dict[str, list[list[LayerNode]] | dict[LayerNode, int]] | None: pools: list[list[Allocator.AllocInfo]] = [[]] alloc_info_list: list[Allocator.AllocInfo] = [] for node in modelgraph.nodes[:-1]: # No allocation for input and last layer, allocated by caller overwrite_input = isinstance(node.layer, TFlattenLayer) # First layer is assumed to take input from outside model inlayersi = [alloc_info_list[modelgraph.nodes.index(innode)] for innode in node.innodes] outlayersi = [modelgraph.nodes.index(outnode) for outnode in node.outnodes] keep_until = max(outlayersi) alloc_info_list.append(Allocator.AllocInfo( node, inlayersi, keep_until, overwrite_input)) for i, a in enumerate(alloc_info_list[1:]): # Skip InputLayer if i == 0: # first layer after input layer, assume it takes input from outside model pools[0].append(a) elif a.overwrite_input: if len(a.input_ai) != 1: logger.error('Need exactly one inpurt layer when overwriting input') return None # Find which pool contains input inp = [p for p in pools if a.input_ai[0] in p] if len(inp) != 1: logger.error('Input layer must be allocated in exactly one pool') return None inp[0].append(a) else: # Find pools not containing inputs ap = [p for p in pools for iai in a.input_ai if iai not in p] # Find pools that can be overwritten op = [p for p in ap if p[-1].keep_until <= i] if len(op) < 1: # no free pool, allocate new one pools.append([a]) else: # Add to first usable pool — maybe possible to optimize allocation size op[0].append(a) return { 'pools': [[a.node for a in p] for p in pools], 'index': {a.node: (i + 1) for i, p in enumerate(pools) for a in p}, }