Skip to main content

CleanRL SAC (jax)

Adapted from SBX (SB3 + Jax) implementation
Created on October 23|Last edited on October 24

Performance (sample efficiency)

HalfCheetah-v2



Computing group metrics from first 10 groups
200k400k600k800k1MTimestep0200040006000800010000Episodic Return
exp_name: td3_continuous_action
exp_name: td3_continuous_action_jax
exp_name: ddpg_continuous_action_jax
exp_name: ddpg_continuous_action
exp_name: rpo_continuous_action_alpha_0_1
exp_name: rpo_continuous_action_alpha_0_05
exp_name: rpo_continuous_action_alpha_0_01
exp_name: ppo_continuous_action_8M
exp_name: rpo_continuous_action
exp_name: sac_jax
Run set
104


Hopper-v2


Run set
108


Walker2d-v2


Run set
106


Runtime

Note: for SAC (jax), multiple runs were done at the same time on the same machine, so the number of steps per second is a lower bound.
Additional optimization can be applied but adding a bit of complexity:


HalfCheetah-v2


Run set
104


Hopper-v2


Run set
108


Walker2d-v2


Run set
106