Source code for norse.torch.module.test.test_sequential

import pytest

import torch
from torch import nn
import norse.torch as norse


[docs]def test_state_sequence(): d = torch.ones(10, 1, 20) l = norse.LIFRecurrent(20, 6) z, s = norse.SequentialState(l)(d) assert z.shape == (10, 1, 6) assert s[0].v.shape == (1, 6)
[docs]def test_state_sequence_apply_no_state(): d = torch.ones(10, 1, 1) m = norse.SequentialState(nn.Linear(1, 1), norse.LIFCell()) m(d) m.forward(d)
[docs]def test_state_sequence_apply_with_state(): d = torch.ones(10, 1, 1) m = norse.SequentialState(nn.Linear(1, 1), norse.LIFCell()) m(d, None) m.forward(d, None) s = [None, norse.LIFFeedForwardState(torch.tensor(0), torch.tensor(0))] m(d, s) m.forward(d, s)
[docs]def test_state_sequence_list(): d = torch.ones(10, 1, 20) l1 = norse.LIFRecurrent(20, 6) l2 = norse.LIFRecurrent(6, 2) s = [None, None] z, s = norse.SequentialState(l1, l2)(d, s) assert z.shape == (10, 1, 2) assert s[1].v.shape == (1, 2)
[docs]def test_state_sequence_norse(): d = torch.ones(10, 2, 10) l1 = norse.LIFRecurrent(10, 5) l2 = norse.LSNNRecurrent(5, 1) z, (s1, s2) = norse.SequentialState(l1, l2)(d) assert z.shape == (10, 2, 1) assert s1.v.shape == (2, 5) assert s2.v.shape == (2, 1)
[docs]def test_state_sequence_mix(): d = torch.ones(10, 3, 20) l1 = norse.LIFRecurrent(20, 10) l2 = torch.nn.RNN(10, 4, 2) # 2 layers l3 = norse.LSNNRecurrent(4, 1) state = [None, torch.randn(2, 3, 4), None] z, (s1, s2, s3) = norse.SequentialState(l1, l2, l3)(d, state) assert z.shape == (10, 3, 1) assert s1.v.shape == (3, 10) assert s2.shape == (2, 3, 4) assert s3.v.shape == (3, 1)
[docs]def test_state_sequence_conv(): data = torch.ones(1, 8, 16, 4, 4) # (timestep, minibatch, channels, x, y) model = norse.SequentialState( norse.Lift(torch.nn.Conv2d(16, 8, 3)), # (1, 8, 8, 2, 2) torch.nn.Flatten(2), # (1, 8, 32) norse.LSNNRecurrent(32, 6), # (1, 8, 6) torch.nn.RNN(6, 4, 2), # (1, 6, 4) with 2 recurrent layers norse.LIFRecurrent(4, 1), # (1, 4, 1) ) model(data)
[docs]def test_backprop_through_time_works(): model = norse.SequentialState( norse.LSNNRecurrent(1, 2), norse.LSNNRecurrent(2, 3), norse.LSNNRecurrent(3, 3), norse.LSNNRecurrent(3, 2), nn.Flatten(), norse.Lift(torch.nn.Linear(4, 2)), norse.Lift(torch.nn.Linear(2, 1)), ) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) loss_func = nn.MSELoss() data = torch.ones(3, 1, 2, 1) target = torch.ones(1, 1) state = None optimizer.zero_grad() # clear gradients for this training step for x in data: out, state = model(x, state) loss = loss_func(out, target) loss.backward() # backpropagation, compute gradients optimizer.step()
[docs]def test_sequential_forward_state_hook(): data = torch.ones(10, 2, 8, 1) model = norse.SequentialState( torch.nn.Linear(1, 10, bias=False), norse.LIFRecurrent(10, 10), norse.LIF(), norse.LIF(), norse.LIF(), ) spikes = [] states = [] def forward_state_hook(mod, inp, output): spikes.append(inp[0]) states.append(inp[1]) model.register_forward_state_hooks(forward_state_hook) _, s = model(data) spikes = torch.stack(spikes).detach().cpu() assert spikes.shape == (4, 10, 2, 8, 10) assert spikes.max() > 0 assert len(states) == 4 assert states[0] is None spikes = [] states = [] model(data, s) assert hasattr(states[0], "v") # isinstance is broken for namedtuple model.remove_forward_state_hooks() model(data, s) assert len(spikes) == 4
[docs]def test_sequential_debug_hook_twice(): def dub(mod, inp, output): pass model = norse.SequentialState( torch.nn.Linear(1, 10, bias=False), norse.LIFRecurrent(10, 10), ) model.register_forward_state_hooks(dub) with pytest.raises(ValueError): model.register_forward_state_hooks(dub)