Source code for norse.torch.module.sequential

from typing import Any, Callable, List, NamedTuple, Optional, Union

import torch

from norse.torch.utils.state import _is_module_stateful


[docs] class SequentialState(torch.nn.Sequential): """ A sequential model that works like PyTorch's ``Sequential`` with the addition that it handles neuron states. Arguments: args (*torch.nn.Module): A list of modules to sequentially apply in the forward pass Example: >>> import torch >>> import norse.torch as snn >>> data = torch.ones(1, 1, 16, 8, 4) # Single timestep, Single Batch, 16 channels >>> model = snn.SequentialState( >>> snn.Lift(torch.nn.Conv2d(16, 8, 3)), # (1, 1, 8, 6, 2) >>> torch.nn.Flatten(3), # (1, 1, 8, 12) >>> snn.LIFRecurrent(12, 6), # (1, 1, 8, 6) >>> snn.LIFRecurrent(6, 1) # (1, 1, 8, 1) >>> ) >>> model(data) Example with recurrent layers: >>> import torch >>> import norse.torch as snn >>> data = torch.ones(1, 16, 8, 4) # Single timestep >>> model = snn.SequentialState( >>> snn.Lift(torch.nn.Conv2d(16, 8, 3)), # (1, 8, 6, 2) >>> torch.nn.Flatten(2), # (1, 8, 12) >>> snn.LSNNRecurrent(12, 6), # (1, 8, 6) >>> torch.nn.RNN(6, 4, 2), # (1, 6, 4) with 2 recurrent layers >>> snn.LIFRecurrent(4, 1) # (1, 4, 1) >>> ) >>> model(data) """
[docs] def __init__(self, *args: torch.nn.Module, return_hidden: bool = False): super(SequentialState, self).__init__() self.stateful_layers = [] self.forward_state_hooks = [] self.return_hidden = return_hidden for idx, module in enumerate(args): self.add_module(str(idx), module) # Identify all the stateful layers self.stateful_layers.append(_is_module_stateful(module))
def register_forward_state_hooks( self, forward_hook: Callable[[torch.nn.Module, torch.Tensor, torch.Tensor], None], ): """ Registers hooks for all state*ful* layers. Hooks can be removed by calling :meth:`remove_state_hooks`_. Arguments: child_hook (Callable): The hook applied to all children everytime they produce an output pre_hook (Optional[Callable]): An optional hook for the SequentialState module, executed *before* the input is propagated to the children. Example: >>> import norse.torch as snn >>> def my_hook(module, input, output): >>> ... >>> module = snn.SequentialState(...) >>> module.register_forward_state_hook(my_hook) >>> module(...) """ if len(self.forward_state_hooks) > 0: raise ValueError("Forward state hooks already in place") for name, module in self.named_children(): if self.stateful_layers[int(name)]: handle = module.register_forward_hook(forward_hook) self.forward_state_hooks.append(handle) def remove_forward_state_hooks(self): """ Disables and discards the forward state hooks, registered in :meth:`register_forward_state_hooks`_. """ for handle in self.forward_state_hooks: handle.remove() self.forward_state_hooks.clear() def forward(self, input_tensor: torch.Tensor, state: Union[list, None] = None): """ Feeds the input to the modules with the given state-list. If the state is None, the initial state is set to None for each of the modules. Parameters: input_tensor: The input tensor too feed into the first module state: Either a list of states for each module or None. If None, the modules will initialise their own default state Returns: A tuple of (output tensor, state list) """ state = [None] * len(self) if state is None else state hidden = [] for index, module in enumerate(self): if self.stateful_layers[index]: input_tensor, s = module(input_tensor, state[index]) state[index] = s else: input_tensor = module(input_tensor) if self.return_hidden: hidden.append(input_tensor) if self.return_hidden: return hidden, state else: return input_tensor, state
class RecurrentSequentialState(NamedTuple): cache: Optional[Any] = None state: Optional[Any] = None class RecurrentSequential(torch.nn.Module): """A sequential module that feeds the output of the underlying modules back as input in the following timestep. """ def __init__(self, *modules: torch.nn.Module, output_modules: List[int] = -1): super().__init__() self.module = SequentialState(*modules, return_hidden=True) self.output_modules = output_modules def forward( self, x: torch.Tensor, state: Optional[RecurrentSequentialState] = None ): if state is None: state = RecurrentSequentialState() else: x = torch.stack((x, state.cache)).sum(0) outputs, out_state = self.module(x, state.state) if isinstance(self.output_modules, int): return outputs[self.output_modules], RecurrentSequentialState( outputs[self.output_modules], out_state ) else: recurrent_outputs = [outputs[i] for i in self.output_modules] return recurrent_outputs, RecurrentSequentialState( recurrent_outputs, out_state )