Source code for norse.torch.functional.lif_correlation
from typing import NamedTuple, Tuple
import torch
import torch.jit
from norse.torch.functional.lif import LIFState, LIFParameters, lif_step
from norse.torch.functional.correlation_sensor import (
CorrelationSensorState,
CorrelationSensorParameters,
correlation_sensor_step,
)
[docs]class LIFCorrelationState(NamedTuple):
lif_state: LIFState
input_correlation_state: CorrelationSensorState
recurrent_correlation_state: CorrelationSensorState
[docs]class LIFCorrelationParameters(NamedTuple):
lif_parameters: LIFParameters = LIFParameters()
input_correlation_parameters: CorrelationSensorParameters = (
CorrelationSensorParameters()
)
recurrent_correlation_parameters: CorrelationSensorParameters = (
CorrelationSensorParameters()
)
[docs]def lif_correlation_step(
input_tensor: torch.Tensor,
state: LIFCorrelationState,
input_weights: torch.Tensor,
recurrent_weights: torch.Tensor,
p: LIFCorrelationParameters = LIFCorrelationParameters(),
dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFCorrelationState]:
z_new, s_new = lif_step(
input_tensor,
state.lif_state,
input_weights,
recurrent_weights,
p.lif_parameters,
dt,
)
input_correlation_state_new = correlation_sensor_step(
z_pre=input_tensor,
z_post=z_new,
state=state.input_correlation_state,
p=p.input_correlation_parameters,
dt=dt,
)
recurrent_correlation_state_new = correlation_sensor_step(
z_pre=state.lif_state.z,
z_post=z_new,
state=state.recurrent_correlation_state,
p=p.recurrent_correlation_parameters,
dt=dt,
)
return (
z_new,
LIFCorrelationState(
lif_state=s_new,
input_correlation_state=input_correlation_state_new,
recurrent_correlation_state=recurrent_correlation_state_new,
),
)