Classification with TabNet: Deep Dive



Original Source Here

Classification with TabNet: Deep Dive

Photo by Mika Baumeister on Unsplash

Tabular data is the bread and butter for training fraud detection algorithms at Ravelin. We extract transaction, identity, product, and network attributes (read this blog if you’re interested in our network features) and place them into a big table of features which can be easily used by different Machine Learning models for training and inference. Decision tree based models (e.g Random Forest or XGBoost) are the go-to algorithms for dealing with tabular data because of their performance, interpretability, speed of training and robustness.

On the other hand, neural networks are considered state of the art in many fields and perform particularly well on large datasets with minimal feature engineering. Many of our clients have large transaction volumes and deep learning is a potential path towards improving model performance in terms of fraud detection.

In this blog, we’re going to deep dive into the neural network architecture called TabNet (Arik & Pfister (2019)) which was designed to be interpretable and to work well with tabular data. After explaining the key building blocks and ideas behind it, you’ll see how to implement it in TensorFlow and how it can be applied to the fraud detection dataset. Most of the code was taken from this implementation, so make sure to leave a star there if you use it in your work!

TabNet

TabNet mimics the behaviour of decision trees using the idea of Sequential Attention. Simplistically speaking, you can think of it as a multi-step neural network that applies two key operations at each step:

  1. An Attentive Transformer selects the most important features to process at the next step
  2. A Feature Transformer processes the features into a more useful representation

The output of the Feature Transformer is later used in the prediction. Using both the Attentive and Feature Transformers, TabNet is able to simulate the decision making process of tree-based models. Consider the high-level overview of model’s prediction on the Adult Census Income dataset below. The model is able to select and process features which are the most useful for the task at hand which improves interpretability and learning.

TabNet prediction intuition on Income dataset. Source: https://arxiv.org/pdf/1908.07442.pdf

The key building block for both Attentive and Feature Transformers are the so called Feature Blocks, so let’s explore them first.

Feature Blocks

Feature Blocks consist out of sequentially applied Fully-Connected (FC) (or Dense) layer and Batch Normalisation (BN). In addition, for Feature Transformers the output gets passed through the GLU activation layer.

TabNet block with GLU

The main function of the GLU (as opposed to a sigmoid gate), is to allow hidden units to propagate deeper into the model and prevent the exploding or vanishing gradients.

In addition, the original paper uses Ghost Batch Normalisation to improve the convergence speed during training. You can find Tensorflow implementation here if you are interested, but we’ll use a default Batch Normalisation layer in this tutorial.

Feature Transformers

A FeatureTransformer (FT) is basically a collection of feature blocks applied sequentially. In the papers, one FeatureTransformer consists of two shared blocks (i.e. weights are reused across steps) and two step dependent blocks. Shared weights reduce the number of parameters in the model and lead to better generalisation.

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: