Source code for norse.torch.module.test.test_lift
import torch
from norse.torch.module.lif import LIF
from norse.torch.module.lift import Lift
[docs]def test_lift_conv():
batch_size = 16
seq_length = 20
in_channels = 64
out_channels = 32
conv2d = Lift(torch.nn.Conv2d(in_channels, out_channels, 5, 1))
data = torch.randn(seq_length, batch_size, in_channels, 20, 30)
output = conv2d(data)
assert output.shape == torch.Size([seq_length, batch_size, out_channels, 16, 26])
[docs]def test_lift_sequential():
batch_size = 16
seq_length = 20
in_channels = 64
out_channels = 32
data = torch.randn(seq_length, batch_size, in_channels, 20, 30)
module = torch.nn.Sequential(
Lift(torch.nn.Conv2d(in_channels, out_channels, 5, 1)),
LIF(),
)
output, _ = module(data)
assert output.shape == torch.Size([seq_length, batch_size, out_channels, 16, 26])
if __name__ == "__main__":
test_lift_conv()
test_lift_sequential()