Experiment log book
Created on September 18|Last edited on April 21
Comment
Experiments summary
I ran some exploratory experiments to figure out whether the filler-tokens/IC-tokens/computation-tokens/pause-tokens (many people have had this idea over time thus the many names) help transformers. This is far from completed research.
The experiments are all done with small GPT-2 transformers in the range of 5-40k parameters to validate whether there's any reason to run this in LLMs of bigger sizes although someone concurrently did run that recently: https://arxiv.org/abs/2310.02226 with a 1B parameter LLM.
The experiments I conducted and noted in my log book were centered around the investigation of using internal computation (IC) tokens within a general-purpose transformer. I hypothesized that these tokens could potentially help a improve the model's performance in tasks that necessitate several computational steps. Thus, I chose to use the Collatz recursion problem and three-digit multiplication exercises for my investigation. The Collatz recursion is not reducible to a closed form formula unlike Fibonacci.
For the Collatz recursion we have the input n and trying to predict the output of this function:
def collatz(n):if n == 1:return 1elif n % 2 == 0:return 1 + collatz(n // 2)else:return 1 + collatz(3*n + 1)
(the Collatz conjecture is that this always terminates). The input to the model is "{n} [compute][compute]...[compute] --- {collatz(n)}", where "[compute]" is a reserved token that I insert into the tokenizer and the model embeddings.
I tried models of different depths (6, 32 and 64) and with various number of IC tokens (0, 5, 50). The model did not show consistent improvement with more IC tokens, in fact if you have too many (50) it does decrease the performance.
The performance was sensitive to depth, but the experiments show that the performance is in fact mostly correlated with parameters rather than "depth". If you keep the parameters constant, but vary the depth of the network, the performance doesn't change a lot (except in the more degenerate case of depth 1). This hinted that the task is likely too hard since even in the best case of increasing depth naturally (instead of the more limited way of increasing depth through IC tokens) the model doesn't seem to be able to use these additional computations.
I additionally try forcing the model to use computation tokens in two ways:
- Applying dropout in the attention on the tokens before the closest [compute]. For instance, if we have ["enxhell" "is" "[compute]" "[compute]" "18"]. When outputing "18" the attention module can only attend to the closest compute and the rest of the tokens are randomly dropped out heavily (0.6+ probability of dropout). The second compute token can only attend at the previous compute token (that's the closest) and the rest are heavily dropped out. This way forcing the model to store as much as possible in the compute tokens. This did not improve or reduce the performance. Which is perhaps a sign that the model is already utilizing the compute tokens - it is storing something useful there, otherwise the model is almost blind to the input. Even if it's a summary of the previous tokens instead of deciding to take the easy way out and predicting "[compute]" without doing anything useful. Although, the model could in theory do another degenerate case of simply predicting the output in the first [compute] and then copying it over and over.
- Make sure that in the output of the first compute token we try to predict the first non-compute token that follows the compute block. So in the case above, when we calculate the loss on the first "[compute]" token output we combine the loss of predicting the next token (in this case "[compute]" and the first non-compute token which is "18"). So that way the [compute] token is forced to do something useful. Additionally, the loss is weighted by how far the token is. Specifically, 1/distance_in_tokens. So for instance for the first compute, the loss is 1/2 and for the second compute it's 1/1 = 1. This also did not impact performance.
I also tried it with a slightly easier task of predicting 3-digit multiplication. I used this task specifically because you could have lots of data and it's still hard and needs varying number of computation steps.
The results are summarized in the graph below:

The results are shown as the average of two runs. Increasing the number of IC tokens tends to help however it is sensitive to the depth and also the number of IC tokens. The relationship is not as clear as more-tokens better-performance.
Another thing to note is that there's quite a bit of variance between runs. Dataset sampling is not to blame for that, but the model is very sensitive to the initialization.
We do analyze why the model fails inspired by manually looking at the data. We do notice that the model fails in the middle digits which coincide with having the most amount of operations (multiplications + additions) needed to calculate correctly. In fact, the more operations that are needed, the worse the model performs even within those digits (but generally across all). And the IC tokens improve generally across the board but also learn to maintain that performance as operations increase (as opposed to simply having better performance across the board, but still dropping when there's more ops).

Performance of the models of various depths and IC tokens grouped by the number of operations required to output that digit correctly
Meanwhile this paper came out: https://arxiv.org/pdf/2310.02226.pdf. They did this to an actual LLM (with some tweaks), which was the ultimate point of this experiment, and it does seem to work. Hence, I stopped some additional experiments I was planning on doing before committing some more time and resources to this. However, as you can see from some limited data here, it is crucial to pick the right number of IC tokens and seemingly even per-task. Given that the LLMs do many different "tasks" as generalists they are, seems like an obvious direction to figure out a way to let the LLM decide the number of IC tokens it needs on the fly.
Log Book
(This is my documentation as I was running the experiments, so it's not polished in writing, but also contains more detail).
The main purpose of these experiments is to validate whether a small scale transformer can actually benefit from computation/filler tokens before validating it in a big model that takes too many resources (that I don't have). The reason why intuitively that would make sense is because the model can do more computation before outputting the answer. However, it's not clear whether the model can effectively learn to utilize those tokens. Before we run big trainings on Llama or likewise, I want to validate that this shows some promise on smaller models and toy datasets as I don't have much compute power to do this.
So the problem needs to be something that requires many steps of reasoning. Fibonacci while recursive, it's not a good choice because there's a closed form formula for it - also it grows with inputs which makes it easier to learn a correlation. One problem that is not reducible to a single closed form formula and is actually quite bound in range is predicting the number of steps the collatz recursion takes to go down to 1. Collatz function is:

What we'd try to predict is:
def collatz(n):if n == 1:return 1elif n % 2 == 0:return 1 + collatz(n // 2)else:return 1 + collatz(3*n + 1)
Here's a graph of the values:

You can see for numbers up to 10k, it's bound between 1 to 260. There's some seeming patterns that emerge though.
Some additional stats:
Mean Steps: ≈85.97
Median Steps: 74
Variance: ≈2170.49
Standard Deviation: ≈46.59
So quite a lot of variance.
For all of the models we use the standard GPT2 architecture, but with much smaller parameters (around 8k for most cases). Adam optimizer for training.
For the tokenizer I use the same tokenizer as GPT2, and because that's a BPE tokenizer that means that some numbers are one token. So it's not the case that each digit is a token.
Varying number of compute tokens and depth
The task is set up such that the text to train on is something like:
{n}---[compute][compute]...[compute]{collatz(n)}
We vary the number of compute tokens. If our hypothesis is right, then with more compute tokens we would see better accuracy/MSE in the evaluation.
We also vary the depth of the network - specifically the number of layers. This in theory should have a similar effect to increasing the number of compute tokens since the model now has more compute it can do per-token. And that's the main reason why the compute tokens we think might work better - the model can do more computation before outputting the final answer.
As you can see from the table below the charts, the standard model has a depth of 6 layers, "deep" is 32 layers, and "deeper" is 64.
Observations:
- The depth doesn't matter
- Number of compute tokens doesn't matter either. In fact, when we add 50 compute tokens, the model destabilizes a bit. But in essence the performance is roughly the same in evaluation accuracy. MSE loss also shows a similar story.
- Train/loss however does show that a deeper model is able to fit the data more. Note that for this, a fair comparison is only between the same compute tokens, but different depths. Because if you compare say 1-ic and 5-ic the loss for 5-ic is going to be lower because it's measured across all tokens and the IC tokens are easy to predict. However, the lower loss doesn't seem to materialize to better eval
Next hypothesis:
- The model is not able to utilize the compute/IC tokens.
- Given that the depth of the model didn't help, maybe the task is just too hard.
Some caveats:
- Tokenization is standard gp2 BPE. So trying to predict the original tokens has one more possibility of failure - i.e some tokens like 42 is one token, numbers like 139 are two, but some three digits are 1 token and so on and so forth. This could maybe make the task a bit too hard for a small model and just adds one more variable that's not needed. I will run experiments without this though, more on this later.
- Evaluation is done using teacher-forcing by hugging face. That is a fancy way of saying that for each token the prediction is done such that the model has the ground truth for all the previous tokens before that. So, it is not generating the whole sequence, but only the next token. So if say the output is 139 and the tokens are "1" and "39" then when it's predicting "39" it already knows that the first token is "1", even if it messed up before. Now that's not a big deal for accuracy because if it messes up one token, it's wrong already. It should only affect the MSE and if we down the line measure accuracy per digit. Will correct this going forward
However, there's a few more experiments that I ran to see if I could force the model to learn.
Forcing the model to use [compute]
Redoing the experiments but with single digits as tokens and without teacher-forcing
Essentially I'll repeat the experiments above (if needed) after fixing some things.
- Make the tokenizer tokenize each digit separately. So the vocab of GPT to be single digits, and some basic signs
- Make the evaluation use generation instead of teacher-forcing.
Note that the forward/jump loss now is a bit less meaningful because it predicts the first digit, which perhaps could be fixed but i think it should give it at least a little bit of a boost so not doing anything more complex until I see a spark there.
First, let's find out if the model can benefit from more compute though. If so, then I will try to run the experiments to see if we can actually use IC tokens or force the model to use them.
Does the task benefit from more compute?
The experiment setup is simple: let's run the model with various depth levels and see if the accuracy improves as you add more layers (compute). If so, then we have hope that IC tokens can benefit the model since more compute can be utilized.
So I ran 1 compute token model with 1, 2, 3, 4, 6, 16, 64 layers.
As you can see from the chart below, there's clear improvements from 1 to 2 to 3 to 4 but then beyond 6 it seems like the model either doesn't benefit or might be more prone to overfitting.
Now there's an argument to be made that it's not only about "compute" when you lower the depth of the network - you're also lowering the number of parameters which means you lower the amount of "memory" the model has (i.e amount of info it can save in params). And the compute tokens keep the same number of parameters, but only increase compute, which if the main reason for better performance is more parameters (which I'm sure it's a big part) then it compute tokens won't help. However, at the very least, we know that progress can be made at the low-depth networks if the network can find some ways to learn from the extra compute tokens. Let's see how these lower depth networks do when we add compute token. The chart below shows that except for layer 2, having 5 compute tokens either matches the performance or it's slightly worse (compute tokens experiments are with dotted lines). The experiments with more tokens (10 and 30) are not shown here but they do not make the situation any better (same or worse).
So a natural next experiment is to look at networks that keep the parameters constant but increase depth. I.e repeat the experiments above but make sure that every time we increase depth, we reduce "width" (i.e embedding dimension or such) so to keep the parameters roughly constant. If compute steps don't matter, then we should see same performance across all depths. If they do matter, then we should see better performance.
In this case I kept the same number of parameters as depth 6 network above for all the other depths. So experiment with depth 6 stays as is, but the rest of the depths I adjusted the width of the network (n_embed) to have roughly the same number of parameters (+-10%).
As you can see from the figure above, depth 4 and 6 perform about the same. And then there's a dip for depth 2 and 3, and finally a big dip for 1. So depth of the network does seem to matter beyond just number of parameters. However, depth 1 is more of a degenerate case since it is very shallow.
The depth 2 and 3 are not too far behind though, a gap of about 10%. While it's not something to ignore, it seems plausible that we might have picked the wrong task here. We might not be able to close a 10% gap by simply using IC tokens since simply putting IC tokens does not strictly equal to more depth - it's a weaker form of adding depth. So it's probably wise we explore some other tasks too before we conclude that adding IC tokens doesn't help and abandon.
3-digit multiplication
The dataset is simple, essentially generated 200k random (non-repeating) pairs of 3-digit numbers and multiplied them. The task is given the numbers, predict the result. Insert intermediate compute tokens between the inputs and the output.
Here are the results from training the model with various depths and various IC tokens. The models here have about 8k parameters or less (lower depths have proportionally lower).


Note that here we're just increasing depth, not keeping the parameters the same. So as we increase depth, the parameters increase. We ran this twice and averaged the results.
The pattern is not entirely straightforward. It does seem that in the accuracy front there's some benefit to having more compute tokens however the relationship is not linear at all. It's not the case that the model learns to utilize more tokens beyond a certain point. For instance, for depth 3,4 beyond 5 tokens the performance either decreases or stays the same. For depth 6, there's no clear benefit to adding tokens. But for that depth the model already achieves what seems the close to cap performance.
For depth 2, where the model is underpowered, adding more IC tokens helped. To the point of 30-IC surpassing bigger depth models (when those are trained with one IC token).
However, note that the MAE (mean absolute error) doesn't show the same story in a lot of cases. I think that's because the model learns to get closer to the number in terms of digits. But then each mistake costs more because as you'll see later the model makes mistakes in digit 3 and 4 and those are costlier than just guessing all digits wrong but guessing say something around 200,000.
What is not shown here however is that there's a lot of variance between runs. Although the distribution is the same because we sample 200k examples out of 800k possible, one explanation could be we got unlucky with samples. I am running the experiments to have the dataset fixed while we speak but I'd say that's unlikely. But the fact that we have these big variations is something we can dig in. Let's take a look at the examples

So you can see that cases where we increase the IC tokens have a lot of variety. Especially at token 20 we have lots of variance between essentially all runs. Manually looking at the data seems like the models really just differ at digit 3 and 4. And the chart confirms it

Accuracy at each digit level. Digit positions are measured from right to left, so 6 is the hundreds of thousands digit.

As you can see models are pretty accurate at the first two and last two digits. But not accurate when it comes to the middle two digits. Although clearly for depth 1 the model fails at position 2 and 5 too.
Here it's broken down by token

You can see that whenever the other run made improvements (the dotted vs non-dotted) most of the times the biggest improvements were made in the 3rd digit. This really shows especially in IC-20 and IC-30. But also more generally, whenever the IC tokens improve, it does so by simply improving position 3. That is to say that even when digit 3 is almost perfect, digit 4 really never improves to perfection (at most reaches 90%).
Below you can see how many digits the models are off by in each case. Because it could be that when the model is wrong in 3 it's also wrong in 4. However, this paints the picture more so that except for depth 1, majority of the cases the model is off in one digit.

If we look at the number of multiplications and additions needed when you multiply two 3-digit numbers by hand, we'd get a plot like this:

Where the additions and multiplications are calculated like below
from collections import defaultdictdef count_operations_per_digit_with_carries(a, b):# a and b are two 3-digit numbersadditions = defaultdict(int)multiplications = defaultdict(int)carry_additions = defaultdict(int)for i, digit_a in enumerate(reversed(str(a))):for j, digit_b in enumerate(reversed(str(b))):position = i + j + 1multiplications[position] += 1if multiplications[position] > 1:additions[position] += 1product = int(digit_a) * int(digit_b)carry = product // 10if carry > 0:carry_additions[position + 1] += 1for position, count in carry_additions.items():additions[position] += countreturn additions, multiplications
You can see that the third digit is the most "complex" one followed by the 4th digit. However, it's not clear why the 4th digit rarely improves when in fact it's slightly simpler as measured by the number of operations. It could however be that because digit 4 is generated before digit 3 (remember, it's from the right that we're measuring), it can do one more computation. So it has to commit early to digit 4. Perhaps, adding an IC token before generating digit 4 could help?
And you can see that within digit 3, if you look at the number of operations the performance goes down for the models not performing well. So it does indicate that the better performing models are learning to do more operations "in their head" as they train. However, again, that's not a strict relationship with more IC tokens, more performance. Some tokens seem to help, some not.

This trend applies if you look at number of operations to do a full multiplication in general, not just for the accuracy in digit 3.

You can see the numbers that require more operations the accuracy suffers on them more and it's a clear relationship between more tokens -> less accurate. The models that improve in performance learn to do those operations better and maintain performance.
So to summarize:
- Do more tokens help? Tending towards yes, but the relationship is not as smooth as more-tokens better-performance. Also only models of certain size seem to be able to utilize them (depth 1 can't) and beyond depth-6 in our case where the model achieved the ceiling, it was counterproductive.
- Why is there variance across runs (answer below)? I don't really have an answer why there is variance between runs. Perhaps the model has trouble always effectively using the extra computation / registers? Or maybe it's a problem with dataset change. We'll see once those experiments finish running.
- What are these models learning? They seem to be progressively closing the gap with more and more operations as seen by the performance in the operations. Digit 3 and 4 are the most problematic because they require the most operations. Which would align with the idea that compute tokens could be useful to store intermediate computations.
The experiments above are all logged (with code and data used, but for data analysis purposes you can use data-analysis notebook since wandb is not as flexible). I'll put a dummy panel here with all the experiments.
The variance between different runs
I wanted to test whether the variance between runs could be explained by any of the following:
- Dataset was slightly different as I was sampling random pairs but the seed wasn't set. Although I was sampling a lot of data so that wouldn't make a big difference. That would also result in a different test set though and that maybe might make a difference
- Initialization of the network
I ran the experiments and dataset was not the issue. You still get quite a bit of variance even when you fix that. However, once you fix the initialization the variance goes away. So seems rather sensitive to the initialization. It's an interesting question why.
Update - Related Work
Someone concurrently did this in an actual LLM of about 1B params. https://arxiv.org/pdf/2310.02226.pdf
So this does seem to work.
I think there's some immediate next actions. It doesn't seem likely that just having a fixed number of tokens works. Some tokens are easier to predict than others and it may be the case that the extra tokens are not used efficiently in the previous steps for tokens that are too far away. So it would be interesting if you can get the model to decide by itself on how many think-tokens/IC tokens it needs. I think that will benefit the model both in compute (although it will be faster GPU-wise to run the model with fixed IC tokens instead of letting the model generate continously) and more importantly in accuracy.
Additionally, it would be cool to get some lever in how likely the model is to generate the compute token so I can control the resources. That however could be an instruction too.
Add a comment