Image-Text Semantic Matching with AutoMM¶
Vision and language are two important aspects of human intelligence to understand the real world. Image-text semantic matching, measuring the visual-semantic similarity between image and text, plays a critical role in bridging the vision and language. Learning a joint space where text and image feature vectors are aligned is a typical solution for image-text matching. It is becoming increasingly significant for various vision-and-language tasks, such as cross-modal retrieval, image captioning, text-to-image synthesis, and multimodal neural machine translation. This tutorial will introduce how to apply AutoMM to the image-text matching task.
import os
import warnings
from IPython.display import Image, display
import numpy as np
warnings.filterwarnings('ignore')
np.random.seed(123)
Dataset¶
In this tutorial, we will use the Flickr30K dataset to demonstrate the image-text matching. The Flickr30k dataset is a popular benchmark for sentence-based picture portrayal. The dataset is comprised of 31,783 images that capture people engaged in everyday activities and events. Each image has a descriptive caption. We organized the dataset using pandas dataframe. To get started, Let’s download the dataset.
from autogluon.core.utils.loaders import load_pd
import pandas as pd
download_dir = './ag_automm_tutorial_imgtxt'
zip_file = 'https://automl-mm-bench.s3.amazonaws.com/flickr30k.zip'
from autogluon.core.utils.loaders import load_zip
load_zip.unzip(zip_file, unzip_dir=download_dir)
Downloading ./ag_automm_tutorial_imgtxt/file.zip from https://automl-mm-bench.s3.amazonaws.com/flickr30k.zip...
100%|██████████| 4.38G/4.38G [02:08<00:00, 34.0MiB/s]
Then we will load the csv files.
dataset_path = os.path.join(download_dir, 'flickr30k_processed')
train_data = pd.read_csv(f'{dataset_path}/train.csv', index_col=0)
val_data = pd.read_csv(f'{dataset_path}/val.csv', index_col=0)
test_data = pd.read_csv(f'{dataset_path}/test.csv', index_col=0)
image_col = "image"
text_col = "caption"
We also need to expand the relative image paths to use their absolute local paths.
def path_expander(path, base_folder):
    path_l = path.split(';')
    return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])
train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
val_data[image_col] = val_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
Take train_data for example, let’s see how the data look like in the
dataframe.
train_data.head()
| caption | image | |
|---|---|---|
| 0 | Two young guys with shaggy hair look at their ... | /home/ci/autogluon/docs/_build/eval/tutorials/... | 
| 1 | Two young White males are outside near many bu... | /home/ci/autogluon/docs/_build/eval/tutorials/... | 
| 2 | Two men in green shirts are standing in a yard | /home/ci/autogluon/docs/_build/eval/tutorials/... | 
| 3 | A man in a blue shirt standing in a garden | /home/ci/autogluon/docs/_build/eval/tutorials/... | 
| 4 | Two friends enjoy time spent together | /home/ci/autogluon/docs/_build/eval/tutorials/... | 
Each row is one image and text pair, implying that they match each other. Since one image corresponds to five captions in the dataset, we copy each image path five times to build the correspondences. We can visualize one image-text pair.
train_data[text_col][0]
'Two young guys with shaggy hair look at their hands while hanging out in the yard'
pil_img = Image(filename=train_data[image_col][0])
display(pil_img)
 
To perform evaluation or semantic search, we need to extract the unique
image and text items from text_data and add one label column in the
test_data.
test_image_data = pd.DataFrame({image_col: test_data[image_col].unique().tolist()})
test_text_data = pd.DataFrame({text_col: test_data[text_col].unique().tolist()})
test_data_with_label = test_data.copy()
test_label_col = "relevance"
test_data_with_label[test_label_col] = [1] * len(test_data)
Initialize Predictor¶
To initialize a predictor for image-text matching, we need to set
problem_type as image_text_similarity. query and
response refer to the two dataframe columns in which two items in
one row should match each other. You can set query=text_col and
response=image_col, or query=image_col and
response=text_col. In image-text matching, query and
response are equivalent.
from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor(
            query=text_col,
            response=image_col,
            problem_type="image_text_similarity",
            eval_metric="recall",
        )
Downloading /home/ci/autogluon/multimodal/src/autogluon/multimodal/data/templates.zip from https://automl-mm-bench.s3.amazonaws.com/few_shot/templates.zip...
By initializing the predictor for image_text_similarity, you have
loaded the pretrained CLIP backbone openai/clip-vit-base-patch32.
Directly Evaluate on Test Dataset (Zero-shot)¶
You may be interested in getting the pretrained model’s performance on your data. Let’s compute the text-to-image and image-to-text retrieval scores.
txt_to_img_scores = predictor.evaluate(
            data=test_data_with_label,
            query_data=test_text_data,
            response_data=test_image_data,
            label=test_label_col,
            cutoffs=[1, 5, 10],
        )
img_to_txt_scores = predictor.evaluate(
            data=test_data_with_label,
            query_data=test_image_data,
            response_data=test_text_data,
            label=test_label_col,
            cutoffs=[1, 5, 10],
        )
