Source code for norse.torch.module.lif_mc
from typing import Optional, Tuple
import numpy as np
import torch
from norse.torch.functional.lif import LIFState, LIFParameters
from norse.torch.functional.lif_mc import lif_mc_step
from norse.torch.module.snn import SNNRecurrentCell
[docs]
class LIFMCRecurrentCell(SNNRecurrentCell):
r"""Computes a single euler-integration step of a LIF multi-compartment
neuron-model.
.. math::
\begin{align*}
\dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} \
- g_{\text{coupling}} v + i) \\
\dot{i} &= -1/\tau_{\text{syn}} i
\end{align*}
together with the jump condition
.. math::
z = \Theta(v - v_{\text{th}})
and transition equations
.. math::
\begin{align*}
v &= (1-z) v + z v_{\text{reset}} \\
i &= i + w_{\text{input}} z_{\text{in}} \\
i &= i + w_{\text{rec}} z_{\text{rec}}
\end{align*}
where :math:`z_{\text{rec}}` and :math:`z_{\text{in}}` are the
recurrent and input spikes respectively.
Parameters:
input_size (int): Size of the input. Also known as the number of input features.
hidden_size (int): Size of the hidden state. Also known as the number of input features.
g_coupling (torch.Tensor): conductances between the neuron compartments
p (LIFParameters): neuron parameters
dt (float): Integration timestep to use
autapses (bool): Allow self-connections in the recurrence? Defaults to False.
"""
[docs]
def __init__(
self,
input_size: int,
hidden_size: int,
p: LIFParameters = LIFParameters(),
g_coupling: Optional[torch.Tensor] = None,
**kwargs,
):
# pytype: disable=wrong-arg-types
super().__init__(
activation=None,
state_fallback=self.initial_state,
input_size=input_size,
hidden_size=hidden_size,
p=p,
**kwargs,
)
# pytype: enable=wrong-arg-types
self.g_coupling = (
g_coupling
if g_coupling is not None
else torch.nn.Parameter(
torch.randn(hidden_size, hidden_size) / np.sqrt(hidden_size)
)
)
def initial_state(self, input_tensor: torch.Tensor) -> LIFState:
state = LIFState(
z=torch.zeros(
input_tensor.shape[0],
self.hidden_size,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
v=self.p.v_leak.detach()
* torch.ones(
input_tensor.shape[0],
self.hidden_size,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
i=torch.zeros(
input_tensor.shape[0],
self.hidden_size,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
)
state.v.requires_grad = True
return state
def forward(
self, input_tensor: torch.Tensor, state: Optional[LIFState] = None
) -> Tuple[torch.Tensor, LIFState]:
if state is None:
state = self.initial_state(input_tensor)
return lif_mc_step(
input_tensor,
state,
self.input_weights,
self.recurrent_weights,
self.g_coupling,
p=self.p,
dt=self.dt,
)