Training T5 model in just 3 lines of code with ONNX Inference



Original Source Here

Background

simpleT5 is a python package built on top of PyTorch-lightning and Hugging Face Transformers that lets you quickly(in just 3 lines of code) train a T5 model. So talking about T5, I have already written and spoken about it a lot in the past, so please feel free to check it out — (Blog, Video).

So for this blog, we will try to develop a Yes/No Question Answering system using this library and then test the performance with and without converting it into ONNX format. So let’s go …

Installation

You can install this library very easily using the pip command as shown below — (I recommend using the — upgrade flag for getting the most recent and updated version)

> pip install --upgrade simplet5

In the installation logs, I can see that the library takes care of other necessary dependencies such as PyTorch-lightning, Transformers, etc with their relevant versions. (So keep an eye on the logs, as I am pretty sure you don’t want to mess up other projects because of the library upgrades.)

In Action

So as we said before, that for this blog we will be training a boolean (yes/no) question answering model and for which we will using the BoolQ dataset (BoolQ is a question answering dataset for yes/no questions containing 15942 examples. Each example is a triplet of (question, passage, answer), with the title of the page as optional additional context.)

Below is a sample snippet from the training data —

Since in JSONL, every line is independently a JSON object, hence we load it in a below-mentioned way and print the head to verify the parsing —

Loading train and validation set of BoolQ using Pandas
Head view of BoolQ train dataframe

So the idea is that we will concatenate passage and question with some separator token and would train the model against the answer column (True/False). Also, while going through the code, I found that the input column (which is passage+sep+question for us) should be named as “source_text” and the output column (which is “answer” for us) should be renamed to “target_text”. Considering the fact that the T5 model requires a “task prefix” for it to learn, differentiate and perform better per task, I choose “boolqpassage:” as the prefix and “question:” as the separator token. The below snippet shows the same —

Data Preparation

Finally, we print the shape of both our data frames to get the idea of size and columns —

training/validation data shapes

We are now ready to train our model (3 lines claim starts here ;)) excluding the import, obviously!!

Model Training

We load the model instance (currently it support T5 and MT5 as possible options) and feed in our training and validation data frames along with other necessary training parameters. Once the training is done, you will see the model dump per epoch in the “outputs” directory. You can also choose to write it somewhere else by specifying it in the “outputdir” parameter of the train method. So, yeah we have our model ready, let’s test it out.

I framed a Yes/No question from one of the blogs and made prediction using the “predict” method as shown below —

text_to_answer = boolqpassage: Overfitting is a phenomenon that occurs when a machine learning or statistics model is tailored to a particular dataset and is unable to generalise to other datasets. This usually happens in complex models, like deep neural networks. question: does overfitting happen in deep networks%%time
model.predict(text_to_answer)
> True (Yes)

That’s definitely a correct answer 👏 As the last line in the passage talks exactly that! Let’s try to trick the model by asking the opposite of that —

text_to_answer = boolqpassage: Overfitting is a phenomenon that occurs when a machine learning or statistics model is tailored to a particular dataset and is unable to generalise to other datasets. This usually happens in complex models, like deep neural networks. question: does overfitting happen in shallow networks and easy networks%%time
> model.predict(text_to_answer)
> False (No)

Pretty good, right!

So talking about the inference time it took for this one example on CPU is

CPU times: user 806 ms, sys: 13.9 ms, total: 820 ms

and the time it takes after converting our model to ONNX format is

CPU times: user 314 ms, sys: 6.9 ms, total: 321ms

As per the documentation, you can easily convert your trained model to ONNX format using the below mentioned command —

model.convert_and_load_onnx_model(model_dir="outputs/SimpleT5-epoch-0-train-loss-0.2274")

If you want to learn more about ONNX then please read through this, this, this and this. Also, you can find the entire code for this use case here. 👍

So, yeah it’s pretty cool and that’ it for this blog. 😃

If you like reading research papers then you might want to checkout some of the research paper summaries that i have written —

Graph-based Text Similarity Method

Grammar Correction System for Mobile Devices

BERT for Extractive Text Summarization

Also, do check out the official github repository and it to get first hand updates:

Lastly, in case you enjoyed reading this article and if you want, you can buy me a “chai” 🥤 on https://www.buymeacoffee.com/TechvizCoffee — because I don’t actually drink coffee ๐Ÿ™‚ Thank you very much! It’s totally optional and voluntary ๐Ÿ™‚

I hope the read was worth your time. See you next time.👋 Thank You!

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: