MLOps Blog

Self-Supervised Learning and Its Applications

12 min
7th August, 2023

In the past decade, the research and development in AI have skyrocketed, especially after the results of the ImageNet competition in 2012. The focus was largely on supervised learning methods that require huge amounts of labeled data to train systems for specific use cases.

In this article, we will explore Self Supervised Learning (SSL) – a hot research topic in a machine learning community.

What are self-supervised learning (SSL) algorithms?

Self-supervised learning (SSL) is an evolving machine learning technique poised to solve the challenges posed by the over-dependence of labeled data. For many years, building intelligent systems using machine learning methods has been largely dependent on good quality labeled data. Consequently, the cost of high-quality annotated data is a major bottleneck in the overall training process.

One of the top priorities of AI researchers is to develop self-learning mechanisms with unstructured data that can scale the research and development of generic AI systems at a low cost. Practically, it is impossible to collect and label all kinds of varied data.

To solve this problem, researchers are working on self-supervised learning (SSL) techniques capable of capturing subtle nuances in data.

Before we jump into self-supervised learning, let’s get some background about popular learning methods used in building intelligent systems.

1. Supervised learning 

A popular learning technique for training neural networks on labeled data for a specific task. You can think of supervised learning as a classroom where a student is taught by a teacher with many examples. For e.g. object classification.

Supervised Learning
Supervised learning | Source

2. Unsupervised learning

Unsupervised learning is a deep learning technique used to find implicit patterns in the data without being explicitly trained on labeled data. Unlike supervised learning, it does not require annotations and a feedback loop for training. For e.g. clustering.

Unsupervised Learning
Unsupervised learning | Source

3. Semi-supervised learning

Semi-supervised learning is a machine learning method in which we have input data, and a fraction of input data is labeled as the output. It is a mix of supervised and unsupervised learning.

Semi-supervised learning
Semi-supervised learning | Source

Semi-supervised learning can be useful in cases where we have a small number of labeled data points to train the model. The training process can use a small chunk of labeled data and pseudo-label the rest of the dataset. 

For example, A student is taught by a teacher a few problems and he has to figure out the solutions to the rest of the problems by himself.

4. Reinforcement learning

Reinforcement learning is a method used to train AI agents to learn environment behavior in specific contexts using reward feedback policy.

For example: Think of it as a child who is trying to win a stage in a game.

Reinforcement learning process
Reinforcement learning process | Source

What is self-supervised learning?

Self-supervised learning is a machine learning process where the model trains itself to learn one part of the input from another part of the input. It is also known as predictive or pretext learning. 

In this process, the unsupervised problem is transformed into a supervised problem by auto-generating the labels. To make use of the huge quantity of unlabeled data, it is crucial to set the right learning objectives to get supervision from the data itself.

The process of the self-supervised learning method is to identify any hidden part of the input from any unhidden part of the input. 

Self-supervised learning
Self-supervised learning | Source

For example, in natural language processing, if we have a few words, using self-supervised learning we can complete the rest of the sentence. Similarly, in a video, we can predict past or future frames based on available video data. Self-supervised learning uses the structure of the data to make use of a variety of supervisory signals across large data sets – all without relying on labels.

What is the difference between self-supervised and unsupervised learning?

Many people confuse both the terms and use them interchangeably. However, both learning techniques have different objectives. 

Self-supervised learning and unsupervised learning methods can be considered complementary learning techniques as both do not need labeled datasets. Unsupervised learning can be considered as the superset of self-supervised learning as it does not have any feedback loops. On the contrary, self-supervised learning has a lot of supervisory signals that act as feedback in the training process. 

An easier way to put it is that the ‘unsupervised’ learning technique focuses a lot on the model and not on the data whereas the ‘self-supervised learning’ technique works the other way around. However, unsupervised learning methods are good at clustering, and dimensionality reduction, while self-supervised learning is a pretext method for regression and classification tasks.

Why do we need self-supervised learning?

Self Supervised Learning came into existence because of the following issues persistent in other learning procedures:

  • High cost: Labeled data is required by most learning methods. The cost of good quality labelled data is very high in terms of time and money.
  • Lengthy lifecycle: The data preparation lifecycle is a lengthy process in developing ML models. It requires cleaning, filtering, annotating, reviewing, and restructuring according to the training framework.
  • Generic AI: The self-supervised learning framework is one step closer to embedding human cognition in machines.

Now let’s talk about the utility of self-supervised learning in different domains.

Self-supervised learning applications in Computer Vision

For many years the focus of learning methods in computer vision has been towards perfecting the model architecture and assuming we have high-quality data. However, in reality, it is hard to have good quality image data without a high cost of time and effort leading to sub-optimal trained models. 

Lately, a large part of the research focus has been on developing self-supervised methods in computer vision across different applications. The ability to train models with unlabelled data fastens the overall training process and empowers the model to learn underlying semantic features without introducing label bias.

For training a self-supervised model there are mainly two stages:

  1. Pretext task

The task we use for pre-training is known as the pretext task. The aim of the pretext task (also known as a supervised task) is to guide the model to learn intermediate representations of data. It is useful in understanding the underlying structural meaning that is beneficial for the practical downstream tasks. 

Generative models can be considered self-supervised models but with different objectives. For e.g. in GANs, they are used to generate realistic images for the discriminator whereas the aim of self-supervised training is to identify good features that can be used for a variety of tasks and not just to fool the discriminator.

Pretext learning
Pretext learning | Source
  1. Downstream tasks

The downstream task is the knowledge transfer process of the pretext model to a specific task. Downstream tasks are provided with less quantity of labeled data. 

Downstream tasks also known as target tasks in the visual domain can be object recognition, object classification, object reidentification, etc. which is finetuned on the pretext model.

Vanilla SSL approach
Vanilla SSL approach | Source

Many ideas have been proposed by researchers for different image-based tasks to train using the SSL method. 

Patch localization

Objective: The aim of the pretext task is to identify the relationship between different patches in the image using self-supervised learning. 

Patch localization in image
Patch localization in image | Source

Training algorithm [Paper]

  1. Sample a random patch from the image.
  2. Nearest Neighbor: Assuming that the first patch is placed in the middle of a 3×3 grid, the second patch is sampled from its 8 neighbouring locations.
  3. Introduce augmentations such as gaps between patches, chromatic aberration, downsampling, and upsampling of patches to handle pixelation and colour jitters. This helps the model not overfit certain low-level signals.
  4. The aim of the task is to identify which of the 8 neighbouring positions is the second patch. The task is framed as a classification problem over 8 classes.

While finalizing the pretext task, it is important to make sure it is not learning trivial patterns as compared to high-level latent features underlying global patterns. For instance, low-level cues like boundary textures between patches can be considered trivial features. However, for some images there exists a trivial solution. This happens due to a camera lens effect called chromatic aberration which occurs due to differences in the focus of light at different wavelengths. 

The convolutional neural networks are capable of learning the relative location of the patches by detecting the difference between magenta(blue+red) and green. The nearest-neighbor experiments proved that few patches retrieved regions from the absolute same location because patches displayed similar aberration.

Context-aware pixel prediction

Objective: To predict the pixel values of an unknown patch in the image based on the overall context of the image using encoder-decoders.

Training algorithm [Paper]

  1. The pretext task is trained using a vanilla encoder-decoder architecture.
  2. The encoder (Pathak, et al., 2016)  produces a latent feature representation of the image using an input image with blacked-out regions.
  3. The decoder uses the latent feature representation from the encoder and estimates the missing image region using the reconstruction loss(MSE).
  4. The channel-wise fully-connected layer between encoder and decoder allows each unit in the decoder to reason about the entire image content.
Context encoder architecture
Context encoder architecture | Source

Loss function

The loss functions used in training were reconstruction loss and adversarial loss.

Reconstruction loss

  • The reconstruction (L2) loss is responsible for capturing the salient features with respect to the context of the full image.
  • The reconstruction loss is defined as the normalized masked distance of input image x.
    • M: Binary mask corresponding to the removed image region with a value of 0 for input pixels and 1 when a pixel is not considered.
    • F: The function resulting in an output of the encoder
Reconstruction loss
Reconstruction loss | Source

