Few Shot Learning from Scratch

Deep Gan Team
6 min readJul 7, 2021

Wiley Wang, John Inacay, and Mike Wang

One of the emerging concepts in the field of deep learning is Few Shot Learning. If you’ve been studying Machine Learning or Deep Learning, you’ve probably heard this term before. But what is it? How does it actually work? We’ll dive into the topic, and show one of the ways to perform Few Shot Learning through Twin Networks. Below, we’ll learn how to build a One Shot Learning system using Twin Networks. You can find our Twin Network Implementation in PyTorch here.

What is Few Shot Learning?

First of all, what is and why should you want to learn Few Shot Learning? Deep learning algorithms famously require massive datasets to perform their target task. However, obtaining and labeling large amounts of data is often time consuming and expensive. The supervised method of training demands more data to cover more tasks. We need a large sampling of each class for a typical classification task. However, if we rethink the deep learning classification task as a 1-to-1 direct input to label function, we can find ways to still perform with smaller datasets. The main goal of Few Shot Learning is to enable Deep Learning algorithms to work in cases of limited data on the direct task.

What is a Twin Network?

A Twin Network uses two identically structured neural networks as backbones whose two outputs are fed into another function or network for a final result. When each network takes its own input, this network structure acts like a comparison agent to provide an equivalence of a similarity function. In our example, we use two Convolutional Neural Networks (CNNs) as the backbones. Traditional CNNs are often used as classifiers by taking on a single input and sendings its output directly to classification layer(s). Through training, the single CNN classifier learns to discern specific classes given enough examples end-to-end. In a twin network, the relationship of images belonging to the same class away from images belonging to different classes is learned end-to-end. Twin Networks learns through Contrastive Loss, a loss function that minimizes when the two inputs belong to the same class, and maximizes when the two inputs belong to different classes.

To show an example, below are the array of x1, x2,and the losses
An example x1 image array
An example x2 image array
The loss array produced by x1 and x2

Figure 1: An example training batch where an image from x1 is compared versus its corresponding image from x2. The Loss Array shows how images of the same fruit have a label of 0 and images of different fruits have a label of 1. Our goal is for the Twin Network to automatically learn how to perform this task.

What is N-Way K-Shot Learning?

One implementation example of Few Shot Learning is to apply the N-Way K-Shot Learning with a Twin Network. We compare an input image to N classes of K images each, and choose the most likely class where the triplet losses are the least. For example, the 5-Way 2-shot task pairs our input image with 5 other classes with 2 examples of each class, resulting in 10 images we compare against.

With a Twin Network, we generate an embedding for each image. The last layer of the Twin Network performs comparison of the two embeddings of each image. If we view the outputs of the Twin Network as a distance function, N-Way K-Shot Learning is similar to the nearest neighbor algorithm.

Our Implementation of Few Shot Learning

To practice Few Shot Learning, we tackled the problem of fruit classification on the Kaggle Fruits 360 dataset. Again, our implementation can be found here.

To start with, we preprocess the data for pytorch consumption. We then define the twin network, using the same CNN backbone to generate features for each image in a pair. We calculate the similarity between the 2 CNN outputs by taking the absolute difference and passing it through a sigmoid function, saturating the final outputs towards 0 or 1 to match our ground truth labels.. To train, we pass batches of image pairs and evaluate the model every 250 batches with a 20-way 1-shot task.

During training, we apply binary cross entropy loss. One way to interpret the problem is as a binary classification problem where the similarity function part of the network is a classifier. However, a subtly different way to interpret it is that we’re trying to pull images of the same class closer together (difference towards 0) and pushing images of different classes further away (difference towards 1). This is the intuition of Contrastive Loss, which binary cross entropy also fits as an implementation.

Another Application: Face Recognition

Figure 2: The Twin Network learns to say that images of person A and B belong to different people
Figure 3: The Twin Network also correctly identifies that two different images of the person A are the same person

How do you use a Twin Network to perform One Shot Learning? One of the most common use cases of One Shot Learning is Face Recognition. In an example scenario, you might want to develop a system that recognizes a user with just a single reference photo. Importantly, our system has never seen this user before. Assuming that we’ve already trained a Twin Network to compare faces, we can compare different photos of people to our reference photo. We can supply the reference photo to the Twin Network’s left input and other photos to the right input. The Twin Network will then say whether the inputs are the same person or not. Our One Shot Learning system is successful if we recognize only the target user correctly. In practice, this method is comparable to K nearest neighbors (note: K does not mean the same thing in this method), where we classify an input by comparing it to its closest examples and their labels/classifications.

Tips and Tricks

Setting the correct learning rate for your optimizer is important. When we initially started training our Twin Network, the loss value initially stayed at a constant value throughout the entire training process. We initially thought our Twin Network implementation wasn’t learning or that our training data was formatted incorrectly. After digging into the issue more deeply, we finally found that our optimizer’s learning rate had been set too high, meaning the optimization process was diverging rather than converging. After lowering the learning rate to 0.0001, the network’s loss value started to decrease. Our experience shows how divergence in a neural network can be subtle and difficult to notice.

Conclusion

In conclusion, we’ve learned how to train a Twin Network from scratch in PyTorch on real data. Once we’ve trained our Twin Network, we’ve also learned how to use this network to perform One Shot Learning. One Shot Learning allows us to recognize new classes with just one example image rather than hundreds or thousands of examples. Twin Networks are an important step in the way that researchers have found to perform Few Shot Learning.

References

--

--

Deep Gan Team

We’re a team of Machine Learning Engineers exploring and researching deep learning technologies