AutoMMPredictor for Image, Text, and Tabular

Are you tired of switching codebases or hacking code for different data modalities (image, text, numerical, and categorical data) and tasks (classification, regression, and more)? AutoMMPredictor provides a one-stop shop for multimodal/unimodal deep learning models. This tutorial demonstrates several application scenarios.

  • Multimodal Prediction

    • CLIP

    • TIMM + Huggingface Transformers + More

  • Image Prediction

  • Text Prediction

  • Configuration Customization

  • APIs

import os
import numpy as np
import warnings
warnings.filterwarnings('ignore')
np.random.seed(123)

Dataset

For demonstration, we use the PetFinder dataset. The PetFinder dataset provides information about shelter animals that appear on their adoption profile to predict the animals’ adoption rates, grouped into five categories, hence a multi-class classification problem.

To get started, let’s download and prepare the dataset.

download_dir = './ag_automm_tutorial'
zip_file = 'https://automl-mm-bench.s3.amazonaws.com/petfinder_kaggle.zip'
from autogluon.core.utils.loaders import load_zip
load_zip.unzip(zip_file, unzip_dir=download_dir)
Downloading ./ag_automm_tutorial/file.zip from https://automl-mm-bench.s3.amazonaws.com/petfinder_kaggle.zip...
100%|██████████| 2.00G/2.00G [01:34<00:00, 21.2MiB/s]

Next, we will load the CSV files.

import pandas as pd
dataset_path = download_dir + '/petfinder_processed'
train_data = pd.read_csv(f'{dataset_path}/train.csv', index_col=0)
test_data = pd.read_csv(f'{dataset_path}/dev.csv', index_col=0)
label_col = 'AdoptionSpeed'

We need to expand the image paths to load them in training.

image_col = 'Images'
train_data[image_col] = train_data[image_col].apply(lambda ele: ele.split(';')[0]) # Use the first image for a quick tutorial
test_data[image_col] = test_data[image_col].apply(lambda ele: ele.split(';')[0])


def path_expander(path, base_folder):
    path_l = path.split(';')
    return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])

train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))

train_data[image_col].iloc[0]
'/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/ag_automm_tutorial/petfinder_processed/train_images/e4b90955c-1.jpg'

Each animal’s adoption profile includes pictures, a text description, and various tabular features such as age, breed, name, color, and more. Let’s look at an example row of data and display the text description and a picture.

example_row = train_data.iloc[47]

example_row
Type                                                             2
Name                                                         Money
Age                                                              4
Breed1                                                         266
Breed2                                                           0
Gender                                                           2
Color1                                                           1
Color2                                                           2
Color3                                                           7
MaturitySize                                                     1
FurLength                                                        2
Vaccinated                                                       2
Dewormed                                                         1
Sterilized                                                       2
Health                                                           1
Quantity                                                         1
Fee                                                              0
State                                                        41401
RescuerID                         ee7445af32acfa1dc8307a9dc7baed21
VideoAmt                                                         0
Description      My pet is a pretty beautiful kitty which has a...
PetID                                                    98c08df17
PhotoAmt                                                       2.0
AdoptionSpeed                                                    2
Images           /var/lib/jenkins/workspace/workspace/autogluon...
Name: 14845, dtype: object
example_row['Description']
'My pet is a pretty beautiful kitty which has a mixed colour soft fur. She is active and full of life. And one thing about her, she loves to eat.She always turn on me like a tiger when I was preparing the food for her.'
example_image = example_row['Images']

from IPython.display import Image, display
pil_img = Image(filename=example_image)
display(pil_img)
../../_images/output_automm_963233_11_0.jpg

For the demo purpose, we will sample 500 and 100 rows for training and testing, respectively.

train_data = train_data.sample(500, random_state=0)
test_data = test_data.sample(100, random_state=0)

Multimodal Prediction

CLIP

AutoMMPredictor allows for finetuning the pre-trained vision language models, such as CLIP.

from autogluon.text.automm import AutoMMPredictor
predictor = AutoMMPredictor(label=label_col)
predictor.fit(
    train_data=train_data,
    hyperparameters={
        "model.names": ["clip"],
        "env.num_gpus": 1,
    },
    time_limit=120, # seconds
)
Global seed set to 123
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type             | Params
-------------------------------------------------------
0 | model             | CLIPForImageText | 151 M
1 | validation_metric | Accuracy         | 0
2 | loss_func         | CrossEntropyLoss | 0
-------------------------------------------------------
151 M     Trainable params
0         Non-trainable params
151 M     Total params
302.560   Total estimated model params size (MB)
Global seed set to 123
Epoch 0, global step 1: val_accuracy reached 0.27000 (best 0.27000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=0-step=1.ckpt" as top 3
Epoch 0, global step 3: val_accuracy reached 0.30000 (best 0.30000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=0-step=3.ckpt" as top 3
Epoch 1, global step 5: val_accuracy reached 0.29000 (best 0.30000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=1-step=5.ckpt" as top 3
Epoch 1, global step 7: val_accuracy was not in top 3
Epoch 2, global step 9: val_accuracy reached 0.29000 (best 0.30000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=2-step=9.ckpt" as top 3
Epoch 2, global step 11: val_accuracy reached 0.33000 (best 0.33000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=2-step=11.ckpt" as top 3
Epoch 3, global step 13: val_accuracy reached 0.33000 (best 0.33000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=3-step=13.ckpt" as top 3
Epoch 3, global step 15: val_accuracy was not in top 3
Epoch 4, global step 17: val_accuracy reached 0.33000 (best 0.33000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=4-step=17.ckpt" as top 3
Time limit reached. Elapsed time is 0:02:00. Signaling Trainer to stop.
Epoch 4, global step 18: val_accuracy was not in top 3
<autogluon.text.automm.predictor.AutoMMPredictor at 0x7fbebc0fcd60>
scores = predictor.evaluate(test_data, metrics=["accuracy"])
scores
{'accuracy': 0.33}

In this example, AutoMMPredictor finetunes CLIP with the image, text, and categorical (converted to text) data.

TIMM + Huggingface Transformers + More

In addtion to CLIP, AutoMMPredictor can simultaneously finetune various timm backbones and huggingface transformers. Moreover, AutoMMPredictor uses MLP for numerical data but converts categorical data to text by default.

Let’s use AutoMMPredictor to train a late fusion model including CLIP, swin_small_patch4_window7_224, google/electra-small-discriminator, a numerical MLP, and a fusion MLP.

from autogluon.text.automm import AutoMMPredictor
predictor = AutoMMPredictor(label=label_col)
predictor.fit(
    train_data=train_data,
    hyperparameters={
        "model.names": ["clip", "timm_image", "hf_text", "numerical_mlp", "fusion_mlp"],
        "model.timm_image.checkpoint_name": "swin_small_patch4_window7_224",
        "model.hf_text.checkpoint_name": "google/electra-small-discriminator",
        "env.num_gpus": 1,
    },
    time_limit=120, # seconds
)
Global seed set to 123
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                | Params
----------------------------------------------------------
0 | model             | MultimodalFusionMLP | 215 M
1 | validation_metric | Accuracy            | 0
2 | loss_func         | CrossEntropyLoss    | 0
----------------------------------------------------------
215 M     Trainable params
0         Non-trainable params
215 M     Total params
430.576   Total estimated model params size (MB)
Global seed set to 123
Epoch 0, global step 1: val_accuracy reached 0.16000 (best 0.16000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053904/epoch=0-step=1.ckpt" as top 3
Epoch 0, global step 3: val_accuracy reached 0.20000 (best 0.20000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053904/epoch=0-step=3.ckpt" as top 3
Epoch 1, global step 5: val_accuracy reached 0.32000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053904/epoch=1-step=5.ckpt" as top 3
Epoch 1, global step 7: val_accuracy was not in top 3
Epoch 2, global step 9: val_accuracy reached 0.17000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053904/epoch=2-step=9.ckpt" as top 3
Time limit reached. Elapsed time is 0:02:03. Signaling Trainer to stop.
<autogluon.text.automm.predictor.AutoMMPredictor at 0x7fbe51c77190>
scores = predictor.evaluate(test_data, metrics=["accuracy"])
scores
{'accuracy': 0.43}

Image Prediction

If you want to use only image data or your tasks only have image data, AutoMMPredictor can help you finetune a wide range of timm backbones, such as swin_small_patch4_window7_224.

from autogluon.text.automm import AutoMMPredictor
predictor = AutoMMPredictor(label=label_col)
predictor.fit(
    train_data=train_data,
    hyperparameters={
        "model.names": ["timm_image"],
        "model.timm_image.checkpoint_name": "swin_tiny_patch4_window7_224",
        "env.num_gpus": 1,
    },
    time_limit=60, # seconds
)
Global seed set to 123
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                            | Params
----------------------------------------------------------------------
0 | model             | TimmAutoModelForImagePrediction | 27.5 M
1 | validation_metric | Accuracy                        | 0
2 | loss_func         | CrossEntropyLoss                | 0
----------------------------------------------------------------------
27.5 M    Trainable params
0         Non-trainable params
27.5 M    Total params
55.046    Total estimated model params size (MB)
Global seed set to 123
Epoch 0, global step 1: val_accuracy reached 0.23000 (best 0.23000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=0-step=1.ckpt" as top 3
Epoch 0, global step 3: val_accuracy reached 0.31000 (best 0.31000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=0-step=3.ckpt" as top 3
Epoch 1, global step 5: val_accuracy reached 0.32000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=1-step=5.ckpt" as top 3
Epoch 1, global step 7: val_accuracy reached 0.35000 (best 0.35000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=1-step=7.ckpt" as top 3
Epoch 2, global step 9: val_accuracy was not in top 3
Epoch 2, global step 11: val_accuracy was not in top 3
Epoch 3, global step 13: val_accuracy was not in top 3
Epoch 3, global step 15: val_accuracy reached 0.34000 (best 0.35000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=3-step=15.ckpt" as top 3
Epoch 4, global step 17: val_accuracy reached 0.36000 (best 0.36000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=4-step=17.ckpt" as top 3
Epoch 4, global step 19: val_accuracy reached 0.36000 (best 0.36000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=4-step=19.ckpt" as top 3
Epoch 5, global step 21: val_accuracy was not in top 3
Epoch 5, global step 23: val_accuracy was not in top 3
Epoch 6, global step 25: val_accuracy was not in top 3
Epoch 6, global step 27: val_accuracy was not in top 3
Time limit reached. Elapsed time is 0:01:01. Signaling Trainer to stop.
Epoch 7, global step 28: val_accuracy was not in top 3
<autogluon.text.automm.predictor.AutoMMPredictor at 0x7fbef3b4b8e0>

Here AutoMMPredictor uses only image data since model.names only include timm_image.

scores = predictor.evaluate(test_data, metrics=["accuracy"])
scores
{'accuracy': 0.34}

Text Prediction

Similarly, you may be interested in only finetuning the text backbones from huggingface transformers, such as google/electra-small-discriminator.

from autogluon.text.automm import AutoMMPredictor
predictor = AutoMMPredictor(label=label_col)
predictor.fit(
    train_data=train_data,
    hyperparameters={
        "model.names": ["hf_text"],
        "model.hf_text.checkpoint_name": "google/electra-small-discriminator",
        "env.num_gpus": 1,
    },
    time_limit=60, # seconds
)
Global seed set to 123
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                         | Params
-------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 13.5 M
1 | validation_metric | Accuracy                     | 0
2 | loss_func         | CrossEntropyLoss             | 0
-------------------------------------------------------------------
13.5 M    Trainable params
0         Non-trainable params
13.5 M    Total params
26.969    Total estimated model params size (MB)
Global seed set to 123
Epoch 0, global step 1: val_accuracy reached 0.32000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=0-step=1.ckpt" as top 3
Epoch 0, global step 3: val_accuracy reached 0.22000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=0-step=3.ckpt" as top 3
Epoch 1, global step 5: val_accuracy reached 0.23000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=1-step=5.ckpt" as top 3
Epoch 1, global step 7: val_accuracy was not in top 3
Epoch 2, global step 9: val_accuracy was not in top 3
Epoch 2, global step 11: val_accuracy reached 0.28000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=2-step=11.ckpt" as top 3
Epoch 3, global step 13: val_accuracy reached 0.31000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=3-step=13.ckpt" as top 3
Epoch 3, global step 15: val_accuracy reached 0.31000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=3-step=15.ckpt" as top 3
Epoch 4, global step 17: val_accuracy was not in top 3
Epoch 4, global step 19: val_accuracy reached 0.32000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=4-step=19.ckpt" as top 3
Epoch 5, global step 21: val_accuracy reached 0.33000 (best 0.33000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=5-step=21.ckpt" as top 3
Epoch 5, global step 23: val_accuracy was not in top 3
Epoch 6, global step 25: val_accuracy was not in top 3
Epoch 6, global step 27: val_accuracy was not in top 3
Epoch 7, global step 29: val_accuracy was not in top 3
Epoch 7, global step 31: val_accuracy was not in top 3
Epoch 8, global step 33: val_accuracy was not in top 3
Epoch 8, global step 35: val_accuracy was not in top 3
Epoch 9, global step 37: val_accuracy was not in top 3
Epoch 9, global step 39: val_accuracy was not in top 3
Saving latest checkpoint...
<autogluon.text.automm.predictor.AutoMMPredictor at 0x7fbe52336e50>

With only hf_text in model.names, AutoMMPredictor automatically uses only text and categorical (converted to text) data.

scores = predictor.evaluate(test_data, metrics=["accuracy"])
scores
{'accuracy': 0.19}

Configuration Customization

The above examples have shown the flexibility of AutoMMPredictor. You may want to know how to customize configurations for your tasks. Fortunately, AutoMMPredictor has a user-friendly configuration design.

First, let’s see the available model presets.

from autogluon.text.automm.presets import list_model_presets, get_preset
model_presets = list_model_presets()
model_presets
['fusion_mlp_image_text_tabular']

Currently, AutoMMPredictor has only one model preset, from which we can construct the predictor’s preset.

preset = get_preset(model_presets[0])
preset
{'model': 'fusion_mlp_image_text_tabular',
 'data': 'default',
 'optimization': 'adamw',
 'environment': 'default'}

AutoMMPredictor configurations consist of four parts: model, data, optimization, and environment. You can convert the preset to configurations to see the details.

from omegaconf import OmegaConf
from autogluon.text.automm.utils import get_config
config = get_config(preset)
print(OmegaConf.to_yaml(config))
model:
  names:
  - categorical_mlp
  - numerical_mlp
  - hf_text
  - timm_image
  - clip
  - fusion_mlp
  categorical_mlp:
    hidden_size: 64
    activation: leaky_relu
    num_layers: 1
    drop_rate: 0.1
    normalization: layer_norm
    data_types:
    - categorical
  numerical_mlp:
    hidden_size: 128
    activation: leaky_relu
    num_layers: 1
    drop_rate: 0.1
    normalization: layer_norm
    data_types:
    - numerical
    merge: concat
  hf_text:
    checkpoint_name: google/electra-base-discriminator
    data_types:
    - text
    tokenizer_name: hf_auto
    max_text_len: 512
    insert_sep: true
    text_segment_num: 2
    stochastic_chunk: false
  timm_image:
    checkpoint_name: swin_base_patch4_window7_224
    mix_choice: all_logits
    data_types:
    - image
    train_transform_types:
    - resize_shorter_side
    - center_crop
    val_transform_types:
    - resize_shorter_side
    - center_crop
    image_norm: imagenet
    image_size: 224
    max_img_num_per_col: 3
  clip:
    checkpoint_name: openai/clip-vit-base-patch32
    data_types:
    - image
    - text
    train_transform_types:
    - resize_shorter_side
    - center_crop
    val_transform_types:
    - resize_shorter_side
    - center_crop
    image_norm: clip
    image_size: 224
    max_img_num_per_col: 0
    tokenizer_name: clip
    max_text_len: 77
    insert_sep: false
    text_segment_num: 1
    stochastic_chunk: false
  fusion_mlp:
    weight: 0.1
    adapt_in_features: max
    hidden_sizes:
    - 128
    activation: leaky_relu
    drop_rate: 0.1
    normalization: layer_norm
    data_types: null
data:
  image: null
  text: null
  categorical:
    minimum_cat_count: 100
    maximum_num_cat: 20
    convert_to_text: true
  numerical:
    convert_to_text: false
    scaler_with_mean: true
    scaler_with_std: true
optimization:
  optim_type: adamw
  learning_rate: 0.0001
  weight_decay: 0.001
  lr_choice: layerwise_decay
  lr_decay: 0.8
  lr_schedule: cosine_decay
  max_epochs: 10
  max_steps: -1
  warmup_steps: 0.1
  end_lr: 0
  lr_mult: 1
  patience: 10
  val_check_interval: 0.5
  top_k: 3
env:
  num_gpus: -1
  num_nodes: 1
  batch_size: 128
  per_gpu_batch_size: 8
  per_gpu_batch_size_evaluation: 64
  precision: 16
  num_workers: 2
  num_workers_evaluation: 2
  fast_dev_run: false
  deterministic: false
  auto_select_gpus: true
  strategy: ddp_spawn

The model config provides four model types: MLP for categorical data (categorical_mlp), MLP for numerical data (numerical_mlp), huggingface transformers for text data (hf_text), timm for image data (timm_image), clip for image+text data, and a MLP to fuse any combinations of categorical_mlp, numerical_mlp, hf_text, and timm_image (fusion_mlp). We can specify the model combinations by setting model.names. Moreover, we can use model.hf_text.checkpoint_name and model.timm_image.checkpoint_name to customize huggingface and timm backbones.

The data config defines some model-agnostic rules in preprocessing data. Note that AutoMMPredictor converts categorical data into text by default.

The optimization config has hyper-parameters for model training. AutoMMPredictor uses layer-wise learning rate decay, which decreases the learning rate gradually from the output to the input end of one model.

The env config contains the environment/machine related hyper-parameters. For example, the optimal values of per_gpu_batch_size and per_gpu_batch_size_evaluation are closely related to the GPU memory size.

You can flexibly customize any hyper-parameter in config via the hyperparameters argument of .fit(). To access one hyper-parameter in config, you need to traverse from top-level keys to bottom-level keys and join them together with . For example, if you want to change the per GPU batch size to 16, you can set hyperparameters={"env.per_gpu_batch_size": 16}.

APIs

Besides .fit() and .evaluate(), AutoMMPredictor also provides other useful APIs, similar to those in TextPredictor and TabularPredictor. You may refer to more details in Text Prediction - Quick Start.

Given data without ground truth labels, AutoMMPredictor can make predictions.

predictions = predictor.predict(test_data.drop(columns=label_col))
predictions[:5]
1873     2
8536     2
7988     2
10127    2
14668    1
Name: AdoptionSpeed, dtype: int64

For classification tasks, we can get the probabilities of all classes.

probas = predictor.predict_proba(test_data.drop(columns=label_col))
probas[:5]
0 1 2 3 4
1873 0.014537 0.295644 0.310817 0.122612 0.256390
8536 0.017496 0.269609 0.302648 0.154956 0.255291
7988 0.019225 0.257543 0.369263 0.139917 0.214052
10127 0.014647 0.264183 0.362841 0.116510 0.241819
14668 0.016556 0.375061 0.286553 0.089543 0.232288

Note that calling .predict_proba on one regression task will throw an exception.

Extract embeddings can be easily done via .extract_embedding().

embeddings = predictor.extract_embedding(test_data.drop(columns=label_col))
embeddings.shape
(100, 256)

It is also convenient to save and load a predictor.

predictor.save('my_saved_dir')
loaded_predictor = AutoMMPredictor.load('my_saved_dir')
scores2 = loaded_predictor.evaluate(test_data, metrics=["accuracy"])
scores2
{'accuracy': 0.19}