AutoMMPredictor for Image, Text, and Tabular¶
Are you tired of switching codebases or hacking code for different data
modalities (image, text, numerical, and categorical data) and tasks
(classification, regression, and more)? AutoMMPredictor
provides a
one-stop shop for multimodal/unimodal deep learning models. This
tutorial demonstrates several application scenarios.
Multimodal Prediction
CLIP
TIMM + Huggingface Transformers + More
Image Prediction
Text Prediction
Configuration Customization
APIs
import os
import numpy as np
import warnings
warnings.filterwarnings('ignore')
np.random.seed(123)
Dataset¶
For demonstration, we use the PetFinder dataset. The PetFinder dataset provides information about shelter animals that appear on their adoption profile to predict the animals’ adoption rates, grouped into five categories, hence a multi-class classification problem.
To get started, let’s download and prepare the dataset.
download_dir = './ag_automm_tutorial'
zip_file = 'https://automl-mm-bench.s3.amazonaws.com/petfinder_kaggle.zip'
from autogluon.core.utils.loaders import load_zip
load_zip.unzip(zip_file, unzip_dir=download_dir)
Downloading ./ag_automm_tutorial/file.zip from https://automl-mm-bench.s3.amazonaws.com/petfinder_kaggle.zip...
100%|██████████| 2.00G/2.00G [01:34<00:00, 21.2MiB/s]
Next, we will load the CSV files.
import pandas as pd
dataset_path = download_dir + '/petfinder_processed'
train_data = pd.read_csv(f'{dataset_path}/train.csv', index_col=0)
test_data = pd.read_csv(f'{dataset_path}/dev.csv', index_col=0)
label_col = 'AdoptionSpeed'
We need to expand the image paths to load them in training.
image_col = 'Images'
train_data[image_col] = train_data[image_col].apply(lambda ele: ele.split(';')[0]) # Use the first image for a quick tutorial
test_data[image_col] = test_data[image_col].apply(lambda ele: ele.split(';')[0])
def path_expander(path, base_folder):
path_l = path.split(';')
return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])
train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
train_data[image_col].iloc[0]
'/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/ag_automm_tutorial/petfinder_processed/train_images/e4b90955c-1.jpg'
Each animal’s adoption profile includes pictures, a text description, and various tabular features such as age, breed, name, color, and more. Let’s look at an example row of data and display the text description and a picture.
example_row = train_data.iloc[47]
example_row
Type 2
Name Money
Age 4
Breed1 266
Breed2 0
Gender 2
Color1 1
Color2 2
Color3 7
MaturitySize 1
FurLength 2
Vaccinated 2
Dewormed 1
Sterilized 2
Health 1
Quantity 1
Fee 0
State 41401
RescuerID ee7445af32acfa1dc8307a9dc7baed21
VideoAmt 0
Description My pet is a pretty beautiful kitty which has a...
PetID 98c08df17
PhotoAmt 2.0
AdoptionSpeed 2
Images /var/lib/jenkins/workspace/workspace/autogluon...
Name: 14845, dtype: object
example_row['Description']
'My pet is a pretty beautiful kitty which has a mixed colour soft fur. She is active and full of life. And one thing about her, she loves to eat.She always turn on me like a tiger when I was preparing the food for her.'
example_image = example_row['Images']
from IPython.display import Image, display
pil_img = Image(filename=example_image)
display(pil_img)

