Source code for norse.torch.functional.adjoint.lif_mc_refrac_adjoint

from typing import Tuple

import torch

from ..lif_refrac import LIFRefracState, LIFRefracParameters
from ..lif import LIFState
from ..lif_mc_refrac import lif_mc_refrac_step
from ..heaviside import heaviside


[docs]class LIFMCRefracAdjointFunction(torch.autograd.Function):
[docs] @staticmethod def forward( ctx, input_tensor: torch.Tensor, z: torch.Tensor, v: torch.Tensor, i: torch.Tensor, rho: torch.Tensor, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, g_coupling: torch.Tensor, p: LIFRefracParameters = LIFRefracParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ctx.p = p ctx.dt = dt s = LIFRefracState(LIFState(z, v, i), rho) z_new, s_new = lif_mc_refrac_step( input_tensor, s, input_weights, recurrent_weights, g_coupling, p, dt ) # dv before spiking dv_m = p.lif.tau_mem_inv * ((p.lif.v_leak - s.lif.v) + s.lif.i) # dv after spiking dv_p = p.lif.tau_mem_inv * ((p.lif.v_leak - s_new.lif.v) + s.lif.i) ctx.save_for_backward( input_tensor, s_new.lif.v, s_new.lif.z, dv_m, dv_p, input_weights, recurrent_weights, g_coupling, s_new.rho, ) return z_new, s_new.lif.v, s_new.lif.i, s_new.rho
[docs] @staticmethod def backward(ctx, doutput, lambda_v, lambda_i, lambda_rho): ( input_tensor, v, z, dv_m, dv_p, input_weights, recurrent_weights, g_coupling, refrac_count, ) = ctx.saved_tensors p = ctx.p tau_syn_inv = p.lif.tau_syn_inv tau_mem_inv = p.lif.tau_mem_inv dt = ctx.dt dw_input = lambda_i.t().mm(input_tensor) dw_rec = lambda_i.t().mm(z) refrac_mask = heaviside(refrac_count) # update for coupling dg_coupling = lambda_v.t().mm(v) # lambda_i step dlambda_i = tau_syn_inv * ( (1 - refrac_mask) * lambda_v - lambda_i ) + torch.linear(lambda_v, g_coupling.t()) lambda_i = lambda_i + dt * dlambda_i # lambda_v decay lambda_v = lambda_v - tau_mem_inv * dt * (1 - refrac_mask) * lambda_v output_term = z * (1 / dv_m) * (doutput) output_term[output_term != output_term] = 0.0 jump_term = z * (dv_p / dv_m) jump_term[jump_term != jump_term] = 0.0 lambda_v = (1 - z) * lambda_v + jump_term * lambda_v + output_term dinput = lambda_i.mm(input_weights) drecurrent = lambda_i.mm(recurrent_weights) return (dinput, drecurrent, lambda_v, lambda_i, dw_input, dw_rec, dg_coupling)
[docs]def lif_mc_refrac_adjoint( input: torch.Tensor, s: LIFRefracState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, g_coupling: torch.Tensor, p: LIFRefracParameters = LIFRefracParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFRefracState]: z, v, i, rho = LIFMCRefracAdjointFunction.apply( input, s.lif.z, s.lif.v, s.lif.i, s.rho, input_weights, recurrent_weights, g_coupling, p, dt, ) return z, LIFRefracState(LIFState(z, v, i), rho)