AutoMM for Image Classification - Quick Start¶
In this quick start, we’ll use the task of image classification to
illustrate how to use MultiModalPredictor. Once the data is prepared
in Pandas
DataFrame
format, a single call to MultiModalPredictor.fit()
will take care of
the model training for you.
Create Image Dataset¶
For demonstration purposes, we use a subset of the Shopee-IET
dataset
from Kaggle. Each image in this data depicts a clothing item and the
corresponding label specifies its clothing category. Our subset of the
data contains the following possible labels: BabyPants
,
BabyShirt
, womencasualshoes
, womenchiffontop
.
We can load a dataset by downloading a url data automatically:
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
from autogluon.multimodal.utils.misc import shopee_dataset
download_dir = './ag_automm_tutorial_imgcls'
train_data, test_data = shopee_dataset(download_dir)
print(train_data)
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling transformers.utils.move_cache().
Moving 0 files to the new cache system
0it [00:00, ?it/s]
Downloading ./ag_automm_tutorial_imgcls/file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/shopee.zip...
100%|██████████| 41.9M/41.9M [00:00<00:00, 43.9MiB/s]
image label
0 /home/ci/autogluon/docs/_build/eval/tutorials/... 0
1 /home/ci/autogluon/docs/_build/eval/tutorials/... 0
2 /home/ci/autogluon/docs/_build/eval/tutorials/... 0
3 /home/ci/autogluon/docs/_build/eval/tutorials/... 0
4 /home/ci/autogluon/docs/_build/eval/tutorials/... 0
.. ... ...
795 /home/ci/autogluon/docs/_build/eval/tutorials/... 3
796 /home/ci/autogluon/docs/_build/eval/tutorials/... 3
797 /home/ci/autogluon/docs/_build/eval/tutorials/... 3
798 /home/ci/autogluon/docs/_build/eval/tutorials/... 3
799 /home/ci/autogluon/docs/_build/eval/tutorials/... 3
[800 rows x 2 columns]
We can see there are 800 rows and 2 columns in this training dataframe. The 2 columns are image and label, and each row represents a different training sample.
Use AutoMM to Fit Models¶
Now, we fit a classifier using AutoMM as follows:
from autogluon.multimodal import MultiModalPredictor
import uuid
model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee"
predictor = MultiModalPredictor(label="label", path=model_path)
predictor.fit(
train_data=train_data,
time_limit=30, # seconds
) # you can trust the default config, e.g., we use a `swin_base_patch4_window7_224` model
Global seed set to 123
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth" to /home/ci/.cache/torch/hub/checkpoints/swin_base_patch4_window7_224_22kto1k.pth
Auto select gpus: [0]
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
----------------------------------------------------------------------
0 | model | TimmAutoModelForImagePrediction | 86.7 M
1 | validation_metric | Accuracy | 0
2 | loss_func | CrossEntropyLoss | 0
----------------------------------------------------------------------
86.7 M Trainable params
0 Non-trainable params
86.7 M Total params
173.495 Total estimated model params size (MB)
Epoch 0, global step 2: 'val_accuracy' reached 0.32500 (best 0.32500), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/image_prediction/tmp/af1b830977d6462fb4b61b051ae461f4-automm_shopee/epoch=0-step=2.ckpt' as top 3
Epoch 0, global step 5: 'val_accuracy' reached 0.86250 (best 0.86250), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/image_prediction/tmp/af1b830977d6462fb4b61b051ae461f4-automm_shopee/epoch=0-step=5.ckpt' as top 3
Time limit reached. Elapsed time is 0:00:30. Signaling Trainer to stop.
Epoch 1, global step 6: 'val_accuracy' reached 0.90000 (best 0.90000), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/image_prediction/tmp/af1b830977d6462fb4b61b051ae461f4-automm_shopee/epoch=1-step=6.ckpt' as top 3
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7f39ccee3160>
label is the name of the column that contains the target variable to predict, e.g., it is “label” in our example. path indicates the directory where models and intermediate outputs should be saved. We set the training time limit to 30 seconds for demonstration purpose, but you can control the training time by setting configurations. To customize AutoMM, please refer to Customize AutoMM.
Evaluate on Test Dataset¶
You can evaluate the classifier on the test dataset to see how it performs, the test top-1 accuracy is:
scores = predictor.evaluate(test_data, metrics=["accuracy"])
print('Top-1 test acc: %.3f' % scores["accuracy"])
Top-1 test acc: 0.963
Predict on a New Image¶
Given an example image, let’s visualize it first,
image_path = test_data.iloc[0]['image']
from IPython.display import Image, display
pil_img = Image(filename=image_path)
display(pil_img)

We can easily use the final model to predict
the label,
predictions = predictor.predict({'image': [image_path]})
print(predictions)
[0]
If probabilities of all categories are needed, you can call
predict_proba
:
proba = predictor.predict_proba({'image': [image_path]})
print(proba)
[[0.5130405 0.27893528 0.07837059 0.12965365]]
Extract Embeddings¶
Extracting representation from the whole image learned by a model is
also very useful. We provide extract_embedding
function to allow
predictor to return the N-dimensional image feature where N
depends
on the model(usually a 512 to 2048 length vector)
feature = predictor.extract_embedding({'image': [image_path]})
print(feature[0].shape)
(1024,)
Save and Load¶
The trained predictor is automatically saved at the end of fit()
,
and you can easily reload it.
Warning
MultiModalPredictor.load()
used pickle
module implicitly,
which is known to be insecure. It is possible to construct malicious
pickle data which will execute arbitrary code during unpickling.
Never load data that could have come from an untrusted source, or
that could have been tampered with. Only load data you trust.
loaded_predictor = MultiModalPredictor.load(model_path)
load_proba = loaded_predictor.predict_proba({'image': [image_path]})
print(load_proba)
[[0.5130405 0.27893528 0.07837059 0.12965365]]
We can see the predicted class probabilities are still the same as above, which means same model!
Other Examples¶
You may go to AutoMM Examples to explore other examples about AutoMM.
Customization¶
To learn how to customize AutoMM, please refer to Customize AutoMM.