Source code for norse.benchmark.norse_lif

import time
import torch

from norse.torch.functional.lif import (
    LIFFeedForwardState,
    LIFParametersJIT,
    _lif_feed_forward_step_jit,
)
from norse.torch.module.encode import PoissonEncoder

from .benchmark import BenchmarkParameters


[docs]class LIFBenchmark(torch.jit.ScriptModule): def __init__(self, parameters): super(LIFBenchmark, self).__init__() self.fc = torch.nn.Linear(parameters.features, parameters.features, bias=False) self.dt = parameters.dt @torch.jit.script_method def forward( self, input_spikes: torch.Tensor, p: LIFParametersJIT, s: LIFFeedForwardState ): sequence_length, batch_size, features = input_spikes.shape # spikes = torch.jit.annotate(List[Tensor], []) spikes = torch.empty( (sequence_length, batch_size, features), device=input_spikes.device ) for ts in range(sequence_length): x = self.fc(input_spikes[ts]) z, s = _lif_feed_forward_step_jit(input_tensor=x, state=s, p=p, dt=self.dt) spikes[ts] = z return spikes
[docs]def lif_feed_forward_benchmark(parameters: BenchmarkParameters): with torch.no_grad(): model = LIFBenchmark(parameters).to(parameters.device) input_spikes = PoissonEncoder(parameters.sequence_length, dt=parameters.dt)( 0.3 * torch.ones( parameters.batch_size, parameters.features, device=parameters.device ) ).contiguous() p = LIFParametersJIT( tau_syn_inv=torch.as_tensor(1.0 / 5e-3), tau_mem_inv=torch.as_tensor(1.0 / 1e-2), v_leak=torch.as_tensor(0.0), v_th=torch.as_tensor(1.0), v_reset=torch.as_tensor(0.0), method="super", alpha=torch.as_tensor(0.0), ) s = LIFFeedForwardState( v=p.v_leak, i=torch.zeros( parameters.batch_size, parameters.features, device=parameters.device, ), ) start = time.time() model(input_spikes, p, s) end = time.time() duration = end - start return duration