MLOps Blog

Implementing Content-Based Image Retrieval With Siamese Networks in PyTorch

7 min
30th August, 2023

Image retrieval is the task of finding images related to a given query. With content-based image retrieval, we refer to the task of finding images containing some attributes which are not in the image metadata, but present in its visual content.

In this post we:

  • explain the theoretical concepts behind content-based image retrieval, 
  • show step by step how to build a content-based image retrieval system with PyTorch, addressing a specific application: finding face images with a set of given face attributes (i.e. male, blond, smiling).

Concepts explained that might be of interest:

Ranking Loss, Contrastive Loss, Siamese Nets, Triplet Nets, Triplet Loss, Image Retrieval

Check also

How you can track your model training thanks to PyTorch + Neptune integration.

Content-based image retrieval: how to build it in high-level

In order to find the images closest to a given query, an image retrieval system needs to:

  • compute a similarity score between all the images in the test set (often called retrieval set) and the query. 
  • rank all those images by the similarity with the query, 
  • return the top ones. 

A common strategy to learn those similarities is to learn representations (often called embeddings) of images and queries in the same vectorial space (often called embedding space). 

In our example, that would be learning embeddings of face images and vectors encoding face attributes in the same space.

CelebA highlevel retreival

Neural networks are used to learn the aforementioned embeddings. In our case, a Convolutional Neural Network (CNN) is used to learn the image embeddings, and a Multilayer Perceptron (MLP), which is a set of fully connected layers, is used to learn the attribute vectors embeddings. 

Those networks are set up in a siamese fashion and trained with a ranking loss (triplet loss in our case). We explain those concepts deeply next. 

Architectures and losses

Ranking losses: triplet loss

Ranking losses aim to learn relative distances between samples, a task which is often called metric learning

To do so, they compute a distance (i.e. Euclidean distance) between sample representations and optimize the model to minimize it for similar samples and maximize it for dissimilar samples. Therefore, the model ends up learning similar representations for the samples you have defined as similar, and distant representations for samples you have defined as dissimilar.

Triplet Loss is the most commonly used ranking loss. It works with triplets of samples which consist of:

  • an anchor sample (the reference sample of the triplet).
  • a positive sample (which is similar to the anchor).
  • a negative sample (which is dissimilar to the anchor). 

Triplet Loss optimizes the model such that the distance between the negative sample and the anchor sample representations is bigger than the distance between the anchor sample and the positive sample representations than a margin. 

The function of the margin is that when the model sufficiently distinguishes between the positive and the negative samples of a triplet, does not waste efforts on enlarging that separation so that it can focus on more difficult triplets. 

In other words: The margin establish when the network performance for a given triplet is already optimum.

The triplet loss is formally defined as:

L(a,p,n) = max(0,m + d(a,p) – d(a,n))

Where:

  • d is the distance function used (i.e euclidean distance), 
  • m is the margin, 
  • a, p and n are the representations of the anchor, the positive, and the negative samples respectively.

In our face by attributes retrieval example:

  • the anchor sample is an image
  • the positive sample is a vector encoding its attributes, 
  • the negative sample is a vector with attributes that do not match. 

Therefore, the model will be optimized to embed that image close to its attributes vector in the embedding space, and distant to other attributes representations. In the testing phase, when you query with an attribute vector to retrieve images containing those, you will find images embedded close to it.

CelebA highlevel pipeline

Easy, hard, and semi-hard triplets

Depending on the distances between triplet samples, we can have three different types of triplets during the loss computation:

  • Easy triplets: The negative sample is sufficiently distant to the anchor sample compared to the positive one. The loss is 0, and so the gradients are.

d(a,n) > (d(a,p) + m)

  • Hard triplets: The negative sample representation is close to the anchor sample representation than the positive. The net is not correctly distinguishing between the positive and negative samples of this triplet. 

d(a,n) < d(a,p)

  • Semi-hard triplets: The negative sample representation is more distant to the anchor than the positive, but the difference between distances is not bigger than the margin. Therefore, the net has to distance them further.

d(a,n) < (d(a,p) + m)

negatives mining

Siamese Network

Ranking losses are often used with Siamese network architectures. Siamese networks are neural networks that share parameters, that is, that share weights. Practically, that means that during training we optimize a single neural network despite it processing different samples. 

In our face images by attributes retrieval example, each triplet contains an image (anchor), and two attribute vectors (positive and negative). 

