Skip to main content

15B thinker distil SSM module distillation

Comparing different SSM modules on thinker distillation
Created on July 15|Last edited on August 18
We replace same single layer at distil using reverse KL.
Observations:
- smaller state is better (16 > 64)
- MIL initialisation is much better



15b-ihyb1lrklm2d64mhwk-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2 -- m2d state 64
15b-ihyb1lrklm2d64-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2 -- m2d 64 state

15b-ihyb1lrklm264xl-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2 -- m2 64, expand 2
15b-ihyb1lrklm216mil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2 -- m2, 16 state + MIL
15b-ihyb1lrklm216ht-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2. -- m2, 16 State, Rand
15b-ihyb1lrklm216xl-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2 -- m2 16 state, expand 2
15b-ihyb1lrklm2d16-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2 -- m2d 16 state

LM losses


Run set
7


Distillation with CE Loss on M2 state 16 w/ and w/o MIL init, 1/50 Hybrid

Observations:
  • we note that for this layer (M2, 16 state, MIL 1/50 15b-ihyb1lcem216mil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm6) when distilling w/ CE the loss does not go down much!
  • to debug we compare to RAND init in the plot below, we see that it is still learning, but MIL init is basically a lower bound for RAND (random does not reach MIL init here but they get very close (if zoom in)


Run set
2

Sanity check
So it seems like MIL is a very goods init, how well would a 5/50 Hybrid MIL perform comapred to 5/50 Hybrid MIL + RKL distill.?
  • 5/50 Hybrid MIL (av. 3 gen tasks) -- 0.77
  • trained version is 0.85
  • Conclusion: definitely distillation works and training helps, just innitialization is already very good but gets much better with trianingf


Here we have the 1/50 Hybrid M1 MIL with CE loss (upper) vs. RKL loss (lower):

  • we see that RKL learns much better, i.e. more stable, goes down more effectively despite soem instabilities in the middle (cannot really compare two losses)
CE:

Run set
1

RKLL:

Run set
1


0-shot distillation on 5/50 with RKL and MIL innit:

  • will 0-shot distill 5/50 with MIL init and RKL (15b-oshyb5lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6 ) reach the loss of iterative 5/50?

Run set
4


45/50 Hybrid

Here we try to distill an ambitious 45/50 Hybrid. We always train all SSM layers here.
Baseline:
  • 25/50 Hybrid distillation run (red and brown) 15b-oshyb25lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti25000_lm6 15b-oshyb25lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6
Options:
  • 45/50 MIL innit w/ layer importance (standard) -- (grey) 15b-oshyb45lmilli-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6
  • 45/50 MIL innit w/ layer importance (standard) + 25/50 init for overlapping layers -- (pink) 15b-oshyb45lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6
Observation:
  • 25/50 init helps at to speed up the beginning, but converges later
  • at 45/50 hybrid level layer importance just keeps the first 5 layers as transformers, everything else is m2, which motivated the next baselines

Options (continuation)
  • 45/50 MIL init with uniform transformer layer distribution with binning (similar to Zebra) -- here we place transformer layers more-or-less uniformly across the backbone -- 15b-oshyb45lmilunif-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6
  • 45/50 MIL init with uniform transformer layer distribution with binning (similar to Zebra) + 25/50 trained init of overlapping layers -- 15b-oshyb45lmilunif-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6, 15b-oshyb45lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6
Observations:
  • 45/50 MIL init with uniform transformer layer distribution with binning (similar to Zebra) seem much better than 45/50 MIL innit w/ layer importance (standard)
  • we have these loss spikes that dont recover most likely is they are too larger; they happen at the same place! Most likely this has to do with the learning rate being too large?
    • try lower base lr?
    • maybe less steps / lower final lr?


Run set
5


TP version of reverse KL loss

  • we compare no tp (`15b-notp-bs128-lr0.0003-lrs0-0-0-0-sl2048_ti5000_lm6` ), sequence tensor parallel -- yellow (we split only the sequence dim 15b-stp-bs128-lr0.0003-lrs0-0-0-0-sl2048_ti5000_lm6) and TP (we split vocab dim) (`15b-tp-bs128-lr0.0003-lrs0-0-0-0-sl2048_ti5000_lm6`)
    • sequence tensor parallel: we set model.base_model.parallel_embeddings = false and sequence_tensor_parallel=true
    • TP (split vocab dim): we set model.base_model.parallel_embeddings = true and sequence_tensor_parallel=true
  • this is obtained using mamba2 branch, commit f7f30e13677c6aeae2696190f537615c7da81df6, same as af2961e68cc62a221d697e1b6bf28d84e920d016
  • Observations:
    • stp and no tp are essentially identical (numerical instability) in loss and in gradient norm
    • tp is very different, not clear why, was the same in debuggin runs



Run set
3