Imbalanced Data — Oversampling Using Gaussian Mixture Models



Original Source Here

Using GMM as an oversampling technique

If we go back to the classification dataset we generated in the section “Dataset Preparation”, we can try to extract subpopulations/clusters from the data. Let’s extract 5 clusters for example.

gmm = GaussianMixture(5)
gmm.fit(X_train)

That was it!

Under the hood, our GMM model has now created 5 different clusters with different normal distributions reflecting the feature values each cluster can take.

Below is an example showing the cluster mean values for each of the 20 features.

pd.DataFrame(gmm.means_.T)
We can see, for example, that Cluster 0 has a mean value of -0.159613 for Feature 1 (image by author)

More importantly, the GMM model can help us with 2 functions:

1- It can look at the feature values for a particular sample and assign the sample to a cluster.

gmm.predict(X_test)
Example prediction showing the first 3 samples in X_test belong to clusters 0, 2, and 4 respectively (Image by author)

2- It can use the fitted normal distributions to generate new samples which we can use for oversampling.

gmm.sample(5)
5 samples were generated with feature values, i.e. X, they fall into cluster 0, 1, 2, 2, and 2 respectively (image by author)

Finally, in this example, we are clustering our data into 5 different buckets, but our problem is a binary classification problem where our target variable can be either 0 or 1.

One idea is to check for the relationship between every cluster and the target (y) variable.

cluster_mean = pd.DataFrame(data={
"Cluster": gmm.predict(X_train),
"Mean Target Variable (y)": y_train
}).groupby("Cluster").mean().reset_index(drop=False)
plt.figure(figsize=(10, 5))
sns.barplot(data=cluster_mean, x="Cluster",
y="Mean Target Variable (y)")
plt.show()
Cluster 4 is most strongly associated with a positive target variable (image by author)

We can see that Cluster 4 has the highest mean value for the target value at just over 40%.

Remember this is an imbalanced dataset where only 5% of the samples have a target variable (y) value = 1, so 40% is a large number.

The last step would be to generate random samples from the GMM model and only keep the ones which belong to Cluster 4. We can label them with a positive target variable (y=1).

samples, clusters = gmm.sample(100)
samples_to_keep = samples[clusters==4]

We can finally add them to our training data!

We can similarly draw sample from the top 2–3 clusters that are most strongly associated with y=1. Or we can draw sample from any cluster where the mean value of y is above a predefined threshold.

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: