Building a model


Marie-Hélène Burle

Key to creating neural networks in PyTorch is the torch.nn package which contains the nn.Module and a forward method which returns an output from some input.

Let’s build a neural network to classify the MNIST.

First, we need to define the architecture of the network. There are many types of architectures. For images, CNN are well suited.

In Python, you can define a subclass of an existing class with:

class YourSubclass(BaseClass):
    <definition of your subclass>        

The subclass is derived from the base class and inherits its properties. PyTorch contains the class torch.nn.Module which is used as the base class when defining a neural network.

# Load packages
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    # Define the architecture of the network
    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels,
        # 5x5 square convolution kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5*5 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    # Set the flow of data through the network for the forward pass
    # x represents the data
    def forward(self, x):
        # Max pooling over a (2, 2) window
        # F.relu is the rectified-linear activation function
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        # flatten all dimensions except the batch dimension
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Let’s create an instance of Net and print its structure:

net = Net()
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
params = list(net.parameters())
print(params[0].size())  # conv1's .weight
torch.Size([6, 1, 5, 5])