Skip to main content

Fluid Dynamics with JAX

Created on April 16|Last edited on April 26
If you're into computational fluid dynamics (CFD), you know how important it is to write performant code. However, achieving high performance often means sacrificing readability and ease of use. That's why I was intrigued when I discovered JAX, a powerful Python library that enables developers to write high-performance numerical code with minimal effort. JAX's just-in-time (JIT) compilation and automatic differentiation capabilities make it particularly well-suited for CFD simulations where performance is critical, and the ability to compute gradients is essential for optimization and uncertainty quantification.
My motivation for this project was to learn how to implement the Lattice Boltzmann Method (LBM) using JAX. I was particularly interested in exploring the advantages of using JAX for CFD simulations, such as improved performance, modularity, and ease of use. Additionally, I want to experiment with reinforcement learning (RL) in the context of fluid dynamics simulations, such as for the optimization of fluid flow control or the design of new turbulence models.
You can check out my github repository, it contains the codes I used to generate the simulations in this report as well as some cool work in progress on reinforcement learning for fluids.


Methods

Lattice Boltzmann Formulation

Overview

One of the popular techniques to simulate fluids is the Lattice Boltzmann Method (LBM). This method involves simulating the motion of fluid particles using statistical mechanics principles. One of the advantages of the LBM is that it's highly parallelizable and can be efficiently implemented on modern hardware like GPUs. This method can also easily handle complex geometries and boundary conditions, making it well-suited for simulating fluid dynamics in a wide range of applications.

Governing Equations

The Boltzmann equation describes the statistical behavior of particles in a gas or fluid. It can be written as:
ft+uf=Ω\frac{\partial f}{\partial t} + \mathbf{u} \cdot \nabla f = \Omega

where f(u,x,t)f(\mathbf{u}, \mathbf{x}, t) is the distribution function, which gives the probability density function of microscopic fluid particles that are moving with a certain velocity u\mathbf{u} at a certain point in space x\mathbf{x} and time tt. The collision term Ω\Omega describes the interactions between the particles.
To simulate the fluid dynamics using LBM, we discretize the Boltzmann equation in both space and time. We define a lattice, which is a regular grid composed of discrete points xj\mathbf{x}_j, and a set of discrete velocities ei\mathbf{e}_i. These discrete velocities correspond to the possible lattice points where particles can move in one time step. With this, we can define our discrete distribution functions fi(xj,tk)f_i(\mathbf{x}_j,t_k).
By taking moments of the distribution function, we can extract macroscopic variables such as density, velocity, and pressure. For example, the zeroth moment of the distribution function represents the density at a given lattice point and time:
ρ(xj,tk)=ifi(xj,tk)\rho(\mathbf{x}_j,t_k) = \sum_i f_i(\mathbf{x}_j,t_k)

Similarly, the first moment of the distribution function represents the momentum density:
ρ(xj,tk)u(xj,tk)=ieifi(xj,tk)\rho(\mathbf{x}_j,t_k) \mathbf{u}(\mathbf{x}_j,t_k) = \sum_i \mathbf{e}_if_i(\mathbf{x}_j,t_k)

The LBM works by updating the distribution functions according to the collision and advection processes. The collision process represents the interactions between the particles, which can be modeled using a collision operator. The advection process represents the movement of the particles between lattice points, which can be modeled using a streaming step. The combination of these processes leads to the evolution of the distribution functions over time, which, in turn, gives us the macroscopic variables such as density, velocity, and pressure.
The evolution of the distribution functions is governed by the following equation:
fi(xj+ei,tk+1)=fi(xj,tk)1τ[fi(xj,tk)fieq(xj,tk)]f_i(\mathbf{x}_j+\mathbf{e}_i,t_{k+1}) = f_i(\mathbf{x}_j,t_k) - \frac{1}{\tau}\left[f_i(\mathbf{x}_j,t_k) - f_i^{eq}(\mathbf{x}_j,t_k)\right]

where τ\tau is a relaxation time that determines the rate at which the distribution functions approach their equilibrium values, and fieq(xj,tk)f_i^{eq}(\mathbf{x}_j,t_k) is the local equilibrium distribution function:
fieq(x,t)=wiρ(x,t)[1+eiu(x,t)cs2+(eiu(x,t))22cs4u(x,t)u(x,t)2cs2]f_i^{eq}(\mathbf{x},t) = w_i \rho(\mathbf{x},t)\left[1 + \frac{\mathbf{e}_i \cdot \mathbf{u}(\mathbf{x},t)}{c_s^2} + \frac{(\mathbf{e}_i \cdot \mathbf{u}(\mathbf{x},t))^2}{2c_s^4} - \frac{\mathbf{u}(\mathbf{x},t) \cdot \mathbf{u}(\mathbf{x},t)}{2c_s^2}\right]

where wiw_i are weights that depend on the discrete velocity ei\mathbf{e}_i.

Boundary Conditions

In Lattice Boltzmann Method, boundary conditions are imposed on the distribution functions at the lattice points located on the boundary of the simulation domain. These boundary conditions ensure that the flow field behaves correctly at the boundary and that the fluid dynamics are accurately simulated. The most commonly used boundary conditions in LBM are no-slip, inflow, and outflow boundaries.
  • No-Slip
The no-slip boundary condition requires the fluid velocity to be zero at the boundary. It is imposed by using the bounce-back method, which involves reflecting the distribution functions at the boundary lattice points back to their corresponding lattice points on the other side of the boundary. This effectively simulates a solid boundary that prevents fluid from flowing through the boundary.
  • Inflow
The inflow boundary condition specifies the velocity or pressure at the boundary. It is imposed by setting the distribution functions at the boundary lattice points to their equilibrium values corresponding to the specified velocity, pressure or density.
  • Outflow
The outflow boundary condition allows fluid to flow out of the simulation without reflecting back into the domain. It is imposed by extrapolating the distribution functions from the interior of the domain to the boundary lattice points.

Implementation in JAX

In this section, we will walk through a mock example of implementing a simple Lattice Boltzmann Method (LBM) simulation in JAX. This is meant to serve as a starting point for those who are new to LBM or JAX, and want to see a basic example of how the two can be combined.
We can define a simple class that holds the lattice stencil information, such as discrete velocities and weights. For example, the D2Q9 stencil is commonly used in 2D LBM simulations. To ensure that our data class is compatible with JAX, we can use the chex.dataclass decorator, which adds additional methods and attributes for shape checking and validation.
import chex
import jax.numpy as jnp
import jax

@chex.dataclass
class D2Q9:
r"""
6 2 5
\ | /
3 - 0 - 1
/ | \
7 4 8
"""
D = 2
Q = 9
cs = 1 / 3**0.5
e = jnp.array([[0, 1, 0, -1, 0, 1, -1, -1, 1],
[0, 0, 1, 0, -1, 1, 1, -1, -1]])
w = jnp.array(
[
4.0 / 9, 1.0 / 9, 1.0 / 9, 1.0 / 9, 1.0 / 9,
1.0 / 36, 1.0 / 36, 1.0 / 36, 1.0 / 36,
]
)
opposite = jnp.array([0, 3, 4, 1, 2, 7, 8, 5, 6])
Next let's implement the equilibrium and collision functions. The use of jnp.einsum allows to efficiently perform tensor contractions in JAX, which improves the performance of the code:
def macroscopic_variables(stencil, df):
# Zero order moment -> density
rho = jnp.einsum("XYQ->XY", df)[:,:,jnp.newaxis]
# First order moment -> velocity
u = jnp.einsum("XYQ, dQ->XYd", df, stencil.e) / rho
return rho, u

def equilibrium(stencil, rho, u):
u_norm2 = jnp.linalg.norm(u, axis=-1,ord=2)[..., jnp.newaxis] ** 2.
e_dot_u = jnp.einsum("dQ, XYd->XYQ", stencil.e, u)
df_eq = rho * stencil.w * (
1
+ e_dot_u / stencil.cs**2
+ 0.5 * e_dot_u**2 / stencil.cs**4
- 0.5 * u_norm2 / stencil.cs**2
)
return df_eq

def collide(stencil, df, tau):
rho, u = macroscopic_variables(stencil, df)
eq = equilibrium(stencil, rho, u)
df = df - 1.0 / tau * (df - eq)
return df
In the stream function, we use a loop to iterate over each velocity component of the stencil and apply the streaming operation to the distribution functions of the fluid particles. We use jnp.roll to perform a shift of the distribution functions along the corresponding lattice velocity, which effectively moves the fluid particles along their local velocity vector. The at[..., i].set method is used to update the distribution functions in place. Although the stream function involves a loop, it is a memory-bound operation that is well-suited for efficient computation in JAX.
def stream(stencil, df):
for i in range(stencil.Q):
streamed = jnp.roll(
a=df[..., i],
shift=stencil.e[:, i],
axis=[k for k in range(stencil.D)],
)
df = df.at[..., i].set(streamed)
return df
With the collision and stream operators implemented, we can initialize a distribution function and evolve it over time by repeatedly applying the two operators. To optimize the performance of our simulation, we can use the jax.jit decorator to JIT-compile the LBM step function, which speeds up the computation by optimizing the code for the specific hardware being used.
@jax.jit
def step(stencil, df, tau):
df = collide(stencil, df, tau)
df = stream(stencil, df)
return df

stencil = D2Q9()
rho = jnp.ones((64, 64, 1))
u = jnp.zeros((64, 64, 2))
u = u.at[:,0,0].set(0.05)

df = equilibrium(stencil, rho, u)

tau = 1.5
for i in range(10):
df = step(stencil, df, tau)
And that's it, we have implemented a basic LBM simulation in JAX. This implementation can be extended in many ways, such as adding boundary conditions, implementing multiple relaxation times, or using coupled lattices to model different processes. However, the basic principles remain the same: update the distribution functions of the fluid particles using the collision operator, and advect them along the lattice velocities using the stream operator.

Simulation Results



Rayleigh-Benard Convection

Let's look at simulations of the Rayleigh-Benard convection problem, which describes the natural convection of a fluid between two horizontal plates, heated from below and cooled from above. This process gives rise to the formation of thermal plumes that rise from the heated plate and descend from the cooled plate, leading to the establishment of convective motion in the fluid.
For this simulation, a coupled lattice approach was used that allows modeling both the fluid dynamics and the temperature evolution of the system simultaneously. Two separate lattice stencils were used, one for the fluid dynamics and one for the temperature evolution, and they were coupled together through a nonlinear interaction term that accounts for buoyancy force. This approach enables capturing the complex interplay between the fluid dynamics and the temperature evolution of the system in a computationally efficient manner.
The videos below present the temperature evolution of the fluid for different Rayleigh numbers, which control the strength of the thermal driving force in the system. As the Rayleigh number increases, the convective motion becomes more vigorous and the thermal plumes become more pronounced, leading to a richer and more complex pattern of temperature evolution.

Run set
4


Von-Karman vortex street