{ "cells": [ { "cell_type": "markdown", "id": "2325502b", "metadata": {}, "source": [ "# Customize AutoMM\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/autogluon/autogluon/blob/stable/docs/tutorials/multimodal/advanced_topics/customization.ipynb)\n", "[![Open In SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/autogluon/autogluon/blob/stable/docs/tutorials/multimodal/advanced_topics/customization.ipynb)\n", "\n", "\n", "\n", "\n", "AutoMM has a powerful yet easy-to-use configuration design.\n", "This tutorial walks you through various AutoMM configurations to empower you the customization flexibility. Specifically, AutoMM configurations consist of several parts:\n", "\n", "- optimization\n", "- environment\n", "- model\n", "- data\n", "- distiller\n", "\n", "## Optimization\n", "\n", "### optimization.learning_rate\n", "\n", "Learning rate." ] }, { "cell_type": "markdown", "id": "0f349cfb", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.learning_rate\": 1.0e-4})\n", "# set learning rate to 5.0e-4\n", "predictor.fit(hyperparameters={\"optimization.learning_rate\": 5.0e-4})\n", "```\n" ] }, { "cell_type": "markdown", "id": "8ea63ec5", "metadata": {}, "source": [ "### optimization.optim_type\n", "\n", "Optimizer type.\n", "\n", "- `\"sgd\"`: stochastic gradient descent with momentum.\n", "- `\"adam\"`: a stochastic gradient descent method that is based on adaptive estimation of first-order and second-order moments. See [this paper](https://arxiv.org/abs/1412.6980) for details.\n", "- `\"adamw\"`: improves adam by decoupling the weight decay from the optimization step. See [this paper](https://arxiv.org/abs/1711.05101) for details." ] }, { "cell_type": "markdown", "id": "37fcd9e2", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.optim_type\": \"adamw\"})\n", "# use optimizer adam\n", "predictor.fit(hyperparameters={\"optimization.optim_type\": \"adam\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "aeb40726", "metadata": {}, "source": [ "### optimization.weight_decay\n", "\n", "Weight decay." ] }, { "cell_type": "markdown", "id": "852daaee", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.weight_decay\": 1.0e-3})\n", "# set weight decay to 1.0e-4\n", "predictor.fit(hyperparameters={\"optimization.weight_decay\": 1.0e-4})\n", "```\n" ] }, { "cell_type": "markdown", "id": "7006101a", "metadata": {}, "source": [ "### optimization.lr_decay\n", "\n", "Later layers can have larger learning rates than the earlier layers. The last/head layer\n", "has the largest learning rate `optimization.learning_rate`. For a model with `n` layers, layer `i` has learning rate `optimization.learning_rate * optimization.lr_decay^(n-i)`. To use one uniform learning rate, simply set the learning rate decay to `1`." ] }, { "cell_type": "markdown", "id": "adb57179", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.lr_decay\": 0.9})\n", "# turn off learning rate decay\n", "predictor.fit(hyperparameters={\"optimization.lr_decay\": 1})\n", "```\n" ] }, { "cell_type": "markdown", "id": "59632914", "metadata": {}, "source": [ "### optimization.lr_mult\n", "\n", "While we are using two_stages lr choice,\n", "The last/head layer has the largest learning rate `optimization.learning_rate` * `optimization.lr_mult`.\n", "And other layers has normal learning rate `optimization.learning_rate`.\n", "To use one uniform learning rate, simply set the learning rate multiple to `1`." ] }, { "cell_type": "markdown", "id": "1770634e", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.lr_mult\": 1})\n", "# turn on two-stage lr for 10 times learning rate in head layer\n", "predictor.fit(hyperparameters={\"optimization.lr_mult\": 10})\n", "```\n" ] }, { "cell_type": "markdown", "id": "60b67198", "metadata": {}, "source": [ "### optimization.lr_choice\n", "\n", "We may want different layers to have different lr,\n", "here we have strategy `two_stages` lr choice (see `optimization.lr_mult` section for more details),\n", "or `layerwise_decay` lr choice (see `optimization.lr_decay` section for more details).\n", "To use one uniform learning rate, simply set this to `\"\"`." ] }, { "cell_type": "markdown", "id": "b3dd4814", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.lr_choice\": \"layerwise_decay\"})\n", "# turn on two-stage lr choice\n", "predictor.fit(hyperparameters={\"optimization.lr_choice\": \"two_stages\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "31355e32", "metadata": {}, "source": [ "### optimization.lr_schedule\n", "\n", "Learning rate schedule.\n", "\n", "- `\"cosine_decay\"`: the decay of learning rate follows the cosine curve.\n", "- `\"polynomial_decay\"`: the learning rate is decayed based on polynomial functions.\n", "- `\"linear_decay\"`: linearly decays the learing rate." ] }, { "cell_type": "markdown", "id": "100a6a22", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.lr_schedule\": \"cosine_decay\"})\n", "# use polynomial decay\n", "predictor.fit(hyperparameters={\"optimization.lr_schedule\": \"polynomial_decay\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "5ce9a4af", "metadata": {}, "source": [ "### optimization.max_epochs\n", "\n", "Stop training once this number of epochs is reached." ] }, { "cell_type": "markdown", "id": "30a8c032", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.max_epochs\": 10})\n", "# train 20 epochs\n", "predictor.fit(hyperparameters={\"optimization.max_epochs\": 20})\n", "```\n" ] }, { "cell_type": "markdown", "id": "53264b7c", "metadata": {}, "source": [ "### optimization.max_steps\n", "\n", "Stop training after this number of steps. Training will stop if `optimization.max_steps` or `optimization.max_epochs` have reached (earliest).\n", "By default, we disable `optimization.max_steps` by setting it to -1." ] }, { "cell_type": "markdown", "id": "2482fc3e", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.max_steps\": -1})\n", "# train 100 steps\n", "predictor.fit(hyperparameters={\"optimization.max_steps\": 100})\n", "```\n" ] }, { "cell_type": "markdown", "id": "cd6fa991", "metadata": {}, "source": [ "### optimization.warmup_steps\n", "\n", "Warm up the learning rate from 0 to `optimization.learning_rate` within this percentage of steps at the beginning of training." ] }, { "cell_type": "markdown", "id": "34d3a967", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.warmup_steps\": 0.1})\n", "# do learning rate warmup in the first 20% steps.\n", "predictor.fit(hyperparameters={\"optimization.warmup_steps\": 0.2})\n", "```\n" ] }, { "cell_type": "markdown", "id": "80452db2", "metadata": {}, "source": [ "### optimization.patience\n", "\n", "Stop training after this number of checks with no improvement. The check frequency is controlled by `optimization.val_check_interval`." ] }, { "cell_type": "markdown", "id": "b3bcd482", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.patience\": 10})\n", "# set patience to 5 checks\n", "predictor.fit(hyperparameters={\"optimization.patience\": 5})\n", "```\n" ] }, { "cell_type": "markdown", "id": "2765c653", "metadata": {}, "source": [ "### optimization.val_check_interval\n", "\n", "How often within one training epoch to check the validation set. Can specify as float or int.\n", "\n", "- pass a float in the range [0.0, 1.0] to check after a fraction of the training epoch.\n", "- pass an int to check after a fixed number of training batches." ] }, { "cell_type": "markdown", "id": "2ee8c226", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.val_check_interval\": 0.5})\n", "# check validation set 4 times during a training epoch\n", "predictor.fit(hyperparameters={\"optimization.val_check_interval\": 0.25})\n", "```\n" ] }, { "cell_type": "markdown", "id": "e28553d8", "metadata": {}, "source": [ "### optimization.gradient_clip_algorithm\n", "\n", "The gradient clipping algorithm to use. Support to clip gradients by value or norm." ] }, { "cell_type": "markdown", "id": "7526131f", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.gradient_clip_algorithm\": \"norm\"})\n", "# clip gradients by value\n", "predictor.fit(hyperparameters={\"optimization.gradient_clip_algorithm\": \"value\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "c29d0454", "metadata": {}, "source": [ "### optimization.gradient_clip_val\n", "\n", "Gradient clipping value, which can be the absolute value or gradient norm depending on the choice of `optimization.gradient_clip_algorithm`." ] }, { "cell_type": "markdown", "id": "50e90350", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.gradient_clip_val\": 1})\n", "# cap the gradients to 5\n", "predictor.fit(hyperparameters={\"optimization.gradient_clip_val\": 5})\n", "```\n" ] }, { "cell_type": "markdown", "id": "02d07866", "metadata": {}, "source": [ "### optimization.track_grad_norm\n", "\n", "Track the p-norm of gradients during training. May be set to ‘inf’ infinity-norm. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before logging them." ] }, { "cell_type": "markdown", "id": "1b60c371", "metadata": {}, "source": [ "```\n", "# default used by AutoMM (no tracking)\n", "predictor.fit(hyperparameters={\"optimization.track_grad_norm\": -1})\n", "# track the 2-norm\n", "predictor.fit(hyperparameters={\"optimization.track_grad_norm\": 2})\n", "```\n" ] }, { "cell_type": "markdown", "id": "abe87d32", "metadata": {}, "source": [ "### optimization.log_every_n_steps\n", "\n", "How often to log within steps." ] }, { "cell_type": "markdown", "id": "4f5fe49c", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.log_every_n_steps\": 10})\n", "# log once every 50 steps\n", "predictor.fit(hyperparameters={\"optimization.log_every_n_steps\": 50})\n", "```\n" ] }, { "cell_type": "markdown", "id": "28f30c7e", "metadata": {}, "source": [ "### optimization.top_k\n", "\n", "Based on the validation score, choose top k model checkpoints to do model averaging." ] }, { "cell_type": "markdown", "id": "17258cf4", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.top_k\": 3})\n", "# use top 5 checkpoints\n", "predictor.fit(hyperparameters={\"optimization.top_k\": 5})\n", "```\n" ] }, { "cell_type": "markdown", "id": "4a9f1cfc", "metadata": {}, "source": [ "### optimization.top_k_average_method\n", "\n", "Use what strategy to average the top k model checkpoints.\n", "\n", "- `\"greedy_soup\"`: tries to add the checkpoints from best to worst into the averaging pool and stop if the averaged checkpoint performance decreases. See [the paper](https://arxiv.org/pdf/2203.05482.pdf) for details.\n", "- `\"uniform_soup\"`: averages all the top k checkpoints as the final checkpoint.\n", "- `\"best\"`: picks the checkpoint with the best validation performance." ] }, { "cell_type": "markdown", "id": "8236ab40", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.top_k_average_method\": \"greedy_soup\"})\n", "# average all the top k checkpoints\n", "predictor.fit(hyperparameters={\"optimization.top_k_average_method\": \"uniform_soup\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "652f15b4", "metadata": {}, "source": [ "### optimization.efficient_finetune\n", "\n", "Options for parameter-efficient finetuning. Parameter-efficient finetuning means to finetune only a small portion of parameters instead of the whole pretrained backbone.\n", "\n", "- `\"bit_fit\"`: bias parameters only. See [this paper](https://arxiv.org/pdf/2106.10199.pdf) for details.\n", "- `\"norm_fit\"`: normalization parameters + bias parameters. See [this paper](https://arxiv.org/pdf/2003.00152.pdf) for details.\n", "- `\"lora\"`: LoRA Adaptors. See [this paper](https://arxiv.org/pdf/2106.09685.pdf) for details.\n", "- `\"lora_bias\"`: LoRA Adaptors + bias parameters.\n", "- `\"lora_norm\"`: LoRA Adaptors + normalization parameters + bias parameters.\n", "- `\"ia3\"`: IA3 algorithm. See [this paper](https://arxiv.org/abs/2205.05638) for details.\n", "- `\"ia3_bias\"`: IA3 + bias parameters.\n", "- `\"ia3_norm\"`: IA3 + normalization parameters + bias parameters." ] }, { "cell_type": "markdown", "id": "1a072add", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.efficient_finetune\": None})\n", "# finetune only bias parameters\n", "predictor.fit(hyperparameters={\"optimization.efficient_finetune\": \"bit_fit\"})\n", "# finetune with IA3 + BitFit\n", "predictor.fit(hyperparameters={\"optimization.efficient_finetune\": \"ia3_bias\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "a138ce84-1462-4d67-a82e-80e829a7a57d", "metadata": {}, "source": [ "### optimization.skip_final_val\n", "\n", "Whether to skip the final validation after training is signaled to stop." ] }, { "cell_type": "markdown", "id": "6830fa7f-d6ef-4578-9efd-16923fca0918", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"optimization.skip_final_val\": False})\n", "# skip the final validation\n", "predictor.fit(hyperparameters={\"optimization.skip_final_val\": True})\n", "```\n" ] }, { "cell_type": "markdown", "id": "e3cfbd99", "metadata": {}, "source": [ "## Environment\n", "\n", "### env.num_gpus\n", "\n", "The number of gpus to use. If given -1, we count the GPUs by `env.num_gpus = torch.cuda.device_count()`." ] }, { "cell_type": "markdown", "id": "f6908008", "metadata": {}, "source": [ "```\n", "# by default, all available gpus are used by AutoMM\n", "predictor.fit(hyperparameters={\"env.num_gpus\": -1})\n", "# use 1 gpu only\n", "predictor.fit(hyperparameters={\"env.num_gpus\": 1})\n", "```\n" ] }, { "cell_type": "markdown", "id": "d5e3c075", "metadata": {}, "source": [ "### env.per_gpu_batch_size\n", "\n", "The batch size for each GPU." ] }, { "cell_type": "markdown", "id": "692d5f4f", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"env.per_gpu_batch_size\": 8})\n", "# use batch size 16 per GPU\n", "predictor.fit(hyperparameters={\"env.per_gpu_batch_size\": 16})\n", "```\n" ] }, { "cell_type": "markdown", "id": "3a23b8cc", "metadata": {}, "source": [ "### env.batch_size\n", "\n", "The batch size to use in each step of training. If `env.batch_size` is larger than `env.per_gpu_batch_size * env.num_gpus`, we accumulate gradients to reach the effective `env.batch_size` before performing one optimization step. The accumulation steps are calculated by `env.batch_size // (env.per_gpu_batch_size * env.num_gpus)`." ] }, { "cell_type": "markdown", "id": "14b5a0c2", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"env.batch_size\": 128})\n", "# use batch size 256\n", "predictor.fit(hyperparameters={\"env.batch_size\": 256})\n", "```\n" ] }, { "cell_type": "markdown", "id": "fdf820c7", "metadata": {}, "source": [ "### env.eval_batch_size_ratio\n", "\n", "Prediction or evaluation uses a larger per gpu batch size `env.per_gpu_batch_size * env.eval_batch_size_ratio`." ] }, { "cell_type": "markdown", "id": "c3d8a8ca", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"env.eval_batch_size_ratio\": 4})\n", "# use 2x per gpu batch size during prediction or evaluation\n", "predictor.fit(hyperparameters={\"env.eval_batch_size_ratio\": 2})\n", "```\n" ] }, { "cell_type": "markdown", "id": "6d492098", "metadata": {}, "source": [ "### env.precision\n", "\n", "Support either double (`64`, `\"64\"`, `\"64-true\"`), float (`32`, `\"32\"`, `\"32-true\"`), bfloat16 (`\"bf16-mixed\"`, `\"bf16-true\"`), or float16 (`\"16-mixed\"`, `\"16-true\"`) precision training. For more details, refer to [here](https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision).\n", "\n", "Mixed precision like `\"16-mixed\"` is the combined use of 32 and 16 bit floating points to reduce memory footprint during model training. This can result in improved performance, achieving +3x speedups on modern GPUs." ] }, { "cell_type": "markdown", "id": "3c348024", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"env.precision\": \"16-mixed\"})\n", "# use bfloat16 mixed precision\n", "predictor.fit(hyperparameters={\"env.precision\": \"bf16-mixed\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "3aa8d934", "metadata": {}, "source": [ "### env.num_workers\n", "\n", "The number of worker processes used by the Pytorch dataloader in training. Note that more workers don't always bring speedup especially when `env.strategy = \"ddp_spawn\"`.\n", "For more details, see the guideline [here](https://lightning.ai/docs/pytorch/stable/accelerators/gpu_intermediate.html#distributed-data-parallel)." ] }, { "cell_type": "markdown", "id": "789bed40", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"env.num_workers\": 2})\n", "# use 4 workers in the training dataloader\n", "predictor.fit(hyperparameters={\"env.num_workers\": 4})\n", "```\n" ] }, { "cell_type": "markdown", "id": "86faccf9", "metadata": {}, "source": [ "### env.num_workers_evaluation\n", "\n", "The number of worker processes used by the Pytorch dataloader in prediction or evaluation." ] }, { "cell_type": "markdown", "id": "c4040737", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"env.num_workers_evaluation\": 2})\n", "# use 4 workers in the prediction/evaluation dataloader\n", "predictor.fit(hyperparameters={\"env.num_workers_evaluation\": 4})\n", "```\n" ] }, { "cell_type": "markdown", "id": "a4ea42b0", "metadata": {}, "source": [ "### env.strategy\n", "\n", "Distributed training mode.\n", "\n", "- `\"dp\"`: data parallel.\n", "- `\"ddp\"`: distributed data parallel (python script based).\n", "- `\"ddp_spawn\"`: distributed data parallel (spawn based).\n", "\n", "See [here](https://lightning.ai/docs/pytorch/stable/extensions/strategy.html#selecting-a-built-in-strategy) for more details." ] }, { "cell_type": "markdown", "id": "ab1c3e6f", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"env.strategy\": \"ddp_spawn\"})\n", "# use ddp during training\n", "predictor.fit(hyperparameters={\"env.strategy\": \"ddp\"})\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### env.accelerator\n", "\n", "Support `\"cpu\"`, `\"gpu\"`, or `\"auto\"` (Default).\n", "In the auto mode, gpu has a higher priority if both cpu and gpu are available.\n", "\n", "See [here](https://lightning.ai/docs/pytorch/stable/common/trainer.html#accelerator) for more details." ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"env.accelerator\": \"auto\"})\n", "# use cpu for training\n", "predictor.fit(hyperparameters={\"env.accelerator\": \"cpu\"})\n", "```\n" ] }, { "cell_type": "markdown", "source": [ "### env.compile.turn_on\n", "\n", "Whether to compile Pytorch models through [torch.compile](https://pytorch.org/docs/stable/generated/torch.compile.html). (Default False)\n", "Note that compiling model can cost some time. It is recommended for large models and long time training." ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"env.compile.turn_on\": False})\n", "# turn on torch.compile\n", "predictor.fit(hyperparameters={\"env.compile.turn_on\": True})\n", "```\n" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "### env.compile.mode\n", "\n", "Can be either `“default”`, `“reduce-overhead”`, `“max-autotune”` or `“max-autotune-no-cudagraphs”`.\n", "For details, refer to [torch.compile](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile)." ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"env.compile.mode\": \"default\"})\n", "# reduces the overhead of python with CUDA graphs, useful for small batches.\n", "predictor.fit(hyperparameters={\"env.compile.mode\": “reduce-overhead”})\n", "```\n" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "### env.compile.dynamic\n", "\n", "Whether to use dynamic shape tracing (Default True). For details, refer to [torch.compile](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile)." ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"env.compile.dynamic\": True})\n", "# assumes a static input shape across mini-batches.\n", "predictor.fit(hyperparameters={\"env.compile.dynamic\": False})\n", "```\n" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "### env.compile.backend\n", "\n", "Backend to be used when compiling the model. For details, refer to [torch.compile](https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile)." ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"env.compile.backend\": \"inductor\"})\n", "```\n" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "id": "f2b4a051", "metadata": {}, "source": [ "## Model\n", "\n", "### model.names\n", "\n", "Choose what types of models to use.\n", "\n", "- `\"hf_text\"`: the pretrained text models from [Huggingface](https://huggingface.co/).\n", "- `\"timm_image\"`: the pretrained image models from [TIMM](https://github.com/rwightman/pytorch-image-models/tree/master/timm/models).\n", "- `\"clip\"`: the pretrained CLIP models.\n", "- `\"categorical_mlp\"`: MLP for categorical data.\n", "- `\"numerical_mlp\"`: MLP for numerical data.\n", "- `\"ft_transformer\"`: [FT-Transformer](https://arxiv.org/pdf/2106.11959.pdf) for tabular (categorical and numerical) data.\n", "- `\"fusion_mlp\"`: MLP-based fusion for features from multiple backbones.\n", "- `\"fusion_transformer\"`: transformer-based fusion for features from multiple backbones.\n", "- `\"sam\"`: the pretrained Segment Anything Model from [Huggingface](https://huggingface.co/).\n", "\n", "If no data of one modality is detected, the related model types will be automatically removed in training." ] }, { "cell_type": "markdown", "id": "8bc0c2e5", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.names\": [\"hf_text\", \"timm_image\", \"clip\", \"categorical_mlp\", \"numerical_mlp\", \"fusion_mlp\"]})\n", "# use only text models\n", "predictor.fit(hyperparameters={\"model.names\": [\"hf_text\"]})\n", "# use only image models\n", "predictor.fit(hyperparameters={\"model.names\": [\"timm_image\"]})\n", "# use only clip models\n", "predictor.fit(hyperparameters={\"model.names\": [\"clip\"]})\n", "```\n" ] }, { "cell_type": "markdown", "id": "f2d1c833", "metadata": {}, "source": [ "### model.hf_text.checkpoint_name\n", "\n", "Specify a text backbone supported by the Hugginface [AutoModel](https://huggingface.co/transformers/v3.0.2/model_doc/auto.html#automodel)." ] }, { "cell_type": "markdown", "id": "27360756", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.hf_text.checkpoint_name\": \"google/electra-base-discriminator\"})\n", "# choose roberta base\n", "predictor.fit(hyperparameters={\"model.hf_text.checkpoint_name\": \"roberta-base\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "bff5c1bd", "metadata": {}, "source": [ "### model.hf_text.pooling_mode\n", "\n", "The feature pooling mode for transformer architectures.\n", "\n", "- `cls`: uses the cls feature vector to represent a sentence.\n", "- `mean`: averages all the token feature vectors to represent a sentence." ] }, { "cell_type": "markdown", "id": "1359b199", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.hf_text.pooling_mode\": \"cls\"})\n", "# using the mean pooling\n", "predictor.fit(hyperparameters={\"model.hf_text.pooling_mode\": \"mean\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "87a0ac4c", "metadata": {}, "source": [ "### model.hf_text.tokenizer_name\n", "\n", "Choose the text tokenizer. It is recommended to use the default auto tokenizer.\n", "\n", "- `hf_auto`: the [Huggingface auto tokenizer](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoTokenizer).\n", "- `bert`: the [BERT tokenizer](https://huggingface.co/docs/transformers/v4.21.1/en/model_doc/bert#transformers.BertTokenizer).\n", "- `electra`: the [ELECTRA tokenizer](https://huggingface.co/docs/transformers/v4.21.1/en/model_doc/electra#transformers.ElectraTokenizer).\n", "- `clip`: the [CLIP tokenizer](https://huggingface.co/docs/transformers/v4.21.1/en/model_doc/clip#transformers.CLIPTokenizer)." ] }, { "cell_type": "markdown", "id": "c50a4d84", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.hf_text.tokenizer_name\": \"hf_auto\"})\n", "# using the tokenizer of the ELECTRA model\n", "predictor.fit(hyperparameters={\"model.hf_text.tokenizer_name\": \"electra\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "8372c024", "metadata": {}, "source": [ "### model.hf_text.max_text_len\n", "\n", "Set the maximum text length. Different models may allow different maximum lengths. If `model.hf_text.max_text_len` > 0, we choose the minimum between `model.hf_text.max_text_len` and the maximum length allowed by the model. Setting `model.hf_text.max_text_len` <= 0 would use the model's maximum length." ] }, { "cell_type": "markdown", "id": "1db84e23", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.hf_text.max_text_len\": 512})\n", "# set to use the length allowed by the tokenizer.\n", "predictor.fit(hyperparameters={\"model.hf_text.max_text_len\": -1})\n", "```\n" ] }, { "cell_type": "markdown", "id": "67a57f56", "metadata": {}, "source": [ "### model.hf_text.insert_sep\n", "\n", "Whether to insert the SEP token between texts from different columns of a dataframe." ] }, { "cell_type": "markdown", "id": "61c8f6b9", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.hf_text.insert_sep\": True})\n", "# use no SEP token.\n", "predictor.fit(hyperparameters={\"model.hf_text.insert_sep\": False})\n", "```\n" ] }, { "cell_type": "markdown", "id": "525da692", "metadata": {}, "source": [ "### model.hf_text.text_segment_num\n", "\n", "How many text segments are used in a token sequence. Each text segment has one [token type ID](https://huggingface.co/transformers/v2.11.0/glossary.html#token-type-ids). We choose the minimum between `model.hf_text.text_segment_num` and the default used by the model." ] }, { "cell_type": "markdown", "id": "519d778a", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.hf_text.text_segment_num\": 2})\n", "# use 1 text segment\n", "predictor.fit(hyperparameters={\"model.hf_text.text_segment_num\": 1})\n", "```\n" ] }, { "cell_type": "markdown", "id": "ffe782fd", "metadata": {}, "source": [ "### model.hf_text.stochastic_chunk\n", "\n", "Whether to randomly cut a text chunk if a sample's text token number is larger than `model.hf_text.max_text_len`. If False, cut a token sequence from index 0 to the maximum allowed length. Otherwise, randomly sample a start index to cut a text chunk." ] }, { "cell_type": "markdown", "id": "71f3f97e", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.hf_text.stochastic_chunk\": False})\n", "# select a stochastic text chunk if a text sequence is over-long\n", "predictor.fit(hyperparameters={\"model.hf_text.stochastic_chunk\": True})\n", "```\n" ] }, { "cell_type": "markdown", "id": "3d63c524", "metadata": {}, "source": [ "### model.hf_text.text_aug_detect_length\n", "\n", "Perform text augmentation only when the text token number is no less than `model.hf_text.text_aug_detect_length`." ] }, { "cell_type": "markdown", "id": "2538fbd3", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.hf_text.text_aug_detect_length\": 10})\n", "# Allow text augmentation for texts whose token number is no less than 5\n", "predictor.fit(hyperparameters={\"model.hf_text.text_aug_detect_length\": 5})\n", "```\n" ] }, { "cell_type": "markdown", "id": "c4755a0c", "metadata": {}, "source": [ "### model.hf_text.text_trivial_aug_maxscale\n", "\n", "Set the maximum percentage of text tokens to conduct data augmentation. For each text token sequence, we randomly sample a percentage in [0, `model.hf_text.text_trivial_aug_maxscale`] and one operation from four trivial augmentations, including synonym replacement, random word swap, random word deletion, and random punctuation insertion, to do text augmentation." ] }, { "cell_type": "markdown", "id": "0b44d07f", "metadata": {}, "source": [ "```\n", "# by default, AutoMM doesn't do text augmentation\n", "predictor.fit(hyperparameters={\"model.hf_text.text_trivial_aug_maxscale\": 0})\n", "# Enable trivial augmentation by setting the max scale to 0.1\n", "predictor.fit(hyperparameters={\"model.hf_text.text_trivial_aug_maxscale\": 0.1})\n", "```\n" ] }, { "cell_type": "markdown", "id": "5019b9f4", "metadata": {}, "source": [ "### model.hf_text.gradient_checkpointing\n", "\n", "Whether to turn on gradient checkpointing to reduce the memory consumption for calculating gradients. For more about gradient checkpointing, feel free to refer to [relevant tutorials](https://github.com/cybertronai/gradient-checkpointing)." ] }, { "cell_type": "markdown", "id": "38476035", "metadata": {}, "source": [ "```\n", "# by default, AutoMM doesn't turn on gradient checkpointing\n", "predictor.fit(hyperparameters={\"model.hf_text.gradient_checkpointing\": False})\n", "# Turn on gradient checkpointing\n", "predictor.fit(hyperparameters={\"model.hf_text.gradient_checkpointing\": True})\n", "```\n" ] }, { "cell_type": "markdown", "id": "9b29a7b", "metadata": {}, "source": [ "### model.ft_transformer.checkpoint_name\n", "\n", "Using local pre-trained weights or link to pre-trained weights to initialize ft_transformer backbone." ] }, { "cell_type": "markdown", "id": "6s39392", "metadata": {}, "source": [ "```\n", "# by default, AutoMM doesn't use pre-trained weights\n", "predictor.fit(hyperparameters={\"model.ft_transformer.checkpoint_name\": None})\n", "# initialize the ft_transformer backbone from local checkpoint\n", "predictor.fit(hyperparameters={\"model.ft_transformer.checkpoint_name\": 'my_checkpoint.ckpt'})\n", "# initialize the ft_transformer backbone from url of checkpoint\n", "predictor.fit(hyperparameters={\"model.ft_transformer.checkpoint_name\": 'https://automl-mm-bench.s3.amazonaws.com/ft_transformer_pretrained_ckpt/iter_2k.ckpt'})\n", "```" ] }, { "cell_type": "markdown", "id": "9b3d3a7b", "metadata": {}, "source": [ "### model.ft_transformer.num_blocks\n", "\n", "Number of transformer blocks in the ft_transformer backbone." ] }, { "cell_type": "markdown", "id": "642d7392", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.ft_transformer.num_blocks\": 3})\n", "# increase the number of blocks to 5 in ft_transformer\n", "predictor.fit(hyperparameters={\"model.ft_transformer.num_blocks\": 5})\n", "```" ] }, { "cell_type": "markdown", "id": "5340d090", "metadata": {}, "source": [ "### model.ft_transformer.token_dim\n", "\n", "The dimension of tokens after categorical and numerical tokenizer in ft_transformer." ] }, { "cell_type": "markdown", "id": "780bddb0", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.ft_transformer.token_dim\": 192})\n", "# increase the token dimension to 256 in ft_transformer\n", "predictor.fit(hyperparameters={\"model.ft_transformer.token_dim\": 256})\n", "```" ] }, { "cell_type": "markdown", "id": "87348422", "metadata": {}, "source": [ "### model.ft_transformer.hidden_size\n", "\n", "The model embedding dimension of ft_transformer backbone." ] }, { "cell_type": "markdown", "id": "996c3a0e", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.ft_transformer.hidden_size\": 192})\n", "# increase the model embedding dimension to 256 in ft_transformer\n", "predictor.fit(hyperparameters={\"model.ft_transformer.hidden_size\": 256})\n", "```" ] }, { "cell_type": "markdown", "id": "d6523568", "metadata": {}, "source": [ "### model.ft_transformer.ffn_hidden_size\n", "\n", "The hidden layer dimension of the FFN (Feed-Forward) layer in [ft_transformer blocks](https://arxiv.org/pdf/2106.11959v5.pdf). In the [Transformer](https://arxiv.org/pdf/1706.03762.pdf) paper, the hidden layer dimension in FFN is set to $4\\times$ of the model hidden size. Here, we set it equal to the model hidden size by default." ] }, { "cell_type": "markdown", "id": "3e448822", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.ft_transformer.ffn_hidden_size\": 192})\n", "# increase the FFN hidden layer dimension to 256 in ft_transformer\n", "predictor.fit(hyperparameters={\"model.ft_transformer.ffn_hidden_size\": 256})\n", "```" ] }, { "cell_type": "markdown", "id": "ec6d7d31", "metadata": {}, "source": [ "### model.timm_image.checkpoint_name\n", "\n", "Select an image backbone from [TIMM](https://github.com/rwightman/pytorch-image-models/tree/master/timm/models)." ] }, { "cell_type": "markdown", "id": "20aa69bb", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.timm_image.checkpoint_name\": \"swin_base_patch4_window7_224\"})\n", "# choose a vit base\n", "predictor.fit(hyperparameters={\"model.timm_image.checkpoint_name\": \"vit_base_patch32_224\"})\n", "```" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### model.timm_image.train_transforms\n", "\n", "Augment images in training. Support passing a list of supported strings chosen from (`resize_to_square`, `resize_shorter_side`, `center_crop`, `random_resize_crop`, `random_horizontal_flip`, `random_vertical_flip`, `color_jitter`, `affine`, `randaug`, `trivial_augment`), or a list of callable and pickle-able transform objects. For example, you use the torchvision transforms (https://pytorch.org/vision/stable/transforms.html)." ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.timm_image.train_transforms\": [\"resize_shorter_side\", \"center_crop\", \"trivial_augment\"]})\n", "# use random resize crop and random horizontal flip\n", "predictor.fit(hyperparameters={\"model.timm_image.train_transforms\": [\"random_resize_crop\", \"random_horizontal_flip\"]})\n", "# or use a list of callable and pickle-able objects, e.g., torchvision transforms\n", "predictor.fit(hyperparameters={\"model.timm_image.train_transforms\": [torchvision.transforms.RandomResizedCrop(224), torchvision.transforms.RandomHorizontalFlip()]})\n", "```" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### model.timm_image.val_transforms\n", "\n", "Transform images in validation/test/deployment. Similar to `model.timm_image.train_transforms`, support a list of strings or callable and pickle-able objects to transform images." ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.timm_image.val_transforms\": [\"resize_shorter_side\", \"center_crop\"]})\n", "# resize image to square\n", "predictor.fit(hyperparameters={\"model.timm_image.val_transforms\": [\"resize_to_square\"]})\n", "# or use a list of callable and pickle-able objects, e.g., torchvision transforms\n", "predictor.fit(hyperparameters={\"model.timm_image.val_transforms\": [torchvision.transforms.Resize((224, 224)]})\n", "```\n" ] }, { "cell_type": "markdown", "id": "db39fec1", "metadata": {}, "source": [ "### model.mmdet_image.checkpoint_name\n", "\n", "Specify a MMDetection model supported by [MMDetection](https://mmdetection.readthedocs.io/en/latest/user_guides/inference.html). Please use \"yolox_nano\", \"yolox_tiny\", \"yolox_s\", \"yolox_m\", \"yolox_l\", or \"yolox_x\" to run our modified YOLOX models that are compatible to Autogluon." ] }, { "cell_type": "markdown", "id": "cf07cc08", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.checkpoint_name\": \"yolov3_mobilenetv2_8xb24-320-300e_coco\"})\n", "# choose YOLOX-L\n", "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.checkpoint_name\": \"yolox_l\"})\n", "# choose DINO-SwinL\n", "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.checkpoint_name\": \"dino-5scale_swin-l_8xb2-36e_coco\"})\n", "```" ] }, { "cell_type": "markdown", "id": "64f465e9", "metadata": {}, "source": [ "### model.mmdet_image.output_bbox_format\n", "\n", "The output bounding box format:\n", "\n", "- `\"xyxy\"`: Output [x1,y1,x2,y2]. Bounding boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right. This is our default output format.\n", "- `\"xywh\"`: Output [x1,y1,w,h]. Bounding boxes are represented via corner, width and height, x1, y1 being top left, w, h being width and height." ] }, { "cell_type": "markdown", "id": "87be5d56", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.output_bbox_format\": \"xyxy\"})\n", "# choose xywh output format\n", "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.output_bbox_format\": \"xywh\"})\n", "```" ] }, { "cell_type": "markdown", "id": "30c7ec4d", "metadata": {}, "source": [ "### model.mmdet_image.frozen_layers\n", "\n", "The layers to be frozen. All layers that contain such substring will be frozen." ] }, { "cell_type": "markdown", "source": [ "```\n", "# default used by AutoMM, freeze nothing and update all parameters\n", "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.frozen_layers\": []})\n", "# freeze the model's backbone\n", "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.frozen_layers\": [\"backbone\"]})\n", "# freeze the model's backbone and neck\n", "predictor = MultiModalPredictor(hyperparameters={\"model.mmdet_image.frozen_layers\": [\"backbone\", \"neck\"]})\n", "```" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "id": "56ee3f92", "metadata": {}, "source": [ "### model.sam.checkpoint_name\n", "\n", "Specify a SAM backbone supported by the Hugginface [SAM](https://huggingface.co/docs/transformers/main/model_doc/sam)." ] }, { "cell_type": "markdown", "id": "ccbd46cb", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.sam.checkpoint_name\": \"facebook/sam-vit-huge\"})\n", "# choose SAM-Large\n", "predictor.fit(hyperparameters={\"model.sam.checkpoint_name\": \"facebook/sam-vit-large\"})\n", "# choose SAM-Base\n", "predictor.fit(hyperparameters={\"model.sam.checkpoint_name\": \"facebook/sam-vit-base\"})\n", "```" ] }, { "cell_type": "markdown", "id": "b106e2c8", "metadata": {}, "source": [ "### model.sam.train_transforms\n", "\n", "Augment images in training. Support passing `random_horizontal_flip` currently." ] }, { "cell_type": "markdown", "id": "11b22638", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.sam.train_transforms\": [\"random_horizontal_flip\"]})\n", "```" ] }, { "cell_type": "markdown", "id": "bc1433cd", "metadata": {}, "source": [ "### model.sam.img_transforms\n", "\n", "Process input images for semantic segmentation. Support passing `resize_to_square` currently." ] }, { "cell_type": "markdown", "id": "2ffc0e05", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.sam.img_transforms\": [\"resize_to_square\"]})\n", "```" ] }, { "cell_type": "markdown", "id": "78130e5c", "metadata": {}, "source": [ "### model.sam.gt_transforms\n", "\n", "Process ground truth masks for semantic segmentation. Support passing `resize_gt_to_square` currently." ] }, { "cell_type": "markdown", "id": "5964c3b5", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.sam.gt_transforms\": [\"resize_gt_to_square\"]})\n", "```" ] }, { "cell_type": "markdown", "id": "8ca01fcf", "metadata": {}, "source": [ "### model.sam.frozen_layers\n", "\n", "Freeze the modules of SAM in training. " ] }, { "cell_type": "markdown", "id": "4377a293", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.sam.frozen_layers\": [\"mask_decoder.iou_prediction_head\", \"prompt_encoder\"]})\n", "```" ] }, { "cell_type": "markdown", "id": "ce7a5e4c", "metadata": {}, "source": [ "### model.sam.num_mask_tokens\n", "\n", "The number of mask proposals of SAM's mask decoder." ] }, { "cell_type": "markdown", "id": "9be77770", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.sam.num_mask_tokens\": 1})\n", "```" ] }, { "cell_type": "markdown", "id": "d3ef9e73", "metadata": {}, "source": [ "### model.sam.ignore_label\n", "\n", "Specifies a target value that is ignored and does not contribute to the training loss and metric calculation." ] }, { "cell_type": "markdown", "id": "5d2e0373", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"model.sam.ignore_label\": 255})\n", "```" ] }, { "cell_type": "markdown", "id": "690a7766", "metadata": {}, "source": [ "## Data\n", "\n", "### data.image.missing_value_strategy\n", "\n", "How to deal with missing images, opening which fails.\n", "\n", "- `\"skip\"`: skip a sample with missing images.\n", "- `\"zero\"`: use zero image to replace a missing image." ] }, { "cell_type": "markdown", "id": "ed5ad640", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.image.missing_value_strategy\": \"zero\"})\n", "# skip the image\n", "predictor.fit(hyperparameters={\"data.image.missing_value_strategy\": \"skip\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "2b5a689b", "metadata": {}, "source": [ "### data.text.normalize_text\n", "Whether to normalize text with encoding problems. If True, TextProcessor will run through a series of encoding and decoding for text normalization. Please refer to the [Example](https://github.com/autogluon/autogluon/tree/master/examples/automm/kaggle_feedback_prize) of Kaggle competition for applying text normalization." ] }, { "cell_type": "markdown", "id": "ab6c46ad", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.text.normalize_text\": False})\n", "# turn on text normalization\n", "predictor.fit(hyperparameters={\"data.text.normalize_text\": True})\n", "```\n" ] }, { "cell_type": "markdown", "id": "177eb155", "metadata": {}, "source": [ "### data.categorical.convert_to_text\n", "\n", "Whether to treat categorical data as text. If True, no categorical models, e.g., `\"categorical_mlp\"` and `\"categorical_transformer\"`, would be used." ] }, { "cell_type": "markdown", "id": "267ad0a9", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.categorical.convert_to_text\": True})\n", "# turn off the conversion\n", "predictor.fit(hyperparameters={\"data.categorical.convert_to_text\": False})\n", "```\n" ] }, { "cell_type": "markdown", "id": "3bf8d9e6", "metadata": {}, "source": [ "### data.numerical.convert_to_text\n", "\n", "Whether to convert numerical data to text. If True, no numerical models e.g., `\"numerical_mlp\"` and `\"numerical_transformer\"`, would be used." ] }, { "cell_type": "markdown", "id": "f158d9a0", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.numerical.convert_to_text\": False})\n", "# turn on the conversion\n", "predictor.fit(hyperparameters={\"data.numerical.convert_to_text\": True})\n", "```\n" ] }, { "cell_type": "markdown", "id": "daaa41c9", "metadata": {}, "source": [ "### data.numerical.scaler_with_mean\n", "\n", "If True, center the numerical data (not including the numerical labels) before scaling." ] }, { "cell_type": "markdown", "id": "984abb92", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.numerical.scaler_with_mean\": True})\n", "# turn off centering\n", "predictor.fit(hyperparameters={\"data.numerical.scaler_with_mean\": False})\n", "```\n" ] }, { "cell_type": "markdown", "id": "589677a4", "metadata": {}, "source": [ "### data.numerical.scaler_with_std\n", "\n", "If True, scale the numerical data (not including the numerical labels) to unit variance." ] }, { "cell_type": "markdown", "id": "8cfca7db", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.numerical.scaler_with_std\": True})\n", "# turn off scaling\n", "predictor.fit(hyperparameters={\"data.numerical.scaler_with_std\": False})\n", "```\n" ] }, { "cell_type": "markdown", "id": "9241360b", "metadata": {}, "source": [ "### data.label.numerical_label_preprocessing\n", "\n", "How to process the numerical labels in regression tasks.\n", "\n", "- `\"standardscaler\"`: standardizes numerical labels by removing the mean and scaling to unit variance.\n", "- `\"minmaxscaler\"`: transforms numerical labels by scaling each feature to range (0, 1)." ] }, { "cell_type": "markdown", "id": "bea7d018", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.label.numerical_label_preprocessing\": \"standardscaler\"})\n", "# scale numerical labels to (0, 1)\n", "predictor.fit(hyperparameters={\"data.label.numerical_label_preprocessing\": \"minmaxscaler\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "103f689a", "metadata": {}, "source": [ "### data.pos_label\n", "\n", "The positive label in a binary classification task. Users need to specify this label to properly use some metrics, e.g., roc_auc, average_precision, and f1." ] }, { "cell_type": "markdown", "id": "1a14b1af", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.pos_label\": None})\n", "# assume the labels are [\"changed\", \"not changed\"] and \"changed\" is the positive label\n", "predictor.fit(hyperparameters={\"data.pos_label\": \"changed\"})\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### data.column_features_pooling_mode\n", "\n", "How to aggregate column features into one feature vector for a dataframe with multiple feature columns. Currently, it works only for `few_shot_classification`.\n", "- `\"concat\"`: Concatenate features of different columns into a long feature vector.\n", "- `\"mean\"`: Average the column features so that the feature dimension doesn't increase along with the column number." ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.column_features_pooling_mode\": \"concat\"})\n", "# use the mean pooling\n", "predictor.fit(hyperparameters={\"data.column_features_pooling_mode\": \"mean\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "e3e3242f", "metadata": {}, "source": [ "### data.mixup.turn_on\n", "\n", "If True, use Mixup in training." ] }, { "cell_type": "markdown", "id": "0161ba46", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.mixup.turn_on\": False})\n", "# turn on Mixup\n", "predictor.fit(hyperparameters={\"data.mixup.turn_on\": True})\n", "```\n" ] }, { "cell_type": "markdown", "id": "fd97b924", "metadata": {}, "source": [ "### data.mixup.mixup_alpha\n", "\n", "Mixup alpha value. Mixup is active if `data.mixup.mixup_alpha` > 0." ] }, { "cell_type": "markdown", "id": "dd9bc14b", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.mixup.mixup_alpha\": 0.8})\n", "# set it to 1.0 to turn off Mixup\n", "predictor.fit(hyperparameters={\"data.mixup.mixup_alpha\": 1.0})\n", "```\n" ] }, { "cell_type": "markdown", "id": "b48c9408", "metadata": {}, "source": [ "### data.mixup.cutmix_alpha\n", "\n", "Cutmix alpha value. Cutmix is active if `data.mixup.cutmix_alpha` > 0." ] }, { "cell_type": "markdown", "id": "9fc9b53a", "metadata": {}, "source": [ "```\n", "# by default, Cutmix is turned off by using alpha 1.0\n", "predictor.fit(hyperparameters={\"data.mixup.cutmix_alpha\": 1.0})\n", "# turn it on by choosing a number in range (0, 1)\n", "predictor.fit(hyperparameters={\"data.mixup.cutmix_alpha\": 0.8})\n", "```\n" ] }, { "cell_type": "markdown", "id": "b3a58751", "metadata": {}, "source": [ "### data.mixup.prob\n", "\n", "The probability of conducting Mixup or Cutmix if enabled." ] }, { "cell_type": "markdown", "id": "738cdcc7", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.mixup.prob\": 1.0})\n", "# set probability to 0.5\n", "predictor.fit(hyperparameters={\"data.mixup.prob\": 0.5})\n", "```\n" ] }, { "cell_type": "markdown", "id": "5991c094", "metadata": {}, "source": [ "### data.mixup.switch_prob\n", "\n", "The probability of switching to Cutmix instead of Mixup when both are active." ] }, { "cell_type": "markdown", "id": "d24393ef", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.mixup.switch_prob\": 0.5})\n", "# set probability to 0.7\n", "predictor.fit(hyperparameters={\"data.mixup.switch_prob\": 0.7})\n", "```\n" ] }, { "cell_type": "markdown", "id": "ab459677", "metadata": {}, "source": [ "### data.mixup.mode\n", "\n", "How to apply Mixup or Cutmix params (per `\"batch\"`, `\"pair\"` (pair of elements), `\"elem\"` (element)).\n", "See [here](https://github.com/rwightman/pytorch-image-models/blob/d30685c283137b4b91ea43c4e595c964cd2cb6f0/timm/data/mixup.py#L211-L216) for more details." ] }, { "cell_type": "markdown", "id": "ada57733", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.mixup.mode\": \"batch\"})\n", "# use \"pair\"\n", "predictor.fit(hyperparameters={\"data.mixup.mode\": \"pair\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "737d454d", "metadata": {}, "source": [ "### data.mixup.label_smoothing\n", "\n", "Apply label smoothing to the mixed label tensors." ] }, { "cell_type": "markdown", "id": "40f1d216", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.mixup.label_smoothing\": 0.1})\n", "# set it to 0.2\n", "predictor.fit(hyperparameters={\"data.mixup.label_smoothing\": 0.2})\n", "```\n" ] }, { "cell_type": "markdown", "id": "73c8a29d", "metadata": {}, "source": [ "### data.mixup.turn_off_epoch\n", "\n", "Stop Mixup or Cutmix after reaching this number of epochs." ] }, { "cell_type": "markdown", "id": "a0c3715f", "metadata": {}, "source": [ "```\n", "# default used by AutoMM\n", "predictor.fit(hyperparameters={\"data.mixup.turn_off_epoch\": 5})\n", "# turn off mixup after 7 epochs\n", "predictor.fit(hyperparameters={\"data.mixup.turn_off_epoch\": 7})\n", "```\n" ] }, { "cell_type": "markdown", "id": "4fb33f07", "metadata": {}, "source": [ "## Distiller\n", "\n", "### distiller.soft_label_loss_type\n", "\n", "What loss to compute when using teacher's output (logits) to supervise student's." ] }, { "cell_type": "markdown", "id": "f3ca2c3d", "metadata": {}, "source": [ "```\n", "# default used by AutoMM for classification\n", "predictor.fit(hyperparameters={\"distiller.soft_label_loss_type\": \"cross_entropy\"})\n", "# default used by AutoMM for regression\n", "predictor.fit(hyperparameters={\"distiller.soft_label_loss_type\": \"mse\"})\n", "```\n" ] }, { "cell_type": "markdown", "id": "c91287a0", "metadata": {}, "source": [ "### distiller.temperature\n", "\n", "Before computing the soft label loss, scale the teacher and student logits with it (teacher_logits / temperature, student_logits / temperature)." ] }, { "cell_type": "markdown", "id": "4f67e3e1", "metadata": {}, "source": [ "```\n", "# default used by AutoMM for classification\n", "predictor.fit(hyperparameters={\"distiller.temperature\": 5})\n", "# set temperature to 1\n", "predictor.fit(hyperparameters={\"distiller.temperature\": 1})\n", "```\n" ] }, { "cell_type": "markdown", "id": "2f95727c", "metadata": {}, "source": [ "### distiller.hard_label_weight\n", "\n", "Scale the student's hard label (groundtruth) loss with this weight (hard_label_loss \\* hard_label_weight)." ] }, { "cell_type": "markdown", "id": "5f5d5eca", "metadata": {}, "source": [ "```\n", "# default used by AutoMM for classification\n", "predictor.fit(hyperparameters={\"distiller.hard_label_weight\": 0.2})\n", "# set not to scale the hard label loss\n", "predictor.fit(hyperparameters={\"distiller.hard_label_weight\": 1})\n", "```\n" ] }, { "cell_type": "markdown", "id": "0ebc0b75", "metadata": {}, "source": [ "### distiller.soft_label_weight\n", "\n", "Scale the student's soft label (teacher's output) loss with this weight (soft_label_loss \\* soft_label_weight)." ] }, { "cell_type": "markdown", "id": "8b3b90c2", "metadata": {}, "source": [ "```\n", "# default used by AutoMM for classification\n", "predictor.fit(hyperparameters={\"distiller.soft_label_weight\": 50})\n", "# set not to scale the soft label loss\n", "predictor.fit(hyperparameters={\"distiller.soft_label_weight\": 1})\n", "```\n" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }