Press "Enter" to skip to content

Building an Image Classifier with PyTorch

Rogier van der Geer shows how you can use PyTorch to build out a Convolutional Neural Network for image classification:

The tool that we are going to use to make a classifier is called a convolutional neural network, or CNN. You can find a great explanation of what these are right here on wikipedia.

But we are not going to fully train one ourselves: that would take way more time than I would be willing to spend. Instead, we are going to do transfer learning, where we take a pre-trained CNN and replace only the last layer by a layer of our own. Then we only need to train that single layer, as all the other layers already have weights that are quite sensible. Here we exploit the fact that the images we are interested in have a lot of the same properties as those images that the original network was trained on. You can find a great explanation of transfer learning here.

Read on for a detailed example.