How to Add New Data to a Pretrained Model in Scikit-learn

https://miro.medium.com/max/1200/0*vjvTh3eBicSX3b4b

Original Source Here

Machine Learning

How to Add New Data to a Pretrained Model in Scikit-learn

A step-by-step tutorial on how to use warm_start=True and partial_fit() in scikit-learn

Photo by h heyerlein on Unsplash

When you build a Machine Learning model from scratch, usually, you split your dataset into training and test set, and then you train your model on your training set. Then, you test the performance of your model on your test set, and if you get something decent, you can use your model for prediction.

But what if new data becomes available at some point?

In other words, how to train an already trained model? Or again, how to add new data to an already trained model?

In this article I try to give some answers to this non-trivial question, using the scikit-learn library. You can check this interesting article by Vidhi Chugh to understand when you need to retrain your model.

One possible (trivial) solution to the previous question, could be to train the model from scratch, by using both old and new data. However, this solution does not scale, if the first training requires a long time.

The solution to the problem is to add samples to an already trained model. And this scikit-learn allows you to do it in some cases. Just follow some precautions.

Scikit-learn proposes two strategies:

To illustrate how to add new data to a pre-trained model in Scikit-learn, I will use a practical example, using the well-known iris dataset, provided by the Scikit-learn library.

warm start

A warm start is a parameter provided by some Scikit-models. If it is set to True, it permits the use of the existing fitted model attributes to initialize a new model in a subsequent call to fit.

For example, you can set warm_start = True in a Random Forest Classifier, then you can fit the model regularly. If you call again the fit method on new data, new estimators will be added to the existing trees. This means that the use of warm_start = True does not change the existing trees.

warm_start = True should not be used for incremental learning on new datasets where there could be concept drift. Concept drift is a type of drift in the data model, which happens when the underlying relationship between the output and the input variables changes.

To understand how warm_start = True works, I describe an example. The idea is to show that the use of warm_start = True could improve the performance of an algorithm if I add new data, that has the same distribution as the original data and which maintains the same relationship with the output variable.

Firstly, I load the iris dataset, provided by the Scikit-learn library:

from sklearn import datasetsiris = datasets.load_iris()
X = iris.data
y = iris.target

Then, I split the dataset into three parts:

  • X_train, y_train — training set 80% of 40% of data (48 samples)
  • X_test, y_test — test set 20% of 40 of data (12 samples)
  • X2, y2 — new samples (60% of data) (90 samples)
from sklearn.model_selection import train_test_splitX1, X2, y1, y2 = train_test_split(X, y, test_size=0.60, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X1, y1, test_size=0.20, random_state=42)

I will use X2 and y2 to retrain the model.

Note that the training set is very small (48 samples).

I train the model, with warm_start = False:

from sklearn.ensemble import RandomForestClassifiermodel = RandomForestClassifier(max_depth=2, random_state=0, warm_start=False, n_estimators=1)
model.fit(X_train, y_train)

I calculate the score:

model.score(X_test, y_test)

which gives the following output:

0.75

Now, I fit the model on new data:

model.fit(X2, y2)

The previous fit deletes the model already learned. Then, I calculate the score:

model.score(X_test, y_test)

which gives the following output:

0.8333333333333334

Now I build a new model with warm_start = True, to see if the model score increases.

model = RandomForestClassifier(max_depth=2, random_state=0, warm_start=True, n_estimators=1)
model.fit(X_train, y_train)
model.score(X_test, y_test)

which gives the following output:

0.75

Now, I fit the model and I calculate the score:

model.n_estimators+=1
model.fit(X2, y2)
model.score(X_test, y_test)

which gives the following output:

0.9166666666666666

The incremental learning has improved the score!

partial fit

The second strategy provided by Scikit-learn to add new data to a pre-trained model is the use of the partial_fit() method. Not all the models provide this method.

While the warm_start = True parameter does not change the attribute parameters already learned by the model, the partial fit could change it because it learns from new data.

I consider again the iris dataset.

Now I use a SGDClassifier:

from sklearn.linear_model import SGDClassifier
import numpy as np
model = SGDClassifier()
model.partial_fit(X_train, y_train, classes=np.unique(y))

The first time I run the partial_fit() method, I must pass to the method also all the classes. In this. example, I suppose that I know all the classes contained in y, although, I do not have enough samples to represent them.

I calculate the score:

model.score(X_test, y_test)

which gives the following output:

0.4166666666666667

Now, I add new samples to the model:

model.partial_fit(X2, y2)

and I calculate the score:

model.score(X_test, y_test)

which gives the following output:

0.8333333333333334

Adding new data has improved the performance of the algorithm!

Summary

Congratulations! You have just learned how to add new data to a pre-trained model in Scikit-learn! You can use either the warm_start parameter set to True or the partial_fit() method. However, not all the models in the Scikit-learn library provide the possibility to add new data to a pre-trained model. Thus my suggestion is to check the documentation!

You can download the code used in this tutorial from my Github repository.

If you have read this far, for me it is already a lot for today. Thanks! You can read my trending articles at this link.

Related Articles

Stay connected!

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: