wackpropogation
Direct feedback alignment, but the weight updated is modulated by a local rule that preserves signal from the previous layer.
Created on March 8|Last edited on March 8
Comment
if y is the gradient passed to the neuron from DFA, then the weight update is dw = y(x-yw), instead of the typical yx. I theorize that the rule is encouraging lower loss numbers while also incentivizing preservation of signal across layers in a deep neural network.
Hopefully that explains the strange learning behavior on an RNN. Backprop(through time) quickly outpaces the new rule, only to be overtaken later on. Perhaps the well-documented vanishing gradient problem is what slows the traditional learning rule down. Direct feedback alignment transmits the error directly to each layer, and if my rule does indeed preserve signal across layers, it's able to take advantage of the recurrent connection more effectively.
It's entirely possible that the advantage my rule provides vanishes with scale, though. It does seem like backprop is inexorably marching lower, and it's not entirely clear whether wackprop does, too.
Rules and Learning Rates
Run set
6
Let's look at how I came to run this. I benchmarked on the 0.0001 learning rate. Notice the erratic jumps in loss made before each improved dip down. It's wacky.
Run set
2
I tried out the 0.001 learning rate:
Run set
2
I decided to test out a finer learning rate, hoping that my fancy new rule just needed a little more precision to beat out backprop much later on.
Run set
2
It succeeds at the expense of my sanity. The unoptimized algorithm takes hours to dip closer to backprop, and remains erratic throughout training. I suspect that I ought to be able to significantly accelerate training by using gpus and more parallelization. DFA has the advantage of not needing to operate on one layer at a time.
Speaking of optimizing, neither backprop nor wackprop are using Adam. Maybe backprop+adam wipe the floor? But, maybe wackprop+adam balance the scales once more. For this experiment, I chose to omit adam for simplicity, and because adam wouldn't work without me having to enable gradient calculations earlier in development.
I also do no batching, so some wackiness is probably reductable.
dataset
The dataset is roneneldan/tinystories on huggingface:
I chose it for its significantly lower vocabulary, hoping that it would train easier than a raw text dataset. Some Microsoft study said that it works on transformers.
code
Here's the learning rule:
if self.is_last_layer==True:projected_error = global_errorelse:projected_error = global_error @ self.feedback_weights # assuming matrix multiplication # Apply ReLU derivativerelu_derivative = (self.out_traces.data.squeeze() > 0).float() # 1 for activated neurons, 0 otherwiseprojected_error *= relu_derivative# only evaluate past weight updates with the current reward signalimprint_update = self.candidate_weights.data# Reset or decay candidate_weightsself.candidate_weights.data *= 0.9 # Example: decay by half is 0.5# If dropout was applied during forward pass, apply the same mask here# if self.dropout_mask is not None:# projected_error *= self.dropout_mask# Assuming 'inputs' holds the inputs to the layercandidate_update = projected_error*(input.T - projected_error * self.weight.data.T) # oja's rule, reusedcandidate_update = candidate_update.Tself.candidate_weights.data += candidate_updateimprint_update = imprint_update.squeeze()update = learning_rate * imprint_updateself.weight.data += update
The code is run after every forward pass. Input activations are stored for each layer as `input` and I went ahead and added in a `candidate_weights bit just to complicate things, extending/delaying the effect of each weight update across iterations.
`global_error is the result of autograd on `torch.nn.MSELoss() where I've only allowed gradient on the final layer. If I'm not mistaken it's just the derivative of `expected_output-output. On only the final layer. It's a vector of the same size as the output of the network, so that's why if we're not on the final layer we project it through `self.feedback_weights, which is literally just a random matrix with the right dimensions to fix things.
inspiration
The whole thing is a combo of DFA, Oja's Rule, and my own idea to stretch the weight updates through time.
Add a comment