Source code for norse.torch.functional.superspike

import torch
from norse.torch.functional.heaviside import heaviside


class SuperSpike(torch.autograd.Function):
    r"""SuperSpike surrogate gradient as described in Section 3.3.2 of

    F. Zenke, S. Ganguli, **"SuperSpike: Supervised Learning in Multilayer Spiking Neural Networks"**,
    Neural Computation 30, 1514–1541 (2018),
    `doi:10.1162/neco_a_01086 <https://www.mitpressjournals.org/doi/full/10.1162/neco_a_01086>`_
    """

    @staticmethod
    def forward(input_tensor: torch.Tensor, alpha: float) -> torch.Tensor:
        return heaviside(input_tensor)

    @staticmethod
    def setup_context(ctx, inputs, output):
        input_tensor, alpha = inputs
        ctx.alpha = alpha
        ctx.save_for_backward(input_tensor)

    @staticmethod
    def backward(ctx, grad_output):
        inp, = ctx.saved_tensors
        grad = None
        if ctx.needs_input_grad[0]:
            grad = grad_output / (torch.abs(inp) + 1.0).pow(
                2
            )  # section 3.3.2 (beta -> alpha)
        return grad, None

[docs] def super_fn(x: torch.Tensor, alpha: float = 100.0) -> torch.Tensor: return SuperSpike.apply(x, alpha)