Text Prediction - Customization

This tutorial introduces how to customize the hyperparameters in TextPredictor.

import numpy as np
import warnings
import autogluon as ag
warnings.filterwarnings('ignore')
np.random.seed(123)

Stanford Sentiment Treebank Data

For demonstration, we use the Stanford Sentiment Treebank (SST) dataset.

from autogluon.core import TabularDataset
subsample_size = 1000  # subsample for faster demo, you may try specifying larger value
train_data = TabularDataset('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sst/train.parquet')
test_data = TabularDataset('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sst/dev.parquet')
train_data = train_data.sample(n=subsample_size, random_state=0)
train_data.head(10)
sentence label
43787 very pleasing at its best moments 1
16159 , american chai is enough to make you put away... 0
59015 too much like an infomercial for ram dass 's l... 0
5108 a stirring visual sequence 1
67052 cool visual backmasking 1
35938 hard ground 0
49879 the striking , quietly vulnerable personality ... 1
51591 pan nalin 's exposition is beautiful and myste... 1
56780 wonderfully loopy 1
28518 most beautiful , evocative 1

Configure TextPredictor

Preset Configurations

TextPredictor provides several simple preset configurations. Let’s take a look at the available presets.

from autogluon.text.text_prediction.presets import list_text_presets
list_text_presets()
['default',
 'medium_quality_faster_train',
 'high_quality',
 'best_quality',
 'multilingual']

You may be interested in the configuration differences behind the preset strings.

list_text_presets(verbose=True)
{'default': {'model.hf_text.checkpoint_name': 'google/electra-base-discriminator',
  'optimization.lr_decay': 0.9},
 'medium_quality_faster_train': {'model.hf_text.checkpoint_name': 'google/electra-small-discriminator',
  'optimization.learning_rate': 0.0004,
  'optimization.lr_decay': 0.9},
 'high_quality': {'model.hf_text.checkpoint_name': 'google/electra-base-discriminator'},
 'best_quality': {'model.hf_text.checkpoint_name': 'microsoft/deberta-v3-base',
  'optimization.lr_decay': 0.9,
  'env.per_gpu_batch_size': 2},
 'multilingual': {'model.hf_text.checkpoint_name': 'microsoft/mdeberta-v3-base',
  'optimization.top_k': 1,
  'optimization.lr_decay': 0.9,
  'env.precision': 'bf16',
  'env.per_gpu_batch_size': 4}}

We can find that the main difference between different presets lie in the choices of huggingface transformer checkpoints. Preset default has the same configuration as preset high_quality. Designing the presets follows the rule of thumb that larger backbones tend to have better performance but with higher cost.

Let’s train a text predictor with preset medium_quality_faster_train.

from autogluon.text import TextPredictor
predictor = TextPredictor(eval_metric='acc', label='label')
predictor.fit(
    train_data=train_data,
    presets='medium_quality_faster_train',
    time_limit=60,
)
Global seed set to 123
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                         | Params
-------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 13.5 M
1 | validation_metric | Accuracy                     | 0
2 | loss_func         | CrossEntropyLoss             | 0
-------------------------------------------------------------------
13.5 M    Trainable params
0         Non-trainable params
13.5 M    Total params
26.967    Total estimated model params size (MB)
Global seed set to 123
Epoch 0, global step 3: val_acc reached 0.45000 (best 0.45000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054821/epoch=0-step=3.ckpt" as top 3
Epoch 0, global step 6: val_acc reached 0.55500 (best 0.55500), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054821/epoch=0-step=6.ckpt" as top 3
Epoch 1, global step 10: val_acc reached 0.59500 (best 0.59500), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054821/epoch=1-step=10.ckpt" as top 3
Epoch 1, global step 13: val_acc reached 0.65500 (best 0.65500), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054821/epoch=1-step=13.ckpt" as top 3
Epoch 2, global step 17: val_acc reached 0.59500 (best 0.65500), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054821/epoch=2-step=17.ckpt" as top 3
Epoch 2, global step 20: val_acc reached 0.82000 (best 0.82000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054821/epoch=2-step=20.ckpt" as top 3
Epoch 3, global step 24: val_acc reached 0.80500 (best 0.82000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054821/epoch=3-step=24.ckpt" as top 3
Epoch 3, global step 27: val_acc reached 0.84000 (best 0.84000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054821/epoch=3-step=27.ckpt" as top 3
Epoch 4, global step 31: val_acc was not in top 3
Epoch 4, global step 34: val_acc reached 0.84500 (best 0.84500), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054821/epoch=4-step=34.ckpt" as top 3
Epoch 5, global step 38: val_acc reached 0.84000 (best 0.84500), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054821/epoch=5-step=38.ckpt" as top 3
Epoch 5, global step 41: val_acc was not in top 3
Epoch 6, global step 45: val_acc was not in top 3
Epoch 6, global step 48: val_acc reached 0.84500 (best 0.84500), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054821/epoch=6-step=48.ckpt" as top 3
Epoch 7, global step 52: val_acc was not in top 3
Epoch 7, global step 55: val_acc was not in top 3
Epoch 8, global step 59: val_acc was not in top 3
Epoch 8, global step 62: val_acc was not in top 3
Epoch 9, global step 66: val_acc was not in top 3
Time limit reached. Elapsed time is 0:01:00. Signaling Trainer to stop.
<autogluon.text.text_prediction.predictor.TextPredictor at 0x7f18ca4b0dc0>

Below we report both f1 and acc metrics for our predictions. If you want to obtain the best F1 score, you should set eval_metric='f1' when constructing the TextPredictor.

predictor.evaluate(test_data, metrics=['f1', 'acc'])
{'f1': 0.79498861047836, 'acc': 0.7935779816513762}

Custom Hyperparameter Values

The pre-registered configurations provide reasonable default hyperparameters. A common workflow is to first train a model with one of the presets and then tune some hyperparameters to see if the performance can be further improved.

TextPredictor builds on top of AutoMMPredictor, which has a flexible and easy-to-use configuration design. Please refer to AutoMMPredictor for Image, Text, and Tabular on how to customize hyper-parameters.