How to Learn Multiple Tasks with a Single Neural Network

https://miro.medium.com/max/1200/0*PPjIbWKTt95CTOZW

Original Source Here

In 2016, researchers at Deepmind came out with a paper that solves this problem. I especially like their approach because it’s not that complicated. All they are really doing is applying a special type of regularization to the network. Let’s take a closer look.

Say we have two tasks, A and B. The block schedule is that first A is trained with many examples, and then we switch over to training B. The Deepmind researchers propose to first train A normally (i.e. regular gradient descent/backpropogation). Then, during the B block, we keep the weights learned from A and continue with gradient descent. The only difference is we now include a quadratic regularization term that is unique for each weight. The idea is to use this per-weight regularization to penalize moving away from the weights learned from A. Weights that are deemed more important to A will be penalized heavier. Mathematically, our cost function during the B training block is L(θ) = L_B(θ) + Σ(k_i * (θ_Ai — θ_i)²), where i is an index over all the network weights, and θ_Ai are the weights right after the A training block completed. L_B(θ) is the normal cost function for B, probably squared error or log-loss. Finally, k_i is the importance of weight i for prediction of A.

There’s also a more intuitive way to think about the per-weight regularization. Imagine a physical spring. When you pull on a spring, the further you pull the stronger the spring pulls back. Also, some springs are stronger than others. How does this relate to our algorithm? You can imagine a spring connected to each weight in the neural network. The relative strength of all the springs is the regularization. Certain springs (important to A) will be extremely strong, so during the training of B, the algorithm will be discouraged from pulling on those strong springs, and the corresponding weights won’t change much. Therefore, the algorithm will instead pull on weaker springs, and the weights corresponding to those will be changed more.

One more way to think about this algorithm is that it’s an improvement on L2 regularization. With L2 regularization, weights are discouraged from changing much with a penalty corresponding to the sum of the squared magnitudes of the weights. However, in L2 regularization, all the weights are punished equally. In our algorithm, only the important weights are prevented from changing.

Ok — we now understand intuitively how this algorithm works. By keeping important weights for A relatively constant, we preserve performance on A while training successfully on B. However, we still haven’t explained how to determine “important” weights for A. So let’s ask the question: what makes a weight important? A reasonable answer might be: A weight is important if it affects the final prediction more than other weights. More specifically, we can say a weight is important if its derivative with respect to the final prediction has higher magnitude than other weights derivatives. We’re missing one thing though — because weights in a neural network affect other weights, their derivatives are linked to one another. In other words, we can’t just consider the derivative of a given weight on its own; we need to look at the covariance matrix of all the weight derivatives. A more formal version of this Is called the Fisher Information Matrix, which is what the researchers ended up using.

So, to summarize: first we train A normally, then we train B with a per-weight quadratic regularization. These per-weight regularizations depend on the relative importance of the weights to A, which can be found with the Fisher Information Matrix. The result is a neural network that works well for both A and B. The name the researchers gave this method is “Elastic Weight Consolidation” (EWC).

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: