Source code for norse.torch.utils.test.test_tensorboard

import torch
import norse.torch as snn

import norse.torch.utils.tensorboard as tensorboard


[docs]class MockWriter: key = None spikes = None index = None
[docs] def add_histogram(self, key, spikes, index): self.key = key self.spikes = spikes self.index = index
[docs] def add_image(self, key, spikes, index): self.key = key self.spikes = spikes self.index = index
[docs] def add_scalar(self, key, spikes, index): self.key = key self.spikes = spikes self.index = index
[docs]def test_activity_hook(): cell = snn.LIFCell() writer = MockWriter() hook = tensorboard.hook_spike_activity_mean("lif", writer) cell.register_forward_hook(hook) s = None for _ in range(7): z, s = cell(torch.ones(2), s) assert z.max() > 0 assert torch.eq(writer.spikes, z.mean()) hook = tensorboard.hook_spike_activity_sum("lif", writer) cell.register_forward_hook(hook) s = None for _ in range(7): z, s = cell(torch.ones(2), s) assert z.max() > 0 assert torch.eq(writer.spikes, z.sum())
[docs]def test_image_hook(): cell = snn.LIFCell() writer = MockWriter() hook = tensorboard.hook_spike_image("lif", writer) cell.register_forward_hook(hook) s = None for _ in range(7): z, s = cell(torch.ones(2), s) assert z.max() > 0 assert torch.all(torch.eq(writer.spikes, z))
[docs]def test_histogram_hook(): cell = snn.LIFCell() writer = MockWriter() hook = tensorboard.hook_spike_histogram_mean("lif", writer) cell.register_forward_hook(hook) s = None for _ in range(7): z, s = cell(torch.ones(2), s) assert z.max() > 0 assert torch.eq(writer.spikes, z.mean()) hook = tensorboard.hook_spike_histogram_sum("lif", writer) cell.register_forward_hook(hook) s = None for _ in range(7): z, s = cell(torch.ones(2), s) assert z.max() > 0 assert torch.eq(writer.spikes, z.sum())