import torch

)

from norse.torch.module.snn import SNN, SNNCell, SNNRecurrent, SNNRecurrentCell

r"""Computes a single euler-integration step of a feed-forward exponential
LIF neuron-model *without* recurrence, adapted from
It takes as input the input current as generated by an arbitrary torch
module or function. More specifically it implements one integration step
of the following ODE

.. math::
\begin{align*}
\dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\
\dot{i} &= -1/\tau_{\text{syn}} i \\
\dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right)
\end{align*}

together with the jump condition

.. math::
z = \Theta(v - v_{\text{th}})

and transition equations

.. math::
i = i + i_{\text{in}}

where :math:i_{\text{in}} is meant to be the result of applying
an arbitrary pytorch module (such as a convolution) to input spikes.

Parameters:
p (LIFAdExParameters): Parameters of the LIFEx neuron model.
dt (float): Time step to use.

Examples:
>>> batch_size = 16
>>> data = torch.randn(batch_size, 20, 30)
>>> output, s0 = lif_ex(data)
"""

super().__init__(
self.initial_state,
p=p,
**kwargs,
)

def initial_state(self, x: torch.Tensor) -> LIFAdExFeedForwardState:
v=self.p.v_leak.detach(),
i=torch.zeros(
*x.shape,
device=x.device,
dtype=x.dtype,
),
a=torch.zeros(
*x.shape,
device=x.device,
dtype=x.dtype,
),
)
return state

r"""Computes a single of euler-integration step of a recurrent adaptive exponential
LIF neuron-model *with* recurrence, adapted from
More specifically it implements one integration step of the following ODE

.. math::
\begin{align*}
\dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\
\dot{i} &= -1/\tau_{\text{syn}} i \\
\dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right)
\end{align*}

together with the jump condition

.. math::
z = \Theta(v - v_{\text{th}})

and transition equations

.. math::
\begin{align*}
v &= (1-z) v + z v_{\text{reset}} \\
i &= i + w_{\text{input}} z_{\text{in}} \\
i &= i + w_{\text{rec}} z_{\text{rec}}
\end{align*}

where :math:z_{\text{rec}} and :math:z_{\text{in}} are the
recurrent and input spikes respectively.

Examples:
>>> batch_size = 16
>>> input = torch.randn(batch_size, 10)
>>> output, s0 = lif(input)

Parameters:
input_size (int): Size of the input. Also known as the number of input features.
hidden_size (int): Size of the hidden state. Also known as the number of input features.
p (LIFAdExParameters): Parameters of the LIF neuron model.
input_weights (torch.Tensor): Weights used for input tensors. Defaults to a random
matrix normalized to the number of hidden neurons.
recurrent_weights (torch.Tensor): Weights used for input tensors. Defaults to a random
matrix normalized to the number of hidden neurons.
autapses (bool): Allow self-connections in the recurrence? Defaults to False. Will also
remove autapses in custom recurrent weights, if set above.
dt (float): Time step to use.

"""

def __init__(
self,
input_size: int,
hidden_size: int,
**kwargs,
):
super().__init__(
state_fallback=self.initial_state,
p=p,
input_size=input_size,
hidden_size=hidden_size,
**kwargs,
)

def initial_state(self, input_tensor: torch.Tensor) -> LIFAdExState:
dims = (*input_tensor.shape[:-1], self.hidden_size)
z=torch.zeros(
*dims,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
v=torch.full(
dims,
self.p.v_leak.detach(),
device=input_tensor.device,
dtype=input_tensor.dtype,
),
i=torch.zeros(
*dims,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
a=torch.zeros(
*dims,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
)
return state

r"""A neuron layer that wraps a recurrent LIFAdExCell in time such
that the layer keeps track of temporal sequences of spikes.
After application, the layer returns a tuple containing
(spikes from all timesteps, state from the last timestep).

Example:
>>> data = torch.zeros(10, 5, 2) # 10 timesteps, 5 batches, 2 neurons
>>> l(data) # Returns tuple of (Tensor(10, 5, 4), LIFExState)

Parameters:
p (LIFAdExParameters): The neuron parameters as a torch Module, which allows the module
to configure neuron parameters as optimizable.
dt (float): Time step to use in integration. Defaults to 0.001.
"""

super().__init__(
state_fallback=self.initial_state,
p=p,
**kwargs,
)

def initial_state(self, input_tensor: torch.Tensor) -> LIFAdExFeedForwardState:
v=torch.full(
input_tensor.shape[1:],  # Assume first dimension is time
self.p.v_leak.detach(),
device=input_tensor.device,
dtype=input_tensor.dtype,
),
i=torch.zeros(
input_tensor.shape[1:],  # Assume first dimension is time
device=input_tensor.device,
dtype=input_tensor.dtype,
),
a=torch.zeros(
input_tensor.shape[1:],  # Assume first dimension is time
device=input_tensor.device,
dtype=input_tensor.dtype,
),
)
return state

r"""A neuron layer that wraps a recurrent LIFAdExRecurrentCell in time (*with*
recurrence) such that the layer keeps track of temporal sequences of spikes.
After application, the layer returns a tuple containing
(spikes from all timesteps, state from the last timestep).

Example:
>>> data = torch.zeros(10, 5, 2) # 10 timesteps, 5 batches, 2 neurons
>>> l(data) # Returns tuple of (Tensor(10, 5, 4), LIFAdExState)

Parameters:
input_size (int): The number of input neurons
hidden_size (int): The number of hidden neurons
p (LIFAdExParameters): The neuron parameters as a torch Module, which allows the module
to configure neuron parameters as optimizable.
input_weights (torch.Tensor): Weights used for input tensors. Defaults to a random
matrix normalized to the number of hidden neurons.
recurrent_weights (torch.Tensor): Weights used for input tensors. Defaults to a random
matrix normalized to the number of hidden neurons.
autapses (bool): Allow self-connections in the recurrence? Defaults to False. Will also
remove autapses in custom recurrent weights, if set above.
dt (float): Time step to use in integration. Defaults to 0.001.
"""

def __init__(
self,
input_size: int,
hidden_size: int,
**kwargs,
):
super().__init__(
state_fallback=self.initial_state,
input_size=input_size,
hidden_size=hidden_size,
p=p,
**kwargs,
)

def initial_state(self, input_tensor: torch.Tensor) -> LIFAdExState:
dims = (  # Remove first dimension (time)
*input_tensor.shape[1:-1],
self.hidden_size,
)
z=torch.zeros(
*dims,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
v=torch.full(
dims,
self.p.v_leak.detach(),
device=input_tensor.device,
dtype=input_tensor.dtype,
),
i=torch.zeros(
*dims,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
a=torch.zeros(
*dims,
device=input_tensor.device,
dtype=input_tensor.dtype,
),
)