Source code for norse.dataset.memory

from typing import Optional

import torch
import torch.utils.data

from norse.torch.functional.encode import poisson_encode


[docs]class MemoryStoreRecallDataset(torch.utils.data.Dataset): """ A memory dataset that generates random patterns of 4-bit data, and a 2-bit command pattern (store and recall). Note that you can control the randomness by setting `a manual seed in PyTorch <https://pytorch.org/docs/stable/generated/torch.manual_seed.html>`_. Inspired by Bellec et al.: `Biologically inspired alternatives to backpropagation through time for learning in recurrent neural nets <https://arxiv.org/abs/1901.09049>`_. Arguments: samples (int): Number of samples in the dataset. seq_length (int): Number of timesteps to simulate per command. Defaults to 100. seq_periods (int): Number of commands in one sample. Defaults to 12. seq_repetitions (int): Number of times one store/recall pair occurs in a single sample. Defaults to 4. population_size (int): Number of neurons encoding each command. Defaults to 5. poisson_rate (int): Poisson rate for each command in Hz. Defaults to 250. dt (float): Timestep for the dataset. Defaults to 0.001 (1000Hz). seed (Optional[int]): Optional seed for the random generator """ def __init__( self, samples: int, seq_length: int = 100, seq_periods: int = 12, seq_repetitions: int = 4, population_size: int = 5, poisson_rate: int = 100, dt: float = 0.001, seed: Optional[int] = None, ): self.samples = samples self.seq_length = seq_length self.seq_periods = seq_periods self.seq_repetitions = seq_repetitions self.population_size = population_size self.poisson_rate = poisson_rate self.dt = dt self.store_indices = torch.randint( low=0, high=seq_periods // 2, size=(samples, seq_repetitions), ) self.recall_indices = torch.randint( low=seq_periods // 2, high=seq_periods, size=(samples, seq_repetitions), ) self.generator = None if seed is None else torch.manual_seed(seed) def __len__(self): return self.samples def _generate_sequence(self, idx, rep_idx): data_pattern = torch.stack( [ torch.randperm(2, generator=self.generator) for _ in range(self.seq_periods) ] ).byte() store_index = self.store_indices[idx][rep_idx] recall_index = self.recall_indices[idx][rep_idx] store_pattern = torch.zeros((self.seq_periods, 1)).byte() recall_pattern = store_pattern.clone() label_pattern = torch.zeros((self.seq_periods, 2)).byte() store_pattern[store_index] = 1 recall_pattern[recall_index] = 1 label_class = data_pattern[store_index].byte() label_pattern[recall_index] = label_class data_pattern[recall_index] = torch.zeros(2) def encode_pattern(pattern, hz): return poisson_encode( pattern.repeat_interleave(self.population_size, dim=1), seq_length=self.seq_length, f_max=hz, dt=self.dt, ) encoded_data_pattern = encode_pattern(data_pattern, self.poisson_rate) encoded_command_pattern = encode_pattern( torch.cat((store_pattern, recall_pattern), dim=1), self.poisson_rate // 2 ) encoded_pattern = torch.cat( (encoded_data_pattern, encoded_command_pattern), dim=2 ) encoded = torch.cat(encoded_pattern.chunk(self.seq_periods, dim=1)).squeeze() return encoded, label_pattern def __getitem__(self, idx): repetitions = [ self._generate_sequence(idx, i) for i in range(self.seq_repetitions) ] return ( torch.cat([rep[0] for rep in repetitions]), torch.cat([rep[1] for rep in repetitions]), )