Transfer Learning for Deep Learning with PyTorch (Alien vs. Predator)
What is Transfer Learning?
Transfer learning is a technique of using a trained model to solve another related task. It's popular to use other network model weight to reduce your training time because you need a lot of data to train a network model. To reduce the training time, you use other network and its weight and modify the last layer to solve our problem. The advantage is you can use a small dataset to train the last layer.
Before you start, you need to understand the dataset that you are going to use. In this part, you will classify an Alien and a Predator from nearly 700 images. For this technique, you don't really need a big amount of data to train. You can download the dataset from Kaggle: Alien vs. Predator.
Step 1) Load the Data
The first step is to load our data and do some transformation to images so that it matched the network requirements. You will load the data from a folder with torchvision.dataset. The module will iterate in the folder to split the data for train and validation. The transformation process will crop the images from the center, perform a horizontal flip, normalize, and finally convert it to tensor.
from__future__import print_function, divisionimport osimport timeimport torchimport torchvisionfrom torchvision import datasets, models, transformsimport torch.optim as optimimport numpy as npimport matplotlib.pyplot as pltdata_dir ="alien_pred"input_shape =224mean = [0.5,0.5,0.5]std = [0.5,0.5,0.5]#data transformationdata_transforms ={'train': transforms.Compose([ transforms.CenterCrop(input_shape), transforms.ToTensor(), transforms.Normalize(mean, std) ]),'validation': transforms.Compose([ transforms.CenterCrop(input_shape), transforms.ToTensor(), transforms.Normalize(mean, std) ]),}image_datasets ={ x: datasets.ImageFolder( os.path.join(data_dir, x), transform=data_transforms[x] )for x in ['train','validation']}dataloaders ={ x: torch.utils.data.DataLoader( image_datasets[x], batch_size=32, shuffle=True, num_workers=4 )for x in ['train','validation']}dataset_sizes ={x:len(image_datasets[x])for x in ['train','validation']}print(dataset_sizes)class_names = image_datasets['train'].classesdevice = torch.device("cuda:0"if torch.cuda.is_available() else"cpu")
Let's visualize our dataset. The visualization process will get the next batch of images from the train data-loaders and labels and display it with matplot.
In this process, you will use ResNet18 from torchvision module. You will use torchvision.models to load resnet18 with the pre-trained weight set to be True. After that, you will freeze the layers so that these layers are not trainable. You also modify the last layer with a Linear layer to fit with our needs that is 2 classes. You also use CrossEntropyLoss for multi-class loss function and for the optimizer you will use SGD with the learning rate of 0.0001 and a momentum of 0.9.
## Load the model based on VGG19vgg_based = torchvision.models.vgg19(pretrained=True)## freeze the layersfor param in vgg_based.parameters(): param.requires_grad =False# Modify the last layernumber_features = vgg_based.classifier[6].in_featuresfeatures =list(vgg_based.classifier.children())[:-1] # Remove last layerfeatures.extend([torch.nn.Linear(number_features, len(class_names))])vgg_based.classifier = torch.nn.Sequential(*features)vgg_based = vgg_based.to(device)print(vgg_based)criterion = torch.nn.CrossEntropyLoss()optimizer_ft = optim.SGD(vgg_based.parameters(), lr=0.001, momentum=0.9)
Finally, let's start our training process with the number of epochs set to 25 and evaluate after the training process. At each training step, the model will take the input and predict the output. After that, the predicted output will be passed to the criterion to calculate the losses. Then the losses will perform a backprop calculation to calculate the gradient and finally calculating the weights and optimize the parameters with autograd.
At the visualize model, the trained network will be tested with a batch of images to predict the labels. Then it will be visualized with the help of matplotlib.
End then the output of our model will be visualized with matplot below:
Summary
So, let's summarize everything! The first factor is PyTorch is a growing deep learning framework for beginners or for research purpose. It offers high computation time, Dynamic Graph, GPUs support and it's totally written in Python. You are able to define our own network module with ease and do the training process with an easy iteration. It's clear that PyTorch is ideal for beginners to find out deep learning and for professional researchers it's very useful with faster computation time and also the very helpful autograd function to assist dynamic graph.