# Source code for norse.torch.functional.threshold

import torch
import torch.jit
import numpy as np

from norse.torch.functional.heaviside import heaviside
from norse.torch.functional.superspike import super_fn

superspike_fn = super_fn

r"""Approximation of the heaviside step function as

.. math::
h(x,k) = \frac{1}{2} + \frac{1}{2} \text{erfc}(k x)

where erfc is the error function.
"""

@staticmethod
def forward(ctx, x, k):
ctx.save_for_backward(x)
ctx.k = k
return heaviside(x)  # 0 + 0.5 * torch.erfc(k * x)

@staticmethod
def backward(ctx, dy):
(x,) = ctx.saved_tensors
derfc = (2 * torch.exp(-(ctx.k * x).pow(2))) / (torch.as_tensor(np.pi).sqrt())
return derfc * dy, None

[docs]
@torch.jit.ignore
def heavi_erfc_fn(x: torch.Tensor, k: float):
return HeaviErfc.apply(x, k)

r"""Approximation of the heaviside step function as

.. math::
h(x,k) = \frac{1}{2} + \frac{1}{2} \text{tanh}(k x)
"""

@staticmethod
def forward(ctx, x, k):
ctx.save_for_backward(x)
ctx.k = k
return heaviside(x)  # 0.5 + 0.5 * torch.tanh(k * x)

@staticmethod
def backward(ctx, dy):
(x,) = ctx.saved_tensors
dtanh = 1 - (x * ctx.k).tanh().pow(2)
return dy * dtanh, None

[docs]
@torch.jit.ignore
def heavi_tanh_fn(x: torch.Tensor, k: float):
return HeaviTanh.apply(x, k)

r"""Probalistic approximation of the heaviside step function as

.. math::
z \sim p(\frac{1}{2} + \frac{1}{2} \text{tanh}(k x))
"""

@staticmethod
def forward(ctx, x, k):
ctx.k = k
ctx.save_for_backward(x)
p = 0.5 + 0.5 * torch.tanh(k * x)

@staticmethod
def backward(ctx, dy):
(x,) = ctx.saved_tensors
dtanh = 1 - (x * ctx.k).tanh().pow(2)
return dy * dtanh, None

[docs]
@torch.jit.ignore
def logistic_fn(x: torch.Tensor, k: float):
return Logistic.apply(x, k)

r"""Approximation of the heaviside step function as

.. math::
h(x,\alpha) = \frac{1}{2} + \frac{1}{2} \
\frac{x}{(x^2 + \alpha^2)^{1/2}}
"""

@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x)
ctx.alpha = alpha
return heaviside(x)  # 0.5 + 0.5 * (x / (x.pow(2) + alpha ** 2).sqrt())

@staticmethod
def backward(ctx, dy):
(x,) = ctx.saved_tensors
alpha = ctx.alpha

return (
dy
* (
-(x.pow(2) / (2 * (alpha**2 + x.pow(2)).pow(1.5)))
+ 1 / (2 * (alpha**2 + x.pow(2)).sqrt())
)
* 2
* alpha,
None,
)

[docs]
@torch.jit.ignore
def heavi_circ_fn(x: torch.Tensor, k: float):
return HeaviCirc.apply(x, k)

r"""Approximation of the heaviside step function as

.. math::
h(x,\alpha) = 0.5 + 0.5 * \frac{x}{\sqrt{x^2 + alpha^2}}
"""

@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x)
ctx.alpha = alpha

0.5 + 0.5 * (x / (x.pow(2) + alpha**2).sqrt())
).sample()

@staticmethod
def backward(ctx, dy):
(x,) = ctx.saved_tensors
alpha = ctx.alpha
return (
dy
* (
-(x.pow(2) / (2 * (alpha**2 + x.pow(2)).pow(1.5)))
+ 1 / (2 * (alpha**2 + x.pow(2)).sqrt())
)
* 2
* alpha,
None,
)

[docs]
@torch.jit.ignore
def circ_dist_fn(x: torch.Tensor, k: float):
return CircDist.apply(x, k)

r"""Triangular/piecewise linear surrogate gradient as in

S.K. Esser et al., **"Convolutional networks for fast, energy-efficient neuromorphic computing"**,
Proceedings of the National Academy of Sciences 113(41), 11441-11446, (2016),
doi:10.1073/pnas.1604850113 <https://www.pnas.org/content/113/41/11441.short>_
G. Bellec et al., **"A solution to the learning dilemma for recurrent networks of spiking neurons"**,
Nature Communications 11(1), 3625, (2020),
doi:10.1038/s41467-020-17236-y <https://www.nature.com/articles/s41467-020-17236-y>_
"""

@staticmethod
@torch.jit.ignore
def forward(ctx, x: torch.Tensor, alpha: float) -> torch.Tensor:
ctx.save_for_backward(x)
ctx.alpha = alpha
return heaviside(x)

@staticmethod
@torch.jit.ignore
(x,) = ctx.saved_tensors
alpha = ctx.alpha

[docs]
@torch.jit.ignore
def triangle_fn(x: torch.Tensor, alpha: float = 0.3) -> torch.Tensor:
return Triangle.apply(x, alpha)

def threshold(x: torch.Tensor, method: str, alpha: float) -> torch.Tensor:
if method == "heaviside":
return heaviside(x)
elif method == "super":
return superspike_fn(x, torch.as_tensor(alpha))
elif method == "triangle":
return triangle_fn(x, alpha)
elif method == "tanh":
return heavi_tanh_fn(x, alpha)
elif method == "circ":
return heavi_circ_fn(x, alpha)
elif method == "heavi_erfc":
return heavi_erfc_fn(x, alpha)
else:
raise ValueError(
f"Attempted to apply threshold function {method}, but no such "
+ "function exist. We currently support heaviside, super, "
+ "tanh, triangle, circ, and heavi_erfc."
)

def sign(x: torch.Tensor, method: str, alpha: float) -> torch.Tensor:
return 2 * threshold(x, method, alpha) - 1