📉
Tutorials
  • Computer History
  • Function
    • Finance
      • Calculate
    • Manage Data
    • Date&Time
    • Strings and Character
  • Snippets
    • Web Application
      • Hugo
      • JavaScript
        • Stopwatch using JavaScript?
    • Note
    • Start Project
      • GitHub
      • GitLab
    • Python Programming
      • Strings and Character Data
      • List
      • Dictionaries
    • Data Science
      • Setting Option
      • Get Data
  • Link Center
    • Next Articles
    • Google
    • Excel VBA
    • Python
      • Notebook
    • WebApp
      • Vue.js
    • Finance
    • Project
      • Kids
        • Scratch
      • Finance
        • Plotly.js
        • Portfolio
      • Mini Lab
        • Systems Administration
        • Auto Adjust Image
      • Sending Emails
      • ECS
        • Knowledge Base
        • ระบบผู้เชี่ยวชาญ (Expert System)
        • Check product
        • Compare two SQL databases
      • e-Library
        • Knowledge base
        • การจัดหมวดหมู่ห้องสมุด
        • Temp
      • AppSheet
        • บัญชีรายรับรายจ่าย
      • Weather App
      • COVID-19
  • Tutorials
    • Data Science
      • Data Science IPython notebooks
    • UX & UI
      • 7 กฎการออกแบบ UI
    • Web Scraping
      • Scrape Wikipedia Articles
      • Quick Start
    • GUI
      • pysimple
        • Create a GUI
      • Tkinter
        • Python Tkinter Tutorial
      • PyQt
        • PyQt Tutorial
    • MachineLearning
      • การพัฒนา Chat Bot
      • AI ผู้ช่วยใหม่ในการทำ Customer Segmentation
      • Customer Segmentation
      • ตัดคำภาษาไทย ด้วย PyThaiNLP API
    • Excel & VBA
      • INDEX กับ MATCH
      • รวมสูตร Excel ปี 2020
      • How to Write Code in a Spreadsheet
    • Visualization
      • Bokeh
        • Part I: Getting Started
        • Data visualization
        • Plotting a Line Graph
        • Panel Document
        • Interactive Data Visualization
    • VueJS
      • VueJS - Quick Guide
    • Django
      • Customize the Django Admin
      • พัฒนาเว็บด้วย Django
    • Git
      • วิธีสร้าง SSH Key
      • Git คืออะไร
      • เริ่มต้นใช้งาน Git
      • การใช้งาน Git และ Github
      • รวม 10 คำสั่ง Git
      • GIT Push and Pull
    • Finance
      • Stock Analysis using Pandas (Series)
      • Building Investment AI for fintech
      • Resampling Time Series
      • Python for Finance (Series)
      • Stock Data Analysis (Second Edition)
      • Get Stock Data Using Python
      • Stock Price Trend Analysis
      • Calculate Stock Returns
      • Quantitative Trading
      • Backtrader for Backtesting
      • Binance Python API
      • Pine Script (TradingView)
      • Stocks Analysis with Pandas and Scikit-Learn
      • Yahoo Finance API
      • Sentiment Analysis
      • yfinance Library
      • Stock Data Analysis
      • YAHOO_FIN
      • Algorithmic Trading
    • JavaScript
      • Split a number
      • Callback Function
      • The Best JavaScript Examples
      • File and FileReader
      • JavaScript Tutorial
      • Build Reusable HTML Components
      • Developing JavaScript components
      • JavaScript - Quick Guide
      • JavaScript Style Guide()
      • Beginner's Handbook
      • Date Now
    • Frontend
      • HTML
        • File Path
      • Static Site Generators.
        • Creating a New Theme
    • Flask
      • Flask - Quick Guide
      • Flask Dashboards
        • Black Dashboard
        • Light Blue
        • Flask Dashboard Argon
      • Create Flask App
        • Creating First Application
        • Rendering Pages Using Jinja
      • Jinja Templates
        • Primer on Jinja Templating
        • Jinja Template Document
      • Learning Flask
        • Ep.1 Your first Flask app
        • Ep.2 Flask application structure
        • Ep.3 Serving HTML files
        • Ep.4 Serving static files
        • Ep.5 Jinja template inheritance
        • Ep.6 Jinja template design
        • Ep.7 Working with forms in Flask
        • Ep.8 Generating dynamic URLs in Flask
        • Ep.9 Working with JSON data
        • Ep.23 Deploying Flask to a VM
        • Ep.24 Flask and Docker
        • Ep. 25: uWSGI Introduction
        • Ep. 26 Flask before and after request
        • Ep. 27 uWSGI Decorators
        • Ep. 28 uWSGI Decorators
        • Ep. 29 Flask MethodView
        • Ep. 30 Application factory pattern
      • The Flask Mega-Tutorial
        • Chapter 2: Templates
      • Building Flask Apps
      • Practical Flask tutorial series
      • Compiling SCSS to CSS
      • Flask application structure
    • Database
      • READING FROM DATABASES
      • SQLite
        • Data Management
        • Fast subsets of large datasets
      • Pickle Module
        • How to Persist Objects
      • Python SQL Libraries
        • Create Python apps using SQL Server
    • Python
      • Python vs JavaScript
      • Python Pillow – Adjust Image
      • Python Library for Google Search
      • Python 3 - Quick Guide
      • Regular Expressions
        • Python Regular Expressions
        • Regular Expression (RegEx)
        • Validate ZIP Codes
        • Regular Expression Tutorial
      • Python Turtle
      • Python Beginner's Handbook
      • From Beginner to Pro
      • Standard Library
      • Datetime Tutorial
        • Manipulate Times, Dates, and Time Spans
      • Work With a PDF
      • geeksforgeeks.org
        • Python Tutorial
      • Class
      • Modules
        • Modules List
        • pickle Module
      • Working With Files
        • Open, Read, Append, and Other File Handling
        • File Manipulation
        • Reading & Writing to text files
      • Virtual Environments
        • Virtual Environments made easy
        • Virtual Environmen
        • A Primer
        • for Beginners
      • Functions
        • Function Guide
        • Inner Functions
      • Learning Python
        • Pt. 4 Python Strings
        • Pt. 3 Python Variables
      • Zip Function
      • Iterators
      • Try and Except
        • Exceptions: Introduction
        • Exceptions Handling
        • try and excep
        • Errors and Exceptions
        • Errors & Exceptions
      • Control Flow
      • Lambda Functions
        • Lambda Expression คืออะไร
        • map() Function
      • Date and Time
        • Python datetime()
        • Get Current Date and Time
        • datetime in Python
      • Awesome Python
      • Dictionary
        • Dictionary Comprehension
        • ALL ABOUT DICTIONARIES
        • DefaultDict Type for Handling Missing Keys
        • The Definitive Guide
        • Why Functions Modify Lists and Dictionaries
      • Python Structures
      • Variable & Data Types
      • List
        • Lists Explained
        • List Comprehensions
          • Python List Comprehension
          • List Comprehensions in 5-minutes
          • List Comprehension
        • Python List
      • String
        • Strings and Character Data
        • Splitting, Concatenating, and Joining Strings
      • String Formatting
        • Improved String Formatting Syntax
        • String Formatting Best Practices
        • Remove Space
        • Add Spaces
      • Important basic syntax
      • List all the packages
      • comment
    • Pandas
      • Tutorial (GeeksforGeeks)
      • 10 minutes to pandas
      • Options and settings
      • เริ่มต้น Set Up Kaggle.com
      • Pandas - Quick Guide
      • Cookbook
      • NumPy
        • NumPy Package for Scientific
      • IO tools (text, CSV, …)
      • pandas.concat
      • Excel & Google Sheets
        • A Guide to Excel
        • Quickstart to the Google Sheets
        • Python Excel Tutorial: The Definitive Guide
      • Working With Text Data
        • Quickstart
      • API Reference
      • Groupby
      • DateTime Methods
      • DataFrame
      • sort_values()
      • Pundit: Accessing Data in DataFrames
      • datatable
        • DataFrame: to_json()
        • pydatatable
      • Read and Write Files
      • Data Analysis with Pandas
      • Pandas and Python: Top 10
      • 10 minutes to pandas
      • Getting Started with Pandas in Python
    • Markdown
      • Create Responsive HTML Emails
      • Using Markup Languages with Hugo
    • AngularJS
      • Learn AngularJS
    • CSS
      • The CSS Handbook
      • Box Shadow
      • Image Center
      • The CSS Handbook
      • The CSS Handbook
      • Loading Animation
      • CSS Grid Layout
      • Background Image Size
      • Flexbox
  • Series
    • จาวาสคริปต์เบื้องต้น
      • 1: รู้จักกับจาวาสคริปต์
  • Articles
    • Visualization
      • Dash
        • Introducing Dash
    • Finance
      • PyPortfolioOpt
      • Best Libraries for Finance
      • Detection of price support
      • Portfolio Optimization
      • Python Packages For Finance
    • Django
      • เริ่มต้น Django RestFramework
    • General
      • Heroku คืออะไร
      • How to Crack Passwords
    • Notebook
      • IPython Documentation
      • Importing Notebooks
      • Google Colab for Data Analytics
      • Creating Interactive Dashboards
      • The Definitive Guide
      • A gallery of interesting Jupyter Notebooks
      • Advanced Jupyter Notebooks
      • Converting HTML to Notebook
    • Pandas
      • Pandas_UI
      • Pandas Style API
      • Difference Between two Dataframes
      • 19 Essential Snippets in Pandas
      • Time Series Analysis
      • Selecting Columns in a DataFrame
      • Cleaning Up Currency Data
      • Combine Multiple Excel Worksheets
      • Stylin’ with Pandas
      • Pythonic Data Cleaning
      • Make Excel Faster
      • Reading Excel (xlsx) Files
      • How to use iloc and loc for Indexing
      • The Easiest Data Cleaning Method
    • Python
      • pip install package
      • Automating your daily tasks
      • Convert Speech to Text
      • Tutorial, Project Ideas, and Tips
      • Image Handling and Processing
        • Image Processing Part I
        • Image Processing Part II
        • Image tutorial
        • Image Processing with Numpy
        • Converts PIL Image to Numpy Array
      • Convert Dictionary To JSON
      • JSON Dump
      • Speech-to-Text Model
      • Convert Text to Speech
      • Tips & Tricks
        • Fundamentals for Data Science
        • Best Python Code Examples
        • Top 50 Tips & Tricks
        • 11 Beginner Tips
        • 10 Tips & Tricks
      • Password hashing
      • psutil
      • Lambda Expressions
    • Web Scraping
      • Web Scraping using Python
      • Build a Web Scraper
      • Web Scraping for beginner
      • Beautiful Soup
      • Scrape Websites
      • Python Web Scraping
        • Web Scraping Part 1
        • Web Scraping Part 2
        • Web Scraping Part 3
        • Web Scraping Part 4
      • Web Scraper
    • Frontend
      • Book Online with GitBook
      • Progressive Web App คืออะไร
      • self-host a Hugo web app
  • Examples
    • Django
      • Build a Portfolio App
      • SchoolManagement
    • Flask
      • Flask Stock Visualizer
      • Flask by Example
      • Building Flask Apps
      • Flask 101
    • OpenCV
      • Build a Celebrity Look-Alike
      • Face Detection-OpenCV
    • Python
      • Make Game FLASH CARD
      • Sending emails using Google
      • ตรวจหาภาพซ้ำด้วย Perceptual hashing
        • Sending Emails in Python
      • Deck of Cards
      • Extract Wikipedia Data
      • Convert Python File to EXE
      • Business Machine Learning
      • python-business-analytics
      • Simple Blackjack Game
      • Python Turtle Clock
      • Countdown
      • 3D Animation : Moon Phases
      • Defragmentation Algorithm
      • PDF File
        • จัดการข้อความ และรูป จากไฟล์ PDF ด้วย PDFBox
      • Reading and Generating QR codes
      • Generating Password
        • generate one-time password (OTP)
        • Random Password Generator
        • Generating Strong Password
      • PyQt: Building Calculator
      • List Files in a Directory
      • [Project] qID – โปรแกรมแต่งรูปง่ายๆ เพื่อการอัพลงเว็บ
      • Python and Google Docs to Build Books
      • Tools for Record Linking
      • Create Responsive HTML Email
      • psutil()
      • Transfer Learning for Deep Learning
      • ดึงข้อมูลคุณภาพอากาศประเทศไทย
        • Image Classification
    • Web Scraper
      • Scrape Wikipedia Articles
        • Untitled
      • How Scrape Websites with Python 3
    • Finance
      • Algorithmic Trading for Beginners
      • Parse TradingView Stock
      • Creating a stock price database with MariaDB and python
      • Source Code
        • stocks-list
      • Visualizing with D3
      • Real Time Stock in Excel using Python
      • Create Stock Quote Module
      • The Magic Formula Lost Its Sparkle?
      • Stock Market Analysis
      • Stock Portfolio Analyses Part 1
      • Stock Portfolio Analyses Part 2
      • Build A Dashboard In Python
      • Stock Market Predictions with LSTM
      • Trading example
      • Algorithmic Trading Strategies
      • DOWNLOAD FUNDAMENTALS DATA
      • Algorithmic Trading
      • numfin
      • Financial Machine Learning
      • Algorithm To Predict Stock Direction
      • Interactive Brokers API Code
      • The (Artificially) Intelligent Investor
      • Create Auto-Updating Excel of Stock Market
      • Stock Market Predictions
      • Automate Your Stock Portfolio
      • create an analytics dashboard
      • Bitcoin Price Notifications
      • Portfolio Management
    • WebApp
      • CSS
        • The Best CSS Examples
      • JavaScript
        • Memory Game
      • School Clock
      • Frontend Tutorials & Example
      • Side Menu Bar with sub-menu
      • Create Simple CPU Monitor App
      • Vue.js building a converter app
      • jQuery
        • The Best jQuery Examples
      • Image Slideshow
      • Handle Timezones
      • Text to Speech with Javascript
      • Building Blog for Your Portfolio
      • Responsive Website Layout
      • Maths Homework Generator
  • Books
    • Finance
      • Python for Finance (O'Reilly)
    • Website
      • Hugo
        • Go Bootcamp
        • Hugo in Action.
          • About this MEAP
          • Welcome
          • 1. The JAM stack with Hugo
          • 2. Live in 30 minutes
          • 3. Using Markup for content
          • 4. Content Management with Hugo
          • 5. Custom Pages and Customized Content
          • 6. Structuring web pages
          • A Appendix A.
          • B Appendix B.
          • C Appendix C.
    • Python
      • ภาษาไพธอนเบื้องต้น
      • Python Cheatsheet
        • Python Cheatsheet
      • Beginning Python
      • IPython Cookbook
      • The Quick Python Book
        • Case study
        • Part 1. Starting out
          • 1. About Python
          • 2. Getting started
          • 3. The Quick Python overview
        • Part 2. The essentials
          • 14. Exceptions
          • 13. Reading and writing files
          • 12. Using the filesystem
          • 11. Python programs
          • 10. Modules and scoping rules
          • 9. Functions
          • 8. Control flow
          • 4. The absolute basics
          • 5. Lists, tuples, and sets
          • 6. Strings
          • 7. Dictionaries
        • Part 3. Advanced language features
          • 19. Using Python libraries
          • 18. Packages
          • 17. Data types as objects
          • 16. Regular expressions
          • 15. Classes and OOP
        • Part 4. Working with data
          • Appendix B. Exercise answers
          • Appendix A. Python’s documentation
          • 24. Exploring data
          • 23. Saving data
          • 20. Basic file wrangling
          • 21. Processing data files
          • 22. Data over the network
      • The Hitchhiker’s Guide to Python
      • A Whirlwind Tour of Python
        • 9. Defining Functions
      • Automate the Boring Stuff
        • 4. Lists
        • 5. Dictionaries
        • 12. Web Scraping
        • 13. Excel
        • 14. Google Sheets
        • 15. PDF and Word
        • 16. CSV and JSON
    • IPython
    • Pandas
      • จัดการข้อมูลด้วย pandas เบื้องต้น
      • Pandas Tutorial
  • Link Center
    • Temp
  • เทควันโด
    • รวมเทคนิค
    • Help and Documentation
  • Image
    • Logistics
Powered by GitBook
On this page

Was this helpful?

Last updated 5 years ago

Was this helpful?

Image Classification with PyTorch

One of the popular methods to learn the basics of deep learning is with the MNIST dataset. It is the "Hello World" in deep learning. The dataset contains handwritten numbers from 0 - 9 with the total of 60,000 training samples and 10,000 test samples that are already labeled with the size of 28x28 pixels.

Step 1) Preprocess the Data

Before you start the training process, you need to understand the data. In the first step, you will load the dataset using torchvision module. Torchvision will load the dataset and transform the images with the appropriate requirement for the network such as the shape and normalizing the images.

The transform function converts the images into tensor and normalizes the value. The function torchvision.transforms.MNIST, will download the dataset (if it's not available) in the directory, set the dataset for training if necessary and do the transformation process.

To visualize the dataset, you use the data_iterator to get the next batch of images and labels. You use matplot to plot these images and their appropriate label. As you can see below our images and their labels.

Step 2) Network Model Configuration

Now you will make a simple neural network for image classification. Here, we introduce you another way to create the Network model in PyTorch. We will use nn.Sequential to make a sequence model instead of making a subclass of nn.Module.

Here is the output of our network model

Network Explanation

  1. The sequence is that the first layer is a Conv2D layer with an input shape of 1 and output shape of 10 with a kernel size of 5

  2. Next, you have a MaxPool2D layer

  3. A ReLU activation function

  4. a Dropout layer to drop low probability values.

  5. Then a second Conv2d with the input shape of 10 from the last layer and the output shape of 20 with a kernel size of 5

  6. Next a MaxPool2d layer

  7. ReLU activation function.

  8. After that, you will flatten the tensor before you feed it into the Linear layer

  9. Linear Layer will map our output at the second Linear layer with softmax activation function

Step 3) Train the Model

Before you start the training process, it is required to set up the criterion and optimizer function. For the criterion, you will use the CrossEntropyLoss. For the Optimizer, you will use the SGD with a learning rate of 0.001 and a momentum of 0.9.

The forward process will take the input shape and pass it to the first conv2d layer. Then from there, it will be feed into the maxpool2d and finally put into the ReLU activation function. The same process will occur in the second conv2d layer. After that, the input will be reshaped into (-1,320) and feed into the fc layer to predict the output.

Now, you will start the training process. You will iterate through our dataset 2 times or with an epoch of 2 and print out the current loss at every 2000 batch.

At each epoch, the enumerator will get the next tuple of input and corresponding labels. Before we feed the input to our network model, we need to clear the previous gradient. This is required because after the backward process (backpropagation process), the gradient will be accumulated instead of being replaced. Then, we will calculate the losses from the predicted output from the expected output. After that, we will do a backpropagation to calculate the gradient, and finally, we will update the parameters.

Here's the output of the training process

Step 4) Test the Model

After you train our model, you need to test or evaluate with other sets of images. We will use an iterator for the test_loader, and it will generate a batch of images and labels that will be passed to the trained model. The predicted output will be displayed and compared with the expected output.

Reference :

import torch
import torchvision
import numpy as np
from torchvision import datasets, models, transforms

