import torch
from typing import NamedTuple, Tuple
from norse.torch.functional.threshold import threshold
[docs]
class IzhikevichParameters(NamedTuple):
"""Parametrization of av Izhikevich neuron
Parameters:
a (float): time scale of the recovery variable u. Smaller values result in slower recovery in 1/ms
b (float): sensitivity of the recovery variable u to the subthreshold fluctuations of the membrane potential v. Greater values couple v and u more strongly resulting in possible subthreshold oscillations and low-threshold spiking dynamics
c (float): after-spike reset value of the membrane potential in mV
d (float): after-spike reset of the recovery variable u caused by slow high-threshold Na+ and K+ conductances in mV
sq (float): constant of the v squared variable in mV/ms
mn (float): constant of the v variable in 1/ms
bias (float): bias constant in mV/ms
v_th (torch.Tensor): threshold potential in mV
tau_inv (float) : inverse time constant in 1/ms
method (str): method to determine the spike threshold
(relevant for surrogate gradients)
alpha (float): hyper parameter to use in surrogate gradient computation
"""
a: float
b: float
c: float
d: float
sq: float = 0.04
mn: float = 5
bias: float = 140
v_th: float = 30
tau_inv: float = 250
method: str = "super"
alpha: float = 100.0
[docs]
class IzhikevichState(NamedTuple):
"""State of a Izhikevich neuron
Parameters:
v (torch.Tensor): membrane potential
u (torch.Tensor): membrane recovery variable
"""
v: torch.Tensor
u: torch.Tensor
class IzhikevichRecurrentState(NamedTuple):
"""State of a Izhikevich neuron
Parameters:
v (torch.Tensor): membrane potential
u (torch.Tensor): membrane recovery variable
"""
z: torch.Tensor
v: torch.Tensor
u: torch.Tensor
[docs]
class IzhikevichSpikingBehavior(NamedTuple):
"""Spiking behavior of a Izhikevich neuron
Parameters:
p (IzhikevichParameters) : parameters of the Izhikevich neuron model
s (IzhikevichState) : state of the Izhikevich neuron model
"""
p: IzhikevichParameters
s: IzhikevichState
def create_izhikevich_spiking_behavior(
a: float,
b: float,
c: float,
d: float,
v_rest: float,
u_rest: float,
tau_inv: float = 250,
) -> IzhikevichSpikingBehavior:
"""
A function that allows for the creation of custom Izhikevich neurons models, as well as a visualization of their behavior on a 250 ms time window.
Parameters:
a (float): time scale of the recovery variable u. Smaller values result in slower recovery in 1/ms
b (float): sensitivity of the recovery variable u to the subthreshold fluctuations of the membrane potential v. Greater values couple v and u more strongly resulting in possible subthreshold oscillations and low-threshold spiking dynamics
c (float): after-spike reset value of the membrane potential in mV
d (float): after-spike reset of the recovery variable u caused by slow high-threshold Na+ and K+ conductances in mV
v_rest (float): resting value of the v variable in mV
u_rest (float): resting value of the u variable
tau_inv (float) : inverse time constant in 1/ms
current (float) : input current
time_print (float) : size of the time window in ms
timestep_print (float) : timestep of the simulation in ms
"""
params = IzhikevichParameters(a=a, b=b, c=c, d=d, tau_inv=tau_inv)
behavior = IzhikevichSpikingBehavior(
p=params,
s=IzhikevichState(
v=torch.tensor(float(v_rest), requires_grad=True),
u=torch.tensor(u_rest) * params.b,
),
)
return behavior
tonic_spiking_p = IzhikevichParameters(a=0.02, b=0.2, c=-65, d=6)
tonic_spiking = IzhikevichSpikingBehavior(
p=tonic_spiking_p,
s=IzhikevichState(
v=torch.tensor(-70.0, requires_grad=True),
u=torch.tensor(-70) * tonic_spiking_p.b,
),
)
phasic_spiking_p = IzhikevichParameters(a=0.02, b=0.25, c=-65, d=6)
phasic_spiking = IzhikevichSpikingBehavior(
p=phasic_spiking_p,
s=IzhikevichState(
v=torch.tensor(-64.0, requires_grad=True),
u=torch.tensor(-64) * phasic_spiking_p.b,
),
)
tonic_bursting_p = IzhikevichParameters(a=0.02, b=0.2, c=-50, d=2)
tonic_bursting = IzhikevichSpikingBehavior(
p=tonic_bursting_p,
s=IzhikevichState(
v=torch.tensor(-70.0, requires_grad=True),
u=torch.tensor(-70) * tonic_bursting_p.b,
),
)
phasic_bursting_p = IzhikevichParameters(a=0.02, b=0.25, c=-55, d=0.05, tau_inv=200)
phasic_bursting = IzhikevichSpikingBehavior(
p=phasic_bursting_p,
s=IzhikevichState(
v=torch.tensor(-64.0, requires_grad=True),
u=torch.tensor(-64) * phasic_bursting_p.b,
),
)
mixed_mode_p = IzhikevichParameters(a=0.02, b=0.2, c=-55, d=4, tau_inv=250)
mixed_mode = IzhikevichSpikingBehavior(
p=mixed_mode_p,
s=IzhikevichState(
v=torch.tensor(-70.0, requires_grad=True), u=torch.tensor(-70) * mixed_mode_p.b
),
)
spike_frequency_adaptation_p = IzhikevichParameters(
a=0.01, b=0.2, c=-65, d=8, tau_inv=250
)
spike_frequency_adaptation = IzhikevichSpikingBehavior(
p=spike_frequency_adaptation_p,
s=IzhikevichState(
v=torch.tensor(-70.0, requires_grad=True),
u=torch.tensor(-70) * spike_frequency_adaptation_p.b,
),
)
class_1_exc_p = IzhikevichParameters(
a=0.02, b=-0.1, c=-55, d=6, mn=4.1, bias=108, tau_inv=250
)
class_1_exc = IzhikevichSpikingBehavior(
p=class_1_exc_p,
s=IzhikevichState(
v=torch.tensor(-60.0, requires_grad=True), u=torch.tensor(-60) * class_1_exc_p.b
),
)
class_2_exc_p = IzhikevichParameters(a=0.2, b=0.26, c=-65, d=0, tau_inv=250)
class_2_exc = IzhikevichSpikingBehavior(
p=class_2_exc_p,
s=IzhikevichState(
v=torch.tensor(-64.0, requires_grad=True), u=torch.tensor(-64) * class_2_exc_p.b
),
)
spike_latency_p = IzhikevichParameters(a=0.02, b=0.2, c=-65, d=6, tau_inv=250)
spike_latency = IzhikevichSpikingBehavior(
p=spike_latency_p,
s=IzhikevichState(
v=torch.tensor(-70.0, requires_grad=True),
u=torch.tensor(-70) * spike_latency_p.b,
),
)
subthreshold_oscillation_p = IzhikevichParameters(
a=0.05, b=0.26, c=-60, d=0, tau_inv=250
)
subthreshold_oscillation = IzhikevichSpikingBehavior(
p=subthreshold_oscillation_p,
s=IzhikevichState(
v=torch.tensor(-62.0, requires_grad=True),
u=torch.tensor(-62) * subthreshold_oscillation_p.b,
),
)
resonator_p = IzhikevichParameters(a=0.1, b=0.26, c=-60, d=-1, tau_inv=250)
resonator = IzhikevichSpikingBehavior(
p=resonator_p,
s=IzhikevichState(
v=torch.tensor(-62.0, requires_grad=True), u=torch.tensor(-62) * resonator_p.b
),
)
integrator_p = IzhikevichParameters(
a=0.02, b=-0.1, c=-55, d=6, mn=4.1, bias=108, tau_inv=250
)
integrator = IzhikevichSpikingBehavior(
p=integrator_p,
s=IzhikevichState(
v=torch.tensor(-60.0, requires_grad=True), u=torch.tensor(-60) * integrator_p.b
),
)
rebound_spike_p = IzhikevichParameters(a=0.03, b=0.25, c=-60, d=4, tau_inv=200)
rebound_spike = IzhikevichSpikingBehavior(
p=rebound_spike_p,
s=IzhikevichState(
v=torch.tensor(-64.0, requires_grad=True),
u=torch.tensor(-64) * rebound_spike_p.b,
),
)
rebound_burst_p = IzhikevichParameters(a=0.03, b=0.25, c=-52, d=0, tau_inv=200)
rebound_burst = IzhikevichSpikingBehavior(
p=rebound_burst_p,
s=IzhikevichState(
v=torch.tensor(-64.0, requires_grad=True),
u=torch.tensor(-64) * rebound_burst_p.b,
),
)
threshold_variability_p = IzhikevichParameters(a=0.03, b=0.25, c=-60, d=4, tau_inv=250)
threshold_variability = IzhikevichSpikingBehavior(
p=threshold_variability_p,
s=IzhikevichState(
v=torch.tensor(-64.0, requires_grad=True),
u=torch.tensor(-64) * threshold_variability_p.b,
),
)
bistability_p = IzhikevichParameters(a=0.1, b=0.26, c=-60, d=0, tau_inv=250)
bistability = IzhikevichSpikingBehavior(
p=bistability_p,
s=IzhikevichState(
v=torch.tensor(-61.0, requires_grad=True), u=torch.tensor(-61) * bistability_p.b
),
)
dap_p = IzhikevichParameters(a=1.0, b=0.2, c=-60, d=-21, tau_inv=100)
dap = IzhikevichSpikingBehavior(
p=dap_p,
s=IzhikevichState(
v=torch.tensor(-70.0, requires_grad=True), u=torch.tensor(-70) * dap_p.b
),
)
accomodation_p = IzhikevichParameters(a=0.02, b=1.0, c=-55, d=4, tau_inv=500)
accomodation = IzhikevichSpikingBehavior(
p=accomodation_p,
s=IzhikevichState(v=torch.tensor(-65.0, requires_grad=True), u=torch.tensor(-16)),
)
inhibition_induced_spiking_p = IzhikevichParameters(
a=-0.02, b=-1.0, c=-60, d=8, tau_inv=250
)
inhibition_induced_spiking = IzhikevichSpikingBehavior(
p=inhibition_induced_spiking_p,
s=IzhikevichState(
v=torch.tensor(-63.8, requires_grad=True),
u=torch.tensor(-63.8) * inhibition_induced_spiking_p.b,
),
)
inhibition_induced_bursting_p = IzhikevichParameters(
a=-0.026, b=-1.0, c=-45, d=-2, tau_inv=250
)
inhibition_induced_bursting = IzhikevichSpikingBehavior(
p=inhibition_induced_bursting_p,
s=IzhikevichState(
v=torch.tensor(-63.8, requires_grad=True),
u=torch.tensor(-63.8) * inhibition_induced_bursting_p.b,
),
)
[docs]
def izhikevich_feed_forward_step(
input_current: torch.Tensor,
s: IzhikevichState,
p: IzhikevichParameters,
dt: float = 0.001,
) -> Tuple[torch.Tensor, IzhikevichState]:
v_ = s.v + p.tau_inv * dt * (
p.sq * s.v**2 + p.mn * s.v + p.bias - s.u + input_current
)
u_ = s.u + p.tau_inv * dt * p.a * (p.b * s.v - s.u)
z_ = threshold(v_ - p.v_th, p.method, p.alpha)
v_ = (1 - z_) * v_ + z_ * p.c
u_ = (1 - z_) * u_ + z_ * (u_ + p.d)
return z_, IzhikevichState(v_, u_)
def izhikevich_recurrent_step(
input_current: torch.Tensor,
s: IzhikevichRecurrentState,
input_weights: torch.Tensor,
recurrent_weights: torch.Tensor,
p: IzhikevichParameters,
dt: float = 0.001,
) -> Tuple[torch.Tensor, IzhikevichRecurrentState]:
input_current = torch.nn.functional.linear(input_current, input_weights)
recurrent_current = torch.nn.functional.linear(s.z, recurrent_weights)
v_ = s.v + p.tau_inv * dt * (
p.sq * s.v**2 + p.mn * s.v + p.bias - s.u + input_current + recurrent_current
)
u_ = s.u + p.tau_inv * dt * p.a * (p.b * s.v - s.u)
z_ = threshold(v_ - p.v_th, p.method, p.alpha)
v_ = (1 - z_) * v_ + z_ * p.c
u_ = (1 - z_) * u_ + z_ * (u_ + p.d)
return z_, IzhikevichRecurrentState(z_, v_, u_)