Source code for norse.torch.functional.stdp_sensor

from typing import NamedTuple

import torch


[docs]class STDPSensorParameters(NamedTuple): """Parameters of an STDP sensor as it is used for event driven plasticity rules. Parameters: eta_p (torch.Tensor): correlation state eta_m (torch.Tensor): anti correlation state tau_ac_inv (torch.Tensor): anti-correlation sensor time constant tau_c_inv (torch.Tensor): correlation sensor time constant """ eta_p: torch.Tensor = torch.as_tensor(1.0) eta_m: torch.Tensor = torch.as_tensor(1.0) tau_ac_inv: torch.Tensor = torch.as_tensor(1.0 / 100e-3) tau_c_inv: torch.Tensor = torch.as_tensor(1.0 / 100e-3)
[docs]class STDPSensorState(NamedTuple): """State of an event driven STDP sensor. Parameters: a_pre (torch.Tensor): presynaptic STDP sensor state. a_post (torch.Tensor): postsynaptic STDP sensor state. """ a_pre: torch.Tensor a_post: torch.Tensor
[docs]def stdp_sensor_step( z_pre: torch.Tensor, z_post: torch.Tensor, state: STDPSensorState, p: STDPSensorParameters = STDPSensorParameters(), dt: float = 0.001, ) -> STDPSensorState: """Event driven STDP rule. Parameters: z_pre (torch.Tensor): pre-synaptic spikes z_post (torch.Tensor): post-synaptic spikes s (STDPSensorState): state of the STDP sensor p (STDPSensorParameters): STDP sensor parameters dt (float): integration time step """ da_pre = p.tau_c_inv * (-state.a_pre) a_pre_decayed = state.a_pre + dt * da_pre da_post = p.tau_c_inv * (-state.a_post) a_post_decayed = state.a_post + dt * da_post a_pre_new = a_pre_decayed + z_pre * p.eta_p a_post_new = a_post_decayed + z_post * p.eta_m return STDPSensorState(a_pre_new, a_post_new)