Analytics Vidhya

Analytics Vidhya is a community of Generative AI and Data Science professionals. We are building the next-gen data science ecosystem https://www.analyticsvidhya.com

Follow publication

How to Train a Neural Network Classifier on ImageNet using TensorFlow 2

A sampling of images from the ImageNet dataset, where each image is one of 1000 classes.

Image classification is a classic problem in computer vision. Today, state-of-the-art models for this problem use neural networks, which means that implementing and evaluating these models requires the use of a deep learning library like PyTorch or TensorFlow.

You can easily find PyTorch tutorials for downloading a pretrained model, setting up the ImageNet dataset, and evaluating the model. But I could not find a comprehensive tutorial for doing the same in TensorFlow. In this article, I’ll show you how.

Requirements

  • Python 3
  • TensorFlow 2.3.1 (install with pip3 install tensorflow==2.3.1)
  • TensorFlow Datasets (install with pip3 install tensorflow-datasets==4.1.0)
  • CUDA and cuDNN (since I’m using an NVIDIA GPU)
  • ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar which you can download from here. Note that these archives are typically stored in read-only memory (for multiple users) since they require ~156 GB in storage space.

Overview

There are essentially 3 steps which we’ll work through: preparing the ImageNet dataset, compiling a pretrained model, and finally, evaluating the accuracy of the model.

First, let’s import some packages:

Now, we’ll download ImageNet labels and specify where our ImageNet archive files are located. In particular, data_dir should be the path where ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar are located. And write_dir should be the directory where we’d like to write extracted image content. Make sure your write_dir directory containsextracted, dowloaded, and data directories.

The key here is that we called tfds.load with keyword arguments to the download_and_prepare call, specifying where our archive files were located and where extracted records should be placed.

Now, because pretrained classification models take 224 x 224 images as input, we’ll need to do some preprocessing of our data. Here we’ll use mobilenet_v2.preprocess_input(i), but if you’re using a different model, you can replace this call. For example, if I were using a VGG-16, I’d instead call vgg16.preprocess_input(i).

Next, we compile a model of our choice. In my case, a MobileNet V2.

Finally, because of the way we’ve set up our dataset, we can evaluate this model on the training data and print the accuracy using a few lines of code!

With a functioning pretrained classifier, you can now finetune the model to fit the needs of your classification problem. The full implementation is below, but it can also be found on my GitHub.

Full implementation of downloading a pretrained model, setting up the ImageNet dataset, and evaluating the model in Tensorflow 2

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

Analytics Vidhya
Analytics Vidhya

Published in Analytics Vidhya

Analytics Vidhya is a community of Generative AI and Data Science professionals. We are building the next-gen data science ecosystem https://www.analyticsvidhya.com

Pedro Sandoval-Segura
Pedro Sandoval-Segura

Written by Pedro Sandoval-Segura

PhD Student at University of Maryland

Responses (2)

Write a response