Source code for norse.torch.module.lif

"""
A very popular neuron model that combines a :mod:`norse.torch.module.leaky_integrator` with
spike thresholds to produce events (spikes).

See :mod:`norse.torch.functional.lif` for more information.
"""
import torch

from norse.torch.functional.lif import (
    LIFState,
    LIFFeedForwardState,
    LIFParameters,
    lif_step,
    lif_step_sparse,
    lif_feed_forward_step,
    lif_feed_forward_step_sparse,
)
from norse.torch.module.snn import SNN, SNNCell, SNNRecurrent, SNNRecurrentCell


[docs]class LIFCell(SNNCell): """Module that computes a single euler-integration step of a leaky integrate-and-fire (LIF) neuron-model *without* recurrence and *without* time. More specifically it implements one integration step of the following ODE .. math:: \\begin{align*} \\dot{v} &= 1/\\tau_{\\text{mem}} (v_{\\text{leak}} - v + i) \\ \\dot{i} &= -1/\\tau_{\\text{syn}} i \\end{align*} together with the jump condition .. math:: z = \\Theta(v - v_{\\text{th}}) and transition equations .. math:: \\begin{align*} v &= (1-z) v + z v_{\\text{reset}} \\end{align*} Example: >>> data = torch.zeros(5, 2) # 5 batches, 2 neurons >>> l = LIFCell(2, 4) >>> l(data) # Returns tuple of (Tensor(5, 4), LIFState) Arguments: p (LIFParameters): Parameters of the LIF neuron model. dt (float): Time step to use. Defaults to 0.001. """ def __init__(self, p: LIFParameters = LIFParameters(), sparse=False, **kwargs): if sparse: super().__init__( activation=lif_feed_forward_step_sparse, state_fallback=self.initial_state, p=p, **kwargs, ) else: super().__init__( activation=lif_feed_forward_step, state_fallback=self.initial_state, p=p, **kwargs, )
[docs] def initial_state(self, input_tensor: torch.Tensor) -> LIFFeedForwardState: state = LIFFeedForwardState( v=torch.full( input_tensor.shape, self.p.v_leak.detach(), device=input_tensor.device, dtype=input_tensor.dtype, ), i=torch.zeros( *input_tensor.shape, device=input_tensor.device, dtype=input_tensor.dtype, ), ) state.v.requires_grad = True return state
[docs]class LIFRecurrentCell(SNNRecurrentCell): """Module that computes a single euler-integration step of a leaky integrate-and-fire (LIF) neuron-model *with* recurrence but *without* time. More specifically it implements one integration step of the following ODE .. math:: \\begin{align*} \\dot{v} &= 1/\\tau_{\\text{mem}} (v_{\\text{leak}} - v + i) \\ \\dot{i} &= -1/\\tau_{\\text{syn}} i \\end{align*} together with the jump condition .. math:: z = \\Theta(v - v_{\\text{th}}) and transition equations .. math:: \\begin{align*} v &= (1-z) v + z v_{\\text{reset}} \\ i &= i + w_{\\text{input}} z_{\\text{in}} \\ i &= i + w_{\\text{rec}} z_{\\text{rec}} \\end{align*} where :math:`z_{\\text{rec}}` and :math:`z_{\\text{in}}` are the recurrent and input spikes respectively. Example: >>> data = torch.zeros(5, 2) # 5 batches, 2 neurons >>> l = LIFRecurrentCell(2, 4) >>> l(data) # Returns tuple of (Tensor(5, 4), LIFState) Parameters: input_size (int): Size of the input. Also known as the number of input features. hidden_size (int): Size of the hidden state. Also known as the number of input features. p (LIFParameters): Parameters of the LIF neuron model. input_weights (torch.Tensor): Weights used for input tensors. Defaults to a random matrix normalized to the number of hidden neurons. recurrent_weights (torch.Tensor): Weights used for input tensors. Defaults to a random matrix normalized to the number of hidden neurons. autapses (bool): Allow self-connections in the recurrence? Defaults to False. Will also remove autapses in custom recurrent weights, if set above. dt (float): Time step to use. """ def __init__( self, input_size: int, hidden_size: int, sparse: bool = False, p: LIFParameters = LIFParameters(), **kwargs ): if sparse: super().__init__( activation=lif_step_sparse, state_fallback=self.initial_state, p=p, input_size=input_size, hidden_size=hidden_size, **kwargs, ) else: super().__init__( activation=lif_step, state_fallback=self.initial_state, p=p, input_size=input_size, hidden_size=hidden_size, **kwargs, ) self.sparse = sparse
[docs] def initial_state(self, input_tensor: torch.Tensor) -> LIFState: dims = (*input_tensor.shape[:-1], self.hidden_size) state = LIFState( z=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ).to_sparse() if self.sparse else torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ).to_sparse(), v=torch.full( dims, self.p.v_leak.detach(), device=input_tensor.device, dtype=input_tensor.dtype, ), i=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ), ) state.v.requires_grad = True return state
[docs]class LIF(SNN): """ A neuron layer that wraps a :class:`LIFCell` in time such that the layer keeps track of temporal sequences of spikes. After application, the layer returns a tuple containing (spikes from all timesteps, state from the last timestep). Example: >>> data = torch.zeros(10, 5, 2) # 10 timesteps, 5 batches, 2 neurons >>> l = LIF() >>> l(data) # Returns tuple of (Tensor(10, 5, 2), LIFState) Parameters: p (LIFParameters): The neuron parameters as a torch Module, which allows the module to configure neuron parameters as optimizable. dt (float): Time step to use in integration. Defaults to 0.001. """ def __init__(self, p: LIFParameters = LIFParameters(), sparse=False, **kwargs): if sparse: super().__init__( activation=lif_feed_forward_step_sparse, state_fallback=self.initial_state, p=p, **kwargs, ) else: super().__init__( activation=lif_feed_forward_step, state_fallback=self.initial_state, p=p, **kwargs, )
[docs] def initial_state(self, input_tensor: torch.Tensor) -> LIFFeedForwardState: state = LIFFeedForwardState( v=torch.full( input_tensor.shape[1:], # Assume first dimension is time self.p.v_leak.detach(), device=input_tensor.device, dtype=input_tensor.dtype, ), i=torch.zeros( *input_tensor.shape[1:], device=input_tensor.device, dtype=input_tensor.dtype, ), ) state.v.requires_grad = True return state
[docs]class LIFRecurrent(SNNRecurrent): """ A neuron layer that wraps a :class:`LIFRecurrentCell` in time such that the layer keeps track of temporal sequences of spikes. After application, the module returns a tuple containing (spikes from all timesteps, state from the last timestep). Example: >>> data = torch.zeros(10, 5, 2) # 10 timesteps, 5 batches, 2 neurons >>> l = LIFRecurrent(2, 4) >>> l(data) # Returns tuple of (Tensor(10, 5, 4), LIFState) Parameters: input_size (int): The number of input neurons hidden_size (int): The number of hidden neurons p (LIFParameters): The neuron parameters as a torch Module, which allows the module to configure neuron parameters as optimizable. input_weights (torch.Tensor): Weights used for input tensors. Defaults to a random matrix normalized to the number of hidden neurons. recurrent_weights (torch.Tensor): Weights used for input tensors. Defaults to a random matrix normalized to the number of hidden neurons. autapses (bool): Allow self-connections in the recurrence? Defaults to False. Will also remove autapses in custom recurrent weights, if set above. dt (float): Time step to use in integration. Defaults to 0.001. """ def __init__( self, input_size: int, hidden_size: int, sparse: bool = False, p: LIFParameters = LIFParameters(), *args, **kwargs ): self.sparse = sparse if sparse: super().__init__( activation=lif_step_sparse, state_fallback=self.initial_state, input_size=input_size, hidden_size=hidden_size, p=p, **kwargs, ) else: super().__init__( activation=lif_step, state_fallback=self.initial_state, input_size=input_size, hidden_size=hidden_size, p=p, **kwargs, )
[docs] def initial_state(self, input_tensor: torch.Tensor) -> LIFState: dims = ( # Remove first dimension (time) *input_tensor.shape[1:-1], self.hidden_size, ) state = LIFState( z=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ).to_sparse() if self.sparse else torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ), v=torch.full( dims, self.p.v_leak.detach(), device=input_tensor.device, dtype=input_tensor.dtype, ), i=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ), ) state.v.requires_grad = True return state