Skip to main content

Under the Hood of Long Short Term Memory (LSTM)

This article explores how LSTM works, including how to train them with NumPy, vanish/explode the gradient, and visualize their connectivity.
Created on October 5|Last edited on November 3
In the previous article on recurrent neural networks (Under the hood of RNNs), we went through the training loop of a vanilla recurrent neural network (RNN). We spent much time understanding the feedforward and the backpropagation algorithm. The critical point was that we built the entire network from scratch with NumPy. The article also covered the cons of training RNNs. The short context understanding and the vanishing/exploding gradient problem were visualized as well.

Table of Contents


I hope the reader has gone through the article on RNNs before coming here. The problem statement remains the same as the other article, a character-level text generator.
In this article, we dive deep into the working of Long Short Term Memory (LSTM).

Training the LSTM with NumPy

Data

I began with processing the input data. The input data is picked up from any .txt file provided. The file is read and vocabulary is formed. Vocabulary is the collection of unique characters that are found in the text file. The transformation of characters into numbers was the immediate next step. This is done because the model needs to have numbers to process.
char_to_ix = {c:ix for ix,c in enumerate(vocab)}
ix_to_char = np.array(vocab)
text_as_int = np.array([char_to_ix[c] for c in text])
text_as_int stores the input text in the form of numbers.

Feedforward

Here is a link to the GitHub Repository. Before diving into LSTMs, let us make one thing sure so that the text below makes sense. In a vanilla RNN, the gradient flow was faulty. As a consequence, RNN does not learn a longer context. With LSTMs, we want to counter the problem. The architecture of LSTMs provides a better way for the gradients to backpropagate. In the text below, we would put on our detective cap and head on to a journey of intuitive thinking and logic to attain a better gradient flow.

A single LSTM cell

(figo)=(σσtanhσ)Wl(ht1xt)\begin{pmatrix} f\\ i\\ g\\ o \end{pmatrix} =\begin{pmatrix} \sigma\\ \sigma\\ \tanh \\ \sigma \end{pmatrix} W^{l}\begin{pmatrix} h_{t-1}\\ x_{t} \end{pmatrix}\\

ct=fct1+ight=otanhctc_{t} =f\odot c_{t-1} +i\odot g\\ h_{t} =o\odot \tanh c_{t}

When I started with LSTMs, the image and formulae were not intuitive to me at all. If you feel lost here, I can assure you things will ease up as we go a little further. Here at this juncture, I would like to point out that an LSTM cell has two recurrence states, ctc_{t} the memory highway and the hth_{t}, the hidden state representation.
During feedforward, the inputs (ht1h_{t-1} and xtx_{t}) are concatenated as a single matrix ztz_{t} for the simplicity and efficiency of calculations. This input ztz_{t} is then passed through multiple gates. The functions ff, ii, gg, oo are termed as gates of the LSTM architecture. They provide the intuition of how much of a particular data needs to travel through to make a better representation. A gate in the LSTM architecture is a multilayer perceptron with a non-linearity function. The choice of the non-linearity function will make sense as we study in-depth about the different gates.

The Gates

  • Forget Gate ff: This gate is concerned with how much to forget. It takes in the input ztz_{t} and then decides how much of the previous memory state ct1c_{t-1} should be forgotten. The activation non-linearity is σ\sigma. This means that the gate's output is in the range 0 to 1. 0 means to forget everything, while 1 means to remember everything. This gate acts as a switch for the memory state circuit.
    f=σWf(ht1xt)f=\sigma W_{f}\begin{pmatrix} h_{t-1}\\ x_{t} \end{pmatrix}
    
  • After forgetting, we have the amount of memory state that we need from the previous step:
    ct_f=fct1c_{t\_f}= f\odot c_{t-1}
    
  • Input gate ii: This gate is used to decide how much of the present input needs to flow. This acts as a switch for the present input circuit. This gate also uses a σ\sigma non-linearity function.
    i=σWi(ht1xt)i=\sigma W_{i}\begin{pmatrix} h_{t-1}\\ x_{t} \end{pmatrix}
    
  • Gate gate gg: This gate closely resembles the recurrence formula of a vanilla RNN. We can say that this gate is the hidden state of the RNN in an LSTM. The resemblance to the RNN formula intensifies upon noticing the non-linearity function. This gate is the only one that uses a tanh\tanh function.
    g=tanhWg(ht1xt)g=\tanh W_{g}\begin{pmatrix} h_{t-1}\\ x_{t} \end{pmatrix}
    
  • The usage of the input gate and the gate gate will make sense now. The input gate behaves like a switch to the output of the gate gate:
    ht_i=igh_{t\_i}=i\odot g\\
    
  • Upon pointwise addition of ct_fc_{t\_f} and ht_ih_{t\_i} we get the present memory state. The memory state does not only holds the past and present information but also holds a definite amount of both to make a better representation possible.
    ctl=ct_f+ht_ic^{l}_{t} =c_{t\_f}+h_{t\_i}
    
  • Output gate oo: This gate is responsible for deciding how much output will flow into making the present hidden state.
    o=σWo(ht1xt)o=\sigma W_{o}\begin{pmatrix} h_{t-1}\\ x_{t} \end{pmatrix}
    
  • Let us pass the memory state from a tanh\tanh first.
    ct_o=tanhctc_{t\_o}=\tanh c_{t}
    
  • Then the ct_oc_{t\_o} needs to be elementwise multiplied with the output gate to evaluate how much of the ct_oc_{t\_o} needs to be a part of the hidden state.
    ht=oct_oh_{t}=o\odot c_{t\_o}
    
Feed forward for LSTM

Loss Formulation

After projecting the final hidden state hth_{t} we have the un-normalized log probabilities for each of the characters in the vocabulary. These un-normalized log probabilities are the elements in yty_{t}.
pk=eykjeyjp_{k} =\frac{e^{y_{k}}}{\sum _{j} e^{y_{j}}}

Here pkp_{k} is the normalized probability of the correct class kk. We then apply a negative log\log on this and get the softmax loss of yty_{t}.
Lt=logpk\boxed{\mathcal{L}_{t} =-\log p_{k}}

We take this loss and back-propagate through the network.

Backpropagation

I will assume the reader knows computing the gradients of a softmax function. The reader can refer to the previous article on RNNs.
Visualization of the backpropagation in LSTM
In this stage, we have to back-propagate the softmax loss. We would handhold and traverse through the reversed time step and see the gradients flowing.

Back Propagation

Link to the GitHub repository

  • The gradient of the loss L\mathcal{L} wrt WyW_y: We have projected the last hidden state hfinalh_{final} and computed the softmax loss. This means that the weight matrix WyW_{y} is subject to receive gradients only at the final time step.
    y=WyhfinalyWy=hfinaly=W_{y} h_{final}\\ \frac{\partial y}{\partial W_{y}} =h_{final}
    
  • The gradient:
    LWy=LyyWyLWy=Lyhfinal\frac{\partial \mathcal{L}}{\partial W_{y}} =\frac{\partial \mathcal{L}}{\partial y}\frac{\partial y}{\partial W_{y}}\\ \boxed{\frac{\partial \mathcal{L}}{\partial W_{y}} =\frac{\partial \mathcal{L}}{\partial y} h_{final}}
    
dWy = np.matmul(dy, hs[final].T)
dby = dy
  • The gradient of the loss L\mathcal{L} wrt the present hidden state hth_{t}: Here we have to take two things under consideration. The final hidden state hfinalh_{final} has gradients flowing from the projection head. All hidden states other than hfinalh_{final} have gradients flowing from the next raw hidden state hraw_t+1h_{raw\_t+1}
    y=Wyhfinalythfinal=Wyy=W_{y} h_{final}\\ \frac{\partial y_{t}}{\partial h_{final}} =W_{y}
    
  • The gradient:
    Lhfinal=LyWy+hnext\boxed{\frac{\partial \mathcal{L}}{\partial h_{final}} =\frac{\partial \mathcal{L}}{\partial y} W_{y} +\partial h_{next}}
    
  • On the final time step the hnext\partial h_{next} is taken to be all zeros.
# dhnext is all zeros
dh[final] = np.matmul(Wy.T, dy)+dhnext
For every other time step, the gradient of loss wrt to the hidden state hth_{t} is performed as
Lht=hnext\frac{\partial \mathcal{L}}{\partial h_{t}} =\partial{h_{next}}