The networks processing the attribute vectors will be siamese, which means that weā€™ll use the same network to forward both of them. This is because we want to learn the same feature extractor for both positive and negative attribute vectors.

siamese network

Triplet Loss in PyTorch

PyTorch provides an implementation of the Triplet Loss called Triplet Margin Loss which you can find here.

triplet margin loss

The documentation uses the same nomenclature as this article. By default, it uses the Euclidean distance to compute distances between the input tensors.  You can try changing it to other distances but from my experience, it doesnā€™t make much of a difference.  

You can instantiate it by importing the Pytorch neural networks library and setting the margin:

import torch.nn as nn
criterion = nn.TripletMarginLoss(margin=0.1)

The exact value of the margin isnā€™t critical for the model to converge.

That said, you need to make sure that the difference of the distances computed by the loss is around the value you set. A good practice is to set a margin such that at the beginning of the training half of the triplets are randomly correct and half of them are not. A commonly used value is 0.1.

To compute the loss in the training loop after the model has computed the representations for the triplet samples, you should call:

loss = criterion(anchor, positive, negative)
loss.backward()

With the triplet samples representations always in that order.

Using triplet loss in Pytorch for face images retrieval

Goal: find face images with certain attributes

The goal in this example task is to find face images with certain attributes. To do that, we will use a dataset of face images with annotated binary attributes.  

The dataset used is the Large-scale CelebFaces Attributes (CelebA) Dataset which contains around 200k celebrity face images with 40 annotated binary attributes. Attributes are encoded as 40-dimensional multi-hot vectors, which contain ones to indicate the positive face image attributes and zeros for the rest.

To give you a feel of what the annotated attributes are let me list a few:

  • Bags Under Eyes
  • Pointy Nose
  • Mustache
  • Wearing Earrings
  • Wearing Necklace
  • Oval Face

The resulting image retrieval system aims to find face images with a given combination of those attributes. Here you can see some examples of the images in the dataset:

How to create the training triplets

The triplets to train this system to retrieve face images given face attributes will be composed of:

  1. An image (which is the anchor), 
  2. Its attributes vector (which is positive), 
  3. And a negative attributes vector (the negative).

Therefore, to create the anchor and positive samples of a triplet, you just need to load an image and its attributes and put them in tensors.

Next, you need to select the negative attributes for our triplet. You could just:

  • Create a random multi-hot 40-D vector, 
  • Check that is different from the positive vector, 
  • Use it as negative. 

But then you would end up with too many easy triplets, given that many negative attributes vectors will be unlikely: i.e. (ā€œwomanā€+ā€blondā€+ā€with beardā€). So it would be easy for the model to determine which one is the positive attributes vector, and it would not learn to discriminate properly between likely attribute vectors.

A common strategy, which preserves the dataset statistics, is to:

  • Sample a random attributes vector from the training data, 
  • Check that is different than our positive vector, 
  • Use it as negative. 

If self.attribute_vectors contains all the training set attributes:

while True:
    att_negative = random.choice(self.attribute_vectors)
    if not np.array_equal(att_n, att_positive):
        break

A strategy to create hard negatives, which can be combined with the previous one, is to change some attributes from the positive attributes vector. As an example:

num_changes = random.randint(1,3)
att_negative = np.copy(att_positive)

for c in range(0,num_changes):
   att_idx = random.randint(0,len(att_negative) - 1)
   if att_positive[att_idx] == 0:
      att_negative[att_idx] = 1
   else:
      att_negative[att_idx] = 0

Training with these hard negative triplets forces the model to discriminate between attribute vectors that are similar in all attributes but 1, 2 or 3.

CelebA triplets

How to design the model and train it

You need a neural network architecture that learns image and attribute vector embeddings in the same vectorial space (embedding space), such that we can compute distances between them with the triplet loss. 

The first thing you need to do is to choose the dimensionality of that embedding space. That depends on the complexity of the similarities you want to learn, but typical dimensionalities range from 100 to 500. We will choose 300.

The following figure shows the architecture weā€™ll use:

To learn the image embeddings, we use a CNN (i.e. ResNet-50) with the same number of outputs as the embedding space dimensionality. That CNN takes as input the image tensor (img_a, which is the anchor of our triplet) [224x224x3] and outputs a 300-D vector (a).

To learn the attributes vector embeddings, we use an MLP consisting of two 300-D linear layers with batch normalization and ReLU activations. Both the positive and negative samples of our triplets are attribute vectors, and we will use the same layers to process them, which is known as a siamese architecture (shown in yellow in the figure). 

Once the embeddings of the image and the attribute vectors are computed, we L2 normalize them to ensure they have the same Euclidian Norm and input them to the Triplet Loss to compute distances between them. In PyTorch, we can write our training loop as:

anchor = self.CNN(img_anchor)
positive = self.MLP(att_positive)
negative = self.MLP(att_negative)
loss = criterion(anchor, anchor, negative)
optimizer.zero_grad()
loss.backward()
optimizer.step()

Note that att_positive and att_negative are processed by the same MLP: we use siamese networks for them.

How to monitor the training 

When using a Triplet Loss to train an image retrieval model it is harder to monitor the training than in other scenarios, such as when training a net for image classification. Thatā€™s because  testing the image retrieval requires the whole dataset image embeddings.

Besides directly monitoring the training and validation losses, which of course will show if the model is properly learning, it is useful to monitor the number of correct triplets per batch, during both training and validation. 

The % of correct triplets (easy negatives) per batch should increase during learning, as the loss decreases. Also, that % gives you a hint of how hard are the triplets for the net. 

If a high percentage of the triplets are already correct, the net is not learning anything from them, so forwarding them is a computational waste, and maybe you should modify your triplets creation pipeline to create harder triplets.

Use the trained model for image retrieval

Once the model is trained, using it for image retrieval is straightforward. You should: 

  1. Compute the embeddings of all the retrieval set images and save them to disk.
  2. Compute the embedding of a query (attributes vector)
  3. Compute the distances between the query embedding and all the retrieval set images.
  4. Get the closest ones.

Results?

Letā€™s take a look at some results of the proposed model for different attributes queries:

Looks pretty good!

How to evaluate an image retrieval model

Content-based image retrieval systems are evaluated, by checking if the images they retrieve for a set of queries contain the desired attributes. 

There are different metrics used in the literature, depending on the task and which features of the system you want to evaluate.  This can be confusing so let me explain two most commonly used ones

  • Precision@K (P@K): It only takes into account the top K (usually 1, 5, 10 or 100) retrieved images. It is used when you are only interested in retrieving correctly a limited number of images (recommendation systems), or when the retrieval set is too large to evaluate all the ranking. A retrieved image is considered correct if it contains the query attributes.

P@K = Correct Results in Top K / K

When reporting the average P@K for a set of queries, this metric is sometimes written as AP@K.

  • Mean Average Precision (mAP): It takes into account the ranking of all the retrieval set. It computes P@K for k=1,2,…,N being N the number of retrieval set images with the query attributes. Then, it averages all the precisions. This metric is more expensive to compute and less intuitive, but it evaluates a retrieval system deeply.

What can go wrong: things to avoid and best practices

As always there are o ton of things that can go wrong so let me share my lessons learned:

  • Setting the margin: If at the beginning of the training the loss decreases but the number of correct triplets does not increase, maybe you have set a too large of a margin. A good practice is to set the margin so that, when you start training some of the triplets are already correct by chance. 
  • Monitor the correct triplets and use hard negatives: If most of the triplets in your batches are already correct, modify your triplets mining strategy to use harder negatives, because your model is not learning anything from them. However, be aware of preserving the statistics of the dataset. 
    Maybe you just want to use a given % of hard triplets in your batches: This is highly dependent on your problem, but if your triplets are not realistic, you might end up training your network to distinguish between samples that you won’t find in any realistic scenario.
  • Computing distances during retrieval: Be sure you use the same distance function as the one used in the loss. I.e. if the loss uses Euclidean distance, you should also use it during retrieval.
  • Retrieval efficiency: If you load the retrieval set images embeddings into your GPU, and compute the distances with the query in CUDA, the retrieval will be much faster for large datasets.

Final thoughts

  • Content-based image retrieval is the task of finding images by its content.
  • Ranking Losses allow a neural network to learn distances between samples by comparing their representation in a vectorial space.
  • Siamese Networks are the most common architecture for training with ranking losses. In that architecture, different samples are forwarded through the same layers at each iteration.
  • The used negatives selection strategy is important for efficient training and better performance.
  • Monitoring the number of correct triplets per batch is helpful to ensure the model is being trained properly.

Additional resources

Was the article useful?

Thank you for your feedback!
What topics would you like to see for your next read
Let us know what should be improved

    Thanks! Your suggestions have been forwarded to our editors