Skip to main content

data transfer optimization

Created on February 27|Last edited on February 27
It's better to accumulate data on device to produce a single obs pointer before sending it to the queue

Expand 280 lines ...
281
281
    return b_obs, b_actions, b_logprobs, b_advantages, b_returns
282
282
283
283
284
+
@jax.jit
285
+
def make_bulk_array(
286
+
    obs: list,
287
+
    values: list,
288
+
    actions: list,
289
+
    logprobs: list,
290
+
):
291
+
    obs = jnp.asarray(obs)
292
+
    values = jnp.asarray(values)
293
+
    actions = jnp.asarray(actions)
294
+
    logprobs = jnp.asarray(logprobs)
295
+
    return obs, values, actions, logprobs
296
+
297
+
284
298
def rollout(
285
299
    key: jax.random.PRNGKey,
286
300
    args,
Expand 110 lines ...
397
411
        writer.add_scalar("stats/inference_time", inference_time, global_step)
398
412
        writer.add_scalar("stats/storage_time", storage_time, global_step)
399
413
        writer.add_scalar("stats/env_send_time", env_send_time, global_step)
414
+
        # `make_bulk_array` is actually important. It accumulates the data from the lists
415
+
        # into single bulk arrays, which later makes transfering the data to the learner's
416
+
        # device slightly faster. See https://wandb.ai/costa-huang/cleanRL/reports/data-transfer-optimization--VmlldzozNjU5MTg1.
417
+
        obs, values, actions, logprobs = make_bulk_array(
418
+
            obs,
419
+
            values,
420
+
            actions,
421
+
            logprobs,
422
+
        )
423
+
400
424
        payload = (
401
425
            global_step,
402
426
            actor_policy_version,
403
427
            update,
404
428
            obs,
405
-
            dones,
406
429
            values,
407
430
            actions,
408
431
            logprobs,
432
+
            dones,
409
433
            env_ids,
410
434
            rewards,
411
435
            np.mean(params_queue_get_time),
Expand 305 lines ...
717
741
                actor_policy_version,
718
742
                update,
719
743
                obs,
720
-
                dones,
721
744
                values,
722
745
                actions,
723
746
                logprobs,
747
+
                dones,
724
748
                env_ids,
725
749
                rewards,
726
750
                avg_params_queue_get_time,
Expand 106 lines ...
20406080100120Time (seconds)50001000015000
baseline
1
make_bulk_array
1
Run set 3
1



baseline
1
slight optimization
1
slgiht optimization (looks like doesn't work)
1