qualia_plugin_snn.postprocessing.FuseBatchNorm module
Provide a postprocessing module to fuse BatchNorm layers to the previous convolution with SpikingJelly layers support.
- class qualia_plugin_snn.postprocessing.FuseBatchNorm.GraphModuleStepModule[source]
Bases:
GraphModule,StepModuleMixin of
torch.fx.graph_module.GraphModuleandspikingjelly.activation_based.base.StepModule.Used by
FuseBatchNormto return a module that still inherits from class:spikingjelly.activation_based.base.StepModule.This is required so that SpikingJelly’s network-wise operations such as
spikingjelly.activation_based.functional.set_step_mode()work as expected.
- class qualia_plugin_snn.postprocessing.FuseBatchNorm.FuseBatchNorm[source]
Bases:
FuseBatchNormExtend
qualia_core.postprocessing.FuseBatchNorm.FuseBatchNormwith support for Spiking Neural Networks.spikingjelly.activation_based.neuron.BaseNodeandspikingjelly.activation_based.base.StepModuleare added toqualia_core.postprocessing.FuseBatchNorm.FuseBatchNorm.custom_layersto avoid tracing inside.SpikingJelly-wrapped layers are added to lookup patterns:
(
spikingjelly.activation_based.layer.Conv1d,spikingjelly.activation_based.layer.BatchNorm1d(
spikingjelly.activation_based.layer.Conv2d,spikingjelly.activation_based.layer.BatchNorm2d(
spikingjelly.activation_based.layer.Conv3d,spikingjelly.activation_based.layer.BatchNorm3d
Extended attributes
timestepsandis_snnare copied into target model.- extra_custom_layers: tuple[type[Module | StepModule], ...] = (<class 'spikingjelly.activation_based.neuron.BaseNode'>, <class 'spikingjelly.activation_based.base.StepModule'>)
- __init__(evaluate: bool = True) None[source]
Construct
qualia_plugin_snn.postprocessing.FuseBatchNorm.FuseBatchNorm.Patterns are extended to include SpikingJelly-wrapped layers.
- Parameters:
evaluate (bool)
- Return type:
None
- fuse(model: Module, graphmodule_cls: type[GraphModule], framework: PyTorch, inplace: bool = False) GraphModule[source]
Fuse BatchNorm to Conv and copy source model
timestepsand is_snn attributes to target model.- Parameters:
mode – PyTorch model with Conv-BatchNorm layers to fuse
inplace (bool) – Modify model in place instead of deep-copying
model (Module)
graphmodule_cls (type[GraphModule])
framework (PyTorch)
- Returns:
Resulting model with BatchNorm fused to Conv
- Return type: