The evolution of artificial intelligence and machine learning was rapid and unanticipated. These advanced models paint Van Gogh-inspired paintings of your backyard, snatch the name of the familiar tune you heard in your high school prom, and finish your meticulously worded email for your client. They are everywhere, and companies are desperately adapting their business models to cater to the growing demands and mine out potential opportunities.
Major and hasty shifts like these often lead to minor gaps that are only enlarged after the scale is multiplied; lack of robustness/slow inference time of models, when deployed to large-scale scenarios, is one of the main issues often neglected during training in a greenhouse setting with synthesized and highly supervised datasets.
While challenging, there are methods that are almost guaranteed to help narrow these training and inference-time gaps. This article dives into six ways you can manage and optimize your models for deployment and inference. Specifically, we focus on the neural networks – the architecture that is most difficult to manage due to its abundant parameters and memory required. Each method will be accompanied by examples/tutorials on how to apply this to your own problem.
This article assumes the prerequisite of understanding deep learning and neural networks, along with the capability to implement them using frameworks such as PyTorch.
May interest you
Memory management with knowledge distillation
Memory has been the main drawback of deep learning models. During training, every parameter to be updated has to withhold the gradient for the calculation of backpropagation. While the problem is slightly alleviated during inference as gradients are no longer required, the high number of parameters remains and could still be a problem during inference if the computation isn’t strong enough. Knowledge distillation is a straightforward method to tackle this.
In essence, knowledge distillation aims to transfer the knowledge learned by a complex model to a model with a much smaller spatial capacity. With millions of parameters, neural networks often exhibit the phenomenon of overparameterization, meaning that a lot of the intermediate neurons/weights are, in fact, not useful to the final prediction; such an issue makes it possible to adopt a network with a much smaller number of layers/neurons per layer and still capture the majority of the original network.
A knowledge distillation method comprises a teacher network and a student network. The teacher network, which is usually a more complex model, is trained using the entire training dataset. The high number of parameters within the network makes convergence relatively simple. Afterward, a knowledge transfer module is applied to distill the knowledge teacher learned to the student network.
How is knowledge transfer applied?
Based on the different types of knowledge to be transferred, there are multiple ways of knowledge transfer. We introduce two of the more common knowledge and their corresponding transfer methods.
- Response-based knowledge: Response-based knowledge focuses on the final output of the network. In other words, the goal is for the student to learn to output similar predictions as the teacher network. Therefore, a typical way of performing such knowledge transfer is through computing a loss based on the predictions of the teacher and the student instead of the comparisons between the student and the original ground truth. The optimization encourages the student weights to update to a space similar to the teacher’s, achieving the knowledge-transferring goal. This soft probability from the teacher often allows the student to learn better than actually using the ground truths.

- Feature-based knowledge: Instead of the output similarity, feature-based knowledge emphasizes the differences in intermediate representations between the teacher and the student. Therefore, transferring features requires the loss to be computed during the intermediate representation similarities instead of on the final output. This ensures that the processing/extraction method of the student is also similar to that of the teacher.

