Knowledge Distillation in AutoMM

Open In Colab Open In SageMaker Studio Lab

Pretrained foundation models are becoming increasingly large. However, these models are difficult to deploy due to limited resources available in deployment scenarios. To benefit from large models under this constraint, you transfer the knowledge from the large-scale teacher models to the student model, with knowledge distillation. In this way, the small student model can be practically deployed under real-world scenarios, while the performance will be better than training the student model from scratch thanks to the teacher.

In this tutorial, we introduce how to adopt MultiModalPredictor for knowledge distillation. For the purpose of demonstration, we use the Question-answering NLI dataset, which comprises 104,743 question, answer pairs sampled from question answering datasets. We will demonstrate how to use a large model to guide the learning and improve the performance of a small model in AutoGluon.

Load Dataset

The Question-answering NLI dataset contains sentence pairs in English. In the label column, 0 means that the sentence is not related to the question and 1 means that the sentence is related to the question.

import datasets
from datasets import load_dataset

datasets.logging.disable_progress_bar()

dataset = load_dataset("glue", "qnli")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 6
      2 from datasets import load_dataset
      4 datasets.logging.disable_progress_bar()
----> 6 dataset = load_dataset("glue", "qnli")

File ~/opt/venv/lib/python3.12/site-packages/datasets/load.py:2112, in load_dataset(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, **config_kwargs)
   2107 verification_mode = VerificationMode(
   2108     (verification_mode or VerificationMode.BASIC_CHECKS) if not save_infos else VerificationMode.ALL_CHECKS
   2109 )
   2111 # Create a dataset builder
-> 2112 builder_instance = load_dataset_builder(
   2113     path=path,
   2114     name=name,
   2115     data_dir=data_dir,
   2116     data_files=data_files,
   2117     cache_dir=cache_dir,
   2118     features=features,
   2119     download_config=download_config,
   2120     download_mode=download_mode,
   2121     revision=revision,
   2122     token=token,
   2123     storage_options=storage_options,
   2124     **config_kwargs,
   2125 )
   2127 # Return iterable dataset in case of streaming
   2128 if streaming:

File ~/opt/venv/lib/python3.12/site-packages/datasets/load.py:1798, in load_dataset_builder(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, use_auth_token, storage_options, **config_kwargs)
   1796     download_config = download_config.copy() if download_config else DownloadConfig()
   1797     download_config.storage_options.update(storage_options)
-> 1798 dataset_module = dataset_module_factory(
   1799     path,
   1800     revision=revision,
   1801     download_config=download_config,
   1802     download_mode=download_mode,
   1803     data_dir=data_dir,
   1804     data_files=data_files,
   1805 )
   1806 # Get dataset builder class from the processing script
   1807 builder_kwargs = dataset_module.builder_kwargs

File ~/opt/venv/lib/python3.12/site-packages/datasets/load.py:1495, in dataset_module_factory(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, **download_kwargs)
   1490             if isinstance(e1, FileNotFoundError):
   1491                 raise FileNotFoundError(
   1492                     f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or any data file in the same directory. "
   1493                     f"Couldn't find '{path}' on the Hugging Face Hub either: {type(e1).__name__}: {e1}"
   1494                 ) from None
-> 1495             raise e1 from None
   1496 else:
   1497     raise FileNotFoundError(
   1498         f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or any data file in the same directory."
   1499     )

File ~/opt/venv/lib/python3.12/site-packages/datasets/load.py:1479, in dataset_module_factory(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, **download_kwargs)
   1464         return HubDatasetModuleFactoryWithScript(
   1465             path,
   1466             revision=revision,
   (...)
   1469             dynamic_modules_path=dynamic_modules_path,
   1470         ).get_module()
   1471     else:
   1472         return HubDatasetModuleFactoryWithoutScript(
   1473             path,
   1474             revision=revision,
   1475             data_dir=data_dir,
   1476             data_files=data_files,
   1477             download_config=download_config,
   1478             download_mode=download_mode,
-> 1479         ).get_module()
   1480 except (
   1481     Exception
   1482 ) as e1:  # noqa all the attempts failed, before raising the error we should check if the module is already cached.
   1483     try:

File ~/opt/venv/lib/python3.12/site-packages/datasets/load.py:1034, in HubDatasetModuleFactoryWithoutScript.get_module(self)
   1029 metadata_configs = MetadataConfigs.from_dataset_card_data(dataset_card_data)
   1030 dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data)
   1031 patterns = (
   1032     sanitize_patterns(self.data_files)
   1033     if self.data_files is not None
-> 1034     else get_data_patterns(base_path, download_config=self.download_config)
   1035 )
   1036 data_files = DataFilesDict.from_patterns(
   1037     patterns,
   1038     base_path=base_path,
   1039     allowed_extensions=ALL_ALLOWED_EXTENSIONS,
   1040     download_config=self.download_config,
   1041 )
   1042 module_name, default_builder_kwargs = infer_module_for_data_files(
   1043     data_files=data_files,
   1044     path=self.name,
   1045     download_config=self.download_config,
   1046 )

File ~/opt/venv/lib/python3.12/site-packages/datasets/data_files.py:457, in get_data_patterns(base_path, download_config)
    455 resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config)
    456 try:
--> 457     return _get_data_files_patterns(resolver)
    458 except FileNotFoundError:
    459     raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None

File ~/opt/venv/lib/python3.12/site-packages/datasets/data_files.py:248, in _get_data_files_patterns(pattern_resolver)
    246 for pattern in patterns:
    247     try:
--> 248         data_files = pattern_resolver(pattern)
    249     except FileNotFoundError:
    250         continue

File ~/opt/venv/lib/python3.12/site-packages/datasets/data_files.py:332, in resolve_pattern(pattern, base_path, allowed_extensions, download_config)
    330     base_path = ""
    331 pattern, storage_options = _prepare_path_and_storage_options(pattern, download_config=download_config)
--> 332 fs, _, _ = get_fs_token_paths(pattern, storage_options=storage_options)
    333 fs_base_path = base_path.split("::")[0].split("://")[-1] or fs.root_marker
    334 fs_pattern = pattern.split("::")[0].split("://")[-1]

File ~/opt/venv/lib/python3.12/site-packages/fsspec/core.py:686, in get_fs_token_paths(urlpath, mode, num, name_function, storage_options, protocol, expand)
    684     paths = _expand_paths(paths, name_function, num)
    685 elif "*" in paths:
--> 686     paths = [f for f in sorted(fs.glob(paths)) if not fs.isdir(f)]
    687 else:
    688     paths = [paths]

File ~/opt/venv/lib/python3.12/site-packages/huggingface_hub/hf_file_system.py:521, in HfFileSystem.glob(self, path, **kwargs)
    519 kwargs = {"expand_info": kwargs.get("detail", False), **kwargs}
    520 path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
--> 521 return super().glob(path, **kwargs)

File ~/opt/venv/lib/python3.12/site-packages/fsspec/spec.py:639, in AbstractFileSystem.glob(self, path, maxdepth, **kwargs)
    635         depth = None
    637 allpaths = self.find(root, maxdepth=depth, withdirs=True, detail=True, **kwargs)
--> 639 pattern = glob_translate(path + ("/" if ends_with_sep else ""))
    640 pattern = re.compile(pattern)
    642 out = {
    643     p: info
    644     for p, info in sorted(allpaths.items())
   (...)
    649     )
    650 }

File ~/opt/venv/lib/python3.12/site-packages/fsspec/utils.py:729, in glob_translate(pat)
    727     continue
    728 elif "**" in part:
--> 729     raise ValueError(
    730         "Invalid pattern: '**' can only be an entire path component"
    731     )
    732 if part:
    733     results.extend(_translate(part, f"{not_sep}*", not_sep))

ValueError: Invalid pattern: '**' can only be an entire path component
dataset['train']
from sklearn.model_selection import train_test_split

train_valid_df = dataset["train"].to_pandas()[["question", "sentence", "label"]].sample(1000, random_state=123)
train_df, valid_df = train_test_split(train_valid_df, test_size=0.2, random_state=123)
test_df = dataset["validation"].to_pandas()[["question", "sentence", "label"]].sample(1000, random_state=123)

Load the Teacher Model

In our example, we will directly load a teacher model with the google/bert_uncased_L-12_H-768_A-12 backbone that has been trained on QNLI and distill it into a student model with the google/bert_uncased_L-6_H-768_A-12 backbone.

!wget --quiet https://automl-mm-bench.s3.amazonaws.com/unit-tests/distillation_sample_teacher.zip -O distillation_sample_teacher.zip
!unzip -q -o distillation_sample_teacher.zip -d .
from autogluon.multimodal import MultiModalPredictor

teacher_predictor = MultiModalPredictor.load("ag_distillation_sample_teacher/")

Distill to Student

Training the student model is straight forward. You may just add the teacher_predictor argument when calling .fit(). Internally, the student will be trained by matching the prediction/feature map from the teacher. It can perform better than directly finetuning the student.

student_predictor = MultiModalPredictor(label="label")
student_predictor.fit(
    train_df,
    tuning_data=valid_df,
    teacher_predictor=teacher_predictor,
    hyperparameters={
        "model.hf_text.checkpoint_name": "google/bert_uncased_L-6_H-768_A-12",
        "optim.max_epochs": 2,
    }
)
print(student_predictor.evaluate(data=test_df))

More about Knowledge Distillation

To learn how to customize distillation and how it compares with direct finetuning, see the distillation examples and README in AutoMM Distillation Examples. Especially the multilingual distillation example with more details and customization.

Other Examples

You may go to AutoMM Examples to explore other examples about AutoMM.

Customization

To learn how to customize AutoMM, please refer to Customize AutoMM.