Single GPU Billion-scale Model Training via Parameter-Efficient Finetuning

Open In Colab Open In SageMaker Studio Lab

As pointed out by a recent paper from Stanford Institute for Human-Centered Artificial Intelligence, AI is undergoing a paradigm shift with the rise of “foundation models”, i.e., giant models that are trained on a diverse collection of datasets generally in a self-supervised way. These foundation models, which are the key of AutoMM, can be easily adapted to down-stream applications. However, as the size of these foundation models grows, finetuning these models becomes increasingly difficult. Following is a figure from the Microsoft research blog that demonstrates the trend:

Scaling of foundation models

The goal of AutoMM is to help anyone solve machine learning problems via open source foundation models, including these giant models. To finetune these large-scale models, we adopt the recently popularized parameter-efficient finetuning technique. The idea is to either finetune a small subset of the weights in the foundation model (e.g., BitFit), or adding a tiny tunable structure on top of the fixed backbone (e.g., Prompt Tuning, LoRA, Adapter, MAM Adapter, IA^3). These techniques can effectively reduce the peak memory usage and model training time, while maintaining the performance.

In this tutorial, we introduce how to apply parameter-efficient finetuning in MultiModalPredictor. We first introduce how to adopt the "ia3_bias" algorithm for parameter-efficient finetuning. Afterwards, we show how you can simply combine "ia3_bias" and gradient checkpointing to finetune the XL-variant of Google’s FLAN-T5 via a single NVIDIA T4 GPU.

Prepare Dataset

The Cross-Lingual Amazon Product Review Sentiment dataset contains Amazon product reviews in four languages. Here, we load the English and German fold of the dataset. In the label column, 0 means negative sentiment and 1 means positive sentiment. For the purpose of demonstration, we downsampled the training data to 1000 samples. We will train the model on the English dataset and directly evaluate its performance on the German and Japanese test set.

!wget --quiet https://automl-mm-bench.s3.amazonaws.com/multilingual-datasets/amazon_review_sentiment_cross_lingual.zip -O amazon_review_sentiment_cross_lingual.zip
!unzip -q -o amazon_review_sentiment_cross_lingual.zip -d .
import os
import shutil
os.environ["TRANSFORMERS_CACHE"] = "cache"

def clear_cache():
    if os.path.exists("cache"):
        shutil.rmtree("cache")

clear_cache()
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

train_en_df = pd.read_csv("amazon_review_sentiment_cross_lingual/en_train.tsv",
                          sep="\t",
                          header=None,
                          names=["label", "text"]) \
                .sample(1000, random_state=123).reset_index(drop=True)

test_en_df = pd.read_csv("amazon_review_sentiment_cross_lingual/en_test.tsv",
                          sep="\t",
                          header=None,
                          names=["label", "text"]) \
               .sample(200, random_state=123).reset_index(drop=True)
test_de_df = pd.read_csv("amazon_review_sentiment_cross_lingual/de_test.tsv",
                          sep="\t", header=None, names=["label", "text"]) \
               .sample(200, random_state=123).reset_index(drop=True)

test_jp_df = pd.read_csv('amazon_review_sentiment_cross_lingual/jp_test.tsv',
                          sep='\t', header=None, names=['label', 'text']) \
               .sample(200, random_state=123).reset_index(drop=True)
train_en_df.head(5)
label text
0 0 This is a film that literally sees little wron...
1 0 This music is pretty intelligent, but not very...
2 0 One of the best pieces of rock ever recorded, ...
3 0 Reading the posted reviews here, is like revis...
4 1 I've just finished page 341, the last page. It...
test_jp_df.head(5)
label text
0 1 原作はビクトル・ユーゴの長編小説だが、私が子供の頃読んだのは短縮版の「ああ無情」。それでもこ...
1 1 ほかの作品のレビューにみんな書いているのに、何故この作品について書いている人が一人しかいない...
2 0 一番の問題点は青島が出ていない事でしょう。 TV番組では『芸人が出ていればバラエティだから...
3 0 昔、 りんたろう監督によるアニメ「カムイの剣」があった。 「カムイの剣」…を観た人なら本作...
4 1 以前のアルバムを聴いていないのでなんとも言えないが、クラシックなメタルを聞いてきた耳には、と...

Finetuning Multilingual Model with IA3 + BitFit

In AutoMM, to enable efficient finetuning, just specify the optim.peft to be "ia3_bias".

from autogluon.multimodal import MultiModalPredictor
import uuid

model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3"
predictor = MultiModalPredictor(label="label",
                                path=model_path)
predictor.fit(train_en_df,
              presets="multilingual",
              hyperparameters={
                  "optim.peft": "ia3_bias",
                  "optim.lr_decay": 0.9,
                  "optim.lr": 3e-03,
                  "optim.end_lr": 3e-03,
                  "optim.max_epochs": 2,
                  "optim.warmup_steps": 0,
                  "env.batch_size": 32,
              })
=================== System Info ===================
AutoGluon Version:  1.3.1b20250509
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       28.40 GB / 30.95 GB (91.8%)
Disk Space Avail:   181.98 GB / 255.99 GB (71.1%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
	2 unique label values:  [np.int64(0), np.int64(1)]
	If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/ca00ccef04614de58ca93daf636dd10c-multilingual_ia3
    ```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                         | Params | Mode 
---------------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 278 M  | train
1 | validation_metric | BinaryAUROC                  | 0      | train
2 | loss_func         | CrossEntropyLoss             | 0      | train
---------------------------------------------------------------------------
122 K     Trainable params
278 M     Non-trainable params
278 M     Total params
1,112.955 Total estimated model params size (MB)
241       Modules in train mode
0         Modules in eval mode
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 7
      4 model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3"
      5 predictor = MultiModalPredictor(label="label",
      6                                 path=model_path)
----> 7 predictor.fit(train_en_df,
      8               presets="multilingual",
      9               hyperparameters={
     10                   "optim.peft": "ia3_bias",
     11                   "optim.lr_decay": 0.9,
     12                   "optim.lr": 3e-03,
     13                   "optim.end_lr": 3e-03,
     14                   "optim.max_epochs": 2,
     15                   "optim.warmup_steps": 0,
     16                   "env.batch_size": 32,
     17               })

File ~/autogluon/multimodal/src/autogluon/multimodal/predictor.py:540, in MultiModalPredictor.fit(self, train_data, presets, tuning_data, max_num_tuning_data, id_mappings, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_predictor, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, predictions, labels, predictors)
    537     assert isinstance(predictors, list)
    538     learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
--> 540 self._learner.fit(
    541     train_data=train_data,
    542     presets=presets,
    543     tuning_data=tuning_data,
    544     max_num_tuning_data=max_num_tuning_data,
    545     time_limit=time_limit,
    546     save_path=save_path,
    547     hyperparameters=hyperparameters,
    548     column_types=column_types,
    549     holdout_frac=holdout_frac,
    550     teacher_learner=teacher_learner,
    551     seed=seed,
    552     standalone=standalone,
    553     hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
    554     clean_ckpts=clean_ckpts,
    555     id_mappings=id_mappings,
    556     predictions=predictions,
    557     labels=labels,
    558     learners=learners,
    559 )
    561 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:665, in BaseLearner.fit(self, train_data, presets, tuning_data, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_learner, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, **kwargs)
    658 self.fit_sanity_check()
    659 self.prepare_fit_args(
    660     time_limit=time_limit,
    661     seed=seed,
    662     standalone=standalone,
    663     clean_ckpts=clean_ckpts,
    664 )
--> 665 fit_returns = self.execute_fit()
    666 self.on_fit_end(
    667     training_start=training_start,
    668     strategy=fit_returns.get("strategy", None),
   (...)
    671     clean_ckpts=clean_ckpts,
    672 )
    674 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:577, in BaseLearner.execute_fit(self)
    575     return dict()
    576 else:
--> 577     attributes = self.fit_per_run(**self._fit_args)
    578     self.update_attributes(**attributes)  # only update attributes for non-HPO mode
    579     return attributes

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1358, in BaseLearner.fit_per_run(self, max_time, save_path, ckpt_path, resume, enable_progress_bar, seed, hyperparameters, advanced_hyperparameters, config, df_preprocessor, data_processors, model, standalone, clean_ckpts)
   1339 config = self.post_update_config_per_run(
   1340     config=config,
   1341     num_gpus=num_gpus,
   1342     precision=precision,
   1343     strategy=strategy,
   1344 )
   1345 trainer = self.init_trainer_per_run(
   1346     num_gpus=num_gpus,
   1347     config=config,
   (...)
   1355     enable_progress_bar=enable_progress_bar,
   1356 )
-> 1358 self.run_trainer(
   1359     trainer=trainer,
   1360     litmodule=litmodule,
   1361     datamodule=datamodule,
   1362     ckpt_path=ckpt_path,
   1363     resume=resume,
   1364 )
   1365 self.on_fit_per_run_end(
   1366     save_path=save_path,
   1367     standalone=standalone,
   (...)
   1372     model=model,
   1373 )
   1375 best_score = (
   1376     trainer.callback_metrics[f"val_{self._validation_metric_name}"].item()
   1377     if f"val_{self._validation_metric_name}" in trainer.callback_metrics
   1378     else self._best_score
   1379 )  # https://github.com/autogluon/autogluon/issues/4428

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1211, in BaseLearner.run_trainer(self, trainer, litmodule, datamodule, ckpt_path, resume, pred_writer, is_train)
   1209     warnings.filterwarnings("ignore", filter)
   1210 if is_train:
-> 1211     trainer.fit(
   1212         litmodule,
   1213         datamodule=datamodule,
   1214         ckpt_path=ckpt_path if resume else None,  # this is to resume training that was broken accidentally
   1215     )
   1216 else:
   1217     blacklist_msgs = []

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:561, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    559 self.training = True
    560 self.should_stop = False
--> 561 call._call_and_handle_interrupt(
    562     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    563 )

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:48, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     46     if trainer.strategy.launcher is not None:
     47         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 48     return trainer_fn(*args, **kwargs)
     50 except _TunerExitException:
     51     _call_teardown_hook(trainer)

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:599, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    592     download_model_from_registry(ckpt_path, self)
    593 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    594     self.state.fn,
    595     ckpt_path,
    596     model_provided=True,
    597     model_connected=self.lightning_module is not None,
    598 )
--> 599 self._run(model, ckpt_path=ckpt_path)
    601 assert self.state.stopped
    602 self.training = False

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1012, in Trainer._run(self, model, ckpt_path)
   1007 self._signal_connector.register_signal_handlers()
   1009 # ----------------------------
   1010 # RUN THE TRAINER
   1011 # ----------------------------
-> 1012 results = self._run_stage()
   1014 # ----------------------------
   1015 # POST-Training CLEAN UP
   1016 # ----------------------------
   1017 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1054, in Trainer._run_stage(self)
   1052 if self.training:
   1053     with isolate_rng():
-> 1054         self._run_sanity_check()
   1055     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
   1056         self.fit_loop.run()

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1083, in Trainer._run_sanity_check(self)
   1080 call._call_callback_hooks(self, "on_sanity_check_start")
   1082 # run eval step
-> 1083 val_loop.run()
   1085 call._call_callback_hooks(self, "on_sanity_check_end")
   1087 # reset logger connector

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py:179, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    177     context_manager = torch.no_grad
    178 with context_manager():
--> 179     return loop_run(self, *args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:145, in _EvaluationLoop.run(self)
    143     self.batch_progress.is_last_batch = data_fetcher.done
    144     # run step hooks
--> 145     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
    146 except StopIteration:
    147     # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
    148     break

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:437, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
    431 hook_name = "test_step" if trainer.testing else "validation_step"
    432 step_args = (
    433     self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
    434     if not using_dataloader_iter
    435     else (dataloader_iter,)
    436 )
--> 437 output = call._call_strategy_hook(trainer, hook_name, *step_args)
    439 self.batch_progress.increment_processed()
    441 if using_dataloader_iter:
    442     # update the hook kwargs now that the step method might have consumed the iterator

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:328, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    325     return None
    327 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 328     output = fn(*args, **kwargs)
    330 # restore current_fx when nested context
    331 pl_module._current_fx_name = prev_fx_name

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:412, in Strategy.validation_step(self, *args, **kwargs)
    410 if self.model != self.lightning_module:
    411     return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 412 return self.lightning_module.validation_step(*args, **kwargs)

File ~/autogluon/multimodal/src/autogluon/multimodal/optim/lit_module.py:381, in LitModule.validation_step(self, batch, batch_idx)
    365 def validation_step(self, batch, batch_idx):
    366     """
    367     Per validation step. This function is registered by LightningModule.
    368     Refer to https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#validation-loop
   (...)
    379         Index of mini-batch.
    380     """
--> 381     output, loss = self._shared_step(batch)
    382     if self.model_postprocess_fn:
    383         output = self.model_postprocess_fn(output)

File ~/autogluon/multimodal/src/autogluon/multimodal/optim/lit_module.py:305, in LitModule._shared_step(self, batch)
    303     self.mixup_fn.mixup_enabled = self.training & (self.current_epoch < self.hparams.mixup_off_epoch)
    304     batch, label = multimodel_mixup(batch=batch, model=self.model, mixup_fn=self.mixup_fn)
--> 305 output = run_model(self.model, batch)
    306 loss = self._compute_loss(output=output, label=label)
    307 return output, loss

File ~/autogluon/multimodal/src/autogluon/multimodal/models/utils.py:865, in run_model(model, batch, trt_model)
    863         output_vec = pure_model(*tuple(input_vec))
    864     else:
--> 865         output_vec = model(*tuple(input_vec))
    867     output = pure_model.get_output_dict(*output_vec)
    868 else:

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/autogluon/multimodal/src/autogluon/multimodal/models/hf_text.py:230, in HFAutoModelForTextPrediction.forward(self, text_token_ids, text_segment_ids, text_valid_length, text_column_names, text_column_indices)
    228 else:
    229     if "token_type_ids" in self.tokenizer.model_input_names:
--> 230         outputs = self.model(
    231             input_ids=text_token_ids,
    232             token_type_ids=text_segment_ids,
    233             attention_mask=text_masks,
    234         )
    235     else:
    236         outputs = self.model(
    237             input_ids=text_token_ids,
    238             attention_mask=text_masks,
    239         )

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:870, in DebertaV2Model.forward(self, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds, output_attentions, output_hidden_states, return_dict)
    860     token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
    862 embedding_output = self.embeddings(
    863     input_ids=input_ids,
    864     token_type_ids=token_type_ids,
   (...)
    867     inputs_embeds=inputs_embeds,
    868 )
--> 870 encoder_outputs = self.encoder(
    871     embedding_output,
    872     attention_mask,
    873     output_hidden_states=True,
    874     output_attentions=output_attentions,
    875     return_dict=return_dict,
    876 )
    877 encoded_layers = encoder_outputs[1]
    879 if self.z_steps > 1:

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:674, in DebertaV2Encoder.forward(self, hidden_states, attention_mask, output_hidden_states, output_attentions, query_states, relative_pos, return_dict)
    664     output_states, attn_weights = self._gradient_checkpointing_func(
    665         layer_module.__call__,
    666         next_kv,
   (...)
    671         output_attentions,
    672     )
    673 else:
--> 674     output_states, attn_weights = layer_module(
    675         next_kv,
    676         attention_mask,
    677         query_states=query_states,
    678         relative_pos=relative_pos,
    679         rel_embeddings=rel_embeddings,
    680         output_attentions=output_attentions,
    681     )
    683 if output_attentions:
    684     all_attentions = all_attentions + (attn_weights,)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:442, in DebertaV2Layer.forward(self, hidden_states, attention_mask, query_states, relative_pos, rel_embeddings, output_attentions)
    433 def forward(
    434     self,
    435     hidden_states,
   (...)
    440     output_attentions: bool = False,
    441 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
--> 442     attention_output, att_matrix = self.attention(
    443         hidden_states,
    444         attention_mask,
    445         output_attentions=output_attentions,
    446         query_states=query_states,
    447         relative_pos=relative_pos,
    448         rel_embeddings=rel_embeddings,
    449     )
    450     intermediate_output = self.intermediate(attention_output)
    451     layer_output = self.output(intermediate_output, attention_output)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:375, in DebertaV2Attention.forward(self, hidden_states, attention_mask, output_attentions, query_states, relative_pos, rel_embeddings)
    366 def forward(
    367     self,
    368     hidden_states,
   (...)
    373     rel_embeddings=None,
    374 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
--> 375     self_output, att_matrix = self.self(
    376         hidden_states,
    377         attention_mask,
    378         output_attentions,
    379         query_states=query_states,
    380         relative_pos=relative_pos,
    381         rel_embeddings=rel_embeddings,
    382     )
    383     if query_states is None:
    384         query_states = hidden_states

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:267, in DisentangledSelfAttention.forward(self, hidden_states, attention_mask, output_attentions, query_states, relative_pos, rel_embeddings)
    262 attention_scores = attention_scores.view(
    263     -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
    264 )
    266 attention_mask = attention_mask.bool()
--> 267 attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
    268 # bsz x height x length x dimension
    269 attention_probs = nn.functional.softmax(attention_scores, dim=-1)

RuntimeError: value cannot be converted to type at::BFloat16 without overflow

The fraction of the tunable parameters is around 0.5% of all parameters. Actually, the model trained purely on English data can achieve good performance on the test sets, even on the German / Japanese test set. It obtained comparable results as full-finetuning as in AutoMM for Text - Multilingual Problems.

score_in_en = predictor.evaluate(test_en_df)
score_in_de = predictor.evaluate(test_de_df)
score_in_jp = predictor.evaluate(test_jp_df)
print('Score in the English Testset:', score_in_en)
print('Score in the German Testset:', score_in_de)
print('Score in the Japanese Testset:', score_in_jp)

Training FLAN-T5-XL on Single GPU

By combining gradient checkpointing and parameter-efficient finetuning, it is feasible to finetune google/flan-t5-xl that has close to two billion parameterswith a single T4 GPU available in AWS G4 instances. To turn on gradient checkpointing, you just need to set "model.hf_text.gradient_checkpointing" to True. To accelerate the training, we downsample the number of training samples to be 200.

# Just for clean the space
clear_cache()
shutil.rmtree(model_path)
from autogluon.multimodal import MultiModalPredictor

train_en_df_downsample = train_en_df.sample(200, random_state=123)

new_model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3_gradient_checkpoint"
predictor = MultiModalPredictor(label="label",
                                path=new_model_path)
predictor.fit(train_en_df_downsample,
              presets="multilingual",
              hyperparameters={
                  "model.hf_text.checkpoint_name": "google/flan-t5-xl",
                  "model.hf_text.gradient_checkpointing": True,
                  "model.hf_text.low_cpu_mem_usage": True,
                  "optim.peft": "ia3_bias",
                  "optim.lr_decay": 0.9,
                  "optim.lr": 3e-03,
                  "optim.end_lr": 3e-03,
                  "optim.max_epochs": 1,
                  "optim.warmup_steps": 0,
                  "env.batch_size": 1,
                  "env.inference_batch_size_ratio": 1
              })

Global seed set to 123
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                         | Params
-------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 1.2 B 
1 | validation_metric | AUROC                        | 0     
2 | loss_func         | CrossEntropyLoss             | 0     
-------------------------------------------------------------------
203 K     Trainable params
1.2 B     Non-trainable params
1.2 B     Total params
4,894.913 Total estimated model params size (MB)
Epoch 0, global step 20: 'val_roc_auc' reached 0.88802 (best 0.88802), saving model to '/home/ubuntu/autogluon/docs/tutorials/multimodal/advanced_topics/multilingual_ia3_gradient_checkpoint/epoch=0-step=20.ckpt' as top 1
Epoch 0, global step 40: 'val_roc_auc' reached 0.94531 (best 0.94531), saving model to '/home/ubuntu/autogluon/docs/tutorials/multimodal/advanced_topics/multilingual_ia3_gradient_checkpoint/epoch=0-step=40.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=1` reached.





<autogluon.multimodal.predictor.MultiModalPredictor at 0x7fd58c4dbca0>
score_in_en = predictor.evaluate(test_en_df)
print('Score in the English Testset:', score_in_en)
Score in the English Testset: {'roc_auc': 0.931263189629183}
# Just for clean the space
clear_cache()
shutil.rmtree(new_model_path)

Other Examples

You may go to AutoMM Examples to explore other examples about AutoMM.

Customization

To learn how to customize AutoMM, please refer to Customize AutoMM.