Source code for norse.torch.functional.lif_mc_refrac
from typing import Tuple
import torch
from norse.torch.functional.lif_refrac import LIFRefracState, LIFRefracFeedForwardState
from norse.torch.functional.lif_refrac import LIFRefracParameters
from norse.torch.functional.lif import LIFState, LIFFeedForwardState
from norse.torch.functional.threshold import threshold
def lif_mc_refrac_step(
input_tensor: torch.Tensor,
state: LIFRefracState,
input_weights: torch.Tensor,
recurrent_weights: torch.Tensor,
g_coupling: torch.Tensor,
p: LIFRefracParameters = LIFRefracParameters(),
dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFRefracState]:
# compute whether neurons are refractory or not
refrac_mask = threshold(state.rho, p.lif.method, p.lif.alpha)
# compute voltage
dv = (1 - refrac_mask) * dt * p.lif.tau_mem_inv * (
(p.lif.v_leak - state.lif.v) + state.lif.i
) + torch.nn.functional.linear(state.lif.v, g_coupling)
v_decayed = state.lif.v + dv
# compute current updates
di = -dt * p.lif.tau_syn_inv * state.lif.i
i_decayed = state.lif.i + di
# compute new spikes
z_new = threshold(v_decayed - p.lif.v_th, p.lif.method, p.lif.alpha)
# compute reset
v_new = (1 - z_new) * v_decayed + z_new * p.lif.v_reset
# compute current jumps
i_new = (
i_decayed
+ torch.nn.functional.linear(input_tensor, input_weights)
+ torch.nn.functional.linear(state.lif.z, recurrent_weights)
)
# compute update to refractory counter
rho_new = (1 - z_new) * torch.nn.functional.relu(
state.rho - refrac_mask
) + z_new * p.rho_reset
return z_new, LIFRefracState(LIFState(z_new, v_new, i_new), rho_new)
[docs]
def lif_mc_refrac_feed_forward_step(
input_tensor: torch.Tensor,
state: LIFRefracFeedForwardState,
g_coupling: torch.Tensor,
p: LIFRefracParameters = LIFRefracParameters(),
dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFRefracFeedForwardState]:
# compute whether neurons are refractory or not
refrac_mask = threshold(state.rho, p.lif.method, p.lif.alpha)
# compute voltage
dv = (1 - refrac_mask) * dt * p.lif.tau_mem_inv * (
(p.lif.v_leak - state.lif.v) + state.lif.i
) + torch.nn.functional.linear(state.lif.v, g_coupling)
v_decayed = state.lif.v + dv
# compute current updates
di = -dt * p.lif.tau_syn_inv * state.lif.i
i_decayed = state.lif.i + di
# compute new spikes
z_new = threshold(v_decayed - p.lif.v_th, p.lif.method, p.lif.alpha)
# compute reset
v_new = (1 - z_new) * v_decayed + z_new * p.lif.v_reset
# compute current jumps
i_new = i_decayed + input_tensor
# compute update to refractory counter
rho_new = (1 - z_new) * torch.nn.functional.relu(
state.rho - refrac_mask
) + z_new * p.rho_reset
return z_new, LIFRefracFeedForwardState(LIFFeedForwardState(v_new, i_new), rho_new)