TabNet: The End of Gradient Boosting?

Original Source Here


Each Step is a block of components. The number of Steps is a hyperparameter option when training the model. Increasing this will increase the learning capacity of the model, but will also increase training time, memory usage and the chance of overfitting.

Each Step gets its own vote in the final classification and these votes are equally weighted. This mimics an ensemble classification.

Feature Transformer

The Feature Transformer is a network which has an architecture of its own.

It has multiple layers, some of which are shared across every Step while others are unique to each Step. Each layer contains a fully connected layer, batch normalisation and a Gated Linear Unit activiation. If you aren’t familiar with these terms, Google’s ML Glossary is a good place to start.

TabNet Feature Transformer Model Architecture. Image by Author. Inspired by

The authors of the TabNet paper state that sharing some layers between decision Steps leads to “parameter-efficient and robust learning with high capacity” and that normalization with root 0.5 “helps to stabilize learning by ensuring that the variance throughout does not change dramatically”. The output of the feature transformer uses a ReLU activation function.

Feature Selection

Once features have been transformed, they are passed to the Attentive Transformer and the Mask for feature selection.

The Attentive Transformer is comprised of a fully connected layer, batch normalisation and Sparsemax normalisation. It also includes prior scales, meaning it knows how much each feature has been used by the previous steps. This is used to derive the Mask using the processed features from the previous Feature Transformer.

TabNet Attentive Transformer Model Architecture. Image by Author. Inspired by

The Mask ensures the model focuses on the most important features and is also used to derive explainability. It essentially covers up features, meaning the model is only able to use those that have been considered important by the Attentive Transformer.

We can also understand feature importance by looking at how much a feature has been masked for all decisions and and an individual prediction.

TabNet employs soft feature selection with controllable sparsity in end-to-end learning

This means one model jointly performs feature selection and output mapping, which leads to better performance.

TabNet uses instance-wise feature selection, which means features are selected for each input and each prediction can use different features.

This feature selection is essential as it allows decision boundaries to be generalised to a linear combination of features, where coefficients determine the proportion of each feature, which in the end leads to the model’s interpretability

Implementation in PyTorch

The best way to use TabNet is with Dreamquark’s PyTorch implementation. It uses a scikit-learn style wrapper and is GPU compatible. The repo has plenty of examples of the model in use so I’d highly recommend checking it out.

Training the model is actually really simple and can be done in a few lines of code, TabNet also doesn’t have too many hyperparameters.


Dreamquark also provide some really great notebooks which perfectly show how to implement TabNet while also working to validate the original authors claims about the models accuracy on certain benchmarks.



Both these examples are reproducible and include an XGBoost model to compare to TabNet’s performance.


A key benefit of TabNet over Boosted Trees is that it is more explainable. We cannot dissect predictions in gradient boosting without using something like SHAP or LIME. Because of the masks, we can get an idea of the features our TabNet model used both globally (across the whole dataset) and locally (for a single prediction).

To explore this, I’m going to use the classification example above, which uses a census income dataset.

Feature Importances

We can view importances of our individual features which nicely add up to 1. When we get this data out of a tree-based model it can be skewed towards one variable, or categorical variables with a large number of unique values. In some cases, this can misrepresent what the model is actually doing.

In this example, we see a much greater spread of importance when TabNet is used, meaning it is using features more equally. This may not necessarily be better and there could be flaws in the TabNet process. However, the original paper’s authors did compare feature importances to synthetic data examples and found that TabNet was using the features they expected.

Feature Importances from TabNet and XGBoost models trained on a census dataset. Image by Author.

Note: the features with numbers as feature names (e.g. 2174) appear to be anonymised features.


By using the masks, we can understand which features are being used at a prediction level, we can look at the aggregate of all the Masks or an individual Mask.

So for row 0, which is the first row of our test data, it seems Mask 1 prioritises the 4th feature in the dataset., while the other Masks use different features.

This can give us some understanding of which features the model has used to make its prediction, it gives us more confidence as we can work out the ‘whys’ behind a models predictions and may help us understand how it handles unseen data.

However it isn’t clear how this relates to the actual feature value — we don’t know if the model uses the feature because it is high or low. More importantly, we can’t readily understand interaction terms.

Heatmaps of the Masks in the Census TabNet model. A lighter color indicate the feature is being used. Image by Author.

Improving Results with Self-Supervised Learning

The TabNet paper also proposes self-supervised learning as a way to pretrain the model weights and reduce the amount of training data.

To do this, features within the dataset are masked and the model tries to predict them. A decoder is then used to output results.

This can also be done in Dreamquark’s package

Using self-supervised learning should yield better results with less training data.


TabNet is a deep learning model for tabular learning. It uses sequential attention to choose a subset of meaningful features to process at each decision step. Instance-wise feature selection allows the model’s learning capacity to be focused on the most important features and visualisation of the model’s masks provide explainability.

Hopefully you can see that TabNet allows us to achieve state of the art results while maintaining interpretability. With AI regulation becoming tougher, understanding how our models work is only going to become more important in the future. I would highly recommend giving TabNet a try in your next project or Kaggle competition!

Learn More

Get my content straight to your inbox!


Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: