* your likeability not guaranteed, but we will try
** JAX is experimental and not an officially supported Google product. Version lock your code with containerization where possible.

Brief note on terminology

On OkCupid, when you tell us who you like and who you, well, don’t like, we record that on the backend as ‘votes’. For the rest of this, we’ll be talking about ‘voters’ and ‘votees’ to refer to you and the people you’re (not) liking, respectively. Each like or pass is a ‘vote’.

We record a lot of votes!

Motivation

(If you already know about collaborative filtering and SVD approximations, feel free to skip to the Implementation section)

We want to be able to tell, given your history of passing and liking, who might you like that you haven’t seen yet? If we have this, we can leverage it to show you people who you are more likely to like, meaning that you like them, they get a like, maybe they like you, and hopefully it all goes great from there!

The first step to understanding how we’re going to handle this is to put all these votes into an ‘interaction matrix’. Essentially, the rows of this matrix are the ‘voters’ and the columns are the ‘votees’. When a voter likes a votee, we put a 1 at the coordinate corresponding to that (voter, votee) pair and we put a 0 if the voter passes on a votee.

So let's say we have voters A, B, and C, and we've recorded some votes from A on B and C, a vote from B on A, and a vote from C on B.

Our interaction matrix now looks like this, with voters as the rows and votees as the columns:

A B C
A n/a 0 1
B 1 n/a ?
C ? 0 n/a

Simple enough, and now our task is to try and figure out how to ‘fill in’ the question marks here, e.g. how would C vote on A?

So, how do we do this? Well we could try looking for a similar voter to C based on who they're liking and passing on. Once we’ve done this, we can take those similar voter’s votes and extrapolate them to the missing votees for our voter.

This is known as memory-based collaborative filtering, and while good in theory, it is computationally not very scalable. When we have millions of voters and votees making hundreds of millions of votes each week, it's very difficult to do this when a user needs a list of recommendations pronto!

Learning Representations instead

So, let's take an alternative approach, and approximate these similarities.

To do this, we'll represent each voter and votee with a vector. A vector, for our purposes, is just a fixed-length array of numbers. What we'll want from these vectors is that when we take the dot product of them (i.e., multiply each number with the corresponding number in the other vector, take the sum) it corresponds to the outcome we observed (or, would observe) when the voter interacts with the votee.

So, we would want e.g. dot(A_vector, B_vector) = 0 and dot(A_vector, C_vector) = 1. An important note here is that, as you'll notice, we want dot(A_vector, B_vector) != dot(B_vector, A_vector). This doesn't actually work (the dot product is commutative), so we're going to record two vectors for each user - a votee vector (A_votee_vector) and a voter vector (A_voter_vector).

With this, we'll have dot(A_voter_vector, B_votee_vector) = 0 and dot(B_voter_vector, A_votee_vector) = 1.

Sorry B, A's just not that into you.

Ok, so now we've established that we want vectors. Vectors have uh, numbers in them. But, how do we figure out which numbers, exactly? There are so many numbers we could fill them with!

Well, we can learn the numbers from the votes we’ve observed! We’ll do this using Gradient Descent, which can be described as follows:

  1. Randomly initialize everyone’s voter and votee vector
  2. For some number of ‘epochs’ (passes over the observed votes), we'll go through each vote we’ve observed, and
    1. Compute the dot product between the voter's voter vector and the votee's votee vector
    2. Compute the difference (error) between the dot product and the actual outcome
    3. Take the gradient of the error
      • This tells us how much each number in the vector contributed to the error
      • We can use this as a 'direction' to move the vector
    4. Move opposite that direction by subtracting the gradient from the vector
      • This will reduce (descend) the error
      • Notably, we don’t have to do this vote-by-vote, and instead can do a bunch at a time by batching these computations.
  3. Return the vectors we’ve learned for each voter and votee.

So, once this is done, what can we do with these vectors? Well now, when we want to know if a voter will like a votee, we can simply take the dot product of their respective vectors and find out what the model thinks is going to happen!

This is a very fast operation that we can offload to Vespa to provide the recommendations, and it doesn't get slower the more users we have!

But why did we do this process in particular? I’ve skipped a lot of the process, but essentially this whole process is derived from approximating a Singular Value Decomposition (SVD) to reconstruct the voter-votee interaction matrix. If you’re interested in reading more into that part, I'd recommend you read the seminal Netflix prize post that a lot of this work is derived from.

For simplicity, we're omitting things like biases, regularization, shuffling, and so on that we do use in practice.

Scale issues

While this is a lot faster when we want to figure out if a user will like another user, the training process still requires us to loop through and do a gradient update for each vote. While these gradient updates are relatively cheap, they add up with hundreds of millions of votes to train on!

Implementation

So, we need to make this happen quickly. In order to keep everyone’s recommendations fresh, we want to compute (and recompute) these vectors every day. Additionally, we don't want to be stuck with just this simple operation - we want to be able to try out and research different ways to update the vectors using different error functions, training methods, optimizers, and so on.

We’ve encountered quite a few libraries out there that implement this basic algorithm. Below, we’ll introduce our old favorite, and then talk about why we built our own approach using JAX.

The Baseline - Surprise

When we first started this project, we used the popular Surprise library by Nicolas Hug. While Surprise provides many recommendation algorithms, we specifically are interested in its implementation of SVD. Despite being a Python library, it implements all of the critical parts using fancy C-level Python extensions. This makes it a lot faster than pure Python.

However, if we come up with some complicated new error function, we would have to manually compute the corresponding gradient function. Aside from sheer laziness, we want to avoid this is as it's a process that can introduce a lot of difficult to diagnose bugs and hard to maintain code.

Plenty of frameworks out there are built to automatically compute gradients (e.g. TensorFlow, PyTorch), but can we attain this flexibility while maintaining the efficiency of Surprise?

What is JAX? Isn't that in Florida?

JAX provides a simple Python interface for expressing (with a numpy-like API) the operations we want, while also allowing us to convert these functions into their gradient-returning versions!

It's much more than that as well, as it provides a number of ways to optimize our code using Just-In-Time-compilation and vectorization, among other neat features. If you're interested, I recommend reading up on their documentation.

import jax
import jax.numpy as jnp

So let's define the function that does our dot product. Simple enough.

def dot_product(voter_vector, votee_vector):
    return jnp.dot(voter_vector, votee_vector)

And now let's define our error function. For now, we'll be using the squared Euclidean distance between the real outcome and the predicted outcome. This is commonly known as L2 loss.

def l2_loss(voter_vector, votee_vector, real_outcome):
    return 0.5 * jnp.power(
            real_outcome - dot_product(voter_vector,votee_vector),
            2
    )

Now, we'll average the loss over our batch of votes that we're training on using the jax.vmap function transformation to get our final loss function for the batch. jax.vmap is used here to allow us to efficiently run the function on every row of the inputs.

def loss(voter_vectors, votee_vectors, real_outcomes):
    return jnp.mean(jax.vmap(l2_loss)(
        voter_vectors,
        votee_vectors,
        real_outcomes
    ))

This allows us to give two matrices, each of shape (batch_size, vector_size) (for the voter and votee vectors), as well as a vector of shape (batch_size,) (for the real outcomes) to the loss function. This function will then return a scalar value that represents, on average, how 'wrong' the vectors currently are when predicting the outcomes.

So, now how do we get the gradient? We'll use the jax.grad function as so:

grad_fn = jax.jit(jax.grad(loss, argnums=(0, 1)))

This will produce a function that, when given the same arguments as the loss function, returns two matrices of gradients corresponding to the voter_vectors (argnum 0) and votee_vectors (argnum 1). Additionally, we're running all of these functions through jax.jit so that we can take advantage of faster executions once they're compiled!

So, now that we have all of these, let's put together our training loop.

For our data, we'll have votes in the form of a tuple as so:

(voter_index, votee_index, real_outcome)

These indices are mapped (beforehand) to correspond voters and votees to rows in the associated voter_matrix and votee_matrix.

So our training loop looks something like this:

LEARNING_RATE = 1e-3

def train_epoch(voter_matrix, votee_matrix,                        
                voter_indices, votee_indices, 
                real_outcomes,    
                batch_size):         
    for voter_ib, votee_ib, outcomes_b in zip(        
            create_batch(voter_indices, size=batch_size),    
            create_batch(votee_indices, size=batch_size),   
            create_batch(real_outcomes, size=batch_size)):                 
        # these are of shape (batch_size, vector_size)
        # index portion
        voter_vectors_batch = voter_matrix[voter_ib]                    
        votee_vectors_batch = votee_matrix[votee_ib]
        
        # gradient computation portion
        voter_grad, votee_grad = grad_fn( 
            voter_vectors_batch,
            votee_vectors_batch,
            outcomes_b
        )
                                                           
        # Now let's take that gradient step!
        # update portion
        voter_matrix[voter_ib] -= LEARNING_RATE * voter_grad               
        votee_matrix[votee_ib] -= LEARNING_RATE * votee_grad       
    return voter_matrix, votee_matrix                              

