Skip to main content

Clashluke's group workspace

Timestamps visible
2023-05-12 06:01:03
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 614, in __call__
2023-05-12 06:01:03
    out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
2023-05-12 06:01:03
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 763, in bind
2023-05-12 06:01:03
    outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
2023-05-12 06:01:03
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/interpreters/ad.py", line 396, in process_custom_vjp_call
2023-05-12 06:01:03
    res_and_primals_out = fwd.call_wrapped(*fwd_in)
2023-05-12 06:01:03
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
2023-05-12 06:01:03
    ans = self.f(*args, **dict(self.params, **kwargs))
2023-05-12 06:01:03
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 969, in fwd
2023-05-12 06:01:03
    ans, rule = fun(*args, **kwargs)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/reversible.py", line 76, in _fn
2023-05-12 06:01:03
    out = base(params, x1, [*(sparse,) * (sparse_access == SparseAccess.read), *inner_args])
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/reversible.py", line 46, in base
2023-05-12 06:01:03
    out = fn(new_ctx, inp, *inner_args)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 35, in _fn
2023-05-12 06:01:03
    return fn(ctx, *args, **kwargs)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/linear.py", line 77, in read
2023-05-12 06:01:03
    offset1, offset0, gates = input_fn(ctx, token, position, dense0, ctx.dims.features, ctx.dims.pointwise_features,
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 35, in _fn
2023-05-12 06:01:03
    return fn(ctx, *args, **kwargs)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/linear.py", line 67, in input_fn
2023-05-12 06:01:03
    return scale_norm_act_linear(ctx, token_embedding + position_embedding + dense, ctx.dims.pointwise_features,
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 35, in _fn
2023-05-12 06:01:03
    return fn(ctx, *args, **kwargs)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/norm.py", line 139, in scale_norm_act_linear
2023-05-12 06:01:03
    out = _fn(inp, scale, weights)
2023-05-12 06:01:03
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
2023-05-12 06:01:03
    return fun(*args, **kwargs)
2023-05-12 06:01:03
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 614, in __call__
2023-05-12 06:01:03
    out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
2023-05-12 06:01:03
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 763, in bind
2023-05-12 06:01:03
    outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
2023-05-12 06:01:03
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 1985, in process_custom_vjp_call
2023-05-12 06:01:03
    fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
2023-05-12 06:01:03
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2172, in trace_to_subjaxpr_dynamic
2023-05-12 06:01:03
    ans = fun.call_wrapped(*in_tracers_)
2023-05-12 06:01:03
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
2023-05-12 06:01:03
    ans = self.f(*args, **dict(self.params, **kwargs))
2023-05-12 06:01:03
  File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/custom_derivatives.py", line 965, in wrapped_fun
2023-05-12 06:01:03
    ans, _ = fun(*args, **kwargs)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/norm.py", line 137, in _fn
2023-05-12 06:01:03
    return [_mm(fn(out), w) for fn, w in zip(transform_fns, wgt)], _grad
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/norm.py", line 137, in <listcomp>
2023-05-12 06:01:03
    return [_mm(fn(out), w) for fn, w in zip(transform_fns, wgt)], _grad
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/linear.py", line 22, in all2all
2023-05-12 06:01:03
    out = out.reshape(*inp.shape[:-1], jax.device_count(), 1, -1)
2023-05-12 06:01:03
jax._src.traceback_util.UnfilteredStackTrace: UnboundLocalError: local variable 'out' referenced before assignment
2023-05-12 06:01:03
The stack trace below excludes JAX-internal frames.
2023-05-12 06:01:03
The preceding is the original exception that occurred, unmodified.
2023-05-12 06:01:03
--------------------
2023-05-12 06:01:03
The above exception was the direct cause of the following exception:
2023-05-12 06:01:03
Traceback (most recent call last):
2023-05-12 06:01:03
  File "main.py", line 4, in <module>
2023-05-12 06:01:03
    main()
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 202, in main
2023-05-12 06:01:03
    wctx = step(dat)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 112, in __call__
2023-05-12 06:01:03
    self.wctx = WhileTrainContext(self.step(wctx.serialize()))
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 56, in jitless_step
2023-05-12 06:01:03
    return loop(train_step, wctx.serialize(), steps, training.device_unroll)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 172, in loop
2023-05-12 06:01:03
    return lax.scan(lambda *x: (fn(*x[:-1]), None), fn_input, None, steps, unroll=unroll)[0]
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 172, in <lambda>
2023-05-12 06:01:03
    return lax.scan(lambda *x: (fn(*x[:-1]), None), fn_input, None, steps, unroll=unroll)[0]
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/main.py", line 29, in train_step
2023-05-12 06:01:03
    scalars, grads = grad_fn(params, data_slice)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/main.py", line 104, in compute
2023-05-12 06:01:03
    return body_ctx(ctx, src, tgt)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 35, in _fn
2023-05-12 06:01:03
    return fn(ctx, *args, **kwargs)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/main.py", line 94, in body_ctx
2023-05-12 06:01:03
    carry, _ = block(ctx)(carry, (src[i], tgt[i], jnp.full((), i, dtype=jnp.int32)))
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/main.py", line 50, in _fn
2023-05-12 06:01:03
    src = reversible(ctx, read, SparseAccess.read, src, inp, position)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/reversible.py", line 85, in reversible
2023-05-12 06:01:03
    return _fn(src, list(args))
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/reversible.py", line 76, in _fn
2023-05-12 06:01:03
    out = base(params, x1, [*(sparse,) * (sparse_access == SparseAccess.read), *inner_args])
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/reversible.py", line 46, in base
2023-05-12 06:01:03
    out = fn(new_ctx, inp, *inner_args)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 35, in _fn
2023-05-12 06:01:03
    return fn(ctx, *args, **kwargs)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/linear.py", line 77, in read
2023-05-12 06:01:03
    offset1, offset0, gates = input_fn(ctx, token, position, dense0, ctx.dims.features, ctx.dims.pointwise_features,
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 35, in _fn
2023-05-12 06:01:03
    return fn(ctx, *args, **kwargs)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/linear.py", line 67, in input_fn
2023-05-12 06:01:03
    return scale_norm_act_linear(ctx, token_embedding + position_embedding + dense, ctx.dims.pointwise_features,
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/backend.py", line 35, in _fn
2023-05-12 06:01:03
    return fn(ctx, *args, **kwargs)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/norm.py", line 139, in scale_norm_act_linear
2023-05-12 06:01:03
    out = _fn(inp, scale, weights)
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/norm.py", line 137, in _fn
2023-05-12 06:01:03
    return [_mm(fn(out), w) for fn, w in zip(transform_fns, wgt)], _grad
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/norm.py", line 137, in <listcomp>
2023-05-12 06:01:03
    return [_mm(fn(out), w) for fn, w in zip(transform_fns, wgt)], _grad
2023-05-12 06:01:03
  File "/home/ubuntu/HomebrewNLP-Jax/src/model/linear.py", line 22, in all2all
2023-05-12 06:01:03
    out = out.reshape(*inp.shape[:-1], jax.device_count(), 1, -1)
2023-05-12 06:01:03
UnboundLocalError: local variable 'out' referenced before assignment