Source code for norse.torch.functional.test.test_lif

import torch
import pytest

import norse
from norse.torch.functional.lif import (
    LIFState,
    LIFFeedForwardState,
    LIFParameters,
    LIFParametersJIT,
    lif_step,
    lif_step_integral,
    lif_feed_forward_step,
    lif_feed_forward_integral,
    _lif_feed_forward_step_jit,
    lif_current_encoder,
)


[docs]@pytest.fixture(autouse=True) def cpp_fixture(): norse.utils.IS_OPS_LOADED = True # Enable cpp
[docs]@pytest.fixture() def jit_fixture(): norse.utils.IS_OPS_LOADED = False # Disable cpp
[docs]def test_lif_cpp_and_jit_step(): assert norse.utils.IS_OPS_LOADED x = torch.ones(20) s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)) input_weights = torch.linspace(0, 0.5, 200).view(10, 20) recurrent_weights = torch.linspace(0, -2, 100).view(10, 10) results = [ torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 1, 0, 1, 1, 0, 1, 1]), torch.as_tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 1]), ] cpp_results = [] cpp_states = [] for result in results: z, s = lif_step(x, s, input_weights, recurrent_weights) cpp_results.append(z) cpp_states.append(s) s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)) norse.utils.IS_OPS_LOADED = False # Disable cpp for i, result in enumerate(results): z, s = lif_step(x, s, input_weights, recurrent_weights) assert torch.equal(z, result.float()) assert torch.equal(z, cpp_results[i]) assert torch.equal(s.v, cpp_states[i].v) assert s.v.dtype == torch.float32 assert torch.equal(s.z, cpp_states[i].z) assert torch.equal(s.i, cpp_states[i].i) assert s.i.dtype == torch.float32
[docs]def test_lif_cpp_back(cpp_fixture): x = torch.ones(2) s = LIFState(z=torch.zeros(1), v=torch.zeros(1), i=torch.zeros(1)) s.v.requires_grad = True input_weights = torch.ones(2) recurrent_weights = torch.ones(1) _, s = lif_step(x, s, input_weights, recurrent_weights) z, s = lif_step(x, s, input_weights, recurrent_weights) z.sum().backward()
[docs]def test_lif_jit_back(jit_fixture): x = torch.ones(2) s = LIFState(z=torch.zeros(1), v=torch.zeros(1), i=torch.zeros(1)) s.v.requires_grad = True input_weights = torch.ones(2) recurrent_weights = torch.ones(1) _, s = lif_step(x, s, input_weights, recurrent_weights) z, s = lif_step(x, s, input_weights, recurrent_weights) z.sum().backward()
[docs]def test_lif_heavi(): x = torch.ones(2, 1) s = LIFState(z=torch.ones(2, 1), v=torch.zeros(2, 1), i=torch.zeros(2, 1)) input_weights = torch.ones(1, 1) * 10 recurrent_weights = torch.ones(1, 1) p = LIFParameters(method="heaviside") _, s = lif_step(x, s, input_weights, recurrent_weights, p) z, s = lif_step(x, s, input_weights, recurrent_weights, p) assert z.max() > 0 assert z.shape == (2, 1)
[docs]def test_lif_integral(jit_fixture): x = torch.ones(10, 20) s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)) input_weights = torch.linspace(0, 0.5, 200).view(10, 20) recurrent_weights = torch.linspace(0, -2, 100).view(10, 10) results = torch.stack( [ torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 1, 0, 1, 1, 0, 1, 1]), torch.as_tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 1]), ] ) z, s = lif_step_integral(x, s, input_weights, recurrent_weights) assert torch.equal(torch.tensor(s.v.size()), torch.tensor([10])) assert torch.equal(torch.tensor(s.i.size()), torch.tensor([10])) assert torch.equal(z, results.float())
[docs]def test_lif_integral_cpp(cpp_fixture): x = torch.ones(10, 20) s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)) input_weights = torch.linspace(0, 0.5, 200).view(10, 20) recurrent_weights = torch.linspace(0, -2, 100).view(10, 10) results = torch.stack( [ torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 1, 0, 1, 1, 0, 1, 1]), torch.as_tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 1]), ] ) z, s = lif_step_integral(x, s, input_weights, recurrent_weights) assert torch.equal(torch.tensor(s.v.size()), torch.tensor([10])) assert torch.equal(torch.tensor(s.i.size()), torch.tensor([10])) assert torch.equal(z, results.float())
[docs]def test_lif_feed_forward_step(): x = torch.ones(10) s = LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10)) results = [0.0, 0.1, 0.27, 0.487, 0.7335, 0.9963, 0.0, 0.3951, 0.7717, 0.0] for result in results: _, s = lif_feed_forward_step(x, s) assert torch.allclose(torch.as_tensor(result), s.v, atol=1e-4)
[docs]def test_lif_feed_forward_step_batch(): x = torch.ones(2, 1) s = LIFFeedForwardState(v=torch.zeros(2, 1), i=torch.zeros(2, 1)) z, s = lif_feed_forward_step(x, s) assert z.shape == (2, 1)
[docs]def test_lif_feed_forward_step_cpp(cpp_fixture): assert norse.utils.IS_OPS_LOADED == True x = torch.ones(10) s = LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10)) p = LIFParametersJIT( torch.as_tensor(1.0 / 5e-3), torch.as_tensor(1.0 / 1e-2), torch.as_tensor(0.0), torch.as_tensor(1.0), torch.as_tensor(0.0), "super", torch.as_tensor(0.0), ) results = [0.0, 0.1, 0.27, 0.487, 0.7335, 0.9963, 0.0, 0.3951, 0.7717, 0.0] for result in results: _, s = _lif_feed_forward_step_jit(x, s, p) assert torch.allclose(torch.as_tensor(result), s.v, atol=1e-4)
[docs]def test_lif_feed_forward_step_jit(jit_fixture): assert norse.utils.IS_OPS_LOADED == False x = torch.ones(10) s = LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10)) p = LIFParametersJIT( torch.as_tensor(1.0 / 5e-3), torch.as_tensor(1.0 / 1e-2), torch.as_tensor(0.0), torch.as_tensor(1.0), torch.as_tensor(0.0), "super", torch.as_tensor(0.0), ) results = [0.0, 0.1, 0.27, 0.487, 0.7335, 0.9963, 0.0, 0.3951, 0.7717, 0.0] for result in results: _, s = _lif_feed_forward_step_jit(x, s, p) assert torch.allclose(torch.as_tensor(result), s.v, atol=1e-4)
[docs]def test_lif_feed_forward_integrate_cpp(cpp_fixture): assert norse.utils.IS_OPS_LOADED == True x = torch.ones(9, 2) s = LIFFeedForwardState(v=torch.zeros(2), i=torch.zeros(2)) expected_v = torch.tensor(0.7717) _, s = lif_feed_forward_integral(x, s) assert torch.allclose(expected_v, s.v[0], atol=1e-4)
[docs]def test_lif_feed_forward_integrate_jit(jit_fixture): x = torch.ones(9, 2) s = LIFFeedForwardState(v=torch.zeros(2), i=torch.zeros(2)) expected_v = torch.tensor(0.7717) _, s = lif_feed_forward_integral(x, s) assert torch.allclose(expected_v, s.v[0], atol=1e-4)
[docs]def test_lif_current_encoder(): x = torch.ones(10) v = torch.zeros(10) results = [ 0.1, 0.19, 0.2710, 0.3439, 0.4095, 0.4686, 0.5217, 0.5695, 0.6126, 0.6513, ] for result in results: _, v = lif_current_encoder(x, v) assert torch.allclose(torch.as_tensor(result), v, atol=1e-4)