Let's call this approach jax_naive. So, how does this do?

ignore the shuffle stuff, that didn't make the cut

Yikes! Even for a small amount of votes we’re spending way more time than Surprise! Can we do better?

Gotta go fast(er)

Ok, well, that clearly did not go well. If we think back to it, can we try and use JAX’s numpy-like API to JIT-compile a lot of the indexing and update parts? Would that be faster?

from jax.ops import index_add
LEARNING_RATE = 1e-3

@jax.jit                                                           
def train_batch(voter_matrix, votee_matrix,                        
                voter_ib, votee_ib, outcomes_b):        
                                                    
    # these are (batch_size, vector_size)
    # index portion
    voter_vectors_batch = voter_matrix[voter_ib]                    
    votee_vectors_batch = votee_matrix[votee_ib]
    
    # gradient computation portion
    voter_grad, votee_grad = grad_fn( 
        voter_vectors_batch,
        votee_vectors_batch,
        outcomes_b
    )
                                                                
    # update portion                                               
    new_voter_matrix = index_add(
        voter_matrix,
        voter_ib, 
        -LEARNING_RATE * voter_grad
    )
    
    new_votee_matrix = index_add(
        votee_matrix, 
        votee_ib, 
        -LEARNING_RATE * votee_grad
    )

    return new_voter_matrix, new_votee_matrix

def train_epoch(voter_matrix, votee_matrix,                        
                voter_indices, votee_indices, 
                real_outcomes,    
                batch_size): 
                
    # create_batch just yields batch_size slices of a Python iterable
    for voter_ib, votee_ib, outcomes_b in zip(        
            create_batch(voter_indices, size=batch_size),    
            create_batch(votee_indices, size=batch_size),   
            create_batch(real_outcomes, size=batch_size)):
            
         voter_matrix, votee_matrix = train_batch(
             voter_matrix, votee_matrix,
             voter_ib, votee_ib, outcomes_b
         )
         
    return voter_matrix, votee_matrix
sorry, this is a lot

Ok, cool. So, a few notes - what's up with the jax.ops.index_add? Why are we assigning the result of that to new_voter_matrix?

One of the key limitations of JIT-compiling a JAX function is that we cannot have any in-place operations. This means that any function that modifies its input will not work. For more on why, see their documentation.

How does this version, which we'll call the amateur_jax_model, measure up?

Welp, we’re still going way too slow. Surprise is beating us handily!

What else can we stuff into a JIT-compiled function?

I should note that we're using a batch_size of 1 for these comparisons. This isn't because that's necessarily a good batch size, but it's because that's what Surprise allows for - so we want to make sure our comparison is fair.

Gotta go fast(est)

Right now we’re copying (remember, no in-place changes) the entire voter matrix and votee matrix into the JIT-compiled function every batch, even though we’re only updating (at most) batch_size rows! When we have millions of voters and votees, this can become expensive.

But what if we could stuff the entire train_epoch into one JIT-compiled function?

from jax.experimental import loops
from jax import lax
from functools import partial

@partial(jax.jit, static_argnums=(5,6))
def train_epoch(voter_matrix, votee_matrix, 
                voter_indices, votee_indices,
                real_outcomes, 
                batched_dataset_size, 
                batch_size):
                
    with loops.Scope() as s:
        s.voter_matrix = voter_matrix
        s.votee_matrix = votee_matrix
        for batch_index in s.range(0, batched_dataset_size):
            batch_start = batch_index * batch_size
            
            # batching part
            voter_ib = lax.dynamic_slice(voter_indices, 
                (batch_start,), (batch_size,)
            )
            
            votee_ib = lax.dynamic_slice(votee_indices, 
                (batch_start,), (batch_size,)
            )
            
            outcomes_b = lax.dynamic_slice(real_outcomes, 
                (batch_start,), (batch_size,)
            )
            
            # update part
            s.voter_matrix, s.votee_matrix = train_batch(
                s.voter_matrix, s.votee_matrix,
                voter_ib, votee_ib,
                outcomes_b
            )
            
    return s.voter_matrix, s.votee_matrix

Wow, so, this looks quite different but at the same time quite similar! This is thanks to the experimental JAX loops API, which allows us to write a stateful-looking loop that gets turned into an in-place (or, pure) function by JAX. All of our mutable state in that loop has to be placed in the loops.Scope object.

But once that's done, we can compile it using jax.jit.

A few things get complicated here however, and the root of the problem is that batch_start keeps changing. This means that we can't use a typical slice operator like voter_indices[batch_start:batch_end] - we have to use lax.dynamic_slice instead. Additionally, we have to specify that batch_size and batched_dataset_size are static values, which means that JAX will recompile train_epoch if they change. So, when calling train_epoch we now have to compute how many batches will be in our training run. Not a big deal for simple gradient descent!

Ok, so, this is very complicated and took me like somewhere around a month of banging my head against the keyboard to get right, but how does this approach (which we'll call the full_jax_model) compare?

the others aren't even worth looking at at this point

Now that’s what I call speed! In fact, we’re seeing that as we hit tens of millions of votes, we’re not slowing down nearly as much as Surprise does! We couldn't have even dreamed of getting to this level of comparison with the earlier versions of the model!

Taking advantage of JIT compilation?

While it's an impressive speedup even with just one epoch, when we start to train for more than one epoch we should see an even more dramatic difference as train_epoch is compiled, and subsequent invocations after the first one should use the already-compiled version.

So, compared to Surprise, we should see JAX take significantly less extra time per epoch, right? Let's look at how many times slower each one is after a few epochs compared to just one.

Well, not quite. As it turns out, relative to one epoch of training, the JAX version here slows down at a similar rate to Surprise, but is doing a bit better. In absolute terms however, we're adding orders of magnitude less seconds per epoch!

JAX land vs. numpy land

One thing that hasn't been made super clear throughout this is that train_epoch isn't returning numpy.ndarray objects, but instead jax.interpreters.xla.DeviceArray objects. While these have largely similar APIs, when profiling JAX models and their JIT-compiled methods, one has to be careful.

Fundamentally, while it may look like the computation is done when train_epoch (the inner portion) is called, the results you get back are actually futures. So, if we want to work with them in numpy land or print them or anything useful like that, we have to 'extract' them. We can see that if we measure the time it takes to run train_epoch a few times on the same matrices, it looks like the amount of time taken doesn't increase with the amount of epochs!

However once we extract these, we can see the amount of time taken jump up. So, be careful when profiling your JAX code! You can read more on this async dispatch stuff here.

Why not use TensorFlow or PyTorch?

While TensorFlow and PyTorch are great for a lot of gradient descent driven problems, there are some key parts of this one that they do not handle well.

  • Since we need to train a vector for every voter and every votee, we can’t downsample our training set, and thus need a fast way to iterate through it
  • Neither TensorFlow nor PyTorch offer optimizations for this
    • tf.Data is focused on speeding up the process of large feature sets entering the training loop, but that’s not the issue we have
  • We’re also interested here in optimizing a lot of disparate parameters
    • Neural networks tend to optimize the same parameters every time
    • We're dealing here with a small number of parameters updated at each step, but from a large matrix of parameters

After considerable research and failed attempts on our end, we determined that there wasn't a viable route for us to achieve this kind of performance in TensorFlow or PyTorch. However it's been a year since we've looked at that, and if anyone knows otherwise feel free to let me know how wrong I am!

Don’t try this at home

Back in December 2019 when COVID-19 was ‘anomalous pneumonia’ and the JAX version was 1.55, I’d built out this method and gotten these great results.

This past week, when I was writing this post and realizing I had a few results I wanted to re-run and explore further, I had an awful finding. My timing results could no longer be reproduced! As it turns out, something changed in the JAX library (I think) that slowed down this method significantly.

That doesn't look good

We’ve filed a ticket for this, but in the meantime if you want to reproduce these results you can do so with the following versions of jax and numpy.

jax==0.1.55
jaxlib==0.1.37
numpy==1.17.5

Conclusion & Further work

While this was definitely an improvement on our baseline, it is still a rather basic form of collaborative filtering for recommendations - but it’s worked out well for us so far, allowing us to train these vectors for every voter and votee seen in the past week over the entire OkCupid site within three hours or so on average!

We’ve leveraged this to provide significant improvements in recommendations for all of our users this year.

But, where do we go from here?

Alternative Training Loops and loss functions

There’s all sorts of interesting ways to order the training set other than just one after the other or random. Commonly, methods like BPR and WARP use the model’s current performance to select ‘hard’ examples that it’s not getting right. These methods may be helpful, and are a direction of future work for us.

Neural Collaborative Filtering

NCF is a promising approach as well, allowing us to take advantage of the oh so lauded deep learning to do a non-linear version of all of this. We were awaiting mainlined pytrees support in loops.Scope however, which has only very recently been added.

Without that support, managing all of the disparate state that NCF requires will be nightmarish.

Better optimization methods

The method we have for turning the gradient into new vectors is fairly crude, and there are plenty of alternatives that show promise - but again, ultimately rely on or are made significantly easier by the aforementioned pytrees support.