Deep Q-Network, with PyTorch

Original Source Here

Deep Q-Network, with PyTorch

Explaining the fundamentals of model-free RL algorithms: Deep Q-Network Model (with code!)

Photo by Mathias P.R. Reding on Unsplash

In Q-Learning, we represent the Q-value as a table. However, in many real-world problems, there are enormous state and/or action spaces and tabular representation is insufficient. For instance, Computer Go has 10¹⁷⁰ states and games like Mario Bro has continuous state space. When it is impossible to store all possible combinations of state and action pair values in the 2-D array or Q table, we need to use Deep Q-Network (DQN) instead of Q-Learning algorithm. [1]

DQN is also a model-free RL algorithm where the modern deep learning technique is used. DQN algorithms use Q-learning to learn the best action to take in the given state and a deep neural network or convolutional neural network to estimate the Q value function.

An illustration of DQN architecture

The input to the neural network consists of an 84 x 84 x 4 image, following by 3 convolutional layers and 2 fully connected layers which output a single output for each valid action. [1]

An illustration of DQN architecture [1]

DQN Algorithm

DQN Algorithm [1]

Main Component of DQN — 1. Q-value function

In DQN, we represent value function with weights w,

Q-value function. Image by Author derives from [1].
  • The Q network works like the Q table in Q-learning when selecting actions. While states in Q-learning are countable and finite, states in DQN can be either finite or infinite/ continuous or discrete.
  • The updates in the Q network is done via updating the weights.

Main Component of DQN — 2. Loss function

Let’s define the objective function by the mean squared error in Q-values.

Loss function. Image by Author derives from [1].

It is a loss function for minimizing the error to update the weights in the Q network.

Main Component of DQN — 3. Optimization algorithms

Let’s use Stochastic Gradient to optimize the above objective function, using δL(w)/δw. There are many optimization algorithms available to use in Tensorflow or PyTorch. For example Adam, RMSProp, Adagrad, etc.

Main Component of DQN — 4. Experience replay

Naive Q-learning oscillates or diverges with Neural Networks.

Data is sequential which means successive samples are correlated, not independent and identically distributed.

The policy changes or oscillates rapidly with slight changes to Q-values, therefore, the distribution of data can shift from one extreme to another.

The scale of rewards and Q-values is unknown. The gradients of naive Q-learning can be largely unstable when backpropagated. [2]

To solve the above problem, we can store transitions in a replay buffer and sample a small batch of experience from the replay buffer to update the Q-network. from it. By using experience replay, it will break the successive correlation among the samples and also allows the network to better make use of the experiences. [1]


Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: