Source code for norse.torch.module.test.test_encode

"""
Test for the stateful encoder module
"""

import torch

from ..lif import LIFRecurrent
from norse.torch.module.encode import (
    PopulationEncoder,
    ConstantCurrentLIFEncoder,
    SpikeLatencyEncoder,
    PoissonEncoder,
    PoissonEncoderStep,
    SignedPoissonEncoder,
    SignedPoissonEncoderStep,
    SpikeLatencyLIFEncoder,
)

# Fixes a linting error:
# pylint: disable=E1102


[docs]def test_encode_population(): data = torch.as_tensor([0, 0.5, 1]) out_features = 3 actual = PopulationEncoder(out_features).forward(data) expected = torch.as_tensor( [ [1.0000, 0.8824969, 0.6065307], [0.8824969, 1.0000, 0.8824969], [0.6065307, 0.8824969, 1.0000], ] ) assert torch.allclose(actual, expected)
[docs]def test_constant_current_lif_encode(): data = torch.as_tensor([0, 0, 0, 0]) z = ConstantCurrentLIFEncoder(2).forward(data) assert torch.equal(z, torch.zeros((2, 4)))
[docs]def test_spike_latency_encode(): data = torch.as_tensor([[0, 100, 100], [100, 100, 100]]) encoder = torch.nn.Sequential(ConstantCurrentLIFEncoder(2), SpikeLatencyEncoder()) actual = encoder(data) expected = torch.zeros((2, 2, 3)) expected[0] = torch.as_tensor([[0, 1, 1], [1, 1, 1]]) assert torch.equal(actual, expected)
[docs]def test_signed_poisson_encoder(): seq_length = 100 encoder = SignedPoissonEncoder(seq_length=seq_length) x = torch.randn(10, 10) out = encoder(x) assert out.shape == torch.Size([seq_length, 10, 10])
[docs]def test_signed_poisson_encoder_step(): encoder = SignedPoissonEncoderStep() x = torch.randn(10, 10) out = encoder(x) assert out.shape == torch.Size([10, 10])
[docs]def test_poisson_encoder(): seq_length = 100 encoder = PoissonEncoder(seq_length=seq_length) x = torch.randn(10, 10) out = encoder(x) assert out.shape == torch.Size([seq_length, 10, 10])
[docs]def test_poisson_encoder_step(): encoder = PoissonEncoderStep() x = torch.randn(10, 10) out = encoder(x) assert out.shape == torch.Size([10, 10])
[docs]def test_spike_latency_lif_encoder(): seq_length = 100 encoder = SpikeLatencyLIFEncoder(seq_length=seq_length) x = torch.randn(10, 10) out = encoder(x) assert out.shape == torch.Size([seq_length, 10, 10])
[docs]def test_spike_latency_lif_encoder_backward(): timesteps = 20 data = torch.randn(7, 5) + 10 encoder = SpikeLatencyLIFEncoder(timesteps) layer = LIFRecurrent(5, 2) encoded = encoder(data) out, _ = layer(encoded) out.sum().backward()
[docs]def test_spike_latency_encode_max_spikes(): encoder = torch.nn.Sequential( ConstantCurrentLIFEncoder(seq_length=128), SpikeLatencyEncoder() ) spikes = encoder(1.1 * torch.ones(10)) assert torch.sum(spikes).data == 10
[docs]def test_spike_latency_encode_backward(): torch.autograd.set_detect_anomaly(True) timesteps = 20 data = torch.randn(7, 5) + 10 encoder = torch.nn.Sequential( ConstantCurrentLIFEncoder(seq_length=timesteps), SpikeLatencyEncoder() ) layer = LIFRecurrent(5, 2) encoded = encoder(data) out, _ = layer(encoded) out = out.sum() out.backward()
[docs]def test_spike_latency_encode_chain(): data = torch.randn(7, 5) + 10 encoder = torch.nn.Sequential(ConstantCurrentLIFEncoder(2), SpikeLatencyEncoder()) encoder(data)