Single GPU Billion-scale Model Training via Parameter-Efficient Finetuning

Open In Colab Open In SageMaker Studio Lab

As pointed out by a recent paper from Stanford Institute for Human-Centered Artificial Intelligence, AI is undergoing a paradigm shift with the rise of “foundation models”, i.e., giant models that are trained on a diverse collection of datasets generally in a self-supervised way. These foundation models, which are the key of AutoMM, can be easily adapted to down-stream applications. However, as the size of these foundation models grows, finetuning these models becomes increasingly difficult. Following is a figure from the Microsoft research blog that demonstrates the trend:

Scaling of foundation models

The goal of AutoMM is to help anyone solve machine learning problems via open source foundation models, including these giant models. To finetune these large-scale models, we adopt the recently popularized parameter-efficient finetuning technique. The idea is to either finetune a small subset of the weights in the foundation model (e.g., BitFit), or adding a tiny tunable structure on top of the fixed backbone (e.g., Prompt Tuning, LoRA, Adapter, MAM Adapter, IA^3). These techniques can effectively reduce the peak memory usage and model training time, while maintaining the performance.

In this tutorial, we introduce how to apply parameter-efficient finetuning in MultiModalPredictor. We first introduce how to adopt the "ia3_bias" algorithm for parameter-efficient finetuning. Afterwards, we show how you can simply combine "ia3_bias" and gradient checkpointing to finetune the XL-variant of Google’s FLAN-T5 via a single NVIDIA T4 GPU.

Prepare Dataset

The Cross-Lingual Amazon Product Review Sentiment dataset contains Amazon product reviews in four languages. Here, we load the English and German fold of the dataset. In the label column, 0 means negative sentiment and 1 means positive sentiment. For the purpose of demonstration, we downsampled the training data to 1000 samples. We will train the model on the English dataset and directly evaluate its performance on the German and Japanese test set.

!wget --quiet https://automl-mm-bench.s3.amazonaws.com/multilingual-datasets/amazon_review_sentiment_cross_lingual.zip -O amazon_review_sentiment_cross_lingual.zip
!unzip -q -o amazon_review_sentiment_cross_lingual.zip -d .
import os
import shutil
os.environ["TRANSFORMERS_CACHE"] = "cache"

def clear_cache():
    if os.path.exists("cache"):
        shutil.rmtree("cache")

clear_cache()
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

train_en_df = pd.read_csv("amazon_review_sentiment_cross_lingual/en_train.tsv",
                          sep="\t",
                          header=None,
                          names=["label", "text"]) \
                .sample(1000, random_state=123).reset_index(drop=True)

test_en_df = pd.read_csv("amazon_review_sentiment_cross_lingual/en_test.tsv",
                          sep="\t",
                          header=None,
                          names=["label", "text"]) \
               .sample(200, random_state=123).reset_index(drop=True)
test_de_df = pd.read_csv("amazon_review_sentiment_cross_lingual/de_test.tsv",
                          sep="\t", header=None, names=["label", "text"]) \
               .sample(200, random_state=123).reset_index(drop=True)

test_jp_df = pd.read_csv('amazon_review_sentiment_cross_lingual/jp_test.tsv',
                          sep='\t', header=None, names=['label', 'text']) \
               .sample(200, random_state=123).reset_index(drop=True)
train_en_df.head(5)
label text
0 0 This is a film that literally sees little wron...
1 0 This music is pretty intelligent, but not very...
2 0 One of the best pieces of rock ever recorded, ...
3 0 Reading the posted reviews here, is like revis...
4 1 I've just finished page 341, the last page. It...
test_jp_df.head(5)
label text
0 1 原作はビクトル・ユーゴの長編小説だが、私が子供の頃読んだのは短縮版の「ああ無情」。それでもこ...
1 1 ほかの作品のレビューにみんな書いているのに、何故この作品について書いている人が一人しかいない...
2 0 一番の問題点は青島が出ていない事でしょう。 TV番組では『芸人が出ていればバラエティだから...
3 0 昔、 りんたろう監督によるアニメ「カムイの剣」があった。 「カムイの剣」…を観た人なら本作...
4 1 以前のアルバムを聴いていないのでなんとも言えないが、クラシックなメタルを聞いてきた耳には、と...

Finetuning Multilingual Model with IA3 + BitFit

In AutoMM, to enable efficient finetuning, just specify the optimization.efficient_finetune to be "ia3_bias".

from autogluon.multimodal import MultiModalPredictor
import uuid

model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3"
predictor = MultiModalPredictor(label="label",
                                path=model_path)
predictor.fit(train_en_df,
              presets="multilingual",
              hyperparameters={
                  "optimization.efficient_finetune": "ia3_bias",
                  "optimization.lr_decay": 0.9,
                  "optimization.learning_rate": 3e-03,
                  "optimization.end_lr": 3e-03,
                  "optimization.max_epochs": 2,
                  "optimization.warmup_steps": 0,
                  "env.batch_size": 32,
              })
/home/ci/opt/venv/lib/python3.11/site-packages/mmengine/optim/optimizer/zero_optimizer.py:11: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
  from torch.distributed.optim import \
=================== System Info ===================
AutoGluon Version:  1.2b20250107
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Tue Sep 24 10:00:37 UTC 2024
CPU Count:          8
Pytorch Version:    2.5.1+cu124
CUDA Version:       12.4
Memory Avail:       28.49 GB / 30.95 GB (92.0%)
Disk Space Avail:   184.44 GB / 255.99 GB (72.0%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values:  [0, 1]
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/a95ba07740c04822b076d34e25415d63-multilingual_ia3
    ```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
GPU 0 Name: Tesla T4
GPU 0 Memory: 0.43GB/15.0GB (Used/Total)
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name              | Type                         | Params | Mode 
---------------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 278 M  | train
1 | validation_metric | BinaryAUROC                  | 0      | train
2 | loss_func         | CrossEntropyLoss             | 0      | train
---------------------------------------------------------------------------
122 K     Trainable params
278 M     Non-trainable params
278 M     Total params
1,112.955 Total estimated model params size (MB)
28        Modules in train mode
213       Modules in eval mode
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 7
      4 model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3"
      5 predictor = MultiModalPredictor(label="label",
      6                                 path=model_path)
----> 7 predictor.fit(train_en_df,
      8               presets="multilingual",
      9               hyperparameters={
     10                   "optimization.efficient_finetune": "ia3_bias",
     11                   "optimization.lr_decay": 0.9,
     12                   "optimization.learning_rate": 3e-03,
     13                   "optimization.end_lr": 3e-03,
     14                   "optimization.max_epochs": 2,
     15                   "optimization.warmup_steps": 0,
     16                   "env.batch_size": 32,
     17               })

File ~/autogluon/multimodal/src/autogluon/multimodal/predictor.py:509, in MultiModalPredictor.fit(self, train_data, presets, tuning_data, max_num_tuning_data, id_mappings, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_predictor, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts)
    507 else:
    508     teacher_learner = teacher_predictor._learner
--> 509 self._learner.fit(
    510     train_data=train_data,
    511     presets=presets,
    512     tuning_data=tuning_data,
    513     max_num_tuning_data=max_num_tuning_data,
    514     time_limit=time_limit,
    515     save_path=save_path,
    516     hyperparameters=hyperparameters,
    517     column_types=column_types,
    518     holdout_frac=holdout_frac,
    519     teacher_learner=teacher_learner,
    520     seed=seed,
    521     standalone=standalone,
    522     hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
    523     clean_ckpts=clean_ckpts,
    524     id_mappings=id_mappings,
    525 )
    527 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:654, in BaseLearner.fit(self, train_data, presets, tuning_data, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_learner, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, **kwargs)
    647 self.fit_sanity_check()
    648 self.prepare_fit_args(
    649     time_limit=time_limit,
    650     seed=seed,
    651     standalone=standalone,
    652     clean_ckpts=clean_ckpts,
    653 )
--> 654 fit_returns = self.execute_fit()
    655 self.on_fit_end(
    656     training_start=training_start,
    657     strategy=fit_returns.get("strategy", None),
   (...)
    660     clean_ckpts=clean_ckpts,
    661 )
    663 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:566, in BaseLearner.execute_fit(self)
    564     return dict()
    565 else:
--> 566     attributes = self.fit_per_run(**self._fit_args)
    567     self.update_attributes(**attributes)  # only update attributes for non-HPO mode
    568     return attributes

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1320, in BaseLearner.fit_per_run(self, max_time, save_path, ckpt_path, resume, enable_progress_bar, seed, hyperparameters, advanced_hyperparameters, config, df_preprocessor, data_processors, model, standalone, clean_ckpts)
   1301 config = self.post_update_config_per_run(
   1302     config=config,
   1303     num_gpus=num_gpus,
   1304     precision=precision,
   1305     strategy=strategy,
   1306 )
   1307 trainer = self.init_trainer_per_run(
   1308     num_gpus=num_gpus,
   1309     config=config,
   (...)
   1317     enable_progress_bar=enable_progress_bar,
   1318 )
-> 1320 self.run_trainer(
   1321     trainer=trainer,
   1322     litmodule=litmodule,
   1323     datamodule=datamodule,
   1324     ckpt_path=ckpt_path,
   1325     resume=resume,
   1326 )
   1327 self.on_fit_per_run_end(
   1328     save_path=save_path,
   1329     standalone=standalone,
   (...)
   1334     model=model,
   1335 )
   1337 best_score = (
   1338     trainer.callback_metrics[f"val_{self._validation_metric_name}"].item()
   1339     if f"val_{self._validation_metric_name}" in trainer.callback_metrics
   1340     else self._best_score
   1341 )  # https://github.com/autogluon/autogluon/issues/4428

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1175, in BaseLearner.run_trainer(self, trainer, litmodule, datamodule, ckpt_path, resume, pred_writer, is_train)
   1173     warnings.filterwarnings("ignore", filter)
   1174 if is_train:
-> 1175     trainer.fit(
   1176         litmodule,
   1177         datamodule=datamodule,
   1178         ckpt_path=ckpt_path if resume else None,  # this is to resume training that was broken accidentally
   1179     )
   1180 else:
   1181     blacklist_msgs = []

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:539, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    537 self.state.status = TrainerStatus.RUNNING
    538 self.training = True
--> 539 call._call_and_handle_interrupt(
    540     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    541 )

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     45     if trainer.strategy.launcher is not None:
     46         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47     return trainer_fn(*args, **kwargs)
     49 except _TunerExitException:
     50     _call_teardown_hook(trainer)

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:575, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    568 assert self.state.fn is not None
    569 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    570     self.state.fn,
    571     ckpt_path,
    572     model_provided=True,
    573     model_connected=self.lightning_module is not None,
    574 )
--> 575 self._run(model, ckpt_path=ckpt_path)
    577 assert self.state.stopped
    578 self.training = False

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:982, in Trainer._run(self, model, ckpt_path)
    977 self._signal_connector.register_signal_handlers()
    979 # ----------------------------
    980 # RUN THE TRAINER
    981 # ----------------------------
--> 982 results = self._run_stage()
    984 # ----------------------------
    985 # POST-Training CLEAN UP
    986 # ----------------------------
    987 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1024, in Trainer._run_stage(self)
   1022 if self.training:
   1023     with isolate_rng():
-> 1024         self._run_sanity_check()
   1025     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
   1026         self.fit_loop.run()

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1053, in Trainer._run_sanity_check(self)
   1050 call._call_callback_hooks(self, "on_sanity_check_start")
   1052 # run eval step
-> 1053 val_loop.run()
   1055 call._call_callback_hooks(self, "on_sanity_check_end")
   1057 # reset logger connector

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py:179, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    177     context_manager = torch.no_grad
    178 with context_manager():
--> 179     return loop_run(self, *args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:144, in _EvaluationLoop.run(self)
    142     self.batch_progress.is_last_batch = data_fetcher.done
    143     # run step hooks
--> 144     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
    145 except StopIteration:
    146     # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
    147     break

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:433, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
    427 hook_name = "test_step" if trainer.testing else "validation_step"
    428 step_args = (
    429     self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
    430     if not using_dataloader_iter
    431     else (dataloader_iter,)
    432 )
--> 433 output = call._call_strategy_hook(trainer, hook_name, *step_args)
    435 self.batch_progress.increment_processed()
    437 if using_dataloader_iter:
    438     # update the hook kwargs now that the step method might have consumed the iterator

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:323, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    320     return None
    322 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 323     output = fn(*args, **kwargs)
    325 # restore current_fx when nested context
    326 pl_module._current_fx_name = prev_fx_name

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:412, in Strategy.validation_step(self, *args, **kwargs)
    410 if self.model != self.lightning_module:
    411     return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 412 return self.lightning_module.validation_step(*args, **kwargs)

File ~/autogluon/multimodal/src/autogluon/multimodal/optimization/lit_module.py:289, in LitModule.validation_step(self, batch, batch_idx)
    273 def validation_step(self, batch, batch_idx):
    274     """
    275     Per validation step. This function is registered by LightningModule.
    276     Refer to https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#validation-loop
   (...)
    287         Index of mini-batch.
    288     """
--> 289     output, loss = self._shared_step(batch)
    290     if self.model_postprocess_fn:
    291         output = self.model_postprocess_fn(output)

File ~/autogluon/multimodal/src/autogluon/multimodal/optimization/lit_module.py:234, in LitModule._shared_step(self, batch)
    232     self.mixup_fn.mixup_enabled = self.training & (self.current_epoch < self.hparams.mixup_off_epoch)
    233     batch, label = multimodel_mixup(batch=batch, model=self.model, mixup_fn=self.mixup_fn)
--> 234 output = run_model(self.model, batch)
    235 loss = self._compute_loss(output=output, label=label)
    236 return output, loss

File ~/autogluon/multimodal/src/autogluon/multimodal/models/utils.py:830, in run_model(model, batch, trt_model)
    828         output_vec = pure_model(*tuple(input_vec))
    829     else:
--> 830         output_vec = model(*tuple(input_vec))
    832     output = pure_model.get_output_dict(*output_vec)
    833 else:

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/autogluon/multimodal/src/autogluon/multimodal/models/huggingface_text.py:211, in HFAutoModelForTextPrediction.forward(self, text_token_ids, text_segment_ids, text_valid_length, text_column_names, text_column_indices)
    209 else:
    210     if "token_type_ids" in self.tokenizer.model_input_names:
--> 211         outputs = self.model(
    212             input_ids=text_token_ids,
    213             token_type_ids=text_segment_ids,
    214             attention_mask=text_masks,
    215         )
    216     else:
    217         outputs = self.model(
    218             input_ids=text_token_ids,
    219             attention_mask=text_masks,
    220         )

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:871, in DebertaV2Model.forward(self, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds, output_attentions, output_hidden_states, return_dict)
    861     token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
    863 embedding_output = self.embeddings(
    864     input_ids=input_ids,
    865     token_type_ids=token_type_ids,
   (...)
    868     inputs_embeds=inputs_embeds,
    869 )
--> 871 encoder_outputs = self.encoder(
    872     embedding_output,
    873     attention_mask,
    874     output_hidden_states=True,
    875     output_attentions=output_attentions,
    876     return_dict=return_dict,
    877 )
    878 encoded_layers = encoder_outputs[1]
    880 if self.z_steps > 1:

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:675, in DebertaV2Encoder.forward(self, hidden_states, attention_mask, output_hidden_states, output_attentions, query_states, relative_pos, return_dict)
    665     output_states, attn_weights = self._gradient_checkpointing_func(
    666         layer_module.__call__,
    667         next_kv,
   (...)
    672         output_attentions,
    673     )
    674 else:
--> 675     output_states, attn_weights = layer_module(
    676         next_kv,
    677         attention_mask,
    678         query_states=query_states,
    679         relative_pos=relative_pos,
    680         rel_embeddings=rel_embeddings,
    681         output_attentions=output_attentions,
    682     )
    684 if output_attentions:
    685     all_attentions = all_attentions + (attn_weights,)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:443, in DebertaV2Layer.forward(self, hidden_states, attention_mask, query_states, relative_pos, rel_embeddings, output_attentions)
    434 def forward(
    435     self,
    436     hidden_states,
   (...)
    441     output_attentions: bool = False,
    442 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
--> 443     attention_output, att_matrix = self.attention(
    444         hidden_states,
    445         attention_mask,
    446         output_attentions=output_attentions,
    447         query_states=query_states,
    448         relative_pos=relative_pos,
    449         rel_embeddings=rel_embeddings,
    450     )
    451     intermediate_output = self.intermediate(attention_output)
    452     layer_output = self.output(intermediate_output, attention_output)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:376, in DebertaV2Attention.forward(self, hidden_states, attention_mask, output_attentions, query_states, relative_pos, rel_embeddings)
    367 def forward(
    368     self,
    369     hidden_states,
   (...)
    374     rel_embeddings=None,
    375 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
--> 376     self_output, att_matrix = self.self(
    377         hidden_states,
    378         attention_mask,
    379         output_attentions,
    380         query_states=query_states,
    381         relative_pos=relative_pos,
    382         rel_embeddings=rel_embeddings,
    383     )
    384     if query_states is None:
    385         query_states = hidden_states

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:267, in DisentangledSelfAttention.forward(self, hidden_states, attention_mask, output_attentions, query_states, relative_pos, rel_embeddings)
    262 attention_scores = attention_scores.view(
    263     -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
    264 )
    266 attention_mask = attention_mask.bool()
--> 267 attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
    268 # bsz x height x length x dimension
    269 attention_probs = nn.functional.softmax(attention_scores, dim=-1)

RuntimeError: value cannot be converted to type at::BFloat16 without overflow

The fraction of the tunable parameters is around 0.5% of all parameters. Actually, the model trained purely on English data can achieve good performance on the test sets, even on the German / Japanese test set. It obtained comparable results as full-finetuning as in AutoMM for Text - Multilingual Problems.

score_in_en = predictor.evaluate(test_en_df)
score_in_de = predictor.evaluate(test_de_df)
score_in_jp = predictor.evaluate(test_jp_df)
print('Score in the English Testset:', score_in_en)
print('Score in the German Testset:', score_in_de)
print('Score in the Japanese Testset:', score_in_jp)

Training FLAN-T5-XL on Single GPU

By combining gradient checkpointing and parameter-efficient finetuning, it is feasible to finetune google/flan-t5-xl that has close to two billion parameterswith a single T4 GPU available in AWS G4 instances. To turn on gradient checkpointing, you just need to set "model.hf_text.gradient_checkpointing" to True. To accelerate the training, we downsample the number of training samples to be 200.

# Just for clean the space
clear_cache()
shutil.rmtree(model_path)
from autogluon.multimodal import MultiModalPredictor

train_en_df_downsample = train_en_df.sample(200, random_state=123)

new_model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3_gradient_checkpoint"
predictor = MultiModalPredictor(label="label",
                                path=new_model_path)
predictor.fit(train_en_df_downsample,
              presets="multilingual",
              hyperparameters={
                  "model.hf_text.checkpoint_name": "google/flan-t5-xl",
                  "model.hf_text.gradient_checkpointing": True,
                  "model.hf_text.low_cpu_mem_usage": True,
                  "optimization.efficient_finetune": "ia3_bias",
                  "optimization.lr_decay": 0.9,
                  "optimization.learning_rate": 3e-03,
                  "optimization.end_lr": 3e-03,
                  "optimization.max_epochs": 1,
                  "optimization.warmup_steps": 0,
                  "env.batch_size": 1,
                  "env.eval_batch_size_ratio": 1
              })

Global seed set to 123
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                         | Params
-------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 1.2 B 
1 | validation_metric | AUROC                        | 0     
2 | loss_func         | CrossEntropyLoss             | 0     
-------------------------------------------------------------------
203 K     Trainable params
1.2 B     Non-trainable params
1.2 B     Total params
4,894.913 Total estimated model params size (MB)
Epoch 0, global step 20: 'val_roc_auc' reached 0.88802 (best 0.88802), saving model to '/home/ubuntu/autogluon/docs/tutorials/multimodal/advanced_topics/multilingual_ia3_gradient_checkpoint/epoch=0-step=20.ckpt' as top 1
Epoch 0, global step 40: 'val_roc_auc' reached 0.94531 (best 0.94531), saving model to '/home/ubuntu/autogluon/docs/tutorials/multimodal/advanced_topics/multilingual_ia3_gradient_checkpoint/epoch=0-step=40.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=1` reached.





<autogluon.multimodal.predictor.MultiModalPredictor at 0x7fd58c4dbca0>
score_in_en = predictor.evaluate(test_en_df)
print('Score in the English Testset:', score_in_en)
Score in the English Testset: {'roc_auc': 0.931263189629183}
# Just for clean the space
clear_cache()
shutil.rmtree(new_model_path)

Other Examples

You may go to AutoMM Examples to explore other examples about AutoMM.

Customization

To learn how to customize AutoMM, please refer to Customize AutoMM.