Skip to main content

Clashluke's group workspace

Timestamps visible
2023-05-14 13:58:35
    return lax.scan(lambda *x: (fn(*x[:-1]), None), fn_input, None, steps, unroll=unroll)[0]
2023-05-14 13:58:35
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 172, in <lambda>
2023-05-14 13:58:35
    return lax.scan(lambda *x: (fn(*x[:-1]), None), fn_input, None, steps, unroll=unroll)[0]
2023-05-14 13:58:35
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 29, in train_step
2023-05-14 13:58:35
    scalars, grads = grad_fn(params, data_slice)
2023-05-14 13:58:35
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/main.py", line 68, in compute
2023-05-14 13:58:35
    return body_ctx(ctx, src, tgt)
2023-05-14 13:58:35
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 35, in _fn
2023-05-14 13:58:35
    return fn(ctx, *args, **kwargs)
2023-05-14 13:58:35
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/main.py", line 58, in body_ctx
2023-05-14 13:58:35
    carry, _ = block(ctx)(carry, (src[i], tgt[i], jnp.full((), i, dtype=jnp.int32)))
2023-05-14 13:58:35
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/main.py", line 26, in _fn
2023-05-14 13:58:35
    src, loss = loss_fn(ctx, src[1:], inp)
2023-05-14 13:58:35
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 35, in _fn
2023-05-14 13:58:35
    return fn(ctx, *args, **kwargs)
2023-05-14 13:58:35
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/loss.py", line 101, in loss_fn
2023-05-14 13:58:35
    src, loss, acc = _fn(src, tgt, param)
2023-05-14 13:58:35
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/loss.py", line 92, in _grad
2023-05-14 13:58:35
    (pdx0, x0, pdx1, x1, dsp, sp), x, d_loss, d_acc = dy
2023-05-14 13:58:35
ValueError: not enough values to unpack (expected 4, got 3)