.. _text2text_matching:
Text-to-Text Semantic Matching with AutoMM
==========================================
Computing the similarity between two sentences/passages is a common task
in NLP, with several practical applications such as web search, question
answering, documents deduplication, plagiarism comparison, natural
language inference, recommendation engines, etc. In general, text
similarity models will take two sentences/passages as input and
transform them into vectors, and then similarity scores calculated using
cosine similarity, dot product, or Euclidean distances are used to
measure how alike or different of the two text pieces.
Prepare your Data
-----------------
In this tutorial, we will demonstrate how to use AutoMM for text-to-text
semantic matching with the Stanford Natural Language Inference
(`SNLI `__) corpus. SNLI is a
corpus contains around 570k human-written sentence pairs labeled with
*entailment*, *contradiction*, and *neutral*. It is a widely used
benchmark for evaluating the representation and inference capbility of
machine learning methods. The following table contains three examples
taken from this corpus.
+----------------------------+----------------------------------+------+
| Premise | Hypothesis | Labe |
| | | l |
+============================+==================================+======+
| A black race car starts up | A man is driving down a lonely | cont |
| in front of a crowd of | road. | radi |
| people. | | ctio |
| | | n |
+----------------------------+----------------------------------+------+
| An older and younger man | Two men are smiling and laughing | neut |
| smiling. | at the cats playing on the | ral |
| | floor. | |
+----------------------------+----------------------------------+------+
| A soccer game with | Some men are playing a sport. | enta |
| multiple males playing. | | ilme |
| | | nt |
+----------------------------+----------------------------------+------+
Here, we consider sentence pairs with label *entailment* as positive
pairs (labeled as 1) and those with label *contradiction* as negative
pairs (labeled as 0). Sentence pairs with neural relationship are
discarded. The following code downloads and loads the corpus into
dataframes.
.. code:: python
from autogluon.core.utils.loaders import load_pd
import pandas as pd
snli_train = load_pd.load('https://automl-mm-bench.s3.amazonaws.com/snli/snli_train.csv', delimiter="|")
snli_test = load_pd.load('https://automl-mm-bench.s3.amazonaws.com/snli/snli_test.csv', delimiter="|")
snli_train.head()
.. raw:: html
|
premise |
hypothesis |
label |
0 |
A person on a horse jumps over a broken down a... |
A person is at a diner , ordering an omelette . |
0 |
1 |
A person on a horse jumps over a broken down a... |
A person is outdoors , on a horse . |
1 |
2 |
Children smiling and waving at camera |
There are children present |
1 |
3 |
Children smiling and waving at camera |
The kids are frowning |
0 |
4 |
A boy is jumping on skateboard in the middle o... |
The boy skates down the sidewalk . |
0 |
Train your Model
----------------
Ideally, we want to obtain a model that can return high/low scores for
positive/negative text pairs. Traditional text similarity methods only
work on a lexical level without taking the semantic aspect into account,
for example, using term frequency or tf-idf vectors. With AutoMM, we can
easily train a model that captures the semantic relationship between
sentences. Basically, it uses
`BERT `__ to project each sentence
into a high-dimensional vector and treat the matching problem as a
classification problem following the design in `sentence
transformers `__.
With AutoMM, you just need to specify the query, response, and label
column names and fit the model on the training dataset without worrying
the implementation details. Note that the labels should be binary, and
we need to specify the ``match_label``, which means two sentences have
the same semantic meaning. In practice, your tasks may have different
labels, e.g., duplicate or not duplicate. You may need to define the
``match_label`` by considering your specific task contexts.
.. code:: python
from autogluon.multimodal import MultiModalPredictor
# Initialize the model
predictor = MultiModalPredictor(
problem_type="text_similarity",
query="premise", # the column name of the first sentence
response="hypothesis", # the column name of the second sentence
label="label", # the label column name
match_label=1, # the label indicating that query and response have the same semantic meanings.
eval_metric='auc', # the evaluation metric
)
# Fit the model
predictor.fit(
train_data=snli_train,
time_limit=180,
)
.. parsed-literal::
:class: output
Global seed set to 123
No path specified. Models will be saved in: "AutogluonModels/ag-20230222_233725/"
/home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/metric.py:92: UserWarning: Currently, we cannot convert the metric: auc to a metric supported in torchmetrics. Thus, we will fall-back to use accuracy for multi-class classification problems , ROC-AUC for binary classification problem, and RMSE for regression problems.
warnings.warn(
/home/ci/opt/venv/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.
warnings.warn(*args, **kwargs)
Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
-------------------------------------------------------------------
0 | query_model | HFAutoModelForTextPrediction | 33.4 M
1 | response_model | HFAutoModelForTextPrediction | 33.4 M
2 | validation_metric | AUROC | 0
3 | loss_func | ContrastiveLoss | 0
4 | miner_func | PairMarginMiner | 0
-------------------------------------------------------------------
33.4 M Trainable params
0 Non-trainable params
33.4 M Total params
66.720 Total estimated model params size (MB)
Time limit reached. Elapsed time is 0:03:00. Signaling Trainer to stop.
Epoch 0, global step 137: 'val_roc_auc' reached 0.90326 (best 0.90326), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/matching/AutogluonModels/ag-20230222_233725/epoch=0-step=137.ckpt' as top 3
.. parsed-literal::
:class: output
Evaluate on Test Dataset
------------------------
You can evaluate the macther on the test dataset to see how it performs
with the roc_auc score:
.. code:: python
score = predictor.evaluate(snli_test)
print("evaluation score: ", score)
.. parsed-literal::
:class: output
evaluation score: {'roc_auc': 0.9120869555327099}
Predict on a New Sentence Pair
------------------------------
We create a new sentence pair with similar meaning (expected to be
predicted as :math:`1`) and make predictions using the trained model.
.. code:: python
pred_data = pd.DataFrame.from_dict({"premise":["The teacher gave his speech to an empty room."],
"hypothesis":["There was almost nobody when the professor was talking."]})
predictions = predictor.predict(pred_data)
print('Predicted entities:', predictions[0])
.. parsed-literal::
:class: output
Predicted entities: 1
Predict Matching Probabilities
------------------------------
We can also compute the matching probabilities of sentence pairs.
.. code:: python
probabilities = predictor.predict_proba(pred_data)
print(probabilities)
.. parsed-literal::
:class: output
0 1
0 0.207848 0.792152
Extract Embeddings
------------------
Moreover, we support extracting embeddings separately for two sentence
groups.
.. code:: python
embeddings_1 = predictor.extract_embedding({"premise":["The teacher gave his speech to an empty room."]})
print(embeddings_1.shape)
embeddings_2 = predictor.extract_embedding({"hypothesis":["There was almost nobody when the professor was talking."]})
print(embeddings_2.shape)
.. parsed-literal::
:class: output
(1, 384)
(1, 384)
Other Examples
--------------
You may go to `AutoMM
Examples `__
to explore other examples about AutoMM.
Customization
-------------
To learn how to customize AutoMM, please refer to
:ref:`sec_automm_customization`.