Source code for norse.torch.functional.lift
"""
A module for lifting neuron activation functions in time.
Simlar to the :module:`.lift`_ module.
"""
import torch
class _Lifted:
"""
Helper class for the :func:`lift`_ function to allow for pickling.
Used in distributed execution, like PyTorch Lightning
"""
def __init__(self, activation, p=None):
self.activation = activation
self.p = p
def __call__(self, input_tensor, **kwargs):
if self.p is not None and "p" not in kwargs:
kwargs["p"] = self.p
state = kwargs.get("state")
kwargs.pop("state", None)
outputs = []
for i in input_tensor:
if state is not None:
out, state = self.activation(i, state=state, **kwargs)
else:
out, state = self.activation(i, **kwargs)
outputs.append(out)
return torch.stack(outputs), state
[docs]
def lift(activation, p=None):
"""
Creates a lifted version of the given activation function which
applies the activation function in the temporal domain. The returned
callable can be applied later as if it was a regular activation function,
but the input is now assumed to be a tensor whose first dimension is time.
Parameters:
activation (Callable[[torch.Tensor, Any, Any], Tuple[torch.Tensor, Any]]):
The activation function that takes an input tensor, an optional state, and an
optional parameter object and returns a tuple of (spiking output, neuron state).
The returned spiking output includes the time domain.
p (Any): An optional parameter object to hand to the activation function.
Returns:
A :class:`.Callable`_ that, when applied, evaluates the activation function N times,
where N is the size of the outer (temporal) dimension. The application will provide
a tensor of shape (time, ...).
"""
return _Lifted(activation, p)