TL;DR
Vanishing or exploding gradients are common training instabilities observed in foundation models.
Real-time gradient-norm monitoring using experiment trackers like neptune.ai enables early detection and mitigation.
Implementing stabilization techniques such as gradient clipping and optimizing weight initialization and learning rate schedules improves the training convergence and stability.
As foundation models scale to billions or even trillions of parameters, they often exhibit training instabilities, particularly vanishing and exploding gradients. During the initial training phase (pre-training), it is common to observe loss spikes, which can degrade the model’s performance or render pre-training ineffective.
In this article, we investigate the underlying causes of these instabilities and cover the following questions:
- Why do gradients explode or vanish during foundation model training?
- Why are foundation models especially prone to vanishing or exploding gradients?
- How can we efficiently track gradients across layers during training?
- What are the most effective techniques to prevent the gradients from vanishing or exploding?
- How does the learning rate affect gradient stability and model convergence?
What gradient issues occur during foundation model training?
Foundation models are trained using adaptive gradient descent optimization techniques like Adam that update parameters (weights and biases) iteratively to minimize a loss function (e.g., cross-entropy).
The general update rule for gradient descent is:

where represents model parameters, η is the learning rate, and ∇0L is the gradient of the loss function L with regard to the parameters.
During training, gradient descent updates model parameters by computing the gradients of the loss function via forward and backward passes. During the forward pass, the inputs are passed through the model’s hidden layers to compute the predicted output and the loss with respect to the true label. During the backward pass, gradients are computed recursively using the chain rule to update model parameters.
As models scale in depth and complexity, two major issues arise during their training: vanishing and exploding gradients.
Vanishing gradients
The vanishing gradient problem occurs during backpropagation when the gradient of the activation function becomes very small as we move through the model’s layers.
The gradients of earlier layers are computed through repeated multiplications. For instance, based on the chain rule, the gradient of the loss with respect to the input layer depends on the chain of derivatives from the output layer to the input layer:

As the depth of the model increases, these multiplications shrink the gradients’ magnitude, causing the gradients of the initial weights to be exponentially smaller compared to the later ones. This difference in gradient magnitude causes slow convergence or halts the training process entirely, as earlier weights remain unchanged.
To understand how the gradients propagate in deep neural networks, we can examine the derivatives of the weight matrices (W) and activation functions (Φ(z)):

Using the chain rule, the gradient of the loss with regard to the first layer becomes:

