1.1. Spiking Neural Networks

Spiking neural networks are not that much different than Artificial Neural Networks that are currently most commonly in use. The main difference is that there is an insistence that communication between units in the network happens through spikes - represented as binary one or zero - and that time is involved.

1.1.1. How to define a Network

The spiking neural network primitives in norse are designed to fit in as seamlessly as possible into a traditional deep learning pipeline.

Out:

Net(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=1600, out_features=1024, bias=True)
  (lif0): LIFFeedForwardCell(p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=tensor(100.), v_leak=tensor(0.), v_th=tensor(1.), v_reset=tensor(0.), method='super', alpha=100.0), dt=0.001)
  (lif1): LIFFeedForwardCell(p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=tensor(100.), v_leak=tensor(0.), v_th=tensor(1.), v_reset=tensor(0.), method='super', alpha=100.0), dt=0.001)
  (lif2): LIFFeedForwardCell(p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=tensor(100.), v_leak=tensor(0.), v_th=tensor(1.), v_reset=tensor(0.), method='super', alpha=100.0), dt=0.001)
  (out): LICell()
)
tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000]],

        [[-0.0032, -0.0009,  0.0035,  0.0024,  0.0064,  0.0012,  0.0057,
          -0.0010,  0.0011,  0.0044]]], grad_fn=<CopySlices>)
/usr/local/lib/python3.8/site-packages/torch/autograd/__init__.py:130: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  Variable._execution_engine.run_backward(

import torch
from norse.torch.functional.lif import LIFParameters

from norse.torch.module.leaky_integrator import LICell
from norse.torch.module.lif import LIFFeedForwardCell


class Net(torch.nn.Module):
    def __init__(
        self,
        num_channels=1,
        feature_size=32,
        model="super",
        dtype=torch.float,
    ):
        super(Net, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)

        self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 5, 1)
        self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024)
        self.lif0 = LIFFeedForwardCell(
            p=LIFParameters(method=model, alpha=100.0),
        )
        self.lif1 = LIFFeedForwardCell(
            p=LIFParameters(method=model, alpha=100.0),
        )
        self.lif2 = LIFFeedForwardCell(p=LIFParameters(method=model, alpha=100.0))
        self.out = LICell(1024, 10)
        self.dtype = dtype
        # One would normally also define the device here
        # However, Norse has been built to infer the device type from the input data
        # It is still possible to enforce the device type on initialisation
        # More details are available in our documentation:
        #   https://norse.github.io/norse/hardware.html

    def forward(self, x):
        seq_length = x.shape[0]
        seq_batch_size = x.shape[1]

        # Initialize state variables
        s0 = None
        s1 = None
        s2 = None
        so = None

        voltages = torch.zeros(seq_length, seq_batch_size, 10, dtype=self.dtype)

        for ts in range(seq_length):
            z = self.conv1(x[ts, :])
            z, s0 = self.lif0(z, s0)
            z = torch.nn.functional.max_pool2d(z, 2, 2)
            z = 10 * self.conv2(z)
            z, s1 = self.lif1(z, s1)
            z = torch.nn.functional.max_pool2d(z, 2, 2)
            z = z.view(-1, self.features ** 2 * 64)
            z = self.fc1(z)
            z, s2 = self.lif2(z, s2)
            v, so = self.out(torch.nn.functional.relu(z), so)
            voltages[ts, :, :] = v
        return voltages


if __name__ == "__main__":
    net = Net()
    print(net)

    ########################################################################
    # We can evaluate the network we just defined on an input of size 1x32x32.
    # Note that in contrast to typical spiking neural network simulators time
    # is just another dimension in the input tensor here we chose to evaluate
    # the network on 16 timesteps and there is an explicit batch dimension
    # (number of concurrently evaluated inputs with identical model parameters).

    timesteps = 16
    batch_size = 1
    data = torch.abs(torch.randn(timesteps, batch_size, 1, 32, 32))
    out = net(data)
    print(out)

    ##########################################################################
    # Since the spiking neural network is implemented as a pytorch module, we
    # can use the usual pytorch primitives for optimizing it. Note that the
    # backward computation expects a gradient for each timestep

    net.zero_grad()
    out.backward(torch.randn(timesteps, batch_size, 10))

    ########################################################################
    # .. note::
    #
    #     ``norse`` like pytorch only supports mini-batches. This means that
    #     contrary to most other spiking neural network simulators ```norse```
    #     always integrates several indepdentent sets of spiking neural
    #     networks at once.

Total running time of the script: ( 0 minutes 0.768 seconds)

Gallery generated by Sphinx-Gallery