Source code for norse.torch.module.lif_adex

import torch

from norse.torch.functional.lif_adex import (
    LIFAdExState,
    LIFAdExFeedForwardState,
    LIFAdExParameters,
    lif_adex_step,
    lif_adex_feed_forward_step,
)

from norse.torch.module.snn import SNN, SNNCell, SNNRecurrent, SNNRecurrentCell


[docs]class LIFAdExCell(SNNCell): r"""Computes a single euler-integration step of a feed-forward exponential LIF neuron-model *without* recurrence, adapted from http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model. It takes as input the input current as generated by an arbitrary torch module or function. More specifically it implements one integration step of the following ODE .. math:: \begin{align*} \dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\ \dot{i} &= -1/\tau_{\text{syn}} i \\ \dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right) \end{align*} together with the jump condition .. math:: z = \Theta(v - v_{\text{th}}) and transition equations .. math:: i = i + i_{\text{in}} where :math:`i_{\text{in}}` is meant to be the result of applying an arbitrary pytorch module (such as a convolution) to input spikes. Parameters: p (LIFAdExParameters): Parameters of the LIFEx neuron model. dt (float): Time step to use. Examples: >>> batch_size = 16 >>> lif_ex = LIFAdExCell() >>> data = torch.randn(batch_size, 20, 30) >>> output, s0 = lif_ex(data) """ def __init__(self, p: LIFAdExParameters = LIFAdExParameters(), **kwargs): super().__init__( lif_adex_feed_forward_step, self.initial_state, p=p, **kwargs, ) def initial_state(self, x: torch.Tensor) -> LIFAdExFeedForwardState: state = LIFAdExFeedForwardState( v=self.p.v_leak.detach(), i=torch.zeros( *x.shape, device=x.device, dtype=x.dtype, ), a=torch.zeros( *x.shape, device=x.device, dtype=x.dtype, ), ) state.v.requires_grad = True return state
[docs]class LIFAdExRecurrentCell(SNNRecurrentCell): r"""Computes a single of euler-integration step of a recurrent adaptive exponential LIF neuron-model *with* recurrence, adapted from http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model. More specifically it implements one integration step of the following ODE .. math:: \begin{align*} \dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\ \dot{i} &= -1/\tau_{\text{syn}} i \\ \dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right) \end{align*} together with the jump condition .. math:: z = \Theta(v - v_{\text{th}}) and transition equations .. math:: \begin{align*} v &= (1-z) v + z v_{\text{reset}} \\ i &= i + w_{\text{input}} z_{\text{in}} \\ i &= i + w_{\text{rec}} z_{\text{rec}} \end{align*} where :math:`z_{\text{rec}}` and :math:`z_{\text{in}}` are the recurrent and input spikes respectively. Examples: >>> batch_size = 16 >>> lif = LIFAdExRecurrentCell(10, 20) >>> input = torch.randn(batch_size, 10) >>> output, s0 = lif(input) Parameters: input_size (int): Size of the input. Also known as the number of input features. hidden_size (int): Size of the hidden state. Also known as the number of input features. p (LIFAdExParameters): Parameters of the LIF neuron model. input_weights (torch.Tensor): Weights used for input tensors. Defaults to a random matrix normalized to the number of hidden neurons. recurrent_weights (torch.Tensor): Weights used for input tensors. Defaults to a random matrix normalized to the number of hidden neurons. autapses (bool): Allow self-connections in the recurrence? Defaults to False. Will also remove autapses in custom recurrent weights, if set above. dt (float): Time step to use. """ def __init__( self, input_size: int, hidden_size: int, p: LIFAdExParameters = LIFAdExParameters(), **kwargs, ): super().__init__( activation=lif_adex_step, state_fallback=self.initial_state, p=p, input_size=input_size, hidden_size=hidden_size, **kwargs, ) def initial_state(self, input_tensor: torch.Tensor) -> LIFAdExState: dims = (*input_tensor.shape[:-1], self.hidden_size) state = LIFAdExState( z=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ), v=torch.full( dims, self.p.v_leak.detach(), device=input_tensor.device, dtype=input_tensor.dtype, ), i=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ), a=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ), ) state.v.requires_grad = True return state
[docs]class LIFAdEx(SNN): r"""A neuron layer that wraps a recurrent LIFAdExCell in time such that the layer keeps track of temporal sequences of spikes. After application, the layer returns a tuple containing (spikes from all timesteps, state from the last timestep). Example: >>> data = torch.zeros(10, 5, 2) # 10 timesteps, 5 batches, 2 neurons >>> l = LIFAdExLayer(2, 4) >>> l(data) # Returns tuple of (Tensor(10, 5, 4), LIFExState) Parameters: p (LIFAdExParameters): The neuron parameters as a torch Module, which allows the module to configure neuron parameters as optimizable. dt (float): Time step to use in integration. Defaults to 0.001. """ def __init__(self, p: LIFAdExParameters = LIFAdExParameters(), **kwargs): super().__init__( activation=lif_adex_feed_forward_step, state_fallback=self.initial_state, p=p, **kwargs, ) def initial_state(self, input_tensor: torch.Tensor) -> LIFAdExFeedForwardState: state = LIFAdExFeedForwardState( v=torch.full( input_tensor.shape[1:], # Assume first dimension is time self.p.v_leak.detach(), device=input_tensor.device, dtype=input_tensor.dtype, ), i=torch.zeros( input_tensor.shape[1:], # Assume first dimension is time device=input_tensor.device, dtype=input_tensor.dtype, ), a=torch.zeros( input_tensor.shape[1:], # Assume first dimension is time device=input_tensor.device, dtype=input_tensor.dtype, ), ) state.v.requires_grad = True return state
[docs]class LIFAdExRecurrent(SNNRecurrent): r"""A neuron layer that wraps a recurrent LIFAdExRecurrentCell in time (*with* recurrence) such that the layer keeps track of temporal sequences of spikes. After application, the layer returns a tuple containing (spikes from all timesteps, state from the last timestep). Example: >>> data = torch.zeros(10, 5, 2) # 10 timesteps, 5 batches, 2 neurons >>> l = LIFAdExRecurrent(2, 4) >>> l(data) # Returns tuple of (Tensor(10, 5, 4), LIFAdExState) Parameters: input_size (int): The number of input neurons hidden_size (int): The number of hidden neurons p (LIFAdExParameters): The neuron parameters as a torch Module, which allows the module to configure neuron parameters as optimizable. input_weights (torch.Tensor): Weights used for input tensors. Defaults to a random matrix normalized to the number of hidden neurons. recurrent_weights (torch.Tensor): Weights used for input tensors. Defaults to a random matrix normalized to the number of hidden neurons. autapses (bool): Allow self-connections in the recurrence? Defaults to False. Will also remove autapses in custom recurrent weights, if set above. dt (float): Time step to use in integration. Defaults to 0.001. """ def __init__( self, input_size: int, hidden_size: int, p: LIFAdExParameters = LIFAdExParameters(), **kwargs, ): super().__init__( activation=lif_adex_step, state_fallback=self.initial_state, input_size=input_size, hidden_size=hidden_size, p=p, **kwargs, ) def initial_state(self, input_tensor: torch.Tensor) -> LIFAdExState: dims = ( # Remove first dimension (time) *input_tensor.shape[1:-1], self.hidden_size, ) state = LIFAdExState( z=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ), v=torch.full( dims, self.p.v_leak.detach(), device=input_tensor.device, dtype=input_tensor.dtype, ), i=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ), a=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ), ) state.v.requires_grad = True return state