Parameter sharing, revisited (again)
TL;DR we evaluate alternatives to classical parameter sharing against tuned LM baselines. Classical sharing fails; alternatives perform decently, but not spectacularly.
What is parameter sharing?
If you want to train a large transformer with limited memory or network speed, you can sometimes get away with layer-wise sharing: reusing the same set of parameters across multiple model layers.
This trick works well in several NLP and vision tasks (Lan et al, 2019, Xue et al), training models outperform their non-shared variants with equal number of parameters, though at the cost of more compute. As a typical example, ALBERT-large uses parameter sharing to outperform BERT-base being 5 times smaller parameter-wise.
When applied to language models, one could train a "shared GPT-3" that would only have ~2.5B parameters and fit into most GPU setups. However, a recent report by bkkaggle hints that parameter sharing performs poorly for language models.
Chapter 1: alternative sharing strategies
In this report, we evaluate ALBERT alongside more advanced parameter sharing techniques invented in subsequent research.
- Classic (ALBERT-like) sharing: repeat the entire transformer layer a given number of times. If there are multiple unique layers, apply all copies of layer 1 first, then all copies of layer 2, 3, etc.
- Cyclic (interleaved) sharing: apply layer 1 once, then layer 2, 3, ..., then layer 1 again, 2 again, etc. This trick was proposed by Takase et al, (2021)
- Share matrices only: let each layer learn personalized LayerNorm scales, biases and other "insignificant" parameters, as proposed in Xue et al, (2021)
- Share with adapters: like previous, but also keep a personalized low-rank adapter for each replica of the same weight matrix. Inspired by Hu et al, (2021).
 
 
Experiment setup
- data: OpenWebText, preprocessed with instructions, see data cooking log;
- model: equivalent of GPT-3 (model configs from Table 2.1) + rotary embeddings (Su et al., 2021), training config from GPT-3 paper, except for 50k total training steps;
- baseline with 1024 hidden / 24 layers (GPT3-Medium ) gets val perplexity 14.5, which is already stronger than most papers report, can be improved to 14.2 using 2x steps;
- detailed config, training scripts and library versions are at the end of the report.
Experiments
Baseline runs:
- baseline-768-tied- GPT-medium with rotary embeddings, weak baseline;
- baseline-1024-tied- GPT-medium with rotary embeddings, main reference;
- shared-equivalent-params-baseline- baseline-1024-tied, but with 16 layers instead of 24, to match the number of parameters of the largest shared model; lower bound for shared model quality;
- baseline-2048-tied- GPT-XL with rotary embeddings; upper bound for quality;
Main sharing runs: each using 3 unique sets of weights in cyclic/interleaved order
- shared-2048-tied-3matrix-albert- sharing entire transformer layers (like ALBERT), 3 unique layers repeated 8 times each.
- shared-2048-tied-3matrix-only-matrices- like previous, but allow each layer to train independent layernorm scales and biases, <1% extra parameters
- shared-2048-tied-3matrix-adapter-32- cyclic sharing with non-shared layernorm scales, biases and 32-dimensional adapters, ~25% more parameters than classic ALBERT sharing
- shared-2048-tied-3matrix-adapter32-outer- like best sharing run, but uses traditional layer order instead of cyclic (i.e. 8x layer1, then 8x layer2, then 8x layer3), stopped early due to poor performance
- shared-2048-tied-3matrix-adapter21-scalable- add a multiplicative component to each low-rank adapter. Project to hidden size 21, then apply two linear layers in parallel: A and B. Y = A(x)WX + B(x). A(x) and B(x) use the same in-to-lowrank projection but separate outward projections, so, the full formula would be (ALX)*(WX) + BLX. Same number of parameters as other adapter experiments.
- shared-2048-tied-3matrix-adapter32-flat-scale- instead multiplying by adapters, use additive adapters and multiply by a flat learned vector.
- baseline-1024-tied-23l-adapter32- ablation to multiplicative adapters: does not use sharing, just a regular gpt2-medium, but with multiplicative adapters to each layer. Has 23 layers instead of 24 to match the number of parameters.
Experiment results:
Findings (chapter 1):
- 
Naive (ALBERT-like) sharing indeed falls short w.r.t. non-shared baseline - this agrees with kbkaggle's report
 
- 
Cyclic/interleaved sharing is indeed better than sequential one - this extends the original findings in Takase et al (2021)
 
- 
Training personalized layernorm scales and biases is a ~free boost to quality - this generalizeds Xue et al (2021) results from vision to language modeling
 
- 
Adding low-rank adapters to each replica improves quality at the cost of a few extra parameters. 
- 
Multiplicative adapters Y = A(x)WX + B(x)) yields the best performance so far: - they exhibit no "chemistry" with sharing - same gains for shared and non-shared models. Looks like multiplicative adapters just work better.
- flat learned scales do not yield any gains, so we believe that that the important part is that weight projections are now themselves input-adaptive. See more details in supmat on failed runs.
 
Chapter 2: intra-layer sharing
Randomized intra-layer sharing. So, here's the idea (source: TimDettmers): previous runs suggest that adapters can significantly improve regular shared and non-shared models alike. If adapters are so good, maybe we can rely on them alone?
Perspective 1: this time, the "base" network layers are not trained at all. They contain exclusively random weights - and the adapters learn to shape these random weights into submission. We generate uint8 weights on the fly to save memory. [Tim: we could even generate them inside the GEMM kernel for accelerated training].
These uint8 weights are followed by 3 types of adapters:
- look-up adapters: associate each uint8 value with a learned scalar from a look-up table with 256 parameters. Use separate look-up tables for either each tile or each row of the weight matrix.
- affine adapters: same as the (best-performing) multiplicative adapters from the previous section, see shared-2048-tied-3matrix-adapter21-scalableandbaseline-1024-tied-23l-adapter32.
Perspective 2: since we use look-up table for each row/tile, we can view this as using randomized sharing inside that row/tile. All weights that have int8 code 42 will use the same weight from look-up table. In our runs, each parameter is reused ~16 times (on average), but due to the random nature of int8 generator, the actual number varies within the [5,27] range.
Experiments (chapter 2)
We use the same experiment setup as before.
Baselines:
- baseline-1024-tied- GPT2-medium, 24 layers with 1024 hidden, 12 heads per layer
- baseline-768-tied- GPT2-small, 12 layers with 768 hidden, 16 heads per layer
- shared-equivalent-params-baseline- baseline-1024, but with 16 layers instead of 24, carried over from chapter 1
- shared-2048-tied-3matrix-albert- carried over from chapter 1
- shared-2048-tied-3matrix-adapter21-scalable- carried over from chapter 1
Randomized sharing runs:
- random-2048-grid-and-tiles- int8 random matrix + tile-wise look-up table (256 values per each 64x64 block in weight matrix) + A(x)WX + B(x) adapters.
- random-2048-grid-and-rows- same as above, but use look-up tables per row instead of per tile. Each output neuron gets 176 look-up values. The number of look-up values was reduced to match the number of parameters.
Findings (chapter 2):
- 
Procedural weight generation aka randomized sharing performs better than baseline, but worse than best layer-wise sharing strategies. About as good as ALBERT. 
- 
Per-row and per-tile lookup tables perform approximately equally, per-row is marginally better. However, per-row sharing has computational advantages (see below) 
Why ever use this method? okay, so far we know that randomized sharing performs worse than sharing with B(x)WX + A(x) adapters. However, this sharing strategy can be used to construct efficient training & inference code:
- Efficient training on GPU: generate matrix tiles inside fast shared memory (instead of global memory) to better utilize tensor cores.
- Efficient offloading: regular layer-wise sharing does not mix well with offloading -- it needs ~100% parameters in RAM since they are all used for each layer. In turn, the randomized sharing would allow you to load one layer at a time -- or, more specifically, one set of lookup tables at a time.
Disclaimer: the ideas for efficient computation (as well as randomized sharing in general) were invented by TimDettmers - and then summarized by a hedgehog to the best of its ability. If what you're reading doesn't make sense, go ask Tim for a more technical explanation.
Supplementary materials
Failed experiment branches
This section contains a number of studies that all concluded in "and then it performed poorly". You can find the corresponding runs in https://wandb.ai/learning-at-home/LM_OWT
- baseline-1024-tied-23l-adapter32-attn-only: since we know that adapters help both shared and non-shared models, we tried to check whether they help in attention or FFN layers. We found no distinction: input-adaptive multiplication yields improvements in both attention and ffn, with attention-only adapters yielding around 40% of the full quality gain (from having adapters in both attention and FFN).
- shared-1024-tied-3matrix-adapter4-tilewise: instead of adding adapters to the full weight matrix, have separate adapters for each tile (e.g. 256x256). Massive slowdown, no improvement.
- shared-1024-tied-3matrix-tilegen: generate weights tile-by-tile using a learned 2-layer neural network applied to initial random generated by a layer-specific seed. Massive slowdown, worse quality.
- random-2048-grid-and-tiles-adapter-rank1randomized sharing where instead of A(x)WX+B(x) adapters you have tile-wise rank-1 adapters in each cell. For each tile, construct weight matrix as (w[i, j] a[i] b[j] + c[i] * d[j]). Massive slowdown, worse quality.
- random-1024-grid-iclr20randomized sharing that adds sparse non-shared matrix to shared baseline matrix. The sparse matrix is guaranteed to have the same number of parameters as adapter32 (64 per row). Based on ideas from "Sparse training from scratch" https://arxiv.org/abs/1907.04840 . Did not improve other simple adapters, but added a layer of complexity to the training procedure.
Model config for LeanTransformer (sharing with adapters)
{
  "model_type": "lean_gpt",
  "architectures": [ "LeanGPTModel" ],
  "num_hidden_layers": 24,
  "num_hidden_groups": 24,
  "share_large_matrices": true,
  "num_inner_matrices": 3,
  "adapter_dim": 32,
  "num_inner_groups": 1,
  "hidden_size": 2048,
  "embedding_size": 2048,
  "intermediate_size": 8192,
  "num_attention_heads": 16,
  "vocab_size": 50308,
  "hidden_act": "gelu_fused",
  "position_embedding_type": "rotary",
  "tie_word_embeddings": true,
  "hidden_dropout_prob": 0,
  "attention_probs_dropout_prob": 0,
  "layer_norm_eps": 1e-12,
  "pad_token_id": 1,
  "bos_token_id": 0,
  "eos_token_id": 2
}
Environment setup script (runs on ubuntu 16/18/20)
#!/usr/bin/env bash
set -euxo pipefail
############################################################################
# core libraries
############################################################################
apt-get update --allow-unauthenticated --allow-insecure-repositories
apt-get install -y --no-install-recommends \
    build-essential \
    g++ gdb subversion \
    software-properties-common
apt-get install -y --no-install-recommends \
    wget curl vim nano ssh git libssl-dev
apt-get remove -y swig || true
apt-get install -y --no-install-recommends libstdc++6
apt-get install -y --no-install-recommends swig3.0
ln -s /usr/bin/swig3.0 /usr/bin/swig
############################################################################
# install anaconda (because native python stopped working
############################################################################
wget https://repo.anaconda.com/archive/Anaconda3-2021.11-Linux-x86_64.sh
bash Anaconda3-2021.11-Linux-x86_64.sh -b -p /anaconda3
source /anaconda3/bin/activate
############################################################################
# common python libraries (project specfic libs these are installed later)
############################################################################
conda update -y conda
conda install -y python=3.8.12 --strict-channel-priority
conda install -y numpy scipy cython pandas h5py numba
pip install --upgrade setuptools
# common + devops
pip install \
    PyYAML==5.4.1 \
    Pillow==8.3.0 \
    docopt==0.6.2 \
    typer==0.3.2 \
    black==21.6b0 \
    bokeh==2.4.0dev1 \
    isort==5.9.1 \
    icecream==2.1.1 \
    flake8==3.9.2 \
    uvloop==0.15.2 \
    packaging==19.0 \
    msgpack==0.5.6 \
    sortedcontainers==2.4.0 \
    configargparse==1.2.3 \
    tqdm==4.48.2 \
    termcolor==1.0.0