Adversarial loss

  • The adversarial loss is modelled to make the prediction look real, and learn the latent space of the input data it is trained on.
  • Only the generator G is conditioned against the input mask because the discriminator D is able to exploit the perpetual discontinuity in the patched regions and original context.
Adversarial loss
Adversarial loss | Source

Joint loss

  • Joint loss is developed using combining both reconstruction and adversarial losses
  • However, during experiments, the authors realized that inpainting works best with only adversarial loss.
Joint loss
Joint loss | Source

Semantic inpainting is achieved using the SSL method through auxiliary supervision and learning of strong feature representations. Back in 2016, this paper was one of the early pioneers in using the SSL approach in training a competitive image model.

See also

Computer Vision in Machine Learning Industry – Top 12 Best Resources and How to Use Them to Follow Current Trends

15 Computer Visions Projects You Can Do Right Now

Self-supervised learning applications in Natural Language Processing

Before SSL became part of mainstream computer vision research, SSL was already responsible for making huge strides in the Natural Language Processing(NLP) space.  Right from document processing applications, text suggestion, sentence completion, and more, language models were used almost everywhere. 

However, the learning capabilities of these models have evolved since the Word2Vec paper was published in 2013 which revolutionized the NLP space. The idea of word embedding approaches was simple: instead of asking for a model to predict the next word, we can ask it to predict the next word based on prior context. 

Because of such advancements, we were able to obtain meaningful representation through the distribution of word embeddings that can be used in many scenarios such as sentence completion, word prediction, etc. Today one of the most popular SSL methods used in NLP is BERT.

In the past decade, there has been an inflow of amazing research and development in the field of NLP. Let us distill some of the important ones below.

Next sentence prediction

In Next Sentence Prediction (NSP), we pick two simultaneous sentences from a document and a random sentence from the same or a different document, say sentence A, sentence B, and sentence C. Then we ask the model the relative position of sentence A with respect to sentence B?’ – and the model outputs either IsNextSentence or IsNotNextSentence. We do this for all combinations.

Consider the following scenario:

  1. After completing the school hours, Mike went home.
  2. After almost 50 years, the manned mission to Moon is finally underway.
  3. Once home, Mike watched Netflix to relax.

If we asked a person to reorder any two sentences that fit our logical understanding, they would most likely pick sentence 1 followed by sentence 3.

The main objective of the model here is to predict sentences based on long-term contextual dependencies.

Bidirectional Encoder Representations from Transformers (BERT) a paper published by researchers at the Google AI team has become a gold standard when it comes to several NLP tasks such as Natural Language Inference (MNLI), Question Answering (SQuAD), and more. 

For such downstream tasks, BERT offers a great way to capture the relationship between sentences which is not possible through other language modeling techniques. Here’s how it works for Next Sentence Prediction.

BERT | Source
  1. To make BERT handle a variety of downstream tasks, input representation is able to unambiguously represent a pair of sentences that are packed together in a single sequence. A “sequence” refers to the input token sequence to BERT.
  1. The first token of every sequence is always a special classification token ([CLS]). The final hidden state corresponding to this token is used as the aggregate sequence representation for classification tasks. 
  1. We differentiate the sentences in two ways. First, we separate them with a special token ([SEP]). Second, we add a learned embedding to every token indicating whether it belongs to sentence A or sentence B.
  1. We denote input embedding as E, the final hidden vector of the special [CLS] token as C and the final hidden vector for the ith input token as Ti. This vector C is used for Next Sentence Prediction (NSP)

This task can be understood from the following example:


If you want to utilize the BERT Model for this task, you can refer to Hugging Face documentation on how to do it. 

Auto-regressive language modelling

While autoencoding models like BERT utilize self-supervised learning for tasks like sentence classification (next or not), another application of self-supervised approaches lies in the domain of text generation.

Autoregressive models like GPT (Generative Pre-trained Transformer) are pre-trained on the classic language modelling task — predict the next word having read all the previous ones. Such models correspond to the decoder part of the transformer and a mask is used on top of the full sentence so that the attention heads can only see what was before in the text, and not what’s after. 

Let’s take a deeper dive into how these models work by looking at the training framework of GPT.

The training procedure consists of two stages:

  1. Unsupervised Pre-training

