Lec 4 Building Efficient Spiking Neural Networks with BrainCog

Hello, in this tutorial, we will share with you how to build and train high performance spiking neural networks using BrainCog.

Creating neuron classes

Import the braincog package

from functools import partial
from timm.models import register_model
from braincog.base.node.node import *
from braincog.base.encoder.encoder import *
from braincog.model_zoo.base_module import BaseModule, BaseConvModule
from braincog.datasets import is_dvs_data
from braincog.datasets.datasets import get_dvsc10_data
from braincog.base.utils.visualization import spike_rate_vis, spike_rate_vis_1d

Defining a new neural network

Note that all SNNs need to inherit from BaseModule. BaseModule implements some basic interfaces for neuron state control and visualization in SNNs.

@register_model  ## register model in timm
class SNN5(BaseModule):
    def __init__(self,
                 num_classes=10,
                 step=8,
                 node_type=LIFNode,
                 encode_type='direct',
                 *args,
                 **kwargs):
        super().__init__(step, encode_type, *args, **kwargs)
        self.num_classes = num_classes

        self.node = node_type
        if issubclass(self.node, BaseNode):
            self.node = partial(self.node, **kwargs, step=step)

        self.dataset = kwargs['dataset'] if 'dataset' in kwargs else 'dvsc10'
        if not is_dvs_data(self.dataset):
            init_channel = 3
        else:
            init_channel = 2

        self.feature = nn.Sequential(
            BaseConvModule(init_channel, 16, kernel_size=(3, 3), padding=(1, 1), node=self.node),
            BaseConvModule(16, 64, kernel_size=(5, 5), padding=(2, 2), node=self.node),
            nn.AvgPool2d(2),
            BaseConvModule(64, 128, kernel_size=(5, 5), padding=(2, 2), node=self.node),
            nn.AvgPool2d(2),
            BaseConvModule(128, 256, kernel_size=(3, 3), padding=(1, 1), node=self.node),
            nn.AvgPool2d(2),
            BaseConvModule(256, 512, kernel_size=(3, 3), padding=(1, 1), node=self.node),
            nn.AvgPool2d(2),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 3 * 3, self.num_classes),
        )

    def forward(self, inputs):
        inputs = self.encoder(inputs)
        self.reset()

        if self.layer_by_layer:
            x = self.feature(inputs)
            x = self.fc(x)
            x = rearrange(x, '(t b) c -> t b c', t=self.step).mean(0)
            return x

        else:
            outputs = []
            for t in range(self.step):
                x = inputs[t]
                x = self.feature(x)
                x = self.fc(x)
                outputs.append(x)

            return sum(outputs) / len(outputs)

The BaseModule provides an easy-to-call interface for controlling the behavior of neurons in the network. For example, it is possible to reset the membrane potentials of all neurons in the network by simply calling self.reset(). Thus, it is only necessary to implement __init__() and forward() to build a spiking neural network using braincog.

Visualization of input data

Dataset initialization

BrainCog provides interfaces to several datasets including image datasets and event datasets. Examples include MNIST, CIFAR10, ImageNet, DVS-CIFAR10, DVS-Gesture, N-Caltech101.

When using datasets, the following code can be used to implement them:

train_loader, _, _, _ = get_dvsc10_data(batch_size=1, step=8)
it = iter(train_loader)
inputs, labels = it.next()
print(inputs.shape, labels.shape)
spike_rate_vis(inputs[0, :, 0])

Out[1]: torch.Size([1, 8, 2, 48, 48]) torch.Size([1])

Model Inference and Visualization

After defining the model and acquiring the dataset, the model can be used for forward inference.

model = SNN5(layer_by_layer=True, datasets='dvsc10').cuda()
print(model)

Out[2]: 
SNN5(
  (encoder): Encoder()
  (feature): Sequential(
    (0): BaseConvModule(
      (conv): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (preact): Identity()
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (node): LIFNode(
        (act_fun): QGateGrad()
      )
      (activation): Identity()
    )
    (1): BaseConvModule(
      (conv): Conv2d(16, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      (preact): Identity()
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (node): LIFNode(
        (act_fun): QGateGrad()
      )
      (activation): Identity()
    )
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): BaseConvModule(
      (conv): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      (preact): Identity()
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (node): LIFNode(
        (act_fun): QGateGrad()
      )
      (activation): Identity()
    )
    (4): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (5): BaseConvModule(
      (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (preact): Identity()
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (node): LIFNode(
        (act_fun): QGateGrad()
      )
      (activation): Identity()
    )
    (6): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (7): BaseConvModule(
      (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (preact): Identity()
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (node): LIFNode(
        (act_fun): QGateGrad()
      )
      (activation): Identity()
    )
    (8): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (fc): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=4608, out_features=10, bias=True)
  )
)

The above code shows the instantiation process of the SNN model, so that we get an instance of SNN5, and then we can use this instance for inference.

outputs = model(inputs.cuda())
print(outputs)

Out[3]: tensor([[-0.0524,  0.1208, -0.0167, -0.1197, -0.0659, -0.0528, -0.1191, -0.0732, -0.0615,  0.1520]], device='cuda:0', grad_fn=<MeanBackward1>)

The event data is fed into the model to get the output of the event data belonging to different categories. Since the model has not been trained, the output is random.

The BaseModule in braincog also provides visualization methods to easily capture the activity of each layer of neurons.

model.set_requires_fp(True)
outputs = model(inputs.cuda())
feature_map = model.get_fp()

feature = feature_map[0]
feature = rearrange(feature, '1 (t c) ... -> t c ...', t=8)
spike_rate_vis(feature[:, 0].detach())
feature = feature_map[-1]

feature = rearrange(feature, '1 (t c) ... -> t c ...', t=8)
spike_rate_vis(feature[:, 0].detach())

Neural activity visualization of second lyaer Neural activity visualization of last layer

The above section shares how to build spiking neural networks using braincog.

braincog also provides scripts for training high performance spiking networks.

It can be found at https://github.com/BrainCog-X/Brain-Cog/blob/main/examples/Perception_and_Learning/img_cls/bp/main.py.

Braincog also provides some model benchmarks for your reference when designing SNNs. The predefined models can be found at https://github.com/BrainCog-X/Brain-Cog/tree/main/braincog/model_zoo.

The benchmark results of the models can be found at https://github.com/BrainCog-X/Brain-Cog/tree/main/examples/Perception_and_Learning/img_cls/bp.