JAXPiracy (Directed by Saurav and Soumik)
Comprehensive set of tutorials in JAX: Basics to Advanced
Created on July 29|Last edited on July 29
Comment
Welcome to a series of (ongoing) tutorials on JAX + Flax, wherein we'll go through a series machine learning recipes and use-cases to help you understand and become self-proficient in the writing code in this new ecosystem. But why Flax you ask? Don't we have Tensorflow and PyTorch already? Don't we have enough? Well, let us try and convince you why you should try the (JAX + Flax) ecosystem.
Most of of us can agree that this next era of deep learning will be about scaling. Researchers and industrial labs will continue to push the boundaries of current hardware and scale up. Models are getting better, they want to see more data, and they have higher training instability. All of this clearly can't happen on a single accelerator so we need to be looking at multi-accelerator systems.
Wouldn't it be nice if you could write your code in an accelerator-agnostic way? This isn't possible with most current systems. In PyTorch you need to split your model and then move your tensors onto multiple devices and then somehow figure out how to sync metrics, log concurrently––the list goes on. Tensorflow makes it a bit easier with the Strategy submodule but makes it harder to do fine-grained modifications without significant extra code.
Enter (JAX+Flax). If you use Flax you write your forward pass for one single machine by writing a function and then just convert that function using a JAX transformation and have it run on any number of devices. Using considerably fewer modifications you can run your code on any setup (CPU, GPU, TPU, etc.).
Now we understand the current public emotion regarding JAX, but that is the motivation behind the project. We understand that at the moment JAX seems too nascent to convert to and might be very different to what deep learning engineers are used to, but we plan to address these issues in this series and maybe motivate you to start thinking about JAX.
<INSERT GENERIC INTRO>
- In the first post we'll cover how to write a basic training loop in JAX and Flax. We cover a simple end-to-end training and evaluation pipeline in JAX, Flax and Optax for Image Classification. We will also explore how the Flax-based training and evaluation pipeline differs from the same written in existing popular frameworks such as Tensorflow and PyTorch.
- In the second post we'll cover ....
Roll Credits
Add a comment