.. _sec_automm_predictor: AutoMMPredictor for Image, Text, and Tabular ============================================ Are you tired of switching codebases or hacking code for different data modalities (image, text, numerical, and categorical data) and tasks (classification, regression, and more)? ``AutoMMPredictor`` provides a one-stop shop for multimodal/unimodal deep learning models. This tutorial demonstrates several application scenarios. - Multimodal Prediction - CLIP - TIMM + Huggingface Transformers + More - Image Prediction - Text Prediction - Configuration Customization - APIs .. code:: python import os import numpy as np import warnings warnings.filterwarnings('ignore') np.random.seed(123) Dataset ------- For demonstration, we use the `PetFinder dataset `__. The PetFinder dataset provides information about shelter animals that appear on their adoption profile to predict the animals' adoption rates, grouped into five categories, hence a multi-class classification problem. To get started, let's download and prepare the dataset. .. code:: python download_dir = './ag_automm_tutorial' zip_file = 'https://automl-mm-bench.s3.amazonaws.com/petfinder_kaggle.zip' from autogluon.core.utils.loaders import load_zip load_zip.unzip(zip_file, unzip_dir=download_dir) .. parsed-literal:: :class: output Downloading ./ag_automm_tutorial/file.zip from https://automl-mm-bench.s3.amazonaws.com/petfinder_kaggle.zip... .. parsed-literal:: :class: output 100%|██████████| 2.00G/2.00G [01:34<00:00, 21.2MiB/s] Next, we will load the CSV files. .. code:: python import pandas as pd dataset_path = download_dir + '/petfinder_processed' train_data = pd.read_csv(f'{dataset_path}/train.csv', index_col=0) test_data = pd.read_csv(f'{dataset_path}/dev.csv', index_col=0) label_col = 'AdoptionSpeed' We need to expand the image paths to load them in training. .. code:: python image_col = 'Images' train_data[image_col] = train_data[image_col].apply(lambda ele: ele.split(';')[0]) # Use the first image for a quick tutorial test_data[image_col] = test_data[image_col].apply(lambda ele: ele.split(';')[0]) def path_expander(path, base_folder): path_l = path.split(';') return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l]) train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path)) test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path)) train_data[image_col].iloc[0] .. parsed-literal:: :class: output '/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/ag_automm_tutorial/petfinder_processed/train_images/e4b90955c-1.jpg' Each animal's adoption profile includes pictures, a text description, and various tabular features such as age, breed, name, color, and more. Let's look at an example row of data and display the text description and a picture. .. code:: python example_row = train_data.iloc[47] example_row .. parsed-literal:: :class: output Type 2 Name Money Age 4 Breed1 266 Breed2 0 Gender 2 Color1 1 Color2 2 Color3 7 MaturitySize 1 FurLength 2 Vaccinated 2 Dewormed 1 Sterilized 2 Health 1 Quantity 1 Fee 0 State 41401 RescuerID ee7445af32acfa1dc8307a9dc7baed21 VideoAmt 0 Description My pet is a pretty beautiful kitty which has a... PetID 98c08df17 PhotoAmt 2.0 AdoptionSpeed 2 Images /var/lib/jenkins/workspace/workspace/autogluon... Name: 14845, dtype: object .. code:: python example_row['Description'] .. parsed-literal:: :class: output 'My pet is a pretty beautiful kitty which has a mixed colour soft fur. She is active and full of life. And one thing about her, she loves to eat.She always turn on me like a tiger when I was preparing the food for her.' .. code:: python example_image = example_row['Images'] from IPython.display import Image, display pil_img = Image(filename=example_image) display(pil_img) .. figure:: output_automm_963233_11_0.jpg For the demo purpose, we will sample 500 and 100 rows for training and testing, respectively. .. code:: python train_data = train_data.sample(500, random_state=0) test_data = test_data.sample(100, random_state=0) Multimodal Prediction --------------------- CLIP ~~~~ ``AutoMMPredictor`` allows for finetuning the pre-trained vision language models, such as `CLIP `__. .. code:: python from autogluon.text.automm import AutoMMPredictor predictor = AutoMMPredictor(label=label_col) predictor.fit( train_data=train_data, hyperparameters={ "model.names": ["clip"], "env.num_gpus": 1, }, time_limit=120, # seconds ) .. parsed-literal:: :class: output Global seed set to 123 Using 16bit native Automatic Mixed Precision (AMP) GPU available: True, used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ------------------------------------------------------- 0 | model | CLIPForImageText | 151 M 1 | validation_metric | Accuracy | 0 2 | loss_func | CrossEntropyLoss | 0 ------------------------------------------------------- 151 M Trainable params 0 Non-trainable params 151 M Total params 302.560 Total estimated model params size (MB) Global seed set to 123 Epoch 0, global step 1: val_accuracy reached 0.27000 (best 0.27000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=0-step=1.ckpt" as top 3 Epoch 0, global step 3: val_accuracy reached 0.30000 (best 0.30000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=0-step=3.ckpt" as top 3 Epoch 1, global step 5: val_accuracy reached 0.29000 (best 0.30000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=1-step=5.ckpt" as top 3 Epoch 1, global step 7: val_accuracy was not in top 3 Epoch 2, global step 9: val_accuracy reached 0.29000 (best 0.30000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=2-step=9.ckpt" as top 3 Epoch 2, global step 11: val_accuracy reached 0.33000 (best 0.33000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=2-step=11.ckpt" as top 3 Epoch 3, global step 13: val_accuracy reached 0.33000 (best 0.33000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=3-step=13.ckpt" as top 3 Epoch 3, global step 15: val_accuracy was not in top 3 Epoch 4, global step 17: val_accuracy reached 0.33000 (best 0.33000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053632/epoch=4-step=17.ckpt" as top 3 Time limit reached. Elapsed time is 0:02:00. Signaling Trainer to stop. Epoch 4, global step 18: val_accuracy was not in top 3 .. parsed-literal:: :class: output .. code:: python scores = predictor.evaluate(test_data, metrics=["accuracy"]) scores .. parsed-literal:: :class: output {'accuracy': 0.33} In this example, ``AutoMMPredictor`` finetunes CLIP with the image, text, and categorical (converted to text) data. TIMM + Huggingface Transformers + More ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In addtion to CLIP, ``AutoMMPredictor`` can simultaneously finetune various `timm `__ backbones and `huggingface transformers `__. Moreover, ``AutoMMPredictor`` uses MLP for numerical data but converts categorical data to text by default. Let's use ``AutoMMPredictor`` to train a late fusion model including `CLIP `__, `swin\_small\_patch4\_window7\_224 `__, `google/electra-small-discriminator `__, a numerical MLP, and a fusion MLP. .. code:: python from autogluon.text.automm import AutoMMPredictor predictor = AutoMMPredictor(label=label_col) predictor.fit( train_data=train_data, hyperparameters={ "model.names": ["clip", "timm_image", "hf_text", "numerical_mlp", "fusion_mlp"], "model.timm_image.checkpoint_name": "swin_small_patch4_window7_224", "model.hf_text.checkpoint_name": "google/electra-small-discriminator", "env.num_gpus": 1, }, time_limit=120, # seconds ) .. parsed-literal:: :class: output Global seed set to 123 Using 16bit native Automatic Mixed Precision (AMP) GPU available: True, used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ---------------------------------------------------------- 0 | model | MultimodalFusionMLP | 215 M 1 | validation_metric | Accuracy | 0 2 | loss_func | CrossEntropyLoss | 0 ---------------------------------------------------------- 215 M Trainable params 0 Non-trainable params 215 M Total params 430.576 Total estimated model params size (MB) Global seed set to 123 Epoch 0, global step 1: val_accuracy reached 0.16000 (best 0.16000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053904/epoch=0-step=1.ckpt" as top 3 Epoch 0, global step 3: val_accuracy reached 0.20000 (best 0.20000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053904/epoch=0-step=3.ckpt" as top 3 Epoch 1, global step 5: val_accuracy reached 0.32000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053904/epoch=1-step=5.ckpt" as top 3 Epoch 1, global step 7: val_accuracy was not in top 3 Epoch 2, global step 9: val_accuracy reached 0.17000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_053904/epoch=2-step=9.ckpt" as top 3 Time limit reached. Elapsed time is 0:02:03. Signaling Trainer to stop. .. parsed-literal:: :class: output .. code:: python scores = predictor.evaluate(test_data, metrics=["accuracy"]) scores .. parsed-literal:: :class: output {'accuracy': 0.43} Image Prediction ---------------- If you want to use only image data or your tasks only have image data, ``AutoMMPredictor`` can help you finetune a wide range of `timm `__ backbones, such as `swin\_small\_patch4\_window7\_224 `__. .. code:: python from autogluon.text.automm import AutoMMPredictor predictor = AutoMMPredictor(label=label_col) predictor.fit( train_data=train_data, hyperparameters={ "model.names": ["timm_image"], "model.timm_image.checkpoint_name": "swin_tiny_patch4_window7_224", "env.num_gpus": 1, }, time_limit=60, # seconds ) .. parsed-literal:: :class: output Global seed set to 123 Using 16bit native Automatic Mixed Precision (AMP) GPU available: True, used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ---------------------------------------------------------------------- 0 | model | TimmAutoModelForImagePrediction | 27.5 M 1 | validation_metric | Accuracy | 0 2 | loss_func | CrossEntropyLoss | 0 ---------------------------------------------------------------------- 27.5 M Trainable params 0 Non-trainable params 27.5 M Total params 55.046 Total estimated model params size (MB) Global seed set to 123 Epoch 0, global step 1: val_accuracy reached 0.23000 (best 0.23000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=0-step=1.ckpt" as top 3 Epoch 0, global step 3: val_accuracy reached 0.31000 (best 0.31000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=0-step=3.ckpt" as top 3 Epoch 1, global step 5: val_accuracy reached 0.32000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=1-step=5.ckpt" as top 3 Epoch 1, global step 7: val_accuracy reached 0.35000 (best 0.35000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=1-step=7.ckpt" as top 3 Epoch 2, global step 9: val_accuracy was not in top 3 Epoch 2, global step 11: val_accuracy was not in top 3 Epoch 3, global step 13: val_accuracy was not in top 3 Epoch 3, global step 15: val_accuracy reached 0.34000 (best 0.35000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=3-step=15.ckpt" as top 3 Epoch 4, global step 17: val_accuracy reached 0.36000 (best 0.36000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=4-step=17.ckpt" as top 3 Epoch 4, global step 19: val_accuracy reached 0.36000 (best 0.36000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054142/epoch=4-step=19.ckpt" as top 3 Epoch 5, global step 21: val_accuracy was not in top 3 Epoch 5, global step 23: val_accuracy was not in top 3 Epoch 6, global step 25: val_accuracy was not in top 3 Epoch 6, global step 27: val_accuracy was not in top 3 Time limit reached. Elapsed time is 0:01:01. Signaling Trainer to stop. Epoch 7, global step 28: val_accuracy was not in top 3 .. parsed-literal:: :class: output Here ``AutoMMPredictor`` uses only image data since ``model.names`` only include ``timm_image``. .. code:: python scores = predictor.evaluate(test_data, metrics=["accuracy"]) scores .. parsed-literal:: :class: output {'accuracy': 0.34} Text Prediction --------------- Similarly, you may be interested in only finetuning the text backbones from `huggingface transformers `__, such as `google/electra-small-discriminator `__. .. code:: python from autogluon.text.automm import AutoMMPredictor predictor = AutoMMPredictor(label=label_col) predictor.fit( train_data=train_data, hyperparameters={ "model.names": ["hf_text"], "model.hf_text.checkpoint_name": "google/electra-small-discriminator", "env.num_gpus": 1, }, time_limit=60, # seconds ) .. parsed-literal:: :class: output Global seed set to 123 Using 16bit native Automatic Mixed Precision (AMP) GPU available: True, used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ------------------------------------------------------------------- 0 | model | HFAutoModelForTextPrediction | 13.5 M 1 | validation_metric | Accuracy | 0 2 | loss_func | CrossEntropyLoss | 0 ------------------------------------------------------------------- 13.5 M Trainable params 0 Non-trainable params 13.5 M Total params 26.969 Total estimated model params size (MB) Global seed set to 123 Epoch 0, global step 1: val_accuracy reached 0.32000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=0-step=1.ckpt" as top 3 Epoch 0, global step 3: val_accuracy reached 0.22000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=0-step=3.ckpt" as top 3 Epoch 1, global step 5: val_accuracy reached 0.23000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=1-step=5.ckpt" as top 3 Epoch 1, global step 7: val_accuracy was not in top 3 Epoch 2, global step 9: val_accuracy was not in top 3 Epoch 2, global step 11: val_accuracy reached 0.28000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=2-step=11.ckpt" as top 3 Epoch 3, global step 13: val_accuracy reached 0.31000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=3-step=13.ckpt" as top 3 Epoch 3, global step 15: val_accuracy reached 0.31000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=3-step=15.ckpt" as top 3 Epoch 4, global step 17: val_accuracy was not in top 3 Epoch 4, global step 19: val_accuracy reached 0.32000 (best 0.32000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=4-step=19.ckpt" as top 3 Epoch 5, global step 21: val_accuracy reached 0.33000 (best 0.33000), saving model to "/var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/AutogluonModels/ag-20220312_054251/epoch=5-step=21.ckpt" as top 3 Epoch 5, global step 23: val_accuracy was not in top 3 Epoch 6, global step 25: val_accuracy was not in top 3 Epoch 6, global step 27: val_accuracy was not in top 3 Epoch 7, global step 29: val_accuracy was not in top 3 Epoch 7, global step 31: val_accuracy was not in top 3 Epoch 8, global step 33: val_accuracy was not in top 3 Epoch 8, global step 35: val_accuracy was not in top 3 Epoch 9, global step 37: val_accuracy was not in top 3 Epoch 9, global step 39: val_accuracy was not in top 3 Saving latest checkpoint... .. parsed-literal:: :class: output With only ``hf_text`` in ``model.names``, ``AutoMMPredictor`` automatically uses only text and categorical (converted to text) data. .. code:: python scores = predictor.evaluate(test_data, metrics=["accuracy"]) scores .. parsed-literal:: :class: output {'accuracy': 0.19} Configuration Customization --------------------------- The above examples have shown the flexibility of ``AutoMMPredictor``. You may want to know how to customize configurations for your tasks. Fortunately, ``AutoMMPredictor`` has a user-friendly configuration design. First, let's see the available model presets. .. code:: python from autogluon.text.automm.presets import list_model_presets, get_preset model_presets = list_model_presets() model_presets .. parsed-literal:: :class: output ['fusion_mlp_image_text_tabular'] Currently, ``AutoMMPredictor`` has only one model preset, from which we can construct the predictor's preset. .. code:: python preset = get_preset(model_presets[0]) preset .. parsed-literal:: :class: output {'model': 'fusion_mlp_image_text_tabular', 'data': 'default', 'optimization': 'adamw', 'environment': 'default'} ``AutoMMPredictor`` configurations consist of four parts: ``model``, ``data``, ``optimization``, and ``environment``. You can convert the preset to configurations to see the details. .. code:: python from omegaconf import OmegaConf from autogluon.text.automm.utils import get_config config = get_config(preset) print(OmegaConf.to_yaml(config)) .. parsed-literal:: :class: output model: names: - categorical_mlp - numerical_mlp - hf_text - timm_image - clip - fusion_mlp categorical_mlp: hidden_size: 64 activation: leaky_relu num_layers: 1 drop_rate: 0.1 normalization: layer_norm data_types: - categorical numerical_mlp: hidden_size: 128 activation: leaky_relu num_layers: 1 drop_rate: 0.1 normalization: layer_norm data_types: - numerical merge: concat hf_text: checkpoint_name: google/electra-base-discriminator data_types: - text tokenizer_name: hf_auto max_text_len: 512 insert_sep: true text_segment_num: 2 stochastic_chunk: false timm_image: checkpoint_name: swin_base_patch4_window7_224 mix_choice: all_logits data_types: - image train_transform_types: - resize_shorter_side - center_crop val_transform_types: - resize_shorter_side - center_crop image_norm: imagenet image_size: 224 max_img_num_per_col: 3 clip: checkpoint_name: openai/clip-vit-base-patch32 data_types: - image - text train_transform_types: - resize_shorter_side - center_crop val_transform_types: - resize_shorter_side - center_crop image_norm: clip image_size: 224 max_img_num_per_col: 0 tokenizer_name: clip max_text_len: 77 insert_sep: false text_segment_num: 1 stochastic_chunk: false fusion_mlp: weight: 0.1 adapt_in_features: max hidden_sizes: - 128 activation: leaky_relu drop_rate: 0.1 normalization: layer_norm data_types: null data: image: null text: null categorical: minimum_cat_count: 100 maximum_num_cat: 20 convert_to_text: true numerical: convert_to_text: false scaler_with_mean: true scaler_with_std: true optimization: optim_type: adamw learning_rate: 0.0001 weight_decay: 0.001 lr_choice: layerwise_decay lr_decay: 0.8 lr_schedule: cosine_decay max_epochs: 10 max_steps: -1 warmup_steps: 0.1 end_lr: 0 lr_mult: 1 patience: 10 val_check_interval: 0.5 top_k: 3 env: num_gpus: -1 num_nodes: 1 batch_size: 128 per_gpu_batch_size: 8 per_gpu_batch_size_evaluation: 64 precision: 16 num_workers: 2 num_workers_evaluation: 2 fast_dev_run: false deterministic: false auto_select_gpus: true strategy: ddp_spawn The ``model`` config provides four model types: MLP for categorical data (categorical\_mlp), MLP for numerical data (numerical\_mlp), `huggingface transformers `__ for text data (hf\_text), `timm `__ for image data (timm\_image), clip for image+text data, and a MLP to fuse any combinations of categorical\_mlp, numerical\_mlp, hf\_text, and timm\_image (fusion\_mlp). We can specify the model combinations by setting ``model.names``. Moreover, we can use ``model.hf_text.checkpoint_name`` and ``model.timm_image.checkpoint_name`` to customize huggingface and timm backbones. The ``data`` config defines some model-agnostic rules in preprocessing data. Note that ``AutoMMPredictor`` converts categorical data into text by default. The ``optimization`` config has hyper-parameters for model training. ``AutoMMPredictor`` uses layer-wise learning rate decay, which decreases the learning rate gradually from the output to the input end of one model. The ``env`` config contains the environment/machine related hyper-parameters. For example, the optimal values of ``per_gpu_batch_size`` and ``per_gpu_batch_size_evaluation`` are closely related to the GPU memory size. You can flexibly customize any hyper-parameter in ``config`` via the ``hyperparameters`` argument of ``.fit()``. To access one hyper-parameter in ``config``, you need to traverse from top-level keys to bottom-level keys and join them together with ``.`` For example, if you want to change the per GPU batch size to 16, you can set ``hyperparameters={"env.per_gpu_batch_size": 16}``. APIs ---- Besides ``.fit()`` and ``.evaluate()``, ``AutoMMPredictor`` also provides other useful APIs, similar to those in ``TextPredictor`` and ``TabularPredictor``. You may refer to more details in :ref:`sec_textprediction_beginner`. Given data without ground truth labels, ``AutoMMPredictor`` can make predictions. .. code:: python predictions = predictor.predict(test_data.drop(columns=label_col)) predictions[:5] .. parsed-literal:: :class: output 1873 2 8536 2 7988 2 10127 2 14668 1 Name: AdoptionSpeed, dtype: int64 For classification tasks, we can get the probabilities of all classes. .. code:: python probas = predictor.predict_proba(test_data.drop(columns=label_col)) probas[:5] .. raw:: html
0 1 2 3 4
1873 0.014537 0.295644 0.310817 0.122612 0.256390
8536 0.017496 0.269609 0.302648 0.154956 0.255291
7988 0.019225 0.257543 0.369263 0.139917 0.214052
10127 0.014647 0.264183 0.362841 0.116510 0.241819
14668 0.016556 0.375061 0.286553 0.089543 0.232288
Note that calling ``.predict_proba`` on one regression task will throw an exception. Extract embeddings can be easily done via ``.extract_embedding()``. .. code:: python embeddings = predictor.extract_embedding(test_data.drop(columns=label_col)) embeddings.shape .. parsed-literal:: :class: output (100, 256) It is also convenient to save and load a predictor. .. code:: python predictor.save('my_saved_dir') loaded_predictor = AutoMMPredictor.load('my_saved_dir') scores2 = loaded_predictor.evaluate(test_data, metrics=["accuracy"]) scores2 .. parsed-literal:: :class: output {'accuracy': 0.19}