15B thinker distil SSM module distillation
Comparing different SSM modules on thinker distillation
Created on July 15|Last edited on August 18
Comment
We replace same single layer at distil using reverse KL.
Observations:
- smaller state is better (16 > 64)
- MIL initialisation is much better
15b-ihyb1lrklm2d64mhwk-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2 -- m2d state 64
15b-ihyb1lrklm2d64-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2 -- m2d 64 state
15b-ihyb1lrklm264xl-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2 -- m2 64, expand 2
15b-ihyb1lrklm216mil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2 -- m2, 16 state + MIL
15b-ihyb1lrklm216ht-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2. -- m2, 16 State, Rand
15b-ihyb1lrklm216xl-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2 -- m2 16 state, expand 2
15b-ihyb1lrklm2d16-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2 -- m2d 16 state
LM losses
Run set
7
Distillation with CE Loss on M2 state 16 w/ and w/o MIL init, 1/50 Hybrid
Observations:
- we note that for this layer (M2, 16 state, MIL 1/50 15b-ihyb1lcem216mil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm6) when distilling w/ CE the loss does not go down much!
- to debug we compare to RAND init in the plot below, we see that it is still learning, but MIL init is basically a lower bound for RAND (random does not reach MIL init here but they get very close (if zoom in)
Run set
2
Sanity check
So it seems like MIL is a very goods init, how well would a 5/50 Hybrid MIL perform comapred to 5/50 Hybrid MIL + RKL distill.?
- 5/50 Hybrid MIL (av. 3 gen tasks) -- 0.77
- trained version is 0.85
- Conclusion: definitely distillation works and training helps, just innitialization is already very good but gets much better with trianingf
Here we have the 1/50 Hybrid M1 MIL with CE loss (upper) vs. RKL loss (lower):
- we see that RKL learns much better, i.e. more stable, goes down more effectively despite soem instabilities in the middle (cannot really compare two losses)
CE:
Run set
1
RKLL:
Run set
1
0-shot distillation on 5/50 with RKL and MIL innit:
- will 0-shot distill 5/50 with MIL init and RKL (15b-oshyb5lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6 ) reach the loss of iterative 5/50?
Run set
4
45/50 Hybrid
Here we try to distill an ambitious 45/50 Hybrid. We always train all SSM layers here.
Baseline:
- 25/50 Hybrid distillation run (red and brown) 15b-oshyb25lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti25000_lm6 15b-oshyb25lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6
Options:
- 45/50 MIL innit w/ layer importance (standard) -- (grey) 15b-oshyb45lmilli-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6
- 45/50 MIL innit w/ layer importance (standard) + 25/50 init for overlapping layers -- (pink) 15b-oshyb45lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6
Observation:
- 25/50 init helps at to speed up the beginning, but converges later
- at 45/50 hybrid level layer importance just keeps the first 5 layers as transformers, everything else is m2, which motivated the next baselines
Options (continuation)
- 45/50 MIL init with uniform transformer layer distribution with binning (similar to Zebra) -- here we place transformer layers more-or-less uniformly across the backbone -- 15b-oshyb45lmilunif-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6
- 45/50 MIL init with uniform transformer layer distribution with binning (similar to Zebra) + 25/50 trained init of overlapping layers -- 15b-oshyb45lmilunif-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6, 15b-oshyb45lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6
Observations:
- 45/50 MIL init with uniform transformer layer distribution with binning (similar to Zebra) seem much better than 45/50 MIL innit w/ layer importance (standard)
- we have these loss spikes that dont recover most likely is they are too larger; they happen at the same place! Most likely this has to do with the learning rate being too large?
- try lower base lr?
- maybe less steps / lower final lr?
Run set
5
TP version of reverse KL loss
- we compare no tp (`15b-notp-bs128-lr0.0003-lrs0-0-0-0-sl2048_ti5000_lm6` ), sequence tensor parallel -- yellow (we split only the sequence dim 15b-stp-bs128-lr0.0003-lrs0-0-0-0-sl2048_ti5000_lm6) and TP (we split vocab dim) (`15b-tp-bs128-lr0.0003-lrs0-0-0-0-sl2048_ti5000_lm6`)
- sequence tensor parallel: we set model.base_model.parallel_embeddings = false and sequence_tensor_parallel=true
- TP (split vocab dim): we set model.base_model.parallel_embeddings = true and sequence_tensor_parallel=true
- this is obtained using mamba2 branch, commit f7f30e13677c6aeae2696190f537615c7da81df6, same as af2961e68cc62a221d697e1b6bf28d84e920d016
- Observations:
- stp and no tp are essentially identical (numerical instability) in loss and in gradient norm
- tp is very different, not clear why, was the same in debuggin runs
Run set
3
Add a comment