Text Prediction - Multimodal Table with Text

In many applications, text data may be mixed with numeric/categorical data. AutoGluon’s TextPredictor can train a single neural network that jointly operates on multiple feature types, including text, categorical, and numerical columns. The general idea is to embed the text, categorical and numeric fields separately and fuse these features across modalities. This tutorial demonstrates such an application.

import numpy as np
import pandas as pd
import os
import warnings
warnings.filterwarnings('ignore')
np.random.seed(123)
!python3 -m pip install openpyxl
Collecting openpyxl
  Using cached openpyxl-3.0.7-py2.py3-none-any.whl (243 kB)
Collecting et-xmlfile
  Using cached et_xmlfile-1.1.0-py3-none-any.whl (4.7 kB)
Installing collected packages: et-xmlfile, openpyxl
Successfully installed et-xmlfile-1.1.0 openpyxl-3.0.7

Book Price Prediction Data

For demonstration, we use the book price prediction dataset from the MachineHack Salary Prediction Hackathon. Our goal is to predict a book’s price given various features like its author, the abstract, the book’s rating, etc.

!mkdir -p price_of_books
!wget https://automl-mm-bench.s3.amazonaws.com/machine_hack_competitions/predict_the_price_of_books/Data.zip -O price_of_books/Data.zip
!cd price_of_books && unzip -o Data.zip
!ls price_of_books/Participants_Data
--2021-04-29 01:10:59--  https://automl-mm-bench.s3.amazonaws.com/machine_hack_competitions/predict_the_price_of_books/Data.zip
Resolving automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)... 52.216.239.19
Connecting to automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)|52.216.239.19|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3521673 (3.4M) [application/zip]
Saving to: ‘price_of_books/Data.zip’

price_of_books/Data 100%[===================>]   3.36M  6.53MB/s    in 0.5s

2021-04-29 01:11:00 (6.53 MB/s) - ‘price_of_books/Data.zip’ saved [3521673/3521673]

Archive:  Data.zip
  inflating: Participants_Data/Data_Test.xlsx
  inflating: Participants_Data/Data_Train.xlsx
  inflating: Participants_Data/Sample_Submission.xlsx
Data_Test.xlsx      Data_Train.xlsx  Sample_Submission.xlsx
train_df = pd.read_excel(os.path.join('price_of_books', 'Participants_Data', 'Data_Train.xlsx'), engine='openpyxl')
train_df.head()
Title Author Edition Reviews Ratings Synopsis Genre BookCategory Price
0 The Prisoner's Gold (The Hunters 3) Chris Kuzneski Paperback,– 10 Mar 2016 4.0 out of 5 stars 8 customer reviews THE HUNTERS return in their third brilliant no... Action & Adventure (Books) Action & Adventure 220.00
1 Guru Dutt: A Tragedy in Three Acts Arun Khopkar Paperback,– 7 Nov 2012 3.9 out of 5 stars 14 customer reviews A layered portrait of a troubled genius for wh... Cinema & Broadcast (Books) Biographies, Diaries & True Accounts 202.93
2 Leviathan (Penguin Classics) Thomas Hobbes Paperback,– 25 Feb 1982 4.8 out of 5 stars 6 customer reviews "During the time men live without a common Pow... International Relations Humour 299.00
3 A Pocket Full of Rye (Miss Marple) Agatha Christie Paperback,– 5 Oct 2017 4.1 out of 5 stars 13 customer reviews A handful of grain is found in the pocket of a... Contemporary Fiction (Books) Crime, Thriller & Mystery 180.00
4 LIFE 70 Years of Extraordinary Photography Editors of Life Hardcover,– 10 Oct 2006 5.0 out of 5 stars 1 customer review For seven decades, "Life" has been thrilling t... Photography Textbooks Arts, Film & Photography 965.62

We do some basic preprocessing to convert Reviews and Ratings in the data table to numeric values, and we transform prices to a log-scale.

def preprocess(df):
    df = df.copy(deep=True)
    df.loc[:, 'Reviews'] = pd.to_numeric(df['Reviews'].apply(lambda ele: ele[:-len(' out of 5 stars')]))
    df.loc[:, 'Ratings'] = pd.to_numeric(df['Ratings'].apply(lambda ele: ele.replace(',', '')[:-len(' customer reviews')]))
    df.loc[:, 'Price'] = np.log(df['Price'] + 1)
    return df
