.. _sec_automm_detection_quick_start_coco: AutoMM Detection - Quick Start on a Tiny COCO Format Dataset ============================================================ In this section, our goal is to fast finetune a pretrained model on a small dataset in COCO format, and evaluate on its test set. Both training and test sets are in COCO format. See :ref:`sec_automm_detection_convert_to_coco` for how to convert other datasets to COCO format. Setting up the imports ~~~~~~~~~~~~~~~~~~~~~~ To start, let’s import MultiModalPredictor: .. code:: python from autogluon.multimodal import MultiModalPredictor Make sure ``mmcv-full`` and ``mmdet`` are installed: .. code:: python !mim install mmcv-full !pip install mmdet .. parsed-literal:: :class: output Looking in links: https://download.openmmlab.com/mmcv/dist/cu117/torch1.13.0/index.html Requirement already satisfied: mmcv-full in /home/ci/opt/venv/lib/python3.8/site-packages (1.7.1) Requirement already satisfied: addict in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (2.4.0) Requirement already satisfied: opencv-python>=3 in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (4.7.0.72) Requirement already satisfied: packaging in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (23.0) Requirement already satisfied: numpy in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (1.23.5) Requirement already satisfied: pyyaml in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (5.4.1) Requirement already satisfied: Pillow in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (9.4.0) Requirement already satisfied: yapf in /home/ci/opt/venv/lib/python3.8/site-packages (from mmcv-full) (0.32.0) Requirement already satisfied: mmdet in /home/ci/opt/venv/lib/python3.8/site-packages (2.28.1) Requirement already satisfied: six in /home/ci/opt/venv/lib/python3.8/site-packages (from mmdet) (1.16.0) Requirement already satisfied: matplotlib in /home/ci/opt/venv/lib/python3.8/site-packages (from mmdet) (3.6.3) Requirement already satisfied: pycocotools in /home/ci/opt/venv/lib/python3.8/site-packages (from mmdet) (2.0.6) Requirement already satisfied: terminaltables in /home/ci/opt/venv/lib/python3.8/site-packages (from mmdet) (3.1.10) Requirement already satisfied: numpy in /home/ci/opt/venv/lib/python3.8/site-packages (from mmdet) (1.23.5) Requirement already satisfied: scipy in /home/ci/opt/venv/lib/python3.8/site-packages (from mmdet) (1.10.1) Requirement already satisfied: python-dateutil>=2.7 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (2.8.2) Requirement already satisfied: fonttools>=4.22.0 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (4.38.0) Requirement already satisfied: packaging>=20.0 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (23.0) Requirement already satisfied: pyparsing>=2.2.1 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (3.0.9) Requirement already satisfied: pillow>=6.2.0 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (9.4.0) Requirement already satisfied: kiwisolver>=1.0.1 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (1.4.4) Requirement already satisfied: cycler>=0.10 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (0.11.0) Requirement already satisfied: contourpy>=1.0.1 in /home/ci/opt/venv/lib/python3.8/site-packages (from matplotlib->mmdet) (1.0.7) And also import some other packages that will be used in this tutorial: .. code:: python import os import time from autogluon.core.utils.loaders import load_zip Downloading Data ~~~~~~~~~~~~~~~~ We have the sample dataset ready in the cloud. Let’s download it: .. code:: python zip_file = "https://automl-mm-bench.s3.amazonaws.com/object_detection_dataset/tiny_motorbike_coco.zip" download_dir = "./tiny_motorbike_coco" load_zip.unzip(zip_file, unzip_dir=download_dir) data_dir = os.path.join(download_dir, "tiny_motorbike") train_path = os.path.join(data_dir, "Annotations", "trainval_cocoformat.json") test_path = os.path.join(data_dir, "Annotations", "test_cocoformat.json") .. parsed-literal:: :class: output Downloading ./tiny_motorbike_coco/file.zip from https://automl-mm-bench.s3.amazonaws.com/object_detection_dataset/tiny_motorbike_coco.zip... .. parsed-literal:: :class: output 100%|██████████| 21.8M/21.8M [00:00<00:00, 82.9MiB/s] While using COCO format dataset, the input is the json annotation file of the dataset split. In this example, ``trainval_cocoformat.json`` is the annotation file of the train-and-validate split, and ``test_cocoformat.json`` is the annotation file of the test split. Creating the MultiModalPredictor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We select the ``"medium_quality"`` presets, which uses a YOLOX-small model pretrained on COCO dataset. This preset is fast to finetune or inference, and easy to deploy. We also provide presets ``"high_quality"`` and ``"best quality"``, with higher performance but also slower. .. code:: python presets = "medium_quality" We create the MultiModalPredictor with selected presets. We need to specify the problem_type to ``"object_detection"``, and also provide a ``sample_data_path`` for the predictor to infer the catgories of the dataset. Here we provide the ``train_path``, and it also works using any other split of this dataset. And we also provide a ``path`` to save the predictor. It will be saved to a automatically generated directory with timestamp under ``AutogluonModels`` if ``path`` is not specified. .. code:: python # Init predictor import uuid model_path = f"./tmp/{uuid.uuid4().hex}-quick_start_tutorial_temp_save" predictor = MultiModalPredictor( problem_type="object_detection", sample_data_path=train_path, presets=presets, path=model_path, ) .. parsed-literal:: :class: output Downloading yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth from https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth... .. parsed-literal:: :class: output .. parsed-literal:: :class: output load checkpoint from local path: yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth The model and loaded state dict do not match exactly size mismatch for bbox_head.multi_level_conv_cls.0.weight: copying a param with shape torch.Size([80, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([10, 128, 1, 1]). size mismatch for bbox_head.multi_level_conv_cls.0.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([10]). size mismatch for bbox_head.multi_level_conv_cls.1.weight: copying a param with shape torch.Size([80, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([10, 128, 1, 1]). size mismatch for bbox_head.multi_level_conv_cls.1.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([10]). size mismatch for bbox_head.multi_level_conv_cls.2.weight: copying a param with shape torch.Size([80, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([10, 128, 1, 1]). size mismatch for bbox_head.multi_level_conv_cls.2.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([10]). Finetuning the Model ~~~~~~~~~~~~~~~~~~~~ Learning rate, number of epochs, and batch_size are included in the presets, and thus no need to specify. Note that we use a two-stage learning rate option during finetuning by default, and the model head will have 100x learning rate. Using a two-stage learning rate with high learning rate only on head layers makes the model converge faster during finetuning. It usually gives better performance as well, especially on small datasets with hundreds or thousands of images. We also compute the time of the fit process here for better understanding the speed. We run it on a g4.2xlarge EC2 machine on AWS, and part of the command outputs are shown below: .. code:: python start = time.time() predictor.fit(train_path) # Fit train_end = time.time() .. parsed-literal:: :class: output Using default root folder: ./tiny_motorbike_coco/tiny_motorbike/Annotations/... Specify `root=...` if you feel it is wrong... Global seed set to 123 .. parsed-literal:: :class: output loading annotations into memory... Done (t=0.00s) creating index... index created! .. parsed-literal:: :class: output AutoMM starts to create your model. ✨ - Model will be saved to "/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/quick_start/tmp/342757951a634506b89b87909fa0c7e6-quick_start_tutorial_temp_save". - Validation metric is "map". - To track the learning progress, you can open a terminal and launch Tensorboard: ```shell # Assume you have installed tensorboard tensorboard --logdir /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/quick_start/tmp/342757951a634506b89b87909fa0c7e6-quick_start_tutorial_temp_save ``` Enjoy your coffee, and let AutoMM do the job ☕☕☕ Learn more at https://auto.gluon.ai GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch.. LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ----------------------------------------------------------------------- 0 | model | MMDetAutoModelForObjectDetection | 8.9 M 1 | validation_metric | MeanAveragePrecision | 0 ----------------------------------------------------------------------- 8.9 M Trainable params 0 Non-trainable params 8.9 M Total params 35.765 Total estimated model params size (MB) /home/ci/opt/venv/lib/python3.8/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] Epoch 2, global step 6: 'val_map' reached 0.16740 (best 0.16740), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/quick_start/tmp/342757951a634506b89b87909fa0c7e6-quick_start_tutorial_temp_save/epoch=2-step=6.ckpt' as top 1 Epoch 5, global step 12: 'val_map' reached 0.23535 (best 0.23535), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/quick_start/tmp/342757951a634506b89b87909fa0c7e6-quick_start_tutorial_temp_save/epoch=5-step=12.ckpt' as top 1 Epoch 8, global step 18: 'val_map' reached 0.24509 (best 0.24509), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/quick_start/tmp/342757951a634506b89b87909fa0c7e6-quick_start_tutorial_temp_save/epoch=8-step=18.ckpt' as top 1 Epoch 11, global step 24: 'val_map' was not in top 1 Epoch 14, global step 30: 'val_map' was not in top 1 Epoch 17, global step 36: 'val_map' was not in top 1 AutoMM has created your model 🎉🎉🎉 - To load the model, use the code below: ```python from autogluon.multimodal import MultiModalPredictor predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/quick_start/tmp/342757951a634506b89b87909fa0c7e6-quick_start_tutorial_temp_save") ``` - You can open a terminal and launch Tensorboard to visualize the training log: ```shell # Assume you have installed tensorboard tensorboard --logdir /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/quick_start/tmp/342757951a634506b89b87909fa0c7e6-quick_start_tutorial_temp_save ``` - If you are not satisfied with the model, try to increase the training time, adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html), or post issues on GitHub: https://github.com/autogluon/autogluon Notice that at the end of each progress bar, if the checkpoint at current stage is saved, it prints the model’s save path. In this example, it’s ``./quick_start_tutorial_temp_save``. Print out the time and we can see that it’s fast! .. code:: python print("This finetuning takes %.2f seconds." % (train_end - start)) .. parsed-literal:: :class: output This finetuning takes 106.64 seconds. Evaluation ~~~~~~~~~~ To evaluate the model we just trained, run following code. And the evaluation results are shown in command line output. The first line is mAP in COCO standard, and the second line is mAP in VOC standard (or mAP50). For more details about these metrics, see `COCO’s evaluation guideline `__. Note that for presenting a fast finetuning we use presets “medium_quality”, you could get better result on this dataset by simply using “high_quality” or “best_quality” presets, or customize your own model and hyperparameter settings: :ref:`sec_automm_customization`, and some other examples at :ref:`sec_automm_detection_fast_ft_coco` or :ref:`sec_automm_detection_high_ft_coco`. .. code:: python predictor.evaluate(test_path) eval_end = time.time() .. parsed-literal:: :class: output Using default root folder: ./tiny_motorbike_coco/tiny_motorbike/Annotations/... Specify `root=...` if you feel it is wrong... .. parsed-literal:: :class: output loading annotations into memory... Done (t=0.00s) creating index... index created! .. parsed-literal:: :class: output A new predictor save path is created.This is to prevent you to overwrite previous predictor saved here.You could check current save path at predictor._save_path.If you still want to use this path, set resume=True No path specified. Models will be saved in: "AutogluonModels/ag-20230222_233413/" .. parsed-literal:: :class: output saving file at /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/quick_start/AutogluonModels/ag-20230222_233413/object_detection_result_cache.json loading annotations into memory... Done (t=0.00s) creating index... index created! Loading and preparing results... DONE (t=0.00s) creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.07s). Accumulating evaluation results... DONE (t=0.04s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.297 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.498 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.310 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.182 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.316 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.610 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.220 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.372 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.381 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.279 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.523 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.671 Print out the evaluation time: .. code:: python print("The evaluation takes %.2f seconds." % (eval_end - train_end)) .. parsed-literal:: :class: output The evaluation takes 1.19 seconds. We can load a new predictor with previous save_path, and we can also reset the number of GPUs to use if not all the devices are available: .. code:: python # Load and reset num_gpus new_predictor = MultiModalPredictor.load(model_path) new_predictor.set_num_gpus(1) .. parsed-literal:: :class: output Load pretrained checkpoint: /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/quick_start/tmp/342757951a634506b89b87909fa0c7e6-quick_start_tutorial_temp_save/model.ckpt Evaluating the new predictor gives us exactly the same result: .. code:: python # Evaluate new predictor new_predictor.evaluate(test_path) .. parsed-literal:: :class: output Using default root folder: ./tiny_motorbike_coco/tiny_motorbike/Annotations/... Specify `root=...` if you feel it is wrong... .. parsed-literal:: :class: output loading annotations into memory... Done (t=0.00s) creating index... index created! .. parsed-literal:: :class: output A new predictor save path is created.This is to prevent you to overwrite previous predictor saved here.You could check current save path at predictor._save_path.If you still want to use this path, set resume=True No path specified. Models will be saved in: "AutogluonModels/ag-20230222_233415/" .. parsed-literal:: :class: output saving file at /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/quick_start/AutogluonModels/ag-20230222_233415/object_detection_result_cache.json loading annotations into memory... Done (t=0.00s) creating index... index created! Loading and preparing results... DONE (t=0.00s) creating index... index created! Running per image evaluation... Evaluate annotation type *bbox* DONE (t=0.07s). Accumulating evaluation results... DONE (t=0.04s). Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.297 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.498 Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.310 Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.182 Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.316 Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.610 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.220 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.372 Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.381 Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.279 Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.523 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.671 .. parsed-literal:: :class: output {'map': 0.2965708693031757, 'mean_average_precision': 0.2965708693031757, 'map_50': 0.49750597617301967, 'map_75': 0.3098997668751622, 'map_small': 0.18151102699385396, 'map_medium': 0.3164229109319443, 'map_large': 0.6101904340383907, 'mar_1': 0.2204355179704017, 'mar_10': 0.3720883251115809, 'mar_100': 0.3812614517265679, 'mar_small': 0.2785416666666667, 'mar_medium': 0.5231746031746032, 'mar_large': 0.6709851551956815} For how to set the hyperparameters and finetune the model with higher performance, see :ref:`sec_automm_detection_high_ft_coco`. Inference ~~~~~~~~~ Now that we have gone through the model setup, finetuning, and evaluation, this section details the inference. Specifically, we layout the steps for using the model to make predictions and visualize the results. To run inference on the entire test set, perform: .. code:: python pred = predictor.predict(test_path) print(pred) .. parsed-literal:: :class: output Using default root folder: ./tiny_motorbike_coco/tiny_motorbike/Annotations/... Specify `root=...` if you feel it is wrong... .. parsed-literal:: :class: output loading annotations into memory... Done (t=0.00s) creating index... index created! image \ 0 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 1 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 2 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 3 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 4 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 5 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 6 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 7 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 8 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 9 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 10 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 11 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 12 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 13 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 14 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 15 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 16 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 17 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 18 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 19 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 20 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 21 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 22 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 23 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 24 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 25 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 26 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 27 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 28 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 29 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 30 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 31 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 32 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 33 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 34 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 35 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 36 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 37 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 38 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 39 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 40 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 41 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 42 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 43 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 44 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 45 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 46 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 47 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 48 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... 49 ./tiny_motorbike_coco/tiny_motorbike/Annotatio... bboxes 0 [{'class': 'bicycle', 'bbox': [159.08742, 173.... 1 [{'class': 'bus', 'bbox': [0.7754028, 162.5248... 2 [{'class': 'car', 'bbox': [457.092, 112.7427, ... 3 [{'class': 'motorbike', 'bbox': [15.9307, 35.5... 4 [{'class': 'motorbike', 'bbox': [92.80351, 198... 5 [{'class': 'car', 'bbox': [11.939812, 41.06825... 6 [{'class': 'car', 'bbox': [228.71475, 237.9081... 7 [{'class': 'motorbike', 'bbox': [98.754776, 14... 8 [{'class': 'bicycle', 'bbox': [27.409649, 36.7... 9 [{'class': 'bicycle', 'bbox': [353.5605, 2.437... 10 [{'class': 'car', 'bbox': [486.20712, 127.6418... 11 [{'class': 'motorbike', 'bbox': [8.630514, 206... 12 [{'class': 'motorbike', 'bbox': [191.91798, 12... 13 [{'class': 'motorbike', 'bbox': [48.085285, 42... 14 [{'class': 'motorbike', 'bbox': [140.77747, 11... 15 [{'class': 'motorbike', 'bbox': [188.9875, 176... 16 [{'class': 'car', 'bbox': [477.8009, 262.42056... 17 [{'class': 'car', 'bbox': [1.4333069, 263.6177... 18 [{'class': 'motorbike', 'bbox': [80.230515, 70... 19 [{'class': 'car', 'bbox': [1.975226, 89.24797,... 20 [{'class': 'motorbike', 'bbox': [82.57903, 249... 21 [{'class': 'motorbike', 'bbox': [65.93978, 170... 22 [{'class': 'car', 'bbox': [331.807, 129.59215,... 23 [{'class': 'motorbike', 'bbox': [15.910912, 70... 24 [{'class': 'motorbike', 'bbox': [48.33057, 27.... 25 [{'class': 'motorbike', 'bbox': [78.24378, 131... 26 [{'class': 'motorbike', 'bbox': [69.70891, 102... 27 [{'class': 'car', 'bbox': [0.24610162, 4.27230... 28 [{'class': 'motorbike', 'bbox': [26.056004, 11... 29 [{'class': 'car', 'bbox': [306.49298, 125.6439... 30 [{'class': 'car', 'bbox': [486.70847, 59.88590... 31 [{'class': 'bicycle', 'bbox': [382.49078, 124.... 32 [{'class': 'motorbike', 'bbox': [113.48841, 16... 33 [{'class': 'motorbike', 'bbox': [18.617868, 35... 34 [{'class': 'car', 'bbox': [14.372659, 2.243641... 35 [{'class': 'car', 'bbox': [377.02277, 43.59979... 36 [{'class': 'motorbike', 'bbox': [4.1884418, 23... 37 [{'class': 'motorbike', 'bbox': [10.583711, 90... 38 [{'class': 'motorbike', 'bbox': [12.802935, 31... 39 [{'class': 'motorbike', 'bbox': [155.0082, 31.... 40 [{'class': 'car', 'bbox': [-0.9064317, 316.148... 41 [{'class': 'motorbike', 'bbox': [-1.077795, 0.... 42 [{'class': 'car', 'bbox': [493.76685, 115.5181... 43 [{'class': 'bus', 'bbox': [3.4678698, 10.38836... 44 [{'class': 'motorbike', 'bbox': [57.490456, 13... 45 [{'class': 'motorbike', 'bbox': [171.12007, 11... 46 [{'class': 'motorbike', 'bbox': [69.957245, 10... 47 [{'class': 'bus', 'bbox': [290.71326, 19.16719... 48 [{'class': 'car', 'bbox': [0.4284978, 3.182631... 49 [{'class': 'motorbike', 'bbox': [8.759499, 209... The output ``pred`` is a ``pandas`` ``DataFrame`` that has two columns, ``image`` and ``bboxes``. In ``image``, each row contains the image path In ``bboxes``, each row is a list of dictionaries, each one representing a bounding box: ``{"class": , "bbox": [x1, y1, x2, y2], "score": }`` Note that, by default, the ``predictor.predict`` does not save the detection results into a file. To run inference and save results, run the following: .. code:: python pred = predictor.predict(test_path, save_results=True) .. parsed-literal:: :class: output Using default root folder: ./tiny_motorbike_coco/tiny_motorbike/Annotations/... Specify `root=...` if you feel it is wrong... .. parsed-literal:: :class: output loading annotations into memory... Done (t=0.00s) creating index... index created! .. parsed-literal:: :class: output A new predictor save path is created.This is to prevent you to overwrite previous predictor saved here.You could check current save path at predictor._save_path.If you still want to use this path, set resume=True No path specified. Models will be saved in: "AutogluonModels/ag-20230222_233417/" Saved detection results to /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/quick_start/AutogluonModels/ag-20230222_233417/result.txt .. parsed-literal:: :class: output Saved detection results to /home/ci/autogluon/docs/_build/eval/tutorials/multimodal/object_detection/quick_start/AutogluonModels/ag-20230222_233417/result.txt Here, we save ``pred`` into a ``.txt`` file, which exactly follows the same layout as in ``pred``. You can use a predictor initialized in anyway (i.e. finetuned predictor, predictor with pretrained model, etc.). Visualizing Results ~~~~~~~~~~~~~~~~~~~ To run visualizations, ensure that you have ``opencv`` installed. If you haven’t already, install ``opencv`` by running .. code:: python !pip install opencv-python .. parsed-literal:: :class: output Requirement already satisfied: opencv-python in /home/ci/opt/venv/lib/python3.8/site-packages (4.7.0.72) Requirement already satisfied: numpy>=1.17.3 in /home/ci/opt/venv/lib/python3.8/site-packages (from opencv-python) (1.23.5) To visualize the detection bounding boxes, run the following: .. code:: python from autogluon.multimodal.utils import Visualizer conf_threshold = 0.4 # Specify a confidence threshold to filter out unwanted boxes image_result = pred.iloc[30] img_path = image_result.image # Select an image to visualize visualizer = Visualizer(img_path) # Initialize the Visualizer out = visualizer.draw_instance_predictions(image_result, conf_threshold=conf_threshold) # Draw detections visualized = out.get_image() # Get the visualized image from PIL import Image from IPython.display import display img = Image.fromarray(visualized, 'RGB') display(img) .. figure:: output_quick_start_coco_f6564b_31_0.png Testing on Your Own Image ~~~~~~~~~~~~~~~~~~~~~~~~~ You can also download an image and run inference on that single image. The follow is an example: Download the example image: .. code:: python from autogluon.multimodal import download image_url = "https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/detection/street_small.jpg" test_image = download(image_url) .. parsed-literal:: :class: output Downloading street_small.jpg from https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/detection/street_small.jpg... .. parsed-literal:: :class: output Run inference: .. code:: python pred_test_image = predictor.predict({"image": [test_image]}) print(pred_test_image) .. parsed-literal:: :class: output image bboxes 0 street_small.jpg [{'class': 'car', 'bbox': [235.07066, 211.6830... Other Examples ~~~~~~~~~~~~~~ You may go to `AutoMM Examples `__ to explore other examples about AutoMM. Customization ~~~~~~~~~~~~~ To learn how to customize AutoMM, please refer to :ref:`sec_automm_customization`. Citation ~~~~~~~~ :: @article{DBLP:journals/corr/abs-2107-08430, author = {Zheng Ge and Songtao Liu and Feng Wang and Zeming Li and Jian Sun}, title = {{YOLOX:} Exceeding {YOLO} Series in 2021}, journal = {CoRR}, volume = {abs/2107.08430}, year = {2021}, url = {https://arxiv.org/abs/2107.08430}, eprinttype = {arXiv}, eprint = {2107.08430}, timestamp = {Tue, 05 Apr 2022 14:09:44 +0200}, biburl = {https://dblp.org/rec/journals/corr/abs-2107-08430.bib}, bibsource = {dblp computer science bibliography, https://dblp.org}, } Other Examples ~~~~~~~~~~~~~~ You may go to `AutoMM Examples `__ to explore other examples about AutoMM. Customization ~~~~~~~~~~~~~ To learn how to customize AutoMM, please refer to :ref:`sec_automm_customization`. Citation ~~~~~~~~ :: @misc{redmon2018yolov3, title={YOLOv3: An Incremental Improvement}, author={Joseph Redmon and Ali Farhadi}, year={2018}, eprint={1804.02767}, archivePrefix={arXiv}, primaryClass={cs.CV} }