Text Prediction - Customization¶
This tutorial introduces the presets of TextPredictor
and how to
customize hyperparameters.
import numpy as np
import warnings
import autogluon as ag
warnings.filterwarnings("ignore")
np.random.seed(123)
Stanford Sentiment Treebank Data¶
For demonstration, we use the Stanford Sentiment Treebank (SST) dataset.
from autogluon.core import TabularDataset
subsample_size = 1000 # subsample for faster demo, you may try specifying larger value
train_data = TabularDataset("https://autogluon-text.s3-accelerate.amazonaws.com/glue/sst/train.parquet")
test_data = TabularDataset("https://autogluon-text.s3-accelerate.amazonaws.com/glue/sst/dev.parquet")
train_data = train_data.sample(n=subsample_size, random_state=0)
train_data.head(10)
sentence | label | |
---|---|---|
43787 | very pleasing at its best moments | 1 |
16159 | , american chai is enough to make you put away... | 0 |
59015 | too much like an infomercial for ram dass 's l... | 0 |
5108 | a stirring visual sequence | 1 |
67052 | cool visual backmasking | 1 |
35938 | hard ground | 0 |
49879 | the striking , quietly vulnerable personality ... | 1 |
51591 | pan nalin 's exposition is beautiful and myste... | 1 |
56780 | wonderfully loopy | 1 |
28518 | most beautiful , evocative | 1 |
Configure TextPredictor¶
Preset Configurations¶
TextPredictor
provides several simple preset configurations. Let’s
take a look at the available presets.
from autogluon.text.text_prediction.presets import list_text_presets
list_text_presets()
['default',
'medium_quality_faster_train',
'high_quality',
'best_quality',
'multilingual']
You may be interested in the configuration differences behind the preset strings.
list_text_presets(verbose=True)
{'default': {'model.hf_text.checkpoint_name': 'google/electra-base-discriminator'},
'medium_quality_faster_train': {'model.hf_text.checkpoint_name': 'google/electra-small-discriminator',
'optimization.learning_rate': 0.0004},
'high_quality': {'model.hf_text.checkpoint_name': 'google/electra-base-discriminator'},
'best_quality': {'model.hf_text.checkpoint_name': 'microsoft/deberta-v3-base',
'env.per_gpu_batch_size': 2},
'multilingual': {'model.hf_text.checkpoint_name': 'microsoft/mdeberta-v3-base',
'optimization.top_k': 1,
'env.precision': 'bf16',
'env.per_gpu_batch_size': 4}}
We can find that the main difference between different presets lie in
the choices of Huggingface transformer checkpoints. Preset default
has the same configuration as preset high_quality
. Designing the
presets follows the rule of thumb that larger backbones tend to have
better performance but with higher cost.
Let’s train a text predictor with preset
medium_quality_faster_train
.
from autogluon.text import TextPredictor
predictor = TextPredictor(eval_metric="acc", label="label")
predictor.fit(
train_data=train_data,
presets="medium_quality_faster_train",
time_limit=60,
)
Global seed set to 123
Downloading: 0%| | 0.00/29.0 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/665 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/226k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/455k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/51.7M [00:00<?, ?B/s]
Auto select gpus: [0]
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
-------------------------------------------------------------------
0 | model | HFAutoModelForTextPrediction | 13.5 M
1 | validation_metric | Accuracy | 0
2 | loss_func | CrossEntropyLoss | 0
-------------------------------------------------------------------
13.5 M Trainable params
0 Non-trainable params
13.5 M Total params
26.967 Total estimated model params size (MB)
Epoch 0, global step 3: 'val_acc' reached 0.45000 (best 0.45000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211228/epoch=0-step=3.ckpt' as top 3
Epoch 0, global step 7: 'val_acc' reached 0.55500 (best 0.55500), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211228/epoch=0-step=7.ckpt' as top 3
Epoch 1, global step 10: 'val_acc' reached 0.59500 (best 0.59500), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211228/epoch=1-step=10.ckpt' as top 3
Epoch 1, global step 14: 'val_acc' reached 0.65500 (best 0.65500), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211228/epoch=1-step=14.ckpt' as top 3
Epoch 2, global step 17: 'val_acc' reached 0.59500 (best 0.65500), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211228/epoch=2-step=17.ckpt' as top 3
Epoch 2, global step 21: 'val_acc' reached 0.81000 (best 0.81000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211228/epoch=2-step=21.ckpt' as top 3
Epoch 3, global step 24: 'val_acc' reached 0.81000 (best 0.81000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211228/epoch=3-step=24.ckpt' as top 3
Epoch 3, global step 28: 'val_acc' reached 0.84000 (best 0.84000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211228/epoch=3-step=28.ckpt' as top 3
Epoch 4, global step 31: 'val_acc' was not in top 3
Epoch 4, global step 35: 'val_acc' reached 0.84000 (best 0.84000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211228/epoch=4-step=35.ckpt' as top 3
Epoch 5, global step 38: 'val_acc' reached 0.83500 (best 0.84000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211228/epoch=5-step=38.ckpt' as top 3
Epoch 5, global step 42: 'val_acc' was not in top 3
Epoch 6, global step 45: 'val_acc' reached 0.85000 (best 0.85000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211228/epoch=6-step=45.ckpt' as top 3
Epoch 6, global step 49: 'val_acc' was not in top 3
Epoch 7, global step 52: 'val_acc' was not in top 3
Epoch 7, global step 56: 'val_acc' was not in top 3
Epoch 8, global step 59: 'val_acc' was not in top 3
Epoch 8, global step 63: 'val_acc' was not in top 3
Epoch 9, global step 66: 'val_acc' was not in top 3
Time limit reached. Elapsed time is 0:01:00. Signaling Trainer to stop.
<autogluon.text.text_prediction.predictor.TextPredictor at 0x7f2e3f315af0>
Below we report both f1
and acc
metrics for our predictions.
predictor.evaluate(test_data, metrics=["f1", "acc"])
{'f1': 0.8080808080808082, 'acc': 0.8038990825688074}
The pre-registered configurations provide reasonable default hyperparameters. A common workflow is to first train a model with one of the presets and then tune some hyperparameters to see if the performance can be further improved.
Customize Hyperparameters¶
Customizing hyperparameters is easy for TextPredictor
. For example,
you may want to try backbones beyond those in the presets. Since
TextPredictor
supports loading Huggingface transformers, you can
choose any desired text backbones in the Hugginface model
zoo, e.g., prajjwal1/bert-tiny
.
from autogluon.text import TextPredictor
predictor = TextPredictor(eval_metric="acc", label="label")
predictor.fit(
train_data=train_data,
hyperparameters={
"model.hf_text.checkpoint_name": "prajjwal1/bert-tiny",
},
time_limit=60,
)
Global seed set to 123
Downloading: 0%| | 0.00/285 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/226k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/16.9M [00:00<?, ?B/s]
Auto select gpus: [0]
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
-------------------------------------------------------------------
0 | model | HFAutoModelForTextPrediction | 4.4 M
1 | validation_metric | Accuracy | 0
2 | loss_func | CrossEntropyLoss | 0
-------------------------------------------------------------------
4.4 M Trainable params
0 Non-trainable params
4.4 M Total params
8.772 Total estimated model params size (MB)
Epoch 0, global step 3: 'val_acc' reached 0.55000 (best 0.55000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=0-step=3.ckpt' as top 3
Epoch 0, global step 7: 'val_acc' reached 0.53000 (best 0.55000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=0-step=7.ckpt' as top 3
Epoch 1, global step 10: 'val_acc' reached 0.52000 (best 0.55000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=1-step=10.ckpt' as top 3
Epoch 1, global step 14: 'val_acc' reached 0.59000 (best 0.59000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=1-step=14.ckpt' as top 3
Epoch 2, global step 17: 'val_acc' reached 0.55500 (best 0.59000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=2-step=17.ckpt' as top 3
Epoch 2, global step 21: 'val_acc' reached 0.57500 (best 0.59000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=2-step=21.ckpt' as top 3
Epoch 3, global step 24: 'val_acc' reached 0.63000 (best 0.63000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=3-step=24.ckpt' as top 3
Epoch 3, global step 28: 'val_acc' reached 0.66000 (best 0.66000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=3-step=28.ckpt' as top 3
Epoch 4, global step 31: 'val_acc' reached 0.66000 (best 0.66000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=4-step=31.ckpt' as top 3
Epoch 4, global step 35: 'val_acc' reached 0.64000 (best 0.66000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=4-step=35.ckpt' as top 3
Epoch 5, global step 38: 'val_acc' reached 0.65500 (best 0.66000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=5-step=38.ckpt' as top 3
Epoch 5, global step 42: 'val_acc' reached 0.66500 (best 0.66500), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=5-step=42.ckpt' as top 3
Epoch 6, global step 45: 'val_acc' was not in top 3
Epoch 6, global step 49: 'val_acc' reached 0.66500 (best 0.66500), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=6-step=49.ckpt' as top 3
Epoch 7, global step 52: 'val_acc' reached 0.67000 (best 0.67000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=7-step=52.ckpt' as top 3
Epoch 7, global step 56: 'val_acc' reached 0.67000 (best 0.67000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=7-step=56.ckpt' as top 3
Epoch 8, global step 59: 'val_acc' reached 0.67000 (best 0.67000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211338/epoch=8-step=59.ckpt' as top 3
Epoch 8, global step 63: 'val_acc' was not in top 3
Epoch 9, global step 66: 'val_acc' was not in top 3
Epoch 9, global step 70: 'val_acc' was not in top 3
<autogluon.text.text_prediction.predictor.TextPredictor at 0x7f2e3f421220>
predictor.evaluate(test_data, metrics=["f1", "acc"])
{'f1': 0.5970149253731343, 'acc': 0.6284403669724771}
TextPredictor
also supports using the models that are not available
in the Huggingface model zoo. To do
this, you need to make sure that the models’ definition follow
Hugginface’s AutoModel, AutoConfig, and AutoTokenizer. Let’s simulate a
local model.
import os
from transformers import AutoModel, AutoConfig, AutoTokenizer
model_key = 'prajjwal1/bert-tiny'
local_path = 'custom_local_bert_tiny'
model = AutoModel.from_pretrained(model_key)
config = AutoConfig.from_pretrained(model_key)
tokenizer = AutoTokenizer.from_pretrained(model_key)
model.save_pretrained(local_path)
config.save_pretrained(local_path)
tokenizer.save_pretrained(local_path)
os.listdir(local_path)
['config.json',
'pytorch_model.bin',
'tokenizer_config.json',
'special_tokens_map.json',
'vocab.txt',
'tokenizer.json']
Now we can use this local model in TextPredictor
.
from autogluon.text import TextPredictor
predictor = TextPredictor(eval_metric="acc", label="label")
predictor.fit(
train_data=train_data,
hyperparameters={
"model.hf_text.checkpoint_name": "custom_local_bert_tiny/",
},
time_limit=60,
)
Global seed set to 123
Auto select gpus: [0]
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
-------------------------------------------------------------------
0 | model | HFAutoModelForTextPrediction | 4.4 M
1 | validation_metric | Accuracy | 0
2 | loss_func | CrossEntropyLoss | 0
-------------------------------------------------------------------
4.4 M Trainable params
0 Non-trainable params
4.4 M Total params
8.772 Total estimated model params size (MB)
Epoch 0, global step 3: 'val_acc' reached 0.55000 (best 0.55000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=0-step=3.ckpt' as top 3
Epoch 0, global step 7: 'val_acc' reached 0.53000 (best 0.55000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=0-step=7.ckpt' as top 3
Epoch 1, global step 10: 'val_acc' reached 0.52000 (best 0.55000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=1-step=10.ckpt' as top 3
Epoch 1, global step 14: 'val_acc' reached 0.59000 (best 0.59000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=1-step=14.ckpt' as top 3
Epoch 2, global step 17: 'val_acc' reached 0.55500 (best 0.59000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=2-step=17.ckpt' as top 3
Epoch 2, global step 21: 'val_acc' reached 0.57500 (best 0.59000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=2-step=21.ckpt' as top 3
Epoch 3, global step 24: 'val_acc' reached 0.63000 (best 0.63000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=3-step=24.ckpt' as top 3
Epoch 3, global step 28: 'val_acc' reached 0.66000 (best 0.66000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=3-step=28.ckpt' as top 3
Epoch 4, global step 31: 'val_acc' reached 0.66000 (best 0.66000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=4-step=31.ckpt' as top 3
Epoch 4, global step 35: 'val_acc' reached 0.64000 (best 0.66000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=4-step=35.ckpt' as top 3
Epoch 5, global step 38: 'val_acc' reached 0.65500 (best 0.66000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=5-step=38.ckpt' as top 3
Epoch 5, global step 42: 'val_acc' reached 0.66500 (best 0.66500), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=5-step=42.ckpt' as top 3
Epoch 6, global step 45: 'val_acc' was not in top 3
Epoch 6, global step 49: 'val_acc' reached 0.66500 (best 0.66500), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=6-step=49.ckpt' as top 3
Epoch 7, global step 52: 'val_acc' reached 0.67000 (best 0.67000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=7-step=52.ckpt' as top 3
Epoch 7, global step 56: 'val_acc' reached 0.67000 (best 0.67000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=7-step=56.ckpt' as top 3
Epoch 8, global step 59: 'val_acc' reached 0.67000 (best 0.67000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220728_211405/epoch=8-step=59.ckpt' as top 3
Epoch 8, global step 63: 'val_acc' was not in top 3
Epoch 9, global step 66: 'val_acc' was not in top 3
Epoch 9, global step 70: 'val_acc' was not in top 3
<autogluon.text.text_prediction.predictor.TextPredictor at 0x7f2f1911a8e0>
predictor.evaluate(test_data, metrics=["f1", "acc"])
{'f1': 0.5970149253731343, 'acc': 0.6284403669724771}