# Source code for norse.torch.module.lif_adex

import torch

from norse.torch.functional.lif_adex import (
LIFAdExState,
LIFAdExFeedForwardState,
LIFAdExParameters,
lif_adex_step,
lif_adex_feed_forward_step,
)

from norse.torch.module.snn import SNN, SNNCell, SNNRecurrent, SNNRecurrentCell

[docs]class LIFAdExCell(SNNCell):
r"""Computes a single euler-integration step of a feed-forward exponential
LIF neuron-model *without* recurrence, adapted from
http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model.
It takes as input the input current as generated by an arbitrary torch
module or function. More specifically it implements one integration step
of the following ODE

.. math::
\begin{align*}
\dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\
\dot{i} &= -1/\tau_{\text{syn}} i \\
\dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right)
\end{align*}

together with the jump condition

.. math::
z = \Theta(v - v_{\text{th}})

and transition equations

.. math::
i = i + i_{\text{in}}

where :math:i_{\text{in}} is meant to be the result of applying
an arbitrary pytorch module (such as a convolution) to input spikes.

Parameters:
p (LIFAdExParameters): Parameters of the LIFEx neuron model.
dt (float): Time step to use.

Examples:
>>> batch_size = 16
>>> lif_ex = LIFAdExCell()
>>> data = torch.randn(batch_size, 20, 30)
>>> output, s0 = lif_ex(data)
"""

[docs]    def __init__(self, p: LIFAdExParameters = LIFAdExParameters(), **kwargs):
super().__init__(
lif_adex_feed_forward_step,
self.initial_state,
p=p,
**kwargs,
)

def initial_state(self, x: torch.Tensor) -> LIFAdExFeedForwardState:
state = LIFAdExFeedForwardState(
v=self.p.v_leak.detach(),
i=torch.zeros(
*x.shape,
device=x.device,
dtype=x.dtype,
),
a=torch.zeros(
*x.shape,
device=x.device,
dtype=x.dtype,
),
)
state.v.requires_grad = True
return state

[docs]class LIFAdExRecurrentCell(SNNRecurrentCell):
r"""Computes a single of euler-integration step of a recurrent adaptive exponential
LIF neuron-model *with* recurrence, adapted from
http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model.
More specifically it implements one integration step of the following ODE

.. math::
\begin{align*}
\dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\
\dot{i} &= -1/\tau_{\text{syn}} i \\
\dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right)
\end{align*}

together with the jump condition

.. math::
z = \Theta(v - v_{\text{th}})

and transition equations

.. math::
\begin{align*}
v &= (1-z) v + z v_{\text{reset}} \\
i &= i + w_{\text{input}} z_{\text{in}} \\
i &= i + w_{\text{rec}} z_{\text{rec}}
\end{align*}

where :math:z_{\text{rec}} and :math:z_{\text{in}} are the
recurrent and input spikes respectively.

Examples:
>>> batch_size = 16
>>> lif = LIFAdExRecurrentCell(10, 20)
>>> input = torch.randn(batch_size, 10)
>>> output, s0 = lif(input)

Parameters:
input_size (int): Size of the input. Also known as the number of input features.
hidden_size (int): Size of the hidden state. Also known as the number of input features.
p (LIFAdExParameters): Parameters of the LIF neuron model.
input_weights (torch.Tensor): Weights used for input tensors. Defaults to a random
matrix normalized to the number of hidden neurons.
recurrent_weights (torch.Tensor): Weights used for input tensors. Defaults to a random
matrix normalized to the number of hidden neurons.
autapses (bool): Allow self-connections in the recurrence? Defaults to False. Will also
remove autapses in custom recurrent weights, if set above.
dt (float): Time step to use.

"""

[docs]    def __init__(
self,
input_size: int,
hidden_size: int,
p: LIFAdExParameters = LIFAdExParameters(),
**kwargs,
):
super().__init__(
activation=lif_adex_step,
state_fallback=self.initial_state,
p=p,
input_size=input_size,
hidden_size=hidden_size,
**kwargs,
)

