* your likeability not guaranteed, but we will try
** JAX is experimental and not an officially supported Google product. Version lock your code with containerization where possible.
Brief note on terminology
On OkCupid, when you tell us who you like and who you, well, don’t like, we record that on the backend as ‘votes’. For the rest of this, we’ll be talking about ‘voters’ and ‘votees’ to refer to you and the people you’re (not) liking, respectively. Each like or pass is a ‘vote’.
We record a lot of votes!
(If you already know about collaborative filtering and SVD approximations, feel free to skip to the Implementation section)
We want to be able to tell, given your history of passing and liking, who might you like that you haven’t seen yet? If we have this, we can leverage it to show you people who you are more likely to like, meaning that you like them, they get a like, maybe they like you, and hopefully it all goes great from there!
The first step to understanding how we’re going to handle this is to put all these votes into an ‘interaction matrix’. Essentially, the rows of this matrix are the ‘voters’ and the columns are the ‘votees’. When a voter likes a votee, we put a 1 at the coordinate corresponding to that (voter, votee) pair and we put a 0 if the voter passes on a votee.
So let's say we have voters A, B, and C, and we've recorded some votes from A on B and C, a vote from B on A, and a vote from C on B.
Our interaction matrix now looks like this, with voters as the rows and votees as the columns:
Simple enough, and now our task is to try and figure out how to ‘fill in’ the question marks here, e.g. how would C vote on A?
So, how do we do this? Well we could try looking for a similar voter to C based on who they're liking and passing on. Once we’ve done this, we can take those similar voter’s votes and extrapolate them to the missing votees for our voter.
This is known as memory-based collaborative filtering, and while good in theory, it is computationally not very scalable. When we have millions of voters and votees making hundreds of millions of votes each week, it's very difficult to do this when a user needs a list of recommendations pronto!
Learning Representations instead
So, let's take an alternative approach, and approximate these similarities.
To do this, we'll represent each voter and votee with a vector. A vector, for our purposes, is just a fixed-length array of numbers. What we'll want from these vectors is that when we take the dot product of them (i.e., multiply each number with the corresponding number in the other vector, take the sum) it corresponds to the outcome we observed (or, would observe) when the voter interacts with the votee.
So, we would want e.g.
dot(A_vector, B_vector) = 0 and
dot(A_vector, C_vector) = 1. An important note here is that, as you'll notice, we want
dot(A_vector, B_vector) != dot(B_vector, A_vector). This doesn't actually work (the dot product is commutative), so we're going to record two vectors for each user - a votee vector (
A_votee_vector) and a voter vector (
With this, we'll have
dot(A_voter_vector, B_votee_vector) = 0 and
dot(B_voter_vector, A_votee_vector) = 1.
Sorry B, A's just not that into you.
Ok, so now we've established that we want vectors. Vectors have uh, numbers in them. But, how do we figure out which numbers, exactly? There are so many numbers we could fill them with!
Well, we can learn the numbers from the votes we’ve observed! We’ll do this using Gradient Descent, which can be described as follows:
- Randomly initialize everyone’s voter and votee vector
- For some number of ‘epochs’ (passes over the observed votes), we'll go through each vote we’ve observed, and
- Compute the dot product between the voter's voter vector and the votee's votee vector
- Compute the difference (error) between the dot product and the actual outcome
- Take the gradient of the error
- This tells us how much each number in the vector contributed to the error
- We can use this as a 'direction' to move the vector
- Move opposite that direction by subtracting the gradient from the vector
- This will reduce (descend) the error
- Notably, we don’t have to do this vote-by-vote, and instead can do a bunch at a time by batching these computations.
- Return the vectors we’ve learned for each voter and votee.
So, once this is done, what can we do with these vectors? Well now, when we want to know if a voter will like a votee, we can simply take the dot product of their respective vectors and find out what the model thinks is going to happen!
This is a very fast operation that we can offload to Vespa to provide the recommendations, and it doesn't get slower the more users we have!
But why did we do this process in particular? I’ve skipped a lot of the process, but essentially this whole process is derived from approximating a Singular Value Decomposition (SVD) to reconstruct the voter-votee interaction matrix. If you’re interested in reading more into that part, I'd recommend you read the seminal Netflix prize post that a lot of this work is derived from.
For simplicity, we're omitting things like biases, regularization, shuffling, and so on that we do use in practice.
While this is a lot faster when we want to figure out if a user will like another user, the training process still requires us to loop through and do a gradient update for each vote. While these gradient updates are relatively cheap, they add up with hundreds of millions of votes to train on!
So, we need to make this happen quickly. In order to keep everyone’s recommendations fresh, we want to compute (and recompute) these vectors every day. Additionally, we don't want to be stuck with just this simple operation - we want to be able to try out and research different ways to update the vectors using different error functions, training methods, optimizers, and so on.
We’ve encountered quite a few libraries out there that implement this basic algorithm. Below, we’ll introduce our old favorite, and then talk about why we built our own approach using JAX.
The Baseline - Surprise
When we first started this project, we used the popular Surprise library by Nicolas Hug. While Surprise provides many recommendation algorithms, we specifically are interested in its implementation of SVD. Despite being a Python library, it implements all of the critical parts using fancy C-level Python extensions. This makes it a lot faster than pure Python.
However, if we come up with some complicated new error function, we would have to manually compute the corresponding gradient function. Aside from sheer laziness, we want to avoid this is as it's a process that can introduce a lot of difficult to diagnose bugs and hard to maintain code.
Plenty of frameworks out there are built to automatically compute gradients (e.g. TensorFlow, PyTorch), but can we attain this flexibility while maintaining the efficiency of Surprise?
What is JAX? Isn't that in Florida?
JAX provides a simple Python interface for expressing (with a numpy-like API) the operations we want, while also allowing us to convert these functions into their gradient-returning versions!
It's much more than that as well, as it provides a number of ways to optimize our code using Just-In-Time-compilation and vectorization, among other neat features. If you're interested, I recommend reading up on their documentation.
import jax import jax.numpy as jnp
So let's define the function that does our dot product. Simple enough.
def dot_product(voter_vector, votee_vector): return jnp.dot(voter_vector, votee_vector)
And now let's define our error function. For now, we'll be using the squared Euclidean distance between the real outcome and the predicted outcome. This is commonly known as L2 loss.
def l2_loss(voter_vector, votee_vector, real_outcome): return 0.5 * jnp.power( real_outcome - dot_product(voter_vector,votee_vector), 2 )
Now, we'll average the loss over our batch of votes that we're training on using the
jax.vmap function transformation to get our final loss function for the batch.
jax.vmap is used here to allow us to efficiently run the function on every row of the inputs.
def loss(voter_vectors, votee_vectors, real_outcomes): return jnp.mean(jax.vmap(l2_loss)( voter_vectors, votee_vectors, real_outcomes ))
This allows us to give two matrices, each of shape
(batch_size, vector_size) (for the voter and votee vectors), as well as a vector of shape
(batch_size,) (for the real outcomes) to the
loss function. This function will then return a scalar value that represents, on average, how 'wrong' the vectors currently are when predicting the outcomes.
So, now how do we get the gradient? We'll use the
jax.grad function as so:
grad_fn = jax.jit(jax.grad(loss, argnums=(0, 1)))
This will produce a function that, when given the same arguments as the
loss function, returns two matrices of gradients corresponding to the
argnum 0) and
argnum 1). Additionally, we're running all of these functions through
jax.jit so that we can take advantage of faster executions once they're compiled!
So, now that we have all of these, let's put together our training loop.
For our data, we'll have votes in the form of a tuple as so:
(voter_index, votee_index, real_outcome)
These indices are mapped (beforehand) to correspond voters and votees to rows in the associated
So our training loop looks something like this:
LEARNING_RATE = 1e-3 def train_epoch(voter_matrix, votee_matrix, voter_indices, votee_indices, real_outcomes, batch_size): for voter_ib, votee_ib, outcomes_b in zip( create_batch(voter_indices, size=batch_size), create_batch(votee_indices, size=batch_size), create_batch(real_outcomes, size=batch_size)): # these are of shape (batch_size, vector_size) # index portion voter_vectors_batch = voter_matrix[voter_ib] votee_vectors_batch = votee_matrix[votee_ib] # gradient computation portion voter_grad, votee_grad = grad_fn( voter_vectors_batch, votee_vectors_batch, outcomes_b ) # Now let's take that gradient step! # update portion voter_matrix[voter_ib] -= LEARNING_RATE * voter_grad votee_matrix[votee_ib] -= LEARNING_RATE * votee_grad return voter_matrix, votee_matrix
Let's call this approach
jax_naive. So, how does this do?
Yikes! Even for a small amount of votes we’re spending way more time than Surprise! Can we do better?
Gotta go fast(er)
Ok, well, that clearly did not go well. If we think back to it, can we try and use JAX’s numpy-like API to JIT-compile a lot of the indexing and update parts? Would that be faster?
Ok, cool. So, a few notes - what's up with the
jax.ops.index_add? Why are we assigning the result of that to
One of the key limitations of JIT-compiling a JAX function is that we cannot have any in-place operations. This means that any function that modifies its input will not work. For more on why, see their documentation.
How does this version, which we'll call the
amateur_jax_model, measure up?
Welp, we’re still going way too slow. Surprise is beating us handily!
What else can we stuff into a JIT-compiled function?
I should note that we're using a
batch_size of 1 for these comparisons. This isn't because that's necessarily a good batch size, but it's because that's what Surprise allows for - so we want to make sure our comparison is fair.
Gotta go fast(est)
Right now we’re copying (remember, no in-place changes) the entire voter matrix and votee matrix into the JIT-compiled function every batch, even though we’re only updating (at most)
batch_size rows! When we have millions of voters and votees, this can become expensive.
But what if we could stuff the entire
train_epoch into one JIT-compiled function?
from jax.experimental import loops from jax import lax from functools import partial @partial(jax.jit, static_argnums=(5,6)) def train_epoch(voter_matrix, votee_matrix, voter_indices, votee_indices, real_outcomes, batched_dataset_size, batch_size): with loops.Scope() as s: s.voter_matrix = voter_matrix s.votee_matrix = votee_matrix for batch_index in s.range(0, batched_dataset_size): batch_start = batch_index * batch_size # batching part voter_ib = lax.dynamic_slice(voter_indices, (batch_start,), (batch_size,) ) votee_ib = lax.dynamic_slice(votee_indices, (batch_start,), (batch_size,) ) outcomes_b = lax.dynamic_slice(real_outcomes, (batch_start,), (batch_size,) ) # update part s.voter_matrix, s.votee_matrix = train_batch( s.voter_matrix, s.votee_matrix, voter_ib, votee_ib, outcomes_b ) return s.voter_matrix, s.votee_matrix
Wow, so, this looks quite different but at the same time quite similar! This is thanks to the experimental JAX
loops API, which allows us to write a stateful-looking loop that gets turned into an in-place (or, pure) function by JAX. All of our mutable state in that loop has to be placed in the
But once that's done, we can compile it using
A few things get complicated here however, and the root of the problem is that
batch_start keeps changing. This means that we can't use a typical slice operator like
voter_indices[batch_start:batch_end] - we have to use
lax.dynamic_slice instead. Additionally, we have to specify that
batched_dataset_size are static values, which means that JAX will recompile
train_epoch if they change. So, when calling
train_epoch we now have to compute how many batches will be in our training run. Not a big deal for simple gradient descent!
Ok, so, this is very complicated and took me like somewhere around a month of banging my head against the keyboard to get right, but how does this approach (which we'll call the
Now that’s what I call speed! In fact, we’re seeing that as we hit tens of millions of votes, we’re not slowing down nearly as much as Surprise does! We couldn't have even dreamed of getting to this level of comparison with the earlier versions of the model!
Taking advantage of JIT compilation?
While it's an impressive speedup even with just one epoch, when we start to train for more than one epoch we should see an even more dramatic difference as
train_epoch is compiled, and subsequent invocations after the first one should use the already-compiled version.
So, compared to Surprise, we should see JAX take significantly less extra time per epoch, right? Let's look at how many times slower each one is after a few epochs compared to just one.
Well, not quite. As it turns out, relative to one epoch of training, the JAX version here slows down at a similar rate to Surprise, but is doing a bit better. In absolute terms however, we're adding orders of magnitude less seconds per epoch!
JAX land vs. numpy land
One thing that hasn't been made super clear throughout this is that
train_epoch isn't returning
numpy.ndarray objects, but instead
jax.interpreters.xla.DeviceArray objects. While these have largely similar APIs, when profiling JAX models and their JIT-compiled methods, one has to be careful.
Fundamentally, while it may look like the computation is done when
train_epoch (the inner portion) is called, the results you get back are actually futures. So, if we want to work with them in numpy land or print them or anything useful like that, we have to 'extract' them. We can see that if we measure the time it takes to run
train_epoch a few times on the same matrices, it looks like the amount of time taken doesn't increase with the amount of epochs!
However once we extract these, we can see the amount of time taken jump up. So, be careful when profiling your JAX code! You can read more on this async dispatch stuff here.
Why not use TensorFlow or PyTorch?
While TensorFlow and PyTorch are great for a lot of gradient descent driven problems, there are some key parts of this one that they do not handle well.
- Since we need to train a vector for every voter and every votee, we can’t downsample our training set, and thus need a fast way to iterate through it
- Neither TensorFlow nor PyTorch offer optimizations for this
tf.Datais focused on speeding up the process of large feature sets entering the training loop, but that’s not the issue we have
- We’re also interested here in optimizing a lot of disparate parameters
- Neural networks tend to optimize the same parameters every time
- We're dealing here with a small number of parameters updated at each step, but from a large matrix of parameters
After considerable research and failed attempts on our end, we determined that there wasn't a viable route for us to achieve this kind of performance in TensorFlow or PyTorch. However it's been a year since we've looked at that, and if anyone knows otherwise feel free to let me know how wrong I am!
Don’t try this at home
Back in December 2019 when COVID-19 was ‘anomalous pneumonia’ and the JAX version was 1.55, I’d built out this method and gotten these great results.
This past week, when I was writing this post and realizing I had a few results I wanted to re-run and explore further, I had an awful finding. My timing results could no longer be reproduced! As it turns out, something changed in the JAX library (I think) that slowed down this method significantly.
We’ve filed a ticket for this, but in the meantime if you want to reproduce these results you can do so with the following versions of
jax==0.1.55 jaxlib==0.1.37 numpy==1.17.5
Conclusion & Further work
While this was definitely an improvement on our baseline, it is still a rather basic form of collaborative filtering for recommendations - but it’s worked out well for us so far, allowing us to train these vectors for every voter and votee seen in the past week over the entire OkCupid site within three hours or so on average!
We’ve leveraged this to provide significant improvements in recommendations for all of our users this year.
But, where do we go from here?
Alternative Training Loops and loss functions
There’s all sorts of interesting ways to order the training set other than just one after the other or random. Commonly, methods like BPR and WARP use the model’s current performance to select ‘hard’ examples that it’s not getting right. These methods may be helpful, and are a direction of future work for us.
Neural Collaborative Filtering
NCF is a promising approach as well, allowing us to take advantage of the oh so lauded deep learning to do a non-linear version of all of this. We were awaiting mainlined pytrees support in
loops.Scope however, which has only very recently been added.
Without that support, managing all of the disparate state that NCF requires will be nightmarish.
Better optimization methods
The method we have for turning the gradient into new vectors is fairly crude, and there are plenty of alternatives that show promise - but again, ultimately rely on or are made significantly easier by the aforementioned pytrees support.