Training a classifier on the event-based POKER-DVS dataset

Training a classifier on the event-based POKER-DVS dataset#

When working with Spiking Neural Networks (SNN), we will inevitably encounter the notion of time in our network and data flow. The classic example of MNIST handwritten digits consists of images, much like snapshots in time. Deep learning has shown impressive results on such purely spatial compositions, but SNNs might be able to extract meaning from temporal features and/or save power doing so in comparison to classical networks.

An event camera such as the Dynamic Vision Sensor (DVS) is somewhat based on the functional principle of the human retina. Such a camera can record a scene much more efficiently than a conventional camera by encoding the changes in a visual scene rather than absolute illuminance values. The output is a spike train of change detection events for each pixel. While previously we had to use encoders to equip static image data with a temporal dimension, the POKER-DVS dataset contains recordings of poker cards that are shown to an event camera in rapid succession.

Warning! This notebook uses a large dataset and can take a significant amount of time to execute.

import torch
import numpy as np
import matplotlib.pyplot as plt

We can simply install Norse through pip:

!pip install norse --quiet

For this tutorial we are going to make use of a package that handles event-based datasets called Tonic. It is based on PyTorch Vision, so you should already have most of its dependencies installed.

!pip install tonic --quiet

Let’s start by loading the POKER-DVS dataset and specifying a sparse tensor transform whenever a new sample is loaded

import tonic
import torchvision

sensor_size = tonic.datasets.POKERDVS.sensor_size
frame_transform = tonic.transforms.ToFrame(sensor_size=sensor_size, time_window=1000)

trainset = tonic.datasets.POKERDVS(save_to="./data", train=True)
testset = tonic.datasets.POKERDVS(
    save_to="./data", transform=frame_transform, train=False
)

We can have a look at how a sample of one digit looks like. The event camera’s output is encoded as events that have x/y coordinates, a timestamp and a polarity that indicates whether the lighting increased or decreased at that event. The events are provided in an (NxE) array. Let’s have a look at the first example in the dataset. Every row in the array represents one event of timestamp, x, y, and polarity.

events = trainset[0][0]
events

When accumulated over time into 3 bins, the images show 1 of 4 card symbols

tonic.utils.plot_event_grid(events)

And this one is the target class:

trainset[0][1]

We wrap the training and testing sets in PyTorch DataLoaders that facilitate file loading. Note also the custom collate function pad_tensors , which makes sure that all sparse tensors in the batch have the same dimensions

# reduce this number if you run out of GPU memory
BATCH_SIZE = 32

# add sparse transform to trainset, previously omitted because we wanted to look at raw events
trainset.transform = frame_transform

train_loader = torch.utils.data.DataLoader(
    trainset,
    batch_size=BATCH_SIZE,
    collate_fn=tonic.collation.PadTensors(batch_first=False),
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    testset,
    batch_size=BATCH_SIZE,
    collate_fn=tonic.collation.PadTensors(batch_first=False),
    shuffle=False,
)

Defining a Network#

Once the data is encoded into spikes, a spiking neural network can be constructed in the same way as a one would construct a recurrent neural network. Here we define a spiking neural network with one recurrently connected layer with hidden_features LIF neurons and a readout layer with output_features and leaky-integrators. As you can see, we can freely combine spiking neural network primitives with ordinary torch.nn.Module layers.

from norse.torch import LIFParameters, LIFState
from norse.torch.module.lif import LIFCell, LIFRecurrentCell

# Notice the difference between "LIF" (leaky integrate-and-fire) and "LI" (leaky integrator)
from norse.torch import LICell, LIState

from typing import NamedTuple


class SNNState(NamedTuple):
    lif0: LIFState
    readout: LIState


class SNN(torch.nn.Module):
    def __init__(
        self,
        input_features,
        hidden_features,
        output_features,
        tau_syn_inv,
        tau_mem_inv,
        record=False,
        dt=1e-3,
    ):
        super(SNN, self).__init__()
        self.l1 = LIFRecurrentCell(
            input_features,
            hidden_features,
            p=LIFParameters(
                alpha=100,
                v_th=torch.as_tensor(0.3),
                tau_syn_inv=tau_syn_inv,
                tau_mem_inv=tau_mem_inv,
            ),
            dt=dt,
        )
        self.input_features = input_features
        self.fc_out = torch.nn.Linear(hidden_features, output_features, bias=False)
        self.out = LICell(dt=dt)

        self.hidden_features = hidden_features
        self.output_features = output_features
        self.record = record

    def forward(self, x):
        seq_length, batch_size, _, _, _ = x.shape
        s1 = so = None
        voltages = []

        if self.record:
            self.recording = SNNState(
                LIFState(
                    z=torch.zeros(seq_length, batch_size, self.hidden_features),
                    v=torch.zeros(seq_length, batch_size, self.hidden_features),
                    i=torch.zeros(seq_length, batch_size, self.hidden_features),
                ),
                LIState(
                    v=torch.zeros(seq_length, batch_size, self.output_features),
                    i=torch.zeros(seq_length, batch_size, self.output_features),
                ),
            )

        for ts in range(seq_length):
            z = x[ts, :, :, :].view(-1, self.input_features)
            z, s1 = self.l1(z, s1)
            z = self.fc_out(z)
            vo, so = self.out(z, so)
            if self.record:
                self.recording.lif0.z[ts, :] = s1.z
                self.recording.lif0.v[ts, :] = s1.v
                self.recording.lif0.i[ts, :] = s1.i
                self.recording.readout.v[ts, :] = so.v
                self.recording.readout.i[ts, :] = so.i
            voltages += [vo]

        return torch.stack(voltages)

