Skip to main content

Data Parallel Training with Flux.jl

Training ImageNet in Flux
Created on January 31|Last edited on February 7
Data Parallelism is a common technique to distribute training of a neural network by creating several copies of the model, each having access to only part of the dataset. In theory, since every copy of the model sees a different subset of the data, by combining the learning from every minibatch, we can simulate higher batch sizes and therefore speed up training.
By its nature, data-parallel training is often implemented such that multiple GPUs can be tasked with the duty of calculating gradients for each distinct sub-batch of data. This allows scaling neural network training manifold, ultimately making it viable to train over increasingly large datasets. ImageNet is a standard dataset used to train models on images spread over several different classes and is thus employed for this demonstration.

Training Metrics

Several models are trained since we need to populate the pre-trained weights in Metalhead, a repository of standard computer vision models implemented using Flux. The trained models include ResNet-18/34/50/101/152, EfficientNet-v2, Inception-v2/v3, SqueezeNet, VGG11/13/16/19, ResNeXt among others.
Each model was trained on a machine with a 32 core AMD EPYC 7513, 1 TB of RAM, 4x 32 GB Nvidia V-100s. The reported results follow the accuracy and loss curves while training a ResNet34.

Run set
4


The Approach

For this demo, we are using a vanilla DDP implementation written with [CUDA.jl](https://cuda.juliagpu.org) and Flux. It uses multiple processes and RemoteChannel s to communicate between them. We incur network and serialisation overhead with every time we want to update the models. The API is highly configurable and high level so that its easy to run several experiments without needing to rewrite boilerplate.
ResNetImageNet.start(loss, nothing,
key, model, nworkers(), rcs, updated_grads_channel,
class_idx = classes,
verbose = false,
devices = CUDA.devices(),
cycles = 1000,
nsamples = 600,
batchsize = 48,
sched = sched,
saveweights = false,
opt = opt)

Work in Progress

Most frameworks have limited support for inbuilt data parallelism. This is with good reason too. DDP is a niche use case because using pre-trained models can alleviate the need to train models over large datasets. However, it is still desired to have high-quality tools to make DDP easier to implement. Examples of such projects include moolib and pytorch lightning. In Julia, we really like composition, so the natural way to scale model training is to marry a library that can help with parallelization with Flux. One of the prime candidates is Dagger. The output is a simple glue package called DaggerFlux which automates the process of spreading a model over multiple GPUs.
One of the ways we can use Dagger is also to keep track of which GPU is available and send it a chunk of data to train on. In this manner, we are also able to modularize the approach to development. Improvements to the parallelization, GPU capabilities, and features in the ML frameworks can be added incrementally while allowing low friction APIs for users to consume. In FluxML, we pride ourselves in having a high level, simple and flexible yet performant structures to accommodate esoteric cases beyond simple ResNet training. In order to make use of parallelism, DaggerFlux exposes a simple structure.
using Flux, DaggerFlux

model = Chain(Dense(3,3), Dense(3,3))

dagger_model = DaggerChain(model)
And that's it. Now we are ready to automatically scale up training based on how many (and which) GPUs are available to Julia, and schedule training such that we can take advantage of data parallelism.

Future Work

There are a lot of interesting techniques out there to improve time-to-model-convergence such as superconvergence and the zerO optimiser which would be good to have implementations of in Julia. Further, there is ongoing work to incorporate the automatic DtoD transfers on P2P connected devices in CUDA.jl. Using this, we can simplify the DDP implementation significantly since we don't need to use the CPU to transfer data between devices - a critical step in synchronizing gradients to train our models. We can eliminate the need to serialise so often, and that should result in speedups to the convergence as well.