Image Captioning with an End to End Transformer Network.

Original Source Here

Image Captioning with an End to End Transformer Network.

This article is my attempt to elaborate on the attention mechanism and the Transformer Network to solve sequence to sequence problems through Image captioning with transformer Networks.

Transformer Networks are deep learning models that learn context and meaning in sequential data by tracking the relationships between the sequences. Since the introduction of Transformer Networks in 2017 by Google Brain in their revolutionary paper “Attention is all you need”, transformers have been outperforming conventional neural networks in various problem domains, like Neural Machine Translation, Text Summarization, Language Understanding, and other Natural Language Processing tasks. Along with this, they have also proved to be quite effective in Computer Vision tasks like Image Classification with Vision Transformers and Generative Networks as well.

In this article, I will be trying to elaborate on my understanding of the attention mechanism through vision transformers and on sequence to sequence tasks through Transformer Networks.

Image Captioning as a Sequence to Sequence Problem

For problems in the Image Domain, like Image Classification and feature extraction from Images, Deep Convolutional Neural Network architectures like ResNet and Inception are used. These Neural Networks extract features from a series of Convolution operations. Convolution layers in a Neural Network extract feature maps from images to classify and recognize them. Instead of using a CNN architecture, we can treat image problems like we do while dealing with text in NLP.

Visualizing the Learned features of a Convolutional Layer demonstrated at

In Natural Language processing, we represent natural language as a sequence of tokens and try to learn the relationship between them. Each token has its representation as an n-dimensional vector also known as an embedding and these embeddings can be learned for language understanding. Similarly, we can represent an image as a sequence of patches, encode them with positional encodings to add information about their relative positions, and then represent each patch as an n-dimensional embedding. This will make the images ready as sequences.

Image representation as a series of patches and their embeddings
from tensorflow import keras
from tensorflow.keras import layers
class Patches(layers.Layer):
def __init__(self, patch_size):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
## Patch Encoder Layer
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded

The Patches class breaks down the image into patches and the Patch Encoder Layer encodes the patches with Positional Encodings. The shapes of the patches with the following hyperparameters can be checked:

resized_image = tf.image.resize(
tf.convert_to_tensor([try_img]), size=(IMAGE_SIZE, IMAGE_SIZE)
patches = Patches(PATCH_SIZE)(resized_image)
## Checking the shapes
print(f"Shape of the resized image {resized_image.shape}")
print(f"Shape of the patches: {patches.shape}")
encoded_patches = PatchEncoder(NUM_PATCHES, PROJECTION_DIM)(patches)
print(f"Shape of the Encoded patches: {encoded_patches.shape}")


Shape of the resized image (1, 128, 128, 3)
Shape of the patches: (1, 64, 768)
Shape of the Encoded patches: (1, 64, 512)

The Transformer Network

The Transformer Network learns context and thus meaning in a sequence through the attention mechanism, meaning they can extract information even from a variable length of sequences. Below is a step-by-step walkthrough of the Transformer network.

Why Transformers?

Before transformers, sequence to sequence problems was carried out with an Encoder-Decoder architecture with RNN cells, where LSTMs (Long Short Term Memory) units were a popular choice, as they could retain information for comparatively longer sequences than traditional RNN cells. But even LSTM networks suffered from the vanishing gradient problem for long sequences. LSTMs typically lose information with sequences over the length of 100. Another issue with LSTM cells was that it was difficult to parallelize the computations carried out in an LSTM network.

The Vanishing Gradient Problem (

Transformers tackle this issue by looking at the entire sequence and assigning them attention weights that intuitively, signify the importance of one part of a sequence. With Transformers, the vanishing gradient problem is also solved and the computations can be carried out in parallel.

Transformer Network Architecture

The Transformer Network Architecture

The Transformer Network is an Encoder-Decoder architecture but unlike other Sequence to Sequence architectures, Transformers rely on attention mechanism instead of RNN cells to understand the relationship between sequences. The Encoder takes in the Input Embeddings and outputs a vector representation with the shape (batch size, sequence length, model dimensions). The Decoder takes in the target sequence as its input and, first attends to itself for language understanding and then, tries to compute its relationship with the encoder outputs. At a high level, it tries to figure out which part of the Input is the relevant information for the corresponding target. The Decoder then represents this information as a vector and finally gives output probabilities

The Attention Mechanism

Transformers learn the inter-relationship between sequences through the attention mechanism. Intuitively, attention is a weight on how important a token (or patch) is in the sequence considering our problem domain. Considering our problem of Image captioning, at the encoder, attention weights are assigned to each patch of the image. These weights signify what part of the image the model should attend to while performing a task. These weights are learned and updated over time as the model trains.

Below, the attention weights are visualized showing what part of the image each attention head is focusing on.

Visualization of Attention weights on an image.

Scaled Dot Product Attention

In the paper, the authors use Multi Headed Attention which implements the scaled dot product attention, a simple but powerful form of attention mechanism.

The Query, Key, and Value:

Query, key and value are Transformer terminologies which have quite the literal meaning. Each Key vector is associated with a value vector and is queried by the query vector. In our problem domain consider the following:

Each patch of the image is the key K and each key K is associated with the value V. While computing the Multi-Head Attention in the Decoder, the patches are queried with tokens Q

If the token to be queried is the word “frisbee”, the scaled dot product attention, assigns the highest weight to the key, (In our case patches of the image showing the frisbee)

Attention plots for images taken from the paper “Show, Attend, and Tell Neural Image Caption Generation with Visual Attention”

Mathematically, the scaled dot product attention is calculated as:

Calculation of the Scaled Dot Product Attention.

Here Q, K, and V are tensors, which are the results of the matrix multiplication between trained weights Wq, Wk, and Wv. As these three matrices of weights are learned during training, the query, key, and value vectors end up being different despite of the identical input. This is why Multi-Head Attention discussed later makes sense.

In the expression for calculation of the dot product attention, Q and K are tensors. Closer Query and Key Vectors will have higher dot products. These dot products are scaled with dk which is the scaling factor for the operation referring to the dimensions of the Key tensor. Applying the softmax function will normalize the dot product to scores between 0 and 1. Finally, multiplying the softmax results with the value vectors will push the low-scoring Query and Key vectors close to zero.

Below is the implementation of the scaled dot product attention in code:

def scaled_dot_product_attention(q, k, v, mask):
"""Calculate the attention weights.
q, k, v must have matching leading dimensions.
k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
The mask has different shapes depending on its type(padding or look ahead)
but it must be broadcastable for addition.
q: query shape == (..., seq_len_q, depth)
k: key shape == (..., seq_len_k, depth)
v: value shape == (..., seq_len_v, depth_v)
mask: Float tensor with shape broadcastable
to (..., seq_len_q, seq_len_k). Defaults to None.
output, attention_weights
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)# scale matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# add the mask to the scaled tensor.
if mask is not None:
scaled_attention_logits += (mask * -1e9)
# softmax is normalized on the last axis (seq_len_k) so that the scores
# add up to 1.
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)return output, attention_weights

Multi-Head Attention:

In a Transformer Network, the Attention Module described earlier repeats its computations multiple times in parallel. Each of these computations is called an Attention Head. Each of these calculations is later combined to produce a final attention score.

As discussed earlier, for each Attention Head, Queries Q, Keys K, and Values V are calculated as results from learnable weight matrices Wq, Wk, and Wv. These weight matrices are different for each head, hence, for n number of heads n different weight matrices Wq, Wk and Wv are learned. Scaled dot product attention is carried out in each of the heads and finally, the results are concatenated together to form a rich representation.

As for each head, different weight matrices are learned, this allows the Transformer to encode multiple relationships and nuances for each element in the sequence. With Multi-Head Attention, in our problem domain, each head can have a different representation for a patch in the Encoder, and each word can have multiple relationships and nuances in the Decoder.

Mathematical representation of the multi-head Attention
Attention Plots for the same image with different attention heads.

Tensorflow implementation of Multi-Head Attention:

class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self,*, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0self.depth = d_model // self.num_headsself.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
#print(f"Query shape: {q.shape}")
#print(f"Key shape: {k.shape}")
#print(f"value shape: {v.shape}")
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = scaled_dot_product_attention(
q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)return output, attention_weights

Image Captioning with Transformers:

The Architecture of the Image Captioning Model. Source: “CPTR: Full transformer network for Image Captioning”

The Transformer for Image captioning consists of an Encoder and a decoder with Encoder layers and Decoder Layers within them, below, the components of the transformer is elaborated:

The Encoder

The Encoder of the transformer takes in an image of dimensions (batch_size, image_size, image_size, num_channels) as input, breaks down the images into patches with the Patches class implemented in the above example, then positional embeddings are added to the patches as discussed above and passed on to the Encoder layer. The Encoder layer has a self-attention layer with residual connections which self attend to the patches. The attention layer outputs a context vector which is then passed to a point-wise feed foreword neural network. Finally, the Encoder outputs a tensor of shape (batch_size, seq_length, d_model)

However, as transformers have an extremely large set of learnable parameters and computation costs, though parallelized is quite significant, for this implementation, a pretrained vision transformer is used from TensorFlow hub.

from vit_keras import vitvit_model = vit.vit_b32(
image_size = IMAGE_SIZE,
activation = 'softmax',
pretrained = True,
include_top = False,
pretrained_top = False,
new_input = vit_model.input
hidden_layer = vit_model.layers[-2].output
vision_transformer_model = tf.keras.Model(new_input, hidden_layer)
### The Encoder Classclass Encoder(tf.keras.layers.Layer):
def __init__(self, d_model, vision_transformer):
super(Encoder, self).__init__()
self.vit = vision_transformer
self.units = d_model
self.dense = tf.keras.layers.Dense(self.units, activation=tf.nn.gelu)
def call(self, x, training, mask):
## x: (batch, image_size, image_size, 3)
x = self.vit(x)
x = self.dense(x)
return x

The Decoder

The Decoder takes in the target captions as the input positionally encodes it, and passes it through a masked self-attention layer. The self-attention layer is masked for the decoder to not see the outputs from the future. This allows the decoder to understand the context of the target captions. The target captions are then used as a query for the Multi-Head Attention layer with the encoder outputs as the keys and values. This allows the network to learn the relationships between the captions and the relevant parts of the images. During training, these attention weights are learned and backpropagated through the entire network, and the network learns over time. The Multi-Head Attention layer produces a context vector which is then passed through a softmax function to output a set of probabilities for prediction.

def point_wise_feed_forward_network(d_model, dff):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
### Decoder Layer:
class DecoderLayer(tf.keras.layers.Layer):
def __init__(self,*, d_model, num_heads, dff, rate=0.1):
super(DecoderLayer, self).__init__()
self.mha1 = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
self.mha2 = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
self.dropout3 = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training,
look_ahead_mask, padding_mask):
# enc_output.shape == (batch_size, input_seq_len, d_model)
#print(f"Decoder Layer input x shape: {x.shape}")attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model)
attn1 = self.dropout1(attn1, training=training)
#print("attn2 shape: ", attn1.shape)
#print(f"x shape: {x.shape}")
out1 = self.layernorm1(attn1 + x)
#print(f"Encoder outpur (Value and key) shape: {enc_output.shape}")
#print(f"out1 (Query) shape: {out1.shape}")
attn2, attn_weights_block2 = self.mha2(
enc_output, enc_output, out1, None) # (batch_size, target_seq_len, d_model)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layernorm2(attn2 + out1) # (batch_size, target_seq_len, d_model)
ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layernorm3(ffn_output + out2) # (batch_size, target_seq_len, d_model)
return out3, attn_weights_block1, attn_weights_block2class Decoder(tf.keras.layers.Layer):
def __init__(self,*, num_layers, d_model, num_heads, dff, target_vocab_size, max_tokens,
super(Decoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.max_tokens = max_tokens
self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
self.pos_encoding = positional_encoding(max_tokens, d_model)
self.dec_layers = [
DecoderLayer(d_model=d_model, num_heads=num_heads, dff=dff, rate=rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training,
look_ahead_mask, padding_mask):
seq_len = tf.shape(x)[1]
attention_weights = {}
x = self.embedding(x) # (batch_size, target_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)for i in range(self.num_layers):
x, block1, block2 = self.dec_layers[i](x, enc_output, training,
look_ahead_mask, padding_mask)
attention_weights[f'decoder_layer{i+1}_block1'] = block1
attention_weights[f'decoder_layer{i+1}_block2'] = block2
# x.shape == (batch_size, target_seq_len, d_model)
return x, attention_weights

For instance, if the caption says “cat”, the relevant parts of the picture showing the picture of a cat are learned by the network through backpropagation.

Combining the Encoder and decoder, the Transformer model can be created as a TensorFlow Keras Model, as shown below:

class Transformer(tf.keras.Model):
def __init__(self,*, num_layers, d_model, num_heads, dff,
target_vocab_size, vision_transformer, max_tokens, rate=0.1):
self.vision_transformer = vision_transformer
self.encoder = Encoder(d_model,
self.decoder = Decoder(num_layers=num_layers, d_model=d_model,
num_heads=num_heads, dff=dff,
target_vocab_size=target_vocab_size, max_tokens=max_tokens, rate=rate)
self.final_layer = tf.keras.layers.Dense(target_vocab_size)def call(self, inputs, training):
# Keras models prefer if you pass all your inputs in the first argument
#print("Transformer call function called")
inp, tar = inputs
#print(f"inp: {inp.shape}. tar: {tar.shape}")
padding_mask, look_ahead_mask = self.create_masks(inp, tar)
#print(f"Mask shapes: {padding_mask.shape}. {look_ahead_mask.shape}")enc_output = self.encoder(inp, training, None) # (batch_size, inp_seq_len, d_model)
#print("Encoder Works")
# dec_output.shape == (batch_size, tar_seq_len, d_model)
dec_output, attention_weights = self.decoder(
tar, enc_output, training, look_ahead_mask, padding_mask)
final_output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size)return final_output, attention_weightsdef create_masks(self, inp, tar):
# Encoder padding mask (Used in the 2nd attention block in the decoder too.)
padding_mask = create_padding_mask(inp)
# Used in the 1st attention block in the decoder.
# It is used to pad and mask future tokens in the input received by
# the decoder.
look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
dec_target_padding_mask = create_padding_mask(tar)
look_ahead_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
return padding_mask, look_ahead_mask


The Transformer Module was trained with a custom Sparse Categorical Cross-Entropy Loss with a learning rate scheduler, Adam optimizer, and a batch size of 128 for about 2 epochs (due to training time and computational resources). For experimentation, the transformer had 256 dimensions and 8 layers in total with 4 heads for Multi-Head Attention. After 2 epochs, the transformer could make decent predictions on the image.

Detailed code can be found in the notebook provided at the end of the article.


Qualitative analysis of the results shows that the transformer network could generate decent captions for the input images even with limited time and resources. Below are some captions generated by the transformer network:

More results are available in the notebook provided below. Better results could be obtained if the model was trained for more epochs and if the hyperparameters were tuned for a larger number of learnable parameters

Further applications

As the transformer model is able to quite fabulously understand the relationship between the images and the captions, intuitively, I can think of it having applications for captioning images in domains convolutional neural networks fail to generalize well, such as Medical images. As they are quite something at the sequence to sequence problems video captioning would also be an interesting application of this caption transformer model.

Project on Github

Social Media Links

Reach out to me for projects or queries on:


or drop an email:


Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot

%d bloggers like this: