# Copyright 2021 (c) Pierre-Emmanuel Novac <penovac@unice.fr> Université Côte d'Azur, CNRS, LEAT. All rights reserved.
# April 29, 2021
from __future__ import annotations
import importlib
import logging
import sys
from collections.abc import Iterable
from itertools import zip_longest
from typing import TYPE_CHECKING, Callable, cast
from .LayerNode import LayerNode
from .layers import TBaseLayer, TInputLayer
if TYPE_CHECKING:
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
from torch import nn # noqa: I001 # torch must be imported before keras to avoid deadlock
import keras.Model # type: ignore[import-untyped] # No stubs for keras package
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
logger = logging.getLogger(__name__)
[docs]
class ModelGraph:
def __init__(self, nodes: list[LayerNode] | None = None) -> None:
super().__init__()
self.__nodes = nodes or []
[docs]
def add_node(self,
node: LayerNode,
innodes: Iterable[LayerNode] | None = None,
outnodes: Iterable[LayerNode] | None = None) -> None:
innodes = innodes or []
outnodes = outnodes or []
node.innodes.extend(
innode for innode in innodes if innode not in node.innodes) # could be nicer to use a set but we need to keep order
node.outnodes.extend(outnode for outnode in outnodes if outnode not in node.outnodes)
for innode in innodes:
if node not in innode.outnodes:
innode.outnodes.append(node)
for outnode in outnodes:
if node not in outnode.innodes:
outnode.innodes.append(node)
if node in self.__nodes:
logger.warning('Node already exists in graph')
self.__nodes.append(node)
[docs]
def delete_node(self, node: LayerNode) -> None:
for innode in node.innodes:
# Disconnect layer to remove from output of previous layer
index = innode.outnodes.index(node)
_ = innode.outnodes.pop(index)
# Connect outputs from layer to remove to output of previous layer
# Try to preserve insertion location and ordering
for i, e in enumerate(node.outnodes):
innode.outnodes.insert(index + i, e)
for outnode in node.outnodes:
# Disconnect layer to remove from input of next layer
index = outnode.innodes.index(node)
_ = outnode.innodes.pop(index)
# Connect inputs from layer to remove to input of next layer
# Try to preserve insertion location and ordering
for i, e in enumerate(node.innodes):
outnode.innodes.insert(index + i, e)
self.__nodes.remove(node) # Remove layer from list
# Delete each node for which predicate function is true
[docs]
def delete_node_if(self, predicate: Callable[[LayerNode], bool]) -> None:
to_delete = [n for n in self.nodes if predicate(n)]
for n in to_delete:
self.delete_node(n)
[docs]
def replace_node(self, oldnode: LayerNode, newnode: LayerNode) -> None:
self.add_node(newnode, oldnode.innodes, oldnode.outnodes)
self.delete_node(oldnode)
[docs]
def find_node_from_layer(self, layer: TBaseLayer) -> LayerNode | None:
nodes = [node for node in self.nodes if node.layer is layer]
if len(nodes) == 0:
return None
if len(nodes) > 1:
logger.warning('More than one node found for layer, returning the first one')
return nodes[0]
[docs]
def get_nodes_for_layers(self, layers: TBaseLayer | Iterable[TBaseLayer]) -> tuple[LayerNode | None, ...]:
if isinstance(layers, Iterable):
return tuple(n for layer in layers for n in self.get_nodes_for_layers(layer))
return (self.find_node_from_layer(layers), )
[docs]
def no_none_in_nodes(self, nodes: Iterable[LayerNode | None]) -> TypeGuard[Iterable[LayerNode]]:
return all(node is not None for node in nodes)
[docs]
def add_layer(self,
layer: TBaseLayer,
inlayers: list[TBaseLayer] | None = None,
outlayers: list[TBaseLayer] | None = None) -> None:
if self.find_node_from_layer(layer):
return
inlayers = inlayers or []
outlayers = outlayers or []
# Special case for missing InputLayer in case of Sequential model
for inlayer in inlayers:
# If InputLayer does not exist in graph
if isinstance(inlayer, TInputLayer) and inlayer not in [n.layer for n in self.nodes]:
self.add_layer(inlayer) # No input, linking to output is handled by the next add_node
innodes = self.get_nodes_for_layers(inlayers)
if not self.no_none_in_nodes(innodes):
logger.error('Input node for layer %s not found', layer.name)
return
outnodes = self.get_nodes_for_layers(outlayers)
if not self.no_none_in_nodes(outnodes):
logger.error('Output node for layer %s not found', layer.name)
return
self.add_node(LayerNode(layer), innodes, outnodes)
@property
def nodes(self) -> list[LayerNode]:
return self.__nodes
@override
def __str__(self) -> str:
pad = 48
header = f'{"Inputs": <{pad}} | {"Layer": <{pad}} | {"Outputs": <{pad}} | {"Input shape": <{pad}} | {"Output shape": <{pad}}\n' # noqa: E501
s = '—' * len(header) + '\n'
s += header
s += '—' * len(header) + '\n'
for node in self.nodes:
for inlayername, layername, outlayername, inshape, outshape in zip_longest(
[n.layer.name for n in node.innodes],
[node.layer.name],
[n.layer.name for n in node.outnodes],
[str(s) for s in node.layer.input_shape],
[str(node.layer.output_shape)], fillvalue=''):
s += f'{inlayername: <{pad}} | {layername: <{pad}} | {outlayername: <{pad}} | {inshape: <{pad}} | {outshape: <{pad}}\n' # noqa: E501
s += '-' * len(header) + '\n'
return s
[docs]
def graphviz(self) -> str | None:
try:
from graphviz import Digraph # type: ignore[import-untyped] # Graphviz is missing py.typed xflr6/graphviz#180
except ImportError:
logger.warning('Graphviz not available')
return None
grph = Digraph()
for node in self.nodes:
for out in node.outnodes:
grph.edge(node.layer.name, out.layer.name)
return cast(str, grph.source)
[docs]
@classmethod
def auto_detect(cls, obj: keras.Model | nn.Module) -> ModelGraph:
if importlib.util.find_spec('torch') is not None:
from torch import nn
if isinstance(obj, nn.Module):
from .TorchModelGraph import TorchModelGraph
return TorchModelGraph(obj)
from .KerasModelGraph import KerasModelGraph
return KerasModelGraph(obj)