Abstract

Recurrent Neural Networks (RNNs) are difficult to train for long term sequences, owing to the exploding and vanishing gradient problem. There are many strategies which have been adapted to solve this problem, such as gradient clipping, constraining model parameters and self-attention. In this report I show how attention is leveraged to solve problems with long sequences, using RNNs augmented with self-attention. I propose that attention allows these problems to be solved by creating skip connections between later points in the sequence and earlier points. Two tasks are considered:

RNNs with attention use both their hidden state and the attention module to solve these problems, I explore the strategies the networks adopt by considering attention heat maps like this:

image.png

to gain intuition on how attention is leveraged to solve these problems and how attention mitigates the vanishing gradient problem.

Code can be found here, feel free to try this at home.

Introduction

Sequential Problems

Long term sequential tasks have long been explored by deep learning researchers to improve and gain understanding into the common issues that present themselves when models are trained on sequential data, like that found in NLP. RNNs, LSTMs, and memory networks have been developed specifically to deal with these problems. In this report, we'll take a closer look at the learning dynamics of RNNs, RNNs with self attention mechanisms and RNNs with orthonormal recurrent weight matrices using Pytorch.

Models

This report will focus on models which make use of recurrence, but will explore memory augmented networks as well, as introduced in Neural Machine Translation by Jointly Learning to Align and Translate and Sparse Attentive Backtracking: Temporal Credit Assignment Through Reminding.

RNNs

RNNs are a simple model for considering sequential data. Their general mechanism of action is shown below.

image.png For example, consider a sequential input $x_1, x_2, ..., x_T$ with $x_t \in \mathbb{R}^{n}$ and output sequence of the same length $y_1, y_2, ..., y_T$ with $y_t \in \mathbb{R}^{k}$. At each time step $t =1$ to $t=T$ an RNN performs the following operations:

$h_t = f(Ux_t + Wh_{t-1} +b)$

$o_t = Vh_t +c$

$\hat{y}_t = softmax(o_t)$

</div>

where $h_t \in \mathbb{R}^m$, $W \in \mathbb{R}^{m \times m}$, $W \in \mathbb{R}^{n \times m}$, $b \in \mathbb{R}^{m}$, $V \in \mathbb{R}^{m \times k}$, $c \in \mathbb{R}^{k}$ and $f()$ is an activation function. The overall function maps: $\mathbb{R}^m \rarr \mathbb{R}^k$.

These models are trained using an algorithm called backpropagation through time (BPTT). Where the gradient from one time step flows through to previous time steps through the hidden state, $h_t$ and therefore the recurrent weight matrix $V$.

The Exploding and Vanishing Gradient in RNNs

While simple, these models suffer from a common problem called the exploding and vanishing gradient problem(EVGP), which is frequently seen with long term sequences. This problem is discussed in some detail here: EVGP. I'll provide a brief explanation of the source of this problem.

Consider an RNN as defined above, and a loss at output time $\tau$ using the cross entropy loss function which is common in these classification tasks. It can be shown that the gradient with respect to the hidden state (critical to the flow of information through time) is proportional to the product of the transposed recurrent weight matrix $W^T$ times the diagonal matrix $D_s$ (containing derivatives of the activation functions of each hidden unit):

$\triangledown_{h_t} L \propto \prod_{s=t+1}^\tau (W^T \cdot D_s) = (W^T \cdot D_s)0 (W^T \cdot D_s)1 ... (W^T \cdot D_s){\tau-1}(W^T \cdot D_s)\tau$

For simplicity let us consider a linear activation unit, such that $D_s = \mathbb{I}$ (the identity matrix, and recall from linear algebra $A =\mathbb{I} A$), moreover, let's get the eigendecomposition of the recurrent weight matrix: $W=Q\Lambda Q^T$ where Q is an orthonormal matrix, and $\Lambda$ is diagonal with values along the diagonal being the eignevalues of $W$. Then:

$\triangledown_{h_t} L \propto (Q^T \Lambda Q)t (Q^T \Lambda Q){t-1} ... (Q^T \Lambda Q)_\tau$

Since $Q$ is orthonormal, $Q^T Q = \mathbb{I}$ this can be simplified to:

$\triangledown_{h_t} L \propto Q^T \Lambda^{\tau-t}Q$

Then when $\tau - t$ is large, we can expect this diagonal matrix ($\Lambda^{\tau - t}$) to explode for eigenvalues greater than 1 and approach zero for eigenvalues less than one.

image.png

So if all eigenvalues were exactly 1, we'd see gradient propragation which neither explodes, nor vanishes (assuming that the true activation function does not cause gradients to vanish too aggressively). This is the case with forcing the recurrent weight matrix to be orthonormal, a model we define as ORNN (and our implementation, expRNN, can be found here). However, this significantly limits the forms that the recurrent weight matrix can take, and reduces how expressive the RNN can be.

