Source code for qualia_plugin_snn.learningframework.SpikingJelly

"""Provide the SpikingJelly single-step learningframework module."""

from __future__ import annotations

import sys

import spikingjelly.activation_based.base as sjb  # type: ignore[import-untyped]
import spikingjelly.activation_based.neuron as sjn  # type: ignore[import-untyped]
import torch
import torch.fx
from qualia_core.learningframework import PyTorch
from qualia_core.typing import TYPE_CHECKING
from spikingjelly.activation_based import functional

import qualia_plugin_snn
import qualia_plugin_snn.learningmodel.pytorch

# We are inside a TYPE_CHECKING block but our custom TYPE_CHECKING constant triggers TCH001-TCH003 so ignore them
if TYPE_CHECKING:
    from types import ModuleType

    from torch import nn

    from qualia_plugin_snn.learningmodel.pytorch.SNN import SNN

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override


[docs] class SpikingJelly(PyTorch): """SpikingJelly single-step LearningFramework implementation extending PyTorch. :attr:`qualia_core.learningframework.PyTorch.PyTorch.learningmodels` are replaced by the Spiking Neural Networks from :mod:`qualia_plugin_snn.learningmodel.pytorch` """ learningmodels: ModuleType = qualia_plugin_snn.learningmodel.pytorch """:mod:`qualia_plugin_snn.learningmodel.pytorch` additional learningmodels for Spiking Neural Networks. :meta hide-value: """
[docs] class TrainerModule(PyTorch.TrainerModule): """SpikingJelly single-step TrainerModule implementation extending PyTorch TrainerModule.""" #: Spiking Neural Network model used by this TrainerModule model: SNN
[docs] @override def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass for a Spiking Neural Network model with duplicated timesteps. First calls SpikingJelly's reset on the model to reset neurons potentials. Then duplicate the input to generate the number of timesteps given by :attr:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.timesteps` Call :meth:`qualia_plugin_snn.learningmodel.pytorch.SNN.SNN.forward` for each timestep. Finally, average the output of the model over the timesteps. :param x: Input data :return: Output predictions """ functional.reset_net(self.model) x_list = [self.model(x) for _ in range(self.model.timesteps)] return torch.stack(x_list).sum(0) / self.model.timesteps
[docs] @override def trace_model(self, model: nn.Module, extra_custom_layers: tuple[type[nn.Module], ...] = ()) -> tuple[torch.fx.Graph, torch.fx.GraphModule]: # StepModule not a subclass of nn.Module but still required to avoid parsing sj-wrapped nn layers sj_extra_custom_layers: tuple[type[nn.Module | sjb.StepModule], ...] = ( sjn.BaseNode, sjb.StepModule, ) return super().trace_model(model=model, extra_custom_layers=(*extra_custom_layers, *sj_extra_custom_layers))