Source code for norse.torch.functional.test.test_threshold

import torch
from pytest import raises


from norse.torch.functional.threshold import (
    threshold,
    sign,
    heavi_erfc_fn,
    logistic_fn,
    circ_dist_fn,
)


[docs]def test_heavi_erfc_fn_forward(): assert torch.equal(heavi_erfc_fn(torch.ones(100), 100.0), torch.ones(100)) assert torch.equal(heavi_erfc_fn(-1.0 * torch.ones(100), 100.0), torch.zeros(100))
[docs]def test_heavi_erfc_fn_backward(): x = torch.ones(10, requires_grad=True) out = heavi_erfc_fn(x, 0.1) out.backward(torch.ones(10)) assert torch.sum(x.grad > 0) == 10 x = torch.ones(10, requires_grad=True) out = heavi_erfc_fn(-0.001 * x, 0.1) out.backward(torch.ones(10)) assert torch.sum(x.grad < 0) == 10
[docs]def test_logistic_fn_forward(): x = torch.ones(10) out = logistic_fn(x, 0.1) assert torch.sum(torch.logical_or(out == 1, out == 0)) == 10
[docs]def test_logistic_fn_backward(): x = torch.ones(10, requires_grad=True) out = logistic_fn(x, 0.1) out.backward(torch.ones(10)) assert torch.sum(x.grad > 0) == 10 x = torch.ones(10, requires_grad=True) out = logistic_fn(-0.001 * x, 0.1) out.backward(torch.ones(10)) assert torch.sum(x.grad < 0) == 10
[docs]def test_circ_dist_fn_forward(): x = torch.ones(10) out = circ_dist_fn(x, 0.1) assert torch.sum(torch.logical_or(out == 1, out == 0)) == 10
[docs]def test_circ_dist_fn_backward(): x = torch.ones(10, requires_grad=True) out = circ_dist_fn(x, 0.1) out.backward(torch.ones(10)) assert torch.sum(x.grad > 0) == 10 x = torch.ones(10, requires_grad=True) out = circ_dist_fn(-0.001 * x, 0.1) out.backward(torch.ones(10)) assert torch.sum(x.grad < 0) == 10
[docs]def test_threshold_throws(): alpha = 10.0 x = torch.ones(10) with raises(ValueError): _ = threshold(x, "noasd", alpha)
[docs]def test_threshold_backward(): alpha = 10.0 x = torch.ones(10) methods = ["super", "tanh", "triangle", "circ", "heavi_erfc"] for method in methods: x = torch.ones(10, requires_grad=True) out = threshold(x, method, alpha) out.backward(torch.ones(10)) x = torch.full((10,), 0.1, requires_grad=True) out = threshold(x, method, alpha) out.backward(torch.ones(10)) x = torch.full((10,), -0.1, requires_grad=True) out = threshold(x, method, alpha) out.backward(torch.ones(10))
[docs]def test_threshold(): alpha = 10.0 methods = ["super", "heaviside", "tanh", "triangle", "circ", "heavi_erfc"] for method in methods: x = torch.ones(10) out = threshold(x, method, alpha) assert torch.equal(out, torch.ones(10)) x = torch.full((10,), 0.1) out = threshold(x, method, alpha) assert torch.equal(out, torch.ones(10)) x = torch.full((10,), -0.1) out = threshold(x, method, alpha) assert torch.equal(out, torch.zeros(10))
[docs]def test_sign(): alpha = 10.0 methods = [ "super", "heaviside", "tanh", "triangle", "circ", ] for method in methods: x = torch.ones(10) out = sign(x, method, alpha) assert torch.equal(out, torch.ones(10)) x = torch.full((10,), 0.1) out = sign(x, method, alpha) assert torch.equal(out, torch.ones(10)) x = torch.full((10,), -0.1) out = sign(x, method, alpha) assert torch.equal(out, -torch.ones(10))