Source code for norse.torch.functional.test.test_regularization

import torch

from norse.torch.functional.lif import LIFFeedForwardState, lif_feed_forward_step
from norse.torch.functional.regularization import regularize_step, voltage_accumulator


[docs]def test_regularisation_spikes(): x = torch.ones(5, 10) s = LIFFeedForwardState(torch.ones(10), torch.ones(10)) z, s = lif_feed_forward_step(x, s) zr, rs = regularize_step(z, s) assert torch.equal(z, zr) assert rs == 0 z, s = lif_feed_forward_step(x, s) zr, rs = regularize_step(z, s) assert rs == 50
[docs]def test_regularisation_voltage(): x = torch.ones(5, 10) s = LIFFeedForwardState(torch.ones(10), torch.ones(10)) z, s = lif_feed_forward_step(x, s) # pytype: disable=wrong-arg-types zr, rs = regularize_step(z, s, accumulator=voltage_accumulator) # pytype: enable=wrong-arg-types assert torch.equal(z, zr) assert torch.equal(s.v, rs)
[docs]def test_regularisation_voltage_state(): x = torch.ones(5, 10) state = torch.zeros(10) s = LIFFeedForwardState(torch.ones(10), torch.ones(10)) z, s = lif_feed_forward_step(x, s) # pytype: disable=wrong-arg-types zr, rs = regularize_step(z, s, accumulator=voltage_accumulator, state=state) # pytype: enable=wrong-arg-types assert torch.equal(z, zr) assert torch.equal(s.v, rs)