MLOps Blog

Continual Learning: Methods and Application

10 min
22nd February, 2024

TL;DR:

In many machine-learning projects, the model has to frequently be retrained to adapt to changing data or to personalize it.

Continual learning is a set of approaches to train machine learning models incrementally, using data samples only once as they arrive.

Methods for continual learning can be categorized as regularization-based, architectural, and memory-based, each with specific advantages and drawbacks.

Adapting continual learning is an incremental process, from carefully identifying the objective over implementing a simple baseline solution to selecting and tuning the continual learning method.

The key to continual learning success is identifying the objective, choosing the right tools, selecting a suitable model architecture, incrementally improving the hyperparameters, and using all available data.

At the beginning of my machine learning journey, I was convinced that creating an ML model always looks similar. You start with a business problem, prepare a dataset, and finally train the model, which is evaluated and deployed. Then, you repeat this process until you are satisfied with the results.

But most real-world machine learning (ML) projects are not like that. There are a lot of problems that make the whole process much more complicated. For example, an insufficient amount of training data, limited computing power, and, of course, running out of time.

What’s more – what if the data distribution changes after model deployment? What if you handle a classification problem and the number of classes increases over time?

These problems keep many ML practitioners awake at night. If you’re part of this group, continual learning is exactly what you need.

What is continual learning?

Continual learning (CL) is a research field focusing on developing practical approaches for effectively training machine learning models incrementally.

Training incrementally means that the model is trained using batches from a data stream without access to a collection of past data. Rather than having access to an entire dataset during model training, like in traditional machine learning, plenty of smaller datasets are passed to the model sequentially.

Each smaller dataset, which might contain just one sample, is only used once. The data just appears like a stream, and we don’t know what to expect next.

Consequently, we don’t have training, validation, and test sets in continual learning. In classic ML training pipeline, we focus on achieving high performance on the current dataset, which we measure by evaluating a model on the validation and test set. In CL, we also want to achieve a high performance on the current batch of data. But simultaneously, we must prevent the model from forgetting what it learned from past data.

Important note:

Continual learning aims to allow the model to effectively learn new concepts while ensuring it does not forget already acquired information.

Plenty of CL techniques exist that are useful in various machine-learning scenarios. This article will focus on continual learning for deep learning models because of their ability for wide adaptation and suitability.

Use cases and applications

Before we dive into specific approaches and their implementations, let’s take a step back and ask: When exactly do we need continual learning?

Using CL techniques may be the solution when:

  • A model needs to adapt to new data quickly: Some ML models require frequent retraining to be useful. Consider a fraud detection model for bank transfers. If you achieve 99% accuracy on the initial training dataset, there is no guarantee that this accuracy will be maintained after a day, week, or month. New fraud methods are invented daily, so the model needs to be updated (automatically) as quickly as possible to prevent malicious transactions. With CL, you can ensure that the model learns from the latest data and adapts to it as effectively and quickly as possible.
  • A model needs to be personalized: Let’s say you maintain a document classification pipeline, and each of your many users has slightly different data to be processed—for example, documents with different vocabulary and writing styles. With continual learning, you can use each document to automatically retrain models, gradually adjusting it to the data the user uploads to the system.
Model personalization via CL learning in a document classification
Figure 1. Model personalization via continual learning in a document classification process. The user uploads documents to the system, and the model is retrained after each batch, creating a new personalized model.

In general, continual learning is worth considering when your model needs to adapt to data from a stream quickly. This is often the case when deploying a model in dynamically changing environments.

Continual learning scenarios

Depending on the data stream characteristics, problems within the continual learning scenario can be divided into three, each with a common solution.

Class incremental continual learning

Class Incremental (CI) continual learning is a scenario in which the number of classes in a classification task is not fixed but can increase over time.

For example, say you already have a cat classifier that can distinguish between five different species. But now, you need to handle a new species (in other words, add a sixth class). 

Such a scenario is common in real-world ML applications yet is among the most difficult to handle.

Domain incremental continual learning

Domain Incremental (DI) continual learning comprises all cases where data distribution changes over time.

For example, when you train a machine learning model to extract data from invoices, and users upload invoices with a different layout, then we can say that the input data distribution has changed.

This phenomenon is called a distribution shift and is a problem for ML models because their accuracy decreases as the data distribution deviates from that of its training data.

Task incremental continual learning

Task Incremental (TI) continual learning is classic multi-task learning but in an incremental way.

Multi-task learning is an ML technique where one model is trained to solve multiple tasks. This approach is widespread in NLP, where one model might learn to perform text classification, named entity recognition, and text summarization. Each task will have a separate output layer, but the other model parameters can be shared.

In task incremental continual learning, instead of having separate models for each task, one model is trained to solve them all. The difficulty in the continual learning setting is that data for each task arrives at a different time, and the number of tasks might not be known beforehand, requiring the model’s architecture to expand over time. Every input example needs a task label that helps identify your expected output. For instance, outputs in classification and text summarization problems are different, so based on the task label, you can decide if the current example trains classification or extraction.

Challenges in continual learning

Unfortunately, there is no free lunch.

Training models incrementally is challenging because ML models tend to overfit current data and forget the past. This phenomenon is called “catastrophic forgetting” and remains an open research problem.

The most difficult CL scenario is class-incremental learning, as learning how to discriminate among a wider set of classes is much more demanding than adapting to shifts in data. When a new class appears, it may significantly impact the decision boundary of existing classes. For example, a new class, “labrador retriever,” will have some overlap with an existing class “dog.”

In contrast, task-incremental problems are relatively easier and better researched because they can be simply solved by freezing part of the model parameters (which prevents forgetting) and training only the output layers.

However, regardless of the scenario, training an ML model incrementally is always much more complex than classic offline training, where all the data is available upfront, and you can implement hyperparameter optimization. Moreover, different model architectures react to incremental training in their own way. It is not easy to find the best (or just satisfying) solution right away, even for experienced machine learning engineers. Therefore, a good practice is to run and carefully track various experiments. It makes you verify ideas not just in theory but, first of all, in practice.

Tracking continual learning experiments with neptune.ai

To give you an idea of what experimenting with CL methods looks like, I’ve prepared examples in a GitHub repo.

I used Pytorch and Avalanche to create a simple experimental setup to compare various continual learning methods on an image classification problem in a class-incremental scenario. The experiments show that memory-based methods (Replay, GEM, AGEM) outperform all other techniques regarding the final model’s accuracy.


The code is set up to track all experiment metadata in Neptune. If you want, you can see the project and the results of my experiment here in my Neptune account.

See in the app

Neptune.ai provides a convenient way to track and compare machine learning experiments. Check out my example project or visit the product website to learn more.

Continual learning methods

Over the second decade of the 2000s, there has been a rapid improvement in recent advances in continual learning methods. Researchers proposed many new techniques to prevent catastrophic forgetting and make incremental model training more effective.

These techniques can be divided into architectural, regularization, and memory-based approaches.

Architectural approaches

One way to adapt an ML model to new data is to modify its architecture. Methods that focus on this approach are called architectural or parameter-based.

If you serve clients from various countries and need to train a personalized text classifier for each of them (task-incremental scenario), you can use a multilingual LLM (Large Language Model) as the core model and select a different classification layer based on the input text’s language. While the core model’s parameters remain frozen, the classification layers are fine-tuned using the incoming samples.

The idea is to rebuild the model in a way that guarantees the preservation of already acquired knowledge and simultaneously allows it to absorb the new data. The model can be rebuilt at any time necessary, for example, when a sample of a new class arrives or after each training batch. 

You can implement architectural approaches, for example, by creating dedicated, specialized subnetworks like in Progressive Neural Networks or by simply having multiple model heads (last layers), which are selected based on the input data characteristics (which can be, for example, the task label in task-incremental scenarios). 

Regularization approaches

Regularization-based methods keep the model architecture fixed during incremental training. To make the model learn new data without forgetting the past, they use techniques like knowledge distillation, loss function modification, selection of parameters that should (or should not) be updated, or just a simple regularization (which explains the name).

The general idea is to ensure parameter modification is as subtle as possible, which prevents the model from forgetting. Such methods are often relatively quick and easy to implement but, simultaneously, less effective than architectural or memory-based methods, especially in difficult class incremental scenarios, due to their inability to learn complex relationships in feature space. Examples of regularization-based methods are Elastic Weights Consolidation and Learning Without Forgetting.

The main advantage of regularization-based methods is that their implementation is almost always possible thanks to their simplicity. However, if architectural or memory-based approaches are available, the regularization-based techniques are widely used in many continual learning problems more as quickly delivered baselines rather than final solutions.

Memory-based approaches

Memory-based continual learning methods involve saving part of the input samples (and their labels in a supervised learning scenario) into a memory buffer during training. The memory can be a database, a local file system, or just an object in RAM.

The idea is to use these examples later for model training along with currently seen data to prevent catastrophic forgetting. For example, a training input batch may consist of current and randomly selected examples from memory.

These methods are very popular in solving various continual learning problems thanks to their effectiveness and simple implementation. It has been empirically shown that memory-based methods are the most effective in all three continual learning scenarios. But, of course, this technique requires constant access to past data, which is impossible in many cases.

