Source code for norse.torch.module.test.test_izhikevich

import torch
import pytest

from norse.torch.module.izhikevich import (
    IzhikevichCell,
    IzhikevichRecurrentCell,
    Izhikevich,
    IzhikevichRecurrent,
)
from norse.torch.functional import izhikevich, IzhikevichSpikingBehavior

list_method = [
    izhikevich.tonic_spiking,
    izhikevich.phasic_spiking,
    izhikevich.tonic_bursting,
    izhikevich.phasic_bursting,
    izhikevich.mixed_mode,
    izhikevich.spike_frequency_adaptation,
    izhikevich.class_1_exc,
    izhikevich.class_2_exc,
    izhikevich.spike_latency,
    izhikevich.subthreshold_oscillation,
    izhikevich.resonator,
    izhikevich.integrator,
    izhikevich.rebound_spike,
    izhikevich.rebound_burst,
    izhikevich.threshhold_variability,
    izhikevich.bistability,
    izhikevich.dap,
    izhikevich.accomodation,
    izhikevich.inhibition_induced_spiking,
    izhikevich.inhibition_induced_bursting,
]


[docs]class SNNetwork(torch.nn.Module): def __init__(self, spiking_method: IzhikevichSpikingBehavior): super(SNNetwork, self).__init__() self.spiking_method = spiking_method self.l0 = Izhikevich(spiking_method) self.l1 = Izhikevich(spiking_method) self.s0 = self.s1 = None
[docs] def forward(self, spikes): spikes, self.s0 = self.l0(spikes, self.s0) _, self.s1 = self.l1(spikes, self.s1) return self.s1.v.squeeze()
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_cell(spiking_method): shape = (5, 2) data = torch.randn(shape) cell = IzhikevichCell(spiking_method) out, s = cell(data) for x in s: assert x.shape == (5, 2) assert out.shape == (5, 2)
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_recurrent_cell(spiking_method): cell = IzhikevichRecurrentCell(2, 4, spiking_method) data = torch.randn(5, 2) out, s = cell(data) for x in s: assert x.shape == (5, 4) assert out.shape == (5, 4)
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_recurrent_cell_autapses(spiking_method): cell = IzhikevichRecurrentCell( 2, 2, spiking_method, autapses=True, recurrent_weights=torch.ones(2, 2) * 0.01, dt=0.0001, ) assert not torch.allclose( torch.zeros(2), (cell.recurrent_weights * torch.eye(*cell.recurrent_weights.shape)).sum(0), ) s1 = izhikevich.IzhikevichRecurrentState( z=torch.ones(1, 2), v=torch.zeros(1, 2), u=torch.zeros(1, 2) ) _, s_full = cell(torch.zeros(1, 2), s1) s2 = izhikevich.IzhikevichRecurrentState( z=torch.tensor([[0, 1]], dtype=torch.float32), v=torch.zeros(1, 2), u=torch.zeros(1, 2), ) _, s_part = cell(torch.zeros(1, 2), s2) assert not s_full.v[0, 0] == s_part.v[0, 0]
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_recurrent_cell_no_autapses(spiking_method): cell = IzhikevichRecurrentCell(2, 2, spiking_method, autapses=False) assert ( cell.recurrent_weights * torch.eye(*cell.recurrent_weights.shape) ).sum() == 0 s1 = izhikevich.IzhikevichRecurrentState( z=torch.ones(1, 2), v=torch.zeros(1, 2), u=torch.zeros(1, 2) ) _, s_full = cell(torch.zeros(1, 2), s1) s2 = izhikevich.IzhikevichRecurrentState( z=torch.tensor([[0, 1]], dtype=torch.float32), v=torch.zeros(1, 2), u=torch.zeros(1, 2), ) _, s_part = cell(torch.zeros(1, 2), s2) assert s_full.v[0, 0] == s_part.v[0, 0]
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_in_time(spiking_method): layer = Izhikevich(spiking_method) data = torch.randn(10, 5, 2) out, _ = layer(data) assert out.shape == (10, 5, 2)
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_recurrent_sequence(spiking_method): l1 = IzhikevichRecurrent(8, 6, spiking_method) l2 = IzhikevichRecurrent(6, 4, spiking_method) l3 = IzhikevichRecurrent(4, 1, spiking_method) z = torch.ones(10, 1, 8) z, s1 = l1(z) z, s2 = l2(z) z, s3 = l3(z) assert s1.v.shape == (1, 6) assert s2.v.shape == (1, 4) assert s3.v.shape == (1, 1) assert z.shape == (10, 1, 1)
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_feedforward_cell_backward(spiking_method): # Tests that gradient variables can be used in subsequent applications cell = IzhikevichCell(spiking_method) data = torch.randn(5, 4) out, s = cell(data) out, _ = cell(out, s) loss = out.sum() loss.backward()
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_recurrent_cell_backward(spiking_method): # Tests that gradient variables can be used in subsequent applications cell = IzhikevichRecurrentCell(4, 4, spiking_method) data = torch.randn(5, 4) out, s = cell(data) out, _ = cell(out, s) loss = out.sum() loss.backward()
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_feedforward_layer(spiking_method): layer = Izhikevich(spiking_method) data = torch.randn(10, 5, 4) out, s = layer(data) assert out.shape == (10, 5, 4) for x in s: assert x.shape == (5, 4)
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_feedforward_layer_backward(spiking_method): model = Izhikevich(spiking_method) data = torch.ones(10, 12) out, _ = model(data) loss = out.sum() loss.backward()
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_recurrent_layer_backward_iteration(spiking_method): # Tests that gradient variables can be used in subsequent applications model = IzhikevichRecurrent(6, 6, spiking_method) data = torch.ones(10, 6) out, s = model(data) out, _ = model(out, s) loss = out.sum() loss.backward()
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_recurrent_layer_backward(spiking_method): model = IzhikevichRecurrent(6, 6, spiking_method) data = torch.ones(10, 6) out, _ = model(data) loss = out.sum() loss.backward()
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_izhikevich_feedforward_layer_backward_iteration(spiking_method): # Tests that gradient variables can be used in subsequent applications model = Izhikevich(spiking_method) data = torch.ones(10, 6) out, s = model(data) out, _ = model(out, s) loss = out.sum() loss.backward()
[docs]@pytest.mark.parametrize("spiking_method", list_method) def test_backward_model(spiking_method): model = SNNetwork(spiking_method) data = torch.ones(10, 12) out = model(data) loss = out.sum() loss.backward()