{ "cells": [ { "cell_type": "markdown", "id": "b3aad2e0-3b9a-4761-ae52-12fb4f29143a", "metadata": {}, "source": [ "# Handling Class Imbalance with AutoMM - Focal Loss\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/autogluon/autogluon/blob/stable/docs/tutorials/multimodal/advanced_topics/focal_loss.ipynb)\n", "[![Open In SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/autogluon/autogluon/blob/stable/docs/tutorials/multimodal/advanced_topics/focal_loss.ipynb)\n", "\n", "\n", "In this tutorial, we introduce how to use focal loss with the AutoMM package for balanced training.\n", "Focal loss is first introduced in this [Paper](https://arxiv.org/abs/1708.02002)\n", "and can be used for balancing hard/easy samples as well as un-even sample distribution among classes. This tutorial demonstrates how to use focal loss." ] }, { "cell_type": "markdown", "id": "1a764799-e835-420e-b51e-81dcfbda884c", "metadata": {}, "source": [ "## Create Dataset\n", "We use the shopee dataset for demonstration in this tutorial. Shopee dataset contains 4 classes and has 200 samples each in the training set." ] }, { "cell_type": "code", "execution_count": null, "id": "aa00faab-252f-44c9-b8f7-57131aa8251c", "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "!pip install autogluon.multimodal\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2c9c2134-bc31-4611-ab58-31701d8afdbe", "metadata": {}, "outputs": [], "source": [ "from autogluon.multimodal.utils.misc import shopee_dataset\n", "\n", "download_dir = \"./ag_automm_tutorial_imgcls_focalloss\"\n", "train_data, test_data = shopee_dataset(download_dir)" ] }, { "cell_type": "markdown", "id": "6965f254-5d25-4ee0-a58f-3c5b275d4ae9", "metadata": {}, "source": [ "For the purpose of demonstrating the effectiveness of Focal Loss on imbalanced training data, we artificially downsampled the shopee \n", "training data to form an imbalanced distribution." ] }, { "cell_type": "code", "execution_count": null, "id": "918c3f82-06da-45c6-af8f-e847bbbc3e79", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "ds = 1\n", "\n", "imbalanced_train_data = []\n", "for lb in range(4):\n", " class_data = train_data[train_data.label == lb]\n", " sample_index = np.random.choice(np.arange(len(class_data)), size=int(len(class_data) * ds), replace=False)\n", " ds /= 3 # downsample 1/3 each time for each class\n", " imbalanced_train_data.append(class_data.iloc[sample_index])\n", "imbalanced_train_data = pd.concat(imbalanced_train_data)\n", "print(imbalanced_train_data)\n", "\n", "weights = []\n", "for lb in range(4):\n", " class_data = imbalanced_train_data[imbalanced_train_data.label == lb]\n", " weights.append(1 / (class_data.shape[0] / imbalanced_train_data.shape[0]))\n", " print(f\"class {lb}: num samples {len(class_data)}\")\n", "weights = list(np.array(weights) / np.sum(weights))\n", "print(weights)" ] }, { "cell_type": "markdown", "id": "5ee803f4-06be-4e39-8347-4fe57992f2f8", "metadata": {}, "source": [ "## Create and train `MultiModalPredictor`\n", "\n", "### Train with Focal Loss\n", "We specify the model to use focal loss by setting the `\"optim.loss_func\"` to `\"focal_loss\"`.\n", "There are also three other optional parameters you can set.\n", "\n", "`optim.focal_loss.alpha` - a list of floats which is the per-class loss weight that can be used to balance un-even sample distribution across classes.\n", "Note that the `len` of the list ***must*** match the total number of classes in the training dataset. A good way to compute `alpha` for each class is to use the inverse of its percentage number of samples.\n", "\n", "`optim.focal_loss.gamma` - float which controls how much to focus on the hard samples. Larger value means more focus on the hard samples.\n", "\n", "`optim.focal_loss.reduction` - how to aggregate the loss value. Can only take `\"mean\"` or `\"sum\"` for now." ] }, { "cell_type": "code", "execution_count": null, "id": "bcb4bb3f-47ee-4772-bd3b-3ba00c76cd60", "metadata": {}, "outputs": [], "source": [ "import uuid\n", "from autogluon.multimodal import MultiModalPredictor\n", "\n", "model_path = f\"./tmp/{uuid.uuid4().hex}-automm_shopee_focal\"\n", "\n", "predictor = MultiModalPredictor(label=\"label\", problem_type=\"multiclass\", path=model_path)\n", "\n", "predictor.fit(\n", " hyperparameters={\n", " \"model.mmdet_image.checkpoint_name\": \"swin_tiny_patch4_window7_224\",\n", " \"env.num_gpus\": 1,\n", " \"optim.loss_func\": \"focal_loss\",\n", " \"optim.focal_loss.alpha\": weights, # shopee dataset has 4 classes.\n", " \"optim.focal_loss.gamma\": 1.0,\n", " \"optim.focal_loss.reduction\": \"sum\",\n", " \"optim.max_epochs\": 10,\n", " },\n", " train_data=imbalanced_train_data,\n", ") \n", "\n", "predictor.evaluate(test_data, metrics=[\"acc\"])" ] }, { "cell_type": "markdown", "id": "d0ae5688-e0f5-4a62-a36f-87ce772ea8ef", "metadata": {}, "source": [ "### Train without Focal Loss" ] }, { "cell_type": "code", "execution_count": null, "id": "fadf7437-53b7-454d-81e0-596732f355a3", "metadata": {}, "outputs": [], "source": [ "import uuid\n", "from autogluon.multimodal import MultiModalPredictor\n", "\n", "model_path = f\"./tmp/{uuid.uuid4().hex}-automm_shopee_non_focal\"\n", "\n", "predictor2 = MultiModalPredictor(label=\"label\", problem_type=\"multiclass\", path=model_path)\n", "\n", "predictor2.fit(\n", " hyperparameters={\n", " \"model.mmdet_image.checkpoint_name\": \"swin_tiny_patch4_window7_224\",\n", " \"env.num_gpus\": 1,\n", " \"optim.max_epochs\": 10,\n", " },\n", " train_data=imbalanced_train_data,\n", ")\n", "\n", "predictor2.evaluate(test_data, metrics=[\"acc\"])" ] }, { "cell_type": "markdown", "id": "ddb15b82-3af8-4837-8f2e-654dcc561e9d", "metadata": {}, "source": [ "As we can see that the model with focal loss is able to achieve a much better performance compared to the model without focal loss.\n", "When your data is imbalanced, try out focal loss to see if it brings improvements to the performance!\n", "\n", "## Citations\n", "\n", "```\n", "@misc{https://doi.org/10.48550/arxiv.1708.02002,\n", " doi = {10.48550/ARXIV.1708.02002},\n", " \n", " url = {https://arxiv.org/abs/1708.02002},\n", " \n", " author = {Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Dollár, Piotr},\n", " \n", " keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},\n", " \n", " title = {Focal Loss for Dense Object Detection},\n", " \n", " publisher = {arXiv},\n", " \n", " year = {2017},\n", " \n", " copyright = {arXiv.org perpetual, non-exclusive license}\n", "}\n", "\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "d32fa55f-8a21-41dd-bcb9-b4b9922d7d61", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }