Source code for norse.torch.module.test.test_lif_adex

import torch

from norse.torch.module.lif_adex import (
    LIFAdExCell,
    LIFAdExRecurrentCell,
    LIFAdEx,
    LIFAdExRecurrent,
    LIFAdExState,
    LIFAdExFeedForwardState,
)


[docs]def test_lif_adex_cell(): cell = LIFAdExCell() data = torch.randn(5, 2) out, _ = cell(data) assert out.shape == (5, 2)
[docs]def test_lif_adex_cell_autapses(): cell = LIFAdExRecurrentCell(2, 2, autapses=True) assert not torch.allclose( torch.zeros(2), (cell.recurrent_weights * torch.eye(*cell.recurrent_weights.shape)).sum(0), ) s1 = LIFAdExState( z=torch.ones(1, 2), v=torch.zeros(1, 2), i=torch.zeros(1, 2), a=torch.zeros(1, 2), ) z, s_full = cell(torch.zeros(1, 2), s1) s2 = LIFAdExState( z=torch.tensor([[0, 1]], dtype=torch.float32), v=torch.zeros(1, 2), i=torch.zeros(1, 2), a=torch.zeros(1, 2), ) z, s_part = cell(torch.zeros(1, 2), s2) assert not s_full.i[0, 0] == s_part.i[0, 0]
[docs]def test_lif_adex_cell_no_autapses(): cell = LIFAdExRecurrentCell(2, 2, autapses=False) assert ( cell.recurrent_weights * torch.eye(*cell.recurrent_weights.shape) ).sum() == 0 s1 = LIFAdExState( z=torch.ones(1, 2), v=torch.zeros(1, 2), i=torch.zeros(1, 2), a=torch.zeros(1, 2), ) z, s_full = cell(torch.zeros(1, 2), s1) s2 = LIFAdExState( z=torch.tensor([[0, 1]], dtype=torch.float32), v=torch.zeros(1, 2), i=torch.zeros(1, 2), a=torch.zeros(1, 2), ) z, s_part = cell(torch.zeros(1, 2), s2) assert s_full.i[0, 0] == s_part.i[0, 0]
[docs]def test_lif_adex_feedforward_cell_state(): cell = LIFAdExCell() input_tensor = torch.randn(5, 2, 4) state = LIFAdExFeedForwardState( v=cell.p.v_leak, i=torch.zeros( input_tensor.shape, ), a=torch.zeros( input_tensor.shape, ), ) out, _ = cell(input_tensor, state) assert out.shape == (5, 2, 4)
[docs]def test_lif_adex_cell_backward(): cell = LIFAdExCell() data = torch.randn(5, 2) out, _ = cell(data) out.sum().backward()
[docs]def test_lif_adex_recurrent_cell_backward(): cell = LIFAdExRecurrentCell(2, 4) data = torch.randn(5, 2) out, _ = cell(data) out.sum().backward()
[docs]def test_lif_adex(): layer = LIFAdEx() data = torch.randn(10, 5, 4) out, _ = layer(data) assert out.shape == (10, 5, 4)
[docs]def test_lif_adex_recurrent(): layer = LIFAdExRecurrent(2, 4) data = torch.randn(2, 2) out, _ = layer(data) assert out.shape == (2, 4)
[docs]def test_lif_adex_backward(): cell = LIFAdEx() data = torch.randn(5, 2, 4) out, _ = cell(data) out.sum().backward()
[docs]def test_lif_adex_recurrent_backward(): cell = LIFAdExRecurrentCell(2, 4) data = torch.randn(5, 2, 2) out, _ = cell(data) out.sum().backward()