Another solution for the EVGP is to provide skip connections between the later time steps and earlier time steps, for example using self-attention. These skip connections allow gradients to flow from time $\tau$ to time $t$ directly, if the model learns to attend from time $\tau$ to time $t$. Which brings us to the topic of self-attention and memory networks.

MemRNN: RNNs with Self-Attention

RNNs can be augmented to contain a memory cell, holding past hidden states, and can use an additional neural network to create connections to these memories. We modify the RNN above as follows:

$h_{t+1} = f(Ux_{t+1} + Ws_t +b)$

where:

$s_t = g(h_t, c_t)$

where:

$c_t = \alpha_{1,t}h_1 + \alpha_{2,t} h_2 + ... + \alpha_{t,t}h_t$

where $\alpha$s are defined as:

$\alpha_{i,t} = \frac{exp(e_{i,t})}{\sum_{j=1}^t exp(exp(e_{j,t})}$

and $e_{i,t} = a(s_{t-1}, h_i)$

$a$ here is called an alignment function and is often done using the aforementioned additional neural network. Additionally, $g()$ can be any number of functions,including addition and concatenation (we use addition here). An example of one of these networks is shown below for common sequence to sequence problems:

image.png

(source: https://ift6135h18.files.wordpress.com/2018/03/10_26_memory_and_attention.pdf)

So the MemRNN works by taking a linear combination of all past states to update the current hidden state, where the strength of each contribution of each past hidden state is defined by the alignment function.

Sparse Attentive Backtracking

Sparse Attentive Backtracking (SAB) is a twist on self-attention augmented RNNs designed to make them faster while maintaining high performance. The model attends only to every $k$th hidden state, and also uses truncated back propagation through time (TBPTT). The model also makes use of sparse attention, focusing attention to a maximum number of hidden states as indicated by a hyper parameter. The imposed sparsity increases training speed significantly.

image.png (source: https://arxiv.org/pdf/1809.03702.pdf)

I adapted the code for the SAB from Rosemary Ke's github repo here.

Tasks

In order to dive into the learning dynamics of these models we'll need to consider tasks with two fundemantal properties: a) long term sequences, where there is a large delay between a relevent input and it's associated label, and b) simple tasks where the attention mechanism is leveraged in an obvious manner, so we can more easily understand what's happening. We have 2 candidate tasks that meet these criteria: the copy task and the denoise task.

The Copy Task

Introduced by Sepp Hochreiter and J├╝rgen Schmidhuber this task has sequential characters fed to the model, followed by a long delay of null inputs, followed by a flag that the model should begin to output the characters in order. An example input and label is shown below:

input = [1, 2, 8, 2, 3, 7, 3, 6, 8, 5, 0,0,...,0,9, 0,0,0,0,0,0,0,0,0,0]

label = [0,0,0,0,...,0,1, 2, 8, 2, 3, 7, 3, 6, 8, 5]

In this case, the 9 is the marker to indicate to the model to begin outputting the initial input sequence, 0 is the null input and the characters are number 1 through 8. We'll refer to each non-zero input to the model as a character, and the outputs as output labels, and all $0$'s input to the model as null inputs.

The Denoise Task

Similar to the copy task, the denoise task (Jing et al. ) generalizes the copy task, having the characters distributed throughout the input sequence.

input = [...,0,1,0,...0, 2,...,0,...,0, 8,0,...,0, 2, 0,..., 0,3,0,...,0,7,0,...,0,3,0,...,0,6,0,...,0,8,0,...,0,5,0,...,0,9, 0,0,0,0,0,0,0,0,0,0]

label = [0,0,0,0,...,0,1, 2, 8, 2, 3, 7, 3, 6, 8, 5]

Similarly, the number 9 is a marker that indicates to the model to begin to output all input characters in order. Once again, we'll refer to each non-zero input to the model as a character, the outputs as output labels, and all $0$'s input to the model as null inputs.

Why these Tasks?

The goal in this report is to explore the attention mechanism used with long term dependencies, something that is fundamental to the design of these two tasks. In the case of the copy task, the dependency between label and associated input is consistently the same size (the length of the delay), whereas in denoise task, this dependency length changes from example to example. The EVGP consistently prevents RNNs (and LSTMs) from performing well on them, as the significant delay between input characters and associated output label prevents their being little meaningful signal during training through BPTT.

Copy task general Results

I show the loss and the accuracy for the best runs of each model below. It's also worth nothing, that these results are the result of an extensive hyperparameter sweep taken for each model. There has yet to be an implementation of the a vanilla RNN that has solved the copytask, except those which make use of orthogonal initializations. On the left the loss for the best model of each type is shown, and on the right the associated accuracy in the predicted sequences is shown.

We can draw a few simple conclusions:

ORNNs

The copytask has been shown to have a simple solution making use of an orthogonal recurrent weight matrix. The solution relies on the consistent delay, and placement of the initial input sequence. We see here that the ORNN leverages this to solve the problem very quickly. For more discussion on the topic one can see Tunable Efficient Unitary Neural Networks (EUNN) and their application to RNNs.

RNNs

RNNs never learn to solve this problem, this is due to the EVGP discussed above, the loss signal associated with each label does not travel far enough thruogh time for the model to learn a significant relationship between the label and the associated input at the beginning of the sequence.

MemRNN

MemRNN can learn to perform at 100% accuracy after around 1000 training steps. As a side note, there is sudden drop in accuracy in the MemRNN at one point during training. This can happen when the model has a sudden shift in attention strategy and will be discussed below.

SAB

SAB slowly learns to perform well on this task. However, it takes significantly longer, and while it could eventually also gain 100 % accuracy, it takes well over 30,000 training steps. It spends some time stuck performing nearly randomly before continuing to learn.

Copy task general Results

Denoise Task General Results

I show the loss and the accuracy for the best runs of each model below. On the left the loss for the best model of each type is shown, and on the right the associated accuracy in the predicted sequences is shown.

We can draw a few simple conclusions:

ORNNs

Despite training for 30,000 steps, the ORNN does not move far beyond random performance in terms of loss. While the model propagates gradients well from the output to the labels, it has a difficult time accounting for the random input points associated with the denoise task.

RNNs

RNNs never learn to solve this problem, this is due to the EVGP discussed above, the loss signal associated with each label does not travel far enough through time for the model to learn a significant relationship between the label and the associated input at the beginning of the sequence.

MemRNNs

This tends to occur when the loss function is not smooth, and can lead to updates moving the network out of its local minimum. However, the model quickly recovers.

SAB

SAB takes significantly longer to train than MemRNN, but does manage to learn the task quite well.

Denoise Task General Results

Attention Visualizations: An Introduction

Despite the fact that attention models have become ubiquitous in modern sequential data processing systems, like machine translation, we often know little of how they are leveraged to improve performance. Attention is tricky to visualize and often, while one can see how it is leveraged once a model is trained, seeing the development of the mechanism throughout training can give us better intuition in how attention is used to solve the problem.

I will visualize attention by creating heatmaps showing the strength of the connection from one hidden state to another hidden state. In these heatmaps, the y value indicates the time step (or where the model attends from) while the x axis will indicate where the model is attending to. The intensity at each point indicates the strength of the connection.

During each forward pass, I calculate $\alpha_{i,t}$ for all i at each time step as a vector, note that for each t=1,...T I generate an alpha vector $\alpha \in \mathbb{R}^t$. and save each vector in a list. Note that here we use the following formulation as our alignment function:

$g(h_j, s_t) = v^T Tanh(V_a s_t + U_a h_j)$

The code that follows can be found here.

def forward(x_t, h_t):
  all_hs = torch.stack(self.memory)
  Uahs = Ua(all_hs)
  energy = torch.matmul(
                        F.tanh(Va(self.st).expand_as(Uahs) + Uahs),
                        v
                       ).squeeze(2)
  alphas = self.softmax(energy)

I store each of these vectors in a list, then zero-pad all of these different length alpha vectors to be the length of the full sequence, and stack them to create the matrix data for a heatmap, as shown in the code below (or can be seen in the repo here):

def construct_heatmap_data(alphas):
    tot_length = len(alphas)
    return torch.stack(
        [
            torch.cat(
                 # zero pad is calculated by subtracting each alpha vector length
                (a[:, 0].clone().detach(), a.new_zeros(tot_length - alpha.shape[0])), 
                dim=0)
            for alpha in alphas
        ],
        dim=0)

Note that it is the zero padding which gives these attention heat maps their characteristic triangular pattern. This could be avoided if the entire sequence is fed into the model prior to performing attention, as in seq2seq models, but the definition of the copy task and denoise task precludes this option, as each time step has an output label (hence it is not seq2seq).

A video of the resulting heatmap evolution is shown below:

Attention Visualizations: An Introduction

Copy Task: Visualized Attention

In this section we compare 3 distinct models:

I add the variation in nonlinearity for MemRNN since we'll see that even small decisions made when selecting your model has a big impact in how attention is used during training.

Each model adopts a different, albeit similar, strategy for solving the copy task. SAB quickly learns to focus on attention within the first 10 hidden states, where the input sequence is located. Then consistently attends to a hidden state just after the sequence until the end of the output sequence. This indicates that the model has stored the relevant input sequence in its hidden state (and not at different points in memory) and simply outputs the stored order when the marker is received.

MemRNN with the ReLu nonlinearity learns to attends to attend primarily to the $0$th hidden state while receiving input characters. Afterwards, it attends to the fifth hidden state until the output marker is received. Once the output marker is received, it focuses on the $t-1$th hidden state (note the high intensity pixel in the lower right) then begins to slightly attend to the 7th, 8th and 9th hidden state before attending to the final hidden states at the very end. Note that there are briefs period in this task where attention is suddenly dispersed far beyond the times when the model is receiving input characters, during the delay period. This coincides with a sudden drop in performance shown in the loss and accuracy plots below, shown in magenta. The similar spikes found for the MemRNN with Tanh likely have similar issues, but these are not captured due to their infrequency.

MemRNN with the Tanh nonlinearity learns to disperse attention during the input of the initial sequence across all previous hidden states before focusing attention to an early hidden state, until the output sequence is called. Once output begins, the model begins to focus on the following time steps in sequence.

All of these strategies indicate that the model stores the sequence of inputs in its hidden state during the delay period, before leveraging attention (in a surprising way) to retrieve these numbers during output. Moreover, each of these attention mechanisms provides a manner for having the gradient from the loss at an output label travel with increased signal to the time associated input character. Attention provides the network a method for learning a relationship between the output label and its much earlier seen input character by connecting the two hidden states more directly.

Starting from the beginning of training, where the attention is uniform, the loss gradient travels somewhat to all previous hidden states. Later in training the attention is sparse, but is provides strong connections to the input characters through maintained attention on the earlier hidden states. This hypothesis is difficult to convey in traditional paper journals, but the videos and interactive plots allows one to gain an intuition.

Copy Task: Visualized Attention

Denoise Task: Visualized Attention

In this section we compare the same 3 distinct models:

All models learns slightly different strategies to leverage attention to solve the denoise task. As discussed previously, SAB does not save each hidden state in memory. As such, it learns to attend to the next closest hidden state each time a new output number is input, and maintains this attention until a new number is input into the model. During output, SAB attends the previously stored hidden states seen during the output phase, with most focus being on the previously saved hidden state. This pattern indicates that SAB is storing the seen non zero inputs in its hidden state, then retrieving them in order.

Both MemRNN networks learn similar strategies of attending to the hidden state where a non zero number is input to the model. However, they differ in one important way: when used with a ReLu non-linearity, the model attends to the first hidden state consistently throughout, with slight attention to each of the non zero inputs at their initial input times as shown by the faint vertical lines seen during development. Conversely, when used with the Tanh nonlinearity the model maintains attention to the hidden state wherein the non-zero was input, shown by the brighter vertical lines in the Tanh attention video. In both cases, the model eventually learns to attend to each of the relevant hidden states at the required output time, shown by the high intensity dots which cascade across the bottom of each heatmap. This learned strategy makes sense, it seems likely that the hidden state when a non zero value is input to the model stores the required information to output that value once it is time to output the sequence. However, the idea of maintaining attention on this hidden state, as seen in the Tanh model, is more robust and therefore less prone to sudden failure. This can be corroborated by seeing that the loss spikes in the MemRNNTanh plot are lower than the spikes in the MemRNNReLu plot.

Once again, these skip connections provide a connection between an output loss and its associated label. From the initial uniform distribution of the attention to the learned sparse nature of the attention, the loss signal is directly connected to its relevant input label through a minimal number of hidden states and non-linearities, thus allowing these models to solve the denoise problem better than RNNs without self-attention.

Denoise Task: Visualized Attention

Conclusion

RNNs are powerful models that handle sequential data. They can be difficult to train on sequences with long term dependencies. While orthogonal RNNs can mitigate the problems associated with the EVGP, they are less expressive since the recurrent weight matrix is kept to the subset of all possible weight matrices that are orthogonal. This problem can be seen in the poor performance on the denoise task by the ORNN, despite gradients flowing very well in these models.

Self-attention can also mitigate these problems by providing dynamic connections that skip over the long term sequences, providing a more direct connection between an input and it's associated output loss. This connection allows gradients to flow through it during training, thus providing a manner for the relationship between the output and input to be formed through back propagation.

Attention has become the nuts and bolts of all our current state of the art NLP models. But many practitioners don't really have a concrete understanding of how attention solves problems. For example, I did not expect there to be a significant difference between the attention strategies employed by the model just by changing the non linearity. New attention techniques are important to explore in order to find better ways of solving problems considering sequential data. The recent SAB model provides significant speed up over the traditional MemRNN architecture. Training time took roughly half as long and could likely be further optimized, but had a trade off in that it can take longer to reach the same performance. If we can find improved attention mechanisms, we could further improve models like GPT and BERT, but designing new attention mechanisms is easier when we understand how the current ones work.

In my next report, I'll show visualizations of gradient flow in these models to bring further intuition to how self-attention solves problems with long term dependencies.