Imbalanced Data — Oversampling Using Gaussian Mixture Models

Original Source Here

Using GMM as an oversampling technique

If we go back to the classification dataset we generated in the section “Dataset Preparation”, we can try to extract subpopulations/clusters from the data. Let’s extract 5 clusters for example.

`gmm = GaussianMixture(5)gmm.fit(X_train)`

That was it!

Under the hood, our GMM model has now created 5 different clusters with different normal distributions reflecting the feature values each cluster can take.

Below is an example showing the cluster mean values for each of the 20 features.

`pd.DataFrame(gmm.means_.T)`

More importantly, the GMM model can help us with 2 functions:

1- It can look at the feature values for a particular sample and assign the sample to a cluster.

`gmm.predict(X_test)`

2- It can use the fitted normal distributions to generate new samples which we can use for oversampling.

`gmm.sample(5)`

Finally, in this example, we are clustering our data into 5 different buckets, but our problem is a binary classification problem where our target variable can be either 0 or 1.

One idea is to check for the relationship between every cluster and the target (y) variable.

`cluster_mean = pd.DataFrame(data={    "Cluster": gmm.predict(X_train),     "Mean Target Variable (y)": y_train}).groupby("Cluster").mean().reset_index(drop=False)plt.figure(figsize=(10, 5))sns.barplot(data=cluster_mean, x="Cluster",      y="Mean Target Variable (y)")plt.show()`

We can see that Cluster 4 has the highest mean value for the target value at just over 40%.

Remember this is an imbalanced dataset where only 5% of the samples have a target variable (y) value = 1, so 40% is a large number.

The last step would be to generate random samples from the GMM model and only keep the ones which belong to Cluster 4. We can label them with a positive target variable (y=1).

`samples, clusters = gmm.sample(100)samples_to_keep = samples[clusters==4]`

We can finally add them to our training data!

We can similarly draw sample from the top 2–3 clusters that are most strongly associated with y=1. Or we can draw sample from any cluster where the mean value of y is above a predefined threshold.

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot