📉
Tutorials
  • Computer History
  • Function
    • Finance
      • Calculate
    • Manage Data
    • Date&Time
    • Strings and Character
  • Snippets
    • Web Application
      • Hugo
      • JavaScript
        • Stopwatch using JavaScript?
    • Note
    • Start Project
      • GitHub
      • GitLab
    • Python Programming
      • Strings and Character Data
      • List
      • Dictionaries
    • Data Science
      • Setting Option
      • Get Data
  • Link Center
    • Next Articles
    • Google
    • Excel VBA
    • Python
      • Notebook
    • WebApp
      • Vue.js
    • Finance
    • Project
      • Kids
        • Scratch
      • Finance
        • Plotly.js
        • Portfolio
      • Mini Lab
        • Systems Administration
        • Auto Adjust Image
      • Sending Emails
      • ECS
        • Knowledge Base
        • ระบบผู้เชี่ยวชาญ (Expert System)
        • Check product
        • Compare two SQL databases
      • e-Library
        • Knowledge base
        • การจัดหมวดหมู่ห้องสมุด
        • Temp
      • AppSheet
        • บัญชีรายรับรายจ่าย
      • Weather App
      • COVID-19
  • Tutorials
    • Data Science
      • Data Science IPython notebooks
    • UX & UI
      • 7 กฎการออกแบบ UI
    • Web Scraping
      • Scrape Wikipedia Articles
      • Quick Start
    • GUI
      • pysimple
        • Create a GUI
      • Tkinter
        • Python Tkinter Tutorial
      • PyQt
        • PyQt Tutorial
    • MachineLearning
      • การพัฒนา Chat Bot
      • AI ผู้ช่วยใหม่ในการทำ Customer Segmentation
      • Customer Segmentation
      • ตัดคำภาษาไทย ด้วย PyThaiNLP API
    • Excel & VBA
      • INDEX กับ MATCH
      • รวมสูตร Excel ปี 2020
      • How to Write Code in a Spreadsheet
    • Visualization
      • Bokeh
        • Part I: Getting Started
        • Data visualization
        • Plotting a Line Graph
        • Panel Document
        • Interactive Data Visualization
    • VueJS
      • VueJS - Quick Guide
    • Django
      • Customize the Django Admin
      • พัฒนาเว็บด้วย Django
    • Git
      • วิธีสร้าง SSH Key
      • Git คืออะไร
      • เริ่มต้นใช้งาน Git
      • การใช้งาน Git และ Github
      • รวม 10 คำสั่ง Git
      • GIT Push and Pull
    • Finance
      • Stock Analysis using Pandas (Series)
      • Building Investment AI for fintech
      • Resampling Time Series
      • Python for Finance (Series)
      • Stock Data Analysis (Second Edition)
      • Get Stock Data Using Python
      • Stock Price Trend Analysis
      • Calculate Stock Returns
      • Quantitative Trading
      • Backtrader for Backtesting
      • Binance Python API
      • Pine Script (TradingView)
      • Stocks Analysis with Pandas and Scikit-Learn
      • Yahoo Finance API
      • Sentiment Analysis
      • yfinance Library
      • Stock Data Analysis
      • YAHOO_FIN
      • Algorithmic Trading
    • JavaScript
      • Split a number
      • Callback Function
      • The Best JavaScript Examples
      • File and FileReader
      • JavaScript Tutorial
      • Build Reusable HTML Components
      • Developing JavaScript components
      • JavaScript - Quick Guide
      • JavaScript Style Guide()
      • Beginner's Handbook
      • Date Now
    • Frontend
      • HTML
        • File Path
      • Static Site Generators.
        • Creating a New Theme
    • Flask
      • Flask - Quick Guide
      • Flask Dashboards
        • Black Dashboard
        • Light Blue
        • Flask Dashboard Argon
      • Create Flask App
        • Creating First Application
        • Rendering Pages Using Jinja
      • Jinja Templates
        • Primer on Jinja Templating
        • Jinja Template Document
      • Learning Flask
        • Ep.1 Your first Flask app
        • Ep.2 Flask application structure
        • Ep.3 Serving HTML files
        • Ep.4 Serving static files
        • Ep.5 Jinja template inheritance
        • Ep.6 Jinja template design
        • Ep.7 Working with forms in Flask
        • Ep.8 Generating dynamic URLs in Flask
        • Ep.9 Working with JSON data
        • Ep.23 Deploying Flask to a VM
        • Ep.24 Flask and Docker
        • Ep. 25: uWSGI Introduction
        • Ep. 26 Flask before and after request
        • Ep. 27 uWSGI Decorators
        • Ep. 28 uWSGI Decorators
        • Ep. 29 Flask MethodView
        • Ep. 30 Application factory pattern
      • The Flask Mega-Tutorial
        • Chapter 2: Templates
      • Building Flask Apps
      • Practical Flask tutorial series
      • Compiling SCSS to CSS
      • Flask application structure
    • Database
      • READING FROM DATABASES
      • SQLite
        • Data Management
        • Fast subsets of large datasets
      • Pickle Module
        • How to Persist Objects
      • Python SQL Libraries
        • Create Python apps using SQL Server
    • Python
      • Python vs JavaScript
      • Python Pillow – Adjust Image
      • Python Library for Google Search
      • Python 3 - Quick Guide
      • Regular Expressions
        • Python Regular Expressions
        • Regular Expression (RegEx)
        • Validate ZIP Codes
        • Regular Expression Tutorial
      • Python Turtle
      • Python Beginner's Handbook
      • From Beginner to Pro
      • Standard Library
      • Datetime Tutorial
        • Manipulate Times, Dates, and Time Spans
      • Work With a PDF
      • geeksforgeeks.org
        • Python Tutorial
      • Class
      • Modules
        • Modules List
        • pickle Module
      • Working With Files
        • Open, Read, Append, and Other File Handling
        • File Manipulation
        • Reading & Writing to text files
      • Virtual Environments
        • Virtual Environments made easy
        • Virtual Environmen
        • A Primer
        • for Beginners
      • Functions
        • Function Guide
        • Inner Functions
      • Learning Python
        • Pt. 4 Python Strings
        • Pt. 3 Python Variables
      • Zip Function
      • Iterators
      • Try and Except
        • Exceptions: Introduction
        • Exceptions Handling
        • try and excep
        • Errors and Exceptions
        • Errors & Exceptions
      • Control Flow
      • Lambda Functions
        • Lambda Expression คืออะไร
        • map() Function
      • Date and Time
        • Python datetime()
        • Get Current Date and Time
        • datetime in Python
      • Awesome Python
      • Dictionary
        • Dictionary Comprehension
        • ALL ABOUT DICTIONARIES
        • DefaultDict Type for Handling Missing Keys
        • The Definitive Guide
        • Why Functions Modify Lists and Dictionaries
      • Python Structures
      • Variable & Data Types
      • List
        • Lists Explained
        • List Comprehensions
          • Python List Comprehension
          • List Comprehensions in 5-minutes
          • List Comprehension
        • Python List
      • String
        • Strings and Character Data
        • Splitting, Concatenating, and Joining Strings
      • String Formatting
        • Improved String Formatting Syntax
        • String Formatting Best Practices
        • Remove Space
        • Add Spaces
      • Important basic syntax
      • List all the packages
      • comment
    • Pandas
      • Tutorial (GeeksforGeeks)
      • 10 minutes to pandas
      • Options and settings
      • เริ่มต้น Set Up Kaggle.com
      • Pandas - Quick Guide
      • Cookbook
      • NumPy
        • NumPy Package for Scientific
      • IO tools (text, CSV, …)
      • pandas.concat
      • Excel & Google Sheets
        • A Guide to Excel
        • Quickstart to the Google Sheets
        • Python Excel Tutorial: The Definitive Guide
      • Working With Text Data
        • Quickstart
      • API Reference
      • Groupby
      • DateTime Methods
      • DataFrame
      • sort_values()
      • Pundit: Accessing Data in DataFrames
      • datatable
        • DataFrame: to_json()
        • pydatatable
      • Read and Write Files
      • Data Analysis with Pandas
      • Pandas and Python: Top 10
      • 10 minutes to pandas
      • Getting Started with Pandas in Python
    • Markdown
      • Create Responsive HTML Emails
      • Using Markup Languages with Hugo
    • AngularJS
      • Learn AngularJS
    • CSS
      • The CSS Handbook
      • Box Shadow
      • Image Center
      • The CSS Handbook
      • The CSS Handbook
      • Loading Animation
      • CSS Grid Layout
      • Background Image Size
      • Flexbox
  • Series
    • จาวาสคริปต์เบื้องต้น
      • 1: รู้จักกับจาวาสคริปต์
  • Articles
    • Visualization
      • Dash
        • Introducing Dash
    • Finance
      • PyPortfolioOpt
      • Best Libraries for Finance
      • Detection of price support
      • Portfolio Optimization
      • Python Packages For Finance
    • Django
      • เริ่มต้น Django RestFramework
    • General
      • Heroku คืออะไร
      • How to Crack Passwords
    • Notebook
      • IPython Documentation
      • Importing Notebooks
      • Google Colab for Data Analytics
      • Creating Interactive Dashboards
      • The Definitive Guide
      • A gallery of interesting Jupyter Notebooks
      • Advanced Jupyter Notebooks
      • Converting HTML to Notebook
    • Pandas
      • Pandas_UI
      • Pandas Style API
      • Difference Between two Dataframes
      • 19 Essential Snippets in Pandas
      • Time Series Analysis
      • Selecting Columns in a DataFrame
      • Cleaning Up Currency Data
      • Combine Multiple Excel Worksheets
      • Stylin’ with Pandas
      • Pythonic Data Cleaning
      • Make Excel Faster
      • Reading Excel (xlsx) Files
      • How to use iloc and loc for Indexing
      • The Easiest Data Cleaning Method
    • Python
      • pip install package
      • Automating your daily tasks
      • Convert Speech to Text
      • Tutorial, Project Ideas, and Tips
      • Image Handling and Processing
        • Image Processing Part I
        • Image Processing Part II
        • Image tutorial
        • Image Processing with Numpy
        • Converts PIL Image to Numpy Array
      • Convert Dictionary To JSON
      • JSON Dump
      • Speech-to-Text Model
      • Convert Text to Speech
      • Tips & Tricks
        • Fundamentals for Data Science
        • Best Python Code Examples
        • Top 50 Tips & Tricks
        • 11 Beginner Tips
        • 10 Tips & Tricks
      • Password hashing
      • psutil
      • Lambda Expressions
    • Web Scraping
      • Web Scraping using Python
      • Build a Web Scraper
      • Web Scraping for beginner
      • Beautiful Soup
      • Scrape Websites
      • Python Web Scraping
        • Web Scraping Part 1
        • Web Scraping Part 2
        • Web Scraping Part 3
        • Web Scraping Part 4
      • Web Scraper
    • Frontend
      • Book Online with GitBook
      • Progressive Web App คืออะไร
      • self-host a Hugo web app
  • Examples
    • Django
      • Build a Portfolio App
      • SchoolManagement
    • Flask
      • Flask Stock Visualizer
      • Flask by Example
      • Building Flask Apps
      • Flask 101
    • OpenCV
      • Build a Celebrity Look-Alike
      • Face Detection-OpenCV
    • Python
      • Make Game FLASH CARD
      • Sending emails using Google
      • ตรวจหาภาพซ้ำด้วย Perceptual hashing
        • Sending Emails in Python
      • Deck of Cards
      • Extract Wikipedia Data
      • Convert Python File to EXE
      • Business Machine Learning
      • python-business-analytics
      • Simple Blackjack Game
      • Python Turtle Clock
      • Countdown
      • 3D Animation : Moon Phases
      • Defragmentation Algorithm
      • PDF File
        • จัดการข้อความ และรูป จากไฟล์ PDF ด้วย PDFBox
      • Reading and Generating QR codes
      • Generating Password
        • generate one-time password (OTP)
        • Random Password Generator
        • Generating Strong Password
      • PyQt: Building Calculator
      • List Files in a Directory
      • [Project] qID – โปรแกรมแต่งรูปง่ายๆ เพื่อการอัพลงเว็บ
      • Python and Google Docs to Build Books
      • Tools for Record Linking
      • Create Responsive HTML Email
      • psutil()
      • Transfer Learning for Deep Learning
      • ดึงข้อมูลคุณภาพอากาศประเทศไทย
        • Image Classification
    • Web Scraper
      • Scrape Wikipedia Articles
        • Untitled
      • How Scrape Websites with Python 3
    • Finance
      • Algorithmic Trading for Beginners
      • Parse TradingView Stock
      • Creating a stock price database with MariaDB and python
      • Source Code
        • stocks-list
      • Visualizing with D3
      • Real Time Stock in Excel using Python
      • Create Stock Quote Module
      • The Magic Formula Lost Its Sparkle?
      • Stock Market Analysis
      • Stock Portfolio Analyses Part 1
      • Stock Portfolio Analyses Part 2
      • Build A Dashboard In Python
      • Stock Market Predictions with LSTM
      • Trading example
      • Algorithmic Trading Strategies
      • DOWNLOAD FUNDAMENTALS DATA
      • Algorithmic Trading
      • numfin
      • Financial Machine Learning
      • Algorithm To Predict Stock Direction
      • Interactive Brokers API Code
      • The (Artificially) Intelligent Investor
      • Create Auto-Updating Excel of Stock Market
      • Stock Market Predictions
      • Automate Your Stock Portfolio
      • create an analytics dashboard
      • Bitcoin Price Notifications
      • Portfolio Management
    • WebApp
      • CSS
        • The Best CSS Examples
      • JavaScript
        • Memory Game
      • School Clock
      • Frontend Tutorials & Example
      • Side Menu Bar with sub-menu
      • Create Simple CPU Monitor App
      • Vue.js building a converter app
      • jQuery
        • The Best jQuery Examples
      • Image Slideshow
      • Handle Timezones
      • Text to Speech with Javascript
      • Building Blog for Your Portfolio
      • Responsive Website Layout
      • Maths Homework Generator
  • Books
    • Finance
      • Python for Finance (O'Reilly)
    • Website
      • Hugo
        • Go Bootcamp
        • Hugo in Action.
          • About this MEAP
          • Welcome
          • 1. The JAM stack with Hugo
          • 2. Live in 30 minutes
          • 3. Using Markup for content
          • 4. Content Management with Hugo
          • 5. Custom Pages and Customized Content
          • 6. Structuring web pages
          • A Appendix A.
          • B Appendix B.
          • C Appendix C.
    • Python
      • ภาษาไพธอนเบื้องต้น
      • Python Cheatsheet
        • Python Cheatsheet
      • Beginning Python
      • IPython Cookbook
      • The Quick Python Book
        • Case study
        • Part 1. Starting out
          • 1. About Python
          • 2. Getting started
          • 3. The Quick Python overview
        • Part 2. The essentials
          • 14. Exceptions
          • 13. Reading and writing files
          • 12. Using the filesystem
          • 11. Python programs
          • 10. Modules and scoping rules
          • 9. Functions
          • 8. Control flow
          • 4. The absolute basics
          • 5. Lists, tuples, and sets
          • 6. Strings
          • 7. Dictionaries
        • Part 3. Advanced language features
          • 19. Using Python libraries
          • 18. Packages
          • 17. Data types as objects
          • 16. Regular expressions
          • 15. Classes and OOP
        • Part 4. Working with data
          • Appendix B. Exercise answers
          • Appendix A. Python’s documentation
          • 24. Exploring data
          • 23. Saving data
          • 20. Basic file wrangling
          • 21. Processing data files
          • 22. Data over the network
      • The Hitchhiker’s Guide to Python
      • A Whirlwind Tour of Python
        • 9. Defining Functions
      • Automate the Boring Stuff
        • 4. Lists
        • 5. Dictionaries
        • 12. Web Scraping
        • 13. Excel
        • 14. Google Sheets
        • 15. PDF and Word
        • 16. CSV and JSON
    • IPython
    • Pandas
      • จัดการข้อมูลด้วย pandas เบื้องต้น
      • Pandas Tutorial
  • Link Center
    • Temp
  • เทควันโด
    • รวมเทคนิค
    • Help and Documentation
  • Image
    • Logistics
Powered by GitBook
On this page
  • Transfer Learning for Deep Learning with PyTorch (Alien vs. Predator)
  • What is Transfer Learning?
  • Summary

Was this helpful?

  1. Examples
  2. Python

Transfer Learning for Deep Learning

Previouspsutil()Nextดึงข้อมูลคุณภาพอากาศประเทศไทย

Last updated 5 years ago

Was this helpful?

Transfer Learning for Deep Learning with PyTorch (Alien vs. Predator)

What is Transfer Learning?

Transfer learning is a technique of using a trained model to solve another related task. It's popular to use other network model weight to reduce your training time because you need a lot of data to train a network model. To reduce the training time, you use other network and its weight and modify the last layer to solve our problem. The advantage is you can use a small dataset to train the last layer.

Loading Dataset

Step 1) Load the Data

The first step is to load our data and do some transformation to images so that it matched the network requirements. You will load the data from a folder with torchvision.dataset. The module will iterate in the folder to split the data for train and validation. The transformation process will crop the images from the center, perform a horizontal flip, normalize, and finally convert it to tensor.

from __future__ import print_function, division
import os
import time
import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

data_dir = "alien_pred"
input_shape = 224
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

#data transformation
data_transforms = {
   'train': transforms.Compose([
       transforms.CenterCrop(input_shape),
       transforms.ToTensor(),
       transforms.Normalize(mean, std)
   ]),
   'validation': transforms.Compose([
       transforms.CenterCrop(input_shape),
       transforms.ToTensor(),
       transforms.Normalize(mean, std)
   ]),
}

image_datasets = {
   x: datasets.ImageFolder(
       os.path.join(data_dir, x),
       transform=data_transforms[x]
   )
   for x in ['train', 'validation']
}

dataloaders = {
   x: torch.utils.data.DataLoader(
       image_datasets[x], batch_size=32,
       shuffle=True, num_workers=4
   )
   for x in ['train', 'validation']
}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'validation']}

print(dataset_sizes)
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Let's visualize our dataset. The visualization process will get the next batch of images from the train data-loaders and labels and display it with matplot.

images, labels = next(iter(dataloaders['train']))