In the case of an activation function like ReLU, where the derivative of the active neurons ( z l > 0) is 1 and the derivative of inactive neurons ( z l < 0) is 0, the gradient flow stops for inactive neurons. In other words, the gradients vanish where z l < 0.
Even if the majority of the neurons are active ( z l > 0), if the norm of the weight matrices W l is less than 1, then the product ∏(Φ l (z l ) W l ), for l = 2 to L will shrink exponentially as the number of layers increases. Thus, the gradients of the initial layers (∂L/∂W1) will be close to zero, and those layers will not be updated. This behaviour is very common when using ReLU as an activation function in very deep neural networks.
Exploding gradients
The exploding gradient problem is the opposite of the vanishing gradient issue. It occurs when the gradient grows exponentially during backpropagation, resulting in large changes in model parameters. This manifests as loss spikes and fluctuations, particularly in the early stages of training.
The primary cause for exploding gradients is the repeated multiplication of large weight matrices and the choice of the activation function. When the norms of the weight matrices ||W l|| and the activation function’s derivatives ||Φ ‘l (z l )|| are greater than 1, their product across layers causes the gradient to grow exponentially with the model depth. As a consequence, the model may diverge or oscillate, but never converge to a minimum.
How does foundation model training benefit from tracking layer-wise gradients?
Effectively addressing vanishing and exploding gradients in foundation model training involves three stages:
- Discovery: The first step is to discover whether there is an issue with the gradients of the foundation models during training. This is done by monitoring the norm of the gradients for each layer throughout the training process. This allows us to observe if the magnitude of the gradients is becoming very small (vanishing) or very large (exploding).
- Identifying the root cause: Once we know that there is an issue, the next step is to understand where in the model these problems originate. By tracking the evolution of the gradient norms across layers, we gain insightful information into which layer or block of layers is responsible for the gradients to diminish or explode.
- Implementing and validating solutions: Based on the insights gained from monitoring, we can make the necessary adjustments to the hyperparameters, like learning rate, or employ techniques like gradient clipping. Once implemented, we can assess the solution’s effectiveness.
Step-by-step guide to gradient-norm tracking in PyTorch
Gradient norm tracking calculates the norm of the gradients for each model layer during the backpropagation process. The L2 norm is a common choice because it provides a smooth and differentiable measure of the gradient magnitude per layer, making it ideal to detect extreme values seen in vanishing and exploding gradients.
Here, we will show a step-by-step guide on implementing gradient norm tracking in a BERT sequence classification model in PyTorch using neptune.ai for monitoring and visualization.
Editor’s note
Do you feel like experimenting with neptune.ai?
-
Request a free trial
-
See how it works: watch a 2-min explainer or a full demo
You can find the full implementation and the required dependencies in this GitHub repository.
For the experimental setup, we used the transformers and dataset libraries from Hugging Face. We selected the MRPC (Microsoft Research Paraphrase Corpus) task from the GLUE benchmark, which involves determining whether two sentences are semantically equivalent. To simulate a pretraining scenario, we initialize the BERT model with random weights.
💡 You can find the complete code on GitHub.
Step 1: Initialize Neptune for logging
For detailed instructions on installing and configuring Neptune for logging metadata, please refer to the documentation.
When initializing the Neptune run, we add descriptive tags. Tags make it easier to search and organize the experiments when tracking multiple models, datasets, or configurations.
Here, we use three tags:
- “gradient tracking” to indicate that this experiment includes gradient monitoring
- “pytorch” refers to the framework used
- “transformers” specifies the type of model architecture
import os
from random import random
from neptune_scale import Run
from getpass import getpass
os.environ["NEPTUNE_API_TOKEN"] = getpass("Enter your Neptune API token: ")
os.environ["NEPTUNE_PROJECT"] = "workspace-name/project-name"
custom_id = random()
run = Run(
experiment_name="gradient_tracking",
run_id=f"gradient-{custom_id}",
)
run.log_configs({
"learning_rate": 1e-1,
"batch_size": 1,
"optimizer": "Adam",
})
run.add_tags(["gradient_tracking", "pytorch", "transformers"])
Step 2: Define the gradient-norm logging function
Next, we define a function for tracking the gradient norm for each layer of the model.
The function is designed to calculate the L2 norm of the gradients for each named parameter (weight and bias vector) in the model. It represents the overall magnitude of the gradient for each parameter that has a gradient. This helps to identify layers where the gradients are very small (potential vanishing) or very large (potential exploding).
def log_gradient_norms(model, step, log_every_n_steps=1):
"""
Logs L2 norm of gradients for model parameters every n steps using torch.no_grad.
Args:
model (torch.nn.Module): The neural network model.
step (int): The current training step or epoch, for tracking.
log_every_n_steps (int): Log only every n steps to reduce overhead.
"""
if step % log_every_n_steps != 0:
return # Skip logging for this step
with torch.no_grad(): # Prevent building a computation graph during norm computation
for name, param in model.named_parameters():
if param.grad is not None:
# Optional: skip small/irrelevant layers if needed, e.g.,
# if not name.startswith("encoder.layer."): continue
grad_norm = param.grad.norm().item()
run.log_metrics({f"gradients/{name}": grad_norm}, step=step)
While computing the L2 norm is inexpensive, logging the gradient norm for each parameter in foundation models with billions of parameters can consume memory and slow down training. In practice, it is advisable to monitor only selected layers (e.g., key components such as attention weights, embeddings, or layer outputs), aggregate norms at the layer or block level, and reduce logging frequency (e.g., logging norms every n steps instead of every step).
Asynchronous logging tools like Neptune allow logging the metrics in parallel with the training process without holding up the main computation pipeline. This allows you to be quite liberal with what you log. Neptune’s backend is tuned for very high-throughput ingestion (millions of data points per second), so even per-parameter or per-token gradient streams won’t throttle your run.
Additionally, wrapping the gradient norm calculations within a torch.no_grad() context avoids unnecessary memory allocation and reduces the computational cost of gradient tracking, as it prevents PyTorch from keeping track of these computations for backpropagation.
Step 3: Train the model and track gradients
In this step, we train the BERT model and log training metrics such as gradient norms and the model loss using Neptune:
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=1e-1)
model.train()
for epoch in range(10):
for step, batch in enumerate(train_dataloader):
inputs = {k: v.to('cuda') for k, v in batch.items() if k in tokenizer.model_input_names}
labels = batch['labels'].to('cuda')
optimizer.zero_grad()
outputs = model(**inputs, labels=labels)
loss = outputs.loss
loss.backward()
# Log gradient norms
log_gradient_norms(model, step + epoch * len(train_dataloader))
optimizer.step()
# Log Loss to Neptune
run.log_metrics({"loss": loss.item()}, step=step + epoch * len(train_dataloader))
run.close()
Here, we used the Adam optimizer with two different learning rates, 0.1 and 10. As expected for learning rate 10, the model diverges in the very first steps, the loss explodes to NaN values quickly, as shown in the plot below. Although the loss does not explode for a learning rate of 0.1, its value is still too large to learn anything meaningful during training.
💡 You can find the complete code on GitHub.
Using gradient tracking to diagnose training issues
Once we have implemented gradient tracking, the next step is to interpret the collected data to diagnose and address training instabilities.
Let’s revisit the example from the previous section. We trained a BERT model and logged the L2 norm of gradients across model layers using Neptune. When we used a relatively large learning rate (LR = 10), the model diverged in the first steps of training. For a smaller learning rate (LR =0.1), we observed that the loss did not fluctuate, but remained high.
💡 Explore this data in a live project on Neptune.
When we now further reduce the learning rate to 0.001, the loss and the gradient norm of the last layer (classifier) do not decrease. This means that the model is not converging, and a likely cause might be vanishing gradients. To validate our hypothesis, we decreased the learning rate further to 0.00005 and observed a decrease in both the loss and the gradient norm of the last layer.
Another insight we get by observing the pooler layer is that for both choices of the learning rate (0.001 and 0.00005), the gradient norm is decreasing. This once again highlights the benefits of using the gradient tracking for each layer, as we can investigate what is happening on each layer and find out which one is not getting updated during training.
Techniques for gradient stabilization
Monitoring gradient norms and training loss provides insights into the learning dynamics of the foundation models. Real-time tracking of these metrics helps diagnose issues such as vanishing or exploding gradients, convergence issues, and layers that are not learning effectively (e.g., their gradient norm is not decreasing).
By analyzing how the gradient norm behaves for each layer and how the loss evolves over time, we can identify such issues early in the training. This enables us to incorporate techniques that stabilize and improve training.
Some of these techniques are:
- Gradient clipping: The gradient clipping method imposes a threshold on gradients during backpropagation, preventing them from becoming very small (vanishing) or extremely large (exploding).
- Layer normalization: Layer normalization is a standard component in foundation models, playing an important role in stabilizing training. It normalizes activations across features (values in the embedding vector of the token) within each token, helping to maintain consistent activation scales and improving convergence. In doing so, it indirectly mitigates issues like vanishing or exploding gradients. Although it is not manually tuned, understanding its behavior is crucial when diagnosing training issues or developing foundation models from scratch.
- Weight initialization: In deep architectures such as foundation models, weight initialization plays a critical role in the stability and convergence speed of training. Poor weight initialization can cause the gradients to vanish or explode as they propagate through many layers. To address this, several initialization strategies have been proposed:
- Xavier (Glorot) initialization aims to maintain a consistent variance of activations and gradients across layers by scaling the weights based on the number of inputs and output units. This means that the variance of the outputs of each layer should be equal to the variance of its inputs for the model to learn effectively.
- He initialization takes into account the nonlinearity of the activation functions such as ReLU, which zero out negative inputs, leading to a loss of variance in the model. To address this, He initialization sets the variance of the weights to be higher than the ones proposed by Xavier (Glorot), enabling more effective training.
Although the foundation models may use weight initialization methods tailored (modify or adapt Xavier and He initialization) to their specific architecture, understanding initializations like Xavier (Glorot) and He is important when designing or debugging such models. For instance, BERT uses a truncated normal (Gaussian) initialization with a small standard deviation.
- Activation functions: Choosing the right activation function is crucial for the effective and stable training of foundation models. ReLU is the most widely used activation function due to its simplicity and computational efficiency. However, it may lead to the “dying neuron” problem when the gradient becomes zero and the learning process is stopped.
To address this, some other activation functions are used in foundation models:- GELU (Gaussian error linear units) provides smoother activation and better empirical performance by approximating the input with a Gaussian distribution. It has been used in models like BERT and GPT.
- Swish, proposed by Google researchers, is a self-gated activation function that performs better than ReLU in very deep neural networks. It is designed to smoothly interpolate between a linear function and the ReLU function.
- LeakyReLU is an extension of ReLU that addresses the “dying neuron” issue by allowing a small, non-zero gradient for negative values, preventing neurons from becoming inactive.
- Learning rate schedules: During the early stages of training, the model weights are randomly initialized, and optimization is sensitive to the choice of learning rate. A warmup phase is commonly used to avoid unstable loss spikes caused by large gradient updates. In this phase, the learning rate is very small and gradually increases over a few initial steps.
Wrapping up
Training instabilities in large-scale models can prevent them from learning. Monitoring gradient norms across layers helps identify root causes and evaluate the effectiveness of mitigation measures.
Efficiently analyzing gradients in foundation models requires an experiment tracker that can handle a high throughput of metrics data. Neptune cannot only handle millions of requests per second but also comes with efficient visualization utilities.
Gradient clipping, layer normalization, and optimizing the learning rate and weight initialization are key methods for addressing vanishing and exploding gradients. In very deep models, where vanishing gradients are the prime concern, specialized activation functions prevent neurons from becoming inactive.