# common data science libs
pip install \
    ninja==1.10.0.post1 \
    tensorboardX==2.4 \
    wandb==0.10.33 \
    matplotlib==3.4.2 \
    seaborn==0.11.1 \
    holoviews==1.14.4 \
    plotly==5.1.0 \
    jupyterlab==3.0.16
# pytorch utils
conda install -y cudatoolkit=11.3 -c pytorch
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install https://github.com/huggingface/transformers/archive/3dc82427166239e2764196c07fa4c5dcc25b1590.zip # 4.18.dev0
pip install datasets==2.0.0
pip install \
    torch_optimizer==0.1.0 \
    revlib==1.7.0 \
    bitsandbytes-cuda113==0.26.0 \
    pytorch-lightning==1.3.8 \
    triton==1.0.0 \
    einops==0.3.2 \
    libzero==0.0.5
# domain-specific ML libs
pip install \
    opencv-python==4.4.0.42 \
    albumentations==1.0.0 \
    scikit-image==0.17.2 \
    lmdb==1.2.1 \
    librosa==0.7.0 \
    sentencepiece==0.1.96 \
    nltk==3.6.2 \
    gensim==4.0.1 \
    sacrebleu==1.5.1 \
    sacremoses==0.0.45 \
    subword-nmt==0.3.7 \
    youtokentome==1.0.6
pip uninstall -y enum34
############################################################################
# Set locale
############################################################################
locale-gen ru_RU.UTF-8
update-locale
############################################################################
# Clean
############################################################################
apt-get autoremove
apt-get clean
apt-get autoclean
rm -rf /var/lib/apt/lists/*
rm -rf /tmp/*
rm -rf /.cache
rm -rf /var/cache/apt/*.bin
find /var/log -iname '*.gz' -delete
find /var/log -iname '*.1' -delete
###########################################################################
# project-specific libraries (aka YOUR CODE HERE)
###########################################################################
# hivemind dependencies
pip install \
    prefetch_generator>=1.0.1 \
    grpcio>=1.33.2 \
    grpcio-tools>=1.33.2 \
    multiaddr>=0.0.9 \
    pymultihash>=0.8.2 \
    cryptography>=3.4.6 \
    pydantic>=1.8.1 \
    whatsmyip
pip install razdel==0.5.0
# golang
wget https://golang.org/dl/go1.16.4.linux-amd64.tar.gz
rm -rf /usr/local/go && tar -C /usr/local -xzf go1.16.4.linux-amd64.tar.gz
export PATH=$PATH:/usr/local/go/bin
pip install omegaconf==2.0.5 antlr4-python3-runtime==4.8 hydra-core==1.0.7
Training script (tuned for 8 gpus, otherwise adjust batch size to get 2^20 tokens/step)
pip install https://github.com/learning-at-home/lean_transformer/archive/4f0c8beaec17ca35e28929132ad171485bff732b.zip
git clone https://github.com/justheuristic/junk -b fairseq fairseq
cd fairseq
python setup.py develop --prefix=.
PYTHONPATH=`pwd`:$PYTHONPATH python fairseq_cli/train.py \
    $INPUT_PATH/data-bin/openwebtext --task language_modeling --arch lean_lm --hf-model-config $SOURCE_CODE_PATH/model_config.json \
    --max-tokens 16384 --update-freq 8 --max-update 50000 --tokens-per-sample 2048 --sample-break-mode none \
    --ddp-backend pytorch_ddp --distributed-world-size $NUM_GPUS --seed 4 \
    --amp --fp16-no-flatten-grads --min-loss-scale 1e-10 --fp16-scale-window 250 \
    --lr-scheduler cosine --lr 0.0002 --warmup-init-lr 0.0 --warmup-updates 5000 \
    --optimizer adam --weight-decay 0.1 --clip-norm 1.0 --adam-betas "(0.9, 0.95)" --adam-eps 1e-08 \
    --save-dir $SNAPSHOT_PATH --save-interval-updates 1000 --keep-best-checkpoints 0 --no-epoch-checkpoints --keep-interval-updates 2 \
    --valid-subset valid,valid_1b,valid_lambada,valid_ccnews,valid_wiki,valid_wiki2,valid_ptb --validate-interval-updates 1000 \
    --log-format simple --log-interval 50 --wandb-project $WANDB_PROJECT