Source code for norse.torch.module.test.test_leaky_integrator

import torch

from norse.torch.module.leaky_integrator import LICell, LILinearCell, LIState


[docs]def test_li_linear_cell(): cell = LILinearCell(2, 4) data = torch.randn(5, 2) out, s = cell(data) for x in s: assert x.shape == (5, 4) assert out.shape == (5, 4)
[docs]def test_li_linear_cell_state(): cell = LILinearCell(2, 4) data = torch.randn(5, 2) out, s = cell(data, LIState(torch.ones(5, 4), torch.ones(5, 4))) for x in s: assert x.shape == (5, 4) assert out.shape == (5, 4)
[docs]def test_cell_backward(): model = LILinearCell(12, 1) data = torch.ones(100, 12) out, _ = model(data) loss = out.sum() loss.backward()
[docs]def test_li_cell(): layer = LICell() data = torch.randn(10, 2, 4) out, _ = layer(data) assert out.shape == (10, 2, 4)
[docs]def test_li_cell_state(): layer = LICell() data = torch.randn(2, 4) out, s = layer(data, LIState(torch.ones(2, 4), torch.ones(2, 4))) for x in s: assert x.shape == (2, 4) assert out.shape == (2, 4)
[docs]def test_li_backward(): model = LICell() data = torch.ones(10, 12, 1) out, _ = model(data) loss = out.sum() loss.backward()