TabNet — Deep Neural Network for Structured, Tabular Data

Original Source Here

Framingham Heart Study

Today, I’m going to go through an example of how to use TabNet for a classification task. The dataset contains results from the Framingham Heart Study, which is a study that began in 1948 and has provided (and is still providing) significant insights into risk factors for cardiovascular disease. For those interested in learning more about this study, please checkout this link.

I you’re interested in learning more about TabNet’s architecture, I encourage you to look over the original paper I linked above. Additional resources include this repo where you can see the original TabNet code.

Finally, before we dive in, you can follow along using my notebook found at this repo.

The Data

The data used for this analysis consisted of 16 variables, including the target variable ANYCHD. The descriptions of each can be found below.

Table 1 — Description of the variables found in the data set.

Here’s what our dataframe looks like.

Table 2 — A view of the data in tabular format

Investigating missing values

Next, I wanted to see how much of the data was missing. This is easy to do using df.isnull().sum(), which will tell us how much data is missing per variable. Another way is to use a package missingno which allows us to visualize the relationship between the missing data very quickly.

In figure 1, a matrix representation of the missing values (white) by variable. This is organized vertically by row, which allows us to see if there are any relationships between the missing values. For example, missing values for HDLC and LDLC are identical, suggesting that these values weren’t collected for a portion of the patients in this dataset.

Figure 1 — A matrix of the missing data by variable. Image by author

We can also get a heatmap for a different way of looking at the relationship between missing values as in figure 2. Here we have an easier time seeing the relationship between both HDLC and LDLC with TOTCHOL. The value <1 means that it is slightly less than 1. Since al 3 of these variables are measure of cholesterol, it suggests that cholesterol data was not collected for certain patients in the dataset.

Figure 2 — A heatmap demonstrating the relationship between missing values. Image by author

Imputing missing values

Now that we have gathered information about our missing values, we need to decide what to do about them. There are many options depending on your data, and you can read more about the various imputation algorithms available on sklearn’s webpage.

I opted for the KNN imputer, which you can implement using the following code. To summarize, in the first block I simply divided the data into features and target.

The second block transforms the features using the KNN imputer. As you can see from the print statements, there were originally 1812 missing values which were imputed.

Figure 3 — Using the KNN Imputer to deal with missing values in the dataset.

The final step is to split our data. Using the code below, I initially split the data into 70% for the training set and 30% for the validations set. Then I split the validation set in two equal parts for the validation and test sets. The print statements provide us with information about the shape of the splits.

Figure 4 — Splitting the data into train, validation and test sets


You can be ready to run TabNet in a few simple lines, as shown below. This is a pytorch implementation of TabNet, so you’ll have to import (or install if you haven’t yet) torch, pythorch_tabnet, and the model you wish to use (binary classifier, multi-classifier, regressor).

You will also need some kind of metric to evaluate your model. Here’s a list of those available from sklearn. I also included a label enconder in the event that your data is slightly different to mine. My categorical variables are all binary integers, but if you had categories stored as strings you would use this (or an alternative such as one-hot encoding) first.

Figure 5 — Importing the necessary libraries

Next, we have to define our model, which you can see in the first block of code below. On the first line we define our optimizer, Adam. The next few lines are scheduling a stepwise decay in our learning rate. Let’s unpack what it says:

  • The learning rate is initially set to, lr = 0.020
  • After 10 epochs, we will apply a decay rate of 0.9
  • The result is simply the product of our learning rate and decay rate 0.02*0.9, meaning at epoch 10 it will reduce to 0.018

In the next block of code, we fit the model to our data. Basically it says the train and validation sets will be evaluated with auc (area under the curve) and accuracy as metrics for a total of 1,000 iterations (epochs).

The patience parameter states that if an improvement in metrics is not observed after 50 consecutive epochs, the model will stop running and the best weights from the best epoch will be loaded.

The batch size of 256 was selected based on recommendations from TabNet’s paper, where they suggest a batch size of up to 10% of the total data. They also recommend that the virtual batch size is smaller than the batch size and can be evenly divided into the batch size.

The number of workers was left at zero, which means that batch sizes will be loaded as needed. From what I read, increasing this number is a very memory-hungry process.

Weights can be 0 (no sampling) or 1 (automated sampling). Lastly, drop_last refers to to dropping the last batch if not complete during training.

It’s important to note that many of these are default parameters. You can check out the full list of parameters here.

Figure 6 — Defining and fitting our TabNet classifier

The results of this analysis can be seen in figure 8, and they can be reproduced using the code below. The first three blocks of code plot the loss score, accuracy (for the train and validation sets), and feature importance (for the test set).

The final blocks simply compute the best accuracy achieved for the validation and test sets, which were 68% and 63%, respectively.

Figure 7 — Plotting the loss score, accuracy, and feature importance
Figure 8 — (left) loss score; (middle) accuracy for training (blue) and validation (orange) sets; (right) relative feature importance

Unsupervised pretraining

TabNet can also be pretrained as an unsupervised model. Pretraining involves deliberately masking certain cells and learning relationships between these cells and adjacent columns by predicting the masked values. The weights learned can then be saved and used for a supervised task.

Let’s see how using unsupervised pretraining can influence our models accuracy!

Although similar, the code has some differences, so I have included it below. Specifically, you have to import the TabNetPretrainer. You can see in the first block of code, the TabNetClassifier is replaced by the TabNetPretrainer.

When you fit the model, note the last line pretraining_ratio, which is the percentage of features that are masked during pretraining. A value of 0.8 indicates 80% of features are masked.

The next block of code refers to the reconstructed features generated from TabNet’s encoded representations. These are saved and can then used in a separate supervised task.

Figure 9 — Unsupervised representation learning with TabNet

When pretraining was used for this dataset, the results were 76% and 71% accuracy for the validation and test sets, respectively. This is a significant improvement! Below, you can see the loss scores, accuracy for training (blue) and validation (orange) sets, and the feature importance determined for the test set.

Figure 10 — (left) loss score; (middle) accuracy for training (blue) and validation (orange) sets; (right) relative feature importance. Image by author.


In this post, we walked through an example of how to implement TabNet for a classification task. We found that unsupervised pretraining with TabNet significantly improved the model’s accuracy.

In figure 11 below, I plotted the feature importance for the supervised (left) and unsupervised (right) models. It’s interesting that unsupervised pretraining was able to improve the model’s accuracy while reducing the number of features.

This makes sense when we think about the relationship of the features to one another. Pretraining with TabNet learns, for example, that blood pressure meds (BPMEDS), systolic blood pressure (SYSBP) and diastolic blood pressure (DIABP) are related. Thus, unsupervised representation learning acts as a superior encoder model for a supervised learning task, with cleaner and more interpretable results.

Figure 11 — Feature importance values for the supervised (left) and unsupervised (right) TabNet models. Image by author.

I hope you enjoyed this post! Give it a try and let me know how it worked for you.


Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: