8 Creators and Core Contributors Talk About Their Model Training Libraries From PyTorch Ecosystem

PyTorch

8 Creators and Core Contributors Talk About Their Model Training Libraries From PyTorch Ecosystem

Jakub Czakon

Senior Data Scientist

I started using Pytorch to train my models back in early 2018 with 0.3.1 release. I got hooked by the Pythonic feel, ease of use and flexibility.

It was just so much easier to do things in Pytorch than in Tensorflow or Theano.But something I missed was the Keras-like high-level interface to PyTorch and there was not much out there back then.

Fast-forward to 2020, and we have 6 high-level training APIs in the PyTorch Ecosystem.

  • But which one should you choose?
  • What are the pros and cons of using each one?

I thought: who can explain the differences between those libraries better than the authors themselves?

I picked up my proverbial phone and asked them to write an article with me. They all agreed and this is how this post was created!

So, I’ve asked authors to talk about the following aspects of their libraries:

  • Philosophy of the project
  • API structure
  • The learning curve for new users
  • Build-in features (what you get out-of-the-box)
  • Extension capabilities (simplicity of integration in research)
  • Reproducibility
  • Distributed training
  • Productionalization
  • Popularity

… and they really did answer thoroughly 🙂

You can jump to the library that you are interested in or go straight to my subjective comparison at the end.

Benjamin Bossan

Core Contributor

skorch logo

The philosophy behind skorch development can be summarized as follows:

  • follow the sklearn API
  • don’t hide PyTorch
  • don’t reinvent the wheel
  • be hackable

These principles laid out the design space within which we operate. Regarding the scikit-learn API, it presents itself, most obviously, in how you train and predict:

from skorch import NeuralNetClassifier

net = NeuralNetClassifier(...)
net.fit(X_train, y_train)
net.predict(X_test)
from skorch import NeuralNetClassifier

net = NeuralNetClassifier(...)
net.fit(X_train, y_train)
net.predict(X_test)

Because skorch is using this simple and well-established API everyone should be able to start using it very quickly.

But the sklearn integration goes deeper than calling “fit” and “predict”. You can seamlessly integrate your skorch model within sklearn `Pipeline`s, use sklearn’s numerous metrics (no need to re-implement F1, R², etc.), and use it with GridSearchCV.

When it comes to parameter sweeps: you can use any other hyperparameter search strategy as long as there is a sklearn-compatible implementation.

We are especially proud that you can search on almost any hyper-parameter without additional work. For example, if your module has an initialization parameter called num_units, you can grid search that parameter right away.

Here is a list of things you can grid search out-of-the-box:

  • any parameter on your Module (number of units and layers, nonlinearity, dropout rate, …)
  • optimizer (learning rate, momentum…)
  • criterion
  • DataLoader (batch size, shuffling, …)
  • callbacks (any parameter, even on your custom callbacks)

This is how it looks like in code:

from sklearn.model_selection import GridSearchCV

params = {
    'lr': [0.01, 0.02],
    'max_epochs': [10, 20],
    'module__num_units': [10, 20],
    'optimizer__momentum': [0.6, 0.9, 0.95],
    'iterator_train__shuffle': [True, False],
    'callbacks__mycallback__someparam': [1, 2, 3],
}

net = NeuralNetClassifier(...)
gs = GridSearchCV(net, params, cv=3, scoring='accuracy')
gs.fit(X, y)

print(gs.best_score_, gs.best_params_)
from sklearn.model_selection import GridSearchCV

params = {
    'lr': [0.01, 0.02],
    'max_epochs': [10, 20],
    'module__num_units': [10, 20],
    'optimizer__momentum': [0.6, 0.9, 0.95],
    'iterator_train__shuffle': [True, False],
    'callbacks__mycallback__someparam': [1, 2, 3],
}

net = NeuralNetClassifier(...)
gs = GridSearchCV(net, params, cv=3, scoring='accuracy')
gs.fit(X, y)

print(gs.best_score_, gs.best_params_)

As far as I’m aware, no other framework provides this flexibility. On top of that, by using the dask parallel backend, you can distribute the hyper-parameter search across your cluster without too much hassle.

Using the mature sklearn API, skorch users can avoid the boilerplate code that is typically seen when writing train loops, validation loops, and hyper-parameter search in pure PyTorch.

From the PyTorch side, we decided not to hide the backend behind an abstraction layer, as is the case in keras, for example. Instead, we expose numerous components known from PyTorch. As a user, you can use PyTorch’s Dataset (think torchvision, including TTA), DataLoader, and learning rate schedulers. Most importantly, you can use PyTorch Modules with almost no restrictions.

We thus made a conscious effort to re-use as many existing features from sklearn and PyTorch as possible instead of re-inventing the wheel. This makes skorch easy to use on top of your existing codebase or to remove it after your initial experimentation phase without any lock-in effect.

For instance, you can replace the neural net with any sklearn model or you can extract the PyTorch module and use it without skorch.

On top of re-using existing features, we added some of our own. Most notably, skorch works with many common data types out-of-the-box. On top of Datasets, you can use:

  • numpy arrays,
  • torch tensors,
  • pandas DataFrames,
  • Python dictionaries holding heterogeneous data,
  • external/custom datasets like ImageFolder from torchvision.

We’ve put extra effort to make these work well with sklearn.

Additionally, we implemented a simple yet powerful callback system, which you can use to adapt most of skorch’s behavior to your liking. Some of the callbacks that we provide are:

  • learning rate schedulers,
  • scoring functions (using custom or sklearn metrics),
  • early stopping,
  • checkpointing,
  • parameter freezing,
  • and TensorBoard and Neptune integration.

If this is not enough to satisfy your customization needs, we took pains to facilitate implementing your own callbacks or your own model trainers. Our documentation contains examples of how to implement custom callbacks and custom trainers, modifying every possible behavior right down to the training step.

The philosophy of not re-inventing the wheel should make skorch easy to learn for anyone who is familiar with sklearn and PyTorch. And since we designed skorch around customization and flexibility, it shouldn’t be too hard to master. To learn more about skorch check out these examples and notebooks .

Skorch is geared towards, and used in, production. We addressed some common issues regarding productionalization, specifically:

  • we make sure to be backward compatible and to give a sufficiently long deprecation period where necessary.
  • you can train on GPU and serve on CPU,
  • you can pickle a whole sklearn Pipeline containing the skorch model for later re-use.
  • we provide a helper function to turn your training code into a command line script that exposes all your model parameters, including their documentation, as command line arguments, with just three lines of extra code

That being said, I have implemented, or know people who have implemented, more research-y stuff, like GANs and numerous types of semi-supervised learning techniques. This does require more profound knowledge of skorch, though, so you might have to dig deeper in the docs or ask us for pointers on github.

I personally haven’t come across anyone using skorch with reinforcement learning, but I would like to hear what experience people had with that.

Since our initial release of skorch in the summer of 2017, the project has matured a lot and an active community has grown around it. In a typical week, a handful of issues are opened on github or a question is asked on stackoverflow. We answer most questions within a day, and if there is a good feature request or bug report, we try to guide the reporter towards implementing it themselves.

This way, we have had more than 20 contributors over the project’s lifetime, with 3 of them being regulars, which means the project’s health is not dependent on a single person.

The big difference between skorch and some other higher-level frameworks, say fastai, is that skorch doesn’t come “batteries-included”. That means, it’s up to the user to implement their own modules or to use the modules of one of the many existing collections (say, torchvision). Skorch provides the skeleton, but you have to bring the meat.

When not to use Skorch

  • super custom PyTorch code, possibly reinforcement learning
  • backend agnostic code (switch between PyTorch, tensorflow, …)
  • there is no need at all for the sklearn API
  • avoid a very slight performance overhead

When to use skorch

  • gain sklearn API and all associated benefits like hyper-parameter search
  • most PyTorch workflows just work
  • avoid boilerplate, standardize code
  • use some of the many utilities discussed above