READ ALSO
Knowledge Distillation: Principles, Algorithms, Applications
Speed up inference using model quantization and layer fusion
The deep and wide architecture inherently requires numerous matrix operations. The multiplication, which can be computed in a parallel manner with the help of GPUs, can still be rather slow with model loading and when a large number of inferences at the same time is required. To this end, model quantization can be a great tool to subdue the issue.
What is model quantization?
Conventional models built by popular libraries such as PyTorch are usually in high floating point precision, which helps during training to get the best accuracy possible. Quantization is the technique of computing and storing parameters at a lower bit width to decrease the inference time and increase efficiency. The matrix computations can thus be more compact and high-performing under this setting.
PTQ v.s. QAT
Quantization methods are roughly two-fold:
- 1 Post-Training Quantization (PTQ), where the precision decrease occurs only after the model is trained.
- 2 Quantization Aware Training (QAT), where the training takes quantization into the account.
For PTQ, we can further split it into PTQ-dynamic and PTQ-static. PTQ-dynamic is the easiest to implement in that all the model weights are quantized ahead, and only activations are dynamically quantized over the inference procedure. If loading weights take a significant amount of time in the model inference procedure rather than the matrix multiplication itself, PTQ-dynamic is the way to go. PTQ-static, on the other hand, uses a representative dataset to find the best way to quantize both the weights and the activations.
PTQ-D implementation
# create a model instance of the model
model_fp32 = M()
# create a quantized model instance using the original model
model_int8 = torch.quantization.quantize_dynamic(
model_fp32,
{torch.nn.Linear}, # aet of layers to dynamically quantize
dtype=torch.qint8) # the target dtype for quantized weights
By this simple method above, you can directly create a quantized model after training. To showcase the capability of quantization, we can measure the before and after model size with the following code.
def print_size_of_model(model, label=""):
torch.save(model.state_dict(), "m.p")
size=os.path.getsize("m.p")
print("model: ",label,' \t','Size (KB):', size/1e3)
os.remove('temp.p')
return size
# compare the sizes
f=print_size_of_model(model_fp32,"fp32")
q=print_size_of_model(model_int8,"int8")
print("The model is {0:.2f} times smaller".format(f/q))
As an example, a linear layer of 1024 input and 256 output channels will be 3.97 times smaller after quantization. You can apply the same thing for inference time by measuring how long it takes for a random torch tensor to go through, and you will see a significant improvement with quantization too.
For more information, please visit Pytorch documentation.
Trade-off with Quantization
In short, PTQs (especially PTQ-dynamic) are much easier to implement with the current libraries, as you can train the model normally with perfect precision and then opt in for the quantization afterward. The drawback is also apparent: the accuracy of the model would likely decrease with the PTQs as the network becomes rather inaccurate with the sequential matrix multiplications.
QATs somewhat improve the aforementioned accuracy issue, but directly applying quantization in training is more challenging, and it is hard to determine what the true potential of the original model will be. Given that the computations are allowed, and the hardware is rather advanced, it would be suggested to put quantization on hold until you know the full potential of your model, then apply adequate quantization to the model by realising the tradeoffs.
Layer fusion
Another well-explored technique for increasing the model efficiency is to compress the network via layer fusion. Ways to perform layer fusion can vary dramatically, but we provide a very straightforward overview of how a sample layer fusion algorithm works.
Essentially, given a set of neural networks, you find the pairwise similarity between layers through a certain distance metric. This distance metric can be as simple as the cosine similarity of the two sets of weights. Afterward, you select the top-k similar layers, and for each layer, you simply freeze one set of weights and the gradient of weights. This will allow the efficiency of the model to increase if a lot of layers are similar to one another without jeopardizing too much accuracy.
Check the reference pseudo-algorithm.
Using ONNX library
Before you decide on any form of deployment, a critical step to optimize your model is to use the ONNX library.
ONNX, short for Open Neural Network Exchange, is an open-source ecosystem hoping to integrate complex and diverse machine learning libraries together. Nowadays, there are various deep learning frameworks one can write complex networks on – the most famous being Tensorflow and PyTorch. While similar in syntax, they are actually very different from one another and cannot be interchanged.
ONNX is a library to increase interoperability – every model from different libraries can be converted to an ONNX format, which can be easily optimized for different hardware. Depending on different deployment targets (e.g., CPU, GPU), ONNX can help optimize it accordingly.
This is particularly helpful when your machine learning model is developed from multiple teams within a company. Each team can experiment with models and libraries based on their own preference and still seamlessly combine the models together in the end for production.
For more information, check the article about benefits of the ONNX Library.
Determine the mode of deployment
After performing what you can do with previous points, it is now time to put your model to the test under the pressure of an actual use case. This step is rather tricky, as different use cases of a model may require a different type of deployment. We list several common deployment modes and their main differences.
Read also
Single-sample inference deployment
Perhaps the most common use case of a machine learning model is to open it for client requests as a service (through API or direct UI) with one sample at a time (e.g., using Google API to perform image segmentation on a particular image). In this case, inference on a single sample for real-time response is required. Therefore, if the model only initializes when the client requests a model would lead to slow running time and too many models co-existing and exceeding the memory capacity of your machine.
A standard approach to avoid this is to deploy the model as a web service and configure it to a constant standby mode. Libraries such as Flask allow such deployments that can be easily incorporated into PyTorch pipelines. In simple words, you should pre-load the existing model and weights on the web server, such that any requests received from clients are sent to the already-loaded models for inference.
A simple implementation of PyTorch models on Flask can be found here.
Note that this pipeline could still suffer from bottlenecks under a high number of requests and large data transmission. Therefore there are two major details to consider when designing a Flask or other server-based models:
- Avoid sending unnecessary data to the server end. Take image recognition as an example: if the model takes in an image size of (224,224), it would be much more viable to perform the image preprocessing at the end machine instead of sending the image in full resolution. For tasks such as detection, send the predictions as bounding box coordinates instead of sending back the image directly.
- Open up multiple models concurrently. This would avoid congestion during peak traffic hours. However, depending on the computational power of a company, too many models staying concurrently in a GPU could also lead to slower inference time in general and memory errors.
Batch deployment
On the other hand, if the clients require batches of data to be analyzed concurrently, a more feasible alternative would be to deploy the model such that it accepts batches of data before processing. An example would be a manufacturing company performing anomaly detection on the production line. Batches of images of products could be taken simultaneously to the prediction, and thus sequentially inferencing each image would actually slow the timeline down.
Today, neural network architectures are designed such that they can accept batches of inputs concurrently to expedite the matrix multiplication process. Thus, your network should accept multiple data entries and collect them as batches before feeding them into the network. Web service deployments such as Flask are still viable, but models should be put on hold when collecting entries from potentially multiple clients.
Deploying models onto edge devices
Contrary to web-service-based models, where the client has to transmit the data to the central server to obtain feedback, it is also a possible alternative to deploy identical models onto edge devices (e.g., individual mobile phones). Such a deployment method would allow the inference time to be much faster, especially for heavy data (bottleneck at transmission). However, accurate models, especially deep-learning-based models, usually require very heavy computation that phones or other edge devices can not perform. Much distillation/lighter architecture or even the sacrifice of accuracies must be compromised for this mode of deployment.
Model pruning in optimizing models
Another step that is important to consider is model pruning, which will greatly improve the model’s efficiency when put into production.
Network pruning is essentially overlaying a binary mask upon the network weights such that fewer weights are used for the prediction. Exactly which ways to mask can vary based on the pruning method, but the most common is simply pruning the weights with the smallest magnitude (i.e., the least effect on the outcome). Pruning can lead to the following effects:
- 1 Reduce latency and inference costs
- 2 Increase memory space and reduce power consumption.
Pruning in PyTorch
For libraries like PyTorch, pruning can be very straightforward with the libraries such as :
import torch.nn.utils.prune as prune
Now, given a model, you can simply call any pruning function you wish, such as:
prune.random_unstructured(module, name="weight", amount=0.3)
Pruning in Tensorflow
If you are a fan of Tensorflow instead of PyTorch, there is also a designated optimization library for you, termed the TensorFlow optimization toolkit. This toolkit introduces an important method for reducing network weights and increasing – pruning.
For more details, refer to Tensorflow Model Optimization.
Online deep learning for model optimization
Data characteristics often change over time: analysis of time-series data could no longer work with a black swan event; photos of an object evolve with better cameras. As such, the main work in optimizing models for deployment begins after the model is deployed through constant monitoring and refinement.
While straightforward retraining by incorporating the newly gathered data solves the problem, the computation to retrain from scratch is ultimately expensive and often time-consuming (training a model on ImageNet-22k could take days, even with the advanced GPUs). This section introduces an alternative to retraining from scratch: online deep learning.
What is online deep learning?
Traditional deep learning is performed through the optimization methods such as stochastic gradient descent with the entire set of data. Online deep learning aims to tackle the limitation where only the newly-coming data is available and perform Online Gradient Descent (OGD) on the currently trained model to adapt to the new features learned. Additional methods, such as hedge backpropagation, have also been introduced under the premises of OGD to further improve the online deep learning approach by also determining the proper depths of the model.
Data privacy issues
Using the newly-retrieved data, however, could still cause privacy issues. Customers of your service may not want their data to be sent back to central servers for further model optimization. To this end, a branch of learning algorithms, namely federated learning, is introduced.

