LAMBADA Method: How to use Data Augmentation in NLU?
In this tutorial, I will walk you through the implementation to reproduce LAMBADA.
From my previous article, which illustrate the basic idea of LAMBADA method that leverage Natural Language Generation(NLG) to boost training set for the Natural Language Understanding(NLU) task including text classification.
Before you dive into the code fragment, you may have a look at my previous article about the basic idea of the LAMBADA, including the fundamental thinking, and the workflow.
Step 1: Preparation
- We use
distilBERT
as a classification model andGPT-2
as text generation model. For both, we load pretrained weights and fine tune them. - In case of
GPT-2
we apply the Huggingface Transfomers library to bootstrap a pretrained model and subsequently to fine-tune it. - To load and fine-tune DistilBERT we use Ktrain, a library that provides a high-level interface for language models, eliminating the need to worry about tokenization and other pre-processing tasks.
1 | !pip install ktrain |
Step 2: Load Data
Then, we load the data from the csv file, which can be obtained from my repository. We split it into train set, valid set, and test set.
1 | labels = data_train['Label'].unique() |
You can see the labels are:
array([‘label15’, ‘label7’, ‘label0’, ‘label13’, ‘label9’, ‘label8’, ‘label2’, ‘label4’, ‘label1’, ‘label10’, ‘label5’, ‘label3’,
‘label14’, ‘label11’, ‘label12’, ‘label6’], dtype=object)
Step 3: Training the Initial Intent Classifier (BERT)
Initialize model and learner
1 | import ktrain |
We download the pretrained DistilBERT model, transform the training and validation data from pure text into the valid format for our model and initialize a learner object, which is used in KTrain to train the model.
1 | distil_bert = text.Transformer('distilbert-base-cased', |
Train classifier
- Train classifier for given learning rate and number of epochs.
- The number of epochs chosen depends on the size of your training data set.
- Make sure to monitor the accuracies and losses!
Now it’s time to train the model. We feed the training data to the network multiple times, specified by the number of epochs. In the beginning both monitored metrics, namely the loss function (decrease) and the accuracy (increase), should indicate improvement of the model with each epoch passed. However, after training the model for a while the validation loss will increase and the validation accuracy drop. This is a result of overfitting the training data and it is time to stop feeding the same data to the network.
The optimal number of epochs depends on your data set, model and training parameters. If you do not know the right number of epochs beforehand you can use a high number of epochs and activate checkpoints by setting the checkpoint_folder parameter to select the best performing model afterwards.
1 | N_TRAINING_EPOCHS = 1 |
Evaluate trained predictor
To check the performance of our trained classifier, we use our test data in the eval.csv
file.
1 | predictor = ktrain.get_predictor(learner.model, preproc=distil_bert) |
Note that thanks to the KTrain interface we can simply feed the list of utterances to the predictor without the need to pre-process the raw strings beforehand.
Prepare model for download
1 | import datetime |
Step 4: Fine-tune GPT-2 to generate utterances
Fine-tune GPT-2
To fine-tune GPT-2, we use a Python script made available by Huggingface on their Github repository: https://github.com/huggingface/transformers
Put transformed dataset in directory where this jupyter notebook located at in order to run python script smoothly.
1 | utterance_file = data_train[['Label', 'Text']] |
Among others, we specify the following parameters:
- the pretrained model that we want to use (gpt2-medium). Larger models, typically generate better text outputs. Please note, these models require a large amount of memory during training, so make sure you pick a model that fits into your (GPU-)memory.
- the number of epochs. This parameter specifies how many times the training data is fed through the network. On the one hand, if the number of epochs is too small, the model will not learn to generate useful utterances. On the other hand, if the number is chosen too big, the model will likely overfit and the variability in the generated text data will be limited – the model will basically just remember the training data.
- the batch size. This determines how many utterances are used for training in parallel. The larger the batch size the faster the training, larger batch sizes require more memory, though.
- the block size. The block size defines an upper bound on the number of tokens considered from each training data instance that are used. Make sure that this number is sufficient so that utterances are not cropped.
1 | !python finetune_gpt.py \ |
Load and Manually Test Model
You can play around with the model, generating utterances for different intents. See how the parameters top_k and top_p influence the result.
1 | from transformers import GPT2Tokenizer, TFGPT2LMHeadModel |
Top k sampling means sorting by probability and zero-ing out the probabilities for anything below the k’th token. It appears to improve quality by removing the tail and making it less likely to go off topic. But in some cases, there really are many words we could sample from reasonably (broad distribution below), and in some cases there aren’t (narrow distribution below).
To address this problem, the authors propose top p sampling, aka nucleus sampling, in which we compute the cumulative distribution and cut off as soon as the CDF exceeds P. In the broad distribution example above, it may take the top 100 tokens to exceed top_p = .9. In the narrow distribution, we may already exceed top_p = .9 with just “hot” and “warm” in our sample distribution. In this way, we still avoid sampling egregiously wrong tokens, but preserve variety when the highest scoring tokens have low confidence.
1 | input_ids = tokenizer.encode('i m trying to', return_tensors='tf') |
0: i m trying to get a list of all the words in the wordlist and their synonyms. it s a wordlist of about 10k words. i m trying to do a word frequency plot. the word frequency plot shows that the frequency of
1: i m trying to find out if there are any books on bitcoin cash which are free for anyone to download. the author of the book, jr. b. lang, is a bitcoin cash expert and has been quoted in many publications as saying that
2: i m trying to find a way to get my iphone from to my apple apple tv via usb. my iphone is 3rd gen and apple tv 2nd gen.
3: i m trying to determine if there are two sets of rules for a particular problem that can be applied to any other problem. one set of rules is for discrete cases and the second set is for continuous cases.
4: i m trying to understand how a node can be a node in a dapp. i m trying to understand how a node can be a node in a dapp.
5: i m trying to do a pulldown on a website with a template.i m using jquery.getElementsByTagName(‘meta’) and the following output
6: i m trying to determine the best way to store my bitcoin. my bitcoin is stored on my master wallet. however, i want to add the bch address from my wallet to my bitcoin. my master wallet only has the private key of the wallet
7: i m trying to find a way to show the number of lines in the code of a function. i m using the following code snippet from the os x man page:
8: i m trying to find the code for my samsung galaxy s3. the s3 is a s1 with an ikon gsm camera. the camera is a 1.5m lens. it also has the moto g
9: i m trying to figure out a way to find out the time and date of the most recent call. i m using the time-to-call function from the samsung s go-pro app, but the function does not work on my device
Prepare model for download
1 | !zip -r -X gpt-2_tuned.zip .//content//transformers//output |
adding: /content//transformers//output/ (stored 0%)
adding: /content//transformers//output/tokenizer_config.json (deflated 37%)
adding: /content//transformers//output/vocab.json (deflated 59%)
adding: /content//transformers//output/config.json (deflated 51%)
adding: /content//transformers//output/training_args.bin (deflated 44%)
adding: /content//transformers//output/merges.txt (deflated 53%)
adding: /content//transformers//output/special_tokens_map.json (deflated 52%)
adding: /content//transformers//output/checkpoint-500/ (stored 0%)
adding: /content//transformers//output/checkpoint-500/optimizer.pt (deflated 9%)
adding: /content//transformers//output/checkpoint-500/tokenizer_config.json (deflated 37%)
adding: /content//transformers//output/checkpoint-500/vocab.json (deflated 59%)
adding: /content//transformers//output/checkpoint-500/config.json (deflated 51%)
adding: /content//transformers//output/checkpoint-500/training_args.bin (deflated 44%)
adding: /content//transformers//output/checkpoint-500/merges.txt (deflated 53%)
adding: /content//transformers//output/checkpoint-500/special_tokens_map.json (deflated 52%)
adding: /content//transformers//output/checkpoint-500/scheduler.pt (deflated 49%)
adding: /content//transformers//output/checkpoint-500/pytorch_model.bin (deflated 9%)
adding: /content//transformers//output/pytorch_model.bin (deflated 9%)
Step 5: Generate and Filter New Utterances
We now generate the new utterances for all intents. To have a sufficiently large sample that we can choose the best utterances from, we generate 200 per intent.
1 | NUMBER_OF_GENERATED_UTTERANCES_PER_INTENT = 200 |
Generate the result by calling the function above:
1 | labels = data_train["Label"].unique() |
Save file:
1 | generated_utterances_df.to_csv("generated.csv", index=False) |
Train BERT classifier with augmented dataset
After a while the data is generated, and we can have a closer look at it. First, we use our old distilBERT classifier to predict the intent for all generated utterances. We also keep track of the prediction probability indicating the level of confidence of each individual prediction made by our model.
1 | generated_data = generated_utterances_df.assign( |
Let’s have a look at some of the utterances for which the intent used for generation does not match the predicted intent.
1 | generated_data_predicted[generated_data_predicted['intent'] != |
intent | utterance | predicted_intent | prediction_proba | |
---|---|---|---|---|
0 | label12 | why is there a special category called coset … |
label14 | 0.602601 |
1 | label12 | how can i set up qgis for different operating … | label13 | 0.918635 |
2 | label12 | is there a minimum required work weekly for eu… | label6 | 0.835881 |
5 | label12 | who was ryan about the arcania | label10 | 0.564748 |
7 | label12 | example of combining points data | label13 | 0.927397 |
8 | label12 | how can i switch between my the clock/utc-rpi … | label3 | 0.430493 |
9 | label12 | which mlb agent is responsible for taxonomy of… | label11 | 0.572228 |
10 | label12 | how can i prove that two points are continuous… | label14 | 0.950781 |
We can see that in some cases the prediction is clearly wrong. However, there are also cases where the prediction matches the utterance, but doesn’t match the intent used for generation. This indicates that our GPT-2 model is not perfect as it doesn’t generate matching utterances for an intent all the time.
To stop from training our classifier with corrupt data, we drop all utterances for which the basic intent does not match the predicted intent. For those with matching instances, we only keep the ones with the highest prediction probability scores.
Filter generated utterances
Filter utterances with old classifier when prediction matches:
1 | correctly_predicted_data = generated_data_predicted[ |
Check for the number of unique utterances per intent:
1 | correctly_predicted_data.drop_duplicates( |
Take TOP_N predictions per intent according to probability and drop duplicated
We can see that for each intent, there are at least 35 mutually distinct utterances. To keep a balanced data set, we pick the top 30 utterances per intent according to the prediction probability.
1 | TOP_N = 30 |
1 | top_predictions_per_intent |
intent | utterance | predicted_intent | prediction_proba | |
---|---|---|---|---|
103 | label0 | adding jquery files from within a wordpress theme | label0 | 0.810196 |
66 | label0 | formatting wordpress content | label0 | 0.808586 |
131 | label0 | add.php to front page of wordpress | label0 | 0.807943 |
177 | label0 | adding wordpress in post_meta subcategory | label0 | 0.804899 |
51 | label0 | is there a way to change how views are display… | label0 | 0.801144 |
… | … | … | … | … |
36 | label9 | new player | label9 | 0.815141 |
147 | label9 | is there a way to get a different card with ea… | label9 | 0.772208 |
117 | label9 | how do i recover my lost data | label9 | 0.768557 |
49 | label9 | i need to save my first pakistani boy in dlc 2… | label9 | 0.757511 |
15 | label9 | how can i bypass game engine antiophthalmic fa… | label9 | 0.738432 |
Step 6: Train the Intent Classifier with Augmented Data
Combine old and augmented data
We now combine the generated data with the initial training data and split the enriched data set intotraining and validation data.
1 | data_train_aug = data_train.append(top_predictions_per_intent[['intent', 'utterance']].rename( |
Initialise augmented model and learner
Now it’s time to train our new intent classification model. The code is like the one above:
1 | distil_bert_augmented = text.Transformer('distilbert-base-cased', |
1 | processed_train_aug = distil_bert_augmented.preprocess_train( |
Train classifier
Train classifier for given learning rate and number of epochs.
1 | N_TRAINING_EPOCHS_AUGMENTED = 5 |
Evaluate trained predictor
1 | predictor_aug = ktrain.get_predictor( |
Accuracy: 87.85%
Prepare model for download
1 | predictor.save('models/augmented/_distilbert_aug_{}epochs_{}'.format( |
adding: models/augmented/ (stored 0%)
adding: models/augmented/_distilbert_aug_5epochs_2021-04-01-15-13-40-763970/ (stored 0%)
adding: models/augmented/_distilbert_aug_5epochs_2021-04-01-15-13-40-763970/config.json (deflated 61%)
adding: models/augmented/_distilbert_aug_5epochs_2021-04-01-15-13-40-763970/tf_model.h5 (deflated 8%)
adding: models/augmented/_distilbert_aug_5epochs_2021-04-01-15-13-40-763970/tf_model.preproc (deflated 55%)
LAMBADA AI: Summary
We employed the LAMBADA method to augment data used for Natural Language Understanding (NLU) tasks. We trained a GPT-2 model to generate new training utterances and utilized them as training data for our intent classification model (DistilBERT). The performance of the intent classification model improved by at least 4% in each of our tests.
Additionally, we saw that high-level libraries such as KTrain and Huggingface Transformers help to reduce the complexity of applying state-of-the-art transformer models for Natural Language Generation (NLG) and other Natural Language Processing (NLP) tasks such as classification and make these approaches broadly applicable.
LAMBADA Method: How to use Data Augmentation in NLU?
http://vincentgaohj.github.io/Blog/2021/10/30/LAMBADA-Method-How-to-use-Data-Augmentation-in-NLU/