LSTM based sequence-sequence/ encoder-decoder architecture

Sambhav Mehta
4 min readJul 29, 2024
Basic encoder-decoder architecture

In a seq2seq model, there are 3 important parts
1. Encoder
2. Context-Vector (hidden state at the last time step)
3. Decoder

Internal Diagram for seq2seq model

How does encoder decoder works?

Let’s take an example of English to Hindi machine translation for seq2seq model.

english i/p and hindi o/p

It is a supervised machine learning problem where we send first english words sequentially to the encoder and decoder predict the word and calculate loss function to update the gradient for all trainable parameters (weights and baises).

Step1: Tokenisation of the words in input and output column

Tokenising the words

Step2: Vectorization
Convert tokens into number separately for input and output.

For simplification we are using one-hot encoding technique
For input the vocabulary size is 5

OHE for input tokens

For output side we have all the unique tokens/words + 2 extra tokens (start_token and end_token). In total the output side has 6+2=8 words
This “start token” will be send to the decoder as input along with the hidden state of the last time step. The “end token” will signal the decoder to stop predicting the sequence word (inference/prediction).

OHE for output tokens

Training a sequence sequence model

Forward Propagation

Forward propagation state
  1. During forward propagation we will send OHE vectors to encoder as input at each time steps.
  2. At each time steps, the cell and the hidden states of encoder mode will get updated.
  3. At the last time steps, we will get final hidden state from the encoder which will input as the context vector for decoder along with start token.
  4. When the decoder predict output, it uses soft-max to predict the word with highest probability.
Teacher forcing in decoder

5. Even if decoder predict a wrong word, we will send actual word in the next time step. This method is called teacher force. We do this because the model converge very fast.

Backward Propagation:

Backward propagation in sequential data is called backward propagation through time (BPTT).

During backward propagation, we calculate the loss using categorical cross-entropy at every time step. Using this loss, we calculate the gradient, which in turn updates the weights.

Forward propagation until the training data finishes

Propagation first goes through decoder than context vecor and finally to encoder.

Decoder Back-propagation:
For each output step, compute the gradient of the loss with respect to the decoder’s output.

  • Backpropagate through time (BPTT) to compute gradients for each parameter in the decoder.
  • Accumulate gradients over all time steps.
  • Update the decoder parameters using gradient descent or another optimization algorithm.

Context Vector Gradient:
Aggregate the gradients from the decoder back to the context vector.

Encoder Backpropagation:
Backpropagate the gradient of the context vector through the encoder layers.

  • Compute gradients for each parameter in the encoder using BPTT.
  • Update the encoder parameters using gradient descent or another optimisation algorithm.

Parameter Update:
Apply the computed gradients to the encoder and decoder parameters, updating them in the direction that minimises the loss.

Step to improvement the model:

  1. Use embedding layer for both encoder and decoder. Either use pretrained embedding like word2vec or glove or make your own trainable embedding during the training process.
  2. Use Deep LSTM/ Stacked LSTM.
    a) Easily able to handle long term dependencies.
    b) Reduce Overfitting because of more trainable parameters.
    c) Layered representation which help to capture hierarchy i.e. initial layers capture word meaning, middle layers capture sentence meaning and top layers capture paragraph meaning.

Drawback of LSTM based Encoder-Decoder:

  1. Computation complexity and resource intensity.
  2. Alignment Issues.
  3. Unable to handle long sequences (>30 words).
  4. Huge data requirement to train the model.

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

Sambhav Mehta
Sambhav Mehta

Written by Sambhav Mehta

I make content on data science and related field

No responses yet

Write a response