Source code for qualia_plugin_snn.postprocessing.OperationCounter

"""Provide the OperationCounter postprocessing module based on Lemaire et al., 2022."""

from __future__ import annotations

import dataclasses
import logging
import math
import sys
from dataclasses import dataclass
from typing import Literal, NamedTuple

from qualia_core.typing import TYPE_CHECKING, ModelConfigDict
from qualia_core.utils.logger import Logger
from qualia_core.utils.logger.CSVFormatter import CSVFormatter

from qualia_plugin_snn.experimenttracking.QualiaDatabase import QualiaDatabase
from qualia_plugin_snn.learningmodel.pytorch.SNN import SNN

from .EnergyEstimationMetric import EnergyEstimationMetric

# We are inside a TYPE_CHECKING block but our custom TYPE_CHECKING constant triggers TCH001-TCH003 so ignore them
if TYPE_CHECKING:
    from qualia_codegen_core.graph import ModelGraph
    from qualia_core.qualia import TrainResult
    from torch.types import Number

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

logger = logging.getLogger(__name__)


[docs] class OperationCounterLoggerFields(NamedTuple): """Interface object for CSV logging. Should contain the same fields as :class:`OperationMetrics` and returned by :meth:`OperationMetrics.asnamedtuple`. :param name: Layer name :param syn_acc: Number of accumulate operations for synaptic computation :param syn_mac: Number of multiply-accumulate operations for synaptic computation :param addr_acc: Number of accumulate operations for addressing :param addr_mac: Number of multiply-accumulate operations for addressing :param total_acc: Total number of accumulate operations :param total_mac: Total number of multiply-accumulate operations :param mem_read: Number of memory write operations :param mem_write: Number of memory read operations :param input_spikerate: Average input spike rate per timestep :param output_spikerate: Average output spike rate per timestep :param input_count: Input count per timestep :param output_count: Output count per timestep :param input_is_binary: If input tensor only contains binary values, i.e., spikes :param output_is_binary: If output tensor only contains binary values, i.e., spikes :param is_sj: If the layer is a SpikingJelly layer and has been processed as part of a Spiking Neural Network """ #: Layer name name: str #: Number of accumulate operations for synaptic computation syn_acc: float #: Number of multiply-accumulate operations for synaptic computation syn_mac: float #: Number of accumulate operations for addressing addr_acc: float #: Number of multiply-accumulate operations for addressing addr_mac: float #: Total number of accumulate operations total_acc: float #: Total number of multiply-accumulate operations total_mac: float #: Number of memory write operations mem_write: float #: Number of memory read operations mem_read: float #: Average input spike rate per timestep input_spikerate: float | None #: Average output spike rate per timestep output_spikerate: float | None #: Input count per timestep input_count: Number | None #: Output count per timestep output_count: Number | None #: If input tensor only contains binary values, i.e., spikes input_is_binary: bool #: If output tensor only contains binary values, i.e., spikes output_is_binary: bool #: If the layer is a SpikingJelly layer and has been processed as part of a Spiking Neural Network is_sj: bool | Literal['Hybrid']
[docs] @dataclass class OperationMetrics: """Holds the computed average operations per inference for each layer. :param name: Layer name :param syn_acc: Number of accumulate operations for synaptic computation :param syn_mac: Number of multiply-accumulate operations for synaptic computation :param addr_acc: Number of accumulate operations for addressing :param addr_mac: Number of multiply-accumulate operations for addressing :param mem_read: Number of memory write operations :param mem_write: Number of memory read operations :param input_spikerate: Average input spike rate per timestep :param output_spikerate: Average output spike rate per timestep :param input_count: Input count per timestep :param output_count: Output count per timestep :param input_is_binary: If input tensor only contains binary values, i.e., spikes :param output_is_binary: If output tensor only contains binary values, i.e., spikes :param is_sj: If the layer is a SpikingJelly layer and has been processed as part of a Spiking Neural Network """ #: Layer name name: str #: Number of accumulate operations for synaptic computation syn_acc: float #: Number of multiply-accumulate operations for synaptic computation syn_mac: float #: Number of accumulate operations for addressing addr_acc: float #: Number of multiply-accumulate operations for addressing addr_mac: float #: Number of memory write operations mem_write: float #: Number of memory read operations mem_read: float #: Average input spike rate per timestep input_spikerate: float | None #: Average output spike rate per timestep output_spikerate: float | None #: Input count per timestep input_count: Number | None #: Output count per timestep output_count: Number | None #: If input tensor only contains binary values, i.e., spikes input_is_binary: bool #: If output tensor only contains binary values, i.e., spikes output_is_binary: bool #: If the layer is a SpikingJelly layer and has been processed as part of a Spiking Neural Network is_sj: bool | Literal['Hybrid'] @property def total_acc(self) -> float: """Total count of accumulate operations.""" return self.syn_acc + self.addr_acc @property def total_mac(self) -> float: """Total count of accumulate operations.""" return self.syn_mac + self.addr_mac
[docs] def asnamedtuple(self) -> OperationCounterLoggerFields: """Return the data from this class as a NamedTuple for use with the CSV logger. Instanciate a :class:`OperationCounterLoggerFields` object with all of this class fields and properties and return it. :return: the :class:`OperationCounterLoggerFields` with all data from this object copied into it """ return OperationCounterLoggerFields(**dataclasses.asdict(self), total_acc=self.total_acc, total_mac=self.total_mac)
[docs] def asdict(self) -> dict[str, str | Number | None]: """Return the data from this class as a dictionary. :return: a dictionary with each attribute and property of this dataclass as keys and the associated values """ return {**dataclasses.asdict(self), 'total_acc': self.total_acc, 'total_mac': self.total_mac}
[docs] class OperationCounter(EnergyEstimationMetric): r"""Operation counter metric. From `An Analytical Estimation of Spiking Neural Networks Energy Efficiency <https://arxiv.org/abs/2210.13107>`_, Lemaire et al. ICONIP2022. .. code-block:: bibtex @inproceedings{EnergyEstimationMetricICONIP2022, title = {An Analytical Estimation of Spiking Neural Networks Energy Efficiency}, author = {Lemaire, Edgar and Cordone, Loïc and Castagnetti, Andrea and Novac, Pierre-Emmanuel and Courtois, Jonathan and Miramond, Benoît}, booktitle = {Proceedings of the 29th International Conference on Neural Information Processing}, pages = {574--587}, year = {2023}, doi = {10.1007/978-3-031-30105-6_48}, series = {ICONIP}, } Supports sequential (non-residual) formal and spiking convolutional neural networks with the following layers: * :class:`torch.nn.Conv1d` * :class:`torch.nn.Conv2d` * :class:`torch.nn.Linear` * :class:`torch.nn.ReLU` for formal neural networks * :class:`spikingjelly.activation_based.neuron.IFNode` for spiking neural networks * :class:`spikingjelly.activation_based.neuron.LIFNode` for spiking neural networks """
[docs] def __init__(self, total_spikerate_exclude_nonbinary: bool = True) -> None: # noqa: FBT001, FBT002 """Construct :class:`qualia_plugin_snn.postprocessing.OperationCounter.OperationCounter`. :param total_spikerate_exclude_nonbinary: If True, exclude non-binary inputs/outputs from total spikerate computation """ super().__init__(mem_width=0, fifo_size=1, total_spikerate_exclude_nonbinary=total_spikerate_exclude_nonbinary)
[docs] def _compute_model_operations_fnn(self, modelgraph: ModelGraph) -> list[OperationMetrics]: """Compute the operations per inference for each layer of a formal neural network. Supports the following layers: * :class:`qualia_codegen_core.graph.layers.TConvLayer.TConvLayer` * :class:`qualia_codegen_core.graph.layers.TDenseLayer.TDenseLayer` * :class:`qualia_codegen_core.graph.layers.TAddLayer.TAddLayer` :meta public: :param modelgraph: Model to compute energy on :return: A list of OperationMetrics for each layer and a total with fields populated with operation estimation """ from qualia_codegen_core.graph.layers import TAddLayer, TConvLayer, TDenseLayer, TFlattenLayer oms: list[OperationMetrics] = [] for node in modelgraph.nodes: if isinstance(node.layer, TFlattenLayer): # Flatten is assumed to not do anything for on-target inference om = OperationMetrics(name=node.layer.name, syn_acc=0, syn_mac=0, addr_acc=0, addr_mac=0, mem_read=0, mem_write=0, input_spikerate=None, output_spikerate=None, input_count=None, output_count=None, input_is_binary=False, output_is_binary=False, is_sj=False) elif isinstance(node.layer, TConvLayer): om = OperationMetrics(name=node.layer.name, syn_mac=self._mac_ops_conv_fnn(node.layer), syn_acc=self._acc_ops_conv_fnn(node.layer), addr_mac=self._mac_addr_conv_fnn(node.layer), addr_acc=self._acc_addr_conv_fnn(node.layer), mem_read=(self._rdin_conv_fnn(node.layer) + self._rdweights_conv_fnn(node.layer) + self._rdbias_conv_fnn(node.layer)), mem_write=self._wrout_conv_fnn(node.layer), input_spikerate=None, output_spikerate=None, input_count=None, output_count=None, input_is_binary=False, output_is_binary=False, is_sj=False) elif isinstance(node.layer, TDenseLayer): om = OperationMetrics(name=node.layer.name, syn_mac=self._mac_ops_fc_fnn(node.layer), syn_acc=self._acc_ops_fc_fnn(node.layer), addr_mac=self._mac_addr_fc_fnn(node.layer), addr_acc=self._acc_addr_fc_fnn(node.layer), mem_read=(self._rdin_fc_fnn(node.layer) + self._rdweights_fc_fnn(node.layer) + self._rdbias_fc_fnn(node.layer)), mem_write=self._wrout_fc_fnn(node.layer), input_spikerate=None, output_spikerate=None, input_count=None, output_count=None, input_is_binary=False, output_is_binary=False, is_sj=False) elif isinstance(node.layer, TAddLayer): # Assume element-wise addition of inputs, meaning we need to read inputs and write outputs om = OperationMetrics(name=node.layer.name, syn_mac=0, syn_acc=self._acc_ops_add_fnn(node.layer), addr_mac=self._mac_addr_add_fnn(node.layer), addr_acc=self._acc_addr_add_fnn(node.layer), mem_read=self._rdin_add_fnn(node.layer), mem_write=self._wrout_add_fnn(node.layer), input_spikerate=None, output_spikerate=None, input_count=None, output_count=None, input_is_binary=False, output_is_binary=False, is_sj=False) else: logger.warning('%s skipped, result may be inaccurate', node.layer.name) continue oms.append(om) om_total = OperationMetrics(name='Total', syn_acc=sum(om.syn_acc for om in oms), syn_mac=sum(om.syn_mac for om in oms), addr_acc=sum(om.addr_acc for om in oms), addr_mac=sum(om.addr_mac for om in oms), mem_read=sum(om.mem_read for om in oms), mem_write=sum(om.mem_write for om in oms), input_spikerate=None, output_spikerate=None, input_count=None, output_count=None, input_is_binary=False, output_is_binary=False, is_sj=False) oms.append(om_total) return oms
[docs] def _compute_model_operations_snn(self, # noqa: PLR0913, C901, PLR0912, PLR0915 modelgraph: ModelGraph, input_spikerates: dict[str, float], output_spikerates: dict[str, float], input_is_binary: dict[str, bool], output_is_binary: dict[str, bool], input_counts: dict[str, Number], output_counts: dict[str, Number], is_module_sj: dict[str, bool], timesteps: int) -> list[OperationMetrics]: """Compute the operations per inference for each layer of a spiking neural network. Supports the following layers: * :class:`qualia_codegen_core.graph.layers.TConvLayer.TConvLayer` * :class:`qualia_codegen_core.graph.layers.TDenseLayer.TDenseLayer` Input spike rates are per-timestep, this function multiplies by the number of timesteps to get the spike rates per infernce which are used by the operation count functions. :meta public: :param modelgraph: Model to computer operations on :param input_spikerate: Dict of layer names and average spike per input per timestep for the layer :param output_spikerate: Dict of layer names and average spike per output per timestep for the layer :param input_is_binary: Dict of layer names and whether its input is binary (spike) or not :param output_is_binary: Dict of layer names and whether its output is binary (spike) or not :param input_counts: Dict of layer names and number of inputs for the layer :param output_counts: Dict of layer names and number of outputs for the layer :param is_module_sj: Whether the layer is a spiking layer (a SpikingJelly module) :param timesteps: Number of timesteps :return: A list of OperationMetrics for each layer and a total with fields populated with operation count """ from qualia_codegen_core.graph.layers import TAddLayer, TConvLayer, TDenseLayer, TFlattenLayer, TInputLayer from qualia_codegen_plugin_snn.graph.layers import TIfLayer, TLifLayer oms: list[OperationMetrics] = [] for node in modelgraph.nodes: # Skip dummy input layer if isinstance(node.layer, TInputLayer): continue # TIfLayer is completely hidden in case the previous layer is Conv/Dense since it already contains the required info if isinstance(node.layer, TIfLayer) and len(node.innodes) > 0 and isinstance(node.innodes[0].layer, (TConvLayer, TDenseLayer)): continue # Account for timesteps here since the spikerate has been averaged over timesteps input_spikerate = input_spikerates[node.layer.name] * timesteps # If no If activation, no output spikes are generated so no reset operation or writing to the output queue output_spikerate = (output_spikerates[node.layer.name] * timesteps if len(node.outnodes) > 0 and isinstance(node.outnodes[0].layer, TIfLayer) else 0) leak = len(node.outnodes) > 0 and isinstance(node.outnodes[0].layer, TLifLayer) om: OperationMetrics | None = None if isinstance(node.layer, TFlattenLayer): # Flatten is assumed to not do anything for on-target inference om = OperationMetrics(name=node.layer.name, syn_acc=0, syn_mac=0, addr_acc=0, addr_mac=0, mem_read=0, mem_write=0, input_spikerate=input_spikerates[node.layer.name], output_spikerate=output_spikerates[node.layer.name], input_count=input_counts[node.layer.name], output_count=output_counts[node.layer.name], input_is_binary=input_is_binary[node.layer.name], output_is_binary=output_is_binary[node.layer.name], is_sj=is_module_sj[node.layer.name]) elif is_module_sj[node.layer.name]: is_sj: bool | Literal['Hybrid'] if isinstance(node.layer, TConvLayer): if not input_is_binary[node.layer.name]: # Non-binary dense input: # Computed as sparse input over a single timestep but with MAC operations for membrane potentials increment syn_acc = (self._acc_ops_conv_fnn(node.layer) # Bias + output_spikerate * math.prod(node.layer.output_shape[0][1:])) # Reset syn_mac = (self._mac_ops_conv_fnn(node.layer) * input_spikerates[node.layer.name] # Input * Weight MACs + (math.prod(node.layer.output_shape[0][1:]) if leak else 0)) # Leak addr_acc = self._acc_addr_conv_snn(node.layer, input_spikerates[node.layer.name]) addr_mac = self._mac_addr_conv_snn(node.layer, input_spikerates[node.layer.name]) mem_read = (self._rdin_snn(node.layer, input_spikerates[node.layer.name]) + self._rdweights_conv_snn(node.layer, input_spikerate) + (self._rdbias_conv_snn(node.layer, timesteps) if node.layer.use_bias else 0) + self._rdpot_conv_snn(node.layer, input_spikerate, timesteps)) is_sj = 'Hybrid' else: syn_acc = (self._acc_ops_conv_snn(node.layer, input_spikerate, output_spikerate, timesteps)) syn_mac = (self._mac_ops_conv_snn(node.layer, timesteps, leak)) addr_acc = self._acc_addr_conv_snn(node.layer, input_spikerate) addr_mac = self._mac_addr_conv_snn(node.layer, input_spikerate) mem_read = (self._rdin_snn(node.layer, input_spikerate) + self._rdweights_conv_snn(node.layer, input_spikerate) + (self._rdbias_conv_snn(node.layer, timesteps) if node.layer.use_bias else 0) + self._rdpot_conv_snn(node.layer, input_spikerate, timesteps)) is_sj = True om = OperationMetrics(name=node.layer.name, syn_acc=syn_acc, syn_mac=syn_mac, addr_acc=addr_acc, addr_mac=addr_mac, mem_read=mem_read, mem_write = (self._wrout_snn(node.layer, output_spikerate) + self._wrpot_conv_snn(node.layer, input_spikerate, timesteps)), input_spikerate=input_spikerates[node.layer.name], output_spikerate=output_spikerates[node.layer.name], input_count=input_counts[node.layer.name], output_count=output_counts[node.layer.name], input_is_binary=input_is_binary[node.layer.name], output_is_binary=output_is_binary[node.layer.name], is_sj=is_sj, ) elif isinstance(node.layer, TDenseLayer): if not input_is_binary[node.layer.name]: # Non-binary dense input: # Computed as sparse input over a single timestep but with MAC operations for membrane potentials increment syn_acc = (self._acc_ops_fc_fnn(node.layer) # Bias + output_spikerate * math.prod(node.layer.output_shape[0][1:])) # Reset syn_mac = (self._mac_ops_fc_fnn(node.layer) * input_spikerates[node.layer.name] # Input * Weight MACs + (math.prod(node.layer.output_shape[0][1:]) if leak else 0)) # Leak addr_acc = self._acc_addr_fc_snn(node.layer, input_spikerates[node.layer.name]) addr_mac = self._mac_addr_fc_snn(node.layer) mem_read = (self._rdin_snn(node.layer, input_spikerates[node.layer.name]) + self._rdweights_fc_snn(node.layer, input_spikerate) + (self._rdbias_fc_snn(node.layer, timesteps) if node.layer.use_bias else 0) + self._rdpot_fc_snn(node.layer, input_spikerate, timesteps)) is_sj = 'Hybrid' else: syn_acc = (self._acc_ops_fc_snn(node.layer, input_spikerate, output_spikerate, timesteps)) syn_mac = (self._mac_ops_fc_snn(node.layer, timesteps, leak)) addr_acc = self._acc_addr_fc_snn(node.layer, input_spikerate) addr_mac = self._mac_addr_fc_snn(node.layer) mem_read = (self._rdin_snn(node.layer, input_spikerate) + self._rdweights_fc_snn(node.layer, input_spikerate) + (self._rdbias_fc_snn(node.layer, timesteps) if node.layer.use_bias else 0) + self._rdpot_fc_snn(node.layer, input_spikerate, timesteps)) is_sj = True om = OperationMetrics(name=node.layer.name, syn_acc=syn_acc, syn_mac=syn_mac, addr_acc=addr_acc, addr_mac=addr_mac, mem_read=mem_read, mem_write = (self._wrout_snn(node.layer, output_spikerate) + self._wrpot_fc_snn(node.layer, input_spikerate, timesteps)), input_spikerate=input_spikerates[node.layer.name], output_spikerate=output_spikerates[node.layer.name], input_count=input_counts[node.layer.name], output_count=output_counts[node.layer.name], input_is_binary=input_is_binary[node.layer.name], output_is_binary=output_is_binary[node.layer.name], is_sj=is_sj, ) elif isinstance(node.layer, TAddLayer): om = OperationMetrics(name=node.layer.name, syn_acc=0, syn_mac=0, addr_acc=0, addr_mac=0, mem_read=0, mem_write=0, input_spikerate=input_spikerates[node.layer.name], output_spikerate=output_spikerates[node.layer.name], input_count=input_counts[node.layer.name], output_count=output_counts[node.layer.name], input_is_binary=input_is_binary[node.layer.name], output_is_binary=output_is_binary[node.layer.name], is_sj=is_module_sj[node.layer.name], ) else: # noqa: PLR5501 keep separate if for clarity and consistency if isinstance(node.layer, TConvLayer): om = OperationMetrics(name=node.layer.name, syn_acc=self._acc_ops_conv_fnn(node.layer), syn_mac=self._mac_ops_conv_fnn(node.layer), addr_acc=self._acc_addr_conv_fnn(node.layer), addr_mac=self._mac_addr_conv_fnn(node.layer), mem_read=(self._rdin_conv_fnn(node.layer) + self._rdweights_conv_fnn(node.layer) + (self._rdbias_conv_fnn(node.layer) if node.layer.use_bias else 0)), mem_write=self._wrout_conv_fnn(node.layer), input_spikerate=input_spikerates[node.layer.name], output_spikerate=output_spikerates[node.layer.name], input_count=input_counts[node.layer.name], output_count=output_counts[node.layer.name], input_is_binary=input_is_binary[node.layer.name], output_is_binary=output_is_binary[node.layer.name], is_sj=is_module_sj[node.layer.name], ) elif isinstance(node.layer, TDenseLayer): om = OperationMetrics(name=node.layer.name, syn_acc=self._acc_ops_fc_fnn(node.layer), syn_mac=self._mac_ops_fc_fnn(node.layer), addr_acc=self._acc_addr_fc_fnn(node.layer), addr_mac=self._mac_addr_fc_fnn(node.layer), mem_read=(self._rdin_fc_fnn(node.layer) + self._rdweights_fc_fnn(node.layer) + (self._rdbias_fc_fnn(node.layer) if node.layer.use_bias else 0)), mem_write=self._wrout_fc_fnn(node.layer), input_spikerate=input_spikerates[node.layer.name], output_spikerate=output_spikerates[node.layer.name], input_count=input_counts[node.layer.name], output_count=output_counts[node.layer.name], input_is_binary=input_is_binary[node.layer.name], output_is_binary=output_is_binary[node.layer.name], is_sj=is_module_sj[node.layer.name], ) elif isinstance(node.layer, TAddLayer): om = OperationMetrics(name=node.layer.name, syn_acc=self._acc_ops_add_fnn(node.layer), syn_mac=self._mac_ops_add_fnn(node.layer), addr_acc=self._acc_addr_add_fnn(node.layer), addr_mac=self._mac_addr_add_fnn(node.layer), mem_read=self._rdin_add_fnn(node.layer), mem_write=self._wrout_add_fnn(node.layer), input_spikerate=input_spikerates[node.layer.name], output_spikerate=output_spikerates[node.layer.name], input_count=input_counts[node.layer.name], output_count=output_counts[node.layer.name], input_is_binary=input_is_binary[node.layer.name], output_is_binary=output_is_binary[node.layer.name], is_sj=is_module_sj[node.layer.name], ) if om is None: # We do not know how to handle this layer, set energy values to 0 logger.warning('%s not handled, result may be inaccurate', node.layer.name) om = OperationMetrics(name=node.layer.name, syn_acc=0, syn_mac=0, addr_acc=0, addr_mac=0, mem_read=0, mem_write=0, input_spikerate=input_spikerates[node.layer.name], output_spikerate=output_spikerates[node.layer.name], input_count=input_counts[node.layer.name], output_count=output_counts[node.layer.name], input_is_binary=input_is_binary[node.layer.name], output_is_binary=output_is_binary[node.layer.name], is_sj=is_module_sj[node.layer.name]) oms.append(om) total_is_sj: bool | Literal['Hybrid'] = (True if all(is_sj is True for is_sj in is_module_sj.values()) else False if all(not is_sj for is_sj in is_module_sj.values()) else 'Hybrid') om_total = OperationMetrics(name='Total', syn_acc=sum(om.syn_acc for om in oms), syn_mac=sum(om.syn_mac for om in oms), addr_acc=sum(om.addr_acc for om in oms), addr_mac=sum(om.addr_mac for om in oms), mem_read=sum(om.mem_read for om in oms), mem_write=sum(om.mem_write for om in oms), input_spikerate=input_spikerates['__TOTAL__'], output_spikerate=output_spikerates['__TOTAL__'], input_count=input_counts['__TOTAL__'], output_count=output_counts['__TOTAL__'], input_is_binary=input_is_binary[modelgraph.nodes[1].layer.name], # First layer after 'input' output_is_binary=output_is_binary[modelgraph.nodes[-1].layer.name], is_sj=total_is_sj, ) oms.append(om_total) return oms
[docs] def _operations_summary(self, oms: list[OperationMetrics]) -> str: """Generate a human-friendly text summary of the operations per layer. :meta public: :param oms: List of OperationMetrics per layer and the total :return: The text summary """ pad = 11 pad_name = max(len(om.name) for om in oms) header = f'{"Layer": <{pad_name}} |' header += f' {"Syn. Acc": <{pad}} | {"Syn. MAc": <{pad}} |' header += f' {"Addr. Acc": <{pad}} | {"Addr. MAc": <{pad}} |' header += f' {"Tot. Acc": <{pad}} | {"Tot. MAc": <{pad}} |' header += f' {"Mem. read": <{pad}} | {"Mem. write": <{pad}} |' header += f' {"SNN": <{pad}} | {"Input Spike Rate": <{pad}} | {"Output Spike Rate": <{pad}}\n' s = '—' * len(header) + '\n' s += header s += '—' * len(header) + '\n' for i, om in enumerate(oms): # Print in nJ, original values are in pJ s += f'{om.name: <{pad_name}.{pad_name}} |' s += f' {float(om.syn_acc): <{pad}.4} | {float(om.syn_mac): <{pad}.4} |' s += f' {float(om.addr_acc): <{pad}.4} | {float(om.addr_mac): <{pad}.4} |' s += f' {float(om.total_acc): <{pad}.4} | {float(om.total_mac): <{pad}.4} |' s += f' {float(om.mem_read): <{pad}.4} | {float(om.mem_write): <{pad}.4} |' s += f' {om.is_sj!s: <{pad}} |' input_spikerate_pad = 16 if om.input_is_binary else 11 if om.input_spikerate is not None: s += f' {om.input_spikerate: <{input_spikerate_pad}.4}' else: s += f' {"N/A": <{input_spikerate_pad}.16}' if not om.input_is_binary: s += ' (NB)' s += ' |' output_spikerate_pad = 16 if om.output_is_binary else 11 if om.output_spikerate is not None: s += f' {om.output_spikerate: <{output_spikerate_pad}.4}' else: s += f' {"N/A": <{output_spikerate_pad}.16}' if not om.output_is_binary: s += ' (NB)' s += '\n' s += ('-' if i < len(oms) - 2 else '—') * len(header) + '\n' s += ' NB = Non binary data' return s
[docs] @override def __call__(self, trainresult: TrainResult, model_conf: ModelConfigDict) -> tuple[TrainResult, ModelConfigDict]: """Compute operation count metric from Lemaire et al, 2022. First process the model to extract the graph and activity in case of SNN using :meth:`_process_model`. Then call either :meth:`_compute_model_operations_snn` or :meth:`_compute_model_operations_fnn` depending on whether the model is an SNN or an FNN. Print the resulting metrics and log them to a CSV file inside the `logs/<bench.name>/OperationCounter` directory. :meta public: :param trainresult: TrainResult containing the SNN or FNN model, the dataset and the training configuration :param model_conf: Unused :return: The unmodified trainresult """ (modelgraph, input_spikerates, output_spikerates, input_is_binary, output_is_binary, input_counts, output_counts, is_module_sj) = self._process_model(trainresult=trainresult) if modelgraph is None: return trainresult, model_conf if getattr(trainresult.model, 'is_snn', False) or isinstance(trainresult.model, SNN): if (input_spikerates is None or output_spikerates is None or input_is_binary is None or output_is_binary is None or input_counts is None or output_counts is None or is_module_sj is None): logger.error('SNN model detected but one of the SNN metric could not be computed') raise RuntimeError oms = self._compute_model_operations_snn(modelgraph, input_spikerates, output_spikerates, input_is_binary, output_is_binary, input_counts, output_counts, is_module_sj, trainresult.model.timesteps) else: oms = self._compute_model_operations_fnn(modelgraph) operationcsvlogger: Logger[OperationCounterLoggerFields] = Logger(name='OperationCounter', suffix=f'_{trainresult.name}.csv', formatter=CSVFormatter()) operationcsvlogger.fields = OperationCounterLoggerFields for om in oms: operationcsvlogger(om.asnamedtuple()) logger.info(('Estimated operation count for one inference and spike rate per neuron per timestep:\n%s'), self._operations_summary(oms)) if trainresult.experimenttracking and isinstance(trainresult.experimenttracking, QualiaDatabase): if not trainresult.model_hash: logger.error('Missing model hash, cannot record OperationCounter in QualiaDatabase') else: trainresult.experimenttracking.log_operationcounter(trainresult.model_hash, oms) return trainresult, model_conf