Knowledge Distillation in AutoMM¶
Pretrained foundation models are becoming increasingly large. However, these models are difficult to deploy due to limited resources available in deployment scenarios. To benefit from large models under this constraint, you transfer the knowledge from the large-scale teacher models to the student model, with knowledge distillation. In this way, the small student model can be practically deployed under real-world scenarios, while the performance will be better than training the student model from scratch thanks to the teacher.
In this tutorial, we introduce how to adopt MultiModalPredictor
for knowledge distillation. For the purpose of demonstration, we use the Question-answering NLI dataset,
which comprises 104,743 question, answer pairs sampled from question answering datasets. We will demonstrate how to use a large model to guide the learning and improve the performance of a small model in AutoGluon.
Load Dataset¶
The Question-answering NLI dataset contains
sentence pairs in English. In the label column, 0
means that the sentence is not related to the question and 1
means that the sentence is related to the question.
import datasets
from datasets import load_dataset
datasets.logging.disable_progress_bar()
dataset = load_dataset("glue", "qnli")
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 6
2 from datasets import load_dataset
4 datasets.logging.disable_progress_bar()
----> 6 dataset = load_dataset("glue", "qnli")
File ~/opt/venv/lib/python3.12/site-packages/datasets/load.py:2112, in load_dataset(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, **config_kwargs)
2107 verification_mode = VerificationMode(
2108 (verification_mode or VerificationMode.BASIC_CHECKS) if not save_infos else VerificationMode.ALL_CHECKS
2109 )
2111 # Create a dataset builder
-> 2112 builder_instance = load_dataset_builder(
2113 path=path,
2114 name=name,
2115 data_dir=data_dir,
2116 data_files=data_files,
2117 cache_dir=cache_dir,
2118 features=features,
2119 download_config=download_config,
2120 download_mode=download_mode,
2121 revision=revision,
2122 token=token,
2123 storage_options=storage_options,
2124 **config_kwargs,
2125 )
2127 # Return iterable dataset in case of streaming
2128 if streaming:
File ~/opt/venv/lib/python3.12/site-packages/datasets/load.py:1798, in load_dataset_builder(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, use_auth_token, storage_options, **config_kwargs)
1796 download_config = download_config.copy() if download_config else DownloadConfig()
1797 download_config.storage_options.update(storage_options)
-> 1798 dataset_module = dataset_module_factory(
1799 path,
1800 revision=revision,
1801 download_config=download_config,
1802 download_mode=download_mode,
1803 data_dir=data_dir,
1804 data_files=data_files,
1805 )
1806 # Get dataset builder class from the processing script
1807 builder_kwargs = dataset_module.builder_kwargs
File ~/opt/venv/lib/python3.12/site-packages/datasets/load.py:1495, in dataset_module_factory(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, **download_kwargs)
1490 if isinstance(e1, FileNotFoundError):
1491 raise FileNotFoundError(
1492 f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or any data file in the same directory. "
1493 f"Couldn't find '{path}' on the Hugging Face Hub either: {type(e1).__name__}: {e1}"
1494 ) from None
-> 1495 raise e1 from None
1496 else:
1497 raise FileNotFoundError(
1498 f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or any data file in the same directory."
1499 )
File ~/opt/venv/lib/python3.12/site-packages/datasets/load.py:1479, in dataset_module_factory(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, **download_kwargs)
1464 return HubDatasetModuleFactoryWithScript(
1465 path,
1466 revision=revision,
(...)
1469 dynamic_modules_path=dynamic_modules_path,
1470 ).get_module()
1471 else:
1472 return HubDatasetModuleFactoryWithoutScript(
1473 path,
1474 revision=revision,
1475 data_dir=data_dir,
1476 data_files=data_files,
1477 download_config=download_config,
1478 download_mode=download_mode,
-> 1479 ).get_module()
1480 except (
1481 Exception
1482 ) as e1: # noqa all the attempts failed, before raising the error we should check if the module is already cached.
1483 try:
File ~/opt/venv/lib/python3.12/site-packages/datasets/load.py:1034, in HubDatasetModuleFactoryWithoutScript.get_module(self)
1029 metadata_configs = MetadataConfigs.from_dataset_card_data(dataset_card_data)
1030 dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data)
1031 patterns = (
1032 sanitize_patterns(self.data_files)
1033 if self.data_files is not None
-> 1034 else get_data_patterns(base_path, download_config=self.download_config)
1035 )
1036 data_files = DataFilesDict.from_patterns(
1037 patterns,
1038 base_path=base_path,
1039 allowed_extensions=ALL_ALLOWED_EXTENSIONS,
1040 download_config=self.download_config,
1041 )
1042 module_name, default_builder_kwargs = infer_module_for_data_files(
1043 data_files=data_files,
1044 path=self.name,
1045 download_config=self.download_config,
1046 )
File ~/opt/venv/lib/python3.12/site-packages/datasets/data_files.py:457, in get_data_patterns(base_path, download_config)
455 resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config)
456 try:
--> 457 return _get_data_files_patterns(resolver)
458 except FileNotFoundError:
459 raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None
File ~/opt/venv/lib/python3.12/site-packages/datasets/data_files.py:248, in _get_data_files_patterns(pattern_resolver)
246 for pattern in patterns:
247 try:
--> 248 data_files = pattern_resolver(pattern)
249 except FileNotFoundError:
250 continue
File ~/opt/venv/lib/python3.12/site-packages/datasets/data_files.py:332, in resolve_pattern(pattern, base_path, allowed_extensions, download_config)
330 base_path = ""
331 pattern, storage_options = _prepare_path_and_storage_options(pattern, download_config=download_config)
--> 332 fs, _, _ = get_fs_token_paths(pattern, storage_options=storage_options)
333 fs_base_path = base_path.split("::")[0].split("://")[-1] or fs.root_marker
334 fs_pattern = pattern.split("::")[0].split("://")[-1]
File ~/opt/venv/lib/python3.12/site-packages/fsspec/core.py:686, in get_fs_token_paths(urlpath, mode, num, name_function, storage_options, protocol, expand)
684 paths = _expand_paths(paths, name_function, num)
685 elif "*" in paths:
--> 686 paths = [f for f in sorted(fs.glob(paths)) if not fs.isdir(f)]
687 else:
688 paths = [paths]
File ~/opt/venv/lib/python3.12/site-packages/huggingface_hub/hf_file_system.py:521, in HfFileSystem.glob(self, path, **kwargs)
519 kwargs = {"expand_info": kwargs.get("detail", False), **kwargs}
520 path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
--> 521 return super().glob(path, **kwargs)
File ~/opt/venv/lib/python3.12/site-packages/fsspec/spec.py:639, in AbstractFileSystem.glob(self, path, maxdepth, **kwargs)
635 depth = None
637 allpaths = self.find(root, maxdepth=depth, withdirs=True, detail=True, **kwargs)
--> 639 pattern = glob_translate(path + ("/" if ends_with_sep else ""))
640 pattern = re.compile(pattern)
642 out = {
643 p: info
644 for p, info in sorted(allpaths.items())
(...)
649 )
650 }
File ~/opt/venv/lib/python3.12/site-packages/fsspec/utils.py:729, in glob_translate(pat)
727 continue
728 elif "**" in part:
--> 729 raise ValueError(
730 "Invalid pattern: '**' can only be an entire path component"
731 )
732 if part:
733 results.extend(_translate(part, f"{not_sep}*", not_sep))
ValueError: Invalid pattern: '**' can only be an entire path component
dataset['train']
from sklearn.model_selection import train_test_split
train_valid_df = dataset["train"].to_pandas()[["question", "sentence", "label"]].sample(1000, random_state=123)
train_df, valid_df = train_test_split(train_valid_df, test_size=0.2, random_state=123)
test_df = dataset["validation"].to_pandas()[["question", "sentence", "label"]].sample(1000, random_state=123)
Load the Teacher Model¶
In our example, we will directly load a teacher model with the google/bert_uncased_L-12_H-768_A-12 backbone that has been trained on QNLI and distill it into a student model with the google/bert_uncased_L-6_H-768_A-12 backbone.
!wget --quiet https://automl-mm-bench.s3.amazonaws.com/unit-tests/distillation_sample_teacher.zip -O distillation_sample_teacher.zip
!unzip -q -o distillation_sample_teacher.zip -d .
from autogluon.multimodal import MultiModalPredictor
teacher_predictor = MultiModalPredictor.load("ag_distillation_sample_teacher/")
Distill to Student¶
Training the student model is straight forward. You may just add the teacher_predictor
argument when calling .fit()
.
Internally, the student will be trained by matching the prediction/feature map from the teacher. It can perform better than
directly finetuning the student.
student_predictor = MultiModalPredictor(label="label")
student_predictor.fit(
train_df,
tuning_data=valid_df,
teacher_predictor=teacher_predictor,
hyperparameters={
"model.hf_text.checkpoint_name": "google/bert_uncased_L-6_H-768_A-12",
"optim.max_epochs": 2,
}
)
print(student_predictor.evaluate(data=test_df))
More about Knowledge Distillation¶
To learn how to customize distillation and how it compares with direct finetuning, see the distillation examples and README in AutoMM Distillation Examples. Especially the multilingual distillation example with more details and customization.
Other Examples¶
You may go to AutoMM Examples to explore other examples about AutoMM.
Customization¶
To learn how to customize AutoMM, please refer to Customize AutoMM.