Methods: LSH
LSH
In normal attention we perform a matrix multiply of all keys and queries. This is the same as taking the dot product of every query with every key, and is the source of the quadratic cost of attention. Dot products will be large when vectors are aligned. We then pass the dot products through a softmax that makes relatively large input yet larger. It makes sense that keys and queries that are close together in our high dimensional space will end up with a high score after the softmax. If we could figure out upfront which vectors were close, we might disregard the ones that were far away, since their softmax product most likely would be small. It turns out that the LSH (Andoni et al.,2015) is a suitable algorithm to cluster high dimensional vectors.
LSH clustering
We could do perfect clustering by comparing each key and query. But that would still leave us with the quadratic cost. LSH instead approximates clustering probabilistically. This means that the cost of it is reduced, but that it can't guarantee perfect accuracy. The intuition behind LSH is described in the blog: https://www.pragmatic.ml/reformer-deep-dive/
1. If we draw a random vector u
and take the dot product with other vectors, we will get positive and negative scores depending if vectors are aligned with u
or not.
2. If we take the sign of this dot product, we have effectively created two "buckets" for our vectors: '+' and '-'
3. If we draw more lines ("hash rounds"), we create additional buckets in the same way. E.g. 3 lines would give us 2^3=8 buckets.
4. Vectors that end up in the same bucket will have a high probability to be close together, and this probability will go up has the number of hash rounds increases.
Note that the reformer paper does LSH via random projections as described in (Andoni et al.,2015).
LSH-attention
The next step is to add LSH clustering to our new attention mechanism. We refer to figure 2 of the reformer paper that gives a clear overview of the main steps:
- First we receive an input sequence of tokens in original order. In normal attention our tokens are projected in to keys and queries, but we'll set keys and queries to be identical. (see chapter nn)
- We then use LSH clustering to produce a bucket ID for each item (i.e. key/query). We sort our items first for bucket id and next for position in the original sequence.
- Group items to equal size chunks, since number of items in each bucket may vary.
- Concatenate chunks with their previous chunk to allow for limited in-between attention. In the diagram we also mask items from other buckets, so that only in-bucket attention is allowed.
- Calculate a normal dot product attention within each concatenated chunk.
- Un-chunk and unsort all items.
LSH attention complexity
According to the reformer paper the time and memory complexities of scaled dot produc and LSH attention are:
Attention Type | Memory Complexity | Time Complexity |
---|---|---|
Scaled Dot-Product | max(bnhldk,bnhl2)max(bn_hld_k, bn_hl^2) | max(bnhldk,bnhl2)max(bn_hld_k, bn_hl^2) |
LSH Attention | max(bnhldk,bnhlnr(4l/nc)2)max(bn_hld_k, bn_hln_r(4l/n_c)^2) | max(bnhldk,bnhnrl(4l/nc)2)max(bn_hld_k, bn_hn_rl(4l/n_c)^2) |
Where:
- ll is the length of the input sequence
- bb is the batch size
- nhn_h is the number of heads in multihead attention.
- ncn_c is the number of LSH chunks.
- nrn_r is the number of hash rounds
- dkd_k is the model dimension
Language models typically benefit from long input sequences (ll) to give the model a longer context(ref), so we often want to make ll large. This means that the cost is usually dominated by the latter part of the max in the table above, i.e. ll is usually larger than dkd_k.
In this case dot product attention has a quadratic cost in both time and memory that depends on l2l^2. If we want to compare LSH-attention to dot product attention, we can compare lnr(4l/nc)2ln_r(4l/n_c)^2 to l2l^2, since bb and nhn_h are common terms. This means that the cost of LSH relative to normal attention is linear w.r.t. nrn_r, but decreases quadratically with ncn_c. In our implementation we did not set ncn_c directly, but rater define a bucket size, bsb_s where ncn_c = l/bsl/b_s.
We can therefor rewrite the lsh complexity as nr(4bs)2ln_r(4b_s)^2l, which is linearly dependent on ll. It also makes intuitive sense that the complexity of LSH-attention grows quadratically with bsb_s, since we do normal dot product attention within each chunk.