rows = 4
columns = 4
fig=plt.figure()
for i in range(16):
   fig.add_subplot(rows, columns, i+1)
   plt.title(class_names[labels[i]])
   img = images[i].numpy().transpose((1, 2, 0))
   img = std * img + mean
   plt.imshow(img)
plt.show()

Step 2) Define Model

In this process, you will use ResNet18 from torchvision module. You will use torchvision.models to load resnet18 with the pre-trained weight set to be True. After that, you will freeze the layers so that these layers are not trainable. You also modify the last layer with a Linear layer to fit with our needs that is 2 classes. You also use CrossEntropyLoss for multi-class loss function and for the optimizer you will use SGD with the learning rate of 0.0001 and a momentum of 0.9.

## Load the model based on VGG19
vgg_based = torchvision.models.vgg19(pretrained=True)

## freeze the layers
for param in vgg_based.parameters():
   param.requires_grad = False

# Modify the last layer
number_features = vgg_based.classifier[6].in_features
features = list(vgg_based.classifier.children())[:-1] # Remove last layer
features.extend([torch.nn.Linear(number_features, len(class_names))])
vgg_based.classifier = torch.nn.Sequential(*features)

vgg_based = vgg_based.to(device)

print(vgg_based)

criterion = torch.nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(vgg_based.parameters(), lr=0.001, momentum=0.9)

