Few Shot Learning Using SBERT

Original Source Here

Few Shot Learning Using SBERT


“If the dataset is really small, both of these techniques might not help us. Imagine a task where we need to build a classification with only one or two samples per class, and each sample is super difficult to find.”

In this article, we are going to implement document classification with the help of a very less number of documents.

Basically, document classification majorly falls into 3 categories in terms of labeled data availability:

1. Simple Classification, an abundance of Data, where we have a huge amount of data for the training and testing of our model

2. Few-Shot Classification, a very less amount of data for each category for e.g 10–40 data points for each class

3 One/Single Shot Classification, when we have only one data point for each category

In this article, we are going to cover the second scenario, where we have a very less number of labeled data points available for each category

Few-Shot Learning:

Likewise this previous article on Bert, here also we are going to use the Embeddings of the Sentences to get the predictions.

But there is a slight difference between the previous and this implementation and that is the way of embedding creation. In Bert, we were creating the token embedding but in SBERT we create the document embedding with the help of Sentence embeddings.


Sentence-Transformers is a Python framework for state-of-the-art sentence, text, and image embeddings. These embeddings can then be used for classification or clustering. e.g. with the help of cosine-similarity between embeddings we could find the sentences with similar meaning. This can be useful for semantic textual similarity, semantic search, or paraphrase mining.


pip install -U sentence-transformers

Implementation of Few-Shot Learning

Now suppose we have 3 different classes Class A, B, and C. And we have only 10 labeled data points for each class. Now we need to create a robust classifier with the help of this minimal data.

We will create embedding of our labeled dataset and while inferencing, will measure the distance of the new document with saved embeddings of each category, categorize it with the closest one.

While creating the Embedding for the document with the help of SBERT, we need to break the Document into multiple sentences. As SBERT maintains the context at the sentence level and compares accordingly.

To break the Document into multiple sentences we have used SynTOK Library. This library uses multiple regular expressions as well as rules to break the documents into a list of sentences.

Flow to create document embedding
import syntok.segmenter as segmenter

document = open('README.txt').read()

# choose the segmentation function you need/prefer

for paragraph in segmenter.process(document):
for sentence in paragraph:
for token in sentence:
# roughly reproduce the input,
# except for hyphenated word-breaks
# and replacing "n't" contractions with "not",
# separating tokens by single spaces
print(token.value, end=' ')
print() # print one sentence per line
print() # separate paragraphs with newlines

for paragraph in segmenter.analyze(document):
for sentence in paragraph:
for token in sentence:
# exactly reproduce the input
# and do not remove "imperfections"
print(token.spacing, token.value, sep='', end='')
print("\n") # reinsert paragraph separators
###### Two diffrent types of segmentator are available, we could use any one of them as per our convenience. #####

Now we have the list of sentences for the document, we will extract the embeddings with the help of SBERT.

Here we are using the pre-trained model ‘stsb-bert-base’ for the sentence similarity task

from sentence_transformers import SentenceTransformer ,  util
import torch
model = SentenceTransformer('stsb-bert-base')#Our sentences we like to encode
sentences = ['This framework generates embeddings for each input sentence',
'Sentences are passed as a list of string.',
'The quick brown fox jumps over the lazy dog.']

#Sentences are encoded by calling model.encode()
embedding_list = model.encode(sentences)

In the above snippet, we have the embeddings at the sentence level, as we have broken the document into 3 different sentences, so to get the document embedding, we need to take the average of all of the sentence embeddings in a document.

“Document_embed = torch.mean(torch.stack(embedding_list),dim=0)”

This is how we could calculate the document embedding.

Now we could repeat the same process for each document of each category. Because of this, we have 10 different document embeddings for each class. So to get the individual class embedding we have taken the average of all of its document embeddings.


For the classification of the new document, we could extract its embeddings just like we have extracted for the labeled data.

And after that, we calculate the Cosine distance of each class embedding with the new document embedding. And assign it to the closest one.

“distance = util.pytorch_cos_sim(doc_1, doc_2)”

With the help of this technique, we could generate more labeled data as well with decent accuracy.

Thanks for Reading !!


1 Nikhil (https://nkhandelwal204.medium.com/)

2 Deepak Saini




Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: