Why Cross Entropy Loss?



Original Source Here

While solving classification problems using deep learning models, we use cross entropy to tell the model how good or bad it’s predictions are during training. What is this cross entropy loss?

Cross entropy in a way can be looked as the difference between 2 probability distributions in the case of supervised learning with one-hot encoded labels. Let’s say we are trying to classify an input between 3 categories. Our model outputs a probability distribution m and the label l.

It is okay if you dont understand this next piece of code, this is just to show us the cross entropy value.

# here x, y are probability distributions
def cross_entropy(x, y): return — torch.log(x[y.argmax()])
# OUTPUT:
# m = [0.6, 0.2, 0.2] l = [0, 0, 1]
# Cross entropy loss: 1.6094

Comparison with Mean Squared Error (L2 loss)

If we consider the probabilities as 2 vectors and find the squared distance between them, we get the L2 loss. Why don’t we use this loss for these classification problems?

Difference between MSE and CSL
Loss comparisons

The cross entropy loss (log-loss) punishes the model much more for being wrong (vice versa for being right) and also gives a larger gradient value that would provide for faster convergence. Our model parameters learn using these gradients.

Information Theory and Likelihood: Two ways to think about cross entropy loss

Information Theory

The term cross-entropy comes from Information Theory. To understand cross-entropy, we need to first understand what information content and entropy are.

If something surprises you, it holds some new information. Information that you were not expecting. For example, rain in desert region during summers is surprising. It’s giving you new information about the weather which is typically dry. Why is this considered new? Because the chances for an event like to occur is very low. This new information is quantified as the Information Content (IC) of the event. This is inversely proportional to the probability of that event.

where x is the event and p is the probability of the event. (IC of events with p=0, is 0)

Entropy is nothing but on average how surprised are we while getting information from this distribution. In more mathematical terms, the expected value of IC.

The probability distribution and information content of the team winning

The entropy for the above system is: 1.0732

Entropy is a nice measure of uncertainty in the distribution. The more uncertain the distribution, the more surprising it is for the consumer, higher the entropy.

Now that we understand entropy, what is cross entropy? Cross entropy is entropy of the predicted (or estimated) probability distribution and not the actual. Since the data is drawn from the actual population, the expectation is still calculated using the actual distribution.

Let the estimated distribution be q and the actual distribution be p.

Why does this make sense for classification? Our labels for the classification task are one-hot encoded vectors. Now these vectors can be thought of as a probability distribution with one class having a probability of 1 as seen in the following figure (left). This works as our actual distribution p. The entropy of this distribution (left) is 0.

(left) Input labels as one-hot encoded vectors, p. (right) Predicted distribution from the model, q.

Our model predicts a probability for each class on an input. This is our q. The cross entropy now in a way tells us the difference between p and q. (In this case, this cross entropy is also the KL Divergence). Since p is 0 for all the non-correct classes, the cross entropy reduces to

The predicted information content of the event

The cross entropy for the above p and q is: 0.6931

This behaves as our loss. Our goal is to use gradient descent to update the parameters such that the model output distribution q should be as close to p as possible to give us a minimum loss.

Another reason why we don’t worry about the predicted probabilities of the non-correct classes is because increase the probability of one class leads to decrease in the others. All of them have to sum to 1.

Likelihood

To give you an intuitive definition, likelihood is like a measure that tells us how well the the model — using it’s parameters — fits the dataset (we are given the dataset). This is different from probability as probability measures the chance of producing a data point given the model.

For our problem, we are interested in maximizing this likelihood. In other words, the model has to learn to make good predictions. This is maximum likelihood estimation in an intuitive sense. The estimators are tasked to find the parameters which gives us the best measure of the model fitting the data.

Now instead of using likelihood, we use log-likelihood. Why?

  1. The relation x > y will hold even if we apply log to it; log is a monotonic function.
  2. log converts a chain of multiplications into additions. Such multiplication operations can produce very large or very small numbers which cannot be represented easily by computers. Log comes in handy in such situations.

And a few other reasons from the log-loss diagrams.

In deep neural networks we use gradient descent to minimize the loss. To use the powerful gradient descent as an estimator, we convert the maximizing likelihood problem to a minimizing problem by negating the log-likelihood. Thus, we have the negative log-likelihood (NLL). This behave as the loss function that we need to minimize. The NLL loss is low when the likelihood estimates for the correct category are high, and vice versa. As the likelihood for the correct class is high, the log likelihood is high, which implies the negative likelihood is low. By doing so, we train the model to get better at making predictions.

Thanks for reading!

References

  1. A Short Introduction to Entropy, Cross-Entropy and KL-Divergence https://www.youtube.com/watch?v=ErfnhcEV1O8
  2. Maximum Likelihood Estimation (MLE): The Intuition https://www.youtube.com/watch?v=wbh730SeqeA
  3. Information content, Entropy & cross entropy https://www.youtube.com/watch?v=sbvv-uQmwVY
  4. fastbook: https://github.com/fastai/fastbook

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: