How to Optimize Learning Rate with TensorFlow — It’s Easier Than You Think



Original Source Here

How to optimize learning rate in TensorFlow

Optimizing the learning rate is easy once you get the gist of it. The idea is to start small — let’s say with 0.001 and increase the value every epoch. You’ll get terrible accuracy when training the model, but that’s expected. Don’t even mind it, as we’re only interested in how the loss changes as we change the learning rate.

Let’s start by importing TensorFlow and setting the seed so you can reproduce the results:

import tensorflow as tf
tf.random.set_seed(42)

We’ll train the model for 100 epochs to test 100 different loss/learning rate combinations. Here’s the range for the learning rate values:

Image 4 — Range of learning rate values (image by author)

A learning rate of 0.001 is the default one for, let’s say, Adam optimizer, and 2.15 is definitely too large.

Next, let’s define a neural network model architecture, compile the model, and train it. The only new thing here is the LearningRateScheduler. It allows us to enter the above-declared way to change the learning rate as a lambda function.

Here’s the entire code:

initial_model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])

initial_model.compile(
loss=tf.keras.losses.binary_crossentropy,
optimizer=tf.keras.optimizers.Adam(),
metrics=[
tf.keras.metrics.BinaryAccuracy(name='accuracy')
]
)

initial_history = initial_model.fit(
X_train_scaled,
y_train,
epochs=100,
callbacks=[
tf.keras.callbacks.LearningRateScheduler(
lambda epoch: 1e-3 * 10 ** (epoch / 30)
)
]

)

The training will start now and you’ll see a decent accuracy immediately — around 75% — but it will drop after 50-something epochs because the learning rate became too large. After 100 epochs, the initial_model had around 60% accuracy:

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: