Skip to main content

Sharpened Cosine Similarity in JAX

Some new research into layers better than convolution
Created on March 29|Last edited on March 29
In the past few months, there's been some excitement and fast iteration on an idea from a few years ago.
In 2020, Brandon Rohrer tweeted:


He followed up by providing some excellent discussion of the details of his point. He argues it with a beautiful description of kernels and how they process different kinds of signals. He ultimately turns this intuition into a practical suggestion:
𝑠𝑐𝑑(𝑠,𝑘)=(𝑠𝑘(𝑠+𝑞)(𝑘+𝑞))𝑝𝑠𝑐𝑑(𝑠,𝑘)=\left(\dfrac{𝑠⋅𝑘}{(‖𝑠‖+𝑞)(‖𝑘‖+𝑞)}\right)^𝑝

calling this Sharpened cosine distance.
Recently, Raphael Pisoni wrote a couple nice blog posts:
not only explaining a Keras implementation, but also a JAX one! In the second post you can find some discussions of this applied to MNIST.
It's a fun and exciting story, so I encourage you to check out the work, and try applying this yourself!


Tags: ML News
Iterate on AI agents and models faster. Try Weights & Biases today.