Source code for norse.torch.module.test.test_training
"""
Tests that the training of Norse modules stays intact, for instance
that gradients are properly propagated
"""
import torch
from norse.torch.module.encode import PoissonEncoder
from norse.torch.module.lif import LIFRecurrentCell
[docs]class SNNetwork(torch.nn.Module):
def __init__(self):
super(SNNetwork, self).__init__()
self.encoder = PoissonEncoder(10, f_max=1000)
self.l0 = LIFRecurrentCell(12, 6)
self.l1 = LIFRecurrentCell(6, 1)
self.s0 = self.s1 = None
[docs] def forward(self, input_tensor):
spike_ts = self.encoder(input_tensor)
spikes = None
for spikes in spike_ts:
spikes, self.s0 = self.l0(spikes, self.s0)
spikes, self.s1 = self.l1(spikes, self.s1)
return spikes
[docs]def test_optimize_model():
model = SNNetwork()
optimizer = torch.optim.Adam(model.parameters(), lr=1)
optimizer.zero_grad()
input_weights = model.l0.input_weights.clone()
recurrent_weights = model.l1.recurrent_weights.clone()
data = torch.ones(1, 12)
out = model(data)
loss = out.sum()
loss.backward()
optimizer.step()
assert not torch.all(torch.eq(input_weights, model.l0.input_weights))
assert not torch.all(torch.eq(recurrent_weights, model.l0.recurrent_weights))