It’s a good idea to test the network’s response to time constant parameters that depend on the duration of recordings in the dataset as well as average number of events. We use dt=1e-6 because the events we’re dealing with have microsecond resolution

example_snn = SNN(
    np.product(trainset.sensor_size),
    100,
    len(trainset.classes),
    tau_syn_inv=torch.tensor(1 / 1e-2),
    tau_mem_inv=torch.tensor(1 / 1e-2),
    record=True,
    dt=1e-3,
)

frames, target = next(iter(train_loader))

frames[:, :1].shape

Note that we are only applying a subset (1000) of the data timesteps (22227).

example_readout_voltages = example_snn(frames[:, :1])
voltages = example_readout_voltages.squeeze(1).detach().numpy()

plt.plot(voltages)
plt.ylabel("Voltage [a.u.]")
plt.xlabel("Time [us]")
plt.show()
plt.plot(example_snn.recording.lif0.v.squeeze(1).detach().numpy())
plt.show()
plt.plot(example_snn.recording.lif0.i.squeeze(1).detach().numpy())
plt.show()

Training the Network#

The final model is then simply the sequential composition of our network and a decoding step.

def decode(x):
    x, _ = torch.max(x, 0)
    log_p_y = torch.nn.functional.log_softmax(x, dim=1)
    return log_p_y


class Model(torch.nn.Module):
    def __init__(self, snn, decoder):
        super(Model, self).__init__()
        self.snn = snn
        self.decoder = decoder

    def forward(self, x):
        x = self.snn(x)
        log_p_y = self.decoder(x)
        return log_p_y
LR = 0.002
INPUT_FEATURES = np.product(trainset.sensor_size)
HIDDEN_FEATURES = 100
OUTPUT_FEATURES = len(trainset.classes)

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

model = Model(
    snn=SNN(
        input_features=INPUT_FEATURES,
        hidden_features=HIDDEN_FEATURES,
        output_features=OUTPUT_FEATURES,
        tau_syn_inv=torch.tensor(1 / 1e-2),
        tau_mem_inv=torch.tensor(1 / 1e-2),
    ),
    decoder=decode,
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
model

What remains to do is to setup training and test code. This code is completely independent of the fact that we are training a spiking neural network and in fact has been largely copied from the pytorch tutorials.

from tqdm.notebook import tqdm, trange


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    losses = []

    for (data, target) in tqdm(train_loader, leave=False):
        data, target = data.to(device), torch.LongTensor(target).to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = torch.nn.functional.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

    mean_loss = np.mean(losses)
    return losses, mean_loss

Just like the training function, the test function is standard boilerplate, common with any other supervised learning task.

def test(model, device, test_loader, epoch):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), torch.LongTensor(target).to(device)
            output = model(data)
            test_loss += torch.nn.functional.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    accuracy = 100.0 * correct / len(test_loader.dataset)

    return test_loss, accuracy
training_losses = []
mean_losses = []
test_losses = []
accuracies = []

torch.autograd.set_detect_anomaly(True)

EPOCHS = 10

for epoch in trange(EPOCHS):
    training_loss, mean_loss = train(model, DEVICE, train_loader, optimizer, epoch)
    test_loss, accuracy = test(model, DEVICE, test_loader, epoch)
    training_losses += training_loss
    mean_losses.append(mean_loss)
    test_losses.append(test_loss)
    accuracies.append(accuracy)

print(f"final accuracy: {accuracies[-1]}")

We can visualize the output of the trained network on an example input

trained_snn = model.snn
trained_readout_voltages = trained_snn(frames[:, :1].to("cuda"))
plt.plot(trained_readout_voltages.squeeze(1).cpu().detach().numpy())

plt.ylabel("Voltage [a.u.]")
plt.xlabel("Time [ms]")
plt.show()