.. _sec_textprediction_heterogeneous: Text Prediction - Heterogeneous Data Types ========================================== In your applications, your text data may be mixed with other common data types like numerical data and categorical data (which are commonly found in tabular data). The ``TextPrediction`` task in AutoGluon can train a single neural network that jointly operates on multiple feature types, including text, categorical, and numerical columns. Here we'll again use the `Semantic Textual Similarity `__ dataset to illustrate this functionality. .. code:: python import numpy as np import warnings warnings.filterwarnings('ignore') np.random.seed(123) Load Data --------- .. code:: python from autogluon.utils.tabular.utils.loaders import load_pd train_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/train.parquet') dev_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/dev.parquet') train_data.head(10) .. parsed-literal:: :class: output Loaded data from: https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/train.parquet | Columns = 4 / 4 | Rows = 5749 -> 5749 Loaded data from: https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/dev.parquet | Columns = 4 / 4 | Rows = 1500 -> 1500 .. raw:: html
sentence1 sentence2 genre score
0 A plane is taking off. An air plane is taking off. main-captions 5.00
1 A man is playing a large flute. A man is playing a flute. main-captions 3.80
2 A man is spreading shreded cheese on a pizza. A man is spreading shredded cheese on an uncoo... main-captions 3.80
3 Three men are playing chess. Two men are playing chess. main-captions 2.60
4 A man is playing the cello. A man seated is playing the cello. main-captions 4.25
5 Some men are fighting. Two men are fighting. main-captions 4.25
6 A man is smoking. A man is skating. main-captions 0.50
7 The man is playing the piano. The man is playing the guitar. main-captions 1.60
8 A man is playing on a guitar and singing. A woman is playing an acoustic guitar and sing... main-captions 2.20
9 A person is throwing a cat on to the ceiling. A person throws a cat on the ceiling. main-captions 5.00
Note the STS dataset contains two text fields: ``sentence1`` and ``sentence2``, one categorical field: ``genre``, and one numerical field ``score``. Let's try to predict the **score** based on the other features: ``sentence1``, ``sentence2``, ``genre``. .. code:: python import autogluon as ag from autogluon import TextPrediction as task predictor_score = task.fit(train_data, label='score', time_limits=60, ngpus_per_trial=1, seed=123, output_directory='./ag_sts_mixed_score') .. parsed-literal:: :class: output /var/lib/jenkins/miniconda3/envs/autogluon_docs-v0_0_15/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above. and should_run_async(code) NumPy-shape semantics has been activated in your code. This is required for creating and manipulating scalar and zero-size tensors, which were not supported in MXNet before, as in the official NumPy library. Please DO NOT manually deactivate this semantics while using `mxnet.numpy` and `mxnet.numpy_extension` modules. 2020-12-08 20:42:27,127 - root - INFO - All Logs will be saved to ./ag_sts_mixed_score/ag_text_prediction.log 2020-12-08 20:42:27,151 - root - INFO - Train Dataset: 2020-12-08 20:42:27,152 - root - INFO - Columns: - Text( name="sentence1" #total/missing=4599/0 length, min/avg/max=16/57.93/367 ) - Text( name="sentence2" #total/missing=4599/0 length, min/avg/max=15/57.63/311 ) - Categorical( name="genre" #total/missing=4599/0 num_class (total/non_special)=4/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[1612, 358, 2629] ) - Numerical( name="score" #total/missing=4599/0 shape=() ) 2020-12-08 20:42:27,152 - root - INFO - Tuning Dataset: 2020-12-08 20:42:27,153 - root - INFO - Columns: - Text( name="sentence1" #total/missing=1150/0 length, min/avg/max=16/56.84/272 ) - Text( name="sentence2" #total/missing=1150/0 length, min/avg/max=16/57.15/249 ) - Categorical( name="genre" #total/missing=1150/0 num_class (total/non_special)=4/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[388, 92, 670] ) - Numerical( name="score" #total/missing=1150/0 shape=() ) 2020-12-08 20:42:27,154 - root - INFO - Label columns=['score'], Feature columns=['sentence1', 'sentence2', 'genre'], Problem types=['regression'], Label shapes=[()] 2020-12-08 20:42:27,154 - root - INFO - Eval Metric=mse, Stop Metric=mse, Log Metrics=['mse', 'rmse', 'mae'] .. parsed-literal:: :class: output HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value=''))) .. parsed-literal:: :class: output 0%| | 0/576 [00:00