Performing Image Augmentation using Pytorch

A Detailed Guide on How to Use Image Augmentation in PyTorch to Give Your Models a Data Boost.

In the last few years, there have been some major breakthroughs and developments in the field of Deep Learning. The constant research and rapid developments have made Deep Learning an industry-standard in the field of AI and the main topic of discussion in almost every AI and Data Science conventions, overthrowing its parent and predecessor— traditional Machine Learning. 

When it comes to Computer Vision, that strictly deals with video and image data, and problems like object detection, body pose detection, image segmentation, etc., Deep Learning has proved out to be a much more reliable option as compared to the traditional Machine Learning.

The reason behind this is that Deep Learning specializes is tackling high dimensionality problems. While machine learning works perfectly fine when you only have a few hundred features to train your model on, the performance starts to deteriorate as the dimensionality of your data increases. With the evolution of Data Science and Big Data over the years, the complexity of the problems and the type of data the Data Scientists have to work with has increased a lot. 

To give you an idea of this massive increase in the scale of data, we will consider an example here. Let’s say that you are working on an image dataset, where you have to deal with 3000×4000 px RGB images.

Considering each pixel data to be a feature, every single data instance (i.e., the images) will have (3000 x 4000 x 3) = 36,000,000 features. Yes, the number of features that the model will have to train on is in millions, which is, frankly speaking, not feasible for almost any traditional machine learning algorithm to handle.

Deep Learning, on the other hand, performs exceptionally well when we have to deal with high dimensional data, like the images in the example we discussed above. This makes Deep Learning the ideal choice for Computer Vision problems. 

This ability to deal with high dimensional data makes Deep Learning seem so powerful, right? And now that GPUs don’t cost an arm and a leg (and that you can access free, high-speed GPUs for free via services like Google Colab), it might seem like there’s no need for traditional Machine Learning at all! 

However, there’s a catch here. Generally, it is observed that the performance of a Deep Learning model is directly proportional to the size of the dataset, i.e., the total number of data instances within a dataset.

Upon a glance at the graph given above, you will observe a rather strange pattern in it. You will see that when the size of the dataset is small, Machine Learning tends to perform slightly better. However, as the size of the dataset that the model is to be trained on increases, Deep Learning models really start to outperform their Machine Learning counterparts by a huge margin.

The reason? Deep Learning model architectures, in general, have millions of parameters to train in order to effectively adapt to certain patterns within the data. To facilitate this extensive training task, a very large amount of data is required. If there isn’t enough data for the model to train, the model’s inference performance will take a huge hit, and you might not get the results that you expected.

Therefore, if you are working on a Computer Vision problem, say an image segmentation problem, then in order to get a good performance out of a Deep Learning model, you obviously need large amounts of image data. Now one solution to this can be a collection of more images for training your model. But the downside of data collection is that it can be a very expensive task, both economically as well as technologically. 

A more economically feasible option would a technique known as Image Augmentation. If you are just getting started in Deep Learning, this might be an entirely new term for you. In that case, you have ended up in the right place. In this article, we will understand what Image Augmentation is, as well as have a look at how to apply image augmentation to training data in Python using PyTorch.

So, let’s get started.

Image Augmentation

Image Augmentation can be defined as the process by which we can generate new images by creating randomized variations in the existing image data. The technique can be used to increase the size of your dataset by creating additional data instances that can be used to train your model on. For an image classifications model, this simply translates to better performance.

I think the definition will become clearer once you see an example. In the example given below, we have the original image of an SUV on a street. 

In the first augmented image, by zooming in and increasing the brightness, we got a new image. The second augmented image was generated by tweaking the hue and temperature of the original image. In the third augmented image, the original image was vertically flipped. Thus, just by tweaking the color and the alignment of the images, we were able to create 3 more data instances that a model can train on. 

For human eyes, all these images in the example given above might look alike. But for a Deep Learning model that deals with the images as individual pixels (with values ranging from 0-255) spanning across the 3 color channels (RGB), all these images are different, since the individual pixel values of these images are different. Thus, image augmentation allows us to generate new image data for training our deep learning model without having to go extra lengths to collect the data manually. 

One more advantage that the image augmentation technique provides for Deep Learning is that by creating randomization in the image data, it significantly reduces the chances of the model overfitting on the training data. This allows the model to generalize better, and hence, improves the inference accuracy of the model.

Image Augmentation Using PyTorch

Now that we know what the image augmentation technique is used for, let us have a look at how you can implement a variety of image augmentations in PyTorch.

For this tutorial, first, we will understand the use and the effect of different image augmentation methods individually on a single image. Once we are done with that, we will see how to perform image augmentations in a Deep Learning project for a real-world dataset.

Let us begin by importing all the necessary PyData modules and PyTorch.

Now, before we start performing the transformations, let us have a look at our original image.

Now, let us have a look at some of the most used image augmentation techniques in PyTorch and the purpose they are used for.

  • CenterCrop – The CenterCrop image augmentation is used to crop the input image at the center. The size of the crop is determined by the help of the ‘size’ attribute. A single integer value as the size argument performs a square cropping on the image of dimension size x size. To set a custom size, the value of the size attribute should be a tuple size = (width, height).

Here’s how to implement CenterCrop in PyTorch:

  • ColorJitter– ColorJitter augmentation technique is used to randomly change the brightness, contrast, saturation, and hue of the image. Unlike the CenterCrop image augmentation that we saw earlier, ColorJitter doesn’t have a fixed behavior. Rather, it results in a random color augmentation each time. 

Here’s how to implement ColorJitter in PyTorch:

  • Grayscale – The Grayscale image augmentation is used to convert a multi-channeled (RGB, CYAN, etc.) image into a single-channeled (gray-scaled) or triple-channeled (r==g==b) image.

Here’s how to implement Grayscale in PyTorch:

  • Pad– The Pad image transform is used to pad the given image on all sides. The thickness of the padding is determined by the ‘padding’ argument. 

Here’s how to implement Pad in PyTorch:

  • RandomCrop– The RandomCrop image augmentation acts in a way similar to that as the CenterCrop. The only difference is that it crops the original image at any random location rather than from just the center. Again, the size of the crop is determined by the size’ attribute.

Here’s how to implement RandomCrop in PyTorch:

  • RandomHorizontalFlip – The RandomHorizontalFlip image augmentation horizontally flips the image. The probability of the flipping operation can be controlled using the ‘p’ attribute, its value ranging from 0 <= p <=1. 

Here’s how to implement RandomHorizontalFlip in PyTorch:

  • RandomVerticalFlip – Just like the horizontal flip augmentation that we saw earlier, RandomVerticalFlip also flips the image. The only difference is the flipping occurs across the x-axis, i.e., in simple words, in the vertical direction. The probability of the flipping operation can be controlled using the ‘p’ attribute, its value ranging from 0 <= p <=1. 

Here’s how to implement RandomVerticalFlip in PyTorch:

  • RandomPerspective – The RandomPerspective image augmentation is used to randomly distort the image along with a given perspective. The probability of the flipping operation can be controlled using the ‘p’ attribute, its value ranging from 0 <= p <=1.; and the scale of the distortion can be controlled using the ‘distortion_scale’ attribute, its value also ranging between 0-1.

Here’s how to implement RandomPerspective in PyTorch:

  • RandomRotation – The RandomRotation randomly rotates the image. The degree of rotation of the image is determined using the ‘degree’ attribute. 

Here’s how to implement RandomRotation in PyTorch:

  • RandomErasing – The RandomErasing image augmentation technique randomly selects a rectangular region in the original image and erases all the pixels in the region. The probability or the erase operation can be controlled using the ‘p’ attribute, its value ranging from 0 <= p <=1.

Here’s how to implement RandomErasing in PyTorch:

Now that we have seen some of the most used image augmentation techniques in PyTorch, let us have a look at how to apply these in a real-world project. Generally, the augmentations/transforms are applied in a sequence, all at once. For this, we have to use torchvision.transforms.Compose() method. The augmentations that are to be performed on the images are passes to the compose method as an argument.

Let us see how to implement this using PyTorch:

Up until now, we saw how to apply the transformations/augmentations on a single image. But in real-world problems, the datasets may have thousands of images. 

Unlike the Pandas DataFrames that we see in many traditional machine learning problems, it is generally not possible to store all the images in the memory (RAM) at once in the form of DataFrames. Therefore, PyTorch handles these images via the various Dataset classes available in PyTorch.In order to apply the transforms on an entire dataset, all you need to do is pass the torchvision.transforms.Compose method object (or an individual image augmentation method object, if you want) as the value to the ‘transform’ attribute. There are several Dataset classes in PyTorch, but as an example, we will see how to apply the image augmentation to an ImageFolder dataset.

With this, we come to an end of our tutorial part where we learned why image augmentation is necessary for Deep Learning and how to apply different image augmentations in PyTorch. By using multiple combinations of augmentations on different batches of data and retraining your model on this augmented and original data again and again over several epochs (training cycles), you can overcome the barrier that has a small image dataset poses against your model.

If you know any other image augmentation technique that you like to use often, let us know about that in the comments!

Happy learning!


You may also like...

3 Responses

  1. September 16, 2020

    […] released their state of the art DeepFakes detection system — Microsoft Video Authenticator. A computer vision-powered system, the Video Authenticator can analyze still images or the consecutive frames in a video to […]

  2. September 16, 2020

    […] released their state of the art DeepFakes detection system — Microsoft Video Authenticator. A computer vision-powered system, the Video Authenticator can analyze still images or the consecutive frames in a video to […]

  3. November 17, 2020

    […] Computer vision enables the computer to understand, visualize, and analyze the images present and enables it to make predictions and detect certain objects present within the images.  […]


Leave a Reply

Your email address will not be published. Required fields are marked * Protection Status