The first stage is learning a high-capacity language model on a large corpus of text. 

Given an unsupervised corpus of tokens U = {u1, . . . , un}, we use a standard language modelling objective to maximize the following likelihood: 

Unsupervised Pre-training

where k is the size of the context window, and the conditional probability P is modelled using a neural network with parameters Θ. These parameters are trained using stochastic gradient descent. 

The model being trained here is a multi-layer transformer decoder for the language model, which is a variant of the transformer. This model applies a multi-headed self-attention operation over the input context tokens followed by position-wise feedforward layers to produce an output distribution over target tokens: 

Unsupervised Pre-training

where U = (u−k, . . . , u−1) is the context vector of tokens, n is the number of layers, We is the token embedding matrix, and Wp is the position embedding matrix. This constrained self-attention where every token can attend to context to its left brings the self-supervised approach into the picture.

Read also

Vanishing and Exploding Gradients in Neural Network Models: Debugging, Monitoring, and Fixing

  1. Supervised fine-tuning

In this step, we assume a labelled dataset C, where each instance consists of a sequence of input tokens, x1 , . . . , xm, along with a label y. The inputs are passed through our pre-trained model to obtain the final transformer block’s activation hml , which is then fed into an added linear output layer with parameters Wy to predict y: 

Supervised fine-tuning

This gives us the following objective to maximize: 

Supervised fine-tuning

Including language modelling as an auxiliary objective to the fine-tuning helped learning by – improving generalization of the supervised model, and accelerating convergence. Specifically, we optimize the following objective (with weight λ): 

Supervised fine-tuning

Overall, the only extra parameters we require during fine-tuning are Wy, and embeddings for delimiter tokens.

(left) Transformer architecture and training objectives used in this work. (right) Input transformations for fine-tuning on different tasks
(left) Transformer architecture and training objectives used in this work
(right) Input transformations for fine-tuning on different tasks | Source

In the above image, on left is the Transformer architecture and training objectives and on right are the Input transformations for fine-tuning on different tasks. We convert all structured inputs into token sequences to be processed by our pre-trained model, followed by a linear+softmax layer. For different tasks, different processing is required like for Textual Entailment we concatenate the premise (p), the entailing text and hypothesis (h), the entailed text, token sequences, with a delimiter token ($) in between.

There have been many iterations of improvements over the original GPT model and to understand how you can use it for your own use cases, you can refer to this page.

Read more

Unmasking BERT: The Key to Transformer Model Performance

10 Things You Need to Know About BERT and the Transformer Architecture That Are Reshaping the AI Landscape

Self-supervised learning applications: industrial case studies

So far we have talked about how popular models have been trained using self-supervised approaches and how you can train one yourself or use one from the available libraries.

Now let’s take a look at how the industry is leveraging this technique to solve critical problems.

1. Hate-speech detection at Facebook

“We believe that self-supervised learning (SSL) is one of the most promising ways to build background knowledge and approximate a form of common sense in AI systems.”

AI Scientists, Facebook

Facebook is not just advancing self-supervised learning techniques across many domains through fundamental, open scientific research, but they are also applying this leading-edge work in production to quickly improve the accuracy of content understanding systems in their products that keep people safe on their platforms.

One such example is XLM, Facebook AI’s method of training language systems across multiple languages without relying on hand-labeled datasets to improve hate speech detection.

Hate-speech detection at Facebook
Hate-speech detection at Facebook | Source

This application of self-supervised learning has made their models more robust and their platforms much safer. Let’s talk briefly about what XLM is and how it was able to make such a difference.


The model

It is a Transformers based architecture that is pre-trained using one of the three language-modeling objectives:

  1. Casual Language Modelling (CLM): to model the probability of a word given the previous words in a sentence i.e. P(wt |w1, . . . , wt−1, θ).
  1. Masked Language Modelling (MLM): the masked language modelling objective of BERT i.e. masking randomly chosen tokens with [MASK] keyword and trying to predict them.
  1. Translation Language Modelling (TLM):  a new addition and an extension of MLM, where instead of considering monolingual text streams, parallel sentences are concatenated as illustrated in the following image. Words in both the source and target sentences are masked. To predict a word masked in an English sentence, the model can either attend to surrounding English words or to the French translation, encouraging the model to align the English and French representations. The model can also leverage the French context if the English one is not sufficient to infer the masked English words. 
