Source code for qualia_plugin_snn.postprocessing.FuseBatchNorm

"""Provide a postprocessing module to fuse BatchNorm layers to the previous convolution with SpikingJelly layers support."""

from __future__ import annotations

import sys

import spikingjelly.activation_based.base as sjb  # type: ignore[import-untyped]
import spikingjelly.activation_based.layer as sjl  # type: ignore[import-untyped]
import spikingjelly.activation_based.neuron as sjn  # type: ignore[import-untyped]
from qualia_core.postprocessing.FuseBatchNorm import FuseBatchNorm as FuseBatchNormQualiaCore
from qualia_core.typing import TYPE_CHECKING
from torch.fx.graph_module import GraphModule

# We are inside a TYPE_CHECKING block but our custom TYPE_CHECKING constant triggers TCH001-TCH003 so ignore them
if TYPE_CHECKING:
    from qualia_core.learningframework.PyTorch import PyTorch
    from torch import nn

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

[docs] class GraphModuleStepModule(GraphModule, sjb.StepModule): # type: ignore[misc] """Mixin of :class:`torch.fx.graph_module.GraphModule` and :class:`spikingjelly.activation_based.base.StepModule`. Used by :class:`FuseBatchNorm` to return a module that still inherits from class:`spikingjelly.activation_based.base.StepModule`. This is required so that SpikingJelly's network-wise operations such as :meth:`spikingjelly.activation_based.functional.set_step_mode` work as expected. """
[docs] class FuseBatchNorm(FuseBatchNormQualiaCore): """Extend :class:`qualia_core.postprocessing.FuseBatchNorm.FuseBatchNorm` with support for Spiking Neural Networks. :class:`spikingjelly.activation_based.neuron.BaseNode` and :class:`spikingjelly.activation_based.base.StepModule` are added to :attr:`qualia_core.postprocessing.FuseBatchNorm.FuseBatchNorm.custom_layers` to avoid tracing inside. SpikingJelly-wrapped layers are added to lookup patterns: * (:class:`spikingjelly.activation_based.layer.Conv1d`, :class:`spikingjelly.activation_based.layer.BatchNorm1d` * (:class:`spikingjelly.activation_based.layer.Conv2d`, :class:`spikingjelly.activation_based.layer.BatchNorm2d` * (:class:`spikingjelly.activation_based.layer.Conv3d`, :class:`spikingjelly.activation_based.layer.BatchNorm3d` Extended attributes ``timesteps`` and ``is_snn`` are copied into target model. """ extra_custom_layers: tuple[type[nn.Module | sjb.StepModule], ...] = ( sjn.BaseNode, sjb.StepModule, # StepModule not a subclass of nn.Module but still required to avoid parsing sj-wrapped nn layers )
[docs] def __init__(self, evaluate: bool = True) -> None: # noqa: FBT001, FBT002 """Construct :class:`qualia_plugin_snn.postprocessing.FuseBatchNorm.FuseBatchNorm`. Patterns are extended to include SpikingJelly-wrapped layers. """ super().__init__(evaluate=evaluate) self.patterns += [ (sjl.Conv1d, sjl.BatchNorm1d), (sjl.Conv2d, sjl.BatchNorm2d), (sjl.Conv3d, sjl.BatchNorm3d), ]
[docs] @override def fuse(self, model: nn.Module, graphmodule_cls: type[GraphModule], framework: PyTorch, inplace: bool = False) -> GraphModule: """Fuse BatchNorm to Conv and copy source model ``timesteps`` and `is_snn` attributes to target model. :param mode: PyTorch model with Conv-BatchNorm layers to fuse :param inplace: Modify model in place instead of deep-copying :return: Resulting model with BatchNorm fused to Conv """ if isinstance(model, sjb.StepModule): graphmodule_cls = GraphModuleStepModule fused_model = super().fuse(model, graphmodule_cls=graphmodule_cls, framework=framework, inplace=inplace) if hasattr(model, 'timesteps'): fused_model.timesteps = model.timesteps if hasattr(model, 'step_mode'): fused_model.step_mode = model.step_mode if hasattr(model, 'is_snn'): fused_model.is_snn = model.is_snn return fused_model