How to use FocalLoss#
In this tutorial, we introduce how to use FocalLoss
from the AutoMM package for balanced training.
FocalLoss is first introduced in this Paper
and can be used for balancing hard/easy samples as well as un-even sample distribution among classes. This tutorial demonstrates how to use FocalLoss
.
Create Dataset#
We use the shopee dataset for demonstration in this tutorial. Shopee dataset contains 4 classes and has 200 samples each in the training set.
from autogluon.multimodal.utils.misc import shopee_dataset
download_dir = "./ag_automm_tutorial_imgcls_focalloss"
train_data, test_data = shopee_dataset(download_dir)
Downloading ./ag_automm_tutorial_imgcls_focalloss/file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/shopee.zip...
100%|██████████| 41.9M/41.9M [00:00<00:00, 46.9MiB/s]
For the purpose of demonstrating the effectiveness of Focal Loss on imbalanced training data, we artificially downsampled the shopee training data to form an imbalanced distribution.
import numpy as np
import pandas as pd
ds = 1
imbalanced_train_data = []
for lb in range(4):
class_data = train_data[train_data.label == lb]
sample_index = np.random.choice(np.arange(len(class_data)), size=int(len(class_data) * ds), replace=False)
ds /= 3 # downsample 1/3 each time for each class
imbalanced_train_data.append(class_data.iloc[sample_index])
imbalanced_train_data = pd.concat(imbalanced_train_data)
print(imbalanced_train_data)
weights = []
for lb in range(4):
class_data = imbalanced_train_data[imbalanced_train_data.label == lb]
weights.append(1 / (class_data.shape[0] / imbalanced_train_data.shape[0]))
print(f"class {lb}: num samples {len(class_data)}")
weights = list(np.array(weights) / np.sum(weights))
print(weights)
image label
167 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
189 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
12 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
178 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
149 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
.. ... ...
725 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
783 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
653 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
771 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
708 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
[295 rows x 2 columns]
class 0: num samples 200
class 1: num samples 66
class 2: num samples 22
class 3: num samples 7
[0.0239850482815907, 0.07268196448966878, 0.21804589346900635, 0.6852870937597342]
Create and train MultiModalPredictor
#
Train with FocalLoss
#
We specify the model to use FocalLoss
by setting the "optimization.loss_function"
to "focal_loss"
.
There are also three other optional parameters you can set.
optimization.focal_loss.alpha
- a list of floats which is the per-class loss weight that can be used to balance un-even sample distribution across classes.
Note that the len
of the list must match the total number of classes in the training dataset. A good way to compute alpha
for each class is to use the inverse of its percentage number of samples.
optimization.focal_loss.gamma
- float which controls how much to focus on the hard samples. Larger value means more focus on the hard samples.
optimization.focal_loss.reduction
- how to aggregate the loss value. Can only take "mean"
or "sum"
for now.
import uuid
from autogluon.multimodal import MultiModalPredictor
model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_focal"
predictor = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)
predictor.fit(
hyperparameters={
"model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
"env.num_gpus": 1,
"optimization.loss_function": "focal_loss",
"optimization.focal_loss.alpha": weights, # shopee dataset has 4 classes.
"optimization.focal_loss.gamma": 1.0,
"optimization.focal_loss.reduction": "sum",
"optimization.max_epochs": 10,
},
train_data=imbalanced_train_data,
)
predictor.evaluate(test_data, metrics=["acc"])
Global seed set to 0
AutoMM starts to create your model. ✨
- AutoGluon version is 0.8.2b20230630.
- Pytorch version is 1.13.1+cu117.
- Model will be saved to "/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/5969a023506146b7a40d7fcaea602b42-automm_shopee_focal".
- Validation metric is "accuracy".
- To track the learning progress, you can open a terminal and launch Tensorboard:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/5969a023506146b7a40d7fcaea602b42-automm_shopee_focal
```
Enjoy your coffee, and let AutoMM do the job ☕☕☕ Learn more at https://auto.gluon.ai
/home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/config.py:552: UserWarning: Received loss function=focal_loss for problem=multiclass. Currently, we only support using BCE loss for regression problems and choose the loss_function automatically otherwise.
warnings.warn(
/home/ci/opt/venv/lib/python3.8/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
1 GPUs are detected, and 1 GPUs will be used.
- GPU 0 name: Tesla T4
- GPU 0 memory: 15.74GB/15.84GB (Free/Total)
CUDA version is 11.7.
Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
----------------------------------------------------------------------
0 | model | TimmAutoModelForImagePrediction | 86.7 M
1 | validation_metric | MulticlassAccuracy | 0
2 | loss_func | FocalLoss | 0
----------------------------------------------------------------------
86.7 M Trainable params
0 Non-trainable params
86.7 M Total params
173.495 Total estimated model params size (MB)
Epoch 0, global step 2: 'val_accuracy' reached 0.16949 (best 0.16949), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/5969a023506146b7a40d7fcaea602b42-automm_shopee_focal/epoch=0-step=2.ckpt' as top 3
Epoch 1, global step 4: 'val_accuracy' reached 0.57627 (best 0.57627), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/5969a023506146b7a40d7fcaea602b42-automm_shopee_focal/epoch=1-step=4.ckpt' as top 3
Epoch 2, global step 6: 'val_accuracy' reached 0.91525 (best 0.91525), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/5969a023506146b7a40d7fcaea602b42-automm_shopee_focal/epoch=2-step=6.ckpt' as top 3
Epoch 3, global step 8: 'val_accuracy' reached 0.96610 (best 0.96610), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/5969a023506146b7a40d7fcaea602b42-automm_shopee_focal/epoch=3-step=8.ckpt' as top 3
Epoch 4, global step 10: 'val_accuracy' reached 0.96610 (best 0.96610), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/5969a023506146b7a40d7fcaea602b42-automm_shopee_focal/epoch=4-step=10.ckpt' as top 3
Epoch 5, global step 12: 'val_accuracy' reached 0.98305 (best 0.98305), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/5969a023506146b7a40d7fcaea602b42-automm_shopee_focal/epoch=5-step=12.ckpt' as top 3
Epoch 6, global step 14: 'val_accuracy' reached 0.98305 (best 0.98305), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/5969a023506146b7a40d7fcaea602b42-automm_shopee_focal/epoch=6-step=14.ckpt' as top 3
Epoch 7, global step 16: 'val_accuracy' was not in top 3
Epoch 8, global step 18: 'val_accuracy' was not in top 3
Epoch 9, global step 20: 'val_accuracy' was not in top 3
`Trainer.fit` stopped: `max_epochs=10` reached.
Start to fuse 3 checkpoints via the greedy soup algorithm.
AutoMM has created your model 🎉🎉🎉
- To load the model, use the code below:
```python
from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/5969a023506146b7a40d7fcaea602b42-automm_shopee_focal")
```
- You can open a terminal and launch Tensorboard to visualize the training log:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/5969a023506146b7a40d7fcaea602b42-automm_shopee_focal
```
- If you are not satisfied with the model, try to increase the training time,
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub: https://github.com/autogluon/autogluon
{'acc': 0.8875}
Train without FocalLoss
#
import uuid
from autogluon.multimodal import MultiModalPredictor
model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_non_focal"
predictor2 = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)
predictor2.fit(
hyperparameters={
"model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
"env.num_gpus": 1,
"optimization.max_epochs": 10,
},
train_data=imbalanced_train_data,
)
predictor2.evaluate(test_data, metrics=["acc"])
Global seed set to 0
AutoMM starts to create your model. ✨
- AutoGluon version is 0.8.2b20230630.
- Pytorch version is 1.13.1+cu117.
- Model will be saved to "/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/821f57b10d5a401db90b113e0806c389-automm_shopee_non_focal".
- Validation metric is "accuracy".
- To track the learning progress, you can open a terminal and launch Tensorboard:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/821f57b10d5a401db90b113e0806c389-automm_shopee_non_focal
```
Enjoy your coffee, and let AutoMM do the job ☕☕☕ Learn more at https://auto.gluon.ai
1 GPUs are detected, and 1 GPUs will be used.
- GPU 0 name: Tesla T4
- GPU 0 memory: 14.73GB/15.84GB (Free/Total)
CUDA version is 11.7.
Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
----------------------------------------------------------------------
0 | model | TimmAutoModelForImagePrediction | 86.7 M
1 | validation_metric | MulticlassAccuracy | 0
2 | loss_func | CrossEntropyLoss | 0
----------------------------------------------------------------------
86.7 M Trainable params
0 Non-trainable params
86.7 M Total params
173.495 Total estimated model params size (MB)
Epoch 0, global step 2: 'val_accuracy' reached 0.66102 (best 0.66102), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/821f57b10d5a401db90b113e0806c389-automm_shopee_non_focal/epoch=0-step=2.ckpt' as top 3
Epoch 1, global step 4: 'val_accuracy' reached 0.67797 (best 0.67797), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/821f57b10d5a401db90b113e0806c389-automm_shopee_non_focal/epoch=1-step=4.ckpt' as top 3
Epoch 2, global step 6: 'val_accuracy' reached 0.74576 (best 0.74576), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/821f57b10d5a401db90b113e0806c389-automm_shopee_non_focal/epoch=2-step=6.ckpt' as top 3
Epoch 3, global step 8: 'val_accuracy' reached 0.94915 (best 0.94915), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/821f57b10d5a401db90b113e0806c389-automm_shopee_non_focal/epoch=3-step=8.ckpt' as top 3
Epoch 4, global step 10: 'val_accuracy' reached 0.96610 (best 0.96610), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/821f57b10d5a401db90b113e0806c389-automm_shopee_non_focal/epoch=4-step=10.ckpt' as top 3
Epoch 5, global step 12: 'val_accuracy' reached 0.96610 (best 0.96610), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/821f57b10d5a401db90b113e0806c389-automm_shopee_non_focal/epoch=5-step=12.ckpt' as top 3
Epoch 6, global step 14: 'val_accuracy' reached 0.96610 (best 0.96610), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/821f57b10d5a401db90b113e0806c389-automm_shopee_non_focal/epoch=6-step=14.ckpt' as top 3
Epoch 7, global step 16: 'val_accuracy' was not in top 3
Epoch 8, global step 18: 'val_accuracy' was not in top 3
Epoch 9, global step 20: 'val_accuracy' was not in top 3
`Trainer.fit` stopped: `max_epochs=10` reached.
Start to fuse 3 checkpoints via the greedy soup algorithm.
AutoMM has created your model 🎉🎉🎉
- To load the model, use the code below:
```python
from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/821f57b10d5a401db90b113e0806c389-automm_shopee_non_focal")
```
- You can open a terminal and launch Tensorboard to visualize the training log:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/821f57b10d5a401db90b113e0806c389-automm_shopee_non_focal
```
- If you are not satisfied with the model, try to increase the training time,
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub: https://github.com/autogluon/autogluon
{'acc': 0.7375}
As we can see that the model with FocalLoss
is able to achieve a much better performance compared to the model without FocalLoss
.
When your data is imbalanced, try out FocalLoss
to see if it brings improvements to the performance!
Citations#
@misc{https://doi.org/10.48550/arxiv.1708.02002,
doi = {10.48550/ARXIV.1708.02002},
url = {https://arxiv.org/abs/1708.02002},
author = {Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Dollár, Piotr},
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Focal Loss for Dense Object Detection},
publisher = {arXiv},
year = {2017},
copyright = {arXiv.org perpetual, non-exclusive license}
}