Skip to main content

Clashluke's group workspace

Timestamps visible
2022-10-30 10:40:33
    main()
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 211, in main
2022-10-30 10:40:33
    wctx = step(dat)
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 121, in __call__
2022-10-30 10:40:33
    self.wctx = WhileTrainContext(self.step(wctx.serialize()))
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
2022-10-30 10:40:33
    return fun(*args, **kwargs)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/api.py", line 2136, in cache_miss
2022-10-30 10:40:33
    out_tree, out_flat = f_pmapped_(*args, **kwargs)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/api.py", line 2012, in pmap_f
2022-10-30 10:40:33
    out = pxla.xla_pmap(
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/core.py", line 2066, in bind
2022-10-30 10:40:33
    return map_bind(self, fun, *args, **params)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/core.py", line 2098, in map_bind
2022-10-30 10:40:33
    outs = primitive.process(top_trace, fun, tracers, params)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/core.py", line 2069, in process
2022-10-30 10:40:33
    return trace.process_map(self, fun, tracers, params)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/core.py", line 689, in process_call
2022-10-30 10:40:33
    return primitive.impl(f, *tracers, **params)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 941, in xla_pmap_impl
2022-10-30 10:40:33
    compiled_fun, fingerprint = parallel_callable(
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/linear_util.py", line 295, in memoized_fun
2022-10-30 10:40:33
    ans = call(fun, *args)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1187, in parallel_callable
2022-10-30 10:40:33
    pmap_computation = lower_parallel_callable(
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 313, in wrapper
2022-10-30 10:40:33
    return func(*args, **kwargs)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1355, in lower_parallel_callable
2022-10-30 10:40:33
    jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1262, in stage_parallel_callable
2022-10-30 10:40:33
    jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 313, in wrapper
2022-10-30 10:40:33
    return func(*args, **kwargs)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2013, in trace_to_jaxpr_final
2022-10-30 10:40:33
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1946, in trace_to_subjaxpr_dynamic
2022-10-30 10:40:33
    ans = fun.call_wrapped(*in_tracers_)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
2022-10-30 10:40:33
    ans = self.f(*args, **dict(self.params, **kwargs))
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 64, in jitless_step
2022-10-30 10:40:33
    return loop(train_step, wctx.serialize(), steps, training.device_unroll)
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 172, in loop
2022-10-30 10:40:33
    return lax.scan(lambda *x: (fn(*x[:-1]), None), fn_input, None, steps, unroll=unroll)[0]
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
2022-10-30 10:40:33
    return fun(*args, **kwargs)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 259, in scan
2022-10-30 10:40:33
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 245, in _create_jaxpr
2022-10-30 10:40:33
    jaxpr, consts, out_tree = _initial_style_jaxpr(
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/util.py", line 277, in wrapped
2022-10-30 10:40:33
    result = call(weak_arg, *args, **kwargs)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 57, in _initial_style_jaxpr
2022-10-30 10:40:33
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/util.py", line 277, in wrapped
2022-10-30 10:40:33
    result = call(weak_arg, *args, **kwargs)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 51, in _initial_style_open_jaxpr
2022-10-30 10:40:33
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 313, in wrapper
2022-10-30 10:40:33
    return func(*args, **kwargs)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1929, in trace_to_jaxpr_dynamic
2022-10-30 10:40:33
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1946, in trace_to_subjaxpr_dynamic
2022-10-30 10:40:33
    ans = fun.call_wrapped(*in_tracers_)
2022-10-30 10:40:33
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/linear_util.py", line 168, in call_wrapped
2022-10-30 10:40:33
    ans = self.f(*args, **dict(self.params, **kwargs))
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 172, in <lambda>
2022-10-30 10:40:33
    return lax.scan(lambda *x: (fn(*x[:-1]), None), fn_input, None, steps, unroll=unroll)[0]
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 31, in train_step
2022-10-30 10:40:33
    grads = {k: v * wctx.ctx.optimizer.hessian_scale for k, v in grads}
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 31, in <dictcomp>
2022-10-30 10:40:33
    grads = {k: v * wctx.ctx.optimizer.hessian_scale for k, v in grads}
2022-10-30 10:40:33
jax._src.traceback_util.UnfilteredStackTrace: ValueError: too many values to unpack (expected 2)
2022-10-30 10:40:33
The stack trace below excludes JAX-internal frames.
2022-10-30 10:40:33
The preceding is the original exception that occurred, unmodified.
2022-10-30 10:40:33
--------------------
2022-10-30 10:40:33
The above exception was the direct cause of the following exception:
2022-10-30 10:40:33
Traceback (most recent call last):
2022-10-30 10:40:33
  File "main.py", line 4, in <module>
2022-10-30 10:40:33
    main()
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 211, in main
2022-10-30 10:40:33
    wctx = step(dat)
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 121, in __call__
2022-10-30 10:40:33
    self.wctx = WhileTrainContext(self.step(wctx.serialize()))
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 64, in jitless_step
2022-10-30 10:40:33
    return loop(train_step, wctx.serialize(), steps, training.device_unroll)
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 172, in loop
2022-10-30 10:40:33
    return lax.scan(lambda *x: (fn(*x[:-1]), None), fn_input, None, steps, unroll=unroll)[0]
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 172, in <lambda>
2022-10-30 10:40:33
    return lax.scan(lambda *x: (fn(*x[:-1]), None), fn_input, None, steps, unroll=unroll)[0]
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 31, in train_step
2022-10-30 10:40:33
    grads = {k: v * wctx.ctx.optimizer.hessian_scale for k, v in grads}
2022-10-30 10:40:33
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 31, in <dictcomp>
2022-10-30 10:40:33
    grads = {k: v * wctx.ctx.optimizer.hessian_scale for k, v in grads}
2022-10-30 10:40:33
ValueError: too many values to unpack (expected 2)