dh[t] = dhnext
  • The gradient of the loss L\mathcal{L} wrt the memory state ctc_{t}: Here we need to consider the upstream gradients that are flowing from the time step t+1t+1. This is added to the present gradient that is computed from the gradient of the hidden state ht\partial{h_{t}}.
    ht=ottanhcthtct=ottanhctcthtct=ot(1tanh2ct)h_{t} =o_{t} \odot \tanh c_{t}\\ \frac{\partial h_{t}}{\partial c_{t}} =o_{t}\frac{\partial \tanh c_{t}}{\partial c_{t}}\\ \frac{\partial h_{t}}{\partial c_{t}} =o_{t}\left( 1-\tanh^{2} c_{t}\right)
    
  • The gradient:
    Lct=Lhthtct+cnextLct=Lhtot(1tanh2ct)+cnext\frac{\partial \mathcal{L}}{\partial c_{t}} =\frac{\partial \mathcal{L}}{\partial h_{t}}\frac{\partial h_{t}}{\partial c_{t}} +\partial c_{next}\\ \boxed{\frac{\partial \mathcal{L}}{\partial c_{t}} =\frac{\partial \mathcal{L}}{\partial h_{t}} o_{t}\left( 1-\tanh^{2} c_{t}\right) +\partial c_{next}}
    
dc = dc_next
dc += dh * o[t] * dtanh(tanh(cs[t]))

Backproping the gates

The complicated gradients have been derived above. The gradients for the gates are comparatively easier. In the later section, I would not derive the whole equation. Feel free to derive them in your own time.
  • The gradient of the loss L\mathcal{L} wrt the output gate oto_{t}:
    Lot=Lhttanhct\boxed{\frac{\partial \mathcal{L}}{\partial o_{t}} =\frac{\partial \mathcal{L}}{\partial h_{t}} \tanh c_{t}}
    
  • The gradient of the loss L\mathcal{L} wrt oto^{'}_{t}
    Lot=Lotot(1ot)\boxed{\frac{\partial \mathcal{L}}{\partial o^{'}_{t}} =\frac{\partial \mathcal{L}}{\partial o_{t}} o_{t}( 1-o_{t})}
    
  • The gradient of the loss L\mathcal{L} wrt the gate gtg_{t}
    Lgt=Lctit\boxed{\frac{\partial \mathcal{L}}{\partial g_{t}} =\frac{\partial \mathcal{L}}{\partial c_{t}} i_{t}}
    
  • The gradient of the loss L\mathcal{L} wrt gtg^{'}_{t}
    Lgt=Lgt(1gt2)\boxed{\frac{\partial \mathcal{L}}{\partial g^{'}_{t}} =\frac{\partial \mathcal{L}}{\partial g_{t}}\left( 1-g^{2}_{t}\right)}
    
  • The gradient of the loss L\mathcal{L} wrt the input gateiti_{t}:
    Lit=Lctgt\boxed{\frac{\partial \mathcal{L}}{\partial i_{t}} =\frac{\partial \mathcal{L}}{\partial c_{t}} g_{t}}
    
  • The gradient of the loss L\mathcal{L} wrt iti^{'}_{t}
    Lit=Litit(1it)\boxed{\frac{\partial \mathcal{L}}{\partial i^{'}_{t}} =\frac{\partial \mathcal{L}}{\partial i_{t}} i_{t}( 1-i_{t})}
    
  • The gradient of the loss L\mathcal{L} wrt the forget gate ftf_{t}:
    Lft=Lctct1\boxed{\frac{\partial \mathcal{L}}{\partial f_{t}} =\frac{\partial \mathcal{L}}{\partial c_{t}} c_{t-1}}
    
  • The gradient of loss L\mathcal{L} wrt ftf^{'}_{t}
    Lft=Lftft(1ft)\boxed{\frac{\partial \mathcal{L}}{\partial f^{'}_{t}} =\frac{\partial \mathcal{L}}{\partial f_{t}} f_{t}( 1-f_{t})}
    

Backproping the weights of individual gates

Because all the gates are nothing but multilayer perceptrons, they will have gradients for their respective weight matrices. This section shows the gradients of individual gate weights.
  • The gradient of the loss L\mathcal{L} wrt weight of the output gate WoW_{o}:
    LWo=Lotzt\boxed{\frac{\partial \mathcal{L}}{\partial W_{o}} =\frac{\partial \mathcal{L}}{\partial o^{'}_{t}} z_{t}}
    
  • The gradient of the loss L\mathcal{L} wrt weight of the gate gate WgW_{g}:
    LWg=Lgtzt\boxed{\frac{\partial \mathcal{L}}{\partial W_{g}} =\frac{\partial \mathcal{L}}{\partial g^{'}_{t}} z_{t}}
    
  • The gradient of the loss L\mathcal{L} wrt weight of the input gate WiW_{i}:
    LWi=Litzt\boxed{\frac{\partial \mathcal{L}}{\partial Wi} =\frac{\partial \mathcal{L}}{\partial i^{'}_{t}} z_{t}}
    
  • The gradient of the loss L\mathcal{L} wrt weight of the forget gate WfW_{f}:
    LWf=Lftzt\boxed{\frac{\partial \mathcal{L}}{\partial W_{f}} =\frac{\partial \mathcal{L}}{\partial f^{'}_{t}} z_{t}}
    

Final gradient of input

In this section, we calculate the gradient of the loss with respect to the input. On close observations, we will conclude that the gradient holds the concatenated gradients of ht1h_{t-1} and xtx{t}.
  • The gradient of the loss L\mathcal{L} wrt the input ztz_{t}:
    Lzt=LftWf+LgtWg+LitWi+LotWo\boxed{\frac{\partial \mathcal{L}}{\partial z_{t}} =\frac{\partial \mathcal{L}}{\partial f^{'}_{t}} W_{f} +\frac{\partial \mathcal{L}}{\partial g^{'}_{t}} W_{g} +\frac{\partial \mathcal{L}}{\partial i^{'}_{t}} W_{i} +\frac{\partial \mathcal{L}}{\partial o^{'}_{t}} W_{o}}
    



Run set
2


Vanish/Explode the Gradient

Here we see why LSTMs are preferred over RNNs.
Backprop in RNN

As we had discussed earlier, the problem with the backpropagation in the vanilla RNN is the tanh\tanh non-linearity and the repeated multiplication of the weight matrix WW. This often leads to the vanishing or the exploding gradient problem. The point was further proved by looking at the histograms of gradients of the RNN at various stages of training.


-> Backprop in LSTM <-
Upon looking at the picture above, we notice one thing. The gradients of the memory state ct\partial{c_{t}} flows without much perturbation. Apart from the element-wise multiplication with the forget gate, the gradient ct\partial{c_{t}} flows freely in the circuit provided for the memory state. This is the reason why the memory state is involved in the architecture in the first place. The name gradient highway now makes sense, doesn't it? The architecture has a better flow of gradients than the vanilla RNN.
To prove the point, here I have created the GIFs of histograms. There are two rows, one for the GIFs of the gradient of hidden state ht\partial{h_{t}} and the other is that of the gradient of the memory state ct\partial{c_{t}}. The histograms are formed by taking the histograms at each time step from time-step 25 to time-step 0. To make the visuals more concrete, the histograms are taken from different epochs while the model trains. The epochs are 0, 10, 20, 30, and 40.


Run set
1

Some observations to see here:
  • At the very beginning, both the gradients vanish, which is due to the weight initialization.
  • Later when the model trains, the vanishing problem goes away. This means that the weights are tuned so that the gradients are backpropagated properly.
  • The memory state does not vanish as is expected.

Visualize the Connectivity

In the article Visualizing memorization in RNNs, the authors propose a great tool to visualize the contextual understanding of a sequence model, connectivity between the desired output and all the input. This means the visualization would say which inputs are the reason that we got our desired output.
To visualize the connectivity, the first step in to see the heat map colors. The heat map I have chosen is shown below. The cold connectivity (not so connected) will be transparent and gradually move from light to dark blue. The hot connectivity (strong connection) will be colored red.
The heatmap colors with their intensities of connection
For this experiment, I have chosen a sequence of varying lengths and tried inferring the immediate next character. The inferred character will be colored green. The rest of the sequence will be colored as the heat map chosen.
  • Sequence length 20:
    
RNN
LSTM

  • Sequence length 40:
    
RNN
LSTM
  • Sequence length 100:
    
RNN
LSTM
The connectivity shows quite evidently that LSTMs can pick up on long contexts. The reason they can and RNNs cannot lie in the better backpropagation of loss.

Discussion

This was a tough project for me, due to the fact that LSTMs are not that easy to code from scratch. I purposely did not pursue the analysis of the gates, which might be taken care of in a future article. I would be more than happy to handle doubts about my code and article. Please feel free to comment below.
As a final note, I would like to thank Kyle Goyette for his valuable feedback and suggestions.
Reach out to me - @ariG23498. 
Iterate on AI agents and models faster. Try Weights & Biases today.