Source code for norse.torch.functional.test.test_lsnn

import torch

from norse.torch.functional.lsnn import (
    LSNNState,
    LSNNFeedForwardState,
    lsnn_feed_forward_step,
    lsnn_step,
    ada_lif_step,
)


[docs]def test_lsnn_step(): x = torch.ones(20) s = LSNNState( z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10), b=torch.zeros(10) ) input_weights = torch.randn(10, 20).float() recurrent_weights = torch.randn(10, 10).float() for _ in range(100): _, s = lsnn_step(x, s, input_weights, recurrent_weights)
[docs]def test_lsnn_step_batch(): x = torch.ones(16, 20) s = LSNNState( z=torch.zeros(16, 10), v=torch.zeros(16, 10), i=torch.zeros(16, 10), b=torch.zeros(16, 10), ) input_weights = torch.randn(10, 20).float() recurrent_weights = torch.randn(10, 10).float() for _ in range(100): _, s = lsnn_step(x, s, input_weights, recurrent_weights)
[docs]def test_ada_lif_step(): x = torch.ones(20) s = LSNNState( z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10), b=torch.zeros(10) ) input_weights = torch.randn(10, 20).float() recurrent_weights = torch.randn(10, 10).float() for _ in range(100): _, s = ada_lif_step(x, s, input_weights, recurrent_weights)
[docs]def test_lsnn_feed_forward_step(): x = torch.ones(10) s = LSNNFeedForwardState(v=torch.zeros(10), i=torch.zeros(10), b=torch.zeros(10)) for _ in range(100): _, s = lsnn_feed_forward_step(x, s)