Source code for norse.torch.functional.correlation_sensor

import torch
import torch.jit

from typing import NamedTuple

from .heaviside import heaviside


@torch.jit.script
def pre_mask(weights, z):
    """Computes the mask produced by the pre-synaptic spikes on
    the synapse array."""
    return torch.transpose(
        torch.transpose(torch.zeros_like(weights), 1, 2) + z, 1, 2
    )  # pragma: no cover


@torch.jit.script
def post_mask(weights, z):
    """Computes the mask produced by post-synaptic spikes on
    the synapse array.
    """
    return torch.zeros_like(weights) + z  # pragma: no cover


@torch.jit.script
def post_pre_update(post_pre, post_spike_mask, pre_spike_mask):
    """Computes which synapses in the synapse array should be updated."""
    return heaviside(post_pre + post_spike_mask - pre_spike_mask)  # pragma: no cover


[docs]class CorrelationSensorParameters(NamedTuple): eta_p: torch.Tensor = torch.as_tensor(1.0) eta_m: torch.Tensor = torch.as_tensor(1.0) tau_ac_inv: torch.Tensor = torch.as_tensor(1.0 / 100e-3) tau_c_inv: torch.Tensor = torch.as_tensor(1.0 / 100e-3)
[docs]class CorrelationSensorState(NamedTuple): post_pre: torch.Tensor correlation_trace: torch.Tensor anti_correlation_trace: torch.Tensor
[docs]def correlation_sensor_step( z_pre: torch.Tensor, z_post: torch.Tensor, state: CorrelationSensorState, p: CorrelationSensorParameters = CorrelationSensorParameters(), dt: float = 0.001, ) -> CorrelationSensorState: """Euler integration step of an idealized version of the correlation sensor as it is present on the BrainScaleS 2 chips. """ dcorrelation_trace = dt * p.tau_c_inv * (-state.correlation_trace) correlation_trace_decayed = ( state.correlation_trace + (1 - state.post_pre) * dcorrelation_trace ) danti_correlation_trace = dt * p.tau_ac_inv * (-state.anti_correlation_trace) anti_correlation_trace_decayed = ( state.anti_correlation_trace + state.post_pre * danti_correlation_trace ) # compute the pre and post masks based on the current spikes pre_spike_mask = pre_mask(state.post_pre, z_pre) post_spike_mask = post_mask(state.post_pre, z_post) post_pre_new = post_pre_update(state.post_pre, post_spike_mask, pre_spike_mask) correlation_trace_new = correlation_trace_decayed + (p.eta_p * pre_spike_mask) anti_correlation_trace_new = ( anti_correlation_trace_decayed + p.eta_m * post_spike_mask ) return CorrelationSensorState( post_pre=post_pre_new, correlation_trace=correlation_trace_new, anti_correlation_trace=anti_correlation_trace_new, )
[docs]def correlation_based_update( ts: int, linear_update: torch.nn.Module, weights: torch.Tensor, correlation_state: CorrelationSensorState, learning_rate: float, ts_frequency: int, ): if ts % ts_frequency == 0: ( _, input_features, hidden_features, ) = correlation_state.correlation_trace.shape # proposed weight update dw = torch.cat( ( correlation_state.correlation_trace.flatten(), correlation_state.anti_correlation_trace.flatten(), ) ) dw = linear_update(dw).detach() weights = weights + learning_rate * torch.reshape( dw, (hidden_features, input_features) ) # reset correlation traces correlation_state.correlation_trace.zero_() correlation_state.anti_correlation_trace.zero_() return weights