Source code for norse.torch.module.test.test_coba
import torch
from norse.torch.module.coba_lif import CobaLIFCell, CobaLIFState
[docs]def test_coba():
cell = CobaLIFCell(4, 3)
data = torch.ones(5, 4)
spikes, state = cell(data)
assert spikes.shape == (5, 3)
assert state.v.shape == (5, 3)
[docs]def test_coba_state():
cell = CobaLIFCell(4, 3)
input_tensor = torch.ones(5, 4)
state = CobaLIFState(
z=torch.zeros(
input_tensor.shape[0],
cell.hidden_size,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
v=torch.zeros(
input_tensor.shape[0],
cell.hidden_size,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
g_e=torch.zeros(
input_tensor.shape[0],
cell.hidden_size,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
g_i=torch.zeros(
input_tensor.shape[0],
cell.hidden_size,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
)
spikes, state = cell(input_tensor, state)
assert spikes.shape == (5, 3)
assert state.v.shape == (5, 3)
[docs]def test_coba_backward():
cell = CobaLIFCell(4, 3)
data = torch.ones(5, 4)
spikes, _ = cell(data)
spikes.sum().backward()