Implementing Convolutional AutoEncoders using PyTorch
Continuing from the previous story in this post we will build a Convolutional AutoEncoder from scratch on MNIST dataset using PyTorch.
First of all we will import all the required dependencies
import os
import torch
import numpy as np
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import seaborn as sns
Now we preset some hyper-parameters and download the dataset which is already present in PyTorch. If the dataset is not on your local machine it will be downloaded from the server.
EPOCHS = 100
BATCH_SIZE = 128
LR = 1e-3IMAGE_TRANSFORMS = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5),(0.5)),
transforms.Resize((28,28))
])DATASET = MNIST('./data', transform = IMAGE_TRANSFORMS, download= True)DATALOADER = DataLoader(DATASET, batch_size= BATCH_SIZE, shuffle = True)
Now we define our AutoEncoder class which inherits from nn.module of PyTorch. Next we define forward method of the class for a forward pass through the network.
class AUTOENCODER(nn.Module):
def __init__(self):
super(AUTOENCODER,self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 16, 3, stride=3, padding=1),
nn.ReLU(True),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(16, 8, 3, stride=2, padding=1),
nn.ReLU(True),
nn.MaxPool2d(2, stride=1)
) self.decoder = nn.Sequential(
nn.ConvTranspose2d(8, 16, 3, stride=2),
nn.ReLU(True),
nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),
nn.ReLU(True),
nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),
nn.Tanh()
)
def forward(self,x):
x = self.encoder(x)
x = self.decoder(x)
return x
To save the images generated by the decoder part of the AutoEncoder we create a folder
if not os.path.exists('./dc_img'):
os.mkdir('./dc_img')
We use mean squared error as the loss function to train the network. We store the images generated by the network at every 10th epoch and save them in the folder that we created previously.
Auto_enc = AUTOENCODER().cuda()criterion = nn.MSELoss()
optimizer = torch.optim.Adam(Auto_enc.parameters(),lr = LR,
weight_decay=1e-5)total_loss = 0LOSSES = []for epoch in range(EPOCHS):
total_loss = 0
for data in DATALOADER:
img,_ = data
img = Variable(img).cuda()
output = Auto_enc(img)
loss = criterion(output,img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.data if epoch % 10 == 0:
print('epoch [{}/{}], loss:{:.4f}'.format(epoch+1, EPOCHS, total_loss))
pic = output.cpu().data
save_image(pic, './dc_img/image_{}.png'.format(epoch))
LOSSES.append(total_loss)
After implementing the previous code snippet we have trained our AutoEncoder Below is the loss curve of the network.
Below are some of the images obtained from the network.
So, as we could see that the AutoEncoder model started reconstructing the images since the start of the training process. After the first epoch, this reconstruction was not proper and was improved until the 40th epoch. After the complete training, as we can see in the image generated after the 90th epoch and on testing, it can construct the images very well matching to the original input images.