norse.torch.module.sequential module

class norse.torch.module.sequential.SequentialState(*args)[source]

Bases: torch.nn.modules.container.Sequential

A sequential model that works like PyTorch’s Sequential with the addition that it handles neuron states.

Parameters

args (*torch.nn.Module) – A list of modules to sequentially apply in the forward pass

Example

>>> import torch
>>> import norse.torch as snn
>>> data = torch.ones(1, 16, 8, 4)         # Single timestep
>>> model = snn.SequentialState(
>>>   snn.Lift(torch.nn.Conv2d(16, 8, 3)), # (1, 8, 6, 2)
>>>   torch.nn.Flatten(2),                 # (1, 8, 12)
>>>   snn.LIFRecurrent(12, 6),             # (1, 8, 6)
>>>   snn.LIFRecurrent(6, 1)               # (1, 8, 1)
>>> )
>>> model(data)
Example with recurrent layers:
>>> import torch
>>> import norse.torch as snn
>>> data = torch.ones(1, 16, 8, 4)         # Single timestep
>>> model = snn.SequentialState(
>>>   snn.Lift(torch.nn.Conv2d(16, 8, 3)), # (1, 8, 6, 2)
>>>   torch.nn.Flatten(2),                 # (1, 8, 12)
>>>   snn.LSNNRecurrent(12, 6),            # (1, 8, 6)
>>>   torch.nn.RNN(6, 4, 2),               # (1, 6, 4) with 2 recurrent layers
>>>   snn.LIFRecurrent(4, 1)               # (1, 4, 1)
>>> )
>>> model(data)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(input_tensor, state=None)[source]

Feeds the input to the modules with the given state-list. If the state is None, the initial state is set to None for each of the modules.

Parameters
  • input_tensor (Tensor) – The input tensor too feed into the first module

  • state (Optional[list]) – Either a list of states for each module or None. If None, the modules will initialise their own default state

Returns

A tuple of (output tensor, state list)

register_forward_state_hooks(forward_hook)[source]

Registers hooks for all state*ful* layers.

Hooks can be removed by calling :meth:`remove_state_hooks`_.

Parameters
  • child_hook (Callable) – The hook applied to all children everytime they produce an output

  • pre_hook (Optional[Callable]) – An optional hook for the SequentialState module, executed before the input is propagated to the children.

Example

>>> import norse.torch as snn
>>> def my_hook(module, input, output):
>>>     ...
>>> module = snn.SequentialState(...)
>>> module.register_forward_state_hook(my_hook)
>>> module(...)
remove_forward_state_hooks()[source]

Disables the forward state hooks, registered in :meth:`register_forward_state_hooks`_.

training: bool