MLOps Blog

Gumbel Softmax Loss Function Guide + How to Implement it in PyTorch

5 min
25th August, 2023

Training deep learning models has never been easier. You just define the architecture and loss function, sit back, and monitor, well, at least in simple cases. Some architectures come with inherent random components. This makes the forward pass stochastic, and your model – no longer deterministic.

In deterministic models, the output of the model is fully determined by parameter values and initial conditions.

Stochastic models have inherent randomness. The same set of parameter values and initial conditions will lead to an ensemble of different outputs.

This means you can’t sample the way you were because sampling from deterministic functions leads to the same results but with stochastic and its added randomness you can’t achieve that, the whole sampling would become non-deterministic.

You see, the backpropagation algorithm relies on having chains of continuous functions in each layer of the neural network. A lot of Neural networks fundamentally utilize discrete operations. Since sampling from discrete space isn’t the same as sampling from continuous that’s where the Gumbel-Softmax trick comes to the rescue. It not only helps sampling from discrete space operate like continuous but it keeps the stochastic nature of the node intact while also keeping the backpropagation step viable

Let’s explore these operations with examples to gain a better understanding.

Discrete operations in Deep Learning

We use discrete sampling across many areas that involve Deep Learning. For eg, In language models, where we have sequences of words or character tokens that are being sampled, where each discrete token corresponds to a word or a character. This way we’re sampling from discrete space.  

Discrete Operations
A sequence of word tokenizations demonstrating sampling from discrete space | Source

Another popular example is the LSTM recurrent neural network architecture. It has internal gating mechanisms that are used to learn long-term dependencies. Although these gating units are continuous, the operation w.r.t. the whole LSTM cell has some discrete nature.

Vanilla LSTM unit
An LSTM cell | Source

One more popular example of using discrete sampling in deep learning is the seq2seq DNC architecture. Seq2Seq-DNC uses Read / Write (discrete) operations on the external memory to store encoder-decoder states, in order to support long-range dependencies.

Operation structure of DNC memory area
Operation structure of DNC memory area | Source

These read/write operations are sampled using another neural network architecture. In a way, this neural network is sampling from a discrete space.

Now let’s take a look at the motivation and purpose behind Gumbel-Softmax.

Understanding Gumbel-Softmax

The problem Gumbel-Softmax addresses is working with discrete data generated from a categorical distribution. Let’s see the inner mechanics behind it.

Gumbel Max trick

Gumbel Max trick is a technique that allows sampling from categorical distribution during the forward pass of a neural network. It essentially is done by combining the reparameterization trick and smooth relaxation. Let’s look at how this works.

Gumbel softmax 2
Sampling from a categorical distribution by taking argmax of a combination class probabilities and Gumbel noise | Source

In this technique, if we take the class probabilities and apply the logarithmic function to each, and to each of these logits we add Gumbel noise which can be sampled by taking two logs of some uniform distribution. This step is similar to that used in the Reparametrization Trick above where we add the normally distributed noise to the mean.

After combining the deterministic and stochastic parts of the sampling process, we use the argmax function to find the class that has the maximum value for each sample. The class or sample is encoded as a one-hot vector for use by the rest of the neural network.

Now we have a way of sampling from a categorical distribution, as opposed to a continuous distribution. However, we still can’t backpropagate through argmax because the gradients that get out of it are 0 i.e. it’s not differentiable.

The paper [3] proposed a technique that replaces argmax with softmax. Let’s look into this.

Gumbel Softmax

Gumbel Softmax
Replacing argmax with softmax because softmax is differentiable(required by backpropagation) | Source: Author

In this approach, we still combine the log probabilities with Gumbel noise, but now we take the softmax over the samples instead of the argmax. 

Lambda(λ) is the softmax temperature parameter which allows us to control how closely the Gumbel-softmax distribution approximates the categorical distribution. If lambda is very small, then we get very close to a quantized categorical sample, and conversely, the Gumbel-softmax distribution becomes more uniformly distributed as the lambda increases.

Implementation of Gumbel Softmax

In this section, we’ll train a Variational Auto-Encoder on the MNIST dataset to reconstruct images. We’ll apply Gumbel-softmax in sampling from the encoder states. Let’s code!

Note: We’ll use Pytorch as our framework of choice for this implementation


Read how you can keep track of your PyTorch model training

First, let’s import the required dependencies.

import numpy as np
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import io
import pathlib
import math
irange = range

import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

# Logging metadata
import neptune
from neptune.types import File
run = neptune.init_run(project='common/pytorch-integration',

It’s always handy to define some hyper-parameters early on.

batch_size = 100
epochs = 10
temperature = 1.0
no_cuda = False
seed = 2020
log_interval = 10
hard = False # Nature of Gumbel-softmax

As mentioned earlier, we’ll utilize MNIST for this implementation. Let’s import it.

is_cuda = not no_cuda and torch.cuda.is_available()
if is_cuda:

kwargs = {'num_workers': 1, 'pin_memory': True} if is_cuda else {}

train_loader =
datasets.MNIST('./data/MNIST', train=True, download=True,
batch_size=batch_size, shuffle=True, **kwargs)
test_loader =
datasets.MNIST('./data/MNIST', train=False, transform=transforms.ToTensor()),
batch_size=batch_size, shuffle=True, **kwargs)

