Skip to main content

Clashluke's group workspace

Timestamps visible
2023-05-07 09:18:13
--------------------
2023-05-07 09:18:13
The above exception was the direct cause of the following exception:
2023-05-07 09:18:13
Traceback (most recent call last):
2023-05-07 09:18:13
  File "main.py", line 4, in <module>
2023-05-07 09:18:13
    main()
2023-05-07 09:18:13
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 202, in main
2023-05-07 09:18:13
    wctx = step(dat)
2023-05-07 09:18:13
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 112, in __call__
2023-05-07 09:18:13
    self.wctx = WhileTrainContext(self.step(wctx.serialize()))
2023-05-07 09:18:13
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 56, in jitless_step
2023-05-07 09:18:13
    return loop(train_step, wctx.serialize(), steps, training.device_unroll)
2023-05-07 09:18:13
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 172, in loop
2023-05-07 09:18:13
    return lax.scan(lambda *x: (fn(*x[:-1]), None), fn_input, None, steps, unroll=unroll)[0]
2023-05-07 09:18:13
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 172, in <lambda>
2023-05-07 09:18:13
    return lax.scan(lambda *x: (fn(*x[:-1]), None), fn_input, None, steps, unroll=unroll)[0]
2023-05-07 09:18:13
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 29, in train_step
2023-05-07 09:18:13
    scalars, grads = grad_fn(params, data_slice)
2023-05-07 09:18:13
TypeError: Custom VJP rule must produce an output with the same container (pytree) structure as the args tuple of the primal function, and in particular must produce a tuple of length equal to the number of arguments to the primal function, but got VJP output structure PyTreeDef(((*, *), None, *)) for primal input structure PyTreeDef((*, *, *, *)).