MLOps Blog

Visualizing Machine Learning Models: How to Guide and Tools

14 min
Abhishek Jha
11th May, 2023

Why do we need to visualize Machine Learning models?

If you refuse to trust decision-making to something whose process you don’t entirely understand, then why even hire people to work? No one knows how the human brain (with its hundred billion neurons!) makes decisions.” – Cassie Kozyrkov

This quote has been used by some people to criticize the recent push for explainable AI. It sounds like a valid point at first, right? But it fails to consider that we don’t want to replicate the human mind. We want to build something better.

Machine learning models are being trained with terabytes of data with the goal of increasing efficiency while making good decisions at the same time, something that humans do pretty well.

The responsibility that we’re putting on ML models means that we need to be able to make them as transparent as possible because otherwise, we can’t trust them.

To do so, we need to visualize ML models. To understand this, let’s get into the 5 W’s of visualization: Why, Who, What, When, and Where.

Check also

The Best Tools for Machine Learning Model Visualization
The Best Tools to Visualize Metrics and Hyperparameters of Machine Learning Experiments

The 5 W’s of model visualization in Machine Learning

1. Why do we want to visualize models?

Although we have discussed this a little already in the overview, let’s try to get into the specifics.


We need to understand the model’s decision-making process. The extent of this problem becomes especially clear in the case of Neural Networks.

Real-world Neural Network models have millions of parameters and extreme internal complexity, as they use many non-linear transformations during training. Visualizing such complicated models would help us build trust in self-driving cars, medical imaging models that help doctors diagnose, or satellite imagery models, which can be crucial in relief planning or security efforts.

Dig deeper

Explainability and Auditability in ML: Definitions, Techniques, and Tools

How to Compare Machine Learning Models and Algorithms

Debugging & improvements

Building machine learning models is an iterative process full of experimentation. Finding the optimal combination of hyperparameters can be quite challenging. Visualization can accelerate this process.

In turn, this can speed up the whole development process even if the model runs into some problems along the way.

Comparison & selection

The act of choosing the best model out of an ensemble of well-performing models can be simply reduced to visualizing parts of the model which offer the highest accuracy or lowest loss while ensuring the model doesn’t overfit.

Frameworks can be designed to compare different snapshots of a single model as it trains over time, i.e., comparing a model after n1 epochs and the same model after n2 epochs of training time.

Might be useful

When debugging, comparing, and visualizing models, it is really useful to use an experiment tracker.

Media intelligence company Hypefactors is using Neptune for that.

“We use Neptune for most of our tracking tasks, from experiment tracking to uploading the artifacts. A very useful part of tracking was monitoring the metrics, now we could easily see and compare those F-scores and other metrics.” – Andrea Duque, Data Scientist at Hypefactors

For more:

Teaching concepts

Perhaps teaching is where visualization is most useful for educating novice users about fundamental concepts of machine learning.

Interactive platforms can be designed where users play around with multiple datasets and toggle parameters to observe the effects on the model’s intermediate states and outputs. This could seriously help to build intuition about how models work.

2. Who should use visualization?

Data Scientists / Machine Learning Engineers

People who mainly focus on developing, experimenting with, and deploying models are the ones who will benefit the most from visualization.

Some famous tools which many practitioners already use include TensorBoard, DeepEyes, or Blocks. All these tools give users extended control over things like hyperparameter tuning, pruning unnecessary layers, and more, thus allowing their models to achieve better performance.

Model users

Visualization might be good for other stakeholders, maybe with some technical background, but mainly dealing with consuming the services of the model via an API.

Examples include Activis, a visual analytics system developed by Facebook for their own engineers to explore in-house deployed neural networks.

Such visualization tools are really useful for those who only wish to use pretrained models to get predictions for their own tasks.

Novice users

In the “Why” section, I mentioned how visualization could help new students to learn what machine learning is—this point stands true here as well.

This group can also be extended further to include inquisitive consumers who hesitate to use ML-powered applications due to fear of privacy invasion.

A couple of web-based JavaScript frameworks like ConvNetJS & TensorFlow.js have enabled developers to create highly interactive explorable explanations for models.

3. What can we visualize?

Model architecture

The first and main thing you can visualize is the model architecture. It tells us how many layers there are, in what order they’re positioned, and more.

This part also includes the computational graph that defines how a model would train, test, save to the disk and checkpoint after epoch iterations.

All of this can help developers get a better grasp of what’s going on inside their model.

Learned parameters

Parameters that are tweaked under the backpropagation phase of training fall under this category.

Visualizing weights and biases could be worthwhile to understand what the model has learned. Similarly, in Convolutional Neural Networks, we can take a look at the learned filters to see what kind of image features the model has learned.

Model metrics

Summary statistics such as loss, accuracy, and other measures of error which are computed every epoch, can be represented as a time-series over the model’s training course.

Representing a model through a set of numbers may seem a bit abstract, however, they help keep track of the model’s progress while training.

Not only do these metrics describe a single model’s performance, but they’re also crucial for comparing multiple models at once.

4. When is visualization most relevant?

During training

Using visualization while training is a good way to monitor and track model performance. There are many tools out there precisely for this (, weights, and biases, etc), we’ll discuss them in a moment.

For example, Deep View visualizes models using its own metrics for monitoring such as discriminability and density metrics, which help detect overfitting by simply observing the neuron density early in the training phase.

Another tool, Deep Eyes, identifies stable and unstable layers and neurons so users can prune their models to speed up training.

After training

There are techniques like attribution visualization, to regenerate an image with important regions highlighted, and feature visualization, to generate an entirely new image that supposedly is representative of the same class. They’re usually performed in the computer vision domain after training the model.

Some tools, like Embedding Projector, specialize in visualizing 2D and 3D embeddings produced by trained neural networks.

Similarly, as discussed earlier, tools like ActiVis, RNNVis, LSTMVis are also used after training to visualize and even compare different models.

5. Where is visualization applied?

Application domains & models

Visualization has been used a lot in domains like autonomous driving, urban planning, medical imaging to increase user trust in the models.

Visual analytics systems are being developed to understand more about harder kinds of networks like GANs, which have only been around for a couple of years but have produced remarkable results for data generation.

Examples include DGMTracker and GANViz, which focus on understanding the training dynamics of GANs to help model developers better train these complex models.

Research & development

Combining visualization with research has led to the creation of tools and frameworks for model interpretability and democratization. Another consequence of this rapidly developing area is that new work is immediately publicized and open-sourced without waiting for it to be “officially” published at some conference.

For example, the most popular libraries for implementing neural networks are open source and have consistent contributions for improving all areas of the codebase.

So far, we’ve talked about all the theoretical aspects of doing visualization, now let’s take a look at the most important one.

How can we visualize models?

When we talk about visualizing models, we’re really talking about drawing up a picture of the key components that allow models to learn and draw inference. We can get a good look inside if we visualize:

1. Model architecture

The design of a model gives a pretty good idea about how data flows within itself. Visualizing it helps keep track of what manipulations are being applied at what stage.

One popular way to do it, particularly in neural networks, is with a node link diagram where neurons are shown as nodes and edge weights as links. This method is also becoming the standard due to the increasing popularity of Tensorboard.

Apart from this, certain machine learning algorithms have inbuilt provisions if you want to peek inside. We will take a look at examples of this in the next section.

2. Model training

Monitoring and observing a number of metrics computed epoch after epoch (like loss and accuracy) help keep track of model progression during the training phase.

This can be implemented by considering metrics as time series and plotting them in line charts, you don’t need external assistance for this step.

Dig deeper

ML Experiment Tracking: What It Is, Why It Matters, and How to Implement It

A Complete Guide to Monitoring ML Experiments Live in Neptune

Deep Dive Into Error Analysis and Model Debugging in Machine Learning (and Deep Learning)

Another way is to use sophisticated tools like Tensorboard, designed specifically for this purpose. The benefit of going with a framework is that they’re very flexible, interactive, and save a whole lot of your time.

3. Model inference

Inference is the process of drawing conclusions out of a trained model. Visualizing the results help in interpreting and retracing how the model generates its estimates. A few ways to do this are:

  • Visualizing instance-level observation, where intensive analysis and scrutiny are placed on a single data instance’s transformation process throughout the network and ultimately its final output.
  • This can further be extended to identify and analyze misclassified instances with the help of confusion matrices or heatmaps. This would allow us to understand when a specific instance can fail and how it fails.
  • Visualization at this stage is a good way to do interactive experimentation—experiment with input data or hyperparameters to see how it affects the outcome. Tensorflow playground is a great example of this.

So far, we’ve looked at all the prerequisites needed to enter the world of visualization. Now, it’s time to expand the ‘How’ part a little into practicality and check out a few tools for the job.

Tools and frameworks for model visualization

Model architecture visualization

1. sklearn/dTreeViz

Decision trees are easily interpretable models, thanks to their tree-like structure. You can simply examine the conditions on the branches and trace the flow as predictions come out of the model.

There are several ways to visualize a decision tree. Let’s begin with the one offered by sklearn itself.

Making the required imports.

import matplotlib.pyplot as plt
from sklearn import tree
from sklearn.datasets import load_iris

As you can see, we’re going to use the famous iris dataset for this example.

The next step is to define and fit the tree to the data.

iris = load_iris()
X, y =,
clf = tree.DecisionTreeClassifier(max_depth=4)
clf =, y)

Now let’s plot the fitted tree.

tree.plot_tree(clf, filled=True, fontsize=10)

This is how it comes out of the other end.

Decision tree - model visualization
  • We have 4 features in our dataset, sepal length, sepal width, petal length, petal width in the same order. The root node splits the entire population based on petal length.
  • This results in a leaf node with classified samples while the remaining get split again on petal width as the associated Gini impurity index is still high.
  • This cycle goes on until we achieve homogeneous nodes with low Gini impurity index, or the MAX_DEPTH is reached.
  • To sum it up, we get a pretty decent idea of the architecture of our classification decision tree model.

Another way to visualize decision trees is to use the dTreeViz library. Not only does it work for scikit-learn trees, but it also supports XGBoost, Spark MLlib, and LightGBM trees.

Let’s see how it differs from sklearn’s functions.

First, the library needs to be installed via pip or conda, you can find the instructions here.

Once done with the installation, time to make the required imports.

from sklearn import tree
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
from dtreeviz.trees import *

Defining, fitting, and plotting the tree.

classifier = tree.DecisionTreeClassifier(max_depth=4)
iris = load_iris(),
viz = dtreeviz(classifier,,,
feature_names= iris.feature_names,
class_names=["setosa", "versicolor", "virginica"])


This is what we get.

Decision tree - model visualization
  • The obtained plot conveys a very similar meaning as the sklearn one, however it’s much more descriptive with histograms for every decision node.
  • You can switch off these plots by setting the parameter fancy=False.
  • Similarly, you can also visualize regression trees, feature-target space heatmaps, and decision boundaries with the help of this library.

2. ANN Visualizer

If you’re working on a Neural Network model which you need to visualize, this could be one of the ways to go. Let’s take a look at how we can use it to our advantage.

Similar to dVizTree, this library also relies on graphviz which needs to be installed. You can find the installation instructions here.

Making the required imports.

import keras
from keras.models import Sequential
from keras.layers import Dense
from ann.visualizer.visualize import ann_viz

Now let’s define our neural network.

network = Sequential()
network.add(Dense(units=6, activation='relu',
kernel_initializer='uniform', input_dim=7))
network.add(Dense(units=4, activation='relu',
network.add(Dense(units=1, activation='sigmoid",

Plotting the network.

ann_viz(network, view=True, title=’Example ANN’)

This comes as the output.

Neural network - model visualization
  • This gives a pretty good overview of the architecture of our defined neural network model.
  • We can tally the number of neurons in each layer with the code and see that it came out just as we wanted.
  • The only drawback to this library is that it only works with Keras.

We just saw an example of how to visualize the architecture of an Artificial Neural Network, but this is not all that this library can do. We can also use this library for visualizing a Convolutional Neural Network. Let’s see how.

First, as usual, let’s define CNN.

def build_cnn_model():

model.add(Conv2D(32, (3, 3), padding="same",
input_shape=(32, 32, 3), activation="relu"))

model.add(Conv2D(64, (3, 3), padding="same",
input_shape=(32, 32, 3),activation="relu"))

model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Dense(512, activation="relu"))
model.add(Dense(10, activation="softmax"))

return model

For visualization purposes, we keep the network small in size. Now let’s plot it and see what we get.

CNN - model visualization
  • This sure paints a pretty good picture of the design of our CNN, with all the layers described pretty well.
  • You only need to modify the network code to get the new visualization.

3. Netron

As described by its creators, Netron is a viewer tool for deep learning and machine learning models which can generate pretty descriptive visualization for the model’s architecture.

It’s a cross-platform tool, it works on Mac, Linux, and Windows, and supports a wide variety of frameworks and formats, like Keras, TensorFlow, Pytorch, Caffe, etc. How can we make use of this tool?

Being an OS-independent tool, you can either install it on your machine by following these instructions or simply use their web app which we’ll be using here.

Let’s visualize the CNN we defined for the last tool. All we need to do is save the model and upload the saved file in .h5 format or any other supported format. This is what we get:

CNN - model visualization
  • At first, it might seem similar to what we got with ANN visualizer, but there’s one big difference between the two—Netron is much more interactive.
  • We can change the orientation of the diagram to horizontal or vertical as per our own convenience.
  • Not only this, all the colored nodes are expandable, which can be clicked on to view the respective nodes’ properties and gain a better understanding. For example, when we click max_pooling2d, we get this:
CNN - model visualization
  • We can see that a number of properties like datatype, stride size, trainable etc. can be inferred from this about the clicked node, making it a bit better than our previous tool.


This tool is mainly for illustrating Neural Networks (NN) parametrically and exporting those drawings to Scalable Vector Graphics (SVG), hence the name NN-SVG.

The tool can generate figures of three types: 

  • classic Fully-Connected Neural Network (FCNN) figures, 
  • Convolutional Neural Network (CNN) figures, and 
  • Deep Neural Network figures following the style introduced in the AlexNet paper

This tool is hosted, so there’s no need for any installation. The following is an example of simple neural network architecture created with the help of this tool:

Neural network 2 - model visualization
  • We have a number of options here, like:
  1. Edge width proportional to edge weights,
  2. Edge opacity proportional to edge weights,
  3. Edge colour proportional to edge weights,
  4. Layer spacing,
  5. Manipulating architecture and weights.
  • All of these options let you create intuitive illustrations pretty quickly.

As discussed above, we can also create designs of Convolutional Neural Networks with this tool, which can look like the example below:

CNN - model visualization
  • We get a whole range of options in CNN as well, just like we did with Neural Networks.
  • You can simply manipulate architecture right there to get a new output, this tool is quite interactive and a really good choice.
  • This tool has been quite popular for creating diagrams of networks for research and publications.

5. TensorBoard

No model visualization tutorial can be complete without TensorFlow’s open-source visualization toolkit, TensorBoard. I saved it for last because it’s huge. We’ll only be discussing its use for the model architecture here, and we’ll pick it up again in the next section.

So, firstly, TensorBoard installation can be done via either of these commands.

pip install tensorboard
conda install -c conda-forge tensorboard

Let’s now load TensorBoard into our notebook via running this command in a cell.

%load_ext tensorboard

Once it’s loaded, we have to create a log directory where TensorBoard will store all its logs and will read from them in order to display various visualizations, after which TensorBoard will have to be reloaded with the changes.

%reload_ext tensorboard

Now let’s make the required imports and define our model. For this exercise, we’ll be fitting our model to the MNIST dataset.

import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard

mnist = tf.keras.datasets.mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X train, X_test = X_train / 255.0, X_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')])


Now we need to create a TensorBoard callback, which is responsible for logging all the events, and then specify our created log directory.

callbacks = [TensorBoard(log_dir=log_folder, histogram_freq=1,
 write_graph=True, write_images=True,
 update_freq='epoch', profile_batch=2)]

Finally, we fit the model to our data and pass in the callbacks so that everything can be visualized later., y_train, epochs=5,
    validation_split=0.15, callbacks=callbacks)

Now let’s load the TensorBoard window right into our jupyter notebook by running the following command:

%tensorboard --logdir={log_folder}

Next, we’ll see the following output if we navigate to the graphs tab:

Tensorboard - model visualization

So, this is how we can see our model’s architecture with TensorBoard.

We’ve explored tools and frameworks for visualizing our model architecture, now let’s move on to the next part – training visualization.

Model training visualization

1. TensorBoard

Picking up where we left off, in the last section we had fitted our Neural Network to MNIST dataset and checked out the ‘Graphs’ tab. But, it turns out that there are a lot of other tabs to explore in TensorBoard.

Let’s begin with ‘Scalars’.

As the image below shows, this tab deals with the plots of loss and accuracy computed epoch after epoch.

Tensorboard - model visualization

The next tab is ‘Images’.

This tab visualizes weights and bias. Each image has a slider, we can adjust it to display parameters at different epochs.

Next, the ‘Distributions’ tab.

It shows the distribution of weights and biases for a certain dense layer over each epoch.

Tensorboard - model visualization

The ‘Histograms’ tab does a similar thing as the ‘Distributions’ tab, just with the help of histograms.

Tensorboard - model visualization

The most interesting tab is the ‘Projector’. It can visualize any kind of vector representation, be it word embeddings or numpy array representations of images.

By default, it uses Principal Component Analysis (PCA) for plotting this visualization into 3D space, but there are options of other dimension reduction methods like UMAP, or T-SNE. In fact, you can define your own custom method.

With PCA, it looks like this.

Tensorboard - model visualization

We’ve covered all the tabs, but there are still tons of things that we can do with TensorBoard like visualizing training data, plotting confusion matrix, or hyperparameter tuning, but they’re beyond the scope of this article.

Let’s check out another tool.

2. Neptune

Neptune is a metadata store that is free to use for personal projects and can be used as an API service. 

While the coding part can be done anywhere, as long as the code is connected to Neptune, constant tracking and logging will keep going on inside Neptune’s UI, thus simplifying project management.

Let’s begin by installing the required stuff to get started:

!pip install -q neptune-client
!pip install -q neptune-contrib

We’ll be using a similar model and MNIST dataset as in the examples above, so we won’t be going over that part again.

To log metrics after every batch and epoch, let’s create a NeptuneLogger callback. This part is similar to the TensorBoard callback that we’ve created in the previous example.

from tensorflow.keras.callbacks import Callback
class NeptuneLogger(Callback):
    def on_batch_end(self, batch, logs={}):
        for log_name, log_value in logs. items():

    def on epoch_end(self, epoch, logs={}):
        for log_name, log_value in logs.items():

To connect our code to the Neptune application, we need an API token. To get that API token, you need to sign up with Neptune and create a project. The name of that project will go against the parameter project and corresponding API token against api_token.

Let’s initialize the API now.

run = neptune.init(project=YOUR_PROJECT_NAME,  api_token=YOUR_API_TOKEN)

Now let’s deal with whatever we want to log.


#log params
run["parameters/epochs"] = EPOCHS
run["parameters/batch_size"] = BATCH_SIZE

#log name and append tag 
run[ "sys/name"] = "metrics"
run[ "sys/tags"].add("demo")

Great! Now all that’s left to do is pass our NeptuneLogger as keras callback.

history =, y=y_train,
      epochs=EPOCHS, batch_size=BATCH_SIZE,
      validation_data=(x_test, y_test),

Once this last code cell has been executed, we can head over to the Neptune application’s UI to visualize whatever we logged.

Training accuracy/loss plotted against batches.

Training accuracy/loss plotted against epochs.

Validation accuracy/loss plotted against epochs.

Along with simple metrics like loss and accuracy, you can also easily plot other things, like confusion matrix or AUC-ROC curve (see example here).

Though this was just a demo with limited logging, you can imagine how easy this tool makes visualization of different aspects of a model when dealing with a project that involves consistent retraining and updating.

Compare tools

If you’re running a few experiments on your own and you’re looking for a visualization tool, TensorBoard is a good choice. Neptune is more suitable for researchers that are looking for a sophisticated tool that would allow them to dive deeper into the experimentation process. It also provides team collaboration features.

Check detailed comparison of Neptune and TensorBoard.

3. Weights & Biases

Just like Neptune, this tool also helps to track, monitor, and visualize ML models and projects. 

To get started, sign up on their website and then install & login via the following commands:

pip install wandb
wandb login

After entering the API key, you should be all set. Let’s now make the required imports for our Keras model.

import wandb
from wandb.keras import WandbCallback

Let’s initialize wandb and get our project started.


Now all we need to do is train the model we’ve been using so far and pass the WandbCallback to log metrics., y_train, validation_data=(X_test, y_test),
    callbacks=[WandbCallback()], epochs=5)

Now we can head to the UI and see what has been logged.

Wandb - model visualization

This is just a look at what it could be like. However, just like Neptune, this can be extended to plot a number of different things, as well as compare different models based on these logged metrics.

4. TensorWatch

TensorWatch is a debugging and visualization tool for data science offered by Microsoft Research. Most of the currently available tools follow a “what you see is what you log” (WYSIWYL) approach, which represents the results using a number of predefined visualizations. 

This can cause an issue with constantly changing models. TensorWatch solves this by treating everything as a stream. A simple description of how it works is:

  • When something is written in the TensorWatch stream, the values get serialised and sent to a TCP/IP socket including the file which you have specified.
  • From the Jupyter Notebook, the previously logged values will be loaded from the file and then the TCP-IP socket will listen for any other future values.
  • The visualiser then listens to the stream and renders the values as they arrive.

The only drawback of using TensorWatch is that currently it only supports the PyTorch framework.

Let’s get started by installing it.

pip install tensorwatch

Next, we need to install a Python package called regim which allows us to take PyTorch model files and run training and test epochs on specified dataset with a small amount of code.

With the regim package, we can use the train dataset to train a model over epochs, maintain few metrics, and make callbacks on events. After each epoch, it can run the trained model so far on the test dataset and maintain metrics on it as well. 

It maintains separate Watchers for train and test cycles, so we can see metrics for each separately. The port parameter specifies offset to the baseline port for its socket.

git clone
cd regim
pip install -e .

Then, run your training script from the folder where regim was installed.


With the regim package, we can use the train dataset to train a model over epochs, maintain few metrics, and make callbacks on events. After each epoch, it can run the trained model so far on the test dataset and maintain metrics on it as well. 

It maintains separate Watchers for train and test cycles, so we can see metrics for each separately. The port parameter specifies offset to the baseline port for its socket.

train = tw.WatcherClient(port=0)
test = tw.WatcherClient(port=1)

Now let’s plot a couple of metrics like train loss, train accuracy, test loss and test accuracy.

loss_stream = train.create_stream(expr='lambda d:
                                 d.metrics.batch_loss)', event_name='batch')
loss_plot = tw.Visualizer(loss_stream, vis_type='line',
                                 xtitle='Epoch', ytitle='Train Loss')

acc_stream = train.create_stream(expr='lambda d:
                                 (d.metrics.epochf, d.metrics.batch_accuracy)', event_name='batch')
acc_plot = tw.Visualizer(acc_stream, vis_type='line',
                                 host=loss_plot, xtitle='Epoch', ytitle='Train Accuracy', yrange=(0,))

test loss_stream = test.create_stream(expr='lambda d:
                                 (d.metrics.epochf, d.metrics.batch_loss)', event_name='batch')
test_loss_plot = tw.Visualizer(test_loss_stream, vis_type='line',
                                 host=loss_plot, xtitle= 'Epoch', ytitle='Test Loss', yrange=(0,))

test_acc_stream = test.create_stream(expr='lambda d:
                                 d.metrics.batch_accuracy)', event_name='batch')
test_acc_plot = tw.Visualizer(test_acc_stream, vis type='line',
                                 host=loss_plot, xtitle='Epoch', ytitle='Test Accuracy',yrange=(0,))

This is what we get:

Tensorwatch - model visualization

Similarly, we can also plot average of weight gradients in each layer via this:

grads_stream = train.create_stream(expr='lambda
                                 event_name='batch', throttle=1)

grads_plot = tw.Visualizer(grads_stream, vis_type='line',
                                 title="Weight Gradients",
                                 ytitle="Abs Mean Gradient', history_len=20)

Below is what the gradients look like. At this point it’s quite obvious that these were just a handful of things we can do with this tool, there’s more to explore. As this tool is still under development, we’ll surely see a lot more of it once it starts supporting a wider range of frameworks.

Tensorwatch - model visualization

5. Neural Network Playground

This tool is an honorary mention as it has more use in learning and educating non-expert users about the internal mechanics of a neural network.

It’s a Tensorflow-based, open-source tool offered by Google. You can simulate small neural networks right in your browser and observe results as you play around with them.

It looks like this:

Neural Network Playground - model visualization
  • You can toggle the problem type between Classification and Regression, and modify every possible aspect of the network and problem, ranging from activation functions to a number of hidden layers and from using pure dataset to a noisy one.
  • Once everything has been set, training can be started by just hitting the play button after which you can observe the formed decision boundary (how the model is segregating the different classes).
  • Model training can be quite fun to watch, and gives you intuition behind this black box.

We’ve covered a considerable number of tools for visualizing training, let’s move on to the last part of the visualization.

Model inference visualization

This mainly deals with interpreting the predictions generated by the model and gathering an idea about how and why they were reached in the first place. 

I have discussed this extensively in another article on this blog. It covers a considerable number of tools that can be a really good addition to your MLOps tool arsenal.

Wrapping up

We covered a lot of ground in this article, starting from finding answers for why we need to visualize models in the first place, to getting hands-on experience with a number of tools that can help us create visualizations. 

I hope that this article helped you understand visualization better and that next time you get stuck in a problem where visualization is the answer, you will make use of the discussed tools and knowledge gained here.

The tools and frameworks discussed here are just the most popular of the entire bunch. With research happening at such a rapid pace, you should always keep an eye out for new tools. 

If you want to see how Neptune can help with model visualization, check out:

That’s all for now. Thanks for reading!