Philosophy

The idea behind the Catalyst is quite simple:

  • collect all the technical, dev-heavy, Deep Learning stuff in a framework,
  • make it easy to re-use boring day-to-day components,
  • focus on research and hypothesis testing in our projects.

To make that happen we looked at a typical Deep Learning project, which usually has the following structure:

for stage in stages:
    for epoch in epochs:
        for dataloader in dataloaders:
            for batch in dataloader:
                handle(batch)
for stage in stages:
    for epoch in epochs:
        for dataloader in dataloaders:
            for batch in dataloader:
                handle(batch)

If you think about it, most of the time, all you need to do is specify the handle method for the new model and how batches of data should be fed to that model. Why then, so much of our time is spent implementing pipelines and debugging training loops rather than developing something new or testing a hypothesis?

We realized that it is possible to separate the engineering from the research so that  we can invest our time once in the high-quality, reusable engineering backbone and use it across all the projects.

That is how Catalyst was born: an Open Source PyTorch framework, that allows you to write compact but full-features pipelines, abstracts engineering boilerplate away and lets you focus on the main part of your project.

Our mission at Catalyst.Team is to use our software engineering and deep learning expertise to standardize workflows and enable cross-domain communication between deep learning and reinforcement learning researchers.

We believe that reduced development friction and free flow of ideas will lead to future breakthroughs in DL and such an R&D Ecosystem will help make that happen.

The learning curve

Catalyst can be easily adopted by both DL newcomers and seasoned experts thanks to two APIs:

  • Notebook API, which was developed with a focus on easy experimentation and Jupyter Notebooks usage – to start your path into reproducible DL research.
  • Config API, which mostly focuses on scalability and CLI interface – to bring the power of DL/RL even on large clusters.

When it comes to PyTorch user experience we really want to keep it as simple as possible:

  • You define your loaders, model, criterion, optimizer, and scheduler as you usually would:
import torch

# data
loaders = {"train": ..., "valid": ...}

# model, criterion, optimizer
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
import torch

# data
loaders = {"train": ..., "valid": ...}

# model, criterion, optimizer
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
  • and you pass those PyTorch objects to Catalyst Runner
from catalyst.dl import SupervisedRunner

# experiment setup
logdir = "./logdir"
num_epochs = 42

# model runner
runner = SupervisedRunner()

# model training
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
    verbose=True,)
from catalyst.dl import SupervisedRunner

# experiment setup
logdir = "./logdir"
num_epochs = 42

# model runner
runner = SupervisedRunner()

# model training
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    logdir=logdir,
    num_epochs=num_epochs,
    verbose=True,)

Clearly decoupled engineering from deep learning with almost no boilerplate. This is how we feel deep learning code should look like.

To get started with both APIs you can follow our tutorials and pipelines or if you don’t want to choose, just check out the most common ones: classification and segmentation.

Design and Architecture

The most interesting part about Notebook and Config API is that they use the same “backend” logicExperiment, Runner, State and Callback abstractions, which are the core features of Catalyst.

  • Experiment: an abstraction that contains information about the experiment – a model, a criterion, an optimizer, a scheduler, and their hyperparameters. It also contains information about the data and transformations used. In general, the Experiment knows what you would like to run.
  • Runner: a class that knows how to run an experiment. It contains all the logic of how to run the experiment, stages (another distinctive feature of Catalyst), epoch and batches.
  • State: some intermediate storage between Experiment and Runner that saves the current state of the Experiments – model, criterion, optimizer, schedulers, metrics, loggers, loaders, etc
  • Callback: a powerful abstraction that lets you customize your experiment run logic. To give users maximum flexibility and extensibility we allow callback execution anywhere in the training loop:
on_stage_start
    on_epoch_start
       on_loader_start
           on_batch_start
           # ... 
       on_batch_end
    on_epoch_end
on_stage_end

on_exception
on_stage_start
    on_epoch_start
       on_loader_start
           on_batch_start
           # ... 
       on_batch_end
    on_epoch_end
on_stage_end

on_exception

By implementing these methods you can make any additional logic possible.

As a result, you can implement any Deep Learning pipeline in a few lines of code (and after Catalyst.RL 2.0 release – Reinforcement Learning pipeline), combining it from available primitives (thanks to the community, their number is growing every day).

Everything else (Models, Criterions, Optimizers, Schedulers) are pure PyTorch primitives. Catalyst does not create any wrappers or abstractions on top but rather makes it easy to reuse those building blocks between different frameworks and domains.

Extension capabilities / Simplicity of integration in research

Thanks to flexible framework design and Callbacks-mechanism, Catalyst is easily extendable for a large number of DL-based projects. You can check out our Catalyst-powered repositories on awesome-catalyst-list.

If you are interested in Reinforcement Learning – there are a large number of RL-based repos and competition solutions also. To compare Catalyst.RL with other RL frameworks you could check out Open Source RL list.

Other built-in features (what you get out of the box)

Knowing that you can extend it easily gives comfort but there are a ton of features that you get out-of-the-box. Some of them include:

  • Based on a flexible callback system, Catalyst has easily integrated such common Deep Learning best practices, such as gradient accumulation, gradient clipping, weight decay correction, top-K best checkpoints saving, tensorboard integration, and many other useful day-to-day deep learning utils.
  • Thanks to our contributors and contrib modules, Catalyst has access to all recent SOTA features, like AdamW, OneCycle, SWA, Ranger, LookAhead, and many other research developments.
  • Moreover, we integrate with such popular libraries like Nvidia apex, Albumentations, SMP, transformers, wandb, and neptune.ai just out of the box to make your research more user-friendly. Thanks to such integrations, Catalyst has full support for test-time augmentations, mixed precision, and distributed training.
  • For the industry needs, we also have framework-wise support for PyTorch tracing which makes putting models in production easier. Furthermore, we deploy predefined Catalyst-based docker images with each release for easier integration.
  • Finally, we support additional solutions for both model serving – ReAction (industry-oriented) and experiments monitoring – Alchemy (research-oriented).

Everything is integrated into the library and covered by CI tests (we have a dedicated gpu-server for that). And thanks to Catalyst scripts, you can schedule a large number of experiments and run them in parallel over all available GPUs from the command line (check catalyst-parallel-run for more info).

Reproducibility

We’ve put a lot of work to make experiments that you run with Catalyst reproducible. Thanks to library-wise determinism Catalyst-based experiments are reproducible not only between server runs on one server but also between several runs over different servers and different hardware parts (with docker encapsulation, of course). See experiments here if interested.

Moreover, Reinforcement Learning experiments are also reproducibility-oriented (as RL far as RL can be reproducible). For example, with synchronous experiment runs, you can achieve very close performance, thanks to determinism in sampled trajectories. This is notoriously hard and as far as I am aware Catalyst has the most reproducible RL pipelines out there.

To achieve this new level of reproducibility in DL and RL we had to create several additional features:

  • Full source code dumping: thanks to Experiments, Runner and Callbacks abstractions, it’s quite easy to save these primitive for further usage.
  • Catalyst source code dumpling: with such feature even working with the dev version of Catalyst, you can always reproduce experiment results.
  • Environment versioning: Catalyst dumps pip and conda packages versions (it can be later used to define your docker images)
  • Finally, Catalyst supports several monitoring tools, like Alchemy, Neptune.ai, Wandb to store all your experiment metrics and additional info for better research progress tracking and reproducibility.

Thanks to those library-wise solutions, you can be sure that the pipelines you implement in Catalyst are reproducible with all the experiment logs and checkpoints saved for future reference.

Distributed training

Based on our integrations, Catalyst already has native support for distributed training. Moreover, we support Slurm training and working on better Kubernetes integration for both DL and RL pipelines.

Productionalization

Now that we know how Catalyst helps with deep learning research we can talk about deploying trained models to production.

As was already mentioned, Catalyst supports model tracing out-of-the-box. It lets you convert PyTorch models (that use Python code) to TorchScript model (that has everything integrated). TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.

Additionally, to help Catalyst users deploy their pipelines into production systems, Catalyst.Team has a Docker Hub with pre-build Catalyst-based images( including fp16 support).

Moreover, to help researchers bring their ideas into production and real-world applications, we’ve created Catalyst.Ecosystem:

  • Reaction: our own PyTorch Serving solution with sync/async API, batch mode support, quest, and all other typical backends that you would expect from a well-designed production system.
  • Alchemy: our monitoring tools for experiment tracking, model comparison and research results sharing.

Popularity

Since the first pypi release 12 months ago Catalyst has gained ~1.5k stars on Github and over 100k downloads. We are proud to be part of such an Open Source Ecosystem and extremely grateful to all our users and contributors for constant support and feedback.

One of the online communities that was especially helpful was ods.ai: one of the largest slack channels for Data Scientists and Machine learning practitioners in the world (40k+ users). Without their ideas and feedback, Catalyst wouldn’t get where it is today.

Special thanks to our early-adopters,

that make it all worth it.

Since the beginning of the development of the Сatalyst, a lot of people have influenced it in a lot of different ways. As a token of my appreciation a HUGE THANK YOU to: I want to express personal thanks to:

Thanks to all that support, Catalyst has become a part of Kaggle docker image, was added to the PyTorch Ecosystem and now we are developing our own DL R&D Ecosystem to accelerate your research and production needs.

To read more about Catalyst.Ecosystem, please check our vision and project manifesto.

Finally, we are always happy to help our Catalyst.Friends: companies/startups/research labs, who are already using Catalyst or are considering using it for their next project.

Thanks for reading, and…Break the cycle – use Catalyst!

When to use Catalyst

  • To have flexible and reusable codebase without boilerplate.You want to share your expertise with other researchers from different Deep Learning areas.
  • Boost your research speed with Catalyst.Ecosystem.

When not to use Catalyst

  • You have only started your deep learning path – in this way low-level PyTorch is a great introduction.
  • You want to create very specific, custom, pipelines with a bunch of irreproducible tricks 🙂

Sylvain Gugger

Core Contributor

fastai logo

Note:

What follows is about the version 2 of fastai that will be released in July 2020. You can preview it here and it is documented here. If you read this post after it has been released, it will be in the main repository and will be documented there.

Fastai is a deep learning library which provides:

  • practitioners: with high-level components that can quickly and easily provide state of the art results in standard deep learning domains,
  • researchers: with low-level components that can be mixed and matched to build new things.

It aims to do both things without substantial compromises in ease of use, flexibility, or performance.

This is possible thanks to a carefully layered architecture. It expresses common underlying patterns of many deep learning and data processing techniques in terms of decoupled abstractions. What is important is that these abstractions can be expressed clearly and concisely which makes fastai approachable and rapidly productive, but also deeply hackable and configurable.

A high-level API offers customizable models with sensible defaults, which is built on top of a hierarchy of lower-level building blocks.

This article covers a representative subset of the features of the library. For details, see our the fastai paper, and the documentation.

API

When talking about fastai API one needs to distinguish High and Middle/Low-level API.We will talk about both in the following sections.

High-level API

The high-level API is very useful to beginners and practitioners who are mainly interested in applying pre-existing deep learning methods.

It offers concise APIs for main application areas:

  • vision,
  • text,
  • tabular
  • time-series analysis,
  • recommendation (collaborative filtering)

These APIs choose intelligent default values and behaviors based on all available information.

For instance, fastai provides a Learner class which brings together architecture, optimizer, and data, and automatically chooses an appropriate loss function where possible.

To give another example, generally, a training set should be shuffled, and a validation set should not be shuffled. fastai provides a single Dataloaders class which automatically constructs validation and training data loaders with these details already handled.

To see how those “clear and concise code” principles in action let’s fine-tune an imagenet model on the Oxford IIT Pets dataset and achieve close to state-of-the-art accuracy within a couple of minutes of training on a single GPU:

from fastai.vision.all import *

path = untar_data(URLs.PETS)
dls = ImageDataloaders.from_name_re(path=path, bs=64,
    fnames = get_image_files(path/"images"), path = r'/([^/]+)_\d+.jpg$',
    item_tfms=RandomResizedCrop(450, min_scale=0.75), 
    batch_tfms=[*aug_transforms(size=224, max_warp=0.), 
                Normalize.from_stats(*imagenet_stats)])

learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(4)
from fastai.vision.all import *

path = untar_data(URLs.PETS)
dls = ImageDataloaders.from_name_re(path=path, bs=64,
    fnames = get_image_files(path/"images"), path = r'/([^/]+)_\d+.jpg$',
    item_tfms=RandomResizedCrop(450, min_scale=0.75), 
    batch_tfms=[*aug_transforms(size=224, max_warp=0.), 
                Normalize.from_stats(*imagenet_stats)])

learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(4)

This is not an excerpt. These are all of the lines of code necessary for this task. Each line of code does one important task, allowing the user to focus on what they need to do, rather than minor details:

from fastai.vision.all import * 
from fastai.vision.all import * 

imports all the necessary pieces from the library. It’s important to note that the library has been designed carefully to avoid these styles of imports cluttering the namespace.

path = untar_data(URLs.PETS)
path = untar_data(URLs.PETS)
downloads a standard dataset from the fast.ai datasets collection (if not previously downloaded) to a configurable location, extracts it (if not previously extracted), and returns a pathlib.Path object with the extracted location.
dls = ImageDataloaders.from_name_re(path=path, bs=64,
    fnames = get_image_files(path/"images"), pat = r'/([^/]+)_\d+.jpg$',
    item_tfms=RandomResizedCrop(450, min_scale=0.75), 
    batch_tfms=[*aug_transforms(size=224, max_warp=0.), 
    Normalize.from_stats(*imagenet_stats)])
dls = ImageDataloaders.from_name_re(path=path, bs=64,
    fnames = get_image_files(path/"images"), pat = r'/([^/]+)_\d+.jpg$',
    item_tfms=RandomResizedCrop(450, min_scale=0.75), 
    batch_tfms=[*aug_transforms(size=224, max_warp=0.), 
    Normalize.from_stats(*imagenet_stats)])
sets up the Dataloaders. Note the separation of item level and batch level transforms:
  • item transforms are applied to individual images on the CPU
  • batch transforms are applied to a mini batch on the GPU (if available).
aug_transforms() selects a set of data augmentations. As always in fastai, a default that works well across a variety of vision datasets is chosen but can be fully customized if needed.
learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn = cnn_learner(dls, resnet34, metrics=error_rate)
Creates a Learner, which combines an optimizer, a model, and the data to train on. Each application (vision, text, tabular) has a customized function that creates a Learner, which automatically handles whatever details it can for the user. For instance, in this image classification problem, it will :
  • download an ImageNet-pretrained model, if not already available,
  • remove the classification head of the model,
  • replace it with a head appropriate for this particular dataset,
  • set appropriate optimizer, weight decay, learning rate, and so forth
learn.fine_tune(4)
learn.fine_tune(4)
fine-tunes the model. In this case, it is using the 1-cycle policy, which is a recent best practice for training deep learning models but is not widely available in other libraries. A lot of things happen under the hood in .fine_tune():
  • annealing both the learning rates and the momentums,
  • printing metrics on the validation set,
  • displaying results in an HTML or console table
  • recording losses and metrics after every batch and so forth.
  • A GPU will be used if one is available.
  • It will first train the head for one epoch while the body of the model is frozen, then fine-tunes for as many epochs given (here 4) using discriminative learning rates.

One of the strengths of the fastai library is how consistent the API is across applications.

For example, fine-tuning a pretrained model on the IMDB dataset (a text classification task) using ULMFiT can be done in 6 lines of code:

from fastai2.text.all import *

path = untar_data(URLs.IMDB)
dls = TextDataloaders.from_folder(path, valid='test')
learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)
learn.fine_tune(4, 1e-2)
from fastai2.text.all import *

path = untar_data(URLs.IMDB)
dls = TextDataloaders.from_folder(path, valid='test')
learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)
learn.fine_tune(4, 1e-2)
Users get a very similar experience in other domains like tabular, time series or recommendation systems. Once a Learner has been trained, you can explore the results with the command learn.show_results(). How those results are presented depends on the application, in vision you get labeled pictures, in text you get a dataframe summarizing samples, targets and predictions. In our pets classification example you would get something like this:
fastai vision

In the IMDb classification problem, you’d get something like this:

fastai text

Another important high-level API component is the data block API, which is an expressive API for data loading. It is the first attempt we are aware of, to systematically define all of the steps necessary to prepare data for a deep learning model, and give users a mix and match recipe book for combining these pieces (which we refer to as data blocks).

Here is an example of how to use the data block API to get the MNIST dataset ready for modeling:

mnist = DataBlock(
    blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), 
    get_items=get_image_files, 
    splitter=GrandparentSplitter(),
    get_y=parent_label)
dls = mnist.databunch(untar_data(URLs.MNIST_TINY), batch_tfms=Normalize)
mnist = DataBlock(
    blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), 
    get_items=get_image_files, 
    splitter=GrandparentSplitter(),
    get_y=parent_label)
dls = mnist.databunch(untar_data(URLs.MNIST_TINY), batch_tfms=Normalize)

Mid and low-level API

In the previous section, you saw how you can get a lot done quickly with the high-level api which has a ton of out-of-the-box functionalities. However, there are situations, when you need to tweak things or extend what is already there.

This is where middle and low-level APIs come into the picture:

  • mid-level API provides the core deep learning and data-processing methods for each of these applications,
  • low-level API provide a library of optimized primitives and functional and object-oriented foundations, which allows the mid-level to be developed and customized.

The training loop can be customized using theLearner novel two-way callback system. It allows gradients, data, losses, control flow, and anything else to be read and changed at any point during training.

There is a rich history of using callbacks to allow for customization of numeric software, and today nearly all modern deep learning libraries provide this functionality. However, fastai’s callback system is the first that we are aware of that supports the design principles necessary for complete two-way callbacks:

  • A callback should be available at every single point during training which gives users full flexibility. Every callback should be able to access every piece of information available at that stage in the training loop, including hyper-parameters, losses, gradients, input and target data, and so forth ;
  • Every callback should be able to modify all these pieces of information, at any time before they are used,

All the tweaks of the training loop (different schedulers, mixed-precision training, reporting on TensorBoard, wandb, neptune, or equivalent, MixUp, oversampling strategies, distributed training, GAN training…) are implemented in callbacks that the end-user can mix and match with their own, making it easier to experiment with things and do ablation studies. Convenience methods are there to add those callbacks for the user, making training in mixed precision as easy as saying

learn = learn.to_fp16()
learn = learn.to_fp16()

or training in a distributed environment as easy as

learn = learn.to_distributed()
learn = learn.to_distributed()

fastai also provides a new, generic optimizer abstraction that allows recent optimization techniques, like LAMB, RAdam or AdamW, to be implemented in a few lines of code.

It is possible thanks to refactoring optimizer abstractions into two basic pieces:

  • stats, which track and aggregate statistics such as gradient moving averages
  • steppers, which combine stats and hyper-parameters to “step” the weights using some function.

This foundation has allowed us to write most of fastai’s optimizers in 2-3 lines of code, while in other popular libraries that would take you 50+.

There are many other mid-tier and low-level APIs that make it easy for researchers and developers to build new methods on top of a fast and flexible foundation.

The library is already in wide use in research, industry, and teaching. We have used it to create a complete, and very popular deep learning course: Practical deep learning for coders (the first video of the last iteration has 256k views). 

The repository has 16.9k stars and is used in more than 2,000 projects at the time of writing. The community is very active on the fast.ai forum, be it to clarify points of the course that are unclear, help with debugging or team up to tackle a new deep learning project.

When to use fastai

  • The goal is to have something easy enough for beginners but flexible enough for researchers/practitioners.

When not to use fastai

  • The only thing I can think of is that you wouldn’t use fastai to serve in production a model you trained in a different framework, since we don’t deal with that aspect.

Victor Fomin

Core Contributor

ignite logo

Pytorch Ignite is a high-level library that helps with training neural networks in PyTorch. Since its beginning in 2018, our goal has been to:

“make the common things easy and the hard things possible”.

Why use Ignite?

Ignite’s high level of abstraction assumes little about the type of model or multiple models that user is training. We only require the user to define the closure to be run in the training and optional validation loop. It gives users a lot of flexibility and allows them to use Ignite in tasks such as co-training multiple models (i.e. GANs) or tracking multiple losses and metrics in your training loop

Ignite concepts and API

There are a few core objects in the Ignite’s API that you need to learn:

  • Engine: the essence of the library
  • Events & Handlers: interaction with the Engine (e.g. early stopping, checkpoints, logging)
  • Metrics: out-of-the-box metrics for various tasks

We will present some basics to understand the main ideas but feel free to dig deeper into examples in the repository.

Engine

It simply loops over provided data, executes a processing function and returns a result.

A Trainer is an Engine with model’s weights update as processing function.

from ignite.engine import Engine

def update_model(trainer, batch):
    model.train()
    optimizer.zero_grad()
    x, y = prepare_batch(batch)
    y_pred = model(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(update_model)
trainer.run(data, max_epochs=100)
from ignite.engine import Engine

def update_model(trainer, batch):
    model.train()
    optimizer.zero_grad()
    x, y = prepare_batch(batch)
    y_pred = model(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(update_model)
trainer.run(data, max_epochs=100)
An Evaluator (object to validate model) is an Engine with on-line metric computation logic as processing function.
from ignite.engine import Engine

total_loss = []
def compute_metrics(_, batch):
    x, y = batch
    model.eval()
    with torch.no_grad():
        y_pred = model(x)
        loss = criterion(y_pred, y)
        total_loss.append(loss.item())

    return loss.item()

evaluator = Engine(compute_metrics)
evaluator.run(data, max_epochs=1)
print(f”Loss: {torch.tensor(total_loss).mean()}”)
from ignite.engine import Engine

total_loss = []
def compute_metrics(_, batch):
    x, y = batch
    model.eval()
    with torch.no_grad():
        y_pred = model(x)
        loss = criterion(y_pred, y)
        total_loss.append(loss.item())

    return loss.item()

evaluator = Engine(compute_metrics)
evaluator.run(data, max_epochs=1)
print(f”Loss: {torch.tensor(total_loss).mean()}”)

This code can silently train a model and compute total loss.

In the next section we will see how to make the training and validation more user-friendly.

Events & Handlers

In order to improve the flexibility of Engine and allow users to interact at each step of the run, we introduced events and handlers. The idea is that users could execute a custom code inside of the training loop as an event handler, similar to callbacks in other libraries.

fire_event(Events.STARTED)

while epoch < max_epochs:
    fire_event(Events.EPOCH_STARTED)
    # run once on data
    for batch in data:
        fire_event(Events.ITERATION_STARTED)

        output = process_function(batch)

        fire_event(Events.ITERATION_COMPLETED)
    fire_event(Events.EPOCH_COMPLETED)
fire_event(Events.COMPLETED)
fire_event(Events.STARTED)

while epoch < max_epochs:
    fire_event(Events.EPOCH_STARTED)
    # run once on data
    for batch in data:
        fire_event(Events.ITERATION_STARTED)

        output = process_function(batch)

        fire_event(Events.ITERATION_COMPLETED)
    fire_event(Events.EPOCH_COMPLETED)
fire_event(Events.COMPLETED)
At each fire_event call, all its event handlers are executed. For example, users may want to set up some run-dependent variables at the beginning of training (Events.STARTED) and update the learning rate on each iteration (Events.ITERATION_COMPLETED). With Ignite the code will look like this:
train_loader = …
model = …
optimizer = …
criterion = ...
lr_scheduler = …

def process_function(engine, batch):
    # … user function to update model weights

trainer = Engine(process_function)

@trainer.on(Events.STARTED)
def setup_logging_folder(_):
    # create a folder for the run
    # set up some run dependent variables

@trainer.on(Events.ITERATION_COMPLETED)
def update_lr(engine):
    lr_scheduler.step()

trainer.run(train_loader, max_epochs=50)
train_loader = …
model = …
optimizer = …
criterion = ...
lr_scheduler = …

def process_function(engine, batch):
    # … user function to update model weights

trainer = Engine(process_function)

@trainer.on(Events.STARTED)
def setup_logging_folder(_):
    # create a folder for the run
    # set up some run dependent variables

@trainer.on(Events.ITERATION_COMPLETED)
def update_lr(engine):
    lr_scheduler.step()

trainer.run(train_loader, max_epochs=50)

The cool thing with handlers (vs “callback” interfaces) is that it can be any function with the correct signature (we only require the first argument to be engine), e.g. lambda, simple function, class method etc. We do not require to inherit from an interface and override possibly its abstract methods.

trainer.add_event_handler(
    Events.STARTED, lambda engine: print("Start training"))

# attach handler with args, kwargs
mydata = [1, 2, 3, 4]


def on_training_ended(engine, data):
    print("Training is ended. mydata={}".format(data))


trainer.add_event_handler(
    Events.COMPLETED, on_training_ended, mydata)
trainer.add_event_handler(
    Events.STARTED, lambda engine: print("Start training"))

# attach handler with args, kwargs
mydata = [1, 2, 3, 4]


def on_training_ended(engine, data):
    print("Training is ended. mydata={}".format(data))


trainer.add_event_handler(
    Events.COMPLETED, on_training_ended, mydata)

Built-in events filtering

There are cases when users would like to execute the code periodically/once or with a custom rule like:

  • run the validation every 5 epochs,
  • store a checkpoint every 1000 iterations,
  • change a variable on 20th epoch,
  • log gradients on the first 10 iterations.
  • etc.

Ignite provides such flexibility to separate “the code to execute” from the logic “when to execute the code”.

For example, to run the validation every 5 epochs it is simply coded: 

@trainer.on(Events.EPOCH_COMPLETED(every=5))
def run_validation(_):
    # run validation
@trainer.on(Events.EPOCH_COMPLETED(every=5))
def run_validation(_):
    # run validation

Similarly, to change some training variable once on 20th epoch:

@trainer.on(Events.EPOCH_STARTED(once=20))
def change_training_variable(_):
    # ...
@trainer.on(Events.EPOCH_STARTED(once=20))
def change_training_variable(_):
    # ...

More generally, user can provide its own events filtering function:

def first_x_iters(_, event):
    if event < 10:
        return True
    return False

@trainer.on(Events.ITERATION_COMPLETED(event_filter=first_x_iters))
def log_gradients(_):
# …
def first_x_iters(_, event):
    if event < 10:
        return True
    return False

@trainer.on(Events.ITERATION_COMPLETED(event_filter=first_x_iters))
def log_gradients(_):
# …

Out-of-the-box handlers

Ignite provides a list of handlers and metrics to simplify user’s code:

  • Checkpoint : to save training checkpoints (composed of trainer, model(s), optimizer(s), lr scheduler(s), etc)to save best models (by validation score)
  • EarlyStopping: stops the training if no progress is done (by validation score)
  • TerminateOnNan: stops the training if NaN is encountered
  • Optimizer Parameters Scheduling: concatenate, add a warm-up, setup linear or cosine annealing, linear piecewise scheduling of any optimizer parameter (lr, momentum, betas, …)
optimizer parameter scheduling
  • Logging to common platforms: TensorBoard, Visdom, MLflow, Polyaxon or Neptune (batch losses, metrics GPU mem/utilization, optimizer parameters and more).
ignite logging

Metrics

Ignite also provides a list of out-of-the-box metrics for various tasks: Precision, Recall, Accuracy, Confusion Matrix, IoU etc, ~20 regression metrics

For example, below we compute validation accuracy on the validation dataset:

from ignite.metrics import Accuracy

def compute_predictions(_, batch):
    # …
    return y_pred, y_true

evaluator = Engine(compute_predictions)
metric = Accuracy()
metric.attach(evaluator, "val_accuracy")
evaluator.run(val_loader)
> evaluator.state.metrics[“val_accuracy”] = 0.98765
from ignite.metrics import Accuracy

def compute_predictions(_, batch):
    # …
    return y_pred, y_true

evaluator = Engine(compute_predictions)
metric = Accuracy()
metric.attach(evaluator, "val_accuracy")
evaluator.run(val_loader)
> evaluator.state.metrics[“val_accuracy”] = 0.98765

Go here and here to see the full list of available metrics.

Ignite metrics have this cool property that users can compose its own metric by using basic arithmetical operations or torch methods:

precision = Precision(average=False)
recall = Recall(average=False)
F1_per_class = (precision * recall * 2 / (precision + recall))
F1_mean = F1_per_class.mean()  # torch mean method
F1_mean.attach(engine, "F1")
precision = Precision(average=False)
recall = Recall(average=False)
F1_per_class = (precision * recall * 2 / (precision + recall))
F1_mean = F1_per_class.mean()  # torch mean method
F1_mean.attach(engine, "F1")

Library structure

The library is composed of two main modules:

  • Core module contains bases like Engine, metrics, some essential handlers. It has PyTorch as the only dependency.
  • Contrib module may depend on other libraries (e.g. scikit-learn, tensorboardX, visdom, tqdm, etc) and can potentially have backward compatibility breaking changes between versions.

Both modules are largely covered by unit tests.

Extension capabilities / Simplicity of integration in research

We believe that our event/handler system is rather flexible and gives people the ability to interact with every part of the training process. Because of that, we’ve seen Ignite being used to train GANs (we provide two basic examples to train DCGAN and CycleGAN) or Reinforcement Learning models.

According to Github’s “Used by”, Ignite was used by researchers for their papers:

  • BatchBALD: Efficient and Diverse Batch Acquisition for Deep Bayesian Active Learning, github
  • A Model to Search for Synthesizable Molecules, github 
  • Localised Generative Flows,  github 
  • Extracting T Cell Function and Differentiation Characteristics from the Biomedical Literature, github 

Because of those (and other research projects) we strongly believe that Ignite gives you enough flexibility to do deep learning research.

Integrations with other libraries/frameworks

Ignite plays nicely with other libraries or frameworks if their features do not overlap. Some cool integrations that we have include:

  • hyperparameter tuning with Ax (Ignite example).
  • hyperparameter tuning with Optuna (Optuna example).
  • logging to TensorBoard, Visdom, MLflow, Polyaxon, Neptune (Ignite’s code), Chainer UI (Chainer’s code).
  • Training with mixed precision using Nvidia Apex (Ignite’s examples).

Reproducibility

We’ve put a lot of effort into making Ignite training reproducible:

  • Ignite’s Engine automatically handles the random states and when it is possible forces the data loaders to provide same data samples on different runs;
  • Ignite integrates with experiment tracking systems like MLflow, Polyaxon, Neptune. This helps to keep track of software, parameter, and data dependencies of ML experiments;
  • We provide several examples and “references” (inspired from torchvision) of reproducible training on vision tasks (e.g. classification on CIFAR10, ImageNet, and segmentation on Pascal VOC12).

Distributed training

Distributed training is also supported by Ignite but we leave up to the user to set up its type of parallelism: model or data.

For example, in data distributed configuration users are required to correctly set up the distributed process group, wrap the model, use distributed sampler etc. Ignite handles metrics computation: reduction of the value across all processes.

We provide several examples (e.g. distributed CIFAR10) to display how to use Ignite in a distributed configuration.

Popularity

At the moment of writing, Ignite had about 2.5k stars and according to Github’s “Used by” feature is used by 205 repositories. Some honorable mentions are:

Thomas Wolf from HuggingFace also left some awesome feedback for the library in one of his blog articles (Thanks, Thomas!):

“Using the awesome PyTorch ignite framework and the new API for Automatic Mixed Precision (FP16/32) provided by NVIDIA’s apex, we were able to distill our +3k lines of competition code in less than 250 lines of training code with distributed and FP16 options!”

  • Deep-Reinforcement-Learning-Hands-On-Second-Edition by Max LapanThis is a book on Deep Reinforcement Learning by Max Lapan wherein the second edition examples are made with Ignite.
  • Project MONAI: AI Toolkit for Healthcare Imaging. This project primarily focused on the healthcare research to develop DL models for medical imaging uses Ignite for end-to-end training.

For other use-cases, please take a look at Ignite’s github page and its “Used by”.

When to use Ignite

  • Remove boilerplate and standardize your code using highly customizable modules of Ignite’s API.

  • When you require factorized code but don’t want to sacrifice on flexibility to support your complicated training strategies

  • Use the rich array of utilities like metrics, handlers, and loggers available to evaluate/debug your model with ease

When not to use Ignite

  • When there is a super custom PyTorch code where Ignite’s API is overhead.
  • When completely satisfied by pure PyTorch API or another high-level library

Thank you for reading! Pytorch-Ignite presented to you with love by the PyTorch community!

Philosophy

PyTorch Lightning is a very lightweight wrapper on PyTorch which is more like a coding standard than a framework. The format allows you to get rid of a ton of boilerplate code while keeping it easy to follow.

The use of hooks, standard across every part of the training, means you can override any part of the internal functionality down to how the backward pass is done – it is extremely flexible.

The result is a framework that gives researchers, students, and production teams the ultimate flexibility to try crazy ideas without having to learn yet another framework while automating away all the engineering details.

Lightning has two additional, more ambitious motivations: reproducibility of research and democratization of best practices in the deep learning community.

Notable features

  • Train on CPU, GPU or TPUs without changing your code!
  • Only library to support TPU training (Trainer(num_tpu_cores=8))
  • Trivial multi-node training
  • Trivial multi-GPU training
  • Trivial 16 bit precision support
  • Built-in performance profiler (Trainer(profile=True))
  • Tons of integrations with libraries like tensorboard, comet.ml, neptune.ai, etc… (Trainer(logger=NeptuneLogger(...)))

Team

Lightning has 90+ contributors and a core team of 8 contributors who make sure the project moves forward lightning fast.

Documentation

Lightning documentation is extremely thorough yet simple and easy to use.

API

At the core, Lightning has an API that centers around two objects, the Trainer and the LightningModule.

The Trainer abstracts away all the engineering details and the LightningModule captures all the science/research code. This decoupling makes the research code more readable and allows it to run on arbitrary hardware.

pytorch lightning

LightningModule

All the research logic goes into LightningModule.

For example, in a cancer detection system, this part would handle the main things like the object detection model, data loaders for medical images etc.

It groups the core ingredients you need to build a deep learning system:

  • The computations (init, forward).
  • What happens in the training loop (training_step).
  • What happens in the validation loop (validation_step).
  • What happens in the testing loop (test_step).
  • The optimizer(s) to use (configure_optimizers).
  • The data to use (train, test, val dataloaders).

Let’s take a look at the example from the docs and unpack what is happening there.

import pytorch_lightning as pl


class MNISTExample(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss']
                                for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss']
                                for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers
        # and learning_rate schedulers
        # (LBFGS it is automatically supported,
        # no need for closure function)
        return torch.optim.Adam(self.parameters(), lr=0.02)

    @pl.data_loader
    def train_dataloader(self):
        # REQUIRED
        return DataLoader(
            MNIST(os.getcwd(), train=True, download=True,
                  transform=transforms.ToTensor()), batch_size=32)

    @pl.data_loader
    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(
            MNIST(os.getcwd(), train=True, download=True,
                  transform=transforms.ToTensor()), batch_size=32)

    @pl.data_loader
    def test_dataloader(self):
        # OPTIONAL
        return DataLoader(
            MNIST(os.getcwd(), train=False, download=True,
                  transform=transforms.ToTensor()), batch_size=32)
import pytorch_lightning as pl


class MNISTExample(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss']
                                for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss']
                                for x in outputs]).mean()
        tensorboard_logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers
        # and learning_rate schedulers
        # (LBFGS it is automatically supported,
        # no need for closure function)
        return torch.optim.Adam(self.parameters(), lr=0.02)

    @pl.data_loader
    def train_dataloader(self):
        # REQUIRED
        return DataLoader(
            MNIST(os.getcwd(), train=True, download=True,
                  transform=transforms.ToTensor()), batch_size=32)

    @pl.data_loader
    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(
            MNIST(os.getcwd(), train=True, download=True,
                  transform=transforms.ToTensor()), batch_size=32)

    @pl.data_loader
    def test_dataloader(self):
        # OPTIONAL
        return DataLoader(
            MNIST(os.getcwd(), train=False, download=True,
                  transform=transforms.ToTensor()), batch_size=32)

As you can see, the LightningModule builds on top of pure PyTorch code and simply organizes them in nine methods:

  • __init__(): Defines our model or multiple models, and initializes the weights
  • forward():  You can think of it as your standard PyTorch forward method but with additional flexibility to define what you want to happen at the prediction/inference level.
  • training_step(): Defines what happens in the training loop. It combines a forward pass, loss calculation, and any other logic you want to execute during training.
  • validation_step(): Defines what happens in the validation loop. For example, you can go calculate loss or accuracy for each batch and store them in the logs.
  • validation_end(): Everything that you want to happen after the validation loop ends. For example, you may want to calculate the average loss or accuracy over validation batches
  • test_step(): What you want to happen to each batch at inference time. You can put your Test Time Augmentation logic or other things here.
  • test_end(): Similarly to validation_end, you can use it to aggregate the batch results calculated during test_step
  • configure_optimizers(): initialize an optimizer or multiple optimizers
  • train/val/test_dataloader(): returns your PyTorch DataLoaders for train, validation, and test sets.

Since every PytorchLightning system needs to implement those methods it is really easy to see exactly what is happening in the research.

For example, to understand what a paper is doing, all you have to do is look at the training_step of the LightningModule!

This readability and a close mapping between the core research concepts and implementation lies at the core of Lightning.

Trainer

This is where the engineering part of deep learning happens.

In the cancer detection system, this might mean how many GPUs you use, when you save checkpoints when you stop training, etc… These are details that make up a lot of the “secret sauce” of research which are standard best practices across deep learning projects (ie: not hugely relevant to cancer detection).

Notice that the LightningModule has nothing about GPUs or 16-bit precision or early stopping or logging or anything like that. All of that is automatically handled by the trainer.

from pytorch_lightning import Trainer

model = MNISTExample()

# most basic trainer, uses good defaults
trainer = Trainer()    
trainer.fit(model)
from pytorch_lightning import Trainer

model = MNISTExample()

# most basic trainer, uses good defaults
trainer = Trainer()    
trainer.fit(model)

That’s all it takes to train this model! The trainer handles everything for you including:

  • Early stopping
  • Automatic logging to Tensorboard (or comet, mlflow, neptune, etc…)
  • Auto checkpointing
  • And more (we’ll talk about that in the next sections)

All of this is free out of the box!

The learning curve

Since LightningModule is simply reorganizing pure Pytorch objects and everything is “out in the open” it is trivial to refactor your PyTorch code to the Lightning format.

For more information about making the switch from pure PyTorch to Lightning read this article.

Build-in features (what you get out of the box)

Lightning gives a ton of advanced features out-of-the-box.For instance, it takes a one-liner to use things like:

  • Multi-gpu training
Trainer(gpus=8)
Trainer(gpus=8)
  • TPU training
Trainer(num_tpu_cores=8)
Trainer(num_tpu_cores=8)
  • Multi-node training
Trainer(gpus=8, num_nodes=8, distributed_backend=’ddp’)
Trainer(gpus=8, num_nodes=8, distributed_backend=’ddp’)
  • Gradient Clipping
Trainer(gradient_clip_val=2.0)
Trainer(gradient_clip_val=2.0)
  • Accumulated Gradients
Trainer(accumulate_grad_batches=12)
Trainer(accumulate_grad_batches=12)
  • 16-bit precision
Trainer(use_amp=True)
Trainer(use_amp=True)
  • Truncated back-propagation through time
Trainer(truncated_bptt_steps=3)
Trainer(truncated_bptt_steps=3)
  • and a lot more.

If you would like to see the full list of free-magic features go here.

Extension capabilities / Simplicity of integration in research

Having a bunch of in-built functionalities is great but for researchers, it’s crucial to not have to learn yet another library, and directly control key parts of research such as data-processing without having other abstractions operate on those.

This flexible format allows for the most freedom in training and validating. This interface should be thought of as a system, not as a model. The system might have multiple models (GANs, seq-2-seq, etc…) or just one model, such as this simple MNIST example.

Thus researchers are free to try as many crazy things as they want, and ONLY have to worry about the LightningModule.

But maybe you need even MORE flexibility. In this case, you can do things like:

  • Change how the backward step is done.
  • Change how 16-bit is initialized.
  • Add your own way of doing distributed training.
  • Add Learning rate schedulers.
  • Use multiple optimizers.
  • Change the frequency of optimizer updates.
  • And many many more things.

Under the hood, everything in Lightning is implemented as hooks that can be overridden by the user. This makes EVERY single aspect of training highly configurable — which is exactly the flexibility a research or production team needs.

But wait you say… this is too simple for your use case? No worries, Lightning was designed while doing research at NYU and Facebook AI Research for my PhD to be as flexible as possible for researchers.

Here are some examples:

  • Need your own backward pass? Override this hook:
def backward(self, use_amp, loss, optimizer):
    if use_amp:
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
    else:
        loss.backward()
def backward(self, use_amp, loss, optimizer):
    if use_amp:
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
    else:
        loss.backward()
  • Need your own amp init? Override this hook:
def configure_apex(self, amp, model, optimizers, amp_level):
    model, optimizers = amp.initialize(
        model, optimizers, opt_level=amp_level,
    )

    return model, optimizers
def configure_apex(self, amp, model, optimizers, amp_level):
    model, optimizers = amp.initialize(
        model, optimizers, opt_level=amp_level,
    )

    return model, optimizers
  • Want to go as deep as adding your own DDP implementation? Override these two hooks:
def configure_ddp(self, model, device_ids):
    # Lightning DDP simply routes to test_step, val_step, etc...
    model = LightningDistributedDataParallel(
        model,
        device_ids=device_ids,
        find_unused_parameters=True
    )
    return model

def init_ddp_connection(self):
    # use slurm job id for the port number
    # guarantees unique ports across jobs from same grid search
    try:
        # use the last 4 numbers in the job id as the id
        default_port = os.environ['SLURM_JOB_ID']
        default_port = default_port[-4:]

        # all ports should be in the 10k+ range
        default_port = int(default_port) + 15000

    except Exception as e:
        default_port = 12910

    # if user gave a port number, use that one instead
    try:
        default_port = os.environ['MASTER_PORT']
    except Exception:
        os.environ['MASTER_PORT'] = str(default_port)

    # figure out the root node addr
    try:
        root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
    except Exception:
        root_node = '127.0.0.2'

    root_node = self.trainer.resolve_root_node_address(root_node)
    os.environ['MASTER_ADDR'] = root_node
    dist.init_process_group(
        'nccl',
        rank=self.proc_rank,
        world_size=self.world_size
    )
def configure_ddp(self, model, device_ids):
    # Lightning DDP simply routes to test_step, val_step, etc...
    model = LightningDistributedDataParallel(
        model,
        device_ids=device_ids,
        find_unused_parameters=True
    )
    return model

def init_ddp_connection(self):
    # use slurm job id for the port number
    # guarantees unique ports across jobs from same grid search
    try:
        # use the last 4 numbers in the job id as the id
        default_port = os.environ['SLURM_JOB_ID']
        default_port = default_port[-4:]

        # all ports should be in the 10k+ range
        default_port = int(default_port) + 15000

    except Exception as e:
        default_port = 12910

    # if user gave a port number, use that one instead
    try:
        default_port = os.environ['MASTER_PORT']
    except Exception:
        os.environ['MASTER_PORT'] = str(default_port)

    # figure out the root node addr
    try:
        root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
    except Exception:
        root_node = '127.0.0.2'

    root_node = self.trainer.resolve_root_node_address(root_node)
    os.environ['MASTER_ADDR'] = root_node
    dist.init_process_group(
        'nccl',
        rank=self.proc_rank,
        world_size=self.world_size
    )

There are 10s of hooks like these and we add more as researchers request them.

The bottom line is that Lightning is trivial to use for a new user and infinitely extensible if you’re a researcher or production team working with the bleeding-edge AI research.

Readability and moving towards Reproducibility

As I mentioned, Lightning was created with a second more ambitious broad motivation: Reproducibility. While true reproducibility requires standard code, standard seeds, standard hardware, etc… Lightning contributes to reproducible research in two ways:

  • to standardize the format of the ML code,
  • decouple the engineering from the science so that the approach can be tested in different systems.

The result is an expressive, powerful API for doing research.

If every research project and paper was implemented using the LightningModule template, it would be very easy to find out what’s going on (but perhaps not easy to understand haha)

Distributed training

Lightning makes multi-GPU or even multi-GPU multi-node training trivial.

For instance, if you want to train the above example on multiple GPUs just add the following flags to the trainer:

trainer = Trainer(gpus=4, distributed_backend='dp')    
trainer.fit(model)
trainer = Trainer(gpus=4, distributed_backend='dp')    
trainer.fit(model)

Using the above flags will run this model on 4 GPUs.If you want to run on say 16 GPUs, where you have 4 machines each with 4 GPUs, change the trainer flags to this:

trainer = Trainer(gpus=4, nb_gpu_nodes=4, distributed_backend='ddp')    
trainer.fit(model)
trainer = Trainer(gpus=4, nb_gpu_nodes=4, distributed_backend='ddp')    
trainer.fit(model)

And submit the following SLURM job:

#!/bin/bash -l

# SLURM SUBMIT SCRIPT
#SBATCH --nodes=4
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4
#SBATCH --mem=0
#SBATCH --time=0-02:00:00

# activate conda env
source activate $1

# -------------------------
# debugging flags (optional)
 export NCCL_DEBUG=INFO
 export PYTHONFAULTHANDLER=1

# on your cluster you might need these:
# set the network interface
# export NCCL_SOCKET_IFNAME=^docker0,lo

# might need the latest cuda
# module load NCCL/2.4.7-1-cuda.10.0
# -------------------------

# run script from above
srun python3 mnist_example.py
#!/bin/bash -l

# SLURM SUBMIT SCRIPT
#SBATCH --nodes=4
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4
#SBATCH --mem=0
#SBATCH --time=0-02:00:00

# activate conda env
source activate $1

# -------------------------
# debugging flags (optional)
 export NCCL_DEBUG=INFO
 export PYTHONFAULTHANDLER=1

# on your cluster you might need these:
# set the network interface
# export NCCL_SOCKET_IFNAME=^docker0,lo

# might need the latest cuda
# module load NCCL/2.4.7-1-cuda.10.0
# -------------------------

# run script from above
srun python3 mnist_example.py

This is crazy simple considering how much happens under the hood.

For more information about distributed training with Pytorch lightning read this article about “How To Train A GAN On 128 GPUs Using PyTorch”.

Productionalization

Lightning models can be easily deployed because they’re still simple PyTorch models under the hood. This means we can leverage all the engineering advancements from the PyTorch community on supporting deployment.

Popularity

Pytorch Lightning has over 3800 stars on Github and has recently hit 110k downloads.More importantly, the community is growing rapidly with over 90 contributors, many from the top AI labs in the world adding new features daily.You can talk to us on Github or Slack.

When to use PyTorch Lightning

  • Lightning is made for professional researchers and production teams working on cutting edge research. It’s great when you know what you need to do. This focus means it adds advanced features for people looking to test/build things very quickly without getting bogged down in the details.

When not to use PyTorch Lightning

  • Although lightning is made for professional researchers and data scientists, new-comers can still benefit. For new-comers, we recommend they build a simple MNIST system from scratch using pure PyTorch. This will show them how to set up a training loop, etc. Once they understand how that works and how the forward/backward pass work, they can move into lightning. 

Ethan Harris

Creator

Matt Painter

Creator

Our part of the blog will be a little different from the others because torchbearer is coming to an end (sort of). In particular, we are joining the PyTorch-Lightning team. The move came about from a meeting with William Falcon at NeurIPS 2019, and was recently announced on the PyTorch blog.

So, instead of trying to sell you torchbearer, we thought we should write about what we did well, what we did wrong, and why we are moving to Lightning.

What we did well

  • The lib got pretty popular and got to 500+ stars on GitHub which was far more than we had ever imagined.
  • We became a part of the PyTorch ecosystem. It was an important experience for us that allowed us to feel like a valued part of a wider community.
  • We’ve built a comprehensive set of built-in callbacks and metrics. This was one of our key successes; a lot of powerful outcomes can be achieved in a single line of code with torchbearer.
  • An important feature of torchbearer that enables extreme flexibility is the state object. This is a mutable dictionary that houses all of the variables that are in use by the core training loop. By editing these variables in callbacks at different points in the loop, most highly complex outcomes can be achieved.
  • It was always important to us that torchbearer had good documentation. We focused on example-led docs that can be executed in your browser with Google Colab. The example library has been a success, giving quick information on the more powerful use cases of torchbearer.
  • A final thing to note is that torchbearer has been used by both of us over the past two years for our PhD research. We count this as a success because we have almost never had to change the torchbearer API in order to prototype our ideas, even the ridiculous ones!

What we did wrong

  • The state object, which makes this library so flexible, is also problematic. The ability to access any part of the library from any other leads itself towards abuse in the same way that global variables do. In particular, determining how and when a particular variable in the state object was changed is challenging once more than one object is acting on it. Additionally, for state to be effective you need to know what each variable is and in which callbacks you can access it, so the learning curve is steep.
  • By its nature, torchbearer does not lend itself to distributed training, or even to some extent low precision training. Since every part of state is available at all times, how do you chunk this and distribute it across devices? PyTorch can deal with this in some way, in that torchbearer can be used when distributed, but it is unclear exactly what is happening to state at these times.
  • Changing the core training loop was non-trivial. Torchbearer offers a way to completely write your own core loop, but you then have to manually write in callback points to ensure all the built-in Torchbearer functionality. Coupling this with a lower standard of documentation compared to other aspects of the library, custom loops were overly complicated and likely completely unknown to most users.
  • Managing an open-source project while working on our PhDs ended up being more difficult than expected. As a result, some parts of the library were thoroughly tested and stable (since they were important for our PhD work), while others were under-developed and buggy.
  • During our initial growth, we decided to dramatically change the core API. This significantly improved Torchbearer, but also meant a lot of effort moving from one version to the next. It felt justified as we were still pre 1.0.0 stable release but it certainly contributed to some users choosing other libraries.

Why we are joining Pytorch Lightning?

  • The first key reason for our willingness to move to Lightning is its popularity. With Lightning we become part of the fastest-growing PyTorch training library, that has already eclipsed many of its competitors.\
  • The second key reason for our move, and a key part of the success of Lightning, is that it was built from the ground up to support distributed training and low precision, both challenging to implement in torchbearer. These practical considerations made in the early stages of Lightning’s development are invaluable to the modern deep learning practitioner and would be challenging to retro-fit in torchbearer.
  • In addition, at Lightning we will be part of a larger team of core developers. This will enable us to ensure greater stability and to support a broader range of use cases than is possible with just two developers as we have now.

Ultimately, we have always believed that the best way to move things forward would be to join efforts with another library. This is our chance to do that and help Lightning become the best training library for PyTorch.

(Subjective) Comparison and Final Thoughts

Jakub Czakon

Senior Data Scientist

At this point, I want to give a…

huge THANK YOU to all the authors!

Wow, this is a lot of first-hand info and I hope it will make it easier to choose the library that works for you.

As I was working on this article with them and looking closer at what their libraries have to offer (and creating some Pull Requests), I gained my own personal perspective that I want to share with you here.

Skorch

If you want the sklearn-like API then Skorch is your lib. It is well tested and documented. It actually gives more flexibility then what I had anticipated before working on this article which was a nice surprise. That said the focus of this lib is not cutting edge research but rather production applications. I feel that it really delivers on their promise and does exactly what it was built to do. I really respect tools/libs like that.

Fastai

Fastai for a long time has been a great choice for people getting into deep learning. It can get you state-of-the-art results in 10 lines of almost magical code. But there is another side to the library, perhaps lesser-known, that lets you access lower-level APIs and create custom building blocks that give researchers and practitioners flexibility to implement very complex systems. Maybe it was the uber-popular fastai deep learning course that created a false image of this library in my mind but I will definitely take it for a spin in the future, especially with the recent v2 pre-release.

Pytorch Ignite

Ignite is an interesting animal. With its, a bit exotic (for my personal taste), engine, event and handler API you can do pretty much whatever you want. It has a ton of features out-of-the-box and I definitely understand why many researchers use it in their daily work. It took me a moment to get familiar with the framework but you just need to stop thinking in “callback terms” and you’ll be fine. That said, the API doesn’t speak to me as clearly as some other libs. You should check it out though, as it may be a great choice for you.

Catalyst

Before looking into Catalyst I thought it was a heavy(ish) framework for creating deep learning pipelines. Now my view is completely different. It decouples engineering stuff from research in a beautiful way. Pure PyTorch objects go into a trainer that deals with the training. It is very flexible and has a separate module that deals with Reinforcement Learning. It also gives you a lot of features out-of-the-box when it comes to reproducibility, and serving models in production. And those multistage pipelines I told you about? You can easily create them with minimal overhead. Overall I think it is a great project and a lot of people out there could benefit from using it.

Pytorch Lightning

Lightning also wants to separate science from engineering and I think it does a great job at that. There are just a ton of in-built features that make it even more appealing.But something that makes this library a bit different is that it enables reproducibility by making deep learning research implementations readable. It is really easy to follow the logic inside of the LightningModule where the training step (among other things) is not abstracted away. I think communicating research projects in this way can be extremely effective. It is getting very popular very quickly and with authors of Torchbearer joining the core developer team I think that this project has a bright future in front of it, Lightning bright even 🙂

So which one should you choose?As always it depends but I think you now have enough information to make a good decision!

You liked it? Share it and let others enjoy it too!

Get notified of new articles

By submitting the form you give concent to store the information provided and to contact you.Please review our Privacy Policy for further information.

Neptune is the most lightweight experiment tracking tool

Track and share your:

  • Metrics and results
  • Hyperparameters
  • Charts and visualizations
  • Data versions
  • Model binaries
  • Notebook checkpoints
Experiment tracking tool
  • Neptune brings organization and collaboration to data science projects. Everything is secured and backed-up in an organized knowledge repository.
  • Copyright 2020 Neptune Labs Inc.
    All Rights Reserved