For example, some information-extraction procedures in healthcare may require strict data-retention policies, like deleting documents from the system soon after the desired information is extracted and exported. In such a case, we cannot use a memory buffer.

Another example may be a robot vacuum cleaner trying to improve its route through a house. It takes pictures of the environment and uses continual learning to enhance the model responsible for the navigation. Since the photos show the inside of people’s houses,  they will inevitably contain sensitive, personal information. Thus, model training must happen on the robot (on-device learning), and the images should not be stored longer than necessary. Moreover, there may simply not be enough space to store a sufficient amount of data on the device to make memory-based methods effective.

How to choose the right continual learning method for your project

Across the three groups of continual learning approaches, many techniques exist. Like with model architectures and training paradigms, a project’s success depends on selecting the right ones. But how do you choose the best approach for your problem?

The rules of thumb are:
 

  1. Always start with a simple regularization-based approach. If the accuracy is sufficient, that’s great – you have a cheap and quick solution. If not, you have a valuable baseline to compare with.
  2. If you can store even a tiny fraction of the historical data, use a memory-based technique no matter what kind of model you are training.
  3. You should try the architectural approach only if you cannot adopt memory-based techniques. Implementing it will be more complicated and time-consuming, but it’s the only feasible way to go at this stage.

You can combine methods from different groups to maximize gains. Various experiments show that mixed approaches can be beneficial in many scenarios. For example, suppose you use a memory-based method but want to fine-tune personalized models for each user effectively. In that case, there is no contraindication to using a memory buffer and an interchangeable output layer. 

However, identifying a scenario and selecting a proper method is just half of success. In the next section, we’ll look into implementing it in practice.

Adopting continual learning

Who needs continual learning?

For small companies, using continual learning to make models learn from the data stream is a good practice, but for big companies, it is a necessity. Taking care of updating thousands of models simultaneously is simply not feasible.

Adopting CL in production environments is beneficial but challenging. That’s especially true when you’re starting to train a model from scratch instead of converting an existing classically trained model over to CL. Since initially you do not have access to any data samples, you don’t have a training, test, and validation set that you can use for hyperparameter tuning and model evaluation. Thus, developing an effective continual learning solution this way is often a long and iterative process.

Continual learning development stages

For this reason, a more typical approach is to start with classical training and slowly evolve the training setup towards continual learning. Chip Huyen, in her excellent book “Designing Machine Learning Systems,” distinguishes four stages of advancement:

  1. Manual, stateless retraining: There is no automation. The developer decides when model retraining is needed, and retraining always means training the model from scratch. There is no incremental training and no continual learning.
  2. Automated retraining: The model is trained from scratch every time, but training scheduling is somehow automated (e.g., through Cron), and the whole pipeline (data preparation, training) is automated. This is yet to be a continual learning process, but some crucial prerequisites have been set up.
  1. Automated, stateful training: The model is no longer trained from scratch but finetuned using only a fraction of the data given the fixed schedule, e.g., training every day on the data from the previous day. Simple regularization-based CL solutions are adopted at this stage, and it can be recognized as the first primitive version of continual learning.
  1. Continual learning: The model is trained using a more advanced CL method, achieving satisfying performance. Additional training is performed only when there is a clear need (e.g., data distribution changes or accuracy drops).

As you can see, there is a significant leap between manual, stateless retraining and CL.

Most ML systems in production today do not fully use continual learning but remain in lower stages. As you can see, getting all the way to the last stage requires gradually improving existing processes. But how can we do that effectively? What common mistakes should you avoid? In the next section, I’ve summarized some best practices to help you build continual learning solutions faster and better.

My top five tips for implementing continual learning

Precisely identify your objective

Do you want the model to adapt to new data quickly? Past knowledge isn’t that important? Or is remembering the past the priority? Does the model accuracy need to be on a certain level? Answers to these questions are fundamental and will shape your approach.

Architectural methods like Progressive Neural Networks could be a good choice if you prioritize preserving past data over learning new concepts. Freezing parameters allow the model to prevent it from Catastrophic Forgetting. If the goal is to adapt to new data as quickly as possible, a simple regularization-based method, like increasing weight updates for the most influential model parameters, can do the job.

However, if you want to balance between preserving the past and learning new knowledge, the prompt tuning method (which belongs to the architectural category) can be useful:

First, you use transfer learning to create a strong backbone model. Then, during incremental training, you freeze this model and only fine-tune an additional, tiny part of the parameters. While the backbone model is responsible for keeping past knowledge, the extra parameters allow for the effective learning of new concepts. The main benefit is that the additional parameters can be stripped off at any time, so you can always go back to the bare backbone model and recover the baseline performance when something goes wrong.