print(f"txt_to_img_scores: {txt_to_img_scores}")
print(f"img_to_txt_scores: {img_to_txt_scores}")
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa9b1606040>
Traceback (most recent call last):
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa9b1606040>
Traceback (most recent call last):
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
txt_to_img_scores: {'recall@1': 0.58964, 'recall@5': 0.83533, 'recall@10': 0.90156}
img_to_txt_scores: {'recall@1': 0.15525, 'recall@5': 0.571, 'recall@10': 0.7176}
Here we report the recall, which is the eval_metric in
intializing the predictor above. One cutoff value means using the
top k retrieved items to calculate the score. You may find that the
text-to-image recalls are much higher than the image-to-text recalls.
This is because each image is paired with five texts. In image-to-text
retrieval, the upper bound of recall@1 is 20%, which means that the
top-1 text is correct, but there are totally five texts to retrieve.
Finetune Predictor¶
After measuring the pretrained performance, we can finetune the model on our dataset to see whether we can get improvements. For a quick demo, here we set the time limit to 180 seconds.
predictor.fit(
            train_data=train_data,
            tuning_data=val_data,
            time_limit=180,
        )
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa9b1606040>Exception ignored in:
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fa9b1606040>Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
      File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
self._shutdown_workers()
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    self._shutdown_workers()
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    if w.is_alive():
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    AssertionErrorassert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: : can only test a child process
can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa9b1606040>
Traceback (most recent call last):
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
Exception ignored in:     self._shutdown_workers()
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    if w.is_alive():<function _MultiProcessingDataLoaderIter.__del__ at 0x7fa9b1606040>
Traceback (most recent call last):
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
    self._shutdown_workers()
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    if w.is_alive():
  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7faa36976790>
Evaluate the Finetuned Model on the Test Dataset¶
Now Let’s evaluate the finetuned model. Similarly, we also compute the recalls of text-to-image and image-to-text retrievals.
txt_to_img_scores = predictor.evaluate(
            data=test_data_with_label,
            query_data=test_text_data,
            response_data=test_image_data,
            label=test_label_col,
            cutoffs=[1, 5, 10],
        )
img_to_txt_scores = predictor.evaluate(
            data=test_data_with_label,
            query_data=test_image_data,
            response_data=test_text_data,
            label=test_label_col,
            cutoffs=[1, 5, 10],
        )
print(f"txt_to_img_scores: {txt_to_img_scores}")
print(f"img_to_txt_scores: {img_to_txt_scores}")
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa9b1606040>
Traceback (most recent call last):
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa9b1606040>if w.is_alive():
  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
Traceback (most recent call last):
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
        self._shutdown_workers()
  File "/home/ci/opt/venv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    assert self._parent_pid == os.getpid(), 'can only test a child process'if w.is_alive():
  File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionErrorAssertionError: can only test a child process
: can only test a child process
txt_to_img_scores: {'recall@1': 0.70928, 'recall@5': 0.91897, 'recall@10': 0.95708}
img_to_txt_scores: {'recall@1': 0.17145, 'recall@5': 0.6784, 'recall@10': 0.8206}
We can observe large improvements over the zero-shot predictor. This means that finetuning CLIP on our customized data may help achieve better performance.
Predict Whether Image and Text Match¶
Whether finetuned or not, the predictor can predict whether image and text pairs match.
pred = predictor.predict(test_data.head(5))
print(pred)
0    1
1    1
2    1
3    1
4    1
dtype: int64
Predict Matching Probabilities¶
The predictor can also return to you the matching probabilities.
proba = predictor.predict_proba(test_data.head(5))
print(proba)
          0         1
0  0.342470  0.657530
1  0.330035  0.669965
2  0.348019  0.651981
3  0.345222  0.654778
4  0.328657  0.671343
The second column is the probability of being a match.
Extract Embeddings¶
Another common user case is to extract image and text embeddings.
image_embeddings = predictor.extract_embedding({image_col: test_image_data[image_col][:5].tolist()})
print(image_embeddings.shape)
(5, 512)
text_embeddings = predictor.extract_embedding({text_col: test_text_data[text_col][:5].tolist()})
print(text_embeddings.shape)
(5, 512)
Semantic Search¶
We also provide an advanced util function to conduct semantic search. First, given one or more texts, we can search semantically similar images from an image database.
from autogluon.multimodal.utils import semantic_search
text_to_image_hits = semantic_search(
        matcher=predictor,
        query_data=test_text_data.iloc[[3]],
        response_data=test_image_data,
        top_k=5,
    )
Let’s visualize the text query and top-1 image response.
test_text_data.iloc[[3]]
| caption | |
|---|---|
| 3 | A man in an orange hat starring at something | 
pil_img = Image(filename=test_image_data[image_col][text_to_image_hits[0][0]['response_id']])
display(pil_img)
 
Similarly, given one or more images, we can retrieve texts with similar semantic meanings.
image_to_text_hits = semantic_search(
        matcher=predictor,
        query_data=test_image_data.iloc[[6]],
        response_data=test_text_data,
        top_k=5,
    )
Let’s visualize the image query and top-1 text response.
pil_img = Image(filename=test_image_data[image_col][6])
display(pil_img)
 
test_text_data[text_col][image_to_text_hits[0][1]['response_id']]
'Several students waiting outside an igloo'
Other Examples¶
You may go to AutoMM Examples to explore other examples about AutoMM.
Customization¶
To learn how to customize AutoMM, please refer to Customize AutoMM.
