Source code for norse.torch.functional.test.test_lift

import torch
import pytest

from norse.torch.functional.leaky_integrator import li_step, LIState
from norse.torch.functional.lif import (
    lif_step,
    lif_feed_forward_step,
    LIFState,
    LIFFeedForwardState,
    LIFParameters,
)
from norse.torch.functional.lift import lift


[docs]def test_lift_without_state_or_parameters(): data = torch.ones(3, 2, 1) lifted = lift(lif_feed_forward_step) with pytest.raises(TypeError): # No state given lifted(data)
[docs]def test_lift_with_state_without_parameters(): data = torch.ones(3, 2, 1) lifted = lift(lif_feed_forward_step) z, s = lifted( data, state=LIFFeedForwardState(torch.zeros_like(data[0]), torch.zeros_like(data[0])), ) assert z.shape == (3, 2, 1) assert s.v.shape == (2, 1) assert s.i.shape == (2, 1)
[docs]def test_lift_with_state_and_parameters(): data = torch.ones(3, 2, 1) lifted = lift( lif_feed_forward_step, p=LIFParameters(v_th=torch.as_tensor(0.3), method="tanh") ) z, s = lifted( data, state=LIFFeedForwardState(torch.zeros_like(data[0]), torch.zeros_like(data[0])), ) assert z.shape == (3, 2, 1) assert s.v.shape == (2, 1) assert s.i.shape == (2, 1)
[docs]def test_lift_with_lift_step(): data = torch.ones(3, 2, 1) lifted = lift(lif_step) z, s = lifted( data, state=LIFState( v=torch.zeros(2, 1), i=torch.zeros(2, 1), z=torch.zeros(2, 1), ), input_weights=torch.ones(1, 1), recurrent_weights=torch.ones(1, 1), ) assert z.shape == (3, 2, 1) assert s.v.shape == (2, 1)
[docs]def test_lift_with_leaky_integrator(): data = torch.ones(3, 2, 1) lifted = lift(li_step) z, s = lifted( data, state=LIState( v=torch.zeros(2, 1), i=torch.zeros(2, 1), ), input_weights=torch.ones(1, 1), ) assert z.shape == (3, 2, 1) assert s.v.shape == (2, 1)