Carefully select the model architecture

Deep learning models behave differently under incremental training, even if it seems that they are very similar to each other. For example, convolutional neural networks achieve significantly better accuracy in continual learning when they use batch normalization and skip connections.

Moreover, even models with the same number of parameters may exhibit different performance depending on the layers’ architecture. If a model has many layers with relatively few parameters, we can describe it as “long.” In contrast, if a model has a small number of layers and each of them has numerous parameters, we can call it “wide.” Wider models are better for CL than longer models because the longer models are more difficult to train by the backpropagation algorithm. Small weight corrections in the first layer of the long model may have a bigger impact on the weights of the next layer and, consequently, can strongly influence weights in the last layer (snowball effect). Wider models are also harder to overfit.

Start simple, then improve

Starting a continual learning project is a daunting task. Here is the roadmap I follow in all my projects:

  1. Check if you really need continual learning. It’s crucial to be aware that adopting continual learning is a progressive process, and you might discover that you don’t need it along the way. Do not overthink your solution and only implement CL approaches if they genuinely benefit you. For example, if you have one model that has to be retrained once a year, it is probably not worth it.
  2. First, try a naive, straightforward solution. This will give you two benefits. First, you have a baseline to compare with. Second, when you improve the solution by, for example, implementing regularization or adding memory, it’s much less likely that you will overengineer it.
  3. Choose the right method for your problem. What kind of model do you use? Do you have access to past data? Do you prioritize adapting to new data or remembering the past data? Answers to these questions will shape the choice of the method (see section How to Choose the Right Continual Learning Method for Your Project).
  4. Experiment as much as you can. It is not easy to find the best (or just satisfying) solution right away, even for experienced machine learning engineers. A good habit is to experiment by simulating (production-like) continual-learning scenarios on the available data and try to tune the hyperparameters.
  5. Take the time to understand the problems. Immature continual-learning solutions are often very fragile. Poor performance may be caused by many factors, such as uncalibrated hyperparameters and choosing unsuitable CL methods or training procedures. Always try to carefully understand the problem before taking action.

Choose your tools wisely

Suppose you decided to adopt continual learning in your system, and the time has come to pick a method to implement.

You’ve seen many methods described in scientific papers that might be worth trying, but they seem time-consuming to implement. Fortunately, in most cases, there is no need to implement the method on your own.

There are a bunch of high-quality libraries out there providing ready-to-use solutions:

  • Avalanche is an end-to-end continual learning library based on PyTorch. The amazing ContinualAI community created it to provide an open-source (MIT licensed) codebase for fast prototyping, training, and evaluation of continual learning methods. Avalanche has ready-to-use methods from different groups (regularization-based, architectural, and memory-based).
  • Continuum is a library providing tools for creating continual learning scenarios from existing datasets. It is designed for PyTorch and can be used in various domains like Computer Vision and Natural Language Processing. Continuum is very mature and easy to use, making it one of the most reliable continual learning libraries.
  • Renate is a library designed by the AWS Labs. Renate supports plenty of ready-to-use methods, especially memory-based ones. But, the main advantage is the embedded hyperparameter optimization framework that can be used to increase the overall model performance with minimal effort.

If you have access to old data, don’t hesitate to use it

Memory-based methods are currently the most effective ones for incremental training. Using memory ensures significant advantages over other approaches and is relatively simpler to implement. So, if you can access even just a fraction of past data and use it for incremental training – do it!

In cases where no past data is available, rather than implementing a sophisticated continual learning method, maybe it’s worth asking firstly if there is a way to make a memory-based method applicable in another way. For example, even a memory buffer with artificially generated examples may be beneficial.

Summary

Continual learning is a fascinating concept that can help you train effective ML models incrementally. Training incrementally is crucial when a model needs to adapt to new data or be personalized.

Achieving the desired model performance is a long journey, and you’ll have to be patient as you progress toward full-scale continual learning. Remember to always precisely identify the objective and take your time carefully selecting the method that is best for your use case.

As I outlined in the Choose Your Tools Wisely section above, plenty of ready-to-use methods can make your model learn from the evolving data stream without forgetting the already acquired knowledge. Different methods fit different use cases, so don’t be afraid to experiment. I hope those tips will help you create the perfect machine-learning model!

If you are interested in the academic aspect of continual learning and want to dive into details, I recommend this excellent review paper.

Was the article useful?

Thank you for your feedback!
What topics would you like to see for your next read
Let us know what should be improved

    Thanks! Your suggestions have been forwarded to our editors