Few Shot Learning with FewShotSVMPredictor#

Open In Colab Open In SageMaker Studio Lab

In this tutorial we introduce a simple but effective way for few shot classification problems. We present the FusionSVM model which leverages the high-quality features from foundational models and use a simple SVM for few shot classification task. Specifically, we extract sample features with pretrained models, and use the features for SVM learning. We show the effectiveness of this FusionSVMModel on a text classification dataset and a vision classification dataset.

Text Classification on MLDoc dataset#

Load Dataset#

We prepare all datasets in the format of pd.DataFrame as in many of our tutorials have done. For this tutorial, we’ll use a small MLDoc dataset for demonstration. The dataset is a text classification dataset, which contains 4 classes and we downsampled the training data to 10 samples per class, a.k.a 10 shots. For more details regarding MLDoc please see this link.

import pandas as pd
import os

from autogluon.core.utils.loaders import load_zip

download_dir = "./ag_automm_tutorial_fs_cls"
zip_file = "https://automl-mm-bench.s3.amazonaws.com/nlp_datasets/MLDoc-10shot-en.zip"
load_zip.unzip(zip_file, unzip_dir=download_dir)
dataset_path = os.path.join(download_dir)
train_df = pd.read_csv(f"{dataset_path}/train.csv", names=["label", "text"])
test_df = pd.read_csv(f"{dataset_path}/test.csv", names=["label", "text"])
print(train_df)
print(test_df)
Downloading ./ag_automm_tutorial_fs_cls/file.zip from https://automl-mm-bench.s3.amazonaws.com/nlp_datasets/MLDoc-10shot-en.zip...
   label                                               text
0   GCAT  b'Secretary-General Kofi Annan expressed conce...
1   CCAT  b'The health of ABB Asea Brown Boveri AG\'s Po...
2   GCAT  b'Nepali Prime Minister Lokendra Bahadur Chand...
3   CCAT  b'Integ Inc said Thursday its net loss widened...
4   GCAT  b'These are the leading stories in the Skopje ...
5   ECAT  b'Fears of a slowdown in India\'s industrial g...
6   MCAT  b'The Australian Treasury will offer a total o...
7   CCAT  b'Malaysia\'s Suria Capital Holdings Bhd and M...
8   MCAT  b'The UK gilt repo market had a quiet session ...
9   CCAT  b"Commonwealth Edison Co's (ComEd) 794 megawat...
10  GCAT  b'Police arrested 47 people on Thursday in a c...
11  GCAT  b"Army troops in the Comoros island of Anjouan...
12  ECAT  b"The House Banking Committee is considering w...
13  GCAT  b'A possible international anti-drug centre in...
14  ECAT  b'Angela Knight, economic secretary to the Bri...
15  GCAT  b'Nearly 300 people were feared dead in floods...
16  MCAT  b'The Oslo stock index fell with other Europea...
17  ECAT  b'Morgan Keegan said it won $18.540 million of...
18  CCAT  b'Britons can bank on the phone, bank on the i...
19  CCAT  b"Standard Chartered Bank and Prudential Secur...
20  CCAT  b"United Water Resources Inc said it and Lyonn...
21  ECAT  b'Tanzania on Thursday unveiled its 1997/98 bu...
22  GCAT  b'U.S. President Bill Clinton will meet Prime ...
23  CCAT  b"Pacific Century Regional Developments Ltd sa...
24  MCAT  b'The Athens bourse ended 0.65 percent lower w...
25  ECAT  b'Sri Lanka broad money supply, or M2, is seen...
26  GCAT  b'Collated results of African Nations Cup prel...
27  GCAT  b'Philippine President Fidel Ramos said on Fri...
28  MCAT  b'Shanghai copper futures ended down on heavy ...
29  CCAT  b"Goldman Sachs & Co said on Monday that David...
30  ECAT  b'Maine\'s revenues were higher than forecast ...
31  CCAT  b'Thai animal feedmillers said on Monday they ...
32  MCAT  b"Worldwide trading volume in emerging markets...
33  ECAT  b'One week ended June 25 daily avgs-millions  ...
34  ECAT  b'Algeria\'s non-energy exports reached $688 m...
35  ECAT  b'U.S. seasonally adjusted retail sales rose 1...
36  MCAT  b'The Indonesian rupiah weakened against the d...
37  MCAT  b'Brazilian stocks ended slightly higher led b...
38  MCAT  b'The price of gold hung around the psychologi...
39  MCAT  b'The won closed stronger versus the dollar on...
     label                                               text
0     CCAT  b'RJR Nabisco Holdings Corp has prevailed over...
1     ECAT  b"Britain's economy grew 0.8 percent in the fo...
2     ECAT  b'Slovenia\'s state Institute of Macroeconomic...
3     CCAT  b"Belgium's second largest bank Credit Communa...
4     GCAT  b'The IRA ordered its guerrillas to observe a ...
...    ...                                                ...
3995  CCAT  b"A consortium comprising Itochu Corp and Hanj...
3996  ECAT  b"The volume of Hong Kong's domestic exports i...
3997  ECAT  b'The Danish finance ministry said on Tuesday ...
3998  GCAT  b'A court is to investigate charges that forme...
3999  MCAT  b"German consumers of feed grains, bread rye a...

[4000 rows x 2 columns]
100%|██████████| 2.59M/2.59M [00:00<00:00, 26.4MiB/s]

Create the FewShotSVMPredictor#

In order to run FusionSVM model, we first initialize a FewShotSVMPredictor with the following parameters.

from autogluon.multimodal.utils.few_shot_learning import FewShotSVMPredictor
hyperparameters = {
    "model.hf_text.checkpoint_name": "sentence-transformers/all-mpnet-base-v2",
    "model.hf_text.pooling_mode": "mean",
    "env.per_gpu_batch_size": 32,
    "env.eval_batch_size_ratio": 4,
}

