.. _sec_textprediction_heterogeneous: Text Prediction - Heterogeneous Data Types ========================================== In your applications, your text data may be mixed with other common data types like numerical data and categorical data (which are commonly found in tabular data). The ``TextPrediction`` task in AutoGluon can train a single neural network that jointly operates on multiple feature types, including text, categorical, and numerical columns. Here we'll again use the `Semantic Textual Similarity `__ dataset to illustrate this functionality. .. code:: python import numpy as np import warnings warnings.filterwarnings('ignore') np.random.seed(123) Load Data --------- .. code:: python from autogluon.core.utils.loaders import load_pd train_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/train.parquet') dev_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/dev.parquet') train_data.head(10) .. raw:: html
sentence1 sentence2 genre score
0 A plane is taking off. An air plane is taking off. main-captions 5.00
1 A man is playing a large flute. A man is playing a flute. main-captions 3.80
2 A man is spreading shreded cheese on a pizza. A man is spreading shredded cheese on an uncoo... main-captions 3.80
3 Three men are playing chess. Two men are playing chess. main-captions 2.60
4 A man is playing the cello. A man seated is playing the cello. main-captions 4.25
5 Some men are fighting. Two men are fighting. main-captions 4.25
6 A man is smoking. A man is skating. main-captions 0.50
7 The man is playing the piano. The man is playing the guitar. main-captions 1.60
8 A man is playing on a guitar and singing. A woman is playing an acoustic guitar and sing... main-captions 2.20
9 A person is throwing a cat on to the ceiling. A person throws a cat on the ceiling. main-captions 5.00
Note the STS dataset contains two text fields: ``sentence1`` and ``sentence2``, one categorical field: ``genre``, and one numerical field ``score``. Let's try to predict the **score** based on the other features: ``sentence1``, ``sentence2``, ``genre``. .. code:: python import autogluon.core as ag from autogluon.text import TextPrediction as task predictor_score = task.fit(train_data, label='score', time_limits=60, ngpus_per_trial=1, seed=123, output_directory='./ag_sts_mixed_score') .. parsed-literal:: :class: output 2020-10-23 22:28:12,644 - root - INFO - All Logs will be saved to ./ag_sts_mixed_score/ag_text_prediction.log 2020-10-23 22:28:12,669 - root - INFO - Train Dataset: 2020-10-23 22:28:12,669 - root - INFO - Columns: - Text( name="sentence1" #total/missing=4599/0 length, min/avg/max=16/57.73/367 ) - Text( name="sentence2" #total/missing=4599/0 length, min/avg/max=16/57.68/311 ) - Categorical( name="genre" #total/missing=4599/0 num_class (total/non_special)=4/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[1594, 374, 2631] ) - Numerical( name="score" #total/missing=4599/0 shape=() ) 2020-10-23 22:28:12,670 - root - INFO - Tuning Dataset: 2020-10-23 22:28:12,670 - root - INFO - Columns: - Text( name="sentence1" #total/missing=1150/0 length, min/avg/max=16/57.63/272 ) - Text( name="sentence2" #total/missing=1150/0 length, min/avg/max=15/56.93/229 ) - Categorical( name="genre" #total/missing=1150/0 num_class (total/non_special)=4/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[406, 76, 668] ) - Numerical( name="score" #total/missing=1150/0 shape=() ) 2020-10-23 22:28:12,671 - root - INFO - Label columns=['score'], Feature columns=['sentence1', 'sentence2', 'genre'], Problem types=['regression'], Label shapes=[()] 2020-10-23 22:28:12,671 - root - INFO - Eval Metric=mse, Stop Metric=mse, Log Metrics=['mse', 'rmse', 'mae'] .. parsed-literal:: :class: output HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value=''))) .. parsed-literal:: :class: output 55%|█████▍ | 314/576 [01:00<00:50, 5.16it/s] .. parsed-literal:: :class: output .. parsed-literal:: :class: output 55%|█████▍ | 314/576 [01:01<00:51, 5.10it/s] .. code:: python score = predictor_score.evaluate(dev_data, metrics='spearmanr') print('Spearman Correlation=', score['spearmanr']) .. parsed-literal:: :class: output Spearman Correlation= 0.8547863115572756 We can also train a model that predicts the **genre** using the other columns as features. .. code:: python predictor_genre = task.fit(train_data, label='genre', time_limits=60, ngpus_per_trial=1, seed=123, output_directory='./ag_sts_mixed_genre') .. parsed-literal:: :class: output 2020-10-23 22:30:51,237 - root - INFO - All Logs will be saved to ./ag_sts_mixed_genre/ag_text_prediction.log 2020-10-23 22:30:51,264 - root - INFO - Train Dataset: 2020-10-23 22:30:51,265 - root - INFO - Columns: - Text( name="sentence1" #total/missing=4599/0 length, min/avg/max=16/57.45/340 ) - Text( name="sentence2" #total/missing=4599/0 length, min/avg/max=15/57.31/311 ) - Categorical( name="genre" #total/missing=4599/0 num_class (total/non_special)=3/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[1615, 347, 2637] ) - Numerical( name="score" #total/missing=4599/0 shape=() ) 2020-10-23 22:30:51,265 - root - INFO - Tuning Dataset: 2020-10-23 22:30:51,265 - root - INFO - Columns: - Text( name="sentence1" #total/missing=1150/0 length, min/avg/max=17/58.74/367 ) - Text( name="sentence2" #total/missing=1150/0 length, min/avg/max=17/58.41/265 ) - Categorical( name="genre" #total/missing=1150/0 num_class (total/non_special)=3/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[385, 103, 662] ) - Numerical( name="score" #total/missing=1150/0 shape=() ) 2020-10-23 22:30:51,266 - root - INFO - Label columns=['genre'], Feature columns=['sentence1', 'sentence2', 'score'], Problem types=['classification'], Label shapes=[3] 2020-10-23 22:30:51,266 - root - INFO - Eval Metric=acc, Stop Metric=acc, Log Metrics=['acc', 'nll'] .. parsed-literal:: :class: output HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value=''))) .. parsed-literal:: :class: output 57%|█████▋ | 329/576 [01:01<00:46, 5.34it/s] .. parsed-literal:: :class: output .. parsed-literal:: :class: output 57%|█████▋ | 329/576 [01:02<00:46, 5.26it/s] .. code:: python score = predictor_genre.evaluate(dev_data, metrics='acc') print('Genre-prediction Accuracy = {}%'.format(score['acc'] * 100)) .. parsed-literal:: :class: output Genre-prediction Accuracy = 85.6%