Cross-lingual language model pretraining
Cross-lingual language model pertaining | Source

Thus, XLM is a ​​cross-lingual language model whose pretraining can be done with either CLM, MLM, or MLM used in combination with TLM. Now let’s take a look at the benefits that XLM brings to the table.

Performance analysis

  1. Cross-lingual classification

XLM provides a better initialization of sentence encoders for zero-shot cross-lingual classification and was able to achieve State-of-the-Art(SOTA) performance by obtaining 71.5% accuracy on the same through the MLM method. Combining MLM and TLM improves the performance even further to 75.1%.

  1. Machine Translation Systems

Similar to the first point, it provides a better initialization of supervised and unsupervised neural machine translation systems. Pre-training with MLM objective showed significant improvements in the case of unsupervised systems while the same objective led to SOTA performance in supervised systems with a BLEU score of 38.5.

  1. Language models for low-resource languages

For low-resource languages, it is often beneficial to leverage data in similar but higher-resource languages, especially when they share a significant fraction of their vocabularies. XLM was found to improve the Nepali language model (a low-resource language) by utilizing the information from Hindi (a relatively popular one with considerable resources) as they share the same Devnagari script.

  1. Unsupervised cross-lingual word embeddings

XLM outperformed previous works on cross-lingual word embeddings by reaching a SOTA level Pearson correlation score of 0.69 between source words and their translation.

With such advancements, XLM has indeed made a difference in Natural Language Processing.

2. Google’s medical imaging analysis model

In the medical domain, training deep learning models has always been a difficult task owing to limited labeled data and the time-consuming and expensive nature of annotating such data. To tackle this problem Google’s Research Team introduced a novel Multi-Instance Contrastive Learning (MICLe) method that uses multiple images of the underlying pathology per patient case, to construct more informative positive pairs for self-supervised learning. 

Google’s medical imaging analysis model
Google’s medical imaging analysis model | Source

A couple of things to keep in mind about the illustrated approach:

  • Step 1 is carried out using SimCLR, another framework designed by Google for self-supervised representation learning on images. We will discuss it shortly.
  •  Unlike step (1), steps (2) and (3) are task and dataset-specific.

So let’s take it step by step.

Step 1: the SimCLR framework

It stands for A Simple Framework for Contrastive Learning of Visual Representations, and it significantly advances the state of the art on self-supervised and semi-supervised learning and achieves a new record for image classification with a limited amount of class-labeled data.

  • SimCLR first learns generic representations of images on an unlabelled dataset, and then it can be fine-tuned with a small amount of labelled images to achieve good performance for a given classification task (just like medical imaging task).
  • The generic representations are learned by simultaneously maximizing agreement between differently transformed views of the same image and minimizing agreement between transformed views of different images, following a method called contrastive learning. Updating the parameters of a neural network using this contrastive objective causes representations of corresponding views to “attract” each other, while representations of non-corresponding views “repel” each other.
  • To begin, SimCLR randomly draws examples from the original dataset, transforming each example twice using a combination of simple augmentations, creating two sets of corresponding views.
  • It then computes the image representation using a CNN, based on ResNet architecture.
  • Finally, SimCLR computes a non-linear projection of the image representation using a fully-connected network (i.e., MLP), which amplifies the invariant features and maximizes the ability of the network to identify different transformations of the same image.
SimCLR framework
SimCLR framework | Source

The trained model not only does well at identifying different transformations of the same image but also learns representations of similar concepts (e.g., chairs vs. dogs), which later can be associated with labels through fine-tuning.

Step 2: MICLe

After the initial pre-training with SimCLR on unlabelled natural images is complete, the model is trained to capture the special characteristics of medical image datasets. This, too, can be done with SimCLR, but this method constructs positive pairs only through augmentation and does not readily leverage patients’ metadata for positive pair construction. Hence MICLe is used here. 

  • Given multiple images of a given patient case, MICLe constructs a positive pair for self-supervised contrastive learning by drawing two crops from two distinct images from the same patient case. Such images may be taken from different viewing angles and show different body parts with the same underlying pathology.
  • This presents a great opportunity for self-supervised learning algorithms to learn representations that are robust to changes of viewpoint, imaging conditions, and other confounding factors in a direct way. 