import uuid
model_path = f"./tmp/{uuid.uuid4().hex}-automm_mldoc-10shot-en"
predictor = FewShotSVMPredictor(
    label="label",  # column name of the label
    hyperparameters=hyperparameters,
    eval_metric="acc",
    path=model_path  # path to save model and artifacts
)
/home/ci/autogluon/multimodal/src/autogluon/multimodal/data/utils.py:439: UserWarning: provided max length: 512 is smaller than sentence-transformers/all-mpnet-base-v2's default: 514
  warnings.warn(

Train the model#

Now we train the model with the train_df.

predictor.fit(train_df)
Saving into /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/4385b40f8e4248c68d3f908038268a9d-automm_mldoc-10shot-en/svm_model.pkl

Run evaluation#

result = predictor.evaluate(test_df, metrics=["acc", "macro_f1"])
print(result)
{'acc': 0.83575, 'macro_f1': 0.8344679316932194}

Comparing to the normal MultiModalPredictor#

from autogluon.multimodal import MultiModalPredictor
import numpy as np
from sklearn.metrics import f1_score

hyperparameters = {
    "model.hf_text.checkpoint_name": "sentence-transformers/all-mpnet-base-v2",
    "model.hf_text.pooling_mode": "mean",
}

automm_predictor = MultiModalPredictor(
    label="label",
    problem_type="classification",
    eval_metric="acc"
)

automm_predictor.fit(
    train_data=train_df,
    presets="multilingual",
    hyperparameters=hyperparameters,
)

results, preds = automm_predictor.evaluate(test_df, return_pred=True)
test_labels = np.array(test_df["label"])
macro_f1 = f1_score(test_labels, preds, average="macro")
results["macro_f1"] = macro_f1

print(results)
{'acc': 0.659, 'macro_f1': 0.6650039441527744}
No path specified. Models will be saved in: "AutogluonModels/ag-20230302_162814/"
Detected data scarcity. Consider running using the preset 'few_shot_text_classification' for better performance.
AutoMM starts to create your model. ✨

- Model will be saved to "/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20230302_162814".

- Validation metric is "acc".

- To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20230302_162814
    ```

Enjoy your coffee, and let AutoMM do the job ☕☕☕ Learn more at https://auto.gluon.ai

/home/ci/autogluon/multimodal/src/autogluon/multimodal/data/utils.py:439: UserWarning: provided max length: 512 is smaller than sentence-transformers/all-mpnet-base-v2's default: 514
  warnings.warn(
/home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/environment.py:102: UserWarning: bf16 is not supported by the GPU device / cuda version. Consider using GPU devices with versions after Amphere or upgrading cuda to be >=11.0. MultiModalPredictor is switching precision from bf16 to 32.
  warnings.warn(
/home/ci/opt/venv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1609: PossibleUserWarning: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  rank_zero_warn(
AutoMM has created your model 🎉🎉🎉

- To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20230302_162814")
    ```

- You can open a terminal and launch Tensorboard to visualize the training log:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20230302_162814
    ```

- If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub: https://github.com/autogluon/autogluon

As you can see that the FewShotSVMPredictor performs much better than the normal MultiModalPredictor.

Load a pretrained model#

The FewShotSVMPredictor automatically saves the model and artifacts to disk during training. You can specify the path to save by setting the path=<your_desired_save_path> when initializing the predictor. You can also load a pretrained FewShotSVMPredictor and perform downstream tasks by the following code:

predictor2 = FewShotSVMPredictor.load(model_path)
result2 = predictor2.evaluate(test_df, metrics=["acc", "macro_f1"])
print(result2)
{'acc': 0.83575, 'macro_f1': 0.8344679316932194}
/home/ci/autogluon/multimodal/src/autogluon/multimodal/data/utils.py:439: UserWarning: provided max length: 512 is smaller than sentence-transformers/all-mpnet-base-v2's default: 514
  warnings.warn(
Loading from ./tmp/4385b40f8e4248c68d3f908038268a9d-automm_mldoc-10shot-en/svm_model.pkl

Image Classification on Stanford Cars#

Load Dataset#

We also provide an example of using FewShotSVMPredictor on a few-shot image classification task. We use the Stanford Cars dataset for demonstration and downsampled the training set to have 8 samples per class. The Stanford Cars is an image classification dataset and contains 196 classes. For more information regarding the dataset, please see here.

import pandas as pd
import os

from autogluon.core.utils.loaders import load_zip, load_s3

download_dir = "./ag_automm_tutorial_fs_cls/stanfordcars/"
zip_file = "https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/stanfordcars.zip"
train_csv = "https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/train_8shot.csv"
test_csv = "https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/test.csv"

load_zip.unzip(zip_file, unzip_dir=download_dir)
dataset_path = os.path.join(download_dir)
Downloading ./ag_automm_tutorial_fs_cls/stanfordcars//file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/stanfordcars.zip...
100%|██████████| 1.96G/1.96G [00:59<00:00, 33.1MiB/s]
Unzipping ./ag_automm_tutorial_fs_cls/stanfordcars//file.zip to ./ag_automm_tutorial_fs_cls/stanfordcars/
!wget https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/train_8shot.csv -O ./ag_automm_tutorial_fs_cls/stanfordcars/train.csv
!wget https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/test.csv -O ./ag_automm_tutorial_fs_cls/stanfordcars/test.csv
--2023-03-02 16:35:07--  https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/train_8shot.csv
Resolving automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)... 52.217.225.57, 52.217.234.1, 54.231.194.121, ...
Connecting to automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)|52.217.225.57|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 141918 (139K) [text/csv]
Saving to: ‘./ag_automm_tutorial_fs_cls/stanfordcars/train.csv’

./ag_automm_tutoria 100%[===================>] 138.59K  --.-KB/s    in 0.002s  

2023-03-02 16:35:07 (61.1 MB/s) - ‘./ag_automm_tutorial_fs_cls/stanfordcars/train.csv’ saved [141918/141918]

--2023-03-02 16:35:07--  https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/test.csv
Resolving automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)... 52.216.57.225, 3.5.1.140, 52.216.137.252, ...
Connecting to automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)|52.216.57.225|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 719335 (702K) [text/csv]
Saving to: ‘./ag_automm_tutorial_fs_cls/stanfordcars/test.csv’

./ag_automm_tutoria 100%[===================>] 702.48K  --.-KB/s    in 0.01s   

2023-03-02 16:35:07 (54.5 MB/s) - ‘./ag_automm_tutorial_fs_cls/stanfordcars/test.csv’ saved [719335/719335]
train_df_raw = pd.read_csv(os.path.join(download_dir, "train.csv"))
train_df = train_df_raw.drop(
        columns=[
            "Source",
            "Confidence",
            "XMin",
            "XMax",
            "YMin",
            "YMax",
            "IsOccluded",
            "IsTruncated",
            "IsGroupOf",
            "IsDepiction",
            "IsInside",
        ]
    )
train_df["ImageID"] = download_dir + train_df["ImageID"].astype(str)


test_df_raw = pd.read_csv(os.path.join(download_dir, "test.csv"))
test_df = test_df_raw.drop(
        columns=[
            "Source",
            "Confidence",
            "XMin",
            "XMax",
            "YMin",
            "YMax",
            "IsOccluded",
            "IsTruncated",
            "IsGroupOf",
            "IsDepiction",
            "IsInside",
        ]
    )
test_df["ImageID"] = download_dir + test_df["ImageID"].astype(str)

print(os.path.exists(train_df.iloc[0]["ImageID"]))
print(train_df)
print(os.path.exists(test_df.iloc[0]["ImageID"]))
print(test_df)
True
                                                ImageID  LabelName
0     ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...        164
1     ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...          3
2     ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...        125
3     ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...         51
4     ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...        139
...                                                 ...        ...
1563  ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...        124
1564  ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...         94
1565  ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...          7
1566  ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...        174
1567  ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...        194

[1568 rows x 2 columns]
True
                                                ImageID  LabelName
0     ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...        181
1     ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...        124
2     ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...        189
3     ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...         97
4     ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...        121
...                                                 ...        ...
8036  ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...         66
8037  ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...        120
8038  ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...          8
8039  ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...         13
8040  ./ag_automm_tutorial_fs_cls/stanfordcars/stanf...        161

[8041 rows x 2 columns]

Create the FewShotSVMPredictor#

In order to run FusionSVM model, we first initialize a FewShotSVMPredictor with the following parameters.

from autogluon.multimodal.utils.few_shot_learning import FewShotSVMPredictor
hyperparameters = {
    "model.names": ["clip"],
    "model.clip.max_text_len": 0,
    "env.num_workers": 2,
    "model.clip.checkpoint_name": "openai/clip-vit-large-patch14-336",
    "env.eval_batch_size_ratio": 1,
}

import uuid
model_path = f"./tmp/{uuid.uuid4().hex}-automm_stanfordcars-8shot-en"
predictor = FewShotSVMPredictor(
    label="LabelName",  # column name of the label
    hyperparameters=hyperparameters,
    eval_metric="acc",
    path=model_path  # path to save model and artifacts
)
The model does not support using an image size that is different from the default size. Provided image size=224. Default size=336. Detailed model configuration=CLIPConfig {
  "_commit_hash": "ce19dc912ca5cd21c8a653c79e251e808ccabcd1",
  "_name_or_path": "openai/clip-vit-large-patch14-336",
  "architectures": [
    "CLIPModel"
  ],
  "initializer_factor": 1.0,
  "logit_scale_init_value": 2.6592,
  "model_type": "clip",
  "projection_dim": 768,
  "text_config": {
    "_name_or_path": "",
    "add_cross_attention": false,
    "architectures": null,
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 0,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout": 0.0,
    "early_stopping": false,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": 2,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "hidden_act": "quick_gelu",
    "hidden_size": 768,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "initializer_factor": 1.0,
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "is_decoder": false,
    "is_encoder_decoder": false,
    "label2id": {
      "LABEL_0": 0,
      "LABEL_1": 1
    },
    "layer_norm_eps": 1e-05,
    "length_penalty": 1.0,
    "max_length": 20,
    "max_position_embeddings": 77,
    "min_length": 0,
    "model_type": "clip_text_model",
    "no_repeat_ngram_size": 0,
    "num_attention_heads": 12,
    "num_beam_groups": 1,
    "num_beams": 1,
    "num_hidden_layers": 12,
    "num_return_sequences": 1,
    "output_attentions": false,
    "output_hidden_states": false,
    "output_scores": false,
    "pad_token_id": 1,
    "prefix": null,
    "problem_type": null,
    "projection_dim": 768,
    "pruned_heads": {},
    "remove_invalid_values": false,
    "repetition_penalty": 1.0,
    "return_dict": true,
    "return_dict_in_generate": false,
    "sep_token_id": null,
    "suppress_tokens": null,
    "task_specific_params": null,
    "temperature": 1.0,
    "tf_legacy_loss": false,
    "tie_encoder_decoder": false,
    "tie_word_embeddings": true,
    "tokenizer_class": null,
    "top_k": 50,
    "top_p": 1.0,
    "torch_dtype": null,
    "torchscript": false,
    "transformers_version": "4.26.1",
    "typical_p": 1.0,
    "use_bfloat16": false,
    "vocab_size": 49408
  },
  "text_config_dict": {
    "hidden_size": 768,
    "intermediate_size": 3072,
    "num_attention_heads": 12,
    "num_hidden_layers": 12,
    "projection_dim": 768
  },
  "torch_dtype": "float32",
  "transformers_version": null,
  "vision_config": {
    "_name_or_path": "",
    "add_cross_attention": false,
    "architectures": null,
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": null,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout": 0.0,
    "early_stopping": false,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": null,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "hidden_act": "quick_gelu",
    "hidden_size": 1024,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "image_size": 336,
    "initializer_factor": 1.0,
    "initializer_range": 0.02,
    "intermediate_size": 4096,
    "is_decoder": false,
    "is_encoder_decoder": false,
    "label2id": {
      "LABEL_0": 0,
      "LABEL_1": 1
    },
    "layer_norm_eps": 1e-05,
    "length_penalty": 1.0,
    "max_length": 20,
    "min_length": 0,
    "model_type": "clip_vision_model",
    "no_repeat_ngram_size": 0,
    "num_attention_heads": 16,
    "num_beam_groups": 1,
    "num_beams": 1,
    "num_channels": 3,
    "num_hidden_layers": 24,
    "num_return_sequences": 1,
    "output_attentions": false,
    "output_hidden_states": false,
    "output_scores": false,
    "pad_token_id": null,
    "patch_size": 14,
    "prefix": null,
    "problem_type": null,
    "projection_dim": 768,
    "pruned_heads": {},
    "remove_invalid_values": false,
    "repetition_penalty": 1.0,
    "return_dict": true,
    "return_dict_in_generate": false,
    "sep_token_id": null,
    "suppress_tokens": null,
    "task_specific_params": null,
    "temperature": 1.0,
    "tf_legacy_loss": false,
    "tie_encoder_decoder": false,
    "tie_word_embeddings": true,
    "tokenizer_class": null,
    "top_k": 50,
    "top_p": 1.0,
    "torch_dtype": null,
    "torchscript": false,
    "transformers_version": "4.26.1",
    "typical_p": 1.0,
    "use_bfloat16": false
  },
  "vision_config_dict": {
    "hidden_size": 1024,
    "image_size": 336,
    "intermediate_size": 4096,
    "num_attention_heads": 16,
    "num_hidden_layers": 24,
    "patch_size": 14,
    "projection_dim": 768
  }
}
. We have ignored the provided image size.

Train the model#

Now we train the model with the train_df.

predictor.fit(train_df)
Saving into /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/cf82299462be44e8ba488385f82165e9-automm_stanfordcars-8shot-en/svm_model.pkl

Run evaluation#

result = predictor.evaluate(test_df, metrics=["acc", "macro_f1"])
print(result)
{'acc': 0.814202213655018, 'macro_f1': 0.8138817523693703}

Comparing to the normal MultiModalPredictor#

from autogluon.multimodal import MultiModalPredictor
import numpy as np
from sklearn.metrics import f1_score


hyperparameters = {
    "model.names": ["timm_image"],
    "model.timm_image.checkpoint_name": "swin_base_patch4_window7_224",
    "env.per_gpu_batch_size": 8,
    "optimization.max_epochs": 10,
    "optimization.learning_rate": 1.0e-3,
    "optimization.optim_type": "adamw",
    "optimization.weight_decay": 1.0e-3,
}

automm_predictor = MultiModalPredictor(
    label="LabelName",  # column name of the label
    hyperparameters=hyperparameters,
    problem_type="classification",
    eval_metric="acc",
)
automm_predictor.fit(
    train_data=train_df,
)

results, preds = automm_predictor.evaluate(test_df, return_pred=True)
test_labels = np.array(test_df["LabelName"])
macro_f1 = f1_score(test_labels, preds, average="macro")
results["macro_f1"] = macro_f1

print(results)
{'acc': 0.20669071011068274, 'macro_f1': 0.18549266831995712}
No path specified. Models will be saved in: "AutogluonModels/ag-20230302_164500/"
AutoMM starts to create your model. ✨

- Model will be saved to "/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20230302_164500".

- Validation metric is "acc".

- To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20230302_164500
    ```

Enjoy your coffee, and let AutoMM do the job ☕☕☕ Learn more at https://auto.gluon.ai

/home/ci/opt/venv/lib/python3.8/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth" to /home/ci/.cache/torch/hub/checkpoints/swin_base_patch4_window7_224_22kto1k.pth
Start to fuse 3 checkpoints via the greedy soup algorithm.
AutoMM has created your model 🎉🎉🎉

- To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20230302_164500")
    ```

- You can open a terminal and launch Tensorboard to visualize the training log:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20230302_164500
    ```

- If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub: https://github.com/autogluon/autogluon

As you can see that the FewShotSVMPredictor performs much better than the normal MultiModalPredictor in image classification as well.

Citation#

@InProceedings{SCHWENK18.658,
  author = {Holger Schwenk and Xian Li},
  title = {A Corpus for Multilingual Document Classification in Eight Languages},
  booktitle = {Proceedings of the Eleventh International Conference on Language Resources and Evaluation (LREC 2018)},
  year = {2018},
  month = {may},
  date = {7-12},
  location = {Miyazaki, Japan},
  editor = {Nicoletta Calzolari (Conference chair) and Khalid Choukri and Christopher Cieri and Thierry Declerck and Sara Goggi and Koiti Hasida and Hitoshi Isahara and Bente Maegaard and Joseph Mariani and Hélène Mazo and Asuncion Moreno and Jan Odijk and Stelios Piperidis and Takenobu Tokunaga},
  publisher = {European Language Resources Association (ELRA)},
  address = {Paris, France},
  isbn = {979-10-95546-00-9},
  language = {english}
  }
  
@inproceedings{KrauseStarkDengFei-Fei_3DRR2013,
  title = {3D Object Representations for Fine-Grained Categorization},
  booktitle = {4th International IEEE Workshop on  3D Representation and Recognition (3dRR-13)},
  year = {2013},
  address = {Sydney, Australia},
  author = {Jonathan Krause and Michael Stark and Jia Deng and Li Fei-Fei}
}