AutoMM Detection - High Performance Finetune on COCO Format Dataset

https://automl-mm-bench.s3.amazonaws.com/object_detection/example_image/pothole144_gt.jpg

Fig. 3 Pothole Dataset

In this section, our goal is to fast finetune and evaluate a pretrained model on Pothole dataset in COCO format. Pothole is a single object, i.e. pothole, detection dataset, containing 665 images with bounding box annotations for the creation of detection models and can work as POC/POV for the maintenance of roads. See AutoMM Detection - Prepare Pothole Dataset for how to prepare Pothole dataset.

To start, let’s import MultiModalPredictor:

from autogluon.multimodal import MultiModalPredictor

Make sure mmcv-full and mmdet are installed:

!mim install mmcv-full
!pip install mmdet
Looking in links: https://download.openmmlab.com/mmcv/dist/cu117/torch1.13.0/index.html
Requirement already satisfied: mmcv-full in /home/ci/opt/venv/lib/python3.8/site-packages (1.7.1)
Requirement already satisfied: yapf in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (0.32.0)
Requirement already satisfied: numpy in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (1.23.5)
Requirement already satisfied: Pillow in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (9.4.0)
Requirement already satisfied: opencv-python>=3 in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (4.7.0.72)
Requirement already satisfied: pyyaml in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (5.4.1)
Requirement already satisfied: addict in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (2.4.0)
Requirement already satisfied: packaging in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (23.0)
Requirement already satisfied: mmdet in /home/ci/opt/venv/lib/python3.8/site-packages (2.28.1)
Requirement already satisfied: six in /home/ci/opt/venv/lib/python3.8/site-packages (from mmdet) (1.16.0)
Requirement already satisfied: pycocotools in /home/ci/opt/venv/lib/python3.8/site-packages (from mmdet) (2.0.6)
Requirement already satisfied: terminaltables in /home/ci/opt/venv/lib/python3.8/site-packages (from mmdet) (3.1.10)
Requirement already satisfied: numpy in /home/ci/opt/venv/lib/python3.8/site-packages (from mmdet) (1.23.5)
Requirement already satisfied: matplotlib in /home/ci/opt/venv/lib/python3.8/site-packages (from mmdet) (3.6.3)
Requirement already satisfied: scipy in /home/ci/opt/venv/lib/python3.8/site-packages (from mmdet) (1.10.1)
Requirement already satisfied: pillow>=6.2.0 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (9.4.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (1.0.7)
Requirement already satisfied: kiwisolver>=1.0.1 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (1.4.4)
Requirement already satisfied: cycler>=0.10 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (4.38.0)
Requirement already satisfied: pyparsing>=2.2.1 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (3.0.9)
Requirement already satisfied: packaging>=20.0 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (23.0)
Requirement already satisfied: python-dateutil>=2.7 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (2.8.2)

And also import some other packages that will be used in this tutorial:

import os
import time

from autogluon.core.utils.loaders import load_zip

We have the sample dataset ready in the cloud. Let’s download it:

zip_file = "https://automl-mm-bench.s3.amazonaws.com/object_detection/dataset/pothole.zip"
download_dir = "./pothole"

load_zip.unzip(zip_file, unzip_dir=download_dir)
data_dir = os.path.join(download_dir, "pothole")
train_path = os.path.join(data_dir, "Annotations", "usersplit_train_cocoformat.json")
val_path = os.path.join(data_dir, "Annotations", "usersplit_val_cocoformat.json")
test_path = os.path.join(data_dir, "Annotations", "usersplit_test_cocoformat.json")

While using COCO format dataset, the input is the json annotation file of the dataset split. In this example, usersplit_train_cocoformat.json is the annotation file of the train split. usersplit_val_cocoformat.json is the annotation file of the validation split. And usersplit_test_cocoformat.json is the annotation file of the test split.

We select the VFNet with ResNet-50 as backbone, Feature Pyramid Network (FPN) as neck, and input resolution is 640x640, pretrained on COCO dataset. (The neck of the object detector refers to the additional layers existing between the backbone and the head. Their role is to collect feature maps from different stages.) With this setting, it sacrifices training and inference time, and also requires much more GPU memory, but the performance is high.

We use val_metric = map, i.e., mean average precision in COCO standard as our validation metric. In previous section AutoMM Detection - Fast Finetune on COCO Format Dataset, we did not specify the validation metric and by default the validation loss is used as validation metric. Using validation loss is much faster but using mean average precision gives the best performance.

And we use all the GPUs (if any):

checkpoint_name = "vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco"
num_gpus = -1  # use all GPUs
val_metric = "map"

We create the MultiModalPredictor with selected checkpoint name, val_metric, and number of GPUs. We need to specify the problem_type to "object_detection", and also provide a sample_data_path for the predictor to infer the categories of the dataset. Here we provide the train_path, and it also works using any other split of this dataset.

predictor = MultiModalPredictor(
    hyperparameters={
        "model.mmdet_image.checkpoint_name": checkpoint_name,
        "env.num_gpus": num_gpus,
        "optimization.val_metric": val_metric,
    },
    problem_type="object_detection",
    sample_data_path=train_path,
)
processing vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco...
Output()
Successfully downloaded vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco_20201027pth-6879c318.pth to /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune
Successfully dumped vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco.py to /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune
processing vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco...
vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco_20201027pth-6879c318.pth exists in /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune
Successfully dumped vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco.py to /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune
load checkpoint from local path: vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco_20201027pth-6879c318.pth
The model and loaded state dict do not match exactly

size mismatch for bbox_head.vfnet_cls.weight: copying a param with shape torch.Size([80, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 256, 3, 3]).
size mismatch for bbox_head.vfnet_cls.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([1]).

We used 1e-4 for Yolo V3 in previous tutorial, but set the learning rate to be 5e-6 here, because larger model always prefers smaller learning rate. Note that we use a two-stage learning rate option during finetuning by default, and the model head will have 100x learning rate. Using a two-stage learning rate with high learning rate only on head layers makes the model converge faster during finetuning. It usually gives better performance as well, especially on small datasets with hundreds or thousands of images. We also set the batch_size to be 2, because this model is too huge to run with larger batch size. We also compute the time of the fit process here for better understanding the speed. We only set the number of epochs to be 1 for a quick demonstration, and to seriously finetune the model on this dataset we will need to set this to 20 or more.

import time
start = time.time()
predictor.fit(
    train_path,
    hyperparameters={
        "optimization.learning_rate": 5e-6, # we use two stage and detection head has 100x lr
        "optimization.max_epochs": 1,
        "optimization.check_val_every_n_epoch": 1, # make sure there is at least one validation
        "env.per_gpu_batch_size": 2,  # decrease it when model is large
    },
)
end = time.time()
Using default root folder: ./pothole/pothole/Annotations/... Specify root=... if you feel it is wrong...
Global seed set to 123
No path specified. Models will be saved in: "AutogluonModels/ag-20230222_232821/"
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
AutoMM starts to create your model. ✨

- Model will be saved to "/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune/AutogluonModels/ag-20230222_232821".

- Validation metric is "map".

- To track the learning progress, you can open a terminal and launch Tensorboard:
    `shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune/AutogluonModels/ag-20230222_232821
    `

Enjoy your coffee, and let AutoMM do the job ☕☕☕ Learn more at https://auto.gluon.ai

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Trainer(val_check_interval=1.0) was configured so validation will run at the end of the training epoch..
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                             | Params
-----------------------------------------------------------------------
0 | model             | MMDetAutoModelForObjectDetection | 33.7 M
1 | validation_metric | MeanAveragePrecision             | 0
-----------------------------------------------------------------------
33.5 M    Trainable params
225 K     Non-trainable params
33.7 M    Total params
134.821   Total estimated model params size (MB)
/home/ci/opt/venv/lib/python3.8/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Epoch 0, global step 3: 'val_map' reached 0.00005 (best 0.00005), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune/AutogluonModels/ag-20230222_232821/epoch=0-step=3.ckpt' as top 1
Trainer.fit stopped: max_epochs=1 reached.
AutoMM has created your model 🎉🎉🎉

- To load the model, use the code below:
    `python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune/AutogluonModels/ag-20230222_232821")
    `

- You can open a terminal and launch Tensorboard to visualize the training log:
    `shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune/AutogluonModels/ag-20230222_232821
    `

- If you are not satisfied with the model, try to increase the training time,
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub: https://github.com/autogluon/autogluon

Print out the time and we can see that it takes a long time even for one epoch.

print("This finetuning takes %.2f seconds." % (end - start))
This finetuning takes 164.68 seconds.

To get a model with reasonable performance, you will need to finetune the model with more epochs. We set max_epochs to 50 and trained a model offline. And we uploaded it to AWS S3. To load and check the result:

# Load Trained Predictor from S3
zip_file = "https://automl-mm-bench.s3.amazonaws.com/object_detection/checkpoints/pothole_AP50_718.zip"
download_dir = "./pothole_AP50_718"
load_zip.unzip(zip_file, unzip_dir=download_dir)
better_predictor = MultiModalPredictor.load("./pothole_AP50_718/AutogluonModels/ag-20221123_021130")
better_predictor.set_num_gpus(1)

# Evaluate new predictor
better_predictor.evaluate(test_path)
Downloading ./pothole_AP50_718/file.zip from https://automl-mm-bench.s3.amazonaws.com/object_detection/checkpoints/pothole_AP50_718.zip...
100%|██████████| 251M/251M [00:04<00:00, 57.4MiB/s]
Unzipping ./pothole_AP50_718/file.zip to ./pothole_AP50_718
Start to upgrade the previous configuration trained by AutoMM version=0.6.1b20221118.
Loading a model that has been trained via AutoGluon Multimodal<=0.6.2. Try to update the timm image size.
/home/ci/opt/venv/lib/python3.8/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LabelEncoder from version 1.0.2 when using version 1.1.1. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  warnings.warn(
/home/ci/opt/venv/lib/python3.8/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator StandardScaler from version 1.0.2 when using version 1.1.1. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  warnings.warn(
processing vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco...
vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco_20201027pth-6879c318.pth exists in /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune
Successfully dumped vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco.py to /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune
processing vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco...
vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco_20201027pth-6879c318.pth exists in /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune
Successfully dumped vfnet_r50_fpn_mdconv_c3-c5_mstrain_2x_coco.py to /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune
Load pretrained checkpoint: /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune/pothole_AP50_718/AutogluonModels/ag-20221123_021130/model.ckpt
Using default root folder: ./pothole/pothole/Annotations/... Specify root=... if you feel it is wrong...
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
A new predictor save path is created.This is to prevent you to overwrite previous predictor saved here.You could check current save path at predictor._save_path.If you still want to use this path, set resume=True
No path specified. Models will be saved in: "AutogluonModels/ag-20230222_233138/"
saving file at /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/finetune/AutogluonModels/ag-20230222_233138/object_detection_result_cache.json
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
Loading and preparing results...
DONE (t=0.01s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type bbox
DONE (t=0.25s).
Accumulating evaluation results...
DONE (t=0.05s).
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.449
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.717
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.481
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.246
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.458
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.606
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.255
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.556
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.634
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.525
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.620
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.749
{'map': 0.4493231102621059,
 'mean_average_precision': 0.4493231102621059,
 'map_50': 0.7174652483674355,
 'map_75': 0.4813991783298232,
 'map_small': 0.24554152087905134,
 'map_medium': 0.4584927016251649,
 'map_large': 0.6056972302542543,
 'mar_1': 0.25486725663716814,
 'mar_10': 0.5557522123893806,
 'mar_100': 0.6336283185840708,
 'mar_small': 0.5253521126760563,
 'mar_medium': 0.6204419889502761,
 'mar_large': 0.7494252873563219}

We can get the prediction on test set:

pred = better_predictor.predict(test_path)
Using default root folder: ./pothole/pothole/Annotations/... Specify root=... if you feel it is wrong...
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!

Let’s also visualize the prediction result:

!pip install opencv-python
Requirement already satisfied: opencv-python in /home/ci/opt/venv/lib/python3.8/site-packages (4.7.0.72)
Requirement already satisfied: numpy>=1.17.0 in /home/ci/opt/venv/lib/python3.8/site-packages (from opencv-python) (1.23.5)
from autogluon.multimodal.utils import visualize_detection
conf_threshold = 0.25  # Specify a confidence threshold to filter out unwanted boxes
visualization_result_dir = "./"  # Use the pwd as result dir to save the visualized image
visualized = visualize_detection(
    pred=pred[12:13],
    detection_classes=predictor.get_predictor_classes(),
    conf_threshold=conf_threshold,
    visualization_result_dir=visualization_result_dir,
)
from PIL import Image
from IPython.display import display
img = Image.fromarray(visualized[0][:, :, ::-1], 'RGB')
display(img)
Saved visualizations to ./
../../../../_images/output_detection_high_performance_finetune_coco_4a07f9_22_1.png

Under this high performance finetune setting, it took a long time to train but reached mAP = 0.450, mAP50 = 0.718! For how to finetune faster, see AutoMM Detection - Fast Finetune on COCO Format Dataset, where we finetuned a YOLOv3 model with lower performance but much faster.

Other Examples

You may go to AutoMM Examples to explore other examples about AutoMM.

Customization

To learn how to customize AutoMM, please refer to Customize AutoMM.

Citation

@article{DBLP:journals/corr/abs-2008-13367,
  author    = {Haoyang Zhang and
               Ying Wang and
               Feras Dayoub and
               Niko S{\"{u}}nderhauf},
  title     = {VarifocalNet: An IoU-aware Dense Object Detector},
  journal   = {CoRR},
  volume    = {abs/2008.13367},
  year      = {2020},
  url       = {https://arxiv.org/abs/2008.13367},
  eprinttype = {arXiv},
  eprint    = {2008.13367},
  timestamp = {Wed, 16 Sep 2020 11:20:03 +0200},
  biburl    = {https://dblp.org/rec/journals/corr/abs-2008-13367.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}