The output model structure

VGG(
  (features): Sequential(
	(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(1): ReLU(inplace)
	(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(3): ReLU(inplace)
	(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
	(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(6): ReLU(inplace)
	(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(8): ReLU(inplace)
	(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
	(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(11): ReLU(inplace)
	(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(13): ReLU(inplace)
	(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(15): ReLU(inplace)
	(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(17): ReLU(inplace)
	(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
	(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(20): ReLU(inplace)
	(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(22): ReLU(inplace)
	(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(24): ReLU(inplace)
	(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(26): ReLU(inplace)
	(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
	(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(29): ReLU(inplace)
	(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(31): ReLU(inplace)
	(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(33): ReLU(inplace)
	(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
	(35): ReLU(inplace)
	(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
	(0): Linear(in_features=25088, out_features=4096, bias=True)
	(1): ReLU(inplace)
	(2): Dropout(p=0.5)
	(3): Linear(in_features=4096, out_features=4096, bias=True)
	(4): ReLU(inplace)
	(5): Dropout(p=0.5)
	(6): Linear(in_features=4096, out_features=2, bias=True)
  )
)

Step 3) Train and Test Model

We will use some of the function from PyTorch Tutorial to help us train and evaluate our model.

def train_model(model, criterion, optimizer, num_epochs=25):
   since = time.time()

   for epoch in range(num_epochs):
       print('Epoch {}/{}'.format(epoch, num_epochs - 1))
       print('-' * 10)

       #set model to trainable
       # model.train()

       train_loss = 0

       # Iterate over data.
       for i, data in enumerate(dataloaders['train']):
           inputs , labels = data
           inputs = inputs.to(device)
           labels = labels.to(device)

           optimizer.zero_grad()
          
           with torch.set_grad_enabled(True):
               outputs  = model(inputs)
               loss = criterion(outputs, labels)

           loss.backward()
           optimizer.step()

           train_loss += loss.item() * inputs.size(0)

           print('{} Loss: {:.4f}'.format(
               'train', train_loss / dataset_sizes['train']))
          
   time_elapsed = time.time() - since
   print('Training complete in {:.0f}m {:.0f}s'.format(
       time_elapsed // 60, time_elapsed % 60))

   return model

def visualize_model(model, num_images=6):
   was_training = model.training
   model.eval()
   images_so_far = 0
   fig = plt.figure()

   with torch.no_grad():
       for i, (inputs, labels) in enumerate(dataloaders['validation']):
           inputs = inputs.to(device)
           labels = labels.to(device)

           outputs = model(inputs)
           _, preds = torch.max(outputs, 1)

           for j in range(inputs.size()[0]):
               images_so_far += 1
               ax = plt.subplot(num_images//2, 2, images_so_far)
               ax.axis('off')
               ax.set_title('predicted: {} truth: {}'.format(class_names[preds[j]], class_names[labels[j]]))
               img = inputs.cpu().data[j].numpy().transpose((1, 2, 0))
               img = std * img + mean
               ax.imshow(img)

               if images_so_far == num_images:
                   model.train(mode=was_training)
                   return
       model.train(mode=was_training)

Finally, let's start our training process with the number of epochs set to 25 and evaluate after the training process. At each training step, the model will take the input and predict the output. After that, the predicted output will be passed to the criterion to calculate the losses. Then the losses will perform a backprop calculation to calculate the gradient and finally calculating the weights and optimize the parameters with autograd.

At the visualize model, the trained network will be tested with a batch of images to predict the labels. Then it will be visualized with the help of matplotlib.

vgg_based = train_model(vgg_based, criterion, optimizer_ft, num_epochs=25)

visualize_model(vgg_based)

plt.show()

Results

The final result is that you achieved an accuracy of 92%.

Epoch 23/24
----------
train Loss: 0.0044
train Loss: 0.0078
train Loss: 0.0141
train Loss: 0.0221
train Loss: 0.0306
train Loss: 0.0336
train Loss: 0.0442
train Loss: 0.0482
train Loss: 0.0557
train Loss: 0.0643
train Loss: 0.0763
train Loss: 0.0779
train Loss: 0.0843
train Loss: 0.0910
train Loss: 0.0990
train Loss: 0.1063
train Loss: 0.1133
train Loss: 0.1220
train Loss: 0.1344
train Loss: 0.1382
train Loss: 0.1429
train Loss: 0.1500
Epoch 24/24
----------
train Loss: 0.0076
train Loss: 0.0115
train Loss: 0.0185
train Loss: 0.0277
train Loss: 0.0345
train Loss: 0.0420
train Loss: 0.0450
train Loss: 0.0490
train Loss: 0.0644
train Loss: 0.0755
train Loss: 0.0813
train Loss: 0.0868
train Loss: 0.0916
train Loss: 0.0980
train Loss: 0.1008
train Loss: 0.1101
train Loss: 0.1176
train Loss: 0.1282
train Loss: 0.1323
train Loss: 0.1397
train Loss: 0.1436
train Loss: 0.1467
Training complete in 2m 47s

End then the output of our model will be visualized with matplot below:

Summary

So, let's summarize everything! The first factor is PyTorch is a growing deep learning framework for beginners or for research purpose. It offers high computation time, Dynamic Graph, GPUs support and it's totally written in Python. You are able to define our own network module with ease and do the training process with an easy iteration. It's clear that PyTorch is ideal for beginners to find out deep learning and for professional researchers it's very useful with faster computation time and also the very helpful autograd function to assist dynamic graph.

Before you start, you need to understand the dataset that you are going to use. In this part, you will classify an Alien and a Predator from nearly 700 images. For this technique, you don't really need a big amount of data to train. You can download the dataset from

Reference :

Source: Alien vs. Predator Kaggle
Kaggle: Alien vs. Predator.
https://www.guru99.com/transfer-learning.html