Source code for qualia_plugin_snn.learningmodel.pytorch.layers.quantized_SNN_layers

"""Contains quantized spiking neuron implementations."""

from __future__ import annotations

import sys
from typing import Literal

import torch
from qualia_core.learningmodel.pytorch.layers.QuantizedLayer import QuantizedLayer, QuantizerActProtocol, QuantizerInputProtocol
from qualia_core.learningmodel.pytorch.Quantizer import QuantizationConfig, Quantizer, update_params
from spikingjelly.activation_based.neuron import IFNode, LIFNode  # type: ignore[import-untyped]

from .CustomNode import ATIF, SpikeFunctionSigmoid

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override


[docs] class QuantizedLIFNode(LIFNode, # type: ignore[misc] QuantizerInputProtocol, QuantizerActProtocol, QuantizedLayer): """Quantized variant of SpikingJelly's :class:`spikingjelly.activation_based.neuron.LIFNode`. Hyperparameters ``v_threshold``, ``v_reset`` and ``tau`` are quantized as well as membrane potential ``v``. """ v: torch.Tensor v_threshold: float v_reset: float | None tau: float
[docs] def __init__(self, # noqa: PLR0913 quant_params: QuantizationConfig, tau: float = 2., decay_input: bool = True, # noqa: FBT001, FBT002 v_threshold: float = 1., v_reset: float = 0., detach_reset: bool = False, # noqa: FBT002, FBT001 step_mode: str = 's', backend: str = 'torch') -> None: """Construct :class:`QuantizedLIFNode`. For more information about spiking neuron parameters, see: :meth:`spikingjelly.activation_based.neuron.LIFNode.__init__` :param tau: Membrane time constant :param decay_input: Whether the input will decay :param v_threshold: Threshold of this neurons layer :param v_reset: Reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike :param detach_reset: Whether detach the computation graph of reset in backward :param step_mode: The step mode, which can be `s` (single-step) or `m` (multi-step) :param backend: backend fot this neurons layer, only 'torch' is supported :param quant_params: Quantization configuration dict, see :class:`qualia_core.learningmodel.pytorch.Quantizer.Quantizer` """ self.call_super_init = True # Support multiple inheritance from nn.Module super().__init__(tau=tau, decay_input=decay_input, v_threshold=v_threshold, v_reset=v_reset, detach_reset=detach_reset, step_mode=step_mode, backend=backend) # Create the quantizer instance quant_params_input = update_params(tensor_type='input', quant_params=quant_params) # Does not work like weights since it's dynamic so need to keep track of global max, don't use 'input' since it can be # skipped so use 'act' type # The type is now 'v' giving the potential tensor is how quantization parameters. quant_params_v = update_params(tensor_type='v', quant_params=quant_params) quant_params_act = update_params(tensor_type='act', quant_params=quant_params) self.quantizer_input = Quantizer(**quant_params_input) self.quantizer_v = Quantizer(**quant_params_v) self.quantizer_act = Quantizer(**quant_params_act)
@property @override # type: ignore[misc] def supported_backends(self) -> tuple[Literal['torch']]: """Supported step_mode and backend. Only torch backend is supported. :return: Tuple of ``'torch'`` :raise ValueError: When :attr:`step_mode` is not ``'s'`` or ``'m'`` """ if self.step_mode in ['s', 'm']: return ('torch',) raise ValueError(self.step_mode)
[docs] @override # type: ignore[misc] def neuronal_charge(self, x: torch.Tensor) -> None: """Quantized :meth:`spikingjelly.activation_based.neuron.LIFNode.neuronal_charge`. Membrane potential and hyperparameters are quantized before and after computation using :meth:`quantize_v_and_hyperparams`. :param x: Input tensor """ self.quantize_v_and_hyperparams() super().neuronal_charge(x) self.quantize_v_and_hyperparams()
[docs] @override # type: ignore[misc] def single_step_forward(self, x: torch.Tensor) -> torch.Tensor: """Quantized :meth:`spikingjelly.activation_based.neuron.LIFNode.single_step_forward`. Input is (optionally) quantized. Membrane potential and hyperparameters are quantized before and after computation using :meth:`quantize_v_and_hyperparams`. :param x: Input tensor """ self.v_float_to_tensor(x) x = self.quantizer_input(x) spike: torch.Tensor if self.training: self.neuronal_charge(x) spike = self.neuronal_fire() self.neuronal_reset(spike) else: self.quantize_v_and_hyperparams() if self.v_reset is None: if self.decay_input: spike, self.v = self.jit_eval_single_step_forward_soft_reset_decay_input(x, self.v, self.v_threshold, self.tau) else: spike, self.v = self.jit_eval_single_step_forward_soft_reset_no_decay_input(x, self.v, self.v_threshold, self.tau) elif self.decay_input: spike, self.v = self.jit_eval_single_step_forward_hard_reset_decay_input(x, self.v, self.v_threshold, self.v_reset, self.tau) else: spike, self.v = self.jit_eval_single_step_forward_hard_reset_no_decay_input(x, self.v, self.v_threshold, self.v_reset, self.tau) self.quantize_v_and_hyperparams() return self.quantizer_act(spike)
[docs] @override # type: ignore[misc] def multi_step_forward(self, x_seq: torch.Tensor) -> torch.Tensor: """Implement multi-step as loop over single-step for quantized neurons, inefficient but at least it works. :param x_seq: Input tensor with timesteps :return: Output tensor with timesteps """ out = torch.zeros_like(x_seq) for t, d in enumerate(x_seq): out[t] = self.single_step_forward(d) return out
@property def reciprocal_tau(self) -> float: """Return 1 / tau. :return: 1 / tau """ return 1 / self.tau
[docs] def quantize_v_and_hyperparams(self) -> None: """Quantize potential and hyperparameters (v_threshold, tau, v_reset) in-place with the same quantizer at the same time. tau is not quantized directly, instead quantize reciprocal_tau (1 / tau) because this is what is used during inference to avoid division. """ self.v, hyperparams = self.quantizer_v(self.v, bias_tensor=self.get_hyperparams_tensor(device=self.v.device, dtype=self.v.dtype)) self.v_threshold = hyperparams[0].item() self.tau = 1 / hyperparams[1].item() # Turn back reciprocal_tau int tau if self.v_reset is not None: self.v_reset = hyperparams[2].item()
[docs] def get_hyperparams_tensor(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: """Pack v_threshold, reciprocal_tau and optionally v_reset into the same Tensor. :param device: Device to create the tensor on :param dtype: Data type for the created tensor :return: New tensor with hyperparemeters concatenated """ if self.v_reset is None: return torch.tensor([self.v_threshold, self.reciprocal_tau], device=device, dtype=dtype) return torch.tensor([self.v_threshold, self.reciprocal_tau, self.v_reset], device=device, dtype=dtype)
@property @override def weights_q(self) -> int | None: """Number of fractional part bits for the membrane potential and hyperparameters in case of fixed-point quantization. See :meth:`qualia_core.learningmodel.pytorch.Quantizer.Quantizer.fractional_bits`. :return: Fractional part bits for the membrane potential and hyperparameters or ``None`` if not applicable. """ return self.quantizer_v.fractional_bits @property @override def weights_round_mode(self) -> str | None: return self.quantizer_v.roundtype
[docs] class QuantizedIFNode(IFNode, # type: ignore[misc] QuantizerInputProtocol, QuantizerActProtocol, QuantizedLayer): """Quantized variant of SpikingJelly's :class:`spikingjelly.activation_based.neuron.IFNode`. Hyperparameters ``v_threshold`` and ``v_reset`` are quantized as well as membrane potential ``v``. """ v: torch.Tensor v_threshold: float v_reset: float | None
[docs] def __init__(self, # noqa: PLR0913 quant_params: QuantizationConfig, v_threshold: float = 1., v_reset: float = 0., detach_reset: bool = False, # noqa: FBT001, FBT002 step_mode: str = 's', backend: str = 'torch') -> None: """Construct :class:`QuantizedIFNode`. For more information about spiking neuron parameters, see: :meth:`spikingjelly.activation_based.neuron.IFNode.__init__` :param v_threshold: Threshold of this neurons layer :param v_reset: Reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike :param detach_reset: Whether detach the computation graph of reset in backward :param step_mode: The step mode, which can be `s` (single-step) or `m` (multi-step) :param backend: backend fot this neurons layer, only 'torch' is supported :param quant_params: Quantization configuration dict, see :class:`qualia_core.learningmodel.pytorch.Quantizer.Quantizer` """ self.call_super_init = True # Support multiple inheritance from nn.Module super().__init__(v_threshold=v_threshold, v_reset=v_reset, detach_reset=detach_reset, step_mode=step_mode, backend=backend) # Create the quantizer instance quant_params_input = update_params(tensor_type = 'input', quant_params = quant_params) # Does not work like weights since it's dynamic so need to keep track of global max, don't use 'input' since it can be # skipped so use 'act' type quant_params_v = update_params(tensor_type='v', quant_params=quant_params) quant_params_act = update_params(tensor_type='act', quant_params=quant_params) self.quantizer_input = Quantizer(**quant_params_input) self.quantizer_v = Quantizer(**quant_params_v) self.quantizer_act = Quantizer(**quant_params_act)
@property @override # type: ignore[misc] def supported_backends(self) -> tuple[Literal['torch']]: """Supported step_mode and backend. Only torch backend is supported. :return: Tuple of ``'torch'`` :raise ValueError: When :attr:`step_mode` is not ``'s'`` or ``'m'`` """ if self.step_mode in ['s', 'm']: return ('torch',) raise ValueError(self.step_mode)
[docs] @override # type: ignore[misc] def neuronal_charge(self, x: torch.Tensor) -> None: """Quantized :meth:`spikingjelly.activation_based.neuron.IFNode.neuronal_charge`. Membrane potential and hyperparameters are quantized before and after computation using :meth:`quantize_v_and_hyperparams`. :param x: Input tensor """ self.quantize_v_and_hyperparams() super().neuronal_charge(x) self.quantize_v_and_hyperparams()
[docs] @override # type: ignore[misc] def single_step_forward(self, x: torch.Tensor) -> torch.Tensor: """Quantized :meth:`spikingjelly.activation_based.neuron.IFNode.single_step_forward`. Input is (optionally) quantized. Membrane potential and hyperparameters are quantized before and after computation using :meth:`quantize_v_and_hyperparams`. :param x: Input tensor """ self.v_float_to_tensor(x) x = self.quantizer_input(x) spike: torch.Tensor if self.training: self.neuronal_charge(x) spike = self.neuronal_fire() self.neuronal_reset(spike) else: self.quantize_v_and_hyperparams() if self.v_reset is None: spike, self.v = self.jit_eval_single_step_forward_soft_reset(x, self.v, self.v_threshold) else: spike, self.v = self.jit_eval_single_step_forward_hard_reset(x, self.v, self.v_threshold, self.v_reset) self.quantize_v_and_hyperparams() return self.quantizer_act(spike)
[docs] @override # type: ignore[misc] def multi_step_forward(self, x_seq: torch.Tensor) -> torch.Tensor: """Implement multi-step as loop over single-step for quantized neurons, inefficient but at least it works. :param x_seq: Input tensor with timesteps :return: Output tensor with timesteps """ out = torch.zeros_like(x_seq) for t, d in enumerate(x_seq): out[t] = self.single_step_forward(d) return out
[docs] def quantize_v_and_hyperparams(self) -> None: """Quantize potential and hyperparameters (v_threshold, v_reset) in-place with the same quantizer at the same time.""" self.v, hyperparams = self.quantizer_v(self.v, bias_tensor=self.get_hyperparams_tensor(device=self.v.device, dtype=self.v.dtype)) self.v_threshold = hyperparams[0].item() if self.v_reset is not None: self.v_reset = hyperparams[1].item()
[docs] def get_hyperparams_tensor(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: """Pack v_threshold and optionally v_reset into the same Tensor. :param device: Device to create the tensor on :param dtype: Data type for the created tensor :return: New tensor with hyperparemeters concatenated """ if self.v_reset is None: return torch.tensor([self.v_threshold], device=device, dtype=dtype) return torch.tensor([self.v_threshold, self.v_reset], device=device, dtype=dtype)
@property @override def weights_q(self) -> int | None: """Number of fractional part bits for the membrane potential and hyperparameters in case of fixed-point quantization. See :meth:`qualia_core.learningmodel.pytorch.Quantizer.Quantizer.fractional_bits`. :return: Fractional part bits for the membrane potential and hyperparameters or ``None`` if not applicable. """ return self.quantizer_v.fractional_bits @property @override def weights_round_mode(self) -> str | None: return self.quantizer_v.roundtype
[docs] class QuantizedATIF(ATIF, QuantizerInputProtocol, QuantizerActProtocol, QuantizedLayer): """Quantized Integrate and Fire soft-reset with learnable Vth and activation scaling, based on spikingjelly."""
[docs] def __init__(self, # noqa: PLR0913 quant_params: QuantizationConfig, v_threshold: float = 1.0, vth_init_l: float = 0.8, vth_init_h: float = 1., alpha: float = 1., device: str = 'cpu') -> None: """Construct :class:`ATIF`. :param v_threshold: Factor to apply to the uniform initialization bounds :param vth_init_l: Lower bound for uniform initialization of threshold Tensor :param vth_init_h: Higher bound for uniform initialization of threshold Tensor :param alpha: Sigmoig surrogate scale factor :param device: Device to run the computation on :param quant_params: Quantization configuration dict, see :class:`qualia_core.learningmodel.pytorch.Quantizer.Quantizer` """ super().__init__(v_threshold=v_threshold, vth_init_l=vth_init_l, vth_init_h=vth_init_h, alpha=alpha, device=device) # Create the quantizer instance quant_params_input = update_params(tensor_type='input', quant_params=quant_params) # Does not work like weights since it's dynamic so need to keep track of global max, don't use 'input' since it can be # skipped so use 'act' type quant_params_v = update_params(tensor_type='v', quant_params=quant_params) quant_params_act = update_params(tensor_type='act', quant_params=quant_params) self.quantizer_input = Quantizer(**quant_params_input) self.quantizer_v = Quantizer(**quant_params_v) self.quantizer_act = Quantizer(**quant_params_act)
[docs] @override def ifsrl_fn(self, x: torch.Tensor) -> torch.Tensor: """Quantized :meth:`qualia_plugin_snn.learningmodel.pytorch.layers.CustomNode.ATIF.ifsrl_fn`. Input is quantized. Membrane potential is quantized before and after computation. Threshold is quantized by :meth:`get_coeffs`. :param x: Input tensor :return: Output tensor """ # Primary membrane charge self.v_float_to_tensor(x) x = self.quantizer_input(x) self.v = self.quantizer_v(self.v) self.v = self.v + x self.v = self.quantizer_v(self.v) # Fire q = self.quantizer_v(self.v - self.vp_th) z = SpikeFunctionSigmoid.apply(q, self.alpha * torch.ones(1).to(self.device)).float() # Soft-Reset self.v = (1. - z) * self.v + z * (self.v - self.vp_th) self.v = self.quantizer_v(self.v) return self.quantizer_act(z * self.get_coeffs())
[docs] @override def get_coeffs(self) -> torch.Tensor: """Return the quantized Tensor of threshold :attr:`v_threshold`. :return: Quantized Tensor of threshold :attr:`v_threshold` """ return self.quantizer_v(self.v_threshold)
[docs] @override def set_coeffs(self, v_threshold: torch.Tensor) -> None: """Quantized and replace the Tensor of threshold :attr:`v_threshold`. :param v_threshold: New Tensor of quantized threshold to replace :attr:`v_threshold` """ _ = self.v_threshold.copy_(self.quantizer_v(v_threshold))
@property @override def weights_q(self) -> int | None: """Number of fractional part bits for the membrane potential and hyperparameters in case of fixed-point quantization. See :meth:`qualia_core.learningmodel.pytorch.Quantizer.Quantizer.fractional_bits`. :return: Fractional part bits for the membrane potential and hyperparameters or ``None`` if not applicable. """ return self.quantizer_v.fractional_bits @property @override def weights_round_mode(self) -> str | None: return self.quantizer_v.roundtype