# Load packages
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
# Define the architecture of the network
def __init__(self):
super(Net, self).__init__()
# 1 input image channel, 6 output channels,
# 5x5 square convolution kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
# Set the flow of data through the network for the forward pass
# x represents the data
def forward(self, x):
# Max pooling over a (2, 2) window
# F.relu is the rectified-linear activation function
= F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x # If the size is a square, you can specify with a single number
= F.max_pool2d(F.relu(self.conv2(x)), 2)
x # flatten all dimensions except the batch dimension
= torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
x return x
Building a model
Key to creating neural networks in PyTorch is the torch.nn
package which contains the nn.Module
and a forward
method which returns an output from some input.
Let’s build a neural network to classify the MNIST.
First, we need to define the architecture of the network. There are many types of architectures. For images, CNN are well suited.
In Python, you can define a subclass of an existing class with:
class YourSubclass(BaseClass):
<definition of your subclass>
The subclass is derived from the base class and inherits its properties. PyTorch contains the class torch.nn.Module
which is used as the base class when defining a neural network.
Let’s create an instance of Net
and print its structure:
= Net()
net print(net)
Net(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
= list(net.parameters())
params print(len(params))
print(params[0].size()) # conv1's .weight
10
torch.Size([6, 1, 5, 5])