Federated learning aims to perform the training on the edge devices (or customer end) and use that to update the central model. A common approach is illustrated as the following:
- 1 The central server-side model copies the model to the edge devices.
- 2 The edge device models are optimised using online learning algorithms with given data.
- 3 The models are sent back to the server to optimise the central model.
There are multiple variants regarding how this federated learning algorithm could be performed, but all in all, they prevent the transmission of confidential data and still improve the central model accordingly.
End note
This article serves as an end-to-end guide to help bridge your machine learning models to the best state for deployment and inference. We began by introducing knowledge distillation and quantization to improve the memory and speed capacity of the network, as well as ONNX to optimize the network to different hardware.
Afterward, we provided an introduction to the different modes of deployment to consider and a brief introduction to pruning with Tensorflow. Finally, we introduced error detection methods when putting the model into production, as well as online learning and federated learning mechanisms to continue refining your model to the best capability possible. Every scenario is different, but hopefully, this article could make the daunting process of optimizing deep learning models for deployment slightly smoother.
References
- https://neptune.ai/blog/must-do-error-analysis
- https://www.tensorflow.org/model_optimization
- https://pytorch.org/docs/stable/quantization.html
- https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html
- https://medium.com/trueface-ai/two-benefits-of-the-onnx-library-for-ml-models-4b3e417df52e