Text Prediction - Heterogeneous Data Types¶
In your applications, your text data may be mixed with other common data
types like numerical data and categorical data (which are commonly found
in tabular data). The TextPrediction
task in AutoGluon can train a
single neural network that jointly operates on multiple feature types,
including text, categorical, and numerical columns. Here we’ll again use
the Semantic Textual
Similarity
dataset to illustrate this functionality.
import numpy as np
import warnings
warnings.filterwarnings('ignore')
np.random.seed(123)
Load Data¶
from autogluon.core.utils.loaders import load_pd
train_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/train.parquet')
dev_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/dev.parquet')
train_data.head(10)
sentence1 | sentence2 | genre | score | |
---|---|---|---|---|
0 | A plane is taking off. | An air plane is taking off. | main-captions | 5.00 |
1 | A man is playing a large flute. | A man is playing a flute. | main-captions | 3.80 |
2 | A man is spreading shreded cheese on a pizza. | A man is spreading shredded cheese on an uncoo... | main-captions | 3.80 |
3 | Three men are playing chess. | Two men are playing chess. | main-captions | 2.60 |
4 | A man is playing the cello. | A man seated is playing the cello. | main-captions | 4.25 |
5 | Some men are fighting. | Two men are fighting. | main-captions | 4.25 |
6 | A man is smoking. | A man is skating. | main-captions | 0.50 |
7 | The man is playing the piano. | The man is playing the guitar. | main-captions | 1.60 |
8 | A man is playing on a guitar and singing. | A woman is playing an acoustic guitar and sing... | main-captions | 2.20 |
9 | A person is throwing a cat on to the ceiling. | A person throws a cat on the ceiling. | main-captions | 5.00 |
Note the STS dataset contains two text fields: sentence1
and
sentence2
, one categorical field: genre
, and one numerical field
score
. Let’s try to predict the score based on the other
features: sentence1
, sentence2
, genre
.
import autogluon.core as ag
from autogluon.text import TextPrediction as task
predictor_score = task.fit(train_data, label='score',
time_limits=60, ngpus_per_trial=1, seed=123,
output_directory='./ag_sts_mixed_score')
2020-10-23 22:28:12,644 - root - INFO - All Logs will be saved to ./ag_sts_mixed_score/ag_text_prediction.log
2020-10-23 22:28:12,669 - root - INFO - Train Dataset:
2020-10-23 22:28:12,669 - root - INFO - Columns:
- Text(
name="sentence1"
#total/missing=4599/0
length, min/avg/max=16/57.73/367
)
- Text(
name="sentence2"
#total/missing=4599/0
length, min/avg/max=16/57.68/311
)
- Categorical(
name="genre"
#total/missing=4599/0
num_class (total/non_special)=4/3
categories=['main-captions', 'main-forums', 'main-news']
freq=[1594, 374, 2631]
)
- Numerical(
name="score"
#total/missing=4599/0
shape=()
)
2020-10-23 22:28:12,670 - root - INFO - Tuning Dataset:
2020-10-23 22:28:12,670 - root - INFO - Columns:
- Text(
name="sentence1"
#total/missing=1150/0
length, min/avg/max=16/57.63/272
)
- Text(
name="sentence2"
#total/missing=1150/0
length, min/avg/max=15/56.93/229
)
- Categorical(
name="genre"
#total/missing=1150/0
num_class (total/non_special)=4/3
categories=['main-captions', 'main-forums', 'main-news']
freq=[406, 76, 668]
)
- Numerical(
name="score"
#total/missing=1150/0
shape=()
)
2020-10-23 22:28:12,671 - root - INFO - Label columns=['score'], Feature columns=['sentence1', 'sentence2', 'genre'], Problem types=['regression'], Label shapes=[()]
2020-10-23 22:28:12,671 - root - INFO - Eval Metric=mse, Stop Metric=mse, Log Metrics=['mse', 'rmse', 'mae']
HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))
55%|█████▍ | 314/576 [01:00<00:50, 5.16it/s]
55%|█████▍ | 314/576 [01:01<00:51, 5.10it/s]
score = predictor_score.evaluate(dev_data, metrics='spearmanr')
print('Spearman Correlation=', score['spearmanr'])
Spearman Correlation= 0.8547863115572756
We can also train a model that predicts the genre using the other columns as features.
predictor_genre = task.fit(train_data, label='genre',
time_limits=60, ngpus_per_trial=1, seed=123,
output_directory='./ag_sts_mixed_genre')
2020-10-23 22:30:51,237 - root - INFO - All Logs will be saved to ./ag_sts_mixed_genre/ag_text_prediction.log
2020-10-23 22:30:51,264 - root - INFO - Train Dataset:
2020-10-23 22:30:51,265 - root - INFO - Columns:
- Text(
name="sentence1"
#total/missing=4599/0
length, min/avg/max=16/57.45/340
)
- Text(
name="sentence2"
#total/missing=4599/0
length, min/avg/max=15/57.31/311
)
- Categorical(
name="genre"
#total/missing=4599/0
num_class (total/non_special)=3/3
categories=['main-captions', 'main-forums', 'main-news']
freq=[1615, 347, 2637]
)
- Numerical(
name="score"
#total/missing=4599/0
shape=()
)
2020-10-23 22:30:51,265 - root - INFO - Tuning Dataset:
2020-10-23 22:30:51,265 - root - INFO - Columns:
- Text(
name="sentence1"
#total/missing=1150/0
length, min/avg/max=17/58.74/367
)
- Text(
name="sentence2"
#total/missing=1150/0
length, min/avg/max=17/58.41/265
)
- Categorical(
name="genre"
#total/missing=1150/0
num_class (total/non_special)=3/3
categories=['main-captions', 'main-forums', 'main-news']
freq=[385, 103, 662]
)
- Numerical(
name="score"
#total/missing=1150/0
shape=()
)
2020-10-23 22:30:51,266 - root - INFO - Label columns=['genre'], Feature columns=['sentence1', 'sentence2', 'score'], Problem types=['classification'], Label shapes=[3]
2020-10-23 22:30:51,266 - root - INFO - Eval Metric=acc, Stop Metric=acc, Log Metrics=['acc', 'nll']
HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))
57%|█████▋ | 329/576 [01:01<00:46, 5.34it/s]
57%|█████▋ | 329/576 [01:02<00:46, 5.26it/s]
score = predictor_genre.evaluate(dev_data, metrics='acc')
print('Genre-prediction Accuracy = {}%'.format(score['acc'] * 100))
Genre-prediction Accuracy = 85.6%