Source code for norse.torch.module.lif_correlation

import torch

from typing import Optional, Tuple

from norse.torch.functional.lif_correlation import (
    LIFCorrelationState,
    LIFCorrelationParameters,
    lif_correlation_step,
)
from norse.torch.functional.lif import LIFState
from norse.torch.functional.correlation_sensor import CorrelationSensorState


[docs]class LIFCorrelation(torch.nn.Module): def __init__( self, input_size, hidden_size, p: LIFCorrelationParameters = LIFCorrelationParameters(), dt: float = 0.001, ): super(LIFCorrelation, self).__init__() self.hidden_size = hidden_size self.input_size = input_size self.p = p self.dt = dt
[docs] def forward( self, input_tensor: torch.Tensor, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, state: Optional[LIFCorrelationState], ) -> Tuple[torch.Tensor, LIFCorrelationState]: if state is None: hidden_features = self.hidden_size input_features = self.input_size batch_size = input_tensor.shape[0] state = LIFCorrelationState( lif_state=LIFState( z=torch.zeros( batch_size, hidden_features, device=input_tensor.device, dtype=input_tensor.dtype, ), v=self.p.lif_parameters.v_leak.detach(), i=torch.zeros( batch_size, hidden_features, device=input_tensor.device, dtype=input_tensor.dtype, ), ), input_correlation_state=CorrelationSensorState( post_pre=torch.zeros( (batch_size, input_features, hidden_features), device=input_tensor.device, dtype=input_tensor.dtype, ), correlation_trace=torch.zeros( (batch_size, input_features, hidden_features), device=input_tensor.device, dtype=input_tensor.dtype, ).float(), anti_correlation_trace=torch.zeros( (batch_size, input_features, hidden_features), device=input_tensor.device, dtype=input_tensor.dtype, ).float(), ), recurrent_correlation_state=CorrelationSensorState( correlation_trace=torch.zeros( (batch_size, hidden_features, hidden_features), device=input_tensor.device, dtype=input_tensor.dtype, ), anti_correlation_trace=torch.zeros( (batch_size, hidden_features, hidden_features), device=input_tensor.device, dtype=input_tensor.dtype, ), post_pre=torch.zeros( (batch_size, hidden_features, hidden_features), device=input_tensor.device, dtype=input_tensor.dtype, ), ), ) return lif_correlation_step( input_tensor, state, input_weights, recurrent_weights, self.p, self.dt, )