Source code for norse.torch.functional.test.test_superspike

import torch
from norse.torch.functional.superspike import super_fn
from norse.torch.functional.heaviside import heaviside


[docs]def test_forward(): assert torch.equal(super_fn(torch.ones(100), 100.0), heaviside(torch.ones(100))) assert torch.equal(super_fn(-1.0 * torch.ones(100), 100.0), torch.zeros(100))
[docs]def test_backward(): x = torch.ones(10, requires_grad=True) out = super_fn(x, 100.0) out.backward(torch.ones(10)) assert torch.sum(x.grad > 0) == 10 x = torch.ones(10, requires_grad=True) out = super_fn(-0.001 * x, 100.0) out.backward(torch.ones(10)) assert torch.sum(x.grad < 0) == 10