Seq2Seq
https://towardsdatascience.com/day-1-2-attention-seq2seq-models-65df3f49e263
A Seq2Seq model is a model that takes a sequence of items (words, letters, time series, etc) and outputs another sequence of items. It can be used as a model for machine interaction and machine translation.
Business use cases: chatbot, machine translation, question answering, text summarization, text generation.
datasets: Cornell Movie — Dialogs Corpus Dataset which contains over 220,579 conversational exchanges between 10,292 pairs of movie characters. And it involves 9,035 characters from 617 movies.
The model is composed of an encoder and a decoder. The encoder captures the context of the input sequence in the form of a hidden state vector and sends it to the decoder, which then produces the output sequence. Since the task is sequence based, both the encoder and decoder tend to use some form of RNNs, LSTMs, GRUs
RNN
RNNs by design, take two inputs, the current example they see, and a representation of the previous input. Thus, the output at time step t depends on the current input as well as the input at time t-1. This is the reason they perform better when posed with sequence related tasks. The sequential information is preserved in a hidden state of the network and used in the next instance.
The Encoder, consisting of RNNs, takes the sequence as an input and generates a final embedding at the end of the sequence. This is then sent to the Decoder, which then uses it to predict a sequence, and after every successive prediction, it uses the previous hidden state to predict the next instance of the sequence.
Drawback: The output sequence relies heavily on the context defined by the hidden state in the final output of the encoder, making it challenging for the model to deal with long sentences. In the case of long sequences, there is a high probability that the initial context has been lost by the end of the sequence.
Solution is attention
The issue was that a single hidden state vector at the end of the encoder wasn’t enough, we send as many hidden state vectors as the number of instances in the input sequence.
Another valuable addition to creating the Attention based model is the context vector. This is generated for every time instance in the output sequences. At every step, the context vector is a weighted sum of the input hidden states.
The generated context vector is combined with the hidden state vector by concatenation and this new attention hidden vector is used for predicting the output at that time instance. Note that this attention vector is generated for every time instance in the output sequence and now replaces the hidden state vector.
Attention scores: these are the output of another neural network model, the alignment model, which is trained jointly with the seq2seq model initially. The alignment model scores how well an input (represented by its hidden state) matches with the previous output (represented by attention hidden state) and does this matching for every input with the previous output. Then a softmax is taken over all these scores and the resulting number is the attention score for each input.
Hence, we now know which part of the input is most important for the prediction of each of the instances in the output sequence. In the training phase, the model has learned how to align various instances from the output sequence to the input sequence. Below is an illustrated example of a machine translation model, shown in a matrix form. Note that each of the entries in the matrix is the attention score associated with the input and the output sequence.
LSTM
A common LSTM unit is composed of a cell, an input gate, an output gate[14] and a forget gate. The cell remembers values over arbitrary time intervals and the three gates regulate the flow of information into and out of the cell. Typical LSTM layer has about 300k params per layer.
GRU
GRU is like a LSTM with a forget gate, but has fewer parameters than LSTM, as it lacks an output gate. GRUs have been shown to exhibit better performance on certain smaller and less frequent datasets.