For the demo purpose, we will sample 500 and 100 rows for training and testing, respectively.
train_data = train_data.sample(500, random_state=0)
test_data = test_data.sample(100, random_state=0)
Multimodal Prediction¶
CLIP¶
AutoMMPredictor
allows for finetuning the pre-trained vision
language models, such as
CLIP.
from autogluon.text.automm import AutoMMPredictor
predictor = AutoMMPredictor(label=label_col)
predictor.fit(
train_data=train_data,
hyperparameters={
"model.names": ["clip"],
"env.num_gpus": 1,
},
time_limit=120, # seconds
)
Global seed set to 123
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
-------------------------------------------------------
0 | model | CLIPForImageText | 151 M
1 | validation_metric | Accuracy | 0
2 | loss_func | CrossEntropyLoss | 0
-------------------------------------------------------
151 M Trainable params
0 Non-trainable params
151 M Total params
302.560 Total estimated model params size (MB)
Global seed set to 123
Epoch 0, global step 1: val_accuracy reached 0.27000 (best 0.27000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=0-step=1.ckpt" as top 3
Epoch 0, global step 3: val_accuracy reached 0.30000 (best 0.30000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=0-step=3.ckpt" as top 3
Epoch 1, global step 5: val_accuracy reached 0.29000 (best 0.30000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=1-step=5.ckpt" as top 3
Epoch 1, global step 7: val_accuracy was not in top 3
Epoch 2, global step 9: val_accuracy reached 0.29000 (best 0.30000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=2-step=9.ckpt" as top 3
Epoch 2, global step 11: val_accuracy reached 0.33000 (best 0.33000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=2-step=11.ckpt" as top 3
Epoch 3, global step 13: val_accuracy reached 0.33000 (best 0.33000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=3-step=13.ckpt" as top 3
Epoch 3, global step 15: val_accuracy was not in top 3
Epoch 4, global step 17: val_accuracy reached 0.33000 (best 0.33000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=4-step=17.ckpt" as top 3
Time limit reached. Elapsed time is 0:02:00. Signaling Trainer to stop.
Epoch 4, global step 18: val_accuracy was not in top 3
<autogluon.text.automm.predictor.AutoMMPredictor at 0x7fbebc0fcd60>
scores = predictor.evaluate(test_data, metrics=["accuracy"])
scores
{'accuracy': 0.33}
In this example, AutoMMPredictor
finetunes CLIP with the image,
text, and categorical (converted to text) data.
TIMM + Huggingface Transformers + More¶
In addtion to CLIP, AutoMMPredictor
can simultaneously finetune
various timm
backbones and huggingface
transformers. Moreover,
AutoMMPredictor
uses MLP for numerical data but converts categorical
data to text by default.
Let’s use AutoMMPredictor
to train a late fusion model including
CLIP,
swin_small_patch4_window7_224,
google/electra-small-discriminator,
a numerical MLP, and a fusion MLP.
from autogluon.text.automm import AutoMMPredictor
predictor = AutoMMPredictor(label=label_col)
predictor.fit(
train_data=train_data,
hyperparameters={
"model.names": ["clip", "timm_image", "hf_text", "numerical_mlp", "fusion_mlp"],
"model.timm_image.checkpoint_name": "swin_small_patch4_window7_224",
"model.hf_text.checkpoint_name": "google/electra-small-discriminator",
"env.num_gpus": 1,
},
time_limit=120, # seconds
)
Global seed set to 123
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
----------------------------------------------------------
0 | model | MultimodalFusionMLP | 215 M
1 | validation_metric | Accuracy | 0
2 | loss_func | CrossEntropyLoss | 0
----------------------------------------------------------
215 M Trainable params
0 Non-trainable params
215 M Total params
430.576 Total estimated model params size (MB)
Global seed set to 123
Epoch 0, global step 1: val_accuracy reached 0.16000 (best 0.16000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053904/epoch=0-step=1.ckpt" as top 3
Epoch 0, global step 3: val_accuracy reached 0.20000 (best 0.20000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053904/epoch=0-step=3.ckpt" as top 3
Epoch 1, global step 5: val_accuracy reached 0.32000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053904/epoch=1-step=5.ckpt" as top 3
Epoch 1, global step 7: val_accuracy was not in top 3
Epoch 2, global step 9: val_accuracy reached 0.17000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053904/epoch=2-step=9.ckpt" as top 3
Time limit reached. Elapsed time is 0:02:03. Signaling Trainer to stop.
<autogluon.text.automm.predictor.AutoMMPredictor at 0x7fbe51c77190>
scores = predictor.evaluate(test_data, metrics=["accuracy"])
scores
{'accuracy': 0.43}
Image Prediction¶
If you want to use only image data or your tasks only have image data,
AutoMMPredictor
can help you finetune a wide range of
timm backbones,
such as
swin_small_patch4_window7_224.
from autogluon.text.automm import AutoMMPredictor
predictor = AutoMMPredictor(label=label_col)
predictor.fit(
train_data=train_data,
hyperparameters={
"model.names": ["timm_image"],
"model.timm_image.checkpoint_name": "swin_tiny_patch4_window7_224",
"env.num_gpus": 1,
},
time_limit=60, # seconds
)
Global seed set to 123
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
----------------------------------------------------------------------
0 | model | TimmAutoModelForImagePrediction | 27.5 M
1 | validation_metric | Accuracy | 0
2 | loss_func | CrossEntropyLoss | 0
----------------------------------------------------------------------
27.5 M Trainable params
0 Non-trainable params
27.5 M Total params
55.046 Total estimated model params size (MB)
Global seed set to 123
Epoch 0, global step 1: val_accuracy reached 0.23000 (best 0.23000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=0-step=1.ckpt" as top 3
Epoch 0, global step 3: val_accuracy reached 0.31000 (best 0.31000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=0-step=3.ckpt" as top 3
Epoch 1, global step 5: val_accuracy reached 0.32000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=1-step=5.ckpt" as top 3
Epoch 1, global step 7: val_accuracy reached 0.35000 (best 0.35000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=1-step=7.ckpt" as top 3
Epoch 2, global step 9: val_accuracy was not in top 3
Epoch 2, global step 11: val_accuracy was not in top 3
Epoch 3, global step 13: val_accuracy was not in top 3
Epoch 3, global step 15: val_accuracy reached 0.34000 (best 0.35000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=3-step=15.ckpt" as top 3
Epoch 4, global step 17: val_accuracy reached 0.36000 (best 0.36000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=4-step=17.ckpt" as top 3
Epoch 4, global step 19: val_accuracy reached 0.36000 (best 0.36000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=4-step=19.ckpt" as top 3
Epoch 5, global step 21: val_accuracy was not in top 3
Epoch 5, global step 23: val_accuracy was not in top 3
Epoch 6, global step 25: val_accuracy was not in top 3
Epoch 6, global step 27: val_accuracy was not in top 3
Time limit reached. Elapsed time is 0:01:01. Signaling Trainer to stop.
Epoch 7, global step 28: val_accuracy was not in top 3
<autogluon.text.automm.predictor.AutoMMPredictor at 0x7fbef3b4b8e0>
Here AutoMMPredictor
uses only image data since model.names
only
include timm_image
.
scores = predictor.evaluate(test_data, metrics=["accuracy"])
scores
{'accuracy': 0.34}
Text Prediction¶
Similarly, you may be interested in only finetuning the text backbones from huggingface transformers, such as google/electra-small-discriminator.
from autogluon.text.automm import AutoMMPredictor
predictor = AutoMMPredictor(label=label_col)
predictor.fit(
train_data=train_data,
hyperparameters={
"model.names": ["hf_text"],
"model.hf_text.checkpoint_name": "google/electra-small-discriminator",
"env.num_gpus": 1,
},
time_limit=60, # seconds
)
Global seed set to 123
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
-------------------------------------------------------------------
0 | model | HFAutoModelForTextPrediction | 13.5 M
1 | validation_metric | Accuracy | 0
2 | loss_func | CrossEntropyLoss | 0
-------------------------------------------------------------------
13.5 M Trainable params
0 Non-trainable params
13.5 M Total params
26.969 Total estimated model params size (MB)
Global seed set to 123
Epoch 0, global step 1: val_accuracy reached 0.32000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=0-step=1.ckpt" as top 3
Epoch 0, global step 3: val_accuracy reached 0.22000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=0-step=3.ckpt" as top 3
Epoch 1, global step 5: val_accuracy reached 0.23000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=1-step=5.ckpt" as top 3
Epoch 1, global step 7: val_accuracy was not in top 3
Epoch 2, global step 9: val_accuracy was not in top 3
Epoch 2, global step 11: val_accuracy reached 0.28000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=2-step=11.ckpt" as top 3
Epoch 3, global step 13: val_accuracy reached 0.31000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=3-step=13.ckpt" as top 3
Epoch 3, global step 15: val_accuracy reached 0.31000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=3-step=15.ckpt" as top 3
Epoch 4, global step 17: val_accuracy was not in top 3
Epoch 4, global step 19: val_accuracy reached 0.32000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=4-step=19.ckpt" as top 3
Epoch 5, global step 21: val_accuracy reached 0.33000 (best 0.33000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=5-step=21.ckpt" as top 3
Epoch 5, global step 23: val_accuracy was not in top 3
Epoch 6, global step 25: val_accuracy was not in top 3
Epoch 6, global step 27: val_accuracy was not in top 3
Epoch 7, global step 29: val_accuracy was not in top 3
Epoch 7, global step 31: val_accuracy was not in top 3
Epoch 8, global step 33: val_accuracy was not in top 3
Epoch 8, global step 35: val_accuracy was not in top 3
Epoch 9, global step 37: val_accuracy was not in top 3
Epoch 9, global step 39: val_accuracy was not in top 3
Saving latest checkpoint...
<autogluon.text.automm.predictor.AutoMMPredictor at 0x7fbe52336e50>
With only hf_text
in model.names
, AutoMMPredictor
automatically uses only text and categorical (converted to text) data.
scores = predictor.evaluate(test_data, metrics=["accuracy"])
scores
{'accuracy': 0.19}
Configuration Customization¶
The above examples have shown the flexibility of AutoMMPredictor
.
You may want to know how to customize configurations for your tasks.
Fortunately, AutoMMPredictor
has a user-friendly configuration
design.
First, let’s see the available model presets.
from autogluon.text.automm.presets import list_model_presets, get_preset
model_presets = list_model_presets()
model_presets
['fusion_mlp_image_text_tabular']
Currently, AutoMMPredictor
has only one model preset, from which we
can construct the predictor’s preset.
preset = get_preset(model_presets[0])
preset
{'model': 'fusion_mlp_image_text_tabular',
'data': 'default',
'optimization': 'adamw',
'environment': 'default'}
AutoMMPredictor
configurations consist of four parts: model
,
data
, optimization
, and environment
. You can convert the
preset to configurations to see the details.
from omegaconf import OmegaConf
from autogluon.text.automm.utils import get_config
config = get_config(preset)
print(OmegaConf.to_yaml(config))
model:
names:
- categorical_mlp
- numerical_mlp
- hf_text
- timm_image
- clip
- fusion_mlp
categorical_mlp:
hidden_size: 64
activation: leaky_relu
num_layers: 1
drop_rate: 0.1
normalization: layer_norm
data_types:
- categorical
numerical_mlp:
hidden_size: 128
activation: leaky_relu
num_layers: 1
drop_rate: 0.1
normalization: layer_norm
data_types:
- numerical
merge: concat
hf_text:
checkpoint_name: google/electra-base-discriminator
data_types:
- text
tokenizer_name: hf_auto
max_text_len: 512
insert_sep: true
text_segment_num: 2
stochastic_chunk: false
timm_image:
checkpoint_name: swin_base_patch4_window7_224
mix_choice: all_logits
data_types:
- image
train_transform_types:
- resize_shorter_side
- center_crop
val_transform_types:
- resize_shorter_side
- center_crop
image_norm: imagenet
image_size: 224
max_img_num_per_col: 3
clip:
checkpoint_name: openai/clip-vit-base-patch32
data_types:
- image
- text
train_transform_types:
- resize_shorter_side
- center_crop
val_transform_types:
- resize_shorter_side
- center_crop
image_norm: clip
image_size: 224
max_img_num_per_col: 0
tokenizer_name: clip
max_text_len: 77
insert_sep: false
text_segment_num: 1
stochastic_chunk: false
fusion_mlp:
weight: 0.1
adapt_in_features: max
hidden_sizes:
- 128
activation: leaky_relu
drop_rate: 0.1
normalization: layer_norm
data_types: null
data:
image: null
text: null
categorical:
minimum_cat_count: 100
maximum_num_cat: 20
convert_to_text: true
numerical:
convert_to_text: false
scaler_with_mean: true
scaler_with_std: true
optimization:
optim_type: adamw
learning_rate: 0.0001
weight_decay: 0.001
lr_choice: layerwise_decay
lr_decay: 0.8
lr_schedule: cosine_decay
max_epochs: 10
max_steps: -1
warmup_steps: 0.1
end_lr: 0
lr_mult: 1
patience: 10
val_check_interval: 0.5
top_k: 3
env:
num_gpus: -1
num_nodes: 1
batch_size: 128
per_gpu_batch_size: 8
per_gpu_batch_size_evaluation: 64
precision: 16
num_workers: 2
num_workers_evaluation: 2
fast_dev_run: false
deterministic: false
auto_select_gpus: true
strategy: ddp_spawn
The model
config provides four model types: MLP for categorical data
(categorical_mlp), MLP for numerical data (numerical_mlp),
huggingface
transformers for text
data (hf_text),
timm for image
data (timm_image), clip for image+text data, and a MLP to fuse any
combinations of categorical_mlp, numerical_mlp, hf_text, and
timm_image (fusion_mlp). We can specify the model combinations by
setting model.names
. Moreover, we can use
model.hf_text.checkpoint_name
and
model.timm_image.checkpoint_name
to customize huggingface and timm
backbones.
The data
config defines some model-agnostic rules in preprocessing
data. Note that AutoMMPredictor
converts categorical data into text
by default.
The optimization
config has hyper-parameters for model training.
AutoMMPredictor
uses layer-wise learning rate decay, which decreases
the learning rate gradually from the output to the input end of one
model.
The env
config contains the environment/machine related
hyper-parameters. For example, the optimal values of
per_gpu_batch_size
and per_gpu_batch_size_evaluation
are closely
related to the GPU memory size.
You can flexibly customize any hyper-parameter in config
via the
hyperparameters
argument of .fit()
. To access one
hyper-parameter in config
, you need to traverse from top-level keys
to bottom-level keys and join them together with .
For example, if
you want to change the per GPU batch size to 16, you can set
hyperparameters={"env.per_gpu_batch_size": 16}
.
APIs¶
Besides .fit()
and .evaluate()
, AutoMMPredictor
also
provides other useful APIs, similar to those in TextPredictor
and
TabularPredictor
. You may refer to more details in
Text Prediction - Quick Start.
Given data without ground truth labels, AutoMMPredictor
can make
predictions.
predictions = predictor.predict(test_data.drop(columns=label_col))
predictions[:5]
1873 2
8536 2
7988 2
10127 2
14668 1
Name: AdoptionSpeed, dtype: int64
For classification tasks, we can get the probabilities of all classes.
probas = predictor.predict_proba(test_data.drop(columns=label_col))
probas[:5]
0 | 1 | 2 | 3 | 4 | |
---|---|---|---|---|---|
1873 | 0.014537 | 0.295644 | 0.310817 | 0.122612 | 0.256390 |
8536 | 0.017496 | 0.269609 | 0.302648 | 0.154956 | 0.255291 |
7988 | 0.019225 | 0.257543 | 0.369263 | 0.139917 | 0.214052 |
10127 | 0.014647 | 0.264183 | 0.362841 | 0.116510 | 0.241819 |
14668 | 0.016556 | 0.375061 | 0.286553 | 0.089543 | 0.232288 |
Note that calling .predict_proba
on one regression task will throw
an exception.
Extract embeddings can be easily done via .extract_embedding()
.
embeddings = predictor.extract_embedding(test_data.drop(columns=label_col))
embeddings.shape
(100, 256)
It is also convenient to save and load a predictor.
predictor.save('my_saved_dir')
loaded_predictor = AutoMMPredictor.load('my_saved_dir')
scores2 = loaded_predictor.evaluate(test_data, metrics=["accuracy"])
scores2
{'accuracy': 0.19}