How to Train a Classification Model with TensorFlow in 10 Minutes

Original Source Here

Accuracy, precision, and recall increase slightly as we train the model, while loss decreases. All have occasional spikes, which would hopefully wear off if you were to train the model longer.

According to the chart, you could train the model for more epochs, as there’s no sign of plateau.

But are we overfitting? Let’s answer that next.

Making predictions

You can now use the predict() function to get prediction probabilities on the scaled test data:

predictions = model.predict(X_test_scaled)

Here’s how they look like:

Image 9 — Prediction probabilities (image by author)

You’ll have to convert them to classes before evaluation. The logic is simple — if the probability is greater than 0.5 we assign 1 (good wine), and 0 (bad wine) otherwise:

prediction_classes = [
1 if prob > 0.5 else 0 for prob in np.ravel(predictions)

Here’s how the first 20 look like:

Image 10 — Prediction classes (image by author)

That’s all we need — let’s evaluate the model next.

Model evaluation on test data

Let’s start with the confusion matrix:

from sklearn.metrics import confusion_matrix

print(confusion_matrix(y_test, prediction_classes))
Image 11 — Confusion matrix (image by author)

There are more false negatives (214) than false positives (99), so the recall value on the test set will be lower than precision.

The following snippet prints accuracy, precision, and recall on the test set:

from sklearn.metrics import accuracy_score, precision_score, recall_score

print(f'Accuracy: {accuracy_score(y_test, prediction_classes):.2f}')
print(f'Precision: {precision_score(y_test, prediction_classes):.2f}')
print(f'Recall: {recall_score(y_test, prediction_classes):.2f}')
Image 12 — Accuracy, precision, and recall on the test set (image by author)

All values are somewhat lower when compared to train set evaluation:

  • Accuracy: 0.82
  • Precision: 0.88
  • Recall: 0.83

The model is overfitting slightly, but it’s still decent work for a couple of minutes. We’ll go over the optimization in the following article.


Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: