Skip to content

Contact sales

By filling out this form and clicking submit, you acknowledge our privacy policy.

NMT: Encoder and Decoder with Keras

Nov 19, 2020 • 9 Minute Read

Introduction

This guide builds on skills covered in Encoders and Decoders for Neural Machine Translation, which covers the different RNN models and the power of seq2seq modeling. It also covered the roles of encoder and decoder models in machine translation; they are two separate RNN models, combined to perform complex deep learning tasks.

By the end of the previous guide, you will have the pre-processed data and have extracted the features you need to build the model.

In this part of the guide, you will use that data and the concepts of LSTM, encoders, and decoders to build a network that gives optimum translation results. Finally, these results are further used to build a simple code to learn Spanish, which will give you random English sentences with their Spanish translations.

Let's start with building the model.

Building the Model

The first step is to define an input sequence for the encoder. Because it's a character-level translation, it plugs the input into the encoder character by character. Now you need the encoder's final output as an initial state/input to the decoder. So, for the encoder LSTM model, the return_state = True. With this, you can get the hidden state representation of the encoder at the end of the input sequence. state_h denotes a hidden state and state_c denotes cell state.

      encoder_inputs = keras.Input(shape=(None, num_encoder_tokens))
encoder = keras.layers.LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)

encoder_states = [state_h, state_c]
    

This sets the initial state for the decoder in decoder_inputs. The first character got from one-hot encoding (decoder_input_data), i.e., SOS or \t is embedded with the final encoded state, to the decoder network to get the first target character.

Again, the LSTM return_sequences and return_state are kept True so that the network considers the decoder output and two decoder states at every time step. The model will run through each layer of the network, one step at a time, and add a softmax activation function at the last layer's output. This will give out your first output word. It feeds this word back and predicts the complete sentence.

      decoder_inputs = keras.Input(shape=(None, num_decoder_tokens))
decoder_lstm = keras.layers.LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
decoder_dense = keras.layers.Dense(num_decoder_tokens, activation="softmax")
decoder_outputs = decoder_dense(decoder_outputs)
    

Training and Saving the Model

Now the aim is to train the basic LSTM-based seq2seq model and predict decoder_target_data and compile the model by setting the optimizer and learning rate, decay, and beta values. It calculates the loss and validation loss. Accuracy is the performance matrices. Next, fit the model, and split the data into an 80-20 ratio. And finally, use save() to save the model.

      model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)

model.compile(optimizer=Adam(lr=0.01, beta_1=0.9, beta_2=0.999, decay=0.001), loss='categorical_crossentropy', metrics=["accuracy"])

model.fit(
    [encoder_input_data, decoder_input_data],
    decoder_target_data,
    batch_size=batch_size,
    epochs=epochs,
    validation_split=0.2,
)
model.save("E2S")
    
      from keras.utils import plot_model
plot_model(model, to_file='modelsummary.png', show_shapes=True, show_layer_names=True)
    
      print("shape encoder_input_data :",encoder_input_data.shape)
print("shape decoder_input_data :",decoder_input_data.shape)
print("shape decoder_target_data:",decoder_target_data.shape)
    

Decode the Sentence

Finally, create the model by using Keras model() function for encoder_inputs i.e., input tensor and encoder hidden states state_h_enc and state_c_enc as output tensor.

      encoder_inputs = model.input[0]  # input_1
encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output  # lstm_1
encoder_states = [state_h_enc, state_c_enc]
encoder_model = keras.Model(encoder_inputs, encoder_states)
    

Now build the model for the decoder.

      decoder_inputs = model.input[1]  # input_2
decoder_state_input_h = keras.Input(shape=(latent_dim,), name="input_3")
decoder_state_input_c = keras.Input(shape=(latent_dim,), name="input_4")
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_lstm = model.layers[3]
decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(
    decoder_inputs, initial_state=decoder_states_inputs
)
decoder_states = [state_h_dec, state_c_dec]
decoder_dense = model.layers[4]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = keras.Model(
    [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states
)
    

Create two reverse-lookup token indexes to decode the sequence to make it readable.

      reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())
    

Next, create a predict function named decode_sequence. After generating the empty sequence of length 1, the model should know when to start and stop reading the text. To read the model will check out for \t in this case. Keep two conditions, either when the max length of sentence is hit or find stop character \n. Keep on updating the target sequence by one and update the states.

      def decode_sequence(input_seq):
    states_value = encoder_model.predict(input_seq)

    target_seq = np.zeros((1, 1, num_decoder_tokens))
    target_seq[0, 0, target_token_index["\t"]] = 1.0

    stop_condition = False
    decoded_sentence = ""
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict([target_seq] + states_value)

        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        if sampled_char == "\n" or len(decoded_sentence) > max_decoder_seq_length:
            stop_condition = True

        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.0

        states_value = [h, c]
    return decoded_sentence
    

Learn Spanish

A random sentence will appear when you run the cell. The sentences are basic. It's always an add-on to your skills to learn a new foreign language. Also, it will be helpful when you visit Spain :)

      i = np.random.choice(len(input_texts))
input_seq = encoder_input_data[i:i+1]
translation = decode_sequence(input_seq)
print('-')
print('Input:', input_texts[i])
print('Translation:', translation)
    

Validate with google translator.

Perfecto!!

Conclusion

The character-by-character translation is accurate. Seq2seq models can deal with variable-length inputs. Encoders and decoders work together. Encoders' LSTM weights are updated so they learn space representation of the text, whereas decoders' LSTM weights give grammatically correct sentences. The performance of any project depends on the model you choose and the volume and pre-processing of the data. But hyper-parameters also play a major role in deep learning problems. You can improve the accuracy of this model as well by tuning the hyper-parameters or increasing the data.

Machine translation can also be performed by using the GRU RNN model. It's a cousin to LSTM with fewer states. I would recommend that you understand different RNN models. You can learn more about GRU here and learn to understand the difference between the two RNNs and select the model that gives you the best results.

If you have any queries regarding this guide, feel free to ask at Codealphabet.

Gaurav Singhal

Gaurav S.

Guarav is a Data Scientist with a strong background in computer science and mathematics. He has extensive research experience in data structures, statistical data analysis, and mathematical modeling. With a solid background in Web development he works with Python, JAVA, Django, HTML, Struts, Hibernate, Vaadin, Web Scrapping, Angular, and React. His data science skills include Python, Matplotlib, Tensorflows, Pandas, Numpy, Keras, CNN, ANN, NLP, Recommenders, Predictive analysis. He has built systems that have used both basic machine learning algorithms and complex deep neural network. He has worked in many data science projects, some of them are product recommendation, user sentiments, twitter bots, information retrieval, predictive analysis, data mining, image segmentation, SVMs, RandomForest etc.

More about this author