AutoMM for Entity Extraction with Text and Image - Quick Start

Open In Colab Open In SageMaker Studio Lab

We have introduced how to train an entity extraction model with text data. Here, we move a step further by integrating data of other modalities. In many real-world applications, textual data usually comes with data of other modalities. For example, Twitter allows you to compose tweets with text, photos, videos, and GIFs. Amazon.com uses text, images, and videos to describe their products. These auxiliary modalities can be leveraged as additional context resolution of entities. Now, with AutoMM, you can easily exploit multimodal data to enhance entity extraction without worrying about the details.

import os
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

Get the Twitter Dataset

In the following example, we will demonstrate how to build a multimodal named entity recognition model with a real-world Twitter dataset. This dataset consists of scrapped tweets from 2016 to 2017, and each tweet was composed of one sentence and one image. Let’s download the dataset.

download_dir = './ag_automm_tutorial_ner'
zip_file = 'https://automl-mm-bench.s3.amazonaws.com/ner/multimodal_ner.zip'
from autogluon.core.utils.loaders import load_zip
load_zip.unzip(zip_file, unzip_dir=download_dir)
Downloading ./ag_automm_tutorial_ner/file.zip from https://automl-mm-bench.s3.amazonaws.com/ner/multimodal_ner.zip...
  0%|          | 0.00/423M [00:00<?, ?iB/s]
  2%|▏         | 10.5M/423M [00:00<00:03, 105MiB/s]
  5%|▍         | 21.0M/423M [00:00<00:05, 80.2MiB/s]
  7%|▋         | 31.6M/423M [00:00<00:04, 90.2MiB/s]
 10%|█         | 42.3M/423M [00:00<00:03, 96.3MiB/s]
 12%|█▏        | 52.3M/423M [00:00<00:04, 88.3MiB/s]
 14%|█▍        | 61.3M/423M [00:01<00:08, 41.6MiB/s]
 16%|█▌        | 68.0M/423M [00:01<00:10, 34.6MiB/s]
 17%|█▋        | 73.2M/423M [00:01<00:10, 34.7MiB/s]
 20%|█▉        | 84.5M/423M [00:01<00:07, 48.0MiB/s]
 22%|██▏       | 91.3M/423M [00:01<00:07, 45.3MiB/s]
 24%|██▍       | 103M/423M [00:01<00:05, 58.6MiB/s]
 26%|██▋       | 111M/423M [00:01<00:04, 65.1MiB/s]
 28%|██▊       | 120M/423M [00:02<00:04, 65.8MiB/s]
 30%|███       | 128M/423M [00:02<00:05, 58.6MiB/s]
 32%|███▏      | 137M/423M [00:02<00:04, 60.2MiB/s]
 35%|███▌      | 148M/423M [00:02<00:03, 71.0MiB/s]
 37%|███▋      | 156M/423M [00:02<00:04, 62.7MiB/s]
 39%|███▊      | 163M/423M [00:02<00:04, 53.7MiB/s]
 40%|████      | 171M/423M [00:02<00:04, 59.2MiB/s]
 42%|████▏     | 177M/423M [00:03<00:05, 45.0MiB/s]
 43%|████▎     | 183M/423M [00:03<00:05, 41.7MiB/s]
 45%|████▍     | 189M/423M [00:03<00:05, 44.7MiB/s]
 47%|████▋     | 200M/423M [00:03<00:03, 58.6MiB/s]
 49%|████▉     | 206M/423M [00:03<00:04, 43.8MiB/s]
 51%|█████▏    | 218M/423M [00:03<00:03, 57.5MiB/s]
 53%|█████▎    | 225M/423M [00:04<00:03, 60.4MiB/s]
 56%|█████▌    | 236M/423M [00:04<00:02, 73.0MiB/s]
 58%|█████▊    | 245M/423M [00:04<00:03, 49.5MiB/s]
 61%|██████    | 256M/423M [00:04<00:02, 62.0MiB/s]
 63%|██████▎   | 265M/423M [00:04<00:02, 68.2MiB/s]
 65%|██████▌   | 276M/423M [00:04<00:01, 78.4MiB/s]
 68%|██████▊   | 288M/423M [00:04<00:01, 86.8MiB/s]
 70%|███████   | 297M/423M [00:05<00:02, 51.9MiB/s]
 73%|███████▎  | 309M/423M [00:05<00:01, 63.5MiB/s]
 75%|███████▌  | 318M/423M [00:05<00:01, 64.2MiB/s]
 77%|███████▋  | 326M/423M [00:05<00:01, 60.0MiB/s]
 80%|███████▉  | 337M/423M [00:05<00:01, 70.4MiB/s]
 82%|████████▏ | 346M/423M [00:06<00:01, 53.5MiB/s]
 84%|████████▍ | 357M/423M [00:06<00:01, 65.2MiB/s]
 86%|████████▋ | 365M/423M [00:06<00:01, 50.5MiB/s]
 89%|████████▉ | 377M/423M [00:06<00:00, 62.4MiB/s]
 91%|█████████ | 385M/423M [00:06<00:00, 50.2MiB/s]
 92%|█████████▏| 391M/423M [00:06<00:00, 44.2MiB/s]
 94%|█████████▍| 397M/423M [00:07<00:00, 44.4MiB/s]
 96%|█████████▋| 408M/423M [00:07<00:00, 57.7MiB/s]
 98%|█████████▊| 415M/423M [00:07<00:00, 53.7MiB/s]
100%|██████████| 423M/423M [00:07<00:00, 57.3MiB/s]

Next, we will load the CSV files.

dataset_path = download_dir + '/multimodal_ner'
train_data = pd.read_csv(f'{dataset_path}/twitter17_train.csv')
test_data = pd.read_csv(f'{dataset_path}/twitter17_test.csv')
label_col = 'entity_annotations'

We need to expand the image paths to load them in training.

image_col = 'image'
train_data[image_col] = train_data[image_col].apply(lambda ele: ele.split(';')[0]) # Use the first image for a quick tutorial
test_data[image_col] = test_data[image_col].apply(lambda ele: ele.split(';')[0])

def path_expander(path, base_folder):
	path_l = path.split(';')
	p = ';'.join([os.path.abspath(base_folder+path) for path in path_l])
	return p

train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))

train_data[image_col].iloc[0]
'/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/ag_automm_tutorial_ner/multimodal_ner/twitter2017_images/17_06_1818.jpg'

Each row consists of the text and image of a single tweet and the entity_annotataions which contains the named entity annotations for the text column. Let’s look at an example row and display the text and picture of the tweet.

example_row = train_data.iloc[0]

example_row
text_snippet           Uefa Super Cup : Real Madrid v Manchester United
image                 /home/ci/autogluon/docs/tutorials/multimodal/m...
entity_annotations    [{"entity_group": "B-MISC", "start": 0, "end":...
Name: 0, dtype: object

Below is the image of this tweet.

example_image = example_row[image_col]

from IPython.display import Image, display
pil_img = Image(filename=example_image, width =300)
display(pil_img)

As you can see, this photo contains the logos of the Real Madrid football club, Manchester United football club, and the UEFA super cup. Clearly, the key information of the tweet sentence is coded here in a different modality.

Training

Now let’s fit the predictor with the training data. Firstly, we need to specify the problem_type to ner. As our annotations are used for text columns, to ensure the model to locate the correct text column for entity extraction, we need to set the corresponding column type to text_ner using the column_types parameter in cases where multiple text columns are present. Here we set a tight time budget for a quick demo.

from autogluon.multimodal import MultiModalPredictor
import uuid

label_col = "entity_annotations"
model_path = f"./tmp/{uuid.uuid4().hex}-automm_multimodal_ner"
predictor = MultiModalPredictor(problem_type="ner", label=label_col, path=model_path)
predictor.fit(
	train_data=train_data,
	column_types={"text_snippet":"text_ner"},
	time_limit=300, #second
)
=================== System Info ===================
AutoGluon Version:  1.3.2b20250527
Python Version:     3.11.10
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       28.41 GB / 30.95 GB (91.8%)
Disk Space Avail:   180.22 GB / 255.99 GB (70.4%)
===================================================

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/11bf9b93e6a94605bad31fddb52e51d8-automm_multimodal_ner
    ```
INFO: Seed set to 0
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[8], line 7
      5 model_path = f"./tmp/{uuid.uuid4().hex}-automm_multimodal_ner"
      6 predictor = MultiModalPredictor(problem_type="ner", label=label_col, path=model_path)
----> 7 predictor.fit(
      8 	train_data=train_data,
      9 	column_types={"text_snippet":"text_ner"},
     10 	time_limit=300, #second
     11 )

File ~/autogluon/multimodal/src/autogluon/multimodal/predictor.py:540, in MultiModalPredictor.fit(self, train_data, presets, tuning_data, max_num_tuning_data, id_mappings, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_predictor, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, predictions, labels, predictors)
    537     assert isinstance(predictors, list)
    538     learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
--> 540 self._learner.fit(
    541     train_data=train_data,
    542     presets=presets,
    543     tuning_data=tuning_data,
    544     max_num_tuning_data=max_num_tuning_data,
    545     time_limit=time_limit,
    546     save_path=save_path,
    547     hyperparameters=hyperparameters,
    548     column_types=column_types,
    549     holdout_frac=holdout_frac,
    550     teacher_learner=teacher_learner,
    551     seed=seed,
    552     standalone=standalone,
    553     hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
    554     clean_ckpts=clean_ckpts,
    555     id_mappings=id_mappings,
    556     predictions=predictions,
    557     labels=labels,
    558     learners=learners,
    559 )
    561 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:665, in BaseLearner.fit(self, train_data, presets, tuning_data, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_learner, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, **kwargs)
    658 self.fit_sanity_check()
    659 self.prepare_fit_args(
    660     time_limit=time_limit,
    661     seed=seed,
    662     standalone=standalone,
    663     clean_ckpts=clean_ckpts,
    664 )
--> 665 fit_returns = self.execute_fit()
    666 self.on_fit_end(
    667     training_start=training_start,
    668     strategy=fit_returns.get("strategy", None),
   (...)
    671     clean_ckpts=clean_ckpts,
    672 )
    674 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:577, in BaseLearner.execute_fit(self)
    575     return dict()
    576 else:
--> 577     attributes = self.fit_per_run(**self._fit_args)
    578     self.update_attributes(**attributes)  # only update attributes for non-HPO mode
    579     return attributes

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/ner.py:203, in NERLearner.fit_per_run(self, max_time, save_path, ckpt_path, resume, enable_progress_bar, seed, hyperparameters, advanced_hyperparameters, config, df_preprocessor, data_processors, model, standalone, clean_ckpts)
    201 config = self.update_config_by_data_per_run(config=config, df_preprocessor=df_preprocessor)
    202 output_shape = self.get_output_shape_per_run(df_preprocessor=df_preprocessor)
--> 203 model = self.get_model_per_run(
    204     model=model,
    205     config=config,
    206     df_preprocessor=df_preprocessor,
    207     output_shape=output_shape,
    208 )
    209 model = self.compile_model_per_run(config=config, model=model)
    210 peft_param_names = self.get_peft_param_names_per_run(model=model, config=config)

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/ner.py:105, in NERLearner.get_model_per_run(self, model, config, df_preprocessor, output_shape)
     97 def get_model_per_run(
     98     self,
     99     model: nn.Module,
   (...)
    102     output_shape: int,
    103 ):
    104     if model is None:
--> 105         model = create_fusion_model(
    106             config=config,
    107             num_classes=output_shape,
    108             num_numerical_columns=len(df_preprocessor.numerical_feature_names),
    109             num_categories=df_preprocessor.categorical_num_categories,
    110         )
    111     return model

File ~/autogluon/multimodal/src/autogluon/multimodal/models/utils.py:1649, in create_fusion_model(config, num_classes, classes, num_numerical_columns, num_categories, numerical_fill_values, pretrained)
   1645         single_models.append(model)
   1647 if len(single_models) > 1:
   1648     # must have one fusion model if there are multiple independent models
-> 1649     model = fusion_model(models=single_models)
   1650 elif len(single_models) == 1:
   1651     model = single_models[0]

File ~/autogluon/multimodal/src/autogluon/multimodal/models/fusion/fusion_ner.py:67, in MultimodalFusionNER.__init__(self, prefix, models, hidden_features, num_classes, adapt_in_features, activation, dropout_prob, normalization, loss_weight)
     23 def __init__(
     24     self,
     25     prefix: str,
   (...)
     33     loss_weight: Optional[float] = None,
     34 ):
     35     """
     36     Parameters
     37     ----------
   (...)
     65         The weight of individual models.
     66     """
---> 67     super().__init__(
     68         prefix=prefix,
     69         models=models,
     70         loss_weight=loss_weight,
     71     )
     72     logger.debug("initializing MultimodalFusionNER")
     74     if loss_weight is not None:

TypeError: AbstractMultimodalFusionModel.__init__() got an unexpected keyword argument 'loss_weight'

Under the hood, AutoMM automatically detects the data modalities, selects the related models from the multimodal model pools, and trains the selected models. If multiple backbones are available, AutoMM appends a late-fusion model on top of them.

Evaluation

predictor.evaluate(test_data,  metrics=['overall_recall', "overall_precision", "overall_f1"])

Prediction

You can easily obtain the predictions by calling predictor.predict().

prediction_input = test_data.drop(columns=label_col).head(1)
predictions = predictor.predict(prediction_input)
print('Tweet:', prediction_input.text_snippet[0])
print('Image path:', prediction_input.image[0])
print('Predicted entities:', predictions[0])

for entity in predictions[0]:
	print(f"Word '{prediction_input.text_snippet[0][entity['start']:entity['end']]}' belongs to group: {entity['entity_group']}")

Reloading and Continuous Training

The trained predictor is automatically saved and you can easily reload it using the path. If you are not satisfied with the current model performance, you can continue training the loaded model with new data.

new_predictor = MultiModalPredictor.load(model_path)
new_model_path = f"./tmp/{uuid.uuid4().hex}-automm_multimodal_ner_continue_train"
new_predictor.fit(train_data, time_limit=60, save_path=new_model_path)
test_score = new_predictor.evaluate(test_data, metrics=['overall_f1'])
print(test_score)

Other Examples

You may go to AutoMM Examples to explore other examples about AutoMM.

Customization

To learn how to customize AutoMM, please refer to Customize AutoMM.