Original Source Here

## Using GMM as an oversampling technique

If we go back to the classification dataset we generated in the section “Dataset Preparation”, we can try to extract subpopulations/clusters from the data. Let’s extract 5 clusters for example.

`gmm = GaussianMixture(5)`

gmm.fit(X_train)

That was it!

Under the hood, our GMM model has now created 5 different clusters with different normal distributions reflecting the feature values each cluster can take.

Below is an example showing the cluster mean values for each of the 20 features.

`pd.DataFrame(gmm.means_.T)`

More importantly, the GMM model can help us with 2 functions:

1- It can look at the feature values for a particular sample and assign the sample to a cluster.

`gmm.predict(X_test)`

2- It can use the fitted normal distributions to **generate new samples** which we can use for **oversampling**.

`gmm.sample(5)`

Finally, in this example, we are clustering our data into 5 different buckets, but our problem is a binary classification problem where our target variable can be either 0 or 1.

One idea is to check for the relationship between every cluster and the target (y) variable.

cluster_mean = pd.DataFrame(data={

"Cluster": gmm.predict(X_train),

"Mean Target Variable (y)": y_train

}).groupby("Cluster").mean().reset_index(drop=False)plt.figure(figsize=(10, 5))

sns.barplot(data=cluster_mean, x="Cluster",

y="Mean Target Variable (y)")

plt.show()

We can see that **Cluster 4** has the highest mean value for the target value at just over 40%.

Remember this is an imbalanced dataset where only 5% of the samples have a target variable (y) value = 1, so 40% is a large number.

The last step would be to generate random samples from the GMM model and only keep the ones which belong to Cluster 4. We can label them with a positive target variable (y=1).

`samples, clusters = gmm.sample(100)`

samples_to_keep = samples[clusters==4]

We can finally add them to our training data!

We can similarly draw sample from the top 2–3 clusters that are most strongly associated with y=1. Or we can draw sample from any cluster where the mean value of y is above a predefined threshold.

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot