Source code for norse.torch.functional.adjoint.lif_mc_adjoint

import torch

from ..lif_mc import lif_mc_step
from ..lif import LIFState, LIFParameters

from typing import Tuple


[docs]class LIFMCAdjointFunction(torch.autograd.Function):
[docs] @staticmethod def forward( ctx, input_tensor: torch.Tensor, z: torch.Tensor, v: torch.Tensor, i: torch.Tensor, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, g_coupling: torch.Tensor, p: LIFParameters = LIFParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ctx.dt = dt ctx.p = p s = LIFState(z, v, i) z_new, s_new = lif_mc_step( input_tensor, s, input_weights, recurrent_weights, g_coupling, p, dt ) # dv before spiking dv_m = p.tau_mem_inv * ((p.v_leak - s.v) + s.i) # dv after spiking dv_p = p.tau_mem_inv * ((p.v_leak - s_new.v) + s.i) ctx.save_for_backward( input_tensor, s_new.v, z_new, dv_m, dv_p, input_weights, recurrent_weights, g_coupling, ) return z_new, s_new.v, s_new.i
[docs] @staticmethod def backward(ctx, doutput, lambda_v, lambda_i): ( input_tensor, v, z, dv_m, dv_p, input_weights, recurrent_weights, g_coupling, ) = ctx.saved_tensors p = ctx.p dt = ctx.dt dw_input = lambda_i.t().mm(input_tensor) dw_rec = lambda_i.t().mm(z) # update for coupling dg_coupling = lambda_v.t().mm(v) # lambda_i step dlambda_i = p.tau_syn_inv * (lambda_v - lambda_i) + torch.linear( lambda_v, g_coupling.t() ) lambda_i = lambda_i + dt * dlambda_i # lambda_v decay lambda_v = lambda_v - p.tau_mem_inv * dt * lambda_v output_term = z * (1 / dv_m) * (doutput) output_term[output_term != output_term] = 0.0 jump_term = z * (dv_p / dv_m) jump_term[jump_term != jump_term] = 0.0 lambda_v = (1 - z) * lambda_v + jump_term * lambda_v + output_term dinput = lambda_i.mm(input_weights) drecurrent = lambda_i.mm(recurrent_weights) return ( dinput, drecurrent, lambda_v, lambda_i, dw_input, dw_rec, dg_coupling, None, None, )
[docs]def lif_mc_adjoint_step( input: torch.Tensor, s: LIFState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, g_coupling: torch.Tensor, p: LIFParameters = LIFParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFState]: z, v, i = LIFMCAdjointFunction.apply( input, s.z, s.v, s.i, input_weights, recurrent_weights, g_coupling, p, dt ) return z, LIFState(z, v, i)