qualia_plugin_snn.postprocessing.FuseBatchNorm module

Provide a postprocessing module to fuse BatchNorm layers to the previous convolution with SpikingJelly layers support.

class qualia_plugin_snn.postprocessing.FuseBatchNorm.GraphModuleStepModule[source]

Bases: GraphModule, StepModule

Mixin of torch.fx.graph_module.GraphModule and spikingjelly.activation_based.base.StepModule.

Used by FuseBatchNorm to return a module that still inherits from class:spikingjelly.activation_based.base.StepModule.

This is required so that SpikingJelly’s network-wise operations such as spikingjelly.activation_based.functional.set_step_mode() work as expected.

class qualia_plugin_snn.postprocessing.FuseBatchNorm.FuseBatchNorm[source]

Bases: FuseBatchNorm

Extend qualia_core.postprocessing.FuseBatchNorm.FuseBatchNorm with support for Spiking Neural Networks.

spikingjelly.activation_based.neuron.BaseNode and spikingjelly.activation_based.base.StepModule are added to qualia_core.postprocessing.FuseBatchNorm.FuseBatchNorm.custom_layers to avoid tracing inside.

SpikingJelly-wrapped layers are added to lookup patterns:

  • (spikingjelly.activation_based.layer.Conv1d, spikingjelly.activation_based.layer.BatchNorm1d

  • (spikingjelly.activation_based.layer.Conv2d, spikingjelly.activation_based.layer.BatchNorm2d

  • (spikingjelly.activation_based.layer.Conv3d, spikingjelly.activation_based.layer.BatchNorm3d

Extended attributes timesteps and is_snn are copied into target model.

extra_custom_layers: tuple[type[Module | StepModule], ...] = (<class 'spikingjelly.activation_based.neuron.BaseNode'>, <class 'spikingjelly.activation_based.base.StepModule'>)
__init__(evaluate: bool = True) None[source]

Construct qualia_plugin_snn.postprocessing.FuseBatchNorm.FuseBatchNorm.

Patterns are extended to include SpikingJelly-wrapped layers.

Parameters:

evaluate (bool)

Return type:

None

fuse(model: Module, graphmodule_cls: type[GraphModule], framework: PyTorch, inplace: bool = False) GraphModule[source]

Fuse BatchNorm to Conv and copy source model timesteps and is_snn attributes to target model.

Parameters:
  • mode – PyTorch model with Conv-BatchNorm layers to fuse

  • inplace (bool) – Modify model in place instead of deep-copying

  • model (Module)

  • graphmodule_cls (type[GraphModule])

  • framework (PyTorch)

Returns:

Resulting model with BatchNorm fused to Conv

Return type:

GraphModule