Explained - The Sensory Neuron as a Transformer: Permutation-Invariant Neural Networks For Reinforcement Learning
A guide of the NeurIPS 2021 spotlight paper by Google's David Ha and Yujin Tang: The Sensory Neuron as a Transformer: Permutation-Invariant Neural Networks for Reinforcement Learning.
Created on December 8|Last edited on December 12
Comment
Below we'll be taking a dive into the important paper by Google Research Scientist David Ha and Research Software Engineer Yujin Tang:
The Sensory Neuron as a Transformer: Permutation-Invariant Neural Networks for Reinforcement Learning
The paper is important for a few reasons, but you likely know that already or you wouldn't be here. After all, there's a reason that it was named a spotlight paper at NeurIPS 2021.
This article is written for those of various skill levels in machine learning. If you are advanced and simply want to understand this paper specifically, I would suggest you jump here.
Prior to this point we will be covering some important concepts one should understand to grasp the work as a whole.
We Will Be Covering The Following
We Will Be Covering The FollowingThe Inspiration For The PaperImportant Terms And ConceptsWhat Is A Policy In Machine Learning?What Is A Sensory Neuron?What Is Permutation-Invariance?What Is Self-Organization?What Is Meta-Learning?The Meat Of The PaperSection 3.1: A Mathematical Definition Of Permutation InvarianceSection 3.2: The MagicThe ResultsCart-PoleAtari PongPyBullet AntCar RacingApplications For The Real WorldCredit Where It's DueRelated Reading
The Inspiration For The Paper
The authors did an outstanding job of encapsulating the inspiration of their paper in the opening quote from their writeup of it on the Google AI Blog.
"The brain is able to use information coming from the skin as if it were coming from the eyes. We don’t see with the eyes or hear with the ears, these are just the receptors, seeing and hearing in fact goes on in the brain.” - Paul Bach-y-Rita

We are all familiar with the human brain and the unique capacities and abilities it offers, being held as inspiration by those in the field of machine learning. In fact, one can argue that it is this inspiration that gave birth to the field or machine learning at all.
The human skill that inspired this paper and the research involved with it comes from the capacity of humans to substitute one sense for another (with enough training) in order to overcome fundamental challenges in reinforcement learning.
Basically, this skill:
Consider situations where inputs may be shuffled, reordered, come from systems where the input numbers may be variable and/or have noise added in the form of hidden inputs or extra unrelated or low-value inputs. We will discuss real-world applications for this below, but current state-of-the-art does not perform these tasks well in zero-shot scenarios - nor does it learn particularly well.
But people do. We can learn to ride a backward bicycle (see above), see with our tongue (as pictured above) and much more. We can do this by retraining our brains to accept as inputs, the sensors trained on another task.
Neural networks however, do not have neural plasticity. As the authors note, most reinforcement learning agents require that their inputs be in a pre-specified format and consistent. This is highly limiting.
To tackle this shortcoming, the authors explore a model in which their sensory neurons are left to self-organize in order to add meaning and context to the input signals.
One can think of this akin to a human learning to adapt a skill to another task.
For example, I may know to play blackjack but not poker. While I will need to learn a number of new skills to transition to the new game, I do already understand the fundamentals of card values and that they can work together towards achieving a goal.
I'll lose a lot of money in the learning poker, but I'm farther ahead than someone who has never played cards at all.
Important Terms And Concepts
Before we dive into the paper, there are concepts that are beneficial to understand. Once more I invite those familiar with ML terms and terminology to jump here.
For everyone else, let's first answer a few questions.
What Is A Policy In Machine Learning?
In machine learning, a policy is a formula based on the current environment, the possible set of actions, the probability that the action will result in a state change, and the reward function. The policy is used to steer a model to the highest reward.
It is based on the Markov Decision Process.
What Is A Sensory Neuron?
In a neural network, a sensory neuron (or sensory input neuron) is a node which takes input from "the outside world" and after processing it through the activation function, passes the resulting value along.
What Is Permutation-Invariance?
Permutation-invariance in machine learning refers to a system in which reordering the inputs does not impact the output. Imagine shaking a box of puzzle pieces. They are reordered but will still produce the same finished product when assembled.
As we will see, permutation-invariance can add a great degree of resilience and flexibility to many machine learning tasks.
What Is Self-Organization?
Self-organization in neural networks, describes the ability of a self-supervised system to take local interactions between disorganized parts of itself, and create from that a coherent policy.
What Is Meta-Learning?
Meta-learning in neural networks refers to the approach of using a reward and/or error system to teach said system to solve problems outside its trained domain. Rather than looking directly at the data however, the system instead looks to the output of the algorithm and trains on making predictions based on that.
The Meat Of The Paper
As the authors describe the problem:
"Modern deep learning systems are generally unable to adapt to a sudden reordering of sensory inputs, unless the model is retrained, or if the user manually corrects the ordering of the inputs for the model."
This makes it far less useful in many implementations where the structure/format/number of inputs may be either unknown or variable.
We won't be going through the paper section-by-section, but there are two sections that are specifically notable and for ease of reference are covered individually. They are:
Section 3.1: A Mathematical Definition Of Permutation Invariance
In Section 3.1 we see the formula that makes it all work:
In essence, what we're seeing here is that when a function is applied to the input (a list of tokens), where that input has permutations (s) the result is the same as when the permutation is not applied.
The formula represents the mathematical definition of permutation invariance.
Section 3.2: The Magic
While there is a lot to be taken from much of the paper, section 3.2 got the simple scribbled "magic" in the margin.

It is here that we see the addition of the AttentionNeuron, as well as the methods outlined that the system is based on.
Better than any written description is their Figure 2.

So let's look at what's happening in each stage:
1 - The previous action () and the current observed state () are sent to the AttentionNeuron.
2.1 - The various current input token states are sent to the Sensory Neurons, along with the previous action (each state is sent to a different but identical Sensory Neuron, and the same previous action is sent to each).
The information sent is only the local information. So in the case on an image from a screen in a game of Pong, it would be just a patch of that image.
2.2 - The various Sensory Neurons send values to the key and value vectors. For the uninitiated, key-value pairs are simply the name and value assignments. For example, the key would be the state and previous action, and the value would be that assigned to that state.
The query matrix is decoupled from the key-value pairs in this implementation.
2.3 and 2.4 - In these stages each key's attention coefficient (a value assigned to a feature based on how important it is deemed to be, relative to those around it) is multiplied by the value assigned to that key. The result of each is then added together to create the value passed to .
It is the addition that makes it permutation-invariant. Regardless of order, adding all the values together will result in the same final sum.
3 - While each neuron generates its own message, the attention mechanism aggregates them into one coherent message to be broadcast.
4 - The aggregated message is then passed on to the policy agent where possible next actions are weighed and the most likely to maximize success selected.
5 - And finally the action is taken, and it all starts over again.
The Results
The authors tested the system on four different challenges.
Cart-Pole

The GIF above serves as an excellent illustration of how the system works, and how it works well.
It begins with a standard learning environment, but you will notice that as we shuffle the observations it learns the new inputs quickly.
Compared with current state of the art we see:

For each experiment, they report the average score and the standard deviation from 1000 test episodes. The agent is trained only in the environment with 5 sensory inputs.
The system described in the paper does not perform as well as the Fully-Connected Neural Network out of the box, because it's using LSTM and needs to train on the input environment.
Once the inputs are shuffled however, the FNNs collapse. They are also completely incapable of dealing with a number of inputs outside their training.
The new system also handles noise extremely well. It clearly learns how to drop the attention coefficient to near-zero for those inputs.
Atari Pong

The testing on Pong was very interesting.
Not only did they shuffle the inputs, as illustrated above, but in some instances hid a percentage of the board during both training and testing.
They wrote about the setup:
"... the role of the AttentionNeuron layer is to take a list of unordered observation patches, and learn to construct a 2D grid representation of the inputs to be used by a downstream policy that expects some form of spatial structure in the codes. Our permutation invariant policy trained with BC can consistently reach a perfect score of 21, even with shuffled screens."
And they go on to add:
"Unlike typical CNN policies, our agent can accept a subset of the screen, since the agent’s input is a variable-length list of patches."
An example of how they show only a subset of the shuffled patches is:

And with all of that, they still get great results.

Mean test scores in Atari Pong, and example of a randomly-shuffled occluded observation.} In the heat map, each value is the average score from 100 test episodes.
We can see that when the system is shown anything above 20% in training, it produces superior performance when the full screen is given to it in testing. In fact, with 40% hidden from it in both training and testing, the system still beats the Atari opponent most of the time.
PyBullet Ant

In this task the system once again does not outperform the standard FNN policy due to the lag in training time, as was expected.

PyBullet Ant Experimental Results
Once the inputs were shuffled however, the permutation-invariant system shows its resilience.
Car Racing

The approach to this task was similar to that of Pong, where patches were shuffled and hidden from the system.
To add an additional challenge, the green grass was replace with 4 additional backgrounds to add a layer of complexity.

And the results were excellent:

An interesting observation the authors made:
"In the CarRacing Test Result (from column 2) shows, our agent generalizes well to most of the test environments with only mild performance drops while the baseline method fails to generalize. We suspect this is because the AttentionNeuron layer has transformed the original RGB space to a useful hidden representation that has eliminated task irrelevant information after observing and reasoning about the sequences of (0t ,at - 1) during training, enabling the downstream hard attention policy to work with an optimized abstract representation tailored for the policy, instead of raw RGB patches."
As we saw in the other experiments, the system gets a lower score out of the gate due to training steps to identify the inputs, but performs admirably after that.
To compensate for this, the system scores well with background changes, and was the only one to handle shuffled inputs.
Applications For The Real World
By presenting the system with only limited or variable data, it is built and trained to be resilient.
As the authors write:
"By presenting the agent with shuffled, and even incomplete observations, we encourage it to interpret the meaning of each local sensory input and how they relate to the global context. This could be useful in many real world applications. For example, such policies could avoid errors due to cross-wiring or complex, dynamic input-output mappings when being deployed in real robots. A similar setup to the CartPole experiment with extra noisy channels could enable a system that receives thousands of noisy input channels to identify the small subset of channels with relevant information"
Of course, in all that is good there is a little evil.
The methods and techniques outlined in the paper may help us find UFOs or better adapt to climate change - or it could assist with monitoring communications to stifle freedom of speech.
Credit Where It's Due
It's good to give credit to resources used. Towards the end of this writing I stumbled on a great video by Aleksa Gordić, a newly minted Research Engineer over at DeepMind.
If you learn well through video or still have questions about the paper, it's highly recommended watching.
Related Reading
Intro to Meta-Learning
Getting started with metalearning for image classification
ByT5: What It Might Mean For SEO
Google has released a paper on ByT5, a token-free NLP model that may well revolutionize how content is understood and presented.
Track and Tune Your Reinforcement Learning Models With Weights & Biases
In this article, we learn how to utilize various tools from Weights & Biases for the GridWorld reinforcement learning task, and also shows integration of OpenAI Gym Environment with W&B.
Webinar: RL experiments w/ Stable-Baselines3 and W&B
Have you tried Stable-Baselines3 yet? If not, there’s still time to learn.
Add a comment