Skip to main content

Clashluke's group workspace

Timestamps visible
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
2023-05-17 13:25:08
    return fun(*args, **kwargs)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/api.py", line 724, in value_and_grad_f
2023-05-17 13:25:08
    ans, vjp_py, aux = _vjp(
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/api.py", line 2182, in _vjp
2023-05-17 13:25:08
    out_primal, out_vjp, aux = ad.vjp(
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 141, in vjp
2023-05-17 13:25:08
    out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 128, in linearize
2023-05-17 13:25:08
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
2023-05-17 13:25:08
    return func(*args, **kwargs)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 777, in trace_to_jaxpr_nounits
2023-05-17 13:25:08
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 190, in call_wrapped
2023-05-17 13:25:08
    ans = self.f(*args, **dict(self.params, **kwargs))
2023-05-17 13:25:08
  File "/home/lucas/PycharmProjects/Olmax/src/model/main.py", line 61, in compute
2023-05-17 13:25:08
    return body_ctx(ctx, src, tgt)
2023-05-17 13:25:08
  File "/home/lucas/PycharmProjects/Olmax/src/backend.py", line 35, in _fn
2023-05-17 13:25:08
    return fn(ctx, *args, **kwargs)
2023-05-17 13:25:08
  File "/home/lucas/PycharmProjects/Olmax/src/model/main.py", line 52, in body_ctx
2023-05-17 13:25:08
    (out, loss), _ = lax.scan(block(ctx), carry, (src, tgt, jnp.arange(ctx.dims.sequence)))
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
2023-05-17 13:25:08
    return fun(*args, **kwargs)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 262, in scan
2023-05-17 13:25:08
    out = scan_p.bind(*consts, *in_flat,
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1031, in scan_bind
2023-05-17 13:25:08
    return core.AxisPrimitive.bind(scan_p, *args, **params)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/core.py", line 2633, in bind
2023-05-17 13:25:08
    return self.bind_with_trace(top_trace, args, params)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace
2023-05-17 13:25:08
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 315, in process_primitive
2023-05-17 13:25:08
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 510, in _scan_jvp
2023-05-17 13:25:08
    jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr(
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 699, in jvp_jaxpr
2023-05-17 13:25:08
    return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate))
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 709, in _jvp_jaxpr
2023-05-17 13:25:08
    jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
2023-05-17 13:25:08
    return func(*args, **kwargs)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2150, in trace_to_jaxpr_dynamic
2023-05-17 13:25:08
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2172, in trace_to_subjaxpr_dynamic
2023-05-17 13:25:08
    ans = fun.call_wrapped(*in_tracers_)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 190, in call_wrapped
2023-05-17 13:25:08
    ans = self.f(*args, **dict(self.params, **kwargs))
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/core.py", line 229, in jaxpr_as_fun
2023-05-17 13:25:08
    return eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.consts, *args)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/core.py", line 447, in eval_jaxpr
2023-05-17 13:25:08
    ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/core.py", line 2633, in bind
2023-05-17 13:25:08
    return self.bind_with_trace(top_trace, args, params)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace
2023-05-17 13:25:08
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 315, in process_primitive
2023-05-17 13:25:08
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 838, in _custom_vjp_call_jaxpr_jvp
2023-05-17 13:25:08
    fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros)  # consts can be tracers!
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2067, in memoized
2023-05-17 13:25:08
    out = cells[args] = fn(*args)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 1992, in fwd_jaxpr_from_zeros
2023-05-17 13:25:08
    return trace_to_subjaxpr_dynamic(fwd_, main_(), in_avals)[::2]
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2172, in trace_to_subjaxpr_dynamic
2023-05-17 13:25:08
    ans = fun.call_wrapped(*in_tracers_)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 215, in call_wrapped
2023-05-17 13:25:08
    out_store.store(side)
2023-05-17 13:25:08
  File "/home/lucas/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 133, in store
2023-05-17 13:25:08
    raise StoreException("Store occupied with not-equal value") from None
2023-05-17 13:25:08
jax._src.traceback_util.UnfilteredStackTrace: jax._src.linear_util.StoreException: Store occupied with not-equal value
2023-05-17 13:25:08
The stack trace below excludes JAX-internal frames.
2023-05-17 13:25:08
The preceding is the original exception that occurred, unmodified.
2023-05-17 13:25:08
--------------------
2023-05-17 13:25:08
The above exception was the direct cause of the following exception:
2023-05-17 13:25:08
Traceback (most recent call last):
2023-05-17 13:25:08
  File "/home/lucas/PycharmProjects/Olmax/main.py", line 4, in <module>
2023-05-17 13:25:08
    main()
2023-05-17 13:25:08
  File "/home/lucas/PycharmProjects/Olmax/src/main.py", line 233, in main
2023-05-17 13:25:08
    wctx = step(dat)
2023-05-17 13:25:08
  File "/home/lucas/PycharmProjects/Olmax/src/main.py", line 116, in __call__
2023-05-17 13:25:08
    self.wctx = WhileTrainContext(self.step(wctx.serialize()))
2023-05-17 13:25:08
  File "/home/lucas/PycharmProjects/Olmax/src/main.py", line 56, in jitless_step
2023-05-17 13:25:08
    wctx, scalars = lax.scan(train_step, wctx.serialize(), data, unroll=training.device_unroll)
2023-05-17 13:25:08
  File "/home/lucas/PycharmProjects/Olmax/src/main.py", line 27, in train_step
2023-05-17 13:25:08
    scalars, grads = grad_fn(params, data_slice)
2023-05-17 13:25:08
  File "/home/lucas/PycharmProjects/Olmax/src/model/main.py", line 61, in compute
2023-05-17 13:25:08
    return body_ctx(ctx, src, tgt)
2023-05-17 13:25:08
  File "/home/lucas/PycharmProjects/Olmax/src/backend.py", line 35, in _fn
2023-05-17 13:25:08
    return fn(ctx, *args, **kwargs)
2023-05-17 13:25:08
  File "/home/lucas/PycharmProjects/Olmax/src/model/main.py", line 52, in body_ctx
2023-05-17 13:25:08
    (out, loss), _ = lax.scan(block(ctx), carry, (src, tgt, jnp.arange(ctx.dims.sequence)))
2023-05-17 13:25:08
jax._src.linear_util.StoreException: Store occupied with not-equal value