MICLe | Source

Step 3: Fine-tuning

  • The model is trained end-to-end during fine-tuning, using the weights of the pre-trained network as initialization for the downstream supervised task dataset.
  • For data augmentation during fine-tuning, random colour augmentation, cropping with resizing, blurring, rotation and flipping were done for the images in both tasks (Dermatology and Chest X-Rays).
  • For every combination of pretraining strategy and downstream fine-tuning task, an extensive hyperparameter search was performed.

Performance analysis

  1. Self-supervised learning utilizes unlabelled domain-specific medical images and significantly outperforms supervised ImageNet pre-training.
Comparison of supervised and self-supervised pre-training, followed by supervised fine-tuning using two architectures on dermatology and chest X-ray classification.
Comparison of supervised and self-supervised pre-training, followed by supervised fine-tuning using two architectures on dermatology and chest X-ray classification | Source
  1. Self-supervised pre-trained models can generalize better to distribution shifts with MICLe pre-training leading to the most gains. This is a valuable finding, as generalization under distribution shift is of paramount importance to clinical applications.
Evaluation of models on distribution-shifted datasets
Evaluation of models on distribution-shifted datasets | Source
  1. Pre-training using self-supervised models can compensate for low label efficiency for medical image classification, and across the sampled label fractions, self-supervised models consistently outperform the supervised baseline. In fact, MICLe is able to match baselines using only 20% of the training data for ResNet-50 (4x) and 30% of the training data for ResNet152 (2x).
Top-1 accuracy for dermatology condition classification for MICLe, SimCLR, and supervised models under different unlabeled pretraining dataset and varied sizes of label fractions
Top-1 accuracy for dermatology condition classification for MICLe, SimCLR, and supervised models under different unlabeled pretraining dataset and varied sizes of label fractions | Source

Challenges in self-supervised learning

So far we have talked about how self-supervised learning is making strides in almost every sphere of the Machine Learning community but it has some drawbacks too. Self-supervised learning is trying to achieve the ‘one method solves all’ approach but it is far from that realization. Some of the key challenges in the SSL space are:

  • Accuracy: Although the premise of the SSL technique is to not use labelled data, the downside to that approach is you either need huge amounts of data to generate accurate pseudo labels or compromise on accuracy. It is important to note that inaccurate labels generated will be counterproductive while training in the initial steps.
  • Computational Efficiency: Due to multiple stages of training( 1. Generating pseudo labels 2. Training on pseudo labels) the time taken to train a model is high as compared to supervised learning. Also, current SSL approaches require a huge amount of data to achieve accuracy close to supervised learning counterparts.
  • Pretext Task: It is very important to choose the right pretext task for your use case. For instance, if you choose an autoencoder as your pretext task where the image is compressed and then regenerated, it will also try to mimic the noise of the original image and if your task is generating high-quality images, this pretext task will do more harm than good.

Key takeaways

In this article, we learned about what is self-supervised learning, why it is gaining traction, and what are the risks and challenges associated with it. We also discussed the popular models that have been trained using this approach as well as took a deep dive into how self-supervised learning is being leveraged by big tech companies to solve some really pressing issues. 

To sum up what we have learned so far: 

  • Self-supervised learning is a blessing in use-cases where we deal with challenges related to data. It can range from having low resources for dataset preparation to time-consuming annotation problems.
  • Another thing for which it’s great is Downstream Tasks i.e. Transfer Learning. Models can be pre-trained in a self-supervised manner on unlabelled datasets, which then can further be fine-tuned for specific use-cases.
  • As a result of the first two points, it becomes obvious that self-supervised learning is the go-to approach if you want to build a scalable ML model.
  • However at the same time, one must be aware of the strings that come attached with using this approach.

While we have tried to cover a lot in this article, it is obvious that what we have discussed isn’t exhaustive. There is still a ton to learn about self-supervised learning. If you want to learn more about its present as well as potential use-cases, you can refer to the following material:

Happy Learning!


Was the article useful?

Thank you for your feedback!