We Raised \$8M Series A to Continue Building Experiment Tracking and Model Registry That “Just Works”

# Gumbel Softmax Loss Function Guide + How to Implement it in PyTorch

Training deep learning models has never been easier. You just define the architecture and loss function, sit back, and monitor, well, at least in simple cases. Some architectures come with inherent random components. This makes the forward pass stochastic, and your model – no longer deterministic.

In deterministic models, the output of the model is fully determined by parameter values and initial conditions.

Stochastic models have inherent randomness. The same set of parameter values and initial conditions will lead to an ensemble of different outputs.

This means you can’t sample the way you were because sampling from deterministic functions leads to the same results but with stochastic and its added randomness you can’t achieve that, the whole sampling would become non-deterministic.

You see, the backpropagation algorithm relies on having chains of continuous functions in each layer of the neural network. A lot of Neural networks fundamentally utilize discrete operations. Since sampling from discrete space isn’t the same as sampling from continuous that’s where the Gumbel-Softmax trick comes to the rescue. It not only helps sampling from discrete space operate like continuous but it keeps the stochastic nature of the node intact while also keeping the backpropagation step viable

Let’s explore these operations with examples to gain a better understanding.

## Discrete operations in Deep Learning

We use discrete sampling across many areas that involve Deep Learning. For eg, In language models, where we have sequences of words or character tokens that are being sampled, where each discrete token corresponds to a word or a character. This way we’re sampling from discrete space.

Another popular example is the LSTM recurrent neural network architecture. It has internal gating mechanisms that are used to learn long-term dependencies. Although these gating units are continuous, the operation w.r.t. the whole LSTM cell has some discrete nature.

One more popular example of using discrete sampling in deep learning is the seq2seq DNC architecture. Seq2Seq-DNC uses Read / Write (discrete) operations on the external memory to store encoder-decoder states, in order to support long-range dependencies.

These read/write operations are sampled using another neural network architecture. In a way, this neural network is sampling from a discrete space.

Now let’s take a look at the motivation and purpose behind Gumbel-Softmax.

## Understanding Gumbel-Softmax

The problem Gumbel-Softmax addresses is working with discrete data generated from a categorical distribution. Let’s see the inner mechanics behind it.

### Gumbel Max trick

Gumbel Max trick is a technique that allows sampling from categorical distribution during the forward pass of a neural network. It essentially is done by combining the reparameterization trick and smooth relaxation. Let’s look at how this works.

In this technique, if we take the class probabilities and apply the logarithmic function to each, and to each of these logits we add Gumbel noise which can be sampled by taking two logs of some uniform distribution. This step is similar to that used in the Reparametrization Trick above where we add the normally distributed noise to the mean.

After combining the deterministic and stochastic parts of the sampling process, we use the argmax function to find the class that has the maximum value for each sample. The class or sample is encoded as a one-hot vector for use by the rest of the neural network.

Now we have a way of sampling from a categorical distribution, as opposed to a continuous distribution. However, we still can’t backpropagate through argmax because the gradients that get out of it are 0 i.e. it’s not differentiable.

The paper [3] proposed a technique that replaces argmax with softmax. Let’s look into this.

### Gumbel Softmax

In this approach, we still combine the log probabilities with Gumbel noise, but now we take the softmax over the samples instead of the argmax.

Lambda(λ) is the softmax temperature parameter which allows us to control how closely the Gumbel-softmax distribution approximates the categorical distribution. If lambda is very small, then we get very close to a quantized categorical sample, and conversely, the Gumbel-softmax distribution becomes more uniformly distributed as the lambda increases.

### Implementation of Gumbel Softmax

In this section, we’ll train a Variational Auto-Encoder on the MNIST dataset to reconstruct images. We’ll apply Gumbel-softmax in sampling from the encoder states. Let’s code!

Note: We’ll use Pytorch as our framework of choice for this implementation

### CHECK ALSO

First, let’s import the required dependencies.

import numpy as np
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import io
import pathlib
import math
irange = range

import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

import neptune.new as neptune
from neptune.new.types import File
run = neptune.init(project='common/pytorch-integration',
api_token='ANONYMOUS')

It’s always handy to define some hyper-parameters early on.

batch_size = 100
epochs = 10
temperature = 1.0
no_cuda = False
seed = 2020
log_interval = 10
hard = False # Nature of Gumbel-softmax

As mentioned earlier, we’ll utilize MNIST for this implementation. Let’s import it.

is_cuda = not no_cuda and torch.cuda.is_available()
torch.manual_seed(seed)
if is_cuda:
torch.cuda.manual_seed(seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if is_cuda else {}

transform=transforms.ToTensor()),
batch_size=batch_size, shuffle=True, **kwargs)
datasets.MNIST('./data/MNIST', train=False, transform=transforms.ToTensor()),
batch_size=batch_size, shuffle=True, **kwargs)

Now, we’ll define the Gumbel-softmax sampling helper functions.

def sample_gumbel(shape, eps=1e-20):
U = torch.rand(shape)
if is_cuda:
U = U.cuda()
return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature):
y = logits + sample_gumbel(logits.size())
return F.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits, temperature, hard=False):
"""
ST-gumple-softmax
input: [*, n_class]
return: flatten --> [*, n_class] an one-hot vector
"""
y = gumbel_softmax_sample(logits, temperature)

if not hard:
return y.view(-1, latent_dim * categorical_dim)

shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
y_hard = (y_hard - y).detach() + y
return y_hard.view(-1, latent_dim * categorical_dim)

Next, let’s define the VAE architecture and loss function.

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, qy):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False) / x.shape[0]

log_ratio = torch.log(qy * categorical_dim + 1e-20)
KLD = torch.sum(qy * log_ratio, dim=-1).mean()

return BCE + KLD

class VAE_gumbel(nn.Module):
def __init__(self, temp):
super(VAE_gumbel, self).__init__()

self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, latent_dim * categorical_dim)

self.fc4 = nn.Linear(latent_dim * categorical_dim, 256)
self.fc5 = nn.Linear(256, 512)
self.fc6 = nn.Linear(512, 784)

self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()

def encode(self, x):
h1 = self.relu(self.fc1(x))
h2 = self.relu(self.fc2(h1))
return self.relu(self.fc3(h2))

def decode(self, z):
h4 = self.relu(self.fc4(z))
h5 = self.relu(self.fc5(h4))
return self.sigmoid(self.fc6(h5))

def forward(self, x, temp, hard):
q = self.encode(x.view(-1, 784))
q_y = q.view(q.size(0), latent_dim, categorical_dim)
z = gumbel_softmax(q_y, temp, hard)
return self.decode(z), F.softmax(q_y, dim=-1).reshape(*q.size())

Time for some more hyper-parameters.

latent_dim = 30
categorical_dim = 10  # one-of-K vector

temp_min = 0.5

ANNEAL_RATE = 0.00003

model = VAE_gumbel(temperature)

if is_cuda:
model.cuda()

We’ll train and test in two different ways.

In the testing function we’ll apply the reconstruction of the image, basically to test the sampling and model efficiency on an unseen sample of data.

def train(epoch):
model.train()
train_loss = 0
temp = temperature
for batch_idx, (data, _) in enumerate(train_loader):
if is_cuda:
data = data.cuda()
recon_batch, qy = model(data, temp, hard)
loss = loss_function(recon_batch, data, qy)
loss.backward()
train_loss += loss.item() * len(data)
optimizer.step()
if batch_idx % 100 == 1:
temp = np.maximum(temp * np.exp(-ANNEAL_RATE * batch_idx), temp_min)

if batch_idx==0:
reconstructed_image = recon_batch.view(batch_size, 1, 28, 28)
grid_array = get_grid(reconstructed_image)

if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
loss.item()))

print('====> Epoch: {} Average loss: {:.4f}'.format(

def test(epoch):
model.eval()
test_loss = 0
temp = temperature
for i, (data, _) in enumerate(test_loader):
if is_cuda:
data = data.cuda()
recon_batch, qy = model(data, temp, hard)
test_loss += loss_function(recon_batch, data, qy).item() * len(data)
if i % 100 == 1:
temp = np.maximum(temp * np.exp(-ANNEAL_RATE * i), temp_min)
if i == 0:
n = min(data.size(0), 8)
comparison = torch.cat([data[:n],
recon_batch.view(batch_size, 1, 28, 28)[:n]])
grid_array = get_grid(comparison)

print('====> Test set loss: {:.4f}'.format(test_loss))
run['metrics/avg_test_loss'].log(test_loss)

Note: Please find the utility functions used in the code excerpt above in the notebook here.

Finally, we’ll define the event loop to run all the individual functions in conjunction.

for epoch in range(1, epochs + 1):
train(epoch)
test(epoch)

At the end of successful execution, you will get a reconstructed image of the MNIST samples like this:

Just by seeing the contrast between the reconstructed part and the original part tells us how well the sampling from Gumbel-softmax worked. We can see the training and test loss convergence in the plot below:

You can access the complete experiment bundled with reconstructed images here and the code above here

## You’ve reached the end!

The Gumbel-Softmax trick can prove super useful in discrete sampling tasks, which used to be handled in other ways. For example, NLP tasks are almost necessarily discrete – like the sampling of words, characters, or phonemes.

## Future prospects

The Gumbel-softmax paper also mentioned its usefulness in Variational Autoencoders, but it’s certainly not limited to that.

You can apply the same technique to Binary Autoencoders and other, complex Neural Networks like Generative Adversarial Networks (GAN’s). It seems limitless.

That’s it for now, stay tuned for more! Adios!

## How to Keep Track of Experiments in PyTorch Using Neptune

4 mins read | Aayush Bajaj | Posted January 19, 2021

Machine Learning development seems a lot like conventional software development since both of them require us to write a lot of code. But it’s not! Let us go through some points to understand this better.

• Machine Learning code doesn’t throw errors (of course I’m talking about semantics), the reason being, even if you configured a wrong equation in a neural network, it’ll still run but will mess up with your expectations. In the words of Andrej Karpathy“Neural Networks fail silently”.
• Machine Learning code/project heavily relies on the reproducibility of results. That means if a hyperparameter is nudged or there’s a change in training data then it can affect the model’s performance in many ways. This means you’ve to jot down every change in hyperparameter and training data to be able to reproduce your work.
When the network is small this can be done in a text-file but what if it’s a bigger project with 10s or 100s of hyperparameters? text-file not so easy now huh!
• Increased complexity in Machine Learning projects means increased complex branching which has to be tracked and stored for future analysis.
• Machine Learning also requires heavy computation that comes at a cost. You definitely don’t want your cloud costs to skyrocket.

Tracking experiments in an organized way helps with all of these core issues. Neptune is a complete tool that helps individuals and teams to track their experiments smoothly. It presents a host of features and presentation options that helps in tracking and collaboration easier.