Source code for norse.torch.functional.lsnn

r"""
Long-short term memory module, building on the work by
[G. Bellec, D. Salaj, A. Subramoney, R. Legenstein, and W. Maass](https://github.com/IGITUGraz/LSNN-official).

The LSNN dynamics is similar to the :mod:`.lif` equations, but it
adds an adaptive term :math:`b`:

.. math::
    \begin{align*}
        \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\
        \dot{i} &= -1/\tau_{\text{syn}}\; i \\
        \dot{b} &= -1/\tau_{b}\; b
    \end{align*}

This adaptation is applied in the jump condition when the neuron spikes:

.. math::
    z = \Theta(v - v_{\text{th}} + b)

Contrast this with the regular LIF jump condition:

.. math::
    z = \Theta(v - v_{\text{th}})

In practice, this means that the LSNN neurons *adapt* to fire more or less
given the same input. The adaptation is determined by the :math:`\tau_b`
time constant.
"""
from typing import NamedTuple, Tuple

import torch

from norse.torch.functional.threshold import threshold


[docs]class LSNNParameters(NamedTuple): r"""Parameters of an LSNN neuron Parameters: tau_syn_inv (torch.Tensor): inverse synaptic time constant (:math:`1/\tau_\text{syn}`) tau_mem_inv (torch.Tensor): inverse membrane time constant (:math:`1/\tau_\text{mem}`) tau_adapt_inv (torch.Tensor): adaptation time constant (:math:`\tau_b`) v_leak (torch.Tensor): leak potential v_th (torch.Tensor): threshold potential v_reset (torch.Tensor): reset potential beta (torch.Tensor): adaptation constant """ tau_syn_inv: torch.Tensor = torch.as_tensor(1.0 / 5e-3) tau_mem_inv: torch.Tensor = torch.as_tensor(1.0 / 1e-2) tau_adapt_inv: torch.Tensor = torch.as_tensor(1.0 / 800) v_leak: torch.Tensor = torch.as_tensor(0.0) v_th: torch.Tensor = torch.as_tensor(1.0) v_reset: torch.Tensor = torch.as_tensor(0.0) beta: torch.Tensor = torch.as_tensor(1.8) method: str = "super" alpha: float = 100.0
class LSNNState(NamedTuple): """State of an LSNN neuron Parameters: z (torch.Tensor): recurrent spikes v (torch.Tensor): membrane potential i (torch.Tensor): synaptic input current b (torch.Tensor): threshold adaptation """ z: torch.Tensor v: torch.Tensor i: torch.Tensor b: torch.Tensor def lsnn_step( input_tensor: torch.Tensor, state: LSNNState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, p: LSNNParameters = LSNNParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LSNNState]: r"""Euler integration step for LIF Neuron with threshold adaptation More specifically it implements one integration step of the following ODE .. math:: \begin{align*} \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\ \dot{i} &= -1/\tau_{\text{syn}} i \\ \dot{b} &= -1/\tau_{b} b \end{align*} together with the jump condition .. math:: z = \Theta(v - v_{\text{th}} + b) and transition equations .. math:: \begin{align*} v &= (1-z) v + z v_{\text{reset}} \\ i &= i + w_{\text{input}} z_{\text{in}} \\ i &= i + w_{\text{rec}} z_{\text{rec}} \\ b &= b + \beta z \end{align*} where :math:`z_{\text{rec}}` and :math:`z_{\text{in}}` are the recurrent and input spikes respectively. Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LSNNState): current state of the lsnn unit input_weights (torch.Tensor): synaptic weights for input spikes recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes p (LSNNParameters): parameters of the lsnn unit dt (float): Integration timestep to use """ # compute voltage decay dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i) v_decayed = state.v + dv # compute current decay di = -dt * p.tau_syn_inv * state.i i_decayed = state.i + di # compute threshold adaptation update db = dt * p.tau_adapt_inv * (p.v_th - state.b) b_decayed = state.b + db # compute new spikes z_new = threshold(v_decayed - b_decayed, p.method, p.alpha) # compute reset v_new = (1 - z_new.detach()) * v_decayed + z_new.detach() * p.v_reset # compute current jumps i_new = ( i_decayed + torch.nn.functional.linear(input_tensor, input_weights) + torch.nn.functional.linear(state.z, recurrent_weights) ) b_new = b_decayed + z_new.detach() * p.beta return z_new, LSNNState(z_new, v_new, i_new, b_new) def ada_lif_step( input_tensor: torch.Tensor, state: LSNNState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, p: LSNNParameters = LSNNParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LSNNState]: r"""Euler integration step for LIF Neuron with adaptation. More specifically it implements one integration step of the following ODE .. math:: \begin{align*} \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + b + i) \\ \dot{i} &= -1/\tau_{\text{syn}} i \\ \dot{b} &= -1/\tau_{b} b \end{align*} together with the jump condition .. math:: z = \Theta(v - v_{\text{th}}) and transition equations .. math:: \begin{align*} v &= (1-z) v + z v_{\\text{reset}} \\ i &= i + w_{\text{input}} z_{\\text{in}} \\ i &= i + w_{\text{rec}} z_{\\text{rec}} \\ b &= b + \beta z \end{align*} where :math:`z_{\text{rec}}` and :math:`z_{\text{in}}` are the recurrent and input spikes respectively. Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LSNNState): current state of the lsnn unit input_weights (torch.Tensor): synaptic weights for input spikes recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes p (LSNNParameters): parameters of the lsnn unit dt (float): Integration timestep to use """ # compute voltage updates dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i - state.b) v_decayed = state.v + dv # compute current updates di = -dt * p.tau_syn_inv * state.i i_decayed = state.i + di # compute threshold updates db = -dt * p.tau_adapt_inv * state.b b_decayed = state.b + db # compute new spikes z_new = threshold(v_decayed - p.v_th, p.method, p.alpha) # compute resets v_new = v_decayed - z_new * (p.v_th - p.v_reset) # compute b update b_new = b_decayed + z_new * p.beta # compute current jumps i_new = ( i_decayed + torch.nn.functional.linear(input_tensor, input_weights) + torch.nn.functional.linear(state.z, recurrent_weights) ) return z_new, LSNNState(z_new, v_new, i_new, b_new)
[docs]class LSNNFeedForwardState(NamedTuple): """Integration state kept for a lsnn module Parameters: v (torch.Tensor): membrane potential i (torch.Tensor): synaptic input current b (torch.Tensor): threshold adaptation """ v: torch.Tensor i: torch.Tensor b: torch.Tensor
[docs]def lsnn_feed_forward_step( input_tensor: torch.Tensor, state: LSNNFeedForwardState, p: LSNNParameters = LSNNParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LSNNFeedForwardState]: r"""Euler integration step for LIF Neuron with threshold adaptation. More specifically it implements one integration step of the following ODE .. math:: \\begin{align*} \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\ \dot{i} &= -1/\tau_{\text{syn}} i \\ \dot{b} &= -1/\tau_{b} b \end{align*} together with the jump condition .. math:: z = \Theta(v - v_{\text{th}} + b) and transition equations .. math:: \begin{align*} v &= (1-z) v + z v_{\text{reset}} \\ i &= i + \text{input} \\ b &= b + \beta z \end{align*} Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LSNNFeedForwardState): current state of the lsnn unit p (LSNNParameters): parameters of the lsnn unit dt (float): Integration timestep to use """ # compute voltage updates dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i) v_decayed = state.v + dv # compute current updates di = -dt * p.tau_syn_inv * state.i i_decayed = state.i + di # compute threshold updates db = dt * p.tau_adapt_inv * (p.v_th - state.b) b_decayed = state.b + db # compute new spikes z_new = threshold(v_decayed - b_decayed, p.method, p.alpha) # compute reset v_new = (1 - z_new) * v_decayed + z_new * p.v_reset # compute b update b_new = b_decayed + z_new * p.beta # compute current jumps i_new = i_decayed + input_tensor return z_new, LSNNFeedForwardState(v=v_new, i=i_new, b=b_new)