What is Deepmind’s retrieval-based transformer (RETRO) & how does it work?

https://miro.medium.com/max/1200/0*LoUi3wy3JX1KcSJI

Original Source Here

What is Deepmind’s retrieval-based transformer (RETRO) & how does it work?

Retro obtains comparable performance to GPT-3 and Jurassic-1 using 25x fewer parameters.

Photo by Hunter Harritt on Unsplash

About a month ago, Deepmind released a new transformer model called RETRO. What’s special about it? It obtains comparable performance to GPT-3 and Jurassic-1 (which are one the biggest and best state-of-the-art language models) using 25x fewer parameters. There are a lot of tricks and optimizations involved in RETRO which make this possible and those will be explained on a medium-to-high level of detail in this article.

First of all, let’s quickly discuss what main components make up RETRO:

  1. A frozen BERT retriever
  2. A differentiable encoder
  3. A chunked cross-attention mechanism to predict tokens [1]

Several studies have examined retrieval for language modeling and suggested that large-scale [1] language models can use it to improve their performance since they can memorize parts of the training data to a great extent. A problem with this though is the leakage between train and test datasets.

“Data leakage refers to accidentally sharing information between the test and training data sets. Typically, when splitting a data-set into testing and training sets, the goal is to ensure that no data is shared between the two.”

Source: Towards data science

What is RETRO?

We will slowly get into what those components mean. But to quickly bring a bit of context, transformers are powerful language models that use attention [1] to include the past context into their training and prediction. RETRO includes a connection to a large text database that includes 2 trillion tokens. RETRO starts by splitting the input sequence into parts and then grabs [1] the most similar token to that part from the database to improve the predictions on the current part.

An interesting analysis piece shown in the paper is what happens when scaling RETRO. What we typically see with language models is that the bigger the model is, the better. For instance, GPT-3 (probably the most famous language model) is around 175B parameters. But there must be a point at which increasing the number of parameters doesn’t result in better performance (diminishing returns).

Source: Retro paper

Retrieval-enhanced architecture

A key-value database is constructed from a collection of text tokens. The keys are frozen BERT [1] embeddings over the entire database during training. RETRO then retrieves a collection of tokens according to a pre-specified k-nearest neighbor value. An encoder-decoder architecture also facilitates the retrieval of those token collections into the models’ predictions [1].

Source: Retro paper

An important part of RETRO is the retrieval and grouping mechanism. Since the number of tokens in the database is insanely huge, they have to be retrieved efficiently into chunks. Each n-tokens from the input are split into a sequence of chunks of a pre-defined size [1] and each chunk is augmented with a set of k neighbor tokens from the database. A likelihood estimation function is then constructed that measures the probability that the ith token of the u-the chunk [1] only depends on previously seen tokens and on the information from previous chunks.

The main RETRO model

The main style of RETRO is an encoder-decoder transformer architecture [1] which encodes the retrieved tokens into a set of neighbors. And before we get into the next part, we first have to explain what is “cross-attention”. Cross attention masks feature from one model which is then used to highlight the extracted features in another model. In transformers, cross-attention is used when attention is performed on queries generated from an embedding while the keys and values are generated from another embedding. This is different from self-attention where the queries, keys, and values are generated from the same embedding.

The encoding retrieval process encodes each chunk of k neighbors into a bi-directional transformer. The outputs of the encoder include the indexes of each neighbor and the activations of each chunk through cross-attention layers [1]. All of the neighbors are encoded in parallel and the encoder is frozen to avoid re-computing the embeddings over the entire database on each training iteration.

The next part of the process is the chunked cross-attention we an intermediate activation is split into chunks. The chunks contain the embeddings of the last token and the embeddings of the previous group of tokens. Attention is then computed across the neighbors in parallel.

The authors of the paper also claim that using this sort of retrieval database system can come in quite handy to update the training data set without having to re-train the model. Large language models are quite expensive to retrain and so in RETRO simply updating the tokens in the database might be sufficient in that case. They also suggest using this in cases where you need to update the language of the model to avoid any privacy, safety, and fairness issues which are quite common in language models.

Results

I won’t dive too deep into the results since they can be seen in the paper, I will just give a quick overview. The model was evaluated on a couple of famous text datasets including C4, Wikitext103, Curation Corpus, Lambda & the Pile [1].

As for model scaling, the model was scaled from 150 million to 7 billion parameters. RETRO outperforms the baseline transformer model across that range of parameters and they observe that further improvements don’t diminish the performance.

They have also experimented with scaling the retrieval database and found resulting improvements in the model’s performance.

Conclusion

Personally, I look at the retrieval database connection to be sort of a software engineering trick to improve the underlying system rather than a machine learning trick and I like the idea that classic software engineering concepts can be used to boost machine learning models. It also allows for a new area of improvement, for example, the speed of retrieval from the database, the size of the database, the compression of the data stored in the database, etc… which I don’t imagine they tinkered around with those in great detail.

It is also great to see smaller-sized models performing greatly since more data scientists in the community can run smaller-sized models, experiment with them, give overall feedback, and maybe even improve them (which by the way is a great addition to a resume)!

References:

[1] Retro paper

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: