Source code for norse.torch.module.coba_lif
import torch
from typing import Optional, Tuple
import numpy as np
from norse.torch.functional.coba_lif import CobaLIFParameters, CobaLIFState
from norse.torch.functional.coba_lif import coba_lif_step
[docs]
class CobaLIFCell(torch.nn.Module):
"""Module that computes a single euler-integration step of a conductance based
LIF neuron-model. More specifically it implements one integration step of
the following ODE
.. math::
\\begin{align*}
\\dot{v} &= 1/c_{\\text{mem}} (g_l (v_{\\text{leak}} - v) \
+ g_e (E_{\\text{rev_e}} - v) + g_i (E_{\\text{rev_i}} - v)) \\\\
\\dot{g_e} &= -1/\\tau_{\\text{syn}} g_e \\\\
\\dot{g_i} &= -1/\\tau_{\\text{syn}} g_i
\\end{align*}
together with the jump condition
.. math::
z = \\Theta(v - v_{\\text{th}})
and transition equations
.. math::
\\begin{align*}
v &= (1-z) v + z v_{\\text{reset}} \\\\
g_e &= g_e + \\text{relu}(w_{\\text{input}}) z_{\\text{in}} \\\\
g_e &= g_e + \\text{relu}(w_{\\text{rec}}) z_{\\text{rec}} \\\\
g_i &= g_i + \\text{relu}(-w_{\\text{input}}) z_{\\text{in}} \\\\
g_i &= g_i + \\text{relu}(-w_{\\text{rec}}) z_{\\text{rec}} \\\\
\\end{align*}
where :math:`z_{\\text{rec}}` and :math:`z_{\\text{in}}` are the recurrent
and input spikes respectively.
Parameters:
input_size (int): Size of the input.
hidden_size (int): Size of the hidden state.
p (LIFParameters): Parameters of the LIF neuron model.
dt (float): Time step to use.
Examples:
>>> batch_size = 16
>>> lif = CobaLIFCell(10, 20)
>>> input = torch.randn(batch_size, 10)
>>> output, s0 = lif(input)
"""
[docs]
def __init__(
self,
input_size: int,
hidden_size: int,
p: CobaLIFParameters = CobaLIFParameters(),
dt: float = 0.001,
):
super(CobaLIFCell, self).__init__()
self.input_weights = torch.nn.Parameter(
torch.randn(hidden_size, input_size) / np.sqrt(input_size)
)
self.recurrent_weights = torch.nn.Parameter(
torch.randn(hidden_size, hidden_size) / np.sqrt(hidden_size)
)
self.input_size = input_size
self.hidden_size = hidden_size
self.p = p
self.dt = dt
def forward(
self, input_tensor: torch.Tensor, state: Optional[CobaLIFState] = None
) -> Tuple[torch.Tensor, CobaLIFState]:
if state is None:
state = CobaLIFState(
z=torch.zeros(
input_tensor.shape[0],
self.hidden_size,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
v=torch.zeros(
input_tensor.shape[0],
self.hidden_size,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
g_e=torch.zeros(
input_tensor.shape[0],
self.hidden_size,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
g_i=torch.zeros(
input_tensor.shape[0],
self.hidden_size,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
)
state.v.requires_grad = True
return coba_lif_step(
input_tensor,
state,
self.input_weights,
self.recurrent_weights,
p=self.p,
dt=self.dt,
)