def initial_state(self, input_tensor: torch.Tensor) -> LIFAdExState:
dims = (*input_tensor.shape[:-1], self.hidden_size)
state = LIFAdExState(
z=torch.zeros(
*dims,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
v=torch.full(
dims,
self.p.v_leak.detach(),
device=input_tensor.device,
dtype=input_tensor.dtype,
),
i=torch.zeros(
*dims,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
a=torch.zeros(
*dims,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
)
state.v.requires_grad = True
return state

[docs]class LIFAdEx(SNN):
r"""A neuron layer that wraps a recurrent LIFAdExCell in time such
that the layer keeps track of temporal sequences of spikes.
After application, the layer returns a tuple containing
(spikes from all timesteps, state from the last timestep).

Example:
>>> data = torch.zeros(10, 5, 2) # 10 timesteps, 5 batches, 2 neurons
>>> l = LIFAdExLayer(2, 4)
>>> l(data) # Returns tuple of (Tensor(10, 5, 4), LIFExState)

Parameters:
p (LIFAdExParameters): The neuron parameters as a torch Module, which allows the module
to configure neuron parameters as optimizable.
dt (float): Time step to use in integration. Defaults to 0.001.
"""

[docs]    def __init__(self, p: LIFAdExParameters = LIFAdExParameters(), **kwargs):
super().__init__(
activation=lif_adex_feed_forward_step,
state_fallback=self.initial_state,
p=p,
**kwargs,
)

def initial_state(self, input_tensor: torch.Tensor) -> LIFAdExFeedForwardState:
state = LIFAdExFeedForwardState(
v=torch.full(
input_tensor.shape[1:],  # Assume first dimension is time
self.p.v_leak.detach(),
device=input_tensor.device,
dtype=input_tensor.dtype,
),
i=torch.zeros(
input_tensor.shape[1:],  # Assume first dimension is time
device=input_tensor.device,
dtype=input_tensor.dtype,
),
a=torch.zeros(
input_tensor.shape[1:],  # Assume first dimension is time
device=input_tensor.device,
dtype=input_tensor.dtype,
),
)
state.v.requires_grad = True
return state

[docs]class LIFAdExRecurrent(SNNRecurrent):
r"""A neuron layer that wraps a recurrent LIFAdExRecurrentCell in time (*with*
recurrence) such that the layer keeps track of temporal sequences of spikes.
After application, the layer returns a tuple containing
(spikes from all timesteps, state from the last timestep).

Example:
>>> data = torch.zeros(10, 5, 2) # 10 timesteps, 5 batches, 2 neurons
>>> l = LIFAdExRecurrent(2, 4)
>>> l(data) # Returns tuple of (Tensor(10, 5, 4), LIFAdExState)

Parameters:
input_size (int): The number of input neurons
hidden_size (int): The number of hidden neurons
p (LIFAdExParameters): The neuron parameters as a torch Module, which allows the module
to configure neuron parameters as optimizable.
input_weights (torch.Tensor): Weights used for input tensors. Defaults to a random
matrix normalized to the number of hidden neurons.
recurrent_weights (torch.Tensor): Weights used for input tensors. Defaults to a random
matrix normalized to the number of hidden neurons.
autapses (bool): Allow self-connections in the recurrence? Defaults to False. Will also
remove autapses in custom recurrent weights, if set above.
dt (float): Time step to use in integration. Defaults to 0.001.
"""

[docs]    def __init__(
self,
input_size: int,
hidden_size: int,
p: LIFAdExParameters = LIFAdExParameters(),
**kwargs,
):
super().__init__(
activation=lif_adex_step,
state_fallback=self.initial_state,
input_size=input_size,
hidden_size=hidden_size,
p=p,
**kwargs,
)

def initial_state(self, input_tensor: torch.Tensor) -> LIFAdExState:
dims = (  # Remove first dimension (time)
*input_tensor.shape[1:-1],
self.hidden_size,
)
state = LIFAdExState(
z=torch.zeros(
*dims,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
v=torch.full(
dims,
self.p.v_leak.detach(),
device=input_tensor.device,
dtype=input_tensor.dtype,
),
i=torch.zeros(
*dims,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
a=torch.zeros(
*dims,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
)
state.v.requires_grad = True
return state