Source code for norse.torch.module.test.test_regularization

import torch

from norse.torch.module.lif import LIFCell
from norse.torch.module.regularization import RegularizationCell


[docs]def test_regularization_module(): cell = LIFCell() r = RegularizationCell() # Defaults to spike counting data = torch.ones(5, 2) + 10 # Batch size of 5 z, s = cell(data) z, rs = r(z, s) assert z.shape == (5, 2) assert rs == 0 z, s = cell(data, s) z, rs = r(z, s) assert rs == 10 assert r.state == 10