.. _sec_textprediction_heterogeneous: Text Prediction - Heterogeneous Data Types ========================================== In your applications, your text data may be mixed with other common data types like numerical data and categorical data (which are commonly found in tabular data). The ``TextPrediction`` task in AutoGluon can train a single neural network that jointly operates on multiple feature types, including text, categorical, and numerical columns. Here we'll again use the `Semantic Textual Similarity `__ dataset to illustrate this functionality. .. code:: python import numpy as np import warnings warnings.filterwarnings('ignore') np.random.seed(123) Load Data --------- .. code:: python from autogluon.utils.tabular.utils.loaders import load_pd train_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/train.parquet') dev_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/dev.parquet') train_data.head(10) .. parsed-literal:: :class: output Loaded data from: https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/train.parquet | Columns = 4 / 4 | Rows = 5749 -> 5749 Loaded data from: https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/dev.parquet | Columns = 4 / 4 | Rows = 1500 -> 1500 .. raw:: html
sentence1 sentence2 genre score
0 A plane is taking off. An air plane is taking off. main-captions 5.00
1 A man is playing a large flute. A man is playing a flute. main-captions 3.80
2 A man is spreading shreded cheese on a pizza. A man is spreading shredded cheese on an uncoo... main-captions 3.80
3 Three men are playing chess. Two men are playing chess. main-captions 2.60
4 A man is playing the cello. A man seated is playing the cello. main-captions 4.25
5 Some men are fighting. Two men are fighting. main-captions 4.25
6 A man is smoking. A man is skating. main-captions 0.50
7 The man is playing the piano. The man is playing the guitar. main-captions 1.60
8 A man is playing on a guitar and singing. A woman is playing an acoustic guitar and sing... main-captions 2.20
9 A person is throwing a cat on to the ceiling. A person throws a cat on the ceiling. main-captions 5.00
Note the STS dataset contains two text fields: ``sentence1`` and ``sentence2``, one categorical field: ``genre``, and one numerical field ``score``. Let's try to predict the **score** based on the other features: ``sentence1``, ``sentence2``, ``genre``. .. code:: python import autogluon as ag from autogluon import TextPrediction as task predictor_score = task.fit(train_data, label='score', time_limits=60, ngpus_per_trial=1, seed=123, output_directory='./ag_sts_mixed_score') .. parsed-literal:: :class: output /var/lib/jenkins/miniconda3/envs/autogluon_docs-v0_0_14/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above. and should_run_async(code) NumPy-shape semantics has been activated in your code. This is required for creating and manipulating scalar and zero-size tensors, which were not supported in MXNet before, as in the official NumPy library. Please DO NOT manually deactivate this semantics while using `mxnet.numpy` and `mxnet.numpy_extension` modules. 2020-10-27 21:41:40,945 - root - INFO - All Logs will be saved to ./ag_sts_mixed_score/ag_text_prediction.log 2020-10-27 21:41:40,969 - root - INFO - Train Dataset: 2020-10-27 21:41:40,969 - root - INFO - Columns: - Text( name="sentence1" #total/missing=4599/0 length, min/avg/max=16/57.15/367 ) - Text( name="sentence2" #total/missing=4599/0 length, min/avg/max=15/57.03/311 ) - Categorical( name="genre" #total/missing=4599/0 num_class (total/non_special)=4/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[1611, 361, 2627] ) - Numerical( name="score" #total/missing=4599/0 shape=() ) 2020-10-27 21:41:40,969 - root - INFO - Tuning Dataset: 2020-10-27 21:41:40,970 - root - INFO - Columns: - Text( name="sentence1" #total/missing=1150/0 length, min/avg/max=16/59.95/340 ) - Text( name="sentence2" #total/missing=1150/0 length, min/avg/max=16/59.53/229 ) - Categorical( name="genre" #total/missing=1150/0 num_class (total/non_special)=4/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[389, 89, 672] ) - Numerical( name="score" #total/missing=1150/0 shape=() ) 2020-10-27 21:41:40,970 - root - INFO - Label columns=['score'], Feature columns=['sentence1', 'sentence2', 'genre'], Problem types=['regression'], Label shapes=[()] 2020-10-27 21:41:40,971 - root - INFO - Eval Metric=mse, Stop Metric=mse, Log Metrics=['mse', 'rmse', 'mae'] .. parsed-literal:: :class: output HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value=''))) .. parsed-literal:: :class: output 55%|█████▍ | 314/576 [01:01<00:51, 5.10it/s] .. parsed-literal:: :class: output .. parsed-literal:: :class: output 55%|█████▍ | 314/576 [01:00<00:50, 5.16it/s] .. code:: python score = predictor_score.evaluate(dev_data, metrics='spearmanr') print('Spearman Correlation=', score['spearmanr']) .. parsed-literal:: :class: output /var/lib/jenkins/miniconda3/envs/autogluon_docs-v0_0_14/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above. and should_run_async(code) .. parsed-literal:: :class: output Spearman Correlation= 0.8571229145912008 We can also train a model that predicts the **genre** using the other columns as features. .. code:: python predictor_genre = task.fit(train_data, label='genre', time_limits=60, ngpus_per_trial=1, seed=123, output_directory='./ag_sts_mixed_genre') .. parsed-literal:: :class: output 2020-10-27 21:44:19,637 - root - INFO - All Logs will be saved to ./ag_sts_mixed_genre/ag_text_prediction.log 2020-10-27 21:44:19,663 - root - INFO - Train Dataset: 2020-10-27 21:44:19,664 - root - INFO - Columns: - Text( name="sentence1" #total/missing=4599/0 length, min/avg/max=16/57.71/367 ) - Text( name="sentence2" #total/missing=4599/0 length, min/avg/max=15/57.49/311 ) - Categorical( name="genre" #total/missing=4599/0 num_class (total/non_special)=3/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[1604, 356, 2639] ) - Numerical( name="score" #total/missing=4599/0 shape=() ) 2020-10-27 21:44:19,664 - root - INFO - Tuning Dataset: 2020-10-27 21:44:19,665 - root - INFO - Columns: - Text( name="sentence1" #total/missing=1150/0 length, min/avg/max=16/57.69/340 ) - Text( name="sentence2" #total/missing=1150/0 length, min/avg/max=16/57.71/229 ) - Categorical( name="genre" #total/missing=1150/0 num_class (total/non_special)=3/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[396, 94, 660] ) - Numerical( name="score" #total/missing=1150/0 shape=() ) 2020-10-27 21:44:19,665 - root - INFO - Label columns=['genre'], Feature columns=['sentence1', 'sentence2', 'score'], Problem types=['classification'], Label shapes=[3] 2020-10-27 21:44:19,666 - root - INFO - Eval Metric=acc, Stop Metric=acc, Log Metrics=['acc', 'nll'] .. parsed-literal:: :class: output HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value=''))) .. parsed-literal:: :class: output 44%|████▍ | 254/576 [00:50<01:03, 5.08it/s] 52%|█████▏ | 299/576 [00:57<00:52, 5.24it/s] .. parsed-literal:: :class: output .. parsed-literal:: :class: output 49%|████▉ | 284/576 [00:54<00:55, 5.23it/s] .. code:: python score = predictor_genre.evaluate(dev_data, metrics='acc') print('Genre-prediction Accuracy = {}%'.format(score['acc'] * 100)) .. parsed-literal:: :class: output Genre-prediction Accuracy = 89.66666666666666%