Source code for norse.torch.functional.test.test_encode

"""
Test for the encoder module
"""
import torch

from norse.torch.functional.encode import (
    population_encode,
    constant_current_lif_encode,
    spike_latency_lif_encode,
    spike_latency_encode,
)

# Fixes a linting error:
# pylint: disable=E1102


[docs]def test_encode_population(): data = torch.as_tensor([0, 0.5, 1]) out_features = 3 actual = population_encode(data, out_features) expected = torch.as_tensor( [ [1.0000, 0.8824969, 0.6065307], [0.8824969, 1.0000, 0.8824969], [0.6065307, 0.8824969, 1.0000], ] ) assert torch.allclose(actual, expected)
[docs]def test_encode_population_scale(): data = torch.as_tensor([0, 0.5, 1]) out_features = 3 scale = data.max() actual = population_encode(data, out_features, scale=scale) expected = torch.as_tensor( [ [1.0000, 0.8824969, 0.6065307], [0.8824969, 1.0000, 0.8824969], [0.6065307, 0.8824969, 1.0000], ] ) assert torch.allclose(actual, expected)
[docs]def test_encode_population_batch(): data = torch.as_tensor([[0, 0, 0], [0.5, 0.5, 0.5], [1, 1, 1]]) out_features = 3 actual = population_encode(data, out_features) assert actual.shape == (3, 3, 3)
[docs]def test_constant_current_lif_encode(): data = torch.as_tensor([0, 0, 0, 0]) z = constant_current_lif_encode(data, 2) assert torch.equal(z, torch.zeros((2, 4))) data = torch.as_tensor([[16, 16, 16], [32, 32, 32], [64, 64, 64], [128, 128, 128]]) z = constant_current_lif_encode(data, 10) assert torch.equal(z[-1], torch.ones((4, 3)))
[docs]def test_spike_latency_lif_encode(): spikes = spike_latency_lif_encode(1.1 * torch.ones(10), seq_length=128) assert torch.sum(spikes).data == 10
[docs]def test_spike_latency_encode_with_batch(): data = torch.as_tensor([[100, 100], [100, 100]]) spikes = constant_current_lif_encode(data, 5) actual = spike_latency_encode(spikes) expected = torch.zeros((5, 2, 2)) expected[0] = torch.as_tensor([[1, 1], [1, 1]]) assert torch.equal(actual, expected)
[docs]def test_spike_latency_encode_without_batch(): spikes = torch.as_tensor([[0, 1, 1, 0], [1, 1, 1, 0]]) actual = spike_latency_encode(spikes) assert torch.equal(actual, torch.as_tensor([[0, 1, 1, 0], [1, 0, 0, 0]]))
[docs]def test_spike_latency_encode_without_batch_2(): spikes = torch.as_tensor([[[0, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1]]]) actual = spike_latency_encode(spikes) expected = torch.as_tensor([[[0, 1, 1], [1, 1, 1]], [[1, 0, 0], [0, 0, 0]]]) assert torch.equal(actual, expected)
[docs]def test_spike_latency_encode_without_batch_3(): spikes = torch.as_tensor( [ [ [1.0, 1.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0, 0.0], ], [ [1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0], ], ] ) actual = spike_latency_encode(spikes) expected = spikes.clone() expected[1] = torch.as_tensor( [ [0.0, 0.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 0.0, 1.0], ] ) assert torch.equal(actual, expected)