Source code for norse.torch.module.test.test_lif_correlation

import torch
import numpy as np

from norse.torch.module.leaky_integrator import LILinearCell
from norse.torch.module.lif_correlation import LIFCorrelation
from norse.torch.functional.correlation_sensor import correlation_based_update


[docs]def test_lif_correlation_training(): def generate_random_data( seq_length, batch_size, input_features, device="cpu", dtype=torch.float, dt=0.001, ): freq = 5 prob = freq * dt mask = torch.rand( (seq_length, batch_size, input_features), device=device, dtype=dtype ) x_data = torch.zeros( (seq_length, batch_size, input_features), device=device, dtype=dtype, requires_grad=False, ) x_data[mask < prob] = 1.0 y_data = torch.tensor(1 * (np.random.rand(batch_size) < 0.5), device=device) return x_data, y_data seq_length = 50 batch_size = 1 input_features = 10 hidden_features = 8 output_features = 2 device = "cpu" x, y_data = generate_random_data( seq_length=seq_length, batch_size=batch_size, input_features=input_features, device=device, ) input_weights = ( torch.randn((input_features, hidden_features), device=device).float().t() ) recurrent_weights = torch.randn( (hidden_features, hidden_features), device=device ).float() lif_correlation = LIFCorrelation(input_features, hidden_features) out = LILinearCell(hidden_features, output_features).to(device) log_softmax_fn = torch.nn.LogSoftmax(dim=1) loss_fn = torch.nn.NLLLoss() linear_update = torch.nn.Linear(2 * 10 * 8, 10 * 8) rec_linear_update = torch.nn.Linear(2 * 8 * 8, 8 * 8) optimizer = torch.optim.Adam( list(linear_update.parameters()) + [input_weights, recurrent_weights] + list(out.parameters()), lr=1e-1, ) loss_hist = [] num_episodes = 3 for e in range(num_episodes): s1 = None so = None voltages = torch.zeros(seq_length, batch_size, output_features, device=device) hidden_voltages = torch.zeros( seq_length, batch_size, hidden_features, device=device ) hidden_currents = torch.zeros( seq_length, batch_size, hidden_features, device=device ) optimizer.zero_grad() for ts in range(seq_length): z1, s1 = lif_correlation( x[ts, :, :], input_weights=input_weights, recurrent_weights=recurrent_weights, state=s1, ) input_weights = correlation_based_update( ts, linear_update, input_weights.detach(), s1.input_correlation_state, 0.01, 10, ) recurrent_weights = correlation_based_update( ts, rec_linear_update, recurrent_weights.detach(), s1.recurrent_correlation_state, 0.01, 10, ) vo, so = out(z1, so) hidden_voltages[ts, :, :] = s1.lif_state.v.detach() hidden_currents[ts, :, :] = s1.lif_state.i.detach() voltages[ts, :, :] = vo m, _ = torch.max(voltages, dim=0) log_p_y = log_softmax_fn(m) loss_val = loss_fn(log_p_y, y_data.long()) loss_val.backward() optimizer.step() loss_hist.append(loss_val.item()) print(f"{e}/{num_episodes}: {loss_val.item()}")
# assert loss_hist[0] > loss_hist[1] # assert loss_hist[1] > loss_hist[2]