Matrix Factorization from Scratch in JAX: Regularized SVD for Recommendation Systems
Bayesian Hyperparameter Search with Cross Validation for doubly-regularized Matrix Factorization on MovieLens.
Created on January 24|Last edited on February 7
Comment
Table of Contents (Click to Expand)
🟧 What Is Matrix Factorization?
Briefly: given a matrix of all the users and all the items, such that the entries are ratings, Matrix Factorization approximates that matrix as the product of two others providing predictions for entries unobserved in the reference.
Said differently: write all movie-lovers and movie-names in a grid with entries being that movie-lover's rating for that movie, let's hide some of them and try to predict with Machine Learning.
An incredibly common approach for Recommender Systems problems is Matrix Factorization (MF).
This report won't cover Matrix Factorization in depth. Instead, we're going to show how straightforward it is to code up MF in JAX. We'll also demonstrate how Weights & Biases can make regularization and performance validation much easier and interpretable, all via the lens of the famous MovieLens dataset.
JAX combines the convenience of numpy's APIs and form factor with XLA and AutoDiff. Together, these constitute the perfect tools for fitting MF models with stochastic gradient descent (SGD). Let's dig in.
⚙️ Components of Matrix Factorization
Given a user-item matrix, the elements describe the ratings for a given pair. These ratings are expectedly sparse–simply ask yourself if you've seen most of the movies on Netflix.
Due to this sparsity, our goal is to find two matrices whose product faithfully reproduces the non-zero elements and interpret the other elements of the product as predictions.
The key to what makes Matrix Factorization work is that the two factor matrices are of smaller dimension than the original. In particular, and will approximate our user-item matrix . This is an example of an embedding model, with the embedding dimension and two latent spaces which represent similarity in user and item.
In optimization terms, we want to find the two matrices that well approximate our observations; it suffices to minimize:
How we go about this minimization can take advantage of a variety of methods and approaches. For this discussion, we begin with two random matrices, and use a simple Stochastic Gradient Descent (SGD) process. If you'd like a bit more detail on SGD for matrix factorization, I recommend these notes.
JAX
While there are a multitude of MF tutorials in PyTorch and Tensorflow, I haven't seen them for JAX. This is especially surprising considering how well suited JAX is for these methods.
JAX really shines in how it handles Array or Tree-based data, and how it handles AutoDiff. AutoDiff is the differentiation and evaluation of functions in code; in your calculus class you likely saw how to take derivatives of mathematical functions, but the hallmark of modern optimization-based machine learning is Back Propagation which works via currying derivatives through a series of functions. JAX handles this for you for a huge variety of python functions in addition to making it very fast.
Let's begin by examining the loss function above but first, recall that we are only considering those entries for which ratings are known. Another important observation is that when we get ratings data like this, due to the extreme sparsity, storing the entire matrix (primarily made of zeros) is inefficient. It's common instead to store only a list of the entry locations for which the matrix is nonzero. More precisely, we store the observation matrix as:
sp_a = {'indices': [(row, column)],'values': [values]}
Additionally, we store the two parameters we wish to learn ( & ) as:
params = {'users': u, 'items': v}
Now, the first joy of working with JAX appears. Here is the above loss function (notice how straightforward it is):
def sp_mse_loss(A, params):U, V = params['users'], params['items']rows, columns = A['indices']estimator = -(U @ V.T)[(rows, columns)]square_err = jax.tree_map(lambda x: x**2,A['values']+estimator)return jnp.mean(square_err)
But we're not done yet. How might you take the derivative of such a loss function? Would you have to pull out a calculus textbook? JAX can help:
jit(jax.value_and_grad(sp_mse_loss), argnums=1)
Two things that might be mysterious are:
- jit – a compiler that precomputes aspects of your code to dramatically improve performance if used many times in the future
- argnums – an argument that indicates to JAX which variables your function should be differentiated with respect to.
It's worth noting that JAX can figure out how to differentiate this arbitrary code I wrote despite it not following any crucially specific domain-specific language. I simply wrote it in the familiar mathematical way.
The dataset: MovieLens 🎥
The canonical open-source dataset for getting started with recommendation systems is the MovieLens dataset from the Group Lens project. I'll leave their description here:
This data set consists of: * 100,000 ratings (1-5) from 943 users on 1682 movies. * Each user has rated at least 20 movies. * Simple demographic info for the users (age, gender, occupation, zip) The data was collected through the MovieLens web site (movielens.umn.edu) during the seven-month period from September 19th, 1997 through April 22nd, 1998. This data has been cleaned up - users who had less than 20 ratings or did not have complete demographic information were removed from this data set. Detailed descriptions of the data file can be found at the end of this file.
For the purpose of this work, I don't use any of the demographic data. I do some light cleanup of features and store the dataset as a wandb artifact for convenience. If you're interested in seeing all the sausage-making it's contained in this Google Colab.
Regularization 📏
Now we need to look at regularization. As with most ML models, regularization can go a long way for us here. We'll look at two kinds of regularization and the appropriate hyper-parameters for them. Regularization for matrix factorization will look both familiar and a bit different, but in the end easy to compute again thanks to JAX's friendly APIs for matrix operations.
The first is regularization which is simply penalizing proportionally to the square of Frobenius norm of the matrix; which can be viewed as an -norm acting either element-wise, or on the singular values of the matrix. Note that this is the same regularization as in the General Linear Model.
More plainly:
The second regularization is applied to the approximating matrix to keep entries as close to zero as possible.
The second formulation of this term is the "element-wise product of Gramians," which is a popular trick when dealing with matrix products like the above. For more information about the intuition of this formulation, and some inspiration for how to utilize it, consult Krichene et al.
For our discussion, merely observe that the dimension of and are , the embedding dimension.
This means that we're hoping to now optimize:
As usual; and are hyperparameters learned via tuning to understand the optimal regularization.
JAX Implementations
Now that we have the above mathematical descriptions, let's observe another gentle transition to JAX code. I recommend you compare the math and the code, as it will help you appreciate the JAX API.
def ell_two_regularlization_term(params, dimensions):U, V = params['users'], params['items']N, M = dimensions['users'], dimensions['items']user_sq = jnp.multiply(U, U)item_sq = jnp.multiply(V, V)return (jnp.sum(user_sq)/N + jnp.sum(item_sq)/M)
Here's that Gramian loss term:
def gramian_regularization_term(params, dimensions):U, V = params['users'], params['items']N, M = dimensions['users'], dimensions['items']gr_user = U.T @ Ugr_item = V.T @ Vgr_square = jnp.multiply(gr_user, gr_item)return (jnp.sum(gr_square)/(N*M))
And now the entire loss function:
def regularized_omse(A, params, dimensions, hyperparams):lr, lg = hyperparams['ell_2'], hyperparams['gram']losses = {'omse': sp_mse_loss(A, params),'l2_loss': l2_loss(params, dimensions),'gr_loss': gr_loss(params, dimensions),}losses.update({'total_loss': losses['omse'] + lr*losses['l2_loss'] + lg*losses['gr_loss']})return losses['total_loss'], losses
We log a lot of the individual metrics in a dictionary because wandb will make it easy to check on these later and understand how our loss function components are comprised. When attempting to fit hyperparameters, the ability to inspect these different components is priceless.
Cross Validation
As always, it's important to build resistance against the random aspects of training and the data splits. If we temporarily suspend our disbelief about the sequential aspects of the dataset and naively utilize standard k-fold cross-validation, then we can build more robust estimates of our model performances.
I implemented a JAX version of a k-fold class; note that this keeps things in sparse representations:
def sparse_array_concatenate(sparse_array_iterable):return {'indices': tuple(map(jnp.concatenate, zip(*(x['indices'] for x in sparse_array_iterable)))),'values': jnp.concatenate([x['values'] for x in sparse_array_iterable]),}class jax_df_Kfold(object):"""Simple class that handles Kfold splitting of a matrix as a dataframe and stores as sparse jarrays"""def __init__(self,df: pd.DataFrame,user_dim: int,item_dim: int,k: int = 5,prng_key=random.PRNGKey(0)):self._df = dfself._num_folds = kself._split_idxes = jnp.array_split(random.permutation(prng_key, df.index.to_numpy(), axis=0, independent=True),self._num_folds)self._fold_arrays = dict()for fold_index in range(self._num_folds): # let's create sparse jax arrays for each fold pieceself._fold_arrays[fold_index] = (self._df[self._df.index.isin(self._split_idxes[fold_index])].pipe(start_pipeline).pipe(ratings_to_sparse_array, user_dim=user_dim, item_dim=item_dim))def get_fold(self, fold_index: int):assert(self._num_folds > fold_index)test = self._fold_arrays[fold_index]train = sparse_array_concatenate([v for k,v in self._fold_arrays.items() if k != fold_index] )return train, test
Note that I also use a learning rate scheduler as is customary. I simply implemented the simplest one I know. An area for improvement would be to use something more clever:
def lr_decay(step_num,base_learning_rate,decay_pct = 0.5,period_length = 100.0):return (base_learning_rate * math.pow(decay_pct,math.floor((1+step_num)/period_length))
Hyperparameter Optimization
It's hard to find good reference material on hyperparameters for recommendation systems and it's often a challenging problem to implement at scale. Thankfully, Weights & Biases' Sweeps tool is an incredibly powerful mechanism for carrying out an HPO.
Somewhat similar to Galuzzi et al. 2020, we will perform Bayesian Hyperparameter Optimization for our models and datasets. Weights & Biases implements the BHPO for you which means we don't need to get our hands dirty with any Gaussian Processes 😅. The Galuzzi paper focuses on three parameters
the learning rate 𝜂, the regularization factor 𝜆, and the number of latent factors K
In our work we're going to train two regularization terms (for each regularization), the latent dimension, the initial learning rate, and the standard deviation of the random priors.
Matrix factorization is usually randomized with "random priors". This is in lieu of better priors. When you have good priors, use those.
💡
Let's get some high level estimations for a good set of hyperparameters via a big set of cross-validated sweeps!
BHPO Training Curves (Log Scale)
There's quite a lot of plots here but they let us really dig into how the performance looks as training happens. I've moved the CV meta metrics to the top so you can see the mean, max, and min across the folds.
We see that mean error happily converges to values close to 1, which means generally our ratings are quite close to the actuals. We also observe that we get reasonable convergence in the loss components associated to regularization.
Parameter evaluation
In this section we dig even deeper into the fold metrics and ensure our best models don't have an extremely irregular behavior. We also notice that the Bayesian sweeps are working!
As time goes on, the model performance continues to improve and its variance under different parameter selection decreases. The parallel coordinates chart gives us a sense of which domains these hyperparameters should be sampled from to yield a good Test omse_mean. (make sure you select some intervals on the parallel coordinates to see how hyperparameters effect our loss)
Sweep: mf-HPO-CV
50
Refined HPO
Now that the Bayesian sweeps have gotten us closer to a good range of parameter values, let's return to random hyperparameter search and continue to understand the effects of our hyperparameters.
For a more thorough evaluation of hyperparameter tuning, random sweeps give us a finer estimate of parameter importance in the reduced domains. (Note: for a different approach to hyperparameter optimization in RecSys, check out this paper.)
Sweep: mf-HPO-CV 1
50
Sweep: mf-HPO-CV 2
0
Conclusions
So where do we go from here? Well, there are two major remaining considerations:
- How should we handle the pre-quential nature of recommendation system problems?
- What do these model outputs actually look like?
In the next part, we'll address both. Seeing not only how to make a more thoughtful validation set, but then how to use Weights & Biases to evaluate the performance of a model under various circumstances. We'll see how a latent space looks, and how to really evaluate the quality of different recsys models
Additional resources
If you're excited to dig in further, this free course from Google covers the basics of Recommendation systems, this series on Kaggle is a great intro to the field, and this project from NVIDIA is a great open-source code base for building scalable recsys.
Appendix
For a deeper dive into the relationship between some of these hyperparameters, I also generated this report to look at correlation analyses:
Of course, if you're excited to track your Machine Learning experiments, and use a powerful system-of-record for all of your Machine Learning projects, check out Weights & Biases–free for Academic and Personal use.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.