import torch
import torch.jit
import numpy as np
from norse.torch.functional.heaviside import heaviside
from norse.torch.functional.superspike import super_fn
superspike_fn = super_fn
class HeaviErfc(torch.autograd.Function):
r"""Approximation of the heaviside step function as
.. math::
h(x,k) = \frac{1}{2} + \frac{1}{2} \text{erfc}(k x)
where erfc is the error function.
def forward(ctx, x, k):
ctx.k = k
return heaviside(x) # 0 + 0.5 * torch.erfc(k * x)
def backward(ctx, dy):
(x,) = ctx.saved_tensors
derfc = (2 * torch.exp(-(ctx.k * x).pow(2))) / (torch.as_tensor(np.pi).sqrt())
return derfc * dy, None
def heavi_erfc_fn(x: torch.Tensor, k: float):
return HeaviErfc.apply(x, k)
class HeaviTanh(torch.autograd.Function):
r"""Approximation of the heaviside step function as
.. math::
h(x,k) = \frac{1}{2} + \frac{1}{2} \text{tanh}(k x)
def forward(ctx, x, k):
ctx.k = k
return heaviside(x) # 0.5 + 0.5 * torch.tanh(k * x)
def backward(ctx, dy):
(x,) = ctx.saved_tensors
dtanh = 1 - (x * ctx.k).tanh().pow(2)
return dy * dtanh, None
def heavi_tanh_fn(x: torch.Tensor, k: float):
return HeaviTanh.apply(x, k)
class Logistic(torch.autograd.Function):
r"""Probalistic approximation of the heaviside step function as
.. math::
z \sim p(\frac{1}{2} + \frac{1}{2} \text{tanh}(k x))
def forward(ctx, x, k):
ctx.k = k
p = 0.5 + 0.5 * torch.tanh(k * x)
return torch.distributions.bernoulli.Bernoulli(probs=p).sample()
def backward(ctx, dy):
(x,) = ctx.saved_tensors
dtanh = 1 - (x * ctx.k).tanh().pow(2)
return dy * dtanh, None
def logistic_fn(x: torch.Tensor, k: float):
return Logistic.apply(x, k)
class HeaviCirc(torch.autograd.Function):
r"""Approximation of the heaviside step function as
.. math::
h(x,\alpha) = \frac{1}{2} + \frac{1}{2} \
\frac{x}{(x^2 + \alpha^2)^{1/2}}
def forward(ctx, x, alpha):
ctx.alpha = alpha
return heaviside(x) # 0.5 + 0.5 * (x / (x.pow(2) + alpha ** 2).sqrt())
def backward(ctx, dy):
(x,) = ctx.saved_tensors
alpha = ctx.alpha
return (
* (
-(x.pow(2) / (2 * (alpha**2 + x.pow(2)).pow(1.5)))
+ 1 / (2 * (alpha**2 + x.pow(2)).sqrt())
* 2
* alpha,
def heavi_circ_fn(x: torch.Tensor, k: float):
return HeaviCirc.apply(x, k)
class CircDist(torch.autograd.Function):
r"""Approximation of the heaviside step function as
.. math::
h(x,\alpha) = 0.5 + 0.5 * \frac{x}{\sqrt{x^2 + alpha^2}}
def forward(ctx, x, alpha):
ctx.alpha = alpha
return torch.distributions.bernoulli.Bernoulli(
0.5 + 0.5 * (x / (x.pow(2) + alpha**2).sqrt())
def backward(ctx, dy):
(x,) = ctx.saved_tensors
alpha = ctx.alpha
return (
* (
-(x.pow(2) / (2 * (alpha**2 + x.pow(2)).pow(1.5)))
+ 1 / (2 * (alpha**2 + x.pow(2)).sqrt())
* 2
* alpha,
def circ_dist_fn(x: torch.Tensor, k: float):
return CircDist.apply(x, k)
class Triangle(torch.autograd.Function):
r"""Triangular/piecewise linear surrogate gradient as in
S.K. Esser et al., **"Convolutional networks for fast, energy-efficient neuromorphic computing"**,
Proceedings of the National Academy of Sciences 113(41), 11441-11446, (2016),
`doi:10.1073/pnas.1604850113 <>`_
G. Bellec et al., **"A solution to the learning dilemma for recurrent networks of spiking neurons"**,
Nature Communications 11(1), 3625, (2020),
`doi:10.1038/s41467-020-17236-y <>`_
def forward(ctx, x: torch.Tensor, alpha: float) -> torch.Tensor:
ctx.alpha = alpha
return heaviside(x)
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
alpha = ctx.alpha
grad_input = grad_output.clone()
grad = grad_input * alpha * torch.relu(1 - x.abs())
return grad, None
def triangle_fn(x: torch.Tensor, alpha: float = 0.3) -> torch.Tensor:
return Triangle.apply(x, alpha)
def threshold(x: torch.Tensor, method: str, alpha: float) -> torch.Tensor:
if method == "heaviside":
return heaviside(x)
elif method == "super":
return superspike_fn(x, torch.as_tensor(alpha))
elif method == "triangle":
return triangle_fn(x, alpha)
elif method == "tanh":
return heavi_tanh_fn(x, alpha)
elif method == "circ":
return heavi_circ_fn(x, alpha)
elif method == "heavi_erfc":
return heavi_erfc_fn(x, alpha)
raise ValueError(
f"Attempted to apply threshold function {method}, but no such "
+ "function exist. We currently support heaviside, super, "
+ "tanh, triangle, circ, and heavi_erfc."
def sign(x: torch.Tensor, method: str, alpha: float) -> torch.Tensor:
return 2 * threshold(x, method, alpha) - 1