Transformers for Tabular Data: TabTransformer Deep Dive*iPDX0cNVz95Ww3cO

Original Source Here

Transformers for Tabular Data: TabTransformer Deep Dive

Photo by Samule Sun on Unsplash


Today, Transformers are the key blocks in most of the state-of-the-art Natural Language Processing (NLP) and Computer Vision (CV) architectures. Nevertheless, tabular domain is still mainly dominated by gradient boosted decision trees (GBDT), so it was only logical that someone will attempt to bridge this gap. The first transformer-based models was introduced by Huang et al. (2020) in their TabTransformer: Tabular Data Modeling Using Contextual Embeddings paper.

This post aims to provide an overview of the paper, deep dive into the model details and show you how to use the TabTransformer with your data.

Paper Overview

The main idea in the paper is that the performance of regular Multi-layer Perceptron (MLP) can be significantly improved if we use Transformers to transforms regular categorical embeddings into contextual ones. Let’s digest this statement a bit.

Categorical Embeddings

A classical way to use categorical features in deep learning models is to train their embeddings. This means that each categorical value gets a unique dense vector representation which can be passed on to the next layers. For instance, below you can see that each categorical feature gets represented using a 4 dimensional array. These embeddings are then concatenated with numerical features and are used as inputs to the MLP.

MLP with categorical embeddings. Image by author.

Contextual Embeddings

Authors of the paper argue that categorical embeddings lack the context meaning that they don’t encode any interactions and relationships between the categorical variables. To contextualise the embeddings it’s proposed to use Transformers which are currently used in NLP for exactly the same purpose.

Contextual Embeddings in TabTransformer. Image by author.

To visualise the motivation, consider the image of trained contextual embeddings below. 2 categorical features are highlighted —relationship (black) and marital status (blue). These features are related, so values of “Married”, “Husband” and “Wife” should be close to each other in the vector space even though they come from different variables.

Example of trained TabTransformer Embeddings. Image by author.

With trained contextual embeddings, we can indeed see that marital status of “Married” is closer to the relationship levels of “Husband” and “Wife”, whereas “non-married” categorical values from separate clusters to the right. This type of context makes these embeddings more useful and it would not have been possible using simple categorical embeddings.

TabTransformer Architecture

With the motivation above in mind, authors propose the following architecture:

TabTransformer architecture. Adapted from Huang et al. (2020)

We can breakdown this architecture into 5 steps:

  1. Numerical features are normalised and passed forward
  2. Categorical features are embedded
  3. Embeddings are passed through Transformer blocks N times to get contextual embeddings
  4. Contextual categorical embeddings are concatenated with numerical features
  5. Concatenation gets passed through MLP to get the required prediction

While the model architecture is quite simple, the authors show that the addition of the Transformer layers can boost the performance quite significantly. The magic, of course, happens inside of these Transformer blocks, so let’s explore them in a bit more detail.


Transformer architecture. Adapted from Vaswani et al. (2017)

You’ve probably seen the Transformers architecture before (and if you haven’t I highly recommend this annotated notebook) but for a quick recap, remember that it consists out of the encoder and the decoder parts (see above). For the TabTransformer, we only care about the encoder part which contextualising the input embeddings (decoder part is transforming these embeddings into the final output). But how exactly does it do it? The answer is — multi-headed attention mechanism.

Multi-Headed Attention

Quoting my favourite article about the attention mechanism:

The key concept behind self attention is that it allows the network to learn how best to route information between pieces of a an input sequence.

In other words, self-attention helps the model to figure out which parts of the input are more important and which ones are less when representing a certain word/category. I highly recommend reading the article quoted above to get a good intuition on why it works so well.

Multi-Headed Attention. Adapted from Vaswani et al. (2017)

Attention gets calculated using 3 learned matrices — Q, K and V which stand for Query, Key and Value. At first, we multiply Q and K to get the attention matrix. This matrix gets scaled and passed through the softmax layer. Afterwards, we multiply it by the V matrix to get out final values. For more intuitive understanding consider the image below which shows how we get from Input Embeddings to Contextual Embeddings using matrices Q, K and V.

Self-Attention flow visualised. Image by author.

By repeating this procedure h times (with different Q, K, V matrices) we get multiple contextual embeddings which form our final Multi-Headed Attention.

Short Recap

I know it was a lot, so let’s summaries everything stated above.

  • Simple categorical embeddings don’t include contextual information
  • By passing categorical embeddings through Transformer Encoder we are able to contextualise the embeddings
  • Transformer architecture can contextualise embeddings because it uses Multi-Headed Attention mechanism
  • Multi-Headed Attention uses matrices Q, K and V to find useful interactions and correlations while encoding the variables
  • In TabTransformer, contextualised embeddings are concatenated with numerical input and passed through a simple MLP to output a prediction

While the idea behind TabTransformer is quite simple, the mechanism of attention might take some time to grasp, so I highly encourage you to re-read the explanations above and follow all the suggested links if you feel lost. It gets easier, I promise!


Results section. Adapted from Huang et al. (2020)

According to the reported results, TabTransformer outperforms all other deep learning tabular models (especially TabNet which I’ve covered here). Furthermore, it comes close to the performance level of GBDTs which is quite encouraging. The model is also relatively robust to missing and noisy data, and outperforms other models in the semi-supervised setting. However, these datasets are clearly not exhaustive and as further papers proved (e.g. this), there’s still a lot of room for improvements.


Now, let’s finally find out how to apply the model to your own data. Example data is taken from the Tabular Playground Kaggle competition. To easily use TabTransformer, I’ve created a tabtransformertf package. It can be installed using pip install tabtransformertf and allows us to use the model without extensive pre-processing. Below you can see the main steps required to train the model but make sure to look into the supplementary notebook for more details.

Data pre-processing

The first step is to set appropriate data types and transform our training and validation data into TF Datasets. Previously installed package has a nice utility to do just that.

The next step is to prepare pre-processing layers for categorical data which we’ll pass on to the main model.

And that’s it for pre-processing! Now, we can move to building a model.

TabTransformer Model

Initialising the model is quite easy. There are a few parameters to specify, but the most important ones are — embbeding_dim , depth and heads . All of the parameters were selected after hyperparameter tuning, so check out the notebook to see the procedure.

With model initialised, we can fit it like any other Keras model. Training parameters can be adjusted as well, so feel free to play around with learning rate and early stopping.


The competition metric is ROC AUC, so let’s use it together with PR AUC to evaluate the model’s performance.

You can also score the test set yourself and submit it to Kaggle. This solution placed me in the top 35% which is not bad, but not great either. Why does TabTransfromer underperform? There might be a few reasons:

  • Dataset is too small, deep learning models are notoriously data hungry
  • TabTransformer very easily overfits on toy examples like the tabular playground
  • There are not enough categorical features to make the model useful


This article explored the main ideas behind the TabTransformer and showed how to apply it using tabtransformertf package.

TabTransformer is an interesting architecture that outperformed many/most of the deep tabular models at the time. Its main advantage is that it contextualises categorical embeddings which increases their expressive power. It achieves this using multi-headed attention mechanism on the categorical features which was one of the first applications of Transformers to the tabular data.

One obvious disadvantage of the architecture is that numerical features are simply passed forward to the final MLP layer. Hence, they are not contextualised and their values are not accounted for in the categorical embeddings as well. In the next article, I’ll explore how we can fix this flaw and further improve the performance.


Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: