norse.torch.module.regularization module

This module contains ``torch.nn.Module``s for regularisation operations on spiking layers, where it can be desirable to regularise spikes, membrane parameters, or other properties over time.

class norse.torch.module.regularization.RegularizationCell(accumulator=<function spike_accumulator>, state=None)[source]

Bases: torch.nn.modules.module.Module

A regularisation cell that accumulates some state (for instance number of spikes) for each forward step, which can later be applied to a loss term.

Example

>>> import torch
>>> from norse.torch.module import lif, regularization
>>> cell = lif.LIFCell(2, 4) # 2 -> 4
>>> r = regularization.RegularizationCell() # Defaults to spike counting
>>> data = torch.ones(5, 2)  # Batch size of 5
>>> z, s = cell(data)
>>> z, regularization_term = r(z, s)
>>> ...
>>> loss = ... + 1e-3 * regularization_term
Parameters
  • accumulator (Accumulator) – The accumulator that aggregates some data (such as spikes) that can later be included in an error term.

  • state (Optional[T]) – The regularization state to be aggregated to of any type T. Defaults to None.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(z, s)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool