Handling Class Imbalance with AutoMM - Focal Loss¶
In this tutorial, we introduce how to use focal loss with the AutoMM package for balanced training. Focal loss is first introduced in this Paper and can be used for balancing hard/easy samples as well as un-even sample distribution among classes. This tutorial demonstrates how to use focal loss.
Create Dataset¶
We use the shopee dataset for demonstration in this tutorial. Shopee dataset contains 4 classes and has 200 samples each in the training set.
from autogluon.multimodal.utils.misc import shopee_dataset
download_dir = "./ag_automm_tutorial_imgcls_focalloss"
train_data, test_data = shopee_dataset(download_dir)
Downloading ./ag_automm_tutorial_imgcls_focalloss/file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/shopee.zip...
0%| | 0.00/84.0M [00:00<?, ?iB/s]
8%|▊ | 6.61M/84.0M [00:00<00:02, 37.9MiB/s]
12%|█▏ | 10.4M/84.0M [00:00<00:02, 25.6MiB/s]
19%|█▊ | 15.6M/84.0M [00:00<00:02, 33.7MiB/s]
23%|██▎ | 19.5M/84.0M [00:00<00:01, 34.9MiB/s]
30%|██▉ | 24.8M/84.0M [00:00<00:01, 40.4MiB/s]
35%|███▍ | 29.2M/84.0M [00:00<00:01, 35.7MiB/s]
40%|███▉ | 33.5M/84.0M [00:01<00:01, 31.6MiB/s]
48%|████▊ | 40.2M/84.0M [00:01<00:01, 29.6MiB/s]
52%|█████▏ | 43.3M/84.0M [00:01<00:01, 30.1MiB/s]
60%|█████▉ | 50.3M/84.0M [00:01<00:01, 32.6MiB/s]
70%|██████▉ | 58.7M/84.0M [00:01<00:00, 37.5MiB/s]
78%|███████▊ | 65.3M/84.0M [00:01<00:00, 42.5MiB/s]
83%|████████▎ | 69.8M/84.0M [00:01<00:00, 37.6MiB/s]
88%|████████▊ | 74.3M/84.0M [00:02<00:00, 39.2MiB/s]
93%|█████████▎| 78.4M/84.0M [00:02<00:00, 31.9MiB/s]
98%|█████████▊| 82.1M/84.0M [00:02<00:00, 27.1MiB/s]
100%|██████████| 84.0M/84.0M [00:02<00:00, 31.8MiB/s]
For the purpose of demonstrating the effectiveness of Focal Loss on imbalanced training data, we artificially downsampled the shopee training data to form an imbalanced distribution.
import numpy as np
import pandas as pd
ds = 1
imbalanced_train_data = []
for lb in range(4):
class_data = train_data[train_data.label == lb]
sample_index = np.random.choice(np.arange(len(class_data)), size=int(len(class_data) * ds), replace=False)
ds /= 3 # downsample 1/3 each time for each class
imbalanced_train_data.append(class_data.iloc[sample_index])
imbalanced_train_data = pd.concat(imbalanced_train_data)
print(imbalanced_train_data)
weights = []
for lb in range(4):
class_data = imbalanced_train_data[imbalanced_train_data.label == lb]
weights.append(1 / (class_data.shape[0] / imbalanced_train_data.shape[0]))
print(f"class {lb}: num samples {len(class_data)}")
weights = list(np.array(weights) / np.sum(weights))
print(weights)
image label
56 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
18 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
83 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
91 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
141 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
.. ... ...
623 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
788 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
658 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
702 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
676 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
[295 rows x 2 columns]
class 0: num samples 200
class 1: num samples 66
class 2: num samples 22
class 3: num samples 7
[np.float64(0.0239850482815907), np.float64(0.07268196448966878), np.float64(0.21804589346900635), np.float64(0.6852870937597342)]
Create and train MultiModalPredictor
¶
Train with Focal Loss¶
We specify the model to use focal loss by setting the "optim.loss_func"
to "focal_loss"
.
There are also three other optional parameters you can set.
optim.focal_loss.alpha
- a list of floats which is the per-class loss weight that can be used to balance un-even sample distribution across classes.
Note that the len
of the list must match the total number of classes in the training dataset. A good way to compute alpha
for each class is to use the inverse of its percentage number of samples.
optim.focal_loss.gamma
- float which controls how much to focus on the hard samples. Larger value means more focus on the hard samples.
optim.focal_loss.reduction
- how to aggregate the loss value. Can only take "mean"
or "sum"
for now.
import uuid
from autogluon.multimodal import MultiModalPredictor
model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_focal"
predictor = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)
predictor.fit(
hyperparameters={
"model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
"env.num_gpus": 1,
"optim.loss_func": "focal_loss",
"optim.focal_loss.alpha": weights, # shopee dataset has 4 classes.
"optim.focal_loss.gamma": 1.0,
"optim.focal_loss.reduction": "sum",
"optim.max_epochs": 10,
},
train_data=imbalanced_train_data,
)
predictor.evaluate(test_data, metrics=["acc"])
=================== System Info ===================
AutoGluon Version: 1.3.1b20250509
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
Pytorch Version: 2.6.0+cu124
CUDA Version: 12.4
Memory Avail: 28.40 GB / 30.95 GB (91.8%)
Disk Space Avail: 166.39 GB / 255.99 GB (65.0%)
===================================================
AutoMM starts to create your model. ✨✨✨
To track the learning progress, you can open a terminal and launch Tensorboard:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/6176658ceffc4f59a78127852ce7cea7-automm_shopee_focal
```
Seed set to 0
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[4], line 8
4 model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_focal"
6 predictor = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)
----> 8 predictor.fit(
9 hyperparameters={
10 "model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
11 "env.num_gpus": 1,
12 "optim.loss_func": "focal_loss",
13 "optim.focal_loss.alpha": weights, # shopee dataset has 4 classes.
14 "optim.focal_loss.gamma": 1.0,
15 "optim.focal_loss.reduction": "sum",
16 "optim.max_epochs": 10,
17 },
18 train_data=imbalanced_train_data,
19 )
21 predictor.evaluate(test_data, metrics=["acc"])
File ~/autogluon/multimodal/src/autogluon/multimodal/predictor.py:540, in MultiModalPredictor.fit(self, train_data, presets, tuning_data, max_num_tuning_data, id_mappings, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_predictor, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, predictions, labels, predictors)
537 assert isinstance(predictors, list)
538 learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
--> 540 self._learner.fit(
541 train_data=train_data,
542 presets=presets,
543 tuning_data=tuning_data,
544 max_num_tuning_data=max_num_tuning_data,
545 time_limit=time_limit,
546 save_path=save_path,
547 hyperparameters=hyperparameters,
548 column_types=column_types,
549 holdout_frac=holdout_frac,
550 teacher_learner=teacher_learner,
551 seed=seed,
552 standalone=standalone,
553 hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
554 clean_ckpts=clean_ckpts,
555 id_mappings=id_mappings,
556 predictions=predictions,
557 labels=labels,
558 learners=learners,
559 )
561 return self
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:665, in BaseLearner.fit(self, train_data, presets, tuning_data, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_learner, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, **kwargs)
658 self.fit_sanity_check()
659 self.prepare_fit_args(
660 time_limit=time_limit,
661 seed=seed,
662 standalone=standalone,
663 clean_ckpts=clean_ckpts,
664 )
--> 665 fit_returns = self.execute_fit()
666 self.on_fit_end(
667 training_start=training_start,
668 strategy=fit_returns.get("strategy", None),
(...)
671 clean_ckpts=clean_ckpts,
672 )
674 return self
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:577, in BaseLearner.execute_fit(self)
575 return dict()
576 else:
--> 577 attributes = self.fit_per_run(**self._fit_args)
578 self.update_attributes(**attributes) # only update attributes for non-HPO mode
579 return attributes
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1292, in BaseLearner.fit_per_run(self, max_time, save_path, ckpt_path, resume, enable_progress_bar, seed, hyperparameters, advanced_hyperparameters, config, df_preprocessor, data_processors, model, standalone, clean_ckpts)
1290 validation_metric, custom_metric_func = self.get_validation_metric_per_run()
1291 mixup_active, mixup_func = self.get_mixup_func_per_run(config=config)
-> 1292 loss_func, aug_loss_func = self.get_loss_func_per_run(config=config, mixup_active=mixup_active)
1293 model_postprocess_fn = self.get_model_postprocess_fn_per_run(loss_func=loss_func)
1294 num_gpus, strategy = self.get_num_gpus_and_strategy_per_run(config=config)
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:846, in BaseLearner.get_loss_func_per_run(self, config, mixup_active)
845 def get_loss_func_per_run(self, config, mixup_active=None):
--> 846 loss_func = get_loss_func(
847 problem_type=self._problem_type,
848 mixup_active=mixup_active,
849 loss_func_name=config.optim.loss_func,
850 config=config.optim,
851 )
852 aug_loss_func = get_aug_loss_func(
853 config=config.optim,
854 problem_type=self._problem_type,
855 )
856 return loss_func, aug_loss_func
File ~/autogluon/multimodal/src/autogluon/multimodal/optim/losses/utils.py:63, in get_loss_func(problem_type, mixup_active, loss_func_name, config, **kwargs)
61 else:
62 if loss_func_name is not None and loss_func_name.lower() == "focal_loss":
---> 63 loss_func = FocalLoss(
64 alpha=config.focal_loss.alpha,
65 gamma=config.focal_loss.gamma,
66 reduction=config.focal_loss.reduction,
67 )
68 else:
69 loss_func = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
File ~/autogluon/multimodal/src/autogluon/multimodal/optim/losses/focal_loss.py:49, in FocalLoss.__init__(self, alpha, gamma, reduction, eps)
47 except:
48 raise ValueError(f"{type(alpha)} {alpha} is not in a supported format.")
---> 49 alpha = torch.tensor(alpha)
50 self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none")
ValueError: too many dimensions 'str'
Train without Focal Loss¶
import uuid
from autogluon.multimodal import MultiModalPredictor
model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_non_focal"
predictor2 = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)
predictor2.fit(
hyperparameters={
"model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
"env.num_gpus": 1,
"optim.max_epochs": 10,
},
train_data=imbalanced_train_data,
)
predictor2.evaluate(test_data, metrics=["acc"])
As we can see that the model with focal loss is able to achieve a much better performance compared to the model without focal loss. When your data is imbalanced, try out focal loss to see if it brings improvements to the performance!
Citations¶
@misc{https://doi.org/10.48550/arxiv.1708.02002,
doi = {10.48550/ARXIV.1708.02002},
url = {https://arxiv.org/abs/1708.02002},
author = {Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Dollár, Piotr},
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Focal Loss for Dense Object Detection},
publisher = {arXiv},
year = {2017},
copyright = {arXiv.org perpetual, non-exclusive license}
}