Classification is one of the most widely applied areas in Machine Learning. As Data Scientists, we all have worked on an ML classification model. Do you remember what was the number of classes in the classification problem you solved, at max, maybe 100 or 200? Will the same model architecture work when the number of classes exceeds 10000?
Several real-world applications, including product recognition, face matching, and visual product search, are formulated as a Multiclass Classification problem. A Multiclass Classification is a class of problems where a given data point is classified into one of the classes from a given list.
Traditional Machine Learning and Deep Learning methods are used to solve Multiclass Classification problems, but the model’s complexity increases as the number of classes increases. Particularly in Deep Learning, the network size increases as the number of classes increases. Therefore, it becomes infeasible to solve a Multiclass Classification problem with thousands of classes at some point in time. In this article, we will talk about feasible techniques to deal with such a large-scale ML Classification model.
In this article, you will learn:
- 1 What are some examples of large-scale ML classification models?
- 2 Lesson 1: Mitigating data sparsity problems within ML classification algorithms
- 3 Lesson 2: Deep metric learning for classification
- 4 Lesson 3: Selecting an appropriate loss function to train large-scale classification problems
- 5 Lesson 4: Measuring the performance of a large-scale classification model
- 6 Lesson 5: Understanding the common challenges you will face with classification models in production
- 7 Lesson 6: Best practices to follow while building a pipeline for classification models
What are some examples of large-scale ML classification models?
Think of a visual product search application where given an image of a product, the system needs to fetch the most relevant product from the catalogue. The product catalogue might have close to a million unique products. To solve this problem of product search, one needs to build a classification solution where the number of classes is equal to the number of unique products, which could be in the order of millions. It’s infeasible to solve this using traditional softmax-based approaches. For these kinds of applications, one needs to build solutions differently.
It is important to keep a few things in mind while dealing with such problems. Let’s take a look at some of them.
Lesson 1: Mitigating data sparsity problems within ML classification algorithms
What are the most popular algorithms used to solve a multi-class classification problem?
- 1 KNN
- 2 Decision Tree
- 3 Random Forest
- 4 Naive Bayes
- 5 Deep Learning using Cross Entropy Loss
To some extent, Logistic Regression and SVM can also be leveraged to solve a multi-class classification problem by fitting multiple binary classifiers using a one-vs-all or one-vs-one strategy.
The primary issue while solving a large-scale classification problem using any of the above-mentioned algorithms is that, given a class, the number of samples collected could sometimes be very limited. So, the model might not have a sufficient number of data samples to learn the pattern for each class. In most of the scenarios, it’s difficult to collect a sufficient number of samples for each class, and we find most of the time, there is a long tail of classes where the number of data points is relatively very less.
Traditional softmax-based classification methods might not work well here as they may suffer from data sparsity issues within such classes. Hence, these kinds of problems need to be solved differently. Instead of optimising for cross-entropy loss, these types of problems can be solved by projecting input features into high dimensional vector space and then performing a k nearest neighbour (kNN) search in the embedding space. This method is called Metric Learning. Some groups also call this field Extreme Classification. Zero-shot or few-shot learning could be another way out of it, as such methods allow models to learn, even when the number of observed samples is just one or only a few.
Lesson 2: Deep metric learning for classification
Metric Learning is a field of research that aims to construct task-specific distance metrics using supervised or weakly supervised data. Metric Learning methods generally use a linear projection, but Deep Metric Learning leverages Deep Learning architectures to capture complex non-linear transformations.
Metric Learning problems principally fall into two categories one is Supervised where data points have a class label as target (same as in standard classification problem), and another one is Weakly Supervised where data points are at tuple level (pairs, triplets or quadruplets). In the case of the Weakly Supervised case, the goal is to learn a distance metric that minimizes the distance between positive pairs and maximizes the distance between negative pairs. This method can also be leveraged to solve non-classification problems as well.
As this method works on distance metrics, the success of these networks depends on these networks’ understanding of similarity relationships among samples. However, Data Preparation, Data Sampling Strategy, selection of appropriate Distance Metrics, selection of the appropriate Loss function, and the structure of the network determine the performance of these models as well.
Lesson 3: Selecting an appropriate loss function to train large-scale classification problems
In this section, you will learn a set of loss functions that could be utilised to train a large-scale classification model. The selection of the correct loss function plays a pivotal role in the success of the algorithm.
Contrastive Loss considers a pair of similar (called positive) and dissimilar (called negative) samples and differentiates between them by contrasting those pairs. The objective of the loss is to learn distance in a way, so that distance for positive pairs is lesser than the distance for negative pairs.
In this equation:
- Y is 0 when input samples are similar pairs, and Y is 1 when input samples are dissimilar pairs.
- m is the margin.
- Dw is the distance function.
Here’s an example of how you can define the contrastive loss function:
The Triplet Loss works on a set of triplets where each triplet has an anchor, positive and negative sample. It follows the same fundamentals as Contrastive Loss, which minimizes the distance between the anchor and positive and maximizes the distance between the anchor and negative.
The Loss function for Triplet Loss is as follows:
L(a, p, n) = max(0, D(a, p) — D(a, n) + margin)
where D(x, y): the distance between the learned vector representation of x and y. As a distance metric L2 distance can be used. The objective of this function is to keep the distance between the anchor and positive smaller than the distance between the anchor and negative.
A Vanilla softmax is not a viable option
One needs to look beyond Vanilla Softmax when the number of classes is more than 100 of thousands or one million. Due to summation over a very large number of classes, the computation of loss becomes expensive. Hence, this leads to an increase in inference time as well. So, instead of a Vanilla Softmax, a hierarchical softmax or noise contrastive loss is used.
Sampling hard negative helps to learn more precise features
As mentioned earlier, Data Sampling Strategy is key to success while building these models. If we are using Triplet Loss, the dataset is a set of triplets. Each data sample has a triplet of an anchor, a positive and a negative where the anchor and the positives are similar pairs, and the anchor and negative are dissimilar pairs.
During learning, the model learns a projection where the positive is placed closer to the anchor than the negatives. If we have easy negatives, then it’s easy for the model to learn the projection. So, sometimes it’s important to sample hard negatives that push the model to learn more specific features.
As an example, regarding a product recognition system, it’s relatively easier for the model to learn patterns about a pair of apples as positive and a pair of apples and pineapple as negative. But the model will learn more precise features when we consider a mango instead of a pineapple as negative, as a mango’s shape is much more similar to an apple’s. So, during data preparation, we need to figure out a way to identify easy negatives, soft negatives, and hard negatives and, based on the requirement, have to choose a sample from each of these categories.
You may also like
Lesson 4: Measuring the performance of a large-scale classification model
To evaluate a classification model, one generally uses accuracy, precision, recall, and f1-score, where metrics are calculated on individual data samples, and in the end, the summation is done to compute the overall metrics. All of these above metrics are 0-1 metrics where an inference is considered a correct prediction if the top predicted class is relevant.
For a product recognition system, the objective is to identify the SKU (Stock Keeping Unit) of an item given a product image as input. As an example, say there are three flavours (raspberry, chocolate, and pistachio) of cornetto ice cream. Now, predicting a “raspberry cornetto ice cream” as “chocolate cornetto ice cream” would be a less costly misclassification than identifying it as a “rose flower vase”.
Even if the top prediction is not the most relevant, we are interested in looking at relevant predictions in top k. So, we are not only interested in the top 1 prediction but also interested if any class in top k is relevant. That’s why along with precision and recall, Precision@k and recall@k are also important metrics to evaluate a large-scale classification system.
We can also leverage the power of visualisation to evaluate the model. A set of classes sometimes forms a group/cluster. So, we can plot the high-dimensional vector space into lower dimensions and evaluate the integrity at the cluster level.
Lesson 5: Understanding the common challenges you will face with classification models in production
So far, we have talked about ways to build the model by choosing an appropriate sampling strategy, algorithm, and loss function and by choosing the right evaluation metrics. In this section, let’s discuss the issues that you might face with such models in production.
In the traditional classification model, we pass the input feature vector through the hypothesis function and get inference in O(1) time. But in this case, we are approaching the problem differently. We intend to learn the vector representation of the input feature vector and then perform a k nearest neighbour search in the high dimensional vector space.
The k-nearest neighbour search could sometimes be challenging in terms of computation time as we need to compute distance with every representative sample (we need one or more representative samples for each class). So, the running time could be much higher.
As the computation of distance with all representative samples is not trivial, we follow approximate nearest neighbour (ANN) search algorithms that find the most similar representative in O(log N) time.
Approximate Nearest Neighbour (ANN) search algorithms can be majorly categorised into 3 buckets.
- 1 Hash-based
- 2 Tree based
- 3 Graph-based
In the following section, we will talk briefly about some of the popular ANN algorithms.
Hierarchical Navigable Small Worlds (HNSW) graphs are one of the top choices for vector similarity search. HNSW is a graph-based robust algorithm for approximate nearest neighbour search. It builds a proximity graph where two vertices that are closer in proximity are linked. The proximity graph is built on two fundamentals Probability Skip List and Navigable Small World (NSW).
- In the case of the Probability Skip List, HNSW builds a hierarchical graph where the highest layers have longer edges that enables fast search, whereas lower layers have shorter edges that enable accurate search.
- For Navigable Small World (NSW) graph, the idea is to build a proximity graph with short-range and long-range links so that the search time is reduced to poly or logarithmic complexity. While searching the Navigable Small World Graph, we start at a pre-defined source and traverse the graph by identifying vertices that are closest to the query vector.
HNSW is one of the most straightforward approaches to building a graph for nearest neighbour search, but it’s the best indexing scheme in terms of memory utilisation. But there are techniques like Product Quantization (PQ) that improve memory utilisation by compressing vectors.
Facebook AI Similarity Search (FAISS) is a library for efficient similarity search that has implemented several algorithms related to vector similarity search, including (but not limited to) Product Quantisation, Hierarchical Navigable Small World (HNSW), Additive Quantisation, Search with Inverted indexing, etc. The backbone is implemented in C, and It’s optimised for memory utilisation and speed.
Following is a code snippet for FAISS with HNSW algorithm.
Check this out to learn more.
Scalable Nearest Neighbour (ScaNN) is another implementation of vector similarity search for large-scale problems. This implementation includes search space pruning and quantization for Maximum Inner Product Search. This library claims the best performance in terms of search speed.
Following is a code snippet to build and search using ScaNN.
Refer to this for further reading on this.
So far we have seen multiple techniques to improve turnaround time during inference, but all these techniques are limited by the capacity of physical memory. In reality, the inference server has limited RAM. Due to that, we would only be able to index some of the vectors in memory. So, the running time may increase proportionally with the number of classes which is not the case for traditional softmax-based classification approaches. When we look at the end-to-end flow of the Metrics Learning approach, we will find multiple areas that can be optimised to improve the runtime.
- Optimisation in Approximate Nearest Neighbour Search
The runtime of the nearest neighbour search primarily depends on the size of the search space. One might think of smarter ways to formulate the problem to reduce the search space. As we are working with a very large number of classes, a set of similar classes might form a cluster. We can use additional information to formulate a typology of classes. So, during inference, first, we identify the cluster the sample belongs to and then run a nearest neighbour search on the space of that cluster. ANN Search algorithm like HNSW tries to achieve similar fundamentals by leveraging the vector space.
- Optimisation of the model architecture
Optimisation of the runtime of the model architecture is subject to the nature of the model. I would like to mention some of the optimisation techniques that are not only applicable for Large Scale Classification but also applicable to any Deep Learning model, in general, are the followings:
- If it’s a PyTorch model, serve using ONNX runtime.
- Deploy the model in a triton inference server.
- Horizontal Scaling using multi-batch input.
Lesson 6: Best practices to follow while building a pipeline for classification models
Versioning and building an automated training pipeline are essential for a model built using the Metric Learning approach. Whenever there is a change in classes, there is a need to re-train the model and refresh the index of embeddings with new changes.
This is where an MLOps tool like neptune.ai could be super helpful. Using neptune.ai, you will be able to perform versioning of models and their artifacts. Neptune also allows you to easily manage ML metadata via its customizable UI for operations like searching and comparing. It is able to integrate easily with a variety of data science tools.
While neptune.ai helps you to seamlessly track and version model artifacts, you might also want to leverage workflow and orchestration tools like Kubeflow and Apache Airflow to build and automate the indexing pipeline.
This article talks about the difficulties one faces while building a large-scale classification model. We have seen how a large-scale classification system is different from a traditional softmax-based classifier. Approximate Nearest Neighbour search is the key to building an end-to-end system, and the choice of algorithm for the approximate nearest neighbour search is the driving factor for the success of the algorithm. The popularity of Deep Learning and its ability to generate rich representation has made this topic an active field of research, and we hope to see more work on this domain from the research community.
- Large scale classification in deep neural network with Label Mapping
- Improving Deep Metric Learning by Divide and Conquer
- Deep Metric Learning: A Survey
- Metric Learning
- One-Shot Learning With Siamese Network
- Hierarchical Navigable Small Worlds (HNSW)
- Faiss: A library for efficient similarity search
- ANN Benchmarks: A Data Scientist’s Journey to Billion Scale Performance
- IVFPQ + HNSW for Billion-scale Similarity Search
- Apache airflow