Knowledge Distillation in AutoMM

Open In Colab Open In SageMaker Studio Lab

Pretrained foundation models are becoming increasingly large. However, these models are difficult to deploy due to limited resources available in deployment scenarios. To benefit from large models under this constraint, you transfer the knowledge from the large-scale teacher models to the student model, with knowledge distillation. In this way, the small student model can be practically deployed under real-world scenarios, while the performance will be better than training the student model from scratch thanks to the teacher.

In this tutorial, we introduce how to adopt MultiModalPredictor for knowledge distillation. For the purpose of demonstration, we use the Question-answering NLI dataset, which comprises 104,743 question, answer pairs sampled from question answering datasets. We will demonstrate how to use a large model to guide the learning and improve the performance of a small model in AutoGluon.

Load Dataset

The Question-answering NLI dataset contains sentence pairs in English. In the label column, 0 means that the sentence is not related to the question and 1 means that the sentence is related to the question.

import datasets
from datasets import load_dataset

datasets.logging.disable_progress_bar()

dataset = load_dataset("glue", "qnli")
dataset['train']
Dataset({
    features: ['question', 'sentence', 'label', 'idx'],
    num_rows: 104743
})
from sklearn.model_selection import train_test_split

train_valid_df = dataset["train"].to_pandas()[["question", "sentence", "label"]].sample(1000, random_state=123)
train_df, valid_df = train_test_split(train_valid_df, test_size=0.2, random_state=123)
test_df = dataset["validation"].to_pandas()[["question", "sentence", "label"]].sample(1000, random_state=123)

Load the Teacher Model

In our example, we will directly load a teacher model with the google/bert_uncased_L-12_H-768_A-12 backbone that has been trained on QNLI and distill it into a student model with the google/bert_uncased_L-6_H-768_A-12 backbone.

!wget --quiet https://automl-mm-bench.s3.amazonaws.com/unit-tests/distillation_sample_teacher.zip -O distillation_sample_teacher.zip
!unzip -q -o distillation_sample_teacher.zip -d .
from autogluon.multimodal import MultiModalPredictor

teacher_predictor = MultiModalPredictor.load("ag_distillation_sample_teacher/")
/home/ci/opt/venv/lib/python3.11/site-packages/mmengine/optim/optimizer/zero_optimizer.py:11: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
  from torch.distributed.optim import \
Start to upgrade the previous configuration trained by AutoMM version=0.5.3b20221108.
/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/base.py:376: InconsistentVersionWarning: Trying to unpickle estimator LabelEncoder from version 1.0.2 when using version 1.5.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  warnings.warn(
/home/ci/opt/venv/lib/python3.11/site-packages/sklearn/base.py:376: InconsistentVersionWarning: Trying to unpickle estimator StandardScaler from version 1.0.2 when using version 1.5.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  warnings.warn(
Load pretrained checkpoint: /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/ag_distillation_sample_teacher/model.ckpt
/home/ci/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:2117: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=torch.device("cpu"))["state_dict"]  # nosec B614

Distill to Student

Training the student model is straight forward. You may just add the teacher_predictor argument when calling .fit(). Internally, the student will be trained by matching the prediction/feature map from the teacher. It can perform better than directly finetuning the student.

student_predictor = MultiModalPredictor(label="label")
student_predictor.fit(
    train_df,
    tuning_data=valid_df,
    teacher_predictor=teacher_predictor,
    hyperparameters={
        "model.hf_text.checkpoint_name": "google/bert_uncased_L-6_H-768_A-12",
        "optimization.max_epochs": 2,
    }
)
No path specified. Models will be saved in: "AutogluonModels/ag-20250107_030353"
=================== System Info ===================
AutoGluon Version:  1.2b20250107
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Tue Sep 24 10:00:37 UTC 2024
CPU Count:          8
Pytorch Version:    2.5.1+cu124
CUDA Version:       12.4
Memory Avail:       27.59 GB / 30.95 GB (89.2%)
Disk Space Avail:   168.62 GB / 255.99 GB (65.9%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values:  [0, 1]
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250107_030353
    ```
Seed set to 0
/home/ci/opt/venv/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'softmax_regression_loss_func' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['softmax_regression_loss_func'])`.
GPU Count: 1
GPU Count to be Used: 1
GPU 0 Name: Tesla T4
GPU 0 Memory: 0.43GB/15.0GB (Used/Total)
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name                         | Type                         | Params | Mode 
--------------------------------------------------------------------------------------
0 | student_model                | HFAutoModelForTextPrediction | 67.0 M | train
1 | teacher_model                | HFAutoModelForTextPrediction | 109 M  | train
2 | validation_metric            | BinaryAUROC                  | 0      | train
3 | hard_label_loss_func         | CrossEntropyLoss             | 0      | train
4 | soft_label_loss_func         | CrossEntropyLoss             | 0      | train
5 | softmax_regression_loss_func | MSELoss                      | 0      | train
6 | output_feature_loss_func     | MSELoss                      | 0      | train
7 | output_feature_adaptor       | Identity                     | 0      | train
8 | rkd_loss_func                | RKDLoss                      | 0      | train
--------------------------------------------------------------------------------------
176 M     Trainable params
0         Non-trainable params
176 M     Total params
705.761   Total estimated model params size (MB)
239       Modules in train mode
120       Modules in eval mode
Epoch 0, global step 3: 'val_roc_auc' reached 0.65541 (best 0.65541), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250107_030353/epoch=0-step=3.ckpt' as top 3
Epoch 0, global step 7: 'val_roc_auc' reached 0.71088 (best 0.71088), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250107_030353/epoch=0-step=7.ckpt' as top 3
Epoch 1, global step 10: 'val_roc_auc' reached 0.71581 (best 0.71581), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250107_030353/epoch=1-step=10.ckpt' as top 3
Epoch 1, global step 14: 'val_roc_auc' reached 0.71631 (best 0.71631), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250107_030353/epoch=1-step=14.ckpt' as top 3
`Trainer.fit` stopped: `max_epochs=2` reached.
Start to fuse 3 checkpoints via the greedy soup algorithm.
/home/ci/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:2117: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=torch.device("cpu"))["state_dict"]  # nosec B614
/home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/checkpoint.py:45: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(per_path, map_location=torch.device("cpu"))["state_dict"]  # nosec B614
/home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/checkpoint.py:45: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(per_path, map_location=torch.device("cpu"))["state_dict"]  # nosec B614
/home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/checkpoint.py:63: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  avg_state_dict = torch.load(checkpoint_paths[0], map_location=torch.device("cpu"))["state_dict"]  # nosec B614
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250107_030353")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7fba51affb50>
print(student_predictor.evaluate(data=test_df))
{'roc_auc': 0.8208768918708066}

More about Knowledge Distillation

To learn how to customize distillation and how it compares with direct finetuning, see the distillation examples and README in AutoMM Distillation Examples. Especially the multilingual distillation example with more details and customization.

Other Examples

You may go to AutoMM Examples to explore other examples about AutoMM.

Customization

To learn how to customize AutoMM, please refer to Customize AutoMM.