train_subsample_size = 1500  # subsample for faster demo, you can try setting to larger values
test_subsample_size = 5
train_df = preprocess(train_df)
train_data = train_df.iloc[100:].sample(train_subsample_size, random_state=123)
test_data = train_df.iloc[:100].sample(test_subsample_size, random_state=245)
train_data.head()
Title Author Edition Reviews Ratings Synopsis Genre BookCategory Price
949 Furious Hours Casey Cep Paperback,– 1 Jun 2019 4.0 NaN ‘It’s been a long time since I picked up a boo... True Accounts (Books) Biographies, Diaries & True Accounts 5.743003
5504 REST API Design Rulebook Mark Masse Paperback,– 7 Nov 2011 5.0 NaN In todays market, where rival web services com... Computing, Internet & Digital Media (Books) Computing, Internet & Digital Media 5.786897
5856 The Atlantropa Articles: A Novel Cody Franklin Paperback,– Import, 1 Nov 2018 4.5 2.0 #1 Amazon Best Seller! Dystopian Alternate His... Action & Adventure (Books) Romance 6.893656
4137 Hickory Dickory Dock (Poirot) Agatha Christie Paperback,– 5 Oct 2017 4.3 21.0 There’s more than petty theft going on in a Lo... Action & Adventure (Books) Crime, Thriller & Mystery 5.192957
3205 The Stanley Kubrick Archives (Bibliotheca Univ... Alison Castle Hardcover,– 21 Aug 2016 4.6 3.0 In 1968, when Stanley Kubrick was asked to com... Cinema & Broadcast (Books) Humour 6.889591

Training

We can simply create a TextPredictor and call predictor.fit() to train a model that operates on across all types of features. Internally, the neural network will be automatically generated based on the inferred data type of each feature column. To save time, we subsample the data and only train for three minutes.

from autogluon.text import TextPredictor
time_limit = 3 * 60  # set to larger value in your applications
predictor = TextPredictor(label='Price', path='ag_text_book_price_prediction')
predictor.fit(train_data, time_limit=time_limit)
INFO:root:NumPy-shape semantics has been activated in your code. This is required for creating and manipulating scalar and zero-size tensors, which were not supported in MXNet before, as in the official NumPy library. Please DO NOT manually deactivate this semantics while using mxnet.numpy and mxnet.numpy_extension modules.
INFO:autogluon.text.text_prediction.mx.models:The GluonNLP V0 backend is used. We will use 8 cpus and 1 gpus to train each trial.
All Logs will be saved to /var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/ag_text_book_price_prediction/task0/training.log
INFO:root:Fitting and transforming the train data...
INFO:root:Done! Preprocessor saved to /var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/ag_text_book_price_prediction/task0/preprocessor.pkl
INFO:root:Process dev set...
INFO:root:Done!
INFO:root:Max length for chunking text: 480, Stochastic chunk: Train-False/Test-False, Test #repeat: 1.
INFO:root:#Total Params/Fixed Params=109338913/0
Level 15:root:Using gradient accumulation. Global batch size = 128
INFO:root:Local training results will be saved to /var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/ag_text_book_price_prediction/task0/results_local.jsonl.
Level 15:root:[Iter 1/100, Epoch 0] train loss=2.72e+00, gnorm=6.91e+01, lr=1.00e-05, #samples processed=128, #sample per second=12.05. ETA=17.52min
Level 15:root:[Iter 2/100, Epoch 0] train loss=1.86e+00, gnorm=4.09e+01, lr=2.00e-05, #samples processed=128, #sample per second=12.12. ETA=17.30min
Level 25:root:[Iter 2/100, Epoch 0] valid r2=-8.0283e-01, root_mean_squared_error=1.0801e+00, mean_absolute_error=9.0033e-01, time spent=8.979s, total time spent=0.52min. Find new best=True, Find new top-3=True
Level 15:root:[Iter 3/100, Epoch 0] train loss=2.05e+00, gnorm=4.37e+01, lr=3.00e-05, #samples processed=128, #sample per second=6.44. ETA=22.13min
Level 15:root:[Iter 4/100, Epoch 0] train loss=3.00e+00, gnorm=7.53e+01, lr=4.00e-05, #samples processed=128, #sample per second=12.85. ETA=20.41min
Level 25:root:[Iter 4/100, Epoch 0] valid r2=-5.0688e-01, root_mean_squared_error=9.8744e-01, mean_absolute_error=8.3290e-01, time spent=9.066s, total time spent=1.03min. Find new best=True, Find new top-3=True
Level 15:root:[Iter 5/100, Epoch 0] train loss=1.99e+00, gnorm=4.95e+01, lr=5.00e-05, #samples processed=128, #sample per second=6.46. ETA=22.43min
Level 15:root:[Iter 6/100, Epoch 0] train loss=1.01e+00, gnorm=1.81e+01, lr=6.00e-05, #samples processed=128, #sample per second=14.69. ETA=20.77min
Level 25:root:[Iter 6/100, Epoch 0] valid r2=-4.4972e-01, root_mean_squared_error=9.6853e-01, mean_absolute_error=7.6413e-01, time spent=9.156s, total time spent=1.50min. Find new best=True, Find new top-3=True
Level 15:root:[Iter 7/100, Epoch 0] train loss=1.60e+00, gnorm=3.10e+01, lr=7.00e-05, #samples processed=128, #sample per second=6.27. ETA=22.14min
Level 15:root:[Iter 8/100, Epoch 0] train loss=1.12e+00, gnorm=5.11e+00, lr=8.00e-05, #samples processed=128, #sample per second=13.57. ETA=20.97min
Level 25:root:[Iter 8/100, Epoch 0] valid r2=1.0843e-01, root_mean_squared_error=7.5954e-01, mean_absolute_error=6.0718e-01, time spent=9.238s, total time spent=2.01min. Find new best=True, Find new top-3=True
Level 15:root:[Iter 9/100, Epoch 0] train loss=9.84e-01, gnorm=1.89e+01, lr=9.00e-05, #samples processed=128, #sample per second=6.20. ETA=21.92min
Level 15:root:[Iter 10/100, Epoch 0] train loss=1.27e+00, gnorm=5.74e+00, lr=1.00e-04, #samples processed=128, #sample per second=12.04. ETA=21.11min
Level 25:root:[Iter 10/100, Epoch 0] valid r2=-1.7946e-02, root_mean_squared_error=8.1158e-01, mean_absolute_error=6.2267e-01, time spent=9.380s, total time spent=2.52min. Find new best=False, Find new top-3=True
Level 15:root:[Iter 11/100, Epoch 1] train loss=1.28e+00, gnorm=1.52e+01, lr=9.89e-05, #samples processed=128, #sample per second=6.39. ETA=21.67min
Level 15:root:[Iter 12/100, Epoch 1] train loss=9.78e-01, gnorm=1.08e+01, lr=9.78e-05, #samples processed=128, #sample per second=11.56. ETA=21.00min
Level 25:root:[Iter 12/100, Epoch 1] valid r2=2.5560e-01, root_mean_squared_error=6.9402e-01, mean_absolute_error=5.4284e-01, time spent=9.951s, total time spent=3.06min. Find new best=True, Find new top-3=True
INFO:root:Training completed. Auto-saving to "ag_text_book_price_prediction/". For loading the model, you can use predictor = TextPredictor.load("ag_text_book_price_prediction/")
<autogluon.text.text_prediction.predictor.predictor.TextPredictor at 0x7f0f03db7d50>

Prediction

We can easily obtain predictions and extract data embeddings using the TextPredictor.

predictions = predictor.predict(test_data)
print('Predictions:')
print('------------')
print(np.exp(predictions) - 1)
print()
print('True Value:')
print('------------')
print(np.exp(test_data['Price']) - 1)
Predictions:
------------
1     322.644562
31    457.390015
19    847.595947
45    553.764893
82    558.156494
Name: Price, dtype: float32

True Value:
------------
1     202.93
31    799.00
19    352.00
45    395.10
82    409.00
Name: Price, dtype: float64
performance = predictor.evaluate(test_data)
print(performance)
0.5481857657432556
embeddings = predictor.extract_embedding(test_data)
print(embeddings)
[[-0.23467553  0.4187132  -0.38100958 ...  0.10046377 -0.09448481
   0.06040561]
 [-0.17090559  0.45281178 -0.427304   ... -0.00930673 -0.04129042
   0.04809796]
 [-0.18613525  0.55869526 -0.41296566 ... -0.48956588  0.4044409
   0.26346046]
 [-0.19453217  0.43630105 -0.37463886 ... -0.05238581  0.09296411
   0.32541677]
 [-0.3221671   0.55299646 -0.45222527 ...  0.07579938 -0.0347809
   0.19916777]]

What’s happening inside?

Internally, we use different networks to encode the text columns, categorical columns, and numerical columns. The features generated by individual networks are aggregated by a late-fusion aggregator. The aggregator can output both the logits or score predictions. The architecture can be illustrated as follows:

https://autogluon-text-data.s3.amazonaws.com/figures/fuse-late.png

Fig. 1 Multimodal Network with Late Fusion

Here, we use the pretrained NLP backbone to extract the text features and then use two other towers to extract the feature from categorical column and the numerical column.

In addition, to deal with multiple text fields, we separate these fields with the [SEP] token and alternate 0s and 1s as the segment IDs, which is shown as follows:

https://autogluon-text-data.s3.amazonaws.com/figures/preprocess.png

Fig. 2 Preprocessing

How does this compare with TabularPredictor?

Note that TabularPredictor can also handle data tables with text, numeric, and categorical columns, but it uses an ensemble of many types of models and may featurize text. TextPredictor instead directly fits individual Transformer neural network models directly to the raw text (which are also capable of handling additional numeric/categorical columns). We generally recommend TabularPredictor if your table contains mainly numeric/categorical columns and TextPredictor if your table contains mainly text columns, but you may easily try both and we encourage this. In fact, TabularPredictor.fit(..., hyperparameters='multimodal') will train a TextPredictor along with many tabular models and ensemble them together. Refer to the tutorial “Multimodal Data Tables: Combining BERT/Transformers and Classical Tabular Models” for more details.

Other Examples

You may go to https://github.com/awslabs/autogluon/tree/master/examples/text_prediction to explore other TextPredictor examples, including scripts to train a TextPredictor on the complete book price prediction dataset.