Source code for norse.torch.module.test.test_lif_refrac

import torch

from norse.torch.functional.lif import LIFState, LIFFeedForwardState
from norse.torch.functional.lif_refrac import LIFRefracState, LIFRefracFeedForwardState
from norse.torch.module.lif_refrac import LIFRefracCell, LIFRefracRecurrentCell


[docs]def test_lif_refrac_cell(): cell = LIFRefracRecurrentCell(2, 4) data = torch.randn(5, 2) out, s = cell(data) assert s.rho.shape == (5, 4) assert s.lif.v.shape == (5, 4) assert s.lif.i.shape == (5, 4) assert s.lif.z.shape == (5, 4) assert out.shape == (5, 4)
[docs]def test_lif_refrac_cell_state(): cell = LIFRefracRecurrentCell(2, 4) input_tensor = torch.randn(5, 2) state = LIFRefracState( lif=LIFState( z=torch.zeros( input_tensor.shape[0], cell.hidden_size, ), v=cell.p.lif.v_leak * torch.ones( input_tensor.shape[0], cell.hidden_size, ), i=torch.zeros( input_tensor.shape[0], cell.hidden_size, ), ), rho=torch.zeros( input_tensor.shape[0], cell.hidden_size, ), ) out, s = cell(input_tensor, state) assert s.rho.shape == (5, 4) assert s.lif.v.shape == (5, 4) assert s.lif.i.shape == (5, 4) assert s.lif.z.shape == (5, 4) assert out.shape == (5, 4)
[docs]def test_lif_refrac_cell_backward(): cell = LIFRefracRecurrentCell(2, 4) data = torch.randn(5, 2) out, _ = cell(data) out.sum().backward()
[docs]def test_lif_refrac_feedforward(): batch_size = 16 cell = LIFRefracCell() x = torch.randn(batch_size, 20, 30) out, s = cell(x) assert out.shape == (batch_size, 20, 30) assert s.lif.v.shape == (batch_size, 20, 30) assert s.lif.i.shape == (batch_size, 20, 30) assert s.rho.shape == (batch_size, 20, 30)
[docs]def test_lif_refrac_feedforward_cell(): batch_size = 16 cell = LIFRefracCell() input_tensor = torch.randn(batch_size, 20, 30) state = LIFRefracFeedForwardState( LIFFeedForwardState( v=cell.p.lif.v_leak, i=torch.zeros( input_tensor.shape, ), ), rho=torch.zeros( input_tensor.shape, ), ) out, s = cell(input_tensor, state) assert out.shape == (batch_size, 20, 30) assert s.lif.v.shape == (batch_size, 20, 30) assert s.lif.i.shape == (batch_size, 20, 30) assert s.rho.shape == (batch_size, 20, 30)
[docs]def test_lif_refrac_cell_autapses(): cell = LIFRefracRecurrentCell(2, 2, autapses=True) assert not torch.allclose( torch.zeros(2), (cell.recurrent_weights * torch.eye(*cell.recurrent_weights.shape)).sum(0), ) s1 = LIFRefracState( rho=torch.zeros(1, 2), lif=LIFState(z=torch.ones(1, 2), v=torch.zeros(1, 2), i=torch.zeros(1, 2)), ) z, s_full = cell(torch.zeros(1, 2), s1) s2 = LIFRefracState( rho=torch.zeros(1, 2), lif=LIFState( z=torch.tensor([[0, 1]], dtype=torch.float32), v=torch.zeros(1, 2), i=torch.zeros(1, 2), ), ) z, s_part = cell(torch.zeros(1, 2), s2) assert not s_full.lif.i[0, 0] == s_part.lif.i[0, 0]
[docs]def test_lif_refrac_cell_no_autapses(): cell = LIFRefracRecurrentCell(2, 2, autapses=False) assert ( cell.recurrent_weights * torch.eye(*cell.recurrent_weights.shape) ).sum() == 0 s1 = LIFRefracState( rho=torch.zeros(1, 2), lif=LIFState(z=torch.ones(1, 2), v=torch.zeros(1, 2), i=torch.zeros(1, 2)), ) z, s_full = cell(torch.zeros(1, 2), s1) s2 = LIFRefracState( rho=torch.zeros(1, 2), lif=LIFState( z=torch.tensor([[0, 1]], dtype=torch.float32), v=torch.zeros(1, 2), i=torch.zeros(1, 2), ), ) z, s_part = cell(torch.zeros(1, 2), s2) assert s_full.lif.i[0, 0] == s_part.lif.i[0, 0]
[docs]def test_lif_refrac_feedforward_backward(): batch_size = 16 cell = LIFRefracCell() x = torch.randn(batch_size, 20, 30) out, _ = cell(x) out.sum().backward()