# This is used to transform the images to Tensor and normalize it
transform = transforms.Compose(
   [transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

training = torchvision.datasets.MNIST(root='./data', train=True,
                                       download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(training, batch_size=4,
                                         shuffle=True, num_workers=2)

testing = torchvision.datasets.MNIST(root='./data', train=False,
                                      download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testing, batch_size=4,
                                        shuffle=False, num_workers=2)

classes = ('0', '1', '2', '3',
          '4', '5', '6', '7', '8', '9')
         
import matplotlib.pyplot as plt
import numpy as np

#create an iterator for train_loader
# get random training images
data_iterator = iter(train_loader)
images, labels = data_iterator.next()

#plot 4 images to visualize the data
rows = 2
columns = 2
fig=plt.figure()
for i in range(4):
   fig.add_subplot(rows, columns, i+1)
   plt.title(classes[labels[i]])
   img = images[i] / 2 + 0.5     # this is for unnormalize the image
   img = torchvision.transforms.ToPILImage()(img)
   plt.imshow(img)
plt.show()
import torch.nn as nn

# flatten the tensor into 
class Flatten(nn.Module):
   def forward(self, input):
       return input.view(input.size(0), -1)

#sequential based model
seq_model = nn.Sequential(
           nn.Conv2d(1, 10, kernel_size=5),
           nn.MaxPool2d(2),
           nn.ReLU(),
           nn.Dropout2d(),
           nn.Conv2d(10, 20, kernel_size=5),
           nn.MaxPool2d(2),
           nn.ReLU(),
           Flatten(),
           nn.Linear(320, 50),
           nn.ReLU(),
           nn.Linear(50, 10),
           nn.Softmax(),
         )

net = seq_model
print(net)
Sequential(
  (0): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (2): ReLU()
  (3): Dropout2d(p=0.5)
  (4): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): ReLU()
  (7): Flatten()
  (8): Linear(in_features=320, out_features=50, bias=True)
  (9): ReLU()
  (10): Linear(in_features=50, out_features=10, bias=True)
  (11): Softmax()
)
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(2): 

#set the running loss at each epoch to zero
   running_loss = 0.0
# we will enumerate the train loader with starting index of 0
# for each iteration (i) and the data (tuple of input and labels)
   for i, data in enumerate(train_loader, 0):
       inputs, labels = data

# clear the gradient
       optimizer.zero_grad()

#feed the input and acquire the output from network
       outputs = net(inputs)

#calculating the predicted and the expected loss
       loss = criterion(outputs, labels)

#compute the gradient
       loss.backward()

#update the parameters
       optimizer.step()

       # print statistics
       running_loss += loss.item()
       if i % 1000 == 0:
           print('[%d, %5d] loss: %.3f' %
                 (epoch + 1, i + 1, running_loss / 1000))
           running_loss = 0.0
[1, 	1] loss: 0.002
[1,  1001] loss: 2.302
[1,  2001] loss: 2.295
[1,  3001] loss: 2.204
[1,  4001] loss: 1.930
[1,  5001] loss: 1.791
[1,  6001] loss: 1.756
[1,  7001] loss: 1.744
[1,  8001] loss: 1.696
[1,  9001] loss: 1.650
[1, 10001] loss: 1.640
[1, 11001] loss: 1.631
[1, 12001] loss: 1.631
[1, 13001] loss: 1.624
[1, 14001] loss: 1.616
[2, 	1] loss: 0.001
[2,  1001] loss: 1.604
[2,  2001] loss: 1.607
[2,  3001] loss: 1.602
[2,  4001] loss: 1.596
[2,  5001] loss: 1.608
[2,  6001] loss: 1.589
[2,  7001] loss: 1.610
[2,  8001] loss: 1.596
[2,  9001] loss: 1.598
[2, 10001] loss: 1.603
[2, 11001] loss: 1.596
[2, 12001] loss: 1.587
[2, 13001] loss: 1.596
[2, 14001] loss: 1.603
#make an iterator from test_loader
#Get a batch of training images
test_iterator = iter(test_loader)
images, labels = test_iterator.next()

results = net(images)
_, predicted = torch.max(results, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))

fig2 = plt.figure()
for i in range(4):
   fig2.add_subplot(rows, columns, i+1)
   plt.title('truth ' + classes[labels[i]] + ': predict ' + classes[predicted[i]])
   img = images[i] / 2 + 0.5     # this is to unnormalize the image
   img = torchvision.transforms.ToPILImage()(img)
   plt.imshow(img)
plt.show()
  1. Examples
  2. Python
  3. ดึงข้อมูลคุณภาพอากาศประเทศไทย

Image Classification

Previousดึงข้อมูลคุณภาพอากาศประเทศไทยNextWeb Scraper
https://www.guru99.com/pytorch-tutorial.html