Now, we’ll define the Gumbel-softmax sampling helper functions.

def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    if is_cuda:
        U = U.cuda()
    return -torch.log(-torch.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature, hard=False):
    input: [*, n_class]
    return: flatten --> [*, n_class] an one-hot vector
    y = gumbel_softmax_sample(logits, temperature)

    if not hard:
        return y.view(-1, latent_dim * categorical_dim)

    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # Set gradients w.r.t. y_hard gradients w.r.t. y
    y_hard = (y_hard - y).detach() + y
    return y_hard.view(-1, latent_dim * categorical_dim)

Next, let’s define the VAE architecture and loss function.

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, qy):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False) / x.shape[0]

    log_ratio = torch.log(qy * categorical_dim + 1e-20)
    KLD = torch.sum(qy * log_ratio, dim=-1).mean()

    return BCE + KLD

class VAE_gumbel(nn.Module):
    def __init__(self, temp):
        super(VAE_gumbel, self).__init__()

        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, latent_dim * categorical_dim)

        self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)
        self.fc5 = nn.Linear(256, 512)
        self.fc6 = nn.Linear(512, 784)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        return self.relu(self.fc3(h2))

    def decode(self, z):
        h4 = self.relu(self.fc4(z))
        h5 = self.relu(self.fc5(h4))
        return self.sigmoid(self.fc6(h5))

    def forward(self, x, temp, hard):
        q = self.encode(x.view(-1, 784))
        q_y = q.view(q.size(0), latent_dim, categorical_dim)
        z = gumbel_softmax(q_y, temp, hard)
        return self.decode(z), F.softmax(q_y, dim=-1).reshape(*q.size())

Time for some more hyper-parameters.

latent_dim = 30
categorical_dim = 10  # one-of-K vector

temp_min = 0.5

ANNEAL_RATE = 0.00003

model = VAE_gumbel(temperature)

if is_cuda:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

We’ll train and test in two different ways. 

In the testing function we’ll apply the reconstruction of the image, basically to test the sampling and model efficiency on an unseen sample of data.

def train(epoch):
    train_loss = 0
    temp = temperature
    for batch_idx, (data, _) in enumerate(train_loader):
        if is_cuda:
            data = data.cuda()
        recon_batch, qy = model(data, temp, hard)
        loss = loss_function(recon_batch, data, qy)
        train_loss += loss.item() * len(data)
        if batch_idx % 100 == 1:
            temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_idx), temp_min)

        if batch_idx==0:
            reconstructed_image = recon_batch.view(batch_size, 1, 28, 28)
            grid_array = get_grid(reconstructed_image)

            run["train_reconstructed_images/{}".format('training_reconstruction_' + str(epoch))].upload(File.as_image(grid_array))
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
    run['metrics/avg_train_loss'].log(train_loss / len(train_loader.dataset))
def test(epoch):
    test_loss = 0
    temp = temperature
    for i, (data, _) in enumerate(test_loader):
        if is_cuda:
            data = data.cuda()
        recon_batch, qy = model(data, temp, hard)
        test_loss += loss_function(recon_batch, data, qy).item() * len(data)
        if i % 100 == 1:
            temp = np.maximum(temp * np.exp(-ANNEAL_RATE * i), temp_min)
        if i == 0:
            n = min(data.size(0), 8)
            comparison =[data[:n],
                                    recon_batch.view(batch_size, 1, 28, 28)[:n]])
            grid_array = get_grid(comparison)

            run["test_reconstructed_images/{}".format('test_reconstruction_' + str(epoch))].upload(File.as_image(grid_array))

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

Note: Please find the utility functions used in the code excerpt above in the notebook here.

Finally, we’ll define the event loop to run all the individual functions in conjunction.

for epoch in range(1, epochs + 1):

At the end of successful execution, you will get a reconstructed image of the MNIST samples like this:

MNIST samples

Just by seeing the contrast between the reconstructed part and the original part tells us how well the sampling from Gumbel-softmax worked. We can see the training and test loss convergence in the plot below:

You can access the complete experiment bundled with reconstructed images here and the code above here

You’ve reached the end!

The Gumbel-Softmax trick can prove super useful in discrete sampling tasks, which used to be handled in other ways. For example, NLP tasks are almost necessarily discrete – like the sampling of words, characters, or phonemes.

Future prospects

The Gumbel-softmax paper also mentioned its usefulness in Variational Autoencoders, but it’s certainly not limited to that. 

You can apply the same technique to Binary Autoencoders and other, complex Neural Networks like Generative Adversarial Networks (GAN’s). It seems limitless.

That’s it for now, stay tuned for more! Adios!



Was the article useful?

Thank you for your feedback!