How to use FocalLoss¶
In this tutorial, we introduce how to use FocalLoss
from the AutoMM
package for balanced training. FocalLoss is first introduced in this
Paper and can be used for
balancing hard/easy samples as well as un-even sample distribution among
classes. This tutorial demonstrates how to use FocalLoss
.
Create Dataset¶
We use the shopee dataset for demonstration in this tutorial. Shopee dataset contains 4 classes and has 200 samples each in the training set.
from autogluon.multimodal.utils.misc import shopee_dataset
download_dir = "./ag_automm_tutorial_imgcls_focalloss"
train_data, test_data = shopee_dataset(download_dir)
Downloading ./ag_automm_tutorial_imgcls_focalloss/file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/shopee.zip...
100%|██████████| 41.9M/41.9M [00:00<00:00, 78.3MiB/s]
For the purpose of demonstrating the effectiveness of Focal Loss on imbalanced training data, we artificially downsampled the shopee training data to form an imbalanced distribution.
ds = 1
import numpy as np
import pandas as pd
imbalanced_train_data = []
for lb in range(4):
class_data = train_data[train_data.label == lb]
sample_index = np.random.choice(np.arange(len(class_data)), size=int(len(class_data) * ds), replace=False)
ds /= 3 # downsample 1/3 each time for each class
imbalanced_train_data.append(class_data.iloc[sample_index])
imbalanced_train_data = pd.concat(imbalanced_train_data)
print(imbalanced_train_data)
weights = []
for lb in range(4):
class_data = imbalanced_train_data[imbalanced_train_data.label == lb]
weights.append(1 / (class_data.shape[0] / imbalanced_train_data.shape[0]))
print(f"class {lb}: num samples {len(class_data)}")
weights = list(np.array(weights) / np.sum(weights))
print(weights)
image label
128 /home/ci/autogluon/docs/_build/eval/tutorials/... 0
164 /home/ci/autogluon/docs/_build/eval/tutorials/... 0
142 /home/ci/autogluon/docs/_build/eval/tutorials/... 0
45 /home/ci/autogluon/docs/_build/eval/tutorials/... 0
155 /home/ci/autogluon/docs/_build/eval/tutorials/... 0
.. ... ...
703 /home/ci/autogluon/docs/_build/eval/tutorials/... 3
791 /home/ci/autogluon/docs/_build/eval/tutorials/... 3
667 /home/ci/autogluon/docs/_build/eval/tutorials/... 3
649 /home/ci/autogluon/docs/_build/eval/tutorials/... 3
616 /home/ci/autogluon/docs/_build/eval/tutorials/... 3
[295 rows x 2 columns]
class 0: num samples 200
class 1: num samples 66
class 2: num samples 22
class 3: num samples 7
[0.0239850482815907, 0.07268196448966878, 0.21804589346900635, 0.6852870937597342]
Create and train MultiModalPredictor
¶
Train with FocalLoss
¶
We specify the model to use FocalLoss
by setting the
"optimization.loss_function"
to "focal_loss"
. There are also
three other optional parameters you can set.
optimization.focal_loss.alpha
- a list of floats which is the
per-class loss weight that can be used to balance un-even sample
distribution across classes. Note that the len
of the list must
match the total number of classes in the training dataset. A good way to
compute alpha
for each class is to use the inverse of its percentage
number of samples.
optimization.focal_loss.gamma
- float which controls how much to
focus on the hard samples. Larger value means more focus on the hard
samples.
optimization.focal_loss.reduction
- how to aggregate the loss value.
Can only take "mean"
or "sum"
for now.
import uuid
from autogluon.multimodal import MultiModalPredictor
model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_focal"
predictor = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)
predictor.fit(
hyperparameters={
"model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
"env.num_gpus": 1,
"optimization.loss_function": "focal_loss",
"optimization.focal_loss.alpha": weights, # shopee dataset has 4 classes.
"optimization.focal_loss.gamma": 1.0,
"optimization.focal_loss.reduction": "sum",
"optimization.max_epochs": 10,
},
train_data=imbalanced_train_data,
)
predictor.evaluate(test_data, metrics=["acc"])
Global seed set to 123 AutoMM starts to create your model. ✨ - Model will be saved to "/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal". - Validation metric is "accuracy". - To track the learning progress, you can open a terminal and launch Tensorboard:`shell # Assume you have installed tensorboard tensorboard --logdir /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal `
Enjoy your coffee, and let AutoMM do the job ☕☕☕ Learn more at https://auto.gluon.ai /home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/config.py:536: UserWarning: Received loss function=focal_loss for problem=multiclass. Currently, we only support using BCE loss for regression problems and choose the loss_function automatically otherwise. warnings.warn( /home/ci/opt/venv/lib/python3.8/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] Using 16bit None Automatic Mixed Precision (AMP) GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ---------------------------------------------------------------------- 0 | model | TimmAutoModelForImagePrediction | 86.7 M 1 | validation_metric | Accuracy | 0 2 | loss_func | FocalLoss | 0 ---------------------------------------------------------------------- 86.7 M Trainable params 0 Non-trainable params 86.7 M Total params 173.495 Total estimated model params size (MB) Epoch 0, global step 2: 'val_accuracy' reached 0.32203 (best 0.32203), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal/epoch=0-step=2.ckpt' as top 3 Epoch 1, global step 2: 'val_accuracy' reached 0.32203 (best 0.32203), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal/epoch=1-step=2.ckpt' as top 3 Epoch 1, global step 4: 'val_accuracy' reached 0.71186 (best 0.71186), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal/epoch=1-step=4.ckpt' as top 3 Epoch 2, global step 4: 'val_accuracy' reached 0.71186 (best 0.71186), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal/epoch=2-step=4.ckpt' as top 3 Epoch 2, global step 6: 'val_accuracy' reached 0.93220 (best 0.93220), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal/epoch=2-step=6.ckpt' as top 3 Epoch 3, global step 6: 'val_accuracy' reached 0.93220 (best 0.93220), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal/epoch=3-step=6.ckpt' as top 3 Epoch 3, global step 8: 'val_accuracy' reached 0.94915 (best 0.94915), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal/epoch=3-step=8.ckpt' as top 3 Epoch 4, global step 8: 'val_accuracy' reached 0.94915 (best 0.94915), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal/epoch=4-step=8.ckpt' as top 3 Epoch 4, global step 10: 'val_accuracy' reached 0.96610 (best 0.96610), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal/epoch=4-step=10.ckpt' as top 3 Epoch 5, global step 10: 'val_accuracy' reached 0.96610 (best 0.96610), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal/epoch=5-step=10.ckpt' as top 3 Epoch 5, global step 12: 'val_accuracy' reached 0.96610 (best 0.96610), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal/epoch=5-step=12.ckpt' as top 3 Epoch 6, global step 12: 'val_accuracy' was not in top 3 Epoch 6, global step 14: 'val_accuracy' was not in top 3 Epoch 7, global step 14: 'val_accuracy' was not in top 3 Epoch 7, global step 16: 'val_accuracy' was not in top 3 Epoch 8, global step 16: 'val_accuracy' was not in top 3 Epoch 8, global step 18: 'val_accuracy' was not in top 3 Epoch 9, global step 18: 'val_accuracy' was not in top 3 Epoch 9, global step 20: 'val_accuracy' was not in top 3 Trainer.fit stopped: max_epochs=10 reached. Start to fuse 3 checkpoints via the greedy soup algorithm. AutoMM has created your model 🎉🎉🎉 - To load the model, use the code below:`python from autogluon.multimodal import MultiModalPredictor predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal") `
- You can open a terminal and launch Tensorboard to visualize the training log:`shell # Assume you have installed tensorboard tensorboard --logdir /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/d81483307de24438b658cc013c48a808-automm_shopee_focal `
- If you are not satisfied with the model, try to increase the training time, adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html), or post issues on GitHub: https://github.com/autogluon/autogluon
{'acc': 0.875}
Train without FocalLoss
¶
import uuid
from autogluon.multimodal import MultiModalPredictor
model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_non_focal"
predictor2 = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)
predictor2.fit(
hyperparameters={
"model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
"env.num_gpus": 1,
"optimization.max_epochs": 10,
},
train_data=imbalanced_train_data,
)
predictor2.evaluate(test_data, metrics=["acc"])
Global seed set to 123 AutoMM starts to create your model. ✨ - Model will be saved to "/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal". - Validation metric is "accuracy". - To track the learning progress, you can open a terminal and launch Tensorboard:`shell # Assume you have installed tensorboard tensorboard --logdir /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal `
Enjoy your coffee, and let AutoMM do the job ☕☕☕ Learn more at https://auto.gluon.ai Using 16bit None Automatic Mixed Precision (AMP) GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ---------------------------------------------------------------------- 0 | model | TimmAutoModelForImagePrediction | 86.7 M 1 | validation_metric | Accuracy | 0 2 | loss_func | CrossEntropyLoss | 0 ---------------------------------------------------------------------- 86.7 M Trainable params 0 Non-trainable params 86.7 M Total params 173.495 Total estimated model params size (MB) Epoch 0, global step 2: 'val_accuracy' reached 0.71186 (best 0.71186), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal/epoch=0-step=2.ckpt' as top 3 Epoch 1, global step 2: 'val_accuracy' reached 0.71186 (best 0.71186), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal/epoch=1-step=2.ckpt' as top 3 Epoch 1, global step 4: 'val_accuracy' reached 0.69492 (best 0.71186), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal/epoch=1-step=4.ckpt' as top 3 Epoch 2, global step 4: 'val_accuracy' was not in top 3 Epoch 2, global step 6: 'val_accuracy' reached 0.83051 (best 0.83051), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal/epoch=2-step=6.ckpt' as top 3 Epoch 3, global step 6: 'val_accuracy' reached 0.83051 (best 0.83051), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal/epoch=3-step=6.ckpt' as top 3 Epoch 3, global step 8: 'val_accuracy' reached 0.91525 (best 0.91525), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal/epoch=3-step=8.ckpt' as top 3 Epoch 4, global step 8: 'val_accuracy' reached 0.91525 (best 0.91525), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal/epoch=4-step=8.ckpt' as top 3 Epoch 4, global step 10: 'val_accuracy' reached 0.93220 (best 0.93220), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal/epoch=4-step=10.ckpt' as top 3 Epoch 5, global step 10: 'val_accuracy' reached 0.93220 (best 0.93220), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal/epoch=5-step=10.ckpt' as top 3 Epoch 5, global step 12: 'val_accuracy' reached 0.93220 (best 0.93220), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal/epoch=5-step=12.ckpt' as top 3 Epoch 6, global step 12: 'val_accuracy' was not in top 3 Epoch 6, global step 14: 'val_accuracy' was not in top 3 Epoch 7, global step 14: 'val_accuracy' was not in top 3 Epoch 7, global step 16: 'val_accuracy' was not in top 3 Epoch 8, global step 16: 'val_accuracy' was not in top 3 Epoch 8, global step 18: 'val_accuracy' was not in top 3 Epoch 9, global step 18: 'val_accuracy' was not in top 3 Epoch 9, global step 20: 'val_accuracy' was not in top 3 Trainer.fit stopped: max_epochs=10 reached. Start to fuse 3 checkpoints via the greedy soup algorithm. AutoMM has created your model 🎉🎉🎉 - To load the model, use the code below:`python from autogluon.multimodal import MultiModalPredictor predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal") `
- You can open a terminal and launch Tensorboard to visualize the training log:`shell # Assume you have installed tensorboard tensorboard --logdir /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/advanced_topics/tmp/2a7bbda8d1c0470bb00081e55a65f6ad-automm_shopee_non_focal `
- If you are not satisfied with the model, try to increase the training time, adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html), or post issues on GitHub: https://github.com/autogluon/autogluon
{'acc': 0.7125}
As we can see that the model with FocalLoss
is able to achieve a
much better performance compared to the model without FocalLoss
.
When your data is imbalanced, try out FocalLoss
to see if it brings
improvements to the performance!
Citations¶
@misc{https://doi.org/10.48550/arxiv.1708.02002,
doi = {10.48550/ARXIV.1708.02002},
url = {https://arxiv.org/abs/1708.02002},
author = {Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Dollár, Piotr},
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Focal Loss for Dense Object Detection},
publisher = {arXiv},
year = {2017},
copyright = {arXiv.org perpetual, non-exclusive license}
}