Building a Transformer-Based, CNN-Powered Image Captioning Model

David Landup
David Landup

With our datasets primed and ready to go - we can define the model. Using KerasNLP, we can fairly easily implement a transformer from scratch:

# Encoder
encoder_inputs = keras.Input(shape=(None,))
x = keras_nlp.layers.TokenAndPositionEmbedding(...)(encoder_inputs)
encoder_outputs = keras_nlp.layers.TransformerEncoder(...)(inputs=x)

encoder = keras.Model(encoder_inputs, encoder_outputs)

# Decoder
decoder_inputs = keras.Input(shape=(None,))
encoded_seq_inputs = keras.Input(shape=(None, EMBED_DIM))
x = keras_nlp.layers.TokenAndPositionEmbedding(...)(decoder_inputs)
x = keras_nlp.layers.TransformerDecoder(...)(decoder_sequence=x, encoder_sequence=encoded_seq_inputs)
x = keras.layers.Dropout(0.5)(x)
decoder_outputs = keras.layers.Dense(...)(x)

decoder = keras.Model([decoder_inputs,encoded_seq_inputs], decoder_outputs)
# The output of the transformer is the output of the decoder
transformer_outputs = decoder([decoder_inputs, encoder_outputs])

transformer = keras.Model([encoder_inputs, decoder_inputs], transformer_outputs)

The input is followed by TokenAndPositionEmbedding() and either a TransformerEncoder or TransformerDecoder. The model is put together by feeding the output of the encoder into the decoder, besides the input it already gets, which is embedded positionally. This architecture is pretty much a perfect reflection of the diagram from the paper. Now, we'll have to make a few tweaks here - transformers work on sequences, and our images aren't sequences, nor are the feature maps output by a ConvNet.

Start project to continue
Lessson 2/3
You must first start the project before tracking progress.
Mark completed

© 2013-2024 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms