Introduction
- The paper proposes a two-stage synthesis network that can perform transfer learning for the task of machine comprehension.
- The problem is the following:
- We have a domain DS for which we have labelled dataset of question-answer pairs and another domain DT for which we do not have any labelled dataset.
- We use the data for domain DS to train SynNet and use that to generate synthetic question-answer pairs for domain DT.
- Now we can train a machine comprehension model M on DS and finetune using the synthetic data for DT.
- Link to the paper
SynNet
- Works in two stages:
- Answer Synthesis - Given a text paragraph, generate an answer.
- Question Synthesis - Given a text paragraph and an answer, generate a question.
Answer Synthesis Network
- Given the labelled dataset for DS, generate a labelled dataset of <word, tag> pair such that each word in the given paragraph is assigned one of the 4 tags:
- IOBstart - if it is the starting word of an answer
- IOBmid - if it is the intermediate word of an answer
- IOBend - if it is the ending word of an answer
- IOBnone - if it is not part of any answer
- For training, map the words to their GloVe embeddings and pass through a Bi-LSTM. Next, pass them through two-FC layers followed by a softmax layer.
- For the target domain DT, all the consecutive word spans where no label is IOBnone are returned as candidate answers.
Question Synthesis Network
- Given an input paragraph and a candidate answer, Question Synthesis network generates question one word at a time.
- Map each word in the paragraph to their GloVe embedding. After the word vector, append a ‘1’ if the word was part of the candidate answer else append a ‘0’.
- Feed to a Bi-LSTM network (encoder-decoder) where the decoder conditions on the representation generated by the encoder as well as the question tokens generated so far. Decoding is stopped when “END” token is produced.
- The paragraph may contain some named entities or rare words which do not appear in the softmax vocabulary. To account for such words, a copying mechanism is also incorporated.
- At each time step, a Pointer Network (CP) and a Vocabulary Predictor (VP) are used to generate probability distribution for the next word and a Latent Predictor Network is used to decide which of the two networks would be used for the prediction.
- At inference time, a greedy decoding is used where the most likely predictor is chosen and then the most likely word from that predictor is chosen.
Machine Comprehension Model
- Given any MC model, first train it over domain DS and then fine-tune using the artificial questions generated using DT.
Implementation Details
- Data Regularization - There is a need to alternate between mini batches from source and target domain while fine-tuning the MC model.
- At inference time, the fine-tuned MC model is used to get the distribution P(i=start) and P(i=end) (corresponding to the likelihood of choosing word I as the starting or ending word for the answer) for all the words and DP is used to find the optimal answer span.
- Checkpoint Averaging - Use the different checkpointed models to average the answer likelihood before running DP.
- Using the synthetically generated dataset helps to gain a 2% improvement in terms of F-score (from SQuAD -> NewsQA). Using checkpointed models further improves the performance to overall 46.6% F score which closes the gap with respect to the performance of model trained on NewsQA itself (~52.3% F score)