Introduction

Hello there! In this notebook the advantages of Binary Neural Networks are explored and compared to networks utilizing full precision calculations.

Binary Neural Networks

The main difference between common neural networks and bnns is the use of 1 bit weights and quantization functions on the input data. This allows us to take advantage of highly optimizable binary operations in order to speed up the learning and inference of the neural networks.

Per default we use the sign function to transform floating point inputs and weights to it’s binary value:

\[\begin{split} sign(x) = \left\{ \begin{array}{ll} +1 & \mbox{if } x \geq 0 \\ -1 & \mbox{if } x < 0 \end{array} \right.\end{split}\]

When computing gradients, we use a so called Straight Through Estimator (STE). Also gradients are automatically canceled if they get to big due to the fact that changes in already very large gradients have no effect when applying the sign function above. This gives us the following quantization behaviour for a real number \(r\) passing through our quantized layers (\(q\) is the quantized value, \(c\) is a given loss value, \(t_{clip}\) the gradient cancellation threshold):

\[\begin{split}\text{Forward}: q = sign(r)\\ \text{Backward}: \frac{\delta c}{\delta q} = \frac{\delta c}{\delta r} 1_{|r| \leq t_{clip}}\end{split}\]

Now let’s see how our layers compare with full precision layers when used in a simple modle. For this we will build a full precision LeNet and compare its performance with our binarized LeNet version when trained on the MNIST-Dataset.

Imports

First we need to import some packages to be ready to go…

[3]:
%matplotlib notebook
import sys
import torch
from torch import nn
import matplotlib.animation
import matplotlib.pyplot as plt
from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

sys.path.append("../../")
from bitorch.layers import (
    QLinear,
    QActivation,
    QConv2d_NoAct,
    WeightGraphicalDebug
)

print("everything imported!")

everything imported!

We want to train on the MNIST dataset, containing a collection of handwritten digits. For this we first download the dataset and then create loaders for training and testing:

[6]:
train_dataset = mnist.MNIST(root='./train', train=True, transform=ToTensor(), download=True)
test_dataset = mnist.MNIST(root='./test', train=False, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=100)
test_loader = DataLoader(test_dataset, batch_size=100)

Now we want to build a Model. At first we will create a full precision LeNet version. Bitorch comes with an included WeightGraphicalDebug layer to output weight tensors. Below the implementation of the model:

[3]:
num_convolution_features = 64
num_linear_nodes = 1000
num_output_nodes = 10

model = nn.Sequential(
    # first convolution block
    nn.Conv2d(1, num_convolution_features, kernel_size=5),
    nn.BatchNorm2d(num_convolution_features),
    nn.Tanh(),
    nn.MaxPool2d(2, 2),

    # second convolution blocks
    WeightGraphicalDebug(
        nn.Conv2d(num_convolution_features, num_convolution_features, kernel_size=5),
        num_outputs = 10
    ),
    nn.BatchNorm2d(num_convolution_features),
    nn.Tanh(),
    nn.MaxPool2d(2, 2),

    nn.Flatten(),

    nn.Linear(num_convolution_features * 4 * 4, num_output_nodes),
    nn.BatchNorm1d(num_output_nodes),
    nn.Tanh(),

    nn.Linear(num_output_nodes, num_output_nodes),
)

You may notice LeNet consists of two convolution blocks containing a convolution and a batchnorm layer, an tanh activation function and a max pool layer, followed by one blocks with linear units to classify the features retrieved from the prior layers.

We want to catch the weights of the second convolution block, so we wrap the convolution layer in our weight debug layer. later we will create and pass matplotlib objects to this layer in order to create graphical output.

Now lets train our model:

[4]:
criterion = CrossEntropyLoss()
optimizer = Adam(params=model.parameters(), lr=0.001)
num_epochs = 10

# this part is for graphic output (still a bit hacky)
fig = plt.figure()
axes = [plt.subplot(2, 5, i) for i in range(1, 11)]
example_img = torch.zeros((3, 3))
example_img[0][0] = 1.0
images = [ax.imshow(example_img, cmap='gray') for ax in axes]

# set graphic objects in debug layer
debug_layer = model[4]
debug_layer.set_figure(fig)
debug_layer.set_images(images)

# we also want the loss to be plotted in a graph
fig2, ax2 = plt.subplots()
loss_plot = None
losses = []

for epoch in range(num_epochs):
    epoch_loss = 0.0

    model.train()
    for idx, (x_train, y_train) in enumerate(train_loader):
        optimizer.zero_grad()

        y_hat = model(x_train)
        loss = criterion(y_hat, y_train)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if idx % 100 == 0 and idx > 0:
            print(f"    Loss in epoch {epoch + 1} for batch {idx}: {epoch_loss / idx}")
    epoch_loss /= len(train_loader)

    losses.append(epoch_loss)
    loss_plot = ax2.plot(losses)
    fig2.canvas.draw()
    print("total loss of epoch", epoch+1, ":", epoch_loss)
    Loss in epoch 1 for batch 100: 1.4657597327232361
    Loss in epoch 1 for batch 200: 1.2042358693480493
    Loss in epoch 1 for batch 300: 1.0070937711993853
    Loss in epoch 1 for batch 400: 0.85908417083323
    Loss in epoch 1 for batch 500: 0.7480730555951596
total loss of epoch 1 : 0.6600182635088762
    Loss in epoch 2 for batch 100: 0.1879598069936037
    Loss in epoch 2 for batch 200: 0.16950206831097603
    Loss in epoch 2 for batch 300: 0.1564542181789875
    Loss in epoch 2 for batch 400: 0.1464688781928271
    Loss in epoch 2 for batch 500: 0.13780882103741168
total loss of epoch 2 : 0.12892181511657932
    Loss in epoch 3 for batch 100: 0.0810570726916194
    Loss in epoch 3 for batch 200: 0.07514244790188968
    Loss in epoch 3 for batch 300: 0.07176227282111844
    Loss in epoch 3 for batch 400: 0.06929742332547903
    Loss in epoch 3 for batch 500: 0.06648961989209055
total loss of epoch 3 : 0.06311609602688502
    Loss in epoch 4 for batch 100: 0.045978564508259295
    Loss in epoch 4 for batch 200: 0.042802285542711614
    Loss in epoch 4 for batch 300: 0.04097476026664178
    Loss in epoch 4 for batch 400: 0.04011683324119076
    Loss in epoch 4 for batch 500: 0.038841374116018415
total loss of epoch 4 : 0.036812753922616445
    Loss in epoch 5 for batch 100: 0.025593137433752418
    Loss in epoch 5 for batch 200: 0.02476766913663596
    Loss in epoch 5 for batch 300: 0.02389705753264328
    Loss in epoch 5 for batch 400: 0.023903127894736826
    Loss in epoch 5 for batch 500: 0.023372564477846025
total loss of epoch 5 : 0.02267387439768451
    Loss in epoch 6 for batch 100: 0.017221764298155903
    Loss in epoch 6 for batch 200: 0.01620328177930787
    Loss in epoch 6 for batch 300: 0.015696850900227825
    Loss in epoch 6 for batch 400: 0.015374823551392183
    Loss in epoch 6 for batch 500: 0.015205692081712187
total loss of epoch 6 : 0.014946727050313106
    Loss in epoch 7 for batch 100: 0.01362854138482362
    Loss in epoch 7 for batch 200: 0.012483824003720657
    Loss in epoch 7 for batch 300: 0.012193969532381744
    Loss in epoch 7 for batch 400: 0.012442733142524958
    Loss in epoch 7 for batch 500: 0.012343961025122554
total loss of epoch 7 : 0.01268254928290844
    Loss in epoch 8 for batch 100: 0.017260767875704915
    Loss in epoch 8 for batch 200: 0.017801770525984466
    Loss in epoch 8 for batch 300: 0.017268517272702108
    Loss in epoch 8 for batch 400: 0.016472979491227308
    Loss in epoch 8 for batch 500: 0.01520508725894615
total loss of epoch 8 : 0.014251592132495716
    Loss in epoch 9 for batch 100: 0.010850381379714235
    Loss in epoch 9 for batch 200: 0.00870377257000655
    Loss in epoch 9 for batch 300: 0.008475594795309007
    Loss in epoch 9 for batch 400: 0.008686605037073605
    Loss in epoch 9 for batch 500: 0.008817368844989687
total loss of epoch 9 : 0.008879245344918065
    Loss in epoch 10 for batch 100: 0.007755318194394931
    Loss in epoch 10 for batch 200: 0.006581475574639626
    Loss in epoch 10 for batch 300: 0.006072543915361166
    Loss in epoch 10 for batch 400: 0.005625771405466367
    Loss in epoch 10 for batch 500: 0.005263186836265959
total loss of epoch 10 : 0.0051518319802319945

You can see as the loss declines the full precision feature maps vary a bit in shade.

Below we evaluate how well our full precision network performs on a test dataset by calculating its accuarcy when confronted with previously unseen examples:

[5]:
model.eval()
test_loss = 0.0
correct = 0.0
# now validate model with test dataset
with torch.no_grad():
    for x_test, y_test in test_loader:

        y_hat = model(x_test)
        test_loss += criterion(y_hat, y_test).item()

        # determine count of correctly predicted labels
        predictions = torch.argmax(y_hat, dim=1)
        correct += torch.sum(y_test == predictions).item()
test_loss /= len(test_loader)
accuracy = correct / (len(test_loader) * test_loader.batch_size)
print("test loss:", test_loss, "test accuracy:", accuracy)
test loss: 0.040575054759101475 test accuracy: 0.9887

With the full precision version we achieve an accuracy of ~98%-99%.

Next, we want to evaluate the performance of our binary layers by first building a model:

[4]:
num_convolution_features = 64
num_linear_nodes = 1000
num_output_nodes = 10

model = nn.Sequential(
    # first convolution block
    nn.Conv2d(1, num_convolution_features, kernel_size=5),
    nn.BatchNorm2d(num_convolution_features),
    nn.Tanh(),
    nn.MaxPool2d(2, 2),

    # second convolution blocks
    # previously: Conv2d(num_convolution_features, num_convolution_features, kernel_size=5),
    QActivation(),
    WeightGraphicalDebug(
        QConv2d_NoAct(num_convolution_features, num_convolution_features, kernel_size=5),
        num_outputs=10),
    nn.BatchNorm2d(num_convolution_features),
    nn.MaxPool2d(2, 2),

    nn.Flatten(),

    # previously: Linear(num_convolution_features * 4 * 4, num_linear_nodes)
    QActivation(),
    QLinear(num_convolution_features * 4 * 4, num_linear_nodes),
    nn.BatchNorm1d(num_linear_nodes),
    nn.Tanh(),

    nn.Linear(num_linear_nodes, 10),
)

In the quantized LeNet version above we simply exchanged the convolution, linear and activation layers starting from the second convolution block with our corresponding quantized versions.

Note that we also added an additional activation layer in front of the convolution and linear layers. This is necessary because we want to pass already binarized input values to our quantized layers. We want the network to learn on binarized data, so it would not make much sense to pass full precision input values to a layer with binarized weights.

The default QConv2d layer from bittorch already includes this activation layer but for better understanding we used the dedicated QConv2d_NoAct version here.

Now we train this model using the exact same code as before:

[7]:
criterion = CrossEntropyLoss()
optimizer = Adam(params=model.parameters(), lr=0.001)
num_epochs = 10

# this part is for graphic output (still a bit hacky)
fig = plt.figure()
axes = [plt.subplot(2, 5, i) for i in range(1, 11)]
example_img = torch.zeros((3, 3))
example_img[0][0] = 1.0
images = [ax.imshow(example_img, cmap='gray') for ax in axes]

# set graphic objects in debug layer
debug_layer = model[5]
debug_layer.set_figure(fig)
debug_layer.set_images(images)

# we also want the loss to be plotted in a graph
fig2, ax2 = plt.subplots()
loss_plot = None
losses = []

for epoch in range(num_epochs):
    epoch_loss = 0.0

    model.train()
    for idx, (x_train, y_train) in enumerate(train_loader):
        optimizer.zero_grad()

        y_hat = model(x_train)
        loss = criterion(y_hat, y_train)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if idx % 100 == 0 and idx > 0:
            print(f"    Loss in epoch {epoch + 1} for batch {idx}: {epoch_loss / idx}")
    epoch_loss /= len(train_loader)

    losses.append(epoch_loss)
    loss_plot = ax2.plot(losses)
    fig2.canvas.draw()
    print("total loss of epoch", epoch+1, ":", epoch_loss)
    Loss in epoch 1 for batch 100: 0.3191458823531866
    Loss in epoch 1 for batch 200: 0.22882670355960727
    Loss in epoch 1 for batch 300: 0.1879387534254541
    Loss in epoch 1 for batch 400: 0.16579986134544014
    Loss in epoch 1 for batch 500: 0.1520107238981873
total loss of epoch 1 : 0.138434021412977
    Loss in epoch 2 for batch 100: 0.0782842071680352
    Loss in epoch 2 for batch 200: 0.06797829900635406
    Loss in epoch 2 for batch 300: 0.06449848238068322
    Loss in epoch 2 for batch 400: 0.06472309559467249
    Loss in epoch 2 for batch 500: 0.06367404395993799
total loss of epoch 2 : 0.06163016845831104
    Loss in epoch 3 for batch 100: 0.05601555611938238
    Loss in epoch 3 for batch 200: 0.05255320848664269
    Loss in epoch 3 for batch 300: 0.04967744196668112
    Loss in epoch 3 for batch 400: 0.05016258191855741
    Loss in epoch 3 for batch 500: 0.049570631820824926
total loss of epoch 3 : 0.04757844619802199
    Loss in epoch 4 for batch 100: 0.04511312256800011
    Loss in epoch 4 for batch 200: 0.03883477459778078
    Loss in epoch 4 for batch 300: 0.03895589078854149
    Loss in epoch 4 for batch 400: 0.04154080096050165
    Loss in epoch 4 for batch 500: 0.04224612207897008
total loss of epoch 4 : 0.04081949879558427
    Loss in epoch 5 for batch 100: 0.037817918587243184
    Loss in epoch 5 for batch 200: 0.031362870420853145
    Loss in epoch 5 for batch 300: 0.031488330888872346
    Loss in epoch 5 for batch 400: 0.033103330412122885
    Loss in epoch 5 for batch 500: 0.03294398140651174
total loss of epoch 5 : 0.03274958998075211
    Loss in epoch 6 for batch 100: 0.030703408132540064
    Loss in epoch 6 for batch 200: 0.027845895515347364
    Loss in epoch 6 for batch 300: 0.028419735451655773
    Loss in epoch 6 for batch 400: 0.029805011277421727
    Loss in epoch 6 for batch 500: 0.03155679095414234
total loss of epoch 6 : 0.031322444301282906
    Loss in epoch 7 for batch 100: 0.02826335902354913
    Loss in epoch 7 for batch 200: 0.027949679844023193
    Loss in epoch 7 for batch 300: 0.027713744351252293
    Loss in epoch 7 for batch 400: 0.028659114930414942
    Loss in epoch 7 for batch 500: 0.028272173204110004
total loss of epoch 7 : 0.027773375253121534
    Loss in epoch 8 for batch 100: 0.026254299548309063
    Loss in epoch 8 for batch 200: 0.023044601824731215
    Loss in epoch 8 for batch 300: 0.0225934889868707
    Loss in epoch 8 for batch 400: 0.022798736034001194
    Loss in epoch 8 for batch 500: 0.022899673691936188
total loss of epoch 8 : 0.02252530144642151
    Loss in epoch 9 for batch 100: 0.025631154326256365
    Loss in epoch 9 for batch 200: 0.021320976483111737
    Loss in epoch 9 for batch 300: 0.019907770642263737
    Loss in epoch 9 for batch 400: 0.020488056090616737
    Loss in epoch 9 for batch 500: 0.021229280417726842
total loss of epoch 9 : 0.021772936817239193
    Loss in epoch 10 for batch 100: 0.021776572078233584
    Loss in epoch 10 for batch 200: 0.021522397231747162
    Loss in epoch 10 for batch 300: 0.022618910595289587
    Loss in epoch 10 for batch 400: 0.022459085493155725
    Loss in epoch 10 for batch 500: 0.02168780977021379
total loss of epoch 10 : 0.02128381393169396

And now we also evaluate the performance of our quantized model:

[8]:
model.eval()
test_loss = 0.0
correct = 0.0
# now validate model with test dataset
with torch.no_grad():
    for x_test, y_test in test_loader:

        y_hat = model(x_test)
        test_loss += criterion(y_hat, y_test).item()

        # determine count of correctly predicted labels
        predictions = torch.argmax(y_hat, dim=1)
        correct += torch.sum(y_test == predictions).item()
test_loss /= len(test_loader)
accuracy = correct / (len(test_loader) * test_loader.batch_size)
print("test loss:", test_loss, "test accuracy:", accuracy)
test loss: 0.05959067487354787 test accuracy: 0.9839

As we can see the binarized version of LeNet performs slightly worse than the full precision version. But considering that we reduced the number of possible weight and input values from full precision (i.e. 2^32 possible values) to binary (2 possible values) the loss of accuracy is astonishingly small.

With some further tweaks to the binarized LeNet version we are able to reduce this accuracy gap to less than 1-2% while still only working on a 2-bit network.

Our bittorch framework does not yet contain the implementations of speed up operations to fully utilize the binarized feature maps but it already demonstrates the potential performance gains of binary neural networks while still competing with state-of-the-art full precision networks.